├── ch_dp_gan
├── util
│ ├── __init__.py
│ ├── model_util.py
│ ├── ltp.py
│ ├── berkely.py
│ └── file_util.py
├── treebuilder
│ ├── __init__.py
│ └── dp_nr
│ │ ├── __init__.py
│ │ └── parser.py
├── dataset
│ ├── __init__.py
│ └── cdtb.py
├── structure
│ ├── __init__.py
│ └── vocab.py
├── .gitmodules
├── requirements.txt
├── config.py
├── interface.py
├── new_ctb.py
├── .gitignore
├── berkeleyparser
│ └── README
└── pre_process.py
├── en_dp_gan
├── model
│ ├── __init__.py
│ ├── stacked_parser_tdt_gan3
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── model.cpython-36.pyc
│ │ │ ├── model.cpython-37.pyc
│ │ │ ├── parser.cpython-36.pyc
│ │ │ ├── parser.cpython-37.pyc
│ │ │ ├── trainer.cpython-36.pyc
│ │ │ ├── trainer.cpython-37.pyc
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── __init__.cpython-37.pyc
│ │ └── parser.py
│ ├── __pycache__
│ │ ├── metric.cpython-36.pyc
│ │ ├── metric.cpython-37.pyc
│ │ ├── parser.cpython-36.pyc
│ │ ├── parser.cpython-37.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── rst_tree.cpython-36.pyc
│ │ ├── tree_obj.cpython-36.pyc
│ │ ├── masked_gru.cpython-36.pyc
│ │ ├── split_attn.cpython-36.pyc
│ │ ├── biaffine_attn.cpython-36.pyc
│ │ └── multi_head_attn.cpython-36.pyc
│ └── parser.py
├── util
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── file_util.pyc
│ ├── __pycache__
│ │ ├── eval.cpython-36.pyc
│ │ ├── drawer.cpython-36.pyc
│ │ ├── drawer.cpython-37.pyc
│ │ ├── radam.cpython-36.pyc
│ │ ├── radam.cpython-37.pyc
│ │ ├── DL_tricks.cpython-36.pyc
│ │ ├── DL_tricks.cpython-37.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── file_util.cpython-36.pyc
│ │ ├── file_util.cpython-37.pyc
│ │ ├── patterns.cpython-36.pyc
│ │ ├── patterns.cpython-37.pyc
│ │ ├── rst_utils.cpython-36.pyc
│ │ ├── rst_utils.cpython-37.pyc
│ │ ├── data_builder.cpython-36.pyc
│ │ ├── data_builder.cpython-37.pyc
│ │ └── model_util.cpython-37.pyc
│ ├── model_util.py
│ ├── DL_tricks.py
│ ├── patterns.py
│ ├── file_util.py
│ ├── drawer.py
│ ├── eval.py
│ └── data_builder.py
├── structure
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── rst.cpython-36.pyc
│ │ ├── rst.cpython-37.pyc
│ │ ├── nodes.cpython-36.pyc
│ │ ├── nodes.cpython-37.pyc
│ │ ├── vocab.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── rst_tree.cpython-36.pyc
│ │ ├── rst_tree.cpython-37.pyc
│ │ ├── tree_obj.cpython-36.pyc
│ │ └── tree_obj.cpython-37.pyc
│ ├── rst.py
│ ├── tree_obj.py
│ └── vocab.py
├── requirements.txt
├── __pycache__
│ ├── config.cpython-36.pyc
│ ├── config.cpython-37.pyc
│ ├── path_config.cpython-36.pyc
│ └── path_config.cpython-37.pyc
├── .idea
│ ├── misc.xml
│ ├── modules.xml
│ ├── en_dp_gan.iml
│ └── inspectionProfiles
│ │ └── Project_Default.xml
├── main.py
├── config.py
└── path_config.py
├── ch_dp_gan_xlnet
├── util
│ ├── __init__.py
│ ├── model_util.py
│ ├── ltp.py
│ ├── berkely.py
│ └── file_util.py
├── treebuilder
│ ├── __init__.py
│ └── dp_nr
│ │ ├── __init__.py
│ │ └── parser.py
├── dataset
│ ├── __init__.py
│ └── cdtb.py
├── structure
│ ├── __init__.py
│ └── vocab.py
├── .gitmodules
├── requirements.txt
├── config.py
├── interface.py
├── new_ctb.py
├── .gitignore
├── berkeleyparser
│ └── README
└── pre_process.py
├── en_dp_gan_xlnet
├── model
│ ├── __init__.py
│ ├── .DS_Store
│ ├── stacked_parser_tdt_xlnet
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── model.cpython-36.pyc
│ │ │ ├── model.cpython-37.pyc
│ │ │ ├── parser.cpython-36.pyc
│ │ │ ├── parser.cpython-37.pyc
│ │ │ ├── trainer.cpython-36.pyc
│ │ │ ├── trainer.cpython-37.pyc
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── __init__.cpython-37.pyc
│ │ └── parser.py
│ ├── __pycache__
│ │ ├── metric.cpython-36.pyc
│ │ ├── metric.cpython-37.pyc
│ │ ├── parser.cpython-36.pyc
│ │ ├── parser.cpython-37.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── masked_gru.cpython-36.pyc
│ │ ├── rst_tree.cpython-36.pyc
│ │ ├── split_attn.cpython-36.pyc
│ │ ├── tree_obj.cpython-36.pyc
│ │ ├── biaffine_attn.cpython-36.pyc
│ │ └── multi_head_attn.cpython-36.pyc
│ └── parser.py
├── util
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── con2dep.pyc
│ ├── file_util.pyc
│ ├── __pycache__
│ │ ├── drawer.cpython-36.pyc
│ │ ├── drawer.cpython-37.pyc
│ │ ├── eval.cpython-36.pyc
│ │ ├── radam.cpython-36.pyc
│ │ ├── radam.cpython-37.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── con2dep.cpython-36.pyc
│ │ ├── patterns.cpython-36.pyc
│ │ ├── patterns.cpython-37.pyc
│ │ ├── DL_tricks.cpython-36.pyc
│ │ ├── DL_tricks.cpython-37.pyc
│ │ ├── file_util.cpython-36.pyc
│ │ ├── file_util.cpython-37.pyc
│ │ ├── rst_utils.cpython-36.pyc
│ │ ├── rst_utils.cpython-37.pyc
│ │ ├── data_builder.cpython-36.pyc
│ │ └── data_builder.cpython-37.pyc
│ ├── DL_tricks.py
│ ├── patterns.py
│ ├── con2dep.py
│ ├── drawer.py
│ ├── file_util.py
│ └── eval.py
├── structure
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── rst.cpython-36.pyc
│ │ ├── rst.cpython-37.pyc
│ │ ├── nodes.cpython-36.pyc
│ │ ├── nodes.cpython-37.pyc
│ │ ├── rst_xl.cpython-36.pyc
│ │ ├── rst_xl.cpython-37.pyc
│ │ ├── vocab.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── rst_tree.cpython-36.pyc
│ │ ├── rst_tree.cpython-37.pyc
│ │ ├── tree_obj.cpython-36.pyc
│ │ └── tree_obj.cpython-37.pyc
│ ├── rst.py
│ ├── rst_xl.py
│ ├── tree_obj.py
│ ├── vocab.py
│ └── cdtb.py
├── .DS_Store
├── requirements.txt
├── __pycache__
│ ├── config.cpython-36.pyc
│ ├── config.cpython-37.pyc
│ ├── path_config.cpython-36.pyc
│ └── path_config.cpython-37.pyc
├── .idea
│ ├── misc.xml
│ ├── modules.xml
│ ├── en_dp_gan_xlnet.iml
│ └── inspectionProfiles
│ │ └── Project_Default.xml
├── main.py
├── config.py
└── path_config.py
├── .DS_Store
└── README.md
/ch_dp_gan/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/en_dp_gan/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/en_dp_gan/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ch_dp_gan/treebuilder/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/treebuilder/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/.DS_Store
--------------------------------------------------------------------------------
/ch_dp_gan/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .cdtb import CDTB
4 |
--------------------------------------------------------------------------------
/ch_dp_gan/structure/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .nodes import *
4 |
--------------------------------------------------------------------------------
/en_dp_gan/structure/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .nodes import *
4 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .cdtb import CDTB
4 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/structure/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .nodes import *
4 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .nodes import *
4 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/.DS_Store
--------------------------------------------------------------------------------
/en_dp_gan/requirements.txt:
--------------------------------------------------------------------------------
1 | nltk
2 | torch>=0.4
3 | thulac
4 | tqdm
5 | gensim
6 | numpy
7 | tensorboardX
8 | matplotlib
9 |
--------------------------------------------------------------------------------
/en_dp_gan/util/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__init__.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/file_util.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/file_util.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/.DS_Store
--------------------------------------------------------------------------------
/ch_dp_gan/treebuilder/dp_nr/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/requirements.txt:
--------------------------------------------------------------------------------
1 | nltk
2 | torch>=0.4
3 | thulac
4 | tqdm
5 | gensim
6 | numpy
7 | tensorboardX
8 | matplotlib
9 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__init__.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/con2dep.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/con2dep.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/file_util.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/file_util.pyc
--------------------------------------------------------------------------------
/ch_dp_gan/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "berkeleyparser"]
2 | path = berkeleyparser
3 | url = git@github.com:slavpetrov/berkeleyparser.git
4 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "berkeleyparser"]
2 | path = berkeleyparser
3 | url = git@github.com:slavpetrov/berkeleyparser.git
4 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/treebuilder/dp_nr/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
--------------------------------------------------------------------------------
/en_dp_gan/model/stacked_parser_tdt_gan3/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
--------------------------------------------------------------------------------
/en_dp_gan/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/eval.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/eval.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/__pycache__/path_config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/__pycache__/path_config.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/__pycache__/path_config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/__pycache__/path_config.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/metric.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/metric.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/metric.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/metric.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/parser.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/parser.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/parser.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/parser.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/drawer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/drawer.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/drawer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/drawer.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/radam.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/radam.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/radam.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/radam.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/rst_tree.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/rst_tree.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/tree_obj.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/tree_obj.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/rst.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/rst.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/rst.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/rst.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/DL_tricks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/DL_tricks.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/DL_tricks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/DL_tricks.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/file_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/file_util.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/file_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/file_util.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/patterns.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/patterns.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/patterns.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/patterns.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/rst_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/rst_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/rst_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/rst_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/ch_dp_gan/requirements.txt:
--------------------------------------------------------------------------------
1 | nltk
2 | torch==0.4.1
3 | thulac
4 | tqdm
5 | gensim
6 | numpy
7 | tensorboardX
8 | matplotlib
9 | sklearn==0.20.0
10 | pyltp==0.2.1
11 |
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/masked_gru.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/masked_gru.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/split_attn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/split_attn.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/nodes.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/nodes.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/nodes.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/nodes.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/vocab.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/vocab.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/data_builder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/data_builder.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/data_builder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/data_builder.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/util/__pycache__/model_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/util/__pycache__/model_util.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/__pycache__/path_config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/__pycache__/path_config.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/__pycache__/path_config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/__pycache__/path_config.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/drawer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/drawer.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/drawer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/drawer.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/eval.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/eval.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/radam.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/radam.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/radam.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/radam.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/biaffine_attn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/biaffine_attn.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/rst_tree.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/rst_tree.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/rst_tree.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/rst_tree.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/tree_obj.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/tree_obj.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/structure/__pycache__/tree_obj.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/structure/__pycache__/tree_obj.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/metric.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/metric.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/metric.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/metric.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/parser.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/parser.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/parser.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/parser.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/rst.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/rst.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/con2dep.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/con2dep.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/patterns.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/patterns.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/patterns.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/patterns.cpython-37.pyc
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/requirements.txt:
--------------------------------------------------------------------------------
1 | nltk
2 | torch==0.4.1
3 | thulac
4 | tqdm
5 | gensim
6 | numpy
7 | tensorboardX
8 | matplotlib
9 | sklearn==0.20.0
10 | pyltp==0.2.1
11 |
--------------------------------------------------------------------------------
/en_dp_gan/model/__pycache__/multi_head_attn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/__pycache__/multi_head_attn.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/masked_gru.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/masked_gru.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/rst_tree.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/rst_tree.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/split_attn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/split_attn.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/tree_obj.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/tree_obj.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/nodes.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/nodes.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/nodes.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/nodes.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/rst_xl.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst_xl.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/rst_xl.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst_xl.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/vocab.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/vocab.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/DL_tricks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/DL_tricks.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/DL_tricks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/DL_tricks.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/file_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/file_util.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/file_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/file_util.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/rst_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/rst_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/rst_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/rst_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/rst_tree.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst_tree.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/rst_tree.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/rst_tree.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/tree_obj.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/tree_obj.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/__pycache__/tree_obj.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/structure/__pycache__/tree_obj.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/data_builder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/data_builder.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/__pycache__/data_builder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/util/__pycache__/data_builder.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/biaffine_attn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/biaffine_attn.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/__pycache__/multi_head_attn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/__pycache__/multi_head_attn.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/parser.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/parser.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/parser.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/parser.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/trainer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/trainer.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan/model/stacked_parser_tdt_gan3/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/parser.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/parser.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/parser.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/parser.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/trainer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/trainer.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/GAN_DP/HEAD/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/en_dp_gan/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/ch_dp_gan/util/model_util.py:
--------------------------------------------------------------------------------
1 |
2 | def get_parameter_number(model):
3 | total_num = sum(p.numel() for p in model.parameters())
4 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
5 | print('Total: ' + str(total_num), 'Trainable: ' + str(trainable_num))
6 |
--------------------------------------------------------------------------------
/en_dp_gan/util/model_util.py:
--------------------------------------------------------------------------------
1 |
2 | def get_parameter_number(model):
3 | total_num = sum(p.numel() for p in model.parameters())
4 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
5 | print('Total: ' + str(total_num), 'Trainable: ' + str(trainable_num))
6 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/util/model_util.py:
--------------------------------------------------------------------------------
1 |
2 | def get_parameter_number(model):
3 | total_num = sum(p.numel() for p in model.parameters())
4 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
5 | print('Total: ' + str(total_num), 'Trainable: ' + str(trainable_num))
6 |
--------------------------------------------------------------------------------
/en_dp_gan/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/en_dp_gan/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
8 | import logging
9 | from sys import argv
10 | from model.stacked_parser_tdt_gan3.trainer import Trainer
11 |
12 |
13 | if __name__ == '__main__':
14 | logging.basicConfig(level=logging.INFO)
15 | test_desc = argv[1] if len(argv) >= 2 else "no message."
16 | trainer = Trainer()
17 | trainer.train(test_desc)
18 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
8 | import logging
9 | from sys import argv
10 | from model.stacked_parser_tdt_xlnet.trainer import Trainer
11 |
12 |
13 | if __name__ == '__main__':
14 | logging.basicConfig(level=logging.INFO)
15 | test_desc = argv[1] if len(argv) >= 2 else "no message."
16 | trainer = Trainer()
17 | trainer.train(test_desc)
18 |
--------------------------------------------------------------------------------
/en_dp_gan/.idea/en_dp_gan.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/.idea/en_dp_gan_xlnet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/en_dp_gan/util/DL_tricks.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Lynx Zhang
3 | Date: 2020.
4 | Email: zzlynx@outlook.com
5 | """
6 | import torch
7 |
8 |
9 | def weights_init_normal(m):
10 | classname = m.__class__.__name__
11 | if (classname.find("Conv") != -1) or (classname.find("Linear") != -1):
12 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
13 | elif classname.find("BatchNorm2d") != -1:
14 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
15 | torch.nn.init.constant_(m.bias.data, 0.0)
16 |
17 |
18 | def log(x):
19 | return torch.log(x + 1e-8)
20 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/DL_tricks.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Lynx Zhang
3 | Date: 2020.
4 | Email: zzlynx@outlook.com
5 | """
6 | import torch
7 |
8 |
9 | def weights_init_normal(m):
10 | classname = m.__class__.__name__
11 | if (classname.find("Conv") != -1) or (classname.find("Linear") != -1):
12 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
13 | elif classname.find("BatchNorm2d") != -1:
14 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
15 | torch.nn.init.constant_(m.bias.data, 0.0)
16 |
17 |
18 | def log(x):
19 | return torch.log(x + 1e-8)
20 |
--------------------------------------------------------------------------------
/ch_dp_gan/config.py:
--------------------------------------------------------------------------------
1 | # UTF-8
2 | # Author: Longyin Zhang
3 | # Date:
4 |
5 | from util.file_util import *
6 |
7 | SET = 66
8 | nr2ids = load_data("data/nr2ids.pkl")
9 | ids2nr = load_data("data/ids2nr.pkl")
10 | SEED = 19 # 17, 7, 14, 2, 13, 19, 21, 22
11 | NR_LABEL = 46
12 | USE_GAN = True
13 | MAX_H, MAX_W = 15, 25
14 | in_channel_G, out_channel_G, ker_h_G, ker_w_G, strip_G = 2, 16, 3, MAX_W, 1
15 | p_w_G, p_h_G = 3, 1
16 | MAX_POOLING = True
17 | WARM_UP_EP = 7 if USE_GAN else 20
18 | LABEL_SWITCH, SWITCH_ITE = False, 4
19 | USE_BOUND, BOUND_INFO_SIZE = False, 30
20 | LAYER_NORM_USE = True
21 | CONTEXT_ATT, ML_ATT_HIDDEN, HEADS = False, 128, 2
22 |
--------------------------------------------------------------------------------
/en_dp_gan/util/patterns.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
8 | import re
9 |
10 | leaf_parttern = r' *\( \w+ \(leaf .+'
11 | leaf_re = re.compile(leaf_parttern)
12 | node_parttern = r' *\( \w+ \(span .+'
13 | node_re = re.compile(node_parttern)
14 | end_parttern = r'\s*\)\s*'
15 | end_re = re.compile(end_parttern)
16 | nodetype_parttern = r' *\( (\w+) .+'
17 | type_re = re.compile(nodetype_parttern)
18 | rel_parttern = r' *\( \w+ \(.+\) \(rel2par ([\w-]+).+'
19 | rel_re = re.compile(rel_parttern)
20 | node_leaf_parttern = r' *\( \w+ \((\w+) \d+.*\).+'
21 | node_leaf_re = re.compile(node_leaf_parttern)
22 | upper_parttern = r'[a-z].*'
23 | upper_re = re.compile(upper_parttern)
24 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/patterns.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
8 | import re
9 |
10 | leaf_parttern = r' *\( \w+ \(leaf .+'
11 | leaf_re = re.compile(leaf_parttern)
12 | node_parttern = r' *\( \w+ \(span .+'
13 | node_re = re.compile(node_parttern)
14 | end_parttern = r'\s*\)\s*'
15 | end_re = re.compile(end_parttern)
16 | nodetype_parttern = r' *\( (\w+) .+'
17 | type_re = re.compile(nodetype_parttern)
18 | rel_parttern = r' *\( \w+ \(.+\) \(rel2par ([\w-]+).+'
19 | rel_re = re.compile(rel_parttern)
20 | node_leaf_parttern = r' *\( \w+ \((\w+) \d+.*\).+'
21 | node_leaf_re = re.compile(node_leaf_parttern)
22 | upper_parttern = r'[a-z].*'
23 | upper_re = re.compile(upper_parttern)
24 |
--------------------------------------------------------------------------------
/en_dp_gan/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/config.py:
--------------------------------------------------------------------------------
1 | # UTF-8
2 | # Author: Longyin Zhang
3 | # Date:
4 |
5 | from util.file_util import *
6 |
7 | SET = 65
8 | nr2ids = load_data("data/nr2ids.pkl")
9 | ids2nr = load_data("data/ids2nr.pkl")
10 | SEED = 19 # 17, 7, 14, 2, 13, 19, 21, 22
11 | NR_LABEL = 46
12 | USE_GAN = True
13 | MAX_H, MAX_W = 15, 25
14 | in_channel_G, out_channel_G, ker_h_G, ker_w_G, strip_G = 2, 16, 3, MAX_W, 1
15 | p_w_G, p_h_G = 3, 1
16 | MAX_POOLING = True
17 | WARM_UP_EP = 7 if USE_GAN else 50
18 | LABEL_SWITCH, SWITCH_ITE = False, 4
19 | USE_BOUND, BOUND_INFO_SIZE = False, 30
20 | LAYER_NORM_USE = True
21 | CONTEXT_ATT, ML_ATT_HIDDEN, HEADS = False, 128, 2
22 |
23 | EDU_ENCODE_V = 1
24 | Finetune_XL = True
25 | XLNET_PATH = "data/pytorch_model.bin"
26 | XLNET_SIZE, CHUNK_SIZE = 768, 512
27 |
--------------------------------------------------------------------------------
/ch_dp_gan/util/ltp.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import os
4 | import pyltp
5 |
6 | LTP_DATA_DIR = 'pub/pyltp_models'
7 | path_to_tagger = os.path.join(LTP_DATA_DIR, 'pos.model')
8 | path_to_parser = os.path.join(LTP_DATA_DIR, "parser.model")
9 |
10 |
11 | class LTPParser:
12 | def __init__(self):
13 | self.tagger = pyltp.Postagger()
14 | self.parser = pyltp.Parser()
15 | self.tagger.load(path_to_tagger)
16 | self.parser.load(path_to_parser)
17 |
18 | def __enter__(self):
19 | self.tagger.load(path_to_tagger)
20 | self.parser.load(path_to_parser)
21 | return self
22 |
23 | def __exit__(self, exc_type, exc_val, exc_tb):
24 | self.tagger.release()
25 | self.parser.release()
26 |
27 | def parse(self, words):
28 | tags = self.tagger.postag(words)
29 | parse = self.parser.parse(words, tags)
30 | return parse
31 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/util/ltp.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import os
4 | import pyltp
5 |
6 | LTP_DATA_DIR = 'pub/pyltp_models'
7 | path_to_tagger = os.path.join(LTP_DATA_DIR, 'pos.model')
8 | path_to_parser = os.path.join(LTP_DATA_DIR, "parser.model")
9 |
10 |
11 | class LTPParser:
12 | def __init__(self):
13 | self.tagger = pyltp.Postagger()
14 | self.parser = pyltp.Parser()
15 | self.tagger.load(path_to_tagger)
16 | self.parser.load(path_to_parser)
17 |
18 | def __enter__(self):
19 | self.tagger.load(path_to_tagger)
20 | self.parser.load(path_to_parser)
21 | return self
22 |
23 | def __exit__(self, exc_type, exc_val, exc_tb):
24 | self.tagger.release()
25 | self.parser.release()
26 |
27 | def parse(self, words):
28 | tags = self.tagger.postag(words)
29 | parse = self.parser.parse(words, tags)
30 | return parse
31 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/con2dep.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
8 |
9 | dep_tuples, dep_idx = [], -1
10 |
11 |
12 | def con2dep(root):
13 | global dep_tuples, dep_idx
14 | dep_tuples, dep_idx = [], -1
15 | recursive_one(root)
16 | return dep_tuples
17 |
18 |
19 | def recursive_one(root):
20 | global dep_tuples, dep_idx
21 | if root.left_child is not None:
22 | idx1 = recursive_one(root.left_child)
23 | idx2 = recursive_one(root.right_child)
24 | later = max(idx1, idx2)
25 | pre = min(idx1, idx2)
26 | if root.child_NS_rel == "SN":
27 | dep_tuples.append((later, pre, 0, root.child_rel))
28 | return idx2
29 | else:
30 | dep_tuples.append((later, pre, 1, root.child_rel))
31 | return idx1
32 | else:
33 | dep_idx += 1
34 | return dep_idx
35 |
--------------------------------------------------------------------------------
/ch_dp_gan/interface.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | from abc import abstractmethod
3 |
4 | from typing import List
5 | from structure.nodes import Sentence, Paragraph, EDU
6 |
7 |
8 | class SegmenterI:
9 | @abstractmethod
10 | def cut(self, text: str) -> Paragraph:
11 | raise NotImplemented()
12 |
13 | @abstractmethod
14 | def cut_sent(self, text: str, sid=None) -> List[Sentence]:
15 | raise NotImplemented()
16 |
17 | @abstractmethod
18 | def cut_edu(self, sent: Sentence) -> List[EDU]:
19 | raise NotImplemented()
20 |
21 |
22 | class ParserI:
23 | @abstractmethod
24 | def parse(self, para: Paragraph) -> Paragraph:
25 | raise NotImplemented()
26 |
27 |
28 | class PipelineI:
29 | @abstractmethod
30 | def cut_sent(self, text: str) -> Paragraph:
31 | raise NotImplemented()
32 |
33 | @abstractmethod
34 | def cut_edu(self, para: Paragraph) -> Paragraph:
35 | raise NotImplemented()
36 |
37 | @abstractmethod
38 | def parse(self, para: Paragraph) -> Paragraph:
39 | raise NotImplemented()
40 |
41 | @abstractmethod
42 | def full_parse(self, text: str):
43 | raise NotImplemented()
44 |
45 | def __call__(self, text: str):
46 | return self.full_parse(text)
47 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/interface.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | from abc import abstractmethod
3 |
4 | from typing import List
5 | from structure.nodes import Sentence, Paragraph, EDU
6 |
7 |
8 | class SegmenterI:
9 | @abstractmethod
10 | def cut(self, text: str) -> Paragraph:
11 | raise NotImplemented()
12 |
13 | @abstractmethod
14 | def cut_sent(self, text: str, sid=None) -> List[Sentence]:
15 | raise NotImplemented()
16 |
17 | @abstractmethod
18 | def cut_edu(self, sent: Sentence) -> List[EDU]:
19 | raise NotImplemented()
20 |
21 |
22 | class ParserI:
23 | @abstractmethod
24 | def parse(self, para: Paragraph) -> Paragraph:
25 | raise NotImplemented()
26 |
27 |
28 | class PipelineI:
29 | @abstractmethod
30 | def cut_sent(self, text: str) -> Paragraph:
31 | raise NotImplemented()
32 |
33 | @abstractmethod
34 | def cut_edu(self, para: Paragraph) -> Paragraph:
35 | raise NotImplemented()
36 |
37 | @abstractmethod
38 | def parse(self, para: Paragraph) -> Paragraph:
39 | raise NotImplemented()
40 |
41 | @abstractmethod
42 | def full_parse(self, text: str):
43 | raise NotImplemented()
44 |
45 | def __call__(self, text: str):
46 | return self.full_parse(text)
47 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/rst.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
8 | from config import *
9 | from util.file_util import *
10 |
11 |
12 | class RST:
13 | def __init__(self, train_set=TRAIN_SET, dev_set=DEV_SET, test_set=TEST_SET):
14 | self.train = load_data(train_set)
15 | # self.train_rst = load_data(RST_TRAIN_TREES_RST)
16 | self.dev = load_data(dev_set)
17 | # self.dev_rst = load_data(RST_DEV_TREES_RST)
18 | self.test = load_data(test_set)
19 | # self.test_rst = load_data(RST_TEST_TREES_RST)
20 |
21 | self.tree_obj_train = load_data(RST_TRAIN_TREES)
22 | self.tree_obj_dev = load_data(RST_DEV_TREES)
23 | self.tree_obj_test = load_data(RST_TEST_TREES)
24 | self.edu_tree_obj_test = load_data(RST_TEST_EDUS_TREES)
25 | if TRAIN_Ext:
26 | ext_trees = load_data(EXT_TREES)
27 | ext_set = load_data(EXT_SET)
28 | self.train = self.train + ext_set
29 | self.tree_obj_train = self.tree_obj_train + ext_trees
30 |
31 | @staticmethod
32 | def get_voc_labels():
33 | word2ids = load_data(VOC_WORD2IDS_PATH)
34 | pos2ids = load_data(POS_word2ids_PATH)
35 | return word2ids, pos2ids
36 |
37 | @staticmethod
38 | def get_dev():
39 | return load_data(RST_DEV_TREES)
40 |
41 | @staticmethod
42 | def get_test():
43 | if USE_AE:
44 | return load_data(RST_TEST_EDUS_TREES_)
45 | else:
46 | return load_data(RST_TEST_TREES)
47 |
--------------------------------------------------------------------------------
/en_dp_gan/structure/rst.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
8 | from config import *
9 | from util.file_util import *
10 |
11 |
12 | class RST:
13 | def __init__(self, train_set=TRAIN_SET, dev_set=DEV_SET, test_set=TEST_SET):
14 | self.train = load_data(train_set)
15 | # self.train_rst = load_data(RST_TRAIN_TREES_RST)
16 | self.dev = load_data(dev_set)
17 | # self.dev_rst = load_data(RST_DEV_TREES_RST)
18 | self.test = load_data(test_set)
19 | # self.test_rst = load_data(RST_TEST_TREES_RST)
20 |
21 | # tree_obj
22 | self.tree_obj_train = load_data(RST_TRAIN_TREES)
23 | self.tree_obj_dev = load_data(RST_DEV_TREES)
24 | self.tree_obj_test = load_data(RST_TEST_TREES)
25 | self.edu_tree_obj_test = load_data(RST_TEST_EDUS_TREES)
26 |
27 | if TRAIN_Ext:
28 | ext_trees = load_data(EXT_TREES)
29 | ext_set = load_data(EXT_SET)
30 | self.train = self.train + ext_set
31 | self.tree_obj_train = self.tree_obj_train + ext_trees
32 |
33 | @staticmethod
34 | def get_voc_labels():
35 | word2ids = load_data(VOC_WORD2IDS_PATH)
36 | pos2ids = load_data(POS_word2ids_PATH)
37 | return word2ids, pos2ids
38 |
39 | @staticmethod
40 | def get_dev():
41 | return load_data(RST_DEV_TREES)
42 |
43 | @staticmethod
44 | def get_test():
45 | if USE_AE:
46 | return load_data(RST_TEST_EDUS_TREES_)
47 | else:
48 | return load_data(RST_TEST_TREES)
49 |
--------------------------------------------------------------------------------
/ch_dp_gan/new_ctb.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import re
3 | import os
4 | from nltk.tree import Tree as ParseTree
5 | from util.berkely import BerkeleyParser
6 | from tqdm import tqdm
7 |
8 |
9 | if __name__ == '__main__':
10 | ctb_dir = "data/CTB"
11 | save_dir = "data/CTB_auto"
12 | encoding = "UTF-8"
13 | ctb = {}
14 | s_pat = re.compile(r"\S+?)>(?P.*?)", re.M | re.DOTALL)
15 | parser = BerkeleyParser()
16 | for file in tqdm(os.listdir(ctb_dir)):
17 | if os.path.isfile(os.path.join(save_dir, file)):
18 | continue
19 | print(file)
20 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd:
21 | doc = fd.read()
22 | parses = []
23 | for match in s_pat.finditer(doc):
24 | sid = match.group("sid")
25 | sparse = ParseTree.fromstring(match.group("sparse"))
26 | pairs = [(node[0], node.label()) for node in sparse.subtrees()
27 | if node.height() == 2 and node.label() != "-NONE-"]
28 | words, tags = list(zip(*pairs))
29 | print(sid, " ".join(words))
30 | if sid == "5133":
31 | parse = sparse
32 | else:
33 | parse = parser.parse(words, timeout=2000)
34 | parses.append((sid, parse))
35 | with open(os.path.join(save_dir, file), "w+", encoding=encoding) as save_fd:
36 | for sid, parse in parses:
37 | save_fd.write("\n" % sid)
38 | save_fd.write(str(parse))
39 | save_fd.write("\n\n")
40 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/new_ctb.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import re
3 | import os
4 | from nltk.tree import Tree as ParseTree
5 | from util.berkely import BerkeleyParser
6 | from tqdm import tqdm
7 |
8 |
9 | if __name__ == '__main__':
10 | ctb_dir = "data/CTB"
11 | save_dir = "data/CTB_auto"
12 | encoding = "UTF-8"
13 | ctb = {}
14 | s_pat = re.compile(r"\S+?)>(?P.*?)", re.M | re.DOTALL)
15 | parser = BerkeleyParser()
16 | for file in tqdm(os.listdir(ctb_dir)):
17 | if os.path.isfile(os.path.join(save_dir, file)):
18 | continue
19 | print(file)
20 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd:
21 | doc = fd.read()
22 | parses = []
23 | for match in s_pat.finditer(doc):
24 | sid = match.group("sid")
25 | sparse = ParseTree.fromstring(match.group("sparse"))
26 | pairs = [(node[0], node.label()) for node in sparse.subtrees()
27 | if node.height() == 2 and node.label() != "-NONE-"]
28 | words, tags = list(zip(*pairs))
29 | print(sid, " ".join(words))
30 | if sid == "5133":
31 | parse = sparse
32 | else:
33 | parse = parser.parse(words, timeout=2000)
34 | parses.append((sid, parse))
35 | with open(os.path.join(save_dir, file), "w+", encoding=encoding) as save_fd:
36 | for sid, parse in parses:
37 | save_fd.write("\n" % sid)
38 | save_fd.write(str(parse))
39 | save_fd.write("\n\n")
40 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/rst_xl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
8 | from config import *
9 | from util.file_util import *
10 |
11 |
12 | class RST:
13 | def __init__(self, train_set=TRAIN_SET_XL, dev_set=DEV_SET_XL, test_set=TEST_SET_XL):
14 | self.train = load_data(train_set)
15 | self.train_rst = load_data(RST_TRAIN_TREES_RST)
16 | self.dev = load_data(dev_set)
17 | self.dev_rst = load_data(RST_DEV_TREES_RST)
18 | self.test = load_data(test_set)
19 | self.test_rst = load_data(RST_TEST_TREES_RST)
20 |
21 | self.tree_obj_train = load_data(RST_TRAIN_TREES)
22 | self.tree_obj_dev = load_data(RST_DEV_TREES)
23 | self.tree_obj_test = load_data(RST_TEST_TREES)
24 | self.edu_tree_obj_test = load_data(RST_TEST_EDUS_TREES)
25 |
26 | if TRAIN_Ext:
27 | ext_trees = load_data(EXT_TREES)
28 | ext_trees_rst = load_data(EXT_TREES_RST)
29 | ext_set = load_data(EXT_SET_XL)
30 | self.train = self.train + ext_set
31 | self.train_rst = self.train_rst + ext_trees_rst
32 | self.tree_obj_train = self.tree_obj_train + ext_trees
33 |
34 | @staticmethod
35 | def get_voc_labels():
36 | word2ids = load_data(VOC_WORD2IDS_PATH)
37 | pos2ids = load_data(POS_word2ids_PATH)
38 | return word2ids, pos2ids
39 |
40 | @staticmethod
41 | def get_dev():
42 | return load_data(RST_DEV_TREES)
43 |
44 | @staticmethod
45 | def get_test():
46 | if USE_AE:
47 | return load_data(RST_TEST_EDUS_TREES)
48 | else:
49 | return load_data(RST_TEST_TREES)
50 |
--------------------------------------------------------------------------------
/en_dp_gan/structure/tree_obj.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: lyzhang
5 | @Date: 2018/4/5
6 | @Description: tree.edus in the form of ids
7 | tree.nodes in the form of rst_tree objective
8 | """
9 | import copy
10 |
11 |
12 | class tree_obj:
13 | def __init__(self, tree=None):
14 | self.edus = list()
15 | self.nodes = list()
16 | self.sents_edus = list()
17 | if tree is not None:
18 | self.file_name = tree.file_name
19 | self.tree = tree
20 | self.pre_traverse(tree)
21 |
22 | def __copy__(self):
23 | tree_ = copy.copy(self.tree)
24 | edus_ = [copy.copy(edu) for edu in self.edus]
25 | nodes_ = [copy.copy(node) for node in self.nodes]
26 | t_o = tree_obj(tree_)
27 | t_o.edus = edus_
28 | t_o.nodes = nodes_
29 | return t_o
30 |
31 | def assign_edus(self, edus_list):
32 | for edu in edus_list:
33 | self.edus.append(edu)
34 |
35 | def pre_traverse(self, root):
36 | if root is None:
37 | return
38 | self.pre_traverse(root.left_child)
39 | self.pre_traverse(root.right_child)
40 | # judge if nodes
41 | if root.left_child is None and root.right_child is None:
42 | self.edus.append(root)
43 | self.nodes.append(root)
44 | else:
45 | self.nodes.append(root)
46 |
47 | def get_sents_edus(self):
48 | tmp_sent_edus = list()
49 | for edu in self.edus:
50 | tmp_sent_edus.append(edu)
51 | if edu.edu_node_boundary:
52 | self.sents_edus.append(tmp_sent_edus)
53 | tmp_sent_edus = list()
54 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/tree_obj.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: lyzhang
5 | @Date: 2018/4/5
6 | @Description: tree.edus in the form of ids
7 | tree.nodes in the form of rst_tree objective
8 | """
9 | import copy
10 |
11 |
12 | class tree_obj:
13 | def __init__(self, tree=None):
14 | self.edus = list()
15 | self.nodes = list()
16 | self.sents_edus = list()
17 | if tree is not None:
18 | self.file_name = tree.file_name
19 | # init by tree
20 | self.tree = tree
21 | self.pre_traverse(tree)
22 | # self.get_sents_edus()
23 |
24 | def __copy__(self):
25 | tree_ = copy.copy(self.tree)
26 | edus_ = [copy.copy(edu) for edu in self.edus]
27 | nodes_ = [copy.copy(node) for node in self.nodes]
28 | t_o = tree_obj(tree_)
29 | t_o.edus = edus_
30 | t_o.nodes = nodes_
31 | return t_o
32 |
33 | def assign_edus(self, edus_list):
34 | for edu in edus_list:
35 | self.edus.append(edu)
36 |
37 | def pre_traverse(self, root):
38 | if root is None:
39 | return
40 | self.pre_traverse(root.left_child)
41 | self.pre_traverse(root.right_child)
42 | if root.left_child is None and root.right_child is None:
43 | self.edus.append(root)
44 | self.nodes.append(root)
45 | else:
46 | self.nodes.append(root)
47 |
48 | def get_sents_edus(self):
49 | tmp_sent_edus = list()
50 | for edu in self.edus:
51 | tmp_sent_edus.append(edu)
52 | if edu.edu_node_boundary:
53 | self.sents_edus.append(tmp_sent_edus)
54 | tmp_sent_edus = list()
55 |
--------------------------------------------------------------------------------
/ch_dp_gan/util/berkely.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import subprocess
3 | import thulac
4 | import threading
5 | import os
6 | from nltk.tree import Tree
7 |
8 |
9 | BERKELEY_JAR = "berkeleyparser/BerkeleyParser-1.7.jar"
10 | BERKELEY_GRAMMAR = "berkeleyparser/chn_sm5.gr"
11 |
12 |
13 | class BerkeleyParser(object):
14 | def __init__(self):
15 | self.tokenizer = thulac.thulac()
16 | self.cmd = ['java', '-Xmx1024m', '-jar', BERKELEY_JAR, '-gr', BERKELEY_GRAMMAR]
17 | self.process = self.start()
18 |
19 | def start(self):
20 | return subprocess.Popen(self.cmd, env=dict(os.environ), universal_newlines=True, shell=False, bufsize=0,
21 | stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, errors='ignore')
22 |
23 | def stop(self):
24 | if self.process:
25 | self.process.terminate()
26 |
27 | def restart(self):
28 | self.stop()
29 | self.process = self.start()
30 |
31 | def parse_thread(self, text, results):
32 | text = text.replace("(", '-LRB-')
33 | text = text.replace(")", '-RRB-')
34 | self.process.stdin.write(text + '\n')
35 | self.process.stdin.flush()
36 | ret = self.process.stdout.readline().strip()
37 | results.append(ret)
38 |
39 | def parse(self, words, timeout=20000):
40 | # words, _ = list(zip(*self.tokenizer.cut(text)))
41 | results = []
42 | t = threading.Thread(target=self.parse_thread, kwargs={'text': " ".join(words), 'results': results})
43 | t.setDaemon(True)
44 | t.start()
45 | t.join(timeout)
46 |
47 | if not results:
48 | self.restart()
49 | raise TimeoutError()
50 | else:
51 | return Tree.fromstring(results[0])
52 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/util/berkely.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import subprocess
3 | import thulac
4 | import threading
5 | import os
6 | from nltk.tree import Tree
7 |
8 |
9 | BERKELEY_JAR = "berkeleyparser/BerkeleyParser-1.7.jar"
10 | BERKELEY_GRAMMAR = "berkeleyparser/chn_sm5.gr"
11 |
12 |
13 | class BerkeleyParser(object):
14 | def __init__(self):
15 | self.tokenizer = thulac.thulac()
16 | self.cmd = ['java', '-Xmx1024m', '-jar', BERKELEY_JAR, '-gr', BERKELEY_GRAMMAR]
17 | self.process = self.start()
18 |
19 | def start(self):
20 | return subprocess.Popen(self.cmd, env=dict(os.environ), universal_newlines=True, shell=False, bufsize=0,
21 | stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, errors='ignore')
22 |
23 | def stop(self):
24 | if self.process:
25 | self.process.terminate()
26 |
27 | def restart(self):
28 | self.stop()
29 | self.process = self.start()
30 |
31 | def parse_thread(self, text, results):
32 | text = text.replace("(", '-LRB-')
33 | text = text.replace(")", '-RRB-')
34 | self.process.stdin.write(text + '\n')
35 | self.process.stdin.flush()
36 | ret = self.process.stdout.readline().strip()
37 | results.append(ret)
38 |
39 | def parse(self, words, timeout=20000):
40 | # words, _ = list(zip(*self.tokenizer.cut(text)))
41 | results = []
42 | t = threading.Thread(target=self.parse_thread, kwargs={'text': " ".join(words), 'results': results})
43 | t.setDaemon(True)
44 | t.start()
45 | t.join(timeout)
46 |
47 | if not results:
48 | self.restart()
49 | raise TimeoutError()
50 | else:
51 | return Tree.fromstring(results[0])
52 |
--------------------------------------------------------------------------------
/ch_dp_gan/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # pycharm
107 | .idea/
108 |
109 | # data
110 | data/
111 |
112 | # pyltp models
113 | pub/pyltp_models
114 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # pycharm
107 | .idea/
108 |
109 | # data
110 | data/
111 |
112 | # pyltp models
113 | pub/pyltp_models
114 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/drawer.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 |
4 | def draw_line(x, y_s, colors, labels, shapes, begins, lines_d, l_colors, step=1, x_name="STEP", y_name="LOSS"):
5 | fig = plt.figure(figsize=(6, 4), facecolor="white")
6 | ax1 = fig.add_subplot(1, 1, 1)
7 | ax1.set_xlabel(x_name)
8 | ax1.set_ylabel(y_name)
9 | for idx, y in enumerate(y_s):
10 | x_ = x[begins[idx]:]
11 | y_ = y[begins[idx]:]
12 | ax1.scatter(x_[::step], y_[::step], color=colors[idx], marker=shapes[idx], edgecolors=colors[idx], s=20, label=labels[idx])
13 | if lines_d[idx]:
14 | line_general = 20
15 | x_n = x_[::line_general]
16 | y_n = []
17 | sum_ = 0.
18 | sum_idx_num = 0
19 | for idx__, y__ in enumerate(y_):
20 | sum_idx_num += 1
21 | sum_ += y__
22 | if idx__ > 0 and idx__ % line_general == 0:
23 | y_n.append(sum_ / line_general)
24 | sum_ = 0.
25 | sum_idx_num = 0
26 | if sum_ > 0.:
27 | y_n.append(sum_ / sum_idx_num)
28 | if idx == 0:
29 | ax1.plot(x_n[:29], y_n[:29], color="PaleGreen", linewidth=2.6, linestyle="-")
30 | ax1.plot(x_n[28:], y_n[28:], color="Green", linewidth=2.6, linestyle="-")
31 | else:
32 | ax1.plot(x_n, y_n, color=l_colors[idx], linewidth=2.6, linestyle="-")
33 | plt.vlines(542, 0, 2, colors="black", linestyles="dashed", linewidth=2)
34 | plt.annotate("warm up", xy=(300, 0.45), xytext=(250, 0.15), arrowprops=dict(facecolor="white", headlength=4,
35 | headwidth=13, width=4))
36 | plt.legend(loc="upper left")
37 | plt.grid(linestyle='-')
38 | plt.show()
39 |
40 |
41 | if __name__ == "__main__":
42 | draw_line([1, 2, 3], [[3, 2, 1]], ["red"], ["A"])
43 |
--------------------------------------------------------------------------------
/en_dp_gan/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Since I had done many different experiments based on this project, some model settings below could be useless.
4 | Longyin Zhang
5 | """
6 | from path_config import *
7 | from util.file_util import *
8 |
9 | SAVE_MODEL = True
10 | VERSION, SET = 5, 97
11 | USE_CUDA, CUDA_ID = True, 3
12 | USE_BOUND, USE_GAN, USE_S_STACK = True, False, False
13 |
14 | WARM_UP_EP = 7 if USE_GAN else 20
15 | MAX_W, MAX_H = 80, 20
16 | LEARN_RATE, WD, D_LR = 0.001, 1e-4, 0.0001
17 | EPOCH, BATCH_SIZE, LOG_EVERY, VALIDATE_EVERY = 20, 5, 1, 5
18 | in_channel_G, out_channel_G, ker_h_G, ker_w_G, strip_G = 2, 8, 3, MAX_W, 1
19 | p_w_G, p_h_G = 3, 1
20 | METype = 1
21 | USE_AE = False
22 | TRAIN_Ext = False
23 |
24 | MAX_POOLING = True
25 | RANDOM_MASK_LEARN, RMR = False, 0.1
26 | LABEL_SWITCH, SWITCH_ITE = False, 4
27 | USE_POSE, POS_SIZE = True, 30
28 | USE_ELMo, EMBED_LEARN = True, False
29 | EMBED_SIZE, HIDDEN_SIZE = (1024 if USE_ELMo else 300), 256
30 | BOUND_INFO_SIZE = 30 # 30
31 | GATE_V = 1
32 | USE_LEAF, Re_DROP, GateDrop = True, 0.3, 0.2
33 | TEST_ONLY = False
34 | OPT_ATTN = False
35 | EDU_ATT, ML_ATT_HIDDEN_e, HEADS_e = False, 128, 2
36 | CONTEXT_ATT, ML_ATT_HIDDEN, HEADS = False, 128, 2
37 | SPLIT_MLP_SIZE, NR_MLP_SIZE = 128, 128
38 | FEAT_W_SIZE, FEAT_P_SIZE, FEAT_Head_SIZE = 50, 30, 20
39 | LEN_SIZE, CENTER_SIZE, INNER_SIZE = 10, 30, 30
40 | FEAT_H_SIZE = 140
41 | USE_CNN, KERNEL_SIZE, PADDING_SIZE = True, 2, 1
42 | NUCL_MLP_SIZE, REL_MLP_SIZE = 64, 128 # Total basic.
43 | LAYER_NORM_USE = True
44 | ALPHA_SPAN, ALPHA_NR = 0.3, 1.0
45 | SEED = 36
46 | USE_R_ADAM = False
47 | WARM_UP = 20
48 | L2, DROP_OUT = 1e-5, 0.2
49 | BETA1, BETA2 = 0.9, 0.999
50 | TRAN_LABEL_NUM, NR_LABEL_NUM, NUCL_LABEL_NUM, REL_LABEL_NUM = 1, 42, 3, 18
51 | SHIFT, REDUCE = "SHIFT", "REDUCE"
52 | REDUCE_NN, REDUCE_NS, REDUCE_SN = "REDUCE-NN", "REDUCE-NS", "REDUCE-SN"
53 | NN, NS, SN = "NN", "NS", "SN"
54 | PAD, PAD_ids = "", 0
55 | UNK, UNK_ids = "", 1
56 | action2ids = {SHIFT: 0, REDUCE: 1}
57 | ids2action = {0: SHIFT, 1: REDUCE}
58 | nucl2ids = {NN: 0, NS: 1, SN: 2} if METype == 1 else {"N": 0, "S": 1}
59 | ids2nucl = {0: NN, 1: NS, 2: SN} if METype == 1 else {0: "N", 1: "S"}
60 | ns_dict = {"Satellite": 0, "Nucleus": 1, "Root": 2}
61 | ns_dict_ = {0: "Satellite", 1: "Nucleus", 2: "Root"}
62 | coarse2ids = load_data(REL_coarse2ids)
63 | ids2coarse = load_data(REL_ids2coarse)
64 | nr2ids = load_data(LABEL2IDS)
65 | ids2nr = load_data(IDS2LABEL)
66 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from path_config import *
3 | from util.file_util import *
4 | SAVE_MODEL = True
5 | VERSION, SET = 9, 147
6 | USE_CUDA, CUDA_ID = True, 0
7 | EDU_ENCODE_VERSION, SPLIT_V = 2, 1
8 | CHUNK_SIZE = 768 # 64
9 | MAX_LEN = 32
10 |
11 | USE_BOUND, USE_GAN, USE_S_STACK = True, True, False
12 | TRAIN_XLNET, XL_FINE, Joint_EDU_R, XLNET_TYPE, XLNET_SIZE = True, True, False, "xlnet-base-cased", 768
13 |
14 | MAX_W, MAX_H = 80, 20
15 | LEARN_RATE, WD, D_LR = 0.0001, 1e-4, 0.0001
16 | EPOCH, BATCH_SIZE, LOG_EVERY, VALIDATE_EVERY = 50, 1, 5, 20
17 | UPDATE_ITE = 32
18 | WARM_UP_EP = 7 if USE_GAN else EPOCH
19 | # CNN FEATURE
20 | in_channel_G, out_channel_G, ker_h_G, ker_w_G, strip_G = 2, 32, 3, MAX_W // 2, 1
21 | p_w_G, p_h_G = 3, 3
22 | METype = 1
23 | USE_AE = False
24 | TRAIN_Ext = False
25 |
26 | MAX_POOLING = True
27 | RANDOM_MASK_LEARN, RMR = False, 0.1
28 | LABEL_SWITCH, SWITCH_ITE = False, 4
29 | USE_POSE, POS_SIZE = False, 30
30 | USE_ELMo, EMBED_LEARN = False, False
31 | EMBED_SIZE, HIDDEN_SIZE = (1024 if USE_ELMo else 300), 384
32 | BOUND_INFO_SIZE = 30 # 30
33 |
34 | GATE_V = 1
35 | USE_LEAF, Re_DROP, GateDrop = True, 0.3, 0.2
36 |
37 | TEST_ONLY = False
38 | OPT_ATTN = False
39 |
40 | # ATTN
41 | EDU_ATT, ML_ATT_HIDDEN_e, HEADS_e = False, 128, 2
42 | CONTEXT_ATT, ML_ATT_HIDDEN, HEADS = True, 128, 2
43 |
44 | # MLP
45 | SPLIT_MLP_SIZE, NR_MLP_SIZE = 128, 128
46 |
47 | # feature engineer (feat_e_len, feat_e_w, feat_e_p, feat_e_h, feat_e_c, feat_e_inner)
48 | FEAT_W_SIZE, FEAT_P_SIZE, FEAT_Head_SIZE = 50, 30, 20
49 | LEN_SIZE, CENTER_SIZE, INNER_SIZE = 10, 30, 30
50 | FEAT_H_SIZE = 140
51 | USE_CNN, KERNEL_SIZE, PADDING_SIZE = True, 2, 1
52 | NUCL_MLP_SIZE, REL_MLP_SIZE = 64, 128 # Total basic.
53 | LAYER_NORM_USE = True
54 | ALPHA_SPAN, ALPHA_NR = 0.3, 1.0
55 | SEED = 19
56 | USE_R_ADAM = False
57 | WARM_UP = 20
58 | L2, DROP_OUT = 1e-5, 0.2
59 | BETA1, BETA2 = 0.9, 0.999
60 | TRAN_LABEL_NUM, NR_LABEL_NUM, NUCL_LABEL_NUM, REL_LABEL_NUM = 1, 42, 3, 18
61 | SHIFT, REDUCE = "SHIFT", "REDUCE"
62 | REDUCE_NN, REDUCE_NS, REDUCE_SN = "REDUCE-NN", "REDUCE-NS", "REDUCE-SN"
63 | NN, NS, SN = "NN", "NS", "SN"
64 | PAD, PAD_ids = "", 0
65 | UNK, UNK_ids = "", 1
66 | action2ids = {SHIFT: 0, REDUCE: 1}
67 | ids2action = {0: SHIFT, 1: REDUCE}
68 | nucl2ids = {NN: 0, NS: 1, SN: 2} if METype == 1 else {"N": 0, "S": 1}
69 | ids2nucl = {0: NN, 1: NS, 2: SN} if METype == 1 else {0: "N", 1: "S"}
70 | ns_dict = {"Satellite": 0, "Nucleus": 1, "Root": 2}
71 | ns_dict_ = {0: "Satellite", 1: "Nucleus", 2: "Root"}
72 | coarse2ids = load_data(REL_coarse2ids)
73 | ids2coarse = load_data(REL_ids2coarse)
74 | nr2ids = load_data(LABEL2IDS)
75 | ids2nr = load_data(IDS2LABEL)
76 | NR_LOSS_OPT = False
77 |
--------------------------------------------------------------------------------
/en_dp_gan/path_config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
8 | # RST-DT
9 | RST_DT_PATH = "data/rst_dt"
10 | RAW_RST_DT_PATH = "data/rst_dt/RST_DT_RAW"
11 |
12 | # EDUs of train & dev
13 | RST_EDUs_ori_tr = "data/EDUs/ori_train_edus.tsv"
14 | RST_EDUs_ori_de = "data/EDUs/ori_dev_edus.tsv"
15 | # EN EDUs of train & dev
16 | RST_EDUs_google_tr = "data/EDUs/translate_train.txt"
17 | RST_EDUs_google_de = "data/EDUs/translate_dev.tsv"
18 |
19 | EXT_TREES = "data/rst_dt/ext_trees.pkl"
20 | EXT_TREES_RST = "data/rst_dt/ext_trees_rst.pkl"
21 | EXT_SET = "data/rst_dt/ext_set.pkl"
22 |
23 | # RST-DT Generated
24 | RST_TRAIN_TREES = "data/rst_dt/train_trees.pkl"
25 | RST_TRAIN_TREES_RST = "data/rst_dt/train_trees_rst.pkl"
26 | TRAIN_SET = "data/rst_dt/train_set.pkl"
27 | RST_TEST_TREES = "data/rst_dt/test_trees.pkl"
28 | RST_TEST_TREES_RST = "data/rst_dt/test_trees_rst.pkl"
29 | RST_TEST_EDUS_TREES = "data/rst_dt/edus_test_trees.pkl"
30 | RST_TEST_EDUS_TREES_ = "data/rst_dt/edus_test_trees_.pkl"
31 | TEST_SET = "data/rst_dt/test_set.pkl"
32 | RST_DEV_TREES = "data/rst_dt/dev_trees.pkl"
33 | RST_DEV_TREES_RST = "data/rst_dt/dev_trees_rst.pkl"
34 | DEV_SET = "data/rst_dt/dev_set.pkl"
35 |
36 | TRAIN_SET_NR = "data/rst_dt/nr_train_set.pkl"
37 | TEST_SET_NR = "data/rst_dt/nr_test_set.pkl"
38 | DEV_SET_NR = "data/rst_dt/nr_dev_set.pkl"
39 |
40 | RST_TEST_TREES_N = "data/rst_dt/test_trees_n.pkl"
41 | RST_TEST_TREES_RST_N = "data/rst_dt/test_trees_rst_n.pkl"
42 | RST_TEST_EDUS_TREES_N = "data/rst_dt/edus_test_trees_n.pkl"
43 | TEST_SET_N = "data/rst_dt/test_n.set.pkl"
44 | TEST_SET_NR_N = "data/rst_dt/nr_test_n.set.pkl"
45 |
46 | # VOC
47 | REL_raw2coarse = "data/voc/raw2coarse.pkl"
48 | VOC_WORD2IDS_PATH = "data/voc/word2ids.pkl"
49 | POS_word2ids_PATH = "data/voc/pos2ids.pkl"
50 | REL_coarse2ids = "data/voc/coarse2ids.pkl"
51 | REL_ids2coarse = "data/voc/ids2coarse.pkl"
52 | VOC_VEC_PATH = "data/voc/ids2vec.pkl"
53 | LABEL2IDS = "data/voc/label2ids.pkl"
54 | IDS2LABEL = "data/voc/ids2label.pkl"
55 | options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/" \
56 | "elmo_2x4096_512_2048cnn_2xhighway_options.json"
57 | weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo" \
58 | "_2x4096_512_2048cnn_2xhighway_weights.hdf5"
59 | path_to_jar = 'stanford-corenlp-full-2018-02-27'
60 | MODELS2SAVE = "data/models_saved"
61 | MODEL_SAVE = "data/models/treebuilder.partptr.basic"
62 | LOG_ALL = "data/logs"
63 | PRE_DEPTH = "data/pre_depth.pkl."
64 | GOLD_DEPTH = "data/gold_depth.pkl"
65 | DRAW_DT = "data/drawer.pkl"
66 | LOSS_PATH = "data/loss/"
67 | EXT_SET_NEG = "data/rst_dt/neg_ext_set.pkl"
68 | TRAIN_SET_NEG = "data/rst_dt/neg_train_set.pkl"
69 | TEST_SET_NEG = "data/rst_dt/neg_test_set.pkl"
70 | DEV_SET_NEG = "data/rst_dt/neg_dev_set.pkl"
71 |
--------------------------------------------------------------------------------
/ch_dp_gan/util/file_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: lyzhang
5 | @Date: 2018/3/13
6 | @Description:
7 | """
8 |
9 | import urllib
10 | import gzip
11 | import os
12 | import shutil
13 | import pickle as pkl
14 |
15 |
16 | def safe_mkdir(path):
17 | """ Create a directory if there isn't one already. """
18 | try:
19 | os.mkdir(path)
20 | except OSError:
21 | pass
22 |
23 |
24 | def safe_mkdirs(path_list):
25 | """ Create a directory if there isn't one already. """
26 | for path in path_list:
27 | try:
28 | os.mkdir(path)
29 | except OSError:
30 | pass
31 |
32 |
33 | def download_one_file(download_url,
34 | local_dest,
35 | expected_byte=None,
36 | unzip_and_remove=False):
37 | """
38 | Download the file from download_url into local_dest
39 | if the file doesn't already exists.
40 | If expected_byte is provided, check if
41 | the downloaded file has the same number of bytes.
42 | If unzip_and_remove is True, unzip the file and remove the zip file
43 | """
44 | if os.path.exists(local_dest) or os.path.exists(local_dest[:-3]):
45 | print('%s already exists' % local_dest)
46 | else:
47 | print('Downloading %s' % download_url)
48 | local_file, _ = urllib.request.urlretrieve(download_url, local_dest)
49 | file_stat = os.stat(local_dest)
50 | if expected_byte:
51 | if file_stat.st_size == expected_byte:
52 | print('Successfully downloaded %s' % local_dest)
53 | if unzip_and_remove:
54 | with gzip.open(local_dest, 'rb') as f_in, open(local_dest[:-3], 'wb') as f_out:
55 | shutil.copyfileobj(f_in, f_out)
56 | os.remove(local_dest)
57 | else:
58 | print('The downloaded file has unexpected number of bytes')
59 |
60 |
61 | def save_data(obj, path, append=False):
62 | if append:
63 | with open(path, "wb+") as f:
64 | pkl.dump(obj, f)
65 | else:
66 | with open(path, "wb") as f:
67 | pkl.dump(obj, f)
68 |
69 |
70 | def load_data(path):
71 | with open(path, "rb") as f:
72 | obj = pkl.load(f)
73 | return obj
74 |
75 |
76 | def write_iterate(ite, file_path):
77 | with open(file_path, "w") as f:
78 | for line in ite:
79 | f.write(line + "\n")
80 |
81 |
82 | def write_append(txt, file_path):
83 | with open(file_path, "a") as f:
84 | f.write(txt + "\n")
85 |
86 |
87 | def write_over(txt, file_path):
88 | with open(file_path, "w") as f:
89 | f.write(txt + "\n")
90 |
91 |
92 | def print_(str_, log_file, write_=False):
93 | if write_:
94 | write_over(str_, log_file)
95 | else:
96 | write_append(str_, log_file)
97 | print(str_)
98 |
--------------------------------------------------------------------------------
/en_dp_gan/util/file_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: lyzhang
5 | @Date: 2018/3/13
6 | @Description:
7 | """
8 |
9 | import urllib
10 | import gzip
11 | import os
12 | import shutil
13 | import pickle as pkl
14 |
15 |
16 | def safe_mkdir(path):
17 | """ Create a directory if there isn't one already. """
18 | try:
19 | os.mkdir(path)
20 | except OSError:
21 | pass
22 |
23 |
24 | def safe_mkdirs(path_list):
25 | """ Create a directory if there isn't one already. """
26 | for path in path_list:
27 | try:
28 | os.mkdir(path)
29 | except OSError:
30 | pass
31 |
32 |
33 | def download_one_file(download_url,
34 | local_dest,
35 | expected_byte=None,
36 | unzip_and_remove=False):
37 | """
38 | Download the file from download_url into local_dest
39 | if the file doesn't already exists.
40 | If expected_byte is provided, check if
41 | the downloaded file has the same number of bytes.
42 | If unzip_and_remove is True, unzip the file and remove the zip file
43 | """
44 | if os.path.exists(local_dest) or os.path.exists(local_dest[:-3]):
45 | print('%s already exists' % local_dest)
46 | else:
47 | print('Downloading %s' % download_url)
48 | local_file, _ = urllib.request.urlretrieve(download_url, local_dest)
49 | file_stat = os.stat(local_dest)
50 | if expected_byte:
51 | if file_stat.st_size == expected_byte:
52 | print('Successfully downloaded %s' % local_dest)
53 | if unzip_and_remove:
54 | with gzip.open(local_dest, 'rb') as f_in, open(local_dest[:-3], 'wb') as f_out:
55 | shutil.copyfileobj(f_in, f_out)
56 | os.remove(local_dest)
57 | else:
58 | print('The downloaded file has unexpected number of bytes')
59 |
60 |
61 | def save_data(obj, path, append=False):
62 | if append:
63 | with open(path, "wb+") as f:
64 | pkl.dump(obj, f)
65 | else:
66 | with open(path, "wb") as f:
67 | pkl.dump(obj, f)
68 |
69 |
70 | def load_data(path):
71 | with open(path, "rb") as f:
72 | obj = pkl.load(f)
73 | return obj
74 |
75 |
76 | def write_iterate(ite, file_path):
77 | with open(file_path, "w") as f:
78 | for line in ite:
79 | f.write(line + "\n")
80 |
81 |
82 | def write_append(txt, file_path):
83 | with open(file_path, "a") as f:
84 | f.write(txt + "\n")
85 |
86 |
87 | def write_over(txt, file_path):
88 | with open(file_path, "w") as f:
89 | f.write(txt + "\n")
90 |
91 |
92 | def print_(str_, log_file, write_=False):
93 | if write_:
94 | write_over(str_, log_file)
95 | else:
96 | write_append(str_, log_file)
97 | print(str_)
98 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/file_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: lyzhang
5 | @Date: 2018/3/13
6 | @Description:
7 | """
8 |
9 | import urllib
10 | import gzip
11 | import os
12 | import shutil
13 | import pickle as pkl
14 |
15 |
16 | def safe_mkdir(path):
17 | """ Create a directory if there isn't one already. """
18 | try:
19 | os.mkdir(path)
20 | except OSError:
21 | pass
22 |
23 |
24 | def safe_mkdirs(path_list):
25 | """ Create a directory if there isn't one already. """
26 | for path in path_list:
27 | try:
28 | os.mkdir(path)
29 | except OSError:
30 | pass
31 |
32 |
33 | def download_one_file(download_url,
34 | local_dest,
35 | expected_byte=None,
36 | unzip_and_remove=False):
37 | """
38 | Download the file from download_url into local_dest
39 | if the file doesn't already exists.
40 | If expected_byte is provided, check if
41 | the downloaded file has the same number of bytes.
42 | If unzip_and_remove is True, unzip the file and remove the zip file
43 | """
44 | if os.path.exists(local_dest) or os.path.exists(local_dest[:-3]):
45 | print('%s already exists' % local_dest)
46 | else:
47 | print('Downloading %s' % download_url)
48 | local_file, _ = urllib.request.urlretrieve(download_url, local_dest)
49 | file_stat = os.stat(local_dest)
50 | if expected_byte:
51 | if file_stat.st_size == expected_byte:
52 | print('Successfully downloaded %s' % local_dest)
53 | if unzip_and_remove:
54 | with gzip.open(local_dest, 'rb') as f_in, open(local_dest[:-3], 'wb') as f_out:
55 | shutil.copyfileobj(f_in, f_out)
56 | os.remove(local_dest)
57 | else:
58 | print('The downloaded file has unexpected number of bytes')
59 |
60 |
61 | def save_data(obj, path, append=False):
62 | if append:
63 | with open(path, "wb+") as f:
64 | pkl.dump(obj, f)
65 | else:
66 | with open(path, "wb") as f:
67 | pkl.dump(obj, f)
68 |
69 |
70 | def load_data(path):
71 | with open(path, "rb") as f:
72 | obj = pkl.load(f)
73 | return obj
74 |
75 |
76 | def write_iterate(ite, file_path):
77 | with open(file_path, "w") as f:
78 | for line in ite:
79 | f.write(line + "\n")
80 |
81 |
82 | def write_append(txt, file_path):
83 | with open(file_path, "a") as f:
84 | f.write(txt + "\n")
85 |
86 |
87 | def write_over(txt, file_path):
88 | with open(file_path, "w") as f:
89 | f.write(txt + "\n")
90 |
91 |
92 | def print_(str_, log_file, write_=False):
93 | if write_:
94 | write_over(str_, log_file)
95 | else:
96 | write_append(str_, log_file)
97 | print(str_)
98 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/util/file_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: lyzhang
5 | @Date: 2018/3/13
6 | """
7 |
8 | import urllib
9 | import gzip
10 | import os
11 | import shutil
12 | import pickle as pkl
13 |
14 |
15 | def safe_mkdir(path):
16 | """ Create a directory if there isn't one already. """
17 | try:
18 | os.mkdir(path)
19 | except OSError:
20 | pass
21 |
22 |
23 | def safe_mkdirs(path_list):
24 | """ Create a directory if there isn't one already. """
25 | for path in path_list:
26 | try:
27 | os.mkdir(path)
28 | except OSError:
29 | pass
30 |
31 |
32 | def download_one_file(download_url,
33 | local_dest,
34 | expected_byte=None,
35 | unzip_and_remove=False):
36 | """
37 | Download the file from download_url into local_dest
38 | if the file doesn't already exists.
39 | If expected_byte is provided, check if
40 | the downloaded file has the same number of bytes.
41 | If unzip_and_remove is True, unzip the file and remove the zip file
42 | """
43 | if os.path.exists(local_dest) or os.path.exists(local_dest[:-3]):
44 | print('%s already exists' % local_dest)
45 | else:
46 | print('Downloading %s' % download_url)
47 | local_file, _ = urllib.request.urlretrieve(download_url, local_dest)
48 | file_stat = os.stat(local_dest)
49 | if expected_byte:
50 | if file_stat.st_size == expected_byte:
51 | print('Successfully downloaded %s' % local_dest)
52 | if unzip_and_remove:
53 | with gzip.open(local_dest, 'rb') as f_in, open(local_dest[:-3], 'wb') as f_out:
54 | shutil.copyfileobj(f_in, f_out)
55 | os.remove(local_dest)
56 | else:
57 | print('The downloaded file has unexpected number of bytes')
58 |
59 |
60 | def save_data(obj, path, append=False):
61 | if append:
62 | with open(path, "wb+") as f:
63 | pkl.dump(obj, f)
64 | else:
65 | with open(path, "wb") as f:
66 | pkl.dump(obj, f)
67 |
68 |
69 | def load_data(path):
70 | with open(path, "rb") as f:
71 | obj = pkl.load(f)
72 | return obj
73 |
74 |
75 | def write_iterate(ite, file_path):
76 | with open(file_path, "w") as f:
77 | for line in ite:
78 | f.write(line + "\n")
79 |
80 |
81 | def write_append(txt, file_path):
82 | with open(file_path, "a") as f:
83 | f.write(txt + "\n")
84 |
85 |
86 | def write_over(txt, file_path):
87 | with open(file_path, "w") as f:
88 | f.write(txt + "\n")
89 |
90 |
91 | def print_(str_, log_file, write_=False):
92 | if write_:
93 | write_over(str_, log_file)
94 | else:
95 | write_append(str_, log_file)
96 | print(str_)
97 |
98 |
99 | def get_files_all(path_):
100 | lines_all = []
101 | for f_name in os.listdir(path_):
102 | tmp_n = os.path.join(path_, f_name)
103 | with open(tmp_n, "r") as f:
104 | lines_all += f.readlines()
105 | return lines_all
106 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/path_config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: Lyzhang
5 | @Date:
6 | @Description:
7 | """
8 | # RST-DT
9 | RST_DT_PATH = "data/rst_dt"
10 | RAW_RST_DT_PATH = "data/rst_dt/RST_DT_RAW"
11 |
12 | # EDUs of train & dev
13 | RST_EDUs_ori_tr = "data/EDUs/ori_train_edus.tsv"
14 | RST_EDUs_ori_de = "data/EDUs/ori_dev_edus.tsv"
15 | # EN EDUs of train & dev
16 | RST_EDUs_google_tr = "data/EDUs/translate_train.txt"
17 | RST_EDUs_google_de = "data/EDUs/translate_dev.tsv"
18 |
19 | EXT_TREES = "data/rst_dt/ext_trees.pkl"
20 | EXT_TREES_RST = "data/rst_dt/ext_trees_rst.pkl"
21 | EXT_SET = "data/rst_dt/ext_set.pkl"
22 |
23 | # RST-DT Generated
24 | RST_TRAIN_TREES = "data/rst_dt/train_trees.pkl"
25 | RST_TRAIN_TREES_RST = "data/rst_dt/train_trees_rst.pkl"
26 | TRAIN_SET = "data/rst_dt/train_set.pkl"
27 | RST_TEST_TREES = "data/rst_dt/test_trees.pkl"
28 | RST_TEST_TREES_RST = "data/rst_dt/test_trees_rst.pkl"
29 | RST_TEST_EDUS_TREES = "data/rst_dt/edus_test_trees.pkl"
30 | RST_TEST_EDUS_TREES_ = "data/rst_dt/edus_test_trees_.pkl"
31 | TEST_SET = "data/rst_dt/test_set.pkl"
32 | RST_DEV_TREES = "data/rst_dt/dev_trees.pkl"
33 | RST_DEV_TREES_RST = "data/rst_dt/dev_trees_rst.pkl"
34 | DEV_SET = "data/rst_dt/dev_set.pkl"
35 | TRAIN_SET_XL = "data/rst_dt/train_set_xl.pkl"
36 | TEST_SET_XL = "data/rst_dt/test_set_xl.pkl"
37 | DEV_SET_XL = "data/rst_dt/dev_set_xl.pkl"
38 | EXT_SET_XL = "data/rst_dt/ext_set_xl.pkl"
39 |
40 | TRAIN_SET_XLM = "data/rst_dt/marcu/train_set_xlm.pkl" # marcu 学习
41 | TEST_SET_XLM = "data/rst_dt/marcu/test_set_xlm.pkl"
42 | DEV_SET_XLM = "data/rst_dt/marcu/dev_set_xlm.pkl"
43 | EXT_SET_XLM = "data/rst_dt/marcu/ext_set_xlm.pkl"
44 | NR2ids_marcu = "data/rst_dt/marcu/nr2ids.pkl"
45 | IDS2nr_marcu = "data/rst_dt/marcu/ids2nr.pkl"
46 |
47 | TRAIN_SET_NR = "data/rst_dt/nr_train_set.pkl" # N 和 R 分开
48 | TEST_SET_NR = "data/rst_dt/nr_test_set.pkl"
49 | DEV_SET_NR = "data/rst_dt/nr_dev_set.pkl"
50 |
51 | RST_TEST_TREES_N = "data/rst_dt/test_trees_n.pkl"
52 | RST_TEST_TREES_RST_N = "data/rst_dt/test_trees_rst_n.pkl"
53 | RST_TEST_EDUS_TREES_N = "data/rst_dt/edus_test_trees_n.pkl"
54 | TEST_SET_N = "data/rst_dt/test_n.set.pkl"
55 | TEST_SET_NR_N = "data/rst_dt/nr_test_n.set.pkl"
56 |
57 | # VOC
58 | REL_raw2coarse = "data/voc/raw2coarse.pkl"
59 | VOC_WORD2IDS_PATH = "data/voc/word2ids.pkl"
60 | POS_word2ids_PATH = "data/voc/pos2ids.pkl"
61 | REL_coarse2ids = "data/voc/coarse2ids.pkl"
62 | REL_ids2coarse = "data/voc/ids2coarse.pkl"
63 | VOC_VEC_PATH = "data/voc/ids2vec.pkl"
64 | LABEL2IDS = "data/voc/label2ids.pkl"
65 | IDS2LABEL = "data/voc/ids2label.pkl"
66 |
67 | options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/" \
68 | "elmo_2x4096_512_2048cnn_2xhighway_options.json"
69 | weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo" \
70 | "_2x4096_512_2048cnn_2xhighway_weights.hdf5"
71 | path_to_jar = 'stanford-corenlp-full-2018-02-27'
72 | MODELS2SAVE = "data/models_saved"
73 | MODEL_SAVE = "data/models/treebuilder.partptr.basic"
74 | LOG_ALL = "data/logs"
75 | PRE_DEPTH = "data/pre_depth.pkl."
76 | GOLD_DEPTH = "data/gold_depth.pkl"
77 | DRAW_DT = "data/drawer.pkl"
78 | LOSS_PATH = "data/loss/"
79 |
80 | # ADDITIONAL
81 | EXT_SET_NEG = "data/rst_dt/neg_ext_set.pkl"
82 | TRAIN_SET_NEG = "data/rst_dt/neg_train_set.pkl"
83 | TEST_SET_NEG = "data/rst_dt/neg_test_set.pkl"
84 | DEV_SET_NEG = "data/rst_dt/neg_dev_set.pkl"
85 |
--------------------------------------------------------------------------------
/ch_dp_gan/structure/vocab.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import logging
4 | from collections import defaultdict
5 | from gensim.models import KeyedVectors
6 | import numpy as np
7 | import torch
8 |
9 |
10 | logger = logging.getLogger(__name__)
11 | PAD_TAG = ""
12 | UNK_TAG = ""
13 |
14 |
15 | class Vocab:
16 | def __init__(self, name, counter, min_occur=1):
17 | self.name = name
18 | self.s2id = defaultdict(int)
19 | self.s2id[UNK_TAG] = 0
20 | self.id2s = [UNK_TAG]
21 |
22 | self.counter = counter
23 | for tag, freq in counter.items():
24 | if tag not in self.s2id and freq >= min_occur:
25 | self.s2id[tag] = len(self.s2id)
26 | self.id2s.append(tag)
27 | logger.info("%s vocabulary size %d" % (self.name, len(self.s2id)))
28 |
29 | def __getitem__(self, item):
30 | return self.s2id[item]
31 |
32 | def embedding(self, dim=None, pretrained=None, binary=False, freeze=False, use_gpu=False):
33 | if dim is not None and pretrained is not None:
34 | raise Warning("dim should not given if pretraiained weights are assigned")
35 |
36 | if dim is None and pretrained is None:
37 | raise Warning("one of dim or pretrained should be assigned")
38 |
39 | if pretrained:
40 | w2v = KeyedVectors.load_word2vec_format(pretrained, binary=binary)
41 | dim = w2v.vector_size
42 | scale = np.sqrt(3.0 / dim)
43 | weights = np.empty([len(self), dim], dtype=np.float32)
44 | oov_count = 0
45 | all_count = 0
46 | for tag, i in self.s2id.items():
47 | if tag in w2v.vocab:
48 | weights[i] = w2v[tag].astype(np.float32)
49 | else:
50 | oov_count += self.counter[tag]
51 | weights[i] = np.zeros(dim).astype(np.float32) if freeze else \
52 | np.random.uniform(-scale, scale, dim).astype(np.float32)
53 | all_count += self.counter[tag]
54 | logger.info("%s vocabulary pretrained OOV %d/%d, %.2f%%" %
55 | (self.name, oov_count, all_count, oov_count/all_count*100))
56 | else:
57 | scale = np.sqrt(3.0 / dim)
58 | weights = np.random.uniform(-scale, scale, [len(self), dim]).astype(np.float32)
59 | weights[0] = np.zeros(dim).astype(np.float32) if freeze else \
60 | np.random.uniform(-scale, scale, dim).astype(np.float32)
61 | weights = torch.from_numpy(weights)
62 | if use_gpu:
63 | weights = weights.cuda()
64 | embedding = torch.nn.Embedding.from_pretrained(weights, freeze=freeze)
65 | return embedding
66 |
67 | def __len__(self):
68 | return len(self.id2s)
69 |
70 |
71 | class Label:
72 | def __init__(self, name, counter, specials=None):
73 | self.name = name
74 | self.counter = counter.copy()
75 | self.label2id = {}
76 | self.id2label = []
77 | if specials:
78 | for label in specials:
79 | del self.counter[label]
80 | self.label2id[label] = len(self.label2id)
81 | self.id2label.append(label)
82 |
83 | for label, freq in self.counter.items():
84 | if label not in self.label2id:
85 | self.label2id[label] = len(self.label2id)
86 | self.id2label.append(label)
87 | logger.info("label %s size %d" % (name, len(self)))
88 |
89 | def __getitem__(self, item):
90 | return self.label2id[item]
91 |
92 | def __len__(self):
93 | return len(self.id2label)
94 |
--------------------------------------------------------------------------------
/en_dp_gan/structure/vocab.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import logging
4 | from collections import defaultdict
5 | from gensim.models import KeyedVectors
6 | import numpy as np
7 | import torch
8 |
9 |
10 | logger = logging.getLogger(__name__)
11 | PAD_TAG = ""
12 | UNK_TAG = ""
13 |
14 |
15 | class Vocab:
16 | def __init__(self, name, counter, min_occur=1):
17 | self.name = name
18 | self.s2id = defaultdict(int)
19 | self.s2id[UNK_TAG] = 0
20 | self.id2s = [UNK_TAG]
21 |
22 | self.counter = counter
23 | for tag, freq in counter.items():
24 | if tag not in self.s2id and freq >= min_occur:
25 | self.s2id[tag] = len(self.s2id)
26 | self.id2s.append(tag)
27 | logger.info("%s vocabulary size %d" % (self.name, len(self.s2id)))
28 |
29 | def __getitem__(self, item):
30 | return self.s2id[item]
31 |
32 | def embedding(self, dim=None, pretrained=None, binary=False, freeze=False, use_gpu=False):
33 | if dim is not None and pretrained is not None:
34 | raise Warning("dim should not given if pretraiained weights are assigned")
35 |
36 | if dim is None and pretrained is None:
37 | raise Warning("one of dim or pretrained should be assigned")
38 |
39 | if pretrained:
40 | w2v = KeyedVectors.load_word2vec_format(pretrained, binary=binary)
41 | dim = w2v.vector_size
42 | scale = np.sqrt(3.0 / dim)
43 | weights = np.empty([len(self), dim], dtype=np.float32)
44 | oov_count = 0
45 | all_count = 0
46 | for tag, i in self.s2id.items():
47 | if tag in w2v.vocab:
48 | weights[i] = w2v[tag].astype(np.float32)
49 | else:
50 | oov_count += self.counter[tag]
51 | weights[i] = np.zeros(dim).astype(np.float32) if freeze else \
52 | np.random.uniform(-scale, scale, dim).astype(np.float32)
53 | all_count += self.counter[tag]
54 | logger.info("%s vocabulary pretrained OOV %d/%d, %.2f%%" %
55 | (self.name, oov_count, all_count, oov_count/all_count*100))
56 | else:
57 | scale = np.sqrt(3.0 / dim)
58 | weights = np.random.uniform(-scale, scale, [len(self), dim]).astype(np.float32)
59 | weights[0] = np.zeros(dim).astype(np.float32) if freeze else \
60 | np.random.uniform(-scale, scale, dim).astype(np.float32)
61 | weights = torch.from_numpy(weights)
62 | if use_gpu:
63 | weights = weights.cuda()
64 | embedding = torch.nn.Embedding.from_pretrained(weights, freeze=freeze)
65 | return embedding
66 |
67 | def __len__(self):
68 | return len(self.id2s)
69 |
70 |
71 | class Label:
72 | def __init__(self, name, counter, specials=None):
73 | self.name = name
74 | self.counter = counter.copy()
75 | self.label2id = {}
76 | self.id2label = []
77 | if specials:
78 | for label in specials:
79 | del self.counter[label]
80 | self.label2id[label] = len(self.label2id)
81 | self.id2label.append(label)
82 |
83 | for label, freq in self.counter.items():
84 | if label not in self.label2id:
85 | self.label2id[label] = len(self.label2id)
86 | self.id2label.append(label)
87 | logger.info("label %s size %d" % (name, len(self)))
88 |
89 | def __getitem__(self, item):
90 | return self.label2id[item]
91 |
92 | def __len__(self):
93 | return len(self.id2label)
94 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/structure/vocab.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import logging
4 | from collections import defaultdict
5 | from gensim.models import KeyedVectors
6 | import numpy as np
7 | import torch
8 |
9 |
10 | logger = logging.getLogger(__name__)
11 | PAD_TAG = ""
12 | UNK_TAG = ""
13 |
14 |
15 | class Vocab:
16 | def __init__(self, name, counter, min_occur=1):
17 | self.name = name
18 | self.s2id = defaultdict(int)
19 | self.s2id[UNK_TAG] = 0
20 | self.id2s = [UNK_TAG]
21 |
22 | self.counter = counter
23 | for tag, freq in counter.items():
24 | if tag not in self.s2id and freq >= min_occur:
25 | self.s2id[tag] = len(self.s2id)
26 | self.id2s.append(tag)
27 | logger.info("%s vocabulary size %d" % (self.name, len(self.s2id)))
28 |
29 | def __getitem__(self, item):
30 | return self.s2id[item]
31 |
32 | def embedding(self, dim=None, pretrained=None, binary=False, freeze=False, use_gpu=False):
33 | if dim is not None and pretrained is not None:
34 | raise Warning("dim should not given if pretraiained weights are assigned")
35 |
36 | if dim is None and pretrained is None:
37 | raise Warning("one of dim or pretrained should be assigned")
38 |
39 | if pretrained:
40 | w2v = KeyedVectors.load_word2vec_format(pretrained, binary=binary)
41 | dim = w2v.vector_size
42 | scale = np.sqrt(3.0 / dim)
43 | weights = np.empty([len(self), dim], dtype=np.float32)
44 | oov_count = 0
45 | all_count = 0
46 | for tag, i in self.s2id.items():
47 | if tag in w2v.vocab:
48 | weights[i] = w2v[tag].astype(np.float32)
49 | else:
50 | oov_count += self.counter[tag]
51 | weights[i] = np.zeros(dim).astype(np.float32) if freeze else \
52 | np.random.uniform(-scale, scale, dim).astype(np.float32)
53 | all_count += self.counter[tag]
54 | logger.info("%s vocabulary pretrained OOV %d/%d, %.2f%%" %
55 | (self.name, oov_count, all_count, oov_count/all_count*100))
56 | else:
57 | scale = np.sqrt(3.0 / dim)
58 | weights = np.random.uniform(-scale, scale, [len(self), dim]).astype(np.float32)
59 | weights[0] = np.zeros(dim).astype(np.float32) if freeze else \
60 | np.random.uniform(-scale, scale, dim).astype(np.float32)
61 | weights = torch.from_numpy(weights)
62 | if use_gpu:
63 | weights = weights.cuda()
64 | embedding = torch.nn.Embedding.from_pretrained(weights, freeze=freeze)
65 | return embedding
66 |
67 | def __len__(self):
68 | return len(self.id2s)
69 |
70 |
71 | class Label:
72 | def __init__(self, name, counter, specials=None):
73 | self.name = name
74 | self.counter = counter.copy()
75 | self.label2id = {}
76 | self.id2label = []
77 | if specials:
78 | for label in specials:
79 | del self.counter[label]
80 | self.label2id[label] = len(self.label2id)
81 | self.id2label.append(label)
82 |
83 | for label, freq in self.counter.items():
84 | if label not in self.label2id:
85 | self.label2id[label] = len(self.label2id)
86 | self.id2label.append(label)
87 | logger.info("label %s size %d" % (name, len(self)))
88 |
89 | def __getitem__(self, item):
90 | return self.label2id[item]
91 |
92 | def __len__(self):
93 | return len(self.id2label)
94 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/vocab.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import logging
4 | from collections import defaultdict
5 | from gensim.models import KeyedVectors
6 | import numpy as np
7 | import torch
8 |
9 |
10 | logger = logging.getLogger(__name__)
11 | PAD_TAG = ""
12 | UNK_TAG = ""
13 |
14 |
15 | class Vocab:
16 | def __init__(self, name, counter, min_occur=1):
17 | self.name = name
18 | self.s2id = defaultdict(int)
19 | self.s2id[UNK_TAG] = 0
20 | self.id2s = [UNK_TAG]
21 |
22 | self.counter = counter
23 | for tag, freq in counter.items():
24 | if tag not in self.s2id and freq >= min_occur:
25 | self.s2id[tag] = len(self.s2id)
26 | self.id2s.append(tag)
27 | logger.info("%s vocabulary size %d" % (self.name, len(self.s2id)))
28 |
29 | def __getitem__(self, item):
30 | return self.s2id[item]
31 |
32 | def embedding(self, dim=None, pretrained=None, binary=False, freeze=False, use_gpu=False):
33 | if dim is not None and pretrained is not None:
34 | raise Warning("dim should not given if pretraiained weights are assigned")
35 |
36 | if dim is None and pretrained is None:
37 | raise Warning("one of dim or pretrained should be assigned")
38 |
39 | if pretrained:
40 | w2v = KeyedVectors.load_word2vec_format(pretrained, binary=binary)
41 | dim = w2v.vector_size
42 | scale = np.sqrt(3.0 / dim)
43 | weights = np.empty([len(self), dim], dtype=np.float32)
44 | oov_count = 0
45 | all_count = 0
46 | for tag, i in self.s2id.items():
47 | if tag in w2v.vocab:
48 | weights[i] = w2v[tag].astype(np.float32)
49 | else:
50 | oov_count += self.counter[tag]
51 | weights[i] = np.zeros(dim).astype(np.float32) if freeze else \
52 | np.random.uniform(-scale, scale, dim).astype(np.float32)
53 | all_count += self.counter[tag]
54 | logger.info("%s vocabulary pretrained OOV %d/%d, %.2f%%" %
55 | (self.name, oov_count, all_count, oov_count/all_count*100))
56 | else:
57 | scale = np.sqrt(3.0 / dim)
58 | weights = np.random.uniform(-scale, scale, [len(self), dim]).astype(np.float32)
59 | weights[0] = np.zeros(dim).astype(np.float32) if freeze else \
60 | np.random.uniform(-scale, scale, dim).astype(np.float32)
61 | weights = torch.from_numpy(weights)
62 | if use_gpu:
63 | weights = weights.cuda()
64 | embedding = torch.nn.Embedding.from_pretrained(weights, freeze=freeze)
65 | return embedding
66 |
67 | def __len__(self):
68 | return len(self.id2s)
69 |
70 |
71 | class Label:
72 | def __init__(self, name, counter, specials=None):
73 | self.name = name
74 | self.counter = counter.copy()
75 | self.label2id = {}
76 | self.id2label = []
77 | if specials:
78 | for label in specials:
79 | del self.counter[label]
80 | self.label2id[label] = len(self.label2id)
81 | self.id2label.append(label)
82 |
83 | for label, freq in self.counter.items():
84 | if label not in self.label2id:
85 | self.label2id[label] = len(self.label2id)
86 | self.id2label.append(label)
87 | logger.info("label %s size %d" % (name, len(self)))
88 |
89 | def __getitem__(self, item):
90 | return self.label2id[item]
91 |
92 | def __len__(self):
93 | return len(self.id2label)
94 |
--------------------------------------------------------------------------------
/en_dp_gan/model/parser.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import numpy as np
3 | from config import ids2nr
4 | from structure.rst_tree import rst_tree
5 |
6 |
7 | class PartitionPtrParser:
8 | def __init__(self, model):
9 | self.model = model
10 | self.ids2nr = ids2nr
11 |
12 | def parse(self, edus, ret_session=False):
13 | session = self.model.init_session(edus)
14 | while not session.terminate():
15 | split_score, nr_score, state = self.model(session)
16 | split = split_score.argmax()
17 | nr = self.ids2nr[nr_score[split].argmax()]
18 | nuclear, relation = nr.split("-")[0], "-".join(nr.split("-")[1:])
19 | session = session.forward(split_score, state, split, nuclear, relation)
20 | tree_parsed = self.build_rst_tree(edus, session.splits[:], session.nuclear[:], session.relations[:])
21 | if ret_session:
22 | return tree_parsed, session
23 | else:
24 | return tree_parsed
25 |
26 | def build_rst_tree(self, edus, splits, nuclear, relations):
27 | left, split, right = splits.pop(0)
28 | nucl = nuclear.pop(0)
29 | rel = relations.pop(0)
30 | if right - split == 1:
31 | # leaf node
32 | right_node = edus[split]
33 | else:
34 | # non leaf
35 | right_node = self.build_rst_tree(edus, splits, nuclear, relations)
36 | if split - left == 1:
37 | # leaf node
38 | left_node = edus[left]
39 | else:
40 | # none leaf
41 | left_node = self.build_rst_tree(edus, splits, nuclear, relations)
42 | root = rst_tree(l_ch=left_node, r_ch=right_node, ch_ns_rel=nucl, child_rel=rel,
43 | temp_edu_span=(left_node.temp_edu_span[0], right_node.temp_edu_span[1]))
44 | return root
45 |
46 | def traverse_tree(self, root):
47 | if root.left_child is not None:
48 | print(root.temp_edu_span, ", ", root.child_NS_rel, ", ", root.child_rel)
49 | self.traverse_tree(root.right_child)
50 | self.traverse_tree(root.left_child)
51 |
52 | def draw_scores_matrix(self):
53 | scores = self.model.scores
54 | self.draw_decision_hot_map(scores)
55 |
56 | @staticmethod
57 | def draw_decision_hot_map(scores):
58 | import matplotlib
59 | import matplotlib.pyplot as plt
60 | text_colors = ["black", "white"]
61 | c_map = "YlGn"
62 | y_label = "split score"
63 | col_labels = ["split %d" % i for i in range(0, scores.shape[1])]
64 | row_labels = ["step %d" % i for i in range(1, scores.shape[0] + 1)]
65 | fig, ax = plt.subplots()
66 | im = ax.imshow(scores, cmap=c_map)
67 | c_bar = ax.figure.colorbar(im, ax=ax)
68 | c_bar.ax.set_ylabel(y_label, rotation=-90, va="bottom")
69 | ax.set_xticks(np.arange(scores.shape[1]))
70 | ax.set_yticks(np.arange(scores.shape[0]))
71 | ax.set_xticklabels(col_labels)
72 | ax.set_yticklabels(row_labels)
73 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
74 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
75 | for edge, spine in ax.spines.items():
76 | spine.set_visible(False)
77 | ax.set_xticks(np.arange(scores.shape[1] + 1) - .5, minor=True)
78 | ax.set_yticks(np.arange(scores.shape[0] + 1) - .5, minor=True)
79 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
80 | ax.tick_params(which="minor", bottom=False, left=False)
81 | threshold = im.norm(scores.max()) / 2.
82 | val_fmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}")
83 | texts = []
84 | kw = dict(horizontalalignment="center", verticalalignment="center")
85 | for i in range(scores.shape[0]):
86 | for j in range(scores.shape[1]):
87 | kw.update(color=text_colors[im.norm(scores[i, j]) > threshold])
88 | text = im.axes.text(j, i, val_fmt(scores[i, j], None), **kw)
89 | texts.append(text)
90 | fig.tight_layout()
91 | plt.show()
92 |
--------------------------------------------------------------------------------
/en_dp_gan/util/drawer.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 |
4 | def draw_line(x, y_s, colors, labels, shapes, begins, lines_d, l_colors, step=1, x_name="STEP", y_name="LOSS"):
5 | fig = plt.figure(figsize=(5, 3.8), facecolor="white")
6 | ax1 = fig.add_subplot(1, 1, 1)
7 | ax1.set_xlabel(x_name)
8 | ax1.set_ylabel(y_name)
9 | for idx, y in enumerate(y_s):
10 | x_ = x[begins[idx]:]
11 | y_ = y[begins[idx]:]
12 | # ax1.plot(x[:], y[:], color=colors[idx], linewidth=3, marker="o", linestyle="", label=labels[idx])
13 | ax1.scatter(x_[::step], y_[::step], color=colors[idx], marker=shapes[idx], edgecolors=colors[idx], s=20, label=labels[idx])
14 | if lines_d[idx]:
15 | line_general = 20
16 | x_n = x_[::line_general]
17 | y_n = []
18 | sum_ = 0.
19 | sum_idx_num = 0
20 | for idx__, y__ in enumerate(y_):
21 | sum_idx_num += 1
22 | sum_ += y__
23 | if idx__ > 0 and idx__ % line_general == 0:
24 | y_n.append(sum_ / line_general)
25 | sum_ = 0.
26 | sum_idx_num = 0
27 | if sum_ > 0.:
28 | y_n.append(sum_ / sum_idx_num)
29 | if idx == 0:
30 | ax1.plot(x_n[:10], y_n[:10], color="PaleGreen", linewidth=2.6, linestyle="-")
31 | ax1.plot(x_n[9:], y_n[9:], color="Green", linewidth=2.6, linestyle="-")
32 | else:
33 | ax1.plot(x_n, y_n, color=l_colors[idx], linewidth=2.6, linestyle="-")
34 | plt.vlines(188, 0, 2, colors="black", linestyles="dashed", linewidth=2)
35 | plt.annotate("warm up", xy=(100, 0.45), xytext=(60, 0.15), arrowprops=dict(facecolor="white", headlength=4,
36 | headwidth=10, width=4))
37 | plt.legend(loc="upper left")
38 | plt.grid(linestyle='-')
39 | plt.show()
40 |
41 |
42 | def draw_all(x_y_s, colors, labels, shapes, begins, lines_d, l_colors, step=1, x_name="STEP", y_name="LOSS"):
43 | fig = plt.figure(figsize=(4, 10), facecolor="white")
44 | for idx_, x_y in enumerate(x_y_s):
45 | x_a, y_a = x_y
46 | ax1 = fig.add_subplot(3, 1, idx_ + 1)
47 | ax1.set_xlabel(x_name)
48 | ax1.set_ylabel(y_name)
49 | for idx, y in enumerate(y_a):
50 | x_ = x_a[begins[idx]:]
51 | y_ = y[begins[idx]:]
52 | # ax1.plot(x[:], y[:], color=colors[idx], linewidth=3, marker="o", linestyle="", label=labels[idx])
53 | ax1.scatter(x_[::step], y_[::step], color=colors[idx], marker=shapes[idx], edgecolors=colors[idx], s=20,
54 | label=labels[idx])
55 | if lines_d[idx]:
56 | line_general = 20
57 | x_n = x_[::line_general]
58 | y_n = []
59 | sum_ = 0.
60 | sum_idx_num = 0
61 | for idx__, y__ in enumerate(y_):
62 | sum_idx_num += 1
63 | sum_ += y__
64 | if idx__ > 0 and idx__ % line_general == 0:
65 | y_n.append(sum_ / line_general)
66 | sum_ = 0.
67 | sum_idx_num = 0
68 | if sum_ > 0.:
69 | y_n.append(sum_ / sum_idx_num)
70 | if idx == 0:
71 | ax1.plot(x_n[:10], y_n[:10], color="PaleGreen", linewidth=2.6, linestyle="-")
72 | ax1.plot(x_n[9:], y_n[9:], color="Green", linewidth=2.6, linestyle="-")
73 | else:
74 | ax1.plot(x_n, y_n, color=l_colors[idx], linewidth=2.6, linestyle="-")
75 | plt.vlines(188, 0, 2, colors="black", linestyles="dashed", linewidth=2)
76 | plt.annotate("warm up", xy=(100, 0.45), xytext=(60, 0.15), arrowprops=dict(facecolor="white", headlength=4,
77 | headwidth=10, width=4))
78 | plt.legend(loc="upper left")
79 | plt.grid(linestyle='-')
80 | plt.show()
81 |
82 |
83 | if __name__ == "__main__":
84 | pass
85 |
--------------------------------------------------------------------------------
/en_dp_gan/util/eval.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | from structure import EDU, Sentence, Relation
3 | import numpy as np
4 |
5 |
6 | def factorize_tree(tree, binarize=True):
7 | quads = set() # (span, nuclear, coarse relation, fine relation)
8 |
9 | def factorize(root, offset=0):
10 | if isinstance(root, EDU):
11 | return [(offset, offset+len(root.text))]
12 | elif isinstance(root, Sentence):
13 | children_spans = []
14 | for child in root:
15 | spans = factorize(child, offset)
16 | children_spans.extend(spans)
17 | offset = spans[-1][1]
18 | return children_spans
19 | elif isinstance(root, Relation):
20 | children_spans = []
21 | for child in root:
22 | spans = factorize(child, offset)
23 | children_spans.extend(spans)
24 | offset = spans[-1][1]
25 | if binarize:
26 | while len(children_spans) >= 2:
27 | right = children_spans.pop()
28 | left = children_spans.pop()
29 | quads.add(((left, right), root.nuclear, root.ctype, root.ftype))
30 | children_spans.append((left[0], right[1]))
31 | else:
32 | quads.add((tuple(children_spans), root.nuclear, root.ctype, root.ftype))
33 | return [(children_spans[0][0], children_spans[-1][1])]
34 |
35 | factorize(tree.root_relation())
36 | return quads
37 |
38 |
39 | def f1_score(num_corr, num_gold, num_pred, treewise_average=True):
40 | num_corr = num_corr.astype(np.float)
41 | num_gold = num_gold.astype(np.float)
42 | num_pred = num_pred.astype(np.float)
43 |
44 | if treewise_average:
45 | precision = (np.nan_to_num(num_corr / num_pred)).mean()
46 | recall = (np.nan_to_num(num_corr / num_gold)).mean()
47 | else:
48 | precision = np.nan_to_num(num_corr.sum() / num_pred.sum())
49 | recall = np.nan_to_num(num_corr.sum() / num_gold.sum())
50 |
51 | if precision + recall == 0:
52 | f1 = 0.
53 | else:
54 | f1 = 2. * precision * recall / (precision + recall)
55 | return precision, recall, f1
56 |
57 |
58 | def evaluation_trees(parses, golds, binarize=True, treewise_avearge=True):
59 | num_gold = np.zeros(len(golds))
60 | num_parse = np.zeros(len(parses))
61 | num_corr_span = np.zeros(len(parses))
62 | num_corr_nuc = np.zeros(len(parses))
63 | num_corr_ctype = np.zeros(len(parses))
64 | num_corr_ftype = np.zeros(len(parses))
65 |
66 | for i, (parse, gold) in enumerate(zip(parses, golds)):
67 | parse_factorized = factorize_tree(parse, binarize=binarize)
68 | gold_factorized = factorize_tree(gold, binarize=binarize)
69 | num_parse[i] = len(parse_factorized)
70 | num_gold[i] = len(gold_factorized)
71 |
72 | # index quads by child spans
73 | parse_dict = {quad[0]: quad for quad in parse_factorized}
74 | gold_dict = {quad[0]: quad for quad in gold_factorized}
75 |
76 | # find correct spans
77 | gold_spans = set(gold_dict.keys())
78 | parse_spans = set(parse_dict.keys())
79 | corr_spans = gold_spans & parse_spans
80 | num_corr_span[i] = len(corr_spans)
81 |
82 | # quad [span, nuclear, ftype, ctype]
83 | for span in corr_spans:
84 | # count correct nuclear
85 | num_corr_nuc[i] += 1 if parse_dict[span][1] == gold_dict[span][1] else 0
86 | # count correct ctype
87 | num_corr_ctype[i] += 1 if parse_dict[span][2] == gold_dict[span][2] else 0
88 | # count correct ftype
89 | num_corr_ftype[i] += 1 if parse_dict[span][3] == gold_dict[span][3] else 0
90 |
91 | span_score = f1_score(num_corr_span, num_gold, num_parse, treewise_average=treewise_avearge)
92 | nuc_score = f1_score(num_corr_nuc, num_gold, num_parse, treewise_average=treewise_avearge)
93 | ctype_score = f1_score(num_corr_ctype, num_gold, num_parse, treewise_average=treewise_avearge)
94 | ftype_score = f1_score(num_corr_ftype, num_gold, num_parse, treewise_average=treewise_avearge)
95 | return span_score, nuc_score, ctype_score, ftype_score
96 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/parser.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import numpy as np
3 | from config import ids2nr
4 | from structure.rst_tree import rst_tree
5 |
6 |
7 | class PartitionPtrParser:
8 | def __init__(self, model):
9 | self.model = model
10 | self.ids2nr = ids2nr
11 |
12 | def parse(self, edus, ret_session=False):
13 | # TODO implement beam search
14 | session = self.model.init_session(edus)
15 | while not session.terminate():
16 | split_score, nr_score, state = self.model(session)
17 | split = split_score.argmax()
18 | nr = self.ids2nr[nr_score[split].argmax()]
19 | nuclear, relation = nr.split("-")[0], "-".join(nr.split("-")[1:])
20 | session = session.forward(split_score, state, split, nuclear, relation)
21 | tree_parsed = self.build_rst_tree(edus, session.splits[:], session.nuclear[:], session.relations[:])
22 | if ret_session:
23 | return tree_parsed, session
24 | else:
25 | return tree_parsed
26 |
27 | def build_rst_tree(self, edus, splits, nuclear, relations):
28 | left, split, right = splits.pop(0)
29 | nucl = nuclear.pop(0)
30 | rel = relations.pop(0)
31 | if right - split == 1:
32 | # leaf node
33 | right_node = edus[split]
34 | else:
35 | # non leaf
36 | right_node = self.build_rst_tree(edus, splits, nuclear, relations)
37 | if split - left == 1:
38 | # leaf node
39 | left_node = edus[left]
40 | else:
41 | # none leaf
42 | left_node = self.build_rst_tree(edus, splits, nuclear, relations)
43 | root = rst_tree(l_ch=left_node, r_ch=right_node, ch_ns_rel=nucl, child_rel=rel,
44 | temp_edu_span=(left_node.temp_edu_span[0], right_node.temp_edu_span[1]))
45 | return root
46 |
47 | def traverse_tree(self, root):
48 | if root.left_child is not None:
49 | print(root.temp_edu_span, ", ", root.child_NS_rel, ", ", root.child_rel)
50 | self.traverse_tree(root.right_child)
51 | self.traverse_tree(root.left_child)
52 |
53 | def draw_scores_matrix(self):
54 | scores = self.model.scores
55 | self.draw_decision_hot_map(scores)
56 |
57 | @staticmethod
58 | def draw_decision_hot_map(scores):
59 | import matplotlib
60 | import matplotlib.pyplot as plt
61 | text_colors = ["black", "white"]
62 | c_map = "YlGn"
63 | y_label = "split score"
64 | col_labels = ["split %d" % i for i in range(0, scores.shape[1])]
65 | row_labels = ["step %d" % i for i in range(1, scores.shape[0] + 1)]
66 | fig, ax = plt.subplots()
67 | im = ax.imshow(scores, cmap=c_map)
68 | c_bar = ax.figure.colorbar(im, ax=ax)
69 | c_bar.ax.set_ylabel(y_label, rotation=-90, va="bottom")
70 | ax.set_xticks(np.arange(scores.shape[1]))
71 | ax.set_yticks(np.arange(scores.shape[0]))
72 | ax.set_xticklabels(col_labels)
73 | ax.set_yticklabels(row_labels)
74 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
75 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
76 | for edge, spine in ax.spines.items():
77 | spine.set_visible(False)
78 | ax.set_xticks(np.arange(scores.shape[1] + 1) - .5, minor=True)
79 | ax.set_yticks(np.arange(scores.shape[0] + 1) - .5, minor=True)
80 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
81 | ax.tick_params(which="minor", bottom=False, left=False)
82 | threshold = im.norm(scores.max()) / 2.
83 | val_fmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}")
84 | texts = []
85 | kw = dict(horizontalalignment="center", verticalalignment="center")
86 | for i in range(scores.shape[0]):
87 | for j in range(scores.shape[1]):
88 | kw.update(color=text_colors[im.norm(scores[i, j]) > threshold])
89 | text = im.axes.text(j, i, val_fmt(scores[i, j], None), **kw)
90 | texts.append(text)
91 | fig.tight_layout()
92 | plt.show()
93 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/util/eval.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | from structure import EDU, Sentence, Relation
3 | import numpy as np
4 |
5 |
6 | def factorize_tree(tree, binarize=True):
7 | quads = set() # (span, nuclear, coarse relation, fine relation)
8 |
9 | def factorize(root, offset=0):
10 | if isinstance(root, EDU):
11 | return [(offset, offset+len(root.text))]
12 | elif isinstance(root, Sentence):
13 | children_spans = []
14 | for child in root:
15 | spans = factorize(child, offset)
16 | children_spans.extend(spans)
17 | offset = spans[-1][1]
18 | return children_spans
19 | elif isinstance(root, Relation):
20 | children_spans = []
21 | for child in root:
22 | spans = factorize(child, offset)
23 | children_spans.extend(spans)
24 | offset = spans[-1][1]
25 | if binarize:
26 | while len(children_spans) >= 2:
27 | right = children_spans.pop()
28 | left = children_spans.pop()
29 | quads.add(((left, right), root.nuclear, root.ctype, root.ftype))
30 | children_spans.append((left[0], right[1]))
31 | else:
32 | quads.add((tuple(children_spans), root.nuclear, root.ctype, root.ftype))
33 | return [(children_spans[0][0], children_spans[-1][1])]
34 |
35 | factorize(tree.root_relation())
36 | return quads
37 |
38 |
39 | def f1_score(num_corr, num_gold, num_pred, treewise_average=True):
40 | num_corr = num_corr.astype(np.float)
41 | num_gold = num_gold.astype(np.float)
42 | num_pred = num_pred.astype(np.float)
43 |
44 | if treewise_average:
45 | precision = (np.nan_to_num(num_corr / num_pred)).mean()
46 | recall = (np.nan_to_num(num_corr / num_gold)).mean()
47 | else:
48 | precision = np.nan_to_num(num_corr.sum() / num_pred.sum())
49 | recall = np.nan_to_num(num_corr.sum() / num_gold.sum())
50 |
51 | if precision + recall == 0:
52 | f1 = 0.
53 | else:
54 | f1 = 2. * precision * recall / (precision + recall)
55 | return precision, recall, f1
56 |
57 |
58 | def evaluation_trees(parses, golds, binarize=True, treewise_avearge=True):
59 | num_gold = np.zeros(len(golds))
60 | num_parse = np.zeros(len(parses))
61 | num_corr_span = np.zeros(len(parses))
62 | num_corr_nuc = np.zeros(len(parses))
63 | num_corr_ctype = np.zeros(len(parses))
64 | num_corr_ftype = np.zeros(len(parses))
65 |
66 | for i, (parse, gold) in enumerate(zip(parses, golds)):
67 | parse_factorized = factorize_tree(parse, binarize=binarize)
68 | gold_factorized = factorize_tree(gold, binarize=binarize)
69 | num_parse[i] = len(parse_factorized)
70 | num_gold[i] = len(gold_factorized)
71 |
72 | # index quads by child spans
73 | parse_dict = {quad[0]: quad for quad in parse_factorized}
74 | gold_dict = {quad[0]: quad for quad in gold_factorized}
75 |
76 | # find correct spans
77 | gold_spans = set(gold_dict.keys())
78 | parse_spans = set(parse_dict.keys())
79 | corr_spans = gold_spans & parse_spans
80 | num_corr_span[i] = len(corr_spans)
81 |
82 | # quad [span, nuclear, ftype, ctype]
83 | for span in corr_spans:
84 | # count correct nuclear
85 | num_corr_nuc[i] += 1 if parse_dict[span][1] == gold_dict[span][1] else 0
86 | # count correct ctype
87 | num_corr_ctype[i] += 1 if parse_dict[span][2] == gold_dict[span][2] else 0
88 | # count correct ftype
89 | num_corr_ftype[i] += 1 if parse_dict[span][3] == gold_dict[span][3] else 0
90 |
91 | span_score = f1_score(num_corr_span, num_gold, num_parse, treewise_average=treewise_avearge)
92 | nuc_score = f1_score(num_corr_nuc, num_gold, num_parse, treewise_average=treewise_avearge)
93 | ctype_score = f1_score(num_corr_ctype, num_gold, num_parse, treewise_average=treewise_avearge)
94 | ftype_score = f1_score(num_corr_ftype, num_gold, num_parse, treewise_average=treewise_avearge)
95 | return span_score, nuc_score, ctype_score, ftype_score
96 |
--------------------------------------------------------------------------------
/en_dp_gan/model/stacked_parser_tdt_gan3/parser.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import numpy as np
3 | import torch
4 | from config import ids2nr, SEED
5 | from structure.rst_tree import rst_tree
6 | import random
7 | random.seed(SEED)
8 | torch.manual_seed(SEED)
9 | np.random.seed(SEED)
10 |
11 |
12 | class PartitionPtrParser:
13 | def __init__(self, model):
14 | self.model = model
15 | self.ids2nr = ids2nr
16 |
17 | def parse(self, edus, ret_session=False):
18 | # TODO implement beam search
19 | session = self.model.init_session(edus)
20 | d_masks, splits = None, []
21 | while not session.terminate():
22 | split_score, nr_score, state, d_mask = self.model(session)
23 | d_masks = d_mask if d_masks is None else torch.cat((d_masks, d_mask), 1)
24 | split = split_score.argmax()
25 | nr = self.ids2nr[nr_score[split].argmax()]
26 | nuclear, relation = nr.split("-")[0], "-".join(nr.split("-")[1:])
27 | session = session.forward(split_score, state, split, nuclear, relation)
28 |
29 | # build tree by splits (left, split, right)
30 | tree_parsed = self.build_rst_tree(edus, session.splits[:], session.nuclear[:], session.relations[:])
31 | if ret_session:
32 | return tree_parsed, session
33 | else:
34 | return tree_parsed
35 |
36 | def build_rst_tree(self, edus, splits, nuclear, relations, type_="Root", rel_=None):
37 | left, split, right = splits.pop(0)
38 | nucl = nuclear.pop(0)
39 | rel = relations.pop(0)
40 | left_n, right_n = nucl[0], nucl[1]
41 | left_rel = rel if left_n == "N" else "span"
42 | right_rel = rel if right_n == "N" else "span"
43 | if right - split == 0:
44 | # leaf node
45 | right_node = edus[split + 1]
46 | else:
47 | # non leaf
48 | right_node = self.build_rst_tree(edus, splits, nuclear, relations, type_=right_n, rel_=right_rel)
49 | if split - left == 0:
50 | # leaf node
51 | left_node = edus[split]
52 | else:
53 | # none leaf
54 | left_node = self.build_rst_tree(edus, splits, nuclear, relations, type_=left_n, rel_=left_rel)
55 | node_height = max(left_node.node_height, right_node.node_height) + 1
56 | root = rst_tree(l_ch=left_node, r_ch=right_node, ch_ns_rel=nucl, child_rel=rel,
57 | temp_edu_span=(left_node.temp_edu_span[0], right_node.temp_edu_span[1]),
58 | node_height=node_height, type_=type_, rel=rel_)
59 | return root
60 |
61 | def traverse_tree(self, root):
62 | if root.left_child is not None:
63 | print(root.temp_edu_span, ", ", root.child_NS_rel, ", ", root.child_rel)
64 | self.traverse_tree(root.right_child)
65 | self.traverse_tree(root.left_child)
66 |
67 | def draw_scores_matrix(self):
68 | scores = self.model.scores
69 | self.draw_decision_hot_map(scores)
70 |
71 | @staticmethod
72 | def draw_decision_hot_map(scores):
73 | import matplotlib
74 | import matplotlib.pyplot as plt
75 | text_colors = ["black", "white"]
76 | c_map = "YlGn"
77 | y_label = "split score"
78 | col_labels = ["split %d" % i for i in range(0, scores.shape[1])]
79 | row_labels = ["step %d" % i for i in range(1, scores.shape[0] + 1)]
80 | fig, ax = plt.subplots()
81 | im = ax.imshow(scores, cmap=c_map)
82 | c_bar = ax.figure.colorbar(im, ax=ax)
83 | c_bar.ax.set_ylabel(y_label, rotation=-90, va="bottom")
84 | ax.set_xticks(np.arange(scores.shape[1]))
85 | ax.set_yticks(np.arange(scores.shape[0]))
86 | ax.set_xticklabels(col_labels)
87 | ax.set_yticklabels(row_labels)
88 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
89 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
90 | for edge, spine in ax.spines.items():
91 | spine.set_visible(False)
92 | ax.set_xticks(np.arange(scores.shape[1] + 1) - .5, minor=True)
93 | ax.set_yticks(np.arange(scores.shape[0] + 1) - .5, minor=True)
94 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
95 | ax.tick_params(which="minor", bottom=False, left=False)
96 | threshold = im.norm(scores.max()) / 2.
97 | val_fmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}")
98 | texts = []
99 | kw = dict(horizontalalignment="center", verticalalignment="center")
100 | for i in range(scores.shape[0]):
101 | for j in range(scores.shape[1]):
102 | kw.update(color=text_colors[im.norm(scores[i, j]) > threshold])
103 | text = im.axes.text(j, i, val_fmt(scores[i, j], None), **kw)
104 | texts.append(text)
105 | fig.tight_layout()
106 | plt.show()
107 |
--------------------------------------------------------------------------------
/ch_dp_gan/berkeleyparser/README:
--------------------------------------------------------------------------------
1 | "THE BERKELEY PARSER"
2 | release 1.1
3 | migrated from Google Code to GitHub July 2015
4 |
5 | This package contains the Berkeley Parser as described in
6 |
7 | "Learning Accurate, Compact, and Interpretable Tree Annotation"
8 | Slav Petrov, Leon Barrett, Romain Thibaux and Dan Klein
9 | in COLING-ACL 2006
10 |
11 | and
12 |
13 | "Improved Inference for Unlexicalized Parsing"
14 | Slav Petrov and Dan Klein
15 | in HLT-NAACL 2007
16 |
17 | If you use this code in your research and would like to acknowledge it, please refer to one of those publications. Note that the jar-archive also contains all source files. For questions please contact Slav Petrov (petrov@cs.berkeley.edu).
18 |
19 |
20 | * PARSING
21 | The main class of the jar-archive is the parser. By default, it will read in PTB tokenized sentences from STDIN (one per line) and write parse trees to STDOUT. It can be evoked with:
22 |
23 | java -jar berkeleyParser.jar -gr
24 |
25 | The parser can produce k-best lists and parse in parallel using multiple threads. Several additional options are also available (return binarized and/or annotated trees, produce an image of the parse tree, tokenize the input, run in fast/accurate mode, print out tree likelihoods, etc.). Starting the parser without supplying a grammar file will print a list of all options.
26 |
27 | * ADDITIONAL TOOLS
28 | A tool for annotating parse trees with their most likely Viterbi derivation over refined categories and scoring the subtrees can be started with
29 |
30 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA/TreeLabeler -gr
31 |
32 | This tool reads in parse trees from STDIN, annotates them as specified and prints them out to STDOUT. You can use
33 |
34 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.TreeScorer -gr
35 |
36 | to compute the (log-)likelihood of a parse tree.
37 |
38 |
39 | * GRAMMARS
40 | Included are grammars for English, German and Chinese. For parsing English text which is not from the Wall Street Journal, we recommend that you use the English grammar after 5 split&merge iterations as experiments suggest that the 6 split&merge iterations grammars are overfitting the Wall Street Journal. Because of the coarse-to-fine method used by the parser, there is essentially no difference in parsing time between the different grammars.
41 |
42 |
43 | * LEARNING NEW GRAMMARS
44 | You will need a treebank in order to learn new grammars. The package contains code for reading in some of the standard treebanks. To learn a grammar from the Wall Street Journal section of the Penn Treebank, you can execute
45 |
46 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTrainer -path -out
47 |
48 | To learn a grammar from trees that are contained in a single file use the -treebank option, e.g.:
49 |
50 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTrainer -path -out -treebank SINGLEFILE
51 |
52 | This will read in the WSJ training set and do 6 iterations of split, merge, smooth. An intermediate grammar file will be written to disk once in a while and you can expect the final grammar to be written to after 15-20 hours. The GrammarTrainer accepts a variety of options which have been set to reasonable default values. Most of the options should be self-explaining and you are encouraged to experiment with them. Note that since EM is a local method each run will produce slightly different results. Furthermore, the default settings prune away rules with probability below a certain threshold, which greatly speeds up the training, but increases the variance. To train grammars on other training sets (e.g. for other languages), consult edu.berkeley.nlp.PCFGLA.Corpus.java and supply the correct language option to the trainer.
53 | To the test the performance of a grammar you can use
54 |
55 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTester -path -in
56 |
57 |
58 | * WRITING GRAMMARS TO TEXT FILES
59 | The parser reads and writes grammar files as serialized java classes. To view the grammars, you can export them to text format with:
60 |
61 | java -cp berkeleyParser.jar edu/berkeley/nlp/PCFGLA/WriteGrammarToTextFile
62 |
63 | This will create three text files. outname.grammar and outname.lexicon contain the respective rule scores and outname.words should be used with the included perl script to map words to their signatures.
64 |
65 | * UNKNOWN WORDS
66 | The lexicon contains arrays with scores for each (tag,signature) pair. The array entries correspond to the scores for the respective tag substates. The signatures of known words are the words themselves. Unknown words, in contrast, are classified into a set of unknown word categories. The perl script "getSignature" takes as first argument a file containing the known words (presumably produced by WriteGrammarToFile). It then reads words from STDIN and returns the word signature to STDOUT. The signatures should be used to look up the tagging probabilities of words in the lexicon file.
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/berkeleyparser/README:
--------------------------------------------------------------------------------
1 | "THE BERKELEY PARSER"
2 | release 1.1
3 | migrated from Google Code to GitHub July 2015
4 |
5 | This package contains the Berkeley Parser as described in
6 |
7 | "Learning Accurate, Compact, and Interpretable Tree Annotation"
8 | Slav Petrov, Leon Barrett, Romain Thibaux and Dan Klein
9 | in COLING-ACL 2006
10 |
11 | and
12 |
13 | "Improved Inference for Unlexicalized Parsing"
14 | Slav Petrov and Dan Klein
15 | in HLT-NAACL 2007
16 |
17 | If you use this code in your research and would like to acknowledge it, please refer to one of those publications. Note that the jar-archive also contains all source files. For questions please contact Slav Petrov (petrov@cs.berkeley.edu).
18 |
19 |
20 | * PARSING
21 | The main class of the jar-archive is the parser. By default, it will read in PTB tokenized sentences from STDIN (one per line) and write parse trees to STDOUT. It can be evoked with:
22 |
23 | java -jar berkeleyParser.jar -gr
24 |
25 | The parser can produce k-best lists and parse in parallel using multiple threads. Several additional options are also available (return binarized and/or annotated trees, produce an image of the parse tree, tokenize the input, run in fast/accurate mode, print out tree likelihoods, etc.). Starting the parser without supplying a grammar file will print a list of all options.
26 |
27 | * ADDITIONAL TOOLS
28 | A tool for annotating parse trees with their most likely Viterbi derivation over refined categories and scoring the subtrees can be started with
29 |
30 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA/TreeLabeler -gr
31 |
32 | This tool reads in parse trees from STDIN, annotates them as specified and prints them out to STDOUT. You can use
33 |
34 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.TreeScorer -gr
35 |
36 | to compute the (log-)likelihood of a parse tree.
37 |
38 |
39 | * GRAMMARS
40 | Included are grammars for English, German and Chinese. For parsing English text which is not from the Wall Street Journal, we recommend that you use the English grammar after 5 split&merge iterations as experiments suggest that the 6 split&merge iterations grammars are overfitting the Wall Street Journal. Because of the coarse-to-fine method used by the parser, there is essentially no difference in parsing time between the different grammars.
41 |
42 |
43 | * LEARNING NEW GRAMMARS
44 | You will need a treebank in order to learn new grammars. The package contains code for reading in some of the standard treebanks. To learn a grammar from the Wall Street Journal section of the Penn Treebank, you can execute
45 |
46 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTrainer -path -out
47 |
48 | To learn a grammar from trees that are contained in a single file use the -treebank option, e.g.:
49 |
50 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTrainer -path -out -treebank SINGLEFILE
51 |
52 | This will read in the WSJ training set and do 6 iterations of split, merge, smooth. An intermediate grammar file will be written to disk once in a while and you can expect the final grammar to be written to after 15-20 hours. The GrammarTrainer accepts a variety of options which have been set to reasonable default values. Most of the options should be self-explaining and you are encouraged to experiment with them. Note that since EM is a local method each run will produce slightly different results. Furthermore, the default settings prune away rules with probability below a certain threshold, which greatly speeds up the training, but increases the variance. To train grammars on other training sets (e.g. for other languages), consult edu.berkeley.nlp.PCFGLA.Corpus.java and supply the correct language option to the trainer.
53 | To the test the performance of a grammar you can use
54 |
55 | java -cp berkeleyParser.jar edu.berkeley.nlp.PCFGLA.GrammarTester -path -in
56 |
57 |
58 | * WRITING GRAMMARS TO TEXT FILES
59 | The parser reads and writes grammar files as serialized java classes. To view the grammars, you can export them to text format with:
60 |
61 | java -cp berkeleyParser.jar edu/berkeley/nlp/PCFGLA/WriteGrammarToTextFile
62 |
63 | This will create three text files. outname.grammar and outname.lexicon contain the respective rule scores and outname.words should be used with the included perl script to map words to their signatures.
64 |
65 | * UNKNOWN WORDS
66 | The lexicon contains arrays with scores for each (tag,signature) pair. The array entries correspond to the scores for the respective tag substates. The signatures of known words are the words themselves. Unknown words, in contrast, are classified into a set of unknown word categories. The perl script "getSignature" takes as first argument a file containing the known words (presumably produced by WriteGrammarToFile). It then reads words from STDIN and returns the word signature to STDOUT. The signatures should be used to look up the tagging probabilities of words in the lexicon file.
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/model/stacked_parser_tdt_xlnet/parser.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import numpy as np
3 | import torch
4 | from config import ids2nr, SEED
5 | from structure.rst_tree import rst_tree
6 | import random
7 | random.seed(SEED)
8 | torch.manual_seed(SEED)
9 | np.random.seed(SEED)
10 |
11 |
12 | class PartitionPtrParser:
13 | def __init__(self, model):
14 | self.ids2nr = ids2nr
15 |
16 | def parse(self, edus, ret_session=False, model=None, model_xl=None, tokenizer_xl=None):
17 | # TODO implement beam search
18 | session = model.init_session(edus, model_xl, tokenizer_xl) # 初始化状态
19 | d_masks, splits = None, []
20 | while not session.terminate():
21 | split_score, nr_score, state, d_mask = model.parse_predict(session)
22 | d_masks = d_mask if d_masks is None else torch.cat((d_masks, d_mask), 1)
23 | split = split_score.argmax()
24 | nr = self.ids2nr[nr_score[split].argmax()]
25 | nuclear, relation = nr.split("-")[0], "-".join(nr.split("-")[1:])
26 | session = session.forward(split_score, state, split, nuclear, relation)
27 | # build tree by splits (left, split, right)
28 | tree_parsed = self.build_rst_tree(edus, session.splits[:], session.nuclear[:], session.relations[:])
29 | if ret_session:
30 | return tree_parsed, session
31 | else:
32 | return tree_parsed
33 |
34 | def build_rst_tree(self, edus, splits, nuclear, relations, type_="Root", rel_=None):
35 | left, split, right = splits.pop(0)
36 | nucl = nuclear.pop(0)
37 | rel = relations.pop(0)
38 | left_n, right_n = nucl[0], nucl[1]
39 | left_rel = rel if left_n == "N" else "span"
40 | right_rel = rel if right_n == "N" else "span"
41 | if right - split == 0:
42 | # leaf node
43 | right_node = edus[split + 1]
44 | else:
45 | # non leaf
46 | right_node = self.build_rst_tree(edus, splits, nuclear, relations, type_=right_n, rel_=right_rel)
47 | if split - left == 0:
48 | # leaf node
49 | left_node = edus[split]
50 | else:
51 | # none leaf
52 | left_node = self.build_rst_tree(edus, splits, nuclear, relations, type_=left_n, rel_=left_rel)
53 | node_height = max(left_node.node_height, right_node.node_height) + 1
54 | root = rst_tree(l_ch=left_node, r_ch=right_node, ch_ns_rel=nucl, child_rel=rel,
55 | temp_edu_span=(left_node.temp_edu_span[0], right_node.temp_edu_span[1]),
56 | node_height=node_height, type_=type_, rel=rel_)
57 | return root
58 |
59 | def traverse_tree(self, root):
60 | if root.left_child is not None:
61 | print(root.temp_edu_span, ", ", root.child_NS_rel, ", ", root.child_rel)
62 | self.traverse_tree(root.right_child)
63 | self.traverse_tree(root.left_child)
64 |
65 | def draw_scores_matrix(self, model):
66 | scores = model.scores
67 | self.draw_decision_hot_map(scores)
68 |
69 | @staticmethod
70 | def draw_decision_hot_map(scores):
71 | import matplotlib
72 | import matplotlib.pyplot as plt
73 | text_colors = ["black", "white"]
74 | c_map = "YlGn"
75 | y_label = "split score"
76 | col_labels = ["split %d" % i for i in range(0, scores.shape[1])]
77 | row_labels = ["step %d" % i for i in range(1, scores.shape[0] + 1)]
78 | fig, ax = plt.subplots()
79 | im = ax.imshow(scores, cmap=c_map)
80 | c_bar = ax.figure.colorbar(im, ax=ax)
81 | c_bar.ax.set_ylabel(y_label, rotation=-90, va="bottom")
82 | ax.set_xticks(np.arange(scores.shape[1]))
83 | ax.set_yticks(np.arange(scores.shape[0]))
84 | ax.set_xticklabels(col_labels)
85 | ax.set_yticklabels(row_labels)
86 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
87 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
88 | for edge, spine in ax.spines.items():
89 | spine.set_visible(False)
90 | ax.set_xticks(np.arange(scores.shape[1] + 1) - .5, minor=True)
91 | ax.set_yticks(np.arange(scores.shape[0] + 1) - .5, minor=True)
92 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
93 | ax.tick_params(which="minor", bottom=False, left=False)
94 | threshold = im.norm(scores.max()) / 2.
95 | val_fmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}")
96 | texts = []
97 | kw = dict(horizontalalignment="center", verticalalignment="center")
98 | for i in range(scores.shape[0]):
99 | for j in range(scores.shape[1]):
100 | kw.update(color=text_colors[im.norm(scores[i, j]) > threshold])
101 | text = im.axes.text(j, i, val_fmt(scores[i, j], None), **kw)
102 | texts.append(text)
103 | fig.tight_layout()
104 | plt.show()
105 |
--------------------------------------------------------------------------------
/ch_dp_gan/dataset/cdtb.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import os
4 | import re
5 | import pickle
6 | import gzip
7 | import hashlib
8 | import thulac
9 | import tqdm
10 | import logging
11 | from itertools import chain
12 | from nltk.tree import Tree as ParseTree
13 | from structure import Discourse, Sentence, EDU, TEXT, Connective, node_type_filter
14 |
15 |
16 | # note: this module will print a junk line "Model loaded succeed" to stdio when initializing
17 | thulac = thulac.thulac()
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class CDTB:
22 | def __init__(self, cdtb_dir, train, validate, test, encoding="UTF-8",
23 | ctb_dir=None, ctb_encoding="UTF-8", cache_dir=None, preprocess=False):
24 | if cache_dir:
25 | hash_figure = "-".join(map(str, [cdtb_dir, train, validate, test, ctb_dir, preprocess]))
26 | hash_key = hashlib.md5(hash_figure.encode()).hexdigest()
27 | cache = os.path.join(cache_dir, hash_key + ".gz")
28 | else:
29 | cache = None
30 |
31 | if cache is not None and os.path.isfile(cache):
32 | with gzip.open(cache, "rb") as cache_fd:
33 | self.preprocess, self.ctb = pickle.load(cache_fd)
34 | self.train = pickle.load(cache_fd)
35 | self.validate = pickle.load(cache_fd)
36 | self.test = pickle.load(cache_fd)
37 | logger.info("load cached dataset from %s" % cache)
38 | return
39 |
40 | self.preprocess = preprocess
41 | self.ctb = self.load_ctb(ctb_dir, ctb_encoding) if ctb_dir else {}
42 | self.train = self.load_dir(cdtb_dir, train, encoding=encoding)
43 | self.validate = self.load_dir(cdtb_dir, validate, encoding=encoding)
44 | self.test = self.load_dir(cdtb_dir, test, encoding=encoding)
45 |
46 | if preprocess:
47 | for discourse in tqdm.tqdm(chain(self.train, self.validate, self.test), desc="preprocessing"):
48 | self.preprocessing(discourse)
49 |
50 | if cache is not None:
51 | with gzip.open(cache, "wb") as cache_fd:
52 | pickle.dump((self.preprocess, self.ctb), cache_fd)
53 | pickle.dump(self.train, cache_fd)
54 | pickle.dump(self.validate, cache_fd)
55 | pickle.dump(self.test, cache_fd)
56 | logger.info("saved cached dataset to %s" % cache)
57 |
58 | def report(self):
59 | # TODO
60 | raise NotImplementedError()
61 |
62 | def preprocessing(self, discourse):
63 | for paragraph in discourse:
64 | for sentence in paragraph.iterfind(filter=node_type_filter(Sentence)):
65 | if self.ctb and (sentence.sid is not None) and (sentence.sid in self.ctb):
66 | parse = self.ctb[sentence.sid]
67 | pairs = [(node[0], node.label()) for node in parse.subtrees()
68 | if node.height() == 2 and node.label() != "-NONE-"]
69 | words, tags = list(zip(*pairs))
70 | else:
71 | words, tags = list(zip(*thulac.cut(sentence.text)))
72 | setattr(sentence, "words", list(words))
73 | setattr(sentence, "tags", list(tags))
74 |
75 | offset = 0
76 | for textnode in sentence.iterfind(filter=node_type_filter([TEXT, Connective, EDU]),
77 | terminal=node_type_filter([TEXT, Connective, EDU])):
78 | if isinstance(textnode, EDU):
79 | edu_words = []
80 | edu_tags = []
81 | cur = 0
82 | for word, tag in zip(sentence.words, sentence.tags):
83 | if offset <= cur < cur + len(word) <= offset + len(textnode.text):
84 | edu_words.append(word)
85 | edu_tags.append(tag)
86 | cur += len(word)
87 | setattr(textnode, "words", edu_words)
88 | setattr(textnode, "tags", edu_tags)
89 | offset += len(textnode.text)
90 | return discourse
91 |
92 | @staticmethod
93 | def load_dir(path, sub, encoding="UTF-8"):
94 | train_path = os.path.join(path, sub)
95 | discourses = []
96 | for file in os.listdir(train_path):
97 | file = os.path.join(train_path, file)
98 | discourse = Discourse.from_xml(file, encoding=encoding)
99 | discourses.append(discourse)
100 | return discourses
101 |
102 | @staticmethod
103 | def load_ctb(ctb_dir, encoding="UTF-8"):
104 | ctb = {}
105 | s_pat = re.compile("\S+?)>(?P.*?)", re.M | re.DOTALL)
106 | for file in os.listdir(ctb_dir):
107 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd:
108 | doc = fd.read()
109 | for match in s_pat.finditer(doc):
110 | sid = match.group("sid")
111 | sparse = ParseTree.fromstring(match.group("sparse"))
112 | ctb[sid] = sparse
113 | return ctb
114 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/dataset/cdtb.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import os
4 | import re
5 | import pickle
6 | import gzip
7 | import hashlib
8 | import thulac
9 | import tqdm
10 | import logging
11 | from itertools import chain
12 | from nltk.tree import Tree as ParseTree
13 | from structure import Discourse, Sentence, EDU, TEXT, Connective, node_type_filter
14 |
15 |
16 | # note: this module will print a junk line "Model loaded succeed" to stdio when initializing
17 | thulac = thulac.thulac()
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class CDTB:
22 | def __init__(self, cdtb_dir, train, validate, test, encoding="UTF-8",
23 | ctb_dir=None, ctb_encoding="UTF-8", cache_dir=None, preprocess=False):
24 | if cache_dir:
25 | hash_figure = "-".join(map(str, [cdtb_dir, train, validate, test, ctb_dir, preprocess]))
26 | hash_key = hashlib.md5(hash_figure.encode()).hexdigest()
27 | cache = os.path.join(cache_dir, hash_key + ".gz")
28 | else:
29 | cache = None
30 |
31 | if cache is not None and os.path.isfile(cache):
32 | with gzip.open(cache, "rb") as cache_fd:
33 | self.preprocess, self.ctb = pickle.load(cache_fd)
34 | self.train = pickle.load(cache_fd)
35 | self.validate = pickle.load(cache_fd)
36 | self.test = pickle.load(cache_fd)
37 | logger.info("load cached dataset from %s" % cache)
38 | return
39 |
40 | self.preprocess = preprocess
41 | self.ctb = self.load_ctb(ctb_dir, ctb_encoding) if ctb_dir else {}
42 | self.train = self.load_dir(cdtb_dir, train, encoding=encoding)
43 | self.validate = self.load_dir(cdtb_dir, validate, encoding=encoding)
44 | self.test = self.load_dir(cdtb_dir, test, encoding=encoding)
45 |
46 | if preprocess:
47 | for discourse in tqdm.tqdm(chain(self.train, self.validate, self.test), desc="preprocessing"):
48 | self.preprocessing(discourse)
49 |
50 | if cache is not None:
51 | with gzip.open(cache, "wb") as cache_fd:
52 | pickle.dump((self.preprocess, self.ctb), cache_fd)
53 | pickle.dump(self.train, cache_fd)
54 | pickle.dump(self.validate, cache_fd)
55 | pickle.dump(self.test, cache_fd)
56 | logger.info("saved cached dataset to %s" % cache)
57 |
58 | def report(self):
59 | # TODO
60 | raise NotImplementedError()
61 |
62 | def preprocessing(self, discourse):
63 | for paragraph in discourse:
64 | for sentence in paragraph.iterfind(filter=node_type_filter(Sentence)):
65 | if self.ctb and (sentence.sid is not None) and (sentence.sid in self.ctb):
66 | parse = self.ctb[sentence.sid]
67 | pairs = [(node[0], node.label()) for node in parse.subtrees()
68 | if node.height() == 2 and node.label() != "-NONE-"]
69 | words, tags = list(zip(*pairs))
70 | else:
71 | words, tags = list(zip(*thulac.cut(sentence.text)))
72 | setattr(sentence, "words", list(words))
73 | setattr(sentence, "tags", list(tags))
74 |
75 | offset = 0
76 | for textnode in sentence.iterfind(filter=node_type_filter([TEXT, Connective, EDU]),
77 | terminal=node_type_filter([TEXT, Connective, EDU])):
78 | if isinstance(textnode, EDU):
79 | edu_words = []
80 | edu_tags = []
81 | cur = 0
82 | for word, tag in zip(sentence.words, sentence.tags):
83 | if offset <= cur < cur + len(word) <= offset + len(textnode.text):
84 | edu_words.append(word)
85 | edu_tags.append(tag)
86 | cur += len(word)
87 | setattr(textnode, "words", edu_words)
88 | setattr(textnode, "tags", edu_tags)
89 | offset += len(textnode.text)
90 | return discourse
91 |
92 | @staticmethod
93 | def load_dir(path, sub, encoding="UTF-8"):
94 | train_path = os.path.join(path, sub)
95 | discourses = []
96 | for file in os.listdir(train_path):
97 | file = os.path.join(train_path, file)
98 | discourse = Discourse.from_xml(file, encoding=encoding)
99 | discourses.append(discourse)
100 | return discourses
101 |
102 | @staticmethod
103 | def load_ctb(ctb_dir, encoding="UTF-8"):
104 | ctb = {}
105 | s_pat = re.compile("\S+?)>(?P.*?)", re.M | re.DOTALL)
106 | for file in os.listdir(ctb_dir):
107 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd:
108 | doc = fd.read()
109 | for match in s_pat.finditer(doc):
110 | sid = match.group("sid")
111 | sparse = ParseTree.fromstring(match.group("sparse"))
112 | ctb[sid] = sparse
113 | return ctb
114 |
--------------------------------------------------------------------------------
/en_dp_gan_xlnet/structure/cdtb.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import os
4 | import re
5 | import pickle
6 | import gzip
7 | import hashlib
8 | import thulac
9 | import tqdm
10 | import logging
11 | from itertools import chain
12 | from nltk.tree import Tree as ParseTree
13 | from structure import Discourse, Sentence, EDU, TEXT, Connective, node_type_filter
14 |
15 |
16 | # note: this module will print a junk line "Model loaded succeed" to stdio when initializing
17 | thulac = thulac.thulac()
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class CDTB:
22 | def __init__(self, cdtb_dir, train, validate, test, encoding="UTF-8",
23 | ctb_dir=None, ctb_encoding="UTF-8", cache_dir=None, preprocess=False):
24 | if cache_dir:
25 | hash_figure = "-".join(map(str, [cdtb_dir, train, validate, test, ctb_dir, preprocess]))
26 | hash_key = hashlib.md5(hash_figure.encode()).hexdigest()
27 | cache = os.path.join(cache_dir, hash_key + ".gz")
28 | else:
29 | cache = None
30 |
31 | if cache is not None and os.path.isfile(cache):
32 | with gzip.open(cache, "rb") as cache_fd:
33 | self.preprocess, self.ctb = pickle.load(cache_fd)
34 | self.train = pickle.load(cache_fd)
35 | self.validate = pickle.load(cache_fd)
36 | self.test = pickle.load(cache_fd)
37 | logger.info("load cached dataset from %s" % cache)
38 | return
39 |
40 | self.preprocess = preprocess
41 | self.ctb = self.load_ctb(ctb_dir, ctb_encoding) if ctb_dir else {}
42 | self.train = self.load_dir(cdtb_dir, train, encoding=encoding)
43 | self.validate = self.load_dir(cdtb_dir, validate, encoding=encoding)
44 | self.test = self.load_dir(cdtb_dir, test, encoding=encoding)
45 |
46 | if preprocess:
47 | for discourse in tqdm.tqdm(chain(self.train, self.validate, self.test), desc="preprocessing"):
48 | self.preprocessing(discourse)
49 |
50 | if cache is not None:
51 | with gzip.open(cache, "wb") as cache_fd:
52 | pickle.dump((self.preprocess, self.ctb), cache_fd)
53 | pickle.dump(self.train, cache_fd)
54 | pickle.dump(self.validate, cache_fd)
55 | pickle.dump(self.test, cache_fd)
56 | logger.info("saved cached dataset to %s" % cache)
57 |
58 | def report(self):
59 | # TODO
60 | raise NotImplementedError()
61 |
62 | def preprocessing(self, discourse):
63 | for paragraph in discourse:
64 | for sentence in paragraph.iterfind(filter=node_type_filter(Sentence)):
65 | if self.ctb and (sentence.sid is not None) and (sentence.sid in self.ctb):
66 | parse = self.ctb[sentence.sid]
67 | pairs = [(node[0], node.label()) for node in parse.subtrees()
68 | if node.height() == 2 and node.label() != "-NONE-"]
69 | words, tags = list(zip(*pairs))
70 | else:
71 | words, tags = list(zip(*thulac.cut(sentence.text)))
72 | setattr(sentence, "words", list(words))
73 | setattr(sentence, "tags", list(tags))
74 |
75 | offset = 0
76 | for textnode in sentence.iterfind(filter=node_type_filter([TEXT, Connective, EDU]),
77 | terminal=node_type_filter([TEXT, Connective, EDU])):
78 | if isinstance(textnode, EDU):
79 | edu_words = []
80 | edu_tags = []
81 | cur = 0
82 | for word, tag in zip(sentence.words, sentence.tags):
83 | if offset <= cur < cur + len(word) <= offset + len(textnode.text):
84 | edu_words.append(word)
85 | edu_tags.append(tag)
86 | cur += len(word)
87 | setattr(textnode, "words", edu_words)
88 | setattr(textnode, "tags", edu_tags)
89 | offset += len(textnode.text)
90 | return discourse
91 |
92 | @staticmethod
93 | def load_dir(path, sub, encoding="UTF-8"):
94 | train_path = os.path.join(path, sub)
95 | discourses = []
96 | for file in os.listdir(train_path):
97 | file = os.path.join(train_path, file)
98 | discourse = Discourse.from_xml(file, encoding=encoding)
99 | discourses.append(discourse)
100 | return discourses
101 |
102 | @staticmethod
103 | def load_ctb(ctb_dir, encoding="UTF-8"):
104 | ctb = {}
105 | s_pat = re.compile("\S+?)>(?P.*?)", re.M | re.DOTALL)
106 | for file in os.listdir(ctb_dir):
107 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd:
108 | doc = fd.read()
109 | for match in s_pat.finditer(doc):
110 | sid = match.group("sid")
111 | sparse = ParseTree.fromstring(match.group("sparse"))
112 | ctb[sid] = sparse
113 | return ctb
114 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 |
3 | This project presents the recently proposed GAN-based DRS parser in
4 | **Adversarial Learning for Discourse Rhetorical Structure Parsing (ACL-IJCNLP2021)**.
5 | For any questions please directly send e-mails to zzlynx@outlook.com (Longyin Zhang).
6 | One may find that English parser is not included here, because of the close correlation
7 | between our various studies, we will release the English system in September 22.
8 |
9 | #### Installation
10 | - Python 3.6.10
11 | - transformers 3.0.2
12 | - pytorch 1.5.0
13 | - numpy 1.19.1
14 | - cudatoolkit 9.2 cudnn 7.6.5
15 | - other necessary python packages (etc.)
16 |
17 | #### Project Structure
18 | ```
19 | ---DP_GAN
20 | |-ch_dp_gan Chinese discourse parser based on Qiu-W2V
21 | |-ch_dp_gan_xlnet Chinese discourse parser based on XLNet
22 | |-en_dp_gan English discourse parser based on GloVe and ELMo
23 | |-en_dp_gan_xlnet English discourse parser based on XLNet
24 | ```
25 | We do not provide data in this project for the data is too large to be uploaded to Github. Please prepare
26 | the data by yourself if you want to train a parser. By the way, we will provide a pre-trained end2end DRS
27 | parser (automatic EDU segmentation and discourse parsing) in the project **sota_end2end_parser** for other
28 | researchers to directly apply it to other NLP tasks.
29 |
30 | We have explained in paper that there are some differences between the XLNet-based systems and the other
31 | systems. I re-checked the system and found that there are still two points that I forgot to mention:
32 | **(i)** We did not use the sentence and paragraph boundaries in the Chinese system, because the performance
33 | is bad;
34 |
35 | **(ii)** In the English XLNet-based parser, we did not input the EDU vectors into RNNs before split point
36 | encoding.
37 |
38 | #### Performance Evaluation
39 |
40 | As stated in the paper, we employ the *original Parseval (Morey et al. 2017)* to evaluate our English DRS
41 | parser and report the **micro-averaged F1-score** as performance. We did not report the results based on Marcu's
42 | RST Parseval because the metric will overestimate the performance level of DRS parsing.
43 |
44 | As we know, when using *RST Parseval*, we actually have 19 relation categories considered for evaluation, i.e.,
45 | the 18 rhetorical relations and the **SPAN** tag. Among these tags, the SPAN relation accounts for more than a
46 | half of the relation labels in RST-DT, which may enlarge the uncertainty of performance comparison.
47 | Specifically, the systems that predict relation tags for the parent node will show weaker performance than the
48 | systems that predict the relation category of each child node. **Why?** Usually, the second kind of systems also
49 | employ SPAN tags for model training and this brings in additional gradients for the model to greedily maximize
50 | the rewards by assigning SPAN label to appropriate tree nodes. However, for the first kind of systems, the SPAN
51 | labels are assigned only according to their predicted Nuclearity category (our system belongs to this kind).
52 |
53 | Here we report the results of (Yu et al. 2018) and ours on **SPAN** for reference:
54 | ```
55 | --- system ---------- P ---- R ---- F
56 | Yu et al. (2018) 60.9 63.7 62.3 (The parsing results are from Yu Nan)
57 | Ours 46.1 43.1 44.5
58 | ```
59 |
60 | Obviously, it's hard to judge whether the performance improvements come from the 18 rhetorical relations or the
61 | fake relation "SPAN" when using RST Parseval. For more clear performance comparison, we explicitly recommend
62 | other DRS researchers to use the original Parseval to evaluate their parsers.
63 |
64 | For Chinese DRS parsing, we use a strict method for performance evaluation, and one can refer to
65 | https://github.com/NLP-Discourse-SoochowU/t2d_discourseparser for details.
66 | Recently, some researchers asked me a question, that is, "Why use different metrics for CDTB parsing in
67 | 2020 and 2021's parsers?" I think it necessary to give an answer here. Actually, our work on top-down RST
68 | parsing was finished in 2019 and we once send the research to ACL2019 and it was rejected. There is a 2-year
69 | distance between the two works and we had some new research during this time including our findings on parsing
70 | evaluation. We found that the RST- and original-Parseval only consider span boundaries for structure analysis
71 | and ignore the split points, which is not strict enough, and we discussed and proposed a new evaluation method
72 | and have a try in CDTB analysis. Besides, when we send the work to ACL2021, we thought we were using the
73 | original-Parseval for evaluation but later we found we just used the new evaluation method. Therefore, we re-report
74 | the results of our 2020's work using the new strict evaluation method to give inspiration for following RST
75 | researchers.
76 |
77 | #### Model Training
78 | In this project, we tuned the hyper-parameters for best performance and the details are well shown in the ACL
79 | paper. Although we had tried our best to check the correction of the paper content, we still find one inaccuracy:
80 | **We trained the XLNet systems for 50 rounds instead of the 30 written in the Appendix.**
81 |
82 | To train the parsing models, run the following commands:
83 | ```
84 | (Chinese) python -m treebuilder.dp_nr.train
85 | (English) python main.py
86 | ```
87 |
88 | Due to the utilization of GAN nets, we found that the parsing performance has some fluctuation, it is related to
89 | the hardware and the software running environment you use. It should be noted that some researchers may
90 | use the preprocessed RST-DT corpus for experiments, there could be some tiny differences when compared with the
91 | original data. We recommend using the original formal RST-DT corpus for experiments.
92 |
93 | We will provide a pre-trained end-to-end parser at https://github.com/NLP-Discourse-SoochowU/sota_end2end_parser,
94 | and one can directly apply it to downstream NLP tasks.
95 |
96 | #### Tips
97 | RST-style discourse parsing has long been known to be complicated, which is why the system we provide actually
98 | contains so many experimental settings. We had conducted a set of experimental comparisons in this system, and
99 | we found that some experimental details could be helpful for your own system, e.g., the EDU attention, the
100 | Context attention, the residual connection, the parameter initialization method, etc. These code
101 | details can hardly bring great research value in this period, but they will make your system more stable.
102 |
103 | -- License
104 | ```
105 | Copyright (c) 2019, Soochow University NLP research group. All rights reserved.
106 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that
107 | the following conditions are met:
108 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
109 | following disclaimer.
110 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the
111 | following disclaimer in the documentation and/or other materials provided with the distribution.
112 | ```
113 |
--------------------------------------------------------------------------------
/en_dp_gan/util/data_builder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | @Author: lyzhang
5 | @Date: 2018.4.5
6 | @Description:
7 | """
8 | import progressbar
9 | from structure.tree_obj import tree_obj
10 | from stanfordcorenlp import StanfordCoreNLP
11 | from config import *
12 | from path_config import path_to_jar
13 | from structure.rst_tree import rst_tree
14 | from util.rst_utils import get_edus_info
15 |
16 |
17 | class Builder:
18 | def __init__(self):
19 | self.nlp = StanfordCoreNLP(path_to_jar)
20 | self.max_len_of_edu = 0
21 |
22 | @staticmethod
23 | def update_edu_trees():
24 | edu_trees = load_data(RST_TEST_EDUS_TREES)
25 | test_trees = load_data(RST_TEST_TREES)
26 | tree_num = len(test_trees)
27 | new_edu_trees = []
28 | for idx in range(tree_num):
29 | tmp_gold_tree = test_trees[idx]
30 | tmp_token_vec = []
31 | for edu in tmp_gold_tree.edus:
32 | tmp_token_vec.extend(edu.temp_edu_emlo_emb)
33 | tmp_edu_tree = edu_trees[idx]
34 | new_edus = []
35 | for edu in tmp_edu_tree.edus:
36 | edu_len = len(edu.temp_edu_ids)
37 | edu.temp_edu_emlo_emb = tmp_token_vec[:edu_len]
38 | tmp_token_vec = tmp_token_vec[edu_len:]
39 | new_edus.append(edu)
40 | new_tree = tree_obj()
41 | new_tree.assign_edus(new_edus)
42 | new_edu_trees.append(new_tree)
43 | save_data(new_edu_trees, RST_TEST_EDUS_TREES)
44 |
45 | def form_trees(self):
46 | # train_dir = os.path.join(RAW_RST_DT_PATH, "train")
47 | # train_tree_p = os.path.join(RST_DT_PATH, "train_trees.pkl")
48 | # train_trees_rst_p = os.path.join(RST_DT_PATH, "train_trees_rst.pkl")
49 | # self.build_specific_trees(train_dir, train_tree_p, train_trees_rst_p)
50 | #
51 | # dev_dir = os.path.join(RAW_RST_DT_PATH, "dev")
52 | # dev_tree_p = os.path.join(RST_DT_PATH, "dev_trees.pkl")
53 | # dev_trees_rst_p = os.path.join(RST_DT_PATH, "dev_trees_rst.pkl")
54 | # self.build_specific_trees(dev_dir, dev_tree_p, dev_trees_rst_p)
55 |
56 | test_dir = os.path.join(RAW_RST_DT_PATH, "test")
57 | test_tree_p = os.path.join(RST_DT_PATH, "test_trees.pkl")
58 | test_trees_rst_p = os.path.join(RST_DT_PATH, "test_trees_rst.pkl")
59 | edus_tree_p = os.path.join(RST_DT_PATH, "edus_test_trees.pkl")
60 | self.build_specific_trees(test_dir, test_tree_p, test_trees_rst_p, edus_tree_p)
61 |
62 | # test_dir = os.path.join(RAW_RST_DT_PATH, "test")
63 | # test_tree_p_n = os.path.join(RST_DT_PATH, "test_trees_n.pkl")
64 | # test_trees_rst_p_n = os.path.join(RST_DT_PATH, "test_trees_rst_n.pkl")
65 | # edus_tree_p_n = os.path.join(RST_DT_PATH, "edus_test_trees_n.pkl")
66 | # self.build_specific_trees(test_dir, test_tree_p_n, test_trees_rst_p_n, edus_tree_p_n)
67 |
68 | def build_specific_trees(self, raw_p=None, tree_p=None, tree_rst_p=None, edus_tree_p=None):
69 | p = progressbar.ProgressBar()
70 | p.start(len(os.listdir(raw_p)))
71 | pro_idx = 1
72 | trees_list, trees_rst_list = [], []
73 | edu_trees_list = []
74 | for file_name in os.listdir(raw_p):
75 | p.update(pro_idx)
76 | pro_idx += 1
77 | if file_name.endswith('.out.dis'):
78 | root, edus_list = self.build_one_tree(raw_p, file_name, edus_tree_p)
79 | trees_rst_list.append(root)
80 | tree_obj_ = tree_obj(root)
81 | trees_list.append(tree_obj_)
82 | if edus_tree_p is not None:
83 | edus_tree_obj_ = tree_obj()
84 | edus_tree_obj_.assign_edus(edus_list)
85 | edu_trees_list.append(edus_tree_obj_)
86 | p.finish()
87 | save_data(trees_list, tree_p)
88 | save_data(trees_rst_list, tree_rst_p)
89 | if edus_tree_p is not None:
90 | save_data(edu_trees_list, edus_tree_p)
91 |
92 | def build_one_tree(self, raw_p, file_name, edus_tree_p):
93 | """ build edu_list and edu_ids according to EDU files.
94 | """
95 | temp_path = os.path.join(raw_p, file_name)
96 | e_bound_list, e_list, e_span_list, e_ids_list, e_tags_list, e_headwords_list, e_cent_word_list, e_emlo_list =\
97 | get_edus_info(temp_path.replace(".dis", ".edus"), temp_path.replace(".out.dis", ".out"), nlp=self.nlp)
98 | lines_list = open(temp_path, 'r').readlines()
99 | root = rst_tree(type_="Root", lines_list=lines_list, temp_line=lines_list[0], file_name=file_name, rel="span")
100 | root.create_tree(temp_line_num=1, p_node_=root)
101 | root.config_edus(temp_node=root, e_list=e_list, e_span_list=e_span_list, e_ids_list=e_ids_list,
102 | e_tags_list=e_tags_list, e_headwords_list=e_headwords_list, e_cent_word_list=e_cent_word_list,
103 | total_e_num=len(e_ids_list), e_bound_list=e_bound_list, e_emlo_list=e_emlo_list)
104 |
105 | edus_list = []
106 | if edus_tree_p is not None:
107 | e_bound_list, e_list, e_span_list, e_ids_list, e_tags_list, e_headwords_list, e_cent_word_list, \
108 | e_emlo_list = get_edus_info(temp_path.replace(".dis", ".auto.edus"),
109 | temp_path.replace(".out.dis", ".out"), nlp=self.nlp)
110 | edu_num = len(e_bound_list)
111 | for idx in range(edu_num):
112 | tmp_edu = rst_tree(temp_edu_boundary=e_bound_list[idx], temp_edu=e_list[idx],
113 | temp_edu_span=e_span_list[idx], temp_edu_ids=e_ids_list[idx],
114 | temp_pos_ids=e_tags_list[idx], temp_edu_heads=e_headwords_list[idx],
115 | temp_edu_has_center_word=e_cent_word_list[idx],
116 | tmp_edu_emlo_emb=e_emlo_list[idx])
117 | edus_list.append(tmp_edu)
118 | return root, edus_list
119 |
120 | def build_tree_obj_list(self):
121 | raw = ""
122 | parse_trees = []
123 | for file_name in os.listdir(raw):
124 | if file_name.endswith(".out"):
125 | tmp_edus_list = []
126 | sent_path = os.path.join(raw, file_name)
127 | edu_path = sent_path + ".edu"
128 | edus_boundary_list, edus_list, edu_span_list, edus_ids_list, edus_tag_ids_list, edus_conns_list, \
129 | edu_headword_ids_list, edu_has_center_word_list = get_edus_info(edu_path, sent_path, nlp=self.nlp)
130 | for _ in range(len(edus_list)):
131 | tmp_edu = rst_tree()
132 | tmp_edu.temp_edu = edus_list.pop(0)
133 | tmp_edu.temp_edu_span = edu_span_list.pop(0)
134 | tmp_edu.temp_edu_ids = edus_ids_list.pop(0)
135 | tmp_edu.temp_pos_ids = edus_tag_ids_list.pop(0)
136 | tmp_edu.temp_edu_conn_ids = edus_conns_list.pop(0)
137 | tmp_edu.temp_edu_heads = edu_headword_ids_list.pop(0)
138 | tmp_edu.temp_edu_has_center_word = edu_has_center_word_list.pop(0)
139 | tmp_edu.edu_node_boundary = edus_boundary_list.pop(0)
140 | tmp_edu.inner_sent = True
141 | tmp_edus_list.append(tmp_edu)
142 | tmp_tree_obj = tree_obj()
143 | tmp_tree_obj.file_name = file_name
144 | tmp_tree_obj.assign_edus(tmp_edus_list)
145 | parse_trees.append(tmp_tree_obj)
146 | return parse_trees
147 |
--------------------------------------------------------------------------------
/ch_dp_gan/pre_process.py:
--------------------------------------------------------------------------------
1 | # UTF-8
2 | # Author: Longyin Zhang
3 | # Date: 2020.10.9
4 | import argparse
5 | import logging
6 | import numpy as np
7 | from dataset import CDTB
8 | from collections import Counter
9 | from itertools import chain
10 | from structure.vocab import Vocab, Label
11 | from structure.nodes import node_type_filter, EDU, Relation, Sentence, TEXT
12 | import progressbar
13 | from util.file_util import *
14 | p = progressbar.ProgressBar()
15 |
16 |
17 | def build_vocab(dataset):
18 | word_freq = Counter()
19 | pos_freq = Counter()
20 | nuc_freq = Counter()
21 | rel_freq = Counter()
22 | for paragraph in chain(*dataset):
23 | for node in paragraph.iterfind(filter=node_type_filter([EDU, Relation])):
24 | if isinstance(node, EDU):
25 | word_freq.update(node.words)
26 | pos_freq.update(node.tags)
27 | elif isinstance(node, Relation):
28 | nuc_freq[node.nuclear] += 1
29 | rel_freq[node.ftype] += 1
30 |
31 | word_vocab = Vocab("word", word_freq)
32 | pos_vocab = Vocab("part of speech", pos_freq)
33 | nuc_label = Label("nuclear", nuc_freq)
34 | rel_label = Label("relation", rel_freq)
35 | return word_vocab, pos_vocab, nuc_label, rel_label
36 |
37 |
38 | def gen_decoder_data(root, edu2ids):
39 | # splits s0 s1 s2 s3 s4 s5 s6
40 | # edus s/ e0 e1 e2 e3 e4 e5 /s
41 | splits = [] # [(0, 3, 6, NS), (0, 2, 3, SN), ...]
42 | child_edus = [] # [edus]
43 |
44 | if isinstance(root, EDU):
45 | child_edus.append(root)
46 | elif isinstance(root, Sentence):
47 | for child in root:
48 | _child_edus, _splits = gen_decoder_data(child, edu2ids)
49 | child_edus.extend(_child_edus)
50 | splits.extend(_splits)
51 | elif isinstance(root, Relation):
52 | children = [gen_decoder_data(child, edu2ids) for child in root]
53 | if len(children) < 2:
54 | raise ValueError("relation node should have at least 2 children")
55 |
56 | while children:
57 | left_child_edus, left_child_splits = children.pop(0)
58 | if children:
59 | last_child_edus, _ = children[-1]
60 | start = edu2ids[left_child_edus[0]]
61 | split = edu2ids[left_child_edus[-1]] + 1
62 | end = edu2ids[last_child_edus[-1]] + 1
63 | nuc = root.nuclear
64 | rel = root.ftype
65 | splits.append((start, split, end, nuc, rel))
66 | child_edus.extend(left_child_edus)
67 | splits.extend(left_child_splits)
68 | return child_edus, splits
69 |
70 |
71 | def numericalize(dataset, word_vocab, pos_vocab, nuc_label, rel_label):
72 | instances = []
73 | for paragraph in filter(lambda d: d.root_relation(), chain(*dataset)):
74 | encoder_inputs = []
75 | decoder_inputs = []
76 | pred_splits = []
77 | pred_nucs = []
78 | pred_rels = []
79 | edus = list(paragraph.edus())
80 | for edu in edus:
81 | edu_word_ids = [word_vocab[word] for word in edu.words]
82 | edu_pos_ids = [pos_vocab[pos] for pos in edu.tags]
83 | encoder_inputs.append((edu_word_ids, edu_pos_ids))
84 | edu2ids = {edu: i for i, edu in enumerate(edus)}
85 | _, splits = gen_decoder_data(paragraph.root_relation(), edu2ids)
86 | for start, split, end, nuc, rel in splits:
87 | decoder_inputs.append((start, end))
88 | pred_splits.append(split)
89 | pred_nucs.append(nuc_label[nuc])
90 | pred_rels.append(rel_label[rel])
91 | instances.append((encoder_inputs, decoder_inputs, pred_splits, pred_nucs, pred_rels))
92 | return instances
93 |
94 |
95 | nr2ids = dict()
96 | ids2nr = dict()
97 | nr_idx = 0
98 |
99 |
100 | def gen_batch_iter(instances, batch_size, use_gpu=False):
101 | global nr2ids, nr_idx, ids2nr
102 | random_instances = np.random.permutation(instances)
103 | num_instances = len(instances)
104 | offset = 0
105 | p.start(num_instances)
106 | while offset < num_instances:
107 | p.update(offset)
108 | batch = random_instances[offset: min(num_instances, offset+batch_size)]
109 | for batchi, (_, _, _, pred_nucs, pred_rels) in enumerate(batch):
110 | for nuc, rel in zip(pred_nucs, pred_rels):
111 | nr = str(nuc) + "-" + str(rel)
112 | if nr not in nr2ids.keys():
113 | nr2ids[nr] = nr_idx
114 | ids2nr[nr_idx] = nr
115 | nr_idx += 1
116 | offset = offset + batch_size
117 | p.finish()
118 |
119 |
120 | def main(args):
121 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir)
122 | word_vocab, pos_vocab, nuc_label, rel_label = build_vocab(cdtb.train)
123 | trainset = numericalize(cdtb.train, word_vocab, pos_vocab, nuc_label, rel_label)
124 | test = numericalize(cdtb.test, word_vocab, pos_vocab, nuc_label, rel_label)
125 | dev = numericalize(cdtb.validate, word_vocab, pos_vocab, nuc_label, rel_label)
126 | for idx in range(1, 8):
127 | gen_batch_iter(trainset, args.batch_size, args.use_gpu)
128 | gen_batch_iter(test, args.batch_size, args.use_gpu)
129 | gen_batch_iter(dev, args.batch_size, args.use_gpu)
130 | print(nr2ids)
131 | save_data(nr2ids, "data/nr2ids.pkl")
132 | save_data(ids2nr, "data/ids2nr.pkl")
133 |
134 |
135 | if __name__ == '__main__':
136 | logging.basicConfig(level=logging.INFO)
137 | arg_parser = argparse.ArgumentParser()
138 |
139 | # dataset parameters
140 | arg_parser.add_argument("--data", default="data/CDTB")
141 | arg_parser.add_argument("--ctb_dir", default="data/CTB")
142 | arg_parser.add_argument("--cache_dir", default="data/cache")
143 |
144 | # model parameters
145 | arg_parser.add_argument("-hidden_size", default=512, type=int)
146 | arg_parser.add_argument("-dropout", default=0.33, type=float)
147 | # w2v_group = arg_parser.add_mutually_exclusive_group(required=True)
148 | arg_parser.add_argument("-w2v_size", default=300, type=int)
149 | arg_parser.add_argument("-pos_size", default=30, type=int)
150 | arg_parser.add_argument("-split_mlp_size", default=64, type=int)
151 | arg_parser.add_argument("-nuc_mlp_size", default=32, type=int)
152 | arg_parser.add_argument("-rel_mlp_size", default=128, type=int)
153 | arg_parser.add_argument("--w2v_freeze", dest="w2v_freeze", action="store_true")
154 | arg_parser.add_argument("-pretrained", default="data/pretrained/sgns.renmin.word")
155 | arg_parser.set_defaults(w2v_freeze=True)
156 |
157 | # train parameters
158 | arg_parser.add_argument("-epoch", default=20, type=int)
159 | arg_parser.add_argument("-batch_size", default=64, type=int)
160 | arg_parser.add_argument("-lr", default=0.001, type=float)
161 | arg_parser.add_argument("-l2", default=0.0, type=float)
162 | arg_parser.add_argument("-log_every", default=10, type=int)
163 | arg_parser.add_argument("-validate_every", default=10, type=int)
164 | arg_parser.add_argument("-a_split_loss", default=0.3, type=float)
165 | arg_parser.add_argument("-a_nuclear_loss", default=1.0, type=float)
166 | arg_parser.add_argument("-a_relation_loss", default=1.0, type=float)
167 | arg_parser.add_argument("-log_dir", default="data/log")
168 | arg_parser.add_argument("-model_save", default="data/models/treebuilder.partptr.model")
169 | arg_parser.add_argument("--seed", default=7, type=int)
170 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")
171 | arg_parser.set_defaults(use_gpu=True)
172 |
173 | main(arg_parser.parse_args())
174 |
175 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/pre_process.py:
--------------------------------------------------------------------------------
1 | # UTF-8
2 | # Author: Longyin Zhang
3 | # Date: 2020.10.9
4 | import argparse
5 | import logging
6 | import numpy as np
7 | from dataset import CDTB
8 | from collections import Counter
9 | from itertools import chain
10 | from structure.vocab import Vocab, Label
11 | from structure.nodes import node_type_filter, EDU, Relation, Sentence, TEXT
12 | import progressbar
13 | from util.file_util import *
14 | p = progressbar.ProgressBar()
15 |
16 |
17 | def build_vocab(dataset):
18 | word_freq = Counter()
19 | pos_freq = Counter()
20 | nuc_freq = Counter()
21 | rel_freq = Counter()
22 | for paragraph in chain(*dataset):
23 | for node in paragraph.iterfind(filter=node_type_filter([EDU, Relation])):
24 | if isinstance(node, EDU):
25 | word_freq.update(node.words)
26 | pos_freq.update(node.tags)
27 | elif isinstance(node, Relation):
28 | nuc_freq[node.nuclear] += 1
29 | rel_freq[node.ftype] += 1
30 |
31 | word_vocab = Vocab("word", word_freq)
32 | pos_vocab = Vocab("part of speech", pos_freq)
33 | nuc_label = Label("nuclear", nuc_freq)
34 | rel_label = Label("relation", rel_freq)
35 | return word_vocab, pos_vocab, nuc_label, rel_label
36 |
37 |
38 | def gen_decoder_data(root, edu2ids):
39 | # splits s0 s1 s2 s3 s4 s5 s6
40 | # edus s/ e0 e1 e2 e3 e4 e5 /s
41 | splits = [] # [(0, 3, 6, NS), (0, 2, 3, SN), ...]
42 | child_edus = [] # [edus]
43 |
44 | if isinstance(root, EDU):
45 | child_edus.append(root)
46 | elif isinstance(root, Sentence):
47 | for child in root:
48 | _child_edus, _splits = gen_decoder_data(child, edu2ids)
49 | child_edus.extend(_child_edus)
50 | splits.extend(_splits)
51 | elif isinstance(root, Relation):
52 | children = [gen_decoder_data(child, edu2ids) for child in root]
53 | if len(children) < 2:
54 | raise ValueError("relation node should have at least 2 children")
55 |
56 | while children:
57 | left_child_edus, left_child_splits = children.pop(0)
58 | if children:
59 | last_child_edus, _ = children[-1]
60 | start = edu2ids[left_child_edus[0]]
61 | split = edu2ids[left_child_edus[-1]] + 1
62 | end = edu2ids[last_child_edus[-1]] + 1
63 | nuc = root.nuclear
64 | rel = root.ftype
65 | splits.append((start, split, end, nuc, rel))
66 | child_edus.extend(left_child_edus)
67 | splits.extend(left_child_splits)
68 | return child_edus, splits
69 |
70 |
71 | def numericalize(dataset, word_vocab, pos_vocab, nuc_label, rel_label):
72 | instances = []
73 | for paragraph in filter(lambda d: d.root_relation(), chain(*dataset)):
74 | encoder_inputs = []
75 | decoder_inputs = []
76 | pred_splits = []
77 | pred_nucs = []
78 | pred_rels = []
79 | edus = list(paragraph.edus())
80 | for edu in edus:
81 | edu_word_ids = [word_vocab[word] for word in edu.words]
82 | edu_pos_ids = [pos_vocab[pos] for pos in edu.tags]
83 | encoder_inputs.append((edu_word_ids, edu_pos_ids))
84 | edu2ids = {edu: i for i, edu in enumerate(edus)}
85 | _, splits = gen_decoder_data(paragraph.root_relation(), edu2ids)
86 | for start, split, end, nuc, rel in splits:
87 | decoder_inputs.append((start, end))
88 | pred_splits.append(split)
89 | pred_nucs.append(nuc_label[nuc])
90 | pred_rels.append(rel_label[rel])
91 | instances.append((encoder_inputs, decoder_inputs, pred_splits, pred_nucs, pred_rels))
92 | return instances
93 |
94 |
95 | nr2ids = dict()
96 | ids2nr = dict()
97 | nr_idx = 0
98 |
99 |
100 | def gen_batch_iter(instances, batch_size, use_gpu=False):
101 | """ generate graphs for global optimization
102 | """
103 | global nr2ids, nr_idx, ids2nr
104 | random_instances = np.random.permutation(instances)
105 | num_instances = len(instances)
106 | offset = 0
107 | p.start(num_instances)
108 | while offset < num_instances:
109 | p.update(offset)
110 | batch = random_instances[offset: min(num_instances, offset+batch_size)]
111 | for batchi, (_, _, _, pred_nucs, pred_rels) in enumerate(batch):
112 | for nuc, rel in zip(pred_nucs, pred_rels):
113 | nr = str(nuc) + "-" + str(rel)
114 | if nr not in nr2ids.keys():
115 | nr2ids[nr] = nr_idx
116 | ids2nr[nr_idx] = nr
117 | nr_idx += 1
118 | offset = offset + batch_size
119 | p.finish()
120 |
121 |
122 | def main(args):
123 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir)
124 | word_vocab, pos_vocab, nuc_label, rel_label = build_vocab(cdtb.train)
125 | trainset = numericalize(cdtb.train, word_vocab, pos_vocab, nuc_label, rel_label)
126 | test = numericalize(cdtb.test, word_vocab, pos_vocab, nuc_label, rel_label)
127 | dev = numericalize(cdtb.validate, word_vocab, pos_vocab, nuc_label, rel_label)
128 | for idx in range(1, 8):
129 | gen_batch_iter(trainset, args.batch_size, args.use_gpu)
130 | gen_batch_iter(test, args.batch_size, args.use_gpu)
131 | gen_batch_iter(dev, args.batch_size, args.use_gpu)
132 | print(nr2ids)
133 | save_data(nr2ids, "data/nr2ids.pkl")
134 | save_data(ids2nr, "data/ids2nr.pkl")
135 |
136 |
137 | if __name__ == '__main__':
138 | logging.basicConfig(level=logging.INFO)
139 | arg_parser = argparse.ArgumentParser()
140 |
141 | # dataset parameters
142 | arg_parser.add_argument("--data", default="data/CDTB")
143 | arg_parser.add_argument("--ctb_dir", default="data/CTB")
144 | arg_parser.add_argument("--cache_dir", default="data/cache")
145 |
146 | # model parameters
147 | arg_parser.add_argument("-hidden_size", default=512, type=int)
148 | arg_parser.add_argument("-dropout", default=0.33, type=float)
149 | # w2v_group = arg_parser.add_mutually_exclusive_group(required=True)
150 | arg_parser.add_argument("-w2v_size", default=300, type=int)
151 | arg_parser.add_argument("-pos_size", default=30, type=int)
152 | arg_parser.add_argument("-split_mlp_size", default=64, type=int)
153 | arg_parser.add_argument("-nuc_mlp_size", default=32, type=int)
154 | arg_parser.add_argument("-rel_mlp_size", default=128, type=int)
155 | arg_parser.add_argument("--w2v_freeze", dest="w2v_freeze", action="store_true")
156 | arg_parser.add_argument("-pretrained", default="data/pretrained/sgns.renmin.word")
157 | arg_parser.set_defaults(w2v_freeze=True)
158 |
159 | # train parameters
160 | arg_parser.add_argument("-epoch", default=20, type=int)
161 | arg_parser.add_argument("-batch_size", default=64, type=int)
162 | arg_parser.add_argument("-lr", default=0.001, type=float)
163 | arg_parser.add_argument("-l2", default=0.0, type=float)
164 | arg_parser.add_argument("-log_every", default=10, type=int)
165 | arg_parser.add_argument("-validate_every", default=10, type=int)
166 | arg_parser.add_argument("-a_split_loss", default=0.3, type=float)
167 | arg_parser.add_argument("-a_nuclear_loss", default=1.0, type=float)
168 | arg_parser.add_argument("-a_relation_loss", default=1.0, type=float)
169 | arg_parser.add_argument("-log_dir", default="data/log")
170 | arg_parser.add_argument("-model_save", default="data/models/treebuilder.partptr.model")
171 | arg_parser.add_argument("--seed", default=7, type=int)
172 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")
173 | arg_parser.set_defaults(use_gpu=True)
174 |
175 | main(arg_parser.parse_args())
176 |
177 |
--------------------------------------------------------------------------------
/ch_dp_gan/treebuilder/dp_nr/parser.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | import numpy as np
4 | from structure.nodes import Paragraph, Relation, rev_relationmap
5 | import matplotlib
6 | import matplotlib.pyplot as plt
7 | from interface import ParserI
8 | from structure.nodes import EDU, TEXT
9 | from config import *
10 |
11 |
12 | class PartPtrParser(ParserI):
13 | def __init__(self, model):
14 | self.parser = PartitionPtrParser(model)
15 |
16 | def parse(self, para):
17 | edus = []
18 | for edu in para.edus():
19 | edu_copy = EDU([TEXT(edu.text)])
20 | setattr(edu_copy, "words", edu.words)
21 | setattr(edu_copy, "tags", edu.tags)
22 | edus.append(edu_copy)
23 | return self.parser.parse(edus)
24 |
25 |
26 | class PartitionPtrParser:
27 | def __init__(self, model):
28 | self.model = model
29 |
30 | def parse(self, edus, ret_session=False):
31 | if len(edus) < 2:
32 | return Paragraph(edus)
33 |
34 | # TODO implement beam search
35 | session = self.init_session(edus)
36 | while not session.terminate():
37 | split_scores, nr_score, state = self.decode(session)
38 | split = split_scores.argmax()
39 | nr_id = nr_score[split].argmax()
40 | nr = ids2nr[nr_id]
41 | nucl_id, rel_id = int(nr.split("-")[0]), int(nr.split("-")[1])
42 | nuclear = self.model.nuc_label.id2label[nucl_id]
43 | relation = self.model.rel_label.id2label[rel_id]
44 | session = session.forward(split_scores, state, split, nuclear, relation)
45 | root_relation = self.build_tree(edus, session.splits[:], session.nuclears[:], session.relations[:])
46 | discourse = Paragraph([root_relation])
47 | if ret_session:
48 | return discourse, session
49 | else:
50 | return discourse
51 |
52 | def init_session(self, edus):
53 | edu_words = [edu.words for edu in edus]
54 | edu_poses = [edu.tags for edu in edus]
55 | max_word_seqlen = max(len(words) for words in edu_words)
56 | edu_seqlen = len(edu_words)
57 |
58 | e_input_words = np.zeros((1, edu_seqlen, max_word_seqlen), dtype=np.long)
59 | e_boundaries = np.zeros([1, edu_seqlen + 1], dtype=np.long)
60 | e_input_poses = np.zeros_like(e_input_words)
61 | e_input_masks = np.zeros_like(e_input_words, dtype=np.uint8)
62 |
63 | for i, (words, poses) in enumerate(zip(edu_words, edu_poses)):
64 | e_input_words[0, i, :len(words)] = [self.model.word_vocab[word] for word in words]
65 | e_input_poses[0, i, :len(poses)] = [self.model.pos_vocab[pos] for pos in poses]
66 | e_input_masks[0, i, :len(words)] = 1
67 | if self.model.word_vocab[words[-1]] in [39, 3015, 178]:
68 | e_boundaries[0][i + 1] = 1
69 |
70 | e_input_words = torch.from_numpy(e_input_words).long()
71 | e_boundaries = torch.from_numpy(e_boundaries).long()
72 | e_input_poses = torch.from_numpy(e_input_poses).long()
73 | e_input_masks = torch.from_numpy(e_input_masks).byte()
74 |
75 | if self.model.use_gpu:
76 | e_input_words = e_input_words.cuda()
77 | e_boundaries = e_boundaries.cuda()
78 | e_input_poses = e_input_poses.cuda()
79 | e_input_masks = e_input_masks.cuda()
80 |
81 | edu_encoded, e_masks = self.model.encode_edus((e_input_words, e_input_poses, e_input_masks, e_boundaries))
82 | memory, _, context = self.model.encoder(edu_encoded, e_masks, e_boundaries)
83 | state = self.model.context_dense(context).unsqueeze(0)
84 | return Session(memory, state)
85 |
86 | def decode(self, session):
87 | left, right = session.stack[-1]
88 | return self.model(left, right, session.memory, session.state)
89 |
90 | def build_tree(self, edus, splits, nuclears, relations):
91 | left, split, right = splits.pop(0)
92 | nuclear = nuclears.pop(0)
93 | ftype = relations.pop(0)
94 | ctype = rev_relationmap[ftype]
95 | if split - left == 1:
96 | left_node = edus[left]
97 | else:
98 | left_node = self.build_tree(edus, splits, nuclears, relations)
99 |
100 | if right - split == 1:
101 | right_node = edus[split]
102 | else:
103 | right_node = self.build_tree(edus, splits, nuclears, relations)
104 |
105 | relation = Relation([left_node, right_node], nuclear=nuclear, ftype=ftype, ctype=ctype)
106 | return relation
107 |
108 |
109 | class Session:
110 | def __init__(self, memory, state):
111 | self.n = memory.size(1) - 2
112 | self.step = 0
113 | self.memory = memory
114 | self.state = state
115 | self.stack = [(0, self.n + 1)]
116 | self.scores = np.zeros((self.n, self.n+2), dtype=np.float)
117 | self.splits = []
118 | self.nuclears = []
119 | self.relations = []
120 |
121 | def forward(self, score, state, split, nuclear, relation):
122 | left, right = self.stack.pop()
123 | if right - split > 1:
124 | self.stack.append((split, right))
125 | if split - left > 1:
126 | self.stack.append((left, split))
127 | self.splits.append((left, split, right))
128 | self.nuclears.append(nuclear)
129 | self.relations.append(relation)
130 | self.state = state
131 | self.scores[self.step] = score
132 | self.step += 1
133 | return self
134 |
135 | def terminate(self):
136 | return self.step >= self.n
137 |
138 | def draw_decision_hotmap(self):
139 | textcolors = ["black", "white"]
140 | cmap = "YlGn"
141 | ylabel = "split score"
142 | col_labels = ["split %d" % i for i in range(0, self.scores.shape[1])]
143 | row_labels = ["step %d" % i for i in range(1, self.scores.shape[0] + 1)]
144 | fig, ax = plt.subplots()
145 | im = ax.imshow(self.scores, cmap=cmap)
146 | cbar = ax.figure.colorbar(im, ax=ax)
147 | cbar.ax.set_ylabel(ylabel, rotation=-90, va="bottom")
148 | ax.set_xticks(np.arange(self.scores.shape[1]))
149 | ax.set_yticks(np.arange(self.scores.shape[0]))
150 | ax.set_xticklabels(col_labels)
151 | ax.set_yticklabels(row_labels)
152 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
153 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
154 | for edge, spine in ax.spines.items():
155 | spine.set_visible(False)
156 | ax.set_xticks(np.arange(self.scores.shape[1] + 1) - .5, minor=True)
157 | ax.set_yticks(np.arange(self.scores.shape[0] + 1) - .5, minor=True)
158 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
159 | ax.tick_params(which="minor", bottom=False, left=False)
160 | threshold = im.norm(self.scores.max()) / 2.
161 | valfmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}")
162 | texts = []
163 | kw = dict(horizontalalignment="center", verticalalignment="center")
164 | for i in range(self.scores.shape[0]):
165 | for j in range(self.scores.shape[1]):
166 | kw.update(color=textcolors[im.norm(self.scores[i, j]) > threshold])
167 | text = im.axes.text(j, i, valfmt(self.scores[i, j], None), **kw)
168 | texts.append(text)
169 | fig.tight_layout()
170 | plt.show()
171 |
172 | def __repr__(self):
173 | return "[step %d]memory size: %s, state size: %s\n stack:\n%s\n, scores:\n %s" % \
174 | (self.step, str(self.memory.size()), str(self.state.size()),
175 | "\n".join(map(str, self.stack)) or "[]",
176 | str(self.scores))
177 |
178 | def __str__(self):
179 | return repr(self)
180 |
--------------------------------------------------------------------------------
/ch_dp_gan_xlnet/treebuilder/dp_nr/parser.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | import numpy as np
4 | from structure.nodes import Paragraph, Relation, rev_relationmap
5 | import matplotlib
6 | import matplotlib.pyplot as plt
7 | from interface import ParserI
8 | from structure.nodes import EDU, TEXT
9 | from config import *
10 |
11 |
12 | class PartPtrParser(ParserI):
13 | def __init__(self, model):
14 | self.parser = PartitionPtrParser(model)
15 |
16 | def parse(self, para):
17 | edus = []
18 | for edu in para.edus():
19 | edu_copy = EDU([TEXT(edu.text)])
20 | setattr(edu_copy, "words", edu.words)
21 | setattr(edu_copy, "tags", edu.tags)
22 | edus.append(edu_copy)
23 | return self.parser.parse(edus)
24 |
25 |
26 | class PartitionPtrParser:
27 | def __init__(self, model):
28 | self.model = model
29 |
30 | def parse(self, edus, ret_session=False, tokenizer_xl=None, model_xl=None):
31 | if len(edus) < 2:
32 | return Paragraph(edus)
33 |
34 | # TODO implement beam search
35 | session = self.init_session(edus, tokenizer_xl, model_xl)
36 | while not session.terminate():
37 | split_scores, nr_score, state = self.decode(session)
38 | split = split_scores.argmax()
39 | nr_id = nr_score[split].argmax()
40 | nr = ids2nr[nr_id]
41 | nucl_id, rel_id = int(nr.split("-")[0]), int(nr.split("-")[1])
42 | nuclear = self.model.nuc_label.id2label[nucl_id]
43 | relation = self.model.rel_label.id2label[rel_id]
44 | session = session.forward(split_scores, state, split, nuclear, relation)
45 | root_relation = self.build_tree(edus, session.splits[:], session.nuclears[:], session.relations[:])
46 | discourse = Paragraph([root_relation])
47 | if ret_session:
48 | return discourse, session
49 | else:
50 | return discourse
51 |
52 | def init_session(self, edus, tokenizer_xl, model_xl):
53 | edu_words = [edu.words for edu in edus]
54 | edu_txts = ["".join(edu.words) for edu in edus]
55 | edu_poses = [edu.tags for edu in edus]
56 | max_word_seqlen = max(len(words) for words in edu_words)
57 | edu_seqlen = len(edu_words)
58 |
59 | e_input_words = np.zeros((1, edu_seqlen, max_word_seqlen), dtype=np.long)
60 | e_txt = [edu_txts]
61 | e_boundaries = np.zeros([1, edu_seqlen + 1], dtype=np.long)
62 | e_input_poses = np.zeros_like(e_input_words)
63 | e_input_masks = np.zeros_like(e_input_words, dtype=np.uint8)
64 |
65 | for i, (words, poses) in enumerate(zip(edu_words, edu_poses)):
66 | e_input_words[0, i, :len(words)] = [self.model.word_vocab[word] for word in words]
67 | e_input_poses[0, i, :len(poses)] = [self.model.pos_vocab[pos] for pos in poses]
68 | e_input_masks[0, i, :len(words)] = 1
69 | if self.model.word_vocab[words[-1]] in [39, 3015, 178]:
70 | e_boundaries[0][i + 1] = 1 # 句子边界
71 |
72 | e_input_words = torch.from_numpy(e_input_words).long()
73 | e_boundaries = torch.from_numpy(e_boundaries).long()
74 | e_input_poses = torch.from_numpy(e_input_poses).long()
75 | e_input_masks = torch.from_numpy(e_input_masks).byte()
76 |
77 | if self.model.use_gpu:
78 | e_input_words = e_input_words.cuda()
79 | e_boundaries = e_boundaries.cuda()
80 | e_input_poses = e_input_poses.cuda()
81 | e_input_masks = e_input_masks.cuda()
82 |
83 | edu_encoded, e_masks = self.model.encode_edus((e_input_words, e_input_poses, e_input_masks, e_boundaries, e_txt)
84 | , tokenizer_xl, model_xl)
85 | memory, _, context = self.model.encoder(edu_encoded, e_masks, e_boundaries)
86 | state = self.model.context_dense(context).unsqueeze(0)
87 | return Session(memory, state)
88 |
89 | def decode(self, session):
90 | left, right = session.stack[-1]
91 | return self.model(left, right, session.memory, session.state)
92 |
93 | def build_tree(self, edus, splits, nuclears, relations):
94 | left, split, right = splits.pop(0)
95 | nuclear = nuclears.pop(0)
96 | ftype = relations.pop(0)
97 | ctype = rev_relationmap[ftype]
98 | if split - left == 1:
99 | left_node = edus[left]
100 | else:
101 | left_node = self.build_tree(edus, splits, nuclears, relations)
102 |
103 | if right - split == 1:
104 | right_node = edus[split]
105 | else:
106 | right_node = self.build_tree(edus, splits, nuclears, relations)
107 |
108 | relation = Relation([left_node, right_node], nuclear=nuclear, ftype=ftype, ctype=ctype)
109 | return relation
110 |
111 |
112 | class Session:
113 | def __init__(self, memory, state):
114 | self.n = memory.size(1) - 2
115 | self.step = 0
116 | self.memory = memory
117 | self.state = state
118 | self.stack = [(0, self.n + 1)]
119 | self.scores = np.zeros((self.n, self.n+2), dtype=np.float)
120 | self.splits = []
121 | self.nuclears = []
122 | self.relations = []
123 |
124 | def forward(self, score, state, split, nuclear, relation):
125 | left, right = self.stack.pop()
126 | if right - split > 1:
127 | self.stack.append((split, right))
128 | if split - left > 1:
129 | self.stack.append((left, split))
130 | self.splits.append((left, split, right))
131 | self.nuclears.append(nuclear)
132 | self.relations.append(relation)
133 | self.state = state
134 | self.scores[self.step] = score
135 | self.step += 1
136 | return self
137 |
138 | def terminate(self):
139 | return self.step >= self.n
140 |
141 | def draw_decision_hotmap(self):
142 | textcolors = ["black", "white"]
143 | cmap = "YlGn"
144 | ylabel = "split score"
145 | col_labels = ["split %d" % i for i in range(0, self.scores.shape[1])]
146 | row_labels = ["step %d" % i for i in range(1, self.scores.shape[0] + 1)]
147 | fig, ax = plt.subplots()
148 | im = ax.imshow(self.scores, cmap=cmap)
149 | cbar = ax.figure.colorbar(im, ax=ax)
150 | cbar.ax.set_ylabel(ylabel, rotation=-90, va="bottom")
151 | ax.set_xticks(np.arange(self.scores.shape[1]))
152 | ax.set_yticks(np.arange(self.scores.shape[0]))
153 | ax.set_xticklabels(col_labels)
154 | ax.set_yticklabels(row_labels)
155 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
156 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
157 | for edge, spine in ax.spines.items():
158 | spine.set_visible(False)
159 | ax.set_xticks(np.arange(self.scores.shape[1] + 1) - .5, minor=True)
160 | ax.set_yticks(np.arange(self.scores.shape[0] + 1) - .5, minor=True)
161 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
162 | ax.tick_params(which="minor", bottom=False, left=False)
163 | threshold = im.norm(self.scores.max()) / 2.
164 | valfmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}")
165 | texts = []
166 | kw = dict(horizontalalignment="center", verticalalignment="center")
167 | for i in range(self.scores.shape[0]):
168 | for j in range(self.scores.shape[1]):
169 | kw.update(color=textcolors[im.norm(self.scores[i, j]) > threshold])
170 | text = im.axes.text(j, i, valfmt(self.scores[i, j], None), **kw)
171 | texts.append(text)
172 | fig.tight_layout()
173 | plt.show()
174 |
175 | def __repr__(self):
176 | return "[step %d]memory size: %s, state size: %s\n stack:\n%s\n, scores:\n %s" % \
177 | (self.step, str(self.memory.size()), str(self.state.size()),
178 | "\n".join(map(str, self.stack)) or "[]",
179 | str(self.scores))
180 |
181 | def __str__(self):
182 | return repr(self)
183 |
--------------------------------------------------------------------------------