├── ch_dp_gan ├── util │ ├── __init__.py │ ├── model_util.py │ ├── ltp.py │ ├── berkely.py │ └── file_util.py ├── treebuilder │ ├── __init__.py │ └── dp_nr │ │ ├── __init__.py │ │ └── parser.py ├── dataset │ ├── __init__.py │ └── cdtb.py ├── structure │ ├── __init__.py │ └── vocab.py ├── .gitmodules ├── requirements.txt ├── config.py ├── interface.py ├── new_ctb.py ├── .gitignore ├── berkeleyparser │ └── README └── pre_process.py ├── en_dp_gan ├── model │ ├── __init__.py │ ├── stacked_parser_tdt_gan3 │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── model.cpython-36.pyc │ │ │ ├── model.cpython-37.pyc │ │ │ ├── parser.cpython-36.pyc │ │ │ ├── parser.cpython-37.pyc │ │ │ ├── trainer.cpython-36.pyc │ │ │ ├── trainer.cpython-37.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── __init__.cpython-37.pyc │ │ └── parser.py │ ├── __pycache__ │ │ ├── metric.cpython-36.pyc │ │ ├── metric.cpython-37.pyc │ │ ├── parser.cpython-36.pyc │ │ ├── parser.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── rst_tree.cpython-36.pyc │ │ ├── tree_obj.cpython-36.pyc │ │ ├── masked_gru.cpython-36.pyc │ │ ├── split_attn.cpython-36.pyc │ │ ├── biaffine_attn.cpython-36.pyc │ │ └── multi_head_attn.cpython-36.pyc │ └── parser.py ├── util │ ├── __init__.py │ ├── __init__.pyc │ ├── file_util.pyc │ ├── __pycache__ │ │ ├── eval.cpython-36.pyc │ │ ├── drawer.cpython-36.pyc │ │ ├── drawer.cpython-37.pyc │ │ ├── radam.cpython-36.pyc │ │ ├── radam.cpython-37.pyc │ │ ├── DL_tricks.cpython-36.pyc │ │ ├── DL_tricks.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── file_util.cpython-36.pyc │ │ ├── file_util.cpython-37.pyc │ │ ├── patterns.cpython-36.pyc │ │ ├── patterns.cpython-37.pyc │ │ ├── rst_utils.cpython-36.pyc │ │ ├── rst_utils.cpython-37.pyc │ │ ├── data_builder.cpython-36.pyc │ │ ├── data_builder.cpython-37.pyc │ │ └── model_util.cpython-37.pyc │ ├── model_util.py │ ├── DL_tricks.py │ ├── patterns.py │ ├── file_util.py │ ├── drawer.py │ ├── eval.py │ └── data_builder.py ├── structure │ ├── __init__.py │ ├── __pycache__ │ │ ├── rst.cpython-36.pyc │ │ ├── rst.cpython-37.pyc │ │ ├── nodes.cpython-36.pyc │ │ ├── nodes.cpython-37.pyc │ │ ├── vocab.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── rst_tree.cpython-36.pyc │ │ ├── rst_tree.cpython-37.pyc │ │ ├── tree_obj.cpython-36.pyc │ │ └── tree_obj.cpython-37.pyc │ ├── rst.py │ ├── tree_obj.py │ └── vocab.py ├── requirements.txt ├── __pycache__ │ ├── config.cpython-36.pyc │ ├── config.cpython-37.pyc │ ├── path_config.cpython-36.pyc │ └── path_config.cpython-37.pyc ├── .idea │ ├── misc.xml │ ├── modules.xml │ ├── en_dp_gan.iml │ └── inspectionProfiles │ │ └── Project_Default.xml ├── main.py ├── config.py └── path_config.py ├── ch_dp_gan_xlnet ├── util │ ├── __init__.py │ ├── model_util.py │ ├── ltp.py │ ├── berkely.py │ └── file_util.py ├── treebuilder │ ├── __init__.py │ └── dp_nr │ │ ├── __init__.py │ │ └── parser.py ├── dataset │ ├── __init__.py │ └── cdtb.py ├── structure │ ├── __init__.py │ └── vocab.py ├── .gitmodules ├── requirements.txt ├── config.py ├── interface.py ├── new_ctb.py ├── .gitignore ├── berkeleyparser │ └── README └── pre_process.py ├── en_dp_gan_xlnet ├── model │ ├── __init__.py │ ├── .DS_Store │ ├── stacked_parser_tdt_xlnet │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── model.cpython-36.pyc │ │ │ ├── model.cpython-37.pyc │ │ │ ├── parser.cpython-36.pyc │ │ │ ├── parser.cpython-37.pyc │ │ │ ├── trainer.cpython-36.pyc │ │ │ ├── trainer.cpython-37.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── __init__.cpython-37.pyc │ │ └── parser.py │ ├── __pycache__ │ │ ├── metric.cpython-36.pyc │ │ ├── metric.cpython-37.pyc │ │ ├── parser.cpython-36.pyc │ │ ├── parser.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── masked_gru.cpython-36.pyc │ │ ├── rst_tree.cpython-36.pyc │ │ ├── split_attn.cpython-36.pyc │ │ ├── tree_obj.cpython-36.pyc │ │ ├── biaffine_attn.cpython-36.pyc │ │ └── multi_head_attn.cpython-36.pyc │ └── parser.py ├── util │ ├── __init__.py │ ├── __init__.pyc │ ├── con2dep.pyc │ ├── file_util.pyc │ ├── __pycache__ │ │ ├── drawer.cpython-36.pyc │ │ ├── drawer.cpython-37.pyc │ │ ├── eval.cpython-36.pyc │ │ ├── radam.cpython-36.pyc │ │ ├── radam.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── con2dep.cpython-36.pyc │ │ ├── patterns.cpython-36.pyc │ │ ├── patterns.cpython-37.pyc │ │ ├── DL_tricks.cpython-36.pyc │ │ ├── DL_tricks.cpython-37.pyc │ │ ├── file_util.cpython-36.pyc │ │ ├── file_util.cpython-37.pyc │ │ ├── rst_utils.cpython-36.pyc │ │ ├── rst_utils.cpython-37.pyc │ │ ├── data_builder.cpython-36.pyc │ │ └── data_builder.cpython-37.pyc │ ├── DL_tricks.py │ ├── patterns.py │ ├── con2dep.py │ ├── drawer.py │ ├── file_util.py │ └── eval.py ├── structure │ ├── __init__.py │ ├── __pycache__ │ │ ├── rst.cpython-36.pyc │ │ ├── rst.cpython-37.pyc │ │ ├── nodes.cpython-36.pyc │ │ ├── nodes.cpython-37.pyc │ │ ├── rst_xl.cpython-36.pyc │ │ ├── rst_xl.cpython-37.pyc │ │ ├── vocab.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── rst_tree.cpython-36.pyc │ │ ├── rst_tree.cpython-37.pyc │ │ ├── tree_obj.cpython-36.pyc │ │ └── tree_obj.cpython-37.pyc │ ├── rst.py │ ├── rst_xl.py │ ├── tree_obj.py │ ├── vocab.py │ └── cdtb.py ├── .DS_Store ├── requirements.txt ├── __pycache__ │ ├── config.cpython-36.pyc │ ├── config.cpython-37.pyc │ ├── path_config.cpython-36.pyc │ └── path_config.cpython-37.pyc ├── .idea │ ├── misc.xml │ ├── modules.xml │ ├── en_dp_gan_xlnet.iml │ └── inspectionProfiles │ │ └── Project_Default.xml ├── main.py ├── config.py └── path_config.py ├── .DS_Store └── README.md /ch_dp_gan/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /en_dp_gan/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /en_dp_gan/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ch_dp_gan/treebuilder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/treebuilder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/.DS_Store -------------------------------------------------------------------------------- /ch_dp_gan/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | from .cdtb import CDTB 4 | -------------------------------------------------------------------------------- /ch_dp_gan/structure/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | from .nodes import * 4 | -------------------------------------------------------------------------------- /en_dp_gan/structure/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | from .nodes import * 4 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | from .cdtb import CDTB 4 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/structure/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | from .nodes import * 4 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | from .nodes import * 4 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/.DS_Store -------------------------------------------------------------------------------- /en_dp_gan/requirements.txt: -------------------------------------------------------------------------------- 1 | nltk 2 | torch>=0.4 3 | thulac 4 | tqdm 5 | gensim 6 | numpy 7 | tensorboardX 8 | matplotlib 9 | -------------------------------------------------------------------------------- /en_dp_gan/util/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__init__.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/file_util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/file_util.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/.DS_Store -------------------------------------------------------------------------------- /ch_dp_gan/treebuilder/dp_nr/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ -------------------------------------------------------------------------------- /en_dp_gan_xlnet/requirements.txt: -------------------------------------------------------------------------------- 1 | nltk 2 | torch>=0.4 3 | thulac 4 | tqdm 5 | gensim 6 | numpy 7 | tensorboardX 8 | matplotlib 9 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__init__.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/con2dep.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/con2dep.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/file_util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/file_util.pyc -------------------------------------------------------------------------------- /ch_dp_gan/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "berkeleyparser"] 2 | path = berkeleyparser 3 | url = git@github.com:slavpetrov/berkeleyparser.git 4 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "berkeleyparser"] 2 | path = berkeleyparser 3 | url = git@github.com:slavpetrov/berkeleyparser.git 4 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/treebuilder/dp_nr/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ -------------------------------------------------------------------------------- /en_dp_gan/model/stacked_parser_tdt_gan3/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ -------------------------------------------------------------------------------- /en_dp_gan/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/__pycache__/path_config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/__pycache__/path_config.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/__pycache__/path_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/__pycache__/path_config.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/parser.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/parser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/parser.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/drawer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/drawer.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/drawer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/drawer.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/radam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/radam.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/radam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/radam.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/rst_tree.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/rst_tree.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/tree_obj.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/tree_obj.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/rst.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/rst.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/rst.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/rst.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/DL_tricks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/DL_tricks.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/DL_tricks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/DL_tricks.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/file_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/file_util.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/file_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/file_util.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/patterns.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/patterns.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/patterns.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/patterns.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/rst_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/rst_utils.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/rst_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/rst_utils.cpython-37.pyc -------------------------------------------------------------------------------- /ch_dp_gan/requirements.txt: -------------------------------------------------------------------------------- 1 | nltk 2 | torch==0.4.1 3 | thulac 4 | tqdm 5 | gensim 6 | numpy 7 | tensorboardX 8 | matplotlib 9 | sklearn==0.20.0 10 | pyltp==0.2.1 11 | -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/masked_gru.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/masked_gru.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/split_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/split_attn.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/nodes.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/nodes.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/nodes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/nodes.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/vocab.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/vocab.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/data_builder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/data_builder.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/data_builder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/data_builder.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/util/__pycache__/model_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/model_util.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/__pycache__/path_config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/__pycache__/path_config.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/__pycache__/path_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/__pycache__/path_config.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/drawer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/drawer.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/drawer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/drawer.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/radam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/radam.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/radam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/radam.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/biaffine_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/biaffine_attn.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/rst_tree.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/rst_tree.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/rst_tree.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/rst_tree.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/tree_obj.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/tree_obj.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/structure/__pycache__/tree_obj.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/tree_obj.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/parser.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/parser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/parser.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/rst.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/rst.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/con2dep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/con2dep.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/patterns.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/patterns.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/patterns.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/patterns.cpython-37.pyc -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/requirements.txt: -------------------------------------------------------------------------------- 1 | nltk 2 | torch==0.4.1 3 | thulac 4 | tqdm 5 | gensim 6 | numpy 7 | tensorboardX 8 | matplotlib 9 | sklearn==0.20.0 10 | pyltp==0.2.1 11 | -------------------------------------------------------------------------------- /en_dp_gan/model/__pycache__/multi_head_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/multi_head_attn.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/masked_gru.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/masked_gru.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/rst_tree.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/rst_tree.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/split_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/split_attn.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/tree_obj.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/tree_obj.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/nodes.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/nodes.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/nodes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/nodes.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/rst_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst_xl.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/rst_xl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst_xl.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/vocab.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/vocab.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/DL_tricks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/DL_tricks.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/DL_tricks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/DL_tricks.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/file_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/file_util.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/file_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/file_util.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/rst_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/rst_utils.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/rst_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/rst_utils.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/rst_tree.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst_tree.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/rst_tree.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst_tree.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/tree_obj.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/tree_obj.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/__pycache__/tree_obj.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/tree_obj.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/data_builder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/data_builder.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/__pycache__/data_builder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/data_builder.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/biaffine_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/biaffine_attn.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/__pycache__/multi_head_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/multi_head_attn.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/parser.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/parser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/parser.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/parser.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/parser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/parser.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /en_dp_gan/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /ch_dp_gan/util/model_util.py: -------------------------------------------------------------------------------- 1 | 2 | def get_parameter_number(model): 3 | total_num = sum(p.numel() for p in model.parameters()) 4 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 5 | print('Total: ' + str(total_num), 'Trainable: ' + str(trainable_num)) 6 | -------------------------------------------------------------------------------- /en_dp_gan/util/model_util.py: -------------------------------------------------------------------------------- 1 | 2 | def get_parameter_number(model): 3 | total_num = sum(p.numel() for p in model.parameters()) 4 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 5 | print('Total: ' + str(total_num), 'Trainable: ' + str(trainable_num)) 6 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/util/model_util.py: -------------------------------------------------------------------------------- 1 | 2 | def get_parameter_number(model): 3 | total_num = sum(p.numel() for p in model.parameters()) 4 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 5 | print('Total: ' + str(total_num), 'Trainable: ' + str(trainable_num)) 6 | -------------------------------------------------------------------------------- /en_dp_gan/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /en_dp_gan/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ 8 | import logging 9 | from sys import argv 10 | from model.stacked_parser_tdt_gan3.trainer import Trainer 11 | 12 | 13 | if __name__ == '__main__': 14 | logging.basicConfig(level=logging.INFO) 15 | test_desc = argv[1] if len(argv) >= 2 else "no message." 16 | trainer = Trainer() 17 | trainer.train(test_desc) 18 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ 8 | import logging 9 | from sys import argv 10 | from model.stacked_parser_tdt_xlnet.trainer import Trainer 11 | 12 | 13 | if __name__ == '__main__': 14 | logging.basicConfig(level=logging.INFO) 15 | test_desc = argv[1] if len(argv) >= 2 else "no message." 16 | trainer = Trainer() 17 | trainer.train(test_desc) 18 | -------------------------------------------------------------------------------- /en_dp_gan/.idea/en_dp_gan.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/.idea/en_dp_gan_xlnet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /en_dp_gan/util/DL_tricks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Lynx Zhang 3 | Date: 2020. 4 | Email: zzlynx@outlook.com 5 | """ 6 | import torch 7 | 8 | 9 | def weights_init_normal(m): 10 | classname = m.__class__.__name__ 11 | if (classname.find("Conv") != -1) or (classname.find("Linear") != -1): 12 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 13 | elif classname.find("BatchNorm2d") != -1: 14 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 15 | torch.nn.init.constant_(m.bias.data, 0.0) 16 | 17 | 18 | def log(x): 19 | return torch.log(x + 1e-8) 20 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/DL_tricks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Lynx Zhang 3 | Date: 2020. 4 | Email: zzlynx@outlook.com 5 | """ 6 | import torch 7 | 8 | 9 | def weights_init_normal(m): 10 | classname = m.__class__.__name__ 11 | if (classname.find("Conv") != -1) or (classname.find("Linear") != -1): 12 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 13 | elif classname.find("BatchNorm2d") != -1: 14 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 15 | torch.nn.init.constant_(m.bias.data, 0.0) 16 | 17 | 18 | def log(x): 19 | return torch.log(x + 1e-8) 20 | -------------------------------------------------------------------------------- /ch_dp_gan/config.py: -------------------------------------------------------------------------------- 1 | # UTF-8 2 | # Author: Longyin Zhang 3 | # Date: 4 | 5 | from util.file_util import * 6 | 7 | SET = 66 8 | nr2ids = load_data("data/nr2ids.pkl") 9 | ids2nr = load_data("data/ids2nr.pkl") 10 | SEED = 19 # 17, 7, 14, 2, 13, 19, 21, 22 11 | NR_LABEL = 46 12 | USE_GAN = True 13 | MAX_H, MAX_W = 15, 25 14 | in_channel_G, out_channel_G, ker_h_G, ker_w_G, strip_G = 2, 16, 3, MAX_W, 1 15 | p_w_G, p_h_G = 3, 1 16 | MAX_POOLING = True 17 | WARM_UP_EP = 7 if USE_GAN else 20 18 | LABEL_SWITCH, SWITCH_ITE = False, 4 19 | USE_BOUND, BOUND_INFO_SIZE = False, 30 20 | LAYER_NORM_USE = True 21 | CONTEXT_ATT, ML_ATT_HIDDEN, HEADS = False, 128, 2 22 | -------------------------------------------------------------------------------- /en_dp_gan/util/patterns.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ 8 | import re 9 | 10 | leaf_parttern = r' *\( \w+ \(leaf .+' 11 | leaf_re = re.compile(leaf_parttern) 12 | node_parttern = r' *\( \w+ \(span .+' 13 | node_re = re.compile(node_parttern) 14 | end_parttern = r'\s*\)\s*' 15 | end_re = re.compile(end_parttern) 16 | nodetype_parttern = r' *\( (\w+) .+' 17 | type_re = re.compile(nodetype_parttern) 18 | rel_parttern = r' *\( \w+ \(.+\) \(rel2par ([\w-]+).+' 19 | rel_re = re.compile(rel_parttern) 20 | node_leaf_parttern = r' *\( \w+ \((\w+) \d+.*\).+' 21 | node_leaf_re = re.compile(node_leaf_parttern) 22 | upper_parttern = r'[a-z].*' 23 | upper_re = re.compile(upper_parttern) 24 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/patterns.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ 8 | import re 9 | 10 | leaf_parttern = r' *\( \w+ \(leaf .+' 11 | leaf_re = re.compile(leaf_parttern) 12 | node_parttern = r' *\( \w+ \(span .+' 13 | node_re = re.compile(node_parttern) 14 | end_parttern = r'\s*\)\s*' 15 | end_re = re.compile(end_parttern) 16 | nodetype_parttern = r' *\( (\w+) .+' 17 | type_re = re.compile(nodetype_parttern) 18 | rel_parttern = r' *\( \w+ \(.+\) \(rel2par ([\w-]+).+' 19 | rel_re = re.compile(rel_parttern) 20 | node_leaf_parttern = r' *\( \w+ \((\w+) \d+.*\).+' 21 | node_leaf_re = re.compile(node_leaf_parttern) 22 | upper_parttern = r'[a-z].*' 23 | upper_re = re.compile(upper_parttern) 24 | -------------------------------------------------------------------------------- /en_dp_gan/.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 16 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 16 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/config.py: -------------------------------------------------------------------------------- 1 | # UTF-8 2 | # Author: Longyin Zhang 3 | # Date: 4 | 5 | from util.file_util import * 6 | 7 | SET = 65 8 | nr2ids = load_data("data/nr2ids.pkl") 9 | ids2nr = load_data("data/ids2nr.pkl") 10 | SEED = 19 # 17, 7, 14, 2, 13, 19, 21, 22 11 | NR_LABEL = 46 12 | USE_GAN = True 13 | MAX_H, MAX_W = 15, 25 14 | in_channel_G, out_channel_G, ker_h_G, ker_w_G, strip_G = 2, 16, 3, MAX_W, 1 15 | p_w_G, p_h_G = 3, 1 16 | MAX_POOLING = True 17 | WARM_UP_EP = 7 if USE_GAN else 50 18 | LABEL_SWITCH, SWITCH_ITE = False, 4 19 | USE_BOUND, BOUND_INFO_SIZE = False, 30 20 | LAYER_NORM_USE = True 21 | CONTEXT_ATT, ML_ATT_HIDDEN, HEADS = False, 128, 2 22 | 23 | EDU_ENCODE_V = 1 24 | Finetune_XL = True 25 | XLNET_PATH = "data/pytorch_model.bin" 26 | XLNET_SIZE, CHUNK_SIZE = 768, 512 27 | -------------------------------------------------------------------------------- /ch_dp_gan/util/ltp.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | import os 4 | import pyltp 5 | 6 | LTP_DATA_DIR = 'pub/pyltp_models' 7 | path_to_tagger = os.path.join(LTP_DATA_DIR, 'pos.model') 8 | path_to_parser = os.path.join(LTP_DATA_DIR, "parser.model") 9 | 10 | 11 | class LTPParser: 12 | def __init__(self): 13 | self.tagger = pyltp.Postagger() 14 | self.parser = pyltp.Parser() 15 | self.tagger.load(path_to_tagger) 16 | self.parser.load(path_to_parser) 17 | 18 | def __enter__(self): 19 | self.tagger.load(path_to_tagger) 20 | self.parser.load(path_to_parser) 21 | return self 22 | 23 | def __exit__(self, exc_type, exc_val, exc_tb): 24 | self.tagger.release() 25 | self.parser.release() 26 | 27 | def parse(self, words): 28 | tags = self.tagger.postag(words) 29 | parse = self.parser.parse(words, tags) 30 | return parse 31 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/util/ltp.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | import os 4 | import pyltp 5 | 6 | LTP_DATA_DIR = 'pub/pyltp_models' 7 | path_to_tagger = os.path.join(LTP_DATA_DIR, 'pos.model') 8 | path_to_parser = os.path.join(LTP_DATA_DIR, "parser.model") 9 | 10 | 11 | class LTPParser: 12 | def __init__(self): 13 | self.tagger = pyltp.Postagger() 14 | self.parser = pyltp.Parser() 15 | self.tagger.load(path_to_tagger) 16 | self.parser.load(path_to_parser) 17 | 18 | def __enter__(self): 19 | self.tagger.load(path_to_tagger) 20 | self.parser.load(path_to_parser) 21 | return self 22 | 23 | def __exit__(self, exc_type, exc_val, exc_tb): 24 | self.tagger.release() 25 | self.parser.release() 26 | 27 | def parse(self, words): 28 | tags = self.tagger.postag(words) 29 | parse = self.parser.parse(words, tags) 30 | return parse 31 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/con2dep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ 8 | 9 | dep_tuples, dep_idx = [], -1 10 | 11 | 12 | def con2dep(root): 13 | global dep_tuples, dep_idx 14 | dep_tuples, dep_idx = [], -1 15 | recursive_one(root) 16 | return dep_tuples 17 | 18 | 19 | def recursive_one(root): 20 | global dep_tuples, dep_idx 21 | if root.left_child is not None: 22 | idx1 = recursive_one(root.left_child) 23 | idx2 = recursive_one(root.right_child) 24 | later = max(idx1, idx2) 25 | pre = min(idx1, idx2) 26 | if root.child_NS_rel == "SN": 27 | dep_tuples.append((later, pre, 0, root.child_rel)) 28 | return idx2 29 | else: 30 | dep_tuples.append((later, pre, 1, root.child_rel)) 31 | return idx1 32 | else: 33 | dep_idx += 1 34 | return dep_idx 35 | -------------------------------------------------------------------------------- /ch_dp_gan/interface.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | from abc import abstractmethod 3 | 4 | from typing import List 5 | from structure.nodes import Sentence, Paragraph, EDU 6 | 7 | 8 | class SegmenterI: 9 | @abstractmethod 10 | def cut(self, text: str) -> Paragraph: 11 | raise NotImplemented() 12 | 13 | @abstractmethod 14 | def cut_sent(self, text: str, sid=None) -> List[Sentence]: 15 | raise NotImplemented() 16 | 17 | @abstractmethod 18 | def cut_edu(self, sent: Sentence) -> List[EDU]: 19 | raise NotImplemented() 20 | 21 | 22 | class ParserI: 23 | @abstractmethod 24 | def parse(self, para: Paragraph) -> Paragraph: 25 | raise NotImplemented() 26 | 27 | 28 | class PipelineI: 29 | @abstractmethod 30 | def cut_sent(self, text: str) -> Paragraph: 31 | raise NotImplemented() 32 | 33 | @abstractmethod 34 | def cut_edu(self, para: Paragraph) -> Paragraph: 35 | raise NotImplemented() 36 | 37 | @abstractmethod 38 | def parse(self, para: Paragraph) -> Paragraph: 39 | raise NotImplemented() 40 | 41 | @abstractmethod 42 | def full_parse(self, text: str): 43 | raise NotImplemented() 44 | 45 | def __call__(self, text: str): 46 | return self.full_parse(text) 47 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/interface.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | from abc import abstractmethod 3 | 4 | from typing import List 5 | from structure.nodes import Sentence, Paragraph, EDU 6 | 7 | 8 | class SegmenterI: 9 | @abstractmethod 10 | def cut(self, text: str) -> Paragraph: 11 | raise NotImplemented() 12 | 13 | @abstractmethod 14 | def cut_sent(self, text: str, sid=None) -> List[Sentence]: 15 | raise NotImplemented() 16 | 17 | @abstractmethod 18 | def cut_edu(self, sent: Sentence) -> List[EDU]: 19 | raise NotImplemented() 20 | 21 | 22 | class ParserI: 23 | @abstractmethod 24 | def parse(self, para: Paragraph) -> Paragraph: 25 | raise NotImplemented() 26 | 27 | 28 | class PipelineI: 29 | @abstractmethod 30 | def cut_sent(self, text: str) -> Paragraph: 31 | raise NotImplemented() 32 | 33 | @abstractmethod 34 | def cut_edu(self, para: Paragraph) -> Paragraph: 35 | raise NotImplemented() 36 | 37 | @abstractmethod 38 | def parse(self, para: Paragraph) -> Paragraph: 39 | raise NotImplemented() 40 | 41 | @abstractmethod 42 | def full_parse(self, text: str): 43 | raise NotImplemented() 44 | 45 | def __call__(self, text: str): 46 | return self.full_parse(text) 47 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/rst.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ 8 | from config import * 9 | from util.file_util import * 10 | 11 | 12 | class RST: 13 | def __init__(self, train_set=TRAIN_SET, dev_set=DEV_SET, test_set=TEST_SET): 14 | self.train = load_data(train_set) 15 | # self.train_rst = load_data(RST_TRAIN_TREES_RST) 16 | self.dev = load_data(dev_set) 17 | # self.dev_rst = load_data(RST_DEV_TREES_RST) 18 | self.test = load_data(test_set) 19 | # self.test_rst = load_data(RST_TEST_TREES_RST) 20 | 21 | self.tree_obj_train = load_data(RST_TRAIN_TREES) 22 | self.tree_obj_dev = load_data(RST_DEV_TREES) 23 | self.tree_obj_test = load_data(RST_TEST_TREES) 24 | self.edu_tree_obj_test = load_data(RST_TEST_EDUS_TREES) 25 | if TRAIN_Ext: 26 | ext_trees = load_data(EXT_TREES) 27 | ext_set = load_data(EXT_SET) 28 | self.train = self.train + ext_set 29 | self.tree_obj_train = self.tree_obj_train + ext_trees 30 | 31 | @staticmethod 32 | def get_voc_labels(): 33 | word2ids = load_data(VOC_WORD2IDS_PATH) 34 | pos2ids = load_data(POS_word2ids_PATH) 35 | return word2ids, pos2ids 36 | 37 | @staticmethod 38 | def get_dev(): 39 | return load_data(RST_DEV_TREES) 40 | 41 | @staticmethod 42 | def get_test(): 43 | if USE_AE: 44 | return load_data(RST_TEST_EDUS_TREES_) 45 | else: 46 | return load_data(RST_TEST_TREES) 47 | -------------------------------------------------------------------------------- /en_dp_gan/structure/rst.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ 8 | from config import * 9 | from util.file_util import * 10 | 11 | 12 | class RST: 13 | def __init__(self, train_set=TRAIN_SET, dev_set=DEV_SET, test_set=TEST_SET): 14 | self.train = load_data(train_set) 15 | # self.train_rst = load_data(RST_TRAIN_TREES_RST) 16 | self.dev = load_data(dev_set) 17 | # self.dev_rst = load_data(RST_DEV_TREES_RST) 18 | self.test = load_data(test_set) 19 | # self.test_rst = load_data(RST_TEST_TREES_RST) 20 | 21 | # tree_obj 22 | self.tree_obj_train = load_data(RST_TRAIN_TREES) 23 | self.tree_obj_dev = load_data(RST_DEV_TREES) 24 | self.tree_obj_test = load_data(RST_TEST_TREES) 25 | self.edu_tree_obj_test = load_data(RST_TEST_EDUS_TREES) 26 | 27 | if TRAIN_Ext: 28 | ext_trees = load_data(EXT_TREES) 29 | ext_set = load_data(EXT_SET) 30 | self.train = self.train + ext_set 31 | self.tree_obj_train = self.tree_obj_train + ext_trees 32 | 33 | @staticmethod 34 | def get_voc_labels(): 35 | word2ids = load_data(VOC_WORD2IDS_PATH) 36 | pos2ids = load_data(POS_word2ids_PATH) 37 | return word2ids, pos2ids 38 | 39 | @staticmethod 40 | def get_dev(): 41 | return load_data(RST_DEV_TREES) 42 | 43 | @staticmethod 44 | def get_test(): 45 | if USE_AE: 46 | return load_data(RST_TEST_EDUS_TREES_) 47 | else: 48 | return load_data(RST_TEST_TREES) 49 | -------------------------------------------------------------------------------- /ch_dp_gan/new_ctb.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import re 3 | import os 4 | from nltk.tree import Tree as ParseTree 5 | from util.berkely import BerkeleyParser 6 | from tqdm import tqdm 7 | 8 | 9 | if __name__ == '__main__': 10 | ctb_dir = "data/CTB" 11 | save_dir = "data/CTB_auto" 12 | encoding = "UTF-8" 13 | ctb = {} 14 | s_pat = re.compile(r"\S+?)>(?P.*?)", re.M | re.DOTALL) 15 | parser = BerkeleyParser() 16 | for file in tqdm(os.listdir(ctb_dir)): 17 | if os.path.isfile(os.path.join(save_dir, file)): 18 | continue 19 | print(file) 20 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd: 21 | doc = fd.read() 22 | parses = [] 23 | for match in s_pat.finditer(doc): 24 | sid = match.group("sid") 25 | sparse = ParseTree.fromstring(match.group("sparse")) 26 | pairs = [(node[0], node.label()) for node in sparse.subtrees() 27 | if node.height() == 2 and node.label() != "-NONE-"] 28 | words, tags = list(zip(*pairs)) 29 | print(sid, " ".join(words)) 30 | if sid == "5133": 31 | parse = sparse 32 | else: 33 | parse = parser.parse(words, timeout=2000) 34 | parses.append((sid, parse)) 35 | with open(os.path.join(save_dir, file), "w+", encoding=encoding) as save_fd: 36 | for sid, parse in parses: 37 | save_fd.write("\n" % sid) 38 | save_fd.write(str(parse)) 39 | save_fd.write("\n\n") 40 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/new_ctb.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import re 3 | import os 4 | from nltk.tree import Tree as ParseTree 5 | from util.berkely import BerkeleyParser 6 | from tqdm import tqdm 7 | 8 | 9 | if __name__ == '__main__': 10 | ctb_dir = "data/CTB" 11 | save_dir = "data/CTB_auto" 12 | encoding = "UTF-8" 13 | ctb = {} 14 | s_pat = re.compile(r"\S+?)>(?P.*?)", re.M | re.DOTALL) 15 | parser = BerkeleyParser() 16 | for file in tqdm(os.listdir(ctb_dir)): 17 | if os.path.isfile(os.path.join(save_dir, file)): 18 | continue 19 | print(file) 20 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd: 21 | doc = fd.read() 22 | parses = [] 23 | for match in s_pat.finditer(doc): 24 | sid = match.group("sid") 25 | sparse = ParseTree.fromstring(match.group("sparse")) 26 | pairs = [(node[0], node.label()) for node in sparse.subtrees() 27 | if node.height() == 2 and node.label() != "-NONE-"] 28 | words, tags = list(zip(*pairs)) 29 | print(sid, " ".join(words)) 30 | if sid == "5133": 31 | parse = sparse 32 | else: 33 | parse = parser.parse(words, timeout=2000) 34 | parses.append((sid, parse)) 35 | with open(os.path.join(save_dir, file), "w+", encoding=encoding) as save_fd: 36 | for sid, parse in parses: 37 | save_fd.write("\n" % sid) 38 | save_fd.write(str(parse)) 39 | save_fd.write("\n\n") 40 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/rst_xl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ 8 | from config import * 9 | from util.file_util import * 10 | 11 | 12 | class RST: 13 | def __init__(self, train_set=TRAIN_SET_XL, dev_set=DEV_SET_XL, test_set=TEST_SET_XL): 14 | self.train = load_data(train_set) 15 | self.train_rst = load_data(RST_TRAIN_TREES_RST) 16 | self.dev = load_data(dev_set) 17 | self.dev_rst = load_data(RST_DEV_TREES_RST) 18 | self.test = load_data(test_set) 19 | self.test_rst = load_data(RST_TEST_TREES_RST) 20 | 21 | self.tree_obj_train = load_data(RST_TRAIN_TREES) 22 | self.tree_obj_dev = load_data(RST_DEV_TREES) 23 | self.tree_obj_test = load_data(RST_TEST_TREES) 24 | self.edu_tree_obj_test = load_data(RST_TEST_EDUS_TREES) 25 | 26 | if TRAIN_Ext: 27 | ext_trees = load_data(EXT_TREES) 28 | ext_trees_rst = load_data(EXT_TREES_RST) 29 | ext_set = load_data(EXT_SET_XL) 30 | self.train = self.train + ext_set 31 | self.train_rst = self.train_rst + ext_trees_rst 32 | self.tree_obj_train = self.tree_obj_train + ext_trees 33 | 34 | @staticmethod 35 | def get_voc_labels(): 36 | word2ids = load_data(VOC_WORD2IDS_PATH) 37 | pos2ids = load_data(POS_word2ids_PATH) 38 | return word2ids, pos2ids 39 | 40 | @staticmethod 41 | def get_dev(): 42 | return load_data(RST_DEV_TREES) 43 | 44 | @staticmethod 45 | def get_test(): 46 | if USE_AE: 47 | return load_data(RST_TEST_EDUS_TREES) 48 | else: 49 | return load_data(RST_TEST_TREES) 50 | -------------------------------------------------------------------------------- /en_dp_gan/structure/tree_obj.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: lyzhang 5 | @Date: 2018/4/5 6 | @Description: tree.edus in the form of ids 7 | tree.nodes in the form of rst_tree objective 8 | """ 9 | import copy 10 | 11 | 12 | class tree_obj: 13 | def __init__(self, tree=None): 14 | self.edus = list() 15 | self.nodes = list() 16 | self.sents_edus = list() 17 | if tree is not None: 18 | self.file_name = tree.file_name 19 | self.tree = tree 20 | self.pre_traverse(tree) 21 | 22 | def __copy__(self): 23 | tree_ = copy.copy(self.tree) 24 | edus_ = [copy.copy(edu) for edu in self.edus] 25 | nodes_ = [copy.copy(node) for node in self.nodes] 26 | t_o = tree_obj(tree_) 27 | t_o.edus = edus_ 28 | t_o.nodes = nodes_ 29 | return t_o 30 | 31 | def assign_edus(self, edus_list): 32 | for edu in edus_list: 33 | self.edus.append(edu) 34 | 35 | def pre_traverse(self, root): 36 | if root is None: 37 | return 38 | self.pre_traverse(root.left_child) 39 | self.pre_traverse(root.right_child) 40 | # judge if nodes 41 | if root.left_child is None and root.right_child is None: 42 | self.edus.append(root) 43 | self.nodes.append(root) 44 | else: 45 | self.nodes.append(root) 46 | 47 | def get_sents_edus(self): 48 | tmp_sent_edus = list() 49 | for edu in self.edus: 50 | tmp_sent_edus.append(edu) 51 | if edu.edu_node_boundary: 52 | self.sents_edus.append(tmp_sent_edus) 53 | tmp_sent_edus = list() 54 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/tree_obj.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: lyzhang 5 | @Date: 2018/4/5 6 | @Description: tree.edus in the form of ids 7 | tree.nodes in the form of rst_tree objective 8 | """ 9 | import copy 10 | 11 | 12 | class tree_obj: 13 | def __init__(self, tree=None): 14 | self.edus = list() 15 | self.nodes = list() 16 | self.sents_edus = list() 17 | if tree is not None: 18 | self.file_name = tree.file_name 19 | # init by tree 20 | self.tree = tree 21 | self.pre_traverse(tree) 22 | # self.get_sents_edus() 23 | 24 | def __copy__(self): 25 | tree_ = copy.copy(self.tree) 26 | edus_ = [copy.copy(edu) for edu in self.edus] 27 | nodes_ = [copy.copy(node) for node in self.nodes] 28 | t_o = tree_obj(tree_) 29 | t_o.edus = edus_ 30 | t_o.nodes = nodes_ 31 | return t_o 32 | 33 | def assign_edus(self, edus_list): 34 | for edu in edus_list: 35 | self.edus.append(edu) 36 | 37 | def pre_traverse(self, root): 38 | if root is None: 39 | return 40 | self.pre_traverse(root.left_child) 41 | self.pre_traverse(root.right_child) 42 | if root.left_child is None and root.right_child is None: 43 | self.edus.append(root) 44 | self.nodes.append(root) 45 | else: 46 | self.nodes.append(root) 47 | 48 | def get_sents_edus(self): 49 | tmp_sent_edus = list() 50 | for edu in self.edus: 51 | tmp_sent_edus.append(edu) 52 | if edu.edu_node_boundary: 53 | self.sents_edus.append(tmp_sent_edus) 54 | tmp_sent_edus = list() 55 | -------------------------------------------------------------------------------- /ch_dp_gan/util/berkely.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import subprocess 3 | import thulac 4 | import threading 5 | import os 6 | from nltk.tree import Tree 7 | 8 | 9 | BERKELEY_JAR = "berkeleyparser/BerkeleyParser-1.7.jar" 10 | BERKELEY_GRAMMAR = "berkeleyparser/chn_sm5.gr" 11 | 12 | 13 | class BerkeleyParser(object): 14 | def __init__(self): 15 | self.tokenizer = thulac.thulac() 16 | self.cmd = ['java', '-Xmx1024m', '-jar', BERKELEY_JAR, '-gr', BERKELEY_GRAMMAR] 17 | self.process = self.start() 18 | 19 | def start(self): 20 | return subprocess.Popen(self.cmd, env=dict(os.environ), universal_newlines=True, shell=False, bufsize=0, 21 | stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, errors='ignore') 22 | 23 | def stop(self): 24 | if self.process: 25 | self.process.terminate() 26 | 27 | def restart(self): 28 | self.stop() 29 | self.process = self.start() 30 | 31 | def parse_thread(self, text, results): 32 | text = text.replace("(", '-LRB-') 33 | text = text.replace(")", '-RRB-') 34 | self.process.stdin.write(text + '\n') 35 | self.process.stdin.flush() 36 | ret = self.process.stdout.readline().strip() 37 | results.append(ret) 38 | 39 | def parse(self, words, timeout=20000): 40 | # words, _ = list(zip(*self.tokenizer.cut(text))) 41 | results = [] 42 | t = threading.Thread(target=self.parse_thread, kwargs={'text': " ".join(words), 'results': results}) 43 | t.setDaemon(True) 44 | t.start() 45 | t.join(timeout) 46 | 47 | if not results: 48 | self.restart() 49 | raise TimeoutError() 50 | else: 51 | return Tree.fromstring(results[0]) 52 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/util/berkely.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import subprocess 3 | import thulac 4 | import threading 5 | import os 6 | from nltk.tree import Tree 7 | 8 | 9 | BERKELEY_JAR = "berkeleyparser/BerkeleyParser-1.7.jar" 10 | BERKELEY_GRAMMAR = "berkeleyparser/chn_sm5.gr" 11 | 12 | 13 | class BerkeleyParser(object): 14 | def __init__(self): 15 | self.tokenizer = thulac.thulac() 16 | self.cmd = ['java', '-Xmx1024m', '-jar', BERKELEY_JAR, '-gr', BERKELEY_GRAMMAR] 17 | self.process = self.start() 18 | 19 | def start(self): 20 | return subprocess.Popen(self.cmd, env=dict(os.environ), universal_newlines=True, shell=False, bufsize=0, 21 | stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, errors='ignore') 22 | 23 | def stop(self): 24 | if self.process: 25 | self.process.terminate() 26 | 27 | def restart(self): 28 | self.stop() 29 | self.process = self.start() 30 | 31 | def parse_thread(self, text, results): 32 | text = text.replace("(", '-LRB-') 33 | text = text.replace(")", '-RRB-') 34 | self.process.stdin.write(text + '\n') 35 | self.process.stdin.flush() 36 | ret = self.process.stdout.readline().strip() 37 | results.append(ret) 38 | 39 | def parse(self, words, timeout=20000): 40 | # words, _ = list(zip(*self.tokenizer.cut(text))) 41 | results = [] 42 | t = threading.Thread(target=self.parse_thread, kwargs={'text': " ".join(words), 'results': results}) 43 | t.setDaemon(True) 44 | t.start() 45 | t.join(timeout) 46 | 47 | if not results: 48 | self.restart() 49 | raise TimeoutError() 50 | else: 51 | return Tree.fromstring(results[0]) 52 | -------------------------------------------------------------------------------- /ch_dp_gan/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # pycharm 107 | .idea/ 108 | 109 | # data 110 | data/ 111 | 112 | # pyltp models 113 | pub/pyltp_models 114 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # pycharm 107 | .idea/ 108 | 109 | # data 110 | data/ 111 | 112 | # pyltp models 113 | pub/pyltp_models 114 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/drawer.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def draw_line(x, y_s, colors, labels, shapes, begins, lines_d, l_colors, step=1, x_name="STEP", y_name="LOSS"): 5 | fig = plt.figure(figsize=(6, 4), facecolor="white") 6 | ax1 = fig.add_subplot(1, 1, 1) 7 | ax1.set_xlabel(x_name) 8 | ax1.set_ylabel(y_name) 9 | for idx, y in enumerate(y_s): 10 | x_ = x[begins[idx]:] 11 | y_ = y[begins[idx]:] 12 | ax1.scatter(x_[::step], y_[::step], color=colors[idx], marker=shapes[idx], edgecolors=colors[idx], s=20, label=labels[idx]) 13 | if lines_d[idx]: 14 | line_general = 20 15 | x_n = x_[::line_general] 16 | y_n = [] 17 | sum_ = 0. 18 | sum_idx_num = 0 19 | for idx__, y__ in enumerate(y_): 20 | sum_idx_num += 1 21 | sum_ += y__ 22 | if idx__ > 0 and idx__ % line_general == 0: 23 | y_n.append(sum_ / line_general) 24 | sum_ = 0. 25 | sum_idx_num = 0 26 | if sum_ > 0.: 27 | y_n.append(sum_ / sum_idx_num) 28 | if idx == 0: 29 | ax1.plot(x_n[:29], y_n[:29], color="PaleGreen", linewidth=2.6, linestyle="-") 30 | ax1.plot(x_n[28:], y_n[28:], color="Green", linewidth=2.6, linestyle="-") 31 | else: 32 | ax1.plot(x_n, y_n, color=l_colors[idx], linewidth=2.6, linestyle="-") 33 | plt.vlines(542, 0, 2, colors="black", linestyles="dashed", linewidth=2) 34 | plt.annotate("warm up", xy=(300, 0.45), xytext=(250, 0.15), arrowprops=dict(facecolor="white", headlength=4, 35 | headwidth=13, width=4)) 36 | plt.legend(loc="upper left") 37 | plt.grid(linestyle='-') 38 | plt.show() 39 | 40 | 41 | if __name__ == "__main__": 42 | draw_line([1, 2, 3], [[3, 2, 1]], ["red"], ["A"]) 43 | -------------------------------------------------------------------------------- /en_dp_gan/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Since I had done many different experiments based on this project, some model settings below could be useless. 4 | Longyin Zhang 5 | """ 6 | from path_config import * 7 | from util.file_util import * 8 | 9 | SAVE_MODEL = True 10 | VERSION, SET = 5, 97 11 | USE_CUDA, CUDA_ID = True, 3 12 | USE_BOUND, USE_GAN, USE_S_STACK = True, False, False 13 | 14 | WARM_UP_EP = 7 if USE_GAN else 20 15 | MAX_W, MAX_H = 80, 20 16 | LEARN_RATE, WD, D_LR = 0.001, 1e-4, 0.0001 17 | EPOCH, BATCH_SIZE, LOG_EVERY, VALIDATE_EVERY = 20, 5, 1, 5 18 | in_channel_G, out_channel_G, ker_h_G, ker_w_G, strip_G = 2, 8, 3, MAX_W, 1 19 | p_w_G, p_h_G = 3, 1 20 | METype = 1 21 | USE_AE = False 22 | TRAIN_Ext = False 23 | 24 | MAX_POOLING = True 25 | RANDOM_MASK_LEARN, RMR = False, 0.1 26 | LABEL_SWITCH, SWITCH_ITE = False, 4 27 | USE_POSE, POS_SIZE = True, 30 28 | USE_ELMo, EMBED_LEARN = True, False 29 | EMBED_SIZE, HIDDEN_SIZE = (1024 if USE_ELMo else 300), 256 30 | BOUND_INFO_SIZE = 30 # 30 31 | GATE_V = 1 32 | USE_LEAF, Re_DROP, GateDrop = True, 0.3, 0.2 33 | TEST_ONLY = False 34 | OPT_ATTN = False 35 | EDU_ATT, ML_ATT_HIDDEN_e, HEADS_e = False, 128, 2 36 | CONTEXT_ATT, ML_ATT_HIDDEN, HEADS = False, 128, 2 37 | SPLIT_MLP_SIZE, NR_MLP_SIZE = 128, 128 38 | FEAT_W_SIZE, FEAT_P_SIZE, FEAT_Head_SIZE = 50, 30, 20 39 | LEN_SIZE, CENTER_SIZE, INNER_SIZE = 10, 30, 30 40 | FEAT_H_SIZE = 140 41 | USE_CNN, KERNEL_SIZE, PADDING_SIZE = True, 2, 1 42 | NUCL_MLP_SIZE, REL_MLP_SIZE = 64, 128 # Total basic. 43 | LAYER_NORM_USE = True 44 | ALPHA_SPAN, ALPHA_NR = 0.3, 1.0 45 | SEED = 36 46 | USE_R_ADAM = False 47 | WARM_UP = 20 48 | L2, DROP_OUT = 1e-5, 0.2 49 | BETA1, BETA2 = 0.9, 0.999 50 | TRAN_LABEL_NUM, NR_LABEL_NUM, NUCL_LABEL_NUM, REL_LABEL_NUM = 1, 42, 3, 18 51 | SHIFT, REDUCE = "SHIFT", "REDUCE" 52 | REDUCE_NN, REDUCE_NS, REDUCE_SN = "REDUCE-NN", "REDUCE-NS", "REDUCE-SN" 53 | NN, NS, SN = "NN", "NS", "SN" 54 | PAD, PAD_ids = "", 0 55 | UNK, UNK_ids = "", 1 56 | action2ids = {SHIFT: 0, REDUCE: 1} 57 | ids2action = {0: SHIFT, 1: REDUCE} 58 | nucl2ids = {NN: 0, NS: 1, SN: 2} if METype == 1 else {"N": 0, "S": 1} 59 | ids2nucl = {0: NN, 1: NS, 2: SN} if METype == 1 else {0: "N", 1: "S"} 60 | ns_dict = {"Satellite": 0, "Nucleus": 1, "Root": 2} 61 | ns_dict_ = {0: "Satellite", 1: "Nucleus", 2: "Root"} 62 | coarse2ids = load_data(REL_coarse2ids) 63 | ids2coarse = load_data(REL_ids2coarse) 64 | nr2ids = load_data(LABEL2IDS) 65 | ids2nr = load_data(IDS2LABEL) 66 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from path_config import * 3 | from util.file_util import * 4 | SAVE_MODEL = True 5 | VERSION, SET = 9, 147 6 | USE_CUDA, CUDA_ID = True, 0 7 | EDU_ENCODE_VERSION, SPLIT_V = 2, 1 8 | CHUNK_SIZE = 768 # 64 9 | MAX_LEN = 32 10 | 11 | USE_BOUND, USE_GAN, USE_S_STACK = True, True, False 12 | TRAIN_XLNET, XL_FINE, Joint_EDU_R, XLNET_TYPE, XLNET_SIZE = True, True, False, "xlnet-base-cased", 768 13 | 14 | MAX_W, MAX_H = 80, 20 15 | LEARN_RATE, WD, D_LR = 0.0001, 1e-4, 0.0001 16 | EPOCH, BATCH_SIZE, LOG_EVERY, VALIDATE_EVERY = 50, 1, 5, 20 17 | UPDATE_ITE = 32 18 | WARM_UP_EP = 7 if USE_GAN else EPOCH 19 | # CNN FEATURE 20 | in_channel_G, out_channel_G, ker_h_G, ker_w_G, strip_G = 2, 32, 3, MAX_W // 2, 1 21 | p_w_G, p_h_G = 3, 3 22 | METype = 1 23 | USE_AE = False 24 | TRAIN_Ext = False 25 | 26 | MAX_POOLING = True 27 | RANDOM_MASK_LEARN, RMR = False, 0.1 28 | LABEL_SWITCH, SWITCH_ITE = False, 4 29 | USE_POSE, POS_SIZE = False, 30 30 | USE_ELMo, EMBED_LEARN = False, False 31 | EMBED_SIZE, HIDDEN_SIZE = (1024 if USE_ELMo else 300), 384 32 | BOUND_INFO_SIZE = 30 # 30 33 | 34 | GATE_V = 1 35 | USE_LEAF, Re_DROP, GateDrop = True, 0.3, 0.2 36 | 37 | TEST_ONLY = False 38 | OPT_ATTN = False 39 | 40 | # ATTN 41 | EDU_ATT, ML_ATT_HIDDEN_e, HEADS_e = False, 128, 2 42 | CONTEXT_ATT, ML_ATT_HIDDEN, HEADS = True, 128, 2 43 | 44 | # MLP 45 | SPLIT_MLP_SIZE, NR_MLP_SIZE = 128, 128 46 | 47 | # feature engineer (feat_e_len, feat_e_w, feat_e_p, feat_e_h, feat_e_c, feat_e_inner) 48 | FEAT_W_SIZE, FEAT_P_SIZE, FEAT_Head_SIZE = 50, 30, 20 49 | LEN_SIZE, CENTER_SIZE, INNER_SIZE = 10, 30, 30 50 | FEAT_H_SIZE = 140 51 | USE_CNN, KERNEL_SIZE, PADDING_SIZE = True, 2, 1 52 | NUCL_MLP_SIZE, REL_MLP_SIZE = 64, 128 # Total basic. 53 | LAYER_NORM_USE = True 54 | ALPHA_SPAN, ALPHA_NR = 0.3, 1.0 55 | SEED = 19 56 | USE_R_ADAM = False 57 | WARM_UP = 20 58 | L2, DROP_OUT = 1e-5, 0.2 59 | BETA1, BETA2 = 0.9, 0.999 60 | TRAN_LABEL_NUM, NR_LABEL_NUM, NUCL_LABEL_NUM, REL_LABEL_NUM = 1, 42, 3, 18 61 | SHIFT, REDUCE = "SHIFT", "REDUCE" 62 | REDUCE_NN, REDUCE_NS, REDUCE_SN = "REDUCE-NN", "REDUCE-NS", "REDUCE-SN" 63 | NN, NS, SN = "NN", "NS", "SN" 64 | PAD, PAD_ids = "", 0 65 | UNK, UNK_ids = "", 1 66 | action2ids = {SHIFT: 0, REDUCE: 1} 67 | ids2action = {0: SHIFT, 1: REDUCE} 68 | nucl2ids = {NN: 0, NS: 1, SN: 2} if METype == 1 else {"N": 0, "S": 1} 69 | ids2nucl = {0: NN, 1: NS, 2: SN} if METype == 1 else {0: "N", 1: "S"} 70 | ns_dict = {"Satellite": 0, "Nucleus": 1, "Root": 2} 71 | ns_dict_ = {0: "Satellite", 1: "Nucleus", 2: "Root"} 72 | coarse2ids = load_data(REL_coarse2ids) 73 | ids2coarse = load_data(REL_ids2coarse) 74 | nr2ids = load_data(LABEL2IDS) 75 | ids2nr = load_data(IDS2LABEL) 76 | NR_LOSS_OPT = False 77 | -------------------------------------------------------------------------------- /en_dp_gan/path_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ 8 | # RST-DT 9 | RST_DT_PATH = "data/rst_dt" 10 | RAW_RST_DT_PATH = "data/rst_dt/RST_DT_RAW" 11 | 12 | # EDUs of train & dev 13 | RST_EDUs_ori_tr = "data/EDUs/ori_train_edus.tsv" 14 | RST_EDUs_ori_de = "data/EDUs/ori_dev_edus.tsv" 15 | # EN EDUs of train & dev 16 | RST_EDUs_google_tr = "data/EDUs/translate_train.txt" 17 | RST_EDUs_google_de = "data/EDUs/translate_dev.tsv" 18 | 19 | EXT_TREES = "data/rst_dt/ext_trees.pkl" 20 | EXT_TREES_RST = "data/rst_dt/ext_trees_rst.pkl" 21 | EXT_SET = "data/rst_dt/ext_set.pkl" 22 | 23 | # RST-DT Generated 24 | RST_TRAIN_TREES = "data/rst_dt/train_trees.pkl" 25 | RST_TRAIN_TREES_RST = "data/rst_dt/train_trees_rst.pkl" 26 | TRAIN_SET = "data/rst_dt/train_set.pkl" 27 | RST_TEST_TREES = "data/rst_dt/test_trees.pkl" 28 | RST_TEST_TREES_RST = "data/rst_dt/test_trees_rst.pkl" 29 | RST_TEST_EDUS_TREES = "data/rst_dt/edus_test_trees.pkl" 30 | RST_TEST_EDUS_TREES_ = "data/rst_dt/edus_test_trees_.pkl" 31 | TEST_SET = "data/rst_dt/test_set.pkl" 32 | RST_DEV_TREES = "data/rst_dt/dev_trees.pkl" 33 | RST_DEV_TREES_RST = "data/rst_dt/dev_trees_rst.pkl" 34 | DEV_SET = "data/rst_dt/dev_set.pkl" 35 | 36 | TRAIN_SET_NR = "data/rst_dt/nr_train_set.pkl" 37 | TEST_SET_NR = "data/rst_dt/nr_test_set.pkl" 38 | DEV_SET_NR = "data/rst_dt/nr_dev_set.pkl" 39 | 40 | RST_TEST_TREES_N = "data/rst_dt/test_trees_n.pkl" 41 | RST_TEST_TREES_RST_N = "data/rst_dt/test_trees_rst_n.pkl" 42 | RST_TEST_EDUS_TREES_N = "data/rst_dt/edus_test_trees_n.pkl" 43 | TEST_SET_N = "data/rst_dt/test_n.set.pkl" 44 | TEST_SET_NR_N = "data/rst_dt/nr_test_n.set.pkl" 45 | 46 | # VOC 47 | REL_raw2coarse = "data/voc/raw2coarse.pkl" 48 | VOC_WORD2IDS_PATH = "data/voc/word2ids.pkl" 49 | POS_word2ids_PATH = "data/voc/pos2ids.pkl" 50 | REL_coarse2ids = "data/voc/coarse2ids.pkl" 51 | REL_ids2coarse = "data/voc/ids2coarse.pkl" 52 | VOC_VEC_PATH = "data/voc/ids2vec.pkl" 53 | LABEL2IDS = "data/voc/label2ids.pkl" 54 | IDS2LABEL = "data/voc/ids2label.pkl" 55 | options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/" \ 56 | "elmo_2x4096_512_2048cnn_2xhighway_options.json" 57 | weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo" \ 58 | "_2x4096_512_2048cnn_2xhighway_weights.hdf5" 59 | path_to_jar = 'stanford-corenlp-full-2018-02-27' 60 | MODELS2SAVE = "data/models_saved" 61 | MODEL_SAVE = "data/models/treebuilder.partptr.basic" 62 | LOG_ALL = "data/logs" 63 | PRE_DEPTH = "data/pre_depth.pkl." 64 | GOLD_DEPTH = "data/gold_depth.pkl" 65 | DRAW_DT = "data/drawer.pkl" 66 | LOSS_PATH = "data/loss/" 67 | EXT_SET_NEG = "data/rst_dt/neg_ext_set.pkl" 68 | TRAIN_SET_NEG = "data/rst_dt/neg_train_set.pkl" 69 | TEST_SET_NEG = "data/rst_dt/neg_test_set.pkl" 70 | DEV_SET_NEG = "data/rst_dt/neg_dev_set.pkl" 71 | -------------------------------------------------------------------------------- /ch_dp_gan/util/file_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: lyzhang 5 | @Date: 2018/3/13 6 | @Description: 7 | """ 8 | 9 | import urllib 10 | import gzip 11 | import os 12 | import shutil 13 | import pickle as pkl 14 | 15 | 16 | def safe_mkdir(path): 17 | """ Create a directory if there isn't one already. """ 18 | try: 19 | os.mkdir(path) 20 | except OSError: 21 | pass 22 | 23 | 24 | def safe_mkdirs(path_list): 25 | """ Create a directory if there isn't one already. """ 26 | for path in path_list: 27 | try: 28 | os.mkdir(path) 29 | except OSError: 30 | pass 31 | 32 | 33 | def download_one_file(download_url, 34 | local_dest, 35 | expected_byte=None, 36 | unzip_and_remove=False): 37 | """ 38 | Download the file from download_url into local_dest 39 | if the file doesn't already exists. 40 | If expected_byte is provided, check if 41 | the downloaded file has the same number of bytes. 42 | If unzip_and_remove is True, unzip the file and remove the zip file 43 | """ 44 | if os.path.exists(local_dest) or os.path.exists(local_dest[:-3]): 45 | print('%s already exists' % local_dest) 46 | else: 47 | print('Downloading %s' % download_url) 48 | local_file, _ = urllib.request.urlretrieve(download_url, local_dest) 49 | file_stat = os.stat(local_dest) 50 | if expected_byte: 51 | if file_stat.st_size == expected_byte: 52 | print('Successfully downloaded %s' % local_dest) 53 | if unzip_and_remove: 54 | with gzip.open(local_dest, 'rb') as f_in, open(local_dest[:-3], 'wb') as f_out: 55 | shutil.copyfileobj(f_in, f_out) 56 | os.remove(local_dest) 57 | else: 58 | print('The downloaded file has unexpected number of bytes') 59 | 60 | 61 | def save_data(obj, path, append=False): 62 | if append: 63 | with open(path, "wb+") as f: 64 | pkl.dump(obj, f) 65 | else: 66 | with open(path, "wb") as f: 67 | pkl.dump(obj, f) 68 | 69 | 70 | def load_data(path): 71 | with open(path, "rb") as f: 72 | obj = pkl.load(f) 73 | return obj 74 | 75 | 76 | def write_iterate(ite, file_path): 77 | with open(file_path, "w") as f: 78 | for line in ite: 79 | f.write(line + "\n") 80 | 81 | 82 | def write_append(txt, file_path): 83 | with open(file_path, "a") as f: 84 | f.write(txt + "\n") 85 | 86 | 87 | def write_over(txt, file_path): 88 | with open(file_path, "w") as f: 89 | f.write(txt + "\n") 90 | 91 | 92 | def print_(str_, log_file, write_=False): 93 | if write_: 94 | write_over(str_, log_file) 95 | else: 96 | write_append(str_, log_file) 97 | print(str_) 98 | -------------------------------------------------------------------------------- /en_dp_gan/util/file_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: lyzhang 5 | @Date: 2018/3/13 6 | @Description: 7 | """ 8 | 9 | import urllib 10 | import gzip 11 | import os 12 | import shutil 13 | import pickle as pkl 14 | 15 | 16 | def safe_mkdir(path): 17 | """ Create a directory if there isn't one already. """ 18 | try: 19 | os.mkdir(path) 20 | except OSError: 21 | pass 22 | 23 | 24 | def safe_mkdirs(path_list): 25 | """ Create a directory if there isn't one already. """ 26 | for path in path_list: 27 | try: 28 | os.mkdir(path) 29 | except OSError: 30 | pass 31 | 32 | 33 | def download_one_file(download_url, 34 | local_dest, 35 | expected_byte=None, 36 | unzip_and_remove=False): 37 | """ 38 | Download the file from download_url into local_dest 39 | if the file doesn't already exists. 40 | If expected_byte is provided, check if 41 | the downloaded file has the same number of bytes. 42 | If unzip_and_remove is True, unzip the file and remove the zip file 43 | """ 44 | if os.path.exists(local_dest) or os.path.exists(local_dest[:-3]): 45 | print('%s already exists' % local_dest) 46 | else: 47 | print('Downloading %s' % download_url) 48 | local_file, _ = urllib.request.urlretrieve(download_url, local_dest) 49 | file_stat = os.stat(local_dest) 50 | if expected_byte: 51 | if file_stat.st_size == expected_byte: 52 | print('Successfully downloaded %s' % local_dest) 53 | if unzip_and_remove: 54 | with gzip.open(local_dest, 'rb') as f_in, open(local_dest[:-3], 'wb') as f_out: 55 | shutil.copyfileobj(f_in, f_out) 56 | os.remove(local_dest) 57 | else: 58 | print('The downloaded file has unexpected number of bytes') 59 | 60 | 61 | def save_data(obj, path, append=False): 62 | if append: 63 | with open(path, "wb+") as f: 64 | pkl.dump(obj, f) 65 | else: 66 | with open(path, "wb") as f: 67 | pkl.dump(obj, f) 68 | 69 | 70 | def load_data(path): 71 | with open(path, "rb") as f: 72 | obj = pkl.load(f) 73 | return obj 74 | 75 | 76 | def write_iterate(ite, file_path): 77 | with open(file_path, "w") as f: 78 | for line in ite: 79 | f.write(line + "\n") 80 | 81 | 82 | def write_append(txt, file_path): 83 | with open(file_path, "a") as f: 84 | f.write(txt + "\n") 85 | 86 | 87 | def write_over(txt, file_path): 88 | with open(file_path, "w") as f: 89 | f.write(txt + "\n") 90 | 91 | 92 | def print_(str_, log_file, write_=False): 93 | if write_: 94 | write_over(str_, log_file) 95 | else: 96 | write_append(str_, log_file) 97 | print(str_) 98 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/file_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: lyzhang 5 | @Date: 2018/3/13 6 | @Description: 7 | """ 8 | 9 | import urllib 10 | import gzip 11 | import os 12 | import shutil 13 | import pickle as pkl 14 | 15 | 16 | def safe_mkdir(path): 17 | """ Create a directory if there isn't one already. """ 18 | try: 19 | os.mkdir(path) 20 | except OSError: 21 | pass 22 | 23 | 24 | def safe_mkdirs(path_list): 25 | """ Create a directory if there isn't one already. """ 26 | for path in path_list: 27 | try: 28 | os.mkdir(path) 29 | except OSError: 30 | pass 31 | 32 | 33 | def download_one_file(download_url, 34 | local_dest, 35 | expected_byte=None, 36 | unzip_and_remove=False): 37 | """ 38 | Download the file from download_url into local_dest 39 | if the file doesn't already exists. 40 | If expected_byte is provided, check if 41 | the downloaded file has the same number of bytes. 42 | If unzip_and_remove is True, unzip the file and remove the zip file 43 | """ 44 | if os.path.exists(local_dest) or os.path.exists(local_dest[:-3]): 45 | print('%s already exists' % local_dest) 46 | else: 47 | print('Downloading %s' % download_url) 48 | local_file, _ = urllib.request.urlretrieve(download_url, local_dest) 49 | file_stat = os.stat(local_dest) 50 | if expected_byte: 51 | if file_stat.st_size == expected_byte: 52 | print('Successfully downloaded %s' % local_dest) 53 | if unzip_and_remove: 54 | with gzip.open(local_dest, 'rb') as f_in, open(local_dest[:-3], 'wb') as f_out: 55 | shutil.copyfileobj(f_in, f_out) 56 | os.remove(local_dest) 57 | else: 58 | print('The downloaded file has unexpected number of bytes') 59 | 60 | 61 | def save_data(obj, path, append=False): 62 | if append: 63 | with open(path, "wb+") as f: 64 | pkl.dump(obj, f) 65 | else: 66 | with open(path, "wb") as f: 67 | pkl.dump(obj, f) 68 | 69 | 70 | def load_data(path): 71 | with open(path, "rb") as f: 72 | obj = pkl.load(f) 73 | return obj 74 | 75 | 76 | def write_iterate(ite, file_path): 77 | with open(file_path, "w") as f: 78 | for line in ite: 79 | f.write(line + "\n") 80 | 81 | 82 | def write_append(txt, file_path): 83 | with open(file_path, "a") as f: 84 | f.write(txt + "\n") 85 | 86 | 87 | def write_over(txt, file_path): 88 | with open(file_path, "w") as f: 89 | f.write(txt + "\n") 90 | 91 | 92 | def print_(str_, log_file, write_=False): 93 | if write_: 94 | write_over(str_, log_file) 95 | else: 96 | write_append(str_, log_file) 97 | print(str_) 98 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/util/file_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: lyzhang 5 | @Date: 2018/3/13 6 | """ 7 | 8 | import urllib 9 | import gzip 10 | import os 11 | import shutil 12 | import pickle as pkl 13 | 14 | 15 | def safe_mkdir(path): 16 | """ Create a directory if there isn't one already. """ 17 | try: 18 | os.mkdir(path) 19 | except OSError: 20 | pass 21 | 22 | 23 | def safe_mkdirs(path_list): 24 | """ Create a directory if there isn't one already. """ 25 | for path in path_list: 26 | try: 27 | os.mkdir(path) 28 | except OSError: 29 | pass 30 | 31 | 32 | def download_one_file(download_url, 33 | local_dest, 34 | expected_byte=None, 35 | unzip_and_remove=False): 36 | """ 37 | Download the file from download_url into local_dest 38 | if the file doesn't already exists. 39 | If expected_byte is provided, check if 40 | the downloaded file has the same number of bytes. 41 | If unzip_and_remove is True, unzip the file and remove the zip file 42 | """ 43 | if os.path.exists(local_dest) or os.path.exists(local_dest[:-3]): 44 | print('%s already exists' % local_dest) 45 | else: 46 | print('Downloading %s' % download_url) 47 | local_file, _ = urllib.request.urlretrieve(download_url, local_dest) 48 | file_stat = os.stat(local_dest) 49 | if expected_byte: 50 | if file_stat.st_size == expected_byte: 51 | print('Successfully downloaded %s' % local_dest) 52 | if unzip_and_remove: 53 | with gzip.open(local_dest, 'rb') as f_in, open(local_dest[:-3], 'wb') as f_out: 54 | shutil.copyfileobj(f_in, f_out) 55 | os.remove(local_dest) 56 | else: 57 | print('The downloaded file has unexpected number of bytes') 58 | 59 | 60 | def save_data(obj, path, append=False): 61 | if append: 62 | with open(path, "wb+") as f: 63 | pkl.dump(obj, f) 64 | else: 65 | with open(path, "wb") as f: 66 | pkl.dump(obj, f) 67 | 68 | 69 | def load_data(path): 70 | with open(path, "rb") as f: 71 | obj = pkl.load(f) 72 | return obj 73 | 74 | 75 | def write_iterate(ite, file_path): 76 | with open(file_path, "w") as f: 77 | for line in ite: 78 | f.write(line + "\n") 79 | 80 | 81 | def write_append(txt, file_path): 82 | with open(file_path, "a") as f: 83 | f.write(txt + "\n") 84 | 85 | 86 | def write_over(txt, file_path): 87 | with open(file_path, "w") as f: 88 | f.write(txt + "\n") 89 | 90 | 91 | def print_(str_, log_file, write_=False): 92 | if write_: 93 | write_over(str_, log_file) 94 | else: 95 | write_append(str_, log_file) 96 | print(str_) 97 | 98 | 99 | def get_files_all(path_): 100 | lines_all = [] 101 | for f_name in os.listdir(path_): 102 | tmp_n = os.path.join(path_, f_name) 103 | with open(tmp_n, "r") as f: 104 | lines_all += f.readlines() 105 | return lines_all 106 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/path_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: Lyzhang 5 | @Date: 6 | @Description: 7 | """ 8 | # RST-DT 9 | RST_DT_PATH = "data/rst_dt" 10 | RAW_RST_DT_PATH = "data/rst_dt/RST_DT_RAW" 11 | 12 | # EDUs of train & dev 13 | RST_EDUs_ori_tr = "data/EDUs/ori_train_edus.tsv" 14 | RST_EDUs_ori_de = "data/EDUs/ori_dev_edus.tsv" 15 | # EN EDUs of train & dev 16 | RST_EDUs_google_tr = "data/EDUs/translate_train.txt" 17 | RST_EDUs_google_de = "data/EDUs/translate_dev.tsv" 18 | 19 | EXT_TREES = "data/rst_dt/ext_trees.pkl" 20 | EXT_TREES_RST = "data/rst_dt/ext_trees_rst.pkl" 21 | EXT_SET = "data/rst_dt/ext_set.pkl" 22 | 23 | # RST-DT Generated 24 | RST_TRAIN_TREES = "data/rst_dt/train_trees.pkl" 25 | RST_TRAIN_TREES_RST = "data/rst_dt/train_trees_rst.pkl" 26 | TRAIN_SET = "data/rst_dt/train_set.pkl" 27 | RST_TEST_TREES = "data/rst_dt/test_trees.pkl" 28 | RST_TEST_TREES_RST = "data/rst_dt/test_trees_rst.pkl" 29 | RST_TEST_EDUS_TREES = "data/rst_dt/edus_test_trees.pkl" 30 | RST_TEST_EDUS_TREES_ = "data/rst_dt/edus_test_trees_.pkl" 31 | TEST_SET = "data/rst_dt/test_set.pkl" 32 | RST_DEV_TREES = "data/rst_dt/dev_trees.pkl" 33 | RST_DEV_TREES_RST = "data/rst_dt/dev_trees_rst.pkl" 34 | DEV_SET = "data/rst_dt/dev_set.pkl" 35 | TRAIN_SET_XL = "data/rst_dt/train_set_xl.pkl" 36 | TEST_SET_XL = "data/rst_dt/test_set_xl.pkl" 37 | DEV_SET_XL = "data/rst_dt/dev_set_xl.pkl" 38 | EXT_SET_XL = "data/rst_dt/ext_set_xl.pkl" 39 | 40 | TRAIN_SET_XLM = "data/rst_dt/marcu/train_set_xlm.pkl" # marcu 学习 41 | TEST_SET_XLM = "data/rst_dt/marcu/test_set_xlm.pkl" 42 | DEV_SET_XLM = "data/rst_dt/marcu/dev_set_xlm.pkl" 43 | EXT_SET_XLM = "data/rst_dt/marcu/ext_set_xlm.pkl" 44 | NR2ids_marcu = "data/rst_dt/marcu/nr2ids.pkl" 45 | IDS2nr_marcu = "data/rst_dt/marcu/ids2nr.pkl" 46 | 47 | TRAIN_SET_NR = "data/rst_dt/nr_train_set.pkl" # N 和 R 分开 48 | TEST_SET_NR = "data/rst_dt/nr_test_set.pkl" 49 | DEV_SET_NR = "data/rst_dt/nr_dev_set.pkl" 50 | 51 | RST_TEST_TREES_N = "data/rst_dt/test_trees_n.pkl" 52 | RST_TEST_TREES_RST_N = "data/rst_dt/test_trees_rst_n.pkl" 53 | RST_TEST_EDUS_TREES_N = "data/rst_dt/edus_test_trees_n.pkl" 54 | TEST_SET_N = "data/rst_dt/test_n.set.pkl" 55 | TEST_SET_NR_N = "data/rst_dt/nr_test_n.set.pkl" 56 | 57 | # VOC 58 | REL_raw2coarse = "data/voc/raw2coarse.pkl" 59 | VOC_WORD2IDS_PATH = "data/voc/word2ids.pkl" 60 | POS_word2ids_PATH = "data/voc/pos2ids.pkl" 61 | REL_coarse2ids = "data/voc/coarse2ids.pkl" 62 | REL_ids2coarse = "data/voc/ids2coarse.pkl" 63 | VOC_VEC_PATH = "data/voc/ids2vec.pkl" 64 | LABEL2IDS = "data/voc/label2ids.pkl" 65 | IDS2LABEL = "data/voc/ids2label.pkl" 66 | 67 | options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/" \ 68 | "elmo_2x4096_512_2048cnn_2xhighway_options.json" 69 | weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo" \ 70 | "_2x4096_512_2048cnn_2xhighway_weights.hdf5" 71 | path_to_jar = 'stanford-corenlp-full-2018-02-27' 72 | MODELS2SAVE = "data/models_saved" 73 | MODEL_SAVE = "data/models/treebuilder.partptr.basic" 74 | LOG_ALL = "data/logs" 75 | PRE_DEPTH = "data/pre_depth.pkl." 76 | GOLD_DEPTH = "data/gold_depth.pkl" 77 | DRAW_DT = "data/drawer.pkl" 78 | LOSS_PATH = "data/loss/" 79 | 80 | # ADDITIONAL 81 | EXT_SET_NEG = "data/rst_dt/neg_ext_set.pkl" 82 | TRAIN_SET_NEG = "data/rst_dt/neg_train_set.pkl" 83 | TEST_SET_NEG = "data/rst_dt/neg_test_set.pkl" 84 | DEV_SET_NEG = "data/rst_dt/neg_dev_set.pkl" 85 | -------------------------------------------------------------------------------- /ch_dp_gan/structure/vocab.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | import logging 4 | from collections import defaultdict 5 | from gensim.models import KeyedVectors 6 | import numpy as np 7 | import torch 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | PAD_TAG = "" 12 | UNK_TAG = "" 13 | 14 | 15 | class Vocab: 16 | def __init__(self, name, counter, min_occur=1): 17 | self.name = name 18 | self.s2id = defaultdict(int) 19 | self.s2id[UNK_TAG] = 0 20 | self.id2s = [UNK_TAG] 21 | 22 | self.counter = counter 23 | for tag, freq in counter.items(): 24 | if tag not in self.s2id and freq >= min_occur: 25 | self.s2id[tag] = len(self.s2id) 26 | self.id2s.append(tag) 27 | logger.info("%s vocabulary size %d" % (self.name, len(self.s2id))) 28 | 29 | def __getitem__(self, item): 30 | return self.s2id[item] 31 | 32 | def embedding(self, dim=None, pretrained=None, binary=False, freeze=False, use_gpu=False): 33 | if dim is not None and pretrained is not None: 34 | raise Warning("dim should not given if pretraiained weights are assigned") 35 | 36 | if dim is None and pretrained is None: 37 | raise Warning("one of dim or pretrained should be assigned") 38 | 39 | if pretrained: 40 | w2v = KeyedVectors.load_word2vec_format(pretrained, binary=binary) 41 | dim = w2v.vector_size 42 | scale = np.sqrt(3.0 / dim) 43 | weights = np.empty([len(self), dim], dtype=np.float32) 44 | oov_count = 0 45 | all_count = 0 46 | for tag, i in self.s2id.items(): 47 | if tag in w2v.vocab: 48 | weights[i] = w2v[tag].astype(np.float32) 49 | else: 50 | oov_count += self.counter[tag] 51 | weights[i] = np.zeros(dim).astype(np.float32) if freeze else \ 52 | np.random.uniform(-scale, scale, dim).astype(np.float32) 53 | all_count += self.counter[tag] 54 | logger.info("%s vocabulary pretrained OOV %d/%d, %.2f%%" % 55 | (self.name, oov_count, all_count, oov_count/all_count*100)) 56 | else: 57 | scale = np.sqrt(3.0 / dim) 58 | weights = np.random.uniform(-scale, scale, [len(self), dim]).astype(np.float32) 59 | weights[0] = np.zeros(dim).astype(np.float32) if freeze else \ 60 | np.random.uniform(-scale, scale, dim).astype(np.float32) 61 | weights = torch.from_numpy(weights) 62 | if use_gpu: 63 | weights = weights.cuda() 64 | embedding = torch.nn.Embedding.from_pretrained(weights, freeze=freeze) 65 | return embedding 66 | 67 | def __len__(self): 68 | return len(self.id2s) 69 | 70 | 71 | class Label: 72 | def __init__(self, name, counter, specials=None): 73 | self.name = name 74 | self.counter = counter.copy() 75 | self.label2id = {} 76 | self.id2label = [] 77 | if specials: 78 | for label in specials: 79 | del self.counter[label] 80 | self.label2id[label] = len(self.label2id) 81 | self.id2label.append(label) 82 | 83 | for label, freq in self.counter.items(): 84 | if label not in self.label2id: 85 | self.label2id[label] = len(self.label2id) 86 | self.id2label.append(label) 87 | logger.info("label %s size %d" % (name, len(self))) 88 | 89 | def __getitem__(self, item): 90 | return self.label2id[item] 91 | 92 | def __len__(self): 93 | return len(self.id2label) 94 | -------------------------------------------------------------------------------- /en_dp_gan/structure/vocab.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | import logging 4 | from collections import defaultdict 5 | from gensim.models import KeyedVectors 6 | import numpy as np 7 | import torch 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | PAD_TAG = "" 12 | UNK_TAG = "" 13 | 14 | 15 | class Vocab: 16 | def __init__(self, name, counter, min_occur=1): 17 | self.name = name 18 | self.s2id = defaultdict(int) 19 | self.s2id[UNK_TAG] = 0 20 | self.id2s = [UNK_TAG] 21 | 22 | self.counter = counter 23 | for tag, freq in counter.items(): 24 | if tag not in self.s2id and freq >= min_occur: 25 | self.s2id[tag] = len(self.s2id) 26 | self.id2s.append(tag) 27 | logger.info("%s vocabulary size %d" % (self.name, len(self.s2id))) 28 | 29 | def __getitem__(self, item): 30 | return self.s2id[item] 31 | 32 | def embedding(self, dim=None, pretrained=None, binary=False, freeze=False, use_gpu=False): 33 | if dim is not None and pretrained is not None: 34 | raise Warning("dim should not given if pretraiained weights are assigned") 35 | 36 | if dim is None and pretrained is None: 37 | raise Warning("one of dim or pretrained should be assigned") 38 | 39 | if pretrained: 40 | w2v = KeyedVectors.load_word2vec_format(pretrained, binary=binary) 41 | dim = w2v.vector_size 42 | scale = np.sqrt(3.0 / dim) 43 | weights = np.empty([len(self), dim], dtype=np.float32) 44 | oov_count = 0 45 | all_count = 0 46 | for tag, i in self.s2id.items(): 47 | if tag in w2v.vocab: 48 | weights[i] = w2v[tag].astype(np.float32) 49 | else: 50 | oov_count += self.counter[tag] 51 | weights[i] = np.zeros(dim).astype(np.float32) if freeze else \ 52 | np.random.uniform(-scale, scale, dim).astype(np.float32) 53 | all_count += self.counter[tag] 54 | logger.info("%s vocabulary pretrained OOV %d/%d, %.2f%%" % 55 | (self.name, oov_count, all_count, oov_count/all_count*100)) 56 | else: 57 | scale = np.sqrt(3.0 / dim) 58 | weights = np.random.uniform(-scale, scale, [len(self), dim]).astype(np.float32) 59 | weights[0] = np.zeros(dim).astype(np.float32) if freeze else \ 60 | np.random.uniform(-scale, scale, dim).astype(np.float32) 61 | weights = torch.from_numpy(weights) 62 | if use_gpu: 63 | weights = weights.cuda() 64 | embedding = torch.nn.Embedding.from_pretrained(weights, freeze=freeze) 65 | return embedding 66 | 67 | def __len__(self): 68 | return len(self.id2s) 69 | 70 | 71 | class Label: 72 | def __init__(self, name, counter, specials=None): 73 | self.name = name 74 | self.counter = counter.copy() 75 | self.label2id = {} 76 | self.id2label = [] 77 | if specials: 78 | for label in specials: 79 | del self.counter[label] 80 | self.label2id[label] = len(self.label2id) 81 | self.id2label.append(label) 82 | 83 | for label, freq in self.counter.items(): 84 | if label not in self.label2id: 85 | self.label2id[label] = len(self.label2id) 86 | self.id2label.append(label) 87 | logger.info("label %s size %d" % (name, len(self))) 88 | 89 | def __getitem__(self, item): 90 | return self.label2id[item] 91 | 92 | def __len__(self): 93 | return len(self.id2label) 94 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/structure/vocab.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | import logging 4 | from collections import defaultdict 5 | from gensim.models import KeyedVectors 6 | import numpy as np 7 | import torch 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | PAD_TAG = "" 12 | UNK_TAG = "" 13 | 14 | 15 | class Vocab: 16 | def __init__(self, name, counter, min_occur=1): 17 | self.name = name 18 | self.s2id = defaultdict(int) 19 | self.s2id[UNK_TAG] = 0 20 | self.id2s = [UNK_TAG] 21 | 22 | self.counter = counter 23 | for tag, freq in counter.items(): 24 | if tag not in self.s2id and freq >= min_occur: 25 | self.s2id[tag] = len(self.s2id) 26 | self.id2s.append(tag) 27 | logger.info("%s vocabulary size %d" % (self.name, len(self.s2id))) 28 | 29 | def __getitem__(self, item): 30 | return self.s2id[item] 31 | 32 | def embedding(self, dim=None, pretrained=None, binary=False, freeze=False, use_gpu=False): 33 | if dim is not None and pretrained is not None: 34 | raise Warning("dim should not given if pretraiained weights are assigned") 35 | 36 | if dim is None and pretrained is None: 37 | raise Warning("one of dim or pretrained should be assigned") 38 | 39 | if pretrained: 40 | w2v = KeyedVectors.load_word2vec_format(pretrained, binary=binary) 41 | dim = w2v.vector_size 42 | scale = np.sqrt(3.0 / dim) 43 | weights = np.empty([len(self), dim], dtype=np.float32) 44 | oov_count = 0 45 | all_count = 0 46 | for tag, i in self.s2id.items(): 47 | if tag in w2v.vocab: 48 | weights[i] = w2v[tag].astype(np.float32) 49 | else: 50 | oov_count += self.counter[tag] 51 | weights[i] = np.zeros(dim).astype(np.float32) if freeze else \ 52 | np.random.uniform(-scale, scale, dim).astype(np.float32) 53 | all_count += self.counter[tag] 54 | logger.info("%s vocabulary pretrained OOV %d/%d, %.2f%%" % 55 | (self.name, oov_count, all_count, oov_count/all_count*100)) 56 | else: 57 | scale = np.sqrt(3.0 / dim) 58 | weights = np.random.uniform(-scale, scale, [len(self), dim]).astype(np.float32) 59 | weights[0] = np.zeros(dim).astype(np.float32) if freeze else \ 60 | np.random.uniform(-scale, scale, dim).astype(np.float32) 61 | weights = torch.from_numpy(weights) 62 | if use_gpu: 63 | weights = weights.cuda() 64 | embedding = torch.nn.Embedding.from_pretrained(weights, freeze=freeze) 65 | return embedding 66 | 67 | def __len__(self): 68 | return len(self.id2s) 69 | 70 | 71 | class Label: 72 | def __init__(self, name, counter, specials=None): 73 | self.name = name 74 | self.counter = counter.copy() 75 | self.label2id = {} 76 | self.id2label = [] 77 | if specials: 78 | for label in specials: 79 | del self.counter[label] 80 | self.label2id[label] = len(self.label2id) 81 | self.id2label.append(label) 82 | 83 | for label, freq in self.counter.items(): 84 | if label not in self.label2id: 85 | self.label2id[label] = len(self.label2id) 86 | self.id2label.append(label) 87 | logger.info("label %s size %d" % (name, len(self))) 88 | 89 | def __getitem__(self, item): 90 | return self.label2id[item] 91 | 92 | def __len__(self): 93 | return len(self.id2label) 94 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/vocab.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | import logging 4 | from collections import defaultdict 5 | from gensim.models import KeyedVectors 6 | import numpy as np 7 | import torch 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | PAD_TAG = "" 12 | UNK_TAG = "" 13 | 14 | 15 | class Vocab: 16 | def __init__(self, name, counter, min_occur=1): 17 | self.name = name 18 | self.s2id = defaultdict(int) 19 | self.s2id[UNK_TAG] = 0 20 | self.id2s = [UNK_TAG] 21 | 22 | self.counter = counter 23 | for tag, freq in counter.items(): 24 | if tag not in self.s2id and freq >= min_occur: 25 | self.s2id[tag] = len(self.s2id) 26 | self.id2s.append(tag) 27 | logger.info("%s vocabulary size %d" % (self.name, len(self.s2id))) 28 | 29 | def __getitem__(self, item): 30 | return self.s2id[item] 31 | 32 | def embedding(self, dim=None, pretrained=None, binary=False, freeze=False, use_gpu=False): 33 | if dim is not None and pretrained is not None: 34 | raise Warning("dim should not given if pretraiained weights are assigned") 35 | 36 | if dim is None and pretrained is None: 37 | raise Warning("one of dim or pretrained should be assigned") 38 | 39 | if pretrained: 40 | w2v = KeyedVectors.load_word2vec_format(pretrained, binary=binary) 41 | dim = w2v.vector_size 42 | scale = np.sqrt(3.0 / dim) 43 | weights = np.empty([len(self), dim], dtype=np.float32) 44 | oov_count = 0 45 | all_count = 0 46 | for tag, i in self.s2id.items(): 47 | if tag in w2v.vocab: 48 | weights[i] = w2v[tag].astype(np.float32) 49 | else: 50 | oov_count += self.counter[tag] 51 | weights[i] = np.zeros(dim).astype(np.float32) if freeze else \ 52 | np.random.uniform(-scale, scale, dim).astype(np.float32) 53 | all_count += self.counter[tag] 54 | logger.info("%s vocabulary pretrained OOV %d/%d, %.2f%%" % 55 | (self.name, oov_count, all_count, oov_count/all_count*100)) 56 | else: 57 | scale = np.sqrt(3.0 / dim) 58 | weights = np.random.uniform(-scale, scale, [len(self), dim]).astype(np.float32) 59 | weights[0] = np.zeros(dim).astype(np.float32) if freeze else \ 60 | np.random.uniform(-scale, scale, dim).astype(np.float32) 61 | weights = torch.from_numpy(weights) 62 | if use_gpu: 63 | weights = weights.cuda() 64 | embedding = torch.nn.Embedding.from_pretrained(weights, freeze=freeze) 65 | return embedding 66 | 67 | def __len__(self): 68 | return len(self.id2s) 69 | 70 | 71 | class Label: 72 | def __init__(self, name, counter, specials=None): 73 | self.name = name 74 | self.counter = counter.copy() 75 | self.label2id = {} 76 | self.id2label = [] 77 | if specials: 78 | for label in specials: 79 | del self.counter[label] 80 | self.label2id[label] = len(self.label2id) 81 | self.id2label.append(label) 82 | 83 | for label, freq in self.counter.items(): 84 | if label not in self.label2id: 85 | self.label2id[label] = len(self.label2id) 86 | self.id2label.append(label) 87 | logger.info("label %s size %d" % (name, len(self))) 88 | 89 | def __getitem__(self, item): 90 | return self.label2id[item] 91 | 92 | def __len__(self): 93 | return len(self.id2label) 94 | -------------------------------------------------------------------------------- /en_dp_gan/model/parser.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import numpy as np 3 | from config import ids2nr 4 | from structure.rst_tree import rst_tree 5 | 6 | 7 | class PartitionPtrParser: 8 | def __init__(self, model): 9 | self.model = model 10 | self.ids2nr = ids2nr 11 | 12 | def parse(self, edus, ret_session=False): 13 | session = self.model.init_session(edus) 14 | while not session.terminate(): 15 | split_score, nr_score, state = self.model(session) 16 | split = split_score.argmax() 17 | nr = self.ids2nr[nr_score[split].argmax()] 18 | nuclear, relation = nr.split("-")[0], "-".join(nr.split("-")[1:]) 19 | session = session.forward(split_score, state, split, nuclear, relation) 20 | tree_parsed = self.build_rst_tree(edus, session.splits[:], session.nuclear[:], session.relations[:]) 21 | if ret_session: 22 | return tree_parsed, session 23 | else: 24 | return tree_parsed 25 | 26 | def build_rst_tree(self, edus, splits, nuclear, relations): 27 | left, split, right = splits.pop(0) 28 | nucl = nuclear.pop(0) 29 | rel = relations.pop(0) 30 | if right - split == 1: 31 | # leaf node 32 | right_node = edus[split] 33 | else: 34 | # non leaf 35 | right_node = self.build_rst_tree(edus, splits, nuclear, relations) 36 | if split - left == 1: 37 | # leaf node 38 | left_node = edus[left] 39 | else: 40 | # none leaf 41 | left_node = self.build_rst_tree(edus, splits, nuclear, relations) 42 | root = rst_tree(l_ch=left_node, r_ch=right_node, ch_ns_rel=nucl, child_rel=rel, 43 | temp_edu_span=(left_node.temp_edu_span[0], right_node.temp_edu_span[1])) 44 | return root 45 | 46 | def traverse_tree(self, root): 47 | if root.left_child is not None: 48 | print(root.temp_edu_span, ", ", root.child_NS_rel, ", ", root.child_rel) 49 | self.traverse_tree(root.right_child) 50 | self.traverse_tree(root.left_child) 51 | 52 | def draw_scores_matrix(self): 53 | scores = self.model.scores 54 | self.draw_decision_hot_map(scores) 55 | 56 | @staticmethod 57 | def draw_decision_hot_map(scores): 58 | import matplotlib 59 | import matplotlib.pyplot as plt 60 | text_colors = ["black", "white"] 61 | c_map = "YlGn" 62 | y_label = "split score" 63 | col_labels = ["split %d" % i for i in range(0, scores.shape[1])] 64 | row_labels = ["step %d" % i for i in range(1, scores.shape[0] + 1)] 65 | fig, ax = plt.subplots() 66 | im = ax.imshow(scores, cmap=c_map) 67 | c_bar = ax.figure.colorbar(im, ax=ax) 68 | c_bar.ax.set_ylabel(y_label, rotation=-90, va="bottom") 69 | ax.set_xticks(np.arange(scores.shape[1])) 70 | ax.set_yticks(np.arange(scores.shape[0])) 71 | ax.set_xticklabels(col_labels) 72 | ax.set_yticklabels(row_labels) 73 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) 74 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") 75 | for edge, spine in ax.spines.items(): 76 | spine.set_visible(False) 77 | ax.set_xticks(np.arange(scores.shape[1] + 1) - .5, minor=True) 78 | ax.set_yticks(np.arange(scores.shape[0] + 1) - .5, minor=True) 79 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 80 | ax.tick_params(which="minor", bottom=False, left=False) 81 | threshold = im.norm(scores.max()) / 2. 82 | val_fmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}") 83 | texts = [] 84 | kw = dict(horizontalalignment="center", verticalalignment="center") 85 | for i in range(scores.shape[0]): 86 | for j in range(scores.shape[1]): 87 | kw.update(color=text_colors[im.norm(scores[i, j]) > threshold]) 88 | text = im.axes.text(j, i, val_fmt(scores[i, j], None), **kw) 89 | texts.append(text) 90 | fig.tight_layout() 91 | plt.show() 92 | -------------------------------------------------------------------------------- /en_dp_gan/util/drawer.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def draw_line(x, y_s, colors, labels, shapes, begins, lines_d, l_colors, step=1, x_name="STEP", y_name="LOSS"): 5 | fig = plt.figure(figsize=(5, 3.8), facecolor="white") 6 | ax1 = fig.add_subplot(1, 1, 1) 7 | ax1.set_xlabel(x_name) 8 | ax1.set_ylabel(y_name) 9 | for idx, y in enumerate(y_s): 10 | x_ = x[begins[idx]:] 11 | y_ = y[begins[idx]:] 12 | # ax1.plot(x[:], y[:], color=colors[idx], linewidth=3, marker="o", linestyle="", label=labels[idx]) 13 | ax1.scatter(x_[::step], y_[::step], color=colors[idx], marker=shapes[idx], edgecolors=colors[idx], s=20, label=labels[idx]) 14 | if lines_d[idx]: 15 | line_general = 20 16 | x_n = x_[::line_general] 17 | y_n = [] 18 | sum_ = 0. 19 | sum_idx_num = 0 20 | for idx__, y__ in enumerate(y_): 21 | sum_idx_num += 1 22 | sum_ += y__ 23 | if idx__ > 0 and idx__ % line_general == 0: 24 | y_n.append(sum_ / line_general) 25 | sum_ = 0. 26 | sum_idx_num = 0 27 | if sum_ > 0.: 28 | y_n.append(sum_ / sum_idx_num) 29 | if idx == 0: 30 | ax1.plot(x_n[:10], y_n[:10], color="PaleGreen", linewidth=2.6, linestyle="-") 31 | ax1.plot(x_n[9:], y_n[9:], color="Green", linewidth=2.6, linestyle="-") 32 | else: 33 | ax1.plot(x_n, y_n, color=l_colors[idx], linewidth=2.6, linestyle="-") 34 | plt.vlines(188, 0, 2, colors="black", linestyles="dashed", linewidth=2) 35 | plt.annotate("warm up", xy=(100, 0.45), xytext=(60, 0.15), arrowprops=dict(facecolor="white", headlength=4, 36 | headwidth=10, width=4)) 37 | plt.legend(loc="upper left") 38 | plt.grid(linestyle='-') 39 | plt.show() 40 | 41 | 42 | def draw_all(x_y_s, colors, labels, shapes, begins, lines_d, l_colors, step=1, x_name="STEP", y_name="LOSS"): 43 | fig = plt.figure(figsize=(4, 10), facecolor="white") 44 | for idx_, x_y in enumerate(x_y_s): 45 | x_a, y_a = x_y 46 | ax1 = fig.add_subplot(3, 1, idx_ + 1) 47 | ax1.set_xlabel(x_name) 48 | ax1.set_ylabel(y_name) 49 | for idx, y in enumerate(y_a): 50 | x_ = x_a[begins[idx]:] 51 | y_ = y[begins[idx]:] 52 | # ax1.plot(x[:], y[:], color=colors[idx], linewidth=3, marker="o", linestyle="", label=labels[idx]) 53 | ax1.scatter(x_[::step], y_[::step], color=colors[idx], marker=shapes[idx], edgecolors=colors[idx], s=20, 54 | label=labels[idx]) 55 | if lines_d[idx]: 56 | line_general = 20 57 | x_n = x_[::line_general] 58 | y_n = [] 59 | sum_ = 0. 60 | sum_idx_num = 0 61 | for idx__, y__ in enumerate(y_): 62 | sum_idx_num += 1 63 | sum_ += y__ 64 | if idx__ > 0 and idx__ % line_general == 0: 65 | y_n.append(sum_ / line_general) 66 | sum_ = 0. 67 | sum_idx_num = 0 68 | if sum_ > 0.: 69 | y_n.append(sum_ / sum_idx_num) 70 | if idx == 0: 71 | ax1.plot(x_n[:10], y_n[:10], color="PaleGreen", linewidth=2.6, linestyle="-") 72 | ax1.plot(x_n[9:], y_n[9:], color="Green", linewidth=2.6, linestyle="-") 73 | else: 74 | ax1.plot(x_n, y_n, color=l_colors[idx], linewidth=2.6, linestyle="-") 75 | plt.vlines(188, 0, 2, colors="black", linestyles="dashed", linewidth=2) 76 | plt.annotate("warm up", xy=(100, 0.45), xytext=(60, 0.15), arrowprops=dict(facecolor="white", headlength=4, 77 | headwidth=10, width=4)) 78 | plt.legend(loc="upper left") 79 | plt.grid(linestyle='-') 80 | plt.show() 81 | 82 | 83 | if __name__ == "__main__": 84 | pass 85 | -------------------------------------------------------------------------------- /en_dp_gan/util/eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from structure import EDU, Sentence, Relation 3 | import numpy as np 4 | 5 | 6 | def factorize_tree(tree, binarize=True): 7 | quads = set() # (span, nuclear, coarse relation, fine relation) 8 | 9 | def factorize(root, offset=0): 10 | if isinstance(root, EDU): 11 | return [(offset, offset+len(root.text))] 12 | elif isinstance(root, Sentence): 13 | children_spans = [] 14 | for child in root: 15 | spans = factorize(child, offset) 16 | children_spans.extend(spans) 17 | offset = spans[-1][1] 18 | return children_spans 19 | elif isinstance(root, Relation): 20 | children_spans = [] 21 | for child in root: 22 | spans = factorize(child, offset) 23 | children_spans.extend(spans) 24 | offset = spans[-1][1] 25 | if binarize: 26 | while len(children_spans) >= 2: 27 | right = children_spans.pop() 28 | left = children_spans.pop() 29 | quads.add(((left, right), root.nuclear, root.ctype, root.ftype)) 30 | children_spans.append((left[0], right[1])) 31 | else: 32 | quads.add((tuple(children_spans), root.nuclear, root.ctype, root.ftype)) 33 | return [(children_spans[0][0], children_spans[-1][1])] 34 | 35 | factorize(tree.root_relation()) 36 | return quads 37 | 38 | 39 | def f1_score(num_corr, num_gold, num_pred, treewise_average=True): 40 | num_corr = num_corr.astype(np.float) 41 | num_gold = num_gold.astype(np.float) 42 | num_pred = num_pred.astype(np.float) 43 | 44 | if treewise_average: 45 | precision = (np.nan_to_num(num_corr / num_pred)).mean() 46 | recall = (np.nan_to_num(num_corr / num_gold)).mean() 47 | else: 48 | precision = np.nan_to_num(num_corr.sum() / num_pred.sum()) 49 | recall = np.nan_to_num(num_corr.sum() / num_gold.sum()) 50 | 51 | if precision + recall == 0: 52 | f1 = 0. 53 | else: 54 | f1 = 2. * precision * recall / (precision + recall) 55 | return precision, recall, f1 56 | 57 | 58 | def evaluation_trees(parses, golds, binarize=True, treewise_avearge=True): 59 | num_gold = np.zeros(len(golds)) 60 | num_parse = np.zeros(len(parses)) 61 | num_corr_span = np.zeros(len(parses)) 62 | num_corr_nuc = np.zeros(len(parses)) 63 | num_corr_ctype = np.zeros(len(parses)) 64 | num_corr_ftype = np.zeros(len(parses)) 65 | 66 | for i, (parse, gold) in enumerate(zip(parses, golds)): 67 | parse_factorized = factorize_tree(parse, binarize=binarize) 68 | gold_factorized = factorize_tree(gold, binarize=binarize) 69 | num_parse[i] = len(parse_factorized) 70 | num_gold[i] = len(gold_factorized) 71 | 72 | # index quads by child spans 73 | parse_dict = {quad[0]: quad for quad in parse_factorized} 74 | gold_dict = {quad[0]: quad for quad in gold_factorized} 75 | 76 | # find correct spans 77 | gold_spans = set(gold_dict.keys()) 78 | parse_spans = set(parse_dict.keys()) 79 | corr_spans = gold_spans & parse_spans 80 | num_corr_span[i] = len(corr_spans) 81 | 82 | # quad [span, nuclear, ftype, ctype] 83 | for span in corr_spans: 84 | # count correct nuclear 85 | num_corr_nuc[i] += 1 if parse_dict[span][1] == gold_dict[span][1] else 0 86 | # count correct ctype 87 | num_corr_ctype[i] += 1 if parse_dict[span][2] == gold_dict[span][2] else 0 88 | # count correct ftype 89 | num_corr_ftype[i] += 1 if parse_dict[span][3] == gold_dict[span][3] else 0 90 | 91 | span_score = f1_score(num_corr_span, num_gold, num_parse, treewise_average=treewise_avearge) 92 | nuc_score = f1_score(num_corr_nuc, num_gold, num_parse, treewise_average=treewise_avearge) 93 | ctype_score = f1_score(num_corr_ctype, num_gold, num_parse, treewise_average=treewise_avearge) 94 | ftype_score = f1_score(num_corr_ftype, num_gold, num_parse, treewise_average=treewise_avearge) 95 | return span_score, nuc_score, ctype_score, ftype_score 96 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/parser.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import numpy as np 3 | from config import ids2nr 4 | from structure.rst_tree import rst_tree 5 | 6 | 7 | class PartitionPtrParser: 8 | def __init__(self, model): 9 | self.model = model 10 | self.ids2nr = ids2nr 11 | 12 | def parse(self, edus, ret_session=False): 13 | # TODO implement beam search 14 | session = self.model.init_session(edus) 15 | while not session.terminate(): 16 | split_score, nr_score, state = self.model(session) 17 | split = split_score.argmax() 18 | nr = self.ids2nr[nr_score[split].argmax()] 19 | nuclear, relation = nr.split("-")[0], "-".join(nr.split("-")[1:]) 20 | session = session.forward(split_score, state, split, nuclear, relation) 21 | tree_parsed = self.build_rst_tree(edus, session.splits[:], session.nuclear[:], session.relations[:]) 22 | if ret_session: 23 | return tree_parsed, session 24 | else: 25 | return tree_parsed 26 | 27 | def build_rst_tree(self, edus, splits, nuclear, relations): 28 | left, split, right = splits.pop(0) 29 | nucl = nuclear.pop(0) 30 | rel = relations.pop(0) 31 | if right - split == 1: 32 | # leaf node 33 | right_node = edus[split] 34 | else: 35 | # non leaf 36 | right_node = self.build_rst_tree(edus, splits, nuclear, relations) 37 | if split - left == 1: 38 | # leaf node 39 | left_node = edus[left] 40 | else: 41 | # none leaf 42 | left_node = self.build_rst_tree(edus, splits, nuclear, relations) 43 | root = rst_tree(l_ch=left_node, r_ch=right_node, ch_ns_rel=nucl, child_rel=rel, 44 | temp_edu_span=(left_node.temp_edu_span[0], right_node.temp_edu_span[1])) 45 | return root 46 | 47 | def traverse_tree(self, root): 48 | if root.left_child is not None: 49 | print(root.temp_edu_span, ", ", root.child_NS_rel, ", ", root.child_rel) 50 | self.traverse_tree(root.right_child) 51 | self.traverse_tree(root.left_child) 52 | 53 | def draw_scores_matrix(self): 54 | scores = self.model.scores 55 | self.draw_decision_hot_map(scores) 56 | 57 | @staticmethod 58 | def draw_decision_hot_map(scores): 59 | import matplotlib 60 | import matplotlib.pyplot as plt 61 | text_colors = ["black", "white"] 62 | c_map = "YlGn" 63 | y_label = "split score" 64 | col_labels = ["split %d" % i for i in range(0, scores.shape[1])] 65 | row_labels = ["step %d" % i for i in range(1, scores.shape[0] + 1)] 66 | fig, ax = plt.subplots() 67 | im = ax.imshow(scores, cmap=c_map) 68 | c_bar = ax.figure.colorbar(im, ax=ax) 69 | c_bar.ax.set_ylabel(y_label, rotation=-90, va="bottom") 70 | ax.set_xticks(np.arange(scores.shape[1])) 71 | ax.set_yticks(np.arange(scores.shape[0])) 72 | ax.set_xticklabels(col_labels) 73 | ax.set_yticklabels(row_labels) 74 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) 75 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") 76 | for edge, spine in ax.spines.items(): 77 | spine.set_visible(False) 78 | ax.set_xticks(np.arange(scores.shape[1] + 1) - .5, minor=True) 79 | ax.set_yticks(np.arange(scores.shape[0] + 1) - .5, minor=True) 80 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 81 | ax.tick_params(which="minor", bottom=False, left=False) 82 | threshold = im.norm(scores.max()) / 2. 83 | val_fmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}") 84 | texts = [] 85 | kw = dict(horizontalalignment="center", verticalalignment="center") 86 | for i in range(scores.shape[0]): 87 | for j in range(scores.shape[1]): 88 | kw.update(color=text_colors[im.norm(scores[i, j]) > threshold]) 89 | text = im.axes.text(j, i, val_fmt(scores[i, j], None), **kw) 90 | texts.append(text) 91 | fig.tight_layout() 92 | plt.show() 93 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/util/eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from structure import EDU, Sentence, Relation 3 | import numpy as np 4 | 5 | 6 | def factorize_tree(tree, binarize=True): 7 | quads = set() # (span, nuclear, coarse relation, fine relation) 8 | 9 | def factorize(root, offset=0): 10 | if isinstance(root, EDU): 11 | return [(offset, offset+len(root.text))] 12 | elif isinstance(root, Sentence): 13 | children_spans = [] 14 | for child in root: 15 | spans = factorize(child, offset) 16 | children_spans.extend(spans) 17 | offset = spans[-1][1] 18 | return children_spans 19 | elif isinstance(root, Relation): 20 | children_spans = [] 21 | for child in root: 22 | spans = factorize(child, offset) 23 | children_spans.extend(spans) 24 | offset = spans[-1][1] 25 | if binarize: 26 | while len(children_spans) >= 2: 27 | right = children_spans.pop() 28 | left = children_spans.pop() 29 | quads.add(((left, right), root.nuclear, root.ctype, root.ftype)) 30 | children_spans.append((left[0], right[1])) 31 | else: 32 | quads.add((tuple(children_spans), root.nuclear, root.ctype, root.ftype)) 33 | return [(children_spans[0][0], children_spans[-1][1])] 34 | 35 | factorize(tree.root_relation()) 36 | return quads 37 | 38 | 39 | def f1_score(num_corr, num_gold, num_pred, treewise_average=True): 40 | num_corr = num_corr.astype(np.float) 41 | num_gold = num_gold.astype(np.float) 42 | num_pred = num_pred.astype(np.float) 43 | 44 | if treewise_average: 45 | precision = (np.nan_to_num(num_corr / num_pred)).mean() 46 | recall = (np.nan_to_num(num_corr / num_gold)).mean() 47 | else: 48 | precision = np.nan_to_num(num_corr.sum() / num_pred.sum()) 49 | recall = np.nan_to_num(num_corr.sum() / num_gold.sum()) 50 | 51 | if precision + recall == 0: 52 | f1 = 0. 53 | else: 54 | f1 = 2. * precision * recall / (precision + recall) 55 | return precision, recall, f1 56 | 57 | 58 | def evaluation_trees(parses, golds, binarize=True, treewise_avearge=True): 59 | num_gold = np.zeros(len(golds)) 60 | num_parse = np.zeros(len(parses)) 61 | num_corr_span = np.zeros(len(parses)) 62 | num_corr_nuc = np.zeros(len(parses)) 63 | num_corr_ctype = np.zeros(len(parses)) 64 | num_corr_ftype = np.zeros(len(parses)) 65 | 66 | for i, (parse, gold) in enumerate(zip(parses, golds)): 67 | parse_factorized = factorize_tree(parse, binarize=binarize) 68 | gold_factorized = factorize_tree(gold, binarize=binarize) 69 | num_parse[i] = len(parse_factorized) 70 | num_gold[i] = len(gold_factorized) 71 | 72 | # index quads by child spans 73 | parse_dict = {quad[0]: quad for quad in parse_factorized} 74 | gold_dict = {quad[0]: quad for quad in gold_factorized} 75 | 76 | # find correct spans 77 | gold_spans = set(gold_dict.keys()) 78 | parse_spans = set(parse_dict.keys()) 79 | corr_spans = gold_spans & parse_spans 80 | num_corr_span[i] = len(corr_spans) 81 | 82 | # quad [span, nuclear, ftype, ctype] 83 | for span in corr_spans: 84 | # count correct nuclear 85 | num_corr_nuc[i] += 1 if parse_dict[span][1] == gold_dict[span][1] else 0 86 | # count correct ctype 87 | num_corr_ctype[i] += 1 if parse_dict[span][2] == gold_dict[span][2] else 0 88 | # count correct ftype 89 | num_corr_ftype[i] += 1 if parse_dict[span][3] == gold_dict[span][3] else 0 90 | 91 | span_score = f1_score(num_corr_span, num_gold, num_parse, treewise_average=treewise_avearge) 92 | nuc_score = f1_score(num_corr_nuc, num_gold, num_parse, treewise_average=treewise_avearge) 93 | ctype_score = f1_score(num_corr_ctype, num_gold, num_parse, treewise_average=treewise_avearge) 94 | ftype_score = f1_score(num_corr_ftype, num_gold, num_parse, treewise_average=treewise_avearge) 95 | return span_score, nuc_score, ctype_score, ftype_score 96 | -------------------------------------------------------------------------------- /en_dp_gan/model/stacked_parser_tdt_gan3/parser.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import numpy as np 3 | import torch 4 | from config import ids2nr, SEED 5 | from structure.rst_tree import rst_tree 6 | import random 7 | random.seed(SEED) 8 | torch.manual_seed(SEED) 9 | np.random.seed(SEED) 10 | 11 | 12 | class PartitionPtrParser: 13 | def __init__(self, model): 14 | self.model = model 15 | self.ids2nr = ids2nr 16 | 17 | def parse(self, edus, ret_session=False): 18 | # TODO implement beam search 19 | session = self.model.init_session(edus) 20 | d_masks, splits = None, [] 21 | while not session.terminate(): 22 | split_score, nr_score, state, d_mask = self.model(session) 23 | d_masks = d_mask if d_masks is None else torch.cat((d_masks, d_mask), 1) 24 | split = split_score.argmax() 25 | nr = self.ids2nr[nr_score[split].argmax()] 26 | nuclear, relation = nr.split("-")[0], "-".join(nr.split("-")[1:]) 27 | session = session.forward(split_score, state, split, nuclear, relation) 28 | 29 | # build tree by splits (left, split, right) 30 | tree_parsed = self.build_rst_tree(edus, session.splits[:], session.nuclear[:], session.relations[:]) 31 | if ret_session: 32 | return tree_parsed, session 33 | else: 34 | return tree_parsed 35 | 36 | def build_rst_tree(self, edus, splits, nuclear, relations, type_="Root", rel_=None): 37 | left, split, right = splits.pop(0) 38 | nucl = nuclear.pop(0) 39 | rel = relations.pop(0) 40 | left_n, right_n = nucl[0], nucl[1] 41 | left_rel = rel if left_n == "N" else "span" 42 | right_rel = rel if right_n == "N" else "span" 43 | if right - split == 0: 44 | # leaf node 45 | right_node = edus[split + 1] 46 | else: 47 | # non leaf 48 | right_node = self.build_rst_tree(edus, splits, nuclear, relations, type_=right_n, rel_=right_rel) 49 | if split - left == 0: 50 | # leaf node 51 | left_node = edus[split] 52 | else: 53 | # none leaf 54 | left_node = self.build_rst_tree(edus, splits, nuclear, relations, type_=left_n, rel_=left_rel) 55 | node_height = max(left_node.node_height, right_node.node_height) + 1 56 | root = rst_tree(l_ch=left_node, r_ch=right_node, ch_ns_rel=nucl, child_rel=rel, 57 | temp_edu_span=(left_node.temp_edu_span[0], right_node.temp_edu_span[1]), 58 | node_height=node_height, type_=type_, rel=rel_) 59 | return root 60 | 61 | def traverse_tree(self, root): 62 | if root.left_child is not None: 63 | print(root.temp_edu_span, ", ", root.child_NS_rel, ", ", root.child_rel) 64 | self.traverse_tree(root.right_child) 65 | self.traverse_tree(root.left_child) 66 | 67 | def draw_scores_matrix(self): 68 | scores = self.model.scores 69 | self.draw_decision_hot_map(scores) 70 | 71 | @staticmethod 72 | def draw_decision_hot_map(scores): 73 | import matplotlib 74 | import matplotlib.pyplot as plt 75 | text_colors = ["black", "white"] 76 | c_map = "YlGn" 77 | y_label = "split score" 78 | col_labels = ["split %d" % i for i in range(0, scores.shape[1])] 79 | row_labels = ["step %d" % i for i in range(1, scores.shape[0] + 1)] 80 | fig, ax = plt.subplots() 81 | im = ax.imshow(scores, cmap=c_map) 82 | c_bar = ax.figure.colorbar(im, ax=ax) 83 | c_bar.ax.set_ylabel(y_label, rotation=-90, va="bottom") 84 | ax.set_xticks(np.arange(scores.shape[1])) 85 | ax.set_yticks(np.arange(scores.shape[0])) 86 | ax.set_xticklabels(col_labels) 87 | ax.set_yticklabels(row_labels) 88 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) 89 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") 90 | for edge, spine in ax.spines.items(): 91 | spine.set_visible(False) 92 | ax.set_xticks(np.arange(scores.shape[1] + 1) - .5, minor=True) 93 | ax.set_yticks(np.arange(scores.shape[0] + 1) - .5, minor=True) 94 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 95 | ax.tick_params(which="minor", bottom=False, left=False) 96 | threshold = im.norm(scores.max()) / 2. 97 | val_fmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}") 98 | texts = [] 99 | kw = dict(horizontalalignment="center", verticalalignment="center") 100 | for i in range(scores.shape[0]): 101 | for j in range(scores.shape[1]): 102 | kw.update(color=text_colors[im.norm(scores[i, j]) > threshold]) 103 | text = im.axes.text(j, i, val_fmt(scores[i, j], None), **kw) 104 | texts.append(text) 105 | fig.tight_layout() 106 | plt.show() 107 | -------------------------------------------------------------------------------- /ch_dp_gan/berkeleyparser/README: -------------------------------------------------------------------------------- 1 | "THE BERKELEY PARSER" 2 | release 1.1 3 | migrated from Google Code to GitHub July 2015 4 | 5 | This package contains the Berkeley Parser as described in 6 | 7 | "Learning Accurate, Compact, and Interpretable Tree Annotation" 8 | Slav Petrov, Leon Barrett, Romain Thibaux and Dan Klein 9 | in COLING-ACL 2006 10 | 11 | and 12 | 13 | "Improved Inference for Unlexicalized Parsing" 14 | Slav Petrov and Dan Klein 15 | in HLT-NAACL 2007 16 | 17 | If you use this code in your research and would like to acknowledge it, please refer to one of those publications. Note that the jar-archive also contains all source files. For questions please contact Slav Petrov (petrov@cs.berkeley.edu). 18 | 19 | 20 | * PARSING 21 | The main class of the jar-archive is the parser. By default, it will read in PTB tokenized sentences from STDIN (one per line) and write parse trees to STDOUT. It can be evoked with: 22 | 23 | java -jar berkeleyParser.jar -gr 24 | 25 | The parser can produce k-best lists and parse in parallel using multiple threads. Several additional options are also available (return binarized and/or annotated trees, produce an image of the parse tree, tokenize the input, run in fast/accurate mode, print out tree likelihoods, etc.). Starting the parser without supplying a grammar file will print a list of all options. 26 | 27 | * ADDITIONAL TOOLS 28 | A tool for annotating parse trees with their most likely Viterbi derivation over refined categories and scoring the subtrees can be started with 29 | 30 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA/TreeLabeler -gr 31 | 32 | This tool reads in parse trees from STDIN, annotates them as specified and prints them out to STDOUT. You can use 33 | 34 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.TreeScorer -gr 35 | 36 | to compute the (log-)likelihood of a parse tree. 37 | 38 | 39 | * GRAMMARS 40 | Included are grammars for English, German and Chinese. For parsing English text which is not from the Wall Street Journal, we recommend that you use the English grammar after 5 split&merge iterations as experiments suggest that the 6 split&merge iterations grammars are overfitting the Wall Street Journal. Because of the coarse-to-fine method used by the parser, there is essentially no difference in parsing time between the different grammars. 41 | 42 | 43 | * LEARNING NEW GRAMMARS 44 | You will need a treebank in order to learn new grammars. The package contains code for reading in some of the standard treebanks. To learn a grammar from the Wall Street Journal section of the Penn Treebank, you can execute 45 | 46 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTrainer -path -out 47 | 48 | To learn a grammar from trees that are contained in a single file use the -treebank option, e.g.: 49 | 50 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTrainer -path -out -treebank SINGLEFILE 51 | 52 | This will read in the WSJ training set and do 6 iterations of split, merge, smooth. An intermediate grammar file will be written to disk once in a while and you can expect the final grammar to be written to after 15-20 hours. The GrammarTrainer accepts a variety of options which have been set to reasonable default values. Most of the options should be self-explaining and you are encouraged to experiment with them. Note that since EM is a local method each run will produce slightly different results. Furthermore, the default settings prune away rules with probability below a certain threshold, which greatly speeds up the training, but increases the variance. To train grammars on other training sets (e.g. for other languages), consult edu.berkeley.nlp.PCFGLA.Corpus.java and supply the correct language option to the trainer. 53 | To the test the performance of a grammar you can use 54 | 55 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTester -path -in 56 | 57 | 58 | * WRITING GRAMMARS TO TEXT FILES 59 | The parser reads and writes grammar files as serialized java classes. To view the grammars, you can export them to text format with: 60 | 61 | java -cp berkeleyParser.jar edu/berkeley/nlp/PCFGLA/WriteGrammarToTextFile 62 | 63 | This will create three text files. outname.grammar and outname.lexicon contain the respective rule scores and outname.words should be used with the included perl script to map words to their signatures. 64 | 65 | * UNKNOWN WORDS 66 | The lexicon contains arrays with scores for each (tag,signature) pair. The array entries correspond to the scores for the respective tag substates. The signatures of known words are the words themselves. Unknown words, in contrast, are classified into a set of unknown word categories. The perl script "getSignature" takes as first argument a file containing the known words (presumably produced by WriteGrammarToFile). It then reads words from STDIN and returns the word signature to STDOUT. The signatures should be used to look up the tagging probabilities of words in the lexicon file. 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/berkeleyparser/README: -------------------------------------------------------------------------------- 1 | "THE BERKELEY PARSER" 2 | release 1.1 3 | migrated from Google Code to GitHub July 2015 4 | 5 | This package contains the Berkeley Parser as described in 6 | 7 | "Learning Accurate, Compact, and Interpretable Tree Annotation" 8 | Slav Petrov, Leon Barrett, Romain Thibaux and Dan Klein 9 | in COLING-ACL 2006 10 | 11 | and 12 | 13 | "Improved Inference for Unlexicalized Parsing" 14 | Slav Petrov and Dan Klein 15 | in HLT-NAACL 2007 16 | 17 | If you use this code in your research and would like to acknowledge it, please refer to one of those publications. Note that the jar-archive also contains all source files. For questions please contact Slav Petrov (petrov@cs.berkeley.edu). 18 | 19 | 20 | * PARSING 21 | The main class of the jar-archive is the parser. By default, it will read in PTB tokenized sentences from STDIN (one per line) and write parse trees to STDOUT. It can be evoked with: 22 | 23 | java -jar berkeleyParser.jar -gr 24 | 25 | The parser can produce k-best lists and parse in parallel using multiple threads. Several additional options are also available (return binarized and/or annotated trees, produce an image of the parse tree, tokenize the input, run in fast/accurate mode, print out tree likelihoods, etc.). Starting the parser without supplying a grammar file will print a list of all options. 26 | 27 | * ADDITIONAL TOOLS 28 | A tool for annotating parse trees with their most likely Viterbi derivation over refined categories and scoring the subtrees can be started with 29 | 30 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA/TreeLabeler -gr 31 | 32 | This tool reads in parse trees from STDIN, annotates them as specified and prints them out to STDOUT. You can use 33 | 34 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.TreeScorer -gr 35 | 36 | to compute the (log-)likelihood of a parse tree. 37 | 38 | 39 | * GRAMMARS 40 | Included are grammars for English, German and Chinese. For parsing English text which is not from the Wall Street Journal, we recommend that you use the English grammar after 5 split&merge iterations as experiments suggest that the 6 split&merge iterations grammars are overfitting the Wall Street Journal. Because of the coarse-to-fine method used by the parser, there is essentially no difference in parsing time between the different grammars. 41 | 42 | 43 | * LEARNING NEW GRAMMARS 44 | You will need a treebank in order to learn new grammars. The package contains code for reading in some of the standard treebanks. To learn a grammar from the Wall Street Journal section of the Penn Treebank, you can execute 45 | 46 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTrainer -path -out 47 | 48 | To learn a grammar from trees that are contained in a single file use the -treebank option, e.g.: 49 | 50 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTrainer -path -out -treebank SINGLEFILE 51 | 52 | This will read in the WSJ training set and do 6 iterations of split, merge, smooth. An intermediate grammar file will be written to disk once in a while and you can expect the final grammar to be written to after 15-20 hours. The GrammarTrainer accepts a variety of options which have been set to reasonable default values. Most of the options should be self-explaining and you are encouraged to experiment with them. Note that since EM is a local method each run will produce slightly different results. Furthermore, the default settings prune away rules with probability below a certain threshold, which greatly speeds up the training, but increases the variance. To train grammars on other training sets (e.g. for other languages), consult edu.berkeley.nlp.PCFGLA.Corpus.java and supply the correct language option to the trainer. 53 | To the test the performance of a grammar you can use 54 | 55 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTester -path -in 56 | 57 | 58 | * WRITING GRAMMARS TO TEXT FILES 59 | The parser reads and writes grammar files as serialized java classes. To view the grammars, you can export them to text format with: 60 | 61 | java -cp berkeleyParser.jar edu/berkeley/nlp/PCFGLA/WriteGrammarToTextFile 62 | 63 | This will create three text files. outname.grammar and outname.lexicon contain the respective rule scores and outname.words should be used with the included perl script to map words to their signatures. 64 | 65 | * UNKNOWN WORDS 66 | The lexicon contains arrays with scores for each (tag,signature) pair. The array entries correspond to the scores for the respective tag substates. The signatures of known words are the words themselves. Unknown words, in contrast, are classified into a set of unknown word categories. The perl script "getSignature" takes as first argument a file containing the known words (presumably produced by WriteGrammarToFile). It then reads words from STDIN and returns the word signature to STDOUT. The signatures should be used to look up the tagging probabilities of words in the lexicon file. 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/parser.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import numpy as np 3 | import torch 4 | from config import ids2nr, SEED 5 | from structure.rst_tree import rst_tree 6 | import random 7 | random.seed(SEED) 8 | torch.manual_seed(SEED) 9 | np.random.seed(SEED) 10 | 11 | 12 | class PartitionPtrParser: 13 | def __init__(self, model): 14 | self.ids2nr = ids2nr 15 | 16 | def parse(self, edus, ret_session=False, model=None, model_xl=None, tokenizer_xl=None): 17 | # TODO implement beam search 18 | session = model.init_session(edus, model_xl, tokenizer_xl) # 初始化状态 19 | d_masks, splits = None, [] 20 | while not session.terminate(): 21 | split_score, nr_score, state, d_mask = model.parse_predict(session) 22 | d_masks = d_mask if d_masks is None else torch.cat((d_masks, d_mask), 1) 23 | split = split_score.argmax() 24 | nr = self.ids2nr[nr_score[split].argmax()] 25 | nuclear, relation = nr.split("-")[0], "-".join(nr.split("-")[1:]) 26 | session = session.forward(split_score, state, split, nuclear, relation) 27 | # build tree by splits (left, split, right) 28 | tree_parsed = self.build_rst_tree(edus, session.splits[:], session.nuclear[:], session.relations[:]) 29 | if ret_session: 30 | return tree_parsed, session 31 | else: 32 | return tree_parsed 33 | 34 | def build_rst_tree(self, edus, splits, nuclear, relations, type_="Root", rel_=None): 35 | left, split, right = splits.pop(0) 36 | nucl = nuclear.pop(0) 37 | rel = relations.pop(0) 38 | left_n, right_n = nucl[0], nucl[1] 39 | left_rel = rel if left_n == "N" else "span" 40 | right_rel = rel if right_n == "N" else "span" 41 | if right - split == 0: 42 | # leaf node 43 | right_node = edus[split + 1] 44 | else: 45 | # non leaf 46 | right_node = self.build_rst_tree(edus, splits, nuclear, relations, type_=right_n, rel_=right_rel) 47 | if split - left == 0: 48 | # leaf node 49 | left_node = edus[split] 50 | else: 51 | # none leaf 52 | left_node = self.build_rst_tree(edus, splits, nuclear, relations, type_=left_n, rel_=left_rel) 53 | node_height = max(left_node.node_height, right_node.node_height) + 1 54 | root = rst_tree(l_ch=left_node, r_ch=right_node, ch_ns_rel=nucl, child_rel=rel, 55 | temp_edu_span=(left_node.temp_edu_span[0], right_node.temp_edu_span[1]), 56 | node_height=node_height, type_=type_, rel=rel_) 57 | return root 58 | 59 | def traverse_tree(self, root): 60 | if root.left_child is not None: 61 | print(root.temp_edu_span, ", ", root.child_NS_rel, ", ", root.child_rel) 62 | self.traverse_tree(root.right_child) 63 | self.traverse_tree(root.left_child) 64 | 65 | def draw_scores_matrix(self, model): 66 | scores = model.scores 67 | self.draw_decision_hot_map(scores) 68 | 69 | @staticmethod 70 | def draw_decision_hot_map(scores): 71 | import matplotlib 72 | import matplotlib.pyplot as plt 73 | text_colors = ["black", "white"] 74 | c_map = "YlGn" 75 | y_label = "split score" 76 | col_labels = ["split %d" % i for i in range(0, scores.shape[1])] 77 | row_labels = ["step %d" % i for i in range(1, scores.shape[0] + 1)] 78 | fig, ax = plt.subplots() 79 | im = ax.imshow(scores, cmap=c_map) 80 | c_bar = ax.figure.colorbar(im, ax=ax) 81 | c_bar.ax.set_ylabel(y_label, rotation=-90, va="bottom") 82 | ax.set_xticks(np.arange(scores.shape[1])) 83 | ax.set_yticks(np.arange(scores.shape[0])) 84 | ax.set_xticklabels(col_labels) 85 | ax.set_yticklabels(row_labels) 86 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) 87 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") 88 | for edge, spine in ax.spines.items(): 89 | spine.set_visible(False) 90 | ax.set_xticks(np.arange(scores.shape[1] + 1) - .5, minor=True) 91 | ax.set_yticks(np.arange(scores.shape[0] + 1) - .5, minor=True) 92 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 93 | ax.tick_params(which="minor", bottom=False, left=False) 94 | threshold = im.norm(scores.max()) / 2. 95 | val_fmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}") 96 | texts = [] 97 | kw = dict(horizontalalignment="center", verticalalignment="center") 98 | for i in range(scores.shape[0]): 99 | for j in range(scores.shape[1]): 100 | kw.update(color=text_colors[im.norm(scores[i, j]) > threshold]) 101 | text = im.axes.text(j, i, val_fmt(scores[i, j], None), **kw) 102 | texts.append(text) 103 | fig.tight_layout() 104 | plt.show() 105 | -------------------------------------------------------------------------------- /ch_dp_gan/dataset/cdtb.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | import os 4 | import re 5 | import pickle 6 | import gzip 7 | import hashlib 8 | import thulac 9 | import tqdm 10 | import logging 11 | from itertools import chain 12 | from nltk.tree import Tree as ParseTree 13 | from structure import Discourse, Sentence, EDU, TEXT, Connective, node_type_filter 14 | 15 | 16 | # note: this module will print a junk line "Model loaded succeed" to stdio when initializing 17 | thulac = thulac.thulac() 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class CDTB: 22 | def __init__(self, cdtb_dir, train, validate, test, encoding="UTF-8", 23 | ctb_dir=None, ctb_encoding="UTF-8", cache_dir=None, preprocess=False): 24 | if cache_dir: 25 | hash_figure = "-".join(map(str, [cdtb_dir, train, validate, test, ctb_dir, preprocess])) 26 | hash_key = hashlib.md5(hash_figure.encode()).hexdigest() 27 | cache = os.path.join(cache_dir, hash_key + ".gz") 28 | else: 29 | cache = None 30 | 31 | if cache is not None and os.path.isfile(cache): 32 | with gzip.open(cache, "rb") as cache_fd: 33 | self.preprocess, self.ctb = pickle.load(cache_fd) 34 | self.train = pickle.load(cache_fd) 35 | self.validate = pickle.load(cache_fd) 36 | self.test = pickle.load(cache_fd) 37 | logger.info("load cached dataset from %s" % cache) 38 | return 39 | 40 | self.preprocess = preprocess 41 | self.ctb = self.load_ctb(ctb_dir, ctb_encoding) if ctb_dir else {} 42 | self.train = self.load_dir(cdtb_dir, train, encoding=encoding) 43 | self.validate = self.load_dir(cdtb_dir, validate, encoding=encoding) 44 | self.test = self.load_dir(cdtb_dir, test, encoding=encoding) 45 | 46 | if preprocess: 47 | for discourse in tqdm.tqdm(chain(self.train, self.validate, self.test), desc="preprocessing"): 48 | self.preprocessing(discourse) 49 | 50 | if cache is not None: 51 | with gzip.open(cache, "wb") as cache_fd: 52 | pickle.dump((self.preprocess, self.ctb), cache_fd) 53 | pickle.dump(self.train, cache_fd) 54 | pickle.dump(self.validate, cache_fd) 55 | pickle.dump(self.test, cache_fd) 56 | logger.info("saved cached dataset to %s" % cache) 57 | 58 | def report(self): 59 | # TODO 60 | raise NotImplementedError() 61 | 62 | def preprocessing(self, discourse): 63 | for paragraph in discourse: 64 | for sentence in paragraph.iterfind(filter=node_type_filter(Sentence)): 65 | if self.ctb and (sentence.sid is not None) and (sentence.sid in self.ctb): 66 | parse = self.ctb[sentence.sid] 67 | pairs = [(node[0], node.label()) for node in parse.subtrees() 68 | if node.height() == 2 and node.label() != "-NONE-"] 69 | words, tags = list(zip(*pairs)) 70 | else: 71 | words, tags = list(zip(*thulac.cut(sentence.text))) 72 | setattr(sentence, "words", list(words)) 73 | setattr(sentence, "tags", list(tags)) 74 | 75 | offset = 0 76 | for textnode in sentence.iterfind(filter=node_type_filter([TEXT, Connective, EDU]), 77 | terminal=node_type_filter([TEXT, Connective, EDU])): 78 | if isinstance(textnode, EDU): 79 | edu_words = [] 80 | edu_tags = [] 81 | cur = 0 82 | for word, tag in zip(sentence.words, sentence.tags): 83 | if offset <= cur < cur + len(word) <= offset + len(textnode.text): 84 | edu_words.append(word) 85 | edu_tags.append(tag) 86 | cur += len(word) 87 | setattr(textnode, "words", edu_words) 88 | setattr(textnode, "tags", edu_tags) 89 | offset += len(textnode.text) 90 | return discourse 91 | 92 | @staticmethod 93 | def load_dir(path, sub, encoding="UTF-8"): 94 | train_path = os.path.join(path, sub) 95 | discourses = [] 96 | for file in os.listdir(train_path): 97 | file = os.path.join(train_path, file) 98 | discourse = Discourse.from_xml(file, encoding=encoding) 99 | discourses.append(discourse) 100 | return discourses 101 | 102 | @staticmethod 103 | def load_ctb(ctb_dir, encoding="UTF-8"): 104 | ctb = {} 105 | s_pat = re.compile("\S+?)>(?P.*?)", re.M | re.DOTALL) 106 | for file in os.listdir(ctb_dir): 107 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd: 108 | doc = fd.read() 109 | for match in s_pat.finditer(doc): 110 | sid = match.group("sid") 111 | sparse = ParseTree.fromstring(match.group("sparse")) 112 | ctb[sid] = sparse 113 | return ctb 114 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/dataset/cdtb.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | import os 4 | import re 5 | import pickle 6 | import gzip 7 | import hashlib 8 | import thulac 9 | import tqdm 10 | import logging 11 | from itertools import chain 12 | from nltk.tree import Tree as ParseTree 13 | from structure import Discourse, Sentence, EDU, TEXT, Connective, node_type_filter 14 | 15 | 16 | # note: this module will print a junk line "Model loaded succeed" to stdio when initializing 17 | thulac = thulac.thulac() 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class CDTB: 22 | def __init__(self, cdtb_dir, train, validate, test, encoding="UTF-8", 23 | ctb_dir=None, ctb_encoding="UTF-8", cache_dir=None, preprocess=False): 24 | if cache_dir: 25 | hash_figure = "-".join(map(str, [cdtb_dir, train, validate, test, ctb_dir, preprocess])) 26 | hash_key = hashlib.md5(hash_figure.encode()).hexdigest() 27 | cache = os.path.join(cache_dir, hash_key + ".gz") 28 | else: 29 | cache = None 30 | 31 | if cache is not None and os.path.isfile(cache): 32 | with gzip.open(cache, "rb") as cache_fd: 33 | self.preprocess, self.ctb = pickle.load(cache_fd) 34 | self.train = pickle.load(cache_fd) 35 | self.validate = pickle.load(cache_fd) 36 | self.test = pickle.load(cache_fd) 37 | logger.info("load cached dataset from %s" % cache) 38 | return 39 | 40 | self.preprocess = preprocess 41 | self.ctb = self.load_ctb(ctb_dir, ctb_encoding) if ctb_dir else {} 42 | self.train = self.load_dir(cdtb_dir, train, encoding=encoding) 43 | self.validate = self.load_dir(cdtb_dir, validate, encoding=encoding) 44 | self.test = self.load_dir(cdtb_dir, test, encoding=encoding) 45 | 46 | if preprocess: 47 | for discourse in tqdm.tqdm(chain(self.train, self.validate, self.test), desc="preprocessing"): 48 | self.preprocessing(discourse) 49 | 50 | if cache is not None: 51 | with gzip.open(cache, "wb") as cache_fd: 52 | pickle.dump((self.preprocess, self.ctb), cache_fd) 53 | pickle.dump(self.train, cache_fd) 54 | pickle.dump(self.validate, cache_fd) 55 | pickle.dump(self.test, cache_fd) 56 | logger.info("saved cached dataset to %s" % cache) 57 | 58 | def report(self): 59 | # TODO 60 | raise NotImplementedError() 61 | 62 | def preprocessing(self, discourse): 63 | for paragraph in discourse: 64 | for sentence in paragraph.iterfind(filter=node_type_filter(Sentence)): 65 | if self.ctb and (sentence.sid is not None) and (sentence.sid in self.ctb): 66 | parse = self.ctb[sentence.sid] 67 | pairs = [(node[0], node.label()) for node in parse.subtrees() 68 | if node.height() == 2 and node.label() != "-NONE-"] 69 | words, tags = list(zip(*pairs)) 70 | else: 71 | words, tags = list(zip(*thulac.cut(sentence.text))) 72 | setattr(sentence, "words", list(words)) 73 | setattr(sentence, "tags", list(tags)) 74 | 75 | offset = 0 76 | for textnode in sentence.iterfind(filter=node_type_filter([TEXT, Connective, EDU]), 77 | terminal=node_type_filter([TEXT, Connective, EDU])): 78 | if isinstance(textnode, EDU): 79 | edu_words = [] 80 | edu_tags = [] 81 | cur = 0 82 | for word, tag in zip(sentence.words, sentence.tags): 83 | if offset <= cur < cur + len(word) <= offset + len(textnode.text): 84 | edu_words.append(word) 85 | edu_tags.append(tag) 86 | cur += len(word) 87 | setattr(textnode, "words", edu_words) 88 | setattr(textnode, "tags", edu_tags) 89 | offset += len(textnode.text) 90 | return discourse 91 | 92 | @staticmethod 93 | def load_dir(path, sub, encoding="UTF-8"): 94 | train_path = os.path.join(path, sub) 95 | discourses = [] 96 | for file in os.listdir(train_path): 97 | file = os.path.join(train_path, file) 98 | discourse = Discourse.from_xml(file, encoding=encoding) 99 | discourses.append(discourse) 100 | return discourses 101 | 102 | @staticmethod 103 | def load_ctb(ctb_dir, encoding="UTF-8"): 104 | ctb = {} 105 | s_pat = re.compile("\S+?)>(?P.*?)", re.M | re.DOTALL) 106 | for file in os.listdir(ctb_dir): 107 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd: 108 | doc = fd.read() 109 | for match in s_pat.finditer(doc): 110 | sid = match.group("sid") 111 | sparse = ParseTree.fromstring(match.group("sparse")) 112 | ctb[sid] = sparse 113 | return ctb 114 | -------------------------------------------------------------------------------- /en_dp_gan_xlnet/structure/cdtb.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | 3 | import os 4 | import re 5 | import pickle 6 | import gzip 7 | import hashlib 8 | import thulac 9 | import tqdm 10 | import logging 11 | from itertools import chain 12 | from nltk.tree import Tree as ParseTree 13 | from structure import Discourse, Sentence, EDU, TEXT, Connective, node_type_filter 14 | 15 | 16 | # note: this module will print a junk line "Model loaded succeed" to stdio when initializing 17 | thulac = thulac.thulac() 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class CDTB: 22 | def __init__(self, cdtb_dir, train, validate, test, encoding="UTF-8", 23 | ctb_dir=None, ctb_encoding="UTF-8", cache_dir=None, preprocess=False): 24 | if cache_dir: 25 | hash_figure = "-".join(map(str, [cdtb_dir, train, validate, test, ctb_dir, preprocess])) 26 | hash_key = hashlib.md5(hash_figure.encode()).hexdigest() 27 | cache = os.path.join(cache_dir, hash_key + ".gz") 28 | else: 29 | cache = None 30 | 31 | if cache is not None and os.path.isfile(cache): 32 | with gzip.open(cache, "rb") as cache_fd: 33 | self.preprocess, self.ctb = pickle.load(cache_fd) 34 | self.train = pickle.load(cache_fd) 35 | self.validate = pickle.load(cache_fd) 36 | self.test = pickle.load(cache_fd) 37 | logger.info("load cached dataset from %s" % cache) 38 | return 39 | 40 | self.preprocess = preprocess 41 | self.ctb = self.load_ctb(ctb_dir, ctb_encoding) if ctb_dir else {} 42 | self.train = self.load_dir(cdtb_dir, train, encoding=encoding) 43 | self.validate = self.load_dir(cdtb_dir, validate, encoding=encoding) 44 | self.test = self.load_dir(cdtb_dir, test, encoding=encoding) 45 | 46 | if preprocess: 47 | for discourse in tqdm.tqdm(chain(self.train, self.validate, self.test), desc="preprocessing"): 48 | self.preprocessing(discourse) 49 | 50 | if cache is not None: 51 | with gzip.open(cache, "wb") as cache_fd: 52 | pickle.dump((self.preprocess, self.ctb), cache_fd) 53 | pickle.dump(self.train, cache_fd) 54 | pickle.dump(self.validate, cache_fd) 55 | pickle.dump(self.test, cache_fd) 56 | logger.info("saved cached dataset to %s" % cache) 57 | 58 | def report(self): 59 | # TODO 60 | raise NotImplementedError() 61 | 62 | def preprocessing(self, discourse): 63 | for paragraph in discourse: 64 | for sentence in paragraph.iterfind(filter=node_type_filter(Sentence)): 65 | if self.ctb and (sentence.sid is not None) and (sentence.sid in self.ctb): 66 | parse = self.ctb[sentence.sid] 67 | pairs = [(node[0], node.label()) for node in parse.subtrees() 68 | if node.height() == 2 and node.label() != "-NONE-"] 69 | words, tags = list(zip(*pairs)) 70 | else: 71 | words, tags = list(zip(*thulac.cut(sentence.text))) 72 | setattr(sentence, "words", list(words)) 73 | setattr(sentence, "tags", list(tags)) 74 | 75 | offset = 0 76 | for textnode in sentence.iterfind(filter=node_type_filter([TEXT, Connective, EDU]), 77 | terminal=node_type_filter([TEXT, Connective, EDU])): 78 | if isinstance(textnode, EDU): 79 | edu_words = [] 80 | edu_tags = [] 81 | cur = 0 82 | for word, tag in zip(sentence.words, sentence.tags): 83 | if offset <= cur < cur + len(word) <= offset + len(textnode.text): 84 | edu_words.append(word) 85 | edu_tags.append(tag) 86 | cur += len(word) 87 | setattr(textnode, "words", edu_words) 88 | setattr(textnode, "tags", edu_tags) 89 | offset += len(textnode.text) 90 | return discourse 91 | 92 | @staticmethod 93 | def load_dir(path, sub, encoding="UTF-8"): 94 | train_path = os.path.join(path, sub) 95 | discourses = [] 96 | for file in os.listdir(train_path): 97 | file = os.path.join(train_path, file) 98 | discourse = Discourse.from_xml(file, encoding=encoding) 99 | discourses.append(discourse) 100 | return discourses 101 | 102 | @staticmethod 103 | def load_ctb(ctb_dir, encoding="UTF-8"): 104 | ctb = {} 105 | s_pat = re.compile("\S+?)>(?P.*?)", re.M | re.DOTALL) 106 | for file in os.listdir(ctb_dir): 107 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd: 108 | doc = fd.read() 109 | for match in s_pat.finditer(doc): 110 | sid = match.group("sid") 111 | sparse = ParseTree.fromstring(match.group("sparse")) 112 | ctb[sid] = sparse 113 | return ctb 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | This project presents the recently proposed GAN-based DRS parser in 4 | **Adversarial Learning for Discourse Rhetorical Structure Parsing (ACL-IJCNLP2021)**. 5 | For any questions please directly send e-mails to zzlynx@outlook.com (Longyin Zhang). 6 | One may find that English parser is not included here, because of the close correlation 7 | between our various studies, we will release the English system in September 22. 8 | 9 | #### Installation 10 | - Python 3.6.10 11 | - transformers 3.0.2 12 | - pytorch 1.5.0 13 | - numpy 1.19.1 14 | - cudatoolkit 9.2 cudnn 7.6.5 15 | - other necessary python packages (etc.) 16 | 17 | #### Project Structure 18 | ``` 19 | ---DP_GAN 20 | |-ch_dp_gan Chinese discourse parser based on Qiu-W2V 21 | |-ch_dp_gan_xlnet Chinese discourse parser based on XLNet 22 | |-en_dp_gan English discourse parser based on GloVe and ELMo 23 | |-en_dp_gan_xlnet English discourse parser based on XLNet 24 | ``` 25 | We do not provide data in this project for the data is too large to be uploaded to Github. Please prepare 26 | the data by yourself if you want to train a parser. By the way, we will provide a pre-trained end2end DRS 27 | parser (automatic EDU segmentation and discourse parsing) in the project **sota_end2end_parser** for other 28 | researchers to directly apply it to other NLP tasks. 29 | 30 | We have explained in paper that there are some differences between the XLNet-based systems and the other 31 | systems. I re-checked the system and found that there are still two points that I forgot to mention: 32 | **(i)** We did not use the sentence and paragraph boundaries in the Chinese system, because the performance 33 | is bad; 34 | 35 | **(ii)** In the English XLNet-based parser, we did not input the EDU vectors into RNNs before split point 36 | encoding. 37 | 38 | #### Performance Evaluation 39 | 40 | As stated in the paper, we employ the *original Parseval (Morey et al. 2017)* to evaluate our English DRS 41 | parser and report the **micro-averaged F1-score** as performance. We did not report the results based on Marcu's 42 | RST Parseval because the metric will overestimate the performance level of DRS parsing. 43 | 44 | As we know, when using *RST Parseval*, we actually have 19 relation categories considered for evaluation, i.e., 45 | the 18 rhetorical relations and the **SPAN** tag. Among these tags, the SPAN relation accounts for more than a 46 | half of the relation labels in RST-DT, which may enlarge the uncertainty of performance comparison. 47 | Specifically, the systems that predict relation tags for the parent node will show weaker performance than the 48 | systems that predict the relation category of each child node. **Why?** Usually, the second kind of systems also 49 | employ SPAN tags for model training and this brings in additional gradients for the model to greedily maximize 50 | the rewards by assigning SPAN label to appropriate tree nodes. However, for the first kind of systems, the SPAN 51 | labels are assigned only according to their predicted Nuclearity category (our system belongs to this kind). 52 | 53 | Here we report the results of (Yu et al. 2018) and ours on **SPAN** for reference: 54 | ``` 55 | --- system ---------- P ---- R ---- F 56 | Yu et al. (2018) 60.9 63.7 62.3 (The parsing results are from Yu Nan) 57 | Ours 46.1 43.1 44.5 58 | ``` 59 | 60 | Obviously, it's hard to judge whether the performance improvements come from the 18 rhetorical relations or the 61 | fake relation "SPAN" when using RST Parseval. For more clear performance comparison, we explicitly recommend 62 | other DRS researchers to use the original Parseval to evaluate their parsers. 63 | 64 | For Chinese DRS parsing, we use a strict method for performance evaluation, and one can refer to 65 | https://github.com/NLP-Discourse-SoochowU/t2d_discourseparser for details. 66 | Recently, some researchers asked me a question, that is, "Why use different metrics for CDTB parsing in 67 | 2020 and 2021's parsers?" I think it necessary to give an answer here. Actually, our work on top-down RST 68 | parsing was finished in 2019 and we once send the research to ACL2019 and it was rejected. There is a 2-year 69 | distance between the two works and we had some new research during this time including our findings on parsing 70 | evaluation. We found that the RST- and original-Parseval only consider span boundaries for structure analysis 71 | and ignore the split points, which is not strict enough, and we discussed and proposed a new evaluation method 72 | and have a try in CDTB analysis. Besides, when we send the work to ACL2021, we thought we were using the 73 | original-Parseval for evaluation but later we found we just used the new evaluation method. Therefore, we re-report 74 | the results of our 2020's work using the new strict evaluation method to give inspiration for following RST 75 | researchers. 76 | 77 | #### Model Training 78 | In this project, we tuned the hyper-parameters for best performance and the details are well shown in the ACL 79 | paper. Although we had tried our best to check the correction of the paper content, we still find one inaccuracy: 80 | **We trained the XLNet systems for 50 rounds instead of the 30 written in the Appendix.** 81 | 82 | To train the parsing models, run the following commands: 83 | ``` 84 | (Chinese) python -m treebuilder.dp_nr.train 85 | (English) python main.py 86 | ``` 87 | 88 | Due to the utilization of GAN nets, we found that the parsing performance has some fluctuation, it is related to 89 | the hardware and the software running environment you use. It should be noted that some researchers may 90 | use the preprocessed RST-DT corpus for experiments, there could be some tiny differences when compared with the 91 | original data. We recommend using the original formal RST-DT corpus for experiments. 92 | 93 | We will provide a pre-trained end-to-end parser at https://github.com/NLP-Discourse-SoochowU/sota_end2end_parser, 94 | and one can directly apply it to downstream NLP tasks. 95 | 96 | #### Tips 97 | RST-style discourse parsing has long been known to be complicated, which is why the system we provide actually 98 | contains so many experimental settings. We had conducted a set of experimental comparisons in this system, and 99 | we found that some experimental details could be helpful for your own system, e.g., the EDU attention, the 100 | Context attention, the residual connection, the parameter initialization method, etc. These code 101 | details can hardly bring great research value in this period, but they will make your system more stable. 102 | 103 | -- License 104 | ``` 105 | Copyright (c) 2019, Soochow University NLP research group. All rights reserved. 106 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that 107 | the following conditions are met: 108 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the 109 | following disclaimer. 110 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the 111 | following disclaimer in the documentation and/or other materials provided with the distribution. 112 | ``` 113 | -------------------------------------------------------------------------------- /en_dp_gan/util/data_builder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | @Author: lyzhang 5 | @Date: 2018.4.5 6 | @Description: 7 | """ 8 | import progressbar 9 | from structure.tree_obj import tree_obj 10 | from stanfordcorenlp import StanfordCoreNLP 11 | from config import * 12 | from path_config import path_to_jar 13 | from structure.rst_tree import rst_tree 14 | from util.rst_utils import get_edus_info 15 | 16 | 17 | class Builder: 18 | def __init__(self): 19 | self.nlp = StanfordCoreNLP(path_to_jar) 20 | self.max_len_of_edu = 0 21 | 22 | @staticmethod 23 | def update_edu_trees(): 24 | edu_trees = load_data(RST_TEST_EDUS_TREES) 25 | test_trees = load_data(RST_TEST_TREES) 26 | tree_num = len(test_trees) 27 | new_edu_trees = [] 28 | for idx in range(tree_num): 29 | tmp_gold_tree = test_trees[idx] 30 | tmp_token_vec = [] 31 | for edu in tmp_gold_tree.edus: 32 | tmp_token_vec.extend(edu.temp_edu_emlo_emb) 33 | tmp_edu_tree = edu_trees[idx] 34 | new_edus = [] 35 | for edu in tmp_edu_tree.edus: 36 | edu_len = len(edu.temp_edu_ids) 37 | edu.temp_edu_emlo_emb = tmp_token_vec[:edu_len] 38 | tmp_token_vec = tmp_token_vec[edu_len:] 39 | new_edus.append(edu) 40 | new_tree = tree_obj() 41 | new_tree.assign_edus(new_edus) 42 | new_edu_trees.append(new_tree) 43 | save_data(new_edu_trees, RST_TEST_EDUS_TREES) 44 | 45 | def form_trees(self): 46 | # train_dir = os.path.join(RAW_RST_DT_PATH, "train") 47 | # train_tree_p = os.path.join(RST_DT_PATH, "train_trees.pkl") 48 | # train_trees_rst_p = os.path.join(RST_DT_PATH, "train_trees_rst.pkl") 49 | # self.build_specific_trees(train_dir, train_tree_p, train_trees_rst_p) 50 | # 51 | # dev_dir = os.path.join(RAW_RST_DT_PATH, "dev") 52 | # dev_tree_p = os.path.join(RST_DT_PATH, "dev_trees.pkl") 53 | # dev_trees_rst_p = os.path.join(RST_DT_PATH, "dev_trees_rst.pkl") 54 | # self.build_specific_trees(dev_dir, dev_tree_p, dev_trees_rst_p) 55 | 56 | test_dir = os.path.join(RAW_RST_DT_PATH, "test") 57 | test_tree_p = os.path.join(RST_DT_PATH, "test_trees.pkl") 58 | test_trees_rst_p = os.path.join(RST_DT_PATH, "test_trees_rst.pkl") 59 | edus_tree_p = os.path.join(RST_DT_PATH, "edus_test_trees.pkl") 60 | self.build_specific_trees(test_dir, test_tree_p, test_trees_rst_p, edus_tree_p) 61 | 62 | # test_dir = os.path.join(RAW_RST_DT_PATH, "test") 63 | # test_tree_p_n = os.path.join(RST_DT_PATH, "test_trees_n.pkl") 64 | # test_trees_rst_p_n = os.path.join(RST_DT_PATH, "test_trees_rst_n.pkl") 65 | # edus_tree_p_n = os.path.join(RST_DT_PATH, "edus_test_trees_n.pkl") 66 | # self.build_specific_trees(test_dir, test_tree_p_n, test_trees_rst_p_n, edus_tree_p_n) 67 | 68 | def build_specific_trees(self, raw_p=None, tree_p=None, tree_rst_p=None, edus_tree_p=None): 69 | p = progressbar.ProgressBar() 70 | p.start(len(os.listdir(raw_p))) 71 | pro_idx = 1 72 | trees_list, trees_rst_list = [], [] 73 | edu_trees_list = [] 74 | for file_name in os.listdir(raw_p): 75 | p.update(pro_idx) 76 | pro_idx += 1 77 | if file_name.endswith('.out.dis'): 78 | root, edus_list = self.build_one_tree(raw_p, file_name, edus_tree_p) 79 | trees_rst_list.append(root) 80 | tree_obj_ = tree_obj(root) 81 | trees_list.append(tree_obj_) 82 | if edus_tree_p is not None: 83 | edus_tree_obj_ = tree_obj() 84 | edus_tree_obj_.assign_edus(edus_list) 85 | edu_trees_list.append(edus_tree_obj_) 86 | p.finish() 87 | save_data(trees_list, tree_p) 88 | save_data(trees_rst_list, tree_rst_p) 89 | if edus_tree_p is not None: 90 | save_data(edu_trees_list, edus_tree_p) 91 | 92 | def build_one_tree(self, raw_p, file_name, edus_tree_p): 93 | """ build edu_list and edu_ids according to EDU files. 94 | """ 95 | temp_path = os.path.join(raw_p, file_name) 96 | e_bound_list, e_list, e_span_list, e_ids_list, e_tags_list, e_headwords_list, e_cent_word_list, e_emlo_list =\ 97 | get_edus_info(temp_path.replace(".dis", ".edus"), temp_path.replace(".out.dis", ".out"), nlp=self.nlp) 98 | lines_list = open(temp_path, 'r').readlines() 99 | root = rst_tree(type_="Root", lines_list=lines_list, temp_line=lines_list[0], file_name=file_name, rel="span") 100 | root.create_tree(temp_line_num=1, p_node_=root) 101 | root.config_edus(temp_node=root, e_list=e_list, e_span_list=e_span_list, e_ids_list=e_ids_list, 102 | e_tags_list=e_tags_list, e_headwords_list=e_headwords_list, e_cent_word_list=e_cent_word_list, 103 | total_e_num=len(e_ids_list), e_bound_list=e_bound_list, e_emlo_list=e_emlo_list) 104 | 105 | edus_list = [] 106 | if edus_tree_p is not None: 107 | e_bound_list, e_list, e_span_list, e_ids_list, e_tags_list, e_headwords_list, e_cent_word_list, \ 108 | e_emlo_list = get_edus_info(temp_path.replace(".dis", ".auto.edus"), 109 | temp_path.replace(".out.dis", ".out"), nlp=self.nlp) 110 | edu_num = len(e_bound_list) 111 | for idx in range(edu_num): 112 | tmp_edu = rst_tree(temp_edu_boundary=e_bound_list[idx], temp_edu=e_list[idx], 113 | temp_edu_span=e_span_list[idx], temp_edu_ids=e_ids_list[idx], 114 | temp_pos_ids=e_tags_list[idx], temp_edu_heads=e_headwords_list[idx], 115 | temp_edu_has_center_word=e_cent_word_list[idx], 116 | tmp_edu_emlo_emb=e_emlo_list[idx]) 117 | edus_list.append(tmp_edu) 118 | return root, edus_list 119 | 120 | def build_tree_obj_list(self): 121 | raw = "" 122 | parse_trees = [] 123 | for file_name in os.listdir(raw): 124 | if file_name.endswith(".out"): 125 | tmp_edus_list = [] 126 | sent_path = os.path.join(raw, file_name) 127 | edu_path = sent_path + ".edu" 128 | edus_boundary_list, edus_list, edu_span_list, edus_ids_list, edus_tag_ids_list, edus_conns_list, \ 129 | edu_headword_ids_list, edu_has_center_word_list = get_edus_info(edu_path, sent_path, nlp=self.nlp) 130 | for _ in range(len(edus_list)): 131 | tmp_edu = rst_tree() 132 | tmp_edu.temp_edu = edus_list.pop(0) 133 | tmp_edu.temp_edu_span = edu_span_list.pop(0) 134 | tmp_edu.temp_edu_ids = edus_ids_list.pop(0) 135 | tmp_edu.temp_pos_ids = edus_tag_ids_list.pop(0) 136 | tmp_edu.temp_edu_conn_ids = edus_conns_list.pop(0) 137 | tmp_edu.temp_edu_heads = edu_headword_ids_list.pop(0) 138 | tmp_edu.temp_edu_has_center_word = edu_has_center_word_list.pop(0) 139 | tmp_edu.edu_node_boundary = edus_boundary_list.pop(0) 140 | tmp_edu.inner_sent = True 141 | tmp_edus_list.append(tmp_edu) 142 | tmp_tree_obj = tree_obj() 143 | tmp_tree_obj.file_name = file_name 144 | tmp_tree_obj.assign_edus(tmp_edus_list) 145 | parse_trees.append(tmp_tree_obj) 146 | return parse_trees 147 | -------------------------------------------------------------------------------- /ch_dp_gan/pre_process.py: -------------------------------------------------------------------------------- 1 | # UTF-8 2 | # Author: Longyin Zhang 3 | # Date: 2020.10.9 4 | import argparse 5 | import logging 6 | import numpy as np 7 | from dataset import CDTB 8 | from collections import Counter 9 | from itertools import chain 10 | from structure.vocab import Vocab, Label 11 | from structure.nodes import node_type_filter, EDU, Relation, Sentence, TEXT 12 | import progressbar 13 | from util.file_util import * 14 | p = progressbar.ProgressBar() 15 | 16 | 17 | def build_vocab(dataset): 18 | word_freq = Counter() 19 | pos_freq = Counter() 20 | nuc_freq = Counter() 21 | rel_freq = Counter() 22 | for paragraph in chain(*dataset): 23 | for node in paragraph.iterfind(filter=node_type_filter([EDU, Relation])): 24 | if isinstance(node, EDU): 25 | word_freq.update(node.words) 26 | pos_freq.update(node.tags) 27 | elif isinstance(node, Relation): 28 | nuc_freq[node.nuclear] += 1 29 | rel_freq[node.ftype] += 1 30 | 31 | word_vocab = Vocab("word", word_freq) 32 | pos_vocab = Vocab("part of speech", pos_freq) 33 | nuc_label = Label("nuclear", nuc_freq) 34 | rel_label = Label("relation", rel_freq) 35 | return word_vocab, pos_vocab, nuc_label, rel_label 36 | 37 | 38 | def gen_decoder_data(root, edu2ids): 39 | # splits s0 s1 s2 s3 s4 s5 s6 40 | # edus s/ e0 e1 e2 e3 e4 e5 /s 41 | splits = [] # [(0, 3, 6, NS), (0, 2, 3, SN), ...] 42 | child_edus = [] # [edus] 43 | 44 | if isinstance(root, EDU): 45 | child_edus.append(root) 46 | elif isinstance(root, Sentence): 47 | for child in root: 48 | _child_edus, _splits = gen_decoder_data(child, edu2ids) 49 | child_edus.extend(_child_edus) 50 | splits.extend(_splits) 51 | elif isinstance(root, Relation): 52 | children = [gen_decoder_data(child, edu2ids) for child in root] 53 | if len(children) < 2: 54 | raise ValueError("relation node should have at least 2 children") 55 | 56 | while children: 57 | left_child_edus, left_child_splits = children.pop(0) 58 | if children: 59 | last_child_edus, _ = children[-1] 60 | start = edu2ids[left_child_edus[0]] 61 | split = edu2ids[left_child_edus[-1]] + 1 62 | end = edu2ids[last_child_edus[-1]] + 1 63 | nuc = root.nuclear 64 | rel = root.ftype 65 | splits.append((start, split, end, nuc, rel)) 66 | child_edus.extend(left_child_edus) 67 | splits.extend(left_child_splits) 68 | return child_edus, splits 69 | 70 | 71 | def numericalize(dataset, word_vocab, pos_vocab, nuc_label, rel_label): 72 | instances = [] 73 | for paragraph in filter(lambda d: d.root_relation(), chain(*dataset)): 74 | encoder_inputs = [] 75 | decoder_inputs = [] 76 | pred_splits = [] 77 | pred_nucs = [] 78 | pred_rels = [] 79 | edus = list(paragraph.edus()) 80 | for edu in edus: 81 | edu_word_ids = [word_vocab[word] for word in edu.words] 82 | edu_pos_ids = [pos_vocab[pos] for pos in edu.tags] 83 | encoder_inputs.append((edu_word_ids, edu_pos_ids)) 84 | edu2ids = {edu: i for i, edu in enumerate(edus)} 85 | _, splits = gen_decoder_data(paragraph.root_relation(), edu2ids) 86 | for start, split, end, nuc, rel in splits: 87 | decoder_inputs.append((start, end)) 88 | pred_splits.append(split) 89 | pred_nucs.append(nuc_label[nuc]) 90 | pred_rels.append(rel_label[rel]) 91 | instances.append((encoder_inputs, decoder_inputs, pred_splits, pred_nucs, pred_rels)) 92 | return instances 93 | 94 | 95 | nr2ids = dict() 96 | ids2nr = dict() 97 | nr_idx = 0 98 | 99 | 100 | def gen_batch_iter(instances, batch_size, use_gpu=False): 101 | global nr2ids, nr_idx, ids2nr 102 | random_instances = np.random.permutation(instances) 103 | num_instances = len(instances) 104 | offset = 0 105 | p.start(num_instances) 106 | while offset < num_instances: 107 | p.update(offset) 108 | batch = random_instances[offset: min(num_instances, offset+batch_size)] 109 | for batchi, (_, _, _, pred_nucs, pred_rels) in enumerate(batch): 110 | for nuc, rel in zip(pred_nucs, pred_rels): 111 | nr = str(nuc) + "-" + str(rel) 112 | if nr not in nr2ids.keys(): 113 | nr2ids[nr] = nr_idx 114 | ids2nr[nr_idx] = nr 115 | nr_idx += 1 116 | offset = offset + batch_size 117 | p.finish() 118 | 119 | 120 | def main(args): 121 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir) 122 | word_vocab, pos_vocab, nuc_label, rel_label = build_vocab(cdtb.train) 123 | trainset = numericalize(cdtb.train, word_vocab, pos_vocab, nuc_label, rel_label) 124 | test = numericalize(cdtb.test, word_vocab, pos_vocab, nuc_label, rel_label) 125 | dev = numericalize(cdtb.validate, word_vocab, pos_vocab, nuc_label, rel_label) 126 | for idx in range(1, 8): 127 | gen_batch_iter(trainset, args.batch_size, args.use_gpu) 128 | gen_batch_iter(test, args.batch_size, args.use_gpu) 129 | gen_batch_iter(dev, args.batch_size, args.use_gpu) 130 | print(nr2ids) 131 | save_data(nr2ids, "data/nr2ids.pkl") 132 | save_data(ids2nr, "data/ids2nr.pkl") 133 | 134 | 135 | if __name__ == '__main__': 136 | logging.basicConfig(level=logging.INFO) 137 | arg_parser = argparse.ArgumentParser() 138 | 139 | # dataset parameters 140 | arg_parser.add_argument("--data", default="data/CDTB") 141 | arg_parser.add_argument("--ctb_dir", default="data/CTB") 142 | arg_parser.add_argument("--cache_dir", default="data/cache") 143 | 144 | # model parameters 145 | arg_parser.add_argument("-hidden_size", default=512, type=int) 146 | arg_parser.add_argument("-dropout", default=0.33, type=float) 147 | # w2v_group = arg_parser.add_mutually_exclusive_group(required=True) 148 | arg_parser.add_argument("-w2v_size", default=300, type=int) 149 | arg_parser.add_argument("-pos_size", default=30, type=int) 150 | arg_parser.add_argument("-split_mlp_size", default=64, type=int) 151 | arg_parser.add_argument("-nuc_mlp_size", default=32, type=int) 152 | arg_parser.add_argument("-rel_mlp_size", default=128, type=int) 153 | arg_parser.add_argument("--w2v_freeze", dest="w2v_freeze", action="store_true") 154 | arg_parser.add_argument("-pretrained", default="data/pretrained/sgns.renmin.word") 155 | arg_parser.set_defaults(w2v_freeze=True) 156 | 157 | # train parameters 158 | arg_parser.add_argument("-epoch", default=20, type=int) 159 | arg_parser.add_argument("-batch_size", default=64, type=int) 160 | arg_parser.add_argument("-lr", default=0.001, type=float) 161 | arg_parser.add_argument("-l2", default=0.0, type=float) 162 | arg_parser.add_argument("-log_every", default=10, type=int) 163 | arg_parser.add_argument("-validate_every", default=10, type=int) 164 | arg_parser.add_argument("-a_split_loss", default=0.3, type=float) 165 | arg_parser.add_argument("-a_nuclear_loss", default=1.0, type=float) 166 | arg_parser.add_argument("-a_relation_loss", default=1.0, type=float) 167 | arg_parser.add_argument("-log_dir", default="data/log") 168 | arg_parser.add_argument("-model_save", default="data/models/treebuilder.partptr.model") 169 | arg_parser.add_argument("--seed", default=7, type=int) 170 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true") 171 | arg_parser.set_defaults(use_gpu=True) 172 | 173 | main(arg_parser.parse_args()) 174 | 175 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/pre_process.py: -------------------------------------------------------------------------------- 1 | # UTF-8 2 | # Author: Longyin Zhang 3 | # Date: 2020.10.9 4 | import argparse 5 | import logging 6 | import numpy as np 7 | from dataset import CDTB 8 | from collections import Counter 9 | from itertools import chain 10 | from structure.vocab import Vocab, Label 11 | from structure.nodes import node_type_filter, EDU, Relation, Sentence, TEXT 12 | import progressbar 13 | from util.file_util import * 14 | p = progressbar.ProgressBar() 15 | 16 | 17 | def build_vocab(dataset): 18 | word_freq = Counter() 19 | pos_freq = Counter() 20 | nuc_freq = Counter() 21 | rel_freq = Counter() 22 | for paragraph in chain(*dataset): 23 | for node in paragraph.iterfind(filter=node_type_filter([EDU, Relation])): 24 | if isinstance(node, EDU): 25 | word_freq.update(node.words) 26 | pos_freq.update(node.tags) 27 | elif isinstance(node, Relation): 28 | nuc_freq[node.nuclear] += 1 29 | rel_freq[node.ftype] += 1 30 | 31 | word_vocab = Vocab("word", word_freq) 32 | pos_vocab = Vocab("part of speech", pos_freq) 33 | nuc_label = Label("nuclear", nuc_freq) 34 | rel_label = Label("relation", rel_freq) 35 | return word_vocab, pos_vocab, nuc_label, rel_label 36 | 37 | 38 | def gen_decoder_data(root, edu2ids): 39 | # splits s0 s1 s2 s3 s4 s5 s6 40 | # edus s/ e0 e1 e2 e3 e4 e5 /s 41 | splits = [] # [(0, 3, 6, NS), (0, 2, 3, SN), ...] 42 | child_edus = [] # [edus] 43 | 44 | if isinstance(root, EDU): 45 | child_edus.append(root) 46 | elif isinstance(root, Sentence): 47 | for child in root: 48 | _child_edus, _splits = gen_decoder_data(child, edu2ids) 49 | child_edus.extend(_child_edus) 50 | splits.extend(_splits) 51 | elif isinstance(root, Relation): 52 | children = [gen_decoder_data(child, edu2ids) for child in root] 53 | if len(children) < 2: 54 | raise ValueError("relation node should have at least 2 children") 55 | 56 | while children: 57 | left_child_edus, left_child_splits = children.pop(0) 58 | if children: 59 | last_child_edus, _ = children[-1] 60 | start = edu2ids[left_child_edus[0]] 61 | split = edu2ids[left_child_edus[-1]] + 1 62 | end = edu2ids[last_child_edus[-1]] + 1 63 | nuc = root.nuclear 64 | rel = root.ftype 65 | splits.append((start, split, end, nuc, rel)) 66 | child_edus.extend(left_child_edus) 67 | splits.extend(left_child_splits) 68 | return child_edus, splits 69 | 70 | 71 | def numericalize(dataset, word_vocab, pos_vocab, nuc_label, rel_label): 72 | instances = [] 73 | for paragraph in filter(lambda d: d.root_relation(), chain(*dataset)): 74 | encoder_inputs = [] 75 | decoder_inputs = [] 76 | pred_splits = [] 77 | pred_nucs = [] 78 | pred_rels = [] 79 | edus = list(paragraph.edus()) 80 | for edu in edus: 81 | edu_word_ids = [word_vocab[word] for word in edu.words] 82 | edu_pos_ids = [pos_vocab[pos] for pos in edu.tags] 83 | encoder_inputs.append((edu_word_ids, edu_pos_ids)) 84 | edu2ids = {edu: i for i, edu in enumerate(edus)} 85 | _, splits = gen_decoder_data(paragraph.root_relation(), edu2ids) 86 | for start, split, end, nuc, rel in splits: 87 | decoder_inputs.append((start, end)) 88 | pred_splits.append(split) 89 | pred_nucs.append(nuc_label[nuc]) 90 | pred_rels.append(rel_label[rel]) 91 | instances.append((encoder_inputs, decoder_inputs, pred_splits, pred_nucs, pred_rels)) 92 | return instances 93 | 94 | 95 | nr2ids = dict() 96 | ids2nr = dict() 97 | nr_idx = 0 98 | 99 | 100 | def gen_batch_iter(instances, batch_size, use_gpu=False): 101 | """ generate graphs for global optimization 102 | """ 103 | global nr2ids, nr_idx, ids2nr 104 | random_instances = np.random.permutation(instances) 105 | num_instances = len(instances) 106 | offset = 0 107 | p.start(num_instances) 108 | while offset < num_instances: 109 | p.update(offset) 110 | batch = random_instances[offset: min(num_instances, offset+batch_size)] 111 | for batchi, (_, _, _, pred_nucs, pred_rels) in enumerate(batch): 112 | for nuc, rel in zip(pred_nucs, pred_rels): 113 | nr = str(nuc) + "-" + str(rel) 114 | if nr not in nr2ids.keys(): 115 | nr2ids[nr] = nr_idx 116 | ids2nr[nr_idx] = nr 117 | nr_idx += 1 118 | offset = offset + batch_size 119 | p.finish() 120 | 121 | 122 | def main(args): 123 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir) 124 | word_vocab, pos_vocab, nuc_label, rel_label = build_vocab(cdtb.train) 125 | trainset = numericalize(cdtb.train, word_vocab, pos_vocab, nuc_label, rel_label) 126 | test = numericalize(cdtb.test, word_vocab, pos_vocab, nuc_label, rel_label) 127 | dev = numericalize(cdtb.validate, word_vocab, pos_vocab, nuc_label, rel_label) 128 | for idx in range(1, 8): 129 | gen_batch_iter(trainset, args.batch_size, args.use_gpu) 130 | gen_batch_iter(test, args.batch_size, args.use_gpu) 131 | gen_batch_iter(dev, args.batch_size, args.use_gpu) 132 | print(nr2ids) 133 | save_data(nr2ids, "data/nr2ids.pkl") 134 | save_data(ids2nr, "data/ids2nr.pkl") 135 | 136 | 137 | if __name__ == '__main__': 138 | logging.basicConfig(level=logging.INFO) 139 | arg_parser = argparse.ArgumentParser() 140 | 141 | # dataset parameters 142 | arg_parser.add_argument("--data", default="data/CDTB") 143 | arg_parser.add_argument("--ctb_dir", default="data/CTB") 144 | arg_parser.add_argument("--cache_dir", default="data/cache") 145 | 146 | # model parameters 147 | arg_parser.add_argument("-hidden_size", default=512, type=int) 148 | arg_parser.add_argument("-dropout", default=0.33, type=float) 149 | # w2v_group = arg_parser.add_mutually_exclusive_group(required=True) 150 | arg_parser.add_argument("-w2v_size", default=300, type=int) 151 | arg_parser.add_argument("-pos_size", default=30, type=int) 152 | arg_parser.add_argument("-split_mlp_size", default=64, type=int) 153 | arg_parser.add_argument("-nuc_mlp_size", default=32, type=int) 154 | arg_parser.add_argument("-rel_mlp_size", default=128, type=int) 155 | arg_parser.add_argument("--w2v_freeze", dest="w2v_freeze", action="store_true") 156 | arg_parser.add_argument("-pretrained", default="data/pretrained/sgns.renmin.word") 157 | arg_parser.set_defaults(w2v_freeze=True) 158 | 159 | # train parameters 160 | arg_parser.add_argument("-epoch", default=20, type=int) 161 | arg_parser.add_argument("-batch_size", default=64, type=int) 162 | arg_parser.add_argument("-lr", default=0.001, type=float) 163 | arg_parser.add_argument("-l2", default=0.0, type=float) 164 | arg_parser.add_argument("-log_every", default=10, type=int) 165 | arg_parser.add_argument("-validate_every", default=10, type=int) 166 | arg_parser.add_argument("-a_split_loss", default=0.3, type=float) 167 | arg_parser.add_argument("-a_nuclear_loss", default=1.0, type=float) 168 | arg_parser.add_argument("-a_relation_loss", default=1.0, type=float) 169 | arg_parser.add_argument("-log_dir", default="data/log") 170 | arg_parser.add_argument("-model_save", default="data/models/treebuilder.partptr.model") 171 | arg_parser.add_argument("--seed", default=7, type=int) 172 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true") 173 | arg_parser.set_defaults(use_gpu=True) 174 | 175 | main(arg_parser.parse_args()) 176 | 177 | -------------------------------------------------------------------------------- /ch_dp_gan/treebuilder/dp_nr/parser.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import numpy as np 4 | from structure.nodes import Paragraph, Relation, rev_relationmap 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | from interface import ParserI 8 | from structure.nodes import EDU, TEXT 9 | from config import * 10 | 11 | 12 | class PartPtrParser(ParserI): 13 | def __init__(self, model): 14 | self.parser = PartitionPtrParser(model) 15 | 16 | def parse(self, para): 17 | edus = [] 18 | for edu in para.edus(): 19 | edu_copy = EDU([TEXT(edu.text)]) 20 | setattr(edu_copy, "words", edu.words) 21 | setattr(edu_copy, "tags", edu.tags) 22 | edus.append(edu_copy) 23 | return self.parser.parse(edus) 24 | 25 | 26 | class PartitionPtrParser: 27 | def __init__(self, model): 28 | self.model = model 29 | 30 | def parse(self, edus, ret_session=False): 31 | if len(edus) < 2: 32 | return Paragraph(edus) 33 | 34 | # TODO implement beam search 35 | session = self.init_session(edus) 36 | while not session.terminate(): 37 | split_scores, nr_score, state = self.decode(session) 38 | split = split_scores.argmax() 39 | nr_id = nr_score[split].argmax() 40 | nr = ids2nr[nr_id] 41 | nucl_id, rel_id = int(nr.split("-")[0]), int(nr.split("-")[1]) 42 | nuclear = self.model.nuc_label.id2label[nucl_id] 43 | relation = self.model.rel_label.id2label[rel_id] 44 | session = session.forward(split_scores, state, split, nuclear, relation) 45 | root_relation = self.build_tree(edus, session.splits[:], session.nuclears[:], session.relations[:]) 46 | discourse = Paragraph([root_relation]) 47 | if ret_session: 48 | return discourse, session 49 | else: 50 | return discourse 51 | 52 | def init_session(self, edus): 53 | edu_words = [edu.words for edu in edus] 54 | edu_poses = [edu.tags for edu in edus] 55 | max_word_seqlen = max(len(words) for words in edu_words) 56 | edu_seqlen = len(edu_words) 57 | 58 | e_input_words = np.zeros((1, edu_seqlen, max_word_seqlen), dtype=np.long) 59 | e_boundaries = np.zeros([1, edu_seqlen + 1], dtype=np.long) 60 | e_input_poses = np.zeros_like(e_input_words) 61 | e_input_masks = np.zeros_like(e_input_words, dtype=np.uint8) 62 | 63 | for i, (words, poses) in enumerate(zip(edu_words, edu_poses)): 64 | e_input_words[0, i, :len(words)] = [self.model.word_vocab[word] for word in words] 65 | e_input_poses[0, i, :len(poses)] = [self.model.pos_vocab[pos] for pos in poses] 66 | e_input_masks[0, i, :len(words)] = 1 67 | if self.model.word_vocab[words[-1]] in [39, 3015, 178]: 68 | e_boundaries[0][i + 1] = 1 69 | 70 | e_input_words = torch.from_numpy(e_input_words).long() 71 | e_boundaries = torch.from_numpy(e_boundaries).long() 72 | e_input_poses = torch.from_numpy(e_input_poses).long() 73 | e_input_masks = torch.from_numpy(e_input_masks).byte() 74 | 75 | if self.model.use_gpu: 76 | e_input_words = e_input_words.cuda() 77 | e_boundaries = e_boundaries.cuda() 78 | e_input_poses = e_input_poses.cuda() 79 | e_input_masks = e_input_masks.cuda() 80 | 81 | edu_encoded, e_masks = self.model.encode_edus((e_input_words, e_input_poses, e_input_masks, e_boundaries)) 82 | memory, _, context = self.model.encoder(edu_encoded, e_masks, e_boundaries) 83 | state = self.model.context_dense(context).unsqueeze(0) 84 | return Session(memory, state) 85 | 86 | def decode(self, session): 87 | left, right = session.stack[-1] 88 | return self.model(left, right, session.memory, session.state) 89 | 90 | def build_tree(self, edus, splits, nuclears, relations): 91 | left, split, right = splits.pop(0) 92 | nuclear = nuclears.pop(0) 93 | ftype = relations.pop(0) 94 | ctype = rev_relationmap[ftype] 95 | if split - left == 1: 96 | left_node = edus[left] 97 | else: 98 | left_node = self.build_tree(edus, splits, nuclears, relations) 99 | 100 | if right - split == 1: 101 | right_node = edus[split] 102 | else: 103 | right_node = self.build_tree(edus, splits, nuclears, relations) 104 | 105 | relation = Relation([left_node, right_node], nuclear=nuclear, ftype=ftype, ctype=ctype) 106 | return relation 107 | 108 | 109 | class Session: 110 | def __init__(self, memory, state): 111 | self.n = memory.size(1) - 2 112 | self.step = 0 113 | self.memory = memory 114 | self.state = state 115 | self.stack = [(0, self.n + 1)] 116 | self.scores = np.zeros((self.n, self.n+2), dtype=np.float) 117 | self.splits = [] 118 | self.nuclears = [] 119 | self.relations = [] 120 | 121 | def forward(self, score, state, split, nuclear, relation): 122 | left, right = self.stack.pop() 123 | if right - split > 1: 124 | self.stack.append((split, right)) 125 | if split - left > 1: 126 | self.stack.append((left, split)) 127 | self.splits.append((left, split, right)) 128 | self.nuclears.append(nuclear) 129 | self.relations.append(relation) 130 | self.state = state 131 | self.scores[self.step] = score 132 | self.step += 1 133 | return self 134 | 135 | def terminate(self): 136 | return self.step >= self.n 137 | 138 | def draw_decision_hotmap(self): 139 | textcolors = ["black", "white"] 140 | cmap = "YlGn" 141 | ylabel = "split score" 142 | col_labels = ["split %d" % i for i in range(0, self.scores.shape[1])] 143 | row_labels = ["step %d" % i for i in range(1, self.scores.shape[0] + 1)] 144 | fig, ax = plt.subplots() 145 | im = ax.imshow(self.scores, cmap=cmap) 146 | cbar = ax.figure.colorbar(im, ax=ax) 147 | cbar.ax.set_ylabel(ylabel, rotation=-90, va="bottom") 148 | ax.set_xticks(np.arange(self.scores.shape[1])) 149 | ax.set_yticks(np.arange(self.scores.shape[0])) 150 | ax.set_xticklabels(col_labels) 151 | ax.set_yticklabels(row_labels) 152 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) 153 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") 154 | for edge, spine in ax.spines.items(): 155 | spine.set_visible(False) 156 | ax.set_xticks(np.arange(self.scores.shape[1] + 1) - .5, minor=True) 157 | ax.set_yticks(np.arange(self.scores.shape[0] + 1) - .5, minor=True) 158 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 159 | ax.tick_params(which="minor", bottom=False, left=False) 160 | threshold = im.norm(self.scores.max()) / 2. 161 | valfmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}") 162 | texts = [] 163 | kw = dict(horizontalalignment="center", verticalalignment="center") 164 | for i in range(self.scores.shape[0]): 165 | for j in range(self.scores.shape[1]): 166 | kw.update(color=textcolors[im.norm(self.scores[i, j]) > threshold]) 167 | text = im.axes.text(j, i, valfmt(self.scores[i, j], None), **kw) 168 | texts.append(text) 169 | fig.tight_layout() 170 | plt.show() 171 | 172 | def __repr__(self): 173 | return "[step %d]memory size: %s, state size: %s\n stack:\n%s\n, scores:\n %s" % \ 174 | (self.step, str(self.memory.size()), str(self.state.size()), 175 | "\n".join(map(str, self.stack)) or "[]", 176 | str(self.scores)) 177 | 178 | def __str__(self): 179 | return repr(self) 180 | -------------------------------------------------------------------------------- /ch_dp_gan_xlnet/treebuilder/dp_nr/parser.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import torch 3 | import numpy as np 4 | from structure.nodes import Paragraph, Relation, rev_relationmap 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | from interface import ParserI 8 | from structure.nodes import EDU, TEXT 9 | from config import * 10 | 11 | 12 | class PartPtrParser(ParserI): 13 | def __init__(self, model): 14 | self.parser = PartitionPtrParser(model) 15 | 16 | def parse(self, para): 17 | edus = [] 18 | for edu in para.edus(): 19 | edu_copy = EDU([TEXT(edu.text)]) 20 | setattr(edu_copy, "words", edu.words) 21 | setattr(edu_copy, "tags", edu.tags) 22 | edus.append(edu_copy) 23 | return self.parser.parse(edus) 24 | 25 | 26 | class PartitionPtrParser: 27 | def __init__(self, model): 28 | self.model = model 29 | 30 | def parse(self, edus, ret_session=False, tokenizer_xl=None, model_xl=None): 31 | if len(edus) < 2: 32 | return Paragraph(edus) 33 | 34 | # TODO implement beam search 35 | session = self.init_session(edus, tokenizer_xl, model_xl) 36 | while not session.terminate(): 37 | split_scores, nr_score, state = self.decode(session) 38 | split = split_scores.argmax() 39 | nr_id = nr_score[split].argmax() 40 | nr = ids2nr[nr_id] 41 | nucl_id, rel_id = int(nr.split("-")[0]), int(nr.split("-")[1]) 42 | nuclear = self.model.nuc_label.id2label[nucl_id] 43 | relation = self.model.rel_label.id2label[rel_id] 44 | session = session.forward(split_scores, state, split, nuclear, relation) 45 | root_relation = self.build_tree(edus, session.splits[:], session.nuclears[:], session.relations[:]) 46 | discourse = Paragraph([root_relation]) 47 | if ret_session: 48 | return discourse, session 49 | else: 50 | return discourse 51 | 52 | def init_session(self, edus, tokenizer_xl, model_xl): 53 | edu_words = [edu.words for edu in edus] 54 | edu_txts = ["".join(edu.words) for edu in edus] 55 | edu_poses = [edu.tags for edu in edus] 56 | max_word_seqlen = max(len(words) for words in edu_words) 57 | edu_seqlen = len(edu_words) 58 | 59 | e_input_words = np.zeros((1, edu_seqlen, max_word_seqlen), dtype=np.long) 60 | e_txt = [edu_txts] 61 | e_boundaries = np.zeros([1, edu_seqlen + 1], dtype=np.long) 62 | e_input_poses = np.zeros_like(e_input_words) 63 | e_input_masks = np.zeros_like(e_input_words, dtype=np.uint8) 64 | 65 | for i, (words, poses) in enumerate(zip(edu_words, edu_poses)): 66 | e_input_words[0, i, :len(words)] = [self.model.word_vocab[word] for word in words] 67 | e_input_poses[0, i, :len(poses)] = [self.model.pos_vocab[pos] for pos in poses] 68 | e_input_masks[0, i, :len(words)] = 1 69 | if self.model.word_vocab[words[-1]] in [39, 3015, 178]: 70 | e_boundaries[0][i + 1] = 1 # 句子边界 71 | 72 | e_input_words = torch.from_numpy(e_input_words).long() 73 | e_boundaries = torch.from_numpy(e_boundaries).long() 74 | e_input_poses = torch.from_numpy(e_input_poses).long() 75 | e_input_masks = torch.from_numpy(e_input_masks).byte() 76 | 77 | if self.model.use_gpu: 78 | e_input_words = e_input_words.cuda() 79 | e_boundaries = e_boundaries.cuda() 80 | e_input_poses = e_input_poses.cuda() 81 | e_input_masks = e_input_masks.cuda() 82 | 83 | edu_encoded, e_masks = self.model.encode_edus((e_input_words, e_input_poses, e_input_masks, e_boundaries, e_txt) 84 | , tokenizer_xl, model_xl) 85 | memory, _, context = self.model.encoder(edu_encoded, e_masks, e_boundaries) 86 | state = self.model.context_dense(context).unsqueeze(0) 87 | return Session(memory, state) 88 | 89 | def decode(self, session): 90 | left, right = session.stack[-1] 91 | return self.model(left, right, session.memory, session.state) 92 | 93 | def build_tree(self, edus, splits, nuclears, relations): 94 | left, split, right = splits.pop(0) 95 | nuclear = nuclears.pop(0) 96 | ftype = relations.pop(0) 97 | ctype = rev_relationmap[ftype] 98 | if split - left == 1: 99 | left_node = edus[left] 100 | else: 101 | left_node = self.build_tree(edus, splits, nuclears, relations) 102 | 103 | if right - split == 1: 104 | right_node = edus[split] 105 | else: 106 | right_node = self.build_tree(edus, splits, nuclears, relations) 107 | 108 | relation = Relation([left_node, right_node], nuclear=nuclear, ftype=ftype, ctype=ctype) 109 | return relation 110 | 111 | 112 | class Session: 113 | def __init__(self, memory, state): 114 | self.n = memory.size(1) - 2 115 | self.step = 0 116 | self.memory = memory 117 | self.state = state 118 | self.stack = [(0, self.n + 1)] 119 | self.scores = np.zeros((self.n, self.n+2), dtype=np.float) 120 | self.splits = [] 121 | self.nuclears = [] 122 | self.relations = [] 123 | 124 | def forward(self, score, state, split, nuclear, relation): 125 | left, right = self.stack.pop() 126 | if right - split > 1: 127 | self.stack.append((split, right)) 128 | if split - left > 1: 129 | self.stack.append((left, split)) 130 | self.splits.append((left, split, right)) 131 | self.nuclears.append(nuclear) 132 | self.relations.append(relation) 133 | self.state = state 134 | self.scores[self.step] = score 135 | self.step += 1 136 | return self 137 | 138 | def terminate(self): 139 | return self.step >= self.n 140 | 141 | def draw_decision_hotmap(self): 142 | textcolors = ["black", "white"] 143 | cmap = "YlGn" 144 | ylabel = "split score" 145 | col_labels = ["split %d" % i for i in range(0, self.scores.shape[1])] 146 | row_labels = ["step %d" % i for i in range(1, self.scores.shape[0] + 1)] 147 | fig, ax = plt.subplots() 148 | im = ax.imshow(self.scores, cmap=cmap) 149 | cbar = ax.figure.colorbar(im, ax=ax) 150 | cbar.ax.set_ylabel(ylabel, rotation=-90, va="bottom") 151 | ax.set_xticks(np.arange(self.scores.shape[1])) 152 | ax.set_yticks(np.arange(self.scores.shape[0])) 153 | ax.set_xticklabels(col_labels) 154 | ax.set_yticklabels(row_labels) 155 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) 156 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") 157 | for edge, spine in ax.spines.items(): 158 | spine.set_visible(False) 159 | ax.set_xticks(np.arange(self.scores.shape[1] + 1) - .5, minor=True) 160 | ax.set_yticks(np.arange(self.scores.shape[0] + 1) - .5, minor=True) 161 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 162 | ax.tick_params(which="minor", bottom=False, left=False) 163 | threshold = im.norm(self.scores.max()) / 2. 164 | valfmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}") 165 | texts = [] 166 | kw = dict(horizontalalignment="center", verticalalignment="center") 167 | for i in range(self.scores.shape[0]): 168 | for j in range(self.scores.shape[1]): 169 | kw.update(color=textcolors[im.norm(self.scores[i, j]) > threshold]) 170 | text = im.axes.text(j, i, valfmt(self.scores[i, j], None), **kw) 171 | texts.append(text) 172 | fig.tight_layout() 173 | plt.show() 174 | 175 | def __repr__(self): 176 | return "[step %d]memory size: %s, state size: %s\n stack:\n%s\n, scores:\n %s" % \ 177 | (self.step, str(self.memory.size()), str(self.state.size()), 178 | "\n".join(map(str, self.stack)) or "[]", 179 | str(self.scores)) 180 | 181 | def __str__(self): 182 | return repr(self) 183 | --------------------------------------------------------------------------------