├── data └── .gitignore ├── libs └── .gitignore ├── logs └── .gitignore ├── output └── .gitignore ├── requirements.txt ├── data_prep_tools ├── README ├── extract_plain_bioasq_dataset.py ├── extract_plain_reuters21578_dataset.py ├── run_reuters_preparation.sh ├── run_bioasq_preparation.sh ├── run_rcv1_preparation.sh └── extract_plain_rcv1_dataset.py ├── README.md ├── setup_exp_env.sh ├── create_rcv1_hdf5.sh ├── create_bioasq_hdf5.sh ├── create_reuters2178_hdf5.sh ├── wordemb_pretrain ├── train_word_emb_gensim.py ├── prepare_full_wiki.sh ├── extract_title_and_body.py └── WikiExtractor.py ├── exp_scripts ├── bioasq │ ├── run_bioasq_encdec_exp.sh │ ├── run_bioasq_rnnm_exp.sh │ ├── run_bioasq_rnnm_reverse_exp.sh │ ├── config_bioasq_rnnm.json │ ├── config_bioasq_rnnm_reverse.json │ ├── config_bioasq_encdec.json │ ├── config_bioasq_rnnm_topsort.json │ └── config_bioasq_rnnm_reverse_topsort.json ├── rcv1 │ ├── run_rcv1_rnnb_exp.sh │ ├── run_rcv1_rnnb_reverse_exp.sh │ ├── run_rcv1_rnnm_exp.sh │ ├── run_rcv1_encdec_exp.sh │ ├── run_rcv1_rnnm_reverse_exp.sh │ ├── run_rcv1_encdec_reverse_exp.sh │ ├── config_rcv1_rnnb.json │ ├── config_rcv1_rnnb_reverse.json │ ├── config_rcv1_rnnm.json │ ├── config_rcv1_rnnm_reverse.json │ ├── config_rcv1_encdec.json │ ├── config_rcv1_encdec_reverse.json │ ├── config_rcv1_rnnm_topsort.json │ └── config_rcv1_rnnm_reverse_topsort.json └── reuters │ ├── run_reuters_rnnb_exp.sh │ ├── run_reuters_rnnm_exp.sh │ ├── run_reuters_encdec_exp.sh │ ├── run_reuters_rnnb_reverse_exp.sh │ ├── run_reuters_rnnm_reverse_exp.sh │ ├── run_reuters_encdec_reverse_exp.sh │ ├── config_reuters_rnnb.json │ ├── config_reuters_rnnb_reverse.json │ ├── config_reuters_rnnm.json │ ├── config_reuters_encdec.json │ ├── config_reuters_rnnm_reverse.json │ └── config_reuters_encdec_reverse.json ├── proc_util ├── create_character_vocab.py ├── cut_long_sentences.py ├── delete_instances.py └── sort_labels.py ├── run_eval_script.py ├── generate_rcv1_label_vocab.py ├── misc.py ├── generate_bioasq_label_vocab.py ├── predict.py ├── optimizers.py ├── create_hdf5_dataset.py ├── evals.py ├── data_iterator.py ├── utils.py └── layers.py /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /libs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /output/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nose2 2 | flake8 3 | toolz==0.8.0 4 | git+git://github.com/mila-udem/fuel.git@0.2.0 5 | git+git://github.com/Theano/Theano.git@rel-0.8.2 6 | -------------------------------------------------------------------------------- /data_prep_tools/README: -------------------------------------------------------------------------------- 1 | This directory contains files to create datasets to be prepared for the next preprocessing steps. 2 | 3 | 1) Reuters21578 4 | 5 | ./run_wiki_preparation.sh 6 | 7 | 2) RCV1 8 | 9 | ./run_nyt_preparation.sh 10 | 11 | 3) BioASQ 12 | 13 | ./run_bioasq_preparation.sh 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Implementation of the following paper 2 | 3 | ``` 4 | @incollection{jinseok2017maximizing, 5 | author = {Jinseok Nam, Eneldo Loza Menc{\'i}a, Hyunwoo J. Kim and Johannes F{\"u}rnkranz}, 6 | title = {Maximizing Subset Accuracy with Recurrent Neural Networks in Multi-label Classification}, 7 | booktitle = {Advances in Neural Information Processing Sysems 30}, 8 | year = {2017} 9 | } 10 | ``` 11 | -------------------------------------------------------------------------------- /setup_exp_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | WORKING_PATH=$PWD 4 | LIBS_DIR=$WORKING_PATH/libs 5 | 6 | mkdir -p $LIBS_DIR 7 | 8 | # download pre-compiled HDF5 binary 9 | wget https://www.hdfgroup.org/ftp/HDF5/releases/hdf5-1.8.12/bin/linux-x86_64/hdf5-1.8.12-linux-x86_64-shared.tar.gz -O $LIBS_DIR/hdf5.tar.gz 10 | mkdir -p $LIBS_DIR/hdf5 11 | tar xzf $LIBS_DIR/hdf5.tar.gz -C $LIBS_DIR/hdf5 --strip-components=1 12 | export HDF5_DIR=$LIBS_DIR/hdf5 13 | 14 | # install required packages 15 | pip install -r requirements.txt --user 16 | -------------------------------------------------------------------------------- /create_rcv1_hdf5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=data/RCV1 4 | WORD_EMB_PATH=output/gensim_pretrained_word_emb.model 5 | OUTPUT_FILENAME=rcv1_dataset.hdf5 6 | 7 | python create_hdf5_dataset.py \ 8 | --trd ${DATA_PATH}/trd.delete.tok.max_300.lc.txt \ 9 | --trl ${DATA_PATH}/trl.delete.txt \ 10 | --vad ${DATA_PATH}/vad.delete.tok.max_300.lc.txt \ 11 | --val ${DATA_PATH}/val.delete.txt \ 12 | --tsd ${DATA_PATH}/tsd.delete.tok.max_300.lc.txt \ 13 | --tsl ${DATA_PATH}/tsl.delete.txt \ 14 | --label_vocab ${DATA_PATH}/label_vocab.txt \ 15 | --word_emb ${WORD_EMB_PATH} \ 16 | --output ${DATA_PATH}/${OUTPUT_FILENAME} 17 | -------------------------------------------------------------------------------- /create_bioasq_hdf5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=data/BioASQ 4 | WORD_EMB_PATH=output/gensim_pretrained_word_emb.model 5 | OUTPUT_FILENAME=bioasq_dataset.hdf5 6 | 7 | python create_hdf5_dataset.py \ 8 | --trd ${DATA_PATH}/trd.delete.tok.max_300.lc.txt \ 9 | --trl ${DATA_PATH}/trl.delete.txt \ 10 | --vad ${DATA_PATH}/vad.delete.tok.max_300.lc.txt \ 11 | --val ${DATA_PATH}/val.delete.txt \ 12 | --tsd ${DATA_PATH}/tsd.delete.tok.max_300.lc.txt \ 13 | --tsl ${DATA_PATH}/tsl.delete.txt \ 14 | --label_vocab ${DATA_PATH}/label_vocab.txt \ 15 | --word_emb ${WORD_EMB_PATH} \ 16 | --output ${DATA_PATH}/${OUTPUT_FILENAME} 17 | -------------------------------------------------------------------------------- /create_reuters2178_hdf5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=data/Reuters21578 4 | WORD_EMB_PATH=output/gensim_pretrained_word_emb.model 5 | OUTPUT_FILENAME=reuters21578_dataset.hdf5 6 | 7 | python create_hdf5_dataset.py \ 8 | --trd ${DATA_PATH}/trd.delete.tok.max_300.lc.txt \ 9 | --trl ${DATA_PATH}/trl.delete.txt \ 10 | --vad ${DATA_PATH}/vad.delete.tok.max_300.lc.txt \ 11 | --val ${DATA_PATH}/val.delete.txt \ 12 | --tsd ${DATA_PATH}/tsd.delete.tok.max_300.lc.txt \ 13 | --tsl ${DATA_PATH}/tsl.delete.txt \ 14 | --label_vocab ${DATA_PATH}/label_vocab.txt \ 15 | --word_emb ${WORD_EMB_PATH} \ 16 | --output ${DATA_PATH}/${OUTPUT_FILENAME} 17 | -------------------------------------------------------------------------------- /wordemb_pretrain/train_word_emb_gensim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import logging 4 | import argparse 5 | 6 | from gensim.models import word2vec 7 | 8 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', 9 | level=logging.INFO) 10 | 11 | if __name__ == '__main__': 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--corpus', type=str, required=True) 15 | parser.add_argument('--output', type=str, required=True) 16 | 17 | args = parser.parse_args() 18 | 19 | sentences = word2vec.LineSentence(args.corpus) 20 | model = word2vec.Word2Vec(sentences, 21 | size=512, 22 | sg=1, 23 | workers=4, 24 | min_count=30, 25 | iter=5) 26 | 27 | model.save(args.output) # pickle dump 28 | -------------------------------------------------------------------------------- /exp_scripts/bioasq/run_bioasq_encdec_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/BioASQ 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=bioasq 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_encdec.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu0,floatX=float32,scan.allow_gc=True' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/bioasq/run_bioasq_rnnm_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/BioASQ 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=bioasq 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_rnnm.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=cuda0,floatX=float32,scan.allow_gc=True' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/rcv1/run_rcv1_rnnb_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/Reuters21578 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=reuters 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_rnnb.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu0,floatX=float32,scan.allow_gc=True' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/rcv1/run_rcv1_rnnb_reverse_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/RCV1 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=rcv1 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_rnnb_reverse.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu1,floatX=float32,scan.allow_gc=True' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/rcv1/run_rcv1_rnnm_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/RCV1 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=rcv1 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_rnnm.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu0,floatX=float32,scan.allow_gc=True,lib.cnmem=0.95' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/rcv1/run_rcv1_encdec_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/RCV1 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=rcv1 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_encdec.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu1,floatX=float32,scan.allow_gc=True,lib.cnmem=0.95' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/bioasq/run_bioasq_rnnm_reverse_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/BioASQ 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=bioasq 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_rnnm_reverse.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=cuda1,floatX=float32,scan.allow_gc=True' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/rcv1/run_rcv1_rnnm_reverse_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/RCV1 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=rcv1 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_rnnm_reverse.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu1,floatX=float32,scan.allow_gc=True,lib.cnmem=0.95' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/reuters/run_reuters_rnnb_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/Reuters21578 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=reuters 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_rnnb.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu0,floatX=float32,scan.allow_gc=True,lib.cnmem=0.95' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/reuters/run_reuters_rnnm_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/Reuters21578 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=reuters 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_rnnm.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu0,floatX=float32,scan.allow_gc=True,lib.cnmem=0.95' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/rcv1/run_rcv1_encdec_reverse_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/RCV1 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=rcv1 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_encdec_reverse.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu0,floatX=float32,scan.allow_gc=True,lib.cnmem=0.95' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/reuters/run_reuters_encdec_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/Reuters21578 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=reuters 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_encdec.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu0,floatX=float32,scan.allow_gc=True,lib.cnmem=0.95' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/reuters/run_reuters_rnnb_reverse_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/Reuters21578 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=reuters 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_rnnb_reverse.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu1,floatX=float32,scan.allow_gc=True,lib.cnmem=0.95' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/reuters/run_reuters_rnnm_reverse_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/Reuters21578 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=reuters 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_rnnm_reverse.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu1,floatX=float32,scan.allow_gc=True,lib.cnmem=0.95' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /exp_scripts/reuters/run_reuters_encdec_reverse_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_PATH=$PWD 4 | 5 | DATASET_PATH=${BASE_PATH}/data/Reuters21578 6 | WORK_SCRATCH=/data/learned_models/mlc2seq 7 | DATASET=reuters 8 | INPUT_CONFIG_PATH=${BASE_PATH}/exp_scripts/${DATASET}/config_${DATASET}_encdec_reverse.json 9 | 10 | if [ $# -eq 0 ] 11 | then 12 | EXP_ID=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 10 | head -n 1) 13 | else 14 | EXP_ID=$1 15 | fi 16 | 17 | OUTPUT_MODEL_DIR=${WORK_SCRATCH}/output 18 | OUTPUT_PREFIX=${DATASET}_${EXP_ID} 19 | CONFIG_PATH=${OUTPUT_PREFIX}.config.json 20 | 21 | MODEL_PATH=${OUTPUT_MODEL_DIR}/${OUTPUT_PREFIX}.model.best.npz 22 | 23 | jq -c 'setpath(["management", "reload_from"]; "'${MODEL_PATH}'")' ${INPUT_CONFIG_PATH} | python -m simplejson.tool > tmp.$$.json && \ 24 | mv tmp.$$.json ${CONFIG_PATH} 25 | 26 | THEANO_FLAGS='device=gpu1,floatX=float32,scan.allow_gc=True,lib.cnmem=0.95' python mlc2seq_single.py \ 27 | --base_datapath ${DATASET_PATH} \ 28 | --base_outputpath ${OUTPUT_MODEL_DIR} \ 29 | --config ${CONFIG_PATH} \ 30 | --experiment_id ${OUTPUT_PREFIX} 31 | -------------------------------------------------------------------------------- /proc_util/create_character_vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | import operator 4 | 5 | 6 | def generate_character_vocab(input_filepath, output_filepath): 7 | 8 | with io.open(input_filepath, encoding='utf-8') as fin,\ 9 | io.open(output_filepath, 'w', encoding='utf-8') as fout: 10 | 11 | char_dict = dict() 12 | 13 | for line in fin: 14 | for c in line.strip(): 15 | if c not in char_dict: 16 | char_dict[c] = 0 17 | 18 | char_dict[c] += 1 19 | 20 | sorted_vocab = sorted(char_dict.items(), key=operator.itemgetter(1), 21 | reverse=True) 22 | 23 | for ch, val in sorted_vocab: 24 | fout.write('%d\t%s\n' % (val, ch)) 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--input', type=str, required=True) 30 | parser.add_argument('--output', type=str, required=True) 31 | 32 | args = parser.parse_args() 33 | 34 | generate_character_vocab(args.input, args.output) 35 | -------------------------------------------------------------------------------- /wordemb_pretrain/prepare_full_wiki.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | WIKI_DUMP=$1 # path to the wikipedia dump 4 | CLEANED_WIKI=$2 # output path 5 | 6 | TMP_DIR=$(mktemp -d) 7 | TMP_MERGE_FILE=$(mktemp) 8 | 9 | # extract plain text 10 | cat ${WIKI_DUMP} | python WikiExtractor.py -o ${TMP_DIR} 11 | 12 | find ${TMP_DIR} -type f -name 'wiki_*' -print0 | while IFS= read -r -d '' file 13 | do 14 | cat "$file" 15 | done > ${TMP_MERGE_FILE} 16 | 17 | TMP_BODY=$(mktemp) 18 | TMP_TITLE=$(mktemp) 19 | 20 | python extract_title_and_body.py --wiki_data ${TMP_MERGE_FILE} \ 21 | --text_body_output ${TMP_BODY} \ 22 | --title_output ${TMP_TITLE} 23 | 24 | # tokenize & lowercase 25 | cat ${TMP_BODY} | ../data_prep_tools/mosesdecoder/scripts/tokenizer/remove-non-printing-char.perl | \ 26 | ../data_prep_tools/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l en | \ 27 | ../data_prep_tools/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l en | \ 28 | ../data_prep_tools/mosesdecoder/scripts/generic/ph_numbers.perl -c | \ 29 | ../data_prep_tools/mosesdecoder/scripts/tokenizer/lowercase.perl -l en > ${CLEANED_WIKI} 30 | 31 | # clean up temporary files 32 | rm ${TMP_TITLE} 33 | rm ${TMP_BODY} 34 | rm ${TMP_MERGE_FILE} 35 | rm -rf ${TMP_DIR} 36 | -------------------------------------------------------------------------------- /proc_util/cut_long_sentences.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | 4 | 5 | def cut_sentences(input_filepath, output_filepath, max_length, level='word'): 6 | 7 | with io.open(input_filepath, encoding='utf-8') as fin, \ 8 | io.open(output_filepath, 'w', encoding='utf-8') as fout: 9 | 10 | for line in fin: 11 | line = line.replace('\r\n', '').strip() 12 | if level == 'word': 13 | line = line.split() 14 | 15 | if len(line) > max_length: 16 | output = ' '.join(line[:max_length]) if level == 'word' \ 17 | else line[:max_length].strip() 18 | else: 19 | output = ' '.join(line) if level == 'word' else line 20 | 21 | fout.write('%s\n' % output) 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--input', type=str, required=True) 27 | parser.add_argument('--output', type=str, required=True) 28 | parser.add_argument('--level', type=str, default='character') 29 | parser.add_argument('--max', type=int, required=True) 30 | 31 | args = parser.parse_args() 32 | 33 | output_filepath = args.output 34 | cut_sentences(args.input, args.output, args.max, args.level) 35 | -------------------------------------------------------------------------------- /run_eval_script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import pprint 4 | import logging 5 | import argparse 6 | 7 | from misc import loadz 8 | from evals import list2sparse, compute_all_measures 9 | 10 | FORMAT = '[%(asctime)s] %(levelname)s - %(message)s' 11 | logging.basicConfig(level=logging.INFO, format=FORMAT) 12 | LOGGER = logging.getLogger(__name__) 13 | 14 | 15 | def main(result_output_path): 16 | LOGGER.info('Loading data from {}'.format(result_output_path)) 17 | result = loadz(result_output_path) 18 | 19 | LOGGER.info('Converting targets and predictions into sparse matrices') 20 | n_labels = len(result['label_vocab']) 21 | preds = list2sparse(result['predictions'], n_labels=n_labels) 22 | targets = list2sparse(result['targets'], n_labels=n_labels) 23 | LOGGER.info('Done') 24 | 25 | eval_ret = compute_all_measures(targets, preds, mb_sz=10000, verbose=True) 26 | 27 | return eval_ret 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--load', type=str, required=True) 33 | parser.add_argument('--output', type=str) 34 | 35 | args = parser.parse_args() 36 | 37 | eval_results = main(args.load) 38 | if args.output is not None: 39 | with open(args.output, 'w') as f: 40 | pprint.pprint(list(eval_results.items()), f) 41 | else: 42 | pprint.pprint(list(eval_results.items())) 43 | -------------------------------------------------------------------------------- /generate_rcv1_label_vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import io 6 | import argparse 7 | from collections import OrderedDict 8 | from six.moves import xrange 9 | 10 | from misc import dfs_topsort 11 | 12 | 13 | def extract_hierarchy(path_hierarchy): 14 | graph = {} 15 | child_desc = {} 16 | with io.open(path_hierarchy, encoding='utf8') as f: 17 | for line in f: 18 | _, parent, _, child = line.strip().split()[:4] 19 | desc = '_'.join(line.strip().split()[5:]) 20 | 21 | if parent not in graph: 22 | graph[parent] = [] 23 | graph[parent].append(child) 24 | 25 | if child not in graph: 26 | graph[child] = [] 27 | 28 | if child not in child_desc: 29 | child_desc[child] = desc 30 | 31 | return graph, child_desc 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--path_to_hierarchy', type=str, required=True) 37 | parser.add_argument('--save_label_vocab', type=str, required=True) 38 | 39 | args = parser.parse_args() 40 | 41 | hs, label_desc = extract_hierarchy(args.path_to_hierarchy) 42 | 43 | order = dfs_topsort(hs, root='Root') 44 | 45 | label_vocab = OrderedDict([(order[idx], idx) for idx in xrange(len(order))]) 46 | 47 | with io.open(args.save_label_vocab, 'w', encoding='utf8') as f: 48 | for label, label_idx in label_vocab.items(): 49 | f.write(u'{}\t{}\n'.format(label_idx, label)) 50 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | 3 | 4 | def savez(obj, filename, protocol=0): 5 | """Saves a compressed object to disk 6 | 7 | """ 8 | joblib.dump(obj, filename, compress=True) 9 | 10 | 11 | def loadz(filename): 12 | """Loads a compressed object from disk 13 | """ 14 | return joblib.load(filename) 15 | 16 | 17 | def dfs_topsort(graph, root=None): 18 | """ Perform topological sort using a modified depth-first search 19 | 20 | Parameters 21 | ---------- 22 | graph: dict 23 | A directed graph 24 | 25 | root: str, int or None 26 | We perform topological sort for a subgraph of the given node. 27 | If `root` is None, the entire graph is considered. 28 | 29 | 30 | Return 31 | ------ 32 | L: list 33 | a sorted list of nodes in the input graph 34 | 35 | """ 36 | 37 | L = [] 38 | color = {u: "white" for u in graph} 39 | found_cycle = [False] 40 | 41 | subgraph = graph 42 | if root is not None: 43 | try: 44 | subgraph = graph[root] 45 | except KeyError as e: 46 | raise(e) 47 | 48 | for u in subgraph: 49 | if color[u] == "white": 50 | dfs_visit(graph, u, color, L, found_cycle) 51 | if found_cycle[0]: 52 | break 53 | 54 | if found_cycle[0]: 55 | L = [] 56 | 57 | L.reverse() 58 | return L 59 | 60 | 61 | def dfs_visit(graph, u, color, L, found_cycle): 62 | if found_cycle[0]: 63 | return 64 | 65 | color[u] = "gray" 66 | for v in graph[u]: 67 | if color[v] == "gray": 68 | print('Found cycle by {}'.format(u)) 69 | found_cycle[0] = True 70 | return 71 | if color[v] == "white": 72 | dfs_visit(graph, v, color, L, found_cycle) 73 | color[u] = "black" 74 | L.append(u) 75 | -------------------------------------------------------------------------------- /proc_util/delete_instances.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | import itertools 4 | 5 | 6 | def delete_instances(input_data_path, input_label_path, 7 | output_data_path, output_label_path, label_vocab_path): 8 | 9 | label_vocab = set() 10 | 11 | with io.open(label_vocab_path, encoding='utf-8') as fin: 12 | for line in fin: 13 | label = line.strip().split('\t')[1] 14 | if label not in label_vocab: 15 | label_vocab.add(label) 16 | 17 | with io.open(input_data_path, encoding='utf-8') as fin_data, \ 18 | io.open(input_label_path, encoding='utf-8') as fin_label, \ 19 | io.open(output_data_path, 'w', encoding='utf-8') as fout_data, \ 20 | io.open(output_label_path, 'w', encoding='utf-8') as fout_label: 21 | 22 | def check_labels(fin_data, fin_label, fout_data, fout_label): 23 | for dd, ll in itertools.izip(fin_data, fin_label): 24 | ll_ = [l for l in ll.split() if l in label_vocab] 25 | if len(ll_) == 0: 26 | continue 27 | 28 | fout_data.write('%s\n' % dd.strip()) 29 | fout_label.write('%s\n' % ' '.join(ll_)) 30 | 31 | check_labels(fin_data, fin_label, fout_data, fout_label) 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--data', type=str, required=True) 37 | parser.add_argument('--label', type=str, required=True) 38 | parser.add_argument('--out_data', type=str, required=True) 39 | parser.add_argument('--out_label', type=str, required=True) 40 | parser.add_argument('--label_vocab', type=str, required=True) 41 | 42 | args = parser.parse_args() 43 | 44 | delete_instances(args.data, args.label, args.out_data, args.out_label, 45 | args.label_vocab) 46 | -------------------------------------------------------------------------------- /proc_util/sort_labels.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | from itertools import count 4 | 5 | 6 | def sort_labels(input_filepath, label_vocab_path, output_filepath, reverse): 7 | 8 | dict_size = 1000000 9 | with io.open(label_vocab_path, encoding='utf8') as f: 10 | label_vocab = dict() 11 | if dict_size > 0: 12 | indices = range(len(label_vocab), dict_size) 13 | else: 14 | indices = count(len(label_vocab)) 15 | label_vocab.update(zip(map( 16 | lambda x: x.rstrip('\n').split('\t')[-1], f), indices)) 17 | 18 | with io.open(input_filepath, encoding='utf-8') as fin,\ 19 | io.open(output_filepath, 'w', encoding='utf-8') as fout: 20 | 21 | for line in fin: 22 | labels = line.strip().split() 23 | 24 | label_index_pairs = [ 25 | (l_, label_vocab[l_]) for l_ in labels if l_ in label_vocab 26 | ] 27 | ''' 28 | print type(label_index_pairs) 29 | print type(label_index_pairs[0]) 30 | print type(label_index_pairs[0][0]) 31 | print type(label_index_pairs[0][0]) 32 | ''' 33 | sorted_labels = sorted(label_index_pairs, 34 | key=lambda x: x[1], reverse=reverse) 35 | 36 | sorted_labels = [ 37 | lp[0] for lp in sorted_labels 38 | ] 39 | 40 | fout.write('%s\n' % u' '.join(sorted_labels)) 41 | 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--input', type=str, required=True) 46 | parser.add_argument('--label_vocab', type=str, required=True) 47 | parser.add_argument('--output', type=str, required=True) 48 | parser.add_argument('--reverse', dest='reverse', action='store_true') 49 | parser.set_defaults(reverse=False) 50 | 51 | args = parser.parse_args() 52 | 53 | sort_labels(args.input, args.label_vocab, args.output, args.reverse) 54 | -------------------------------------------------------------------------------- /wordemb_pretrain/extract_title_and_body.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import io 5 | import argparse 6 | 7 | 8 | def extract_title(wikipedia_header): 9 | title = wikipedia_header.strip().replace( 10 | '<', '').replace('>', '').split('title=')[-1] 11 | title_length = len(title) 12 | return re.sub('\s+', '_', title[1:title_length - 1]).strip() 13 | 14 | 15 | def main(markup_data_path, body_path, title_path): 16 | with io.open(markup_data_path, encoding='utf-8') as fin, \ 17 | io.open(body_path, 'w', encoding='utf-8') as fout_body, \ 18 | io.open(title_path, 'w', encoding='utf-8') as fout_title: 19 | 20 | start_writing = False 21 | num_docs = 0 22 | 23 | while True: 24 | line = fin.readline() 25 | if not line: 26 | break 27 | 28 | if line.startswith('= split_year: 55 | fout_tsl.write('%s\n' % mesh_str) 56 | fout_tsl.flush() 57 | 58 | # doc = re.sub(r"((\d+([\.\,]\d*)?)|\.\d+)", "0", doc) 59 | if year < split_year: 60 | fout_trd.write('%s\n' % doc) 61 | fout_trd.flush() 62 | elif year >= split_year: 63 | fout_tsd.write('%s\n' % doc) 64 | fout_tsd.flush() 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--input', type=str, required=True) 70 | parser.add_argument('--traindata_output', type=str, required=True) 71 | parser.add_argument('--trainlabel_output', type=str, required=True) 72 | parser.add_argument('--testdata_output', type=str, required=True) 73 | parser.add_argument('--testlabel_output', type=str, required=True) 74 | parser.add_argument('--split_year', type=int, required=True) 75 | 76 | args = parser.parse_args() 77 | 78 | processing(args.input, 79 | args.traindata_output, args.trainlabel_output, 80 | args.testdata_output, args.testlabel_output, args.split_year) 81 | -------------------------------------------------------------------------------- /data_prep_tools/extract_plain_reuters21578_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import io 5 | import os 6 | import re 7 | import sys 8 | 9 | import bs4 10 | from bs4 import BeautifulSoup 11 | 12 | 13 | def process_single_document(doc): 14 | assert isinstance(doc, bs4.element.Tag) 15 | 16 | selected_part = '' 17 | if doc.title: 18 | selected_part += doc.title.string + '.' 19 | if doc.body or doc.text: 20 | body_txt = doc.body.string if doc.body else doc.text 21 | try: 22 | splitted = body_txt.split() 23 | if splitted[-1] == 'Reuter' or \ 24 | splitted[-1] == 'REUTER': 25 | selected_part += ' ' + ' '.join(splitted[:-1]) 26 | else: 27 | selected_part += ' ' + ' '.join(splitted) 28 | 29 | except UnicodeEncodeError: 30 | print(doc.body.string) 31 | sys.exit(0) 32 | 33 | replaced = re.sub(' +', ' ', 34 | selected_part.replace('\n', ' ')) 35 | 36 | return replaced 37 | 38 | 39 | def main(data_path, trd_path, trl_path, tsd_path, tsl_path): 40 | train_set = {} 41 | test_set = {} 42 | train_label_set = set() 43 | test_label_set = set() 44 | train_doc_id, test_doc_id = 0, 0 45 | for filename in os.listdir(data_path): 46 | if not filename.endswith('.sgm'): 47 | continue 48 | 49 | with open('/'.join([data_path, filename])) as fin: 50 | reuter_docs = BeautifulSoup(fin.read(), 'html.parser').find_all('reuters') 51 | 52 | for doc in reuter_docs: 53 | labels = [label.string 54 | for label in doc.topics.find_all('d')] 55 | if len(labels) == 0: 56 | continue 57 | 58 | if doc['lewissplit'].lower() == 'train' and \ 59 | doc['topics'].lower() == 'yes': 60 | 61 | train_set[train_doc_id] = (doc, labels) 62 | train_label_set |= set(labels) 63 | train_doc_id += 1 64 | 65 | if doc['lewissplit'].lower() == 'test' and \ 66 | doc['topics'].lower() == 'yes': 67 | 68 | test_set[test_doc_id] = (doc, labels) 69 | test_label_set |= set(labels) 70 | test_doc_id += 1 71 | 72 | # delete labels 73 | common_labels = train_label_set.intersection(test_label_set) 74 | 75 | def filter_labels(labels, common_label_set): 76 | return [label for label in labels if label in common_label_set] 77 | 78 | def write_to_file(text_body_path, label_path, dataset, common_labels): 79 | with io.open(text_body_path, 'w', encoding='utf-8') as f_body, \ 80 | io.open(label_path, 'w', encoding='utf-8') as f_label: 81 | 82 | for doc_id, (doc, labels) in dataset.items(): 83 | output = process_single_document(doc) 84 | labels = filter_labels(labels, common_labels) 85 | 86 | if len(labels) > 0: 87 | try: 88 | f_body.write('%s\n' % output) 89 | except TypeError as e: 90 | print(doc) 91 | print(output) 92 | sys.exit(0) 93 | 94 | f_label.write('%s\n' % ' '.join(labels)) 95 | 96 | write_to_file(trd_path, trl_path, train_set, common_labels) 97 | write_to_file(tsd_path, tsl_path, test_set, common_labels) 98 | 99 | 100 | if __name__ == '__main__': 101 | 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('--data_path', type=str, required=True) 104 | parser.add_argument('--trd_path', type=str, required=True) 105 | parser.add_argument('--trl_path', type=str, required=True) 106 | parser.add_argument('--tsd_path', type=str, required=True) 107 | parser.add_argument('--tsl_path', type=str, required=True) 108 | 109 | args = parser.parse_args() 110 | 111 | main(args.data_path, 112 | args.trd_path, args.trl_path, 113 | args.tsd_path, args.tsl_path) 114 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import sys 5 | import io 6 | import json 7 | import os 8 | import logging 9 | from utils import mul2bin 10 | 11 | FORMAT = '[%(asctime)s] %(levelname)s - %(message)s' 12 | logging.basicConfig(level=logging.INFO, format=FORMAT) 13 | LOGGER = logging.getLogger(__name__) 14 | 15 | try: 16 | import numpy 17 | from six.moves import xrange 18 | import six 19 | 20 | import theano 21 | from theano.sandbox.rng_mrg import MRG_RandomStreams 22 | from data_iterator import (load_test_data, load_dict) 23 | from mlc2seq_base import (build_sampler, gen_sample, init_params) 24 | from utils import (load_params, init_tparams) 25 | from misc import savez 26 | except ImportError as e: 27 | EXIT_STATUS = 64 28 | print('Failed to import: %s' % str(e), file=sys.stderr) 29 | sys.exit(EXIT_STATUS) 30 | 31 | 32 | def main(model_path, data_base_path, option_path, saveto, k): 33 | 34 | # load model_options 35 | with io.open(option_path, encoding='utf8') as f: 36 | config = json.load(f) 37 | 38 | model_options = config['model'] 39 | test_data_options = config['testdata'] 40 | label_vocab_size = test_data_options['n_labels'] 41 | assert 'reverse_labels' in config['data'] 42 | reverse_labels = config['data']['reverse_labels'] 43 | 44 | def join_data_base_path(data_base, options): 45 | for kk, vv in six.iteritems(options): 46 | if kk in ['src', 'trg', 'input_vocab', 'label_vocab']: 47 | options[kk] = os.path.join(data_base, options[kk]) 48 | 49 | return options 50 | 51 | test_data_options = join_data_base_path(data_base_path, test_data_options) 52 | dicts_r, test_stream = load_test_data(**test_data_options) 53 | 54 | word_vocab = load_dict(test_data_options['input_vocab']) 55 | iword_vocab = dict((vv, kk) for kk, vv in six.iteritems(word_vocab)) 56 | label_vocab = load_dict(test_data_options['label_vocab'], 57 | dict_size=label_vocab_size, 58 | include_unk=False, reverse=reverse_labels) 59 | ilabel_vocab = dict((vv, kk) for kk, vv in six.iteritems(label_vocab)) 60 | 61 | model_options['n_labels'] = len(label_vocab) 62 | 63 | LOGGER.info('Building model') 64 | params = init_params(model_options) 65 | 66 | LOGGER.info('Loading parameters from {}'.format(model_path)) 67 | params = load_params(model_path, params) 68 | 69 | LOGGER.info('Initializing parameters') 70 | tparams = init_tparams(params) 71 | 72 | # use_noise is for dropout 73 | use_noise = theano.shared(numpy.float32(0.)) 74 | trng = MRG_RandomStreams(1234) 75 | 76 | n_samples = 0 77 | 78 | LOGGER.info('Building sampler') 79 | f_sample_inits, f_sample_nexts \ 80 | = build_sampler(tparams, model_options, trng, use_noise) 81 | 82 | results = dict() 83 | results['input_vocab'] = iword_vocab 84 | results['label_vocab'] = ilabel_vocab 85 | results['src'] = dict() 86 | results['predictions'] = dict() 87 | results['targets'] = dict() 88 | results['alignments'] = dict() 89 | 90 | for x, x_mask, y, y_mask in test_stream.get_epoch_iterator(): 91 | orig_x = x 92 | if model_options['label_type'] == 'binary': 93 | y, y_mask = mul2bin(y, y_mask, model_options['n_bins']) 94 | 95 | x, x_mask = x.T, x_mask.T 96 | 97 | if model_options['enc_dir'] == 'none': 98 | x_mask[(x == 0) | (x == 1)] = 0. 99 | 100 | for jj in xrange(x.shape[1]): 101 | sample_encoder_inps = [ 102 | x[:, jj][:, None], 103 | x_mask[:, jj][:, None] 104 | ] 105 | 106 | solutions = gen_sample(tparams, 107 | f_sample_inits, 108 | f_sample_nexts, 109 | sample_encoder_inps, 110 | model_options, 111 | trng=trng, 112 | k=k, 113 | max_label_len=50, 114 | argmax=False) 115 | 116 | samples = solutions['samples'] 117 | alignment = solutions['alignments'] 118 | scores = solutions['scores'] 119 | 120 | scores = scores / numpy.array([len(s) for s in samples]) 121 | best_sample = samples[scores.argmin()] 122 | best_alignment = alignment[scores.argmin()] 123 | 124 | results['src'][n_samples + jj] = orig_x[jj] 125 | results['predictions'][n_samples + jj] = best_sample 126 | results['alignments'][n_samples + jj] = numpy.array(best_alignment) 127 | results['targets'][n_samples + jj] = y[jj, y_mask[jj] == 1] 128 | 129 | n_samples += x.shape[1] 130 | LOGGER.info('Number of processed instances: {}'.format(n_samples)) 131 | 132 | LOGGER.info('Making predictions successfully on {} instances'.format( 133 | n_samples)) 134 | savez(results, saveto, 2) 135 | 136 | 137 | if __name__ == "__main__": 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument('-k', type=int, default=5) 140 | parser.add_argument('--model', type=str, required=True) 141 | parser.add_argument('--config', type=str, required=True) 142 | parser.add_argument('--base_datapath', type=str) 143 | parser.add_argument('--saveto', type=str, required=True) 144 | 145 | args = parser.parse_args() 146 | 147 | if args.base_datapath: 148 | data_base_path = os.path.realpath(args.base_datapath) 149 | else: 150 | data_base_path = os.getcwd() 151 | 152 | main(args.model, data_base_path, args.config, 153 | args.saveto, k=args.k) 154 | -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | import theano 2 | from theano import tensor 3 | 4 | import numpy as np 5 | 6 | import six 7 | from utils import itemlist 8 | 9 | 10 | # optimizers 11 | # name(hyperp, tparams, grads, inputs (list), cost) = \ 12 | # f_grad_shared, f_update 13 | def adam(lr, tparams, grads, inp, cost): 14 | gshared = [theano.shared(p.get_value() * 0., 15 | name='%s_grad' % k) 16 | for k, p in six.iteritems(tparams)] 17 | gsup = [(gs, g) for gs, g in zip(gshared, grads)] 18 | 19 | f_grad_shared = theano.function(inp, cost, updates=gsup, name='adam') 20 | 21 | # lr0 = 0.0002 22 | lr0 = lr 23 | b1 = 0.1 24 | b2 = 0.001 25 | e = 1e-6 26 | 27 | updates = [] 28 | 29 | i = theano.shared(np.float32(0.), name='adam_i') 30 | i_t = i + 1. 31 | fix1 = 1. - b1 ** (i_t) 32 | fix2 = 1. - b2 ** (i_t) 33 | lr_t = lr0 * (tensor.sqrt(fix2) / fix1) 34 | 35 | state = [i] 36 | 37 | for p, g in zip(tparams.values(), gshared): 38 | m = theano.shared(p.get_value() * 0., name='%s_m' % p.name) 39 | v = theano.shared(p.get_value() * 0., name='%s_v' % p.name) 40 | state.extend([m, v]) 41 | m_t = (b1 * g) + ((1. - b1) * m) 42 | v_t = (b2 * tensor.sqr(g)) + ((1. - b2) * v) 43 | g_t = m_t / (tensor.sqrt(v_t) + e) 44 | p_t = p - (lr_t * g_t) 45 | updates.append((m, m_t)) 46 | updates.append((v, v_t)) 47 | updates.append((p, p_t)) 48 | updates.append((i, i_t)) 49 | 50 | f_update = theano.function([lr], 51 | [], 52 | updates=updates, 53 | on_unused_input='ignore') 54 | 55 | return f_grad_shared, f_update, state 56 | 57 | 58 | def adadelta(lr, tparams, grads, inp, cost): 59 | zipped_grads = [theano.shared(p.get_value() * np.float32(0.), 60 | name='%s_grad' % k) 61 | for k, p in six.iteritems(tparams)] 62 | running_up2 = [theano.shared(p.get_value() * np.float32(0.), 63 | name='%s_rup2' % k) 64 | for k, p in six.iteritems(tparams)] 65 | running_grads2 = [theano.shared(p.get_value() * np.float32(0.), 66 | name='%s_rgrad2' % k) 67 | for k, p in six.iteritems(tparams)] 68 | 69 | zgup = [(zg, g) for zg, g in zip(zipped_grads, grads)] 70 | rg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2)) 71 | for rg2, g in zip(running_grads2, grads)] 72 | 73 | f_grad_shared = theano.function(inp, 74 | cost, 75 | updates=zgup + rg2up, 76 | name='adadelta') 77 | 78 | updir = [-tensor.sqrt(ru2 + 1e-6) / tensor.sqrt(rg2 + 1e-6) * zg 79 | for zg, ru2, rg2 in zip(zipped_grads, running_up2, running_grads2) 80 | ] 81 | ru2up = [(ru2, 0.95 * ru2 + 0.05 * (ud ** 2)) 82 | for ru2, ud in zip(running_up2, updir)] 83 | param_up = [(p, p + ud) for p, ud in zip(itemlist(tparams), updir)] 84 | 85 | f_update = theano.function([lr], 86 | [], 87 | updates=ru2up + param_up, 88 | on_unused_input='ignore') 89 | 90 | return f_grad_shared, f_update, running_up2 + running_grads2 91 | 92 | 93 | def rmsprop(lr, tparams, grads, inp, cost): 94 | zipped_grads = [theano.shared(p.get_value() * np.float32(0.), 95 | name='%s_grad' % k) 96 | for k, p in six.iteritems(tparams)] 97 | running_grads = [theano.shared(p.get_value() * np.float32(0.), 98 | name='%s_rgrad' % k) 99 | for k, p in six.iteritems(tparams)] 100 | running_grads2 = [theano.shared(p.get_value() * np.float32(0.), 101 | name='%s_rgrad2' % k) 102 | for k, p in six.iteritems(tparams)] 103 | 104 | zgup = [(zg, g) for zg, g in zip(zipped_grads, grads)] 105 | rgup = [(rg, 0.95 * rg + 0.05 * g) for rg, g in zip(running_grads, grads)] 106 | rg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2)) 107 | for rg2, g in zip(running_grads2, grads)] 108 | 109 | f_grad_shared = theano.function(inp, 110 | cost, 111 | updates=zgup + rgup + rg2up, 112 | name='rmsprop') 113 | 114 | updir = [theano.shared(p.get_value() * np.float32(0.), 115 | name='%s_updir' % k) 116 | for k, p in six.iteritems(tparams)] 117 | updir_new = [(ud, 0.9 * ud - lr * zg / tensor.sqrt(rg2 - rg ** 2 + 1e-4)) 118 | for ud, zg, rg, rg2 in zip(updir, zipped_grads, running_grads, 119 | running_grads2)] 120 | param_up = [(p, p + udn[1]) for p, udn in zip( 121 | itemlist(tparams), updir_new)] 122 | f_update = theano.function([lr], 123 | [], 124 | updates=updir_new + param_up, 125 | on_unused_input='ignore') 126 | 127 | return f_grad_shared, f_update, running_grads + running_grads2 + updir 128 | 129 | 130 | def sgd(lr, tparams, grads, inp, cost): 131 | gshared = [theano.shared(p.get_value() * 0., 132 | name='%s_grad' % k) 133 | for k, p in six.iteritems(tparams)] 134 | gsup = [(gs, g) for gs, g in zip(gshared, grads)] 135 | 136 | f_grad_shared = theano.function( 137 | inp, 138 | cost, 139 | updates=gsup, 140 | name='sgd') 141 | 142 | pup = [(p, p - lr * g) for p, g in zip(itemlist(tparams), gshared)] 143 | f_update = theano.function([lr], [], updates=pup) 144 | 145 | return f_grad_shared, f_update, [] 146 | -------------------------------------------------------------------------------- /data_prep_tools/run_reuters_preparation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function create_dir { 4 | if [[ ! -e $1 ]]; then 5 | mkdir $1 6 | elif [[ ! -d $1 ]]; then 7 | echo "$1 exists but not a directory" 1>&2 8 | exit 9 | fi 10 | } 11 | 12 | CWD=$PWD 13 | SRC_DIR=${CWD}/data/Reuters21578/original_dataset 14 | OUTPUT_BASE=${CWD}/data/Reuters21578/MLC2SEQ 15 | 16 | create_dir $OUTPUT_BASE 17 | 18 | trd_base=${OUTPUT_BASE}/trd 19 | trl_base=${OUTPUT_BASE}/trl 20 | vad_base=${OUTPUT_BASE}/vad 21 | val_base=${OUTPUT_BASE}/val 22 | tsd_base=${OUTPUT_BASE}/tsd 23 | tsl_base=${OUTPUT_BASE}/tsl 24 | 25 | TRD=$trd_base.txt 26 | TRL=$trl_base.txt 27 | VAD=$vad_base.txt 28 | VAL=$val_base.txt 29 | TSD=$tsd_base.txt 30 | TSL=$tsl_base.txt 31 | CHAR_VOCAB=${OUTPUT_BASE}/char_vocab.txt 32 | WORD_VOCAB=${OUTPUT_BASE}/word_vocab.txt 33 | LABEL_VOCAB=${OUTPUT_BASE}/label_vocab.txt 34 | SRC=${CWD}/proc_util 35 | MOSESDEC_PATH=${CWD}/data_prep_tools/mosesdecoder 36 | LABEL_MIN_CNT=1 37 | WORD_MIN_CNT=1 38 | MAX_SENT_LEN=300 39 | N_THREADS=4 40 | 41 | python ${CWD}/data_prep_tools/extract_plain_reuters21578_dataset.py \ 42 | --data_path ${SRC_DIR} \ 43 | --trd_path ${TRD} \ 44 | --trl_path ${TRL} \ 45 | --tsd_path ${TSD} \ 46 | --tsl_path ${TSL} 47 | 48 | # shuffle data to break down connection 49 | TRD_SHUF=$trd_base.shuf.txt 50 | TRL_SHUF=$trl_base.shuf.txt 51 | dd if=/dev/urandom of=rand count=$((128*1024)) status=none 52 | shuf --random-source=rand ${TRD} > ${TRD_SHUF} 53 | shuf --random-source=rand ${TRL} > ${TRL_SHUF} 54 | mv ${TRD_SHUF} ${TRD} 55 | mv ${TRL_SHUF} ${TRL} 56 | rm rand 57 | 58 | TMP_TRD=${OUTPUT_BASE}/tmp_trd.txt 59 | TMP_TRL=${OUTPUT_BASE}/tmp_trl.txt 60 | TMP_VAL=${OUTPUT_BASE}/tmp_val.txt 61 | TMP_TSL=${OUTPUT_BASE}/tmp_tsl.txt 62 | 63 | VAD_SIZE=777 64 | 65 | # split the original train data into the train and validation sets 66 | PREV_TRD_NUM_LINE="$(wc -l ${TRD} | cut -d' ' -f 1)" 67 | cat ${TRD} | head -n ${VAD_SIZE} > ${VAD} 68 | cat ${TRL} | head -n ${VAD_SIZE} > ${VAL} 69 | cat ${TRD} | tail -n $((PREV_TRD_NUM_LINE-VAD_SIZE)) > ${TMP_TRD} 70 | cat ${TRL} | tail -n $((PREV_TRD_NUM_LINE-VAD_SIZE)) > ${TMP_TRL} 71 | mv ${TMP_TRD} ${TRD} 72 | mv ${TMP_TRL} ${TRL} 73 | 74 | # create vocabularies 75 | python ${SRC}/create_character_vocab.py --input ${TRD} --output ${CHAR_VOCAB} 76 | awk -v OFS="\t" -v LABEL_MIN_CNT=$LABEL_MIN_CNT '{ for(i=1; i<=NF; i++) w[$i]++ } END {for(i in w) { if(w[i] >= LABEL_MIN_CNT) {print w[i], i}} }' ${TRL} | sort -k 1nr > ${LABEL_VOCAB} 77 | 78 | # sort labels of each instance by label frequency 79 | python ${SRC}/sort_labels.py --input ${TRL} --label_vocab ${LABEL_VOCAB} --output ${TMP_TRL} 80 | python ${SRC}/sort_labels.py --input ${VAL} --label_vocab ${LABEL_VOCAB} --output ${TMP_VAL} 81 | python ${SRC}/sort_labels.py --input ${TSL} --label_vocab ${LABEL_VOCAB} --output ${TMP_TSL} 82 | mv ${TMP_TRL} ${TRL} 83 | mv ${TMP_VAL} ${VAL} 84 | mv ${TMP_TSL} ${TSL} 85 | 86 | # delete instances which have empty label set 87 | python ${SRC}/delete_instances.py --data ${TRD} --label ${TRL} --out_data $trd_base.delete.txt --out_label $trl_base.delete.txt --label_vocab ${LABEL_VOCAB} 88 | python ${SRC}/delete_instances.py --data ${VAD} --label ${VAL} --out_data $vad_base.delete.txt --out_label $val_base.delete.txt --label_vocab ${LABEL_VOCAB} 89 | python ${SRC}/delete_instances.py --data ${TSD} --label ${TSL} --out_data $tsd_base.delete.txt --out_label $tsl_base.delete.txt --label_vocab ${LABEL_VOCAB} 90 | 91 | trd_base=$trd_base.delete 92 | trl_base=$trl_base.delete 93 | vad_base=$vad_base.delete 94 | val_base=$val_base.delete 95 | tsd_base=$tsd_base.delete 96 | tsl_base=$tsl_base.delete 97 | 98 | # tokenize 99 | cat $trd_base.txt | ${MOSESDEC_PATH}/scripts/tokenizer/normalize-punctuation.perl -l en | \ 100 | ${MOSESDEC_PATH}/scripts/tokenizer/tokenizer.perl -a -l en -threads ${N_THREADS} | \ 101 | ${MOSESDEC_PATH}/scripts/generic/ph_numbers.perl -c > $trd_base.tok.txt 102 | 103 | cat $vad_base.txt | ${MOSESDEC_PATH}/scripts/tokenizer/normalize-punctuation.perl -l en | \ 104 | ${MOSESDEC_PATH}/scripts/tokenizer/tokenizer.perl -a -l en -threads ${N_THREADS} | \ 105 | ${MOSESDEC_PATH}/scripts/generic/ph_numbers.perl -c > $vad_base.tok.txt 106 | 107 | cat $tsd_base.txt | ${MOSESDEC_PATH}/scripts/tokenizer/normalize-punctuation.perl -l en | \ 108 | ${MOSESDEC_PATH}/scripts/tokenizer/tokenizer.perl -a -l en -threads ${N_THREADS} | \ 109 | ${MOSESDEC_PATH}/scripts/generic/ph_numbers.perl -c > $tsd_base.tok.txt 110 | 111 | trd_base=$trd_base.tok 112 | trl_base=$trl_base.tok 113 | vad_base=$vad_base.tok 114 | val_base=$val_base.tok 115 | tsd_base=$tsd_base.tok 116 | tsl_base=$tsl_base.tok 117 | 118 | # limit the length of each document 119 | python ${SRC}/cut_long_sentences.py --input $trd_base.txt --output $trd_base.max_${MAX_SENT_LEN}.txt --max ${MAX_SENT_LEN} --level word 120 | python ${SRC}/cut_long_sentences.py --input $vad_base.txt --output $vad_base.max_${MAX_SENT_LEN}.txt --max ${MAX_SENT_LEN} --level word 121 | python ${SRC}/cut_long_sentences.py --input $tsd_base.txt --output $tsd_base.max_${MAX_SENT_LEN}.txt --max ${MAX_SENT_LEN} --level word 122 | 123 | trd_base=$trd_base.max_${MAX_SENT_LEN} 124 | trl_base=$trl_base.max_${MAX_SENT_LEN} 125 | vad_base=$vad_base.max_${MAX_SENT_LEN} 126 | val_base=$val_base.max_${MAX_SENT_LEN} 127 | tsd_base=$tsd_base.max_${MAX_SENT_LEN} 128 | tsl_base=$tsl_base.max_${MAX_SENT_LEN} 129 | 130 | # lowercase 131 | ${MOSESDEC_PATH}/scripts/tokenizer/lowercase.perl -l en < $trd_base.txt > $trd_base.lc.txt 132 | ${MOSESDEC_PATH}/scripts/tokenizer/lowercase.perl -l en < $vad_base.txt > $vad_base.lc.txt 133 | ${MOSESDEC_PATH}/scripts/tokenizer/lowercase.perl -l en < $tsd_base.txt > $tsd_base.lc.txt 134 | 135 | trd_base=$trd_base.lc 136 | trl_base=$trl_base.lc 137 | vad_base=$vad_base.lc 138 | val_base=$val_base.lc 139 | tsd_base=$tsd_base.lc 140 | tsl_base=$tsl_base.lc 141 | 142 | # create the word vocabulary 143 | awk -v OFS="\t" -v WORD_MIN_CNT=$WORD_MIN_CNT '{ for(i=1; i<=NF; i++) w[$i]++ } END {for(i in w) { if(w[i] >= WORD_MIN_CNT) {print w[i], i}} }' $trd_base.txt | sort -k 1nr > ${WORD_VOCAB} 144 | 145 | -------------------------------------------------------------------------------- /data_prep_tools/run_bioasq_preparation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function create_dir { 4 | if [[ ! -e $1 ]]; then 5 | mkdir $1 6 | elif [[ ! -d $1 ]]; then 7 | echo "$1 exists but not a directory" 1>&2 8 | exit 9 | fi 10 | } 11 | 12 | CWD=$PWD 13 | BASE_PATH=${CWD}/data/BioASQ 14 | OUTPUT_BASE=${CWD}/data/BioASQ/MLC2SEQ 15 | 16 | create_dir $OUTPUT_BASE 17 | 18 | RAW_JSON=${BASE_PATH}/allMeSH.json 19 | trd_base=${OUTPUT_BASE}/trd 20 | trl_base=${OUTPUT_BASE}/trl 21 | vad_base=${OUTPUT_BASE}/vad 22 | val_base=${OUTPUT_BASE}/val 23 | tsd_base=${OUTPUT_BASE}/tsd 24 | tsl_base=${OUTPUT_BASE}/tsl 25 | 26 | TRD=$trd_base.txt 27 | TRL=$trl_base.txt 28 | VAD=$vad_base.txt 29 | VAL=$val_base.txt 30 | TSD=$tsd_base.txt 31 | TSL=$tsl_base.txt 32 | CHAR_VOCAB=${OUTPUT_BASE}/char_vocab.txt 33 | WORD_VOCAB=${OUTPUT_BASE}/word_vocab.txt 34 | LABEL_VOCAB=${OUTPUT_BASE}/label_vocab.txt 35 | SPLIT_YEAR=2014 36 | SRC=${CWD}/proc_util 37 | MOSESDEC_PATH=${CWD}/data_prep_tools/mosesdecoder 38 | MAX_SENT_LEN=300 39 | WORD_MIN_CNT=1 40 | LABEL_MIN_CNT=1 41 | N_THREADS=4 42 | 43 | python ${CWD}/data_prep_tools/extract_plain_bioasq_dataset.py \ 44 | --input ${RAW_JSON} \ 45 | --traindata_output ${TRD} \ 46 | --trainlabel_output ${TRL} \ 47 | --testdata_output ${TSD} \ 48 | --testlabel_output ${TSL} \ 49 | --split_year ${SPLIT_YEAR} 50 | 51 | # shuffle data to break down connection 52 | TRD_SHUF=$trd_base.shuf.txt 53 | TRL_SHUF=$trl_base.shuf.txt 54 | dd if=/dev/urandom of=rand count=$((128*1024)) status=none 55 | shuf --random-source=rand ${TRD} > ${TRD_SHUF} 56 | shuf --random-source=rand ${TRL} > ${TRL_SHUF} 57 | mv ${TRD_SHUF} ${TRD} 58 | mv ${TRL_SHUF} ${TRL} 59 | rm rand 60 | 61 | TMP_TRD=${OUTPUT_BASE}/tmp_trd.txt 62 | TMP_TRL=${OUTPUT_BASE}/tmp_trl.txt 63 | TMP_VAL=${OUTPUT_BASE}/tmp_val.txt 64 | TMP_TSL=${OUTPUT_BASE}/tmp_tsl.txt 65 | 66 | VAD_SIZE=50000 67 | 68 | # split the original train data into the train and validation sets 69 | PREV_TRD_NUM_LINE="$(wc -l ${TRD} | cut -d' ' -f 1)" 70 | cat ${TRD} | head -n ${VAD_SIZE} > ${VAD} 71 | cat ${TRL} | head -n ${VAD_SIZE} > ${VAL} 72 | cat ${TRD} | tail -n $((PREV_TRD_NUM_LINE-VAD_SIZE)) > ${TMP_TRD} 73 | cat ${TRL} | tail -n $((PREV_TRD_NUM_LINE-VAD_SIZE)) > ${TMP_TRL} 74 | mv ${TMP_TRD} ${TRD} 75 | mv ${TMP_TRL} ${TRL} 76 | 77 | # create vocabularies 78 | python ${SRC}/create_character_vocab.py --input ${TRD} --output ${CHAR_VOCAB} 79 | awk -v OFS="\t" -v LABEL_MIN_CNT=$LABEL_MIN_CNT '{ for(i=1; i<=NF; i++) w[$i]++ } END {for(i in w) { if(w[i] >= LABEL_MIN_CNT) {print w[i], i}} }' ${TRL} | sort -k 1nr > ${LABEL_VOCAB} 80 | 81 | # sort labels of each instance by label frequency 82 | python ${SRC}/sort_labels.py --input ${TRL} --label_vocab ${LABEL_VOCAB} --output ${TMP_TRL} 83 | python ${SRC}/sort_labels.py --input ${VAL} --label_vocab ${LABEL_VOCAB} --output ${TMP_VAL} 84 | python ${SRC}/sort_labels.py --input ${TSL} --label_vocab ${LABEL_VOCAB} --output ${TMP_TSL} 85 | mv ${TMP_TRL} ${TRL} 86 | mv ${TMP_VAL} ${VAL} 87 | mv ${TMP_TSL} ${TSL} 88 | 89 | # delete instances which have empty label set 90 | python ${SRC}/delete_instances.py --data ${TRD} --label ${TRL} --out_data $trd_base.delete.txt --out_label $trl_base.delete.txt --label_vocab ${LABEL_VOCAB} 91 | python ${SRC}/delete_instances.py --data ${VAD} --label ${VAL} --out_data $vad_base.delete.txt --out_label $val_base.delete.txt --label_vocab ${LABEL_VOCAB} 92 | python ${SRC}/delete_instances.py --data ${TSD} --label ${TSL} --out_data $tsd_base.delete.txt --out_label $tsl_base.delete.txt --label_vocab ${LABEL_VOCAB} 93 | 94 | trd_base=$trd_base.delete 95 | trl_base=$trl_base.delete 96 | vad_base=$vad_base.delete 97 | val_base=$val_base.delete 98 | tsd_base=$tsd_base.delete 99 | tsl_base=$tsl_base.delete 100 | 101 | # tokenize 102 | cat $trd_base.txt | ${MOSESDEC_PATH}/scripts/tokenizer/normalize-punctuation.perl -l en | \ 103 | ${MOSESDEC_PATH}/scripts/tokenizer/tokenizer.perl -a -l en -threads ${N_THREADS} | \ 104 | ${MOSESDEC_PATH}/scripts/generic/ph_numbers.perl -c > $trd_base.tok.txt 105 | 106 | cat $vad_base.txt | ${MOSESDEC_PATH}/scripts/tokenizer/normalize-punctuation.perl -l en | \ 107 | ${MOSESDEC_PATH}/scripts/tokenizer/tokenizer.perl -a -l en -threads ${N_THREADS} | \ 108 | ${MOSESDEC_PATH}/scripts/generic/ph_numbers.perl -c > $vad_base.tok.txt 109 | 110 | cat $tsd_base.txt | ${MOSESDEC_PATH}/scripts/tokenizer/normalize-punctuation.perl -l en | \ 111 | ${MOSESDEC_PATH}/scripts/tokenizer/tokenizer.perl -a -l en -threads ${N_THREADS} | \ 112 | ${MOSESDEC_PATH}/scripts/generic/ph_numbers.perl -c > $tsd_base.tok.txt 113 | 114 | trd_base=$trd_base.tok 115 | trl_base=$trl_base.tok 116 | vad_base=$vad_base.tok 117 | val_base=$val_base.tok 118 | tsd_base=$tsd_base.tok 119 | tsl_base=$tsl_base.tok 120 | 121 | # limit the length of each document 122 | python ${SRC}/cut_long_sentences.py --input $trd_base.txt --output $trd_base.max_${MAX_SENT_LEN}.txt --max ${MAX_SENT_LEN} --level word 123 | python ${SRC}/cut_long_sentences.py --input $vad_base.txt --output $vad_base.max_${MAX_SENT_LEN}.txt --max ${MAX_SENT_LEN} --level word 124 | python ${SRC}/cut_long_sentences.py --input $tsd_base.txt --output $tsd_base.max_${MAX_SENT_LEN}.txt --max ${MAX_SENT_LEN} --level word 125 | 126 | trd_base=$trd_base.max_${MAX_SENT_LEN} 127 | trl_base=$trl_base.max_${MAX_SENT_LEN} 128 | vad_base=$vad_base.max_${MAX_SENT_LEN} 129 | val_base=$val_base.max_${MAX_SENT_LEN} 130 | tsd_base=$tsd_base.max_${MAX_SENT_LEN} 131 | tsl_base=$tsl_base.max_${MAX_SENT_LEN} 132 | 133 | # lowercase 134 | ${MOSESDEC_PATH}/scripts/tokenizer/lowercase.perl -l en < $trd_base.txt > $trd_base.lc.txt 135 | ${MOSESDEC_PATH}/scripts/tokenizer/lowercase.perl -l en < $vad_base.txt > $vad_base.lc.txt 136 | ${MOSESDEC_PATH}/scripts/tokenizer/lowercase.perl -l en < $tsd_base.txt > $tsd_base.lc.txt 137 | 138 | trd_base=$trd_base.lc 139 | trl_base=$trl_base.lc 140 | vad_base=$vad_base.lc 141 | val_base=$val_base.lc 142 | tsd_base=$tsd_base.lc 143 | tsl_base=$tsl_base.lc 144 | 145 | # create the word vocabulary 146 | awk -v OFS="\t" -v WORD_MIN_CNT=$WORD_MIN_CNT '{ for(i=1; i<=NF; i++) w[$i]++ } END {for(i in w) { if(w[i] >= WORD_MIN_CNT) {print w[i], i}} }' $trd_base.txt | sort -k 1nr > ${WORD_VOCAB} 147 | -------------------------------------------------------------------------------- /data_prep_tools/run_rcv1_preparation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CWD=$PWD 4 | ORIGINAL_DATA_PATH=${CWD}/data/RCV1 5 | OUTPUT_BASE=${CWD}/data/RCV1/MLC2SEQ 6 | IS_SWAP=True # sawp train and test split if true 7 | 8 | if [ ! -d "$ORIGINAL_DATA_PATH/CD1" ] || [ ! -d "$ORIGINAL_DATA_PATH/CD2" ]; then 9 | echo "Make sure you have RCV1 CD1 and CD2 directories under ${ORIGINAL_DATA_PATH}." 10 | exit 1 11 | fi 12 | 13 | TMP_DIR=$(mktemp -d) 14 | re='^[0-9]+$' 15 | 16 | echo "Uncompress zipped files.." 17 | find ${ORIGINAL_DATA_PATH} \( -path "*/CD1/*" -o -path "*/CD2/*" \) -type f -name '*.zip' -print0 | while IFS= read -r -d '' file 18 | do 19 | file_no_ext=${file%%.*} 20 | file_base="`basename ${file_no_ext}`" 21 | if [[ $file_base =~ $re ]] ; then 22 | unzip -qq $file -d ${TMP_DIR} 23 | fi 24 | done 25 | echo 'Done' 26 | 27 | trd_base=${OUTPUT_BASE}/trd 28 | trl_base=${OUTPUT_BASE}/trl 29 | vad_base=${OUTPUT_BASE}/vad 30 | val_base=${OUTPUT_BASE}/val 31 | tsd_base=${OUTPUT_BASE}/tsd 32 | tsl_base=${OUTPUT_BASE}/tsl 33 | TRD=$trd_base.txt 34 | TRL=$trl_base.txt 35 | VAD=$vad_base.txt 36 | VAL=$val_base.txt 37 | TSD=$tsd_base.txt 38 | TSL=$tsl_base.txt 39 | CHAR_VOCAB=${OUTPUT_BASE}/char_vocab.txt 40 | WORD_VOCAB=${OUTPUT_BASE}/word_vocab.txt 41 | LABEL_VOCAB=${OUTPUT_BASE}/label_vocab.txt 42 | SRC=${CWD}/proc_util 43 | MOSESDEC_PATH=${CWD}/data_prep_tools/mosesdecoder 44 | MAX_SENT_LEN=300 45 | WORD_MIN_CNT=1 46 | LABEL_MIN_CNT=1 47 | N_THREADS=4 48 | 49 | python extract_plain_rcv1_dataset.py \ 50 | --src_dir ${TMP_DIR} \ 51 | --output_dir ${OUTPUT_BASE} \ 52 | --swap ${IS_SWAP} 53 | 54 | rm -rf $TMP_DIR 55 | 56 | # shuffle data to break down connection 57 | TRD_SHUF=$trd_base.shuf.txt 58 | TRL_SHUF=$trl_base.shuf.txt 59 | dd if=/dev/urandom of=rand count=$((128*1024)) status=none 60 | shuf --random-source=rand ${TRD} > ${TRD_SHUF} 61 | shuf --random-source=rand ${TRL} > ${TRL_SHUF} 62 | mv ${TRD_SHUF} ${TRD} 63 | mv ${TRL_SHUF} ${TRL} 64 | rm rand 65 | 66 | TMP_TRD=${OUTPUT_BASE}/tmp_trd.txt 67 | TMP_TRL=${OUTPUT_BASE}/tmp_trl.txt 68 | TMP_VAL=${OUTPUT_BASE}/tmp_val.txt 69 | TMP_TSL=${OUTPUT_BASE}/tmp_tsl.txt 70 | 71 | # split the original train data into the train and validation sets 72 | PREV_TRD_NUM_LINE="$(wc -l ${TRD} | cut -d' ' -f 1)" 73 | VAD_SIZE=$(expr ${PREV_TRD_NUM_LINE} / 10) 74 | cat ${TRD} | head -n ${VAD_SIZE} > ${VAD} 75 | cat ${TRL} | head -n ${VAD_SIZE} > ${VAL} 76 | cat ${TRD} | tail -n $((PREV_TRD_NUM_LINE-VAD_SIZE)) > ${TMP_TRD} 77 | cat ${TRL} | tail -n $((PREV_TRD_NUM_LINE-VAD_SIZE)) > ${TMP_TRL} 78 | mv ${TMP_TRD} ${TRD} 79 | mv ${TMP_TRL} ${TRL} 80 | 81 | # create vocabularies 82 | python ${SRC}/create_character_vocab.py --input ${TRD} --output ${CHAR_VOCAB} 83 | awk -v OFS="\t" -v LABEL_MIN_CNT=$LABEL_MIN_CNT '{ for(i=1; i<=NF; i++) w[$i]++ } END {for(i in w) { if(w[i] >= LABEL_MIN_CNT) {print w[i], i}} }' ${TRL} | sort -k 1nr > ${LABEL_VOCAB} 84 | 85 | # sort labels of each instance by label frequency 86 | python ${SRC}/sort_labels.py --input ${TRL} --label_vocab ${LABEL_VOCAB} --output ${TMP_TRL} 87 | python ${SRC}/sort_labels.py --input ${VAL} --label_vocab ${LABEL_VOCAB} --output ${TMP_VAL} 88 | python ${SRC}/sort_labels.py --input ${TSL} --label_vocab ${LABEL_VOCAB} --output ${TMP_TSL} 89 | mv ${TMP_TRL} ${TRL} 90 | mv ${TMP_VAL} ${VAL} 91 | mv ${TMP_TSL} ${TSL} 92 | 93 | # delete instances which have empty label set 94 | python ${SRC}/delete_instances.py --data ${TRD} --label ${TRL} --out_data $trd_base.delete.txt --out_label $trl_base.delete.txt --label_vocab ${LABEL_VOCAB} 95 | python ${SRC}/delete_instances.py --data ${VAD} --label ${VAL} --out_data $vad_base.delete.txt --out_label $val_base.delete.txt --label_vocab ${LABEL_VOCAB} 96 | python ${SRC}/delete_instances.py --data ${TSD} --label ${TSL} --out_data $tsd_base.delete.txt --out_label $tsl_base.delete.txt --label_vocab ${LABEL_VOCAB} 97 | 98 | trd_base=$trd_base.delete 99 | trl_base=$trl_base.delete 100 | vad_base=$vad_base.delete 101 | val_base=$val_base.delete 102 | tsd_base=$tsd_base.delete 103 | tsl_base=$tsl_base.delete 104 | 105 | # tokenize 106 | cat $trd_base.txt | ${MOSESDEC_PATH}/scripts/tokenizer/normalize-punctuation.perl -l en | \ 107 | ${MOSESDEC_PATH}/scripts/tokenizer/tokenizer.perl -a -l en -threads ${N_THREADS} | \ 108 | ${MOSESDEC_PATH}/scripts/generic/ph_numbers.perl -c > $trd_base.tok.txt 109 | 110 | cat $vad_base.txt | ${MOSESDEC_PATH}/scripts/tokenizer/normalize-punctuation.perl -l en | \ 111 | ${MOSESDEC_PATH}/scripts/tokenizer/tokenizer.perl -a -l en -threads ${N_THREADS} | \ 112 | ${MOSESDEC_PATH}/scripts/generic/ph_numbers.perl -c > $vad_base.tok.txt 113 | 114 | cat $tsd_base.txt | ${MOSESDEC_PATH}/scripts/tokenizer/normalize-punctuation.perl -l en | \ 115 | ${MOSESDEC_PATH}/scripts/tokenizer/tokenizer.perl -a -l en -threads ${N_THREADS} | \ 116 | ${MOSESDEC_PATH}/scripts/generic/ph_numbers.perl -c > $tsd_base.tok.txt 117 | 118 | trd_base=$trd_base.tok 119 | trl_base=$trl_base.tok 120 | vad_base=$vad_base.tok 121 | val_base=$val_base.tok 122 | tsd_base=$tsd_base.tok 123 | tsl_base=$tsl_base.tok 124 | 125 | # limit the length of each document 126 | python ${SRC}/cut_long_sentences.py --input $trd_base.txt --output $trd_base.max_${MAX_SENT_LEN}.txt --max ${MAX_SENT_LEN} --level word 127 | python ${SRC}/cut_long_sentences.py --input $vad_base.txt --output $vad_base.max_${MAX_SENT_LEN}.txt --max ${MAX_SENT_LEN} --level word 128 | python ${SRC}/cut_long_sentences.py --input $tsd_base.txt --output $tsd_base.max_${MAX_SENT_LEN}.txt --max ${MAX_SENT_LEN} --level word 129 | 130 | trd_base=$trd_base.max_${MAX_SENT_LEN} 131 | trl_base=$trl_base.max_${MAX_SENT_LEN} 132 | vad_base=$vad_base.max_${MAX_SENT_LEN} 133 | val_base=$val_base.max_${MAX_SENT_LEN} 134 | tsd_base=$tsd_base.max_${MAX_SENT_LEN} 135 | tsl_base=$tsl_base.max_${MAX_SENT_LEN} 136 | 137 | # lowercase 138 | ${MOSESDEC_PATH}/scripts/tokenizer/lowercase.perl -l en < $trd_base.txt > $trd_base.lc.txt 139 | ${MOSESDEC_PATH}/scripts/tokenizer/lowercase.perl -l en < $vad_base.txt > $vad_base.lc.txt 140 | ${MOSESDEC_PATH}/scripts/tokenizer/lowercase.perl -l en < $tsd_base.txt > $tsd_base.lc.txt 141 | 142 | trd_base=$trd_base.lc 143 | trl_base=$trl_base.lc 144 | vad_base=$vad_base.lc 145 | val_base=$val_base.lc 146 | tsd_base=$tsd_base.lc 147 | tsl_base=$tsl_base.lc 148 | 149 | # create the word vocabulary 150 | awk -v OFS="\t" -v WORD_MIN_CNT=$WORD_MIN_CNT '{ for(i=1; i<=NF; i++) w[$i]++ } END {for(i in w) { if(w[i] >= WORD_MIN_CNT) {print w[i], i}} }' $trd_base.txt | sort -k 1nr > ${WORD_VOCAB} 151 | -------------------------------------------------------------------------------- /data_prep_tools/extract_plain_rcv1_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from lxml import etree 4 | import os 5 | import argparse 6 | from tempfile import NamedTemporaryFile 7 | import urllib2 8 | import gzip 9 | import io 10 | 11 | 12 | class RCV1Parser(object): 13 | def __init__(self, document_path): 14 | self.body_text = None 15 | self.labels = None 16 | self.itemid = None 17 | 18 | with io.open(document_path, encoding='iso-8859-1') as f: 19 | self.contents = etree.parse(f) 20 | 21 | self.handleID(self.contents.getroot()) 22 | self.handleText(self.contents.findall("text")[0]) 23 | self.handleCodes(self.contents.findall("metadata")[0]) 24 | 25 | def handleID(self, newsitem): 26 | self.itemid = int(newsitem.attrib["itemid"]) 27 | 28 | def handleText(self, text): 29 | self.body_text = ' '.join([line.text.strip() for line in text]) 30 | 31 | def handleCodes(self, metadata): 32 | extracted_codes = [] 33 | for codes in metadata.findall("codes"): 34 | if codes.attrib['class'] == "bip:topics:1.0": 35 | extracted_codes = [ 36 | code.attrib['code'] for code in codes.findall("code")] 37 | 38 | if len(extracted_codes) > 0: 39 | self.labels = ' '.join(extracted_codes) 40 | 41 | def getID(self): 42 | return self.itemid 43 | 44 | def getBodyText(self): 45 | return self.body_text 46 | 47 | def getLabels(self): 48 | return self.labels 49 | 50 | 51 | def get_plain_text(src_dir): 52 | filenames = [os.path.join(src_dir, f) for f in os.listdir(src_dir) 53 | if os.path.isfile(os.path.join(src_dir, f)) and 54 | f.endswith('.xml')] 55 | 56 | body_label_pairs = {} 57 | for file_index, filename in enumerate(filenames): 58 | if (file_index + 1) % 1000 == 0: 59 | print('{} / {}\r'.format(file_index+1, len(filenames)), end='') 60 | 61 | pa = RCV1Parser(filename) 62 | 63 | assert pa.getID() not in body_label_pairs 64 | assert type(pa.getID()) == int 65 | 66 | if pa.getLabels() and pa.getBodyText(): 67 | body_label_pairs[pa.getID()] = (pa.getBodyText(), pa.getLabels()) 68 | 69 | print('') 70 | return body_label_pairs 71 | 72 | 73 | def download_rcv1_token_files(): 74 | 75 | base_url = ('http://jmlr.csail.mit.edu/papers/' 76 | 'volume5/lewis04a/a12-token-files') 77 | 78 | def get_filename(dataset_split, number=None): 79 | if dataset_split == 'train': 80 | assert number is None 81 | return 'lyrl2004_tokens_{}.dat.gz'.format(dataset_split) 82 | elif dataset_split == 'test': 83 | assert number is not None 84 | return 'lyrl2004_tokens_{}_pt{}.dat.gz'.format( 85 | dataset_split, number) 86 | 87 | # download the train file 88 | def download_file(url): 89 | handle = urllib2.urlopen(url) 90 | 91 | return handle.read() 92 | 93 | def uncompress_file(zipped_contents): 94 | with NamedTemporaryFile(suffix='.gz', dir='/tmp') as f: 95 | f.write(zipped_contents) 96 | f.flush() 97 | 98 | with gzip.open(f.name) as fin: 99 | uncompressed_contents = fin.read() 100 | 101 | return uncompressed_contents 102 | 103 | train_contents = uncompress_file(download_file( 104 | '/'.join([base_url, get_filename('train')]))) 105 | 106 | urls = ['/'.join([base_url, get_filename('test', subset_id)]) 107 | for subset_id in range(4)] 108 | test_contents = '\n'.join( 109 | [uncompress_file(download_file(url)) for url in urls]) 110 | 111 | def collect_doc_ids(data): 112 | ids = [] 113 | for line in data.split('\n'): 114 | if line.startswith('.I'): 115 | t, i = line.strip().split() 116 | ids.append(int(i)) 117 | 118 | return ids 119 | 120 | train_ids = {idx: 1 for idx in collect_doc_ids(train_contents)} 121 | test_ids = {idx: 1 for idx in collect_doc_ids(test_contents)} 122 | 123 | return (train_ids, test_ids) 124 | 125 | 126 | def split_data(body_label_dict, train_ids, test_ids): 127 | train_set = [] 128 | test_set = [] 129 | 130 | for doc_id, (body, label) in body_label_dict.items(): 131 | 132 | if doc_id in train_ids: 133 | train_set.append((body, label)) 134 | elif doc_id in test_ids: 135 | test_set.append((body, label)) 136 | 137 | return train_set, test_set 138 | 139 | 140 | def store_processed_data(output_dir, train_set, test_set): 141 | trd_path = os.path.join(args.output_dir, 'trd.txt') 142 | trl_path = os.path.join(args.output_dir, 'trl.txt') 143 | tsd_path = os.path.join(args.output_dir, 'tsd.txt') 144 | tsl_path = os.path.join(args.output_dir, 'tsl.txt') 145 | 146 | with io.open(trd_path, 'w', encoding='utf-8') as f_trd,\ 147 | io.open(trl_path, 'w', encoding='utf-8') as f_trl: 148 | for (body, labels) in train_set: 149 | f_trd.write(u'{}\n'.format(body.strip())) 150 | f_trl.write(u'{}\n'.format(labels.strip())) 151 | 152 | with io.open(tsd_path, 'w', encoding='utf-8') as f_tsd, \ 153 | io.open(tsl_path, 'w', encoding='utf-8') as f_tsl: 154 | for (body, labels) in test_set: 155 | f_tsd.write(u'{}\n'.format(body.strip())) 156 | f_tsl.write(u'{}\n'.format(labels.strip())) 157 | 158 | 159 | if __name__ == '__main__': 160 | parser = argparse.ArgumentParser() 161 | 162 | parser.add_argument('--src_dir', type=str, required=True) 163 | parser.add_argument('--output_dir', type=str, required=True) 164 | parser.add_argument('--swap', type=bool, default=False) 165 | 166 | args = parser.parse_args() 167 | 168 | if not os.path.exists(args.output_dir): 169 | os.makedirs(args.output_dir) 170 | 171 | print("Downloading RCV1 files to obtain train / test split") 172 | train_ids, test_ids = download_rcv1_token_files() 173 | 174 | print("Extracting plain text from xml files...") 175 | body_label_dict = get_plain_text(args.src_dir) 176 | print("Done") 177 | 178 | print("Split data according to the train / test split") 179 | train_set, test_set = split_data(body_label_dict, train_ids, test_ids) 180 | 181 | print("Store them into the output directory") 182 | if args.swap: 183 | store_processed_data(args.output_dir, test_set, train_set) 184 | else: 185 | store_processed_data(args.output_dir, train_set, test_set) 186 | -------------------------------------------------------------------------------- /create_hdf5_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import logging 4 | import io 5 | import argparse 6 | import numpy 7 | import h5py 8 | 9 | from fuel.datasets.hdf5 import H5PYDataset 10 | 11 | from utils import load_pretrained_embeddings 12 | 13 | FORMAT = '[%(asctime)s] %(levelname)s - %(message)s' 14 | logging.basicConfig(level=logging.INFO, format=FORMAT) 15 | LOGGER = logging.getLogger(__name__) 16 | 17 | 18 | def count_num_lines(doc_path): 19 | with io.open(doc_path, encoding='utf-8') as f: 20 | for i, line in enumerate(f): 21 | pass 22 | return i + 1 23 | 24 | 25 | def convert_documents(doc_path, pretrained_word_emb, _mb_sz): 26 | vocab = {word: index for (index, word) in pretrained_word_emb['vocab']} 27 | Wemb = pretrained_word_emb['Wemb'] 28 | 29 | n_max_docs = count_num_lines(doc_path) 30 | mb_sz = min(n_max_docs, _mb_sz) 31 | word_dim = Wemb.shape[1] 32 | 33 | docs = [] 34 | 35 | with open(doc_path) as f: 36 | for doc_id, doc in enumerate(f): 37 | words = doc.strip().split() 38 | idx_list = [] 39 | for word in words: 40 | if word in vocab: 41 | idx_list.append(vocab[word]) 42 | 43 | assert len(idx_list) > 0, '{}'.format(doc_id) 44 | 45 | docs.append(idx_list) 46 | 47 | if len(docs) == mb_sz: 48 | doc_vectors = numpy.zeros([mb_sz, word_dim]) 49 | 50 | for idx_in_mb, word_indices in enumerate(docs): 51 | doc_vectors[idx_in_mb] = Wemb[word_indices].mean(axis=0) 52 | 53 | yield doc_vectors[:] 54 | 55 | docs = [] 56 | 57 | if len(docs) > 0: 58 | doc_vectors = numpy.zeros([len(docs), word_dim]) 59 | 60 | for idx_in_mb, word_indices in enumerate(docs): 61 | doc_vectors[idx_in_mb] = Wemb[word_indices].mean(axis=0) 62 | 63 | yield doc_vectors[:] 64 | 65 | pass 66 | 67 | 68 | def load_vocab(path): 69 | with io.open(path, encoding='utf-8') as f: 70 | items = [line.strip().split()[1] for line in f] 71 | 72 | vocab = {item: idx for idx, item in enumerate(items)} 73 | 74 | return vocab 75 | 76 | 77 | def convert_label_sets(label_path, label_vocab): 78 | n_sets = count_num_lines(label_path) 79 | 80 | target_vectors = [None] * n_sets 81 | 82 | with open(label_path) as f: 83 | for label_set_id, label_set in enumerate(f): 84 | labels = label_set.strip().split() 85 | 86 | label_indices = [label_vocab[label] 87 | for label in labels if label in label_vocab] 88 | 89 | target_vectors[label_set_id] = label_indices 90 | 91 | return target_vectors 92 | 93 | 94 | def main(trd_path, trl_path, vad_path, val_path, tsd_path, tsl_path, 95 | label_vocab_path, word_emb_path, output_path): 96 | 97 | label_vocab = load_vocab(label_vocab_path) 98 | pretrained_emb = load_pretrained_embeddings(word_emb_path) 99 | 100 | LOGGER.info('Converting labels...') 101 | train_label = convert_label_sets(trl_path, label_vocab) 102 | valid_label = convert_label_sets(val_path, label_vocab) 103 | test_label = convert_label_sets(tsl_path, label_vocab) 104 | LOGGER.info('Done') 105 | 106 | n_trd, n_vad, n_tsd = \ 107 | len(train_label), len(valid_label), len(test_label) 108 | 109 | LOGGER.info('Number of train label sets: {}'.format(n_trd)) 110 | LOGGER.info('Number of valid label sets: {}'.format(n_vad)) 111 | LOGGER.info('Number of test label sets: {}'.format(n_tsd)) 112 | 113 | n_total_docs = n_trd + n_vad + n_tsd 114 | n_dim = pretrained_emb['Wemb'].shape[1] 115 | mb_sz = 500000 116 | 117 | with h5py.File(output_path, mode='w') as f: 118 | features = f.create_dataset('features', (n_total_docs, n_dim), 119 | dtype='float32') 120 | n_processed = 0 121 | LOGGER.info('Converting train documents ...') 122 | for train_data in convert_documents(trd_path, pretrained_emb, mb_sz): 123 | features[n_processed: n_processed+train_data.shape[0]] = train_data 124 | n_processed += train_data.shape[0] 125 | LOGGER.info('{} / {}'.format(n_processed, n_total_docs)) 126 | assert n_processed == n_trd 127 | LOGGER.info('Done') 128 | 129 | LOGGER.info('Converting valid documents ...') 130 | for valid_data in convert_documents(vad_path, pretrained_emb, mb_sz): 131 | features[n_processed: n_processed+valid_data.shape[0]] = valid_data 132 | n_processed += valid_data.shape[0] 133 | LOGGER.info('{} / {}'.format(n_processed, n_total_docs)) 134 | assert n_processed == n_vad + n_trd 135 | LOGGER.info('Done') 136 | 137 | LOGGER.info('Converting test documents ...') 138 | for test_data in convert_documents(tsd_path, pretrained_emb, mb_sz): 139 | features[n_processed: n_processed+test_data.shape[0]] = test_data 140 | n_processed += test_data.shape[0] 141 | LOGGER.info('{} / {}'.format(n_processed, n_total_docs)) 142 | assert n_processed == n_total_docs 143 | LOGGER.info('Done') 144 | 145 | _dtype = h5py.special_dtype(vlen=numpy.dtype('uint16')) 146 | targets = f.create_dataset('targets', (n_total_docs,), dtype=_dtype) 147 | all_target_labels = train_label + valid_label + test_label 148 | 149 | assert n_total_docs == len(all_target_labels) 150 | 151 | targets[...] = numpy.array(all_target_labels) 152 | 153 | # assign labels to the dataset 154 | features.dims[0].label = 'batch' 155 | features.dims[1].label = 'feature' 156 | targets.dims[0].label = 'batch' 157 | 158 | targets_shapes = f.create_dataset( 159 | 'targets_shapes', (n_total_docs, 1), dtype='int32') 160 | targets_shapes[...] = numpy.array( 161 | [len(labels) for labels in all_target_labels])[:, None] 162 | 163 | targets.dims.create_scale(targets_shapes, 'shapes') 164 | targets.dims[0].attach_scale(targets_shapes) 165 | 166 | targets_shape_labels = f.create_dataset( 167 | 'targets_shape_labels', (1,), dtype='S6') 168 | targets_shape_labels[...] = ['length'.encode('utf8')] 169 | 170 | targets.dims.create_scale(targets_shape_labels, 'shape_labels') 171 | targets.dims[0].attach_scale(targets_shape_labels) 172 | 173 | split_dict = { 174 | 'train': {'features': (0, n_trd), 'targets': (0, n_trd)}, 175 | 'valid': {'features': (n_trd, n_trd + n_vad), 176 | 'targets': (n_trd, n_trd + n_vad)}, 177 | 'test': {'features': (n_trd + n_vad, n_total_docs), 178 | 'targets': (n_trd + n_vad, n_total_docs)}} 179 | 180 | f.attrs['split'] = H5PYDataset.create_split_array(split_dict) 181 | f.flush() 182 | 183 | 184 | if __name__ == '__main__': 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument('--trd', type=str, required=True) 187 | parser.add_argument('--trl', type=str, required=True) 188 | parser.add_argument('--vad', type=str, required=True) 189 | parser.add_argument('--val', type=str, required=True) 190 | parser.add_argument('--tsd', type=str, required=True) 191 | parser.add_argument('--tsl', type=str, required=True) 192 | parser.add_argument('--label_vocab', type=str, required=True) 193 | parser.add_argument('--word_emb', type=str, required=True) 194 | parser.add_argument('--output', type=str, required=True) 195 | 196 | args = parser.parse_args() 197 | 198 | main(args.trd, args.trl, args.vad, args.val, args.tsd, args.tsl, 199 | args.label_vocab, args.word_emb, args.output) 200 | -------------------------------------------------------------------------------- /evals.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import scipy.sparse as sp 3 | import logging 4 | from six.moves import xrange 5 | from collections import OrderedDict 6 | 7 | FORMAT = '[%(asctime)s] %(levelname)s - %(message)s' 8 | logging.basicConfig(level=logging.INFO, format=FORMAT) 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | 12 | def list2sparse(A, n_labels=None): 13 | if n_labels is None: 14 | n_labels_ = 0 15 | for a in A: 16 | if n_labels_ < numpy.max(a): 17 | n_labels_ = numpy.max(a) 18 | n_labels = n_labels_ 19 | 20 | n_samples = len(A) 21 | mat = sp.dok_matrix((n_samples, n_labels)) 22 | for idx in xrange(n_samples): 23 | for item in A[idx]: 24 | mat[idx, item] = 1 25 | 26 | return mat.tocsr() 27 | 28 | 29 | def is_sparse(matrix): 30 | return sp.issparse(matrix) 31 | 32 | 33 | def is_binary_matrix(matrix): 34 | return numpy.all(numpy.logical_xor(matrix != 1, matrix != 0)) 35 | 36 | 37 | def sparse2dense(sparse_matrix): 38 | """ convert a sparse matrix into a dense matrix of 0 or 1. 39 | 40 | """ 41 | assert sp.issparse(sparse_matrix) 42 | 43 | return numpy.asarray(sparse_matrix.toarray()) 44 | 45 | 46 | def prepare_evaluation(targets, preds): 47 | if is_sparse(targets): 48 | targets = sparse2dense(targets) 49 | 50 | if is_sparse(preds): 51 | preds = sparse2dense(preds) 52 | 53 | assert numpy.array_equal(targets.shape, preds.shape) 54 | assert is_binary_matrix(targets) 55 | assert is_binary_matrix(preds) 56 | 57 | return (targets, preds) 58 | 59 | 60 | def subset_accuracy(true_targets, predictions, per_sample=False, axis=0): 61 | 62 | result = numpy.all(true_targets == predictions, axis=axis) 63 | 64 | if not per_sample: 65 | result = numpy.mean(result) 66 | 67 | return result 68 | 69 | 70 | def hamming_loss(true_targets, predictions, per_sample=False, axis=0): 71 | 72 | result = numpy.mean(numpy.logical_xor(true_targets, predictions), 73 | axis=axis) 74 | 75 | if not per_sample: 76 | result = numpy.mean(result) 77 | 78 | return result 79 | 80 | 81 | def compute_tp_fp_fn(true_targets, predictions, axis=0): 82 | # axis: axis for instance 83 | 84 | tp = numpy.sum(true_targets * predictions, axis=axis).astype('float32') 85 | fp = numpy.sum(numpy.logical_not(true_targets) * predictions, 86 | axis=axis).astype('float32') 87 | fn = numpy.sum(true_targets * numpy.logical_not(predictions), 88 | axis=axis).astype('float32') 89 | 90 | return (tp, fp, fn) 91 | 92 | 93 | def example_f1_score(true_targets, predictions, per_sample=False, axis=0): 94 | tp, fp, fn = compute_tp_fp_fn(true_targets, predictions, axis=axis) 95 | example_f1 = 2*tp / (2*tp + fp + fn) 96 | 97 | if per_sample: 98 | f1 = example_f1 99 | else: 100 | f1 = numpy.mean(example_f1) 101 | 102 | return f1 103 | 104 | 105 | def f1_score_from_stats(tp, fp, fn, average='micro'): 106 | assert len(tp) == len(fp) 107 | assert len(fp) == len(fn) 108 | 109 | if average not in set(['micro', 'macro']): 110 | raise ValueError("Specify micro or macro") 111 | 112 | if average == 'micro': 113 | f1 = 2*numpy.sum(tp) / \ 114 | float(2*numpy.sum(tp) + numpy.sum(fp) + numpy.sum(fn)) 115 | 116 | elif average == 'macro': 117 | 118 | def safe_div(a, b): 119 | """ ignore / 0, div0( [-1, 0, 1], 0 ) -> [0, 0, 0] """ 120 | with numpy.errstate(divide='ignore', invalid='ignore'): 121 | c = numpy.true_divide(a, b) 122 | return c[numpy.isfinite(c)] 123 | 124 | f1 = numpy.mean(safe_div(2*tp, 2*tp + fp + fn)) 125 | 126 | return f1 127 | 128 | 129 | def f1_score(true_targets, predictions, average='micro', axis=0): 130 | """ 131 | average: str 132 | 'micro' or 'macro' 133 | axis: 0 or 1 134 | label axis 135 | """ 136 | if average not in set(['micro', 'macro']): 137 | raise ValueError("Specify micro or macro") 138 | 139 | tp, fp, fn = compute_tp_fp_fn(true_targets, predictions, axis=axis) 140 | f1 = f1_score_from_stats(tp, fp, fn, average=average) 141 | 142 | return f1 143 | 144 | 145 | def average_precision(true_targets, predictions, per_sample=False, axis=0): 146 | pass 147 | 148 | 149 | def compute_all_measures(targets, preds, mb_sz=5000, verbose=0): 150 | """ 151 | Evaluates the model performance with respect to the following measures: 152 | Subset accuracy 153 | Hamming accuracy 154 | Example-based F1 155 | Label-based Micro F1 156 | Label-based Macro F1 157 | 158 | Parameters 159 | ---------- 160 | 161 | targets: sparse matrix of shape (n_instances, n_labels) 162 | Ground truth 163 | 164 | preds: sparse matrix of shape (n_instances, n_labels) 165 | Binary predictions by the model 166 | 167 | Returns 168 | ------- 169 | 170 | eval_ret: OrderedDict 171 | A dictionary that contains evaluation results 172 | 173 | """ 174 | assert targets.shape == preds.shape 175 | 176 | # excluding the label 177 | targets = targets[:, 1:] 178 | preds = preds[:, 1:] 179 | 180 | n_instances, n_labels = targets.shape 181 | _mb_sz = mb_sz 182 | 183 | acc_, hl_, exf1_ = [], [], [] 184 | total_tp = numpy.zeros((n_labels,)) 185 | total_fp = numpy.zeros((n_labels,)) 186 | total_fn = numpy.zeros((n_labels,)) 187 | 188 | if verbose: 189 | LOGGER.info('Started to evaluate the predictions') 190 | 191 | for idx in xrange(0, n_instances, _mb_sz): 192 | if idx + _mb_sz >= n_instances: 193 | _mb_sz = n_instances - idx 194 | 195 | trg = targets[idx:idx+_mb_sz, :] 196 | pred = preds[idx:idx+_mb_sz, :] 197 | assert trg.shape == pred.shape 198 | 199 | trg, pred = prepare_evaluation(trg, pred) 200 | 201 | acc_ += list(subset_accuracy(trg, pred, axis=1, per_sample=True)) 202 | hl_ += list(hamming_loss(trg, pred, axis=1, per_sample=True)) 203 | exf1_ += list(example_f1_score(trg, pred, axis=1, per_sample=True)) 204 | 205 | tp, fp, fn = compute_tp_fp_fn(trg, pred, axis=0) 206 | total_tp += tp 207 | total_fp += fp 208 | total_fn += fn 209 | 210 | if verbose: 211 | LOGGER.info('Evaluated {} / {} instances'.format(idx + _mb_sz, 212 | n_instances)) 213 | del trg, pred 214 | 215 | del targets, preds 216 | 217 | assert len(acc_) == n_instances 218 | 219 | acc = numpy.mean(acc_) 220 | hl = numpy.mean(hl_) 221 | exf1 = numpy.mean(exf1_) 222 | mif1 = f1_score_from_stats(tp, fp, fn, average='micro') 223 | maf1 = f1_score_from_stats(tp, fp, fn, average='macro') 224 | 225 | eval_ret = OrderedDict([('Subset accuracy', acc), 226 | ('Hamming accuracy', 1 - hl), 227 | ('Example-based F1', exf1), 228 | ('Label-based Micro F1', mif1), 229 | ('Label-based Macro F1', maf1)]) 230 | 231 | return eval_ret 232 | 233 | 234 | if __name__ == '__main__': 235 | A = numpy.array([[1, 1, 0], [1, 0, 0]]) 236 | B = numpy.array([[0, 1, 0], [1, 1, 1]]) 237 | 238 | instance_axis = 0 239 | label_axis = 1 240 | 241 | print('Example-based F1') 242 | print(example_f1_score(A, B, per_sample=True, axis=label_axis)) 243 | 244 | print('Micro F1') 245 | print(f1_score(A, B, average='micro', axis=instance_axis)) 246 | 247 | print('Macro F1') 248 | print(f1_score(A, B, average='macro', axis=instance_axis)) 249 | 250 | print('Subset accuracy') 251 | print(subset_accuracy(A, B, axis=label_axis)) 252 | 253 | print('Hamming loss') 254 | print(hamming_loss(A, B, axis=label_axis)) 255 | -------------------------------------------------------------------------------- /data_iterator.py: -------------------------------------------------------------------------------- 1 | import io 2 | import logging 3 | from itertools import count 4 | 5 | import numpy 6 | import six 7 | 8 | from fuel.datasets.text import TextFile 9 | from fuel.transformers import Merge 10 | from fuel.schemes import ConstantScheme 11 | from fuel.transformers import (Batch, Cache, Mapping, SortMapping, Padding, 12 | Filter, Transformer) 13 | 14 | FORMAT = '[%(asctime)s] %(levelname)s - %(message)s' 15 | logging.basicConfig(level=logging.INFO, format=FORMAT) 16 | LOGGER = logging.getLogger(__name__) 17 | 18 | EOS_TOKEN = '' # 0 19 | UNK_TOKEN = '' # 1 20 | EOW_TOKEN = ' ' 21 | 22 | 23 | class Shuffle(Transformer): 24 | def __init__(self, data_stream, buffer_size, **kwargs): 25 | if kwargs.get('iteration_scheme') is not None: 26 | raise ValueError 27 | super(Shuffle, self).__init__( 28 | data_stream, produces_examples=data_stream.produces_examples, 29 | **kwargs) 30 | self.buffer_size = buffer_size 31 | self.cache = [[] for _ in self.sources] 32 | 33 | def get_data(self, request=None): 34 | if request is not None: 35 | raise ValueError 36 | if not self.cache[0]: 37 | self._cache() 38 | return tuple(cache.pop() for cache in self.cache) 39 | 40 | def _cache(self): 41 | temp_caches = [[] for _ in self.sources] 42 | for i in range(self.buffer_size): 43 | try: 44 | for temp_cache, data in zip(temp_caches, 45 | next(self.child_epoch_iterator)): 46 | temp_cache.append(data) 47 | except StopIteration: 48 | if i: 49 | pass 50 | else: 51 | raise 52 | shuffled_indices = numpy.random.permutation(len(temp_caches[0])) 53 | for i in shuffled_indices: 54 | for temp_cache, cache in zip(temp_caches, self.cache): 55 | cache.append(temp_cache[i]) 56 | 57 | 58 | class SortLabels(Transformer): 59 | def __init__(self, data_stream, **kwargs): 60 | if kwargs.get('iteration_scheme') is not None: 61 | raise ValueError 62 | super(SortLabels, self).__init__( 63 | data_stream, produces_examples=data_stream.produces_examples, 64 | **kwargs) 65 | 66 | def transform_example(self, example): 67 | if 'target_labels' in self.sources: 68 | example = list(example) 69 | 70 | index = self.sources.index('target_labels') 71 | labels = example[index] 72 | example[index] = sorted(labels[:-1]) + [0] 73 | 74 | example = tuple(example) 75 | 76 | return example 77 | 78 | 79 | def _source_length(sentence_pair): 80 | """Returns the length of the second element of a sequence. 81 | 82 | This function is used to sort sentence pairs by the length of the 83 | target sentence. 84 | 85 | """ 86 | return len(sentence_pair[0]) 87 | 88 | 89 | def load_dict(filename, dict_size=0, include_unk=True, reverse=False): 90 | """Load vocab from TSV with words in last column.""" 91 | assert type(reverse) is bool 92 | 93 | dict_ = {EOS_TOKEN: 0} 94 | if include_unk: 95 | dict_[UNK_TOKEN] = 1 96 | 97 | with io.open(filename, encoding='utf8') as f: 98 | if dict_size > 0: 99 | indices = range(dict_size + len(dict_) - 1, len(dict_) - 1, -1) \ 100 | if reverse else range(len(dict_), dict_size + len(dict_)) 101 | else: 102 | indices = count(len(dict_)) 103 | dict_.update(zip(map(lambda x: x.rstrip('\n').split('\t')[-1], f), 104 | indices)) 105 | return dict_ 106 | 107 | 108 | def get_stream(source, target, source_input_dict, target_label_dict, batch_size, 109 | buffer_multiplier=100, input_token_level='word', 110 | n_input_tokens=0, n_labels=0, reverse_labels=False, 111 | max_input_length=None, max_label_length=None, pad_labels=True, 112 | is_sort=True): 113 | """Returns a stream over sentence pairs. 114 | 115 | Parameters 116 | ---------- 117 | source : list 118 | A list of files to read source languages from. 119 | target : list 120 | A list of corresponding files in the target language. 121 | source_word_dict : str 122 | Path to a tab-delimited text file whose last column contains the 123 | vocabulary. 124 | target_label_dict : str 125 | See `source_char_dict`. 126 | batch_size : int 127 | The minibatch size. 128 | buffer_multiplier : int 129 | The number of batches to load, concatenate, sort by length of 130 | source sentence, and split again; this makes batches more uniform 131 | in their sentence length and hence more computationally efficient. 132 | n_source_words : int 133 | The number of words in the source vocabulary. Pass 0 (default) to 134 | use the entire vocabulary. 135 | n_target_labels : int 136 | See `n_chars_source`. 137 | 138 | """ 139 | if len(source) != len(target): 140 | raise ValueError("number of source and target files don't match") 141 | 142 | # Read the dictionaries 143 | dicts = [load_dict(source_input_dict, dict_size=n_input_tokens), 144 | load_dict(target_label_dict, dict_size=n_labels, 145 | reverse=reverse_labels, include_unk=False)] 146 | 147 | # Open the two sets of files and merge them 148 | streams = [ 149 | TextFile(source, dicts[0], level=input_token_level, bos_token=None, 150 | eos_token=EOS_TOKEN, encoding='utf-8').get_example_stream(), 151 | TextFile(target, dicts[1], level='word', bos_token=None, 152 | unk_token=None, 153 | eos_token=EOS_TOKEN, encoding='utf-8').get_example_stream() 154 | ] 155 | merged = Merge(streams, ('source_input_tokens', 'target_labels')) 156 | if reverse_labels: 157 | merged = SortLabels(merged) 158 | 159 | # Filter sentence lengths 160 | if max_input_length or max_label_length: 161 | def filter_pair(pair): 162 | src_input_tokens, trg_labels = pair 163 | src_input_ok = (not max_input_length) or \ 164 | len(src_input_tokens) <= (max_input_length + 1) 165 | trg_label_ok = (not max_label_length) or \ 166 | len(trg_labels) <= (max_label_length + 1) 167 | 168 | return src_input_ok and trg_label_ok 169 | 170 | merged = Filter(merged, filter_pair) 171 | 172 | # Batches of approximately uniform size 173 | large_batches = Batch( 174 | merged, 175 | iteration_scheme=ConstantScheme(batch_size * buffer_multiplier) 176 | ) 177 | # sorted_batches = Mapping(large_batches, SortMapping(_source_length)) 178 | # batches = Cache(sorted_batches, ConstantScheme(batch_size)) 179 | # shuffled_batches = Shuffle(batches, buffer_multiplier) 180 | # masked_batches = Padding(shuffled_batches, 181 | # mask_sources=('source_chars', 'target_labels')) 182 | if is_sort: 183 | sorted_batches = Mapping(large_batches, SortMapping(_source_length)) 184 | else: 185 | sorted_batches = large_batches 186 | batches = Cache(sorted_batches, ConstantScheme(batch_size)) 187 | mask_sources = ('source_input_tokens', 'target_labels') 188 | masked_batches = Padding(batches, mask_sources=mask_sources) 189 | 190 | return masked_batches 191 | 192 | 193 | def load_data(src, trg, 194 | valid_src, valid_trg, 195 | input_vocab, 196 | label_vocab, 197 | n_input_tokens, 198 | n_labels, 199 | reverse_labels, 200 | input_token_level, 201 | batch_size, valid_batch_size, 202 | max_input_length, max_label_length): 203 | LOGGER.info('Loading data') 204 | 205 | dictionaries = [input_vocab, label_vocab] 206 | datasets = [src, trg] 207 | valid_datasets = [valid_src, valid_trg] 208 | 209 | # load dictionaries and invert them 210 | vocabularies = [None] * len(dictionaries) 211 | vocabularies_r = [None] * len(dictionaries) 212 | vocab_size = [n_input_tokens, n_labels] 213 | for ii, dd in enumerate(dictionaries): 214 | vocabularies[ii] = load_dict(dd, dict_size=vocab_size[ii]) if ii == 0 \ 215 | else load_dict(dd, dict_size=vocab_size[ii], 216 | reverse=reverse_labels, 217 | include_unk=False) 218 | vocabularies_r[ii] = dict() 219 | for kk, vv in six.iteritems(vocabularies[ii]): 220 | vocabularies_r[ii][vv] = kk 221 | 222 | train_stream = get_stream([datasets[0]], 223 | [datasets[1]], 224 | dictionaries[0], 225 | dictionaries[1], 226 | n_input_tokens=n_input_tokens, 227 | n_labels=n_labels, 228 | reverse_labels=reverse_labels, 229 | input_token_level=input_token_level, 230 | batch_size=batch_size, 231 | max_input_length=max_input_length, 232 | max_label_length=max_label_length) 233 | valid_stream = get_stream([valid_datasets[0]], 234 | [valid_datasets[1]], 235 | dictionaries[0], 236 | dictionaries[1], 237 | n_input_tokens=n_input_tokens, 238 | n_labels=n_labels, 239 | reverse_labels=reverse_labels, 240 | input_token_level=input_token_level, 241 | max_input_length=max_input_length, 242 | batch_size=valid_batch_size) 243 | 244 | return vocabularies_r, train_stream, valid_stream 245 | 246 | 247 | def load_test_data(src, 248 | trg, 249 | input_vocab, 250 | label_vocab, 251 | n_input_tokens, 252 | n_labels, 253 | reverse_labels, 254 | input_token_level, 255 | batch_size, 256 | max_input_length): 257 | LOGGER.info('Loading test data') 258 | 259 | dictionaries = [input_vocab, label_vocab] 260 | datasets = [src, trg] 261 | 262 | # load dictionaries and invert them 263 | vocabularies = [None] * len(dictionaries) 264 | vocabularies_r = [None] * len(dictionaries) 265 | vocab_size = [n_input_tokens, n_labels] 266 | for ii, dd in enumerate(dictionaries): 267 | vocabularies[ii] = load_dict(dd, dict_size=vocab_size[ii]) if ii == 0 \ 268 | else load_dict(dd, dict_size=vocab_size[ii], 269 | reverse=reverse_labels, 270 | include_unk=False) 271 | vocabularies_r[ii] = dict() 272 | for kk, vv in six.iteritems(vocabularies[ii]): 273 | vocabularies_r[ii][vv] = kk 274 | 275 | test_stream = get_stream([datasets[0]], 276 | [datasets[1]], 277 | dictionaries[0], 278 | dictionaries[1], 279 | n_input_tokens=n_input_tokens, 280 | n_labels=n_labels, 281 | reverse_labels=reverse_labels, 282 | input_token_level=input_token_level, 283 | batch_size=batch_size, 284 | max_input_length=max_input_length, 285 | max_label_length=None, 286 | pad_labels=False, 287 | is_sort=False) 288 | 289 | return vocabularies_r, test_stream 290 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import theano 4 | from theano import tensor 5 | import warnings 6 | import six 7 | from six.moves import xrange 8 | import itertools 9 | import copy 10 | 11 | import numpy 12 | from threading import Timer 13 | from collections import OrderedDict 14 | 15 | 16 | # push parameters to Theano shared variables 17 | def zipp(params, tparams): 18 | for kk, vv in six.iteritems(params): 19 | tparams[kk].set_value(vv) 20 | 21 | 22 | # pull parameters from Theano shared variables 23 | def unzip(zipped, params=None): 24 | if not params: 25 | new_params = OrderedDict() 26 | else: 27 | new_params = params 28 | 29 | for kk, vv in six.iteritems(zipped): 30 | new_params[kk] = vv.get_value() 31 | return new_params 32 | 33 | 34 | # Turn list of objects with .name attribute into dict 35 | def name_dict(lst): 36 | d = OrderedDict() 37 | for obj in lst: 38 | d[obj.name] = obj 39 | return d 40 | 41 | 42 | # get the list of parameters: Note that tparams must be OrderedDict 43 | def itemlist(tparams): 44 | return [vv for kk, vv in six.iteritems(tparams)] 45 | 46 | 47 | # dropout 48 | def dropout_layer(state_before, use_noise, p, trng): 49 | proj = tensor.switch(use_noise, 50 | state_before * 51 | trng.binomial(state_before.shape, 52 | p=1-p, 53 | n=1, 54 | dtype=state_before.dtype) / (1. - p), 55 | state_before) 56 | return proj 57 | 58 | 59 | # initialize Theano shared variables according to the initial parameters 60 | def init_tparams(params): 61 | tparams = OrderedDict() 62 | for kk, pp in six.iteritems(params): 63 | tparams[kk] = theano.shared(params[kk], name=kk) 64 | return tparams 65 | 66 | 67 | # load parameters 68 | def load_params(path, params, theano_var=False): 69 | pp = numpy.load(path) 70 | for kk, vv in six.iteritems(params): 71 | if kk not in pp: 72 | warnings.warn('%s is not in the archive' % kk) 73 | continue 74 | if theano_var: 75 | params[kk].set_value(pp[kk]) 76 | else: 77 | params[kk] = pp[kk] 78 | 79 | return params 80 | 81 | 82 | def load_pretrained_embeddings(path, gensim_model=True): 83 | if gensim_model: 84 | import gensim 85 | gensim_model = gensim.models.Word2Vec.load(path) 86 | pretrained_emb = dict() 87 | pretrained_emb['vocab'] = \ 88 | [(index, word) for (index, word) 89 | in enumerate(gensim_model.index2word)] 90 | pretrained_emb['Wemb'] = gensim_model.syn0 91 | else: 92 | pretrained_emb = numpy.load(path) 93 | 94 | return pretrained_emb 95 | 96 | 97 | def merge_dicts(*dict_args): 98 | ''' 99 | Given any number of dicts, shallow copy and merge into a new dict, 100 | precedence goes to key value pairs in latter dicts. 101 | ''' 102 | result = {} 103 | for dictionary in dict_args: 104 | result.update(dictionary) 105 | return result 106 | 107 | 108 | # some utilities 109 | def ortho_weight(ndim): 110 | W = numpy.random.randn(ndim, ndim) 111 | u, s, v = numpy.linalg.svd(W) 112 | return u.astype('float32') 113 | 114 | 115 | def norm_weight(nin, nout=None, scale=0.01, ortho=True): 116 | if nout is None: 117 | nout = nin 118 | if nout == nin and ortho: 119 | W = ortho_weight(nin) 120 | else: 121 | W = scale * numpy.random.randn(nin, nout) 122 | return W.astype('float32') 123 | 124 | 125 | def uniform_weight(nin, nout, scale=None): 126 | if scale is None: 127 | scale = numpy.sqrt(6. / (nin + nout)) 128 | 129 | W = numpy.random.uniform(low=-scale, high=scale, size=(nin, nout)) 130 | return W.astype('float32') 131 | 132 | 133 | def concatenate(tensor_list, axis=0): 134 | """ 135 | Alternative implementation of `theano.tensor.concatenate`. 136 | This function does exactly the same thing, but contrary to Theano's own 137 | implementation, the gradient is implemented on the GPU. 138 | Backpropagating through `theano.tensor.concatenate` yields slowdowns 139 | because the inverse operation (splitting) needs to be done on the CPU. 140 | This implementation does not have that problem. 141 | :usage: 142 | >>> x, y = theano.tensor.matrices('x', 'y') 143 | >>> c = concatenate([x, y], axis=1) 144 | :parameters: 145 | - tensor_list : list 146 | list of Theano tensor expressions that should be concatenated. 147 | - axis : int 148 | the tensors will be joined along this axis. 149 | :returns: 150 | - out : tensor 151 | the concatenated tensor expression. 152 | """ 153 | concat_size = sum(tt.shape[axis] for tt in tensor_list) 154 | 155 | output_shape = () 156 | for k in range(axis): 157 | output_shape += (tensor_list[0].shape[k], ) 158 | output_shape += (concat_size, ) 159 | for k in range(axis + 1, tensor_list[0].ndim): 160 | output_shape += (tensor_list[0].shape[k], ) 161 | 162 | out = tensor.zeros(output_shape) 163 | offset = 0 164 | for tt in tensor_list: 165 | indices = () 166 | for k in range(axis): 167 | indices += (slice(None), ) 168 | indices += (slice(offset, offset + tt.shape[axis]), ) 169 | for k in range(axis + 1, tensor_list[0].ndim): 170 | indices += (slice(None), ) 171 | 172 | out = tensor.set_subtensor(out[indices], tt) 173 | offset += tt.shape[axis] 174 | 175 | return out 176 | 177 | 178 | class RepeatedTimer(object): 179 | def __init__(self, interval, function, return_queue, 180 | *args, **kwargs): 181 | self._timer = None 182 | self._interval = interval 183 | self.function = function # function bound to the timer 184 | # put return values of the function 185 | self._ret_queue = return_queue 186 | self.args = args 187 | self.kwargs = kwargs 188 | self._is_running = False # Is the timer running? 189 | self._is_func_running = False 190 | 191 | def _run(self): 192 | self._is_running = False 193 | self.start() # set a new Timer with pre-specified interval 194 | 195 | # check if the function is running 196 | if not self._is_func_running: 197 | self._is_func_running = True 198 | try: 199 | ret = self.function(*self.args, **self.kwargs) 200 | except Exception as err: 201 | ret = [err] 202 | finally: 203 | self._ret_queue.put(ret) 204 | self._is_func_running = False 205 | 206 | def start(self): 207 | if not self._is_running: 208 | self._timer = Timer(self._interval, self._run) 209 | self._timer.start() 210 | self._is_running = True # timer is running 211 | 212 | def stop(self): 213 | self._timer.cancel() 214 | self._is_running = False 215 | self._is_func_running = False 216 | 217 | 218 | def mul2bin(data, mask, num_dims): 219 | assert data.ndim == 2 220 | n_examples = data.shape[0] 221 | 222 | new_data = numpy.zeros((n_examples, num_dims)).astype('int32') 223 | new_mask = numpy.ones_like(new_data).astype('float32') 224 | for inst_id in xrange(n_examples): 225 | nnz = int(mask[inst_id, :].sum()) # number of nonzeros 226 | indices = data[inst_id, :nnz] 227 | new_data[inst_id, indices] = 1 228 | 229 | return new_data, new_mask 230 | 231 | 232 | def prepare_character_tensor(cx): 233 | 234 | def isplit(iterable, splitters): 235 | return [list(g) for k, g in itertools.groupby(iterable, 236 | lambda x:x in splitters) if not k] 237 | 238 | # index of 'white space' is 2 239 | # sents = [isplit(sent, (2,)) + [[0]] for sent in cx] 240 | total_lengths = [numpy.sum(sent != 0) for sent in cx] 241 | sents = [isplit(sent[:length], (2,)) + [[0]] 242 | for sent, length in zip(cx, total_lengths)] 243 | num_sents = len(cx) 244 | num_words = numpy.max([len(sent) for sent in sents]) 245 | 246 | # word lengths in a batch of sentences 247 | word_lengths = \ 248 | [ 249 | # assume the end of word token 250 | [len(word) for word in sent] 251 | for sent in sents 252 | ] 253 | 254 | max_word_len = numpy.max( 255 | [ 256 | w_len for w_lengths in word_lengths 257 | for w_len in w_lengths 258 | ]) 259 | 260 | max_word_len = min(50, max_word_len) 261 | 262 | chars = numpy.zeros( 263 | [ 264 | max_word_len, 265 | num_words, 266 | num_sents 267 | ], dtype='int64') 268 | 269 | chars_mask = numpy.zeros( 270 | [ 271 | max_word_len, 272 | num_words, 273 | num_sents 274 | ], dtype='float32') 275 | 276 | for sent_idx, sent in enumerate(sents): 277 | for word_idx, word in enumerate(sent): 278 | word_len = min(len(sents[sent_idx][word_idx]), max_word_len) 279 | 280 | chars[:word_len, word_idx, sent_idx] = \ 281 | sents[sent_idx][word_idx][:word_len] 282 | 283 | chars_mask[:word_len, word_idx, sent_idx] = 1. 284 | 285 | return chars, chars_mask 286 | 287 | 288 | def beam_search(solutions, hypotheses, bs_state, k=1, 289 | decode_char=False, level='word', fixed_length=False): 290 | """Performs beam search. 291 | 292 | Parameters: 293 | ---------- 294 | solutions : dict 295 | See 296 | 297 | hypotheses : dict 298 | See 299 | 300 | bs_state : list 301 | State of beam search 302 | 303 | k : int 304 | Size of beam 305 | 306 | decode_char : boolean 307 | Character generation 308 | 309 | Returns: 310 | ------- 311 | updated_solutions : dict 312 | 313 | updated_hypotheses : dict 314 | """ 315 | 316 | assert len(bs_state) >= 2 317 | 318 | next_state, next_p = bs_state[0], bs_state[1] 319 | 320 | if level == 'word': 321 | next_alphas = bs_state[2] 322 | 323 | if decode_char: 324 | next_word_ctxs, prev_word_inps = \ 325 | bs_state[3], bs_state[4] 326 | 327 | # NLL: the lower, the better 328 | cand_scores = hypotheses['scores'][:, None] - numpy.log(next_p) 329 | cand_flat = cand_scores.flatten() 330 | # select (k - dead_k) best words or characters 331 | # argsort's default order: ascending 332 | ranks_flat = cand_flat.argsort()[:(k - solutions['num_samples'])] 333 | costs = cand_flat[ranks_flat] 334 | 335 | voc_size = next_p.shape[1] 336 | # translation candidate indices 337 | trans_indices = (ranks_flat / voc_size).astype('int64') 338 | word_indices = ranks_flat % voc_size 339 | 340 | new_hyp_samples = [] 341 | new_hyp_scores = numpy.zeros( 342 | k - solutions['num_samples']).astype('float32') 343 | new_hyp_states = [] 344 | 345 | if level == 'word': 346 | new_hyp_alignment = [] 347 | new_hyp_char_samples = [] 348 | new_hyp_prev_word_inps = [] 349 | new_hyp_word_ctxs = [] 350 | 351 | for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)): 352 | new_hyp_samples.append(hypotheses['samples'][ti] + [wi]) 353 | new_hyp_scores[idx] = copy.copy(costs[idx]) 354 | new_hyp_states.append(copy.copy(next_state[ti])) 355 | 356 | if level == 'word': 357 | new_hyp_alignment.append( 358 | hypotheses['alignments'][ti] + 359 | [copy.copy(next_alphas[ti])] 360 | ) 361 | if decode_char: 362 | # NOTE just copy of character sequences generated previously 363 | new_hyp_char_samples.append( 364 | copy.copy(hypotheses['character_samples'][ti])) 365 | new_hyp_prev_word_inps.append(copy.copy(prev_word_inps[ti])) 366 | new_hyp_word_ctxs.append(copy.copy(next_word_ctxs[ti])) 367 | 368 | # check the finished samples 369 | updated_hypotheses = OrderedDict([ 370 | ('num_samples', 0), 371 | ('samples', []), 372 | ('scores', []), 373 | ('states', []), 374 | ]) 375 | 376 | if level == 'word': 377 | updated_hypotheses['word_trg_gates'] = [] 378 | updated_hypotheses['alignments'] = [] 379 | 380 | if decode_char: 381 | updated_hypotheses['character_samples'] = [] 382 | updated_hypotheses['prev_word_inps'] = [] 383 | updated_hypotheses['word_ctxs'] = [] 384 | 385 | for idx in xrange(len(new_hyp_samples)): 386 | if (not fixed_length) and new_hyp_samples[idx][-1] == 0: 387 | # if the last word is the EOS token 388 | solutions['num_samples'] += 1 389 | 390 | solutions['samples'].append(new_hyp_samples[idx]) 391 | solutions['scores'].append(new_hyp_scores[idx]) 392 | 393 | if level == 'word': 394 | solutions['alignments'].append(new_hyp_alignment[idx]) 395 | 396 | if decode_char: 397 | solutions['character_samples'].append( 398 | new_hyp_char_samples[idx]) 399 | else: 400 | updated_hypotheses['num_samples'] += 1 401 | 402 | updated_hypotheses['samples'].append(new_hyp_samples[idx]) 403 | updated_hypotheses['scores'].append(new_hyp_scores[idx]) 404 | updated_hypotheses['states'].append(new_hyp_states[idx]) 405 | 406 | if level == 'word': 407 | updated_hypotheses['alignments'].append(new_hyp_alignment[idx]) 408 | if decode_char: 409 | updated_hypotheses['character_samples'].append( 410 | new_hyp_char_samples[idx]) 411 | updated_hypotheses['prev_word_inps'].append( 412 | new_hyp_prev_word_inps[idx]) 413 | updated_hypotheses['word_ctxs'].append( 414 | new_hyp_word_ctxs[idx]) 415 | 416 | if fixed_length: 417 | assert ((updated_hypotheses['num_samples'] + 418 | solutions['num_samples']) <= k), '{}, {}, {}, {}'.format( 419 | len(new_hyp_samples), updated_hypotheses['num_samples'], 420 | solutions['num_samples'], k) 421 | else: 422 | assert ((updated_hypotheses['num_samples'] + 423 | solutions['num_samples']) == k), '{}, {}, {}, {}'.format( 424 | len(new_hyp_samples), updated_hypotheses['num_samples'], 425 | solutions['num_samples'], k) 426 | 427 | updated_hypotheses['scores'] = numpy.array(updated_hypotheses['scores']) 428 | 429 | return solutions, updated_hypotheses 430 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from six.moves import xrange 3 | 4 | import theano 5 | from theano import tensor 6 | 7 | from utils import uniform_weight, ortho_weight, norm_weight, dropout_layer 8 | 9 | 10 | def zero_vector(length): 11 | return numpy.zeros((length, )).astype('float32') 12 | 13 | 14 | # utility function to slice a tensor 15 | def _slice(_x, n, dim): 16 | if _x.ndim == 3: 17 | return _x[:, :, n * dim:(n + 1) * dim] 18 | return _x[:, n * dim:(n + 1) * dim] 19 | 20 | 21 | def _conv(x, tparams, filter_sizes, param_prefix): 22 | # assert x.ndim == 4 23 | 24 | # TODO refactor this function to handle both 3D and 4D tensor properly 25 | n_chars = x.shape[0] 26 | n_words = x.shape[1] 27 | 28 | input_ndim = x.ndim 29 | if x.ndim == 3: 30 | n_samples = 1 31 | char_dim = x.shape[2] 32 | elif x.ndim == 4: 33 | n_samples = x.shape[2] 34 | char_dim = x.shape[3] 35 | 36 | x = x.reshape([n_chars, n_words * n_samples, char_dim]) 37 | x = x.transpose((1, 0, 2)) 38 | 39 | total_proj = [] 40 | for filter_ in filter_sizes: 41 | convolved = tensor.nnet.conv2d( 42 | x[:, None, :, :], 43 | tparams['%s_W%d' % (param_prefix, filter_)], 44 | border_mode=(filter_/2, 0)) 45 | 46 | num_filters = \ 47 | tparams['%s_W%d' % (param_prefix, filter_)].shape[0] 48 | conv_shp = convolved.shape 49 | proj = convolved.reshape([conv_shp[0], conv_shp[1], conv_shp[2]]) 50 | proj = proj.reshape([n_words, n_samples, num_filters, n_chars]) 51 | proj = proj.transpose((3, 0, 1, 2)) 52 | proj = proj + tparams['%s_b%d' % (param_prefix, filter_)] 53 | total_proj.append(proj) 54 | 55 | proj = tensor.concatenate(total_proj, axis=total_proj[0].ndim-1) 56 | # n_chars x n_words x n_samples x num_filters 57 | # proj = tensor.stack(total_proj).mean(0) 58 | 59 | # XXX what was the reasoning to the following lines? 60 | if input_ndim == 3 and proj.ndim == 4: 61 | prj_shp = proj.shape 62 | proj = proj.reshape([prj_shp[0], prj_shp[1], prj_shp[3]]) 63 | 64 | return proj 65 | 66 | 67 | def _gru(mask, x_t2gates, x_t2prpsl, h_tm1, U, Ux, activ=tensor.tanh): 68 | 69 | dim = U.shape[0] # dimension of hidden states 70 | 71 | # concatenated activations of the gates in a GRU 72 | activ_gates = tensor.nnet.sigmoid(x_t2gates + tensor.dot(h_tm1, U)) 73 | 74 | # reset and update gates 75 | reset_gate = _slice(activ_gates, 0, dim) 76 | update_gate = _slice(activ_gates, 1, dim) 77 | 78 | # compute the hidden state proposal 79 | in_prpsl = x_t2prpsl + reset_gate * tensor.dot(h_tm1, Ux) 80 | h_prpsl = activ(in_prpsl) if activ else in_prpsl 81 | 82 | # leaky integrate and obtain next hidden state 83 | h_t = update_gate * h_tm1 + (1. - update_gate) * h_prpsl 84 | 85 | # if this time step is not valid, discard the current hidden states 86 | # obtained above and copy the previous hidden states to the current ones. 87 | if mask.ndim == 1: 88 | h_t = mask[:, None] * h_t + (1. - mask)[:, None] * h_tm1 89 | elif mask.ndim == 2: 90 | h_t = mask[:, :, None] * h_t + (1. - mask)[:, :, None] * h_tm1 91 | 92 | return h_t 93 | 94 | 95 | def _compute_alignment(h_tm1, # s_{i-1} 96 | prj_annot, # proj annotations: U_a * h_j for all j 97 | Wd_att, U_att, 98 | context_mask=None): 99 | 100 | # W_a * s_{i-1} 101 | prj_h_tm1 = tensor.dot(h_tm1, Wd_att) 102 | 103 | # tanh(W_a * s_{i-1} + U_a * h_j) for all j 104 | nonlin_proj = tensor.tanh(prj_h_tm1[None, :, :] + prj_annot) 105 | 106 | # v_a^{T} * tanh(.) 107 | alpha = tensor.dot(nonlin_proj, U_att) 108 | alpha = alpha.reshape([alpha.shape[0], alpha.shape[1]]) 109 | alpha = tensor.exp(alpha - alpha.max(0, keepdims=True)) 110 | if context_mask: 111 | alpha = alpha * context_mask 112 | alpha = alpha / alpha.sum(0, keepdims=True) 113 | 114 | return alpha 115 | 116 | 117 | def _cond_gru(m_, x_, xx_, h_, ctx_, alpha_, pctx_, cc_, ctx_mask, 118 | U, Wc, W_comb_att, U_att, Ux, Wcx): 119 | 120 | # attention 121 | alpha = _compute_alignment(h_, pctx_, 122 | W_comb_att, U_att, 123 | context_mask=ctx_mask) 124 | 125 | ctx_ = (cc_ * alpha[:, :, None]).sum(0) # current context 126 | 127 | new_x_ = x_ + tensor.dot(ctx_, Wc) 128 | new_xx_ = xx_ + tensor.dot(ctx_, Wcx) 129 | 130 | h = _gru(m_, new_x_, new_xx_, h_, U, Ux) 131 | 132 | return h, ctx_, alpha.T # pstate_, preact, preactx, r, u 133 | 134 | 135 | # feedforward layer: affine transformation + point-wise nonlinearity 136 | def param_init_fflayer(options, 137 | param, 138 | prefix='ff', 139 | nin=None, 140 | nout=None, 141 | ortho=True): 142 | 143 | if type(nin) is int and type(nout) is int: 144 | param[prefix + '_W'] = uniform_weight(nin, nout) 145 | param[prefix + '_b'] = zero_vector(nout) 146 | else: 147 | assert type(nout) is list 148 | 149 | if type(nin) is int: 150 | nin = [nin] + nout[:-1] 151 | elif type(nin) is list: 152 | assert len(nin) == len(nout) 153 | 154 | for l, (in_dim, out_dim) in enumerate(zip(nin, nout)): 155 | prefix_ = prefix + '_%d' % l 156 | param[prefix_ + '_W'] = uniform_weight(in_dim, out_dim) 157 | param[prefix_ + '_b'] = zero_vector(out_dim) 158 | 159 | return param 160 | 161 | 162 | def fflayer(tparams, 163 | state_below, 164 | options, 165 | prefix='rconv', 166 | activ=tensor.tanh, 167 | **kwargs): 168 | if type(state_below) is list: 169 | n_layers = len(state_below) 170 | h = [None] * n_layers 171 | for l in xrange(n_layers): 172 | prefix_ = prefix + '_%d' % l 173 | h[l] = (tensor.dot(state_below[l], tparams[prefix_ + '_W']) + 174 | tparams[prefix_ + '_b']) 175 | h[l] = activ(h[l]) if activ else h[l] 176 | if options['use_dropout'] and ('use_noise' in kwargs and 177 | 'dp' in kwargs and 178 | 'trng' in kwargs): 179 | h[l] = dropout_layer(h[l], kwargs['use_noise'], 180 | kwargs['dp'], kwargs['trng']) 181 | else: 182 | h = (tensor.dot(state_below, tparams[prefix + '_W']) + 183 | tparams[prefix + '_b']) 184 | h = activ(h) if activ else h 185 | 186 | if options['use_dropout'] and ('use_noise' in kwargs and 187 | 'dp' in kwargs and 188 | 'trng' in kwargs): 189 | h = dropout_layer(h, kwargs['use_noise'], 190 | kwargs['dp'], kwargs['trng']) 191 | 192 | return h 193 | 194 | 195 | def param_init_conv(options, 196 | param, 197 | prefix='conv', 198 | nin=None, 199 | nout=None, 200 | filter_sizes=None): 201 | 202 | num_filters = nout 203 | for filter_ in filter_sizes: 204 | W = norm_weight(filter_*nin, num_filters) 205 | W = W.reshape([num_filters, 1, filter_, nin]) 206 | b = zero_vector(num_filters) 207 | param[prefix + '_W%d' % filter_] = W 208 | param[prefix + '_b%d' % filter_] = b 209 | 210 | return param 211 | 212 | 213 | def conv_layer(tparams, 214 | state_below, 215 | options, 216 | filter_sizes, 217 | prefix='conv', 218 | activ=tensor.tanh, 219 | **kwargs): 220 | 221 | h = _conv(state_below, tparams, filter_sizes, param_prefix=prefix) 222 | h = activ(h) if activ else h 223 | 224 | return h 225 | 226 | 227 | # GRU layer 228 | def param_init_gru(options, param, prefix='gru', nin=None, dim=None, 229 | input_conv=False): 230 | 231 | def _init_gru(in_dim, hid_dim, prefix_): 232 | param[prefix_ + '_W'] = numpy.concatenate( 233 | [ 234 | uniform_weight(in_dim, hid_dim), 235 | uniform_weight(in_dim, hid_dim) 236 | ], 237 | axis=1) 238 | param[prefix_ + '_Wx'] = uniform_weight(in_dim, hid_dim) 239 | 240 | param[prefix_ + '_U'] = numpy.concatenate( 241 | [ 242 | ortho_weight(hid_dim), ortho_weight(hid_dim) 243 | ], 244 | axis=1) 245 | param[prefix_ + '_b'] = zero_vector(2 * hid_dim) 246 | 247 | param[prefix_ + '_Ux'] = ortho_weight(hid_dim) 248 | param[prefix_ + '_bx'] = zero_vector(hid_dim) 249 | 250 | assert type(nin) is int 251 | 252 | if type(dim) is int: 253 | _init_gru(nin, dim, prefix) 254 | elif type(dim) is list: 255 | in_dim = nin 256 | for l, hid_dim in enumerate(dim): 257 | prefix_ = prefix + '_%d' % l 258 | _init_gru(in_dim, hid_dim, prefix_) 259 | in_dim = hid_dim 260 | 261 | return param 262 | 263 | 264 | def gru_layer(tparams, 265 | state_below, 266 | dims, 267 | options, 268 | prefix='gru', 269 | mask=None, 270 | input_conv=False, 271 | one_step=False, 272 | init_state=None, 273 | **kwargs): 274 | 275 | if one_step: 276 | assert init_state, 'previous state must be provided' 277 | 278 | nsteps = state_below.shape[0] 279 | if state_below.ndim == 4: 280 | n_samples = state_below.shape[2] 281 | elif state_below.ndim == 3: 282 | n_samples = state_below.shape[1] 283 | else: 284 | n_samples = 1 285 | 286 | assert type(dims) is list 287 | 288 | n_layers = len(dims) 289 | 290 | # state_below is the input word embeddings 291 | hidden_states = [None] * n_layers 292 | if init_state is None: 293 | init_state = [None] * n_layers 294 | 295 | if mask is None: 296 | if one_step: 297 | mask = tensor.alloc(1., state_below.shape[0]) 298 | else: 299 | mask = tensor.alloc(1., state_below.shape[0], 1) 300 | 301 | def _prepare_initial_state(input_x, n_samples, hid_dim, init_state_=None): 302 | if init_state_: 303 | assert input_x.ndim - 1 == init_state_.ndim, \ 304 | ('The provided initial state is assumed to have ' 305 | 'one less dimension than the input. (%d - 1) != %d') \ 306 | % (input_x.ndim, init_state_.ndim) 307 | state = [init_state_] 308 | else: 309 | if input_x.ndim == 4: 310 | state = [tensor.alloc(0., 311 | input_x.shape[1], 312 | n_samples, hid_dim)] 313 | else: 314 | state = [tensor.alloc(0., n_samples, hid_dim)] 315 | 316 | return state 317 | 318 | input_x = state_below 319 | for l, hid_dim in enumerate(dims): 320 | prefix_ = prefix + '_%d' % l 321 | 322 | proj_x_ = (tensor.dot(input_x, tparams[prefix_ + '_W']) + 323 | tparams[prefix_ + '_b']) 324 | proj_xx = (tensor.dot(input_x, tparams[prefix_ + '_Wx']) + 325 | tparams[prefix_ + '_bx']) 326 | 327 | # prepare scan arguments 328 | _step = _gru 329 | initial_state = _prepare_initial_state(input_x, n_samples, 330 | hid_dim, init_state[l]) 331 | seqs = [mask, proj_x_, proj_xx] 332 | shared_vars = [tparams[prefix_ + '_U'], tparams[prefix_ + '_Ux']] 333 | 334 | if one_step: 335 | rval = _step(*(seqs + initial_state + shared_vars)) 336 | else: 337 | rval, updates = theano.scan(_step, 338 | sequences=seqs, 339 | outputs_info=initial_state, 340 | non_sequences=shared_vars, 341 | name=prefix_ + '_layer', 342 | n_steps=nsteps, 343 | strict=True) 344 | 345 | if options['use_dropout'] and ('use_noise' in kwargs and 346 | 'dp' in kwargs and 347 | 'trng' in kwargs): 348 | rval = dropout_layer(rval, kwargs['use_noise'], 349 | kwargs['dp'], kwargs['trng']) 350 | 351 | hidden_states[l] = [rval] 352 | input_x = rval 353 | 354 | return hidden_states 355 | 356 | 357 | # Conditional GRU layer with Attention 358 | def param_init_gru_cond(options, 359 | param, 360 | prefix='gru_cond', 361 | nin=None, 362 | dim=None, 363 | dimctx=None): 364 | 365 | assert type(dim) is list 366 | last_dim = dim[-1] 367 | 368 | param = param_init_gru(options, param, prefix=prefix, nin=nin, dim=dim) 369 | 370 | prefix_ = prefix + '_%d' % (len(dim) - 1) 371 | # context to LSTM 372 | param[prefix_ + '_Wc'] = numpy.concatenate( 373 | [ 374 | uniform_weight(dimctx, last_dim), uniform_weight(dimctx, last_dim) 375 | ], axis=1 376 | ) 377 | param[prefix_ + '_Wcx'] = uniform_weight(dimctx, last_dim) 378 | 379 | # attention: combined -> hidden 380 | param[prefix_ + '_W_comb_att'] = uniform_weight(last_dim, dimctx) 381 | 382 | # attention: context -> hidden 383 | param[prefix_ + '_Wc_att'] = uniform_weight(dimctx, dimctx) 384 | 385 | # attention: hidden bias 386 | param[prefix_ + '_b_att'] = zero_vector(dimctx) 387 | 388 | # attention: 389 | param[prefix_ + '_U_att'] = uniform_weight(dimctx, 1) 390 | 391 | return param 392 | 393 | 394 | def gru_cond_layer(tparams, 395 | state_below, 396 | dims, 397 | options, 398 | prefix='gru_cond', 399 | mask=None, 400 | context=None, 401 | one_step=False, 402 | init_memory=None, 403 | init_state=None, 404 | context_mask=None, 405 | **kwargs): 406 | 407 | assert context, 'Context must be provided' 408 | 409 | if one_step: 410 | assert init_state, 'previous state must be provided' 411 | 412 | if init_state: 413 | assert type(init_state) is list 414 | 415 | nsteps = state_below.shape[0] 416 | n_layers = len(dims) 417 | hidden_states = [None] * n_layers 418 | 419 | if state_below.ndim == 3: 420 | n_samples = state_below.shape[1] 421 | else: 422 | n_samples = 1 423 | 424 | # mask 425 | if mask is None: 426 | if one_step: 427 | mask = tensor.alloc(1., nsteps) 428 | else: 429 | mask = tensor.alloc(1., nsteps, 1) 430 | 431 | # projected context 432 | assert context.ndim == 3, \ 433 | 'Context must be 3-d: #annotation x #sample x dim: %d' % context.ndim 434 | 435 | prefix_ = prefix + '_%d' % (n_layers - 1) 436 | pctx_ = (tensor.dot(context, tparams[prefix_ + '_Wc_att']) + 437 | tparams[prefix_ + '_b_att']) 438 | 439 | def _prepare_initial_state(input_x, n_samples, hid_dim, 440 | init_state_=None, attention=False): 441 | if init_state_: 442 | """ 443 | assert input_x.ndim - 1 == init_state_.ndim, \ 444 | ('The provided initial state is assumed to have ' 445 | 'one less dimension than the input. (%d - 1) != %d') \ 446 | % (input_x.ndim, init_state_.ndim) 447 | """ 448 | state = [init_state_] 449 | else: 450 | state = [tensor.alloc(0., n_samples, hid_dim)] 451 | 452 | if attention: 453 | if one_step: 454 | state += [None, None] 455 | else: 456 | state += [ 457 | tensor.alloc(0., n_samples, context.shape[2]), 458 | tensor.alloc(1., n_samples, context.shape[0]) 459 | ] 460 | 461 | return state 462 | 463 | input_x = state_below 464 | for l, hid_dim in enumerate(dims): 465 | prefix_ = prefix + '_%d' % l 466 | proj_x_ = (tensor.dot(input_x, tparams[prefix_ + '_W']) + 467 | tparams[prefix_ + '_b']) 468 | proj_xx = (tensor.dot(input_x, tparams[prefix_ + '_Wx']) + 469 | tparams[prefix_ + '_bx']) 470 | 471 | seqs = [mask, proj_x_, proj_xx] 472 | initial_state = _prepare_initial_state( 473 | input_x, n_samples, hid_dim, init_state[l], l == n_layers-1 474 | ) 475 | 476 | if l < n_layers - 1: 477 | _step = _gru 478 | shared_vars = [tparams[prefix_ + '_U'], tparams[prefix_ + '_Ux']] 479 | non_seqs = [] 480 | else: 481 | _step = _cond_gru 482 | shared_vars = [tparams[prefix_ + '_U'], tparams[prefix_ + '_Wc'], 483 | tparams[prefix_ + '_W_comb_att'], 484 | tparams[prefix_ + '_U_att'], 485 | tparams[prefix_ + '_Ux'], tparams[prefix_ + '_Wcx']] 486 | non_seqs = [pctx_, context, context_mask] 487 | 488 | if one_step: 489 | rval = _step(*( 490 | seqs + initial_state + non_seqs + shared_vars)) 491 | else: 492 | rval, updates = theano.scan( 493 | _step, 494 | sequences=seqs, 495 | outputs_info=initial_state, 496 | non_sequences=non_seqs + shared_vars, 497 | name=prefix + '_layers', 498 | n_steps=nsteps, 499 | strict=True) 500 | 501 | if l < n_layers - 1: 502 | rval = [rval] 503 | 504 | if options['use_dropout'] and ('use_noise' in kwargs and 505 | 'dp' in kwargs and 506 | 'trng' in kwargs): 507 | rval[0] = dropout_layer(rval[0], kwargs['use_noise'], 508 | kwargs['dp'], kwargs['trng']) 509 | hidden_states[l] = rval[0] 510 | input_x = rval[0] 511 | 512 | assert len(rval) == 3 513 | 514 | hidden_states = tensor.stack(hidden_states) 515 | 516 | return hidden_states, rval[1], rval[2] 517 | 518 | 519 | # layers: 'name': ('parameter initializer', 'feedforward') 520 | layers = {'ff': (param_init_fflayer, fflayer), 521 | 'conv': (param_init_conv, conv_layer), 522 | 'gru': (param_init_gru, gru_layer), 523 | 'gru_cond': (param_init_gru_cond, gru_cond_layer)} 524 | 525 | 526 | def get_layer(name): 527 | param_init, layer = layers[name] 528 | return param_init, layer 529 | -------------------------------------------------------------------------------- /wordemb_pretrain/WikiExtractor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # ============================================================================= 5 | # Version: 2.6 (Oct 14, 2013) 6 | # Author: Giuseppe Attardi (attardi@di.unipi.it), University of Pisa 7 | # Antonio Fuschetto (fuschett@di.unipi.it), University of Pisa 8 | # 9 | # Contributors: 10 | # Leonardo Souza (lsouza@amtera.com.br) 11 | # Juan Manuel Caicedo (juan@cavorite.com) 12 | # Humberto Pereira (begini@gmail.com) 13 | # Siegfried-A. Gevatter (siegfried@gevatter.com) 14 | # Pedro Assis (pedroh2306@gmail.com) 15 | # 16 | # ============================================================================= 17 | # Copyright (c) 2009. Giuseppe Attardi (attardi@di.unipi.it). 18 | # ============================================================================= 19 | # This file is part of Tanl. 20 | # 21 | # Tanl is free software; you can redistribute it and/or modify it 22 | # under the terms of the GNU General Public License, version 3, 23 | # as published by the Free Software Foundation. 24 | # 25 | # Tanl is distributed in the hope that it will be useful, 26 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 27 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 28 | # GNU General Public License for more details. 29 | # 30 | # You should have received a copy of the GNU General Public License 31 | # along with this program. If not, see . 32 | # ============================================================================= 33 | 34 | """Wikipedia Extractor: 35 | Extracts and cleans text from Wikipedia database dump and stores output in a 36 | number of files of similar size in a given directory. 37 | Each file contains several documents in Tanl document format: 38 | 39 | ... 40 | 41 | 42 | Usage: 43 | WikiExtractor.py [options] 44 | 45 | Options: 46 | -c, --compress : compress output files using bzip 47 | -b, --bytes= n[KM] : put specified bytes per output file (default 500K) 48 | -B, --base= URL : base URL for the Wikipedia pages 49 | -l, --link : preserve links 50 | -n NS, --ns NS : accepted namespaces (separated by commas) 51 | -o, --output= dir : place output files in specified directory (default 52 | current) 53 | -s, --sections : preserve sections 54 | -h, --help : display this help and exit 55 | """ 56 | 57 | import sys 58 | import gc 59 | import getopt 60 | import urllib 61 | import re 62 | import bz2 63 | import os.path 64 | from htmlentitydefs import name2codepoint 65 | 66 | ### PARAMS #################################################################### 67 | 68 | # This is obtained from the dump itself 69 | prefix = None 70 | 71 | ## 72 | # Whether to preseve links in output 73 | # 74 | keepLinks = False 75 | 76 | ## 77 | # Whether to transform sections into HTML 78 | # 79 | keepSections = False 80 | 81 | ## 82 | # Recognize only these namespaces 83 | # w: Internal links to the Wikipedia 84 | # wiktionary: Wiki dictionry 85 | # wikt: shortcut for Wikctionry 86 | # 87 | acceptedNamespaces = set(['w', 'wiktionary', 'wikt']) 88 | 89 | ## 90 | # Drop these elements from article text 91 | # 92 | discardElements = set([ 93 | 'gallery', 'timeline', 'noinclude', 'pre', 94 | 'table', 'tr', 'td', 'th', 'caption', 95 | 'form', 'input', 'select', 'option', 'textarea', 96 | 'ul', 'li', 'ol', 'dl', 'dt', 'dd', 'menu', 'dir', 97 | 'ref', 'references', 'img', 'imagemap', 'source' 98 | ]) 99 | 100 | #========================================================================= 101 | # 102 | # MediaWiki Markup Grammar 103 | 104 | # Template = "{{" [ "msg:" | "msgnw:" ] PageName { "|" [ ParameterName "=" AnyText | AnyText ] } "}}" ; 105 | # Extension = "<" ? extension ? ">" AnyText "" ; 106 | # NoWiki = "" | "" ( InlineText | BlockText ) "" ; 107 | # Parameter = "{{{" ParameterName { Parameter } [ "|" { AnyText | Parameter } ] "}}}" ; 108 | # Comment = "" | "" ; 109 | # 110 | # ParameterName = ? uppercase, lowercase, numbers, no spaces, some special chars ? ; 111 | # 112 | #=========================================================================== 113 | 114 | # Program version 115 | version = '2.5' 116 | 117 | ##### Main function ########################################################### 118 | 119 | def WikiDocument(out, id, title, text): 120 | url = get_url(id, prefix) 121 | header = '\n' % (id, url, title) 122 | # Separate header from text with a newline. 123 | #header += title + '\n' 124 | header = header.encode('utf-8') 125 | text = clean(text) 126 | footer = "\n" 127 | out.reserve(len(header) + len(text) + len(footer)) 128 | print >> out, header 129 | for line in compact(text): 130 | print >> out, line.encode('utf-8') 131 | print >> out, footer 132 | 133 | def get_url(id, prefix): 134 | return "%s?curid=%s" % (prefix, id) 135 | 136 | #------------------------------------------------------------------------------ 137 | 138 | selfClosingTags = [ 'br', 'hr', 'nobr', 'ref', 'references' ] 139 | 140 | # handle 'a' separetely, depending on keepLinks 141 | ignoredTags = [ 142 | 'b', 'big', 'blockquote', 'center', 'cite', 'div', 'em', 143 | 'font', 'h1', 'h2', 'h3', 'h4', 'hiero', 'i', 'kbd', 'nowiki', 144 | 'p', 'plaintext', 's', 'small', 'span', 'strike', 'strong', 145 | 'sub', 'sup', 'tt', 'u', 'var', 146 | ] 147 | 148 | placeholder_tags = {'math':'formula', 'code':'codice'} 149 | 150 | ## 151 | # Normalize title 152 | def normalizeTitle(title): 153 | # remove leading whitespace and underscores 154 | title = title.strip(' _') 155 | # replace sequences of whitespace and underscore chars with a single space 156 | title = re.compile(r'[\s_]+').sub(' ', title) 157 | 158 | m = re.compile(r'([^:]*):(\s*)(\S(?:.*))').match(title) 159 | if m: 160 | prefix = m.group(1) 161 | if m.group(2): 162 | optionalWhitespace = ' ' 163 | else: 164 | optionalWhitespace = '' 165 | rest = m.group(3) 166 | 167 | ns = prefix.capitalize() 168 | if ns in acceptedNamespaces: 169 | # If the prefix designates a known namespace, then it might be 170 | # followed by optional whitespace that should be removed to get 171 | # the canonical page name 172 | # (e.g., "Category: Births" should become "Category:Births"). 173 | title = ns + ":" + rest.capitalize() 174 | else: 175 | # No namespace, just capitalize first letter. 176 | # If the part before the colon is not a known namespace, then we must 177 | # not remove the space after the colon (if any), e.g., 178 | # "3001: The_Final_Odyssey" != "3001:The_Final_Odyssey". 179 | # However, to get the canonical page name we must contract multiple 180 | # spaces into one, because 181 | # "3001: The_Final_Odyssey" != "3001: The_Final_Odyssey". 182 | title = prefix.capitalize() + ":" + optionalWhitespace + rest 183 | else: 184 | # no namespace, just capitalize first letter 185 | title = title.capitalize(); 186 | return title 187 | 188 | ## 189 | # Removes HTML or XML character references and entities from a text string. 190 | # 191 | # @param text The HTML (or XML) source text. 192 | # @return The plain text, as a Unicode string, if necessary. 193 | 194 | def unescape(text): 195 | def fixup(m): 196 | text = m.group(0) 197 | code = m.group(1) 198 | try: 199 | if text[1] == "#": # character reference 200 | if text[2] == "x": 201 | return unichr(int(code[1:], 16)) 202 | else: 203 | return unichr(int(code)) 204 | else: # named entity 205 | return unichr(name2codepoint[code]) 206 | except: 207 | return text # leave as is 208 | 209 | return re.sub("&#?(\w+);", fixup, text) 210 | 211 | # Match HTML comments 212 | comment = re.compile(r'', re.DOTALL) 213 | 214 | # Match elements to ignore 215 | discard_element_patterns = [] 216 | for tag in discardElements: 217 | pattern = re.compile(r'<\s*%s\b[^>]*>.*?<\s*/\s*%s>' % (tag, tag), re.DOTALL | re.IGNORECASE) 218 | discard_element_patterns.append(pattern) 219 | 220 | # Match ignored tags 221 | ignored_tag_patterns = [] 222 | def ignoreTag(tag): 223 | left = re.compile(r'<\s*%s\b[^>]*>' % tag, re.IGNORECASE) 224 | right = re.compile(r'<\s*/\s*%s>' % tag, re.IGNORECASE) 225 | ignored_tag_patterns.append((left, right)) 226 | 227 | for tag in ignoredTags: 228 | ignoreTag(tag) 229 | 230 | # Match selfClosing HTML tags 231 | selfClosing_tag_patterns = [] 232 | for tag in selfClosingTags: 233 | pattern = re.compile(r'<\s*%s\b[^/]*/\s*>' % tag, re.DOTALL | re.IGNORECASE) 234 | selfClosing_tag_patterns.append(pattern) 235 | 236 | # Match HTML placeholder tags 237 | placeholder_tag_patterns = [] 238 | for tag, repl in placeholder_tags.items(): 239 | pattern = re.compile(r'<\s*%s(\s*| [^>]+?)>.*?<\s*/\s*%s\s*>' % (tag, tag), re.DOTALL | re.IGNORECASE) 240 | placeholder_tag_patterns.append((pattern, repl)) 241 | 242 | # Match preformatted lines 243 | preformatted = re.compile(r'^ .*?$', re.MULTILINE) 244 | 245 | # Match external links (space separates second optional parameter) 246 | externalLink = re.compile(r'\[\w+.*? (.*?)\]') 247 | externalLinkNoAnchor = re.compile(r'\[\w+[&\]]*\]') 248 | 249 | # Matches bold/italic 250 | bold_italic = re.compile(r"'''''([^']*?)'''''") 251 | bold = re.compile(r"'''(.*?)'''") 252 | italic_quote = re.compile(r"''\"(.*?)\"''") 253 | italic = re.compile(r"''([^']*)''") 254 | quote_quote = re.compile(r'""(.*?)""') 255 | 256 | # Matches space 257 | spaces = re.compile(r' {2,}') 258 | 259 | # Matches dots 260 | dots = re.compile(r'\.{4,}') 261 | 262 | # A matching function for nested expressions, e.g. namespaces and tables. 263 | def dropNested(text, openDelim, closeDelim): 264 | openRE = re.compile(openDelim) 265 | closeRE = re.compile(closeDelim) 266 | # partition text in separate blocks { } { } 267 | matches = [] # pairs (s, e) for each partition 268 | nest = 0 # nesting level 269 | start = openRE.search(text, 0) 270 | if not start: 271 | return text 272 | end = closeRE.search(text, start.end()) 273 | next = start 274 | while end: 275 | next = openRE.search(text, next.end()) 276 | if not next: # termination 277 | while nest: # close all pending 278 | nest -=1 279 | end0 = closeRE.search(text, end.end()) 280 | if end0: 281 | end = end0 282 | else: 283 | break 284 | matches.append((start.start(), end.end())) 285 | break 286 | while end.end() < next.start(): 287 | # { } { 288 | if nest: 289 | nest -= 1 290 | # try closing more 291 | last = end.end() 292 | end = closeRE.search(text, end.end()) 293 | if not end: # unbalanced 294 | if matches: 295 | span = (matches[0][0], last) 296 | else: 297 | span = (start.start(), last) 298 | matches = [span] 299 | break 300 | else: 301 | matches.append((start.start(), end.end())) 302 | # advance start, find next close 303 | start = next 304 | end = closeRE.search(text, next.end()) 305 | break # { } 306 | if next != start: 307 | # { { } 308 | nest += 1 309 | # collect text outside partitions 310 | res = '' 311 | start = 0 312 | for s, e in matches: 313 | res += text[start:s] 314 | start = e 315 | res += text[start:] 316 | return res 317 | 318 | def dropSpans(matches, text): 319 | """Drop from text the blocks identified in matches""" 320 | matches.sort() 321 | res = '' 322 | start = 0 323 | for s, e in matches: 324 | res += text[start:s] 325 | start = e 326 | res += text[start:] 327 | return res 328 | 329 | # Match interwiki links, | separates parameters. 330 | # First parameter is displayed, also trailing concatenated text included 331 | # in display, e.g. s for plural). 332 | # 333 | # Can be nested [[File:..|..[[..]]..|..]], [[Category:...]], etc. 334 | # We first expand inner ones, than remove enclosing ones. 335 | # 336 | wikiLink = re.compile(r'\[\[([^[]*?)(?:\|([^[]*?))?\]\](\w*)') 337 | 338 | parametrizedLink = re.compile(r'\[\[.*?\]\]') 339 | 340 | # Function applied to wikiLinks 341 | def make_anchor_tag(match): 342 | global keepLinks 343 | link = match.group(1) 344 | colon = link.find(':') 345 | if colon > 0 and link[:colon] not in acceptedNamespaces: 346 | return '' 347 | trail = match.group(3) 348 | anchor = match.group(2) 349 | if not anchor: 350 | anchor = link 351 | anchor += trail 352 | if keepLinks: 353 | return '%s' % (link, anchor) 354 | else: 355 | return anchor 356 | 357 | def clean(text): 358 | 359 | # FIXME: templates should be expanded 360 | # Drop transclusions (template, parser functions) 361 | # See: http://www.mediawiki.org/wiki/Help:Templates 362 | text = dropNested(text, r'{{', r'}}') 363 | 364 | # Drop tables 365 | text = dropNested(text, r'{\|', r'\|}') 366 | 367 | # Expand links 368 | text = wikiLink.sub(make_anchor_tag, text) 369 | # Drop all remaining ones 370 | text = parametrizedLink.sub('', text) 371 | 372 | # Handle external links 373 | text = externalLink.sub(r'\1', text) 374 | text = externalLinkNoAnchor.sub('', text) 375 | 376 | # Handle bold/italic/quote 377 | text = bold_italic.sub(r'\1', text) 378 | text = bold.sub(r'\1', text) 379 | text = italic_quote.sub(r'"\1"', text) 380 | text = italic.sub(r'"\1"', text) 381 | text = quote_quote.sub(r'\1', text) 382 | text = text.replace("'''", '').replace("''", '"') 383 | 384 | ################ Process HTML ############### 385 | 386 | # turn into HTML 387 | text = unescape(text) 388 | # do it again (&nbsp;) 389 | text = unescape(text) 390 | 391 | # Collect spans 392 | 393 | matches = [] 394 | # Drop HTML comments 395 | for m in comment.finditer(text): 396 | matches.append((m.start(), m.end())) 397 | 398 | # Drop self-closing tags 399 | for pattern in selfClosing_tag_patterns: 400 | for m in pattern.finditer(text): 401 | matches.append((m.start(), m.end())) 402 | 403 | # Drop ignored tags 404 | for left, right in ignored_tag_patterns: 405 | for m in left.finditer(text): 406 | matches.append((m.start(), m.end())) 407 | for m in right.finditer(text): 408 | matches.append((m.start(), m.end())) 409 | 410 | # Bulk remove all spans 411 | text = dropSpans(matches, text) 412 | 413 | # Cannot use dropSpan on these since they may be nested 414 | # Drop discarded elements 415 | for pattern in discard_element_patterns: 416 | text = pattern.sub('', text) 417 | 418 | # Expand placeholders 419 | for pattern, placeholder in placeholder_tag_patterns: 420 | index = 1 421 | for match in pattern.finditer(text): 422 | text = text.replace(match.group(), '%s_%d' % (placeholder, index)) 423 | index += 1 424 | 425 | text = text.replace('<<', u'«').replace('>>', u'»') 426 | 427 | ############################################# 428 | 429 | # Drop preformatted 430 | # This can't be done before since it may remove tags 431 | text = preformatted.sub('', text) 432 | 433 | # Cleanup text 434 | text = text.replace('\t', ' ') 435 | text = spaces.sub(' ', text) 436 | text = dots.sub('...', text) 437 | text = re.sub(u' (,:\.\)\]»)', r'\1', text) 438 | text = re.sub(u'(\[\(«) ', r'\1', text) 439 | text = re.sub(r'\n\W+?\n', '\n', text) # lines with only punctuations 440 | text = text.replace(',,', ',').replace(',.', '.') 441 | return text 442 | 443 | section = re.compile(r'(==+)\s*(.*?)\s*\1') 444 | 445 | def compact(text): 446 | """Deal with headers, lists, empty sections, residuals of tables""" 447 | page = [] # list of paragraph 448 | headers = {} # Headers for unfilled sections 449 | emptySection = False # empty sections are discarded 450 | inList = False # whether opened
    451 | 452 | for line in text.split('\n'): 453 | 454 | if not line: 455 | continue 456 | # Handle section titles 457 | m = section.match(line) 458 | if m: 459 | title = m.group(2) 460 | lev = len(m.group(1)) 461 | if keepSections: 462 | page.append("%s" % (lev, title, lev)) 463 | if title and title[-1] not in '!?': 464 | title += '.' 465 | #headers[lev] = title 466 | headers[lev] = '' 467 | # drop previous headers 468 | for i in headers.keys(): 469 | if i > lev: 470 | del headers[i] 471 | emptySection = True 472 | continue 473 | # Handle page title 474 | if line.startswith('++'): 475 | title = line[2:-2] 476 | if title: 477 | if title[-1] not in '!?': 478 | title += '.' 479 | page.append(title) 480 | # handle lists 481 | elif line[0] in '*#:;': 482 | if keepSections: 483 | page.append("
  • %s
  • " % line[1:]) 484 | else: 485 | continue 486 | # Drop residuals of lists 487 | elif line[0] in '{|' or line[-1] in '}': 488 | continue 489 | # Drop irrelevant lines 490 | elif (line[0] == '(' and line[-1] == ')') or line.strip('.-') == '': 491 | continue 492 | elif len(headers): 493 | items = headers.items() 494 | items.sort() 495 | for (i, v) in items: 496 | page.append(v) 497 | headers.clear() 498 | page.append(line) # first line 499 | emptySection = False 500 | elif not emptySection: 501 | page.append(line) 502 | 503 | return page 504 | 505 | def handle_unicode(entity): 506 | numeric_code = int(entity[2:-1]) 507 | if numeric_code >= 0x10000: return '' 508 | return unichr(numeric_code) 509 | 510 | #------------------------------------------------------------------------------ 511 | 512 | class OutputSplitter: 513 | def __init__(self, compress, max_file_size, path_name): 514 | self.dir_index = 0 515 | self.file_index = -1 516 | self.compress = compress 517 | self.max_file_size = max_file_size 518 | self.path_name = path_name 519 | self.out_file = self.open_next_file() 520 | 521 | def reserve(self, size): 522 | cur_file_size = self.out_file.tell() 523 | if cur_file_size + size > self.max_file_size: 524 | self.close() 525 | self.out_file = self.open_next_file() 526 | 527 | def write(self, text): 528 | self.out_file.write(text) 529 | 530 | def close(self): 531 | self.out_file.close() 532 | 533 | def open_next_file(self): 534 | self.file_index += 1 535 | if self.file_index == 100: 536 | self.dir_index += 1 537 | self.file_index = 0 538 | dir_name = self.dir_name() 539 | if not os.path.isdir(dir_name): 540 | os.makedirs(dir_name) 541 | file_name = os.path.join(dir_name, self.file_name()) 542 | if self.compress: 543 | return bz2.BZ2File(file_name + '.bz2', 'w') 544 | else: 545 | return open(file_name, 'w') 546 | 547 | def dir_name(self): 548 | char1 = self.dir_index % 26 549 | char2 = self.dir_index / 26 % 26 550 | return os.path.join(self.path_name, '%c%c' % (ord('A') + char2, ord('A') + char1)) 551 | 552 | def file_name(self): 553 | return 'wiki_%02d' % self.file_index 554 | 555 | ### READER ################################################################### 556 | 557 | tagRE = re.compile(r'(.*?)<(/?\w+)[^>]*>(?:([^<]*)(<.*?>)?)?') 558 | 559 | def process_data(input, output): 560 | global prefix 561 | 562 | page = [] 563 | id = None 564 | inText = False 565 | redirect = False 566 | for line in input: 567 | line = line.decode('utf-8') 568 | tag = '' 569 | if '<' in line: 570 | m = tagRE.search(line) 571 | if m: 572 | tag = m.group(2) 573 | if tag == 'page': 574 | page = [] 575 | redirect = False 576 | elif tag == 'id' and not id: 577 | id = m.group(3) 578 | elif tag == 'title': 579 | title = m.group(3) 580 | elif tag == 'redirect': 581 | redirect = True 582 | elif tag == 'text': 583 | inText = True 584 | line = line[m.start(3):m.end(3)] + '\n' 585 | page.append(line) 586 | if m.lastindex == 4: # open-close 587 | inText = False 588 | elif tag == '/text': 589 | if m.group(1): 590 | page.append(m.group(1) + '\n') 591 | inText = False 592 | elif inText: 593 | page.append(line) 594 | elif tag == '/page': 595 | colon = title.find(':') 596 | if (colon < 0 or title[:colon] in acceptedNamespaces) and \ 597 | not redirect: 598 | print id, title.encode('utf-8') 599 | sys.stdout.flush() 600 | WikiDocument(output, id, title, ''.join(page)) 601 | id = None 602 | page = [] 603 | elif tag == 'base': 604 | # discover prefix from the xml dump file 605 | # /mediawiki/siteinfo/base 606 | base = m.group(3) 607 | prefix = base[:base.rfind("/")] 608 | 609 | ### CL INTERFACE ############################################################ 610 | 611 | def show_help(): 612 | print >> sys.stdout, __doc__, 613 | 614 | def show_usage(script_name): 615 | print >> sys.stderr, 'Usage: %s [options]' % script_name 616 | 617 | ## 618 | # Minimum size of output files 619 | minFileSize = 200 * 1024 620 | 621 | def main(): 622 | global keepLinks, keepSections, prefix, acceptedNamespaces 623 | script_name = os.path.basename(sys.argv[0]) 624 | 625 | try: 626 | long_opts = ['help', 'compress', 'bytes=', 'basename=', 'links', 'ns=', 'sections', 'output=', 'version'] 627 | opts, args = getopt.gnu_getopt(sys.argv[1:], 'cb:hln:o:B:sv', long_opts) 628 | except getopt.GetoptError: 629 | show_usage(script_name) 630 | sys.exit(1) 631 | 632 | compress = False 633 | file_size = 500 * 1024 634 | output_dir = '.' 635 | 636 | for opt, arg in opts: 637 | if opt in ('-h', '--help'): 638 | show_help() 639 | sys.exit() 640 | elif opt in ('-c', '--compress'): 641 | compress = True 642 | elif opt in ('-l', '--links'): 643 | keepLinks = True 644 | elif opt in ('-s', '--sections'): 645 | keepSections = True 646 | elif opt in ('-B', '--base'): 647 | prefix = arg 648 | elif opt in ('-b', '--bytes'): 649 | try: 650 | if arg[-1] in 'kK': 651 | file_size = int(arg[:-1]) * 1024 652 | elif arg[-1] in 'mM': 653 | file_size = int(arg[:-1]) * 1024 * 1024 654 | else: 655 | file_size = int(arg) 656 | if file_size < minFileSize: raise ValueError() 657 | except ValueError: 658 | print >> sys.stderr, \ 659 | '%s: %s: Insufficient or invalid size' % (script_name, arg) 660 | sys.exit(2) 661 | elif opt in ('-n', '--ns'): 662 | acceptedNamespaces = set(arg.split(',')) 663 | elif opt in ('-o', '--output'): 664 | output_dir = arg 665 | elif opt in ('-v', '--version'): 666 | print 'WikiExtractor.py version:', version 667 | sys.exit(0) 668 | 669 | if len(args) > 0: 670 | show_usage(script_name) 671 | sys.exit(4) 672 | 673 | if not os.path.isdir(output_dir): 674 | try: 675 | os.makedirs(output_dir) 676 | except: 677 | print >> sys.stderr, 'Could not create: ', output_dir 678 | return 679 | 680 | if not keepLinks: 681 | ignoreTag('a') 682 | 683 | output_splitter = OutputSplitter(compress, file_size, output_dir) 684 | process_data(sys.stdin, output_splitter) 685 | output_splitter.close() 686 | 687 | if __name__ == '__main__': 688 | main() 689 | --------------------------------------------------------------------------------