├── .gitignore ├── README.md ├── download_checkpoints.sh ├── download_data.sh ├── download_results.sh ├── dual_flow ├── GAT.py ├── dataset.py ├── dataset_entity.py ├── eval_kamed.sh ├── eval_meddg.sh ├── evaluating.py ├── get_entity_embed.py ├── get_prediction_topk.py ├── get_prediction_topk.sh ├── main.py ├── model.py ├── parsing.py ├── train_kamed.sh ├── train_meddg.sh ├── training.py └── utils.py ├── generation ├── dataset.py ├── eval_kamed.sh ├── eval_meddg.sh ├── evaluating.py ├── inference.py ├── main.py ├── metric.sh ├── metrics.py ├── model.py ├── modeling_bart.py ├── parsing.py ├── train_kamed.sh ├── train_meddg.sh ├── training.py └── utils.py ├── images └── framework.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | #.DS_Store 163 | .DS_Store 164 | dual_flow/.DS_Store 165 | generation/.DS_Store 166 | images/.DS_Store 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DFMed: Medical Dialogue Generation via Dual Flow Modeling 2 | This is the code for ACL 2023 Findings paper: [*Medical Dialogue Generation via Dual Flow Modeling*](https://arxiv.org/abs/2305.18109) by Kaishuai Xu, Wenjun Hou, Yi Cheng, Jian Wang, and Wenjie Li. 3 | 4 | **DFMed** is a novel medical dialogue generation framework, which models the transitions of medical entities and dialogue acts via step-by-step interweaving. 5 | 6 | ![](images/framework.png) 7 | 8 | ## Requirements 9 | Please create a new conda env and install the following pytorch version and main requirement packages (others are in the file). 10 | ``` 11 | conda create -n dual_flow python=3.8 12 | conda activate dual_flow 13 | 14 | pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1 --extra-index-url https://download.pytorch.org/whl/cu113 15 | ``` 16 | - `transformers==4.24.0` 17 | - `nltk==3.4.1` 18 | - `rouge==1.0.1` 19 | - `setuptools==59.5.0` 20 | - `numpy` 21 | 22 | Please note that the `nltk` version is important for calculating BLEU. 23 | ## Data and Checkpoints 24 | Download the data (including the CMeKG knowledge graph, MedDG dataset, and KaMed dataset) through: 25 | ``` 26 | sh download_data.sh 27 | ``` 28 | Download the fine-tuned generation checkpoints through: 29 | ``` 30 | sh download_checkpoints.sh 31 | ``` 32 | Download the fine-tuned dual flow learning checkpoints through: 33 | - MedDG at Dropbox [link](https://www.dropbox.com/scl/fi/aol4jav6mjb25p2x1sekn/test_meddg.tar.gz?rlkey=6fbr2pz4gia8tmduzbibqlbqj&dl=0) 34 | - KaMed at Dropbox [link](https://www.dropbox.com/scl/fi/m73ri6kit8u6xjc79ol8k/test_kamed.tar.gz?rlkey=p5rqv1okcqzl7jjvrm0m1zhee&dl=0) 35 | 36 | ## Results 37 | Download the results of dual flow learning and response generation through: 38 | ``` 39 | sh download_results.sh 40 | ``` 41 | 42 | ## Directory 43 | The final directory is as follows: 44 | ``` 45 | └── DFMed 46 | ├── dual_flow 47 | ├── generation 48 | ├── data 49 | ├── images 50 | ├── results 51 | │ ├── df_results 52 | | └── generation_results 53 | ├── download_checkpoints.sh 54 | ├── download_data.sh 55 | ├── download_results.sh 56 | └── requirements.txt 57 | ``` 58 | 59 | ## Implementations 60 | 1. Train the dual flow learning model. 61 | ``` 62 | cd dual_flow 63 | 64 | # For the MedDG dataset 65 | sh train_meddg.sh 66 | 67 | # For the KaMed dataset 68 | sh train_kamed.sh 69 | ``` 70 | 2. Get act and entity predictions from the checkpoints of top performance. 71 | ``` 72 | sh get_prediction_topk.sh 73 | ``` 74 | 3. Train the generation model. 75 | ``` 76 | cd generation 77 | 78 | # For the MedDG dataset 79 | sh train_meddg.sh 80 | 81 | # For the KaMed dataset 82 | sh train_kamed.sh 83 | ``` 84 | 4. Inference. The `df_results` directory contains act and entity predictions of our training. 85 | ``` 86 | cd generation 87 | 88 | # For the MedDG dataset 89 | sh eval_meddg.sh 90 | 91 | # For the KaMed dataset 92 | sh eval_kamed.sh 93 | ``` 94 | 5. Calculate metrics. We use the algorithm presented by the official [code](https://github.com/lwgkzl/MedDG/blob/master/MedDG/generation/CY_DataReadandMetric.py) of the MedDG dataset. 95 | ``` 96 | python metrics.py --hp ./generate.txt --rf ./reference.txt 97 | ``` 98 | ## Cite 99 | If you use our codes or your research is related to our work, please kindly cite our paper: 100 | ```bibtex 101 | @inproceedings{xu-etal-2023-medical, 102 | title = "Medical Dialogue Generation via Dual Flow Modeling", 103 | author = "Xu, Kaishuai and 104 | Hou, Wenjun and 105 | Cheng, Yi and 106 | Wang, Jian and 107 | Li, Wenjie", 108 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2023", 109 | year = "2023", 110 | address = "Toronto, Canada", 111 | publisher = "Association for Computational Linguistics", 112 | url = "https://aclanthology.org/2023.findings-acl.423", 113 | doi = "10.18653/v1/2023.findings-acl.423", 114 | pages = "6771--6784", 115 | } 116 | ``` -------------------------------------------------------------------------------- /download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | # Download meddg generation checkpoint from Google Drive 2 | filename='meddg_fine_tuned_checkpoint.zip' 3 | fileid='1cIzbRu4Hb6IxMJZ_Ig8I25gtOSRE5aTO' 4 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 5 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 6 | 7 | # Unzip 8 | unzip -q ${filename} 9 | rm ${filename} 10 | 11 | # Download kamed generation checkpoint from Google Drive 12 | filename='kamed_fine_tuned_checkpoint.zip' 13 | fileid='15ZwWVqkaugA7EFZEQDS-2B4f3LKgtpAA' 14 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 15 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 16 | 17 | # Unzip 18 | unzip -q ${filename} 19 | rm ${filename} 20 | 21 | rm ./cookie 22 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | # Download dataset from Google Drive 2 | filename='data.zip' 3 | fileid='1Tb2tO3tQ-6a-j0AiBKDu0RCiDbygH07r' 4 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 5 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 6 | 7 | # Unzip 8 | unzip -q ${filename} 9 | rm ${filename} 10 | 11 | # Download medbert-kd-chinese from Google Drive 12 | filename='medbert-kd-chinese.zip' 13 | fileid='18inyU0OPaPJLh7UleQh8g9hrQV4Og9Cv' 14 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 15 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 16 | 17 | # Unzip 18 | unzip -q ${filename} 19 | rm ${filename} 20 | 21 | # Download bart-base-chinese from Google Drive 22 | filename='bart-base-chinese.zip' 23 | fileid='1W6Yu3-WBrDuxg9qGs1GDCtxzmhpaKEv6' 24 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 25 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 26 | 27 | # Unzip 28 | unzip -q ${filename} 29 | rm ${filename} 30 | 31 | rm ./cookie 32 | -------------------------------------------------------------------------------- /download_results.sh: -------------------------------------------------------------------------------- 1 | # Download dual flow learning results from Google Drive 2 | filename='df_results.zip' 3 | fileid='1ZOaOjDYqqCV4bUyJBG6KA4eAh3zlqz6j' 4 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 5 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 6 | 7 | # Unzip 8 | unzip -q ${filename} 9 | rm ${filename} 10 | 11 | # Download response generation results from Google Drive 12 | filename='generation_results.zip' 13 | fileid='1bCgziRag6uL_kDRzN49vJeJQ76iSQyo1' 14 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 15 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 16 | 17 | # Unzip 18 | unzip -q ${filename} 19 | rm ${filename} 20 | 21 | rm ./cookie 22 | -------------------------------------------------------------------------------- /dual_flow/GAT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import enum 4 | 5 | class LayerType(enum.Enum): 6 | IMP1 = 0, 7 | IMP2 = 1, 8 | IMP3 = 2 9 | 10 | class GAT(torch.nn.Module): 11 | """ 12 | I've added 3 GAT implementations - some are conceptually easier to understand some are more efficient. 13 | 14 | The most interesting and hardest one to understand is implementation #3. 15 | Imp1 and imp2 differ in subtle details but are basically the same thing. 16 | 17 | Tip on how to approach this: 18 | understand implementation 2 first, check out the differences it has with imp1, and finally tackle imp #3. 19 | 20 | """ 21 | 22 | def __init__(self, num_of_layers, num_heads_per_layer, num_features_per_layer, add_skip_connection=True, bias=True, 23 | dropout=0.6, layer_type=LayerType.IMP3, log_attention_weights=False): 24 | super().__init__() 25 | assert num_of_layers == len(num_heads_per_layer) == len(num_features_per_layer), f'Enter valid arch params.' 26 | 27 | GATLayer = get_layer_type(layer_type) # fetch one of 3 available implementations 28 | gat_layers = [] # collect GAT layers 29 | for i in range(num_of_layers): 30 | layer = GATLayer( 31 | num_in_features=num_features_per_layer[i] * num_heads_per_layer[i], # consequence of concatenation 32 | num_out_features=num_features_per_layer[i], 33 | num_of_heads=num_heads_per_layer[i], 34 | concat=True, 35 | activation=nn.ELU(), 36 | dropout_prob=dropout, 37 | add_skip_connection=add_skip_connection, 38 | bias=bias, 39 | log_attention_weights=log_attention_weights 40 | ) 41 | gat_layers.append(layer) 42 | 43 | self.gat_net = nn.Sequential( 44 | *gat_layers, 45 | ) 46 | 47 | # data is just a (in_nodes_features, topology) tuple, I had to do it like this because of the nn.Sequential: 48 | # https://discuss.pytorch.org/t/forward-takes-2-positional-arguments-but-3-were-given-for-nn-sqeuential-with-linear-layers/65698 49 | def forward(self, data): 50 | return self.gat_net(data) 51 | 52 | 53 | class GATLayer(torch.nn.Module): 54 | """ 55 | Base class for all implementations as there is much code that would otherwise be copy/pasted. 56 | """ 57 | 58 | head_dim = 1 59 | 60 | def __init__(self, num_in_features, num_out_features, num_of_heads, layer_type, concat=True, activation=nn.ELU(), 61 | dropout_prob=0.6, add_skip_connection=True, bias=True, log_attention_weights=False): 62 | 63 | super().__init__() 64 | 65 | # Saving these as we'll need them in forward propagation in children layers (imp1/2/3) 66 | self.num_of_heads = num_of_heads 67 | self.num_out_features = num_out_features 68 | self.concat = concat # whether we should concatenate or average the attention heads 69 | self.add_skip_connection = add_skip_connection 70 | 71 | # 72 | # Trainable weights: linear projection matrix (denoted as "W" in the paper), attention target/source 73 | # (denoted as "a" in the paper) and bias (not mentioned in the paper but present in the official GAT repo) 74 | # 75 | 76 | if layer_type == LayerType.IMP1: 77 | # Experimenting with different options to see what is faster (tip: focus on 1 implementation at a time) 78 | self.proj_param = nn.Parameter(torch.Tensor(num_of_heads, num_in_features, num_out_features)) 79 | else: 80 | # You can treat this one matrix as num_of_heads independent W matrices 81 | self.linear_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False) 82 | 83 | # After we concatenate target node (node i) and source node (node j) we apply the additive scoring function 84 | # which gives us un-normalized score "e". Here we split the "a" vector - but the semantics remain the same. 85 | 86 | # Basically instead of doing [x, y] (concatenation, x/y are node feature vectors) and dot product with "a" 87 | # we instead do a dot product between x and "a_left" and y and "a_right" and we sum them up 88 | self.scoring_fn_target = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features)) 89 | self.scoring_fn_source = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features)) 90 | 91 | if layer_type == LayerType.IMP1: # simple reshape in the case of implementation 1 92 | self.scoring_fn_target = nn.Parameter(self.scoring_fn_target.reshape(num_of_heads, num_out_features, 1)) 93 | self.scoring_fn_source = nn.Parameter(self.scoring_fn_source.reshape(num_of_heads, num_out_features, 1)) 94 | 95 | # Bias is definitely not crucial to GAT - feel free to experiment (I pinged the main author, Petar, on this one) 96 | if bias and concat: 97 | self.bias = nn.Parameter(torch.Tensor(num_of_heads * num_out_features)) 98 | elif bias and not concat: 99 | self.bias = nn.Parameter(torch.Tensor(num_out_features)) 100 | else: 101 | self.register_parameter('bias', None) 102 | 103 | if add_skip_connection: 104 | self.skip_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False) 105 | else: 106 | self.register_parameter('skip_proj', None) 107 | 108 | # 109 | # End of trainable weights 110 | # 111 | 112 | self.leakyReLU = nn.LeakyReLU(0.2) # using 0.2 as in the paper, no need to expose every setting 113 | self.softmax = nn.Softmax(dim=-1) # -1 stands for apply the log-softmax along the last dimension 114 | self.activation = activation 115 | # Probably not the nicest design but I use the same module in 3 locations, before/after features projection 116 | # and for attention coefficients. Functionality-wise it's the same as using independent modules. 117 | self.dropout = nn.Dropout(p=dropout_prob) 118 | 119 | self.log_attention_weights = log_attention_weights # whether we should log the attention weights 120 | self.attention_weights = None # for later visualization purposes, I cache the weights here 121 | 122 | self.init_params(layer_type) 123 | 124 | def init_params(self, layer_type): 125 | """ 126 | The reason we're using Glorot (aka Xavier uniform) initialization is because it's a default TF initialization: 127 | https://stackoverflow.com/questions/37350131/what-is-the-default-variable-initializer-in-tensorflow 128 | The original repo was developed in TensorFlow (TF) and they used the default initialization. 129 | Feel free to experiment - there may be better initializations depending on your problem. 130 | """ 131 | nn.init.xavier_uniform_(self.proj_param if layer_type == LayerType.IMP1 else self.linear_proj.weight) 132 | nn.init.xavier_uniform_(self.scoring_fn_target) 133 | nn.init.xavier_uniform_(self.scoring_fn_source) 134 | 135 | if self.bias is not None: 136 | torch.nn.init.zeros_(self.bias) 137 | 138 | def skip_concat_bias(self, attention_coefficients, in_nodes_features, out_nodes_features): 139 | if self.log_attention_weights: # potentially log for later visualization in playground.py 140 | self.attention_weights = attention_coefficients 141 | 142 | # if the tensor is not contiguously stored in memory we'll get an error after we try to do certain ops like view 143 | # only imp1 will enter this one 144 | if not out_nodes_features.is_contiguous(): 145 | out_nodes_features = out_nodes_features.contiguous() 146 | 147 | if self.add_skip_connection: # add skip or residual connection 148 | if out_nodes_features.shape[-1] == in_nodes_features.shape[-1]: # if FIN == FOUT 149 | # unsqueeze does this: (N, FIN) -> (N, 1, FIN), out features are (N, NH, FOUT) so 1 gets broadcast to NH 150 | # thus we're basically copying input vectors NH times and adding to processed vectors 151 | out_nodes_features += in_nodes_features.unsqueeze(1) 152 | else: 153 | # FIN != FOUT so we need to project input feature vectors into dimension that can be added to output 154 | # feature vectors. skip_proj adds lots of additional capacity which may cause overfitting. 155 | out_nodes_features += self.skip_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features) 156 | 157 | if self.concat: 158 | # shape = (N, NH, FOUT) -> (N, NH*FOUT) 159 | out_nodes_features = out_nodes_features.view(-1, self.num_of_heads * self.num_out_features) 160 | else: 161 | # shape = (N, NH, FOUT) -> (N, FOUT) 162 | out_nodes_features = out_nodes_features.mean(dim=self.head_dim) 163 | 164 | if self.bias is not None: 165 | out_nodes_features += self.bias 166 | 167 | return out_nodes_features if self.activation is None else self.activation(out_nodes_features) 168 | 169 | 170 | class GATLayerImp3(GATLayer): 171 | """ 172 | Implementation #3 was inspired by PyTorch Geometric: https://github.com/rusty1s/pytorch_geometric 173 | 174 | But, it's hopefully much more readable! (and of similar performance) 175 | 176 | It's suitable for both transductive and inductive settings. In the inductive setting we just merge the graphs 177 | into a single graph with multiple components and this layer is agnostic to that fact! <3 178 | 179 | """ 180 | 181 | src_nodes_dim = 0 # position of source nodes in edge index 182 | trg_nodes_dim = 1 # position of target nodes in edge index 183 | 184 | nodes_dim = 0 # node dimension/axis 185 | head_dim = 1 # attention head dimension/axis 186 | 187 | def __init__(self, num_in_features, num_out_features, num_of_heads, concat=True, activation=nn.ELU(), 188 | dropout_prob=0.6, add_skip_connection=True, bias=True, log_attention_weights=False): 189 | 190 | # Delegate initialization to the base class 191 | super().__init__(num_in_features, num_out_features, num_of_heads, LayerType.IMP3, concat, activation, dropout_prob, 192 | add_skip_connection, bias, log_attention_weights) 193 | 194 | def forward(self, data): 195 | # 196 | # Step 1: Linear Projection + regularization 197 | # 198 | 199 | in_nodes_features, edge_index = data # unpack data 200 | num_of_nodes = in_nodes_features.shape[self.nodes_dim] 201 | assert edge_index.shape[0] == 2, f'Expected edge index with shape=(2,E) got {edge_index.shape}' 202 | 203 | # shape = (N, FIN) where N - number of nodes in the graph, FIN - number of input features per node 204 | # We apply the dropout to all of the input node features (as mentioned in the paper) 205 | # Note: for Cora features are already super sparse so it's questionable how much this actually helps 206 | in_nodes_features = self.dropout(in_nodes_features) 207 | 208 | # shape = (N, FIN) * (FIN, NH*FOUT) -> (N, NH, FOUT) where NH - number of heads, FOUT - num of output features 209 | # We project the input node features into NH independent output features (one for each attention head) 210 | nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features) 211 | 212 | nodes_features_proj = self.dropout(nodes_features_proj) # in the official GAT imp they did dropout here as well 213 | 214 | # 215 | # Step 2: Edge attention calculation 216 | # 217 | 218 | # Apply the scoring function (* represents element-wise (a.k.a. Hadamard) product) 219 | # shape = (N, NH, FOUT) * (1, NH, FOUT) -> (N, NH, 1) -> (N, NH) because sum squeezes the last dimension 220 | # Optimization note: torch.sum() is as performant as .sum() in my experiments 221 | scores_source = (nodes_features_proj * self.scoring_fn_source).sum(dim=-1) 222 | scores_target = (nodes_features_proj * self.scoring_fn_target).sum(dim=-1) 223 | 224 | # We simply copy (lift) the scores for source/target nodes based on the edge index. Instead of preparing all 225 | # the possible combinations of scores we just prepare those that will actually be used and those are defined 226 | # by the edge index. 227 | # scores shape = (E, NH), nodes_features_proj_lifted shape = (E, NH, FOUT), E - number of edges in the graph 228 | scores_source_lifted, scores_target_lifted, nodes_features_proj_lifted = self.lift(scores_source, scores_target, nodes_features_proj, edge_index) 229 | scores_per_edge = self.leakyReLU(scores_source_lifted + scores_target_lifted) 230 | 231 | # shape = (E, NH, 1) 232 | attentions_per_edge = self.neighborhood_aware_softmax(scores_per_edge, edge_index[self.trg_nodes_dim], num_of_nodes) 233 | # Add stochasticity to neighborhood aggregation 234 | attentions_per_edge = self.dropout(attentions_per_edge) 235 | 236 | # 237 | # Step 3: Neighborhood aggregation 238 | # 239 | 240 | # Element-wise (aka Hadamard) product. Operator * does the same thing as torch.mul 241 | # shape = (E, NH, FOUT) * (E, NH, 1) -> (E, NH, FOUT), 1 gets broadcast into FOUT 242 | nodes_features_proj_lifted_weighted = nodes_features_proj_lifted * attentions_per_edge 243 | 244 | # This part sums up weighted and projected neighborhood feature vectors for every target node 245 | # shape = (N, NH, FOUT) 246 | out_nodes_features = self.aggregate_neighbors(nodes_features_proj_lifted_weighted, edge_index, in_nodes_features, num_of_nodes) 247 | 248 | # 249 | # Step 4: Residual/skip connections, concat and bias 250 | # 251 | 252 | out_nodes_features = self.skip_concat_bias(attentions_per_edge, in_nodes_features, out_nodes_features) 253 | return (out_nodes_features, edge_index) 254 | 255 | # 256 | # Helper functions (without comments there is very little code so don't be scared!) 257 | # 258 | 259 | def neighborhood_aware_softmax(self, scores_per_edge, trg_index, num_of_nodes): 260 | """ 261 | As the fn name suggest it does softmax over the neighborhoods. Example: say we have 5 nodes in a graph. 262 | Two of them 1, 2 are connected to node 3. If we want to calculate the representation for node 3 we should take 263 | into account feature vectors of 1, 2 and 3 itself. Since we have scores for edges 1-3, 2-3 and 3-3 264 | in scores_per_edge variable, this function will calculate attention scores like this: 1-3/(1-3+2-3+3-3) 265 | (where 1-3 is overloaded notation it represents the edge 1-3 and it's (exp) score) and similarly for 2-3 and 3-3 266 | i.e. for this neighborhood we don't care about other edge scores that include nodes 4 and 5. 267 | 268 | Note: 269 | Subtracting the max value from logits doesn't change the end result but it improves the numerical stability 270 | and it's a fairly common "trick" used in pretty much every deep learning framework. 271 | Check out this link for more details: 272 | 273 | https://stats.stackexchange.com/questions/338285/how-does-the-subtraction-of-the-logit-maximum-improve-learning 274 | 275 | """ 276 | # Calculate the numerator. Make logits <= 0 so that e^logit <= 1 (this will improve the numerical stability) 277 | scores_per_edge = scores_per_edge - scores_per_edge.max() 278 | exp_scores_per_edge = scores_per_edge.exp() # softmax 279 | 280 | # Calculate the denominator. shape = (E, NH) 281 | neigborhood_aware_denominator = self.sum_edge_scores_neighborhood_aware(exp_scores_per_edge, trg_index, num_of_nodes) 282 | 283 | # 1e-16 is theoretically not needed but is only there for numerical stability (avoid div by 0) - due to the 284 | # possibility of the computer rounding a very small number all the way to 0. 285 | attentions_per_edge = exp_scores_per_edge / (neigborhood_aware_denominator + 1e-16) 286 | 287 | # shape = (E, NH) -> (E, NH, 1) so that we can do element-wise multiplication with projected node features 288 | return attentions_per_edge.unsqueeze(-1) 289 | 290 | def sum_edge_scores_neighborhood_aware(self, exp_scores_per_edge, trg_index, num_of_nodes): 291 | # The shape must be the same as in exp_scores_per_edge (required by scatter_add_) i.e. from E -> (E, NH) 292 | trg_index_broadcasted = self.explicit_broadcast(trg_index, exp_scores_per_edge) 293 | 294 | # shape = (N, NH), where N is the number of nodes and NH the number of attention heads 295 | size = list(exp_scores_per_edge.shape) # convert to list otherwise assignment is not possible 296 | size[self.nodes_dim] = num_of_nodes 297 | neighborhood_sums = torch.zeros(size, dtype=exp_scores_per_edge.dtype, device=exp_scores_per_edge.device) 298 | 299 | # position i will contain a sum of exp scores of all the nodes that point to the node i (as dictated by the 300 | # target index) 301 | neighborhood_sums.scatter_add_(self.nodes_dim, trg_index_broadcasted, exp_scores_per_edge) 302 | 303 | # Expand again so that we can use it as a softmax denominator. e.g. node i's sum will be copied to 304 | # all the locations where the source nodes pointed to i (as dictated by the target index) 305 | # shape = (N, NH) -> (E, NH) 306 | return neighborhood_sums.index_select(self.nodes_dim, trg_index) 307 | 308 | def aggregate_neighbors(self, nodes_features_proj_lifted_weighted, edge_index, in_nodes_features, num_of_nodes): 309 | size = list(nodes_features_proj_lifted_weighted.shape) # convert to list otherwise assignment is not possible 310 | size[self.nodes_dim] = num_of_nodes # shape = (N, NH, FOUT) 311 | out_nodes_features = torch.zeros(size, dtype=in_nodes_features.dtype, device=in_nodes_features.device) 312 | 313 | # shape = (E) -> (E, NH, FOUT) 314 | trg_index_broadcasted = self.explicit_broadcast(edge_index[self.trg_nodes_dim], nodes_features_proj_lifted_weighted) 315 | # aggregation step - we accumulate projected, weighted node features for all the attention heads 316 | # shape = (E, NH, FOUT) -> (N, NH, FOUT) 317 | out_nodes_features.scatter_add_(self.nodes_dim, trg_index_broadcasted, nodes_features_proj_lifted_weighted) 318 | 319 | return out_nodes_features 320 | 321 | def lift(self, scores_source, scores_target, nodes_features_matrix_proj, edge_index): 322 | """ 323 | Lifts i.e. duplicates certain vectors depending on the edge index. 324 | One of the tensor dims goes from N -> E (that's where the "lift" comes from). 325 | 326 | """ 327 | src_nodes_index = edge_index[self.src_nodes_dim] 328 | trg_nodes_index = edge_index[self.trg_nodes_dim] 329 | 330 | # Using index_select is faster than "normal" indexing (scores_source[src_nodes_index]) in PyTorch! 331 | scores_source = scores_source.index_select(self.nodes_dim, src_nodes_index) 332 | scores_target = scores_target.index_select(self.nodes_dim, trg_nodes_index) 333 | nodes_features_matrix_proj_lifted = nodes_features_matrix_proj.index_select(self.nodes_dim, src_nodes_index) 334 | 335 | return scores_source, scores_target, nodes_features_matrix_proj_lifted 336 | 337 | def explicit_broadcast(self, this, other): 338 | # Append singleton dimensions until this.dim() == other.dim() 339 | for _ in range(this.dim(), other.dim()): 340 | this = this.unsqueeze(-1) 341 | 342 | # Explicitly expand so that shapes are the same 343 | return this.expand_as(other) 344 | 345 | 346 | class GATLayerImp2(GATLayer): 347 | """ 348 | Implementation #2 was inspired by the official GAT implementation: https://github.com/PetarV-/GAT 349 | It's conceptually simpler than implementation #3 but computationally much less efficient. 350 | Note: this is the naive implementation not the sparse one and it's only suitable for a transductive setting. 351 | It would be fairly easy to make it work in the inductive setting as well but the purpose of this layer 352 | is more educational since it's way less efficient than implementation 3. 353 | """ 354 | 355 | def __init__(self, num_in_features, num_out_features, num_of_heads, concat=True, activation=nn.ELU(), 356 | dropout_prob=0.6, add_skip_connection=True, bias=True, log_attention_weights=False): 357 | 358 | super().__init__(num_in_features, num_out_features, num_of_heads, LayerType.IMP2, concat, activation, dropout_prob, 359 | add_skip_connection, bias, log_attention_weights) 360 | 361 | def forward(self, data): 362 | # 363 | # Step 1: Linear Projection + regularization (using linear layer instead of matmul as in imp1) 364 | # 365 | 366 | in_nodes_features, connectivity_mask = data # unpack data 367 | num_of_nodes = in_nodes_features.shape[0] 368 | assert connectivity_mask.shape == (num_of_nodes, num_of_nodes), \ 369 | f'Expected connectivity matrix with shape=({num_of_nodes},{num_of_nodes}), got shape={connectivity_mask.shape}.' 370 | 371 | # shape = (N, FIN) where N - number of nodes in the graph, FIN - number of input features per node 372 | # We apply the dropout to all of the input node features (as mentioned in the paper) 373 | in_nodes_features = self.dropout(in_nodes_features) 374 | 375 | # shape = (N, FIN) * (FIN, NH*FOUT) -> (N, NH, FOUT) where NH - number of heads, FOUT - num of output features 376 | # We project the input node features into NH independent output features (one for each attention head) 377 | nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features) 378 | 379 | nodes_features_proj = self.dropout(nodes_features_proj) # in the official GAT imp they did dropout here as well 380 | 381 | # 382 | # Step 2: Edge attention calculation (using sum instead of bmm + additional permute calls - compared to imp1) 383 | # 384 | 385 | # Apply the scoring function (* represents element-wise (a.k.a. Hadamard) product) 386 | # shape = (N, NH, FOUT) * (1, NH, FOUT) -> (N, NH, 1) 387 | # Optimization note: torch.sum() is as performant as .sum() in my experiments 388 | scores_source = torch.sum((nodes_features_proj * self.scoring_fn_source), dim=-1, keepdim=True) 389 | scores_target = torch.sum((nodes_features_proj * self.scoring_fn_target), dim=-1, keepdim=True) 390 | 391 | # src shape = (NH, N, 1) and trg shape = (NH, 1, N) 392 | scores_source = scores_source.transpose(0, 1) 393 | scores_target = scores_target.permute(1, 2, 0) 394 | 395 | # shape = (NH, N, 1) + (NH, 1, N) -> (NH, N, N) with the magic of automatic broadcast <3 396 | # In Implementation 3 we are much smarter and don't have to calculate all NxN scores! (only E!) 397 | # Tip: it's conceptually easier to understand what happens here if you delete the NH dimension 398 | all_scores = self.leakyReLU(scores_source + scores_target) 399 | # connectivity mask will put -inf on all locations where there are no edges, after applying the softmax 400 | # this will result in attention scores being computed only for existing edges 401 | all_attention_coefficients = self.softmax(all_scores + connectivity_mask) 402 | 403 | # 404 | # Step 3: Neighborhood aggregation (same as in imp1) 405 | # 406 | 407 | # batch matrix multiply, shape = (NH, N, N) * (NH, N, FOUT) -> (NH, N, FOUT) 408 | out_nodes_features = torch.bmm(all_attention_coefficients, nodes_features_proj.transpose(0, 1)) 409 | 410 | # Note: watch out here I made a silly mistake of using reshape instead of permute thinking it will 411 | # end up doing the same thing, but it didn't! The acc on Cora didn't go above 52%! (compared to reported ~82%) 412 | # shape = (N, NH, FOUT) 413 | out_nodes_features = out_nodes_features.permute(1, 0, 2) 414 | 415 | # 416 | # Step 4: Residual/skip connections, concat and bias (same as in imp1) 417 | # 418 | 419 | out_nodes_features = self.skip_concat_bias(all_attention_coefficients, in_nodes_features, out_nodes_features) 420 | return (out_nodes_features, connectivity_mask) 421 | 422 | 423 | class GATLayerImp1(GATLayer): 424 | """ 425 | This implementation is only suitable for a transductive setting. 426 | It would be fairly easy to make it work in the inductive setting as well but the purpose of this layer 427 | is more educational since it's way less efficient than implementation 3. 428 | 429 | """ 430 | def __init__(self, num_in_features, num_out_features, num_of_heads, concat=True, activation=nn.ELU(), 431 | dropout_prob=0.6, add_skip_connection=True, bias=True, log_attention_weights=False): 432 | 433 | super().__init__(num_in_features, num_out_features, num_of_heads, LayerType.IMP1, concat, activation, dropout_prob, 434 | add_skip_connection, bias, log_attention_weights) 435 | 436 | def forward(self, data): 437 | # 438 | # Step 1: Linear Projection + regularization 439 | # 440 | 441 | in_nodes_features, connectivity_mask = data # unpack data 442 | num_of_nodes = in_nodes_features.shape[0] 443 | assert connectivity_mask.shape == (num_of_nodes, num_of_nodes), \ 444 | f'Expected connectivity matrix with shape=({num_of_nodes},{num_of_nodes}), got shape={connectivity_mask.shape}.' 445 | 446 | # shape = (N, FIN) where N - number of nodes in the graph, FIN number of input features per node 447 | # We apply the dropout to all of the input node features (as mentioned in the paper) 448 | in_nodes_features = self.dropout(in_nodes_features) 449 | 450 | # shape = (1, N, FIN) * (NH, FIN, FOUT) -> (NH, N, FOUT) where NH - number of heads, FOUT num of output features 451 | # We project the input node features into NH independent output features (one for each attention head) 452 | nodes_features_proj = torch.matmul(in_nodes_features.unsqueeze(0), self.proj_param) 453 | 454 | nodes_features_proj = self.dropout(nodes_features_proj) # in the official GAT imp they did dropout here as well 455 | 456 | # 457 | # Step 2: Edge attention calculation 458 | # 459 | 460 | # Apply the scoring function (* represents element-wise (a.k.a. Hadamard) product) 461 | # batch matrix multiply, shape = (NH, N, FOUT) * (NH, FOUT, 1) -> (NH, N, 1) 462 | scores_source = torch.bmm(nodes_features_proj, self.scoring_fn_source) 463 | scores_target = torch.bmm(nodes_features_proj, self.scoring_fn_target) 464 | 465 | # shape = (NH, N, 1) + (NH, 1, N) -> (NH, N, N) with the magic of automatic broadcast <3 466 | # In Implementation 3 we are much smarter and don't have to calculate all NxN scores! (only E!) 467 | # Tip: it's conceptually easier to understand what happens here if you delete the NH dimension 468 | all_scores = self.leakyReLU(scores_source + scores_target.transpose(1, 2)) 469 | # connectivity mask will put -inf on all locations where there are no edges, after applying the softmax 470 | # this will result in attention scores being computed only for existing edges 471 | all_attention_coefficients = self.softmax(all_scores + connectivity_mask) 472 | 473 | # 474 | # Step 3: Neighborhood aggregation 475 | # 476 | 477 | # shape = (NH, N, N) * (NH, N, FOUT) -> (NH, N, FOUT) 478 | out_nodes_features = torch.bmm(all_attention_coefficients, nodes_features_proj) 479 | 480 | # shape = (N, NH, FOUT) 481 | out_nodes_features = out_nodes_features.transpose(0, 1) 482 | 483 | # 484 | # Step 4: Residual/skip connections, concat and bias (same across all the implementations) 485 | # 486 | 487 | out_nodes_features = self.skip_concat_bias(all_attention_coefficients, in_nodes_features, out_nodes_features) 488 | return (out_nodes_features, connectivity_mask) 489 | 490 | 491 | # 492 | # Helper functions 493 | # 494 | def get_layer_type(layer_type): 495 | assert isinstance(layer_type, LayerType), f'Expected {LayerType} got {type(layer_type)}.' 496 | 497 | if layer_type == LayerType.IMP1: 498 | return GATLayerImp1 499 | elif layer_type == LayerType.IMP2: 500 | return GATLayerImp2 501 | elif layer_type == LayerType.IMP3: 502 | return GATLayerImp3 503 | else: 504 | raise Exception(f'Layer type {layer_type} not yet supported.') -------------------------------------------------------------------------------- /dual_flow/dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | import random 5 | import pickle 6 | import logging 7 | import numpy as np 8 | from tqdm import tqdm 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from utils import squeeze_lst, get_ner_entity, get_cmekg_entity_specific 12 | 13 | logger = logging.getLogger(__name__) 14 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 15 | datefmt = '%d %H:%M:%S', 16 | level = logging.INFO) 17 | 18 | def load_data(data_name, data_type): 19 | with open(f"../data/{data_name}_{data_type}_kg_entity_specific.pk", "rb") as f: 20 | data = pickle.load(f) 21 | 22 | logger.info(f"Num of {data_name} {data_type} dialogues: %d", len(data)) 23 | return data 24 | 25 | def get_turn_range(raw_ids): 26 | length_lst = [len(x) for x in raw_ids] #length of each turn 27 | 28 | turn_range = [] 29 | end = 1 30 | while end < len(raw_ids): 31 | for start in range(0, end): 32 | if np.sum(length_lst[start:end]) <= 512: 33 | break 34 | turn_range.append([start, end]) 35 | end += 2 36 | return turn_range 37 | 38 | def get_context_samples(raw_ids, turn_range, start_token=None): 39 | ids = [] 40 | idx = [] 41 | for (start, end) in turn_range: 42 | ids_tmp = [start_token] #one sample 43 | idx_tmp = [] 44 | for i in range(start, end): 45 | c_start = len(ids_tmp) 46 | ids_tmp += raw_ids[i] 47 | c_end = len(ids_tmp) 48 | idx_tmp.append((c_start, c_end)) 49 | 50 | if len(ids_tmp) > 512: 51 | ids_tmp = ids_tmp[:512] 52 | idx_tmp[-1] = (idx_tmp[-1][0], 512) 53 | 54 | assert len(ids_tmp) <= 512 55 | assert len(ids_tmp) == idx_tmp[-1][1] 56 | 57 | ids.append(ids_tmp) 58 | idx.append(idx_tmp) 59 | return ids, idx #multiple samples 60 | 61 | def process_data(data_name, 62 | data_type, 63 | mode, 64 | raw_data, 65 | entity_lst, 66 | entity_dict, 67 | tokenizer): 68 | 69 | #ner entity 70 | _, ner_entity_dict = get_ner_entity(data_name) 71 | 72 | #sub kg 73 | with open(f"../data/{data_name}_{data_type}_sub_kg.pk", "rb") as f: 74 | sub_kg_dict = pickle.load(f) 75 | with open(f"../data/{data_name}_{data_type}_sub_kg_center.pk", "rb") as f: 76 | center_node_dict = pickle.load(f) 77 | 78 | #role dict 79 | role_token_id = {"patient": 1, "doctor": 2} 80 | 81 | #act dict 82 | act_id = {"INQUIRE": 0, "DIAGNOSIS": 1, "TREATMENT": 2, "TEST": 3, "PRECAUTION": 4, "INFORM": 5, "CHITCHAT": 6} 83 | 84 | samples_context_ids = [] 85 | samples_context_idx = [] 86 | samples_context_raw = [] 87 | samples_target_raw = [] 88 | 89 | samples_act_turn_idx = [] 90 | samples_entity_turn_idx = [] 91 | 92 | samples_act_labels = [] 93 | samples_entity_labels = [] 94 | samples_target_kg_entity_idx = [] 95 | samples_neg_kg_entity_idx = [] 96 | 97 | samples_idx = [] 98 | samples_turn_idx = [] 99 | for sample in tqdm(raw_data, desc=f"Construct {data_type} samples"): 100 | 101 | dialogue_history = sample["dialogues"] 102 | if dialogue_history[-1]["role"] == "patient": #all conversations end with doctor utterances 103 | del dialogue_history[-1] 104 | 105 | assert len(dialogue_history) % 2 == 0 106 | 107 | context_ids_tmp = [] 108 | act_raw = [] 109 | entity_raw = [] 110 | for i, turn in enumerate(dialogue_history): 111 | context_ids_turn_tmp = [role_token_id[turn["role"]]] #add role token 112 | tokens = tokenizer.tokenize(turn["sentence"]) 113 | ids = tokenizer.convert_tokens_to_ids(tokens) 114 | context_ids_turn_tmp += ids #dialogue input ids 115 | 116 | extra_length = max(len(context_ids_turn_tmp) - 512, 0) #confirm exceeding length 117 | if extra_length > 0 and turn["role"] == "patient": 118 | context_ids_turn_tmp = context_ids_turn_tmp[:-extra_length] 119 | context_ids_tmp.append(context_ids_turn_tmp) 120 | 121 | #save raw entity 122 | entity_raw.append([entity_dict[x] for x in turn["sub_kg_entity"]]) #the entity of the current turn and the entity of sub-kg 123 | assert len(turn["sub_kg_entity"]) == len(set(turn["sub_kg_entity"])) 124 | 125 | #save raw act 126 | turn_act_lst = set() 127 | for act_lst in turn["act"]: 128 | turn_act_lst.update(act_lst) 129 | turn_act_lst = [act_id[act] for act in act_id if act in turn_act_lst] 130 | act_raw.append(turn_act_lst) 131 | 132 | if turn["role"] == "doctor": 133 | #save original target 134 | samples_target_raw.append(turn["sentence"]) 135 | 136 | #save original context 137 | context_raw_tmp = [] 138 | for j in range(i): 139 | context_raw_tmp.append(dialogue_history[j]["sentence"]) 140 | samples_context_raw.append({ 141 | "idx": sample["idx"], 142 | "turn_idx": str(i), 143 | "context": context_raw_tmp, 144 | "entity": turn["entity"], 145 | "target_kg_entity": turn["target_kg_entity"], 146 | "response": turn["sentence"], 147 | }) 148 | 149 | sub_kg = sub_kg_dict[sample["idx"] + "_" + str(turn["turn"])] 150 | center_node = center_node_dict[sample["idx"] + "_" + str(turn["turn"])] 151 | 152 | #save target/negative kg entity 153 | target_kg_entity_lst = copy.deepcopy(turn["target_kg_entity"]) 154 | target_kg_entity_lst = [x[0] for x in target_kg_entity_lst] 155 | neg_tmp = [x for x in sub_kg if x not in center_node and x not in target_kg_entity_lst] #in sub-kg, entities that do not belong to the central node and target node are negative. 156 | neg_tmp.sort() #sort it in order for easy reproduction 157 | neg_n = 20 158 | if len(target_kg_entity_lst) > 0: 159 | if len(neg_tmp) >= neg_n: 160 | neg_kg_entity_lst = random.sample(neg_tmp, neg_n) 161 | else: 162 | neg_kg_entity_lst = neg_tmp + random.sample(entity_lst, neg_n - len(neg_tmp)) 163 | else: 164 | neg_kg_entity_lst = [] 165 | 166 | kg_entity_idx = squeeze_lst([entity_dict[x] for x in target_kg_entity_lst]) 167 | samples_target_kg_entity_idx.append(copy.deepcopy(kg_entity_idx)) 168 | kg_entity_idx = [entity_dict[x] for x in neg_kg_entity_lst] 169 | samples_neg_kg_entity_idx.append(copy.deepcopy(kg_entity_idx)) 170 | 171 | #act label 172 | act_labels = [0] * 7 173 | for act in turn_act_lst: 174 | act_labels[act] = 1 175 | samples_act_labels.append(act_labels) 176 | 177 | #entity label 178 | entity_labels = [0] * len(ner_entity_dict) 179 | for entity in turn["entity"]: 180 | entity_labels[ner_entity_dict[entity]] = 1 181 | samples_entity_labels.append(entity_labels) 182 | 183 | turn_range = get_turn_range(context_ids_tmp) #get the context turn range that meets the encoding length requirements (<=512) 184 | context_ids, context_idx = get_context_samples(context_ids_tmp, turn_range, start_token=tokenizer.cls_token_id) #contains multiple samples 185 | 186 | samples_context_ids += context_ids 187 | samples_context_idx += context_idx 188 | 189 | for i, (start, end) in enumerate(turn_range): 190 | act_turn_idx = act_raw[start:end] 191 | entity_turn_idx = entity_raw[start:end] 192 | samples_act_turn_idx.append(copy.deepcopy(act_turn_idx)) 193 | samples_entity_turn_idx.append(copy.deepcopy(entity_turn_idx)) 194 | 195 | turn_idx = [str(end) for (start, end) in turn_range] 196 | samples_turn_idx += turn_idx 197 | samples_idx += [sample["idx"]] * len(context_ids) 198 | 199 | return (samples_context_ids, samples_context_idx, 200 | samples_act_turn_idx, samples_entity_turn_idx, 201 | samples_act_labels, samples_entity_labels, 202 | samples_target_kg_entity_idx, samples_neg_kg_entity_idx, 203 | samples_target_raw, samples_context_raw, 204 | samples_idx, samples_turn_idx) 205 | 206 | class BaseDataset(Dataset): 207 | def __init__(self, data_name, data_type, mode, tokenizer): 208 | self.data_name = data_name 209 | self.data_type = data_type 210 | self.mode = mode 211 | self.tokenizer = tokenizer 212 | self.raw_data = load_data(data_name, data_type) 213 | self.entity_lst, self.entity_dict, _ = get_cmekg_entity_specific(data_name) 214 | (self.samples_context_ids, self.samples_context_idx, 215 | self.samples_act_turn_idx, self.samples_entity_turn_idx, 216 | self.samples_act_labels, self.samples_entity_labels, 217 | self.samples_target_kg_entity_idx, self.samples_neg_kg_entity_idx, 218 | self.samples_target_raw, self.samples_context_raw, 219 | self.samples_idx, self.samples_turn_idx)= process_data(data_name, data_type, mode, self.raw_data, self.entity_lst, self.entity_dict, tokenizer) 220 | 221 | def __len__(self): 222 | assert len(self.samples_context_ids) == len(self.samples_context_idx) 223 | assert len(self.samples_idx) == len(self.samples_turn_idx) 224 | assert len(self.samples_context_raw) == len(self.samples_target_raw) 225 | return len(self.samples_context_ids) 226 | 227 | def __getitem__(self, item): 228 | lst_data = { 229 | "context_ids": self.samples_context_ids[item], 230 | "context_idx": self.samples_context_idx[item], 231 | "act_turn_idx": self.samples_act_turn_idx[item], 232 | "entity_turn_idx": self.samples_entity_turn_idx[item], 233 | "idx": self.samples_idx[item], 234 | "turn_idx": self.samples_turn_idx[item], 235 | "act_labels": self.samples_act_labels[item], 236 | "entity_labels": self.samples_entity_labels[item], 237 | "target_kg_entity_idx": self.samples_target_kg_entity_idx[item], 238 | "neg_kg_entity_idx": self.samples_neg_kg_entity_idx[item], 239 | } 240 | return lst_data 241 | 242 | def pack_tensor_2D(raw_lst, default, dtype, length=None): 243 | batch_size = len(raw_lst) 244 | length = length if length is not None else max(len(raw) for raw in raw_lst) 245 | tensor = default * torch.ones((batch_size, length), dtype=dtype) 246 | for i, raw in enumerate(raw_lst): 247 | tensor[i, :len(raw)] = torch.tensor(raw, dtype=dtype) 248 | return tensor 249 | 250 | def get_collate_function(data_name): 251 | def collate_function(batch): 252 | context_ids_lst = [x["context_ids"] for x in batch] 253 | context_idx_lst = [x["context_idx"] for x in batch] 254 | context_mask_lst = [[1] * len(context_ids) for context_ids in context_ids_lst] 255 | 256 | act_turn_idx_lst = [x["act_turn_idx"] for x in batch] 257 | entity_turn_idx_lst = [x["entity_turn_idx"] for x in batch] 258 | 259 | # collect all entity idx in batch 260 | index = 0 261 | batch_entity_turn_idx_lst = [] 262 | batch_entity_turn_idx_dict = {} 263 | tmp_idx_lst = set() 264 | for i in range(len(entity_turn_idx_lst)): 265 | for k in range(len(entity_turn_idx_lst[i])): 266 | for item in entity_turn_idx_lst[i][k]: 267 | if item not in tmp_idx_lst: 268 | batch_entity_turn_idx_lst.append(item) 269 | batch_entity_turn_idx_dict[item] = index 270 | tmp_idx_lst.add(item) 271 | index += 1 272 | 273 | # replace entity idx in batch with new idx 274 | new_entity_turn_idx_lst = [] 275 | for i in range(len(entity_turn_idx_lst)): 276 | new_entity_turn_idx_lst_i = [] 277 | for k in range(len(entity_turn_idx_lst[i])): 278 | new_entity_turn_idx_lst_k = [] 279 | for item in entity_turn_idx_lst[i][k]: 280 | new_entity_turn_idx_lst_k.append(batch_entity_turn_idx_dict[item]) 281 | new_entity_turn_idx_lst_i.append(new_entity_turn_idx_lst_k) 282 | new_entity_turn_idx_lst.append(new_entity_turn_idx_lst_i) 283 | 284 | act_labels_lst = [x["act_labels"] for x in batch] 285 | entity_labels_lst = [x["entity_labels"] for x in batch] 286 | 287 | target_kg_entity_idx_lst = [x["target_kg_entity_idx"] for x in batch] 288 | neg_kg_entity_idx_lst = [x["neg_kg_entity_idx"] for x in batch] 289 | 290 | data = { 291 | "input_ids": pack_tensor_2D(context_ids_lst, default=0, dtype=torch.int64), 292 | "attention_mask": pack_tensor_2D(context_mask_lst, default=0, dtype=torch.int64), 293 | "context_idx": context_idx_lst, 294 | "act_turn_idx": act_turn_idx_lst, 295 | "entity_turn_idx": new_entity_turn_idx_lst, 296 | "batch_entity_turn_idx": batch_entity_turn_idx_lst, 297 | "act_labels": torch.tensor(act_labels_lst, dtype=torch.float32), 298 | "entity_labels": torch.tensor(entity_labels_lst, dtype=torch.float32), 299 | "target_kg_entity_idx": target_kg_entity_idx_lst, 300 | "neg_kg_entity_idx": neg_kg_entity_idx_lst, 301 | } 302 | 303 | idx = [x["idx"] for x in batch] 304 | turn_idx = [x["turn_idx"] for x in batch] 305 | return data, idx, turn_idx 306 | return collate_function 307 | 308 | def construct_data(args, data_type, mode, per_gpu_batch_size, tokenizer, data_sampler): 309 | batch_size = per_gpu_batch_size * max(1, args.n_gpu) 310 | dataset = BaseDataset(args.data_name, data_type, mode, tokenizer) 311 | sampler = data_sampler(dataset) 312 | collate_fn = get_collate_function(args.data_name) 313 | dataloader = DataLoader(dataset, sampler=sampler, 314 | batch_size=batch_size, num_workers=args.data_num_workers, collate_fn=collate_fn) 315 | return dataset, dataloader, batch_size 316 | -------------------------------------------------------------------------------- /dual_flow/dataset_entity.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import logging 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | from utils import get_cmekg_entity_specific 7 | 8 | logger = logging.getLogger(__name__) 9 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 10 | datefmt = '%d %H:%M:%S', 11 | level = logging.INFO) 12 | 13 | class EntityDataset(Dataset): 14 | def __init__(self, args, tokenizer): 15 | self.tokenizer = tokenizer 16 | self.args = args 17 | entity_lst, _, entity_type_dict = get_cmekg_entity_specific(args.data_name) 18 | 19 | entity_ids = [] 20 | for entity in entity_lst: 21 | tmp = [] 22 | 23 | tokens = tokenizer.tokenize(entity) 24 | ids = tokenizer.convert_tokens_to_ids(tokens) 25 | tmp += ids 26 | entity_ids.append(tmp[:512]) 27 | 28 | self.samples_entity_ids = entity_ids 29 | 30 | def __len__(self): 31 | return len(self.samples_entity_ids) 32 | 33 | def __getitem__(self, item): 34 | lst_data = { 35 | "entity_ids": self.samples_entity_ids[item], 36 | } 37 | return lst_data 38 | 39 | def pack_tensor_2D(raw_lst, default, dtype, length=None): 40 | batch_size = len(raw_lst) 41 | length = length if length is not None else max(len(raw) for raw in raw_lst) 42 | tensor = default * torch.ones((batch_size, length), dtype=dtype) 43 | for i, raw in enumerate(raw_lst): 44 | tensor[i, :len(raw)] = torch.tensor(raw, dtype=dtype) 45 | return tensor 46 | 47 | def get_collate_function(tokenizer): 48 | def collate_function(batch): 49 | entity_ids_lst = [x["entity_ids"] for x in batch] 50 | entity_mask_lst = [len(x) * [1] for x in entity_ids_lst] 51 | 52 | data = { 53 | "entity_ids": pack_tensor_2D(entity_ids_lst, default=0, dtype=torch.int64), 54 | "entity_mask": pack_tensor_2D(entity_mask_lst, default=0, dtype=torch.int64), 55 | } 56 | 57 | return data 58 | return collate_function 59 | 60 | def construct_data(args, per_gpu_batch_size, tokenizer, data_sampler): 61 | batch_size = per_gpu_batch_size * max(1, args.n_gpu) 62 | dataset = EntityDataset(args, tokenizer) 63 | sampler = data_sampler(dataset) 64 | collate_fn = get_collate_function(tokenizer) 65 | dataloader = DataLoader(dataset, sampler=sampler, 66 | batch_size=batch_size, num_workers=args.data_num_workers, collate_fn=collate_fn) 67 | return dataset, dataloader, batch_size 68 | -------------------------------------------------------------------------------- /dual_flow/eval_kamed.sh: -------------------------------------------------------------------------------- 1 | output_dir="./train" 2 | train_name="demo" 3 | checkpoint="step-*" 4 | 5 | python main.py \ 6 | --mode evaluate \ 7 | --data_name kamed \ 8 | --data_type test \ 9 | --output_dir "${output_dir}" \ 10 | --eval_model_path "${output_dir}/models/${train_name}/${checkpoint}" \ 11 | --per_gpu_eval_batch_size 16 \ 12 | --result_save_dir "${output_dir}/results/${train_name}" \ 13 | --train_name "${train_name}" \ 14 | -------------------------------------------------------------------------------- /dual_flow/eval_meddg.sh: -------------------------------------------------------------------------------- 1 | output_dir="./train" 2 | train_name="demo" 3 | checkpoint="step-*" 4 | 5 | python main.py \ 6 | --mode evaluate \ 7 | --data_name meddg \ 8 | --data_type test \ 9 | --output_dir "${output_dir}" \ 10 | --eval_model_path "${output_dir}/models/${train_name}/${checkpoint}" \ 11 | --per_gpu_eval_batch_size 16 \ 12 | --result_save_dir "${output_dir}/results/${train_name}" \ 13 | --train_name "${train_name}" \ 14 | --for_meddg_160 \ 15 | -------------------------------------------------------------------------------- /dual_flow/evaluating.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import math 5 | import pickle 6 | import torch 7 | import logging 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from utils import get_cmekg_entity_specific, get_ner_entity 12 | 13 | logger = logging.getLogger(__name__) 14 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 15 | datefmt = '%d %H:%M:%S', 16 | level = logging.INFO) 17 | 18 | act_dict = {0: "INQUIRE", 1: "DIAGNOSIS", 2: "TREATMENT", 3: "TEST", 4: "PRECAUTION", 5: "INFORM", 6: "CHITCHAT"} 19 | 20 | def calculate_metrics(scores, labels, threshold): 21 | ep = 1e-8 22 | TP = ((scores > threshold) & (labels == 1)).sum() 23 | TN = ((scores < threshold) & (labels == 0)).sum() 24 | FN = ((scores < threshold) & (labels == 1)).sum() 25 | FP = ((scores > threshold) & (labels == 0)).sum() 26 | 27 | p = TP / (TP + FP + ep) 28 | r = TP / (TP + FN + ep) 29 | f1 = 2 * r * p / (r + p + ep) 30 | acc = (TP + TN) / (TP + TN + FP + FN + ep) 31 | return f1, acc, r, p 32 | 33 | def evaluate(args, model, embeds, entity_matrix, eval_dataset, eval_dataloader, prefix): 34 | # multi-gpu eval 35 | if args.n_gpu > 1: 36 | model = torch.nn.DataParallel(model) 37 | 38 | # Eval 39 | logger.info("***** Running Evaluation {} *****".format(prefix)) 40 | logger.info(" Num examples = %d", len(eval_dataset)) 41 | logger.info(" Batch size = %d", args.eval_batch_size) 42 | 43 | test_loss = 0 44 | test_act_loss = 0 45 | test_entity_loss = 0 46 | test_act_scores = [] 47 | test_act_labels = [] 48 | test_entity_scores = [] 49 | test_entity_labels = [] 50 | test_idx = [] 51 | test_turn_idx = [] 52 | 53 | check_lst = [] 54 | context_embeds = [] 55 | 56 | # gpu memory is not enough so GAT is applied on cpu 57 | # to update all entity embeddings 58 | embeds_tmp_save_path = f"{args.model_save_dir}/{prefix}/{args.data_name}_{eval_dataset.data_type}_embeds_tmp.pk" 59 | if os.path.exists(embeds_tmp_save_path): 60 | with open(embeds_tmp_save_path, "rb") as infile: 61 | embeds_tmp = pickle.load(infile) 62 | else: 63 | embeds_tmp = embeds.to(torch.device("cpu")) 64 | with torch.no_grad(): 65 | model.entity_gat.to(torch.device("cpu")) 66 | embeds_tmp, _ = model.entity_gat((embeds_tmp, entity_matrix)) 67 | with open(embeds_tmp_save_path, "wb") as outfile: 68 | pickle.dump(embeds_tmp, outfile) 69 | embeds_tmp = embeds_tmp.to(args.device) 70 | model.entity_gat.to(args.device) 71 | 72 | model.eval() 73 | for i, (batch, idx, turn_idx) in enumerate(tqdm(eval_dataloader, desc="Evaluating")): 74 | 75 | with torch.no_grad(): 76 | context_idx = batch["context_idx"] 77 | act_turn_idx = batch["act_turn_idx"] 78 | entity_turn_idx = batch["entity_turn_idx"] 79 | batch_entity_turn_idx = batch["batch_entity_turn_idx"] 80 | target_kg_entity_idx = batch["target_kg_entity_idx"] 81 | neg_kg_entity_idx = batch["neg_kg_entity_idx"] 82 | batch = {k:v.to(args.device) for k, v in batch.items() if "idx" not in k and "entity_ids" not in k and "entity_mask" not in k} 83 | 84 | if not args.for_meddg_160: 85 | batch["entity_labels"] = None 86 | 87 | outputs = model(**batch, 88 | entity_embeds=embeds_tmp, 89 | context_idx=context_idx, 90 | act_turn_idx=act_turn_idx, 91 | entity_turn_idx=entity_turn_idx, 92 | batch_entity_turn_idx=batch_entity_turn_idx, 93 | target_kg_entity_idx=target_kg_entity_idx, 94 | neg_kg_entity_idx=neg_kg_entity_idx) 95 | 96 | loss, act_loss, entity_loss, scores_n_labels, context_hiddens = outputs 97 | act_scores, act_labels, entity_scores, entity_labels = scores_n_labels 98 | test_loss += loss 99 | test_act_loss += act_loss 100 | test_entity_loss += entity_loss 101 | test_act_scores.append(act_scores) 102 | test_act_labels.append(act_labels) 103 | test_entity_scores.append(entity_scores) 104 | test_entity_labels.append(entity_labels) 105 | test_idx += idx 106 | test_turn_idx += turn_idx 107 | 108 | for i, (sample_idx, sample_turn_idx) in enumerate(zip(idx, turn_idx)): 109 | check_lst.append(sample_idx + "_" + sample_turn_idx) 110 | context_embeds.append(context_hiddens) 111 | 112 | if os.path.exists(f"{args.log_dir}/{args.data_name}_{eval_dataset.data_type}_best_metric.json"): 113 | with open(f"{args.log_dir}/{args.data_name}_{eval_dataset.data_type}_best_metric.json", "r") as infile: 114 | best_metric = json.load(infile) 115 | else: 116 | best_metric = { 117 | "act": { 118 | "INQUIRE": {"f1": 0, "acc": 0, "recall": 0, "precision": 0, "checkpoint": ""}, 119 | "DIAGNOSIS": {"f1": 0, "acc": 0, "recall": 0, "precision": 0, "checkpoint": ""}, 120 | "TREATMENT": {"f1": 0, "acc": 0, "recall": 0, "precision": 0, "checkpoint": ""}, 121 | "TEST": {"f1": 0, "acc": 0, "recall": 0, "precision": 0, "checkpoint": ""}, 122 | "PRECAUTION": {"f1": 0, "acc": 0, "recall": 0, "precision": 0, "checkpoint": ""}, 123 | "INFORM": {"f1": 0, "acc": 0, "recall": 0, "precision": 0, "checkpoint": ""}, 124 | "CHITCHAT": {"f1": 0, "acc": 0, "recall": 0, "precision": 0, "checkpoint": ""}, 125 | }, 126 | "entity": {"f1": 0, "acc": 0, "recall": 0, "precision": 0, "checkpoint": ""} 127 | } 128 | 129 | if args.result_save_dir: 130 | result_save_dir = args.result_save_dir 131 | else: 132 | result_save_dir = f"{args.model_save_dir}/{prefix}" 133 | 134 | if not os.path.exists(result_save_dir): 135 | os.makedirs(result_save_dir) 136 | 137 | ################ 138 | # Act Evaulate # 139 | ################ 140 | 141 | test_act_scores = torch.cat(test_act_scores, dim=0).cpu().numpy() 142 | test_act_labels = torch.cat(test_act_labels, dim=0).cpu().numpy() 143 | 144 | tshd = [0.4, 0.2, 0.2, 0.2, 0.2, 0.3, 0.5] 145 | tshd = np.ones_like(test_act_labels) * np.array(tshd) 146 | 147 | context_idx_lst = [] 148 | for idx, turn_idx in zip(test_idx, test_turn_idx): 149 | context_idx_lst.append(str(idx) + "_" + str(turn_idx)) 150 | 151 | pred_labels = test_act_scores > tshd 152 | act_predicted_dict = dict() 153 | for labels, context_idx in zip(pred_labels, context_idx_lst): 154 | act_predicted_dict[context_idx] = [] 155 | for i in range(7): 156 | if labels[i]: 157 | act_predicted_dict[context_idx].append(act_dict[i]) 158 | with open(f"{result_save_dir}/{args.data_name}_{eval_dataset.data_type}_predicted_act.pk", "wb") as outfile: 159 | pickle.dump((test_act_scores, test_act_labels, context_idx_lst, act_predicted_dict), outfile) 160 | 161 | act_f1 = [] # individual metric for each act 162 | act_acc = [] 163 | act_r = [] 164 | act_p = [] 165 | for i in range(7): 166 | f1, acc, r, p = calculate_metrics(test_act_scores[:,i], test_act_labels[:,i], tshd[:,i]) 167 | act_f1.append(f1) 168 | act_acc.append(acc) 169 | act_r.append(r) 170 | act_p.append(p) 171 | 172 | ################# 173 | # Save Best Act # 174 | ################# 175 | 176 | if f1 > best_metric["act"][act_dict[i]]["f1"]: 177 | best_metric["act"][act_dict[i]]["f1"] = f1 178 | best_metric["act"][act_dict[i]]["acc"] = acc 179 | best_metric["act"][act_dict[i]]["recall"] = r 180 | best_metric["act"][act_dict[i]]["precision"] = p 181 | best_metric["act"][act_dict[i]]["checkpoint"] = prefix 182 | 183 | ################### 184 | # Entity Evaulate # 185 | ################### 186 | 187 | if args.for_meddg_160: 188 | # For the MedDG dataset 189 | # ner entity 190 | ner_entity_lst, _ = get_ner_entity(args.data_name) 191 | 192 | test_entity_scores = torch.cat(test_entity_scores, dim=0).cpu().numpy() 193 | test_entity_labels = torch.cat(test_entity_labels, dim=0).cpu().numpy() 194 | 195 | tshd = 0.14 196 | 197 | f1, acc, r, p = calculate_metrics(test_entity_scores, test_entity_labels, tshd) 198 | ent_r = f1 199 | 200 | pred_labels = test_entity_scores > tshd 201 | entity_predicted_dict = dict() 202 | for labels, context_idx in zip(pred_labels, context_idx_lst): 203 | entity_predicted_dict[context_idx] = [] 204 | for i in range(160): 205 | if labels[i]: 206 | entity_predicted_dict[context_idx].append(ner_entity_lst[i]) 207 | 208 | with open(f"{result_save_dir}/{args.data_name}_{eval_dataset.data_type}_selected_entity.pk", "wb") as outfile: 209 | pickle.dump((test_entity_scores, test_entity_labels, context_idx_lst, entity_predicted_dict), outfile) 210 | 211 | #################### 212 | # Save Best Entity # 213 | #################### 214 | 215 | if f1 > best_metric["entity"]["f1"]: 216 | best_metric["entity"]["f1"] = f1 217 | best_metric["entity"]["acc"] = acc 218 | best_metric["entity"]["recall"] = r 219 | best_metric["entity"]["precision"] = p 220 | best_metric["entity"]["checkpoint"] = prefix 221 | 222 | else: 223 | # For the KaMed dataset 224 | # context embeds 225 | context_embeds = torch.cat(context_embeds, dim=0) 226 | 227 | # entity dict 228 | _, entity_dict, _ = get_cmekg_entity_specific(args.data_name) 229 | 230 | # 目标kg entity字典(MS MARCO格式) 231 | context_to_kg_entity = dict() 232 | with open(f"./{args.data_name}_{eval_dataset.data_type}_reference_entity.txt", "r") as f: 233 | for line in f: 234 | line = line.strip().split('\t') 235 | context_idx = line[0] 236 | if context_idx in context_to_kg_entity: 237 | pass 238 | else: 239 | context_to_kg_entity[context_idx] = set() 240 | context_to_kg_entity[context_idx].add(entity_dict[line[2]]) 241 | 242 | # 目标sub-graph 243 | with open(f"../data/{args.data_name}_{eval_dataset.data_type}_sub_kg.pk", "rb") as f: 244 | sub_kg = pickle.load(f) 245 | 246 | ent_r = 0 247 | total_num = 0 248 | entity_ranked_dict = dict() 249 | for i, context_idx in enumerate(check_lst): 250 | sub_kg_idx = [entity_dict[x] for x in sub_kg[context_idx]] #sub-graph中的候选entity 251 | if sub_kg_idx != []: 252 | sub_kg_entity_embeds = embeds[sub_kg_idx] #选择相应embed 253 | score = torch.matmul(context_embeds[i:i+1], sub_kg_entity_embeds.T).squeeze() #计算相似度 254 | rank = score.sort(descending=True).indices.tolist() #获取排序后的候选entity位置标签 255 | if isinstance(rank, int): #如果只有一个候选entity 256 | rank = [rank] 257 | entity_ranked = [sub_kg_idx[idx] for idx in rank] #位置标签对应相应entity idx 258 | entity_ranked_dict[context_idx] = entity_ranked[:50] #保存前50的entity 259 | 260 | if context_idx in context_to_kg_entity: #计算召回率 261 | count = 0 262 | for entity in entity_ranked[:20]: #评估前20的entity 263 | if entity in context_to_kg_entity[context_idx]: 264 | count += 1 265 | ent_r += count / len(context_to_kg_entity[context_idx]) #所有target中的召回比例 266 | total_num += 1 267 | else: 268 | entity_ranked_dict[context_idx] = [] 269 | ent_r = ent_r / total_num 270 | 271 | with open(f"{result_save_dir}/{args.data_name}_{eval_dataset.data_type}_ranked_entity.pk", "wb") as outfile: 272 | pickle.dump(entity_ranked_dict, outfile) 273 | 274 | #################### 275 | # Save Best Entity # 276 | #################### 277 | 278 | if ent_r > best_metric["entity"]["recall"]: 279 | best_metric["entity"]["recall"] = ent_r 280 | best_metric["entity"]["checkpoint"] = prefix 281 | 282 | # output best_metric in json file 283 | with open(f"{args.log_dir}/{args.data_name}_{eval_dataset.data_type}_best_metric.json", "w") as outfile: 284 | json.dump(best_metric, outfile, indent=4, separators=(',', ': ')) 285 | 286 | return (test_loss/len(eval_dataloader), test_act_loss/len(eval_dataloader), test_entity_loss/len(eval_dataloader), 287 | act_f1, act_acc, act_r, act_p, ent_r) 288 | -------------------------------------------------------------------------------- /dual_flow/get_entity_embed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import logging 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | logger = logging.getLogger(__name__) 8 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 9 | datefmt = '%d %H:%M:%S', 10 | level = logging.INFO) 11 | 12 | def get_entity_embeds(args, dataloader, model): 13 | embeds = [] 14 | for batch in tqdm(dataloader, desc="Get entity embedding"): 15 | with torch.no_grad(): 16 | batch = {k:v.to(args.device) for k, v in batch.items()} 17 | model.eval() 18 | outputs = model(**batch) 19 | embeds.append(outputs) 20 | embeds = torch.cat(embeds) 21 | return embeds 22 | 23 | def save_embeds(embeds, idx_lst, output_dir): 24 | with open(output_dir, "wb") as outfile: 25 | pickle.dump((embeds.cpu().numpy(), idx_lst), outfile) 26 | -------------------------------------------------------------------------------- /dual_flow/get_prediction_topk.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import pickle 5 | import numpy as np 6 | import argparse 7 | import math 8 | from typing import List 9 | 10 | act_id = {"INQUIRE": 0, "DIAGNOSIS": 1, "TREATMENT": 2, "TEST": 3, "PRECAUTION": 4, "INFORM": 5, "CHITCHAT": 6} 11 | act_dict = {0: "INQUIRE", 1: "DIAGNOSIS", 2: "TREATMENT", 3: "TEST", 4: "PRECAUTION", 5: "INFORM", 6: "CHITCHAT"} 12 | 13 | def calculate_metrics(scores, labels, threshold): 14 | ep = 1e-8 15 | TP = ((scores > threshold) & (labels == 1)).sum() 16 | TN = ((scores < threshold) & (labels == 0)).sum() 17 | FN = ((scores < threshold) & (labels == 1)).sum() 18 | FP = ((scores > threshold) & (labels == 0)).sum() 19 | 20 | p = TP / (TP + FP + ep) 21 | r = TP / (TP + FN + ep) 22 | f1 = 2 * r * p / (r + p + ep) 23 | acc = (TP + TN) / (TP + TN + FP + FN + ep) 24 | return f1, acc, r, p 25 | 26 | def get_acts(args): 27 | 28 | for _, checkpoints, _ in os.walk(args.saved_model_dir): 29 | break 30 | 31 | # origin threshold 32 | tshd = [0.4, 0.2, 0.2, 0.2, 0.2, 0.3, 0.5] 33 | 34 | # rank top-k metrics in valid set 35 | metric_dict = dict() 36 | for ckpt in checkpoints: 37 | with open(f"{args.saved_model_dir}/{ckpt}/{args.data_name}_valid_predicted_act.pk", "rb") as infile: 38 | test_scores, test_labels, test_context_idx_lst, test_predicted_dict = pickle.load(infile) 39 | 40 | for i in range(7): 41 | f1, acc, r, p = calculate_metrics(test_scores[:,i], test_labels[:,i], threshold=tshd[i]) 42 | if i not in metric_dict: 43 | metric_dict[i] = [] 44 | metric_dict[i].append((ckpt, f1)) 45 | 46 | for i in range(7): 47 | metric_dict[i] = sorted(metric_dict[i], key=lambda x: x[1], reverse=True) 48 | for topk in range(args.top_k): 49 | with open(f"{args.saved_model_dir}/{metric_dict[i][topk][0]}/{args.data_name}_{args.data_type}_predicted_act.pk", "rb") as infile: 50 | test_scores, test_labels, test_context_idx_lst, test_predicted_dict = pickle.load(infile) 51 | if not 'best_scores' in locals().keys(): 52 | best_scores = np.zeros_like(test_scores) 53 | best_scores[:, i] = np.maximum(best_scores[:, i], test_scores[:, i]) 54 | 55 | # modify threshold since the maximum may raise the overall recall rate 56 | tshd = [0.4, 0.2, 0.25, 0.2, 0.25, 0.45, 0.5] 57 | 58 | tshd = np.ones_like(best_scores) * np.array(tshd) 59 | pred_labels = best_scores > tshd 60 | predicted_dict = dict() 61 | for labels, context_idx in zip(pred_labels, test_context_idx_lst): 62 | predicted_dict[context_idx] = [] 63 | for i in range(7): 64 | if labels[i]: 65 | predicted_dict[context_idx].append(act_dict[i]) 66 | 67 | with open(f"{args.saved_result_dir}/{args.data_name}_{args.data_type}_predicted_act.pk", "wb") as outfile: 68 | pickle.dump((best_scores, test_labels, test_context_idx_lst, predicted_dict), outfile) 69 | 70 | def get_ner_entity(data_name): 71 | ner_entity_lst = [] 72 | ner_entity_dict = {} 73 | with open(f"../data/{data_name}_ner_entity/{data_name}_entity.txt", "r") as f: 74 | for i, line in enumerate(f): 75 | ner_entity_lst.append(line.strip()) 76 | ner_entity_dict[line.strip()] = i 77 | return ner_entity_lst, ner_entity_dict 78 | 79 | def get_ranked_entitis(args, best_metric): 80 | 81 | with open(f"{args.saved_model_dir}/{best_metric['entity']['checkpoint']}/{args.data_name}_{args.data_type}_ranked_entity.pk", "rb") as infile: 82 | entity_ranked_dict = pickle.load(infile) 83 | 84 | with open(f"{args.saved_result_dir}/{args.data_name}_{args.data_type}_ranked_entity.pk", "wb") as outfile: 85 | pickle.dump(entity_ranked_dict, outfile) 86 | 87 | def get_selected_entitis(args): 88 | 89 | ner_entity_lst, ner_entity_dict = get_ner_entity(args.data_name) 90 | for _, checkpoints, _ in os.walk(args.saved_model_dir): 91 | break 92 | 93 | # origin threshold 94 | tshd = 0.14 95 | 96 | # rank top-k metrics in valid set 97 | metric_lst = list() 98 | for ckpt in checkpoints: 99 | with open(f"{args.saved_model_dir}/{ckpt}/{args.data_name}_valid_selected_entity.pk", "rb") as infile: 100 | test_scores, test_labels, test_context_idx_lst, test_predicted_dict = pickle.load(infile) 101 | 102 | f1, acc, r, p = calculate_metrics(test_scores, test_labels, threshold=tshd) 103 | metric_lst.append((ckpt, f1)) 104 | 105 | metric_lst = sorted(metric_lst, key=lambda x: x[1], reverse=True) 106 | for topk in range(args.top_k): 107 | with open(f"{args.saved_model_dir}/{metric_lst[topk][0]}/{args.data_name}_{args.data_type}_selected_entity.pk", "rb") as infile: 108 | test_scores, test_labels, test_context_idx_lst, test_predicted_dict = pickle.load(infile) 109 | if not 'best_scores' in locals().keys(): 110 | best_scores = np.zeros_like(test_scores) 111 | best_scores = np.maximum(best_scores, test_scores) 112 | 113 | # modify threshold since the maximum may raise the overall recall rate 114 | tshd = 0.15 115 | 116 | pred_labels = best_scores > tshd 117 | predicted_dict = dict() 118 | for labels, context_idx in zip(pred_labels, test_context_idx_lst): 119 | predicted_dict[context_idx] = [] 120 | for i in range(160): 121 | if labels[i]: 122 | predicted_dict[context_idx].append(ner_entity_lst[i]) 123 | 124 | with open(f"{args.saved_result_dir}/{args.data_name}_{args.data_type}_selected_entity.pk", "wb") as outfile: 125 | pickle.dump((best_scores, test_labels, test_context_idx_lst, predicted_dict), outfile) 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument("--data_name", type=str, default="kamed", help="select from [kamed, meddg]") 130 | parser.add_argument("--data_type", type=str, default="train", help="select from [train, test, valid]") 131 | parser.add_argument("--log_dir", type=str, default="./train/log/demo") 132 | parser.add_argument("--saved_model_dir", type=str, default="./train/models") 133 | parser.add_argument("--saved_result_dir", type=str, default="./train/results") 134 | parser.add_argument("--top_k", type=int, default=3) 135 | parser.add_argument("--get_acts", action='store_true') 136 | parser.add_argument("--get_ranked_entitis", action='store_true') 137 | parser.add_argument("--get_selected_entitis", action='store_true') 138 | args = parser.parse_args() 139 | 140 | if not os.path.exists(args.saved_result_dir): 141 | os.makedirs(args.saved_result_dir) 142 | 143 | if args.get_acts: 144 | get_acts(args) 145 | if args.get_ranked_entitis: 146 | # load best metric based on valid set 147 | with open(f"{args.log_dir}/{args.data_name}_valid_best_metric.json", "r") as infile: 148 | best_metric = json.load(infile) 149 | get_ranked_entitis(args, best_metric) 150 | if args.get_selected_entitis: 151 | get_selected_entitis(args) 152 | -------------------------------------------------------------------------------- /dual_flow/get_prediction_topk.sh: -------------------------------------------------------------------------------- 1 | output_dir="./train" 2 | train_name="demo" 3 | 4 | # For MedDG 5 | python get_prediction_topk.py \ 6 | --data_name meddg \ 7 | --data_type test \ 8 | --log_dir "${output_dir}/log/${train_name}" \ 9 | --saved_model_dir "${output_dir}/models/${train_name}" \ 10 | --saved_result_dir "${output_dir}/results/${train_name}" \ 11 | --top_k 3 \ 12 | --get_acts \ 13 | --get_selected_entitis \ # get selected entities, only for MedDG 14 | 15 | # For KaMed 16 | # python get_prediction_topk.py \ 17 | # --data_name kamed \ 18 | # --data_type test \ 19 | # --log_dir "${output_dir}/log/${train_name}" \ 20 | # --saved_model_dir "${output_dir}/models/${train_name}" \ 21 | # --saved_result_dir "${output_dir}/results/${train_name}" \ 22 | # --top_k 3 \ 23 | # --get_acts \ 24 | # --get_ranked_entitis \ # get ranked entities, only for KaMed 25 | -------------------------------------------------------------------------------- /dual_flow/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle 3 | import torch 4 | import logging 5 | from tqdm import tqdm 6 | from parsing import run_parse_args 7 | from transformers import BertTokenizer, BartForConditionalGeneration 8 | from torch.utils.data import RandomSampler, SequentialSampler 9 | 10 | from model import ActEntityModel 11 | from dataset import construct_data 12 | from utils import set_seed, load_entity_embed, get_entity_matrix 13 | from training import train 14 | from evaluating import evaluate 15 | 16 | logger = logging.getLogger(__name__) 17 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 18 | datefmt = '%d %H:%M:%S', 19 | level = logging.INFO) 20 | 21 | act_dict = {0: "INQUIRE", 1: "DIAGNOSIS", 2: "TREATMENT", 3: "TEST", 4: "PRECAUTION", 5: "INFORM", 6: "CHITCHAT"} 22 | 23 | def output_reference_entity(args, dataset, data_type): 24 | with open(f"./{args.data_name}_{data_type}_reference_entity.txt", "w") as outfile: 25 | for sample_raw in dataset.samples_context_raw: 26 | for key in sample_raw["target_kg_entity"]: 27 | outfile.write(sample_raw["idx"] + "_" + sample_raw["turn_idx"] + "\t0\t" + str(key[0]) + "\t1\n") 28 | 29 | def main(): 30 | args = run_parse_args() 31 | 32 | # Setup CUDA, GPU 33 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 34 | args.n_gpu = torch.cuda.device_count() 35 | args.device = device 36 | 37 | # Setup logging 38 | logger.warning("Device: %s, n_gpu: %s", device, args.n_gpu) 39 | 40 | # Set seed 41 | set_seed(args) 42 | 43 | # Tokenizer 44 | tokenizer = BertTokenizer.from_pretrained(args.base_model_path) 45 | 46 | # Entity Embeddings 47 | embeds = None 48 | 49 | # Model 50 | if args.mode == "train": 51 | model_path = args.base_model_path 52 | else: 53 | model_path = args.eval_model_path 54 | embeds = load_entity_embed(model_path, args.data_name) 55 | embeds = embeds.to(args.device) 56 | model = ActEntityModel.from_pretrained(model_path) 57 | model.act_weight = args.act_weight 58 | model.entity_weight = args.entity_weight 59 | model.to(args.device) 60 | 61 | logger.info("Training/Evaluation parameters %s", args) 62 | 63 | # Train/Evaluate 64 | if args.mode == "train": 65 | train_dataset, train_dataloader, args.train_batch_size = construct_data(args, args.data_type, "train", args.per_gpu_train_batch_size, tokenizer, RandomSampler) 66 | val_dataset, val_dataloader, args.eval_batch_size = construct_data(args, "valid", "evaluate", args.per_gpu_eval_batch_size, tokenizer, SequentialSampler) 67 | test_dataset, test_dataloader, args.eval_batch_size = construct_data(args, "test", "evaluate", args.per_gpu_eval_batch_size, tokenizer, SequentialSampler) 68 | output_reference_entity(args, val_dataset, "valid") 69 | output_reference_entity(args, test_dataset, "test") 70 | train(args, model, embeds, train_dataset, train_dataloader, val_dataset, val_dataloader, test_dataset, test_dataloader, tokenizer) 71 | elif args.mode == "evaluate": 72 | prefix = args.eval_model_path.split("/")[-1] 73 | test_dataset, test_dataloader, args.eval_batch_size = construct_data(args, args.data_type, "evaluate", args.per_gpu_eval_batch_size, tokenizer, SequentialSampler) 74 | output_reference_entity(args, test_dataset, args.data_type) 75 | entity_matrix = get_entity_matrix(args.data_name, test_dataset.entity_lst) 76 | (_, _, _, act_f1, act_accuray, act_recall, act_precision, ent_recall) = evaluate(args, model, embeds, entity_matrix, test_dataset, test_dataloader, prefix) 77 | for i in range(7): 78 | print("{}: F1: {}, Acc: {}, Recall: {} Precision: {}".format(act_dict[i], round(act_f1[i], 4), round(act_accuray[i], 4), round(act_recall[i], 4), round(act_precision[i], 4))) 79 | print() 80 | print("{}: Recall: {}".format("Entity", round(ent_recall, 4))) 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /dual_flow/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers import BertForPreTraining 7 | from typing import List, Optional 8 | from GAT import GAT, LayerType 9 | 10 | class CrossAttention(nn.Module): 11 | def __init__(self, hidden_size): 12 | super(CrossAttention, self).__init__() 13 | self.hidden_size = hidden_size 14 | self.query = nn.Linear(hidden_size, hidden_size) 15 | self.key = nn.Linear(hidden_size, hidden_size) 16 | self.value = nn.Linear(hidden_size, hidden_size) 17 | self.norm = nn.LayerNorm(hidden_size) 18 | 19 | def forward(self, hidden_states, key_value_states, attention_mask=None, add_original_input=False): 20 | query_states = self.query(hidden_states) 21 | key_states = self.key(key_value_states) 22 | value_states = self.value(key_value_states) 23 | 24 | attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) 25 | attention_scores = attention_scores / math.sqrt(self.hidden_size) 26 | 27 | if attention_mask is not None: 28 | converted_attention_mask = (1.0 - attention_mask) * torch.finfo(attention_scores.dtype).min 29 | attention_scores = attention_scores + converted_attention_mask 30 | 31 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 32 | attention_output = torch.matmul(attention_probs, value_states) 33 | 34 | if add_original_input: 35 | hidden_states = self.norm(hidden_states + attention_output) 36 | return hidden_states 37 | else: 38 | return attention_output 39 | 40 | class ActEntityModel(BertForPreTraining): 41 | def __init__(self, config): 42 | super(ActEntityModel, self).__init__(config) 43 | self.hidden_size = config.hidden_size 44 | self.act_embeds = nn.Parameter(torch.empty((7, config.hidden_size))) 45 | self.none_entity_embeds = nn.Parameter(torch.empty((1, config.hidden_size))) 46 | self.none_act_embeds = nn.Parameter(torch.empty((1, config.hidden_size))) 47 | self.context_act_attn = CrossAttention(config.hidden_size) 48 | self.context_entity_attn = CrossAttention(config.hidden_size) 49 | self.act_entity_attn = CrossAttention(config.hidden_size) 50 | self.entity_act_attn = CrossAttention(config.hidden_size) 51 | 52 | gru_layer = 2 53 | self.gru_act = nn.GRU(config.hidden_size * 3, config.hidden_size, gru_layer, batch_first=True) 54 | self.gru_entity = nn.GRU(config.hidden_size * 3, config.hidden_size, gru_layer, batch_first=True) 55 | self.w = nn.Linear(config.hidden_size * gru_layer, config.hidden_size) 56 | 57 | self.act_mlp = nn.Linear(config.hidden_size, 7) 58 | self.entity_mlp = nn.Linear(config.hidden_size, config.n_entity) 59 | self.sigmoid = nn.Sigmoid() 60 | self.bceloss = nn.BCELoss() 61 | self.act_weight = 1 62 | self.entity_weight = 1 63 | 64 | # Initialize weights and apply final processing 65 | self.post_init() 66 | 67 | # Initialize GAT 68 | self.entity_gat = GAT(num_of_layers=1, num_heads_per_layer=[4], num_features_per_layer=[config.hidden_size//4], add_skip_connection=True, dropout=0.1, layer_type=LayerType.IMP2) 69 | 70 | def pooling(self, x): 71 | return torch.mean(x, dim=-2, keepdim=True) 72 | 73 | def forward( 74 | self, 75 | input_ids: torch.LongTensor = None, 76 | attention_mask: Optional[torch.Tensor] = None, 77 | context_idx: Optional[List] = None, 78 | entity_ids: Optional[torch.Tensor] = None, 79 | entity_mask: Optional[torch.Tensor] = None, 80 | entity_embeds: Optional[torch.Tensor] = None, 81 | entity_matrix: Optional[torch.Tensor] = None, 82 | act_turn_idx: Optional[List] = None, 83 | entity_turn_idx: Optional[List] = None, 84 | batch_entity_turn_idx: Optional[List] = None, 85 | act_labels: Optional[torch.Tensor] = None, 86 | entity_labels: Optional[torch.Tensor] = None, 87 | target_kg_entity_idx: Optional[List] = None, 88 | neg_kg_entity_idx: Optional[List] = None, 89 | ): 90 | 91 | if entity_ids != None: # only to calculate entity representation 92 | encoder_outputs = self.bert( 93 | input_ids=entity_ids, 94 | attention_mask=entity_mask) 95 | entity_state = encoder_outputs.last_hidden_state #M, L, D 96 | entity_state = (entity_state * entity_mask[:,:,None]).sum(dim=1) / entity_mask.sum(dim=1)[:,None] 97 | return entity_state 98 | 99 | encoder_outputs = self.bert( 100 | input_ids=input_ids, 101 | attention_mask=attention_mask) 102 | context_state = encoder_outputs.last_hidden_state #B, L, D 103 | batch_size, _, _ = context_state.shape 104 | 105 | # get all entity embeds in batch (idx based on all entity list) 106 | batch_entity_embeds = entity_embeds[batch_entity_turn_idx] 107 | # GAT for entity embeds if training 108 | # no need to update entity embeds if not training, since embeds have been updated by GAT before evaluation loop. 109 | if self.training: 110 | batch_entity_matrix = entity_matrix[batch_entity_turn_idx,:][:,batch_entity_turn_idx] 111 | batch_entity_matrix = batch_entity_matrix.to(entity_embeds.device) 112 | batch_entity_embeds, _ = self.entity_gat((batch_entity_embeds, batch_entity_matrix)) 113 | 114 | act_loss = 0 115 | entity_loss = 0 116 | act_scores_batch = [] 117 | entity_scores_batch = [] 118 | pos_embeds = [] 119 | neg_embeds = [] 120 | entity_hiddens = [] 121 | entity_hiddens_batch = [] 122 | for i in range(batch_size): 123 | 124 | final_act_state_lst = [] 125 | final_entity_state_lst = [] 126 | act_state_turn_lst = [] 127 | act_state_turn_one_lst = [] 128 | entity_state_turn_lst = [] 129 | entity_state_turn_one_lst = [] 130 | n_turn = len(context_idx[i]) 131 | for k in range(n_turn): 132 | 133 | if entity_turn_idx[i][k] != []: 134 | turn_entity_embeds = batch_entity_embeds[entity_turn_idx[i][k]] # (idx based on batch entity list) 135 | 136 | entity_state_turn_lst.append(turn_entity_embeds) 137 | entity_state_turn_one_lst.append(turn_entity_embeds) 138 | else: 139 | entity_state_turn_lst.append(self.none_entity_embeds) 140 | entity_state_turn_one_lst.append(self.none_entity_embeds) 141 | 142 | if act_turn_idx[i][k] != []: 143 | act_state_turn_lst.append(self.act_embeds[act_turn_idx[i][k]]) 144 | act_state_turn_one_lst.append(self.act_embeds[act_turn_idx[i][k]]) 145 | else: 146 | if act_state_turn_lst == []: #there is no act in the first turn 147 | act_state_turn_lst.append(self.none_act_embeds) 148 | act_state_turn_one_lst.append(self.none_act_embeds) 149 | act_state_turn = torch.cat(act_state_turn_lst, dim=0) #act in previous turns 150 | act_state_turn_one = torch.cat(act_state_turn_one_lst, dim=0) #act in current turn 151 | entity_state_turn = torch.cat(entity_state_turn_lst, dim=0) #entity in previous turn 152 | entity_state_turn_one = torch.cat(entity_state_turn_one_lst, dim=0) #entity in current turn 153 | context_state_turn = context_state[i][:context_idx[i][k][1]] #context 154 | 155 | context_act_state = self.context_act_attn(self.pooling(context_state_turn), act_state_turn_one, add_original_input=True) 156 | context_entity_state = self.context_entity_attn(self.pooling(context_state_turn), entity_state_turn_one, add_original_input=True) 157 | 158 | entity_act_state = self.entity_act_attn(self.pooling(entity_state_turn_one), act_state_turn, add_original_input=True) 159 | act_entity_state = self.act_entity_attn(self.pooling(act_state_turn_one), entity_state_turn, add_original_input=True) 160 | 161 | final_act_state_turn = torch.cat([self.pooling(act_state_turn_one), context_act_state, act_entity_state], dim=-1) 162 | final_entity_state_turn = torch.cat([self.pooling(entity_state_turn_one), context_entity_state, entity_act_state], dim=-1) 163 | 164 | final_act_state_lst.append(final_act_state_turn) 165 | final_entity_state_lst.append(final_entity_state_turn) 166 | 167 | entity_state_turn_one_lst = [] #reset 168 | act_state_turn_one_lst = [] #reset 169 | 170 | final_act_state = torch.stack(final_act_state_lst, dim=1) 171 | final_entity_state = torch.stack(final_entity_state_lst, dim=1) 172 | _, act_hidden = self.gru_act(final_act_state) 173 | _, entity_hidden = self.gru_entity(final_entity_state) 174 | act_hidden = self.w(act_hidden.reshape(1, -1)) 175 | entity_hidden = self.w(entity_hidden.reshape(1, -1)) 176 | 177 | entity_hiddens_batch.append(entity_hidden) 178 | if target_kg_entity_idx[i] != []: 179 | entity_hiddens.append(entity_hidden.repeat(len(target_kg_entity_idx[i]), 1)) 180 | pos_embeds.append(entity_embeds[target_kg_entity_idx[i]]) 181 | neg_embeds.append(entity_embeds[neg_kg_entity_idx[i]]) 182 | 183 | # act loss 184 | act_scores = self.sigmoid(self.act_mlp(act_hidden)) 185 | act_loss_tmp = self.bceloss(act_scores, act_labels[i:i+1]) 186 | act_loss += act_loss_tmp 187 | act_scores_batch.append(act_scores) 188 | 189 | # entity loss (for meddg 160) 190 | if entity_labels != None: 191 | entity_scores = self.sigmoid(self.entity_mlp(entity_hidden)) 192 | entity_loss_tmp = self.bceloss(entity_scores, entity_labels[i:i+1]) 193 | entity_loss += entity_loss_tmp 194 | entity_scores_batch.append(entity_scores) 195 | 196 | entity_loss = entity_loss / batch_size 197 | act_loss = act_loss / batch_size 198 | 199 | # entity loss (for ranking) 200 | if entity_hiddens != [] and entity_labels == None: 201 | # cosine similarity 202 | entity_hiddens = torch.cat(entity_hiddens, dim=0) 203 | pos_embeds = torch.cat(pos_embeds, dim=0) 204 | neg_embeds = torch.cat(neg_embeds, dim=0) 205 | extra_neg = torch.matmul(entity_hiddens, neg_embeds.T) 206 | 207 | logit_matrix = torch.cat([(entity_hiddens * pos_embeds).sum(-1).unsqueeze(1), extra_neg], dim=1) # [B, 1 + B] 208 | lsm = F.log_softmax(logit_matrix, dim=1) 209 | entity_loss = (-1.0 * lsm[:, 0]).mean() 210 | 211 | loss = act_loss * self.act_weight + entity_loss * self.entity_weight 212 | if entity_scores_batch == []: 213 | return (loss, act_loss, entity_loss, 214 | (torch.cat(act_scores_batch, dim=0), act_labels, None, None), 215 | torch.cat(entity_hiddens_batch, dim=0)) 216 | else: 217 | return (loss, act_loss, entity_loss, 218 | (torch.cat(act_scores_batch, dim=0), act_labels, torch.cat(entity_scores_batch, dim=0), entity_labels), 219 | None) 220 | -------------------------------------------------------------------------------- /dual_flow/parsing.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | 4 | def run_parse_args(): 5 | parser = argparse.ArgumentParser() 6 | 7 | ## Required parameters 8 | parser.add_argument("--mode", type=str, default="train", help="select from [train, evaluate]") 9 | parser.add_argument("--output_dir", type=str, default="./train") 10 | parser.add_argument("--base_model_path", type=str, default="../medbert-kd-chinese") 11 | parser.add_argument("--data_name", type=str, default="kamed", help="select from [kamed, meddg]") 12 | parser.add_argument("--data_type", type=str, default="train", help="select from [train, test, valid]") 13 | parser.add_argument("--train_name", type=str, default=None, help="for recording the objective of this training") 14 | parser.add_argument("--act_weight", type=float, default=1, help="for balancing loss") 15 | parser.add_argument("--entity_weight", type=float, default=1, help="for balancing loss") 16 | parser.add_argument("--result_save_dir", type=str, default=None, help="the directory of the prediction results") 17 | parser.add_argument("--for_meddg_160", action='store_true') 18 | 19 | ## General parameters 20 | parser.add_argument("--eval_model_path", type=str, default=None) 21 | parser.add_argument("--per_gpu_eval_batch_size", default=16, type=int) 22 | parser.add_argument("--per_gpu_train_batch_size", default=16, type=int) 23 | parser.add_argument("--gradient_accumulation_steps", type=int, default=2) 24 | parser.add_argument("--entity_update_steps", type=int, default=1) 25 | 26 | parser.add_argument("--no_cuda", action='store_true') 27 | parser.add_argument('--seed', type=int, default=42) 28 | 29 | parser.add_argument("--evaluate_during_training", action="store_true") 30 | parser.add_argument("--logging_steps", type=int, default=100) 31 | parser.add_argument("--data_num_workers", default=0, type=int) 32 | 33 | parser.add_argument("--lr", default=1e-5, type=float) 34 | parser.add_argument("--weight_decay", default=0.01, type=float) 35 | parser.add_argument("--warmup_steps", default=1000, type=int) 36 | parser.add_argument("--eval_steps", default=5000, type=int) 37 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 38 | parser.add_argument("--max_grad_norm", default=1.0, type=float) 39 | parser.add_argument("--num_train_epochs", default=10, type=int) 40 | 41 | args = parser.parse_args() 42 | 43 | if args.train_name: 44 | args.log_dir = f"{args.output_dir}/log/{args.train_name}" 45 | args.model_save_dir = f"{args.output_dir}/models/{args.train_name}" 46 | else: 47 | time_stamp = time.strftime("%b-%d_%H:%M:%S", time.localtime()) 48 | args.log_dir = f"{args.output_dir}/log/{time_stamp}" 49 | args.model_save_dir = f"{args.output_dir}/models/{time_stamp}" 50 | return args 51 | -------------------------------------------------------------------------------- /dual_flow/train_kamed.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --mode train \ 3 | --data_name kamed \ 4 | --data_type train \ #choose the division of the dataset 5 | --evaluate_during_training \ 6 | --eval_steps 2500 \ 7 | --per_gpu_train_batch_size 12 \ 8 | --entity_update_steps 1 \ #update entity only at the first step 9 | --output_dir ./train \ #dir to save log and fine-tuned models 10 | --lr 3e-5 \ 11 | --num_train_epochs 5 \ 12 | --act_weight 1 \ 13 | --entity_weight 0.1 \ 14 | --train_name demo_kamed \ #training name 15 | -------------------------------------------------------------------------------- /dual_flow/train_meddg.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --mode train \ 3 | --data_name meddg \ 4 | --data_type train \ #choose the division of the dataset 5 | --evaluate_during_training \ 6 | --eval_steps 2500 \ 7 | --per_gpu_train_batch_size 12 \ 8 | --entity_update_steps 10000 \ #update entity each 100 steps before the first 10k steps 9 | --output_dir ./train \ #dir to save log and fine-tuned models 10 | --lr 3e-5 \ 11 | --num_train_epochs 5 \ 12 | --act_weight 0.05 \ 13 | --entity_weight 1 \ 14 | --for_meddg_160 \ #training entity predictor 15 | --train_name demo_meddg \ #training name 16 | -------------------------------------------------------------------------------- /dual_flow/training.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | from tqdm import tqdm, trange 7 | from torch.utils.tensorboard import SummaryWriter 8 | from torch.optim import AdamW 9 | from transformers import get_linear_schedule_with_warmup 10 | from transformers.trainer_pt_utils import get_parameter_names 11 | from torch.utils.data import SequentialSampler 12 | 13 | from evaluating import evaluate 14 | from utils import save_model, set_seed, get_entity_matrix 15 | from dataset_entity import construct_data 16 | from get_entity_embed import get_entity_embeds 17 | 18 | logger = logging.getLogger(__name__) 19 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 20 | datefmt = '%d %H:%M:%S', 21 | level = logging.INFO) 22 | 23 | act_dict = {0: "INQUIRE", 1: "DIAGNOSIS", 2: "TREATMENT", 3: "TEST", 4: "PRECAUTION", 5: "INFORM", 6: "CHITCHAT"} 24 | 25 | def get_optimizer(args, model): 26 | # Parameters with decaying 27 | decay_parameters = get_parameter_names(model, [nn.LayerNorm]) 28 | decay_parameters = [ 29 | name for name in decay_parameters if "bias" not in name 30 | ] 31 | 32 | optimizer_grouped_parameters = [ 33 | { 34 | "params": [ 35 | p for n, p in model.named_parameters() 36 | if n in decay_parameters 37 | ], 38 | "lr": 39 | args.lr, 40 | "weight_decay": 41 | args.weight_decay, 42 | }, 43 | { 44 | "params": [ 45 | p for n, p in model.named_parameters() 46 | if n not in decay_parameters 47 | ], 48 | "lr": 49 | args.lr, 50 | "weight_decay": 51 | 0.0, 52 | }, 53 | ] 54 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon) 55 | return optimizer 56 | 57 | def run_eval(args, model, embeds, entity_matrix, dataset, dataloader, tb_writer, global_step, prefix, data_type): 58 | (test_loss, test_act_loss, test_entity_loss, 59 | act_f1, act_accuray, act_recall, act_precision, ent_recall) = evaluate(args, model, embeds, entity_matrix, dataset, dataloader, prefix) 60 | tb_writer.add_scalar(f'{data_type}/loss', test_loss, global_step) 61 | tb_writer.add_scalar(f'{data_type}/act_loss', test_act_loss, global_step) 62 | tb_writer.add_scalar(f'{data_type}/entity_loss', test_entity_loss, global_step) 63 | for i in range(7): 64 | tb_writer.add_scalar(f'{data_type}/{act_dict[i]}_f1', act_f1[i], global_step) 65 | tb_writer.add_scalar(f'{data_type}/{act_dict[i]}_accuray', act_accuray[i], global_step) 66 | tb_writer.add_scalar(f'{data_type}/{act_dict[i]}_recall', act_recall[i], global_step) 67 | tb_writer.add_scalar(f'{data_type}/{act_dict[i]}_precision', act_precision[i], global_step) 68 | tb_writer.add_scalar(f'{data_type}/entity_recall', ent_recall, global_step) 69 | 70 | def train(args, model, embeds, train_dataset, train_dataloader, val_dataset, val_dataloader, test_dataset, test_dataloader, tokenizer): 71 | # Train the model 72 | tb_writer = SummaryWriter(args.log_dir) 73 | 74 | # Total steps 75 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 76 | 77 | # Get entity matrix 78 | entity_matrix = get_entity_matrix(args.data_name, train_dataset.entity_lst) 79 | 80 | # Prepare optimizer and schedule (linear warmup and decay) 81 | optimizer = get_optimizer(args, model) 82 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 83 | 84 | # multi-gpu training (should be after apex fp16 initialization) 85 | if args.n_gpu > 1: 86 | model = torch.nn.DataParallel(model) 87 | 88 | # Train! 89 | logger.info("***** Running Training *****") 90 | logger.info(" Num examples = %d", len(train_dataset)) 91 | logger.info(" Num Epochs = %d", args.num_train_epochs) 92 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 93 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 94 | args.train_batch_size * args.gradient_accumulation_steps) 95 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 96 | logger.info(" Total optimization steps = %d", t_total) 97 | 98 | global_step = 0 99 | tr_loss, logging_loss = 0.0, 0.0 100 | model.zero_grad() 101 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch") 102 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 103 | _, entity_dataloader, _ = construct_data(args, args.per_gpu_eval_batch_size, tokenizer, SequentialSampler) 104 | for epoch_idx in train_iterator: 105 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 106 | for step, (batch, idx, turn_idx) in enumerate(epoch_iterator): 107 | 108 | # update entity embed each 100 steps and then fix it 109 | # gpu memory is not enough to update entity embed each step 110 | if step % 100 == 0 and global_step < args.entity_update_steps: 111 | embeds = get_entity_embeds(args, entity_dataloader, model) 112 | torch.cuda.empty_cache() 113 | 114 | context_idx = batch["context_idx"] 115 | act_turn_idx = batch["act_turn_idx"] 116 | entity_turn_idx = batch["entity_turn_idx"] 117 | batch_entity_turn_idx = batch["batch_entity_turn_idx"] 118 | target_kg_entity_idx = batch["target_kg_entity_idx"] 119 | neg_kg_entity_idx = batch["neg_kg_entity_idx"] 120 | batch = {k:v.to(args.device) for k, v in batch.items() if "idx" not in k and "entity_ids" not in k and "entity_mask" not in k} 121 | 122 | if not args.for_meddg_160: 123 | batch["entity_labels"] = None 124 | 125 | model.train() 126 | outputs = model(**batch, 127 | entity_embeds=embeds.detach(), 128 | entity_matrix=entity_matrix.detach(), 129 | context_idx=context_idx, 130 | act_turn_idx=act_turn_idx, 131 | entity_turn_idx=entity_turn_idx, 132 | batch_entity_turn_idx=batch_entity_turn_idx, 133 | target_kg_entity_idx=target_kg_entity_idx, 134 | neg_kg_entity_idx=neg_kg_entity_idx, 135 | ) 136 | 137 | loss = outputs[0] 138 | if args.n_gpu > 1: 139 | loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training 140 | if args.gradient_accumulation_steps > 1: 141 | loss = loss / args.gradient_accumulation_steps 142 | loss.backward() 143 | tr_loss += loss.item() 144 | 145 | if (step + 1) % args.gradient_accumulation_steps == 0: 146 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 147 | optimizer.step() 148 | scheduler.step() # Update learning rate schedule 149 | model.zero_grad() 150 | optimizer.zero_grad() 151 | global_step += 1 152 | 153 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 154 | cur_loss = (tr_loss - logging_loss)/args.logging_steps 155 | tb_writer.add_scalar("train/lr", scheduler.get_lr()[0], global_step) 156 | tb_writer.add_scalar("train/loss", cur_loss, global_step) 157 | tb_writer.add_scalar("train/act_loss", outputs[1], global_step) 158 | tb_writer.add_scalar("train/entity_loss", outputs[2], global_step) 159 | logging_loss = tr_loss 160 | 161 | # Save model checkpoint 162 | if global_step % args.eval_steps == 0: 163 | save_model(model, embeds, args.model_save_dir, "step-{}".format(global_step), args) 164 | 165 | if args.evaluate_during_training and global_step % args.eval_steps == 0: 166 | run_eval(args, model, embeds, entity_matrix, val_dataset, val_dataloader, tb_writer, global_step, prefix="step-{}".format(global_step), data_type="valid") 167 | run_eval(args, model, embeds, entity_matrix, test_dataset, test_dataloader, tb_writer, global_step, prefix="step-{}".format(global_step), data_type="test") 168 | 169 | # Save model checkpoint 170 | save_model(model, embeds, args.model_save_dir, "epoch-{}".format(epoch_idx+1), args) 171 | 172 | if args.evaluate_during_training: 173 | run_eval(args, model, embeds, entity_matrix, val_dataset, val_dataloader, tb_writer, global_step, prefix="epoch-{}".format(epoch_idx+1), data_type="valid") 174 | run_eval(args, model, embeds, entity_matrix, test_dataset, test_dataloader, tb_writer, global_step, prefix="epoch-{}".format(epoch_idx+1), data_type="test") 175 | -------------------------------------------------------------------------------- /dual_flow/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import random 5 | import numpy as np 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | def set_seed(args): 10 | random.seed(args.seed) 11 | np.random.seed(args.seed) 12 | torch.manual_seed(args.seed) 13 | torch.cuda.manual_seed(args.seed) 14 | torch.cuda.manual_seed_all(args.seed) 15 | torch.backends.cudnn.deterministic = True 16 | 17 | def save_model(model, embeds, output_dir, save_name, args): 18 | save_dir = os.path.join(output_dir, save_name) 19 | if not os.path.exists(save_dir): 20 | os.makedirs(save_dir) 21 | model.save_pretrained(save_dir) 22 | torch.save(args, os.path.join(save_dir, 'training_args.bin')) 23 | 24 | embeds_np = embeds.cpu().numpy() 25 | entity_embed_dir = os.path.join(save_dir, f"{args.data_name}_entity_embeds.pk") 26 | with open(entity_embed_dir, "wb") as outfile: 27 | pickle.dump(embeds_np, outfile) 28 | 29 | def load_entity_embed(input_dir, data_name): 30 | path = os.path.join(input_dir, f"{data_name}_entity_embeds.pk") 31 | with open(path, "rb") as infile: 32 | embeds = pickle.load(infile) 33 | return torch.from_numpy(embeds) 34 | 35 | def squeeze_lst(lst): 36 | tmp = [] 37 | for x in lst: 38 | if x not in tmp: 39 | tmp.append(x) 40 | return tmp 41 | 42 | def get_cmekg_link_specific(data_name): 43 | head_lst = set() 44 | with open(f"../data/cmekg/head_{data_name}.txt", "r") as f: 45 | for line in f: 46 | head_lst.add(line.strip()) 47 | 48 | del_entity_lst = [] #额外单独删除 49 | 50 | cmekg_link_relation_dict = defaultdict(list) 51 | entity_lst = set() 52 | cmekg_link_dict = defaultdict(list) 53 | cmekg_link_set = set() 54 | useless_r = ["中心词", "ICD-10", "UMLS", "拼音"] 55 | with open("../data/cmekg/kg_disease.txt", "r") as f: 56 | for line in f: 57 | h, r, t = line.split("\t") 58 | t = t.strip() 59 | if r == "英文名称" and not t.isupper(): #仅保留英文缩写 60 | continue 61 | useless_disease_r = ["进入路径标准", "多发地区", "多发季节", "多发群体", "易感人群", "标准住院时间", "出院标准", "发病机制", "发病率", "发病性别倾向", "发病年龄"] 62 | if h != t and r not in useless_r and r not in useless_disease_r and h in head_lst and t not in del_entity_lst: 63 | cmekg_link_set.add(h + "\t" + t) 64 | cmekg_link_set.add(t + "\t" + h) 65 | cmekg_link_dict[h].append(t) 66 | cmekg_link_dict[t].append(h) 67 | cmekg_link_relation_dict[f"{h}\t{t}"].append(r) 68 | entity_lst.add(h) 69 | entity_lst.add(t) 70 | 71 | with open("../data/cmekg/kg_symptom.txt", "r") as f: 72 | for line in f: 73 | h, r, t = line.split("\t") 74 | t = t.strip() 75 | if r == "英文名称" and not t.isupper(): #仅保留英文缩写 76 | continue 77 | useless_symptom_r = ["进入路径标准", "多发地区", "多发季节", "多发群体", "易感人群", "标准住院时间", "出院标准", "发病机制", "发病率", "发病性别倾向", "发病年龄"] 78 | if h != t and r not in useless_r and r not in useless_symptom_r and h in head_lst and t not in del_entity_lst: 79 | cmekg_link_set.add(h + "\t" + t) 80 | cmekg_link_set.add(t + "\t" + h) 81 | cmekg_link_dict[h].append(t) 82 | cmekg_link_dict[t].append(h) 83 | cmekg_link_relation_dict[f"{h}\t{t}"].append(r) 84 | entity_lst.add(h) 85 | entity_lst.add(t) 86 | 87 | with open("../data/cmekg/kg_test.txt", "r") as f: 88 | for line in f: 89 | h, r, t = line.split("\t") 90 | t = t.strip() 91 | if r == "英文名称" and not t.isupper(): #仅保留英文缩写 92 | continue 93 | useless_test_r = ["试剂", "原理", "所属分类", "操作方法", "临床意义", "正常值"] 94 | if h != t and r not in useless_r and r not in useless_test_r and h in head_lst and t not in del_entity_lst: 95 | cmekg_link_set.add(h + "\t" + t) 96 | cmekg_link_set.add(t + "\t" + h) 97 | cmekg_link_dict[h].append(t) 98 | cmekg_link_dict[t].append(h) 99 | cmekg_link_relation_dict[f"{h}\t{t}"].append(r) 100 | entity_lst.add(h) 101 | entity_lst.add(t) 102 | 103 | with open("../data/cmekg/kg_medicine.txt", "r") as f: 104 | for line in f: 105 | h, r, t = line.split("\t") 106 | t = t.strip() 107 | useless_medicine_r = ["英文名称", "拉丁学名", "OTC类型", "出处", "分子量", 108 | "晶系", "化学式", "比重", "硬度", "采集加工", "执行标准", 109 | "批准文号", "有效期", "分布区域", "采收时间", "是否纳入医保", 110 | "是否处方药", "药品监管分级", 111 | "贮藏", "界", "门", "纲", "目", "科", "属", "种", 112 | "入药部位", "性味", "性状", "特殊药品", "规格", "剂型", "成份", "组成"] 113 | if h != t and r not in useless_r and r not in useless_medicine_r and h in head_lst and t not in del_entity_lst: 114 | cmekg_link_set.add(h + "\t" + t) 115 | cmekg_link_set.add(t + "\t" + h) 116 | cmekg_link_dict[h].append(t) 117 | cmekg_link_dict[t].append(h) 118 | cmekg_link_relation_dict[f"{h}\t{t}"].append(r) 119 | entity_lst.add(h) 120 | entity_lst.add(t) 121 | 122 | cmekg_link_dict = dict(cmekg_link_dict) 123 | cmekg_link_relation_dict = dict(cmekg_link_relation_dict) 124 | for key in cmekg_link_dict: 125 | cmekg_link_dict[key] = squeeze_lst(cmekg_link_dict[key]) 126 | return cmekg_link_dict, cmekg_link_set, entity_lst, head_lst 127 | 128 | def get_cmekg_entity_specific(data_name): 129 | entity_lst = [] 130 | entity_dict = dict() 131 | with open(f"../data/cmekg/entities_{data_name}.txt", "r") as f: 132 | for i, line in enumerate(f): 133 | entity_lst.append(line.strip()) 134 | entity_dict[line.strip()] = i 135 | entity_type_dict = dict() 136 | for key in ["disease", "medicine", "symptom", "test"]: 137 | with open(f"../data/cmekg/{key}_lst.txt", "r") as f: 138 | for line in f: 139 | entity_type_dict[line.strip()] = key 140 | return entity_lst, entity_dict, entity_type_dict 141 | 142 | def get_ner_entity(data_name): 143 | ner_entity_lst = [] 144 | ner_entity_dict = {} 145 | with open(f"../data/{data_name}_ner_entity/{data_name}_entity.txt", "r") as f: 146 | for i, line in enumerate(f): 147 | ner_entity_lst.append(line.strip()) 148 | ner_entity_dict[line.strip()] = i 149 | return ner_entity_lst, ner_entity_dict 150 | 151 | def get_entity_matrix(data_name, entity_lst): 152 | cmekg_link_dict, _, _, head_lst = get_cmekg_link_specific(data_name) 153 | entity_matrix = torch.eye(len(entity_lst)) 154 | print(f"Num of {data_name} entities in adj matrix: {len(entity_lst)}", ) 155 | 156 | for key in tqdm(cmekg_link_dict, desc="Get entity matrix"): 157 | for entity in cmekg_link_dict[key]: 158 | entity_matrix[entity_lst.index(key), entity_lst.index(entity)] = 1 159 | entity_matrix = (1.0 - entity_matrix) * torch.finfo(torch.float32).min 160 | 161 | return entity_matrix.to(torch.device("cpu")) 162 | -------------------------------------------------------------------------------- /generation/dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | import pickle 5 | import logging 6 | import numpy as np 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | from utils import get_cmekg_entity_specific 11 | 12 | logger = logging.getLogger(__name__) 13 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 14 | datefmt = '%d %H:%M:%S', 15 | level = logging.INFO) 16 | 17 | def load_data(data_name, data_type): 18 | with open(f"../data/{data_name}_{data_type}_kg_entity_specific.pk", "rb") as f: 19 | data = pickle.load(f) 20 | 21 | logger.info(f"Num of {data_name} {data_type} dialogues: %d", len(data)) 22 | return data 23 | 24 | def get_turn_range(raw_ids): 25 | length_lst = [len(x) for x in raw_ids] #每个turn的长度 26 | 27 | turn_range = [] 28 | end = 1 29 | while end < len(raw_ids): 30 | for start in range(0, end): 31 | if np.sum(length_lst[start:end]) <= 512: 32 | break 33 | turn_range.append([start, end]) 34 | end += 2 35 | return turn_range 36 | 37 | def get_samples(raw_ids, turn_range, start_token=None, end_token=None, is_target=False): 38 | ids = [] 39 | for (start, end) in turn_range: 40 | ids_tmp = [] #一个样本 41 | if not is_target: #构建context样本 42 | for i in range(start, end): 43 | ids_tmp += raw_ids[i] 44 | if len(ids_tmp) == 512: #为补[CLS]留位置 45 | ids_tmp = ids_tmp[0:1] + ids_tmp[2:] 46 | ids_tmp = [start_token] + ids_tmp 47 | 48 | else: #构建target样本 49 | ids_tmp += raw_ids[end] 50 | ids_tmp[0] = start_token 51 | ids_tmp = ids_tmp[:512-1] 52 | ids_tmp += [end_token] 53 | 54 | assert len(ids_tmp) <= 512 55 | ids.append(ids_tmp) 56 | return ids #多个样本 57 | 58 | def process_data(args, 59 | mode, 60 | data_name, 61 | data_type, 62 | raw_data, 63 | tokenizer): 64 | 65 | act_id = {"INQUIRE": 0, "DIAGNOSIS": 1, "TREATMENT": 2, "TEST": 3, "PRECAUTION": 4, "INFORM": 5, "CHITCHAT": 6} 66 | 67 | if not args.for_meddg_160: 68 | #retrieved kg entity 69 | with open(f"{args.act_entity_dir}/{data_name}_{data_type}_ranked_entity.pk", "rb") as f: 70 | entity_ranked_dict = pickle.load(f) 71 | 72 | #cmekg entity 73 | entity_lst, _ = get_cmekg_entity_specific(data_name) 74 | 75 | #predicted act and selected entity 76 | if mode != "train": 77 | 78 | with open(f"{args.act_entity_dir}/{data_name}_{data_type}_predicted_act.pk", "rb") as f: 79 | _, _, _, act_predicted_dict = pickle.load(f) 80 | 81 | if args.for_meddg_160: 82 | with open(f"{args.act_entity_dir}/{data_name}_{data_type}_selected_entity.pk", "rb") as f: 83 | _, _, _, entity_predicted_dict = pickle.load(f) 84 | 85 | #role dict 86 | role_token_id = {"patient": 1, "doctor": 2} 87 | 88 | samples_context_ids = [] 89 | samples_context_raw = [] 90 | samples_target_ids = [] 91 | samples_target_raw = [] 92 | sample_entity_ids = [] 93 | samples_idx = [] 94 | samples_turn_idx = [] 95 | for sample in tqdm(raw_data, desc=f"Construct {data_type} samples"): 96 | 97 | dialogue_history = sample["dialogues"] 98 | if dialogue_history[-1]["role"] == "patient": #所有对话以doctor回复结束 99 | del dialogue_history[-1] 100 | 101 | assert len(dialogue_history) % 2 == 0 102 | 103 | context_ids_tmp = [] 104 | for i, turn in enumerate(dialogue_history): 105 | tmp = [role_token_id[turn["role"]]] 106 | tokens = tokenizer.tokenize(turn["sentence"]) 107 | ids = tokenizer.convert_tokens_to_ids(tokens) 108 | tmp += ids #dialogue input ids 109 | 110 | extra_length = max(len(tmp) - 512, 0) #确认超出长度 111 | if extra_length > 0 and turn["role"] == "patient": 112 | tmp = tmp[:-extra_length] 113 | 114 | context_ids_tmp.append(tmp) 115 | if turn["role"] == "doctor": 116 | #保存原始target 117 | samples_target_raw.append(turn["sentence"]) 118 | 119 | #保存原始context 120 | context_raw_tmp = [] 121 | for j in range(i): 122 | context_raw_tmp.append(dialogue_history[j]["sentence"]) 123 | samples_context_raw.append({ 124 | "context": context_raw_tmp, 125 | "idx": sample["idx"], 126 | "turn_idx": str(i), 127 | }) 128 | 129 | tmp_entity = [tokenizer.cls_token_id] 130 | 131 | if mode != "train": 132 | for x in act_predicted_dict[sample["idx"] + "_" + str(turn["turn"])]: 133 | tmp_entity += [act_id[x]+1,] 134 | 135 | if args.for_meddg_160: 136 | for x in entity_predicted_dict[sample["idx"] + "_" + str(turn["turn"])]: 137 | tokens = tokenizer.tokenize(x) 138 | ids = tokenizer.convert_tokens_to_ids(tokens) 139 | tmp_entity += ids 140 | else: 141 | turn_act_lst = set() 142 | for act_lst in turn["act"]: 143 | turn_act_lst.update(act_lst) 144 | turn_act_lst = [act_id[x]+1 for x in act_id if x in turn_act_lst] #直接保存act编号 145 | tmp_entity += turn_act_lst 146 | 147 | if args.for_meddg_160: 148 | for entity in turn["entity"]: 149 | tokens = tokenizer.tokenize(entity) 150 | ids = tokenizer.convert_tokens_to_ids(tokens) 151 | tmp_entity += ids #entity ids 152 | 153 | if not args.for_meddg_160: 154 | for idx in entity_ranked_dict[sample["idx"] + "_" + str(turn["turn"])][:args.k_entity]: 155 | tokens = tokenizer.tokenize(entity_lst[idx]) 156 | ids = tokenizer.convert_tokens_to_ids(tokens) 157 | tmp_entity += [tokenizer.sep_token_id] + ids #entity ids 158 | 159 | tmp_entity = tmp_entity[:512] 160 | sample_entity_ids.append(tmp_entity) 161 | 162 | turn_range = get_turn_range(context_ids_tmp) #包含多个样本 163 | context_ids = get_samples(context_ids_tmp, turn_range, start_token=tokenizer.cls_token_id) #包含多个样本 164 | target_ids = get_samples(context_ids_tmp, turn_range, start_token=tokenizer.cls_token_id, end_token=tokenizer.sep_token_id, is_target=True) #包含多个样本 165 | 166 | samples_context_ids += context_ids 167 | samples_target_ids += target_ids 168 | 169 | turn_idx = [str(end) for (start, end) in turn_range] 170 | samples_turn_idx += turn_idx 171 | samples_idx += [sample["idx"]] * len(context_ids) 172 | 173 | return (samples_context_ids, samples_target_ids, samples_target_raw, samples_context_raw, sample_entity_ids, 174 | samples_idx, samples_turn_idx) 175 | 176 | class BaseDataset(Dataset): 177 | def __init__(self, args, data_name, data_type, mode, tokenizer): 178 | self.data_name = data_name 179 | self.data_type = data_type 180 | self.mode = mode 181 | self.tokenizer = tokenizer 182 | self.raw_data = load_data(data_name, data_type) 183 | (self.samples_context_ids, self.samples_target_ids, self.samples_target_raw, self.samples_context_raw, self.sample_entity_ids, 184 | self.samples_idx, self.samples_turn_idx) = process_data(args, mode, data_name, data_type, self.raw_data, tokenizer) 185 | 186 | def __len__(self): 187 | assert len(self.samples_context_ids) == len(self.samples_target_ids) 188 | assert len(self.samples_idx) == len(self.samples_turn_idx) 189 | assert len(self.samples_target_raw) == len(self.samples_idx) 190 | assert len(self.samples_context_raw) == len(self.samples_target_raw) 191 | assert len(self.samples_context_ids) == len(self.sample_entity_ids) 192 | return len(self.samples_context_ids) 193 | 194 | def __getitem__(self, item): 195 | lst_data = { 196 | "context_ids": self.samples_context_ids[item], 197 | "target_ids": self.samples_target_ids[item], 198 | "entity_ids": self.sample_entity_ids[item], 199 | "idx": self.samples_idx[item], 200 | "turn_idx": self.samples_turn_idx[item], 201 | } 202 | return lst_data 203 | 204 | def pack_tensor_2D(raw_lst, default, dtype, length=None): 205 | batch_size = len(raw_lst) 206 | length = length if length is not None else max(len(raw) for raw in raw_lst) 207 | tensor = default * torch.ones((batch_size, length), dtype=dtype) 208 | for i, raw in enumerate(raw_lst): 209 | tensor[i, :len(raw)] = torch.tensor(raw, dtype=dtype) 210 | return tensor 211 | 212 | def get_collate_function(tokenizer): 213 | def collate_function(batch): 214 | context_ids_lst = [x["context_ids"] for x in batch] 215 | context_mask_lst = [[1] * len(context_ids) for context_ids in context_ids_lst] 216 | target_ids_lst = [x["target_ids"] for x in batch] 217 | target_mask_lst = [[1] * len(target_ids) for target_ids in target_ids_lst] 218 | entity_ids_lst = [x["entity_ids"] for x in batch] 219 | entity_mask_lst = [[1] * len(entity_ids) for entity_ids in entity_ids_lst] 220 | labels_lst = [x["target_ids"][1:] + [-100] for x in batch] #补"-100"从而保持跟target的长度一致 221 | 222 | data = { 223 | "input_ids": pack_tensor_2D(context_ids_lst, default=0, dtype=torch.int64), 224 | "attention_mask": pack_tensor_2D(context_mask_lst, default=0, dtype=torch.int64), 225 | "decoder_input_ids": pack_tensor_2D(target_ids_lst, default=0, dtype=torch.int64), 226 | "decoder_attention_mask": pack_tensor_2D(target_mask_lst, default=0, dtype=torch.int64), 227 | "entity_input_ids": pack_tensor_2D(entity_ids_lst, default=0, dtype=torch.int64), 228 | "entity_attention_mask": pack_tensor_2D(entity_mask_lst, default=0, dtype=torch.int64), 229 | "labels": pack_tensor_2D(labels_lst, default=-100, dtype=torch.int64), 230 | } 231 | 232 | idx = [x["idx"] for x in batch] 233 | turn_idx = [x["turn_idx"] for x in batch] 234 | return data, idx, turn_idx 235 | return collate_function 236 | 237 | def construct_data(args, data_type, mode, per_gpu_batch_size, tokenizer, data_sampler): 238 | batch_size = per_gpu_batch_size * max(1, args.n_gpu) 239 | dataset = BaseDataset(args, args.data_name, data_type, mode, tokenizer) 240 | sampler = data_sampler(dataset) 241 | collate_fn = get_collate_function(tokenizer) 242 | dataloader = DataLoader(dataset, sampler=sampler, 243 | batch_size=batch_size, num_workers=args.data_num_workers, collate_fn=collate_fn) 244 | return dataset, dataloader, batch_size 245 | -------------------------------------------------------------------------------- /generation/eval_kamed.sh: -------------------------------------------------------------------------------- 1 | output_dir="./train" 2 | train_name="demo" 3 | 4 | checkpoint="epoch-*" 5 | python main.py \ 6 | --mode inference \ 7 | --data_name kamed \ 8 | --data_type test \ 9 | --output_dir "${output_dir}" \ 10 | --eval_model_path "${output_dir}/models/${train_name}/${checkpoint}" \ 11 | --per_gpu_eval_batch_size 16 \ 12 | --decode_max_length 150 \ 13 | --act_entity_dir ../results/df_results/kamed \ 14 | --result_save_dir "${output_dir}/results/${train_name}" \ 15 | --k_entity 10 \ 16 | --top_p 0.35 \ 17 | --num_beams 5 \ 18 | 19 | #if you want to evaluate our checkpoints, you can replace the "eval_model_path" with the downloaded checkpoints path 20 | #and the "result_save_dir" with your expected path. 21 | -------------------------------------------------------------------------------- /generation/eval_meddg.sh: -------------------------------------------------------------------------------- 1 | output_dir="./train" 2 | train_name="demo" 3 | 4 | checkpoint="epoch-*" 5 | python main.py \ 6 | --mode inference \ 7 | --data_name meddg \ 8 | --data_type test \ 9 | --output_dir "${output_dir}" \ 10 | --eval_model_path "${output_dir}/models/${train_name}/${checkpoint}" \ 11 | --per_gpu_eval_batch_size 16 \ 12 | --decode_max_length 150 \ 13 | --act_entity_dir ../results/df_results/meddg \ 14 | --result_save_dir "${output_dir}/results/${train_name}" \ 15 | --top_k 64 \ 16 | --num_beams 5 \ 17 | --for_meddg_160 \ # only for the MedDG dataset 18 | 19 | #if you want to evaluate our checkpoints, you can replace the "eval_model_path" with the downloaded checkpoints path 20 | #and the "result_save_dir" with your expected path. 21 | -------------------------------------------------------------------------------- /generation/evaluating.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import logging 4 | from tqdm import tqdm 5 | 6 | logger = logging.getLogger(__name__) 7 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 8 | datefmt = '%d %H:%M:%S', 9 | level = logging.INFO) 10 | 11 | def evaluate(args, model, eval_dataset, eval_dataloader, prefix): 12 | # multi-gpu eval 13 | if args.n_gpu > 1: 14 | model = torch.nn.DataParallel(model) 15 | 16 | # Eval 17 | logger.info("***** Running Evaluation {} *****".format(prefix)) 18 | logger.info(" Num examples = %d", len(eval_dataset)) 19 | logger.info(" Batch size = %d", args.eval_batch_size) 20 | 21 | test_loss = 0 22 | test_corrects = 0 23 | test_num_targets = 0 24 | test_idx = [] 25 | test_turn_idx = [] 26 | 27 | for batch, idx, turn_idx in tqdm(eval_dataloader, desc="Evaluating"): 28 | with torch.no_grad(): 29 | batch = {k:v.to(args.device) for k, v in batch.items()} 30 | model.eval() 31 | outputs = model(**batch) 32 | 33 | loss, logits = outputs.loss, outputs.logits 34 | test_loss += loss 35 | test_idx += idx 36 | test_turn_idx += turn_idx 37 | 38 | shift_logits = logits.contiguous() 39 | shift_labels = batch["labels"].contiguous() 40 | 41 | _, preds = shift_logits.max(dim=-1) 42 | not_ignore = shift_labels.ne(-100) 43 | num_targets = not_ignore.long().sum().item() 44 | corrects = (shift_labels == preds) & not_ignore 45 | corrects = corrects.float().sum() 46 | 47 | test_corrects += corrects 48 | test_num_targets += num_targets 49 | 50 | return test_loss/len(eval_dataloader), test_corrects/test_num_targets 51 | -------------------------------------------------------------------------------- /generation/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import logging 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | from transformers import BertTokenizer 9 | from metrics import get_bleu, get_entity_acc, get_rouge 10 | 11 | logger = logging.getLogger(__name__) 12 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 13 | datefmt = '%d %H:%M:%S', 14 | level = logging.INFO) 15 | 16 | def inference(args, model, eval_dataset, eval_dataloader, prefix): 17 | # multi-gpu inference 18 | if args.n_gpu > 1: 19 | model = torch.nn.DataParallel(model) 20 | 21 | # Tokenizer 22 | tokenizer = BertTokenizer.from_pretrained(args.base_model_path) 23 | 24 | # Inference 25 | logger.info("***** Running Inference {} *****".format(prefix)) 26 | logger.info(" Num examples = %d", len(eval_dataset)) 27 | logger.info(" Batch size = %d", args.eval_batch_size) 28 | 29 | args.early_stopping = True if args.num_beams > 1 else None 30 | args.do_sample = True if args.top_k != None or args.top_p != 1 else None 31 | print("Do smaple:", args.do_sample) 32 | 33 | generate_text_dict = defaultdict(dict) 34 | generate_texts = [] 35 | for batch, idx, turn_idx in tqdm(eval_dataloader, desc="Inference"): 36 | with torch.no_grad(): 37 | del batch["labels"], batch["decoder_input_ids"], batch["decoder_attention_mask"] # delete labels 38 | batch = {k:v.to(args.device) for k, v in batch.items()} 39 | model.eval() 40 | outputs = model.generate(**batch, 41 | num_beams=args.num_beams, 42 | early_stopping=args.early_stopping, 43 | do_sample=args.do_sample, 44 | top_k=args.top_k, 45 | top_p=args.top_p, 46 | max_length=args.decode_max_length, 47 | bos_token_id=tokenizer.cls_token_id, 48 | eos_token_id=tokenizer.sep_token_id, 49 | pad_token_id=tokenizer.pad_token_id, 50 | ) 51 | for i, output in enumerate(outputs): 52 | text = tokenizer.decode(output, skip_special_tokens=True) 53 | generate_texts.append((idx[i], turn_idx[i], text.replace(" ", ""))) 54 | generate_text_dict[idx[i]][turn_idx[i]] = text.replace(" ", "") 55 | 56 | if args.result_save_dir: 57 | result_save_dir = args.result_save_dir 58 | else: 59 | result_save_dir = f"{args.model_save_dir}/{prefix}" 60 | 61 | if not os.path.exists(result_save_dir): 62 | os.makedirs(result_save_dir) 63 | 64 | reference_text_dict = defaultdict(dict) 65 | ref_file = f"{result_save_dir}/{eval_dataset.data_name}_{eval_dataset.data_type}_reference.txt" 66 | with open(ref_file, "w") as outfile: 67 | for (idx, turn_idx, target) in zip(eval_dataset.samples_idx, eval_dataset.samples_turn_idx, eval_dataset.samples_target_raw): 68 | outfile.write(str(idx) + "\t") 69 | outfile.write(str(turn_idx) + "\t") 70 | outfile.write(target + "\n") 71 | reference_text_dict[idx][turn_idx] = target 72 | 73 | hyp_file = f"{result_save_dir}/{eval_dataset.data_name}_{eval_dataset.data_type}_generate.txt" 74 | with open(hyp_file, "w") as outfile: 75 | for text in generate_texts: 76 | outfile.write(str(text[0]) + "\t") 77 | outfile.write(str(text[1]) + "\t") 78 | outfile.write(text[2] + "\n") 79 | 80 | bleu = get_bleu(ref_file, hyp_file) 81 | acc = get_entity_acc(ref_file, hyp_file) 82 | rouge = get_rouge(ref_file, hyp_file) 83 | 84 | metric_file = f"{result_save_dir}/{eval_dataset.data_name}_{eval_dataset.data_type}_metric.json" 85 | with open(metric_file, "w") as outfile: 86 | json.dump({"bleu-1": bleu["bleu-1"], 87 | "bleu-2": bleu["bleu-2"], 88 | "bleu-4": bleu["bleu-4"], 89 | "entity-f1": acc["f1"], 90 | "rouge-1": rouge["rouge-1"], 91 | "rouge-2": rouge["rouge-2"]}, 92 | outfile, indent=4, separators=(',', ': ')) 93 | return bleu["bleu-1"], bleu["bleu-4"], acc["f1"], rouge["rouge-1"], rouge["rouge-2"] 94 | -------------------------------------------------------------------------------- /generation/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import logging 4 | from parsing import run_parse_args 5 | from transformers import BertTokenizer, BartForConditionalGeneration 6 | from torch.utils.data import RandomSampler, SequentialSampler 7 | 8 | from model import Generator 9 | from dataset import construct_data 10 | from utils import set_seed 11 | from training import train 12 | from evaluating import evaluate 13 | from inference import inference 14 | 15 | logger = logging.getLogger(__name__) 16 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 17 | datefmt = '%d %H:%M:%S', 18 | level = logging.INFO) 19 | 20 | def main(): 21 | args = run_parse_args() 22 | 23 | # Setup CUDA, GPU 24 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 25 | args.n_gpu = torch.cuda.device_count() 26 | args.device = device 27 | 28 | # Setup logging 29 | logger.warning("Device: %s, n_gpu: %s", device, args.n_gpu) 30 | 31 | # Set seed 32 | set_seed(args) 33 | 34 | # Tokenizer 35 | tokenizer = BertTokenizer.from_pretrained(args.base_model_path) 36 | 37 | # Model 38 | if args.mode == "train": 39 | model_path = args.base_model_path 40 | else: 41 | model_path = args.eval_model_path 42 | model = Generator.from_pretrained(model_path) 43 | model.to(args.device) 44 | 45 | logger.info("Training/Evaluation parameters %s", args) 46 | 47 | # Train/Evaluate 48 | if args.mode == "train": 49 | train_dataset, train_dataloader, args.train_batch_size = construct_data(args, args.data_type, "train", args.per_gpu_train_batch_size, tokenizer, RandomSampler) 50 | val_dataset, val_dataloader, args.eval_batch_size = construct_data(args, "valid", "evaluate", args.per_gpu_eval_batch_size, tokenizer, SequentialSampler) 51 | test_dataset, test_dataloader, args.eval_batch_size = construct_data(args, "test", "evaluate", args.per_gpu_eval_batch_size, tokenizer, SequentialSampler) 52 | train(args, model, train_dataset, train_dataloader, val_dataset, val_dataloader, test_dataset, test_dataloader) 53 | else: 54 | prefix = args.eval_model_path.split("/")[-1] 55 | eval_dataset, eval_dataloader, args.eval_batch_size = construct_data(args, args.data_type, args.mode, args.per_gpu_eval_batch_size, tokenizer, SequentialSampler) 56 | if args.mode != "inference": 57 | result = evaluate(args, model, eval_dataset, eval_dataloader, prefix) 58 | print('Acc: {}'.format(result[1])) 59 | else: 60 | bleu_1, bleu_4, f1, rouge_1, rouge_2 = inference(args, model, eval_dataset, eval_dataloader, prefix) 61 | print("BLEU-1: {:.2f}".format(bleu_1 * 100), end="\t") 62 | print("BLEU-4: {:.2f}".format(bleu_4 * 100)) 63 | print("ROUGE-1: {:.2f}".format(rouge_1 * 100), end="\t") 64 | print("ROUGE-2: {:.2f}".format(rouge_2 * 100)) 65 | print("F1: {:.2f}".format(f1 * 100)) 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /generation/metric.sh: -------------------------------------------------------------------------------- 1 | python metrics.py \ 2 | --hp ./generate.txt \ 3 | --rf ./reference.txt \ 4 | -------------------------------------------------------------------------------- /generation/metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import pickle 4 | import argparse 5 | from typing import Iterable 6 | from nltk.translate.bleu_score import sentence_bleu 7 | from nltk.translate.bleu_score import SmoothingFunction 8 | from rouge import Rouge 9 | from collections import Counter 10 | import numpy as np 11 | sys.setrecursionlimit(512 * 512 + 10) 12 | 13 | def get_youyin(sentence): 14 | aa = "(吃早饭|吃饭规律|三餐规律|饮食规律|作息规律|辣|油炸|熬夜|豆类|油腻|生冷|煎炸|浓茶|喝酒|抽烟|吃的多|暴饮暴食|不容易消化的食物|情绪稳定|精神紧张|夜宵).*?(吗|啊?|呢?|么|嘛?)" 15 | bb = "吃得过多过饱|(有没有吃|最近吃|喜欢吃|经常进食|经常吃).*?(辣|油炸|油腻|生冷|煎炸|豆类|不容易消化的食物|夜宵)" 16 | cc = "(工作压力|精神压力|压力).*?(大不大|大吗)|(心情|情绪|精神).*(怎么样|怎样|如何)|(活动量|运动|锻炼).*?(大|多|少|多不多|咋样|怎么样|怎样).*?(吗|呢)" 17 | if re.search(aa,sentence) or re.search(bb,sentence) or re.search(cc,sentence): 18 | return True 19 | return False 20 | 21 | def get_location(sentence): 22 | cc = "哪个部位|哪个位置|哪里痛|什么部位|什么位置|哪个部位|哪个位置|哪一块|那个部位痛|肚脐眼以上|描述一下位置|具体部位|具体位置" 23 | if re.search(cc,sentence) is not None: 24 | return True 25 | return False 26 | 27 | def get_xingzhi(sentence): 28 | cc = "是哪种疼|怎么样的疼|绞痛|钝痛|隐痛|胀痛|隐隐作痛|疼痛.*性质|(性质|什么样).*(的|得)(痛|疼)" 29 | if re.search(cc,sentence) is not None: 30 | return True 31 | return False 32 | 33 | def get_fan(sentence): 34 | cc = "(饭.*?前|吃东西前|餐前).*(疼|痛|不舒服|不适)|(饭.*?后|吃东西后|餐后).*(疼|痛|不舒服|不适)|(早上|早晨|夜里|半夜|晚饭).*(疼|痛|不舒服|不适)" 35 | if re.search(cc,sentence) is not None: 36 | aa = re.search(cc,sentence).span() 37 | if aa[1] - aa[0] <20: 38 | return True 39 | return False 40 | 41 | def get_tong_pinglv(sentence): 42 | cc = "持续的疼|疼一会儿会自行缓解|持续的,还是阵发|症状减轻了没|(疼|痛).*轻|现在没有症状了吗|现在还有症状吗|(一阵一阵|一直|持续).*(疼|痛)|一阵阵.*(痛|疼)|阵发性|持续性" 43 | if re.search(cc,sentence) is not None: 44 | aa = re.search(cc,sentence).span() 45 | return True 46 | return False 47 | 48 | def get_tong(sentence): 49 | if get_tong_pinglv(sentence) or get_fan(sentence) or get_xingzhi(sentence): 50 | return True 51 | return False 52 | 53 | 54 | def get_other_sym(sentence): 55 | cc = "(还有什么|还有啥|有没有其|都有啥|都有什么|还有别的|有其他|有没有什么|还有其他).*(症状|不舒服)|别的不舒服|有其他不舒服|主要是什么症状|主要.*症状|哪些不适症状|哪些.*症状|出现了什么症状" 56 | if re.search(cc,sentence) is not None: 57 | aa = re.search(cc,sentence).span() 58 | return True 59 | return False 60 | 61 | def get_time(sentence): 62 | aa = "(情况|症状|痛|发病|病|感觉|疼|这样|不舒服|大约).*?(多久|多长时间|几周了?|几天了?)" 63 | bb = "(,|。|、|?)(多长时间了|多久了|有多久了|有多长时间了)|^(多久了|多长时间了|有多久了|有多长时间了|几天了|几周了)" 64 | cc = "有多长时间|有多久" 65 | if re.search(aa,sentence) is not None or re.search(bb,sentence) is not None: 66 | return True 67 | return False 68 | 69 | class KD_Metric(): 70 | def __init__(self) -> None: 71 | self._pred_true = 0 72 | self._total_pred = 0 73 | self._total_true = 0 74 | with open("../data/new_cy_bii.pk", "rb") as f: 75 | self.norm_dict = pickle.load(f) 76 | 77 | def reset(self) -> None: 78 | self._pred_true = 0 79 | self._total_pred = 0 80 | self._total_true = 0 81 | 82 | def get_metric(self, reset: bool = False): 83 | rec, acc, f1 = 0., 0., 0. 84 | if self._total_pred > 0: 85 | acc = self._pred_true / self._total_pred 86 | if self._total_true > 0: 87 | rec = self._pred_true / self._total_true 88 | if acc > 0 and rec > 0: 89 | f1 = acc * rec * 2 / (acc + rec) 90 | if reset: 91 | self.reset() 92 | return {"rec":rec, "acc":acc, "f1":f1} 93 | 94 | def convert_sen_to_entity_set(self, sen): 95 | entity_set = set() 96 | for entity in self.norm_dict.keys(): 97 | if entity in sen: 98 | entity_set.add(self.norm_dict[entity]) 99 | if get_location(sen): 100 | entity_set.add("位置") 101 | if get_youyin(sen): 102 | entity_set.add("诱因") 103 | if get_tong(sen): 104 | entity_set.add("性质") 105 | if get_time(sen): 106 | entity_set.add("时长") 107 | return entity_set 108 | 109 | def __call__( 110 | self, 111 | references, # list(list(str)) 112 | hypothesis, # list(list(str)) 113 | ) -> None: 114 | for batch_num in range(len(references)): 115 | ref = "".join(references[batch_num]) 116 | hypo = "".join(hypothesis[batch_num]) 117 | hypo_list = self.convert_sen_to_entity_set(hypo) 118 | ref_list = self.convert_sen_to_entity_set(ref) 119 | 120 | self._total_true += len(ref_list) 121 | self._total_pred += len(hypo_list) 122 | for entity in hypo_list: 123 | if entity in ref_list: 124 | self._pred_true += 1 125 | 126 | def get_entity_acc(ref_dir, hyp_dir): 127 | kd_metric = KD_Metric() 128 | 129 | ref = [] 130 | with open(ref_dir, "r") as f: 131 | for line in f: 132 | text = line.split()[-1] 133 | text_lst = [x for x in text.strip()] 134 | ref.append(text_lst) 135 | 136 | hyp = [] 137 | with open(hyp_dir, "r") as f: 138 | for line in f: 139 | text = line.split()[-1] 140 | text_lst = [x for x in text.strip()] 141 | hyp.append(text_lst) 142 | 143 | kd_metric(ref, hyp) 144 | scores = kd_metric.get_metric() 145 | return scores 146 | 147 | class NLTK_BLEU(): 148 | def __init__( 149 | self, 150 | ngram_weights: Iterable[float] = (0.25, 0.25, 0.25, 0.25), 151 | ) -> None: 152 | self._ngram_weights = ngram_weights 153 | self._scores = [] 154 | self.smoothfunc = SmoothingFunction().method7 155 | 156 | def reset(self) -> None: 157 | self._scores = [] 158 | 159 | def get_metric(self, reset: bool = False): 160 | score = 0. 161 | if len(self._scores): 162 | score = sum(self._scores) / len(self._scores) 163 | if reset: 164 | self.reset() 165 | return score 166 | 167 | def __call__( 168 | self, 169 | references, # list(list(str)) 170 | hypothesis, # list(list(str)) 171 | ) -> None: 172 | for batch_num in range(len(references)): 173 | if len(hypothesis[batch_num]) <= 1: 174 | self._scores.append(0) 175 | else: 176 | self._scores.append(sentence_bleu([references[batch_num]], hypothesis[batch_num], 177 | smoothing_function=self.smoothfunc, 178 | weights=self._ngram_weights)) 179 | 180 | def get_bleu(ref_dir, hyp_dir): 181 | bleu1 = NLTK_BLEU(ngram_weights=(1, 0, 0, 0)) 182 | bleu2 = NLTK_BLEU(ngram_weights=(0.5, 0.5, 0, 0)) 183 | bleu4 = NLTK_BLEU(ngram_weights=(0.25, 0.25, 0.25, 0.25)) 184 | 185 | ref = [] 186 | with open(ref_dir, "r") as f: 187 | for line in f: 188 | text = line.split()[-1] 189 | text_lst = [x for x in text.strip()] 190 | ref.append(text_lst) 191 | 192 | hyp = [] 193 | with open(hyp_dir, "r") as f: 194 | for line in f: 195 | text = line.split()[-1] 196 | text_lst = [x for x in text.strip()] 197 | hyp.append(text_lst) 198 | 199 | print("Num of Samples:",len(ref)) 200 | 201 | bleu1(ref, hyp) 202 | bleu2(ref, hyp) 203 | bleu4(ref, hyp) 204 | scores = { 205 | "bleu-1": bleu1.get_metric(), 206 | "bleu-2": bleu2.get_metric(), 207 | "bleu-4": bleu4.get_metric(), 208 | } 209 | return scores 210 | 211 | def get_rouge(ref_dir, hyp_dir): 212 | ref = [] 213 | with open(ref_dir, "r") as f: 214 | for line in f: 215 | text = line.split()[-1] 216 | ref.append(text.strip()) 217 | 218 | hyp = [] 219 | with open(hyp_dir, "r") as f: 220 | for line in f: 221 | text = line.split()[-1] 222 | hyp.append(text.strip()) 223 | 224 | pred_cleaned = [] 225 | target_cleaned = [] 226 | for x, y in zip(hyp, ref): 227 | if x != "" and y != "": 228 | pred_cleaned.append(" ".join(x)) 229 | target_cleaned.append(" ".join(y)) 230 | 231 | rouge = Rouge() 232 | rouge_scores = rouge.get_scores(pred_cleaned, target_cleaned, avg=True) 233 | scores = { 234 | "rouge-1": rouge_scores["rouge-1"]["f"], 235 | "rouge-2": rouge_scores["rouge-2"]["f"], 236 | "rouge-l": rouge_scores["rouge-l"]["f"], 237 | } 238 | 239 | return scores 240 | 241 | def distinct(hyp_dir): 242 | """ Calculate intra/inter distinct 1/2. """ 243 | hyp = [] 244 | with open(hyp_dir, "r") as f: 245 | for line in f: 246 | text = line.split()[-1] 247 | hyp.append(text.strip()) 248 | 249 | batch_size = len(hyp) 250 | intra_dist1, intra_dist2 = [], [] 251 | unigrams_all, bigrams_all = Counter(), Counter() 252 | for seq in hyp: 253 | unigrams = Counter(seq) 254 | bigrams = Counter(zip(seq, seq[1:])) 255 | intra_dist1.append((len(unigrams)+1e-12) / (len(seq)+1e-5)) 256 | intra_dist2.append((len(bigrams)+1e-12) / (max(0, len(seq)-1)+1e-5)) 257 | 258 | unigrams_all.update(unigrams) 259 | bigrams_all.update(bigrams) 260 | 261 | inter_dist1 = (len(unigrams_all)+1e-12) / (sum(unigrams_all.values())+1e-5) 262 | inter_dist2 = (len(bigrams_all)+1e-12) / (sum(bigrams_all.values())+1e-5) 263 | intra_dist1 = np.average(intra_dist1) 264 | intra_dist2 = np.average(intra_dist2) 265 | return intra_dist1, intra_dist2, inter_dist1, inter_dist2 266 | 267 | if __name__ == "__main__": 268 | parser = argparse.ArgumentParser() 269 | parser.add_argument("--hp", help="generated samples") 270 | parser.add_argument("--rf", help="reference samples") 271 | args = parser.parse_args() 272 | 273 | scores = get_bleu(args.rf, args.hp) 274 | scores.update(get_entity_acc(args.rf, args.hp)) 275 | scores.update(get_rouge(args.rf, args.hp)) 276 | print("BLEU-1:", round(scores["bleu-1"]*100, 2), end="\t") 277 | print("BLEU-2:", round(scores["bleu-2"]*100, 2), end="\t") 278 | print("BLEU-4:", round(scores["bleu-4"]*100, 2)) 279 | print("Rouge-1:", round(scores["rouge-1"]*100, 2), end="\t") 280 | print("Rouge-2:", round(scores["rouge-2"]*100, 2), end="\t") 281 | print("Rouge-L:", round(scores["rouge-l"]*100, 2)) 282 | print("Entity-F1:", round(scores["f1"]*100, 2)) 283 | print("Entity-R:", round(scores["rec"]*100, 2)) 284 | print("Entity-P:", round(scores["acc"]*100, 2)) 285 | print(distinct(args.hp)) 286 | -------------------------------------------------------------------------------- /generation/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import torch 4 | from torch.nn import CrossEntropyLoss 5 | from typing import List, Optional, Tuple, Union, Any, Dict 6 | from transformers.utils import ModelOutput 7 | from transformers.modeling_outputs import Seq2SeqLMOutput 8 | 9 | from modeling_bart import BartForConditionalGeneration, shift_tokens_right 10 | 11 | logger = logging.getLogger(__name__) 12 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 13 | datefmt = '%d %H:%M:%S', 14 | level = logging.INFO) 15 | 16 | class Generator(BartForConditionalGeneration): 17 | def __init__(self, config): 18 | super(Generator, self).__init__(config) 19 | # Initialize weights and apply final processing 20 | self.post_init() 21 | 22 | def forward( 23 | self, 24 | input_ids: torch.LongTensor = None, 25 | attention_mask: Optional[torch.Tensor] = None, 26 | decoder_input_ids: Optional[torch.LongTensor] = None, 27 | decoder_attention_mask: Optional[torch.LongTensor] = None, 28 | entity_input_ids: torch.LongTensor = None, 29 | entity_attention_mask: Optional[torch.Tensor] = None, 30 | head_mask: Optional[torch.Tensor] = None, 31 | decoder_head_mask: Optional[torch.Tensor] = None, 32 | cross_attn_head_mask: Optional[torch.Tensor] = None, 33 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 34 | entity_encoder_outputs: Optional[List[torch.FloatTensor]] = None, 35 | past_key_values: Optional[List[torch.FloatTensor]] = None, 36 | inputs_embeds: Optional[torch.FloatTensor] = None, 37 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 38 | labels: Optional[torch.LongTensor] = None, 39 | use_cache: Optional[bool] = None, 40 | output_attentions: Optional[bool] = None, 41 | output_hidden_states: Optional[bool] = None, 42 | return_dict: Optional[bool] = None, 43 | ) -> Union[Tuple, Seq2SeqLMOutput]: 44 | r""" 45 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 46 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 47 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 48 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 49 | 50 | Returns: 51 | """ 52 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 53 | 54 | if labels is not None: 55 | if use_cache: 56 | logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") 57 | use_cache = False 58 | if decoder_input_ids is None and decoder_inputs_embeds is None: 59 | decoder_input_ids = shift_tokens_right( 60 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 61 | ) 62 | 63 | outputs = self.model( 64 | input_ids, 65 | attention_mask=attention_mask, 66 | decoder_input_ids=decoder_input_ids, 67 | encoder_outputs=encoder_outputs, 68 | decoder_attention_mask=decoder_attention_mask, 69 | entity_input_ids=entity_input_ids, 70 | entity_attention_mask=entity_attention_mask, 71 | entity_encoder_outputs=entity_encoder_outputs, 72 | head_mask=head_mask, 73 | decoder_head_mask=decoder_head_mask, 74 | cross_attn_head_mask=cross_attn_head_mask, 75 | past_key_values=past_key_values, 76 | inputs_embeds=inputs_embeds, 77 | decoder_inputs_embeds=decoder_inputs_embeds, 78 | use_cache=use_cache, 79 | output_attentions=output_attentions, 80 | output_hidden_states=output_hidden_states, 81 | return_dict=return_dict, 82 | ) 83 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 84 | 85 | masked_lm_loss = None 86 | if labels is not None: 87 | loss_fct = CrossEntropyLoss() 88 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 89 | 90 | if not return_dict: 91 | output = (lm_logits,) + outputs[1:] 92 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 93 | 94 | return Seq2SeqLMOutput( 95 | loss=masked_lm_loss, 96 | logits=lm_logits, 97 | past_key_values=outputs.past_key_values, 98 | decoder_hidden_states=outputs.decoder_hidden_states, 99 | decoder_attentions=outputs.decoder_attentions, 100 | cross_attentions=outputs.cross_attentions, 101 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 102 | encoder_hidden_states=outputs.encoder_hidden_states, 103 | encoder_attentions=outputs.encoder_attentions, 104 | ) 105 | 106 | def _prepare_encoder_decoder_kwargs_for_generation( 107 | self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None 108 | ) -> Dict[str, Any]: 109 | # 1. get encoder 110 | encoder = self.get_encoder() 111 | 112 | # 2. prepare encoder args and encoder kwargs from model kwargs 113 | irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] 114 | encoder_kwargs = { 115 | argument: value 116 | for argument, value in model_kwargs.items() 117 | if not any(argument.startswith(p) for p in irrelevant_prefix) 118 | } 119 | 120 | entity_input_ids = encoder_kwargs.pop("entity_input_ids", None) 121 | entity_attention_mask = encoder_kwargs.pop("entity_attention_mask", None) 122 | 123 | # 3. make sure that encoder returns `ModelOutput` 124 | model_input_name = model_input_name if model_input_name is not None else self.main_input_name 125 | encoder_kwargs["return_dict"] = True 126 | encoder_kwargs[model_input_name] = inputs_tensor 127 | model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) 128 | 129 | encoder_kwargs[model_input_name] = entity_input_ids 130 | encoder_kwargs["attention_mask"] = entity_attention_mask 131 | model_kwargs["entity_encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) 132 | 133 | return model_kwargs 134 | 135 | def prepare_inputs_for_generation( 136 | self, 137 | decoder_input_ids, 138 | past=None, 139 | attention_mask=None, 140 | head_mask=None, 141 | decoder_head_mask=None, 142 | cross_attn_head_mask=None, 143 | use_cache=None, 144 | encoder_outputs=None, 145 | entity_encoder_outputs=None, 146 | entity_attention_mask=None, 147 | **kwargs 148 | ): 149 | # cut decoder_input_ids if past is used 150 | if past is not None: 151 | decoder_input_ids = decoder_input_ids[:, -1:] 152 | 153 | return { 154 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 155 | "encoder_outputs": encoder_outputs, 156 | "entity_encoder_outputs": entity_encoder_outputs, 157 | "entity_attention_mask": entity_attention_mask, 158 | "past_key_values": past, 159 | "decoder_input_ids": decoder_input_ids, 160 | "attention_mask": attention_mask, 161 | "head_mask": head_mask, 162 | "decoder_head_mask": decoder_head_mask, 163 | "cross_attn_head_mask": cross_attn_head_mask, 164 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 165 | } 166 | 167 | def _expand_inputs_for_generation( 168 | self, 169 | input_ids: torch.LongTensor, 170 | expand_size: int = 1, 171 | is_encoder_decoder: bool = False, 172 | attention_mask: Optional[torch.LongTensor] = None, 173 | entity_attention_mask: Optional[torch.LongTensor] = None, 174 | encoder_outputs: Optional[ModelOutput] = None, 175 | entity_encoder_outputs: Optional[ModelOutput] = None, 176 | **model_kwargs, 177 | ) -> Tuple[torch.LongTensor, Dict[str, Any]]: 178 | expanded_return_idx = ( 179 | torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) 180 | ) 181 | input_ids = input_ids.index_select(0, expanded_return_idx) 182 | 183 | if "token_type_ids" in model_kwargs: 184 | token_type_ids = model_kwargs["token_type_ids"] 185 | model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) 186 | 187 | if attention_mask is not None: 188 | model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) 189 | model_kwargs["entity_attention_mask"] = entity_attention_mask.index_select(0, expanded_return_idx) 190 | 191 | if is_encoder_decoder: 192 | if encoder_outputs is None: 193 | raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") 194 | encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 195 | 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) 196 | ) 197 | model_kwargs["encoder_outputs"] = encoder_outputs 198 | entity_encoder_outputs["last_hidden_state"] = entity_encoder_outputs.last_hidden_state.index_select( 199 | 0, expanded_return_idx.to(entity_encoder_outputs.last_hidden_state.device) 200 | ) 201 | model_kwargs["entity_encoder_outputs"] = entity_encoder_outputs 202 | return input_ids, model_kwargs -------------------------------------------------------------------------------- /generation/parsing.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | 4 | def run_parse_args(): 5 | parser = argparse.ArgumentParser() 6 | 7 | ## Required parameters 8 | parser.add_argument("--mode", type=str, default="train", help="select from [train, test, inference]") 9 | parser.add_argument("--output_dir", type=str, default="./train") 10 | parser.add_argument("--base_model_path", type=str, default="../bart-base-chinese") 11 | parser.add_argument("--data_name", type=str, default="kamed", help="select from [kamed, meddg]") 12 | parser.add_argument("--data_type", type=str, default="train", help="select from [train, test, valid]") 13 | parser.add_argument("--train_name", type=str, default=None, help="for recording the objective of this training") 14 | parser.add_argument("--for_meddg_160", action='store_true', help="use predicted entities on MedDG dataset") 15 | parser.add_argument("--k_entity", type=int, default=None, help="the retrieved top-k entities for response generation") 16 | parser.add_argument("--act_entity_dir", type=str, default=None, help="the directory of the predicted acts and entities") 17 | parser.add_argument("--result_save_dir", type=str, default=None, help="the directory of the generated results") 18 | 19 | ## General parameters 20 | parser.add_argument("--num_beams", type=int, default=1) 21 | parser.add_argument("--top_k", type=int, default=None) 22 | parser.add_argument("--top_p", type=float, default=1) 23 | parser.add_argument("--decode_max_length", type=int, default=150) 24 | parser.add_argument("--eval_model_path", type=str, default=None) 25 | parser.add_argument("--per_gpu_eval_batch_size", default=16, type=int) 26 | parser.add_argument("--per_gpu_train_batch_size", default=16, type=int) 27 | parser.add_argument("--gradient_accumulation_steps", type=int, default=2) 28 | 29 | parser.add_argument("--no_cuda", action='store_true') 30 | parser.add_argument('--seed', type=int, default=42) 31 | 32 | parser.add_argument("--evaluate_during_training", action="store_true") 33 | 34 | parser.add_argument("--logging_steps", type=int, default=100) 35 | parser.add_argument("--data_num_workers", default=0, type=int) 36 | 37 | parser.add_argument("--lr", default=1e-5, type=float) 38 | parser.add_argument("--weight_decay", default=0.01, type=float) 39 | parser.add_argument("--warmup_steps", default=1000, type=int) 40 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 41 | parser.add_argument("--max_grad_norm", default=1.0, type=float) 42 | parser.add_argument("--num_train_epochs", default=20, type=int) 43 | 44 | args = parser.parse_args() 45 | 46 | if args.train_name: 47 | args.log_dir = f"{args.output_dir}/log/{args.train_name}" 48 | args.model_save_dir = f"{args.output_dir}/models/{args.train_name}" 49 | else: 50 | time_stamp = time.strftime("%b-%d_%H:%M:%S", time.localtime()) 51 | args.log_dir = f"{args.output_dir}/log/{time_stamp}" 52 | args.model_save_dir = f"{args.output_dir}/models/{time_stamp}" 53 | return args 54 | -------------------------------------------------------------------------------- /generation/train_kamed.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --mode train \ 3 | --data_name kamed \ 4 | --data_type train \ 5 | --evaluate_during_training \ 6 | --per_gpu_train_batch_size 4 \ 7 | --output_dir ./train \ 8 | --act_entity_dir ../results/df_results/kamed \ 9 | --lr 2e-5 \ 10 | --num_train_epochs 10 \ 11 | --k_entity 10 \ 12 | --train_name demo_kamed \ 13 | -------------------------------------------------------------------------------- /generation/train_meddg.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --mode train \ 3 | --data_name meddg \ 4 | --data_type train \ 5 | --evaluate_during_training \ 6 | --per_gpu_train_batch_size 4 \ 7 | --output_dir ./train \ 8 | --act_entity_dir ../results/df_results/meddg \ 9 | --lr 3e-5 \ 10 | --num_train_epochs 10 \ 11 | --for_meddg_160 \ 12 | --train_name demo_meddg \ 13 | -------------------------------------------------------------------------------- /generation/training.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import logging 5 | from tqdm import tqdm, trange 6 | from torch.utils.tensorboard import SummaryWriter 7 | from torch.optim import AdamW 8 | from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup 9 | from transformers.trainer_pt_utils import get_parameter_names 10 | 11 | from inference import inference 12 | from evaluating import evaluate 13 | from utils import save_model, set_seed 14 | 15 | logger = logging.getLogger(__name__) 16 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 17 | datefmt = '%d %H:%M:%S', 18 | level = logging.INFO) 19 | 20 | def get_optimizer(args, model): 21 | # Parameters with decaying 22 | decay_parameters = get_parameter_names(model, [nn.LayerNorm]) 23 | decay_parameters = [ 24 | name for name in decay_parameters if "bias" not in name 25 | ] 26 | 27 | optimizer_grouped_parameters = [ 28 | { 29 | "params": [ 30 | p for n, p in model.named_parameters() 31 | if n in decay_parameters 32 | ], 33 | "lr": 34 | args.lr, 35 | "weight_decay": 36 | args.weight_decay, 37 | }, 38 | { 39 | "params": [ 40 | p for n, p in model.named_parameters() 41 | if n not in decay_parameters 42 | ], 43 | "lr": 44 | args.lr, 45 | "weight_decay": 46 | 0.0, 47 | }, 48 | ] 49 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon) 50 | return optimizer 51 | 52 | def run_inf(args, model, dataset, dataloader, tb_writer, global_step, prefix, data_type): 53 | bleu_1, bleu_4, f1, rouge_1, rouge_2 = inference(args, model, dataset, dataloader, prefix=prefix) 54 | tb_writer.add_scalar(f'{data_type}/bleu-1', bleu_1, global_step) 55 | tb_writer.add_scalar(f'{data_type}/bleu-4', bleu_4, global_step) 56 | tb_writer.add_scalar(f'{data_type}/entity-f1', f1, global_step) 57 | tb_writer.add_scalar(f'{data_type}/rouge-1', rouge_1, global_step) 58 | tb_writer.add_scalar(f'{data_type}/rouge-2', rouge_2, global_step) 59 | 60 | def train(args, model, train_dataset, train_dataloader, val_dataset, val_dataloader, test_dataset, test_dataloader): 61 | # Train the model 62 | tb_writer = SummaryWriter(args.log_dir) 63 | 64 | # Total steps 65 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 66 | 67 | # Prepare optimizer and schedule (linear warmup and decay) 68 | optimizer = get_optimizer(args, model) 69 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 70 | 71 | # multi-gpu training (should be after apex fp16 initialization) 72 | if args.n_gpu > 1: 73 | model = torch.nn.DataParallel(model) 74 | 75 | # Train! 76 | logger.info("***** Running Training *****") 77 | logger.info(" Num examples = %d", len(train_dataset)) 78 | logger.info(" Num Epochs = %d", args.num_train_epochs) 79 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 80 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 81 | args.train_batch_size * args.gradient_accumulation_steps) 82 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 83 | logger.info(" Total optimization steps = %d", t_total) 84 | 85 | global_step = 0 86 | tr_loss, logging_loss = 0.0, 0.0 87 | model.zero_grad() 88 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch") 89 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 90 | for epoch_idx in train_iterator: 91 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 92 | for step, (batch, idx, turn_idx) in enumerate(epoch_iterator): 93 | 94 | batch = {k:v.to(args.device) for k, v in batch.items()} 95 | model.train() 96 | outputs = model(**batch) 97 | loss = outputs.loss 98 | 99 | if args.n_gpu > 1: 100 | loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training 101 | if args.gradient_accumulation_steps > 1: 102 | loss = loss / args.gradient_accumulation_steps 103 | loss.backward() 104 | tr_loss += loss.item() 105 | 106 | if (step + 1) % args.gradient_accumulation_steps == 0: 107 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 108 | optimizer.step() 109 | scheduler.step() # Update learning rate schedule 110 | model.zero_grad() 111 | optimizer.zero_grad() 112 | global_step += 1 113 | 114 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 115 | cur_loss = (tr_loss - logging_loss)/args.logging_steps 116 | tb_writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step) 117 | tb_writer.add_scalar('train/loss', cur_loss, global_step) 118 | logging_loss = tr_loss 119 | 120 | if epoch_idx >= 5 and global_step % 5000 == 0: 121 | # Save model checkpoint 122 | save_model(model, args.model_save_dir, 'step-{}'.format(global_step), args) 123 | 124 | run_inf(args, model, val_dataset, val_dataloader, tb_writer, global_step, prefix="step-{}".format(global_step), data_type="valid") 125 | 126 | # Save model checkpoint 127 | save_model(model, args.model_save_dir, 'epoch-{}'.format(epoch_idx+1), args) 128 | 129 | if args.evaluate_during_training: 130 | test_loss, test_gen_accuray = evaluate(args, model, val_dataset, val_dataloader, prefix="epoch-{}".format(epoch_idx+1)) 131 | tb_writer.add_scalar('valid/loss', test_loss, epoch_idx+1) 132 | tb_writer.add_scalar('valid/gen_accuray', test_gen_accuray, epoch_idx+1) 133 | 134 | run_inf(args, model, val_dataset, val_dataloader, tb_writer, global_step, prefix="epoch-{}".format(epoch_idx+1), data_type="valid") 135 | -------------------------------------------------------------------------------- /generation/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | 6 | def set_seed(args): 7 | random.seed(args.seed) 8 | np.random.seed(args.seed) 9 | torch.manual_seed(args.seed) 10 | torch.cuda.manual_seed(args.seed) 11 | torch.cuda.manual_seed_all(args.seed) 12 | torch.backends.cudnn.deterministic = True 13 | 14 | def save_model(model, output_dir, save_name, args): 15 | save_dir = os.path.join(output_dir, save_name) 16 | if not os.path.exists(save_dir): 17 | os.makedirs(save_dir) 18 | model.save_pretrained(save_dir) 19 | torch.save(args, os.path.join(save_dir, 'training_args.bin')) 20 | 21 | def squeeze_lst(lst): 22 | tmp = [] 23 | for x in lst: 24 | if x not in tmp: 25 | tmp.append(x) 26 | return tmp 27 | 28 | def get_cmekg_entity_specific(data_name): 29 | entity_lst = [] 30 | entity_dict = dict() 31 | with open(f"../data/cmekg/entities_{data_name}.txt", "r") as f: 32 | for i, line in enumerate(f): 33 | entity_lst.append(line.strip()) 34 | entity_dict[line.strip()] = i 35 | return entity_lst, entity_dict 36 | -------------------------------------------------------------------------------- /images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaishxu/DFMed/ad2f57eb77a9f093d5fec24cf464e96e1e384bd4/images/framework.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | alabaster==0.7.12 3 | anyio==3.6.2 4 | argon2-cffi==21.3.0 5 | argon2-cffi-bindings==21.2.0 6 | asttokens==2.0.8 7 | attrs==22.1.0 8 | Babel==2.10.3 9 | backcall==0.2.0 10 | beautifulsoup4==4.11.1 11 | bleach==5.0.1 12 | blessed==1.20.0 13 | blis==0.2.4 14 | boto3==1.24.2 15 | botocore==1.27.2 16 | cachetools==4.2.4 17 | certifi==2022.9.24 18 | cffi==1.15.1 19 | charset-normalizer==2.1.1 20 | click==8.1.3 21 | conllu==1.3.1 22 | cycler==0.11.0 23 | cymem==2.0.6 24 | Cython==0.29.32 25 | debugpy==1.6.3 26 | decorator==5.1.1 27 | defusedxml==0.7.1 28 | docutils==0.19 29 | editdistance==0.6.0 30 | entrypoints==0.4 31 | executing==1.1.1 32 | fastjsonschema==2.16.2 33 | filelock==3.8.0 34 | flaky==3.7.0 35 | Flask==2.2.2 36 | Flask-Cors==3.0.10 37 | fonttools==4.32.0 38 | ftfy==6.1.1 39 | fuzzywuzzy==0.18.0 40 | gensim==3.8.3 41 | gevent==22.10.1 42 | greenlet==1.1.3.post0 43 | grpcio==1.43.0 44 | h5py==3.7.0 45 | huggingface-hub==0.10.1 46 | idna==3.4 47 | imagesize==1.4.1 48 | importlib-metadata==4.10.0 49 | importlib-resources==5.10.0 50 | iniconfig==1.1.1 51 | ipywidgets==8.0.2 52 | itsdangerous==2.1.2 53 | jedi==0.18.1 54 | Jinja2==3.1.2 55 | jmespath==1.0.0 56 | joblib==1.2.0 57 | jsonnet==0.18.0 58 | jsonpickle==2.2.0 59 | jsonschema==4.17.0 60 | kiwisolver==1.4.4 61 | Markdown==3.3.6 62 | MarkupSafe==2.1.1 63 | matplotlib==3.5.1 64 | matplotlib-inline==0.1.6 65 | mistune==2.0.4 66 | murmurhash==1.0.7 67 | nbclassic==0.4.8 68 | nbclient==0.7.0 69 | nbconvert==7.2.3 70 | nbformat==5.7.0 71 | nest-asyncio==1.5.6 72 | nltk==3.4.1 73 | notebook==6.5.2 74 | notebook_shim==0.2.2 75 | numpy==1.23.4 76 | numpydoc==1.5.0 77 | nvidia-ml-py==12.535.108 78 | oauthlib==3.1.1 79 | overrides==3.1.0 80 | packaging==21.3 81 | pandas==1.5.2 82 | pandocfilters==1.5.0 83 | parsimonious==0.9.0 84 | parso==0.8.3 85 | pexpect==4.8.0 86 | pickleshare==0.7.5 87 | Pillow==9.2.0 88 | pkgutil_resolve_name==1.3.10 89 | plac==0.9.6 90 | pluggy==1.0.0 91 | preshed==2.0.1 92 | prometheus-client==0.15.0 93 | prompt-toolkit==3.0.31 94 | protobuf==3.19.1 95 | psutil==5.9.3 96 | ptyprocess==0.7.0 97 | pure-eval==0.2.2 98 | py==1.11.0 99 | pyasn1==0.4.8 100 | pyasn1-modules==0.2.8 101 | pycparser==2.21 102 | Pygments==2.13.0 103 | pygtrie==2.4.2 104 | pyparsing==3.0.9 105 | pyrsistent==0.19.2 106 | pytest==7.1.3 107 | python-dateutil==2.8.2 108 | pytorch-pretrained-bert==0.6.2 109 | pytorch-transformers==1.1.0 110 | pytz==2022.5 111 | PyYAML==6.0 112 | pyzmq==24.0.1 113 | qtconsole==5.4.0 114 | QtPy==2.2.1 115 | rapidfuzz==2.13.2 116 | regex==2022.9.13 117 | requests==2.28.1 118 | requests-oauthlib==1.3.0 119 | responses==0.21.0 120 | rouge==1.0.1 121 | rsa==4.8 122 | s3transfer==0.6.0 123 | sacremoses==0.0.46 124 | scikit-learn==1.1.2 125 | scipy==1.9.2 126 | Send2Trash==1.8.0 127 | sentencepiece==0.1.96 128 | setuptools==59.5.0 129 | six==1.16.0 130 | smart-open==5.2.1 131 | sniffio==1.3.0 132 | snowballstemmer==2.2.0 133 | soupsieve==2.3.2.post1 134 | Sphinx==5.3.0 135 | sphinxcontrib-applehelp==1.0.2 136 | sphinxcontrib-devhelp==1.0.2 137 | sphinxcontrib-htmlhelp==2.0.0 138 | sphinxcontrib-jsmath==1.0.1 139 | sphinxcontrib-qthelp==1.0.3 140 | sphinxcontrib-serializinghtml==1.1.5 141 | sqlparse==0.4.2 142 | srsly==1.0.5 143 | stack-data==0.5.1 144 | synonyms==3.16.0 145 | tensorboard==2.7.0 146 | tensorboard-data-server==0.6.1 147 | tensorboard-plugin-wit==1.8.0 148 | tensorboardX==2.5 149 | termcolor==1.1.0 150 | terminado==0.17.0 151 | thinc==7.0.8 152 | threadpoolctl==3.1.0 153 | tinycss2==1.2.1 154 | tokenizers==0.13.1 155 | tomli==2.0.1 156 | tornado==6.2 157 | tqdm==4.64.1 158 | traitlets==5.5.0 159 | transformers==4.23.1 160 | typing_extensions==4.2.0 161 | typing-utils==0.1.0 162 | Unidecode==1.3.4 163 | urllib3==1.26.12 164 | wasabi==0.9.1 165 | wcwidth==0.2.5 166 | webencodings==0.5.1 167 | websocket-client==1.4.2 168 | Werkzeug==2.2.2 169 | wheel==0.37.1 170 | widgetsnbextension==4.0.3 171 | word2number==1.1 172 | xdg==5.1.1 173 | zipp==3.9.0 174 | zope.event==4.5.0 175 | zope.interface==5.5.0 --------------------------------------------------------------------------------