├── .gitignore ├── LICENSE ├── README.md ├── bin └── reddit.sh ├── conf ├── hred_large.yml ├── hred_small.yml ├── taware_large.yml ├── taware_small.yml ├── taware_xlarge.yml ├── thred_large.yml ├── thred_medium.yml ├── vanilla_large.yml ├── vanilla_small.yml ├── vanilla_xlarge.yml └── word_embeddings.yml ├── requirements.txt ├── setup.py ├── thred ├── __init__.py ├── __main__.py ├── corpora │ ├── __init__.py │ ├── corpus_toolkit.py │ └── reddit │ │ ├── __init__.py │ │ ├── profanity_words.txt │ │ ├── reddit_bots.txt │ │ ├── reddit_dialogue.py │ │ ├── reddit_parser.py │ │ ├── reddit_utils.py │ │ ├── sanitizer.py │ │ └── subreddit_whitelist.txt ├── main.py ├── models │ ├── __init__.py │ ├── attention_helper.py │ ├── base.py │ ├── data_utils.py │ ├── hierarchical_base.py │ ├── hred │ │ ├── __init__.py │ │ ├── hred_helper.py │ │ ├── hred_iterators.py │ │ ├── hred_model.py │ │ └── hred_wrapper.py │ ├── model_factory.py │ ├── model_helper.py │ ├── ncm_utils.py │ ├── thred │ │ ├── __init__.py │ │ ├── thred_helper.py │ │ ├── thred_iterators.py │ │ ├── thred_model.py │ │ └── thred_wrapper.py │ ├── topic_aware │ │ ├── __init__.py │ │ ├── taware_decoder.py │ │ ├── taware_helper.py │ │ ├── taware_iterators.py │ │ ├── taware_layer.py │ │ ├── taware_model.py │ │ └── taware_wrapper.py │ ├── topical_base.py │ └── vanilla │ │ ├── __init__.py │ │ ├── bleu.py │ │ ├── eval_metric.py │ │ ├── vanilla_helper.py │ │ ├── vanilla_iterators.py │ │ ├── vanilla_model.py │ │ └── vanilla_wrapper.py ├── topic_model │ ├── __init__.py │ ├── analyzer.py │ └── lda.py └── util │ ├── __init__.py │ ├── chartable.py │ ├── config.py │ ├── device.py │ ├── dull_responses.txt │ ├── embed.py │ ├── emots.txt │ ├── fs.py │ ├── kv.py │ ├── log.py │ ├── misc.py │ ├── nlp.py │ ├── rnn_factory.py │ ├── summary_statistics.py │ ├── twitter_nlp_emoticons.py │ ├── twokenize.py │ ├── vocab.py │ └── wget.py └── thred_env.yml /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.out 3 | .idea 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Nouha Dziri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository hosts the implementation of the paper "[Augmenting Neural Response Generation with Context-Aware Topical 2 | Attention](https://arxiv.org/abs/1811.01063)". 3 | 4 | # Topical Hierarchical Recurrent Encoder Decoder (THRED) 5 | THRED is a multi-turn response generation system intended to produce contextual and topic-aware responses. 6 | The codebase is evolved from the Tensorflow [NMT](https://github.com/tensorflow/nmt) repository. 7 | 8 | __TL;DR__ Steps to create a dialogue agent using this framework: 9 | 1. Download the Reddit Conversation Corpus from [here](https://drive.google.com/file/d/1HyaMz6bz3ju0qwyUKXQPY01r97y7eWD9/view?usp=sharing) (2.5GB download / 7.8GB after uncompressing, which contains triples extracted from Reddit). Please report errors/inappropriate content in the data [here](https://forms.gle/1WfWw5ABHx9GAaVV6). 10 | 2. Install the dependencies using `conda env create -f thred_env.yml` (To use `pip`, see [Dependencies](#dependencies)) 11 | 3. Train the model using the following command (pretrained models will be published soon). Note that `MODEL_DIR` is a directory that the model will be saved into. We recommend to train on at least 2 GPUs, otherwise you can reduce the data size (by omitting conversations from the training file) and the model size (by modifying the config file). 12 | ``` 13 | python -m thred --mode train --config conf/thred_medium.yml --model_dir \ 14 | --train_data --dev_data --test_data 15 | ``` 16 | 4. Chat with the trained model using: 17 | ``` 18 | python -m thred --mode interactive --model_dir 19 | ``` 20 | 21 | ## Dependencies 22 | - Python >= 3.5 (Recommended: 3.6) 23 | - Tensorflow == 1.12.0 24 | - Tensorflow-Hub 25 | - SpaCy >= 2.1.0 26 | - pymagnitude 27 | - tqdm 28 | - redis1 29 | - mistune1 30 | - emot1 31 | - Gensim1 32 | - prompt-toolkit2 33 | 34 | 1*packages required only for parsing and cleaning the Reddit data.* 35 | 2*used only for testing dialogue models in command-line interactive mode* 36 | 37 | To install the dependencies using `pip`, run `pip install -r requirements`. 38 | And for Anaconda, run `conda env create -f thred_env.yml` (recommended). 39 | Once done with the dependencies, run `pip install -e .` to install the thred package. 40 | 41 | ## Data 42 | Our Reddit dataset, which we call Reddit Conversation Corpus (RCC), is collected from 95 selected subreddits (listed [here](thred/corpora/reddit/subreddit_whitelist.txt)). 43 | We processed Reddit for a 20 month-period ranging from November 2016 until August 2018 (excluding June 2017 and July 2017; we utilized these two months along with the October 2016 data to train an LDA model). Please see [here](thred/corpora/reddit) for the details of how the Reddit dataset is built including pre-processing and cleaning the raw Reddit files. The following table summarizes the RCC information: 44 | 45 | | Corpus | #train| #dev | #test | Download | Download with topic words| 46 | |---------- |:-----:|:-----:|:-----:|:-----------|:-----------| 47 | | 3 turns per line | 9.2M | 508K | 406K | [download](https://drive.google.com/file/d/1jV0L7QhFHN7etNknE_wqN1dt__Ia_K_y/view?usp=sharing) (773MB) | [download](https://drive.google.com/file/d/1HyaMz6bz3ju0qwyUKXQPY01r97y7eWD9/view?usp=sharing) (2.5GB) | 48 | | 4 turns per line | 4M | 223K | 178K | [download](https://drive.google.com/file/d/1GbRLmtHFZlV4mCrcYCexfnw3uWN66qDe/view?usp=sharing) (442MB) | [download](https://drive.google.com/file/d/1xXLyi30E0GD7Qig7GGj8JAYuonwdxcCu/view?usp=sharing) (1.2GB) 49 | | 5 turns per line | 1.8M | 100K | 80K | [download](https://drive.google.com/file/d/1Mu3NXw4Af-Ivz9U_Zk_P8UJfoWCrIl8S/view?usp=sharing) (242MB) | [download](https://bit.ly/2JFmYKO) (594MB) 50 | 51 | In the data files, each line corresponds to a single conversation where utterances are TAB-separated. The topic words appear after the last utterance separated also by a TAB. 52 | 53 | Note that the 3-turns/4-turns/5-turns files contain similar content albeit with different number of utterances per line. They are all extracted from the same source. If you found any error or any inappropriate utterance in the data, please report your concerns [here](https://forms.gle/1WfWw5ABHx9GAaVV6). 54 | 55 | ### Embeddings 56 | In the model config files (i.e., the YAML files in [conf](conf)), the embedding types can be either of the following: `glove840B`, `fastText`, `word2vec`, and `hub_word2vec`. For handling the pre-trained embedding vectors, we leverage [Pymagnitude](https://github.com/plasticityai/magnitude/) and [Tensorflow-Hub](https://tfhub.dev/). 57 | Note that you can also use `random300` (300 refers to the dimension of embedding vectors and can be replaced by any arbitrary value) to learn vectors during training of the response generation models. The settings related to embedding models are provided in [word_embeddings.yml](conf/word_embeddings.yml). 58 | 59 | 60 | ## Train 61 | The training configuration should be defined in a YAML file similar to Tensorflow NMT. 62 | Sample configurations for THRED and other baselines are provided [here](conf). 63 | 64 | The implemented models are [Seq2Seq](https://arxiv.org/abs/1409.3215), [HRED](https://arxiv.org/abs/1605.06069), [Topic Aware-Seq2Seq](https://arxiv.org/abs/1606.08340), and THRED. 65 | 66 | Note that while most of the parameters are common among the different models, some models may have additional parameters 67 | (e.g., topical models have `topic_words_per_utterance` and `boost_topic_gen_prob` parameters). 68 | 69 | To train a model, run the following command: 70 | ```bash 71 | python main.py --mode train --config \ 72 | --train_data --dev_data --test_data \ 73 | --model_dir 74 | ``` 75 | In ``, vocabulary files and Tensorflow model files are stored. Training can be resumed by executing: 76 | ```bash 77 | python main.py --mode train --model_dir 78 | ``` 79 | 80 | ## Test 81 | With the following command, the model can be tested against the test dataset. 82 | 83 | ```bash 84 | python main.py --mode test --model_dir --test_data 85 | ``` 86 | It is possible to override test parameters during testing. 87 | These parameters are: beam width `--beam_width`, 88 | length penalty weight `--length_penalty_weight`, and sampling temperature `--sampling_temperature`. 89 | 90 | A simple command line interface is implemented that allows you to converse with the learned model (Similar to test mode, the test parameters can be overrided too): 91 | ```bash 92 | python main.py --mode interactive --model_dir 93 | ``` 94 | In the interactive mode, a pre-trained LDA model is required to feed the inferred topic words into the model. We trained an LDA model using Gensim on a Reddit corpus, collected for this purpose. 95 | It can be downloaded from [here](https://drive.google.com/file/d/1B3GZplM4YFV0A4l0rKte1rTl5ldeQWD5/view?usp=sharing). 96 | The downloaded file should be uncompressed and passed to the program via `--lda_model_dir `. 97 | 98 | ## Citation 99 | Please cite the following paper if you used our work in your research: 100 | ```text 101 | @article{dziri2018augmenting, 102 | title={Augmenting Neural Response Generation with Context-Aware Topical Attention}, 103 | author={Dziri, Nouha and Kamalloo, Ehsan and Mathewson, Kory W and Zaiane, Osmar R}, 104 | journal={arXiv preprint arXiv:1811.01063}, 105 | year={2018} 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /bin/reddit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BOLD='\033[1m' 4 | RED='\033[0;31m' 5 | BROWN='\033[0;33m' 6 | NORMAL='\033[0m' 7 | 8 | 9 | while [[ $# > 1 ]] 10 | do 11 | key="$1" 12 | 13 | case $key in 14 | -b|--batch-size) 15 | BATCH_SIZE="$2" 16 | shift # past argument 17 | ;; 18 | -o|--out-dir) 19 | OUT_DIR="$2" 20 | shift # past argument 21 | ;; 22 | --max-words) 23 | MAX_WORDS="$2" 24 | shift # past argument 25 | ;; 26 | --max-chars) 27 | MAX_CHARS="$2" 28 | shift # past argument 29 | ;; 30 | --min-words) 31 | MIN_WORDS="$2" 32 | shift # past argument 33 | ;; 34 | --min-chars) 35 | MIN_CHARS="$2" 36 | shift # past argument 37 | ;; 38 | -t|--subreddits) 39 | SUBREDDIT_FILE="$2" 40 | shift # past argument 41 | ;; 42 | -c|--comments) 43 | COMMENTS_FILE="$2" 44 | shift # past argument 45 | ;; 46 | -s|--submissions) 47 | SUBMISSIONS_FILE="$2" 48 | shift # past argument 49 | ;; 50 | -k|--skip-lines) 51 | SKIP_LINES="$2" 52 | shift # past argument 53 | ;; 54 | -l|--log) 55 | LOG_FILE="$2" 56 | shift # past argument 57 | ;; 58 | *) 59 | # unknown option 60 | ;; 61 | esac 62 | shift # past argument or value 63 | done 64 | 65 | # Check requirements 66 | if [ -z "$COMMENTS_FILE" ] && [ -z "$SUBMISSIONS_FILE" ] 67 | then 68 | echo -e "${RED}An input compressed file must be provided either using -c (for comments) or -s (for submissions)${NORMAL}" 69 | exit 1 70 | elif [ ! -z "$COMMENTS_FILE" ]; then 71 | filename="${COMMENTS_FILE##*/}" 72 | ext="${COMMENTS_FILE##*.}" 73 | elif [ ! -z "$SUBMISSIONS_FILE" ]; then 74 | filename="${SUBMISSIONS_FILE##*/}" 75 | ext="${SUBMISSIONS_FILE##*.}" 76 | fi 77 | 78 | filename="${filename%.*}" 79 | ext=${ext,,} 80 | if [ "$ext" != "bz2" ] && [ "$ext" != "xz" ] && [ $ext != "bzip2" ]; then 81 | echo -e "${RED}The input file must be either bz2 or xz, but it is ${ext}${NORMAL}" 82 | exit 1 83 | fi 84 | 85 | PYTHON_CMD="python3" 86 | cmd_check=$(command -v python3) 87 | if [ ! -z "$cmd_check" ]; then 88 | version=$(python3 --version 2>&1 | cut -f2 -d ' ') 89 | version=${version%%.*} 90 | if [ "$version" != "3" ]; then 91 | PYTHON_CMD="" 92 | fi 93 | fi 94 | 95 | if [ -z "$PYTHON_CMD" ]; then 96 | command -v python >/dev/null 2>&1 || { echo -e "${RED}Python command (version 3) not found${NORMAL}"; exit 1; } 97 | version=$(python --version 2>&1 | cut -f2 -d ' ') 98 | version=${version%%.*} 99 | if [ "$version" != "3" ]; then 100 | echo -e "${RED}Make sure you have a Python version 3 command in the PATH${NORMAL}" 101 | exit 1 102 | fi 103 | PYTHON_CMD="python" 104 | fi 105 | 106 | if [ ! -f "thred/corpora/reddit/reddit_parser.py" ] 107 | then 108 | if [ -f "../thred/corpora/reddit/reddit_parser.py" ] 109 | then 110 | cd .. 111 | else 112 | echo -e "${RED}Please go to the project base directory.${NORMAL}" 113 | exit 1 114 | fi 115 | fi 116 | 117 | 118 | # Optional args 119 | BATCH_ARG="" 120 | if [ ! -z "$BATCH_SIZE" ]; then 121 | BATCH_ARG="--batch_size $BATCH_SIZE" 122 | fi 123 | 124 | MAXW_ARG="" 125 | if [ ! -z "$MAX_WORDS" ]; then 126 | MAXW_ARG="--max_words $MAX_WORDS" 127 | fi 128 | 129 | MINW_ARG="" 130 | if [ ! -z "$MIN_WORDS" ]; then 131 | MINW_ARG="--min_words $MIN_WORDS" 132 | fi 133 | 134 | MAXC_ARG="" 135 | if [ ! -z "$MAX_CHARS" ]; then 136 | MAXC_ARG="--max_chars $MAX_CHARS" 137 | fi 138 | 139 | MINC_ARG="" 140 | if [ ! -z "$MIN_CHARS" ]; then 141 | MINC_ARG="--min_chars $MIN_CHARS" 142 | fi 143 | 144 | SKIP_ARG="" 145 | if [ ! -z "$SKIP_LINES" ]; then 146 | SKIP_ARG="--skip_lines $SKIP_LINES" 147 | fi 148 | 149 | if [ -z "$SUBREDDIT_FILE" ]; then 150 | SUBREDDIT_FILE="thred/corpora/reddit/subreddit_whitelist.txt" 151 | fi 152 | 153 | if [ -f "$LOG_FILE" ]; then 154 | rm -f "$LOG_FILE" 155 | echo -e "${BROWN}Existing log file removed${NORMAL}" 156 | fi 157 | 158 | COMMENTS_ARG="" 159 | SUBMISSIONS_ARG="" 160 | XZ_FILE="" 161 | if [ ! -z "$COMMENTS_FILE" ] 162 | then 163 | if [ "$ext" == "xz" ]; then 164 | COMMENTS_ARG="--comments_stream" 165 | XZ_FILE="$COMMENTS_FILE" 166 | else 167 | COMMENTS_ARG="--comments_file $COMMENTS_FILE" 168 | fi 169 | 170 | # finding the containing directory is taken from https://stackoverflow.com/a/40700120 171 | if [ -z "$OUT_DIR" ]; then 172 | OUT_DIR="$(dirname -- "$(readlink -f -- "$COMMENTS_FILE")")" 173 | fi 174 | else 175 | if [ "$ext" == "xz" ]; then 176 | SUBMISSIONS_ARG="--submissions_stream" 177 | XZ_FILE="$SUBMISSIONS_FILE" 178 | else 179 | SUBMISSIONS_ARG="--submissions_file $SUBMISSIONS_FILE" 180 | fi 181 | 182 | if [ -z "$OUT_DIR" ]; then 183 | OUT_DIR="$(dirname -- "$(readlink -f -- "$SUBMISSIONS_FILE")")" 184 | fi 185 | fi 186 | 187 | export PYTHONPATH="." 188 | 189 | rnd=$(date | md5sum | head -c 4) 190 | crash_file=".reddit_crash.$rnd" 191 | 192 | if [ "$ext" == "xz" ]; then 193 | nohup sh -c "xz -d $XZ_FILE -c | $PYTHON_CMD -u thred/corpora/reddit/reddit_parser.py --out_dir $OUT_DIR --output_prefix $filename --subreddits $SUBREDDIT_FILE -r $crash_file $COMMENTS_ARG $SUBMISSIONS_ARG $BATCH_ARG $SKIP_ARG $MAXW_ARG $MAXC_ARG $MINW_ARG $MINC_ARG" >$LOG_FILE 2>&1 < /dev/null & 194 | else 195 | nohup $PYTHON_CMD -u thred/corpora/reddit/reddit_parser.py --out_dir $OUT_DIR --output_prefix $filename --subreddits $SUBREDDIT_FILE -r $crash_file $COMMENTS_ARG $SUBMISSIONS_ARG $BATCH_ARG $SKIP_ARG $MAXW_ARG $MAXC_ARG $MINW_ARG $MINC_ARG >$LOG_FILE 2>&1 < /dev/null & 196 | fi 197 | 198 | echo -e "${BOLD}Crash file set to ${crash_file}${NORMAL}" 199 | echo -e "${BOLD}Running... Check out log file ${LOG_FILE}${NORMAL}" 200 | -------------------------------------------------------------------------------- /conf/hred_large.yml: -------------------------------------------------------------------------------- 1 | type: hred 2 | learning_rate: 0.0002 3 | max_gradient_norm: 5.0 4 | vocab_size: 50000 5 | num_layers: 4 6 | residual: True 7 | embedding_type: 'glove840B' 8 | cell_type: 'gru' 9 | encoder_type: 'bi' 10 | context_type: 'bi' 11 | hidden_units: 512 12 | optimizer: 'adam' 13 | batch_size: 128 14 | 15 | num_buckets: 5 16 | src_max_len: 30 17 | tgt_max_len: 30 18 | 19 | num_turns: 3 20 | 21 | num_train_epochs: 15 22 | steps_per_stats: 10 23 | steps_per_eval: 3000 24 | # Early Stopping parameters, inspired from VHRED impl 25 | patience: 50 26 | degrade_threshold: 1.003 27 | 28 | encoder_dropout_rate: 0.2 29 | context_dropout_rate: 0.2 30 | decoder_dropout_rate: 0.2 31 | 32 | decoding_length_factor: 2.0 33 | 34 | learning_rate_decay_scheme: luong234 35 | #start_decay_step: 0 36 | #decay_steps: 10000 37 | #decay_factor: 0.5 38 | 39 | scheduled_sampling_prob: 0.0 40 | scheduled_sampling_decay_scheme: none 41 | 42 | beam_width: 5 43 | length_penalty_weight: 1.0 44 | sampling_temperature: 0.0 45 | infer_batch_size: 128 46 | 47 | log_device: False 48 | -------------------------------------------------------------------------------- /conf/hred_small.yml: -------------------------------------------------------------------------------- 1 | type: hred 2 | learning_rate: 0.0002 3 | max_gradient_norm: 5.0 4 | vocab_size: 40000 5 | num_layers: 2 6 | residual: True 7 | embedding_type: 'glove840B' 8 | cell_type: 'gru' 9 | encoder_type: 'uni' 10 | context_type: 'uni' 11 | hidden_units: 512 12 | optimizer: 'adam' 13 | batch_size: 64 14 | 15 | num_buckets: 5 16 | src_max_len: 30 17 | tgt_max_len: 30 18 | 19 | num_turns: 3 20 | 21 | num_train_epochs: 15 22 | steps_per_stats: 10 23 | steps_per_eval: 3000 24 | # Early Stopping parameters, inspired from VHRED impl 25 | patience: 50 26 | degrade_threshold: 1.003 27 | 28 | encoder_dropout_rate: 0.2 29 | context_dropout_rate: 0.2 30 | decoder_dropout_rate: 0.2 31 | 32 | decoding_length_factor: 2.0 33 | 34 | learning_rate_decay_scheme: luong234 35 | #start_decay_step: 0 36 | #decay_steps: 10000 37 | #decay_factor: 0.5 38 | 39 | scheduled_sampling_prob: 0.0 40 | scheduled_sampling_decay_scheme: luong234 41 | 42 | beam_width: 5 43 | length_penalty_weight: 1.0 44 | sampling_temperature: 0.0 45 | infer_batch_size: 128 46 | 47 | log_device: False 48 | -------------------------------------------------------------------------------- /conf/taware_large.yml: -------------------------------------------------------------------------------- 1 | type: 'topic_aware' 2 | learning_rate: 0.0002 3 | max_gradient_norm: 5.0 4 | vocab_size: 50000 5 | num_layers: 2 6 | residual: True 7 | embedding_type: 'fastText' 8 | cell_type: 'gru' 9 | encoder_type: 'bi' 10 | hidden_units: 800 11 | attention_type: 'bahdanau' 12 | optimizer: 'adam' 13 | batch_size: 128 14 | 15 | num_buckets: 5 16 | topic_words_per_utterance: 50 17 | src_max_len: 30 18 | tgt_max_len: 30 19 | 20 | num_train_epochs: 15 21 | steps_per_stats: 10 22 | steps_per_eval: 3000 23 | # Early Stopping parameters, inspired from VHRED impl 24 | patience: 50 25 | degrade_threshold: 1.003 26 | 27 | encoder_dropout_rate: 0.2 28 | decoder_dropout_rate: 0.2 29 | 30 | decoding_length_factor: 2.0 31 | 32 | learning_rate_decay_scheme: luong234 33 | #start_decay_step: 0 34 | #decay_steps: 10000 35 | #decay_factor: 0.5 36 | 37 | scheduled_sampling_prob: 0.0 38 | scheduled_sampling_decay_scheme: none 39 | 40 | beam_width: 5 41 | length_penalty_weight: 1.0 42 | sampling_temperature: 0.0 43 | infer_batch_size: 128 44 | 45 | log_device: False 46 | 47 | boost_topic_gen_prob: True 48 | -------------------------------------------------------------------------------- /conf/taware_small.yml: -------------------------------------------------------------------------------- 1 | type: 'topic_aware' 2 | learning_rate: 0.0002 3 | max_gradient_norm: 5.0 4 | vocab_size: 40000 5 | num_layers: 1 6 | residual: False 7 | embedding_type: 'fastText' 8 | cell_type: 'gru' 9 | encoder_type: 'uni' 10 | hidden_units: 1024 11 | attention_type: 'bahdanau' 12 | optimizer: 'adam' 13 | batch_size: 128 14 | 15 | num_buckets: 5 16 | topic_words_per_utterance: 100 17 | src_max_len: 30 18 | tgt_max_len: 30 19 | 20 | num_train_epochs: 10 21 | steps_per_stats: 10 22 | steps_per_eval: 3000 23 | # Early Stopping parameters, inspired from VHRED impl 24 | patience: 50 25 | degrade_threshold: 1.003 26 | 27 | encoder_dropout_rate: 0.2 28 | decoder_dropout_rate: 0.2 29 | 30 | decoding_length_factor: 2.0 31 | 32 | learning_rate_decay_scheme: luong234 33 | #start_decay_step: 0 34 | #decay_steps: 10000 35 | #decay_factor: 0.5 36 | 37 | scheduled_sampling_prob: 0.0 38 | scheduled_sampling_decay_scheme: none 39 | 40 | beam_width: 5 41 | length_penalty_weight: 1.0 42 | sampling_temperature: 0.0 43 | infer_batch_size: 128 44 | 45 | log_device: False 46 | 47 | boost_topic_gen_prob: True 48 | -------------------------------------------------------------------------------- /conf/taware_xlarge.yml: -------------------------------------------------------------------------------- 1 | type: 'topic_aware' 2 | learning_rate: 0.0002 3 | max_gradient_norm: 5.0 4 | vocab_size: 50000 5 | num_layers: 4 6 | embedding_type: 'glove840B' 7 | cell_type: 'gru' 8 | encoder_type: 'bi' 9 | hidden_units: 1024 10 | attention_type: 'bahdanau' 11 | optimizer: 'adam' 12 | batch_size: 128 13 | 14 | num_buckets: 5 15 | topic_words_per_utterance: 100 16 | src_max_len: 30 17 | tgt_max_len: 30 18 | 19 | num_train_epochs: 15 20 | steps_per_stats: 30 21 | steps_per_eval: 3000 22 | # Early Stopping parameters, inspired from VHRED impl 23 | patience: 50 24 | degrade_threshold: 1.003 25 | 26 | encoder_dropout_rate: 0.2 27 | decoder_dropout_rate: 0.2 28 | 29 | decoding_length_factor: 2.0 30 | 31 | learning_rate_decay_scheme: luong234 32 | #start_decay_step: 0 33 | #decay_steps: 10000 34 | #decay_factor: 0.5 35 | 36 | scheduled_sampling_prob: 0.0 37 | scheduled_sampling_decay_scheme: none 38 | 39 | beam_width: 5 40 | length_penalty_weight: 1.0 41 | sampling_temperature: 0.0 42 | infer_batch_size: 128 43 | 44 | log_device: False 45 | 46 | boost_topic_gen_prob: True 47 | -------------------------------------------------------------------------------- /conf/thred_large.yml: -------------------------------------------------------------------------------- 1 | type: 'thred' 2 | learning_rate: 0.0002 3 | max_gradient_norm: 5.0 4 | vocab_size: 60000 5 | num_layers: 4 6 | residual: True 7 | embedding_type: 'fastText' 8 | cell_type: 'gru' 9 | encoder_type: 'bi' 10 | context_type: 'bi' 11 | hidden_units: 512 12 | attention_type: 'normed_bahdanau' 13 | optimizer: 'adam' 14 | batch_size: 128 15 | 16 | num_buckets: 5 17 | topic_words_per_utterance: 50 18 | src_max_len: 30 19 | tgt_max_len: 30 20 | 21 | num_turns: 3 22 | 23 | num_train_epochs: 15 24 | steps_per_stats: 10 25 | steps_per_eval: 3000 26 | # Early Stopping parameters, inspired from VHRED impl 27 | patience: 50 28 | degrade_threshold: 1.003 29 | 30 | encoder_dropout_rate: 0.2 31 | context_dropout_rate: 0.2 32 | decoder_dropout_rate: 0.2 33 | 34 | decoding_length_factor: 2.0 35 | 36 | learning_rate_decay_scheme: luong234 37 | #start_decay_step: 0 38 | #decay_steps: 10000 39 | #decay_factor: 0.5 40 | 41 | scheduled_sampling_prob: 0.0 42 | scheduled_sampling_decay_scheme: none 43 | 44 | beam_width: 5 45 | length_penalty_weight: 1.0 46 | sampling_temperature: 0.0 47 | infer_batch_size: 128 48 | 49 | log_device: False 50 | 51 | boost_topic_gen_prob: True 52 | -------------------------------------------------------------------------------- /conf/thred_medium.yml: -------------------------------------------------------------------------------- 1 | type: 'thred' 2 | learning_rate: 0.0002 3 | max_gradient_norm: 5.0 4 | vocab_size: 40000 5 | num_layers: 2 6 | residual: True 7 | embedding_type: 'hub_word2vec' 8 | cell_type: 'gru' 9 | encoder_type: 'bi' 10 | context_type: 'bi' 11 | hidden_units: 512 12 | attention_type: 'normed_bahdanau' 13 | optimizer: 'adam' 14 | batch_size: 128 15 | 16 | num_buckets: 5 17 | topic_words_per_utterance: 100 18 | src_max_len: 30 19 | tgt_max_len: 30 20 | 21 | num_turns: 3 22 | 23 | num_train_epochs: 15 24 | steps_per_stats: 10 25 | steps_per_eval: 3000 26 | # Early Stopping parameters, inspired from VHRED impl 27 | patience: 50 28 | degrade_threshold: 1.003 29 | 30 | encoder_dropout_rate: 0.2 31 | context_dropout_rate: 0.2 32 | decoder_dropout_rate: 0.2 33 | 34 | decoding_length_factor: 2.0 35 | 36 | learning_rate_decay_scheme: luong234 37 | #start_decay_step: 0 38 | #decay_steps: 10000 39 | #decay_factor: 0.5 40 | 41 | scheduled_sampling_prob: 0.0 42 | scheduled_sampling_decay_scheme: luong234 43 | 44 | beam_width: 5 45 | length_penalty_weight: 1.0 46 | sampling_temperature: 0.0 47 | infer_batch_size: 128 48 | 49 | log_device: False 50 | 51 | boost_topic_gen_prob: True 52 | -------------------------------------------------------------------------------- /conf/vanilla_large.yml: -------------------------------------------------------------------------------- 1 | type: vanilla 2 | learning_rate: 0.0002 3 | max_gradient_norm: 5.0 4 | vocab_size: 50000 5 | num_layers: 4 6 | residual: True 7 | embedding_type: 'glove840B' 8 | cell_type: 'lstm' 9 | encoder_type: 'bi' 10 | hidden_units: 512 11 | attention_type: 'normed_bahdanau' 12 | optimizer: 'adam' 13 | batch_size: 128 14 | 15 | num_buckets: 5 16 | src_max_len: 30 17 | tgt_max_len: 30 18 | 19 | num_train_epochs: 12 20 | steps_per_stats: 10 21 | steps_per_eval: 3000 22 | # Early Stopping parameters, inspired from VHRED impl 23 | patience: 50 24 | degrade_threshold: 1.003 25 | 26 | encoder_dropout_rate: 0.2 27 | decoder_dropout_rate: 0.2 28 | 29 | decoding_length_factor: 2.0 30 | 31 | learning_rate_decay_scheme: luong234 32 | #start_decay_step: 160000 33 | #decay_steps: 16000 34 | #decay_factor: 0.5 35 | 36 | scheduled_sampling_prob: 0.0 37 | scheduled_sampling_decay_scheme: none 38 | 39 | beam_width: 5 40 | length_penalty_weight: 1.0 41 | sampling_temperature: 0.0 42 | infer_batch_size: 128 43 | 44 | log_device: False 45 | -------------------------------------------------------------------------------- /conf/vanilla_small.yml: -------------------------------------------------------------------------------- 1 | type: vanilla 2 | learning_rate: 0.0002 3 | max_gradient_norm: 5.0 4 | vocab_size: 40000 5 | num_layers: 2 6 | residual: False 7 | embedding_type: 'fastText' 8 | cell_type: 'lstm' 9 | encoder_type: 'bi' 10 | hidden_units: 512 11 | attention_type: 'normed_bahdanau' 12 | optimizer: 'adam' 13 | batch_size: 128 14 | 15 | num_buckets: 5 16 | src_max_len: 30 17 | tgt_max_len: 30 18 | 19 | num_train_epochs: 12 20 | steps_per_stats: 10 21 | steps_per_eval: 3000 22 | # Early Stopping parameters, inspired from VHRED impl 23 | patience: 50 24 | degrade_threshold: 1.003 25 | 26 | encoder_dropout_rate: 0.2 27 | decoder_dropout_rate: 0.2 28 | 29 | decoding_length_factor: 2.0 30 | 31 | learning_rate_decay_scheme: luong234 32 | #start_decay_step: 160000 33 | #decay_steps: 16000 34 | #decay_factor: 0.5 35 | 36 | scheduled_sampling_prob: 0.0 37 | scheduled_sampling_decay_scheme: luong234 38 | 39 | beam_width: 5 40 | length_penalty_weight: 1.0 41 | sampling_temperature: 0.0 42 | infer_batch_size: 256 43 | 44 | log_device: False 45 | -------------------------------------------------------------------------------- /conf/vanilla_xlarge.yml: -------------------------------------------------------------------------------- 1 | type: vanilla 2 | learning_rate: 0.0002 3 | max_gradient_norm: 5.0 4 | vocab_size: 50000 5 | num_layers: 8 6 | residual: True 7 | embedding_type: 'hub_word2vec' 8 | cell_type: 'lstm' 9 | encoder_type: 'bi' 10 | hidden_units: 512 11 | attention_type: 'normed_bahdanau' 12 | optimizer: 'adam' 13 | batch_size: 128 14 | 15 | num_buckets: 5 16 | src_max_len: 30 17 | tgt_max_len: 30 18 | 19 | num_train_epochs: 15 20 | steps_per_stats: 10 21 | steps_per_eval: 3000 22 | # Early Stopping parameters, inspired from VHRED impl 23 | patience: 50 24 | degrade_threshold: 1.003 25 | 26 | encoder_dropout_rate: 0.2 27 | decoder_dropout_rate: 0.2 28 | 29 | decoding_length_factor: 2.0 30 | 31 | learning_rate_decay_scheme: luong234 32 | #start_decay_step: 160000 33 | #decay_steps: 16000 34 | #decay_factor: 0.5 35 | 36 | scheduled_sampling_prob: 0.0 37 | scheduled_sampling_decay_scheme: none 38 | 39 | beam_width: 5 40 | length_penalty_weight: 1.0 41 | sampling_temperature: 0.0 42 | infer_batch_size: 128 43 | 44 | log_device: False 45 | -------------------------------------------------------------------------------- /conf/word_embeddings.yml: -------------------------------------------------------------------------------- 1 | # word2vec, glove840B, glove6B_200, glove6B_100, glove6B_50, fastText, hub_word2vec 2 | 3 | glove6B_50: 4 | url: "http://magnitude.plasticity.ai/glove/medium/glove.6B.50d.magnitude" 5 | dim: 50 6 | src_type: "magnitude" 7 | 8 | glove6B_100: 9 | url: "http://magnitude.plasticity.ai/glove/medium/glove.6B.100d.magnitude" 10 | dim: 100 11 | src_type: "magnitude" 12 | 13 | glove6B_200: 14 | url: "http://magnitude.plasticity.ai/glove/medium/glove.6B.200d.magnitude" 15 | dim: 200 16 | src_type: "magnitude" 17 | 18 | glove840B: 19 | url: "http://magnitude.plasticity.ai/glove/medium/glove.840B.300d.magnitude" 20 | dim: 300 21 | src_type: "magnitude" 22 | 23 | word2vec: 24 | url: "http://magnitude.plasticity.ai/word2vec/medium/GoogleNews-vectors-negative300.magnitude" 25 | dim: 300 26 | src_type: "magnitude" 27 | 28 | fastText: 29 | url: "http://magnitude.plasticity.ai/fasttext/medium/wiki-news-300d-1M-subword.magnitude" 30 | dim: 300 31 | src_type: "magnitude" 32 | 33 | hub_word2vec: 34 | url: "https://tfhub.dev/google/Wiki-words-500/1" 35 | dim: 500 36 | src_type: "tfhub" 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4==4.6.0 2 | emot==1.0 3 | gensim>=3.4.0 4 | mistune>=0.8.0 5 | numpy<1.19.0,>=1.16.0 6 | praw>=6.0.0 7 | prompt-toolkit>=1.0.15 8 | PyYAML>=3.12,<4.0 9 | redis>=3.0.0 10 | tqdm>=4.20.0 11 | scipy>=1.0.0,<2.0.0 12 | tensorflow_gpu==1.15.0 13 | tensorflow-hub==0.2.0 14 | spacy>=2.1.0,<2.2.0 15 | https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-2.1.0/en_core_web_lg-2.1.0.tar.gz#egg=en_core_web_lg 16 | pymagnitude>=0.1.120 17 | lz4 18 | xxhash 19 | annoy 20 | fasteners -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="thred", 5 | version="0.1.2", 6 | author="Nouha Dziri, Ehsan Kamalloo, Kory Mathewson", 7 | author_email="dziri@cs.ualberta.ca", 8 | description="Neural Response Generation Framework", 9 | long_description=open("README.md", "r", encoding='utf-8').read(), 10 | long_description_content_type="text/markdown", 11 | keywords='dialogue-generation sequence-to-sequence tensorflow', 12 | url="https://github.com/nouhadziri/THRED", 13 | packages=find_packages(exclude=["*.tests", "*.tests.*", 14 | "tests.*", "tests"]), 15 | install_requires=['tensorflow_gpu==1.15.0', 16 | 'tensorflow-hub==0.2.0', 17 | 'spacy>=2.1.0,<2.2.0', 18 | 'scipy>=1.0.0,<2.0.0', 19 | 'pymagnitude', 20 | 'redis', 21 | 'PyYAML', 22 | 'gensim>=3.4.0', 23 | 'mistune>=0.8.0', 24 | 'emot==1.0', 25 | 'tqdm'], 26 | python_requires='>=3.5.0', 27 | tests_require=['pytest'], 28 | ) 29 | -------------------------------------------------------------------------------- /thred/__init__.py: -------------------------------------------------------------------------------- 1 | __version_info__ = ('0', '1', '2') 2 | __version__ = '.'.join(__version_info__) -------------------------------------------------------------------------------- /thred/__main__.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .main import main as thred_main 3 | 4 | 5 | if __name__ == '__main__': 6 | tf.app.run(main=thred_main) 7 | -------------------------------------------------------------------------------- /thred/corpora/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nouhadziri/THRED/c451e801d6d36c21f52a42192b15f985241650b2/thred/corpora/__init__.py -------------------------------------------------------------------------------- /thred/corpora/reddit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nouhadziri/THRED/c451e801d6d36c21f52a42192b15f985241650b2/thred/corpora/reddit/__init__.py -------------------------------------------------------------------------------- /thred/corpora/reddit/profanity_words.txt: -------------------------------------------------------------------------------- 1 | 4r5e 2 | 5h1t 3 | 5hit 4 | a55 5 | anal 6 | anus 7 | ar5e 8 | arrse 9 | arse 10 | ass 11 | ass-fucker 12 | asses 13 | assfucker 14 | assfukka 15 | asshole 16 | assholes 17 | asswhole 18 | a_s_s 19 | b!tch 20 | b00bs 21 | b17ch 22 | b1tch 23 | ballbag 24 | balls 25 | ballsack 26 | bastard 27 | beastial 28 | beastiality 29 | bellend 30 | bestial 31 | bestiality 32 | bi+ch 33 | biatch 34 | bitch 35 | bitcher 36 | bitchers 37 | bitches 38 | bitchin 39 | bitching 40 | bloody 41 | blow job 42 | blowjob 43 | blowjobs 44 | boiolas 45 | bollock 46 | bollok 47 | boner 48 | boob 49 | boobs 50 | booobs 51 | boooobs 52 | booooobs 53 | booooooobs 54 | breasts 55 | buceta 56 | bugger 57 | bum 58 | bunny fucker 59 | butt 60 | butthole 61 | buttmuch 62 | buttplug 63 | c0ck 64 | c0cksucker 65 | carpet muncher 66 | cawk 67 | chink 68 | cipa 69 | cl1t 70 | clit 71 | clitoris 72 | clits 73 | cnut 74 | cock 75 | cock-sucker 76 | cockface 77 | cockhead 78 | cockmunch 79 | cockmuncher 80 | cocks 81 | cocksuck 82 | cocksucked 83 | cocksucker 84 | cocksucking 85 | cocksucks 86 | cocksuka 87 | cocksukka 88 | cok 89 | cokmuncher 90 | coksucka 91 | coon 92 | cox 93 | crap 94 | cum 95 | cummer 96 | cumming 97 | cums 98 | cumshot 99 | cunilingus 100 | cunillingus 101 | cunnilingus 102 | cunt 103 | cuntlick 104 | cuntlicker 105 | cuntlicking 106 | cunts 107 | cyalis 108 | cyberfuc 109 | cyberfuck 110 | cyberfucked 111 | cyberfucker 112 | cyberfuckers 113 | cyberfucking 114 | d1ck 115 | damn 116 | dick 117 | dickhead 118 | dildo 119 | dildos 120 | dink 121 | dinks 122 | dirsa 123 | dlck 124 | dog-fucker 125 | doggin 126 | dogging 127 | donkeyribber 128 | doosh 129 | duche 130 | dyke 131 | ejaculate 132 | ejaculated 133 | ejaculates 134 | ejaculating 135 | ejaculatings 136 | ejaculation 137 | ejakulate 138 | f u c k 139 | f u c k e r 140 | f4nny 141 | fag 142 | fagging 143 | faggitt 144 | faggot 145 | faggs 146 | fagot 147 | fagots 148 | fags 149 | fanny 150 | fannyflaps 151 | fannyfucker 152 | fanyy 153 | fatass 154 | fcuk 155 | fcuker 156 | fcuking 157 | feck 158 | fecker 159 | felching 160 | fellate 161 | fellatio 162 | fingerfuck 163 | fingerfucked 164 | fingerfucker 165 | fingerfuckers 166 | fingerfucking 167 | fingerfucks 168 | fistfuck 169 | fistfucked 170 | fistfucker 171 | fistfuckers 172 | fistfucking 173 | fistfuckings 174 | fistfucks 175 | flange 176 | fook 177 | fooker 178 | fuck 179 | fucka 180 | fucked 181 | fucker 182 | fuckers 183 | fuckhead 184 | fuckheads 185 | fuckin 186 | fucking 187 | fuckijg 188 | fuckingfuckpiss 189 | fuckings 190 | fuckingshitmotherfucker 191 | fuckme 192 | fucks 193 | fuckwhit 194 | fuckwit 195 | fudge packer 196 | fudgepacker 197 | fuk 198 | fuker 199 | fukker 200 | fukkin 201 | fuks 202 | fukwhit 203 | fukwit 204 | fux 205 | fux0r 206 | f_u_c_k 207 | gangbang 208 | gangbanged 209 | gangbangs 210 | gaylord 211 | gaysex 212 | goatse 213 | God 214 | god-dam 215 | god-damned 216 | goddamn 217 | goddamned 218 | hardcoresex 219 | hell 220 | heshe 221 | hoar 222 | hoare 223 | hoer 224 | homo 225 | hore 226 | horniest 227 | horny 228 | hotsex 229 | jack-off 230 | jackoff 231 | jap 232 | jerk-off 233 | jism 234 | jiz 235 | jizm 236 | jizz 237 | kawk 238 | knob 239 | knobead 240 | knobed 241 | knobend 242 | knobhead 243 | knobjocky 244 | knobjokey 245 | kock 246 | kondum 247 | kondums 248 | kum 249 | kummer 250 | kumming 251 | kums 252 | kunilingus 253 | l3i+ch 254 | l3itch 255 | labia 256 | lmfao 257 | lust 258 | lusting 259 | m0f0 260 | m0fo 261 | m45terbate 262 | ma5terb8 263 | ma5terbate 264 | masochist 265 | master-bate 266 | masterb8 267 | masterbat* 268 | masterbat3 269 | masterbate 270 | masterbation 271 | masterbations 272 | masturbate 273 | mo-fo 274 | mof0 275 | mofo 276 | mothafuck 277 | mothafucka 278 | mothafuckas 279 | mothafuckaz 280 | mothafucked 281 | mothafucker 282 | mothafuckers 283 | mothafuckin 284 | mothafucking 285 | mothafuckings 286 | mothafucks 287 | mother fucker 288 | motherfuck 289 | motherfucked 290 | motherfucker 291 | motherfuckers 292 | motherfuckin 293 | motherfucking 294 | motherfuckings 295 | motherfuckka 296 | motherfucks 297 | muff 298 | mutha 299 | muthafecker 300 | muthafuckker 301 | muther 302 | mutherfucker 303 | n1gga 304 | n1gger 305 | nazi 306 | nigg3r 307 | nigg4h 308 | nigga 309 | niggah 310 | niggas 311 | niggaz 312 | nigger 313 | niggers 314 | nob 315 | nob jokey 316 | nobhead 317 | nobjocky 318 | nobjokey 319 | numbnuts 320 | nutsack 321 | orgasim 322 | orgasims 323 | orgasm 324 | orgasms 325 | p0rn 326 | pawn 327 | pecker 328 | penis 329 | penisfucker 330 | phonesex 331 | phuck 332 | phuk 333 | phuked 334 | phuking 335 | phukked 336 | phukking 337 | phuks 338 | phuq 339 | pigfucker 340 | pimpis 341 | piss 342 | pissed 343 | pisser 344 | pissers 345 | pisses 346 | pissflaps 347 | pissin 348 | pissing 349 | pissoff 350 | poop 351 | porn 352 | porno 353 | pornography 354 | pornos 355 | prick 356 | pricks 357 | pron 358 | pube 359 | pusse 360 | pussi 361 | pussies 362 | pussy 363 | pussys 364 | rectum 365 | retard 366 | rimjaw 367 | rimming 368 | s hit 369 | s.o.b. 370 | sadist 371 | schlong 372 | screwing 373 | scroat 374 | scrote 375 | scrotum 376 | semen 377 | sex 378 | sh!+ 379 | sh!t 380 | sh1t 381 | shag 382 | shagger 383 | shaggin 384 | shagging 385 | shemale 386 | shi+ 387 | shit 388 | shitdick 389 | shite 390 | shited 391 | shitey 392 | shitfuck 393 | shitfull 394 | shithead 395 | shiting 396 | shitings 397 | shits 398 | shitted 399 | shitter 400 | shitters 401 | shitting 402 | shittings 403 | shitty 404 | skank 405 | slut 406 | sluts 407 | smegma 408 | smut 409 | snatch 410 | son-of-a-bitch 411 | spac 412 | spunk 413 | s_h_i_t 414 | t1tt1e5 415 | t1tties 416 | teets 417 | teez 418 | testical 419 | testicle 420 | tit 421 | titfuck 422 | tits 423 | titt 424 | tittie5 425 | tittiefucker 426 | titties 427 | tittyfuck 428 | tittywank 429 | titwank 430 | tosser 431 | turd 432 | tw4t 433 | twat 434 | twathead 435 | twatty 436 | twunt 437 | twunter 438 | v14gra 439 | v1gra 440 | vagina 441 | viagra 442 | vulva 443 | w00se 444 | wang 445 | wank 446 | wanker 447 | wanky 448 | whoar 449 | whore 450 | willies 451 | willy 452 | xrated 453 | xxx 454 | dafuck 455 | muslimfuckers 456 | fuckface 457 | shitload 458 | mother'fucking 459 | fuckbuddy 460 | fucknut 461 | fucknuts 462 | fuckah 463 | fuckery 464 | fuckload 465 | fucktard 466 | fucktwat 467 | fuckstick 468 | fuckwits 469 | fuckballs 470 | dumbfuck 471 | dumbfucks 472 | bumfuck 473 | ^^^fucking 474 | fucken 475 | fucky 476 | bumblefuck 477 | fuckig 478 | didisaytocommentonthesizeortofuckingclimbit 479 | muthafuckin 480 | muthafucka 481 | muthafuckas 482 | muffucka 483 | fuckerberg 484 | gifuckingjane 485 | fuckhin 486 | butfucked 487 | buttfuck 488 | bangpussy 489 | pussyfriednachos 490 | skullfuck 491 | fuckboy 492 | ratfucks 493 | unfuckable 494 | unfuckabee 495 | fuckwad 496 | fuck_youkaren 497 | fuckabee 498 | \fuck 499 | fuckton 500 | niggerboi 501 | niggeh 502 | mothafuckaaahaa 503 | fuckass 504 | fuckable 505 | idontknowwhatthefuckimdoingbutimdoingitsoshutupfam 506 | fuckorh 507 | fuckboi 508 | yanifucktrumpsdick 509 | fuckk 510 | .fucking 511 | motherfuckerrrr 512 | motherfucka 513 | fucktards 514 | beepbeepbeepnevermindremovethefuckingcardrightthisfuckinginstantbeeepbeeepbeeep 515 | buttfucker 516 | fuckinn 517 | fuckkkkkk 518 | fuckswithducks 519 | fuckfaces 520 | fuckwads 521 | buttfucked 522 | buttfucking 523 | fuckn 524 | fucko 525 | duckfucker 526 | fuckkk 527 | fuckkkk 528 | fuckkkkkkkkk 529 | fuckking 530 | horsefucker 531 | fuckeable 532 | fuckoff 533 | fuckall 534 | zuckerfuck 535 | fucktammy 536 | muthafuck 537 | tittyfucking 538 | assfuck 539 | fucksakes 540 | dumbfuckery 541 | fuck'em 542 | ohhfuck 543 | mo'fuckaz 544 | ricefuck 545 | fuckstain 546 | fuckbois 547 | fuckssakes 548 | starfucker 549 | fucksies 550 | ratfucked 551 | ^^fuck 552 | motherslumberfucker 553 | manafucked 554 | muthafucker 555 | ohfuckitimissedyou 556 | -fucking 557 | ^^^^^fuckyoureddit 558 | muhfucka 559 | muhfuckas 560 | fuckboys 561 | -fuck 562 | hatefucked 563 | thefuckup 564 | fuckups -------------------------------------------------------------------------------- /thred/corpora/reddit/reddit_bots.txt: -------------------------------------------------------------------------------- 1 | # Added manually 2 | PoliticsModeratorBot 3 | justgoodenough 4 | roydeanbjj 5 | _spiraling 6 | purplespengler 7 | eric_twinge 8 | DeltaBot3 9 | CMVModBot 10 | # extracted from https://www.reddit.com/r/autowikibot/wiki/redditbots 11 | A858DE45F56D9BC9 12 | AAbot 13 | ADHDbot 14 | ALTcointip 15 | AVR_Modbot 16 | A_random_gif 17 | AltCodeBot 18 | Antiracism_Bot 19 | ApiContraption 20 | AssHatBot 21 | AtheismModBot 22 | AutoInsult 23 | BELITipBot 24 | BadLinguisticsBot 25 | BanishedBot 26 | BeetusBot 27 | BensonTheBot 28 | Bible_Verses_Bot 29 | BlackjackBot 30 | BlockchainBot 31 | Brigade_Bot 32 | Bronze-Bot 33 | CAH_BLACK_BOT 34 | CHART_BOT 35 | CLOSING_PARENTHESIS 36 | CPTModBot 37 | Cakeday-Bot 38 | CalvinBot 39 | CaptionBot 40 | CarterDugSubLinkBot 41 | CasualMetricBot 42 | Chemistry_Bot 43 | ChristianityBot 44 | Codebreakerbreaker 45 | Comment_Codebreaker 46 | ComplimentingBot 47 | CreepierSmileBot 48 | CreepySmileBot 49 | CuteBot6969 50 | DDBotIndia 51 | DNotesTip 52 | DRKTipBot 53 | DefinitelyBot 54 | DeltaBot 55 | Dictionary__Bot 56 | DidSomeoneSayBoobs 57 | DogeLotteryModBot 58 | DogeTipStatsBot 59 | DogeWordCloudBot 60 | DotaCastingBot 61 | Downtotes_Plz 62 | DownvotesMcGoats 63 | DropBox_Bot 64 | EmmaBot 65 | Epic_Face_Bot 66 | EscapistVideoBot 67 | ExmoBot 68 | ExplanationBot 69 | FTFY_Cat6 70 | FTFY_Cat 71 | FedoraTipAutoBot 72 | FelineFacts 73 | Fixes_GrammerNazi_ 74 | FriendSafariBot 75 | FriendlyCamelCaseBot 76 | FrontpageWatch 77 | Frown_Bot 78 | GATSBOT 79 | GabenCoinTipBot 80 | GameDealsBot 81 | Gatherer_bot 82 | GeekWhackBot 83 | GiantBombBot 84 | GifAsHTML5 85 | GoneWildResearcher 86 | GooglePlusBot 87 | GotCrypto 88 | GrammerNazi_ 89 | GreasyBacon 90 | Grumbler_bot 91 | GunnersGifsBot 92 | GunnitBot 93 | HCE_Replacement_Bot 94 | HScard_display_bot 95 | Handy_Related_Sub 96 | HighResImageFinder 97 | HockeyGT_Bot 98 | HowIsThisBestOf_Bot 99 | IAgreeBot 100 | ICouldntCareLessBot 101 | IS_IT_SOLVED 102 | I_BITCOIN_CATS 103 | I_Say_No_ 104 | Insane_Photo_Bot 105 | IsItDownBot 106 | JiffyBot 107 | JotBot 108 | JumpToBot 109 | KSPortBot 110 | KarmaConspiracy_Bot 111 | LazyLinkerBot 112 | LinkFixerBotSnr 113 | Link_Correction_Bot 114 | Link_Demobilizer 115 | Link_Rectifier_Bot 116 | LinkedCommentBot 117 | LocationBot 118 | MAGNIFIER_BOT 119 | Makes_Small_Text_Bot 120 | Meta_Bot 121 | MetatasticBot 122 | MetricPleaseBot 123 | Metric_System_Bot 124 | MontrealBot 125 | MovieGuide 126 | MultiFunctionBot 127 | MumeBot 128 | NASCARThreadBot 129 | NFLVideoBot 130 | NSLbot 131 | Nazeem_Bot 132 | New_Small_Text_Bot 133 | Nidalee_Bot 134 | NightMirrorMoon 135 | NoSleepAutoMod 136 | NoSobStoryBot2 137 | NobodyDoesThis 138 | NotRedditEnough 139 | PHOTO_OF_CAPTAIN_RON 140 | PJRP_Bot 141 | PhoenixBot 142 | PigLatinsYourComment 143 | PlayStoreLinks_Bot 144 | PlaylisterBot 145 | PleaseRespectTables 146 | PloungeMafiaVoteBot 147 | PokemonFlairBot 148 | PoliteBot 149 | PoliticBot 150 | PonyTipBot 151 | PornOverlord 152 | Porygon-Bot 153 | PresidentObama___ 154 | ProselytizerBot 155 | PunknRollBot 156 | QUICHE-BOT 157 | RFootballBot 158 | Random-ComplimentBOT 159 | RandomTriviaBot 160 | Rangers_Bot 161 | Readdit_Bot 162 | Reads_Small_Text_Bot 163 | RealtechPostBot 164 | ReddCoinGoldBot 165 | Relevant_News_Bot 166 | RequirementsBot 167 | RfreebandzBOT 168 | RiskyClickBot 169 | SERIAL_JOKE_KILLER 170 | SMCTipBot 171 | SRD_Notifier 172 | SRS_History_Bot 173 | SRScreenshot 174 | SWTOR_Helper_Bot 175 | SakuraiBot_test 176 | SakuraiBot 177 | SatoshiTipBot 178 | ShadowBannedBot 179 | ShibeBot 180 | ShillForMonsanto 181 | Shiny-Bot 182 | ShittyGandhiQuotes 183 | ShittyImageBot 184 | SketchNotSkit 185 | SmallTextReader 186 | Smile_Bot 187 | Somalia_Bot 188 | Some_Bot 189 | StackBot 190 | StarboundBot 191 | StencilTemplateBOT 192 | StreetFightMirrorBot 193 | SuchModBot 194 | SurveyOfRedditBot 195 | TOP_COMMENT_OF_YORE 196 | Text_Reader_Bot 197 | TheSwedishBot 198 | TipMoonBot 199 | TitsOrGTFO_Bot 200 | TweetPoster 201 | Twitch2YouTube 202 | Unhandy_Related_Sub 203 | UnobtaniumTipBot 204 | UrbanDicBot 205 | UselessArithmeticBot 206 | UselessConversionBot 207 | VideoLinkBot 208 | VideopokerBot 209 | VsauceBot 210 | WWE_Network_Bot 211 | WeAppreciateYou 212 | Website_Mirror_Bot 213 | WeeaBot 214 | WhoWouldWinBot 215 | Wiki_Bot 216 | Wiki_FirstPara_bot 217 | WikipediaCitationBot 218 | Wink-Bot 219 | WordCloudBot2 220 | WritingPromptsBot 221 | X_BOT 222 | YT_Bot 223 | _Definition_Bot_ 224 | _FallacyBot_ 225 | _Rita_ 226 | __bot__ 227 | albumbot 228 | allinonebot 229 | annoying_yes_bot 230 | asmrspambot 231 | astro-bot 232 | auto-doge 233 | automoderator 234 | autourbanbot 235 | autowikibot 236 | bRMT_Bot 237 | bad_ball_ban_bot 238 | ban_pruner 239 | baseball_gif_bot 240 | beecointipbot 241 | bitcoinpartybot 242 | bitcointip 243 | bitofnewsbot 244 | bocketybot 245 | c5bot 246 | c5bot 247 | cRedditBot 248 | callfloodbot 249 | callibot 250 | canada_goose_tip_bot 251 | changetip 252 | cheesecointipbot 253 | chromabot 254 | classybot 255 | coinflipbot 256 | coinyetipper 257 | colorcodebot 258 | comment_copier_bot 259 | compilebot 260 | conspirobot 261 | creepiersmilebot 262 | cris9696 263 | cruise_bot 264 | d3posterbot 265 | define_bot 266 | demobilizer 267 | dgctipbot 268 | digitipbot 269 | disapprovalbot 270 | dogetipbot 271 | earthtipbot 272 | edmprobot 273 | elMatadero_bot 274 | elwh392 275 | expired_link_bot 276 | fa_mirror 277 | fact_check_bot 278 | faketipbot 279 | fedora_tip_bot 280 | fedoratips 281 | flappytip 282 | flips_title 283 | foreigneducationbot 284 | frytipbot 285 | fsctipbot 286 | gabenizer-bot 287 | gabentipbot 288 | gfy_bot 289 | gfycat-bot-sucksdick 290 | gifster_bot 291 | gives_you_boobies 292 | givesafuckbot 293 | gocougs_bot 294 | godwin_finder 295 | golferbot 296 | gracefulcharitybot 297 | gracefulclaritybot 298 | gregbot 299 | groompbot 300 | gunners_gif_bot 301 | haiku_robot 302 | havoc_bot 303 | hearing-aid_bot 304 | hearing_aid_bot 305 | hearingaid_bot 306 | hit_bot 307 | hockey_gif_bot 308 | howstat 309 | hwsbot 310 | imgurHostBot 311 | imgur_rehosting 312 | imgurtranscriber 313 | imirror_bot 314 | isitupbot 315 | jerkbot-3hunna 316 | keysteal_bot 317 | kittehcointipbot 318 | last_cakeday_bot 319 | linkfixerbot1 320 | linkfixerbot2 321 | linkfixerbot3 322 | loser_detector_bot 323 | luckoftheshibe 324 | makesTextSmall 325 | malen-shutup-bot 326 | matthewrobo 327 | meme_transcriber 328 | memedad-transcriber 329 | misconception_fixer 330 | mma_gif_bot 331 | moderator-bot 332 | nba_gif_bot 333 | new_eden_news_bot 334 | nhl_gif_bot 335 | not_alot_bot 336 | notoverticalvideo 337 | nyantip 338 | okc_rating_bot 339 | pandatipbot 340 | pandatips 341 | potdealer 342 | provides-id 343 | qznc_bot 344 | rSGSpolice 345 | r_PictureGame 346 | raddit-bot 347 | randnumbot 348 | rarchives 349 | readsmalltextbot 350 | redditbots 351 | redditreviewbot 352 | redditreviewbot 353 | reddtipbot 354 | relevantxkcd-bot 355 | request_bot 356 | rhiever-bot 357 | rightsbot 358 | rnfl_robot 359 | roger_bot 360 | rss_feed 361 | rubycointipbot 362 | rule_bot 363 | rusetipbot 364 | sentimentviewbot 365 | serendipitybot 366 | shadowbanbot 367 | slapbot 368 | slickwom-bot 369 | snapshot_bot 370 | soccer_gif_bot 371 | softwareswap_bot 372 | sports_gif_bot 373 | spursgifs_xposterbot 374 | stats-bot 375 | steam_bot 376 | subtext-bot 377 | synonym_flash 378 | tabledresser 379 | techobot 380 | tennis_gif_bot 381 | test_bot0x00 382 | tipmoonbot1 383 | tipmoonbot2 384 | tittietipbot 385 | topcoin_tip 386 | topredditbot 387 | totes_meta_bot 388 | ttumblrbots 389 | unitconvert 390 | valkyribot 391 | versebot 392 | vertcoinbot 393 | vertcointipbot 394 | wheres_the_karma_bot 395 | wooshbot 396 | xkcd_bot 397 | xkcd_number_bot 398 | xkcd_number_bot 399 | xkcd_number_bot 400 | xkcd_transcriber 401 | xkcdcomic_bot 402 | yes_it_is_weird 403 | yourebot 404 | # Added 32 bots on 2018-01-10 15:29:38 405 | WikiTextBot 406 | auto-xkcd37 407 | TheSwearBot 408 | opfeels 409 | anti-gif-bot 410 | Chick-fil-A_spellbot 411 | SideVoteBot 412 | autotldr 413 | Mentioned_Videos 414 | sneakpeekbot 415 | BOTS_RISE_UP 416 | robot_overloard 417 | imguralbumbot 418 | GoodMod_BadMod 419 | GoodBot_BadBot 420 | EyeBleachBot 421 | ClickableLinkBot 422 | gifv-bot 423 | stabbot 424 | PayRespects-Bot 425 | AreYouDeaf 426 | timee_bot 427 | HelperBot_ 428 | friendly-bot 429 | hug-bot 430 | _trailerbot_tester_ 431 | LinkReplyBot 432 | phonebatterylevelbot 433 | perrycohen 434 | ThisIsABotThatDoStuf 435 | PORTMANTEAU-BOT 436 | AnimalFactsBot 437 | # Added 21 bots on 2018-01-10 16:37:32 438 | theHelperdroid 439 | timestamp_bot 440 | ThisCatMightCheerYou 441 | LiveTwitchClips 442 | Sub_Corrector_Bot 443 | umnikos_bots 444 | table_it_bot 445 | jinx__bot 446 | MyRSSbot 447 | GitCommandBot 448 | image_linker_bot 449 | Subjunctive__Bot 450 | umadbrobot 451 | twinkiac 452 | tippr 453 | metric_robot 454 | BigLebowskiBot 455 | koja1234 456 | assume-gender-bot 457 | you_get_CMV_delta 458 | HieronymusBeta 459 | # Added 15 bots on 2018-01-17 14:48:24 460 | RedditSilverRobot 461 | HogwartsBot 462 | WhoaCorrection 463 | yogobot 464 | garlicbot 465 | here-have-some-sauce 466 | sukabot 467 | _youtubot_ 468 | Agrees_withyou 469 | RiskyClickerBot 470 | SmallSubBot 471 | shhbot 472 | Darnit_Bot 473 | headonbot_ 474 | DreamProcessor 475 | # Added 13 bots on 2018-02-21 19:01:38 476 | timezone_bot 477 | LimbRetrieval-Bot 478 | plsrespecttables 479 | garlictipsbot 480 | m32th4nks 481 | BotPaperScissors 482 | SubAutoCorrectBot 483 | test-bot23 484 | morejpeg_auto 485 | FatFingerHelperBot 486 | MaxImageBot 487 | HaikuBot9000 488 | DuplicatesBot 489 | # Added 29 bots on 2018-08-11 11:57:08 490 | ___alexa___ 491 | Confucius-Bot 492 | alternate-source-bot 493 | BananaFactBot 494 | agree-with-you 495 | CakeDayGIFt_Bot 496 | good-GHB_Bot 497 | BinaryNativeBot 498 | by-accident-bot 499 | resavr_bot 500 | Gyazo_Bot 501 | tupac_cares_bot 502 | dadjokes_bot 503 | TexasFactsBot 504 | tweettranscriberbot 505 | ultimatewikibot 506 | Burgandy_Bot 507 | Kn0ckKn0ckb0t 508 | as-opposed-to 509 | YTubeInfoBot 510 | oofed-bot 511 | WhyNotCollegeBoard 512 | bandalbumsong 513 | CommonMisspellingBot 514 | AlexaPlayBot 515 | ordinarybots 516 | xkcd_bot2000 517 | Bot_Metric 518 | PressFBot 519 | # Added 3 bots on 2018-08-12 23:54:05 520 | good-Human_Bot 521 | Banano_Tipbot 522 | no_string_bets 523 | # Added 15 bots on 2018-11-03 00:31:31 524 | thank_mr_skeltal_bot 525 | _whatbot_ 526 | ghost_of_dongerbot 527 | GoodBotBadAdmins 528 | drift_summary 529 | amp-is-watching-you 530 | JobsHelperBot 531 | vReddit_Player_Bot 532 | Shallyyy 533 | TitleToImageBot 534 | FunCicada 535 | RepliesNice 536 | vandjac 537 | moviescommentbot 538 | Link-Help-Bot 539 | # Added 14 bots on 2019-01-18 18:34:25 540 | EFTBot 541 | SlothFactsBot 542 | societybot 543 | HappyFriendlyBot 544 | icarebot 545 | cool-acronym-bot 546 | ackchyually_bot 547 | not_so_magic_8_ball 548 | MRS_0BAMA_GET_DOWN 549 | clichebot9000 550 | YoMommaJokeBot 551 | ThesaurizeThisBot 552 | multiplevideosbot 553 | botrickbateman 554 | -------------------------------------------------------------------------------- /thred/corpora/reddit/reddit_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from thred.util import fs 4 | 5 | 6 | class RedditBotHandler: 7 | 8 | def __init__(self): 9 | super(RedditBotHandler, self).__init__() 10 | 11 | self.bots = set() 12 | self.bots.add('[deleted]') 13 | with open(self.__get_bot_file(), 'r') as file: 14 | for username in file: 15 | if not username.startswith("#"): 16 | self.bots.add(username.strip().lower()) 17 | 18 | def __get_bot_file(self): 19 | return os.path.join(fs.get_current_dir(__file__), 'reddit_bots.txt') 20 | 21 | def is_bot(self, username): 22 | return username.lower() in self.bots 23 | -------------------------------------------------------------------------------- /thred/corpora/reddit/sanitizer.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import re 4 | 5 | from tqdm import tqdm 6 | 7 | from thred.util import fs 8 | 9 | 10 | def main(): 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-p', '--profanity_file', type=str, help='the profanity file') 15 | parser.add_argument('-f', '--data_file', type=str, required=True, help="the data file") 16 | 17 | params = parser.parse_args() 18 | 19 | if params.profanity_file is None: 20 | containing_dir, _, _ = fs.split3(os.path.abspath(__file__)) 21 | profanity_file = os.path.join(containing_dir, "profanity_words.txt") 22 | else: 23 | profanity_file = params.profanity_file 24 | 25 | one_word_profanities = set() 26 | multi_word_profanities = set() 27 | with codecs.getreader('utf-8')(open(profanity_file, 'rb')) as profanity_reader: 28 | for line in profanity_reader: 29 | line = line.strip() 30 | if not line: 31 | continue 32 | 33 | if ' ' in line: 34 | multi_word_profanities.add(line) 35 | else: 36 | one_word_profanities.add(line) 37 | 38 | print("Profanity words loaded ({} one words/{} multi words)".format(len(one_word_profanities), len(multi_word_profanities))) 39 | 40 | output_file = fs.replace_ext(params.data_file, 'filtered.txt') 41 | prof1, profN = 0, 0 42 | with codecs.getreader('utf-8')(open(params.data_file, 'rb')) as data_file, \ 43 | codecs.getwriter('utf-8')(open(output_file, 'wb')) as out_file: 44 | for line in tqdm(data_file): 45 | post = line.strip() 46 | utterances = post.split("\t") 47 | filtered = False 48 | for utterance in utterances: 49 | 50 | for word in utterance.split(): 51 | if word in one_word_profanities: 52 | filtered = True 53 | prof1 += 1 54 | break 55 | 56 | for profanity in multi_word_profanities: 57 | if profanity in utterance: 58 | filtered = True 59 | profN += 1 60 | break 61 | 62 | if not filtered: 63 | s, e = u"\U0001F100", u"\U0001F9FF" 64 | s2, e2 = u"\U00010000", u"\U0001342E" 65 | post = re.sub(r'[{}-{}{}-{}]'.format(s, e, s2, e2), '', post) 66 | post = re.sub(r"[\uFEFF\u2060\u2E18]", '', post) 67 | line = "\t".join([" ".join(w.strip() for w in u.split()) for u in post.split("\t")]) 68 | out_file.write(line + "\n") 69 | 70 | print("Filtered {} (One word {} / Multi word {})".format(prof1 + profN, prof1, profN)) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() -------------------------------------------------------------------------------- /thred/corpora/reddit/subreddit_whitelist.txt: -------------------------------------------------------------------------------- 1 | /r/askscience t5_2qm4e title 2 | /r/explainlikeimfive t5_2sokd title start$ELI5:,[ELI5],ELI5 -,ELI5-,ELI5 3 | /r/news t5_2qh3l 4 | /r/worldnews t5_2qh13 5 | /r/GetStudying t5_2tl44 6 | /r/LifeProTips t5_2s5oq title start$LPT:,LPT-,LPT -,LPT,LPT Request:,LPT REQUEST :,[LPT REQUEST] 7 | /r/GetMotivated t5_2rmfx title start,end$[Text],[Image],[Discussion],[Article],[Video],[Story],[Tool] 8 | /r/minimalism t5_2r0z9 title start,end$[arts],[lifestyle],[Academic],[fun],[meta] 9 | /r/sports t5_2qgzy title 10 | /r/movies t5_2qh3s 11 | /r/television t5_2qh6e 12 | /r/politics t5_2cneq 13 | /r/YouShouldKnow t5_2r94o title start$YSK,YSK:,YSK -, YSK-$You Should Know 14 | /r/AskReddit t5_2qh1i title 15 | /r/funny t5_2qh33 title 16 | /r/Jokes t5_2qh72 title 17 | /r/TalesFromRetail t5_2t2zt 18 | /r/nostalgia t5_2qnub title 19 | /r/bestof t5_2qh3v 20 | /r/todayilearned t5_2qqjc title start$TIL,TIL:,TIL -,TIL-$Today I learned 21 | /r/space t5_2qh87 22 | /r/trees t5_2r9vp title 23 | /r/food t5_2qh55 title start$[Homemade],[I Ate]$Homemade,I ate 24 | /r/canada t5_2qh68 25 | /r/toronto t5_2qi63 26 | /r/vancouver t5_2qhov 27 | /r/europe t5_2qh4j 28 | /r/worldpolitics t5_2qh9a 29 | /r/CanadaPolitics t5_2s4gt 30 | /r/worldevents t5_2riv9 31 | /r/announcements t5_2r0ij title 32 | /r/gadgets t5_2qgzt 33 | /r/education t5_2qhlm 34 | /r/business t5_2qgzg 35 | /r/CasualConversation t5_323oy title 36 | /r/AskScienceDiscussion t5_2vlah title 37 | /r/Showerthoughts t5_2szyo title 38 | /r/NoStupidQuestions t5_2w844 title 39 | /r/femalefashionadvice t5_2s8o5 title 40 | /r/malefashionadvice t5_2r65t 41 | /r/Fitness t5_2qhx4 title 42 | /r/Cooking t5_2qh7f title 43 | /r/solotravel t5_2rxxm title 44 | /r/DecidingToBeBetter t5_2tand title 45 | /r/productivity t5_2qh1k title 46 | /r/Anxiety t5_2qmij title 47 | /r/depression t5_2qqqf title 48 | /r/InsightfulQuestions t5_2smsq title 49 | /r/answers t5_2qkeh title 50 | /r/HealthyFood t5_2rhbm title 51 | /r/TrueAskReddit t5_2s91q title 52 | /r/askphilosophy t5_2sc5r title 53 | /r/AskEngineers t5_2sebk title 54 | /r/nutrition t5_2qoox title 55 | /r/vegetarian t5_2qm7x title 56 | /r/fitmeals t5_2sd23 title 57 | /r/running t5_2qlit title 58 | /r/bicycling t5_2qi0s title 59 | /r/C25K t5_2rgoq title 60 | /r/Advice t5_2qjdm title 61 | /r/Sneakers t5_2qrtt title 62 | /r/AskCulinary t5_2t82m title 63 | /r/cookingforbeginners t5_32u9b title 64 | /r/Coffee t5_2qhze title 65 | /r/tea t5_2qq5e title 66 | /r/Pizza t5_2qlhq title 67 | /r/BBQ t5_2qxww title 68 | /r/grilledcheese t5_2qyw8 title 69 | /r/cscareerquestions t5_2sdpm title 70 | /r/EngineeringStudents t5_2sh0b title 71 | /r/Entrepreneur t5_2qldo title 72 | /r/Economics t5_2qh1s 73 | /r/marketing t5_2qhmg title 74 | /r/smallbusiness t5_2qr34 title 75 | /r/environment t5_2qh1n 76 | /r/InternetIsBeautiful t5_2ul7u 77 | /r/AskScienceFiction t5_2slu2 78 | /r/books t5_2qh4i title 79 | /r/wikipedia t5_2qh3b 80 | /r/WikiLeaks t5_2qy11 81 | /r/Documentaries t5_2qhlh 82 | /r/marvelstudios t5_2uii8 83 | /r/Moviesinthemaking t5_2uiff 84 | /r/Guitar t5_2qi79 85 | /r/formula1 t5_2qimj 86 | /r/golf t5_2qhcs 87 | /r/skiing t5_2qig7 88 | /r/soccer t5_2qi58 89 | /r/hockey t5_2qiel 90 | /r/StrangerThings t5_3adlm 91 | /r/breakingbad t5_2rlw4 92 | /r/Accounting t5_2qw2b 93 | /r/cars t5_2qhl2 94 | /r/UpliftingNews t5_2u3ta 95 | /r/changemyview t5_2w2s8 title start$CMV: -------------------------------------------------------------------------------- /thred/main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import logging 3 | 4 | from .models import model_factory 5 | from .util.config import Config 6 | 7 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 8 | datefmt = '%m/%d/%Y %H:%M:%S', 9 | level = logging.INFO) 10 | 11 | 12 | def main(_): 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--mode', type=str, required=True, choices=('train', 'interactive', 'test'), help='work mode') 17 | parser.add_argument('--model_dir', type=str, required=True, help='model directory') 18 | 19 | parser.add_argument('--config', type=str, help='config file containing parameters to configure the model') 20 | 21 | parser.add_argument('--train_data', type=str, help='training dataset') 22 | parser.add_argument('--dev_data', type=str, help='development dataset') 23 | parser.add_argument('--test_data', type=str, help='tests dataset') 24 | 25 | parser.add_argument('--embed_conf', type=str, default="conf/word_embeddings.yml", help='embedding config file') 26 | parser.add_argument('--restart_training', action='store_true', help='remove saved models and logs in the model directory to start training from scratch') 27 | parser.add_argument('--eval_best_model', action='store_true', help='whether to evaluate the best model after training finished') 28 | 29 | parser.add_argument('--num_gpus', type=int, default=4, help='number of GPUs to use') 30 | parser.add_argument('--n_responses', type=int, default=1, help='number of generated responses') 31 | parser.add_argument('--beam_width', type=int, help='beam width to override the value in config file') 32 | parser.add_argument('--length_penalty_weight', type=float, 33 | help='length penalty to override the value in config file') 34 | parser.add_argument('--sampling_temperature', type=float, 35 | help='sampling temperature to override the value in config file') 36 | parser.add_argument('--lda_model_dir', type=str, help='required only for testing with topical models (THRED and TA-Seq2Seq)') 37 | 38 | args = vars(parser.parse_args()) 39 | config = Config(**args) 40 | 41 | model = model_factory.create_model(config) 42 | 43 | if config.mode == 'train': 44 | model.train() 45 | elif config.mode == 'interactive': 46 | model.interactive() 47 | elif config.mode == 'test': 48 | model.test() 49 | 50 | 51 | if __name__ == "__main__": 52 | tf.app.run() 53 | -------------------------------------------------------------------------------- /thred/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nouhadziri/THRED/c451e801d6d36c21f52a42192b15f985241650b2/thred/models/__init__.py -------------------------------------------------------------------------------- /thred/models/attention_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def is_attention_enabled(attention_option): 5 | return attention_option is not None and \ 6 | attention_option in ("luong", "scaled_luong", "bahdanau", "normed_bahdanau") 7 | 8 | 9 | def create_attention_mechanism(attention_option, num_units, memory, memory_length): 10 | """Create attention mechanism based on the attention_option.""" 11 | 12 | # Mechanism 13 | if attention_option == "luong": 14 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 15 | num_units, memory, memory_sequence_length=memory_length) 16 | elif attention_option == "scaled_luong": 17 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 18 | num_units, 19 | memory, 20 | memory_sequence_length=memory_length, 21 | scale=True) 22 | elif attention_option == "bahdanau": 23 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 24 | num_units, memory, memory_sequence_length=memory_length) 25 | elif attention_option == "normed_bahdanau": 26 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 27 | num_units, 28 | memory, 29 | memory_sequence_length=memory_length, 30 | normalize=True) 31 | else: 32 | raise ValueError("Unknown attention option %s" % attention_option) 33 | 34 | return attention_mechanism -------------------------------------------------------------------------------- /thred/models/hred/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nouhadziri/THRED/c451e801d6d36c21f52a42192b15f985241650b2/thred/models/hred/__init__.py -------------------------------------------------------------------------------- /thred/models/hred/hred_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from thred.models.model_helper import TrainModel, EvalModel, InferModel 4 | from thred.models.hred import hred_iterators 5 | from thred.models.hred.hred_model import HierarchichalSeq2SeqModel 6 | from thred.util import vocab 7 | 8 | 9 | def create_train_model(hparams, scope=None, num_workers=1, jobid=0): 10 | """Create train graph, model, and iterator.""" 11 | 12 | graph = tf.Graph() 13 | 14 | vocab.create_vocabulary(hparams.vocab_file, hparams.train_data, hparams.vocab_size) 15 | 16 | with graph.as_default(), tf.container(scope or "train"): 17 | vocab_table = vocab.create_vocab_table(hparams.vocab_file) 18 | 19 | dataset = tf.data.TextLineDataset(hparams.train_data) 20 | skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) 21 | 22 | iterator = hred_iterators.get_iterator( 23 | dataset, 24 | vocab_table, 25 | hparams.batch_size, 26 | hparams.num_turns, 27 | hparams.num_buckets, 28 | hparams.src_max_len, 29 | hparams.tgt_max_len, 30 | skip_count=skip_count_placeholder, 31 | num_shards=num_workers, 32 | shard_index=jobid) 33 | 34 | # Note: One can set model_device_fn to 35 | # `tf.train.replica_device_setter(ps_tasks)` for distributed training. 36 | model_device_fn = None 37 | # if extra_args: model_device_fn = extra_args.model_device_fn 38 | with tf.device(model_device_fn): 39 | model = HierarchichalSeq2SeqModel( 40 | mode=tf.contrib.learn.ModeKeys.TRAIN, 41 | iterator=iterator, 42 | num_turns=hparams.num_turns, 43 | params=hparams, 44 | scope=scope) 45 | 46 | return TrainModel(graph=graph, 47 | model=model, 48 | iterator=iterator, 49 | skip_count_placeholder=skip_count_placeholder) 50 | 51 | 52 | def create_pretrain_model(hparams, scope=None, num_workers=1, jobid=0): 53 | """Create train graph, model, and iterator.""" 54 | graph = tf.Graph() 55 | 56 | with graph.as_default(), tf.container(scope or "pretrain"): 57 | vocab_table = vocab.create_vocab_table(hparams.vocab_file) 58 | 59 | iterator = hred_iterators.get_iterator( 60 | hparams.pretrain_data, 61 | vocab_table, 62 | hparams.batch_size, 63 | hparams.num_pretrain_turns, 64 | hparams.num_buckets, 65 | hparams.src_max_len, 66 | hparams.tgt_max_len, 67 | num_shards=num_workers, 68 | shard_index=jobid) 69 | 70 | model = HierarchichalSeq2SeqModel(mode=tf.contrib.learn.ModeKeys.TRAIN, 71 | iterator=iterator, 72 | num_turns=hparams.num_pretrain_turns, 73 | params=hparams, 74 | scope=scope, 75 | log_trainables=False) 76 | 77 | return TrainModel( 78 | graph=graph, 79 | model=model, 80 | iterator=iterator, 81 | skip_count_placeholder=None) 82 | 83 | 84 | def create_eval_model(hparams, scope=None): 85 | """Create train graph, model, src/tgt file holders, and iterator.""" 86 | 87 | graph = tf.Graph() 88 | 89 | with graph.as_default(), tf.container(scope or "eval"): 90 | vocab_table = vocab.create_vocab_table(hparams.vocab_file) 91 | eval_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) 92 | 93 | eval_dataset = tf.data.TextLineDataset(eval_file_placeholder) 94 | iterator = hred_iterators.get_iterator( 95 | eval_dataset, 96 | vocab_table, 97 | hparams.batch_size, 98 | hparams.num_turns, 99 | hparams.num_buckets, 100 | hparams.src_max_len, 101 | hparams.tgt_max_len) 102 | 103 | model = HierarchichalSeq2SeqModel(mode=tf.contrib.learn.ModeKeys.EVAL, 104 | iterator=iterator, 105 | num_turns=hparams.num_turns, 106 | params=hparams, 107 | scope=scope, 108 | log_trainables=False) 109 | return EvalModel( 110 | graph=graph, 111 | model=model, 112 | eval_file_placeholder=eval_file_placeholder, 113 | iterator=iterator) 114 | 115 | 116 | def create_infer_model(hparams, scope=None): 117 | """Create inference model.""" 118 | graph = tf.Graph() 119 | 120 | with graph.as_default(), tf.container(scope or "infer"): 121 | vocab_table = vocab.create_vocab_table(hparams.vocab_file) 122 | reverse_vocab_table = vocab.create_rev_vocab_table(hparams.vocab_file) 123 | 124 | src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) 125 | batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) 126 | 127 | src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder) 128 | 129 | iterator = hred_iterators.get_infer_iterator( 130 | src_dataset, 131 | vocab_table, 132 | batch_size=batch_size_placeholder, 133 | num_turns=hparams.num_turns, 134 | src_max_len=hparams.src_max_len) 135 | 136 | model = HierarchichalSeq2SeqModel(mode=tf.contrib.learn.ModeKeys.INFER, 137 | iterator=iterator, 138 | num_turns=hparams.num_turns, 139 | params=hparams, 140 | rev_vocab_table=reverse_vocab_table, 141 | scope=scope) 142 | return InferModel( 143 | graph=graph, 144 | model=model, 145 | src_placeholder=src_placeholder, 146 | batch_size_placeholder=batch_size_placeholder, 147 | iterator=iterator) 148 | -------------------------------------------------------------------------------- /thred/models/hred/hred_iterators.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from thred.models.model_helper import BatchedInput 4 | from thred.util import vocab 5 | 6 | 7 | def get_iterator(dataset, 8 | vocab_table, 9 | batch_size, 10 | num_turns, 11 | num_buckets, 12 | src_max_len=None, 13 | tgt_max_len=None, 14 | random_seed=None, 15 | num_parallel_calls=4, 16 | output_buffer_size=None, 17 | skip_count=None, 18 | num_shards=1, 19 | shard_index=0): 20 | num_inputs = num_turns - 1 21 | 22 | if not output_buffer_size: 23 | output_buffer_size = batch_size * 1000 24 | 25 | eos_id = tf.constant(vocab.EOS_ID, dtype=tf.int32) 26 | sos_id = tf.constant(vocab.SOS_ID, dtype=tf.int32) 27 | 28 | src_tgt_dataset = dataset.shard(num_shards, shard_index) 29 | if skip_count is not None: 30 | src_tgt_dataset = src_tgt_dataset.skip(skip_count) 31 | 32 | src_tgt_dataset = src_tgt_dataset.shuffle(output_buffer_size, random_seed) 33 | 34 | def _tokenize_lambda(line): 35 | utterances = tf.string_split([line], delimiter="\t").values 36 | 37 | srcs = [tf.string_split([utterances[t]]).values for t in range(num_inputs)] 38 | tgt = tf.string_split([utterances[num_inputs]]).values 39 | 40 | tokenized_data = { 41 | 'tgt': tgt[:tgt_max_len] if tgt_max_len else tgt 42 | } 43 | 44 | for t in range(num_inputs): 45 | tokenized_data['src_%d' % t] = srcs[t][:src_max_len] if src_max_len else srcs[t] 46 | 47 | return tokenized_data 48 | 49 | src_tgt_dataset = src_tgt_dataset.map( 50 | _tokenize_lambda, 51 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 52 | 53 | def _lookup_lambda(data): 54 | tgt = tf.cast(vocab_table.lookup(data['tgt']), tf.int32) 55 | tgt_out = tf.concat((tgt, [eos_id]), 0) 56 | mapped_data = { 57 | 'tgt_in': tf.concat(([sos_id], tgt), 0), 58 | 'tgt_out': tgt_out, 59 | 'tgt_len': tf.size(tgt_out) 60 | } 61 | 62 | for t in range(num_inputs): 63 | src = tf.cast(vocab_table.lookup(data['src_%d' % t]), tf.int32) 64 | mapped_data['src_%d' % t] = src 65 | mapped_data['src_len_%d' % t] = tf.size(src) 66 | 67 | return mapped_data 68 | 69 | src_tgt_dataset = src_tgt_dataset.map( 70 | _lookup_lambda, 71 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 72 | # Create a tgt_input prefixed with and a tgt_output suffixed with . 73 | 74 | # Add in sequence lengths. 75 | # src_tgt_dataset = src_tgt_dataset.map( 76 | # lambda srcs, tgt_in, tgt_out: ( 77 | # srcs, tgt_in, tgt_out, 78 | # [tf.size(srcs[t]) for t in range(num_inputs)], tf.size(tgt_in)), 79 | # num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 80 | 81 | padded_shapes = {'tgt_in': tf.TensorShape([None]), 82 | 'tgt_out': tf.TensorShape([None]), 83 | 'tgt_len': tf.TensorShape([])} 84 | 85 | padded_values = {'tgt_in': eos_id, 86 | 'tgt_out': eos_id, 87 | 'tgt_len': 0} 88 | 89 | for t in range(num_inputs): 90 | padded_shapes['src_%d' % t] = tf.TensorShape([None]) 91 | padded_values['src_%d' % t] = eos_id 92 | padded_shapes['src_len_%d' % t] = tf.TensorShape([]) 93 | padded_values['src_len_%d' % t] = 0 94 | 95 | def _batching_lambda(x): 96 | return x.padded_batch( 97 | batch_size, 98 | # The first three entries are the source and target line rows; 99 | # these have unknown-length vectors. The last two entries are 100 | # the source and target row sizes; these are scalars. 101 | padded_shapes=padded_shapes, 102 | # Pad the source and target sequences with eos tokens. 103 | # (Though notice we don't generally need to do this since 104 | # later on we will be masking out calculations past the true sequence. 105 | padding_values=padded_values) 106 | 107 | if num_buckets > 1: 108 | def key_func(data): 109 | # Calculate bucket_width by maximum source sequence length. 110 | # Pairs with length [0, bucket_width) go to bucket 0, length 111 | # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length 112 | # over ((num_bucket-1) * bucket_width) words all go into the last bucket. 113 | if src_max_len: 114 | bucket_width = (src_max_len + num_buckets - 1) // num_buckets 115 | else: 116 | bucket_width = 10 117 | 118 | # Bucket sentence pairs by the length of their source sentence and target 119 | # sentence. 120 | 121 | bucket_id = data['tgt_len'] // bucket_width 122 | for t in range(num_inputs): 123 | bucket_id = tf.maximum(data['src_len_%d' % t] // bucket_width, bucket_id) 124 | 125 | # bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) 126 | return tf.to_int64(tf.minimum(num_buckets, bucket_id)) 127 | 128 | def reduce_func(unused_key, windowed_data): 129 | return _batching_lambda(windowed_data) 130 | 131 | batched_dataset = src_tgt_dataset.apply( 132 | tf.contrib.data.group_by_window( 133 | key_func=key_func, reduce_func=reduce_func, window_size=batch_size)) 134 | 135 | else: 136 | batched_dataset = _batching_lambda(src_tgt_dataset) 137 | 138 | batched_iter = batched_dataset.make_initializable_iterator() 139 | batched_data = batched_iter.get_next() 140 | 141 | return BatchedInput( 142 | initializer=batched_iter.initializer, 143 | sources=[batched_data['src_%d' % t] for t in range(num_inputs)], 144 | target_input=batched_data['tgt_in'], 145 | target_output=batched_data['tgt_out'], 146 | source_sequence_lengths=[batched_data['src_len_%d' % t] for t in range(num_inputs)], 147 | target_sequence_length=batched_data['tgt_len']) 148 | 149 | 150 | def get_infer_iterator(test_dataset, 151 | vocab_table, 152 | batch_size, 153 | num_turns, 154 | src_max_len=None): 155 | num_inputs = num_turns - 1 156 | 157 | eos_id = tf.constant(vocab.EOS_ID, dtype=tf.int32) 158 | 159 | def _parse_lambda(line): 160 | utterances = tf.string_split([line], delimiter="\t").values 161 | srcs = [tf.string_split([utterances[t]]).values for t in range(num_inputs)] 162 | 163 | parsed_data = {} 164 | 165 | for t in range(num_inputs): 166 | src = srcs[t][:src_max_len] if src_max_len else srcs[t] 167 | src = tf.cast(vocab_table.lookup(src), tf.int32) 168 | parsed_data['src_%d' % t] = src 169 | parsed_data['src_len_%d' % t] = tf.size(src) 170 | 171 | return parsed_data 172 | 173 | test_dataset = test_dataset.map(_parse_lambda) 174 | 175 | padded_shapes = {} 176 | padded_values = {} 177 | 178 | for t in range(num_inputs): 179 | padded_shapes['src_%d' % t] = tf.TensorShape([None]) 180 | padded_values['src_%d' % t] = eos_id 181 | padded_shapes['src_len_%d' % t] = tf.TensorShape([]) 182 | padded_values['src_len_%d' % t] = 0 183 | 184 | def batching_func(x): 185 | return x.padded_batch( 186 | batch_size, 187 | # The entry is the source line rows; 188 | # this has unknown-length vectors. The last entry is 189 | # the source row size; this is a scalar. 190 | padded_shapes=padded_shapes, 191 | # Pad the source sequences with eos tokens. 192 | # (Though notice we don't generally need to do this since 193 | # later on we will be masking out calculations past the true sequence. 194 | padding_values=padded_values) 195 | 196 | batched_dataset = batching_func(test_dataset) 197 | batched_iter = batched_dataset.make_initializable_iterator() 198 | 199 | batched_data = batched_iter.get_next() 200 | 201 | return BatchedInput( 202 | initializer=batched_iter.initializer, 203 | sources=[batched_data['src_%d' % t] for t in range(num_inputs)], 204 | target_input=None, 205 | target_output=None, 206 | source_sequence_lengths=[batched_data['src_len_%d' % t] for t in range(num_inputs)], 207 | target_sequence_length=None) 208 | -------------------------------------------------------------------------------- /thred/models/hred/hred_wrapper.py: -------------------------------------------------------------------------------- 1 | from ..hierarchical_base import BaseHierarchicalEncoderDecoder 2 | from . import hred_helper 3 | 4 | 5 | class HierarchicalEncoderDecoder(BaseHierarchicalEncoderDecoder): 6 | def __init__(self, config): 7 | super(HierarchicalEncoderDecoder, self).__init__(config) 8 | 9 | def _get_model_helper(self): 10 | return hred_helper 11 | 12 | def _get_checkpoint_name(self): 13 | return "hred" 14 | -------------------------------------------------------------------------------- /thred/models/model_factory.py: -------------------------------------------------------------------------------- 1 | from .hred import hred_wrapper 2 | from .topic_aware import taware_wrapper 3 | from .thred import thred_wrapper 4 | from .vanilla import vanilla_wrapper 5 | 6 | 7 | def create_model(config): 8 | if config.type == 'vanilla': 9 | return vanilla_wrapper.VanillaNMTEncoderDecoder(config) 10 | elif config.type == 'hred': 11 | return hred_wrapper.HierarchicalEncoderDecoder(config) 12 | elif config.type == 'topic_aware': 13 | return taware_wrapper.TopicAwareNMTEncoderDecoder(config) 14 | elif config.type == 'thred': 15 | return thred_wrapper.TopicalHierarchicalEncoderDecoder(config) 16 | 17 | raise ValueError('unknown model: ' + config.type) 18 | -------------------------------------------------------------------------------- /thred/models/model_helper.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import time 3 | 4 | from tqdm import tqdm, trange 5 | import tensorflow as tf 6 | 7 | from ..util import misc 8 | from ..util import log 9 | 10 | 11 | class TopicalBatchedInput( 12 | collections.namedtuple("BatchedInput", 13 | ("initializer", "sources", 14 | "topic", "target_input", "target_output", 15 | "source_sequence_lengths", 16 | "topic_sequence_length", "target_sequence_length"))): 17 | pass 18 | 19 | 20 | class BatchedInput( 21 | collections.namedtuple("BatchedInput", 22 | ("initializer", "sources", "target_input", 23 | "target_output", "source_sequence_lengths", 24 | "target_sequence_length"))): 25 | pass 26 | 27 | 28 | class TrainModel( 29 | collections.namedtuple("TrainModel", ("graph", "model", "iterator", 30 | "skip_count_placeholder"))): 31 | pass 32 | 33 | 34 | class EvalModel( 35 | collections.namedtuple("EvalModel", 36 | ("graph", "model", 37 | "eval_file_placeholder", "iterator"))): 38 | pass 39 | 40 | 41 | class InferModel( 42 | collections.namedtuple("InferModel", 43 | ("graph", "model", "src_placeholder", 44 | "batch_size_placeholder", "iterator"))): 45 | pass 46 | 47 | 48 | def get_config_proto(log_device_placement): 49 | return tf.ConfigProto( 50 | log_device_placement=log_device_placement, 51 | allow_soft_placement=True, 52 | gpu_options=tf.GPUOptions(allow_growth=True)) 53 | 54 | 55 | def load_model(model, ckpt, session, name): 56 | start_time = time.time() 57 | model.saver.restore(session, ckpt) 58 | session.run(tf.tables_initializer()) 59 | log.print_out( 60 | " loaded %s model parameters from %s, time %.2fs" % 61 | (name, ckpt, time.time() - start_time)) 62 | return model 63 | 64 | 65 | def create_or_load_model(model, model_dir, session, name): 66 | """Create a model and initialize or load parameters in session.""" 67 | 68 | latest_ckpt = tf.train.latest_checkpoint(model_dir) 69 | if latest_ckpt: 70 | model = load_model(model, latest_ckpt, session, name) 71 | else: 72 | start_time = time.time() 73 | session.run(tf.global_variables_initializer()) 74 | session.run(tf.tables_initializer()) 75 | log.print_out(" created %s model with fresh parameters, time %.2fs" % 76 | (name, time.time() - start_time)) 77 | 78 | global_step = model.global_step.eval(session=session) 79 | return model, global_step 80 | 81 | 82 | def compute_perplexity(model, sess, name, data_size=None): 83 | """Compute perplexity of the output of the model. 84 | Args: 85 | model: model for compute perplexity. 86 | sess: tensorflow session to use. 87 | name: name of the batch. 88 | data_size: data size for showing the progress bar 89 | Returns: 90 | The perplexity of the eval outputs. 91 | """ 92 | total_loss = 0 93 | total_predict_count = 0 94 | start_time = time.time() 95 | step = 0 96 | 97 | if data_size: 98 | pbar = trange(data_size, desc=name) 99 | else: 100 | pbar = tqdm(desc=name) 101 | 102 | pbar.set_postfix(loss='inf', dev_ppl='inf') 103 | update_progress_every = 10 104 | 105 | while True: 106 | try: 107 | loss, predict_count, batch_size = model.eval(sess) 108 | total_loss += loss * batch_size 109 | total_predict_count += predict_count 110 | step += 1 111 | if step % update_progress_every == 0: 112 | ls = total_loss / total_predict_count 113 | ppl = misc.safe_exp(ls) 114 | pbar.set_postfix(loss=ls, dev_ppl="{:.3f}".format(ppl)) 115 | pbar.update(update_progress_every) 116 | except tf.errors.OutOfRangeError: 117 | break 118 | 119 | if data_size: 120 | pbar.n = data_size 121 | pbar.refresh() 122 | pbar.close() 123 | 124 | perplexity = misc.safe_exp(total_loss / total_predict_count) 125 | log.print_time("\n eval %s: perplexity %.2f" % (name, perplexity), start_time) 126 | return perplexity 127 | -------------------------------------------------------------------------------- /thred/models/ncm_utils.py: -------------------------------------------------------------------------------- 1 | from ..util import vocab 2 | 3 | 4 | def get_translation(ncm_outputs, sent_id): 5 | """Given batch decoding outputs, select a sentence and turn to text.""" 6 | # Select a sentence 7 | output = ncm_outputs[sent_id, :].tolist() if len(ncm_outputs.shape) > 1 else ncm_outputs.tolist() 8 | 9 | eos = vocab.EOS.encode("utf-8") 10 | 11 | # If there is an eos symbol in outputs, cut them at that point. 12 | if eos in output: 13 | output = output[:output.index(eos)] 14 | 15 | return b" ".join(output) 16 | -------------------------------------------------------------------------------- /thred/models/thred/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nouhadziri/THRED/c451e801d6d36c21f52a42192b15f985241650b2/thred/models/thred/__init__.py -------------------------------------------------------------------------------- /thred/models/thred/thred_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from ..model_helper import TrainModel, EvalModel, InferModel 4 | from .thred_iterators import get_iterator, get_infer_iterator 5 | from .thred_model import TopicAwareHierarchicalSeq2SeqModel 6 | from thred.util import vocab 7 | 8 | 9 | def create_train_model(hparams, scope=None, num_workers=1, jobid=0): 10 | """Create train graph, model, and iterator.""" 11 | 12 | graph = tf.Graph() 13 | 14 | vocab.create_vocabulary(hparams.vocab_file, hparams.train_data, hparams.vocab_size) 15 | 16 | with graph.as_default(), tf.container(scope or "train"): 17 | vocab_table = vocab.create_vocab_table(hparams.vocab_file) 18 | 19 | dataset = tf.data.TextLineDataset(hparams.train_data) 20 | skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) 21 | 22 | iterator = get_iterator( 23 | dataset, 24 | vocab_table, 25 | hparams.batch_size, 26 | hparams.num_turns, 27 | hparams.num_buckets, 28 | hparams.topic_words_per_utterance, 29 | hparams.src_max_len, 30 | hparams.tgt_max_len, 31 | skip_count=skip_count_placeholder, 32 | num_shards=num_workers, 33 | shard_index=jobid) 34 | 35 | # Note: One can set model_device_fn to 36 | # `tf.train.replica_device_setter(ps_tasks)` for distributed training. 37 | model_device_fn = None 38 | # if extra_args: model_device_fn = extra_args.model_device_fn 39 | with tf.device(model_device_fn): 40 | model = TopicAwareHierarchicalSeq2SeqModel( 41 | mode=tf.contrib.learn.ModeKeys.TRAIN, 42 | iterator=iterator, 43 | num_turns=hparams.num_turns, 44 | params=hparams, 45 | scope=scope) 46 | 47 | return TrainModel(graph=graph, 48 | model=model, 49 | iterator=iterator, 50 | skip_count_placeholder=skip_count_placeholder) 51 | 52 | 53 | def create_pretrain_model(hparams, scope=None, num_workers=1, jobid=0): 54 | """Create train graph, model, and iterator.""" 55 | graph = tf.Graph() 56 | 57 | with graph.as_default(), tf.container(scope or "pretrain"): 58 | vocab_table = vocab.create_vocab_table(hparams.vocab_file) 59 | 60 | iterator = get_iterator( 61 | hparams.pretrain_data, 62 | vocab_table, 63 | hparams.batch_size, 64 | hparams.num_pretrain_turns, 65 | hparams.num_buckets, 66 | hparams.topic_words_per_utterance, 67 | hparams.src_max_len, 68 | hparams.tgt_max_len, 69 | num_shards=num_workers, 70 | shard_index=jobid) 71 | 72 | model = TopicAwareHierarchicalSeq2SeqModel( 73 | mode=tf.contrib.learn.ModeKeys.TRAIN, 74 | iterator=iterator, 75 | num_turns=hparams.num_pretrain_turns, 76 | params=hparams, 77 | scope=scope, 78 | log_trainables=False) 79 | 80 | return TrainModel( 81 | graph=graph, 82 | model=model, 83 | iterator=iterator, 84 | skip_count_placeholder=None) 85 | 86 | 87 | def create_eval_model(hparams, scope=None): 88 | """Create train graph, model, src/tgt file holders, and iterator.""" 89 | 90 | graph = tf.Graph() 91 | 92 | with graph.as_default(), tf.container(scope or "eval"): 93 | vocab_table = vocab.create_vocab_table(hparams.vocab_file) 94 | eval_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) 95 | 96 | eval_dataset = tf.data.TextLineDataset(eval_file_placeholder) 97 | iterator = get_iterator( 98 | eval_dataset, 99 | vocab_table, 100 | hparams.batch_size, 101 | hparams.num_turns, 102 | hparams.num_buckets, 103 | hparams.topic_words_per_utterance, 104 | hparams.src_max_len, 105 | hparams.tgt_max_len) 106 | 107 | model = TopicAwareHierarchicalSeq2SeqModel( 108 | mode=tf.contrib.learn.ModeKeys.EVAL, 109 | iterator=iterator, 110 | num_turns=hparams.num_turns, 111 | params=hparams, 112 | scope=scope, 113 | log_trainables=False) 114 | 115 | return EvalModel( 116 | graph=graph, 117 | model=model, 118 | eval_file_placeholder=eval_file_placeholder, 119 | iterator=iterator) 120 | 121 | 122 | def create_infer_model(hparams, scope=None): 123 | """Create inference model.""" 124 | graph = tf.Graph() 125 | 126 | with graph.as_default(), tf.container(scope or "infer"): 127 | vocab_table = vocab.create_vocab_table(hparams.vocab_file) 128 | reverse_vocab_table = vocab.create_rev_vocab_table(hparams.vocab_file) 129 | 130 | src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) 131 | batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) 132 | 133 | src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder) 134 | 135 | iterator = get_infer_iterator( 136 | src_dataset, 137 | vocab_table, 138 | batch_size=batch_size_placeholder, 139 | num_turns=hparams.num_turns, 140 | topic_words_per_utterance=hparams.topic_words_per_utterance, 141 | src_max_len=hparams.src_max_len) 142 | 143 | model = TopicAwareHierarchicalSeq2SeqModel( 144 | mode=tf.contrib.learn.ModeKeys.INFER, 145 | iterator=iterator, 146 | num_turns=hparams.num_turns, 147 | params=hparams, 148 | rev_vocab_table=reverse_vocab_table, 149 | scope=scope, 150 | log_trainables=False) 151 | 152 | return InferModel( 153 | graph=graph, 154 | model=model, 155 | src_placeholder=src_placeholder, 156 | batch_size_placeholder=batch_size_placeholder, 157 | iterator=iterator) 158 | -------------------------------------------------------------------------------- /thred/models/thred/thred_iterators.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from ..model_helper import TopicalBatchedInput as BatchedInput 4 | from thred.util import vocab 5 | 6 | 7 | def get_iterator(dataset, 8 | vocab_table, 9 | batch_size, 10 | num_turns, 11 | num_buckets, 12 | topic_words_per_utterance=None, 13 | src_max_len=None, 14 | tgt_max_len=None, 15 | random_seed=None, 16 | num_parallel_calls=4, 17 | output_buffer_size=None, 18 | skip_count=None, 19 | num_shards=1, 20 | shard_index=0): 21 | num_inputs = num_turns - 1 22 | 23 | if not output_buffer_size: 24 | output_buffer_size = batch_size * 1000 25 | 26 | eos_id = tf.constant(vocab.EOS_ID, dtype=tf.int32) 27 | sos_id = tf.constant(vocab.SOS_ID, dtype=tf.int32) 28 | 29 | src_tgt_dataset = dataset.shard(num_shards, shard_index) 30 | if skip_count is not None: 31 | src_tgt_dataset = src_tgt_dataset.skip(skip_count) 32 | 33 | src_tgt_dataset = src_tgt_dataset.shuffle(output_buffer_size, random_seed) 34 | 35 | def _tokenize_lambda(line): 36 | delimited_line = tf.string_split([line], delimiter="\t").values 37 | srcs = [tf.string_split([delimited_line[t]]).values for t in range(num_inputs)] 38 | tgt = tf.string_split([delimited_line[num_inputs]]).values 39 | topic = tf.string_split([delimited_line[-1]]).values 40 | 41 | tokenized_data = { 42 | 'tgt': tgt[:tgt_max_len] if tgt_max_len else tgt, 43 | 'topic': topic[:topic_words_per_utterance] if topic_words_per_utterance else topic, 44 | } 45 | 46 | for t in range(num_inputs): 47 | tokenized_data['src_%d' % t] = srcs[t][:src_max_len] if src_max_len else srcs[t] 48 | 49 | return tokenized_data 50 | 51 | src_tgt_dataset = src_tgt_dataset.map( 52 | _tokenize_lambda, 53 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 54 | 55 | def _lookup_lambda(data): 56 | tgt = tf.cast(vocab_table.lookup(data['tgt']), tf.int32) 57 | tgt_out = tf.concat((tgt, [eos_id]), 0) 58 | topic = tf.cast(vocab_table.lookup(data['topic']), tf.int32) 59 | 60 | mapped_data = { 61 | 'tgt_in': tf.concat(([sos_id], tgt), 0), 62 | 'tgt_out': tgt_out, 63 | 'tgt_len': tf.size(tgt_out), 64 | 'topic': topic, 65 | 'topic_len': tf.size(topic), 66 | } 67 | 68 | for t in range(num_inputs): 69 | src = tf.cast(vocab_table.lookup(data['src_%d' % t]), tf.int32) 70 | mapped_data['src_%d' % t] = src 71 | mapped_data['src_len_%d' % t] = tf.size(src) 72 | 73 | return mapped_data 74 | 75 | src_tgt_dataset = src_tgt_dataset.map( 76 | _lookup_lambda, 77 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 78 | # Create a tgt_input prefixed with and a tgt_output suffixed with . 79 | 80 | # Add in sequence lengths. 81 | # src_tgt_dataset = src_tgt_dataset.map( 82 | # lambda srcs, tgt_in, tgt_out: ( 83 | # srcs, tgt_in, tgt_out, 84 | # [tf.size(srcs[t]) for t in range(num_inputs)], tf.size(tgt_in)), 85 | # num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 86 | 87 | padded_shapes = { 88 | 'topic': tf.TensorShape([None]), 89 | 'tgt_in': tf.TensorShape([None]), 90 | 'tgt_out': tf.TensorShape([None]), 91 | 'topic_len': tf.TensorShape([]), 92 | 'tgt_len': tf.TensorShape([]) 93 | } 94 | 95 | padded_values = { 96 | 'topic': eos_id, 97 | 'tgt_in': eos_id, 98 | 'tgt_out': eos_id, 99 | 'topic_len': 0, 100 | 'tgt_len': 0 101 | } 102 | 103 | for t in range(num_inputs): 104 | padded_shapes['src_%d' % t] = tf.TensorShape([None]) 105 | padded_values['src_%d' % t] = eos_id 106 | padded_shapes['src_len_%d' % t] = tf.TensorShape([]) 107 | padded_values['src_len_%d' % t] = 0 108 | 109 | def _batching_lambda(x): 110 | return x.padded_batch( 111 | batch_size, 112 | # The first three entries are the source and target line rows; 113 | # these have unknown-length vectors. The last two entries are 114 | # the source and target row sizes; these are scalars. 115 | padded_shapes=padded_shapes, 116 | # Pad the source and target sequences with eos tokens. 117 | # (Though notice we don't generally need to do this since 118 | # later on we will be masking out calculations past the true sequence. 119 | padding_values=padded_values) 120 | 121 | if num_buckets > 1: 122 | def key_func(data): 123 | # Calculate bucket_width by maximum source sequence length. 124 | # Pairs with length [0, bucket_width) go to bucket 0, length 125 | # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length 126 | # over ((num_bucket-1) * bucket_width) words all go into the last bucket. 127 | if src_max_len: 128 | bucket_width = (src_max_len + num_buckets - 1) // num_buckets 129 | else: 130 | bucket_width = 10 131 | 132 | # Bucket sentence pairs by the length of their source sentence and target 133 | # sentence. 134 | 135 | bucket_id = data['tgt_len'] // bucket_width 136 | for t in range(num_inputs): 137 | bucket_id = tf.maximum(data['src_len_%d' % t] // bucket_width, bucket_id) 138 | 139 | # bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) 140 | return tf.to_int64(tf.minimum(num_buckets, bucket_id)) 141 | 142 | def reduce_func(_, windowed_data): 143 | return _batching_lambda(windowed_data) 144 | 145 | batched_dataset = src_tgt_dataset.apply( 146 | tf.contrib.data.group_by_window( 147 | key_func=key_func, reduce_func=reduce_func, window_size=batch_size)) 148 | 149 | else: 150 | batched_dataset = _batching_lambda(src_tgt_dataset) 151 | 152 | batched_iter = batched_dataset.make_initializable_iterator() 153 | batched_data = batched_iter.get_next() 154 | 155 | return BatchedInput( 156 | initializer=batched_iter.initializer, 157 | sources=[batched_data['src_%d' % t] for t in range(num_inputs)], 158 | topic=batched_data['topic'], 159 | target_input=batched_data['tgt_in'], 160 | target_output=batched_data['tgt_out'], 161 | source_sequence_lengths=[batched_data['src_len_%d' % t] for t in range(num_inputs)], 162 | topic_sequence_length=batched_data['topic_len'], 163 | target_sequence_length=batched_data['tgt_len']) 164 | 165 | 166 | def get_infer_iterator(test_dataset, 167 | vocab_table, 168 | batch_size, 169 | num_turns, 170 | topic_words_per_utterance=None, 171 | src_max_len=None): 172 | num_inputs = num_turns - 1 173 | 174 | eos_id = tf.constant(vocab.EOS_ID, dtype=tf.int32) 175 | 176 | def _parse_lambda(line): 177 | delimited_line = tf.string_split([line], delimiter="\t").values 178 | # utterances = tf.string_split([tf.py_func(lambda x: x.strip(), [delimited_line[0]], [tf.string])[0]], 179 | # delimiter="\t").values 180 | srcs = [tf.string_split([delimited_line[t]]).values for t in range(num_inputs)] 181 | # topic = tf.string_split([tf.py_func(lambda x: x.strip(), [delimited_line[1]], [tf.string])[0]]).values 182 | topic = tf.string_split([delimited_line[-1]]).values 183 | topic = topic[:topic_words_per_utterance] if topic_words_per_utterance else topic 184 | topic = tf.cast(vocab_table.lookup(topic), tf.int32) 185 | 186 | parsed_data = { 187 | 'topic': topic, 188 | 'topic_len': tf.size(topic) 189 | } 190 | 191 | for t in range(num_inputs): 192 | src = srcs[t][:src_max_len] if src_max_len else srcs[t] 193 | src = tf.cast(vocab_table.lookup(src), tf.int32) 194 | parsed_data['src_%d' % t] = src 195 | parsed_data['src_len_%d' % t] = tf.size(src) 196 | 197 | return parsed_data 198 | 199 | test_dataset = test_dataset.map(_parse_lambda) 200 | 201 | padded_shapes = {'topic': tf.TensorShape([None]), 202 | 'topic_len': tf.TensorShape([])} 203 | padded_values = {'topic': eos_id, 204 | 'topic_len': 0} 205 | 206 | for t in range(num_inputs): 207 | padded_shapes['src_%d' % t] = tf.TensorShape([None]) 208 | padded_values['src_%d' % t] = eos_id 209 | padded_shapes['src_len_%d' % t] = tf.TensorShape([]) 210 | padded_values['src_len_%d' % t] = 0 211 | 212 | def batching_func(x): 213 | return x.padded_batch( 214 | batch_size, 215 | # The entry is the source line rows; 216 | # this has unknown-length vectors. The last entry is 217 | # the source row size; this is a scalar. 218 | padded_shapes=padded_shapes, 219 | # Pad the source sequences with eos tokens. 220 | # (Though notice we don't generally need to do this since 221 | # later on we will be masking out calculations past the true sequence. 222 | padding_values=padded_values) 223 | 224 | batched_dataset = batching_func(test_dataset) 225 | batched_iter = batched_dataset.make_initializable_iterator() 226 | 227 | batched_data = batched_iter.get_next() 228 | 229 | return BatchedInput( 230 | initializer=batched_iter.initializer, 231 | sources=[batched_data['src_%d' % t] for t in range(num_inputs)], 232 | topic=batched_data['topic'], 233 | target_input=None, 234 | target_output=None, 235 | source_sequence_lengths=[batched_data['src_len_%d' % t] for t in range(num_inputs)], 236 | topic_sequence_length=batched_data['topic_len'], 237 | target_sequence_length=None) 238 | -------------------------------------------------------------------------------- /thred/models/thred/thred_wrapper.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import numpy as np 4 | 5 | from thred.util import fs, log 6 | from thred.util.embed import EmbeddingUtil 7 | from . import thred_helper 8 | from .. import topical_base, ncm_utils 9 | from ..hierarchical_base import BaseHierarchicalEncoderDecoder 10 | 11 | 12 | class TopicalHierarchicalEncoderDecoder(BaseHierarchicalEncoderDecoder): 13 | def __init__(self, config): 14 | super(TopicalHierarchicalEncoderDecoder, self).__init__(config) 15 | 16 | def _get_model_helper(self): 17 | return thred_helper 18 | 19 | def _get_checkpoint_name(self): 20 | return "thred" 21 | 22 | def _pre_model_creation(self): 23 | self.config['topic_vocab_file'] = path.join(fs.get_current_dir(self.config.vocab_file), 'topic_vocab.in') 24 | self._vocab_table, self.__topic_vocab_table = topical_base.initialize_vocabulary(self.config) 25 | 26 | EmbeddingUtil(self.config.embed_conf).build_if_not_exists( 27 | self.config.embedding_type, self.config.vocab_pkl, self.config.vocab_file) 28 | 29 | if 'original_vocab_size' not in self.config: 30 | self.config['original_vocab_size'] = self.config.vocab_size 31 | 32 | self.config.vocab_size = len(self._vocab_table) 33 | self.config.topic_vocab_size = len(self.__topic_vocab_table) 34 | 35 | if self.config.mode == "interactive" and self.config.lda_model_dir is None: 36 | raise ValueError("In interactive mode, THRED requires a pretrained LDA model") 37 | 38 | def _sample_decode(self, 39 | model, global_step, sess, src_placeholder, batch_size_placeholder, eval_data, summary_writer): 40 | """Pick a sentence and decode.""" 41 | decode_ids = np.random.randint(low=0, high=len(eval_data) - 1, size=1) 42 | 43 | sample_data = [] 44 | for decode_id in decode_ids: 45 | sample_data.append(eval_data[decode_id]) 46 | 47 | iterator_feed_dict = { 48 | src_placeholder: sample_data, 49 | batch_size_placeholder: len(decode_ids), 50 | } 51 | 52 | sess.run(model.iterator.initializer, feed_dict=iterator_feed_dict) 53 | ncm_outputs, infer_summary = model.decode(sess) 54 | 55 | for i, decode_id in enumerate(decode_ids): 56 | log.print_out(" # {}".format(decode_id)) 57 | 58 | output = ncm_outputs[i] 59 | 60 | if self.config.beam_width > 0 and self._consider_beam(): 61 | # get the top translation. 62 | output = output[0] 63 | 64 | translation = ncm_utils.get_translation(output, sent_id=0) 65 | delimited_sample = eval_data[decode_id].split("\t") 66 | utterances, topic = delimited_sample[:-1], delimited_sample[-1] 67 | sources, target = utterances[:-1], utterances[-1] 68 | 69 | log.print_out(" sources:") 70 | for t, src in enumerate(sources): 71 | log.print_out(" @{} {}".format(t + 1, src)) 72 | log.print_out(" topic: {}".format(topic)) 73 | log.print_out(" resp: {}".format(target)) 74 | log.print_out(b" generated: " + translation) 75 | 76 | # Summary 77 | if infer_summary is not None: 78 | summary_writer.add_summary(infer_summary, global_step) 79 | -------------------------------------------------------------------------------- /thred/models/topic_aware/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nouhadziri/THRED/c451e801d6d36c21f52a42192b15f985241650b2/thred/models/topic_aware/__init__.py -------------------------------------------------------------------------------- /thred/models/topic_aware/taware_decoder.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import tensorflow as tf 4 | 5 | from tensorflow.python.framework import ops 6 | from tensorflow.python.ops import control_flow_ops 7 | from tensorflow.python.ops import math_ops 8 | from tensorflow.python.util import nest 9 | from tensorflow.contrib.seq2seq.python.ops.beam_search_decoder import _beam_search_step 10 | 11 | from . import taware_layer 12 | 13 | 14 | class ConservativeDecoderOutput( 15 | collections.namedtuple("ConservativeDecoderOutput", ("rnn_output", "prev_input", "context"))): 16 | pass 17 | 18 | 19 | class ConservativeBasicDecoder(tf.contrib.seq2seq.BasicDecoder): 20 | def __init__(self, cell, helper, initial_state, output_layer): 21 | # if not isinstance(output_layer, taware_layer.JointDenseLayer): 22 | # raise ValueError('Output layer must be of type: JointDenseLayer') 23 | 24 | self._current_context = None 25 | 26 | super(ConservativeBasicDecoder, self).__init__(cell, helper, initial_state, output_layer) 27 | 28 | def step(self, time, inputs, state, name=None): 29 | with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): 30 | cell_outputs, cell_state = self._cell(inputs, state) 31 | if self._output_layer is not None: 32 | # My modification 33 | if isinstance(self._output_layer, taware_layer.JointDenseLayer): 34 | if self._current_context is not None: 35 | msg_attention, _ = tf.split(self._current_context, num_or_size_splits=2, axis=1) 36 | cell_outputs = self._output_layer(cell_outputs, input=inputs, context=msg_attention) 37 | else: 38 | cell_outputs = self._output_layer(cell_outputs, input=inputs) 39 | else: 40 | cell_outputs = self._output_layer(cell_outputs) 41 | 42 | sample_ids = self._helper.sample( 43 | time=time, outputs=cell_outputs, state=cell_state) 44 | (finished, next_inputs, next_state) = self._helper.next_inputs( 45 | time=time, 46 | outputs=cell_outputs, 47 | state=cell_state, 48 | sample_ids=sample_ids) 49 | # My modification 50 | self._current_context = cell_state.attention 51 | 52 | outputs = tf.contrib.seq2seq.BasicDecoderOutput(cell_outputs, sample_ids) 53 | return (outputs, next_state, next_inputs, finished) 54 | 55 | 56 | class ConservativeBeamSearchDecoder(tf.contrib.seq2seq.BeamSearchDecoder): 57 | def __init__(self, 58 | cell, 59 | embedding, 60 | start_tokens, 61 | end_token, 62 | initial_state, 63 | beam_width, 64 | output_layer, 65 | length_penalty_weight=0.0): 66 | # if not isinstance(output_layer, taware_layer.JointDenseLayer): 67 | # raise ValueError('Output layer must be of type: JointDenseLayer') 68 | 69 | self._current_context = None 70 | 71 | super(ConservativeBeamSearchDecoder, self).__init__(cell, 72 | embedding, 73 | start_tokens, 74 | end_token, 75 | initial_state, 76 | beam_width, 77 | output_layer, 78 | length_penalty_weight) 79 | 80 | def step(self, time, inputs, state, name=None): 81 | batch_size = self._batch_size 82 | beam_width = self._beam_width 83 | end_token = self._end_token 84 | length_penalty_weight = self._length_penalty_weight 85 | 86 | with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): 87 | cell_state = state.cell_state 88 | inputs = nest.map_structure( 89 | lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) 90 | cell_state = nest.map_structure( 91 | self._maybe_merge_batch_beams, 92 | cell_state, self._cell.state_size) 93 | cell_outputs, next_cell_state = self._cell(inputs, cell_state) 94 | cell_outputs = nest.map_structure( 95 | lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) 96 | next_cell_state = nest.map_structure( 97 | self._maybe_split_batch_beams, 98 | next_cell_state, self._cell.state_size) 99 | 100 | if self._output_layer is not None: 101 | # My modification 102 | if isinstance(self._output_layer, taware_layer.JointDenseLayer): 103 | reshaped_inputs = tf.reshape(inputs, [-1, beam_width, inputs.shape[-1]]) 104 | if self._current_context is not None: 105 | msg_attention, _ = tf.split(self._current_context, num_or_size_splits=2, axis=1) 106 | msg_attention = tf.reshape(msg_attention, [-1, beam_width, msg_attention.shape[-1]]) 107 | cell_outputs = self._output_layer(cell_outputs, input=reshaped_inputs, context=msg_attention) 108 | else: 109 | cell_outputs = self._output_layer(cell_outputs, input=reshaped_inputs) 110 | else: 111 | cell_outputs = self._output_layer(cell_outputs) 112 | 113 | beam_search_output, beam_search_state = _beam_search_step( 114 | time=time, 115 | logits=cell_outputs, 116 | next_cell_state=next_cell_state, 117 | beam_state=state, 118 | batch_size=batch_size, 119 | beam_width=beam_width, 120 | end_token=end_token, 121 | length_penalty_weight=length_penalty_weight, 122 | coverage_penalty_weight=0.0) 123 | 124 | finished = beam_search_state.finished 125 | sample_ids = beam_search_output.predicted_ids 126 | next_inputs = control_flow_ops.cond( 127 | math_ops.reduce_all(finished), lambda: self._start_inputs, 128 | lambda: self._embedding_fn(sample_ids)) 129 | 130 | # My modification 131 | self._current_context = cell_state.attention 132 | 133 | return (beam_search_output, beam_search_state, next_inputs, finished) 134 | -------------------------------------------------------------------------------- /thred/models/topic_aware/taware_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from ..model_helper import TrainModel, EvalModel, InferModel 4 | from . import taware_iterators 5 | from .taware_model import TopicAwareSeq2SeqModel 6 | from thred.util import vocab 7 | 8 | 9 | def create_train_model(hparams, scope=None, num_workers=1, jobid=0, extra_args=None): 10 | """Create train graph, model, and iterator.""" 11 | train_file = hparams.train_data 12 | 13 | graph = tf.Graph() 14 | 15 | with graph.as_default(), tf.container(scope or "train"): 16 | vocab_table = vocab.create_vocab_table(hparams.vocab_file) 17 | 18 | dataset = tf.data.TextLineDataset(train_file) 19 | skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) 20 | 21 | iterator = taware_iterators.get_iterator( 22 | dataset, 23 | vocab_table, 24 | batch_size=hparams.batch_size, 25 | num_buckets=hparams.num_buckets, 26 | topic_words_per_utterance=hparams.topic_words_per_utterance, 27 | src_max_len=hparams.src_max_len, 28 | tgt_max_len=hparams.tgt_max_len, 29 | skip_count=skip_count_placeholder, 30 | num_shards=num_workers, 31 | shard_index=jobid) 32 | 33 | # Note: One can set model_device_fn to 34 | # `tf.train.replica_device_setter(ps_tasks)` for distributed training. 35 | model_device_fn = None 36 | # if extra_args: model_device_fn = extra_args.model_device_fn 37 | with tf.device(model_device_fn): 38 | model = TopicAwareSeq2SeqModel( 39 | mode=tf.contrib.learn.ModeKeys.TRAIN, 40 | iterator=iterator, 41 | params=hparams, 42 | scope=scope) 43 | 44 | return TrainModel( 45 | graph=graph, 46 | model=model, 47 | iterator=iterator, 48 | skip_count_placeholder=skip_count_placeholder) 49 | 50 | 51 | def create_eval_model(hparams, scope=None): 52 | """Create train graph, model, src/tgt file holders, and iterator.""" 53 | vocab_file = hparams.vocab_file 54 | graph = tf.Graph() 55 | 56 | with graph.as_default(), tf.container(scope or "eval"): 57 | vocab_table = vocab.create_vocab_table(vocab_file) 58 | eval_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) 59 | 60 | eval_dataset = tf.data.TextLineDataset(eval_file_placeholder) 61 | iterator = taware_iterators.get_iterator( 62 | eval_dataset, 63 | vocab_table, 64 | hparams.batch_size, 65 | num_buckets=hparams.num_buckets, 66 | topic_words_per_utterance=hparams.topic_words_per_utterance, 67 | src_max_len=hparams.src_max_len, 68 | tgt_max_len=hparams.tgt_max_len) 69 | model = TopicAwareSeq2SeqModel( 70 | mode=tf.contrib.learn.ModeKeys.EVAL, 71 | iterator=iterator, 72 | params=hparams, 73 | scope=scope, 74 | log_trainables=False) 75 | return EvalModel( 76 | graph=graph, 77 | model=model, 78 | eval_file_placeholder=eval_file_placeholder, 79 | iterator=iterator) 80 | 81 | 82 | def create_infer_model(hparams, scope=None): 83 | """Create inference model.""" 84 | graph = tf.Graph() 85 | vocab_file = hparams.vocab_file 86 | 87 | with graph.as_default(), tf.container(scope or "infer"): 88 | vocab_table = vocab.create_vocab_table(vocab_file) 89 | reverse_vocab_table = vocab.create_rev_vocab_table(vocab_file) 90 | 91 | src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) 92 | batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) 93 | 94 | src_dataset = tf.data.Dataset.from_tensor_slices( 95 | src_placeholder) 96 | iterator = taware_iterators.get_infer_iterator( 97 | src_dataset, 98 | vocab_table, 99 | batch_size=batch_size_placeholder, 100 | topic_words_per_utterance=hparams.topic_words_per_utterance, 101 | src_max_len=hparams.src_max_len) 102 | model = TopicAwareSeq2SeqModel( 103 | mode=tf.contrib.learn.ModeKeys.INFER, 104 | iterator=iterator, 105 | params=hparams, 106 | rev_vocab_table=reverse_vocab_table, 107 | scope=scope, 108 | log_trainables=False) 109 | return InferModel( 110 | graph=graph, 111 | model=model, 112 | src_placeholder=src_placeholder, 113 | batch_size_placeholder=batch_size_placeholder, 114 | iterator=iterator) 115 | -------------------------------------------------------------------------------- /thred/models/topic_aware/taware_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.python.eager import context 4 | from tensorflow.python.framework import ops 5 | from tensorflow.python.framework import tensor_shape 6 | from tensorflow.python.layers import base as layers_base 7 | from tensorflow.python.ops import init_ops 8 | from tensorflow.python.ops import variable_scope as vs 9 | 10 | 11 | class MultiDenseLayer(layers_base.Layer): 12 | def __init__(self, units, activation=None, use_bias=True, dtype=None, name=None, scope=None, **kwargs): 13 | super(MultiDenseLayer, self).__init__(name=name, dtype=dtype, **kwargs) 14 | self.units = units 15 | 16 | self.activation = activation 17 | self.use_bias = use_bias 18 | self._scope = scope 19 | 20 | self._built = False 21 | 22 | self._n_tensors = None 23 | self.kernels = None 24 | self.bias = None 25 | 26 | def build(self, input_shapes): 27 | if self._built: 28 | return 29 | 30 | self.kernels = [] 31 | self.bias = None 32 | self._n_tensors = len(input_shapes) 33 | 34 | with vs.variable_scope(self._scope or self._name) as scope: 35 | with ops.name_scope(scope.original_name_scope): 36 | for i, input_shape in enumerate(input_shapes): 37 | self.kernels.append(vs.get_variable('kernel_{}'.format(i), 38 | shape=[input_shape[-1], self.units], 39 | dtype=self._dtype, 40 | trainable=True)) 41 | 42 | if self.use_bias: 43 | self.bias = vs.get_variable('bias', 44 | shape=[self.units, ], 45 | initializer=init_ops.zeros_initializer(), 46 | dtype=self._dtype, trainable=True) 47 | 48 | self._built = True 49 | 50 | def __op(self, kernel, inputs, shape): 51 | if len(shape) > 2: 52 | # Broadcasting is required for the inputs. 53 | outputs = tf.tensordot(inputs, kernel, [[len(shape) - 1],[0]]) 54 | # Reshape the output back to the original ndim of the input. 55 | # if context.in_graph_mode(): 56 | # for tf > 1.5.0 57 | if not context.executing_eagerly(): 58 | output_shape = shape[:-1] + [self.units] 59 | outputs.set_shape(output_shape) 60 | else: 61 | outputs = tf.matmul(inputs, kernel) 62 | 63 | return outputs 64 | 65 | def __call__(self, inputs, *args, **kwargs): 66 | if not isinstance(inputs, list) and not isinstance(inputs, tuple): 67 | raise ValueError('input must be a list') 68 | 69 | if not inputs: 70 | raise ValueError('input cannot be empty') 71 | 72 | input_shapes = [inp.get_shape().as_list() for inp in inputs] 73 | 74 | self.build(input_shapes) 75 | 76 | outputs = None 77 | for i, kernel in enumerate(self.kernels): 78 | out = self.__op(kernel, inputs[i], input_shapes[i]) 79 | if outputs is None: 80 | outputs = out 81 | else: 82 | outputs = tf.add(out, outputs) 83 | 84 | if self.use_bias: 85 | outputs = tf.nn.bias_add(outputs, self.bias) 86 | 87 | if self.activation is not None: 88 | return self.activation(outputs) # pylint: disable=not-callable 89 | return outputs 90 | 91 | 92 | class JointDenseLayer(layers_base.Layer): 93 | def __init__(self, vocab_size, topic_vocab_size, dtype=None, name=None, scope=None, **kwargs): 94 | super(JointDenseLayer, self).__init__(name=name, dtype=dtype, **kwargs) 95 | self._vocab_size = vocab_size 96 | self._topic_vocab_size = topic_vocab_size 97 | self._units = vocab_size 98 | 99 | with tf.variable_scope(scope or "output_projection"): 100 | self._msg_layer = MultiDenseLayer( 101 | self._vocab_size, 102 | # activation=tf.nn.tanh, 103 | name="message_projection") 104 | self._topical_layer = MultiDenseLayer( 105 | self._topic_vocab_size, 106 | # activation=tf.nn.tanh, 107 | name="topical_projection") 108 | 109 | self._scope = scope 110 | 111 | self._built = False 112 | 113 | self._n_tensors = None 114 | self.kernels = None 115 | self.bias = None 116 | 117 | def compute_output_shape(self, input_shape): 118 | # remove "_" from the name for tf > 1.5.0 119 | input_shape = tensor_shape.TensorShape(input_shape) 120 | input_shape = input_shape.with_rank_at_least(2) 121 | if input_shape[-1].value is None: 122 | raise ValueError( 123 | 'The innermost dimension of input_shape must be defined, but saw: %s' 124 | % input_shape) 125 | return input_shape[:-1].concatenate(self._units) 126 | 127 | def __call__(self, rnn_output, *args, **kwargs): 128 | input = kwargs.get('input') 129 | context_attention = kwargs.get('context') 130 | 131 | message_outputs = self._msg_layer([rnn_output, input]) 132 | 133 | shape = rnn_output.get_shape().as_list() 134 | pad_dims = [[0] * 2 for _ in range(len(shape))] 135 | pad_dims[-1][0] = self._vocab_size - self._topic_vocab_size 136 | 137 | if context_attention is None: 138 | topical_layer_inputs = [rnn_output, input] 139 | else: 140 | topical_layer_inputs = [rnn_output, input, context_attention] 141 | 142 | topical_outputs = tf.pad(self._topical_layer(topical_layer_inputs), pad_dims) 143 | 144 | return tf.add(message_outputs, topical_outputs) 145 | -------------------------------------------------------------------------------- /thred/models/topical_base.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import re 3 | 4 | import tensorflow as tf 5 | 6 | from ..util import vocab 7 | 8 | 9 | def initialize_vocabulary(hparams): 10 | _create_vocabulary(hparams.vocab_file, hparams.topic_vocab_file, hparams.train_data, hparams.vocab_size) 11 | 12 | vocab_table = vocab.create_vocab_dict(hparams.vocab_file) 13 | topic_vocab_table = vocab.create_vocab_dict(hparams.topic_vocab_file) 14 | 15 | for w in topic_vocab_table: 16 | topic_vocab_table[w] = vocab_table[w] 17 | 18 | return vocab_table, topic_vocab_table 19 | 20 | 21 | def _create_vocabulary(vocab_path, topic_vocab_path, data_path, max_vocabulary_size, normalize_digits=False): 22 | """A modified version of vocab.create_vocabulary 23 | """ 24 | 25 | if tf.gfile.Exists(vocab_path) and tf.gfile.Exists(topic_vocab_path): 26 | return 27 | 28 | print("Creating vocabulary files from data %s" % data_path) 29 | dialog_vocab, topic_vocab = {}, {} 30 | 31 | def normalize(word): 32 | if normalize_digits: 33 | if re.match(r'[\-+]?\d+(\.\d+)?', word): 34 | return '' 35 | 36 | return word 37 | 38 | with codecs.getreader('utf-8')( 39 | tf.gfile.GFile(data_path, mode="rb")) as f: 40 | counter = 0 41 | for line in f: 42 | counter += 1 43 | if counter % 100000 == 0: 44 | print(" processing line %d" % counter) 45 | 46 | delimited_line = line.split("\t") 47 | topic_tokens = delimited_line[-1].strip().split() 48 | dialog_tokens = " ".join([delimited_line[i].strip() for i in range(len(delimited_line) - 1)]).split() 49 | 50 | for word in dialog_tokens: 51 | word = normalize(word) 52 | 53 | if word in dialog_vocab: 54 | dialog_vocab[word] += 1 55 | else: 56 | dialog_vocab[word] = 1 57 | 58 | for word in topic_tokens: 59 | word = normalize(word) 60 | 61 | if word in topic_vocab: 62 | topic_vocab[word] += 1 63 | else: 64 | topic_vocab[word] = 1 65 | 66 | for word in topic_vocab: 67 | if word in dialog_vocab: 68 | topic_vocab[word] += dialog_vocab[word] 69 | 70 | topic_vocab_list = sorted(topic_vocab, key=topic_vocab.get, reverse=True) 71 | with codecs.getwriter('utf-8')( 72 | tf.gfile.GFile(topic_vocab_path, mode="wb")) as topic_vocab_file: 73 | for w in topic_vocab_list: 74 | topic_vocab_file.write(w + "\n") 75 | 76 | for reserved_word in vocab.RESERVED_WORDS: 77 | if reserved_word in dialog_vocab: 78 | dialog_vocab.pop(reserved_word) 79 | 80 | dialog_vocab_list = vocab.RESERVED_WORDS + sorted(dialog_vocab, key=dialog_vocab.get, reverse=True) 81 | 82 | if len(dialog_vocab_list) > max_vocabulary_size: 83 | dialog_vocab_list = dialog_vocab_list[:max_vocabulary_size] 84 | 85 | for word in topic_vocab: 86 | if word in dialog_vocab_list and word not in vocab.RESERVED_WORDS: 87 | dialog_vocab_list.remove(word) 88 | 89 | with codecs.getwriter('utf-8')( 90 | tf.gfile.GFile(vocab_path, mode="wb")) as vocab_file: 91 | for w in dialog_vocab_list: 92 | vocab_file.write(w + "\n") 93 | 94 | for w in topic_vocab_list: 95 | vocab_file.write(w + "\n") 96 | 97 | print("Topic vocabulary with {} words created".format(len(topic_vocab_list))) 98 | print("Vocabulary with {} words created".format(len(topic_vocab_list))) 99 | 100 | del topic_vocab 101 | del dialog_vocab 102 | -------------------------------------------------------------------------------- /thred/models/vanilla/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nouhadziri/THRED/c451e801d6d36c21f52a42192b15f985241650b2/thred/models/vanilla/__init__.py -------------------------------------------------------------------------------- /thred/models/vanilla/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | 27 | 28 | def _get_ngrams(segment, max_order): 29 | """Extracts all n-grams upto a given maximum order from an input segment. 30 | 31 | Args: 32 | segment: text segment from which n-grams will be extracted. 33 | max_order: maximum length in tokens of the n-grams returned by this 34 | methods. 35 | 36 | Returns: 37 | The Counter containing all n-grams upto max_order in segment 38 | with a count of how many times each n-gram occurred. 39 | """ 40 | ngram_counts = collections.Counter() 41 | for order in range(1, max_order + 1): 42 | for i in range(0, len(segment) - order + 1): 43 | ngram = tuple(segment[i:i + order]) 44 | ngram_counts[ngram] += 1 45 | return ngram_counts 46 | 47 | 48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 49 | smooth=False): 50 | """Computes BLEU score of translated segments against one or more references. 51 | 52 | Args: 53 | reference_corpus: list of lists of references for each translation. Each 54 | reference should be tokenized into a list of tokens. 55 | translation_corpus: list of translations to score. Each translation 56 | should be tokenized into a list of tokens. 57 | max_order: Maximum n-gram order to use when computing BLEU score. 58 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 59 | 60 | Returns: 61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 62 | precisions and brevity penalty. 63 | """ 64 | matches_by_order = [0] * max_order 65 | possible_matches_by_order = [0] * max_order 66 | reference_length = 0 67 | translation_length = 0 68 | for (references, translation) in zip(reference_corpus, 69 | translation_corpus): 70 | reference_length += min(len(r) for r in references) 71 | translation_length += len(translation) 72 | 73 | merged_ref_ngram_counts = collections.Counter() 74 | for reference in references: 75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 76 | translation_ngram_counts = _get_ngrams(translation, max_order) 77 | overlap = translation_ngram_counts & merged_ref_ngram_counts 78 | for ngram in overlap: 79 | matches_by_order[len(ngram) - 1] += overlap[ngram] 80 | for order in range(1, max_order + 1): 81 | possible_matches = len(translation) - order + 1 82 | if possible_matches > 0: 83 | possible_matches_by_order[order - 1] += possible_matches 84 | 85 | precisions = [0] * max_order 86 | for i in range(0, max_order): 87 | if smooth: 88 | precisions[i] = ((matches_by_order[i] + 1.) / 89 | (possible_matches_by_order[i] + 1.)) 90 | else: 91 | if possible_matches_by_order[i] > 0: 92 | precisions[i] = (float(matches_by_order[i]) / 93 | possible_matches_by_order[i]) 94 | else: 95 | precisions[i] = 0.0 96 | 97 | if min(precisions) > 0: 98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 99 | geo_mean = math.exp(p_log_sum) 100 | else: 101 | geo_mean = 0 102 | 103 | ratio = float(translation_length) / reference_length 104 | 105 | if ratio > 1.0: 106 | bp = 1. 107 | else: 108 | bp = math.exp(1 - 1. / ratio) 109 | 110 | bleu = geo_mean * bp 111 | 112 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 113 | -------------------------------------------------------------------------------- /thred/models/vanilla/eval_metric.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import time 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from thred.util import log 8 | from . import bleu 9 | from ..ncm_utils import get_translation 10 | 11 | 12 | def decode_and_evaluate(name, 13 | model, 14 | sess, 15 | out_file, 16 | ref_file, 17 | metrics, 18 | beam_width, 19 | num_translations_per_input=1, 20 | decode=True): 21 | """Decode a test set and compute a score according to the evaluation task.""" 22 | # Decode 23 | if decode: 24 | log.print_out(" decoding to output '{}'".format(out_file)) 25 | 26 | start_time = time.time() 27 | num_sentences = 0 28 | with codecs.getwriter("utf-8")( 29 | tf.gfile.GFile(out_file, mode="wb")) as trans_f: 30 | trans_f.write("") # Write empty string to ensure file is created. 31 | 32 | num_translations_per_input = max( 33 | min(num_translations_per_input, beam_width), 1) 34 | 35 | i = 0 36 | while True: 37 | i += 1 38 | try: 39 | 40 | if i % 1000 == 0: 41 | log.print_out(" decoding step {}, num sentences {}".format(i, num_sentences)) 42 | 43 | ncm_outputs, _ = model.decode(sess) 44 | if beam_width == 0: 45 | ncm_outputs = np.expand_dims(ncm_outputs, 0) 46 | 47 | batch_size = ncm_outputs.shape[1] 48 | num_sentences += batch_size 49 | 50 | for sent_id in range(batch_size): 51 | translations = [get_translation(ncm_outputs[beam_id], sent_id) 52 | for beam_id in range(num_translations_per_input)] 53 | trans_f.write(b"\t".join(translations).decode("utf-8") + "\n") 54 | except tf.errors.OutOfRangeError: 55 | log.print_time( 56 | " Done, num sentences {}, num translations per input {}".format( 57 | num_sentences, num_translations_per_input), start_time) 58 | break 59 | 60 | # Evaluation 61 | evaluation_scores = {} 62 | # if ref_file and tf.gfile.Exists(out_file): 63 | # for metric in metrics: 64 | # score = evaluate(ref_file, out_file, metric) 65 | # evaluation_scores[metric] = score 66 | # log.print_out(" %s %s: %.1f" % (metric, name, score)) 67 | 68 | return evaluation_scores 69 | 70 | 71 | def evaluate(ref_file, trans_file, metric): 72 | """Pick a metric and evaluate depending on task.""" 73 | # BLEU scores for translation task 74 | if metric.lower() == "bleu": 75 | evaluation_score = _bleu(ref_file, trans_file) 76 | # ROUGE scores for summarization tasks 77 | elif metric.lower() == "accuracy": 78 | evaluation_score = _accuracy(ref_file, trans_file) 79 | elif metric.lower() == "word_accuracy": 80 | evaluation_score = _word_accuracy(ref_file, trans_file) 81 | else: 82 | raise ValueError("Unknown metric {}".format(metric)) 83 | 84 | return evaluation_score 85 | 86 | 87 | def _bleu(ref_file, trans_file): 88 | """Compute BLEU scores and handling BPE.""" 89 | max_order = 4 90 | smooth = False 91 | 92 | ref_files = [ref_file] 93 | reference_text = [] 94 | for reference_filename in ref_files: 95 | with codecs.getreader("utf-8")( 96 | tf.gfile.GFile(reference_filename, "rb")) as fh: 97 | reference_text.append(fh.readlines()) 98 | 99 | per_segment_references = [] 100 | for references in zip(*reference_text): 101 | reference_list = [] 102 | for reference in references: 103 | reference = reference.strip() 104 | reference_list.append(reference.split(" ")) 105 | per_segment_references.append(reference_list) 106 | 107 | translations = [] 108 | with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh: 109 | for line in fh: 110 | line = line.strip() 111 | translations.append(line.split(" ")) 112 | 113 | # bleu_score, precisions, bp, ratio, translation_length, reference_length 114 | bleu_score, _, _, _, _, _ = bleu.compute_bleu( 115 | per_segment_references, translations, max_order, smooth) 116 | return 100 * bleu_score 117 | 118 | 119 | def _accuracy(label_file, pred_file): 120 | """Compute accuracy, each line contains a label.""" 121 | 122 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: 123 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: 124 | count = 0.0 125 | match = 0.0 126 | for label in label_fh: 127 | label = label.strip() 128 | pred = pred_fh.readline().strip() 129 | if label == pred: 130 | match += 1 131 | count += 1 132 | return 100 * match / count 133 | 134 | 135 | def _word_accuracy(label_file, pred_file): 136 | """Compute accuracy on per word basis.""" 137 | 138 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "r")) as label_fh: 139 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "r")) as pred_fh: 140 | total_acc, total_count = 0., 0. 141 | for sentence in label_fh: 142 | labels = sentence.strip().split(" ") 143 | preds = pred_fh.readline().strip().split(" ") 144 | match = 0.0 145 | for pos in range(min(len(labels), len(preds))): 146 | label = labels[pos] 147 | pred = preds[pos] 148 | if label == pred: 149 | match += 1 150 | total_acc += 100 * match / max(len(labels), len(preds)) 151 | total_count += 1 152 | return total_acc / total_count 153 | -------------------------------------------------------------------------------- /thred/models/vanilla/vanilla_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from ..model_helper import TrainModel, EvalModel, InferModel 4 | from . import vanilla_iterators 5 | from .vanilla_model import VanillaSeq2SeqModel 6 | from thred.util import vocab 7 | 8 | 9 | def create_train_model(hparams, scope=None, num_workers=1, jobid=0, extra_args=None): 10 | """Create train graph, model, and iterator.""" 11 | train_file = hparams.train_data 12 | 13 | graph = tf.Graph() 14 | 15 | with graph.as_default(), tf.container(scope or "train"): 16 | vocab_table = vocab.create_vocab_table(hparams.vocab_file) 17 | 18 | dataset = tf.data.TextLineDataset(train_file) 19 | skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) 20 | 21 | iterator = vanilla_iterators.get_iterator( 22 | dataset, 23 | vocab_table, 24 | batch_size=hparams.batch_size, 25 | num_buckets=hparams.num_buckets, 26 | src_max_len=hparams.src_max_len, 27 | tgt_max_len=hparams.tgt_max_len, 28 | skip_count=skip_count_placeholder, 29 | num_shards=num_workers, 30 | shard_index=jobid) 31 | 32 | # Note: One can set model_device_fn to 33 | # `tf.train.replica_device_setter(ps_tasks)` for distributed training. 34 | model_device_fn = None 35 | # if extra_args: model_device_fn = extra_args.model_device_fn 36 | with tf.device(model_device_fn): 37 | model = VanillaSeq2SeqModel( 38 | mode=tf.contrib.learn.ModeKeys.TRAIN, 39 | iterator=iterator, 40 | params=hparams, 41 | scope=scope) 42 | 43 | return TrainModel( 44 | graph=graph, 45 | model=model, 46 | iterator=iterator, 47 | skip_count_placeholder=skip_count_placeholder) 48 | 49 | 50 | def create_eval_model(hparams, scope=None): 51 | """Create train graph, model, src/tgt file holders, and iterator.""" 52 | vocab_file = hparams.vocab_file 53 | graph = tf.Graph() 54 | 55 | with graph.as_default(), tf.container(scope or "eval"): 56 | vocab_table = vocab.create_vocab_table(vocab_file) 57 | eval_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) 58 | 59 | eval_dataset = tf.data.TextLineDataset(eval_file_placeholder) 60 | iterator = vanilla_iterators.get_iterator( 61 | eval_dataset, 62 | vocab_table, 63 | hparams.batch_size, 64 | num_buckets=hparams.num_buckets, 65 | src_max_len=hparams.src_max_len, 66 | tgt_max_len=hparams.tgt_max_len) 67 | model = VanillaSeq2SeqModel( 68 | mode=tf.contrib.learn.ModeKeys.EVAL, 69 | iterator=iterator, 70 | params=hparams, 71 | scope=scope, 72 | log_trainables=False) 73 | return EvalModel( 74 | graph=graph, 75 | model=model, 76 | eval_file_placeholder=eval_file_placeholder, 77 | iterator=iterator) 78 | 79 | 80 | def create_infer_model(hparams, scope=None): 81 | """Create inference model.""" 82 | graph = tf.Graph() 83 | vocab_file = hparams.vocab_file 84 | 85 | with graph.as_default(), tf.container(scope or "infer"): 86 | vocab_table = vocab.create_vocab_table(vocab_file) 87 | reverse_vocab_table = vocab.create_rev_vocab_table(vocab_file) 88 | 89 | src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) 90 | batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) 91 | 92 | src_dataset = tf.data.Dataset.from_tensor_slices( 93 | src_placeholder) 94 | iterator = vanilla_iterators.get_infer_iterator( 95 | src_dataset, 96 | vocab_table, 97 | batch_size=batch_size_placeholder, 98 | src_max_len=hparams.src_max_len) 99 | model = VanillaSeq2SeqModel( 100 | mode=tf.contrib.learn.ModeKeys.INFER, 101 | iterator=iterator, 102 | params=hparams, 103 | rev_vocab_table=reverse_vocab_table, 104 | scope=scope) 105 | return InferModel( 106 | graph=graph, 107 | model=model, 108 | src_placeholder=src_placeholder, 109 | batch_size_placeholder=batch_size_placeholder, 110 | iterator=iterator) 111 | -------------------------------------------------------------------------------- /thred/models/vanilla/vanilla_iterators.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from ..model_helper import BatchedInput 4 | from thred.util import vocab 5 | 6 | 7 | def get_iterator(dataset, 8 | vocab_table, 9 | batch_size, 10 | num_buckets, 11 | random_seed=None, 12 | src_max_len=None, 13 | tgt_max_len=None, 14 | num_parallel_calls=4, 15 | output_buffer_size=None, 16 | skip_count=None, 17 | num_shards=1, 18 | shard_index=0): 19 | if not output_buffer_size: 20 | output_buffer_size = batch_size * 1000 21 | 22 | eos_id = tf.constant(vocab.EOS_ID, dtype=tf.int32) 23 | sos_id = tf.constant(vocab.SOS_ID, dtype=tf.int32) 24 | 25 | src_tgt_dataset = dataset.shard(num_shards, shard_index) 26 | if skip_count is not None: 27 | src_tgt_dataset = src_tgt_dataset.skip(skip_count) 28 | 29 | src_tgt_dataset = src_tgt_dataset.shuffle(output_buffer_size, random_seed) 30 | 31 | def tokenize(line): 32 | utterances = tf.string_split([line], delimiter="\t").values 33 | i, sp = tf.constant(0), tf.Variable([], dtype=tf.string) 34 | cond = lambda i, sp: tf.less(i, tf.size(utterances)-1) 35 | 36 | def loop_body(i, sp): 37 | splitted = tf.string_split([utterances[i]]).values 38 | if src_max_len: 39 | splitted = tf.cond(tf.less(i, tf.size(utterances)-2), 40 | lambda: splitted[:src_max_len - 1], 41 | lambda: splitted[:src_max_len]) 42 | 43 | splitted = tf.cond(tf.less(i, tf.size(utterances)-2), 44 | lambda: tf.concat([splitted, [vocab.SEP]], axis=0), 45 | lambda: splitted) 46 | 47 | return tf.add(i, 1), tf.concat([sp, splitted], axis=0) 48 | 49 | _, srcs = tf.while_loop(cond, loop_body, [i, sp], shape_invariants=[i.get_shape(), tf.TensorShape([None])]) 50 | 51 | 52 | # srcs = [tf.string_split([utterances[t]]).values for t in range(num_inputs)] 53 | tgt = tf.string_split([utterances[tf.size(utterances)-1]]).values 54 | aggregated_src = tf.reduce_join([srcs], axis=0, separator=" ") 55 | 56 | return aggregated_src, tgt[:tgt_max_len] if tgt_max_len else tgt 57 | 58 | src_tgt_dataset = src_tgt_dataset.map(tokenize, 59 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 60 | 61 | # Filter zero length input sequences. 62 | src_tgt_dataset = src_tgt_dataset.filter( 63 | lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) 64 | 65 | if src_max_len: 66 | src_tgt_dataset = src_tgt_dataset.map( 67 | lambda src, tgt: (src[:src_max_len], tgt), 68 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 69 | if tgt_max_len: 70 | src_tgt_dataset = src_tgt_dataset.map( 71 | lambda src, tgt: (src, tgt[:tgt_max_len]), 72 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 73 | 74 | # Convert the word strings to ids. Word strings that are not in the 75 | # vocab get the lookup table's default_value integer. 76 | src_tgt_dataset = src_tgt_dataset.map( 77 | lambda src, tgt: (tf.cast(vocab_table.lookup(src), tf.int32), 78 | tf.cast(vocab_table.lookup(tgt), tf.int32)), 79 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 80 | # Create a tgt_input prefixed with and a tgt_output suffixed with . 81 | src_tgt_dataset = src_tgt_dataset.map( 82 | lambda src, tgt: (src, 83 | tf.concat(([sos_id], tgt), 0), 84 | tf.concat((tgt, [eos_id]), 0)), 85 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 86 | # Add in sequence lengths. 87 | src_tgt_dataset = src_tgt_dataset.map( 88 | lambda src, tgt_in, tgt_out: ( 89 | src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), 90 | num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) 91 | 92 | # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...) 93 | def batching_func(x): 94 | return x.padded_batch( 95 | batch_size, 96 | # The first three entries are the source and target line rows; 97 | # these have unknown-length vectors. The last two entries are 98 | # the source and target row sizes; these are scalars. 99 | padded_shapes=( 100 | tf.TensorShape([None]), # src 101 | tf.TensorShape([None]), # tgt_input 102 | tf.TensorShape([None]), # tgt_output 103 | tf.TensorShape([]), # src_len 104 | tf.TensorShape([])), # tgt_len 105 | # Pad the source and target sequences with eos tokens. 106 | # (Though notice we don't generally need to do this since 107 | # later on we will be masking out calculations past the true sequence. 108 | padding_values=( 109 | eos_id, # src 110 | eos_id, # tgt_input 111 | eos_id, # tgt_output 112 | 0, # src_len -- unused 113 | 0)) # tgt_len -- unused 114 | 115 | if num_buckets > 1: 116 | 117 | def key_func(unused_1, unused_2, unused_3, src_len, tgt_len): 118 | # Calculate bucket_width by maximum source sequence length. 119 | # Pairs with length [0, bucket_width) go to bucket 0, length 120 | # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length 121 | # over ((num_bucket-1) * bucket_width) words all go into the last bucket. 122 | if src_max_len: 123 | bucket_width = (src_max_len + num_buckets - 1) // num_buckets 124 | else: 125 | bucket_width = 10 126 | 127 | # Bucket sentence pairs by the length of their source sentence and target 128 | # sentence. 129 | bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) 130 | return tf.to_int64(tf.minimum(num_buckets, bucket_id)) 131 | 132 | def reduce_func(unused_key, windowed_data): 133 | return batching_func(windowed_data) 134 | 135 | batched_dataset = src_tgt_dataset.apply( 136 | tf.contrib.data.group_by_window( 137 | key_func=key_func, reduce_func=reduce_func, window_size=batch_size)) 138 | 139 | else: 140 | batched_dataset = batching_func(src_tgt_dataset) 141 | batched_iter = batched_dataset.make_initializable_iterator() 142 | (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, 143 | tgt_seq_len) = (batched_iter.get_next()) 144 | return BatchedInput( 145 | initializer=batched_iter.initializer, 146 | sources=src_ids, 147 | target_input=tgt_input_ids, 148 | target_output=tgt_output_ids, 149 | source_sequence_lengths=src_seq_len, 150 | target_sequence_length=tgt_seq_len) 151 | 152 | 153 | def get_infer_iterator(test_dataset, 154 | vocab_table, 155 | batch_size, 156 | src_max_len=None): 157 | eos_id = tf.constant(vocab.EOS_ID, dtype=tf.int32) 158 | 159 | def tokenize(line): 160 | utterances = tf.string_split([line], delimiter="\t").values 161 | i, sp = tf.constant(0), tf.Variable([], dtype=tf.string) 162 | cond = lambda i, sp: tf.less(i, tf.size(utterances) - 1) 163 | 164 | def loop_body(i, sp): 165 | splitted = tf.string_split([utterances[i]]).values 166 | if src_max_len: 167 | splitted = tf.cond(tf.less(i, tf.size(utterances) - 2), 168 | lambda: splitted[:src_max_len - 1], 169 | lambda: splitted[:src_max_len]) 170 | 171 | splitted = tf.cond(tf.less(i, tf.size(utterances) - 2), 172 | lambda: tf.concat([splitted, [vocab.SEP]], axis=0), 173 | lambda: splitted) 174 | 175 | return tf.add(i, 1), tf.concat([sp, splitted], axis=0) 176 | 177 | _, srcs = tf.while_loop(cond, loop_body, [i, sp], shape_invariants=[i.get_shape(), tf.TensorShape([None])]) 178 | 179 | aggregated_src = tf.reduce_join([srcs], axis=0, separator=" ") 180 | return aggregated_src 181 | 182 | test_dataset = test_dataset.map(tokenize) 183 | 184 | if src_max_len: 185 | test_dataset = test_dataset.map(lambda src: src[:src_max_len]) 186 | # Convert the word strings to ids 187 | test_dataset = test_dataset.map(lambda src: tf.cast(vocab_table.lookup(src), tf.int32)) 188 | 189 | # Add in the word counts. 190 | test_dataset = test_dataset.map(lambda src: (src, tf.size(src))) 191 | 192 | def batching_func(x): 193 | return x.padded_batch( 194 | batch_size, 195 | # The entry is the source line rows; 196 | # this has unknown-length vectors. The last entry is 197 | # the source row size; this is a scalar. 198 | padded_shapes=( 199 | tf.TensorShape([None]), # src 200 | tf.TensorShape([])), # src_len 201 | # Pad the source sequences with eos tokens. 202 | # (Though notice we don't generally need to do this since 203 | # later on we will be masking out calculations past the true sequence. 204 | padding_values=( 205 | eos_id, # src 206 | 0)) # src_len -- unused 207 | 208 | batched_dataset = batching_func(test_dataset) 209 | batched_iter = batched_dataset.make_initializable_iterator() 210 | (src_ids, src_seq_len) = batched_iter.get_next() 211 | return BatchedInput( 212 | initializer=batched_iter.initializer, 213 | sources=src_ids, 214 | target_input=None, 215 | target_output=None, 216 | source_sequence_lengths=src_seq_len, 217 | target_sequence_length=None) 218 | -------------------------------------------------------------------------------- /thred/topic_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nouhadziri/THRED/c451e801d6d36c21f52a42192b15f985241650b2/thred/topic_model/__init__.py -------------------------------------------------------------------------------- /thred/topic_model/analyzer.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | from ..util.nlp import NLPToolkit 4 | 5 | enhanced_punctuation = string.punctuation + \ 6 | u'\u2012' + u'\u2013' + u'\u2014' + u'\u2015' + u'\u2018' + u'\u2019' + u'\u201C' + u'\u201D' + \ 7 | u'\u2212' + u'\u2026' 8 | translate_table = dict((ord(char), u'') for char in enhanced_punctuation) 9 | 10 | contractions = {"'ll", "'ve", "'re", "n't", "doesn't", "don't", "i'm"} 11 | 12 | 13 | def normalize(word, min_length=None): 14 | """ 15 | converts terms in lower case, drops stop words and applies stemming using 16 | the PorterStemmer algorithm 17 | """ 18 | 19 | term = word.lower() 20 | 21 | if NLPToolkit.is_stopword(term) or term in contractions: 22 | raise Warning("stopwords are not normalized here") 23 | 24 | if min_length is not None and len(word) < min_length: 25 | raise Warning("word is too short") 26 | 27 | term = term.translate(translate_table) 28 | 29 | if not term: 30 | raise Warning("after normalization word became empty") 31 | 32 | return term 33 | 34 | 35 | def normalize_sequence(words, min_length=None): 36 | normalized = [] 37 | for word in words: 38 | try: 39 | term = normalize(word, min_length) 40 | normalized.append(term) 41 | except Warning: 42 | continue 43 | 44 | return normalized 45 | -------------------------------------------------------------------------------- /thred/topic_model/lda.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import logging 3 | from os import listdir, mkdir 4 | from os.path import isdir, exists, join, abspath 5 | 6 | import gensim 7 | import yaml 8 | from gensim import corpora 9 | 10 | from . import analyzer 11 | from ..util import fs 12 | from ..util.misc import Stopwatch 13 | 14 | 15 | class LDAArgs(dict): 16 | def __init__(self, params=None, *args, **kwargs): 17 | super(LDAArgs, self).__init__(*args, **kwargs) 18 | self.update(params) 19 | self.__dict__ = self 20 | 21 | def save(self, args_file): 22 | to_dump_dict = dict(self.__dict__) 23 | to_dump_dict['documents'] = abspath(to_dump_dict['documents']) 24 | 25 | with codecs.getwriter("utf-8")(open(args_file, "wb")) as f: 26 | yaml.dump(to_dump_dict, f, default_flow_style=False) 27 | 28 | @staticmethod 29 | def load(args_file): 30 | with codecs.getreader("utf-8")(open(args_file, "rb")) as f: 31 | params = yaml.load(f) 32 | 33 | return LDAArgs(params=params) 34 | 35 | 36 | def iter_corpus(documents, min_length=None): 37 | if not exists(documents): 38 | raise ValueError('The documents data does not exist: {}'.format(documents)) 39 | 40 | all_docs = [] 41 | sw = Stopwatch() 42 | 43 | if isdir(documents): 44 | print('Documents stored as files in directory "{}"'.format(documents)) 45 | 46 | files = listdir(documents) 47 | 48 | for i, f in enumerate(files): 49 | if not f.endswith('.txt'): 50 | continue 51 | 52 | file_path = join(documents, f) 53 | 54 | doc = [] 55 | with codecs.getreader("utf-8")(open(file_path, 'rb')) as f: 56 | for line in f: 57 | words = line.split() 58 | for word in words: 59 | try: 60 | term = analyzer.normalize(word, min_length) 61 | doc.append(term) 62 | except Warning: 63 | continue 64 | 65 | all_docs.append(doc) 66 | if i % 1000 == 0: 67 | sw.print(' {} of {} iterated'.format(i, len(files))) 68 | else: 69 | print('Documents stored in each line in file "{}"'.format(documents)) 70 | with codecs.getreader("utf-8")(open(documents, 'rb')) as f: 71 | for i, line in enumerate(f): 72 | doc = analyzer.normalize_sequence(line.split(), min_length) 73 | 74 | if doc: 75 | all_docs.append(doc) 76 | 77 | if i % 100000 == 0: 78 | sw.print(' {} lines iterated'.format(i)) 79 | 80 | sw.print('corpus built') 81 | return all_docs 82 | 83 | 84 | def train(model_dir, args): 85 | if not exists(model_dir): 86 | mkdir(model_dir) 87 | 88 | corpus = iter_corpus(args.documents, args.min_length) 89 | dictionary = corpora.Dictionary(corpus) 90 | dictionary.filter_extremes(no_below=args.no_below) 91 | 92 | mm_corpus_file = join(model_dir, 'corpus.mm') 93 | 94 | if not exists(mm_corpus_file): 95 | print("corpus not found. Starting to build it...") 96 | 97 | class CorpusWrapper: 98 | 99 | def __init__(self, dictionary): 100 | self._dictionary = dictionary 101 | 102 | def __iter__(self): 103 | for tokens in corpus: 104 | yield self._dictionary.doc2bow(tokens) 105 | 106 | gensim.corpora.MmCorpus.serialize(mm_corpus_file, CorpusWrapper(dictionary)) 107 | 108 | mm_corpus = gensim.corpora.MmCorpus(mm_corpus_file) 109 | 110 | # generate LDA model 111 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) 112 | ldamodel = gensim.models.LdaMulticore(mm_corpus, 113 | id2word=dictionary, 114 | alpha='asymmetric', eta='auto', 115 | num_topics=args.num_topics, 116 | passes=args.passes, 117 | eval_every=args.eval_every, 118 | batch=True, 119 | chunksize=args.chunksize, 120 | iterations=args.iterations) 121 | print("Saving LDA model...") 122 | ldamodel.save(join(model_dir, 'LDA.model')) 123 | 124 | print("Saving words for topics...") 125 | with open(join(model_dir, 'TopicWords.txt'), 'w') as topic_file: 126 | for i in range(args.num_topics): 127 | topic_file.write('Topic #{}:\n\t'.format(i)) 128 | topic_words_ids = [x[0] for x in ldamodel.get_topic_terms(i, topn=args.words_per_topic)] 129 | topic_file.write('\n\t'.join([dictionary[x] for x in topic_words_ids]) + '\n') 130 | 131 | args.save(join(model_dir, 'config.yml')) 132 | 133 | 134 | class TopicInferer: 135 | 136 | def __init__(self, model_dir, verbose=True): 137 | self._model_dir = model_dir 138 | self._verbose = verbose 139 | if self._verbose: 140 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) 141 | self._params = LDAArgs.load(join(model_dir, 'config.yml')) 142 | self._ldamodel = gensim.models.LdaMulticore.load(join(model_dir, 'LDA.model')) 143 | 144 | def _init_words_per_topics(self, words_per_topic): 145 | topic_word_dict = {} 146 | for t_id in range(self._params.num_topics): 147 | topic_words_ids = [x[0] for x in 148 | self._ldamodel.get_topic_terms( 149 | t_id, 150 | topn=(words_per_topic or self._params.words_per_topic))] 151 | topic_word_dict[t_id] = [self._ldamodel.id2word[x] for x in topic_words_ids] 152 | 153 | return topic_word_dict 154 | 155 | def from_collection(self, test_collection, dialogue_as_doc=False, words_per_topic=None): 156 | topic_word_dict = self._init_words_per_topics(words_per_topic) 157 | output = [] 158 | for lno, sample in enumerate(test_collection): 159 | utterances = sample.strip().split('\t') 160 | 161 | if dialogue_as_doc: 162 | words = ' '.join(utterances).split() 163 | dialogue_terms = analyzer.normalize_sequence(words) 164 | 165 | doc = self._ldamodel.id2word.doc2bow(dialogue_terms) 166 | topic_ids = self._ldamodel.get_document_topics(doc) 167 | if len(topic_ids) > 0: 168 | t_id = sorted(topic_ids, key=lambda x: x[1], reverse=True)[0][0] 169 | output.append((t_id, topic_word_dict[t_id])) 170 | else: 171 | output.append((-1, [])) 172 | else: 173 | t_ids, t_words = [], [] 174 | for i, utterance in enumerate(utterances): 175 | utterance_terms = analyzer.normalize_sequence(utterance.split()) 176 | 177 | doc = self._ldamodel.id2word.doc2bow(utterance_terms) 178 | topic_ids = self._ldamodel.get_document_topics(doc) 179 | if len(topic_ids) > 0: 180 | t_id = sorted(topic_ids, key=lambda x: x[1], reverse=True)[0][0] 181 | t_ids.append(t_id) 182 | t_words.append(topic_word_dict[t_id]) 183 | else: 184 | t_ids.append(-1) 185 | t_words.append([]) 186 | 187 | output.append((t_ids, t_words)) 188 | 189 | return output 190 | 191 | def from_file(self, test_data, output_file, dialogue_as_doc=False, words_per_topic=None): 192 | topic_word_dict = self._init_words_per_topics(words_per_topic) 193 | 194 | if output_file is None: 195 | output_file = fs.replace_ext(test_data, 'topical.txt') 196 | 197 | sw = Stopwatch() 198 | 199 | with codecs.getreader('utf-8')(open(test_data, 'rb')) as test_file: 200 | with codecs.getwriter('utf-8')(open(output_file, 'wb')) as out_file: 201 | for lno, line in enumerate(test_file): 202 | utterances = line.strip().split('\t') 203 | out_file.write(line.strip() + "\t") 204 | 205 | if lno % 100000 == 0: 206 | sw.print(' {} lines inferred'.format(lno)) 207 | 208 | if dialogue_as_doc: 209 | words = ' '.join(utterances[:-1]).split() 210 | dialogue_terms = analyzer.normalize_sequence(words) 211 | 212 | doc = self._ldamodel.id2word.doc2bow(dialogue_terms) 213 | topic_ids = self._ldamodel.get_document_topics(doc) 214 | if len(topic_ids) > 0: 215 | t_id = sorted(topic_ids, key=lambda x: x[1], reverse=True)[0][0] 216 | out_file.write(' '.join(topic_word_dict[t_id])) 217 | else: 218 | out_file.write('') 219 | else: 220 | for i, utterance in enumerate(utterances): 221 | utterance_terms = analyzer.normalize_sequence(utterance.split()) 222 | 223 | doc = self._ldamodel.id2word.doc2bow(utterance_terms) 224 | topic_ids = self._ldamodel.get_document_topics(doc) 225 | if len(topic_ids) > 0: 226 | t_id = sorted(topic_ids, key=lambda x: x[1], reverse=True)[0][0] 227 | out_file.write( 228 | ' '.join(topic_word_dict[t_id]) + ('\t' if i < len(utterances) - 1 else '')) 229 | else: 230 | out_file.write('') 231 | 232 | out_file.write('\n') 233 | 234 | sw.print('Done!!!') 235 | 236 | 237 | def main(): 238 | import argparse 239 | 240 | parser = argparse.ArgumentParser() 241 | parser.add_argument('--mode', type=str, required=True, choices=("train", "infer"), help='mode') 242 | parser.add_argument('--model_dir', type=str, required=True, help='model directory') 243 | parser.add_argument('--data', type=str, 244 | help='data (if directory, each document is a file, or else each document is a line)') 245 | parser.add_argument('--n_topics', type=int, default=200, help='number of topics') 246 | parser.add_argument('--no_below', type=int, default=600, 247 | help='terms with frequency lower than this argument would be dropped') 248 | parser.add_argument('--min_length', type=int, help='min length of words') 249 | parser.add_argument('--test_data', type=str, help='test data') 250 | parser.add_argument('--dialogue_as_doc', action='store_true', help='treats whole dialogue as document') 251 | parser.add_argument('--output', type=str, help='output file') 252 | 253 | args = parser.parse_args() 254 | if args.mode == 'train': 255 | _params = { 256 | "num_topics": args.n_topics, 257 | "documents": args.data, 258 | "no_below": args.no_below, 259 | "min_length": args.min_length, 260 | "passes": 70, 261 | "eval_every": 10, 262 | "chunksize": 10000, 263 | "iterations": 1000, 264 | "words_per_topic": 100 265 | } 266 | 267 | print("Training starts with arguments: {}".format(_params)) 268 | train(args.model_dir, LDAArgs(_params)) 269 | elif args.mode == 'infer': 270 | TopicInferer(args.model_dir).from_file(args.test_data, args.output, args.dialogue_as_doc) 271 | 272 | 273 | if __name__ == "__main__": 274 | main() 275 | -------------------------------------------------------------------------------- /thred/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nouhadziri/THRED/c451e801d6d36c21f52a42192b15f985241650b2/thred/util/__init__.py -------------------------------------------------------------------------------- /thred/util/chartable.py: -------------------------------------------------------------------------------- 1 | 2 | s = '!"#$%&\'()*+,-./01234567890:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~' 3 | capital_alphabet = "ABCDEFGHIJKLMNOPQRSTUVXYZ" 4 | 5 | 6 | def _build_halfwidth(): 7 | charmap = {} 8 | 9 | msb, mb, lsb = tuple('\uff01'.encode('utf-8')) 10 | for ch in s: 11 | charmap[bytes([msb, mb, lsb]).decode('utf-8')] = ch 12 | 13 | lsb += 1 14 | if lsb == 0xc0: 15 | lsb = 0x80 16 | mb += 1 17 | 18 | assert len(charmap) == len(s) 19 | return charmap 20 | 21 | 22 | def _build_enclosed_supplements(): 23 | 24 | charmap = { 25 | bytes([0xf0, 0x9f, 0x84, 0x80]).decode('utf-8'): '0.', 26 | bytes([0xf0, 0x9f, 0x84, 0x8B]).decode('utf-8'): '0', 27 | bytes([0xf0, 0x9f, 0x84, 0x8C]).decode('utf-8'): '0', 28 | bytes([0xf0, 0x9f, 0x84, 0xAA]).decode('utf-8'): '[S]', 29 | bytes([0xf0, 0x9f, 0x84, 0xAB]).decode('utf-8'): 'C', 30 | bytes([0xf0, 0x9f, 0x84, 0xAC]).decode('utf-8'): 'R', 31 | bytes([0xf0, 0x9f, 0x84, 0xAD]).decode('utf-8'): 'CD', 32 | bytes([0xf0, 0x9f, 0x84, 0xAE]).decode('utf-8'): 'Wz', 33 | bytes([0xf0, 0x9f, 0x85, 0x8A]).decode('utf-8'): 'HV', 34 | bytes([0xf0, 0x9f, 0x85, 0x8B]).decode('utf-8'): 'MV', 35 | bytes([0xf0, 0x9f, 0x85, 0x8C]).decode('utf-8'): 'SD', 36 | bytes([0xf0, 0x9f, 0x85, 0x8D]).decode('utf-8'): 'SS', 37 | bytes([0xf0, 0x9f, 0x85, 0x8E]).decode('utf-8'): 'PPV', 38 | bytes([0xf0, 0x9f, 0x85, 0x8F]).decode('utf-8'): 'WC', 39 | bytes([0xf0, 0x9f, 0x85, 0xAA]).decode('utf-8'): 'MC', 40 | bytes([0xf0, 0x9f, 0x85, 0xAB]).decode('utf-8'): 'MD', 41 | bytes([0xf0, 0x9f, 0x86, 0x8A]).decode('utf-8'): 'P', 42 | bytes([0xf0, 0x9f, 0x86, 0x8B]).decode('utf-8'): 'IC', 43 | bytes([0xf0, 0x9f, 0x86, 0x8C]).decode('utf-8'): 'PA', 44 | bytes([0xf0, 0x9f, 0x86, 0x8D]).decode('utf-8'): 'SA', 45 | bytes([0xf0, 0x9f, 0x86, 0x8E]).decode('utf-8'): 'AB', 46 | bytes([0xf0, 0x9f, 0x86, 0x8F]).decode('utf-8'): 'WC', 47 | bytes([0xf0, 0x9f, 0x86, 0x90]).decode('utf-8'): 'DJ', 48 | bytes([0xf0, 0x9f, 0x86, 0x91]).decode('utf-8'): 'CL', 49 | bytes([0xf0, 0x9f, 0x86, 0x92]).decode('utf-8'): 'COOL', 50 | bytes([0xf0, 0x9f, 0x86, 0x93]).decode('utf-8'): 'FREE', 51 | bytes([0xf0, 0x9f, 0x86, 0x94]).decode('utf-8'): 'ID', 52 | bytes([0xf0, 0x9f, 0x86, 0x95]).decode('utf-8'): 'NEW', 53 | bytes([0xf0, 0x9f, 0x86, 0x96]).decode('utf-8'): 'NG', 54 | bytes([0xf0, 0x9f, 0x86, 0x97]).decode('utf-8'): 'OK', 55 | bytes([0xf0, 0x9f, 0x86, 0x98]).decode('utf-8'): 'SOS', 56 | bytes([0xf0, 0x9f, 0x86, 0x99]).decode('utf-8'): 'UP!', 57 | bytes([0xf0, 0x9f, 0x86, 0x9A]).decode('utf-8'): 'VS', 58 | bytes([0xf0, 0x9f, 0x86, 0x9B]).decode('utf-8'): '3D', 59 | bytes([0xf0, 0x9f, 0x86, 0x9C]).decode('utf-8'): '3D', 60 | bytes([0xf0, 0x9f, 0x86, 0x9D]).decode('utf-8'): '2K', 61 | bytes([0xf0, 0x9f, 0x86, 0x9E]).decode('utf-8'): '4K', 62 | bytes([0xf0, 0x9f, 0x86, 0x9F]).decode('utf-8'): '8K', 63 | bytes([0xf0, 0x9f, 0x86, 0xA0]).decode('utf-8'): '5.1', 64 | bytes([0xf0, 0x9f, 0x86, 0xA1]).decode('utf-8'): '7.1', 65 | bytes([0xf0, 0x9f, 0x86, 0xA2]).decode('utf-8'): '22.2', 66 | bytes([0xf0, 0x9f, 0x86, 0xA3]).decode('utf-8'): '60P', 67 | bytes([0xf0, 0x9f, 0x86, 0xA4]).decode('utf-8'): '120P', 68 | bytes([0xf0, 0x9f, 0x86, 0xA5]).decode('utf-8'): 'd', 69 | bytes([0xf0, 0x9f, 0x86, 0xA6]).decode('utf-8'): 'HC', 70 | bytes([0xf0, 0x9f, 0x86, 0xA7]).decode('utf-8'): 'HDR', 71 | bytes([0xf0, 0x9f, 0x86, 0xA8]).decode('utf-8'): 'Hi-Res', 72 | bytes([0xf0, 0x9f, 0x86, 0xA9]).decode('utf-8'): 'Lossless', 73 | bytes([0xf0, 0x9f, 0x86, 0xAA]).decode('utf-8'): 'SHV', 74 | bytes([0xf0, 0x9f, 0x86, 0xAB]).decode('utf-8'): 'UHD', 75 | bytes([0xf0, 0x9f, 0x86, 0xAC]).decode('utf-8'): 'VOD', 76 | } 77 | 78 | b3, b2, b1, b0 = (0xf0, 0x9f, 0x84, 0x80) 79 | for n in range(10): 80 | charmap[bytes([b3, b2, b1, b0]).decode('utf-8')] = '{},'.format(n) 81 | b0 += 1 82 | 83 | b3, b2, b1, b0 = (0xf0, 0x9f, 0x84, 0x90) 84 | alphas = [[0xf0, 0x9f, 0x84, 0xB0], [0xf0, 0x9f, 0x85, 0x90], [0xf0, 0x9f, 0x85, 0xB0], [0xf0, 0x9f, 0x87, 0xA6]] 85 | for l in capital_alphabet: 86 | charmap[bytes([b3, b2, b1, b0]).decode('utf-8')] = '({})'.format(l) 87 | b0 += 1 88 | for i, code in enumerate(alphas): 89 | charmap[bytes(code).decode('utf-8')] = l 90 | code[3] += 1 91 | if code[3] == 0xc0: 92 | code[3] = 0x80 93 | code[2] += 1 94 | 95 | return charmap 96 | 97 | 98 | def _build_extended_map(): 99 | charmap = { 100 | '\u1D00': 'A', '\u1D03': 'B', '\u1D04': 'C', '\u1D05': 'D', 101 | '\u1D07': 'E', '\u1D0A': 'J', '\u1D0B': 'K', '\u1D0C': 'L', 102 | '\u1D0D': 'M', '\u1D0F': 'O', '\u1D18': 'K', '\u1D1B': 'T', 103 | '\u1D1C': 'U', '\u1D20': 'V', '\u1D21': 'W', '\u1D22': 'Z', 104 | '\u1D29': 'P', '\u1D2C': 'A', '\u1D2E': 'B', '\u1D30': 'D', 105 | '\u1D31': 'E', '\u1D33': 'G', '\u1D34': 'H', '\u1D35': 'I', 106 | '\u1D36': 'J', '\u1D37': 'K', '\u1D38': 'L', '\u1D39': 'M', 107 | '\u1D3A': 'N', '\u1D3C': 'O', '\u1D3E': 'P', '\u1D3F': 'R', 108 | '\u1D40': 'T', '\u1D41': 'U', '\u1D42': 'W', '\u1D43': 'a', 109 | '\u1D45': 'a', '\u1D47': 'b', '\u1D48': 'd', '\u1D49': 'e', 110 | '\u1D4B': 'e', '\u1D4D': 'g', '\u1D4F': 'k', '\u1D50': 'm', 111 | '\u1D52': 'o', '\u1D56': 'p', '\u1D57': 't', '\u1D58': 'u', 112 | '\u1D5B': 'v', '\u1D63': 'r', '\u1D64': 'u', '\u1D65': 'v', 113 | '\u0391': 'A', '\u0392': 'B', '\u0395': 'E', '\u0396': 'Z', 114 | '\u0397': 'H', '\u0399': 'I', '\u039A': 'K', '\u039C': 'M', 115 | '\u039D': 'N', '\u039F': 'O', '\u03A1': 'P', '\u03A4': 'T', 116 | '\u03A5': 'Y', '\u03A7': 'X', '\u03BA': 'k', '\u03BD': 'v', 117 | '\u03BF': 'o', '\u03C7': 'X', '\u03DC': 'F', '\u03DD': 'F', 118 | '\u03E6': 'b', '\u03E4': 'q', '\u03E5': 'q', '\u03F2': 'c', 119 | '\u03F3': 'j', '\u03F9': 'C', '\u03FA': 'M', '\u03FB': 'c', 120 | '\u0251': 'a', '\u0256': 'd', '\u0257': 'd', '\u0260': 'g', 121 | '\u0261': 'g', '\u0262': 'G', '\u0266': 'h', '\u0267': 'h', 122 | '\u0268': 'i', '\u026A': 'I', '\u026B': 'l', '\u026C': 'l', 123 | '\u026D': 'l', '\u0271': 'm', '\u0272': 'n', '\u0273': 'n', 124 | '\u0274': 'N', '\u027C': 'r', '\u027D': 'r', '\u028D': 'm', 125 | '\u027E': 'r', '\u0280': 'R', '\u0282': 's', '\u0287': 't', 126 | '\u0288': 't', '\u0289': 'u', '\u028A': 'u', '\u028B': 'v', 127 | '\u028F': 'Y', '\u0290': 'z', '\u0291': 'z', '\u0297': 'C', 128 | '\u0299': 'B', '\u029C': 'H', '\u029D': 'j', '\u029F': 'L', 129 | '\u025B': 'e', '\u026E': 'b', '\u0284': 'f', 130 | '\u0531': 'u', '\u0532': 'f', '\u0533': 'q', 131 | '\u0535': 't', '\u0537': 't', '\u053A': 'd', '\u053B': 'r', 132 | '\u053C': 'L', '\u0544': 'U', '\u0548': 'n', '\u054D': 'U', 133 | '\u054F': 'S', '\u0550': 'r', '\u0555': 'O', '\u0556': 'S', 134 | '\u0559': "'", '\u055A': "'", '\u055B': '`', '\u055D': '`', 135 | '\u0562': "f", '\u0563': "q", '\u0564': 'n', '\u0565': 't', 136 | '\u0566': "q", '\u0567': "t", '\u0569': 'p', '\u056A': 'd', 137 | '\u056B': "h", '\u056C': "l", '\u0570': 'h', '\u0572': 'n', 138 | '\u0574': "u", '\u0575': "J", '\u0576': 'u', '\u0577': '2', 139 | '\u0578': "n", '\u057C': "n", '\u057D': 'u', '\u0580': 'n', 140 | '\u0581': "g", '\u0582': "L", '\u0584': 'p', '\u0585': 'o', 141 | '\u0586': 'S', '\u0587': "u", '\u0589': ":", 142 | '\u00DE': 'p', '\u00E0': 'a', '\u00E1': 'a', '\u00E2': 'a', 143 | '\u00E3': 'a', '\u00E4': 'a', '\u00E5': 'a', '\u00DF': 'b', 144 | '\u00E8': 'e', '\u00E9': 'e', '\u00EA': 'e', '\u00EB': 'e', 145 | '\u00EC': 'i', '\u00ED': 'i', '\u00EE': 'i', '\u00EF': 'i', 146 | '\u00FE': 'b', 147 | '\u0180': "b", '\u0181': "B", '\u0182': "b", '\u0183': "b", 148 | '\u0184': "b", '\u0185': "b", '\u0187': "C", '\u0188': "c", 149 | '\u0189': "D", '\u018A': "D", '\u0190': "e", '\u0191': "F", 150 | '\u0192': "f", '\u0193': "G", '\u0196': "I", '\u0197': "I", 151 | '\u0198': "K", '\u0199': "k", '\u019A': "l", '\u019D': "N", 152 | '\u019E': "n", '\u019F': "O", '\u01A0': "O", '\u01A1': "o", 153 | '\u01A4': "P", '\u01A5': "p", '\u01A6': "R", '\u01AC': "T", 154 | '\u01AD': "t", '\u01AE': "T", '\u01AF': "U", '\u01B0': "u", 155 | '\u01B1': "U", '\u01B2': "V", '\u01B3': "Y", '\u01B4': "y", 156 | '\u01B5': "Z", '\u01B6': "z", '\u01BC': "5", '\u01BD': "5", 157 | '\u01BB': "2", '\u01BA': "3", '\u01C3': "!", '\u01C4': "DZ", 158 | '\u01C5': "Dz", '\u01C6': "dz", '\u01C7': "LJ", '\u01C8': "Lj", 159 | '\u01C9': "lj", '\u01CA': "NJ", '\u01CB': "Nj", '\u01CC': "nj", 160 | '\u01CD': "A", '\u01CE': "a", '\u01CF': "I", '\u01D0': "i", 161 | '\u01D1': "O", '\u01D2': "o", '\u01D3': "U", '\u01D4': "u", 162 | '\u01D5': "U", '\u01D6': "u", '\u01D7': "U", '\u01D8': "u", 163 | '\u01D9': "U", '\u01DA': "u", '\u01DB': "U", '\u01DC': "u", 164 | '\u01DE': "A", '\u01DF': "a", '\u01E0': "A", '\u01E1': "a", 165 | '\u01E4': "G", '\u01E5': "g", '\u01E6': "G", '\u01E7': "g", 166 | '\u01E8': "K", '\u01E9': "k", '\u01EA': "Q", '\u01EB': "q", 167 | '\u01EC': "Q", '\u01ED': "q", '\u01F0': "j", '\u01F1': "DZ", 168 | '\u01F2': "Dz", '\u01F3': "dz", '\u01F4': "G", '\u01F5': "g", 169 | '\u01F6': "H", '\u01F8': "N", '\u01F9': "n", '\u01FA': "A", 170 | '\u01FB': "a", '\u01FE': "O", '\u01FF': "o", '\u0200': "A", 171 | '\u0201': "a", '\u0202': "A", '\u0203': "a", '\u0204': "E", 172 | '\u0205': "e", '\u0206': "E", '\u0207': "e", '\u0208': "I", 173 | '\u0209': "i", '\u020A': "I", '\u020B': "i", '\u020C': "O", 174 | '\u020D': "o", '\u020E': "O", '\u020F': "o", '\u0210': "R", 175 | '\u0211': "r", '\u0212': "R", '\u0213': "r", '\u0214': "U", 176 | '\u0215': "u", '\u0216': "U", '\u0217': "u", '\u0218': "S", 177 | '\u0219': "s", '\u021A': "T", '\u021B': "t", '\u021E': "H", 178 | '\u021F': "h", '\u0220': "n", '\u0221': "d", '\u0222': "8", 179 | '\u0223': "8", '\u0224': "Z", '\u0225': "z", '\u0226': "A", 180 | '\u0227': "a", '\u0228': "E", '\u0229': "e", '\u022A': "O", 181 | '\u022B': "o", '\u022C': "O", '\u022D': "o", '\u022E': "O", 182 | '\u022F': "o", '\u0230': "O", '\u0231': "o", '\u0232': "Y", 183 | '\u0233': "y", '\u0234': "l", '\u0235': "n", '\u0236': "t", 184 | '\u0237': "j", '\u023A': "A", '\u023B': "C", '\u023C': "c", 185 | '\u023D': "t", '\u023E': "T", '\u023F': "s", '\u0240': "z", 186 | '\u0241': "?", '\u0242': "?", '\u0243': "B", '\u0244': "U", 187 | '\u0245': "A", '\u0246': "E", '\u0247': "e", '\u0248': "J", 188 | '\u0249': "j", '\u024A': "q", '\u024B': "q", '\u024C': "R", 189 | '\u024D': "r", '\u024E': "Y", '\u024F': "y", 190 | '\u04C3': "K", '\u04C4': "k", 191 | } 192 | 193 | return charmap 194 | 195 | 196 | def get_table(): 197 | charmap = {} 198 | 199 | charmap.update(_build_halfwidth()) 200 | charmap.update(_build_enclosed_supplements()) 201 | charmap.update(_build_extended_map()) 202 | 203 | return charmap 204 | 205 | 206 | if __name__ == "__main__": 207 | get_table() 208 | -------------------------------------------------------------------------------- /thred/util/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | import yaml 6 | 7 | from . import fs, log 8 | 9 | 10 | class Config(dict): 11 | 12 | def __init__(self, *args, **kwargs) -> None: 13 | super(Config, self).__init__(*args, **kwargs) 14 | 15 | margs = self.__read_params() 16 | self.update(margs) 17 | 18 | pargs = dict(kwargs) 19 | empty_args = [arg for arg, val in pargs.items() if arg in margs and val is None] 20 | for arg in empty_args: 21 | pargs.pop(arg) 22 | 23 | # For same parameters, model arguments from config file are overwritten by program arguments 24 | self.update(pargs) 25 | self.__dict__ = self 26 | 27 | def __read_params(self): 28 | if self['mode'] == 'train': 29 | config_file = None 30 | 31 | if self['config'] is None and os.path.exists(self['model_dir']): 32 | for f in os.listdir(self['model_dir']): 33 | if f.endswith('_config.yml'): 34 | config_file = os.path.join(self['model_dir'], f) 35 | 36 | if config_file is None: 37 | missing_args = [] 38 | 39 | if self['train_data'] is None: 40 | missing_args.append('train_data') 41 | 42 | if self['dev_data'] is None: 43 | missing_args.append('dev_data') 44 | 45 | if self['config'] is None: 46 | missing_args.append('config') 47 | 48 | if missing_args: 49 | raise ValueError('In train mode, the following arguments are required: {}'.format( 50 | ', '.join(missing_args))) 51 | 52 | config_file = self['config'] 53 | 54 | if not os.path.exists(self['model_dir']): 55 | os.makedirs(self['model_dir']) 56 | elif self['restart_training']: 57 | _cleanup(self['model_dir']) 58 | else: 59 | if not os.path.exists(self['model_dir']): 60 | raise ValueError('model directory does not exist') 61 | 62 | config_file = None 63 | for f in os.listdir(self['model_dir']): 64 | if f.endswith('_config.yml'): 65 | config_file = os.path.join(self['model_dir'], f) 66 | 67 | if not config_file: 68 | raise ValueError('config file not found in model directory') 69 | 70 | with open(config_file, 'r') as file: 71 | model_args = yaml.safe_load(file) 72 | 73 | self._update_relative_paths(model_args) 74 | 75 | return model_args 76 | 77 | def get_infer_model_dir(self): 78 | for f in os.listdir(self.model_dir): 79 | if f == 'best_dev_ppl' and os.listdir(os.path.join(self.model_dir, f)): 80 | return os.path.join(self.model_dir, f) 81 | 82 | return self.model_dir 83 | 84 | def is_pretrain_enabled(self): 85 | return False 86 | 87 | def save(self): 88 | hparams_file = Path(self.model_dir) / "{}_config.yml".format(fs.file_name(self.config)) 89 | log.print_out(" Saving config to {}".format(hparams_file)) 90 | 91 | config_dict = dict(self.__dict__) 92 | 93 | # absolute paths 94 | if config_dict['train_data']: 95 | config_dict['train_data'] = Path(config_dict['train_data']).absolute().as_posix() 96 | if config_dict['test_data']: 97 | config_dict['test_data'] = Path(config_dict['test_data']).absolute().as_posix() 98 | if config_dict['dev_data']: 99 | config_dict['dev_data'] = Path(config_dict['dev_data']).absolute().as_posix() 100 | 101 | # relative paths 102 | if config_dict['vocab_file']: 103 | config_dict['vocab_file'] = Path(config_dict['vocab_file']).name 104 | if config_dict['vocab_pkl']: 105 | config_dict['vocab_pkl'] = Path(config_dict['vocab_pkl']).name 106 | if config_dict['checkpoint_file']: 107 | config_dict['checkpoint_file'] = Path(config_dict['checkpoint_file']).name 108 | if config_dict.get('topic_vocab_file', ''): 109 | config_dict['topic_vocab_file'] = Path(config_dict['topic_vocab_file']).name 110 | if config_dict.get('best_dev_ppl_dir', ''): 111 | config_dict['best_dev_ppl_dir'] = Path(config_dict['best_dev_ppl_dir']).name 112 | 113 | with hparams_file.open("w", encoding="utf-8") as f: 114 | yaml.dump(config_dict, f, default_flow_style=False) 115 | 116 | def _update_relative_paths(self, args): 117 | model_path = Path(self["model_dir"]) 118 | if "vocab_file" in args: 119 | args["vocab_file"] = (model_path / args["vocab_file"]).as_posix() 120 | 121 | if "vocab_pkl" in args: 122 | args["vocab_pkl"] = (model_path / args["vocab_pkl"]).as_posix() 123 | 124 | if "checkpoint_file" in args: 125 | args["checkpoint_file"] = (model_path / args["checkpoint_file"]).as_posix() 126 | 127 | if "topic_vocab_file" in args: 128 | args["topic_vocab_file"] = (model_path / args["topic_vocab_file"]).as_posix() 129 | 130 | if "best_dev_ppl_dir" in args: 131 | args["best_dev_ppl_dir"] = (model_path / args["best_dev_ppl_dir"]).as_posix() 132 | 133 | 134 | def _cleanup(folder): 135 | for f in os.listdir(folder): 136 | file = os.path.join(folder, f) 137 | 138 | if '.ckpt' in f or f.startswith('log_') or \ 139 | f.lower() == 'checkpoint' or \ 140 | f.endswith('.shuf') or \ 141 | f.endswith('_config.yml'): 142 | os.remove(file) 143 | elif (f.endswith('_log') or f.startswith('best_')) and os.path.isdir(file): 144 | shutil.rmtree(file) 145 | 146 | log.print_out(" >> Heads up: model directory cleaned up!") 147 | 148 | 149 | if __name__ == "__main__": 150 | import argparse 151 | parser = argparse.ArgumentParser() 152 | args = parser.parse_args() 153 | config = Config(vars(args)) 154 | config.save() 155 | -------------------------------------------------------------------------------- /thred/util/device.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.client import device_lib 2 | 3 | from .misc import safe_mod 4 | 5 | 6 | class DeviceManager(object): 7 | """ 8 | Derived from https://gist.github.com/jovianlin/b5b2c4be45437992df58a7dd13dbafa7 9 | """ 10 | 11 | def __init__(self): 12 | local_device_protos = device_lib.list_local_devices() 13 | self.cpus = [] 14 | self.gpus = [] 15 | for dev in local_device_protos: 16 | if dev.device_type == 'GPU': 17 | self.gpus.append(dev.name) 18 | elif dev.device_type == 'CPU': 19 | self.cpus.append(dev.name) 20 | 21 | def get_default_device(self): 22 | if self.gpus: 23 | return self.gpus[0] 24 | 25 | return self.cpus[0] 26 | 27 | def num_available_gpus(self): 28 | return len(self.gpus) 29 | 30 | def gpu(self, index): 31 | return self.gpus[index] if self.gpus else self.get_default_device() 32 | 33 | def tail_gpu(self): 34 | return self.gpus[-1] if self.gpus else self.get_default_device() 35 | 36 | 37 | class RoundRobin(object): 38 | 39 | def __init__(self, device_manager): 40 | self.device_manager = device_manager 41 | 42 | def assign(self, n_devices, base=0): 43 | devices = [] 44 | 45 | for i in range(n_devices): 46 | devices.append(self.device_manager.gpu(safe_mod((base + i), self.device_manager.num_available_gpus()))) 47 | 48 | return devices 49 | 50 | 51 | if __name__ == "__main__": 52 | manager = DeviceManager() 53 | print(manager.gpus) 54 | 55 | round_robin = RoundRobin(manager) 56 | print(round_robin.assign(4)) 57 | -------------------------------------------------------------------------------- /thred/util/dull_responses.txt: -------------------------------------------------------------------------------- 1 | i do n't know . 2 | i do n't know what you 're talking about . 3 | i do n't know what you mean . 4 | i do n't know what you mean by that . 5 | i do n't know , i just do n't get it . 6 | i do n't know if you 're being sarcastic or not . 7 | i do n't know if you 're joking or not . 8 | i do n't know if it 's a joke $1 9 | i do n't know what you mean by that answer . 10 | i do n't know why you think that . 11 | i do n't know , i 've never heard of it . 12 | i 'm sorry , i 'm not sure what you mean . i 'm not sure what you 're saying . 13 | i 'm sorry , i do n't know what you 're talking about . 14 | i 'm not sure . 15 | i 'm not sure if i 'm being a $1 16 | i 'm not sure if it 's worth it , but i 'm not sure if it 's worth it . 17 | i 'm not sure if that 's a joke or not . 18 | i 'm not sure if that 's a good idea or not . 19 | i 'm not sure if you 're being sarcastic or not . 20 | i 'm not sure if you 're joking or not . 21 | i 'm not sure how that works , but i 'm not sure . 22 | i 'm not sure how you can do that . 23 | i 'm not sure how $1 24 | i 'm not sure what the point is . i 'm just saying that $1 25 | i 'm not sure what to do with this . 26 | i 'm not sure what you mean by $1 27 | i 'm not sure what you mean by that . i 'm just saying that $1 28 | i 'm not sure what you 're saying . 29 | i 'm not sure what you 're talking about . 30 | i 'm not sure what you 're trying to say . 31 | i 'm not sure what you mean by that . 32 | i 'm not sure what you mean by this . 33 | i 'm not sure what you mean . 34 | i 'm not sure , but i 'm not sure if it 's worth it . 35 | i 'm not sure . i 'm pretty sure it 's $1 36 | i 'm not even sure how to respond to this . 37 | i 'm not trying to be a dick , i 'm just trying to say $1 38 | i do n't understand what you mean . i 'm not saying $1 39 | i 'm not saying that , but i 'm saying that $1 40 | i 'm not saying it 's a bad thing , but it 's not a good thing . 41 | i 'm not saying it 's a bad thing , but i 'm not saying it 's $1 42 | i 'm not saying it 's $1 , but i 'm not saying $1 43 | i 'm not saying it 's not true . i 'm saying that $1 44 | i 'm not saying it 's not a bad thing , but i 'm saying that $1 45 | i 'm not saying you 're wrong , but i 'm saying that you 're $1 46 | i 'm in the same boat . 47 | i 'm not the only one who $1 48 | i 'm not going to be a dick , but i 'm not going to be a dick . 49 | i 'm not a huge fan of the $1, but $1 50 | i have no idea what you 're talking about , but i 'm not sure if $1 -------------------------------------------------------------------------------- /thred/util/embed.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import codecs 4 | from pathlib import Path 5 | from typing import Type, TypeVar, List, Dict, Any 6 | from os import environ, rename 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | import tensorflow_hub as hub 11 | import yaml 12 | from pymagnitude import Magnitude 13 | 14 | from . import fs, wget, vocab 15 | from .misc import Stopwatch 16 | 17 | T = TypeVar('T') 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class EmbeddingType( 22 | collections.namedtuple("EmbeddingType", ("name", "url", "dim", "src_type"))): 23 | pass 24 | 25 | 26 | class EmbeddingFactory: 27 | def __init__(self, embedding_type: EmbeddingType): 28 | self._embedding_type = embedding_type 29 | 30 | def build(self, vocab_list: List[str], **kwargs) -> (List[str], List[str], Dict[str, Dict[str, Any]]): 31 | """ 32 | :param vocab_list: The vocabulary list built upon the dataset 33 | :param h5_file: The output h5 file where the embedding vectors will be saved into 34 | :param kwargs: Additional parameters based on the factory type 35 | :return A tuple containing the Out-Of-Vocabulary words and In-Vocabulary words 36 | """ 37 | raise NotImplementedError() 38 | 39 | 40 | class RandomFactory(EmbeddingFactory): 41 | def __init__(self, embedding_type: EmbeddingType): 42 | super().__init__(embedding_type) 43 | 44 | def build(self, vocab_list: List[str], **kwargs) -> (List[str], List[str], Dict[str, Dict[str, Any]]): 45 | init_weight = kwargs.get('init_weight', 0.1) 46 | 47 | vec_dict = {} 48 | for w in vocab_list: 49 | vec = np.random.uniform(-init_weight, init_weight, size=self._embedding_type.dim) 50 | vec_dict[w] = { 51 | "vec": vec, 52 | "trainable": True 53 | } 54 | 55 | return vocab_list, [], vec_dict 56 | 57 | 58 | class MagnitudeFactory(EmbeddingFactory): 59 | def __init__(self, embedding_type: EmbeddingType): 60 | super().__init__(embedding_type) 61 | 62 | cache_dir = Path(fs.get_project_root_dir()) / ".magnitude" 63 | fs.mkdir_if_not_exists(cache_dir) 64 | embed_file = self._embedding_type.url[self._embedding_type.url.rfind("/") + 1:] 65 | compressed_file = Path(cache_dir) / embed_file 66 | if not compressed_file.exists(): 67 | logger.info(' Downloading magnitude file ("{}")...'.format(embed_file)) 68 | wget.download(self._embedding_type.url, compressed_file) 69 | 70 | self._embed_file = compressed_file 71 | logger.info(' Loading Magnitude module...') 72 | self._magnitude_vecs = Magnitude(self._embed_file) 73 | 74 | def build(self, vocab_list: List[str], **kwargs) -> (List[str], List[str], Dict[str, Dict[str, Any]]): 75 | oov, iov = [], [] 76 | 77 | vec_dict = {} 78 | for w in vocab_list: 79 | is_oov = w not in self._magnitude_vecs 80 | vec = self._magnitude_vecs.query(w) 81 | vec_dict[w] = { 82 | "vec": vec, 83 | "trainable": is_oov 84 | } 85 | 86 | if is_oov: 87 | oov.append(w) 88 | else: 89 | iov.append(w) 90 | 91 | return oov, iov, vec_dict 92 | 93 | 94 | class TfHubFactory(EmbeddingFactory): 95 | __cache = {} 96 | 97 | def __init__(self, embedding_type: EmbeddingType): 98 | super(TfHubFactory, self).__init__(embedding_type) 99 | 100 | def build(self, vocab_list: List[str], **kwargs) -> (List[str], List[str], Dict[str, Dict[str, Any]]): 101 | page_size = kwargs.get('page_size', 15000) 102 | init_weight = kwargs.get('init_weight', 0.1) 103 | 104 | logger.info(' Loading TensorFlow HUB module...') 105 | environ["TFHUB_CACHE_DIR"] = ".tfhub_modules" 106 | embedder = hub.Module(self._embedding_type.url) 107 | 108 | num_pages = len(vocab_list) // page_size 109 | oov, iov = [], [] 110 | vec_dict = {} 111 | with tf.Session() as sess: 112 | sess.run([tf.global_variables_initializer(), tf.tables_initializer()]) 113 | 114 | for i in range(num_pages + 1): 115 | lb = i * page_size 116 | ub = min((i + 1) * page_size, len(vocab_list)) 117 | 118 | page = vocab_list[lb:ub] 119 | embedding_vectors = sess.run(embedder(page)) 120 | 121 | for i, word in enumerate(page): 122 | is_oov = sum(embedding_vectors[i]) == 0 123 | if is_oov: 124 | vec = np.random.uniform(-init_weight, init_weight, self._embedding_type.dim) 125 | else: 126 | vec = embedding_vectors[i] 127 | 128 | vec_dict[word] = { 129 | "vec": vec, 130 | "trainable": is_oov 131 | } 132 | 133 | if is_oov: 134 | oov.append(word) 135 | else: 136 | iov.append(word) 137 | 138 | return oov, iov, vec_dict 139 | 140 | 141 | class EmbeddingUtil: 142 | def __init__(self, config_path: str='conf/word_embeddings.yml'): 143 | with open(config_path, 'r') as file: 144 | self._args = yaml.load(file) 145 | 146 | @classmethod 147 | def from_type(cls: Type[T], embedding_type: EmbeddingType) -> EmbeddingFactory: 148 | if embedding_type.src_type == "tfhub": 149 | return TfHubFactory(embedding_type) 150 | elif embedding_type.src_type == "magnitude": 151 | return MagnitudeFactory(embedding_type) 152 | elif embedding_type.src_type == "random": 153 | return RandomFactory(embedding_type) 154 | else: 155 | raise ValueError( 156 | "Unknown source type '{}' defined in the embedding config file".format(embedding_type.src_type)) 157 | 158 | @classmethod 159 | def load_vectors(cls: Type[T], vocab_pkl: str, vocab_file: str) -> (np.ndarray, np.ndarray): 160 | vocab_list, _ = vocab.load_vocab(vocab_file) 161 | 162 | reserved_words, other_words = [], [] 163 | vec_dict = fs.load_obj(vocab_pkl) 164 | for w in vocab_list: 165 | vec = vec_dict[w]["vec"] 166 | if w in vocab.RESERVED_WORDS: 167 | reserved_words.append(vec) 168 | 169 | other_words.append(vec) 170 | 171 | return np.asarray(reserved_words), np.asarray(other_words) 172 | 173 | def build_if_not_exists(self, embedding_type: str, vocab_pkl: str, vocab_file: str, overwrite: bool=False): 174 | if Path(vocab_pkl).exists() and not overwrite: 175 | return 176 | 177 | sw = Stopwatch() 178 | 179 | if embedding_type.lower().startswith("random"): 180 | try: 181 | dim = int(embedding_type[len("random"):]) 182 | except ValueError: 183 | dim = 300 184 | logger.warning("Unrecognizable dimension for random embedding. Set to default: {}".format(dim)) 185 | _embed_type = EmbeddingType(embedding_type, "", dim, "random") 186 | else: 187 | e = self._args[embedding_type] 188 | _embed_type = EmbeddingType(embedding_type, e["url"], e["dim"], e["src_type"]) 189 | 190 | vocab_list, _ = vocab.load_vocab(vocab_file) 191 | oov, iov, vec_dict = EmbeddingUtil.from_type(_embed_type).build(vocab_list) 192 | fs.save_obj(vec_dict, vocab_pkl) 193 | 194 | rename(vocab_file, fs.replace_ext(vocab_file, 'tf')) 195 | with codecs.getwriter("utf-8")(open(vocab_file, "wb")) as writer: 196 | for rw in vocab.RESERVED_WORDS: 197 | writer.write("{}\n".format(rw)) 198 | 199 | for w in oov: 200 | if w not in vocab.RESERVED_WORDS: 201 | writer.write("{}\n".format(w)) 202 | 203 | for w in iov: 204 | if w not in vocab.RESERVED_WORDS: 205 | writer.write("{}\n".format(w)) 206 | 207 | logger.info("Embedding vectors built from {} in {:.1f}s".format(embedding_type, sw.elapsed())) 208 | -------------------------------------------------------------------------------- /thred/util/emots.txt: -------------------------------------------------------------------------------- 1 | (งツ)ว 2 | ಥ_ಥ 3 | ಥ-ಥ 4 | ಥ╭╮ಥ 5 | (ಥ﹃ಥ) 6 | ಥ~ಥ 7 | ಥωಥ 8 | ಥ﹏ಥ 9 | ಠ‿ಥ 10 | ಥ_ʖಥ 11 | ಥ ͜ʖ ಥ 12 | ಥoಥ 13 | ಥ‿‿ಥ 14 | ಥ◡ಥ 15 | ಠ_ಥ 16 | ಠ╭╮ಥ 17 | ಥ益ಥ 18 | ಠಿ_ಥ 19 | ಥ ೧ ಥ 20 | ಥ_ಠ 21 | ಥ‿ಥ 22 | ಥ∀ಥ 23 | ಠωಥ 24 | ಥ,_」ಥ 25 | ಥヮಥ 26 | ಠಿ_ಠ 27 | (ಥ◞⊱​◟ಥ) 28 | (ಥ◞౪◟ಥ) 29 | (ಥ_◞ಥ) 30 | ᕦ(ಥ_ಥ)ᕤ 31 | (˵ ͠ಥ‿ ͠ಥ˵) 32 | ಇ(˵ಥ_ಥ˵)ಇ 33 | ೭(ಥ_ಥ)೨ 34 | (ง ͠ಥ_ಥ)ง 35 | (ಽ ͡ಥ ͜ʖ ಥ)ಽ 36 | ლ(ಥ益ಥლ) 37 | Σ(ಥ_ಥ) 38 | ヽ༼ಥ_ಥ༽ノ 39 | (☞ ͡ಥuಥ)☞ 40 | ( ಥ ʖ̫ ಥ) 41 | ( ͡ಥ‿ ಥ)━☆゚.*・。゚ 42 | (ಥ_ಥ)━☆゚.*・。゚ 43 | ಠ_ಠ 44 | (ಠ_ಠ) 45 | ಠ.ಠ 46 | ๏ ̯͡ ๏ 47 | (͠≖ ͜ʖ͠≖) 48 | 。^‿^。 49 | |(•◡•)| 50 | ¯\\(ツ)/¯ -------------------------------------------------------------------------------- /thred/util/fs.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import platform 4 | import shutil 5 | import subprocess 6 | import shlex 7 | import time 8 | import pickle 9 | import tarfile 10 | import zipfile 11 | 12 | from urllib.parse import urlparse 13 | 14 | 15 | def split3(path): 16 | fld, f = os.path.split(path) 17 | fname, ext = os.path.splitext(f) 18 | 19 | return fld, fname, ext 20 | 21 | 22 | def file_name(path): 23 | _, fname, _ = split3(path) 24 | return fname 25 | 26 | 27 | def get_current_dir(path): 28 | if os.path.isdir(path): 29 | return path 30 | else: 31 | current_dir, _, _ = split3(path) 32 | return current_dir 33 | 34 | 35 | def get_parent_dir(path): 36 | return os.path.abspath(os.path.join(get_current_dir(path), os.pardir)) 37 | 38 | 39 | def get_project_root_dir(): 40 | return get_parent_dir(get_parent_dir(get_current_dir(__file__))) 41 | 42 | 43 | def replace_ext(path, new_ext): 44 | dir, fname, ext = split3(path) 45 | return os.path.join(dir, "%s.%s" % (fname, new_ext)) 46 | 47 | 48 | def replace_dir(path, new_path, new_ext=None): 49 | _, fname, ext = split3(path) 50 | if new_ext is None: 51 | new_ext = ext 52 | else: 53 | new_ext = '.' + new_ext 54 | return os.path.join(new_path, fname + new_ext) 55 | 56 | 57 | def mkdir_if_not_exists(dir): 58 | if not os.path.exists(dir): 59 | os.mkdir(dir) 60 | 61 | 62 | def rm_if_exists(filename): 63 | try: 64 | os.remove(filename) 65 | except OSError as e: 66 | if e.errno != errno.ENOENT: # errno.ENOENT = no such file or directory 67 | raise # re-raise exception if a different error occurred 68 | 69 | 70 | def is_url(url): 71 | return urlparse(url).scheme != "" 72 | 73 | 74 | def copy(src, dst): 75 | shutil.copy(src, dst) 76 | 77 | 78 | def rm_by_extension(working_dir, ext): 79 | matched_files = 0 80 | for f in os.listdir(working_dir): 81 | actual_f = os.path.join(working_dir, f) 82 | if f.endswith('.' + ext) and os.path.isfile(actual_f): 83 | os.remove(actual_f) 84 | matched_files += 1 85 | 86 | return matched_files 87 | 88 | 89 | def count_lines(file_path): 90 | # https://gist.github.com/zed/0ac760859e614cd03652 91 | with open(file_path, 'rbU') as f: 92 | return sum(1 for _ in f) 93 | 94 | 95 | def save_obj(obj, name): 96 | with open(name, 'wb') as f: 97 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 98 | 99 | 100 | def load_obj(name): 101 | with open(name, 'rb') as f: 102 | return pickle.load(f) 103 | 104 | 105 | def uncompress(compressed_file, out_path="."): 106 | _, _, compressed_ext = split3(compressed_file) 107 | compressed_ext = compressed_ext.lower() 108 | 109 | if compressed_file.lower().endswith(".tar.gz") or compressed_ext in (".tgz", ".gz"): 110 | with tarfile.open(compressed_file, 'r:gz') as tgz_ref: 111 | tgz_ref.extractall(out_path) 112 | elif compressed_ext == ".zip": 113 | with zipfile.ZipFile(compressed_file, 'r') as zip_ref: 114 | zip_ref.extractall(out_path) 115 | -------------------------------------------------------------------------------- /thred/util/kv.py: -------------------------------------------------------------------------------- 1 | """ Key-Value store (i.e., Redis) utilities 2 | """ 3 | import os 4 | import subprocess 5 | 6 | import redis 7 | 8 | 9 | class TinyRedis: 10 | """ TinyRedis is a wrapper on Redis functions and supports a subset of Redis functions. 11 | Upon construction, it pings the server, meaning that an execption would be thrown in case of failure. 12 | """ 13 | def __init__(self, port, max_connections, host='localhost'): 14 | self.__r = redis.Redis(host=host, port=port, max_connections=max_connections, decode_responses=True) 15 | self.ping() 16 | 17 | def __enter__(self): 18 | return self 19 | 20 | def __exit__(self, exc_type, exc_val, exc_tb): 21 | self.close() 22 | 23 | def pipeline(self): 24 | return self.__r.pipeline(transaction=False) 25 | 26 | def ping(self): 27 | self.__r.ping() 28 | 29 | def exists(self, key): 30 | return self.__r.exists(key) 31 | 32 | def delete(self, key): 33 | self.__r.delete(key) 34 | 35 | def set(self, key, value): 36 | return self.__r.set(key, value) 37 | 38 | def get(self, key): 39 | return self.__r.get(key) 40 | 41 | def hscan(self, key): 42 | return self.__r.hscan_iter(key) 43 | 44 | def hget(self, key, field): 45 | return self.__r.hget(key, field) 46 | 47 | def hmget(self, key, *fields): 48 | return self.__r.hmget(key, fields) 49 | 50 | def pl_hincrby(self, key, mappings): 51 | p = self.__r.pipeline(transaction=False) 52 | for field, amount in mappings.items(): 53 | p.hincrby(key, field, amount) 54 | p.execute() 55 | 56 | def lrange(self, key, start_index=0, stop_index=-1): 57 | return self.__r.lrange(key, start=start_index, end=stop_index) 58 | 59 | def sadd(self, key, *members): 60 | return self.__r.sadd(key, *members) 61 | 62 | def smembers(self, key): 63 | return self.__r.smembers(key) 64 | 65 | def info(self, section=None): 66 | return self.__r.info(section) 67 | 68 | def pfadd(self, key, *elements): 69 | return self.__r.pfadd(key, *elements) 70 | 71 | def pfcount(self, key): 72 | return self.__r.pfcount(key) 73 | 74 | def scan(self, cursor=0, match=None, count=None): 75 | return self.__r.scan(cursor, match, count) 76 | 77 | def close(self): 78 | del self.__r 79 | 80 | 81 | def install_redis(install_path='', 82 | download_url='http://download.redis.io/releases/redis-5.0.3.tar.gz', 83 | port=6384, verbose=True): 84 | import tarfile 85 | import tempfile 86 | from urllib.error import URLError, HTTPError 87 | import urllib.request as url_request 88 | 89 | from redis.exceptions import ConnectionError 90 | from .fs import split3 91 | 92 | proceed_install = True 93 | r = redis.Redis(host='localhost', port=port) 94 | try: 95 | r.ping() 96 | proceed_install = False 97 | except ConnectionError: 98 | pass 99 | 100 | if proceed_install: 101 | if not install_path: 102 | tmp_dir = tempfile.mkdtemp(prefix='redis{}'.format(port)) 103 | install_path = os.path.join(tmp_dir, 'redis') 104 | 105 | if not os.path.exists(install_path): 106 | working_dir, redis_name, _ = split3(install_path) 107 | redis_tgzfile = os.path.join(working_dir, 'redis.tar.gz') 108 | 109 | if verbose: 110 | print('Downloading Redis...') 111 | 112 | try: 113 | with url_request.urlopen(download_url) as resp, \ 114 | open(redis_tgzfile, 'wb') as out_file: 115 | data = resp.read() # a `bytes` object 116 | out_file.write(data) 117 | except HTTPError as e: 118 | if verbose: 119 | if e.code == 404: 120 | print( 121 | 'The provided URL seems to be broken. Please find a URL for Redis') 122 | else: 123 | print('Error code: ', e.code) 124 | raise ValueError(e) 125 | except URLError as e: 126 | if verbose: 127 | print('URL error: ', e.reason) 128 | raise ValueError(e) 129 | 130 | if verbose: 131 | print('Extracting Redis...') 132 | with tarfile.open(redis_tgzfile, 'r:gz') as tgz_ref: 133 | tgz_ref.extractall(working_dir) 134 | 135 | os.remove(redis_tgzfile) 136 | 137 | redis_dir = None 138 | for f in os.listdir(working_dir): 139 | if f.lower().startswith('redis'): 140 | redis_dir = os.path.join(working_dir, f) 141 | break 142 | 143 | if not redis_dir: 144 | raise ValueError() 145 | 146 | os.rename(redis_dir, os.path.join(working_dir, redis_name)) 147 | 148 | if verbose: 149 | print('Installing Redis...') 150 | 151 | redis_conf_file = os.path.join(install_path, 'redis.conf') 152 | subprocess.call( 153 | ['sed', '-i', 's/tcp-backlog [0-9]\+$/tcp-backlog 3000/g', redis_conf_file]) 154 | subprocess.call( 155 | ['sed', '-i', 's/daemonize no$/daemonize yes/g', redis_conf_file]) 156 | subprocess.call( 157 | ['sed', '-i', 's/pidfile .*\.pid$/pidfile redis_{}.pid/g'.format(port), redis_conf_file]) 158 | subprocess.call( 159 | ['sed', '-i', 's/port 6379/port {}/g'.format(port), redis_conf_file]) 160 | subprocess.call( 161 | ['sed', '-i', 's/save 900 1/save 15000 1/g', redis_conf_file]) 162 | subprocess.call( 163 | ['sed', '-i', 's/save 300 10/#save 300 10/g', redis_conf_file]) 164 | subprocess.call( 165 | ['sed', '-i', 's/save 60 10000/#save 60 10000/g', redis_conf_file]) 166 | subprocess.call( 167 | ['sed', '-i', 's/logfile ""/logfile "redis_{}.log"/g'.format(port), redis_conf_file]) 168 | subprocess.call(['make'], cwd=install_path) 169 | 170 | if verbose: 171 | print('Running Redis on port {}...'.format(port)) 172 | subprocess.call(['src/redis-server', 'redis.conf'], cwd=install_path) 173 | 174 | return install_path 175 | 176 | 177 | def uninstall_redis(install_path, verbose=True): 178 | import shutil 179 | 180 | if not install_path or not os.path.exists(install_path): 181 | if verbose: 182 | print("Cannot uninstall because the installation path does not exist!") 183 | return 184 | 185 | redis_conf_file = os.path.join(install_path, 'redis.conf') 186 | p = subprocess.Popen(['grep', '-E', '^port [0-9]+$', redis_conf_file], stdout=subprocess.PIPE) 187 | out, _ = p.communicate() 188 | 189 | port = out.split()[1].decode('utf-8') 190 | 191 | if verbose: 192 | print("Shutting down Redis on port {}...".format(port)) 193 | 194 | subprocess.call(['src/redis-cli', '-p', port, 'shutdown'], cwd=install_path) 195 | 196 | shutil.rmtree(install_path) 197 | 198 | if verbose: 199 | print("Redis on port {} uninstalled...".format(port)) 200 | -------------------------------------------------------------------------------- /thred/util/log.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import tensorflow as tf 5 | 6 | 7 | def print_time(s, start_time): 8 | """Take a start time, print elapsed duration, and return a new time.""" 9 | 10 | print("%s, time %ds, %s." % (s, (time.time() - start_time), time.ctime())) 11 | # sys.stdout.flush() 12 | return time.time() 13 | 14 | 15 | def print_out(s, f=None, new_line=True, skip_stdout=False): 16 | """Similar to print but with support to flush and output to a file.""" 17 | 18 | if isinstance(s, bytes): 19 | s = s.decode("utf-8") 20 | 21 | if f: 22 | f.write(s.encode("utf-8")) 23 | if new_line: 24 | f.write(b"\n") 25 | 26 | # stdout 27 | if not skip_stdout: 28 | out_s = s.encode("utf-8") 29 | if not isinstance(out_s, str): 30 | out_s = out_s.decode("utf-8") 31 | print(out_s, end="", file=sys.stdout) 32 | 33 | if new_line: 34 | print() 35 | #sys.stdout.flush() 36 | 37 | 38 | def add_summary(summary_writer, global_step, tag, value): 39 | """Add a new summary to the current summary_writer. 40 | Useful to log things that are not part of the training graph, e.g., tag=BLEU. 41 | """ 42 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 43 | summary_writer.add_summary(summary, global_step) 44 | -------------------------------------------------------------------------------- /thred/util/misc.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import math 3 | import os 4 | import random 5 | import string 6 | import subprocess 7 | from signal import SIGTERM 8 | from time import time 9 | 10 | 11 | def safe_exp(value): 12 | """Exponentiation with catching of overflow error.""" 13 | try: 14 | ans = math.exp(value) 15 | except OverflowError: 16 | ans = float("inf") 17 | return ans 18 | 19 | 20 | def safe_div(dividend, divisor): 21 | return (dividend / divisor) if divisor != 0 else 0 22 | 23 | 24 | def safe_mod(dividend, divisor): 25 | return (dividend % divisor) if divisor != 0 else 0 26 | 27 | 28 | def generate_random_string(length=5): 29 | return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length)) 30 | 31 | 32 | def escRegex(term): 33 | return term.replace('\\', '\\\\') \ 34 | .replace('(', '\(').replace(')', '\)') \ 35 | .replace('*', '\*').replace('+', '\+') \ 36 | .replace('[', '\[').replace(']', '\]') 37 | 38 | 39 | def kill_java_process(process_name): 40 | jps_cmd = subprocess.Popen('jps', stdout=subprocess.PIPE) 41 | jps_output = jps_cmd.stdout.read() 42 | pids = set() 43 | for jps_pair in jps_output.split(b'\n'): 44 | _p = jps_pair.split() 45 | if len(_p) > 1: 46 | if _p[1].decode() == process_name: 47 | pids.add(int(_p[0])) 48 | 49 | jps_cmd.wait() 50 | if pids: 51 | for pid in pids: 52 | os.kill(pid, SIGTERM) 53 | 54 | 55 | def gunzip(gz_path): 56 | with gzip.open(gz_path, 'rb') as in_file: 57 | return in_file.read() 58 | 59 | 60 | class Stopwatch: 61 | def __init__(self): 62 | self.start() 63 | 64 | def start(self): 65 | self.__start = time() 66 | 67 | def elapsed(self): 68 | return round(time() - self.__start, 3) 69 | 70 | def print(self, log_text): 71 | print(log_text, 'elapsed: {}s'.format(self.elapsed())) 72 | -------------------------------------------------------------------------------- /thred/util/nlp.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import re 3 | import codecs 4 | from os.path import join 5 | 6 | import emot 7 | import spacy 8 | from spacy.lang.en.stop_words import STOP_WORDS 9 | 10 | from . import misc, fs 11 | from .twokenize import tokenize as tweet_tokenize 12 | from .twitter_nlp_emoticons import Emoticon_RE 13 | 14 | 15 | def _read_emots(): 16 | emots = set() 17 | with codecs.getreader('utf-8')(open(join(fs.get_current_dir(__file__), 'emots.txt'), 'rb')) as e: 18 | for line in e: 19 | emots.add(line.strip()) 20 | 21 | return emots 22 | 23 | 24 | UNCOMMON_EMOTICONS = _read_emots() 25 | 26 | 27 | class TaggedWord( 28 | collections.namedtuple("TaggedWord", ("index", "term", "lemma", "pos", "ner"))): 29 | 30 | def __repr__(self) -> str: 31 | if self.ner: 32 | return "{term}@{index}/{pos}~{ner}".format(term=self.term, 33 | index=self.index, 34 | pos=self.pos, 35 | ner=self.ner) 36 | else: 37 | return "{term}@{index}/{pos}".format(term=self.term, 38 | index=self.index, 39 | pos=self.pos) 40 | 41 | 42 | class NLPToolkit: 43 | def __init__(self, pipeline=None): 44 | self._nlp = spacy.load('en_core_web_lg') 45 | 46 | if pipeline is not None: 47 | for name in pipeline: 48 | component = self._nlp.create_pipe(name) 49 | self._nlp.add_pipe(component) 50 | 51 | def sent_tokenize(self, text): 52 | doc = self._nlp(text) 53 | return [sent.string.strip() for sent in doc.sents] 54 | 55 | def annotate(self, text): 56 | doc = self._nlp(text) 57 | return NLPToolkit._parse_doc(doc) 58 | 59 | def tokenize(self, text): 60 | doc = self._nlp(text) 61 | return [w.text for w in doc] 62 | 63 | @staticmethod 64 | def _parse_doc(doc): 65 | sentences = [] 66 | for sent in doc.sents: 67 | tagged_words = [] 68 | for token in sent.doc: 69 | tagged_words.append(TaggedWord(index=token.i, 70 | term=NLPToolkit.replace_treebank_standards(token.text), 71 | lemma=token.lemma_, 72 | pos=token.pos_ if doc.is_tagged else None, 73 | ner=token.ent_type_)) 74 | sentences.append(tagged_words) 75 | return sentences 76 | 77 | @staticmethod 78 | def replace_treebank_standards(token): 79 | if token in ('``', "''"): 80 | return '"' 81 | elif token == '`': 82 | return "'" 83 | elif token == '-LRB-': 84 | return "(" 85 | elif token == '-RRB-': 86 | return ")" 87 | elif token == '-LCB-': 88 | return "{" 89 | elif token == '-RCB-': 90 | return "}" 91 | elif token == '-LSB-': 92 | return "[" 93 | elif token == '-RSB-': 94 | return "]" 95 | else: 96 | return token 97 | 98 | @staticmethod 99 | def is_stopword(word): 100 | return word in STOP_WORDS 101 | 102 | 103 | def normalize_entities(sentences, entities=None, decapitalize=True): 104 | if entities is None: 105 | entities = {'PERSON': '', 106 | 'URL': '', 107 | 'NUMBER': '', 108 | 'PERCENT': '', 109 | 'DURATION': '', 110 | 'MONEY': '', 111 | 'DATE': '