├── data ├── idx2ndc.pkl ├── idx2drug.pkl ├── ndc2drug.pkl ├── ddi_A_final.pkl ├── ddi_mask_H.pkl ├── ehr_adj_final.pkl ├── ddi_mask_H.py ├── get_SMILES.py ├── processing.py └── drug-atc.csv ├── src ├── __pycache__ │ ├── beam.cpython-39.pyc │ ├── loss.cpython-39.pyc │ ├── util.cpython-39.pyc │ ├── layers.cpython-39.pyc │ ├── models.cpython-39.pyc │ ├── model_v2.cpython-39.pyc │ ├── recommend.cpython-39.pyc │ ├── data_loader.cpython-39.pyc │ └── ablation_model.cpython-39.pyc ├── loss.py ├── layers.py ├── beam.py ├── COGNet.py ├── data_loader.py ├── MICRON.py ├── recommend.py ├── util.py ├── COGNet_model.py └── models.py ├── mimic_env.yaml └── README.md /data/idx2ndc.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/idx2ndc.pkl -------------------------------------------------------------------------------- /data/idx2drug.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/idx2drug.pkl -------------------------------------------------------------------------------- /data/ndc2drug.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/ndc2drug.pkl -------------------------------------------------------------------------------- /data/ddi_A_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/ddi_A_final.pkl -------------------------------------------------------------------------------- /data/ddi_mask_H.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/ddi_mask_H.pkl -------------------------------------------------------------------------------- /data/ehr_adj_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/ehr_adj_final.pkl -------------------------------------------------------------------------------- /src/__pycache__/beam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/beam.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/layers.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/model_v2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/model_v2.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/recommend.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/recommend.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/data_loader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/data_loader.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/ablation_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/ablation_model.cpython-39.pyc -------------------------------------------------------------------------------- /data/ddi_mask_H.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from rdkit.Chem import BRICS 3 | import dill 4 | import numpy as np 5 | # from pinky.smiles import smilin 6 | # from pinky.fingerprints import ecfp 7 | 8 | NDCList = dill.load(open('./idx2drug.pkl', 'rb')) 9 | voc = dill.load(open('./voc_final.pkl', 'rb')) 10 | med_voc = voc['med_voc'] 11 | 12 | """ 13 | fraction = [] 14 | for k, v in med_voc.idx2word.items(): 15 | tempF = [] 16 | for SMILES in NDCList[v]: 17 | try: 18 | m = ecfp(smilin(SMILES), radius=2) 19 | tempF.append(m) 20 | except: 21 | pass 22 | fraction.append(tempF) 23 | 24 | 25 | """ 26 | NDCList[22] = {0} 27 | NDCList[25] = {0} 28 | NDCList[27] = {0} 29 | fraction = [] 30 | for k, v in med_voc.idx2word.items(): 31 | tempF = set() 32 | for SMILES in NDCList[v]: 33 | try: 34 | m = BRICS.BRICSDecompose(Chem.MolFromSmiles(SMILES)) 35 | for frac in m: 36 | tempF.add(frac) 37 | except: 38 | pass 39 | 40 | fraction.append(tempF) 41 | 42 | fracSet = [] 43 | for i in fraction: 44 | fracSet += i 45 | fracSet = list(set(fracSet)) 46 | 47 | ddi_matrix = np.zeros((len(med_voc.idx2word), len(fracSet))) 48 | 49 | for i, fracList in enumerate(fraction): 50 | for frac in fracList: 51 | ddi_matrix[i, fracSet.index(frac)] = 1 52 | 53 | dill.dump(ddi_matrix, open('ddi_mask_H.pkl', 'wb')) 54 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class cross_entropy_loss(nn.Module): 7 | def __init__(self, device): 8 | super().__init__() 9 | self.device = device 10 | 11 | def forward(self, labels, logits, seq_length, m_length_matrix, med_num, END_TOKEN): 12 | # labels: [batch_size, max_seq_length, max_med_num] 13 | # logits: [batch_size, max_seq_length, max_med_num, med_num] 14 | # m_length_matrix: [batch_size, seq_length] 15 | # seq_length: [batch_size] 16 | 17 | batch_size, max_seq_length = labels.size()[:2] 18 | assert max_seq_length == max(seq_length) 19 | whole_seqs_num = seq_length.sum().item() 20 | whole_med_sum = sum([sum(buf) for buf in m_length_matrix]) + whole_seqs_num 21 | 22 | labels_flatten = torch.empty(whole_med_sum).to(self.device) 23 | logits_flatten = torch.empty(whole_med_sum, med_num).to(self.device) 24 | 25 | start_idx = 0 26 | for i in range(batch_size): 27 | for j in range(seq_length[i]): 28 | for k in range(m_length_matrix[i][j]+1): 29 | if k==m_length_matrix[i][j]: 30 | labels_flatten[start_idx] = END_TOKEN 31 | else: 32 | labels_flatten[start_idx] = labels[i, j, k] 33 | logits_flatten[start_idx, :] = logits[i, j, k, :] 34 | start_idx += 1 35 | 36 | 37 | loss = F.cross_entropy(logits_flatten, labels_flatten.long()) 38 | return loss 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.nn.parameter import Parameter 7 | 8 | 9 | class GraphConvolution(nn.Module): 10 | """ 11 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 12 | """ 13 | 14 | def __init__(self, in_features, out_features, bias=True): 15 | super(GraphConvolution, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 19 | if bias: 20 | self.bias = Parameter(torch.FloatTensor(out_features)) 21 | else: 22 | self.register_parameter('bias', None) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | stdv = 1. / math.sqrt(self.weight.size(1)) 27 | self.weight.data.uniform_(-stdv, stdv) 28 | if self.bias is not None: 29 | self.bias.data.uniform_(-stdv, stdv) 30 | 31 | def forward(self, input, adj): 32 | support = torch.mm(input, self.weight) 33 | output = torch.mm(adj, support) 34 | if self.bias is not None: 35 | return output + self.bias 36 | else: 37 | return output 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + ' (' \ 41 | + str(self.in_features) + ' -> ' \ 42 | + str(self.out_features) + ')' 43 | 44 | 45 | class SelfAttend(nn.Module): 46 | def __init__(self, embedding_size: int) -> None: 47 | super(SelfAttend, self).__init__() 48 | 49 | self.h1 = nn.Sequential( 50 | nn.Linear(embedding_size, 32), 51 | nn.Tanh() 52 | ) 53 | 54 | self.gate_layer = nn.Linear(32, 1) 55 | 56 | def forward(self, seqs, seq_masks=None): 57 | """ 58 | :param seqs: shape [batch_size, seq_length, embedding_size] 59 | :param seq_lens: shape [batch_size, seq_length] 60 | :return: shape [batch_size, seq_length, embedding_size] 61 | """ 62 | gates = self.gate_layer(self.h1(seqs)).squeeze(-1) 63 | if seq_masks is not None: 64 | gates = gates + seq_masks 65 | p_attn = F.softmax(gates, dim=-1) 66 | p_attn = p_attn.unsqueeze(-1) 67 | h = seqs * p_attn 68 | output = torch.sum(h, dim=1) 69 | return output 70 | -------------------------------------------------------------------------------- /data/get_SMILES.py: -------------------------------------------------------------------------------- 1 | import dill 2 | import numpy as np 3 | import pandas as pd 4 | import requests 5 | import re 6 | 7 | # fix mismatch between two mappings 8 | def fix_mismatch(idx2atc, atc2ndc, ndc2atc_original_path): 9 | ndc2atc = pd.read_csv(open(ndc2atc_original_path, 'rb')) 10 | ndc2atc.ATC4 = ndc2atc.ATC4.apply(lambda x: x[:4]) 11 | 12 | mismatch = [] 13 | for k, v in idx2atc.items(): 14 | if v in atc2ndc.NDC.tolist(): 15 | pass 16 | else: 17 | mismatch.append(v) 18 | 19 | for i in mismatch: 20 | atc2ndc = atc2ndc.append({'NDC': i, 'NDC_orig': [s.replace('-', '') for s in ndc2atc[ndc2atc.ATC4 == i].NDC.tolist()]}, ignore_index=True) 21 | 22 | atc2ndc = atc2ndc.append({'NDC': 'seperator', 'NDC_orig': []}, ignore_index=True) 23 | atc2ndc = atc2ndc.append({'NDC': 'decoder_point', 'NDC_orig': []}, ignore_index=True) 24 | 25 | return atc2ndc 26 | 27 | def ndc2smiles(NDC): 28 | url3 = 'https://ndclist.com/?s=' + NDC 29 | r3 = requests.get(url3) 30 | name = re.findall('(.+?)', r3.text)[0] 31 | 32 | url = 'https://dev.drugbankplus.com/guides/tutorials/api_request?request_path=us/product_concepts?q=' + name 33 | r = requests.get(url) 34 | drugbankID = re.findall('(DB\d+)', r.text)[0] 35 | 36 | # re matching might need to update (drugbank may change their html script) 37 | url2 = 'https://www.drugbank.ca/drugs/' + drugbankID 38 | r2 = requests.get(url2) 39 | SMILES = re.findall('SMILES
(.+?)
', r2.text)[0] 40 | return SMILES 41 | 42 | def atc2smiles(atc2ndc): 43 | atc2SMILES = {} 44 | for k, ndc in atc2ndc.values: 45 | if k not in list(atc2SMILES.keys()): 46 | for index, code in enumerate(ndc): 47 | if index > 100: break 48 | try: 49 | SMILES = ndc2smiles(code) 50 | if 'href' in SMILES: 51 | continue 52 | print (k, index, len(ndc), SMILES) 53 | if k not in atc2SMILES: 54 | atc2SMILES[k] = set() 55 | atc2SMILES[k].add(SMILES) 56 | # if len(atc2SMILES[k]) >= 3: 57 | # break 58 | except: 59 | pass 60 | return atc2SMILES 61 | 62 | 63 | def idx2smiles(idx2atc, atc2SMILES): 64 | idx2drug = {} 65 | idx2drug['seperator'] = {} 66 | idx2drug['decoder_point'] = {} 67 | 68 | for idx, atc in idx2atc.items(): 69 | try: 70 | idx2drug[idx] = atc2SMILES[atc] 71 | except: 72 | pass 73 | dill.dump(idx2drug, open('idx2drug.pkl', 'wb')) 74 | 75 | 76 | if __name__ == '__main__': 77 | # get idx2atc 78 | path = './voc_final.pkl' 79 | voc_final = dill.load(open(path, 'rb')) 80 | idx2atc = voc_final['med_voc'].idx2word 81 | 82 | # get atc2ndc 83 | path = './ndc2drug.pkl' 84 | atc2ndc = dill.load(open(path, 'rb')) 85 | 86 | # fix atc2ndc mismatch 87 | ndc2atc_original_path = './ndc2atc_level4.csv' 88 | atc2ndc = fix_mismatch(idx2atc, atc2ndc, ndc2atc_original_path) 89 | 90 | # atc2smiles 91 | atc2SMILES = atc2smiles(atc2ndc) 92 | 93 | # idx2smiles (dumpped) 94 | idx2smiles(idx2atc, atc2SMILES) 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /mimic_env.yaml: -------------------------------------------------------------------------------- 1 | name: mimic 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=1_gnu 9 | - autopep8=1.5.7=pyhd3eb1b0_0 10 | - blas=1.0=mkl 11 | - boost=1.74.0=py39h5472131_3 12 | - boost-cpp=1.74.0=h312852a_4 13 | - bottleneck=1.3.2=py39hdd57654_1 14 | - bzip2=1.0.8=h7f98852_4 15 | - ca-certificates=2021.7.5=h06a4308_1 16 | - cairo=1.16.0=h6cf1ce9_1008 17 | - certifi=2021.5.30=py39h06a4308_0 18 | - cudatoolkit=11.0.221=h6bb024c_0 19 | - cycler=0.10.0=py_2 20 | - dill=0.3.4=pyhd3eb1b0_0 21 | - fontconfig=2.13.1=hba837de_1005 22 | - freetype=2.10.4=h0708190_1 23 | - gettext=0.19.8.1=h0b5b191_1005 24 | - greenlet=1.1.1=py39he80948d_0 25 | - icu=68.1=h58526e2_0 26 | - intel-openmp=2021.3.0=h06a4308_3350 27 | - jbig=2.1=h7f98852_2003 28 | - joblib=1.0.1=pyhd3eb1b0_0 29 | - jpeg=9d=h36c2ea0_0 30 | - kiwisolver=1.3.1=py39h1a9c180_1 31 | - lcms2=2.12=hddcbb42_0 32 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 33 | - lerc=2.2.1=h9c3ff4c_0 34 | - libdeflate=1.7=h7f98852_5 35 | - libffi=3.3=h58526e2_2 36 | - libgcc-ng=11.1.0=hc902ee8_8 37 | - libgfortran-ng=7.5.0=ha8ba4b0_17 38 | - libgfortran4=7.5.0=ha8ba4b0_17 39 | - libglib=2.68.3=h3e27bee_0 40 | - libgomp=11.1.0=hc902ee8_8 41 | - libiconv=1.16=h516909a_0 42 | - libpng=1.6.37=h21135ba_2 43 | - libstdcxx-ng=11.1.0=h56837e0_8 44 | - libtiff=4.3.0=hf544144_1 45 | - libuuid=2.32.1=h7f98852_1000 46 | - libwebp-base=1.2.0=h7f98852_2 47 | - libxcb=1.13=h7f98852_1003 48 | - libxml2=2.9.12=h72842e0_0 49 | - lz4-c=1.9.3=h9c3ff4c_1 50 | - matplotlib-base=3.4.2=py39h2fa2bec_0 51 | - mkl=2021.3.0=h06a4308_520 52 | - mkl-service=2.4.0=py39h7f8727e_0 53 | - mkl_fft=1.3.0=py39h42c9631_2 54 | - mkl_random=1.2.2=py39h51133e4_0 55 | - ncurses=6.2=h58526e2_4 56 | - ninja=1.10.2=hff7bd54_1 57 | - numexpr=2.7.3=py39h22e1b3c_1 58 | - numpy=1.20.3=py39hf144106_0 59 | - numpy-base=1.20.3=py39h74d4b33_0 60 | - olefile=0.46=pyh9f0ad1d_1 61 | - openjpeg=2.4.0=hb52868f_1 62 | - openssl=1.1.1k=h27cfd23_0 63 | - pandas=1.3.1=py39h8c16a72_0 64 | - pcre=8.45=h9c3ff4c_0 65 | - pillow=8.3.1=py39ha612740_0 66 | - pip=21.2.3=pyhd8ed1ab_0 67 | - pixman=0.40.0=h36c2ea0_0 68 | - pthread-stubs=0.4=h36c2ea0_1001 69 | - pycairo=1.20.1=py39hedcb9fc_0 70 | - pycodestyle=2.7.0=pyhd3eb1b0_0 71 | - pyparsing=2.4.7=pyh9f0ad1d_0 72 | - python=3.9.6=h49503c6_1_cpython 73 | - python-dateutil=2.8.2=pyhd8ed1ab_0 74 | - python_abi=3.9=2_cp39 75 | - pytorch=1.7.1=py3.9_cuda11.0.221_cudnn8.0.5_0 76 | - pytz=2021.1=pyhd8ed1ab_0 77 | - rdkit=2021.03.4=py39hccf6a74_0 78 | - readline=8.1=h46c0cb4_0 79 | - reportlab=3.5.68=py39he59360d_0 80 | - scikit-learn=0.24.2=py39ha9443f7_0 81 | - scipy=1.6.2=py39had2a1c9_1 82 | - setuptools=49.6.0=py39hf3d152e_3 83 | - six=1.16.0=pyh6c4a22f_0 84 | - sqlalchemy=1.4.22=py39h3811e60_0 85 | - sqlite=3.36.0=h9cd32fc_0 86 | - threadpoolctl=2.2.0=pyhb85f177_0 87 | - tk=8.6.10=h21135ba_1 88 | - toml=0.10.2=pyhd3eb1b0_0 89 | - torchaudio=0.7.2=py39 90 | - torchvision=0.2.2=py_3 91 | - tornado=6.1=py39h3811e60_1 92 | - typing_extensions=3.10.0.0=pyh06a4308_0 93 | - tzdata=2021a=he74cb21_1 94 | - wheel=0.36.2=pyhd3deb0d_0 95 | - xorg-kbproto=1.0.7=h7f98852_1002 96 | - xorg-libice=1.0.10=h7f98852_0 97 | - xorg-libsm=1.2.3=hd9c2040_1000 98 | - xorg-libx11=1.7.2=h7f98852_0 99 | - xorg-libxau=1.0.9=h7f98852_0 100 | - xorg-libxdmcp=1.1.3=h7f98852_0 101 | - xorg-libxext=1.3.4=h7f98852_1 102 | - xorg-libxrender=0.9.10=h7f98852_1003 103 | - xorg-renderproto=0.11.1=h7f98852_1002 104 | - xorg-xextproto=7.3.0=h7f98852_1002 105 | - xorg-xproto=7.0.31=h7f98852_1007 106 | - xz=5.2.5=h516909a_1 107 | - zlib=1.2.11=h516909a_1010 108 | - zstd=1.5.0=ha95c52a_0 109 | - pip: 110 | - blessings==1.7 111 | - dnc==1.1.0 112 | - flann==1.6.13 113 | - gpustat==0.6.0 114 | - nvidia-ml-py3==7.352.0 115 | - psutil==5.8.0 116 | prefix: /home/wurui/miniconda3/envs/mimic 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of WWW 2022 paper: Conditional Generation Net for Medication Recommendation 2 | 3 | ### Folder Specification 4 | - mimic_env.yaml 5 | - src/ 6 | - COGNet.py: train/test COGNet 7 | - recommend.py: some test function used for COGNet 8 | - COGNet_modelt.py: full model of COGNet 9 | - COGNet_ablation.py: ablation models of COGNet 10 | - train/test baselines: 11 | - MICRON.py 12 | - Other code of train/test baselines can be find [here](https://github.com/ycq091044/SafeDrug). 13 | - models.py: baseline models 14 | - util.py 15 | - layer.py 16 | - data/ **(For a fair comparision, we use the same data and pre-processing scripts used in [Safedrug](https://github.com/ycq091044/SafeDrug))** 17 | - mapping files that collected from external sources 18 | - drug-atc.csv: drug to atc code mapping file 19 | - drug-DDI.csv: this a large file, could be downloaded from https://drive.google.com/file/d/1mnPc0O0ztz0fkv3HF-dpmBb8PLWsEoDz/view?usp=sharing 20 | - ndc2atc_level4.csv: NDC code to ATC-4 code mapping file 21 | - ndc2rxnorm_mapping.txt: NDC to xnorm mapping file 22 | - idx2drug.pkl: drug ID to drug SMILES string dict 23 | - other files that generated from mapping files and MIMIC dataset (we attach these files here, user could use our provided scripts to generate) 24 | - data_final.pkl: intermediate result 25 | - ddi_A_final.pkl: ddi matrix 26 | - ddi_matrix_H.pkl: H mask structure (This file is created by ddi_mask_H.py), used in Safedrug baseline 27 | - idx2ndc.pkl: idx2ndc mapping file 28 | - ndc2drug.pkl: ndc2drug mapping file 29 | - Under MIMIC Dataset policy, we are not allowed to distribute the datasets. Practioners could go to https://physionet.org/content/mimiciii/1.4/ and requrest the access to MIMIC-III dataset and then run our processing script to get the complete preprocessed dataset file. 30 | - voc_final.pkl: diag/prod/med dictionary 31 | - dataset processing scripts 32 | - processing.py: is used to process the MIMIC original dataset. 33 | 34 | 35 | 36 | 37 | ### Step 1: Data Processing 38 | 39 | - Go to https://physionet.org/content/mimiciii/1.4/ to download the MIMIC-III dataset (You may need to get the certificate) 40 | 41 | - go into the folder and unzip three main files (PROCEDURES_ICD.csv.gz, PRESCRIPTIONS.csv.gz, DIAGNOSES_ICD.csv.gz) 42 | 43 | - change the path in processing.py and processing the data to get a complete records_final.pkl 44 | 45 | ```python 46 | vim processing.py 47 | 48 | # line 310-312 49 | # med_file = '/data/mimic-iii/PRESCRIPTIONS.csv' 50 | # diag_file = '/data/mimic-iii/DIAGNOSES_ICD.csv' 51 | # procedure_file = '/data/mimic-iii/PROCEDURES_ICD.csv' 52 | 53 | python processing.py 54 | ``` 55 | 56 | - run ddi_mask_H.py to get the ddi_mask_H.pkl 57 | 58 | ```python 59 | python ddi_mask_H.py 60 | ``` 61 | 62 | 63 | 64 | ### Step 2: Package Dependency 65 | 66 | - First, install the [conda](https://www.anaconda.com/) 67 | 68 | - Then, create the conda environment through yaml file 69 | ```python 70 | conda env create -f mimic_env.yaml 71 | ``` 72 | Note: maybe you need to upgrade the PyTorch to the 1.10.0 version. (Thank [Thomaswbt](https://github.com/Thomaswbt) a lot!) 73 | 74 | 75 | ### Step 3: run the code 76 | 77 | ```python 78 | python COGNet.py 79 | ``` 80 | 81 | here is the argument: 82 | 83 | usage: COGNet.py [-h] [--Test] [--model_name MODEL_NAME] 84 | [--resume_path RESUME_PATH] [--lr LR] 85 | [--target_ddi TARGET_DDI] [--kp KP] [--dim DIM] 86 | 87 | optional arguments: 88 | -h, --help show this help message and exit 89 | --Test test mode 90 | --model_name MODEL_NAME 91 | model name 92 | --resume_path RESUME_PATH 93 | resume path 94 | --lr LR learning rate 95 | --batch_size batch size 96 | --emb_dim dimension size of embedding 97 | --max_len max number of recommended medications 98 | --beam_size number of ways in beam search 99 | 100 | If you cannot run the code on GPU, just change line 61, "cuda" to "cpu". 101 | 102 | ### Citation 103 | ```bibtex 104 | @inproceedings{wu2022cognet, 105 | title = {Conditional Generation Net for Medication Recommendation}, 106 | author = {Rui Wu, Zhaopeng Qiu, Jiacheng Jiang, Guilin Qi, and Xian Wu.}, 107 | booktitle = {{WWW} '22: The Web Conference 2022, Virtual Event, Lyon, France, April 25-29, 2022}, 108 | year = {2022} 109 | } 110 | ``` 111 | 112 | Please feel free to contact me for any question. 113 | 114 | Partial credit to previous reprostories: 115 | - https://github.com/sjy1203/GAMENet 116 | - https://github.com/ycq091044/SafeDrug 117 | - https://github.com/ycq091044/MICRON 118 | 119 | Thank [Chaoqi Yang](https://github.com/ycq091044) and [Junyuan Shang](https://github.com/sjy1203) for releasing their codes! 120 | 121 | Thank my mentor, [Zhaopeng Qiu](https://github.com/zpqiu), for helping me complete the code. 122 | -------------------------------------------------------------------------------- /src/beam.py: -------------------------------------------------------------------------------- 1 | """ Manage beam search info structure. 2 | Heavily borrowed from OpenNMT-py. 3 | For code in OpenNMT-py, please check the following link: 4 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py 5 | """ 6 | 7 | from shutil import which 8 | from util import ddi_rate_score 9 | import torch 10 | import numpy as np 11 | import copy 12 | import random 13 | from torch.autograd.grad_mode import F 14 | 15 | 16 | class Beam(object): 17 | ''' Store the necessary info for beam search ''' 18 | def __init__(self, size, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, ddi_adj, device): 19 | self.ddi_adj = ddi_adj 20 | self.PAD = PAD_TOKEN 21 | self.BOS = BOS_TOKEN 22 | self.EOS = EOS_TOKEN 23 | # print(PAD_TOKEN, EOS_TOKEN, BOS_TOKEN) 24 | 25 | self.device = device 26 | self.size = size 27 | self.done = False # 表示当前是否已经完成了beam search的过程 28 | 29 | self.beam_status = [False] * size # 用于记录每一个beam是否已经处于EOS状态 30 | 31 | self.tt = torch.cuda if device.type=='cuda' else torch 32 | 33 | # 每一个生成结果的分数,初始是beam_size个0 34 | self.scores = self.tt.FloatTensor(size).zero_() 35 | self.all_scores = [] 36 | 37 | # The backpointers at each time-step. 38 | self.prev_ks = [] 39 | 40 | # The outputs at each time-step. 41 | self.next_ys = [self.tt.LongTensor(size).fill_(self.BOS)] 42 | self.prob_list = [] 43 | 44 | 45 | def get_current_state(self, sort=True): 46 | "Get the outputs for the current timestep." 47 | if sort: 48 | return self.get_tentative_hypothesis() 49 | else: 50 | return self.get_tentative_hypothesis_wo_sort() 51 | 52 | def get_current_origin(self): 53 | "Get the backpointers for the current timestep." 54 | return self.prev_ks[-1] 55 | 56 | def advance(self, word_lk): 57 | "Update the status and check for finished or not." 58 | num_words = word_lk.size(1) 59 | if self.done: 60 | self.prev_ks.append(torch.tensor(list(range(self.size)), device=self.device)) 61 | self.next_ys.append(torch.tensor([self.EOS]*self.size, device=self.device)) 62 | self.prob_list.append([[0]*num_words, [0]*num_words]) 63 | return True 64 | 65 | active_beam_idx = torch.tensor([idx for idx in range(self.size) if self.beam_status[idx]==False]).long().to(self.device) 66 | end_beam_idx = torch.tensor([idx for idx in range(self.size) if self.beam_status[idx]==True]).long().to(self.device) 67 | active_word_lk = word_lk[active_beam_idx] # active_beam_num * num_words 68 | 69 | cur_output = self.get_current_state(sort=False) 70 | 71 | active_scores = self.scores[active_beam_idx] 72 | end_scores = self.scores[end_beam_idx] 73 | 74 | if len(self.prev_ks) > 0: 75 | beam_lk = active_word_lk + active_scores.unsqueeze(dim=1).expand_as(active_word_lk) # (active_beam_num, num_words) 76 | else: 77 | beam_lk = active_word_lk[0] 78 | 79 | flat_beam_lk = beam_lk.view(-1) 80 | active_max_idx = len(flat_beam_lk) 81 | flat_beam_lk = torch.cat([flat_beam_lk, end_scores], dim=-1) 82 | 83 | self.all_scores.append(self.scores) 84 | 85 | 86 | sorted_scores, sorted_score_ids = torch.sort(flat_beam_lk, descending=True) 87 | select_num, cur_idx = 0, 0 88 | selected_scores = [] 89 | selected_words = [] 90 | selected_beams = [] 91 | new_active_status = [] 92 | 93 | prob_buf = [] 94 | while select_num < self.size: 95 | cur_score, cur_id = sorted_scores[cur_idx], sorted_score_ids[cur_idx] 96 | if cur_id >= active_max_idx: 97 | which_beam = end_beam_idx[cur_id-active_max_idx] 98 | which_word = torch.tensor(self.EOS).to(self.device) 99 | select_num += 1 100 | new_active_status.append(True) 101 | selected_scores.append(cur_score) 102 | selected_beams.append(which_beam) 103 | selected_words.append(which_word) 104 | prob_buf.append([0]*num_words) 105 | else: 106 | which_beam_idx = cur_id // num_words 107 | which_beam = active_beam_idx[which_beam_idx] 108 | which_word = cur_id - which_beam_idx*num_words 109 | if which_word not in cur_output[which_beam]: 110 | if which_word in [self.EOS, self.BOS]: 111 | new_active_status.append(True) 112 | else: 113 | new_active_status.append(False) 114 | select_num += 1 115 | selected_scores.append(cur_score) 116 | selected_beams.append(which_beam) 117 | selected_words.append(which_word) 118 | prob_buf.append(active_word_lk[which_beam_idx].detach().cpu().numpy().tolist()) 119 | cur_idx += 1 120 | self.prob_list.append(prob_buf) 121 | 122 | self.beam_status = new_active_status 123 | self.scores = torch.stack(selected_scores) 124 | self.prev_ks.append(torch.stack(selected_beams)) 125 | self.next_ys.append(torch.stack(selected_words)) 126 | 127 | if_done = True 128 | for i in range(self.size): 129 | if not self.beam_status[i]: 130 | if_done = False 131 | break 132 | if if_done: 133 | self.done=True 134 | 135 | return self.done 136 | 137 | def sort_scores(self): 138 | "Sort the scores." 139 | return torch.sort(self.scores, 0, True) 140 | 141 | 142 | def get_tentative_hypothesis(self): 143 | "Get the decoded sequence for the current timestep." 144 | 145 | if len(self.next_ys) == 1: 146 | dec_seq = self.next_ys[0].unsqueeze(1) 147 | else: 148 | _, keys = self.sort_scores() 149 | hyps = [self.get_hypothesis(k) for k in keys] 150 | hyps = [[self.BOS] + h for h in hyps] 151 | dec_seq = torch.from_numpy(np.array(hyps)).long().to(self.device) 152 | return dec_seq 153 | 154 | def get_tentative_hypothesis_wo_sort(self): 155 | "Get the decoded sequence for the current timestep." 156 | 157 | if len(self.next_ys) == 1: 158 | dec_seq = self.next_ys[0].unsqueeze(1) 159 | else: 160 | keys = list(range(self.size)) 161 | hyps = [self.get_hypothesis(k) for k in keys] 162 | hyps = [[self.BOS] + h for h in hyps] 163 | dec_seq = torch.from_numpy(np.array(hyps)).long().to(self.device) 164 | return dec_seq 165 | 166 | def get_hypothesis(self, k): 167 | """ 168 | Walk back to construct the full hypothesis. 169 | Parameters. 170 | * `k` - the position in the beam to construct. 171 | Returns. 172 | 1. The hypothesis 173 | 2. The attention at each time step. 174 | """ 175 | hyp = [] 176 | for j in range(len(self.prev_ks)-1, -1, -1): 177 | hyp.append(self.next_ys[j + 1][k].item()) 178 | k = self.prev_ks[j][k] 179 | return hyp[::-1] 180 | 181 | def get_prob_list(self, k): 182 | ret_prob_list = [] 183 | for j in range(len(self.prev_ks)-1, -1, -1): 184 | ret_prob_list.append(self.prob_list[j][k]) 185 | k = self.prev_ks[j][k] 186 | return ret_prob_list[::-1] -------------------------------------------------------------------------------- /src/COGNet.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import torch.nn as nn 4 | import argparse 5 | from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score 6 | import numpy as np 7 | import dill 8 | import time 9 | from torch.nn import CrossEntropyLoss 10 | from torch.optim import Adam 11 | from torch.utils import data 12 | from loss import cross_entropy_loss 13 | import os 14 | import torch.nn.functional as F 15 | import random 16 | from collections import defaultdict 17 | 18 | from torch.utils.data.dataloader import DataLoader 19 | from data_loader import mimic_data, pad_batch_v2_train, pad_batch_v2_eval, pad_num_replace 20 | 21 | import sys 22 | sys.path.append("..") 23 | from COGNet_ablation import COGNet_wo_copy, COGNet_wo_visit_score, COGNet_wo_graph, COGNet_wo_diag, COGNet_wo_proc 24 | from COGNet_model import COGNet, policy_network 25 | from util import llprint, sequence_metric, sequence_output_process, ddi_rate_score, get_n_params, output_flatten, print_result 26 | from recommend import eval, test 27 | 28 | torch.manual_seed(1203) 29 | 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 31 | 32 | model_name = 'COGNet' 33 | resume_path = '' 34 | 35 | if not os.path.exists(os.path.join("saved", model_name)): 36 | os.makedirs(os.path.join("saved", model_name)) 37 | 38 | # Training settings 39 | parser = argparse.ArgumentParser() 40 | # parser.add_argument('--Test', action='store_true', default=True, help="test mode") 41 | parser.add_argument('--Test', action='store_true', default=False, help="test mode") 42 | parser.add_argument('--model_name', type=str, default=model_name, help="model name") 43 | parser.add_argument('--resume_path', type=str, default=resume_path, help='resume path') 44 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 45 | parser.add_argument('--batch_size', type=int, default=16, help='batch_size') 46 | parser.add_argument('--emb_dim', type=int, default=64, help='embedding dimension size') 47 | parser.add_argument('--max_len', type=int, default=45, help='maximum prediction medication sequence') 48 | parser.add_argument('--beam_size', type=int, default=4, help='max num of sentences in beam searching') 49 | 50 | args = parser.parse_args() 51 | 52 | def main(args): 53 | # load data 54 | data_path = '../data/records_final.pkl' 55 | voc_path = '../data/voc_final.pkl' 56 | 57 | # ehr_adj_path = '../data/weighted_ehr_adj_final.pkl' 58 | ehr_adj_path = '../data/ehr_adj_final.pkl' 59 | ddi_adj_path = '../data/ddi_A_final.pkl' 60 | ddi_mask_path = '../data/ddi_mask_H.pkl' 61 | device = torch.device('cuda') 62 | 63 | data = dill.load(open(data_path, 'rb')) 64 | voc = dill.load(open(voc_path, 'rb')) 65 | ehr_adj = dill.load(open(ehr_adj_path, 'rb')) 66 | ddi_adj = dill.load(open(ddi_adj_path, 'rb')) 67 | ddi_mask_H = dill.load(open(ddi_mask_path, 'rb')) 68 | 69 | diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc'] 70 | print(f"Diag num:{len(diag_voc.idx2word)}") 71 | print(f"Proc num:{len(pro_voc.idx2word)}") 72 | print(f"Med num:{len(med_voc.idx2word)}") 73 | 74 | # frequency statistic 75 | med_count = defaultdict(int) 76 | for patient in data: 77 | for adm in patient: 78 | for med in adm[2]: 79 | med_count[med] += 1 80 | 81 | ## rare first 82 | for i in range(len(data)): 83 | for j in range(len(data[i])): 84 | cur_medications = sorted(data[i][j][2], key=lambda x:med_count[x]) 85 | data[i][j][2] = cur_medications 86 | 87 | 88 | split_point = int(len(data) * 2 / 3) 89 | data_train = data[:split_point] 90 | eval_len = int(len(data[split_point:]) / 2) 91 | data_test = data[split_point:split_point + eval_len] 92 | data_eval = data[split_point+eval_len:] 93 | 94 | train_dataset = mimic_data(data_train) 95 | eval_dataset = mimic_data(data_eval) 96 | test_dataset = mimic_data(data_test) 97 | 98 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=pad_batch_v2_train, shuffle=True, pin_memory=True) 99 | eval_dataloader = DataLoader(eval_dataset, batch_size=1, collate_fn=pad_batch_v2_eval, shuffle=True, pin_memory=True) 100 | test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=pad_batch_v2_eval, shuffle=True, pin_memory=True) 101 | 102 | voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word)) 103 | 104 | END_TOKEN = voc_size[2] + 1 105 | DIAG_PAD_TOKEN = voc_size[0] + 2 106 | PROC_PAD_TOKEN = voc_size[1] + 2 107 | MED_PAD_TOKEN = voc_size[2] + 2 108 | SOS_TOKEN = voc_size[2] 109 | TOKENS = [END_TOKEN, DIAG_PAD_TOKEN, PROC_PAD_TOKEN, MED_PAD_TOKEN, SOS_TOKEN] 110 | 111 | model = COGNet(voc_size, ehr_adj, ddi_adj, ddi_mask_H, emb_dim=args.emb_dim, device=device) 112 | 113 | if args.Test: 114 | model.load_state_dict(torch.load(open(args.resume_path, 'rb'))) 115 | model.to(device=device) 116 | tic = time.time() 117 | smm_record, ja, prauc, precision, recall, f1, med_num = test(model, test_dataloader, diag_voc, pro_voc, med_voc, voc_size, 0, device, TOKENS, ddi_adj, args) 118 | result = [] 119 | for _ in range(10): 120 | data_num = len(ja) 121 | final_length = int(0.8 * data_num) 122 | idx_list = list(range(data_num)) 123 | random.shuffle(idx_list) 124 | idx_list = idx_list[:final_length] 125 | avg_ja = np.mean([ja[i] for i in idx_list]) 126 | avg_prauc = np.mean([prauc[i] for i in idx_list]) 127 | avg_precision = np.mean([precision[i] for i in idx_list]) 128 | avg_recall = np.mean([recall[i] for i in idx_list]) 129 | avg_f1 = np.mean([f1[i] for i in idx_list]) 130 | avg_med = np.mean([med_num[i] for i in idx_list]) 131 | cur_smm_record = [smm_record[i] for i in idx_list] 132 | ddi_rate = ddi_rate_score(cur_smm_record, path='../data/ddi_A_final.pkl') 133 | result.append([ddi_rate, avg_ja, avg_prauc, avg_precision, avg_recall, avg_f1, avg_med]) 134 | llprint('\nDDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'.format( 135 | ddi_rate, avg_ja, avg_prauc, avg_precision, avg_recall, avg_f1, avg_med)) 136 | result = np.array(result) 137 | mean = result.mean(axis=0) 138 | std = result.std(axis=0) 139 | 140 | outstring = "" 141 | for m, s in zip(mean, std): 142 | outstring += "{:.4f} $\pm$ {:.4f} & ".format(m, s) 143 | 144 | print (outstring) 145 | print ('test time: {}'.format(time.time() - tic)) 146 | return 147 | 148 | model.to(device=device) 149 | print('parameters', get_n_params(model)) 150 | optimizer = Adam(model.parameters(), lr=args.lr) 151 | 152 | history = defaultdict(list) 153 | best_epoch, best_ja = 0, 0 154 | 155 | EPOCH = 200 156 | for epoch in range(EPOCH): 157 | tic = time.time() 158 | print ('\nepoch {} --------------------------'.format(epoch)) 159 | 160 | model.train() 161 | for idx, data in enumerate(train_dataloader): 162 | diseases, procedures, medications, seq_length, \ 163 | d_length_matrix, p_length_matrix, m_length_matrix, \ 164 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \ 165 | dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, \ 166 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask = data 167 | 168 | diseases = pad_num_replace(diseases, -1, DIAG_PAD_TOKEN).to(device) 169 | procedures = pad_num_replace(procedures, -1, PROC_PAD_TOKEN).to(device) 170 | dec_disease = pad_num_replace(dec_disease, -1, DIAG_PAD_TOKEN).to(device) 171 | stay_disease = pad_num_replace(stay_disease, -1, DIAG_PAD_TOKEN).to(device) 172 | dec_proc = pad_num_replace(dec_proc, -1, PROC_PAD_TOKEN).to(device) 173 | stay_proc = pad_num_replace(stay_proc, -1, PROC_PAD_TOKEN).to(device) 174 | medications = medications.to(device) 175 | m_mask_matrix = m_mask_matrix.to(device) 176 | d_mask_matrix = d_mask_matrix.to(device) 177 | p_mask_matrix = p_mask_matrix.to(device) 178 | dec_disease_mask = dec_disease_mask.to(device) 179 | stay_disease_mask = stay_disease_mask.to(device) 180 | dec_proc_mask = dec_proc_mask.to(device) 181 | stay_proc_mask = stay_proc_mask.to(device) 182 | output_logits = model(diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix, seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, 183 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask) 184 | labels, predictions = output_flatten(medications, output_logits, seq_length, m_length_matrix, voc_size[2] + 2, END_TOKEN, device, max_len=args.max_len) 185 | loss = F.nll_loss(predictions, labels.long()) 186 | optimizer.zero_grad() 187 | loss.backward() 188 | optimizer.step() 189 | llprint('\rtraining step: {} / {}'.format(idx, len(train_dataloader))) 190 | 191 | print () 192 | tic2 = time.time() 193 | ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(model, eval_dataloader, voc_size, epoch, device, TOKENS, args) 194 | print ('training time: {}, test time: {}'.format(time.time() - tic, time.time() - tic2)) 195 | 196 | history['ja'].append(ja) 197 | history['ddi_rate'].append(ddi_rate) 198 | history['avg_p'].append(avg_p) 199 | history['avg_r'].append(avg_r) 200 | history['avg_f1'].append(avg_f1) 201 | history['prauc'].append(prauc) 202 | history['med'].append(avg_med) 203 | 204 | if epoch >= 5: 205 | print ('ddi: {}, Med: {}, Ja: {}, F1: {}'.format( 206 | np.mean(history['ddi_rate'][-5:]), 207 | np.mean(history['med'][-5:]), 208 | np.mean(history['ja'][-5:]), 209 | np.mean(history['avg_f1'][-5:]), 210 | np.mean(history['prauc'][-5:]) 211 | )) 212 | 213 | torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \ 214 | 'Epoch_{}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, ja, ddi_rate)), 'wb')) 215 | 216 | if best_ja < ja: 217 | best_epoch = epoch 218 | best_ja = ja 219 | 220 | print ('best_epoch: {}'.format(best_epoch)) 221 | 222 | dill.dump(history, open(os.path.join('saved', args.model_name, 'history_{}.pkl'.format(args.model_name)), 'wb')) 223 | 224 | 225 | if __name__ == '__main__': 226 | main(args) 227 | -------------------------------------------------------------------------------- /data/processing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import dill 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | 7 | ##### process medications ##### 8 | # load med data 9 | def med_process(med_file): 10 | """读取MIMIC原数据文件,保留pid、adm_id、data以及NDC,以DF类型返回""" 11 | # 读取药物文件,NDC(National Drug Code)以类别类型存储 12 | med_pd = pd.read_csv(med_file, dtype={'NDC':'category'}) 13 | 14 | # drop不用的数据 15 | med_pd.drop(columns=['ROW_ID','DRUG_TYPE','DRUG_NAME_POE','DRUG_NAME_GENERIC', 16 | 'FORMULARY_DRUG_CD','PROD_STRENGTH','DOSE_VAL_RX', 17 | 'DOSE_UNIT_RX','FORM_VAL_DISP','FORM_UNIT_DISP', 'GSN', 'FORM_UNIT_DISP', 18 | 'ROUTE','ENDDATE','DRUG'], axis=1, inplace=True) 19 | med_pd.drop(index = med_pd[med_pd['NDC'] == '0'].index, axis=0, inplace=True) 20 | med_pd.fillna(method='pad', inplace=True) 21 | med_pd.dropna(inplace=True) 22 | med_pd.drop_duplicates(inplace=True) 23 | med_pd['ICUSTAY_ID'] = med_pd['ICUSTAY_ID'].astype('int64') 24 | med_pd['STARTDATE'] = pd.to_datetime(med_pd['STARTDATE'], format='%Y-%m-%d %H:%M:%S') 25 | med_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'STARTDATE'], inplace=True) 26 | med_pd = med_pd.reset_index(drop=True) # 重置索引,同时drop原索引 27 | 28 | med_pd = med_pd.drop(columns=['ICUSTAY_ID']) 29 | med_pd = med_pd.drop_duplicates() 30 | med_pd = med_pd.reset_index(drop=True) 31 | 32 | return med_pd 33 | 34 | # medication mapping 35 | def ndc2atc4(med_pd): 36 | """将NDC映射到ACT4""" 37 | with open(ndc_rxnorm_file, 'r') as f: 38 | ndc2rxnorm = eval(f.read()) 39 | # 根据ndc_rxnorm_file文件读取ndc到xnorm的映射(这个xnorm似乎等同于下面的RXCUI) 40 | med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm) 41 | med_pd.dropna(inplace=True) # 实际上啥也没删掉 42 | 43 | rxnorm2atc = pd.read_csv(ndc2atc_file) 44 | rxnorm2atc = rxnorm2atc.drop(columns=['YEAR','MONTH','NDC']) # NDC删了,直接从RXCUI映射到ATC 45 | # 根据RXCUI删除重复列 46 | rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True) 47 | 48 | med_pd.drop(index = med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True) # 删除特定的RXCUI 49 | 50 | med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64') 51 | med_pd = med_pd.reset_index(drop=True) 52 | med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI']) # 合并两个表 53 | med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True) # 干掉NDC\RXCUI,只剩ATC4了 54 | med_pd = med_pd.rename(columns={'ATC4':'NDC'}) # 重新命名为NDC 55 | med_pd['NDC'] = med_pd['NDC'].map(lambda x: x[:4]) # 只保留前四位 56 | med_pd = med_pd.drop_duplicates() 57 | med_pd = med_pd.reset_index(drop=True) 58 | return med_pd 59 | 60 | # visit >= 2 61 | def process_visit_lg2(med_pd): 62 | """筛除admission次数小于两次的患者数据""" 63 | a = med_pd[['SUBJECT_ID', 'HADM_ID']].groupby(by='SUBJECT_ID')['HADM_ID'].unique().reset_index() 64 | a['HADM_ID_Len'] = a['HADM_ID'].map(lambda x:len(x)) 65 | a = a[a['HADM_ID_Len'] > 1] 66 | return a 67 | 68 | # most common medications 69 | def filter_300_most_med(med_pd): 70 | # 按照NDC出现的次数降序排列,取前300 71 | med_count = med_pd.groupby(by=['NDC']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True) 72 | med_pd = med_pd[med_pd['NDC'].isin(med_count.loc[:299, 'NDC'])] 73 | 74 | return med_pd.reset_index(drop=True) 75 | 76 | ##### process diagnosis ##### 77 | def diag_process(diag_file): 78 | diag_pd = pd.read_csv(diag_file) 79 | diag_pd.dropna(inplace=True) 80 | diag_pd.drop(columns=['SEQ_NUM','ROW_ID'],inplace=True) 81 | diag_pd.drop_duplicates(inplace=True) 82 | diag_pd.sort_values(by=['SUBJECT_ID','HADM_ID'], inplace=True) 83 | diag_pd = diag_pd.reset_index(drop=True) 84 | 85 | def filter_2000_most_diag(diag_pd): 86 | diag_count = diag_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True) 87 | diag_pd = diag_pd[diag_pd['ICD9_CODE'].isin(diag_count.loc[:1999, 'ICD9_CODE'])] 88 | 89 | return diag_pd.reset_index(drop=True) 90 | 91 | diag_pd = filter_2000_most_diag(diag_pd) 92 | 93 | return diag_pd 94 | 95 | ##### process procedure ##### 96 | def procedure_process(procedure_file): 97 | pro_pd = pd.read_csv(procedure_file, dtype={'ICD9_CODE':'category'}) 98 | pro_pd.drop(columns=['ROW_ID'], inplace=True) 99 | pro_pd.drop_duplicates(inplace=True) 100 | pro_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'SEQ_NUM'], inplace=True) 101 | pro_pd.drop(columns=['SEQ_NUM'], inplace=True) 102 | pro_pd.drop_duplicates(inplace=True) 103 | pro_pd.reset_index(drop=True, inplace=True) 104 | 105 | return pro_pd 106 | 107 | def filter_1000_most_pro(pro_pd): 108 | pro_count = pro_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True) 109 | pro_pd = pro_pd[pro_pd['ICD9_CODE'].isin(pro_count.loc[:1000, 'ICD9_CODE'])] 110 | 111 | return pro_pd.reset_index(drop=True) 112 | 113 | ###### combine three tables ##### 114 | def combine_process(med_pd, diag_pd, pro_pd): 115 | """药物、症状、proc的数据结合""" 116 | 117 | med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates() 118 | diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates() 119 | pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates() 120 | 121 | combined_key = med_pd_key.merge(diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') 122 | combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') 123 | 124 | diag_pd = diag_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') 125 | med_pd = med_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') 126 | pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') 127 | 128 | # flatten and merge 129 | diag_pd = diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index() 130 | med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index() 131 | pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'}) 132 | med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x)) 133 | pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x)) 134 | data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner') 135 | data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner') 136 | # data['ICD9_CODE_Len'] = data['ICD9_CODE'].map(lambda x: len(x)) 137 | data['NDC_Len'] = data['NDC'].map(lambda x: len(x)) 138 | 139 | return data 140 | 141 | def statistics(data): 142 | print('#patients ', data['SUBJECT_ID'].unique().shape) 143 | print('#clinical events ', len(data)) 144 | 145 | diag = data['ICD9_CODE'].values 146 | med = data['NDC'].values 147 | pro = data['PRO_CODE'].values 148 | 149 | unique_diag = set([j for i in diag for j in list(i)]) 150 | unique_med = set([j for i in med for j in list(i)]) 151 | unique_pro = set([j for i in pro for j in list(i)]) 152 | 153 | print('#diagnosis ', len(unique_diag)) 154 | print('#med ', len(unique_med)) 155 | print('#procedure', len(unique_pro)) 156 | 157 | avg_diag, avg_med, avg_pro, max_diag, max_med, max_pro, cnt, max_visit, avg_visit = [0 for i in range(9)] 158 | 159 | for subject_id in data['SUBJECT_ID'].unique(): 160 | item_data = data[data['SUBJECT_ID'] == subject_id] 161 | x, y, z = [], [], [] 162 | visit_cnt = 0 163 | for index, row in item_data.iterrows(): 164 | visit_cnt += 1 165 | cnt += 1 166 | x.extend(list(row['ICD9_CODE'])) 167 | y.extend(list(row['NDC'])) 168 | z.extend(list(row['PRO_CODE'])) 169 | x, y, z = set(x), set(y), set(z) 170 | avg_diag += len(x) 171 | avg_med += len(y) 172 | avg_pro += len(z) 173 | avg_visit += visit_cnt 174 | if len(x) > max_diag: 175 | max_diag = len(x) 176 | if len(y) > max_med: 177 | max_med = len(y) 178 | if len(z) > max_pro: 179 | max_pro = len(z) 180 | if visit_cnt > max_visit: 181 | max_visit = visit_cnt 182 | 183 | print('#avg of diagnoses ', avg_diag/ cnt) 184 | print('#avg of medicines ', avg_med/ cnt) 185 | print('#avg of procedures ', avg_pro/ cnt) 186 | print('#avg of vists ', avg_visit/ len(data['SUBJECT_ID'].unique())) 187 | 188 | print('#max of diagnoses ', max_diag) 189 | print('#max of medicines ', max_med) 190 | print('#max of procedures ', max_pro) 191 | print('#max of visit ', max_visit) 192 | 193 | ##### indexing file and final record 194 | class Voc(object): 195 | def __init__(self): 196 | self.idx2word = {} 197 | self.word2idx = {} 198 | 199 | def add_sentence(self, sentence): 200 | for word in sentence: 201 | if word not in self.word2idx: 202 | self.idx2word[len(self.word2idx)] = word 203 | self.word2idx[word] = len(self.word2idx) 204 | 205 | # create voc set 206 | def create_str_token_mapping(df): 207 | diag_voc = Voc() 208 | med_voc = Voc() 209 | pro_voc = Voc() 210 | 211 | for index, row in df.iterrows(): 212 | diag_voc.add_sentence(row['ICD9_CODE']) 213 | med_voc.add_sentence(row['NDC']) 214 | pro_voc.add_sentence(row['PRO_CODE']) 215 | 216 | dill.dump(obj={'diag_voc':diag_voc, 'med_voc':med_voc ,'pro_voc':pro_voc}, file=open('voc_final.pkl','wb')) 217 | return diag_voc, med_voc, pro_voc 218 | 219 | # create final records 220 | def create_patient_record(df, diag_voc, med_voc, pro_voc): 221 | """ 222 | 保存list类型的记录 223 | 每一项代表一个患者,患者中有多个visit,每个visit包含三者数组,按顺序分别表示诊断、proc与药物 224 | 存储的均为编号,可以通过voc_final.pkl来查看对应的具体word 225 | """ 226 | records = [] # (patient, code_kind:3, codes) code_kind:diag, proc, med 227 | for subject_id in df['SUBJECT_ID'].unique(): 228 | item_df = df[df['SUBJECT_ID'] == subject_id] 229 | patient = [] 230 | for index, row in item_df.iterrows(): 231 | admission = [] 232 | admission.append([diag_voc.word2idx[i] for i in row['ICD9_CODE']]) 233 | admission.append([pro_voc.word2idx[i] for i in row['PRO_CODE']]) 234 | admission.append([med_voc.word2idx[i] for i in row['NDC']]) 235 | patient.append(admission) 236 | records.append(patient) 237 | dill.dump(obj=records, file=open('records_final.pkl', 'wb')) 238 | return records 239 | 240 | 241 | 242 | # get ddi matrix 243 | def get_ddi_matrix(records, med_voc, ddi_file): 244 | 245 | TOPK = 40 # topk drug-drug interaction 246 | cid2atc_dic = defaultdict(set) 247 | med_voc_size = len(med_voc.idx2word) 248 | med_unique_word = [med_voc.idx2word[i] for i in range(med_voc_size)] # 所有的药物的ATC4 249 | atc3_atc4_dic = defaultdict(set) 250 | for item in med_unique_word: 251 | atc3_atc4_dic[item[:4]].add(item) # 252 | 253 | with open(cid_atc, 'r') as f: 254 | for line in f: 255 | line_ls = line[:-1].split(',') 256 | cid = line_ls[0] 257 | atcs = line_ls[1:] 258 | for atc in atcs: 259 | if len(atc3_atc4_dic[atc[:4]]) != 0: 260 | cid2atc_dic[cid].add(atc[:4]) 261 | 262 | # 加载DDI数据 263 | ddi_df = pd.read_csv(ddi_file) 264 | # fliter sever side effect,也是采取topK的形式 265 | ddi_most_pd = ddi_df.groupby(by=['Polypharmacy Side Effect', 'Side Effect Name']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True) 266 | ddi_most_pd = ddi_most_pd.iloc[-TOPK:,:] 267 | # ddi_most_pd = pd.DataFrame(columns=['Side Effect Name'], data=['as','asd','as']) 268 | fliter_ddi_df = ddi_df.merge(ddi_most_pd[['Side Effect Name']], how='inner', on=['Side Effect Name']) 269 | ddi_df = fliter_ddi_df[['STITCH 1','STITCH 2']].drop_duplicates().reset_index(drop=True) 270 | 271 | 272 | # weighted ehr adj 273 | ehr_adj = np.zeros((med_voc_size, med_voc_size)) 274 | for patient in records: 275 | for adm in patient: 276 | med_set = adm[2] 277 | for i, med_i in enumerate(med_set): 278 | for j, med_j in enumerate(med_set): 279 | if j<=i: 280 | continue 281 | # ehr_adj[med_i, med_j] = 1 282 | # ehr_adj[med_j, med_i] = 1 283 | ehr_adj[med_i, med_j] += 1 284 | ehr_adj[med_j, med_i] += 1 285 | dill.dump(ehr_adj, open('ehr_adj_final.pkl', 'wb')) 286 | 287 | # ddi adj,DDI表是CID编码的,因此需要将CID映射到ACT编码,才能记录数据集中药物之间的冲突信息 288 | ddi_adj = np.zeros((med_voc_size,med_voc_size)) 289 | for index, row in ddi_df.iterrows(): 290 | # ddi 291 | cid1 = row['STITCH 1'] 292 | cid2 = row['STITCH 2'] 293 | 294 | # cid -> atc_level3 295 | for atc_i in cid2atc_dic[cid1]: 296 | for atc_j in cid2atc_dic[cid2]: 297 | 298 | # atc_level3 -> atc_level4 299 | for i in atc3_atc4_dic[atc_i]: 300 | for j in atc3_atc4_dic[atc_j]: 301 | if med_voc.word2idx[i] != med_voc.word2idx[j]: 302 | ddi_adj[med_voc.word2idx[i], med_voc.word2idx[j]] = 1 303 | ddi_adj[med_voc.word2idx[j], med_voc.word2idx[i]] = 1 304 | dill.dump(ddi_adj, open('ddi_A_final.pkl', 'wb')) 305 | 306 | return ddi_adj 307 | 308 | if __name__ == '__main__': 309 | # MIMIC数据文件,分别包括药物、诊断和proc 310 | med_file = '/data/mimic-iii/PRESCRIPTIONS.csv' 311 | diag_file = '/data/mimic-iii/DIAGNOSES_ICD.csv' 312 | procedure_file = '/data/mimic-iii/PROCEDURES_ICD.csv' 313 | 314 | # 药物信息 315 | med_structure_file = './idx2drug.pkl' # 药物到分子式的映射 316 | 317 | # drug code mapping files 318 | ndc2atc_file = './ndc2atc_level4.csv' # NDC code to ATC-4 code mapping file,用于读取xnorm到ATC 319 | cid_atc = './drug-atc.csv' # drug(CID) to ATC code mapping file,用于处理DDI表 320 | ndc_rxnorm_file = './ndc2rxnorm_mapping.txt' # NDC to xnorm mapping file 321 | 322 | # ddi information 323 | # data example 324 | # STITCH 1,STITCH 2,Polypharmacy Side Effect,Side Effect Name 325 | # CID000002173,CID000003345,C0151714,hypermagnesemia 326 | # CID000002173,CID000003345,C0035344,retinopathy of prematurity 327 | # CID000002173,CID000003345,C0004144,atelectasis 328 | # CID000002173,CID000003345,C0002063,alkalosis 329 | # CID000002173,CID000003345,C0004604,Back Ache 330 | # CID000002173,CID000003345,C0034063,lung edema 331 | ddi_file = '/data/drug-DDI.csv' 332 | 333 | # 处理MIMIC中的药物数据 334 | med_pd = med_process(med_file) 335 | med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True) # 注意这里仅仅是针对med表中出现了两次以上admission的patient 336 | med_pd = med_pd.merge(med_pd_lg2[['SUBJECT_ID']], on='SUBJECT_ID', how='inner').reset_index(drop=True) 337 | 338 | med_pd = ndc2atc4(med_pd) 339 | NDCList = dill.load(open(med_structure_file, 'rb')) 340 | med_pd = med_pd[med_pd.NDC.isin(list(NDCList.keys()))] 341 | med_pd = filter_300_most_med(med_pd) 342 | 343 | print ('complete medication processing') 344 | 345 | # for diagnosis 346 | diag_pd = diag_process(diag_file) 347 | 348 | print ('complete diagnosis processing') 349 | 350 | # for procedure 351 | pro_pd = procedure_process(procedure_file) 352 | # pro_pd = filter_1000_most_pro(pro_pd) 353 | 354 | print ('complete procedure processing') 355 | 356 | # combine 357 | data = combine_process(med_pd, diag_pd, pro_pd) 358 | statistics(data) 359 | data.to_pickle('data_final.pkl') 360 | 361 | print ('complete combining') 362 | 363 | 364 | # ddi_matrix 365 | diag_voc, med_voc, pro_voc = create_str_token_mapping(data) 366 | records = create_patient_record(data, diag_voc, med_voc, pro_voc) # diag,proc,medication按顺序存储 367 | ddi_adj = get_ddi_matrix(records, med_voc, ddi_file) 368 | -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | from sklearn import preprocessing 2 | from torch.nn.functional import pad 3 | from torch.utils import data 4 | import torch 5 | import random 6 | 7 | 8 | class mimic_data(data.Dataset): 9 | def __init__(self, data) -> None: 10 | super().__init__() 11 | self.data = data 12 | 13 | def __getitem__(self, index): 14 | return self.data[index] 15 | 16 | def __len__(self): 17 | return len(self.data) 18 | 19 | 20 | def pad_batch(batch): 21 | seq_length = torch.tensor([len(data) for data in batch]) 22 | batch_size = len(batch) 23 | max_seq = max(seq_length) 24 | 25 | # 统计每一个seq疾病、手术、药物的数量,以及相应的最值 26 | # 同时为每一个seq的disease计算与上一个seq的disease之间的交集和差集 27 | d_length_matrix = [] 28 | p_length_matrix = [] 29 | m_length_matrix = [] 30 | d_max_num = 0 31 | p_max_num = 0 32 | m_max_num = 0 33 | d_dec_list = [] 34 | d_stay_list = [] 35 | for data in batch: 36 | d_buf, p_buf, m_buf = [], [], [] 37 | d_dec_list_buf, d_stay_list_buf = [], [] 38 | for idx, seq in enumerate(data): 39 | d_buf.append(len(seq[0])) 40 | p_buf.append(len(seq[1])) 41 | m_buf.append(len(seq[2])) 42 | d_max_num = max(d_max_num, len(seq[0])) 43 | p_max_num = max(p_max_num, len(seq[1])) 44 | m_max_num = max(m_max_num, len(seq[2])) 45 | if idx==0: 46 | # 第一个seq,则交集与差集为空 47 | d_dec_list_buf.append([]) 48 | d_stay_list_buf.append([]) 49 | else: 50 | # 计算差集与交集 51 | cur_d = set(seq[0]) 52 | last_d = set(data[idx-1][0]) 53 | stay_list = list(cur_d & last_d) 54 | dec_list = list(last_d - cur_d) 55 | d_dec_list_buf.append(dec_list) 56 | d_stay_list_buf.append(stay_list) 57 | d_length_matrix.append(d_buf) 58 | p_length_matrix.append(p_buf) 59 | m_length_matrix.append(m_buf) 60 | d_dec_list.append(d_dec_list_buf) 61 | d_stay_list.append(d_stay_list_buf) 62 | 63 | # 生成m_mask_matrix 64 | m_mask_matrix = torch.full((batch_size, max_seq, m_max_num), -1e9) 65 | for i in range(batch_size): 66 | for j in range(len(m_length_matrix[i])): 67 | m_mask_matrix[i, j, :m_length_matrix[i][j]] = 0. 68 | 69 | # 生成d_mask_matrix 70 | d_mask_matrix = torch.full((batch_size, max_seq, d_max_num), -1e9) 71 | for i in range(batch_size): 72 | for j in range(len(d_length_matrix[i])): 73 | d_mask_matrix[i, j, :d_length_matrix[i][j]] = 0. 74 | 75 | # 生成p_mask_matrix 76 | p_mask_matrix = torch.full((batch_size, max_seq, p_max_num), -1e9) 77 | for i in range(batch_size): 78 | for j in range(len(p_length_matrix[i])): 79 | p_mask_matrix[i, j, :p_length_matrix[i][j]] = 0. 80 | 81 | # 分别生成dec_disease_tensor和stay_disease_tensor 82 | dec_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1) 83 | stay_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1) 84 | dec_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9) 85 | stay_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9) 86 | for b_id, (dec_seqs, stay_seqs) in enumerate(zip(d_dec_list, d_stay_list)): 87 | for s_id, (dec_adm, stay_adm) in enumerate(zip(dec_seqs, stay_seqs)): 88 | dec_disease_tensor[b_id, s_id, :len(dec_adm)] = torch.tensor(dec_adm) 89 | stay_disease_tensor[b_id, s_id, :len(stay_adm)] = torch.tensor(stay_adm) 90 | dec_disease_mask[b_id, s_id, :len(dec_adm)] = 0. 91 | stay_disease_mask[b_id, s_id, :len(dec_adm)] = 0. 92 | 93 | # 分别生成disease、procedure、medication的数据 94 | disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1) 95 | procedure_tensor = torch.full((batch_size, max_seq, p_max_num), -1) 96 | medication_tensor = torch.full((batch_size, max_seq, m_max_num), 0) 97 | 98 | # 分别拼接成一个batch的数据 99 | for b_id, data in enumerate(batch): 100 | for s_id, adm in enumerate(data): 101 | # adm部分的数据按照disease、procedure、medication排序 102 | disease_tensor[b_id, s_id, :len(adm[0])] = torch.tensor(adm[0]) 103 | procedure_tensor[b_id, s_id, :len(adm[1])] = torch.tensor(adm[1]) 104 | medication_tensor[b_id, s_id, :len(adm[2])] = torch.tensor(adm[2]) 105 | # print(disease_tensor[1]) 106 | return disease_tensor, procedure_tensor, medication_tensor, seq_length, \ 107 | d_length_matrix, p_length_matrix, m_length_matrix, \ 108 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \ 109 | dec_disease_tensor, stay_disease_tensor, dec_disease_mask, stay_disease_mask 110 | 111 | 112 | def pad_batch_v2_train(batch): 113 | seq_length = torch.tensor([len(data) for data in batch]) 114 | batch_size = len(batch) 115 | max_seq = max(seq_length) 116 | 117 | # 统计每一个seq疾病、手术、药物的数量,以及相应的最值 118 | # 同时为每一个seq的disease计算与上一个seq的disease之间的交集和差集 119 | d_length_matrix = [] 120 | p_length_matrix = [] 121 | m_length_matrix = [] 122 | d_max_num = 0 123 | p_max_num = 0 124 | m_max_num = 0 125 | d_dec_list = [] 126 | d_stay_list = [] 127 | p_dec_list = [] 128 | p_stay_list = [] 129 | for data in batch: 130 | d_buf, p_buf, m_buf = [], [], [] 131 | d_dec_list_buf, d_stay_list_buf = [], [] 132 | p_dec_list_buf, p_stay_list_buf = [], [] 133 | for idx, seq in enumerate(data): 134 | d_buf.append(len(seq[0])) 135 | p_buf.append(len(seq[1])) 136 | m_buf.append(len(seq[2])) 137 | d_max_num = max(d_max_num, len(seq[0])) 138 | p_max_num = max(p_max_num, len(seq[1])) 139 | m_max_num = max(m_max_num, len(seq[2])) 140 | if idx==0: 141 | # 第一个seq,则交集与差集为空 142 | d_dec_list_buf.append([]) 143 | d_stay_list_buf.append([]) 144 | p_dec_list_buf.append([]) 145 | p_stay_list_buf.append([]) 146 | else: 147 | # 计算差集与交集 148 | cur_d = set(seq[0]) 149 | last_d = set(data[idx-1][0]) 150 | stay_list = list(cur_d & last_d) 151 | dec_list = list(last_d - cur_d) 152 | d_dec_list_buf.append(dec_list) 153 | d_stay_list_buf.append(stay_list) 154 | 155 | cur_p = set(seq[1]) 156 | last_p = set(data[idx-1][1]) 157 | proc_stay_list = list(cur_p & last_p) 158 | proc_dec_list = list(last_p - cur_p) 159 | p_dec_list_buf.append(proc_dec_list) 160 | p_stay_list_buf.append(proc_stay_list) 161 | d_length_matrix.append(d_buf) 162 | p_length_matrix.append(p_buf) 163 | m_length_matrix.append(m_buf) 164 | d_dec_list.append(d_dec_list_buf) 165 | d_stay_list.append(d_stay_list_buf) 166 | p_dec_list.append(p_dec_list_buf) 167 | p_stay_list.append(p_stay_list_buf) 168 | 169 | # 生成m_mask_matrix 170 | m_mask_matrix = torch.full((batch_size, max_seq, m_max_num), -1e9) 171 | for i in range(batch_size): 172 | for j in range(len(m_length_matrix[i])): 173 | m_mask_matrix[i, j, :m_length_matrix[i][j]] = 0. 174 | 175 | # 生成d_mask_matrix 176 | d_mask_matrix = torch.full((batch_size, max_seq, d_max_num), -1e9) 177 | for i in range(batch_size): 178 | for j in range(len(d_length_matrix[i])): 179 | d_mask_matrix[i, j, :d_length_matrix[i][j]] = 0. 180 | 181 | # 生成p_mask_matrix 182 | p_mask_matrix = torch.full((batch_size, max_seq, p_max_num), -1e9) 183 | for i in range(batch_size): 184 | for j in range(len(p_length_matrix[i])): 185 | p_mask_matrix[i, j, :p_length_matrix[i][j]] = 0. 186 | 187 | # 分别生成dec_disease_tensor和stay_disease_tensor 188 | dec_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1) 189 | stay_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1) 190 | dec_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9) 191 | stay_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9) 192 | for b_id, (dec_seqs, stay_seqs) in enumerate(zip(d_dec_list, d_stay_list)): 193 | for s_id, (dec_adm, stay_adm) in enumerate(zip(dec_seqs, stay_seqs)): 194 | dec_disease_tensor[b_id, s_id, :len(dec_adm)] = torch.tensor(dec_adm) 195 | stay_disease_tensor[b_id, s_id, :len(stay_adm)] = torch.tensor(stay_adm) 196 | dec_disease_mask[b_id, s_id, :len(dec_adm)] = 0. 197 | stay_disease_mask[b_id, s_id, :len(dec_adm)] = 0. 198 | 199 | # 分别生成dec_disease_tensor和stay_disease_tensor 200 | dec_proc_tensor = torch.full((batch_size, max_seq, p_max_num), -1) 201 | stay_proc_tensor = torch.full((batch_size, max_seq, p_max_num), -1) 202 | dec_proc_mask = torch.full((batch_size, max_seq, p_max_num), -1e9) 203 | stay_proc_mask = torch.full((batch_size, max_seq, p_max_num), -1e9) 204 | for b_id, (dec_seqs, stay_seqs) in enumerate(zip(p_dec_list, p_stay_list)): 205 | for s_id, (dec_adm, stay_adm) in enumerate(zip(dec_seqs, stay_seqs)): 206 | dec_proc_tensor[b_id, s_id, :len(dec_adm)] = torch.tensor(dec_adm) 207 | stay_proc_tensor[b_id, s_id, :len(stay_adm)] = torch.tensor(stay_adm) 208 | dec_proc_mask[b_id, s_id, :len(dec_adm)] = 0. 209 | stay_proc_mask[b_id, s_id, :len(dec_adm)] = 0. 210 | 211 | # 分别生成disease、procedure、medication的数据 212 | disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1) 213 | procedure_tensor = torch.full((batch_size, max_seq, p_max_num), -1) 214 | medication_tensor = torch.full((batch_size, max_seq, m_max_num), 0) 215 | 216 | # 分别拼接成一个batch的数据 217 | for b_id, data in enumerate(batch): 218 | for s_id, adm in enumerate(data): 219 | # adm部分的数据按照disease、procedure、medication排序 220 | disease_tensor[b_id, s_id, :len(adm[0])] = torch.tensor(adm[0]) 221 | procedure_tensor[b_id, s_id, :len(adm[1])] = torch.tensor(adm[1]) 222 | # dynamic shuffle 223 | # cur_medications = adm[2] 224 | # random.shuffle(cur_medications) 225 | # medication_tensor[b_id, s_id, :len(adm[2])] = torch.tensor(cur_medications) 226 | medication_tensor[b_id, s_id, :len(adm[2])] = torch.tensor(adm[2]) 227 | 228 | # print(disease_tensor[1]) 229 | return disease_tensor, procedure_tensor, medication_tensor, seq_length, \ 230 | d_length_matrix, p_length_matrix, m_length_matrix, \ 231 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \ 232 | dec_disease_tensor, stay_disease_tensor, dec_disease_mask, stay_disease_mask, \ 233 | dec_proc_tensor, stay_proc_tensor, dec_proc_mask, stay_proc_mask 234 | 235 | 236 | 237 | def pad_batch_v2_eval(batch): 238 | seq_length = torch.tensor([len(data) for data in batch]) 239 | batch_size = len(batch) 240 | max_seq = max(seq_length) 241 | 242 | # 统计每一个seq疾病、手术、药物的数量,以及相应的最值 243 | # 同时为每一个seq的disease计算与上一个seq的disease之间的交集和差集 244 | d_length_matrix = [] 245 | p_length_matrix = [] 246 | m_length_matrix = [] 247 | d_max_num = 0 248 | p_max_num = 0 249 | m_max_num = 0 250 | d_dec_list = [] 251 | d_stay_list = [] 252 | p_dec_list = [] 253 | p_stay_list = [] 254 | for data in batch: 255 | d_buf, p_buf, m_buf = [], [], [] 256 | d_dec_list_buf, d_stay_list_buf = [], [] 257 | p_dec_list_buf, p_stay_list_buf = [], [] 258 | for idx, seq in enumerate(data): 259 | d_buf.append(len(seq[0])) 260 | p_buf.append(len(seq[1])) 261 | m_buf.append(len(seq[2])) 262 | d_max_num = max(d_max_num, len(seq[0])) 263 | p_max_num = max(p_max_num, len(seq[1])) 264 | m_max_num = max(m_max_num, len(seq[2])) 265 | if idx==0: 266 | # 第一个seq,则交集与差集为空 267 | d_dec_list_buf.append([]) 268 | d_stay_list_buf.append([]) 269 | p_dec_list_buf.append([]) 270 | p_stay_list_buf.append([]) 271 | else: 272 | # 计算差集与交集 273 | cur_d = set(seq[0]) 274 | last_d = set(data[idx-1][0]) 275 | stay_list = list(cur_d & last_d) 276 | dec_list = list(last_d - cur_d) 277 | d_dec_list_buf.append(dec_list) 278 | d_stay_list_buf.append(stay_list) 279 | 280 | cur_p = set(seq[1]) 281 | last_p = set(data[idx-1][1]) 282 | proc_stay_list = list(cur_p & last_p) 283 | proc_dec_list = list(last_p - cur_p) 284 | p_dec_list_buf.append(proc_dec_list) 285 | p_stay_list_buf.append(proc_stay_list) 286 | d_length_matrix.append(d_buf) 287 | p_length_matrix.append(p_buf) 288 | m_length_matrix.append(m_buf) 289 | d_dec_list.append(d_dec_list_buf) 290 | d_stay_list.append(d_stay_list_buf) 291 | p_dec_list.append(p_dec_list_buf) 292 | p_stay_list.append(p_stay_list_buf) 293 | 294 | # 生成m_mask_matrix 295 | m_mask_matrix = torch.full((batch_size, max_seq, m_max_num), -1e9) 296 | for i in range(batch_size): 297 | for j in range(len(m_length_matrix[i])): 298 | m_mask_matrix[i, j, :m_length_matrix[i][j]] = 0. 299 | 300 | # 生成d_mask_matrix 301 | d_mask_matrix = torch.full((batch_size, max_seq, d_max_num), -1e9) 302 | for i in range(batch_size): 303 | for j in range(len(d_length_matrix[i])): 304 | d_mask_matrix[i, j, :d_length_matrix[i][j]] = 0. 305 | 306 | # 生成p_mask_matrix 307 | p_mask_matrix = torch.full((batch_size, max_seq, p_max_num), -1e9) 308 | for i in range(batch_size): 309 | for j in range(len(p_length_matrix[i])): 310 | p_mask_matrix[i, j, :p_length_matrix[i][j]] = 0. 311 | 312 | # 分别生成dec_disease_tensor和stay_disease_tensor 313 | dec_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1) 314 | stay_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1) 315 | dec_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9) 316 | stay_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9) 317 | for b_id, (dec_seqs, stay_seqs) in enumerate(zip(d_dec_list, d_stay_list)): 318 | for s_id, (dec_adm, stay_adm) in enumerate(zip(dec_seqs, stay_seqs)): 319 | dec_disease_tensor[b_id, s_id, :len(dec_adm)] = torch.tensor(dec_adm) 320 | stay_disease_tensor[b_id, s_id, :len(stay_adm)] = torch.tensor(stay_adm) 321 | dec_disease_mask[b_id, s_id, :len(dec_adm)] = 0. 322 | stay_disease_mask[b_id, s_id, :len(dec_adm)] = 0. 323 | 324 | # 分别生成dec_disease_tensor和stay_disease_tensor 325 | dec_proc_tensor = torch.full((batch_size, max_seq, p_max_num), -1) 326 | stay_proc_tensor = torch.full((batch_size, max_seq, p_max_num), -1) 327 | dec_proc_mask = torch.full((batch_size, max_seq, p_max_num), -1e9) 328 | stay_proc_mask = torch.full((batch_size, max_seq, p_max_num), -1e9) 329 | for b_id, (dec_seqs, stay_seqs) in enumerate(zip(p_dec_list, p_stay_list)): 330 | for s_id, (dec_adm, stay_adm) in enumerate(zip(dec_seqs, stay_seqs)): 331 | dec_proc_tensor[b_id, s_id, :len(dec_adm)] = torch.tensor(dec_adm) 332 | stay_proc_tensor[b_id, s_id, :len(stay_adm)] = torch.tensor(stay_adm) 333 | dec_proc_mask[b_id, s_id, :len(dec_adm)] = 0. 334 | stay_proc_mask[b_id, s_id, :len(dec_adm)] = 0. 335 | 336 | # 分别生成disease、procedure、medication的数据 337 | disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1) 338 | procedure_tensor = torch.full((batch_size, max_seq, p_max_num), -1) 339 | medication_tensor = torch.full((batch_size, max_seq, m_max_num), 0) 340 | 341 | # 分别拼接成一个batch的数据 342 | for b_id, data in enumerate(batch): 343 | for s_id, adm in enumerate(data): 344 | # adm部分的数据按照disease、procedure、medication排序 345 | disease_tensor[b_id, s_id, :len(adm[0])] = torch.tensor(adm[0]) 346 | procedure_tensor[b_id, s_id, :len(adm[1])] = torch.tensor(adm[1]) 347 | medication_tensor[b_id, s_id, :len(adm[2])] = torch.tensor(adm[2]) 348 | 349 | # print(disease_tensor[1]) 350 | return disease_tensor, procedure_tensor, medication_tensor, seq_length, \ 351 | d_length_matrix, p_length_matrix, m_length_matrix, \ 352 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \ 353 | dec_disease_tensor, stay_disease_tensor, dec_disease_mask, stay_disease_mask, \ 354 | dec_proc_tensor, stay_proc_tensor, dec_proc_mask, stay_proc_mask 355 | 356 | 357 | def pad_num_replace(tensor, src_num, target_num): 358 | # replace_tensor = torch.full_like(tensor, target_num) 359 | return torch.where(tensor==src_num, target_num, tensor) 360 | 361 | 362 | -------------------------------------------------------------------------------- /src/MICRON.py: -------------------------------------------------------------------------------- 1 | from types import new_class 2 | import dill 3 | import numpy as np 4 | import argparse 5 | from collections import defaultdict 6 | from sklearn.metrics import jaccard_score, roc_curve 7 | from torch.optim import Adam, RMSprop 8 | import os 9 | import torch 10 | import time 11 | import math 12 | from models import MICRON 13 | from util import llprint, multi_label_metric, ddi_rate_score, get_n_params 14 | import torch.nn.functional as F 15 | 16 | # torch.set_num_threads(30) 17 | torch.manual_seed(1203) 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 19 | 20 | # setting 21 | model_name = 'MICRON_pad_lr1e-4' 22 | resume_path = './Epoch_39_JA_0.5209_DDI_0.06952.model' 23 | # resume_path = './{}_Epoch_39_JA_0.5209_DDI_0.06952.model'.format(model_name) 24 | 25 | if not os.path.exists(os.path.join("saved", model_name)): 26 | os.makedirs(os.path.join("saved", model_name)) 27 | 28 | # Training settings 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--Test', action='store_true', default=False, help="test mode") 31 | parser.add_argument('--model_name', type=str, default=model_name, help="model name") 32 | parser.add_argument('--resume_path', type=str, default=resume_path, help='resume path') 33 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate') # original 0.0002 34 | parser.add_argument('--weight_decay', type=float, default=1e-5, help='learning rate') 35 | parser.add_argument('--dim', type=int, default=64, help='dimension') 36 | 37 | args = parser.parse_args() 38 | 39 | # evaluate 40 | def eval(model, data_eval, voc_size, epoch, val=0, threshold1=0.8, threshold2=0.2): 41 | model.eval() 42 | 43 | smm_record = [] 44 | ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] 45 | med_cnt, visit_cnt = 0, 0 46 | label_list, prob_list = [], [] 47 | add_list, delete_list = [], [] 48 | # 不同visit的指标统计 49 | ja_by_visit = [[] for _ in range(5)] 50 | prauc_by_visit = [[] for _ in range(5)] 51 | pre_by_visit = [[] for _ in range(5)] 52 | recall_by_visit = [[] for _ in range(5)] 53 | f1_by_visit = [[] for _ in range(5)] 54 | smm_record_by_visit = [[] for _ in range(5)] 55 | 56 | for step, input in enumerate(data_eval): 57 | y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], [] 58 | add_temp_list, delete_temp_list = [], [] 59 | if len(input) < 2: continue 60 | for adm_idx, adm in enumerate(input): 61 | # 第0个visit也要添加到结果中去 62 | y_gt_tmp = np.zeros(voc_size[2]) 63 | y_gt_tmp[adm[2]] = 1 64 | y_gt.append(y_gt_tmp) 65 | label_list.append(y_gt_tmp) 66 | 67 | if adm_idx == 0: 68 | representation_base, _, _, _, _ = model(input[:adm_idx+1]) 69 | # 第0个visit也添加 70 | y_pred_tmp = F.sigmoid(representation_base).detach().cpu().numpy()[0] 71 | y_pred_prob.append(y_pred_tmp) 72 | prob_list.append(y_pred_tmp) 73 | 74 | y_old = np.zeros(voc_size[2]) 75 | y_old[y_pred_tmp>=threshold1] = 1 76 | y_old[y_pred_tmp=threshold1] = 1 108 | y_old[y_pred_tmp 1: 141 | add_list.append(np.mean(add_temp_list)) 142 | delete_list.append(np.mean(delete_temp_list)) 143 | elif len(add_temp_list) == 1: 144 | add_list.append(add_temp_list[0]) 145 | delete_list.append(delete_temp_list[0]) 146 | 147 | smm_record.append(y_pred_label) 148 | adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob)) 149 | 150 | ja.append(adm_ja) 151 | prauc.append(adm_prauc) 152 | avg_p.append(adm_avg_p) 153 | avg_r.append(adm_avg_r) 154 | avg_f1.append(adm_avg_f1) 155 | llprint('\rtest step: {} / {}'.format(step, len(data_eval))) 156 | 157 | # 分析各个visit的结果 158 | print('\tvisit1\tvisit2\tvisit3\tvisit4\tvisit5') 159 | print('count:', [len(buf) for buf in ja_by_visit]) 160 | print('prauc:', [np.mean(buf) for buf in prauc_by_visit]) 161 | print('jaccard:', [np.mean(buf) for buf in ja_by_visit]) 162 | print('precision:', [np.mean(buf) for buf in pre_by_visit]) 163 | print('recall:', [np.mean(buf) for buf in recall_by_visit]) 164 | print('f1:', [np.mean(buf) for buf in f1_by_visit]) 165 | print('DDI:', [ddi_rate_score(buf) for buf in smm_record_by_visit]) 166 | 167 | # ddi rate 168 | ddi_rate = ddi_rate_score(smm_record, path='../data/ddi_A_final.pkl') 169 | 170 | llprint('\nDDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n'.format( 171 | ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt 172 | )) 173 | 174 | if val == 0: 175 | return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt 176 | else: 177 | return np.array(label_list), np.array(prob_list) 178 | 179 | 180 | def main(): 181 | 182 | # load data 183 | data_path = '../data/records_final.pkl' 184 | voc_path = '../data/voc_final.pkl' 185 | 186 | ddi_adj_path = '../data/ddi_A_final.pkl' 187 | device = torch.device('cuda') 188 | 189 | ddi_adj = dill.load(open(ddi_adj_path, 'rb')) 190 | data = dill.load(open(data_path, 'rb')) 191 | 192 | voc = dill.load(open(voc_path, 'rb')) 193 | diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc'] 194 | 195 | # np.random.seed(1203) 196 | # np.random.shuffle(data) 197 | 198 | # "添加第一个visit" 199 | # new_data = [] 200 | # for patient in data: 201 | # patient.insert(0, [[],[],[]]) 202 | # # patient.insert(0, patient[0]) 203 | # new_data.append(patient) 204 | # data = new_data 205 | 206 | split_point = int(len(data) * 2 / 3) 207 | data_train = data[:split_point] 208 | eval_len = int(len(data[split_point:]) / 2) 209 | data_test = data[split_point:split_point + eval_len] 210 | data_eval = data[split_point+eval_len:] 211 | 212 | voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word)) 213 | 214 | model = MICRON(voc_size, ddi_adj, emb_dim=args.dim, device=device) 215 | # model.load_state_dict(torch.load(open(args.resume_path, 'rb'))) 216 | 217 | if args.Test: 218 | model.load_state_dict(torch.load(open(args.resume_path, 'rb'))) 219 | model.to(device=device) 220 | tic = time.time() 221 | label_list, prob_list = eval(model, data_eval, voc_size, 0, 1) 222 | 223 | threshold1, threshold2 = [], [] 224 | for i in range(label_list.shape[1]): 225 | _, _, boundary = roc_curve(label_list[:, i], prob_list[:, i], pos_label=1) 226 | # boundary1 should be in [0.5, 0.9], boundary2 should be in [0.1, 0.5] 227 | threshold1.append(min(0.9, max(0.5, boundary[max(0, round(len(boundary) * 0.05) - 1)]))) 228 | threshold2.append(max(0.1, min(0.5, boundary[min(round(len(boundary) * 0.95), len(boundary) - 1)]))) 229 | print (np.mean(threshold1), np.mean(threshold2)) 230 | threshold1 = np.ones(voc_size[2]) * np.mean(threshold1) 231 | threshold2 = np.ones(voc_size[2]) * np.mean(threshold2) 232 | eval(model, data_test, voc_size, 0, 0, threshold1, threshold2) 233 | print ('test time: {}'.format(time.time() - tic)) 234 | 235 | result = [] 236 | for _ in range(10): 237 | test_sample = np.random.choice(data_test, round(len(data_test) * 0.8), replace=True) 238 | ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_add, avg_del, avg_med = eval(model, test_sample, voc_size, 0) 239 | result.append([ddi_rate, ja, avg_f1, prauc, avg_med]) 240 | 241 | result = np.array(result) 242 | mean = result.mean(axis=0) 243 | std = result.std(axis=0) 244 | 245 | outstring = "" 246 | for m, s in zip(mean, std): 247 | outstring += "{:.4f} $\pm$ {:.4f} & ".format(m, s) 248 | 249 | print (outstring) 250 | 251 | print ('test time: {}'.format(time.time() - tic)) 252 | return 253 | 254 | model.to(device=device) 255 | print('parameters', get_n_params(model)) 256 | # exit() 257 | optimizer = RMSprop(list(model.parameters()), lr=args.lr, weight_decay=args.weight_decay) 258 | 259 | # start iterations 260 | history = defaultdict(list) 261 | best_epoch, best_ja = 0, 0 262 | 263 | weight_list = [[0.25, 0.25, 0.25, 0.25]] 264 | 265 | EPOCH = 40 266 | for epoch in range(EPOCH): 267 | t = 0 268 | tic = time.time() 269 | print ('\nepoch {} --------------------------'.format(epoch + 1)) 270 | 271 | sample_counter = 0 272 | mean_loss = np.array([0, 0, 0, 0]) 273 | 274 | model.train() 275 | for step, input in enumerate(data_train): 276 | loss = 0 277 | if len(input) < 2: continue 278 | for adm_idx, adm in enumerate(input): 279 | """第一个visit也参与训练""" 280 | # if adm_idx == 0: continue 281 | # sample_counter += 1 282 | seq_input = input[:adm_idx+1] 283 | 284 | loss_bce_target = np.zeros((1, voc_size[2])) 285 | loss_bce_target[:, adm[2]] = 1 286 | 287 | loss_bce_target_last = np.zeros((1, voc_size[2])) 288 | if adm_idx > 0: 289 | loss_bce_target_last[:, input[adm_idx-1][2]] = 1 290 | 291 | loss_multi_target = np.full((1, voc_size[2]), -1) 292 | for idx, item in enumerate(adm[2]): 293 | loss_multi_target[0][idx] = item 294 | 295 | loss_multi_target_last = np.full((1, voc_size[2]), -1) 296 | if adm_idx > 0: 297 | for idx, item in enumerate(input[adm_idx-1][2]): 298 | loss_multi_target_last[0][idx] = item 299 | 300 | result, result_last, _, loss_ddi, loss_rec = model(seq_input) 301 | 302 | loss_bce = 0.75 * F.binary_cross_entropy_with_logits(result, torch.FloatTensor(loss_bce_target).to(device)) + \ 303 | (1 - 0.75) * F.binary_cross_entropy_with_logits(result_last, torch.FloatTensor(loss_bce_target_last).to(device)) 304 | loss_multi = 5e-2 * (0.75 * F.multilabel_margin_loss(F.sigmoid(result), torch.LongTensor(loss_multi_target).to(device)) + \ 305 | (1 - 0.75) * F.multilabel_margin_loss(F.sigmoid(result_last), torch.LongTensor(loss_multi_target_last).to(device))) 306 | 307 | y_pred_tmp = F.sigmoid(result).detach().cpu().numpy()[0] 308 | y_pred_tmp[y_pred_tmp >= 0.5] = 1 309 | y_pred_tmp[y_pred_tmp < 0.5] = 0 310 | y_label = np.where(y_pred_tmp == 1)[0] 311 | current_ddi_rate = ddi_rate_score([[y_label]], path='../data/ddi_A_final.pkl') 312 | 313 | # l2 = 0 314 | # for p in model.parameters(): 315 | # l2 = l2 + (p ** 2).sum() 316 | 317 | if sample_counter == 0: 318 | lambda1, lambda2, lambda3, lambda4 = weight_list[-1] 319 | else: 320 | current_loss = np.array([loss_bce.detach().cpu().numpy(), loss_multi.detach().cpu().numpy(), loss_ddi.detach().cpu().numpy(), loss_rec.detach().cpu().numpy()]) 321 | current_ratio = (current_loss - np.array(mean_loss)) / np.array(mean_loss) 322 | instant_weight = np.exp(current_ratio) / sum(np.exp(current_ratio)) 323 | lambda1, lambda2, lambda3, lambda4 = instant_weight * 0.75 + np.array(weight_list[-1]) * 0.25 324 | # update weight_list 325 | weight_list.append([lambda1, lambda2, lambda3, lambda4]) 326 | # update mean_loss 327 | mean_loss = (mean_loss * (sample_counter - 1) + np.array([loss_bce.detach().cpu().numpy(), \ 328 | loss_multi.detach().cpu().numpy(), loss_ddi.detach().cpu().numpy(), loss_rec.detach().cpu().numpy()])) / sample_counter 329 | # lambda1, lambda2, lambda3, lambda4 = weight_list[-1] 330 | if current_ddi_rate > 0.08: 331 | loss += lambda1 * loss_bce + lambda2 * loss_multi + \ 332 | lambda3 * loss_ddi + lambda4 * loss_rec 333 | else: 334 | loss += lambda1 * loss_bce + lambda2 * loss_multi + \ 335 | lambda4 * loss_rec 336 | 337 | optimizer.zero_grad() 338 | loss.backward(retain_graph=True) 339 | optimizer.step() 340 | 341 | llprint('\rtraining step: {} / {}'.format(step, len(data_train))) 342 | 343 | print() 344 | tic2 = time.time() 345 | ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, add, delete, avg_med = eval(model, data_eval, voc_size, epoch) 346 | print ('training time: {}, test time: {}'.format(time.time() - tic, time.time() - tic2)) 347 | 348 | history['ja'].append(ja) 349 | history['ddi_rate'].append(ddi_rate) 350 | history['avg_p'].append(avg_p) 351 | history['avg_r'].append(avg_r) 352 | history['avg_f1'].append(avg_f1) 353 | history['prauc'].append(prauc) 354 | history['add'].append(add) 355 | history['delete'].append(delete) 356 | history['med'].append(avg_med) 357 | 358 | if epoch >= 5: 359 | print ('ddi: {}, Med: {}, Ja: {}, F1: {}, Add: {}, Delete: {}'.format( 360 | np.mean(history['ddi_rate'][-5:]), 361 | np.mean(history['med'][-5:]), 362 | np.mean(history['ja'][-5:]), 363 | np.mean(history['avg_f1'][-5:]), 364 | np.mean(history['add'][-5:]), 365 | np.mean(history['delete'][-5:]) 366 | )) 367 | 368 | torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \ 369 | 'Epoch_{}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, ja, ddi_rate)), 'wb')) 370 | 371 | if epoch != 0 and best_ja < ja: 372 | best_epoch = epoch 373 | best_ja = ja 374 | 375 | print ('best_epoch: {}'.format(best_epoch)) 376 | 377 | dill.dump(history, open(os.path.join('saved', args.model_name, 'history_{}.pkl'.format(args.model_name)), 'wb')) 378 | 379 | if __name__ == '__main__': 380 | main() 381 | -------------------------------------------------------------------------------- /src/recommend.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import torch.nn as nn 4 | import argparse 5 | from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score 6 | import numpy as np 7 | import dill 8 | import time 9 | from torch.nn import CrossEntropyLoss 10 | from torch.optim import Adam 11 | from torch.utils import data 12 | from loss import cross_entropy_loss 13 | import os 14 | import torch.nn.functional as F 15 | import random 16 | from collections import defaultdict 17 | 18 | from torch.utils.data.dataloader import DataLoader 19 | from data_loader import mimic_data, pad_num_replace 20 | from beam import Beam 21 | 22 | import sys 23 | sys.path.append("..") 24 | from models import Leap, CopyDrug_batch, CopyDrug_tranformer, CopyDrug_generate_prob, CopyDrug_diag_proc_encode 25 | from COGNet_model import COGNet 26 | from util import llprint, sequence_metric, sequence_metric_v2, sequence_output_process, ddi_rate_score, get_n_params, output_flatten, print_result 27 | 28 | torch.manual_seed(1203) 29 | 30 | # 读取disease跟proc的英文名 31 | icd_diag_path = '../data/D_ICD_DIAGNOSES.csv' 32 | icd_proc_path = '../data/D_ICD_PROCEDURES.csv' 33 | code2diag = {} 34 | code2proc = {} 35 | with open(icd_diag_path, 'r') as f: 36 | lines = f.readlines()[1:] 37 | for line in lines: 38 | line = line.strip().split(',"') 39 | if line[-1] == '': line = line[:-1] 40 | _, icd_code, _, title = line 41 | code2diag[icd_code[:-1]] = title 42 | 43 | with open(icd_proc_path, 'r') as f: 44 | lines = f.readlines()[1:] 45 | for line in lines: 46 | _, icd_code, _, title = line.strip().split(',"') 47 | code2proc[icd_code[:-1]] = title 48 | 49 | 50 | 51 | def eval_recommend_batch(model, batch_data, device, TOKENS, args): 52 | END_TOKEN, DIAG_PAD_TOKEN, PROC_PAD_TOKEN, MED_PAD_TOKEN, SOS_TOKEN = TOKENS 53 | 54 | diseases, procedures, medications, seq_length, \ 55 | d_length_matrix, p_length_matrix, m_length_matrix, \ 56 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \ 57 | dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, \ 58 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask = batch_data 59 | # continue 60 | # 根据vocab对padding数值进行替换 61 | diseases = pad_num_replace(diseases, -1, DIAG_PAD_TOKEN).to(device) 62 | procedures = pad_num_replace(procedures, -1, PROC_PAD_TOKEN).to(device) 63 | dec_disease = pad_num_replace(dec_disease, -1, DIAG_PAD_TOKEN).to(device) 64 | stay_disease = pad_num_replace(stay_disease, -1, DIAG_PAD_TOKEN).to(device) 65 | dec_proc = pad_num_replace(dec_proc, -1, PROC_PAD_TOKEN).to(device) 66 | stay_proc = pad_num_replace(stay_proc, -1, PROC_PAD_TOKEN).to(device) 67 | medications = medications.to(device) 68 | m_mask_matrix = m_mask_matrix.to(device) 69 | d_mask_matrix = d_mask_matrix.to(device) 70 | p_mask_matrix = p_mask_matrix.to(device) 71 | dec_disease_mask = dec_disease_mask.to(device) 72 | stay_disease_mask = stay_disease_mask.to(device) 73 | dec_proc_mask = dec_proc_mask.to(device) 74 | stay_proc_mask = stay_proc_mask.to(device) 75 | 76 | batch_size = medications.size(0) 77 | max_visit_num = medications.size(1) 78 | 79 | input_disease_embdding, input_proc_embedding, encoded_medication, cross_visit_scores, last_seq_medication, last_m_mask, drug_memory = model.encode(diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix, 80 | seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, dec_proc, stay_proc, dec_proc_mask, stay_proc_mask, max_len=20) 81 | 82 | partial_input_medication = torch.full((batch_size, max_visit_num, 1), SOS_TOKEN).to(device) 83 | parital_logits = None 84 | 85 | 86 | for i in range(args.max_len): 87 | partial_input_med_num = partial_input_medication.size(2) 88 | partial_m_mask_matrix = torch.zeros((batch_size, max_visit_num, partial_input_med_num), device=device).float() 89 | # print('val', i, partial_m_mask_matrix.size()) 90 | 91 | parital_logits = model.decode(partial_input_medication, input_disease_embdding, input_proc_embedding, encoded_medication, last_seq_medication, cross_visit_scores, 92 | d_mask_matrix, p_mask_matrix, partial_m_mask_matrix, last_m_mask, drug_memory) 93 | _, next_medication = torch.topk(parital_logits[:, :, -1, :], 1, dim=-1) 94 | partial_input_medication = torch.cat([partial_input_medication, next_medication], dim=-1) 95 | 96 | return parital_logits 97 | 98 | 99 | 100 | def test_recommend_batch(model, batch_data, device, TOKENS, ddi_adj, args): 101 | END_TOKEN, DIAG_PAD_TOKEN, PROC_PAD_TOKEN, MED_PAD_TOKEN, SOS_TOKEN = TOKENS 102 | 103 | diseases, procedures, medications, seq_length, \ 104 | d_length_matrix, p_length_matrix, m_length_matrix, \ 105 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \ 106 | dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, \ 107 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask = batch_data 108 | # continue 109 | # 根据vocab对padding数值进行替换 110 | diseases = pad_num_replace(diseases, -1, DIAG_PAD_TOKEN).to(device) 111 | procedures = pad_num_replace(procedures, -1, PROC_PAD_TOKEN).to(device) 112 | dec_disease = pad_num_replace(dec_disease, -1, DIAG_PAD_TOKEN).to(device) 113 | stay_disease = pad_num_replace(stay_disease, -1, DIAG_PAD_TOKEN).to(device) 114 | dec_proc = pad_num_replace(dec_proc, -1, PROC_PAD_TOKEN).to(device) 115 | stay_proc = pad_num_replace(stay_proc, -1, PROC_PAD_TOKEN).to(device) 116 | medications = medications.to(device) 117 | m_mask_matrix = m_mask_matrix.to(device) 118 | d_mask_matrix = d_mask_matrix.to(device) 119 | p_mask_matrix = p_mask_matrix.to(device) 120 | dec_disease_mask = dec_disease_mask.to(device) 121 | stay_disease_mask = stay_disease_mask.to(device) 122 | dec_proc_mask = dec_proc_mask.to(device) 123 | stay_proc_mask = stay_proc_mask.to(device) 124 | 125 | batch_size = medications.size(0) 126 | visit_num = medications.size(1) 127 | 128 | input_disease_embdding, input_proc_embedding, encoded_medication, cross_visit_scores, last_seq_medication, last_m_mask, drug_memory = model.encode(diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix, 129 | seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, dec_proc, stay_proc, dec_proc_mask, stay_proc_mask, max_len=20) 130 | 131 | # partial_input_medication = torch.full((batch_size, visit_num, 1), SOS_TOKEN).to(device) 132 | # parital_logits = None 133 | 134 | # 为每一个样本声明一个beam 135 | # 这里为了方便实现,写死batch_size必须为1 136 | assert batch_size == 1 137 | # visit_num个batch 138 | beams = [Beam(args.beam_size, MED_PAD_TOKEN, SOS_TOKEN, END_TOKEN, ddi_adj, device) for _ in range(visit_num)] 139 | 140 | # 构建decode输入,每一个visit上需要重复beam_size次数据 141 | input_disease_embdding = input_disease_embdding.repeat_interleave(args.beam_size, dim=0) 142 | input_proc_embedding = input_proc_embedding.repeat_interleave(args.beam_size, dim=0) 143 | encoded_medication = encoded_medication.repeat_interleave(args.beam_size, dim=0) 144 | last_seq_medication = last_seq_medication.repeat_interleave(args.beam_size, dim=0) 145 | cross_visit_scores = cross_visit_scores.repeat_interleave(args.beam_size, dim=0) 146 | # cross_visit_scores = cross_visit_scores.repeat_interleave(args.beam_size, dim=2) 147 | d_mask_matrix = d_mask_matrix.repeat_interleave(args.beam_size, dim=0) 148 | p_mask_matrix = p_mask_matrix.repeat_interleave(args.beam_size, dim=0) 149 | last_m_mask = last_m_mask.repeat_interleave(args.beam_size, dim=0) 150 | 151 | for i in range(args.max_len): 152 | len_dec_seq = i + 1 153 | # b.get_current_state(): (beam_size, len_dec_seq) --> (beam_size, 1, len_dec_seq) 154 | # dec_partial_inputs: (beam_size, visit_num, len_dec_seq) 155 | dec_partial_inputs = torch.cat([b.get_current_state().unsqueeze(dim=1) for b in beams], dim=1) 156 | # dec_partial_inputs = dec_partial_inputs.view(args.beam_size, visit_num, len_dec_seq) 157 | 158 | partial_m_mask_matrix = torch.zeros((args.beam_size, visit_num, len_dec_seq), device=device).float() 159 | # print('val', i, partial_m_mask_matrix.size()) 160 | 161 | # parital_logits: (beam_size, visit_sum, len_dec_seq, all_med_num) 162 | parital_logits = model.decode(dec_partial_inputs, input_disease_embdding, input_proc_embedding, encoded_medication, last_seq_medication, cross_visit_scores, 163 | d_mask_matrix, p_mask_matrix, partial_m_mask_matrix, last_m_mask, drug_memory) 164 | 165 | # word_lk: (beam_size, visit_sum, all_med_num) 166 | word_lk = parital_logits[:, :, -1, :] 167 | 168 | active_beam_idx_list = [] # 记录目前仍然active的beam 169 | for beam_idx in range(visit_num): 170 | # # 如果当前beam完成了,则跳过,这里beams的size应该是不变的 171 | # if beams[beam_idx].done: continue 172 | # inst_idx = beam_inst_idx_map[beam_idx] # 该beam所对应的adm下标 173 | # 更新beam,同时返回当前beam是否完成,如果未完成则表示active 174 | if not beams[beam_idx].advance(word_lk[:, beam_idx, :]): 175 | active_beam_idx_list.append(beam_idx) 176 | 177 | # 如果没有active的beam,则全部样本预测完毕 178 | if not active_beam_idx_list: break 179 | 180 | # Return useful information 181 | all_hyp = [] 182 | all_prob = [] 183 | for beam_idx in range(visit_num): 184 | scores, tail_idxs = beams[beam_idx].sort_scores() # 每个beam按照score排序,找出最优的生成 185 | hyps = beams[beam_idx].get_hypothesis(tail_idxs[0]) 186 | probs = beams[beam_idx].get_prob_list(tail_idxs[0]) 187 | all_hyp += [hyps] # 注意这里只关注最优解,否则写法上要修改 188 | all_prob += [probs] 189 | 190 | return all_hyp, all_prob 191 | 192 | 193 | # evaluate 194 | def eval(model, eval_dataloader, voc_size, epoch, device, TOKENS, args): 195 | model.eval() 196 | END_TOKEN, DIAG_PAD_TOKEN, PROC_PAD_TOKEN, MED_PAD_TOKEN, SOS_TOKEN = TOKENS 197 | ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] 198 | smm_record = [] 199 | med_cnt, visit_cnt = 0, 0 200 | 201 | # fw = open("prediction_results.txt", "w") 202 | 203 | for idx, data in enumerate(eval_dataloader): 204 | diseases, procedures, medications, seq_length, \ 205 | d_length_matrix, p_length_matrix, m_length_matrix, \ 206 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \ 207 | dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, \ 208 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask = data 209 | visit_cnt += seq_length.sum().item() 210 | 211 | output_logits = eval_recommend_batch(model, data, device, TOKENS, args) 212 | 213 | # 每一个med上的预测结果 214 | labels, predictions = output_flatten(medications, output_logits, seq_length, m_length_matrix, voc_size[2], END_TOKEN, device, training=False, testing=False, max_len=args.max_len) 215 | 216 | y_gt = [] # groud truth 表示正确的label 0-1序列 217 | y_pred = [] # 预测的结果 0-1序列 218 | y_pred_prob = [] # 预测的每一个药物的平均概率,非0-1序列 219 | y_pred_label = [] # 预测的结果,非0-1序列 220 | # 针对每一个admission的预测结果 221 | for label, prediction in zip(labels, predictions): 222 | y_gt_tmp = np.zeros(voc_size[2]) 223 | y_gt_tmp[label] = 1 # 01序列,表示正确的label 224 | y_gt.append(y_gt_tmp) 225 | 226 | # label: med set 227 | # prediction: [med_num, probability] 228 | out_list, sorted_predict = sequence_output_process(prediction, [voc_size[2], voc_size[2]+1]) 229 | y_pred_label.append(sorted(sorted_predict)) 230 | y_pred_prob.append(np.mean(prediction[:, :-2], axis=0)) 231 | 232 | # prediction label 233 | y_pred_tmp = np.zeros(voc_size[2]) 234 | y_pred_tmp[out_list] = 1 235 | y_pred.append(y_pred_tmp) 236 | med_cnt += len(sorted_predict) 237 | 238 | # if idx < 100: 239 | # fw.write(print_result(label, sorted_predict)) 240 | 241 | smm_record.append(y_pred_label) 242 | 243 | adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = \ 244 | sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label)) 245 | ja.append(adm_ja) 246 | prauc.append(adm_prauc) 247 | avg_p.append(adm_avg_p) 248 | avg_r.append(adm_avg_r) 249 | avg_f1.append(adm_avg_f1) 250 | llprint('\rtest step: {} / {}'.format(idx, len(eval_dataloader))) 251 | 252 | # fw.close() 253 | 254 | # ddi rate 255 | ddi_rate = ddi_rate_score(smm_record, path='../data/ddi_A_final.pkl') 256 | 257 | llprint('\nDDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'.format( 258 | ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt 259 | )) 260 | 261 | return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt 262 | 263 | 264 | # test 265 | def test(model, test_dataloader, diag_voc, pro_voc, med_voc, voc_size, epoch, device, TOKENS, ddi_adj, args): 266 | model.eval() 267 | END_TOKEN, DIAG_PAD_TOKEN, PROC_PAD_TOKEN, MED_PAD_TOKEN, SOS_TOKEN = TOKENS 268 | ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] 269 | med_cnt_list = [] 270 | smm_record = [] 271 | med_cnt, visit_cnt = 0, 0 272 | all_pred_list = [] 273 | all_label_list = [] 274 | 275 | ja_by_visit = [[] for _ in range(5)] 276 | auc_by_visit = [[] for _ in range(5)] 277 | pre_by_visit = [[] for _ in range(5)] 278 | recall_by_visit = [[] for _ in range(5)] 279 | f1_by_visit = [[] for _ in range(5)] 280 | smm_record_by_visit = [[] for _ in range(5)] 281 | 282 | for idx, data in enumerate(test_dataloader): 283 | diseases, procedures, medications, seq_length, \ 284 | d_length_matrix, p_length_matrix, m_length_matrix, \ 285 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \ 286 | dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, \ 287 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask = data 288 | visit_cnt += seq_length.sum().item() 289 | 290 | output_logits, output_probs = test_recommend_batch(model, data, device, TOKENS, ddi_adj, args) 291 | 292 | labels, predictions = output_flatten(medications, output_logits, seq_length, m_length_matrix, voc_size[2], END_TOKEN, device, training=False, testing=True, max_len=args.max_len) 293 | _, probs = output_flatten(medications, output_probs, seq_length, m_length_matrix, voc_size[2], END_TOKEN, device, training=False, testing=True, max_len=args.max_len) 294 | y_gt = [] 295 | y_pred = [] 296 | y_pred_label = [] 297 | y_pred_prob = [] 298 | 299 | label_hisory = [] 300 | label_hisory_list = [] 301 | pred_list = [] 302 | jaccard_list = [] 303 | def cal_jaccard(set1, set2): 304 | if not set1 or not set2: 305 | return 0 306 | set1 = set(set1) 307 | set2 = set(set2) 308 | a, b = len(set1 & set2), len(set1 | set2) 309 | return a/b 310 | def cal_overlap_num(set1, set2): 311 | count = 0 312 | for d in set1: 313 | if d in set2: 314 | count += 1 315 | return count 316 | 317 | # 针对每一个admission的预测结果 318 | for label, prediction, prob_list in zip(labels, predictions, probs): 319 | label_hisory += label.tolist() ### case study 320 | 321 | y_gt_tmp = np.zeros(voc_size[2]) 322 | y_gt_tmp[label] = 1 # 01序列,表示正确的label 323 | y_gt.append(y_gt_tmp) 324 | 325 | out_list = [] 326 | out_prob_list = [] 327 | for med, prob in zip(prediction, prob_list): 328 | if med in [voc_size[2], voc_size[2]+1]: 329 | break 330 | out_list.append(med) 331 | out_prob_list.append(prob[:-2]) # 去掉SOS与EOS符号 332 | 333 | ## case study 334 | if label_hisory: 335 | jaccard_list.append(cal_jaccard(prediction, label_hisory)) 336 | pred_list.append(out_list) 337 | label_hisory_list.append(label.tolist()) 338 | 339 | # 对于没预测的药物,取每个位置上平均的概率,否则直接取对应的概率 340 | # pred_out_prob_list = np.mean(out_prob_list, axis=0) 341 | pred_out_prob_list = np.max(out_prob_list, axis=0) 342 | # pred_out_prob_list = np.min(out_prob_list, axis=0) 343 | for i in range(131): 344 | if i in out_list: 345 | pred_out_prob_list[i] = out_prob_list[out_list.index(i)][i] 346 | 347 | y_pred_prob.append(pred_out_prob_list) 348 | y_pred_label.append(out_list) 349 | 350 | # prediction label 351 | y_pred_tmp = np.zeros(voc_size[2]) 352 | y_pred_tmp[out_list] = 1 353 | y_pred.append(y_pred_tmp) 354 | med_cnt += len(prediction) 355 | med_cnt_list.append(len(prediction)) 356 | 357 | 358 | smm_record.append(y_pred_label) 359 | for i in range(min(len(labels), 5)): 360 | # single_ja, single_p, single_r, single_f1 = sequence_metric_v2(np.array(y_gt[i:i+1]), np.array(y_pred[i:i+1]), np.array(y_pred_label[i:i+1])) 361 | single_ja, single_auc, single_p, single_r, single_f1 = sequence_metric(np.array([y_gt[i]]), np.array([y_pred[i]]), np.array([y_pred_prob[i]]),np.array([y_pred_label[i]])) 362 | ja_by_visit[i].append(single_ja) 363 | auc_by_visit[i].append(single_auc) 364 | pre_by_visit[i].append(single_p) 365 | recall_by_visit[i].append(single_r) 366 | f1_by_visit[i].append(single_f1) 367 | smm_record_by_visit[i].append(y_pred_label[i:i+1]) 368 | 369 | # 存储所有预测结果 370 | all_pred_list.append(pred_list) 371 | all_label_list.append(labels) 372 | adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = \ 373 | sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label)) 374 | ja.append(adm_ja) 375 | prauc.append(adm_prauc) 376 | avg_p.append(adm_avg_p) 377 | avg_r.append(adm_avg_r) 378 | avg_f1.append(adm_avg_f1) 379 | llprint('\rtest step: {} / {}'.format(idx, len(test_dataloader))) 380 | 381 | # 统计不同visit的指标 382 | if idx%100==0: 383 | print('\tvisit1\tvisit2\tvisit3\tvisit4\tvisit5') 384 | print('count:', [len(buf) for buf in ja_by_visit]) 385 | print('jaccard:', [np.mean(buf) for buf in ja_by_visit]) 386 | print('auc:', [np.mean(buf) for buf in auc_by_visit]) 387 | print('precision:', [np.mean(buf) for buf in pre_by_visit]) 388 | print('recall:', [np.mean(buf) for buf in recall_by_visit]) 389 | print('f1:', [np.mean(buf) for buf in f1_by_visit]) 390 | print('DDI:', [ddi_rate_score(buf) for buf in smm_record_by_visit]) 391 | 392 | print('\tvisit1\tvisit2\tvisit3\tvisit4\tvisit5') 393 | print('count:', [len(buf) for buf in ja_by_visit]) 394 | print('jaccard:', [np.mean(buf) for buf in ja_by_visit]) 395 | print('auc:', [np.mean(buf) for buf in auc_by_visit]) 396 | print('precision:', [np.mean(buf) for buf in pre_by_visit]) 397 | print('recall:', [np.mean(buf) for buf in recall_by_visit]) 398 | print('f1:', [np.mean(buf) for buf in f1_by_visit]) 399 | print('DDI:', [ddi_rate_score(buf) for buf in smm_record_by_visit]) 400 | 401 | pickle.dump(all_pred_list, open('out_list.pkl', 'wb')) 402 | pickle.dump(all_label_list, open('out_list_gt.pkl', 'wb')) 403 | 404 | return smm_record, ja, prauc, avg_p, avg_r, avg_f1, med_cnt_list 405 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score 2 | import numpy as np 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | import sys 6 | import warnings 7 | import dill 8 | from collections import Counter 9 | from rdkit import Chem 10 | from collections import defaultdict 11 | import torch 12 | warnings.filterwarnings('ignore') 13 | 14 | def get_n_params(model): 15 | pp=0 16 | for p in list(model.parameters()): 17 | nn=1 18 | for s in list(p.size()): 19 | nn = nn*s 20 | pp += nn 21 | return pp 22 | 23 | # use the same metric from DMNC 24 | def llprint(message): 25 | sys.stdout.write(message) 26 | sys.stdout.flush() 27 | 28 | def transform_split(X, Y): 29 | x_train, x_eval, y_train, y_eval = train_test_split(X, Y, train_size=2/3, random_state=1203) 30 | x_eval, x_test, y_eval, y_test = train_test_split(x_eval, y_eval, test_size=0.5, random_state=1203) 31 | return x_train, x_eval, x_test, y_train, y_eval, y_test 32 | 33 | def sequence_output_process(output_logits, filter_token): 34 | """生成最终正确的序列,output_logits表示每个位置的prob,filter_token代表SOS与END""" 35 | pind = np.argsort(output_logits, axis=-1)[:, ::-1] # 每个位置上按概率的降序排序 36 | 37 | out_list = [] # 生成的结果 38 | break_flag = False 39 | for i in range(len(pind)): 40 | # 顺序遍历pind上所有值 41 | # break_flag来判断是否退出sentence生成的循环 42 | if break_flag: 43 | break 44 | # 每个位置上是按降序排序好的结果 45 | for j in range(pind.shape[1]): 46 | label = pind[i][j] 47 | # 如果遇到了SOS或者END,就表示句子over了 48 | if label in filter_token: 49 | break_flag = True 50 | break 51 | # 如果遇到了未出现过的,就继续生成 52 | # 否则就继续看下一个概率较大的药 53 | if label not in out_list: 54 | out_list.append(label) 55 | break 56 | y_pred_prob_tmp = [] 57 | for idx, item in enumerate(out_list): 58 | y_pred_prob_tmp.append(output_logits[idx, item]) 59 | # 将out_list中按照概率的高低将所有药物排序? 60 | sorted_predict = [x for _, x in sorted(zip(y_pred_prob_tmp, out_list), reverse=True)] 61 | return out_list, sorted_predict 62 | 63 | 64 | def sequence_metric(y_gt, y_pred, y_prob, y_label): 65 | def average_prc(y_gt, y_label): 66 | score = [] 67 | for b in range(y_gt.shape[0]): 68 | target = np.where(y_gt[b]==1)[0] 69 | out_list = y_label[b] 70 | inter = set(out_list) & set(target) 71 | prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list) 72 | score.append(prc_score) 73 | return score 74 | 75 | 76 | def average_recall(y_gt, y_label): 77 | score = [] 78 | for b in range(y_gt.shape[0]): 79 | target = np.where(y_gt[b] == 1)[0] 80 | out_list = y_label[b] 81 | inter = set(out_list) & set(target) 82 | recall_score = 0 if len(target) == 0 else len(inter) / len(target) 83 | score.append(recall_score) 84 | return score 85 | 86 | 87 | def average_f1(average_prc, average_recall): 88 | score = [] 89 | for idx in range(len(average_prc)): 90 | if (average_prc[idx] + average_recall[idx]) == 0: 91 | score.append(0) 92 | else: 93 | score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx])) 94 | return score 95 | 96 | 97 | def jaccard(y_gt, y_label): 98 | score = [] 99 | for b in range(y_gt.shape[0]): 100 | target = np.where(y_gt[b] == 1)[0] 101 | out_list = y_label[b] 102 | inter = set(out_list) & set(target) 103 | union = set(out_list) | set(target) 104 | jaccard_score = 0 if union == 0 else len(inter) / len(union) 105 | score.append(jaccard_score) 106 | return np.mean(score) 107 | 108 | def f1(y_gt, y_pred): 109 | all_micro = [] 110 | for b in range(y_gt.shape[0]): 111 | all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro')) 112 | return np.mean(all_micro) 113 | 114 | def roc_auc(y_gt, y_pred_prob): 115 | all_micro = [] 116 | for b in range(len(y_gt)): 117 | all_micro.append(roc_auc_score(y_gt[b], y_pred_prob[b], average='macro')) 118 | return np.mean(all_micro) 119 | 120 | def precision_auc(y_gt, y_prob): 121 | all_micro = [] 122 | for b in range(len(y_gt)): 123 | # try: 124 | # all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro')) 125 | # except: 126 | # continue 127 | all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro')) 128 | return np.mean(all_micro) 129 | 130 | def precision_at_k(y_gt, y_prob_label, k): 131 | precision = 0 132 | for i in range(len(y_gt)): 133 | TP = 0 134 | for j in y_prob_label[i][:k]: 135 | if y_gt[i, j] == 1: 136 | TP += 1 137 | precision += TP / k 138 | return precision / len(y_gt) 139 | try: 140 | auc = roc_auc(y_gt, y_prob) 141 | except ValueError: 142 | auc = 0 143 | p_1 = precision_at_k(y_gt, y_label, k=1) 144 | p_3 = precision_at_k(y_gt, y_label, k=3) 145 | p_5 = precision_at_k(y_gt, y_label, k=5) 146 | f1 = f1(y_gt, y_pred) 147 | prauc = precision_auc(y_gt, y_prob) 148 | ja = jaccard(y_gt, y_label) 149 | avg_prc = average_prc(y_gt, y_label) 150 | avg_recall = average_recall(y_gt, y_label) 151 | avg_f1 = average_f1(avg_prc, avg_recall) 152 | 153 | return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1) 154 | 155 | 156 | def sequence_metric_v2(y_gt, y_pred, y_label): 157 | def average_prc(y_gt, y_label): 158 | score = [] 159 | for b in range(y_gt.shape[0]): 160 | target = np.where(y_gt[b]==1)[0] 161 | out_list = y_label[b] 162 | inter = set(out_list) & set(target) 163 | prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list) 164 | score.append(prc_score) 165 | return score 166 | 167 | 168 | def average_recall(y_gt, y_label): 169 | score = [] 170 | for b in range(y_gt.shape[0]): 171 | target = np.where(y_gt[b] == 1)[0] 172 | out_list = y_label[b] 173 | inter = set(out_list) & set(target) 174 | recall_score = 0 if len(target) == 0 else len(inter) / len(target) 175 | score.append(recall_score) 176 | return score 177 | 178 | 179 | def average_f1(average_prc, average_recall): 180 | score = [] 181 | for idx in range(len(average_prc)): 182 | if (average_prc[idx] + average_recall[idx]) == 0: 183 | score.append(0) 184 | else: 185 | score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx])) 186 | return score 187 | 188 | 189 | def jaccard(y_gt, y_label): 190 | score = [] 191 | for b in range(y_gt.shape[0]): 192 | target = np.where(y_gt[b] == 1)[0] 193 | out_list = y_label[b] 194 | inter = set(out_list) & set(target) 195 | union = set(out_list) | set(target) 196 | jaccard_score = 0 if union == 0 else len(inter) / len(union) 197 | score.append(jaccard_score) 198 | return np.mean(score) 199 | 200 | def f1(y_gt, y_pred): 201 | all_micro = [] 202 | for b in range(y_gt.shape[0]): 203 | all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro')) 204 | return np.mean(all_micro) 205 | 206 | def roc_auc(y_gt, y_pred_prob): 207 | all_micro = [] 208 | for b in range(len(y_gt)): 209 | all_micro.append(roc_auc_score(y_gt[b], y_pred_prob[b], average='macro')) 210 | return np.mean(all_micro) 211 | 212 | def precision_auc(y_gt, y_prob): 213 | all_micro = [] 214 | for b in range(len(y_gt)): 215 | all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro')) 216 | return np.mean(all_micro) 217 | 218 | def precision_at_k(y_gt, y_prob_label, k): 219 | precision = 0 220 | for i in range(len(y_gt)): 221 | TP = 0 222 | for j in y_prob_label[i][:k]: 223 | if y_gt[i, j] == 1: 224 | TP += 1 225 | precision += TP / k 226 | return precision / len(y_gt) 227 | # try: 228 | # auc = roc_auc(y_gt, y_prob) 229 | # except ValueError: 230 | # auc = 0 231 | # p_1 = precision_at_k(y_gt, y_label, k=1) 232 | # p_3 = precision_at_k(y_gt, y_label, k=3) 233 | # p_5 = precision_at_k(y_gt, y_label, k=5) 234 | f1 = f1(y_gt, y_pred) 235 | # prauc = precision_auc(y_gt, y_prob) 236 | ja = jaccard(y_gt, y_label) 237 | avg_prc = average_prc(y_gt, y_label) 238 | avg_recall = average_recall(y_gt, y_label) 239 | avg_f1 = average_f1(avg_prc, avg_recall) 240 | 241 | return ja, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1) 242 | 243 | def multi_label_metric(y_gt, y_pred, y_prob): 244 | 245 | def jaccard(y_gt, y_pred): 246 | score = [] 247 | for b in range(y_gt.shape[0]): 248 | target = np.where(y_gt[b] == 1)[0] 249 | out_list = np.where(y_pred[b] == 1)[0] 250 | inter = set(out_list) & set(target) 251 | union = set(out_list) | set(target) 252 | jaccard_score = 0 if union == 0 else len(inter) / len(union) 253 | score.append(jaccard_score) 254 | return np.mean(score) 255 | 256 | def average_prc(y_gt, y_pred): 257 | score = [] 258 | for b in range(y_gt.shape[0]): 259 | target = np.where(y_gt[b] == 1)[0] 260 | out_list = np.where(y_pred[b] == 1)[0] 261 | inter = set(out_list) & set(target) 262 | prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list) 263 | score.append(prc_score) 264 | return score 265 | 266 | def average_recall(y_gt, y_pred): 267 | score = [] 268 | for b in range(y_gt.shape[0]): 269 | target = np.where(y_gt[b] == 1)[0] 270 | out_list = np.where(y_pred[b] == 1)[0] 271 | inter = set(out_list) & set(target) 272 | recall_score = 0 if len(target) == 0 else len(inter) / len(target) 273 | score.append(recall_score) 274 | return score 275 | 276 | def average_f1(average_prc, average_recall): 277 | score = [] 278 | for idx in range(len(average_prc)): 279 | if average_prc[idx] + average_recall[idx] == 0: 280 | score.append(0) 281 | else: 282 | score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx])) 283 | return score 284 | 285 | def f1(y_gt, y_pred): 286 | all_micro = [] 287 | for b in range(y_gt.shape[0]): 288 | all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro')) 289 | return np.mean(all_micro) 290 | 291 | def roc_auc(y_gt, y_prob): 292 | all_micro = [] 293 | for b in range(len(y_gt)): 294 | all_micro.append(roc_auc_score(y_gt[b], y_prob[b], average='macro')) 295 | return np.mean(all_micro) 296 | 297 | def precision_auc(y_gt, y_prob): 298 | all_micro = [] 299 | for b in range(len(y_gt)): 300 | all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro')) 301 | return np.mean(all_micro) 302 | 303 | def precision_at_k(y_gt, y_prob, k=3): 304 | precision = 0 305 | sort_index = np.argsort(y_prob, axis=-1)[:, ::-1][:, :k] 306 | for i in range(len(y_gt)): 307 | TP = 0 308 | for j in range(len(sort_index[i])): 309 | if y_gt[i, sort_index[i, j]] == 1: 310 | TP += 1 311 | precision += TP / len(sort_index[i]) 312 | return precision / len(y_gt) 313 | 314 | # roc_auc 315 | try: 316 | auc = roc_auc(y_gt, y_prob) 317 | except: 318 | auc = 0 319 | # precision 320 | p_1 = precision_at_k(y_gt, y_prob, k=1) 321 | p_3 = precision_at_k(y_gt, y_prob, k=3) 322 | p_5 = precision_at_k(y_gt, y_prob, k=5) 323 | # macro f1 324 | f1 = f1(y_gt, y_pred) 325 | # precision 326 | prauc = precision_auc(y_gt, y_prob) 327 | # jaccard 328 | ja = jaccard(y_gt, y_pred) 329 | # pre, recall, f1 330 | avg_prc = average_prc(y_gt, y_pred) 331 | avg_recall = average_recall(y_gt, y_pred) 332 | avg_f1 = average_f1(avg_prc, avg_recall) 333 | 334 | return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1) 335 | 336 | def ddi_rate_score(record, path='../data/ddi_A_final.pkl'): 337 | # ddi rate 338 | ddi_A = dill.load(open(path, 'rb')) 339 | all_cnt = 0 340 | dd_cnt = 0 341 | for patient in record: 342 | for adm in patient: 343 | med_code_set = adm 344 | for i, med_i in enumerate(med_code_set): 345 | for j, med_j in enumerate(med_code_set): 346 | if j <= i: 347 | continue 348 | all_cnt += 1 349 | if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1: 350 | dd_cnt += 1 351 | if all_cnt == 0: 352 | return 0 353 | return dd_cnt / all_cnt 354 | 355 | 356 | def create_atoms(mol, atom_dict): 357 | """Transform the atom types in a molecule (e.g., H, C, and O) 358 | into the indices (e.g., H=0, C=1, and O=2). 359 | Note that each atom index considers the aromaticity. 360 | """ 361 | atoms = [a.GetSymbol() for a in mol.GetAtoms()] 362 | for a in mol.GetAromaticAtoms(): 363 | i = a.GetIdx() 364 | atoms[i] = (atoms[i], 'aromatic') 365 | atoms = [atom_dict[a] for a in atoms] 366 | return np.array(atoms) 367 | 368 | def create_ijbonddict(mol, bond_dict): 369 | """Create a dictionary, in which each key is a node ID 370 | and each value is the tuples of its neighboring node 371 | and chemical bond (e.g., single and double) IDs. 372 | """ 373 | i_jbond_dict = defaultdict(lambda: []) 374 | for b in mol.GetBonds(): 375 | i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() 376 | bond = bond_dict[str(b.GetBondType())] 377 | i_jbond_dict[i].append((j, bond)) 378 | i_jbond_dict[j].append((i, bond)) 379 | return i_jbond_dict 380 | 381 | def extract_fingerprints(radius, atoms, i_jbond_dict, 382 | fingerprint_dict, edge_dict): 383 | """Extract the fingerprints from a molecular graph 384 | based on Weisfeiler-Lehman algorithm. 385 | """ 386 | 387 | if (len(atoms) == 1) or (radius == 0): 388 | nodes = [fingerprint_dict[a] for a in atoms] 389 | 390 | else: 391 | nodes = atoms 392 | i_jedge_dict = i_jbond_dict 393 | 394 | for _ in range(radius): 395 | 396 | """Update each node ID considering its neighboring nodes and edges. 397 | The updated node IDs are the fingerprint IDs. 398 | """ 399 | nodes_ = [] 400 | for i, j_edge in i_jedge_dict.items(): 401 | neighbors = [(nodes[j], edge) for j, edge in j_edge] 402 | fingerprint = (nodes[i], tuple(sorted(neighbors))) 403 | nodes_.append(fingerprint_dict[fingerprint]) 404 | 405 | """Also update each edge ID considering 406 | its two nodes on both sides. 407 | """ 408 | i_jedge_dict_ = defaultdict(lambda: []) 409 | for i, j_edge in i_jedge_dict.items(): 410 | for j, edge in j_edge: 411 | both_side = tuple(sorted((nodes[i], nodes[j]))) 412 | edge = edge_dict[(both_side, edge)] 413 | i_jedge_dict_[i].append((j, edge)) 414 | 415 | nodes = nodes_ 416 | i_jedge_dict = i_jedge_dict_ 417 | 418 | return np.array(nodes) 419 | 420 | 421 | def buildMPNN(molecule, med_voc, radius=1, device="cpu:0"): 422 | 423 | atom_dict = defaultdict(lambda: len(atom_dict)) 424 | bond_dict = defaultdict(lambda: len(bond_dict)) 425 | fingerprint_dict = defaultdict(lambda: len(fingerprint_dict)) 426 | edge_dict = defaultdict(lambda: len(edge_dict)) 427 | MPNNSet, average_index = [], [] 428 | 429 | print (len(med_voc.items())) 430 | for index, ndc in med_voc.items(): 431 | 432 | smilesList = list(molecule[ndc]) 433 | 434 | """Create each data with the above defined functions.""" 435 | counter = 0 # counter how many drugs are under that ATC-3 436 | for smiles in smilesList: 437 | try: 438 | mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) 439 | atoms = create_atoms(mol, atom_dict) 440 | molecular_size = len(atoms) 441 | i_jbond_dict = create_ijbonddict(mol, bond_dict) 442 | fingerprints = extract_fingerprints(radius, atoms, i_jbond_dict, 443 | fingerprint_dict, edge_dict) 444 | adjacency = Chem.GetAdjacencyMatrix(mol) 445 | # if fingerprints.shape[0] == adjacency.shape[0]: 446 | for _ in range(adjacency.shape[0] - fingerprints.shape[0]): 447 | fingerprints = np.append(fingerprints, 1) 448 | fingerprints = torch.LongTensor(fingerprints).to(device) 449 | adjacency = torch.FloatTensor(adjacency).to(device) 450 | MPNNSet.append((fingerprints, adjacency, molecular_size)) 451 | counter += 1 452 | except: 453 | continue 454 | average_index.append(counter) 455 | 456 | """Transform the above each data of numpy 457 | to pytorch tensor on a device (i.e., CPU or GPU). 458 | """ 459 | 460 | N_fingerprint = len(fingerprint_dict) 461 | 462 | # transform into projection matrix 463 | n_col = sum(average_index) 464 | n_row = len(average_index) 465 | 466 | average_projection = np.zeros((n_row, n_col)) 467 | col_counter = 0 468 | for i, item in enumerate(average_index): 469 | average_projection[i, col_counter : col_counter + item] = 1 / item 470 | col_counter += item 471 | 472 | return MPNNSet, N_fingerprint, torch.FloatTensor(average_projection) 473 | 474 | 475 | 476 | def output_flatten(labels, logits, seq_length, m_length_matrix, med_num, END_TOKEN, device, training=True, testing=False, max_len=20): 477 | ''' 478 | labels: [batch_size, visit_num, medication_num] 479 | logits: [batch_size, visit_num, max_med_num, medication_vocab_size] 480 | ''' 481 | # 将最终多个维度的结果展开 482 | batch_size, max_seq_length = labels.size()[:2] 483 | assert max_seq_length == max(seq_length) 484 | whole_seqs_num = seq_length.sum().item() 485 | if training: 486 | whole_med_sum = sum([sum(buf) for buf in m_length_matrix]) + whole_seqs_num # 因为每一个seq后面会多一个END_TOKEN 487 | 488 | # 将结果展开,然后用库函数进行计算 489 | labels_flatten = torch.empty(whole_med_sum).to(device) 490 | logits_flatten = torch.empty(whole_med_sum, med_num).to(device) 491 | 492 | start_idx = 0 493 | for i in range(batch_size): # 每个batch 494 | for j in range(seq_length[i]): # seq_length[i]指这个batch对应的seq数目 495 | for k in range(m_length_matrix[i][j]+1): # m_length_matrix[i][j]对应seq中med的数目 496 | if k==m_length_matrix[i][j]: # 最后一个label指定为END_TOKEN 497 | labels_flatten[start_idx] = END_TOKEN 498 | else: 499 | labels_flatten[start_idx] = labels[i, j, k] 500 | logits_flatten[start_idx, :] = logits[i, j, k, :] 501 | start_idx += 1 502 | return labels_flatten, logits_flatten 503 | else: 504 | # 将结果按照adm展开,然后用库函数进行计算 505 | labels_flatten = [] 506 | logits_flatten = [] 507 | 508 | start_idx = 0 509 | for i in range(batch_size): # 每个batch 510 | for j in range(seq_length[i]): # seq_length[i]指这个batch对应的seq数目 511 | labels_flatten.append(labels[i,j,:m_length_matrix[i][j]].detach().cpu().numpy()) 512 | 513 | if testing: 514 | logits_flatten.append(logits[j]) # beam search目前直接给出了预测结果 515 | else: 516 | logits_flatten.append(logits[i,j,:max_len,:].detach().cpu().numpy()) # 注意这里手动定义了max_len 517 | # cur_label = [] 518 | # cur_seq_length = [] 519 | # for k in range(m_length_matrix[i][j]+1): # m_length_matrix[i][j]对应seq中med的数目 520 | # if k==m_length_matrix[i][j]: # 最后一个label指定为END_TOKEN 521 | # continue 522 | # else: 523 | # labels_flatten[start_idx] = labels[i, j, k] 524 | # logits_flatten[start_idx, :] = logits[i, j, k, :med_num] 525 | # start_idx += 1 526 | return labels_flatten, logits_flatten 527 | 528 | 529 | def print_result(label, prediction): 530 | ''' 531 | label: [real_med_num, ] 532 | logits: [20, med_vocab_size] 533 | ''' 534 | label_text = " ".join([str(x) for x in label]) 535 | predict_text = " ".join([str(x) for x in prediction]) 536 | 537 | return "[GT]\t{}\n[PR]\t{}\n\n".format(label_text, predict_text) 538 | -------------------------------------------------------------------------------- /data/drug-atc.csv: -------------------------------------------------------------------------------- 1 | CID000003015,G03HA01 2 | CID000004011,N06AA21 3 | CID000071273,N01BB09 4 | CID000062816,C10AC02 5 | CID000052421,D07AC18,V01AA07,V01AA03 6 | CID000056339,N03AB05 7 | CID000077992,N02CC07 8 | CID000005300,L01AD04 9 | CID000222786,H02AB10,S01BA03 10 | CID000071301,C07AB12 11 | CID000004870,C03AA05 12 | CID000004873,A12BA01,B05XA01 13 | CID000002142,D06AX12,J01GB06,S01AA21 14 | CID000002141,V03AF05 15 | CID000131536,J05AE07 16 | CID000001775,N03AB02,N03AB04,N03AB05 17 | CID000003305,M05BA01,M05BB01 18 | CID000005454,N05AF04 19 | CID000003308,M01AB08 20 | CID000004909,D08AC04,S03AA05,S01AX08,R01AX07,R02AA18,N03AA03 21 | CID000004908,P01BA03 22 | CID000002099,A03AE01 23 | CID000054688,J01FA09 24 | CID000002092,G04CA01 25 | CID000005503,A10BB05 26 | CID000004536,G03AC01 27 | CID006918453,M01CB03 28 | CID000003706,J05AE02 29 | CID000005408,G03BA02,G03BA03,G03EK01 30 | CID000004497,C08CA06 31 | CID000004493,L02BB02 32 | CID000003219,S01GX06 33 | CID000002435,S01EA05 34 | CID000003121,N03AG01 35 | CID000040159,P03AC04 36 | CID000005878,A14AA08 37 | CID000002554,N03AF01 38 | CID000002550,C09AA01 39 | CID000002551,N07AB01,S01EB02 40 | CID000010100,N02AC04,N02AC54,N02AC74 41 | CID000002559,J01CA03 42 | CID000443871,A10BX03 43 | CID000003059,N02BA11 44 | CID000019090,G03AC05,G03DB02,L02AB01,G03DB04 45 | CID000059768,C07AB09 46 | CID004183806,A02BB01 47 | CID000176870,L01XE03 48 | CID000004473,C08CA04 49 | CID000003440,C03CA01 50 | CID000003449,N06DA04 51 | CID000003685,L01DB06 52 | CID000003937,C09AA03 53 | CID004630253,D07AB10,S01BA10 54 | CID000018140,L01XX11 55 | CID000000727,P03AB02,D08AE01 56 | CID004475485,C10AA07 57 | CID000077993,N02CC06 58 | CID000148211,A04AA05 59 | CID000072938,J01MA15 60 | CID000002153,R03DA04,R03DA05 61 | CID000002156,C01BD01 62 | CID000008953,B05XA02,B05CB04 63 | CID000005203,N06AB06 64 | CID000005206,N01AB08 65 | CID005281104,H05BX02 66 | CID000004917,N05AB04 67 | CID000004914,C05AD05,D04AB03,J01CE09,N01BA02,N01BA04,S01HA02,S01HA05,C01BA02 68 | CID000004913,C01BA02 69 | CID000003339,C10AB05 70 | CID000004911,M04AB01 71 | CID000213039,J05AE10 72 | CID006436173,A07AA11,D06AX11 73 | CID000002088,M05BA04 74 | CID000002082,P02CA03 75 | CID000005515,L01XX17 76 | CID000005516,L02BA02 77 | CID000002145,L02BG01 78 | CID000005512,G04BD07 79 | CID000005735,N05CF01 80 | CID000005734,N03AX15 81 | CID000005731,N02CC03 82 | CID000005732,N05CF02 83 | CID000003203,J05AG03 84 | CID000002140,V08AA01,V08AA04 85 | CID000004485,C08CA05 86 | CID000060707,J01DH02 87 | CID000003117,N07BB01,P03AA04 88 | CID000003114,C01BA03 89 | CID000004645,D06AA03,G01AA07,J01AA06,S01AA04 90 | CID000038904,L01XA02 91 | CID000060147,G04CA02 92 | CID000004674,M05BA03 93 | CID000002522,D05AX02 94 | CID000002524,A11CC04,D05AX03 95 | CID000104865,C02KX01 96 | CID004659568,N04BX02 97 | CID004659569,N04BX01 98 | CID000003062,C01AA02,C01AA05,C01AA08,V03AB24 99 | CID000003066,N02CA01 100 | CID000154059,G04BD08 101 | CID000004889,C10AA03 102 | CID000004264,D06AX09,R01AX06 103 | CID000000750,B03AA01,B05CX03 104 | CID000002637,J01DC01 105 | CID000002631,J01DD01 106 | CID000160051,C10AC04 107 | CID000001986,S01EC01 108 | CID000003454,J05AB06,J05AB14,S01AD09 109 | CID000003690,L01AA06 110 | CID000151165,A04AD12 111 | CID000003696,N06AA02,N06AA03,N06AA06 112 | CID000003698,C01CE01 113 | CID000003929,J01XX08 114 | CID000004195,C01CA17 115 | CID000002949,G03XA01 116 | CID000004078,N05AC03 117 | CID000004900,A07EA03,H02AB07,H02AB15 118 | CID000005593,S01FA06 119 | CID000005452,N05AC02 120 | CID000004107,M03BA03 121 | CID000071329,C01BD04 122 | CID000005212,G04BE03 123 | CID000005210,A08AA10 124 | CID000005215,D06BA01,J01EC02 125 | CID000004920,G03AC06,G03DA02,G03DA03,G03DA04,L02AB02 126 | CID000004927,D04AA10,R06AD02,R06AD05 127 | CID000008612,N01BA04 128 | CID000005430,D01AC06,P02CA02 129 | CID000003324,J05AB09,S01AD07 130 | CID000003325,A02BA03 131 | CID000124087,R06AX27 132 | CID000004856,M01AC01,M02AA07,S01BC06 133 | CID000005508,M01AB03,M02AA21 134 | CID000005504,C04AB02,M02AX02 135 | CID000005505,A10BB03,V04CA01 136 | CID000054841,N06BA09 137 | CID000005726,J05AF01 138 | CID000005721,J05AH01 139 | CID000002182,L01XX35 140 | CID000002187,L02BG03 141 | CID000004993,P01BD01 142 | CID000003100,D04AA32,D04AA33,R06AA02 143 | CID000000861,H03AA02 144 | CID000003105,S01EA02 145 | CID000003108,B01AC07,C01DX20 146 | CID000004675,M03AC01 147 | CID003468412,A07AA02,D01AA01,G01AA01 148 | CID000003075,C08DB01 149 | CID000068844,S01EC04 150 | CID000060714,V08CA04 151 | CID009571074,J01DD16 152 | CID000002583,C07AA15,S01ED05 153 | CID000004451,J05AE04 154 | CID000002512,G02CB03,N04BC06 155 | CID000002629,J01DD12 156 | CID000000137,L01XD04 157 | CID000041317,D05BB02 158 | CID000002622,J01DE01 159 | CID000062959,J01MA13 160 | CID000002995,N06AA01 161 | CID000002487,N02AF01 162 | CID000003467,D06AX07,J01GB03,S01AA11,S02AA14,S03AA06 163 | CID000003463,C10AB04 164 | CID000003461,L01BC05 165 | CID000003915,R01AC02,S01GX02 166 | CID000003914,S01ED03 167 | CID000003661,A03BA01,A03BB02,N04AC01,N04AC30,S01FA01,S01FA05,A03BA03 168 | CID000039042,C10AB02 169 | CID000004062,N01BB03 170 | CID000004060,N03AB04 171 | CID000004064,N05BC01 172 | CID000002249,C07AB03,C07AB11 173 | CID000152945,G04CB02 174 | CID000004112,L01BA01,L04AX03 175 | CID011947681,C02DD01 176 | CID000002658,J01DC02 177 | CID005329102,L01XE04 178 | CID000004934,A03AB05 179 | CID000004935,S01HA04 180 | CID000004932,C01BC03 181 | CID000002655,J01DD07 182 | CID000039860,A04AD11 183 | CID000005538,D10AD01,D10AD51,L01XX14,D10BA01,D10AD04,L01XX22 184 | CID000004845,R03AC08,R03CC07 185 | CID000005533,N06AX05 186 | CID000005717,R03DC01 187 | CID000005719,N05CF03 188 | CID000005718,J05AF03 189 | CID000004201,C02DC01,D11AX01 190 | CID000005155,J05AF04 191 | CID000005404,G01AG02 192 | CID000005403,R03AC03,R03CC03 193 | CID000005402,D01AE15,D01BA02 194 | CID000005401,G04CA03 195 | CID000005152,R03AC12 196 | CID000004666,L01CD01 197 | CID000002160,N06AA09 198 | CID000048041,C01BC08 199 | CID000000815,A07AA08,J01GB04,S01AA24 200 | CID000002441,N05BA08 201 | CID000060169,M03AC07 202 | CID000004595,A04AA01 203 | CID000060164,D10AD03 204 | CID000004599,A08AB01 205 | CID000002431,C01BD02 206 | CID000003000,D07AC03,D07XC02 207 | CID000003003,A01AC02,C05AA09,D07AB19,D07XB05,D10AA03,H02AB02,R01AD03,S01BA01,S01CB01,S02BA06,S03BA01 208 | CID000003007,N06BA01,N06BA02 209 | CID000438204,D06BA01 210 | CID000007029,A08AA03 211 | CID000123620,D07AC13,D07XC03,R01AD09,R03BA07 212 | CID003002190,J01FA15 213 | CID000004449,N06AX06 214 | CID000002617,J01DB04 215 | CID000104758,L01BA03 216 | CID000003478,A10BB07 217 | CID000023897,N05AE02 218 | CID000163742,B01AX05 219 | CID000003475,A10BB09 220 | CID000003476,A10BB12 221 | CID000002891,B03BB01,B03BA01 222 | CID000003676,C01BB01,C05AD01,D04AB01,N01BB02,R02AD02,S01HA07,S02DA01,A01AD11 223 | CID000003902,L02BG04 224 | CID000003675,N06AF03 225 | CID000003672,C01EB16,G02CC01,M01AE01,M01AE14,M02AA13 226 | CID000004058,N02AB02 227 | CID000004057,A03AB12 228 | CID000004054,N06DX01 229 | CID000004053,L01AA03 230 | CID000004052,M01AC06 231 | CID000003877,J05AF05 232 | CID000091270,C09AA13 233 | CID000003878,N03AX09 234 | CID000004121,C03AA08 235 | CID000065027,J05AE09 236 | CID000003494,A03AB02 237 | CID005311048,L01XA01 238 | CID000004409,M01AX01 239 | CID000016362,N05AG02 240 | CID000003869,C07AG01 241 | CID000004943,N01AX10 242 | CID000004946,C07AA05 243 | CID000005090,M01AH02 244 | CID000005525,C09AA10 245 | CID000005526,B02AA02 246 | CID000130881,C09CA08 247 | CID000009433,R03DA05 248 | CID000093860,L01XX32 249 | CID000003278,C03CC01 250 | CID000003279,J04AK02 251 | CID000072057,V08CA09 252 | CID000072054,G04BD10 253 | CID000005412,A01AB13,D06AA04,J01AA07,S01AA09,S02AA08,S03AA02,A01AB21,S01AA02,J01AA06,J01AA09,S01AA04,G01AA07,D06AA02,D06AA03,J01AA03 254 | CID000004614,M01AE12 255 | CID000004616,N05BA04 256 | CID000002477,N05BE01 257 | CID000002476,N02AE01,N07BC01 258 | CID000002471,C03CA02 259 | CID000002274,J01DF01 260 | CID000001065,C01BA01 261 | CID000002519,N06BC01,N02BE01 262 | CID000003928,J01FF02 263 | CID000004724,C07AA23 264 | CID000004723,N06BA05 265 | CID000003019,C02DA01,V03AH01 266 | CID000000937,C04AC01,C10AD02 267 | CID000003016,N05BA01,N05BA17 268 | CID000122316,N04BD02 269 | CID000002609,J01DC04 270 | CID000005071,J05AC02 271 | CID000003702,C03BA11 272 | CID000150310,C03DA04 273 | CID000150311,C10AX09 274 | CID000004436,R01AA08,S01GA01,R01AB02 275 | CID000004100,S01EC05 276 | CID000004547,A10BX02 277 | CID000003648,N02AA03 278 | CID000004543,N06AA10 279 | CID000004542,G03AC03,G03AD01 280 | CID000003647,C03AA02 281 | CID000060198,L02BG06 282 | CID000004043,S01BA08 283 | CID000004044,M01AG01 284 | CID000004046,P01BC02 285 | CID000003488,A10BB01 286 | CID000002720,C03AA03,C03AA04 287 | CID000002726,N05AA01 288 | CID000004138,C02AB01,C02AB02 289 | CID000002283,D06AX05,R02AB04,J01XX10 290 | CID000002284,M03BX01 291 | CID000005556,N05CD05 292 | CID000005396,L01CB02 293 | CID000005394,L01AX03 294 | CID000005391,N05CD07 295 | CID005311297,R03DC03 296 | CID000034633,C01BG01 297 | CID000005775,C04AB01,V03AB36 298 | CID000003241,R06AX24,S01GX10 299 | CID000667490,L01BB02 300 | CID000060835,N06AX21 301 | CID000060831,R03BB04 302 | CID000004603,J05AH02 303 | CID000004601,M03BC01,N04AB02 304 | CID000004607,J01CF01,J01CF02,J01CF04,J01CF05 305 | CID000004609,L01XA03 306 | CID000002462,A07EA06,D07AC09,R01AD05,R03BA02 307 | CID001349907,H03BB02 308 | CID000000838,A01AD01,B02BC09,C01CA03,C01CA24,R01AA14,R03AA01,S01EA01 309 | CID000030623,V03AF02 310 | CID000002266,D10AX03 311 | CID000002267,R01AC03,R06AX19,S01GX07 312 | CID000002265,L04AX01 313 | CID000003393,N05CD01 314 | CID000003392,D07AC07 315 | CID000003394,M01AE09,M02AA19,S01BC04 316 | CID000003397,L02BB01 317 | CID000002269,J01FA10,S01AA26 318 | CID000041781,C03CA01,C03CA04 319 | CID000004739,A03AB03 320 | CID000004736,N02AD01 321 | CID000004737,N05CA01 322 | CID000060696,M03AC09 323 | CID000004730,J01CE02,J01CE10 324 | CID000096312,L01BB07 325 | CID003085017,V03AE02 326 | CID000002673,J01DB09 327 | CID000002675,J01DD08 328 | CID000002676,C10AA06 329 | CID000002678,R06AE07,R06AE09 330 | CID000003730,V08AB02 331 | CID000000232,B05XB01 332 | CID000003657,L01XX05 333 | CID000003652,P01BA02 334 | CID000003658,N05BB01 335 | CID000002708,L01AA02 336 | CID000000581,R05CB01,S01XA08,V03AB23 337 | CID000028112,N05AN01,D11AX04 338 | CID000002818,N05AH02 339 | CID000068740,M05BA08 340 | CID000004091,A10BA02,A10BD11 341 | CID000004095,N02AC52,N07BC02,R05DA06 342 | CID000005546,C03DB02 343 | CID000005544,A01AC01,D07AB09,D07XB02,H02AB08,R01AD11,R03BA06,S01BA05,C05AA12 344 | CID000040976,G03AC08 345 | CID000004197,C01CE02 346 | CID000004893,C02CA01 347 | CID000004891,P02BA01 348 | CID000031477,S01ED04 349 | CID000016231,C03DB01 350 | CID000004894,A07EA01,C05AA04,D07AA01,D07AA03,D07AC14,D07XA02,D10AA02,H02AB04,H02AB06,R01AD02,S01BA04,S01CB02,S02BA03,S03BA02,H02AB07 351 | CID000002381,N04AA02 352 | CID000003261,N05CD04 353 | CID000003251,N02CA01,N02CA02 354 | CID000003255,D10AF02,J01FA01,S01AA17 355 | CID000050614,J01DC05 356 | CID000004635,N02AA05 357 | CID000004634,G04BD04 358 | CID000004428,N07BB04 359 | CID005282226,S01EE04 360 | CID000054454,C10AA01 361 | CID000057469,D06BB10 362 | CID000003382,C05AA11,D07AC08 363 | CID000003381,C05AA10,D07AC04,S01BA15,S02BA08 364 | CID000003387,G03BA01 365 | CID000003384,C05AA06,D07AB06,D07XB04,D10AA01,S01BA07,S01CB05 366 | CID000001046,J04AK01 367 | CID000003739,V08AC04 368 | CID000056959,C01EB18 369 | CID000000085,A16AA01 370 | CID000003032,M01AB05,M02AA15,S01BC03,D11AX18 371 | CID000060753,C01BD05 372 | CID000060754,V08CA03 373 | CID000002662,L01XX33,M01AH01 374 | CID000002666,J01DB01 375 | CID000004419,N02AF02 376 | CID006323497,J04AB05 377 | CID000003724,V08AB09 378 | CID000054547,J01DD13 379 | CID000002719,P01BA01,P01BA02 380 | CID000002712,N05BA02 381 | CID000000596,L01BC01 382 | CID000000772,B01AB01,C05BA01,C05BA03,S01XA09,S01XA14,B01AB05 383 | CID000003827,R06AX17,S01GX08 384 | CID000003826,M01AB15,S01BC05 385 | CID000003825,M01AE03,M01AE17,M02AA10 386 | CID000003823,D01AC08,G01AF11,J02AB02 387 | CID000003821,N01AX03,N01AX14 388 | CID000004140,G02AB01 389 | CID000002803,C02AC01,N02CX02,S01EA04,S01EA03 390 | CID000002806,B01AC04 391 | CID000062867,N05AC04 392 | CID000004158,N06BA04 393 | CID000004159,D07AA01,D07AC14,D10AA02,H02AB04 394 | CID000115237,N05AX08 395 | CID000013342,L01CA01 396 | CID000004259,J01MA14,S01AE07 397 | CID004479097,B03BA03,V03AB33 398 | CID000004086,R03CB03,R03AB03 399 | CID000005578,J01EA01 400 | CID000005572,N04AA01 401 | CID000005379,J01MA16,S01AE06 402 | CID000005376,L02BA01 403 | CID000005596,G04BD09 404 | CID000005591,A10BG01 405 | CID000060815,N01AH06 406 | CID000004976,N06AA11 407 | CID000002443,G02CB01,N04BC01 408 | CID005353894,V03AB04 409 | CID000005625,J05AG02 410 | CID000004623,D01AC11,G01AF17 411 | CID000002244,A01AD05,B01AC06,N02BA01 412 | CID000047319,M03AC04,M03AC11 413 | CID000170361,N07BA03 414 | CID000003375,C05AA10,D07AC04 415 | CID000087177,A06AD11 416 | CID000037393,P01BX01 417 | CID000041693,N01AH03 418 | CID000005978,L01CA02 419 | CID000060787,J05AE01 420 | CID000002656,J01DD04 421 | CID000002650,J01DD02 422 | CID000003759,N06AF01 423 | CID000004594,A02BC01,A02BC05 424 | CID000003750,L01XX19 425 | CID000065999,C09CA07 426 | CID000005193,N05CA06 427 | CID000157922,L01XD03 428 | CID000002951,M03CA01 429 | CID000003637,C02DB02,C02DB01 430 | CID000003639,C03AA03 431 | CID000054374,H01CB02 432 | CID000002727,A10BB02 433 | CID000002725,R06AB02,R06AB04 434 | CID000003767,J04AC01 435 | CID000004915,L01XB01 436 | CID000005582,P01AX07 437 | CID000004189,A01AB09,A07AC01,D01AC02,G01AF04,J02AB01,S02AA13 438 | CID000003958,N05BA06 439 | CID000000143,V03AF03 440 | CID000036339,N01AX07 441 | CID000004168,A03FA01 442 | CID000004163,N02CA04 443 | CID000004160,G03BA02,G03EK01 444 | CID000002812,A01AB18,D01AC01,G01AF02 445 | CID000002907,L01AA01,L01DB07 446 | CID000002905,S01FA04 447 | CID006398970,J01DD15 448 | CID000002019,L01DA01 449 | CID000005291,L01XE01 450 | CID000005297,A07AA04,J01GA01,S01AA15 451 | CID000002913,R06AX02 452 | CID000005566,N05AB06 453 | CID000002083,R03CC02,R03AC02 454 | CID000158440,V08CA11 455 | CID000005584,N06AA06 456 | CID000016850,S01JA01 457 | CID000002123,L01XX03 458 | CID000034312,N03AF02 459 | CID000002232,B01AE03 460 | CID000005035,G03XC01 461 | CID000039507,H02AB12,S01BA13 462 | CID000005746,L01DC03 463 | CID000005038,C09AA05 464 | CID000005039,A02BA02,A02BA07 465 | CID000004768,C04AX02 466 | CID000002344,N04AC01 467 | CID000002349,J01CE01,S01AA14,J01CE02,J01CE10 468 | CID000003291,N03AD01 469 | CID000003715,M01AB01,C01EB03 470 | CID000072467,L01DC01 471 | CID000050294,R01AC07,R03BC03,S01GX04 472 | CID000002646,J01DC10 473 | CID000003749,C09CA04 474 | CID000003746,R01AX03,R03BB01 475 | CID000002610,J01DB05 476 | CID000004440,N02CC02 477 | CID000002732,C03BA04 478 | CID000064147,J05AB14 479 | CID000004509,J01XE01 480 | CID000444033,R03BA08 481 | CID000004506,N05CD03,N05CD02 482 | CID005311039,J02AX04 483 | CID000003519,C02AC02 484 | CID000003518,C02CC02,S01EX01 485 | CID000004211,L01XX23 486 | CID000003512,D01AA08,D01BA01 487 | CID000003510,A04AA02 488 | CID000000159,B01AC09 489 | CID000000158,G02AD02 490 | CID000004178,C01BB02 491 | CID005481350,J05AF07 492 | CID000001935,N06AA18,N06DA01 493 | CID000004170,C03BA08 494 | CID000004171,C07AB02,C07AB52 495 | CID000004173,A01AB17,D06BX01,G01AF01,J01XD01,P01AB01 496 | CID000027686,R01AC01,A07EB01,S01GX01,R03BC01,D11AH03 497 | CID000003226,N01AB04 498 | CID000150610,J01DH03 499 | CID000071158,N07BB03 500 | CID000657298,H03BA02 501 | CID000006691,B01AA03 502 | CID005311181,B01AC11 503 | CID005381226,J04AB02 504 | CID000036811,C01CA07 505 | CID000004196,G03XB01 506 | CID000005359,M01AE07 507 | CID000004192,N05CD08 508 | CID000005352,M01AB02 509 | CID000060184,C09AA04 510 | CID000002133,D07AC11 511 | CID000002130,N04BB01 512 | CID000002131,N07AA30 513 | CID000060879,C09CA02 514 | CID000060871,J05AF08 515 | CID000004991,N07AA02 516 | CID000005645,A05AA02 517 | CID000005647,J05AB11 518 | CID000002171,J01CA04 519 | CID000003355,C01BC04 520 | CID000003354,G04BD02 521 | CID000051263,N01AH02 522 | CID000003350,D11AX10,G04CB01 523 | CID000005002,N05AH04 524 | CID000005005,C09AA06 525 | CID000004771,A08AA01 526 | CID005353980,A07EC01 527 | CID000000942,N07BA01,A11HA01,C04AC01,C10AD02 528 | CID006447131,D11AH02 529 | CID000003292,N03AB01 530 | CID000003779,C01CA02,D08AX05 531 | CID000110635,G04BE08 532 | CID000110634,G04BE09 533 | CID000041744,L01DB09 534 | CID000477468,J02AX05 535 | CID000002749,D01AE14,G01AX12 536 | CID000002585,C07AG02 537 | CID000145068,R07AX01 538 | CID000004510,C01DA02 539 | CID000048175,D07AC21 540 | CID000004205,N06AX11 541 | CID000004200,A01AB23,J01AA08 542 | CID000005381,D05AX05 543 | CID005361912,J04AB04 544 | CID000002520,C08DA01 545 | CID000039765,M03AC03 546 | CID000003403,C10AA04 547 | CID000042113,N01AB07 548 | CID000005344,J01EB05,S01AB02 549 | CID000005342,M04AB02 550 | CID000004834,J01CA12 551 | CID000123606,N02CC05 552 | CID000060865,S01GX09,R01AC08 553 | CID000005650,C09CA03 554 | CID000005651,A07AA09,J01XA01 555 | CID006435110,P02CF01 556 | CID000002216,S01EA03 557 | CID000002215,G04BE07,N04BC07 558 | CID000003348,R06AX26 559 | CID000003342,M01AE04 560 | CID000003340,C01CA19 561 | CID000006058,A16AA04 562 | CID000004748,N05AB03 563 | CID000004740,C04AD03 564 | CID000000951,C01CA03 565 | CID000002366,N07CA01 566 | CID000002369,C07AB05,S01ED02 567 | CID000025419,M05BA02 568 | CID000006476,N03AD03 569 | CID000003763,N01AB06 570 | CID000059708,N03AX14 571 | CID000003161,A01AB22,J01AA02 572 | CID000003168,N05AD08 573 | CID000000888,A12CC08,B05XA11,A06AD03,B05XA10,A06AD19,A02AA04,A12CC09,A02AA01,A06AD02,A12CC10,A12CC01,A12CC04,B05CB03,A12CC07,G04BX01,A12CC03,A12CC02,A02AA03,A06AD04,A02AA02,A12CC05,A02AA05,A06AD01,C10AX07,A12CC06,A12CC30,D11AX05,V04CC02,B05XA05 574 | CID000041774,A10BF01 575 | CID000002733,M03BB03 576 | CID000000401,J04AB01 577 | CID000002751,C09AA08 578 | CID000002756,A02BA01 579 | CID005281007,L01AX04 580 | CID000005358,N02CC01 581 | CID000444013,V08CA06 582 | CID000004463,J05AG01 583 | CID000004236,N06BA07 584 | CID000104741,L02BA03 585 | CID000003883,A02BC03 586 | CID000003405,B03BB01 587 | CID000003406,V03AB34 588 | CID000002978,A04AD10 589 | CID000002973,V03AC01 590 | CID000002250,C10AA05 591 | CID002761171,J04AD03 592 | CID000002021,J01XX04 593 | CID000002022,D06BB03,J05AB01,S01AD03,J05AB11 594 | CID003000502,J04AB30 595 | CID000003385,L01BC02 596 | CID000002118,N05BA12 597 | CID000004828,C07AA03,C07AA14,C07AA17 598 | CID000004829,A10BG03 599 | CID000001125,A16AX07 600 | CID000123631,L01XE02 601 | CID000004919,N04AA04 602 | CID000060852,M05BA06 603 | CID000005487,M03BX02 604 | CID000005486,B01AC17 605 | CID000047320,M03AC11 606 | CID000005245,M05BA07 607 | CID000005466,N03AG06 608 | CID000003379,R01AD04,R03BA03 609 | CID000003373,V03AB25 610 | CID000003372,N05AB02 611 | CID006398525,A02BX02 612 | CID000005064,J05AB04 613 | CID000005453,L01AC01 614 | CID000002375,L02BB03 615 | CID000004212,L01DB07 616 | CID000002370,N07AB02 617 | CID000002983,D06AA01,J01AA01 618 | CID000003793,J02AC02 619 | CID000003151,A03FA03 620 | CID000004689,A07AA06 621 | CID000003157,C02CA04 622 | CID000003156,R07AB01 623 | CID000003158,N06AA12 624 | CID000082146,L01XX25 625 | CID000002769,A03FA02 626 | CID000004539,J01MA06,S01AE02 627 | CID000002762,J01MB06 628 | CID000002564,R06AA08 629 | CID000003562,N01AB01 630 | CID000000187,S01EB09 631 | CID000060854,N05AE04 632 | CID000003899,L04AA13 633 | CID000003890,S01EE01 634 | CID000156419,H05BX01 635 | CID000003419,C09AA09 636 | CID000003417,J01XX01 637 | CID000003414,J05AD01 638 | CID000003410,R03AC13 639 | CID000002786,D10AF01,G01AA10,J01FF01 640 | CID000002781,R06AA04,D04AA14 641 | CID000004075,A07EC02 642 | CID000003964,N05AH01 643 | CID000003962,C10AA02 644 | CID000003961,C09CA01 645 | CID000027991,H01BA02 646 | CID000001546,L01BB04 647 | CID000004036,M01AG04,M02AA18 648 | CID000004030,P02CA01 649 | CID000004032,C02BB01 650 | CID000003333,C08CA02 651 | CID000002162,C08CA01 652 | CID000004411,C07AA12 653 | CID000004819,N07AX01,S01EB01 654 | CID005311027,S01EE03 655 | CID000060843,L01BA04 656 | CID000005496,J01GB01,S01AA12 657 | CID000009034,N01AF01,N05CA15 658 | CID000005253,C07AA07,C07AA57 659 | CID000003365,D01AC15,J02AC01 660 | CID000003366,D01AE21,J02AX01 661 | CID000003367,L01BB05 662 | CID000005472,B01AC05 663 | CID000005478,C07AA06,S01ED01 664 | CID000005479,J01XD02,P01AB02 665 | CID005282044,J01AA12 666 | CID000005078,N02CC04 667 | CID000060937,M05BA05 668 | CID000005076,J05AE03 669 | CID000002308,A07EA07,D07AC15,R01AD01,R03BA01 670 | CID000005672,L01CA04 671 | CID000047725,L02AE03 672 | CID000010548,S01EB03 673 | CID000003784,C08CA03 674 | CID000003780,C01DA08,D03AX08 675 | CID000003148,A04AA04 676 | CID000003143,L01CD02 677 | CID000006256,S01AD02 678 | CID000001206,N06BA03 679 | CID000014888,L01XX27 680 | CID000002578,L01AD01 681 | CID000003559,N05AD01 682 | CID000000191,C01EB10 683 | CID000004253,G04BE07,N02AA01,N02AA04,N02AA09,N04BC07,R05DA01,R05DA05,S01XA06,D10AX30 684 | CID000051577,A10BF02 685 | CID000001972,A01AB04,A07AA07,G01AA03,J02AA01 686 | CID000001971,J05AF06 687 | CID000001978,C07AB04 688 | CID000054786,B01AC21 689 | CID000032797,D07AD01 690 | CID000047472,G01AF15 691 | CID000002955,J04BA02,D10AX05 692 | CID000002958,L01DB02 693 | CID000002794,J04BA01 694 | CID000003950,L01AD02 695 | CID000003957,R06AX13,R06AX27,C09AA03 696 | CID000003954,A07DA03,A07DA05 697 | CID005362070,A07EC04 698 | CID000125017,N06AX23 699 | CID000002474,N01BB01,N01BB10 700 | CID000004513,A02BA04 701 | CID000002713,A01AB03,B05CA02,D08AC02,D09AA12,R02AA05,S01AX09,S02AA09,S03AA04,D08AE02 702 | CID000005314,M03AB01 703 | CID000000453,A06AD16,B05BC01,B05CX04,V08AC04 704 | CID000077999,A10BG02 705 | CID000002478,L01AB01 706 | CID000002173,J01CA01,J01CA02,J01CA06,J01CA14,J01CA15,S01AA19 707 | CID000002170,N06AA17 708 | CID000002177,J05AE05,J05AE07 709 | CID000002179,L01XX01 710 | CID000003222,C09AA02 711 | CID000001148,N06AX02 712 | CID000005267,C03DA01 713 | CID000031378,H02AA02 714 | CID000003310,L01CB01 715 | CID000005052,C02AA02 716 | CID000005040,L04AA10 717 | CID000074989,P01AX06 718 | CID000005665,N03AG04 719 | CID000002311,C09AA07 720 | CID000002315,C03AA01 721 | CID005473385,L01BB03 722 | CID000104803,M03AC10 723 | CID000002405,C07AB07 724 | CID000000853,C10AX01,H03AA01 725 | CID000060612,N05CM18 726 | CID000060613,J05AB12 727 | CID000504578,A01AB08,A07AA01,B05CA09,D06AX04,J01GB05,R02AB01,S01AA03,S02AA07,S03AA01,D09AA01,S01AA07 728 | CID005487301,A06AX06 729 | CID003062316,L01XE06 730 | CID000000298,D06AX02,D10AF03,G01AA05,J01BA01,S01AA01,S02AA01,S03AA08 731 | CID000003040,J01CF01 732 | CID000003043,J05AF02 733 | CID000003042,A03AA07 734 | CID000002541,C09CA06 735 | CID000083786,M04AA01 736 | CID000002895,M03BX08 737 | CID005493444,C09XA02 738 | CID000004727,M01CC01 739 | CID005362420,S01LA01 740 | CID000027661,C01DA14 741 | CID000000450,G03CA01,G03CA03,L02AA02,L02AA03 742 | CID000000681,C01CA04 743 | CID000002802,N03AE01 744 | CID000002800,G03GB02 745 | CID000003948,J01MA07,S01AE04 746 | CID000004425,V03AB15 747 | -------------------------------------------------------------------------------- /src/COGNet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._C import device 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | from torch.nn.modules.linear import Linear 8 | from data.processing import process_visit_lg2 9 | 10 | from layers import SelfAttend 11 | from layers import GraphConvolution 12 | 13 | 14 | class COGNet(nn.Module): 15 | """在CopyDrug_batch基础上将medication的encode部分修改为transformer encoder""" 16 | def __init__(self, voc_size, ehr_adj, ddi_adj, ddi_mask_H, emb_dim=64, device=torch.device('cpu:0')): 17 | super(COGNet, self).__init__() 18 | self.voc_size = voc_size 19 | self.emb_dim = emb_dim 20 | self.device = device 21 | self.nhead = 2 22 | self.SOS_TOKEN = voc_size[2] # start of sentence 23 | self.END_TOKEN = voc_size[2]+1 # end 新增的两个编码,两者均是针对于药物的embedding 24 | self.MED_PAD_TOKEN = voc_size[2]+2 # 用于embedding矩阵中的padding(全为0) 25 | self.DIAG_PAD_TOKEN = voc_size[0]+2 26 | self.PROC_PAD_TOKEN = voc_size[1]+2 27 | 28 | self.tensor_ddi_mask_H = torch.FloatTensor(ddi_mask_H).to(device) 29 | 30 | # dig_num * emb_dim 31 | self.diag_embedding = nn.Sequential( 32 | nn.Embedding(voc_size[0]+3, emb_dim, self.DIAG_PAD_TOKEN), 33 | nn.Dropout(0.3) 34 | ) 35 | 36 | # proc_num * emb_dim 37 | self.proc_embedding = nn.Sequential( 38 | nn.Embedding(voc_size[1]+3, emb_dim, self.PROC_PAD_TOKEN), 39 | nn.Dropout(0.3) 40 | ) 41 | 42 | # med_num * emb_dim 43 | self.med_embedding = nn.Sequential( 44 | # 添加padding_idx,表示取0向量 45 | nn.Embedding(voc_size[2]+3, emb_dim, self.MED_PAD_TOKEN), 46 | nn.Dropout(0.3) 47 | ) 48 | 49 | # 用于对上一个visit的medication进行编码 50 | # self.medication_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, dim_feedforward=emb_dim*8, batch_first=True, dropout=0.2) 51 | self.medication_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, batch_first=True, dropout=0.2) 52 | # 用于对当前visit的疾病与症状进行编码 53 | # self.diagnoses_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, dim_feedforward=emb_dim*8, batch_first=True, dropout=0.2) 54 | # self.procedure_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, dim_feedforward=emb_dim*8, batch_first=True, dropout=0.2) 55 | self.diagnoses_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, batch_first=True, dropout=0.2) 56 | self.procedure_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, batch_first=True, dropout=0.2) 57 | # self.enc_gru = nn.GRU(emb_dim, emb_dim, batch_first=True, bidirectional=True) 58 | 59 | # self.ehr_gcn = GCN( 60 | # voc_size=voc_size[2], emb_dim=emb_dim, adj=ehr_adj, device=device) 61 | # self.ddi_gcn = GCN( 62 | # voc_size=voc_size[2], emb_dim=emb_dim, adj=ddi_adj, device=device) 63 | self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device) 64 | 65 | self.gcn = GCN(voc_size=voc_size[2], emb_dim=emb_dim, ehr_adj=ehr_adj, ddi_adj=ddi_adj, device=device) 66 | self.inter = nn.Parameter(torch.FloatTensor(1)) 67 | 68 | # 聚合单个visit内的diag和proc得到visit-level的表达 69 | self.diag_self_attend = SelfAttend(emb_dim) 70 | self.proc_self_attend = SelfAttend(emb_dim) 71 | 72 | self.decoder = MedTransformerDecoder(emb_dim, self.nhead, dim_feedforward=emb_dim*2, dropout=0.2, 73 | layer_norm_eps=1e-5) 74 | 75 | # 用于对每一个visit的diagnoses进行编码 76 | 77 | # 用于生成药物序列 78 | self.dec_gru = nn.GRU(emb_dim*3, emb_dim, batch_first=True) 79 | 80 | self.diag_attn = nn.Linear(emb_dim*2, 1) 81 | self.proc_attn = nn.Linear(emb_dim*2, 1) 82 | self.W_diag_attn = nn.Linear(emb_dim, emb_dim) 83 | self.W_proc_attn = nn.Linear(emb_dim, emb_dim) 84 | self.W_diff_attn = nn.Linear(emb_dim, emb_dim) 85 | self.W_diff_proc_attn = nn.Linear(emb_dim, emb_dim) 86 | 87 | # weights 88 | self.Ws = nn.Linear(emb_dim*2, emb_dim) # only used at initial stage 89 | self.Wo = nn.Linear(emb_dim, voc_size[2]+2) # generate mode 90 | # self.Wc = nn.Linear(emb_dim*2, emb_dim) # copy mode 91 | self.Wc = nn.Linear(emb_dim, emb_dim) # copy mode 92 | 93 | self.W_dec = nn.Linear(emb_dim, emb_dim) 94 | self.W_stay = nn.Linear(emb_dim, emb_dim) 95 | self.W_proc_dec = nn.Linear(emb_dim, emb_dim) 96 | self.W_proc_stay = nn.Linear(emb_dim, emb_dim) 97 | 98 | # swtich network to calculate generate probablity 99 | self.W_z = nn.Linear(emb_dim, 1) 100 | 101 | 102 | self.weight = nn.Parameter(torch.tensor([0.3]), requires_grad=True) 103 | # bipartite local embedding 104 | self.bipartite_transform = nn.Sequential( 105 | nn.Linear(emb_dim, ddi_mask_H.shape[1]) 106 | ) 107 | self.bipartite_output = MaskLinear( 108 | ddi_mask_H.shape[1], voc_size[2], False) 109 | 110 | def encode(self, diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix, seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, dec_proc, stay_proc, dec_proc_mask, stay_proc_mask, max_len=20): 111 | device = self.device 112 | # batch维度以及seq维度上并行计算(现在不考虑时间序列信息),每一个medication序列仍然按顺序预测 113 | batch_size, max_visit_num, max_med_num = medications.size() 114 | max_diag_num = diseases.size()[2] 115 | max_proc_num = procedures.size()[2] 116 | 117 | ############################ 数据预处理 ######################### 118 | # 1. 对当前的disease与procedure进行编码 119 | input_disease_embdding = self.diag_embedding(diseases).view(batch_size * max_visit_num, max_diag_num, self.emb_dim) # [batch, seq, max_diag_num, emb] 120 | input_proc_embedding = self.proc_embedding(procedures).view(batch_size * max_visit_num, max_proc_num, self.emb_dim) # [batch, seq, max_proc_num, emb] 121 | d_enc_mask_matrix = d_mask_matrix.view(batch_size * max_visit_num, max_diag_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1, self.nhead, max_diag_num,1) # [batch*seq, nhead, input_length, output_length] 122 | d_enc_mask_matrix = d_enc_mask_matrix.view(batch_size * max_visit_num * self.nhead, max_diag_num, max_diag_num) 123 | p_enc_mask_matrix = p_mask_matrix.view(batch_size * max_visit_num, max_proc_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1, self.nhead, max_proc_num,1) 124 | p_enc_mask_matrix = p_enc_mask_matrix.view(batch_size * max_visit_num * self.nhead, max_proc_num, max_proc_num) 125 | input_disease_embdding = self.diagnoses_encoder(input_disease_embdding, src_mask=d_enc_mask_matrix).view(batch_size, max_visit_num, max_diag_num, self.emb_dim) 126 | input_proc_embedding = self.procedure_encoder(input_proc_embedding, src_mask=p_enc_mask_matrix).view(batch_size, max_visit_num, max_proc_num, self.emb_dim) 127 | 128 | # 1.1 encode visit-level diag and proc representations 129 | visit_diag_embedding = self.diag_self_attend(input_disease_embdding.view(batch_size * max_visit_num, max_diag_num, -1), d_mask_matrix.view(batch_size * max_visit_num, -1)) 130 | visit_proc_embedding = self.proc_self_attend(input_proc_embedding.view(batch_size * max_visit_num, max_proc_num, -1), p_mask_matrix.view(batch_size * max_visit_num, -1)) 131 | visit_diag_embedding = visit_diag_embedding.view(batch_size, max_visit_num, -1) 132 | visit_proc_embedding = visit_proc_embedding.view(batch_size, max_visit_num, -1) 133 | 134 | # 1.3 计算 visit-level的attention score 135 | # [batch_size, max_visit_num, max_visit_num] 136 | cross_visit_scores = self.calc_cross_visit_scores(visit_diag_embedding, visit_proc_embedding) 137 | 138 | 139 | # 3. 构造一个last_seq_medication,表示上一次visit的medication,第一次的由于没有上一次medication,用0填补(用啥填补都行,反正不会用到) 140 | last_seq_medication = torch.full((batch_size, 1, max_med_num), 0).to(device) 141 | last_seq_medication = torch.cat([last_seq_medication, medications[:, :-1, :]], dim=1) 142 | # m_mask_matrix矩阵同样也需要后移 143 | last_m_mask = torch.full((batch_size, 1, max_med_num), -1e9).to(device) # 这里用较大负值,避免softmax之后分走了概率 144 | last_m_mask = torch.cat([last_m_mask, m_mask_matrix[:, :-1, :]], dim=1) 145 | # 对last_seq_medication进行编码 146 | last_seq_medication_emb = self.med_embedding(last_seq_medication) 147 | last_m_enc_mask = last_m_mask.view(batch_size * max_visit_num, max_med_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1,self.nhead,max_med_num,1) 148 | last_m_enc_mask = last_m_enc_mask.view(batch_size * max_visit_num * self.nhead, max_med_num, max_med_num) 149 | encoded_medication = self.medication_encoder(last_seq_medication_emb.view(batch_size * max_visit_num, max_med_num, self.emb_dim), src_mask=last_m_enc_mask) # (batch*seq, max_med_num, emb_dim) 150 | encoded_medication = encoded_medication.view(batch_size, max_visit_num, max_med_num, self.emb_dim) 151 | 152 | # vocab_size, emb_size 153 | ehr_embedding, ddi_embedding = self.gcn() 154 | drug_memory = ehr_embedding - ddi_embedding * self.inter 155 | drug_memory_padding = torch.zeros((3, self.emb_dim), device=self.device).float() 156 | drug_memory = torch.cat([drug_memory, drug_memory_padding], dim=0) 157 | 158 | return input_disease_embdding, input_proc_embedding, encoded_medication, cross_visit_scores, last_seq_medication, last_m_mask, drug_memory 159 | 160 | def decode(self, input_medications, input_disease_embedding, input_proc_embedding, last_medication_embedding, last_medications, cross_visit_scores, 161 | d_mask_matrix, p_mask_matrix, m_mask_matrix, last_m_mask, drug_memory): 162 | """ 163 | input_medications: [batch_size, max_visit_num, max_med_num + 1], 开头包含了 SOS_TOKEN 164 | """ 165 | batch_size = input_medications.size(0) 166 | max_visit_num = input_medications.size(1) 167 | max_med_num = input_medications.size(2) 168 | max_diag_num = input_disease_embedding.size(2) 169 | max_proc_num = input_proc_embedding.size(2) 170 | 171 | input_medication_embs = self.med_embedding(input_medications).view(batch_size * max_visit_num, max_med_num, -1) 172 | # input_medication_embs = self.dropout_emb(input_medication_embs) 173 | input_medication_memory = drug_memory[input_medications].view(batch_size * max_visit_num, max_med_num, -1) 174 | 175 | # m_sos_mask = torch.zeros((batch_size, max_visit_num, 1), device=self.device).float() # 这里用较大负值,避免softmax之后分走了概率 176 | m_self_mask = m_mask_matrix 177 | 178 | last_m_enc_mask = m_self_mask.view(batch_size * max_visit_num, max_med_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1, self.nhead, max_med_num, 1) 179 | medication_self_mask = last_m_enc_mask.view(batch_size * max_visit_num * self.nhead, max_med_num, max_med_num) 180 | m2d_mask_matrix = d_mask_matrix.view(batch_size * max_visit_num, max_diag_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1, self.nhead, max_med_num, 1) 181 | m2d_mask_matrix = m2d_mask_matrix.view(batch_size * max_visit_num * self.nhead, max_med_num, max_diag_num) 182 | m2p_mask_matrix = p_mask_matrix.view(batch_size * max_visit_num, max_proc_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1, self.nhead, max_med_num,1) 183 | m2p_mask_matrix = m2p_mask_matrix.view(batch_size * max_visit_num * self.nhead, max_med_num, max_proc_num) 184 | 185 | dec_hidden = self.decoder(input_medication_embedding=input_medication_embs, input_medication_memory=input_medication_memory, 186 | input_disease_embdding=input_disease_embedding.view(batch_size * max_visit_num, max_diag_num, -1), 187 | input_proc_embedding=input_proc_embedding.view(batch_size * max_visit_num, max_proc_num, -1), 188 | input_medication_self_mask=medication_self_mask, 189 | d_mask=m2d_mask_matrix, 190 | p_mask=m2p_mask_matrix) 191 | 192 | score_g = self.Wo(dec_hidden) # (batch * max_visit_num, max_med_num, voc_size[2]+2) 193 | score_g = score_g.view(batch_size, max_visit_num, max_med_num, -1) 194 | prob_g = F.softmax(score_g, dim=-1) 195 | score_c = self.copy_med(dec_hidden.view(batch_size, max_visit_num, max_med_num, -1), last_medication_embedding, last_m_mask, cross_visit_scores) 196 | # (batch_size, max_visit_num * input_med_num, max_visit_num * max_med_num) 197 | 198 | ###### case study 199 | # 这里前提是batch_size等于1 200 | # 几个取值的说明: 201 | # 1.取最新生成的药物对于历史药物的attention值,所以第三维度为-1 202 | # 2.取第最后一个visit的copy值,所以第二维度为-1 203 | # 3.取最后一个visit对倒数第二个visit的药物的attention值,所以第四维度取最后max_med_num个 204 | # score_c_buf = score_c.view(batch_size, max_visit_num, max_med_num, -1) 205 | # score_c_buf = score_c_buf[0, -1, -1, :] # visit_num * (visit_num * max_med_num) 206 | # max_med_num_in_last = len(score_c_buf) // max_visit_num 207 | # print(score_c_buf[-max_med_num_in_last:]) 208 | prob_c_to_g = torch.zeros_like(prob_g).to(self.device).view(batch_size, max_visit_num * max_med_num, -1) # (batch, max_visit_num * input_med_num, voc_size[2]+2) 209 | 210 | # 用scatter操作代替嵌套循环 211 | # 根据last_seq_medication中的indice,将score_c中的值加到score_c_to_g中去 212 | copy_source = last_medications.view(batch_size, 1, -1).repeat(1, max_visit_num * max_med_num, 1) 213 | prob_c_to_g.scatter_add_(2, copy_source, score_c) 214 | prob_c_to_g = prob_c_to_g.view(batch_size, max_visit_num, max_med_num, -1) 215 | 216 | generate_prob = F.sigmoid(self.W_z(dec_hidden)).view(batch_size, max_visit_num, max_med_num, 1) 217 | prob = prob_g * generate_prob + prob_c_to_g * (1. - generate_prob) 218 | prob[:, 0, :, :] = prob_g[:, 0, :, :] # 第一个seq由于没有last_medication信息,仅取prob_g的概率 219 | 220 | return torch.log(prob) 221 | 222 | # def forward(self, input, last_input=None, max_len=20): 223 | def forward(self, diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix, seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, dec_proc, stay_proc, dec_proc_mask, stay_proc_mask, max_len=20): 224 | device = self.device 225 | # batch维度以及seq维度上并行计算(现在不考虑时间序列信息),每一个medication序列仍然按顺序预测 226 | batch_size, max_seq_length, max_med_num = medications.size() 227 | max_diag_num = diseases.size()[2] 228 | max_proc_num = procedures.size()[2] 229 | 230 | input_disease_embdding, input_proc_embedding, encoded_medication, cross_visit_scores, last_seq_medication, last_m_mask, drug_memory = self.encode(diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix, 231 | seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, dec_proc, stay_proc, dec_proc_mask, stay_proc_mask, max_len=20) 232 | 233 | # 4. 构造给decoder的medications,用于decoding过程中的teacher forcing,注意维度上增加了一维,因为会多生成一个END_TOKEN 234 | input_medication = torch.full((batch_size, max_seq_length, 1), self.SOS_TOKEN).to(device) # [batch_size, seq, 1] 235 | input_medication = torch.cat([input_medication, medications], dim=2) # [batch_size, seq, max_med_num + 1] 236 | 237 | m_sos_mask = torch.zeros((batch_size, max_seq_length, 1), device=self.device).float() # 这里用较大负值,避免softmax之后分走了概率 238 | m_mask_matrix = torch.cat([m_sos_mask, m_mask_matrix], dim=-1) 239 | 240 | output_logits = self.decode(input_medication, input_disease_embdding, input_proc_embedding, encoded_medication, last_seq_medication, cross_visit_scores, 241 | d_mask_matrix, p_mask_matrix, m_mask_matrix, last_m_mask, drug_memory) 242 | 243 | # 5. 加入ddi loss 244 | # output_logits_part = torch.exp(output_logits[:, :, :, :-2] + m_mask_matrix.unsqueeze(-1)) # 去掉SOS与EOS 245 | # output_logits_part = torch.mean(output_logits_part, dim=2) 246 | # neg_pred_prob1 = output_logits_part.unsqueeze(-1) 247 | # neg_pred_prob2 = output_logits_part.unsqueeze(-2) 248 | # neg_pred_prob = neg_pred_prob1 * neg_pred_prob2 # bach * seq * max_med_num * all_med_num * all_med_num 249 | # batch_neg = 0.0005 * neg_pred_prob.mul(self.tensor_ddi_adj).sum() 250 | # return output_logits, batch_neg 251 | return output_logits 252 | 253 | def calc_cross_visit_scores(self, visit_diag_embedding, visit_proc_embedding): 254 | """ 255 | visit_diag_embedding: (batch * visit_num * emb) 256 | visit_proc_embedding: (batch * visit_num * emb) 257 | """ 258 | max_visit_num = visit_diag_embedding.size(1) 259 | batch_size = visit_diag_embedding.size(0) 260 | 261 | # mask表示每个visit只能看到自己之前的visit 262 | mask = (torch.triu(torch.ones((max_visit_num, max_visit_num), device=self.device)) == 1).transpose(0, 1) # 下三角矩阵 263 | mask = mask.float().masked_fill(mask == 0, -1e9).masked_fill(mask == 1, float(0.0)) 264 | mask = mask.unsqueeze(0).repeat(batch_size, 1, 1) # batch * max_visit_num * max_visit_num 265 | 266 | # 每个visit后移一位 267 | padding = torch.zeros((batch_size, 1, self.emb_dim), device=self.device).float() 268 | diag_keys = torch.cat([padding, visit_diag_embedding[:, :-1, :]], dim=1) # batch * max_visit_num * emb 269 | proc_keys = torch.cat([padding, visit_proc_embedding[:, :-1, :]], dim=1) 270 | 271 | # 得到每个visit跟自己前面所有visit的score 272 | diag_scores = torch.matmul(visit_diag_embedding, diag_keys.transpose(-2, -1)) \ 273 | / math.sqrt(visit_diag_embedding.size(-1)) 274 | proc_scores = torch.matmul(visit_proc_embedding, proc_keys.transpose(-2, -1)) \ 275 | / math.sqrt(visit_proc_embedding.size(-1)) 276 | # 1st visit's scores is not zero! 277 | scores = F.softmax(diag_scores + proc_scores + mask, dim=-1) 278 | 279 | ###### case study 280 | # 将第0个val置0,然后重新归一化 281 | # scores_buf = scores 282 | # scores_buf[:, :, 0] = 0. 283 | # scores_buf = scores_buf / torch.sum(scores_buf, dim=2, keepdim=True) 284 | 285 | # print(scores_buf) 286 | return scores 287 | 288 | def copy_med(self, decode_input_hiddens, last_medications, last_m_mask, cross_visit_scores): 289 | """ 290 | decode_input_hiddens: [batch_size, max_visit_num, input_med_num, emb_size] 291 | last_medications: [batch_size, max_visit_num, max_med_num, emb_size] 292 | last_m_mask: [batch_size, max_visit_num, max_med_num] 293 | cross_visit_scores: [batch_size, max_visit_num, max_visit_num] 294 | """ 295 | max_visit_num = decode_input_hiddens.size(1) 296 | input_med_num = decode_input_hiddens.size(2) 297 | max_med_num = last_medications.size(2) 298 | copy_query = self.Wc(decode_input_hiddens).view(-1, max_visit_num*input_med_num, self.emb_dim) 299 | attn_scores = torch.matmul(copy_query, last_medications.view(-1, max_visit_num*max_med_num, self.emb_dim).transpose(-2, -1)) / math.sqrt(self.emb_dim) 300 | med_mask = last_m_mask.view(-1, 1, max_visit_num * max_med_num).repeat(1, max_visit_num * input_med_num, 1) 301 | # [batch_size, max_vist_num * input_med_num, max_visit_num * max_med_num] 302 | attn_scores = F.softmax(attn_scores + med_mask, dim=-1) 303 | 304 | # (batch_size, max_visit_num * input_med_num, max_visit_num) 305 | visit_scores = cross_visit_scores.repeat(1, 1, input_med_num).view(-1, max_visit_num * input_med_num, max_visit_num) 306 | 307 | # (batch_size, max_visit_num * input_med_num, max_visit_num * max_med_num) 308 | visit_scores = visit_scores.unsqueeze(-1).repeat(1, 1, 1, max_med_num).view(-1, max_visit_num * input_med_num, max_visit_num * max_med_num) 309 | 310 | scores = torch.mul(attn_scores, visit_scores).clamp(min=1e-9) 311 | row_scores = scores.sum(dim=-1, keepdim=True) 312 | scores = scores / row_scores # (batch_size, max_visit_num * input_med_num, max_visit_num * max_med_num) 313 | 314 | return scores 315 | 316 | 317 | class MedTransformerDecoder(nn.Module): 318 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 319 | layer_norm_eps=1e-5) -> None: 320 | super(MedTransformerDecoder, self).__init__() 321 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 322 | self.m2d_multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 323 | self.m2p_multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 324 | # Implementation of Feedforward model 325 | self.linear1 = nn.Linear(d_model, dim_feedforward) 326 | self.dropout = nn.Dropout(dropout) 327 | self.linear2 = nn.Linear(dim_feedforward, d_model) 328 | 329 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 330 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 331 | self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) 332 | self.dropout1 = nn.Dropout(dropout) 333 | self.dropout2 = nn.Dropout(dropout) 334 | self.dropout3 = nn.Dropout(dropout) 335 | 336 | self.activation = nn.ReLU() 337 | self.nhead = nhead 338 | 339 | # self.align = nn.Linear(d_model, d_model) 340 | 341 | def forward(self, input_medication_embedding, input_medication_memory, input_disease_embdding, input_proc_embedding, 342 | input_medication_self_mask, d_mask, p_mask): 343 | r"""Pass the inputs (and mask) through the decoder layer. 344 | Args: 345 | input_medication_embedding: [*, max_med_num+1, embedding_size] 346 | Shape: 347 | see the docs in Transformer class. 348 | """ 349 | input_len = input_medication_embedding.size(0) 350 | tgt_len = input_medication_embedding.size(1) 351 | 352 | # [batch_size*visit_num, max_med_num+1, max_med_num+1] 353 | subsequent_mask = self.generate_square_subsequent_mask(tgt_len, input_len * self.nhead, input_disease_embdding.device) 354 | self_attn_mask = subsequent_mask + input_medication_self_mask 355 | 356 | x = input_medication_embedding + input_medication_memory 357 | 358 | x = self.norm1(x + self._sa_block(x, self_attn_mask)) 359 | # attentioned_disease_embedding = self._m2d_mha_block(x, input_disease_embdding, d_mask) 360 | # attentioned_proc_embedding = self._m2p_mha_block(x, input_proc_embedding, p_mask) 361 | # x = self.norm3(x + self._ff_block(torch.cat([attentioned_disease_embedding, self.align(attentioned_proc_embedding)], dim=-1))) 362 | x = self.norm2(x + self._m2d_mha_block(x, input_disease_embdding, d_mask) + self._m2p_mha_block(x, input_proc_embedding, p_mask)) 363 | x = self.norm3(x + self._ff_block(x)) 364 | 365 | return x 366 | 367 | # self-attention block 368 | def _sa_block(self, x, attn_mask): 369 | x = self.self_attn(x, x, x, 370 | attn_mask=attn_mask, 371 | need_weights=False)[0] 372 | return self.dropout1(x) 373 | 374 | # multihead attention block 375 | def _m2d_mha_block(self, x, mem, attn_mask): 376 | x = self.m2d_multihead_attn(x, mem, mem, 377 | attn_mask=attn_mask, 378 | need_weights=False)[0] 379 | return self.dropout2(x) 380 | 381 | def _m2p_mha_block(self, x, mem, attn_mask): 382 | x = self.m2p_multihead_attn(x, mem, mem, 383 | attn_mask=attn_mask, 384 | need_weights=False)[0] 385 | return self.dropout2(x) 386 | 387 | # feed forward block 388 | def _ff_block(self, x): 389 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 390 | return self.dropout3(x) 391 | 392 | def generate_square_subsequent_mask(self, sz: int, batch_size: int, device): 393 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 394 | Unmasked positions are filled with float(0.0). 395 | """ 396 | mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1) 397 | mask = mask.float().masked_fill(mask == 0, -1e9).masked_fill(mask == 1, float(0.0)) 398 | mask = mask.unsqueeze(0).repeat(batch_size, 1, 1) 399 | return mask 400 | 401 | 402 | class PositionEmbedding(nn.Module): 403 | """ 404 | We assume that the sequence length is less than 512. 405 | """ 406 | def __init__(self, emb_size, max_length=512): 407 | super(PositionEmbedding, self).__init__() 408 | self.max_length = max_length 409 | self.embedding_layer = nn.Embedding(max_length, emb_size) 410 | 411 | def forward(self, batch_size, seq_length, device): 412 | assert(seq_length <= self.max_length) 413 | ids = torch.arange(0, seq_length).long().to(torch.device(device)) 414 | ids = ids.unsqueeze(0).repeat(batch_size, 1) 415 | emb = self.embedding_layer(ids) 416 | return emb 417 | 418 | 419 | class MaskLinear(nn.Module): 420 | def __init__(self, in_features, out_features, bias=True): 421 | super(MaskLinear, self).__init__() 422 | self.in_features = in_features 423 | self.out_features = out_features 424 | self.weight = nn.parameter.Parameter(torch.FloatTensor(in_features, out_features)) 425 | if bias: 426 | self.bias = nn.parameter.Parameter(torch.FloatTensor(out_features)) 427 | else: 428 | self.register_parameter('bias', None) 429 | self.reset_parameters() 430 | 431 | def reset_parameters(self): 432 | stdv = 1. / math.sqrt(self.weight.size(1)) 433 | self.weight.data.uniform_(-stdv, stdv) 434 | if self.bias is not None: 435 | self.bias.data.uniform_(-stdv, stdv) 436 | 437 | def forward(self, input, mask): 438 | weight = torch.mul(self.weight, mask) 439 | output = torch.mm(input, weight) 440 | 441 | if self.bias is not None: 442 | return output + self.bias 443 | else: 444 | return output 445 | 446 | def __repr__(self): 447 | return self.__class__.__name__ + ' (' \ 448 | + str(self.in_features) + ' -> ' \ 449 | + str(self.out_features) + ')' 450 | 451 | 452 | class GCN(nn.Module): 453 | def __init__(self, voc_size, emb_dim, ehr_adj, ddi_adj, device=torch.device('cpu:0')): 454 | super(GCN, self).__init__() 455 | self.voc_size = voc_size 456 | self.emb_dim = emb_dim 457 | self.device = device 458 | 459 | ehr_adj = self.normalize(ehr_adj + np.eye(ehr_adj.shape[0])) 460 | ddi_adj = self.normalize(ddi_adj + np.eye(ddi_adj.shape[0])) 461 | 462 | self.ehr_adj = torch.FloatTensor(ehr_adj).to(device) 463 | self.ddi_adj = torch.FloatTensor(ddi_adj).to(device) 464 | self.x = torch.eye(voc_size).to(device) 465 | 466 | self.gcn1 = GraphConvolution(voc_size, emb_dim) 467 | self.dropout = nn.Dropout(p=0.3) 468 | self.gcn2 = GraphConvolution(emb_dim, emb_dim) 469 | self.gcn3 = GraphConvolution(emb_dim, emb_dim) 470 | 471 | def forward(self): 472 | ehr_node_embedding = self.gcn1(self.x, self.ehr_adj) 473 | ddi_node_embedding = self.gcn1(self.x, self.ddi_adj) 474 | 475 | ehr_node_embedding = F.relu(ehr_node_embedding) 476 | ddi_node_embedding = F.relu(ddi_node_embedding) 477 | ehr_node_embedding = self.dropout(ehr_node_embedding) 478 | ddi_node_embedding = self.dropout(ddi_node_embedding) 479 | 480 | ehr_node_embedding = self.gcn2(ehr_node_embedding, self.ehr_adj) 481 | ddi_node_embedding = self.gcn3(ddi_node_embedding, self.ddi_adj) 482 | return ehr_node_embedding, ddi_node_embedding 483 | 484 | def normalize(self, mx): 485 | """Row-normalize sparse matrix""" 486 | rowsum = np.array(mx.sum(1)) 487 | r_inv = np.power(rowsum, -1).flatten() 488 | r_inv[np.isinf(r_inv)] = 0. 489 | r_mat_inv = np.diagflat(r_inv) 490 | mx = r_mat_inv.dot(mx) 491 | return mx 492 | 493 | 494 | class policy_network(nn.Module): 495 | def __init__(self, in_dim, out_dim, hidden_dim): 496 | super(policy_network, self).__init__() 497 | self.layers = nn.Sequential( 498 | nn.Linear(in_dim, hidden_dim), 499 | nn.ReLU(), 500 | nn.Linear(hidden_dim, out_dim) 501 | ) 502 | 503 | def forward(self, x): 504 | return self.layers(x) -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # from torch._C import contiguous_format 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from dnc import DNC 7 | from layers import GraphConvolution 8 | import math 9 | from torch.nn.parameter import Parameter 10 | 11 | ''' 12 | Our model 13 | ''' 14 | 15 | 16 | class GCN(nn.Module): 17 | def __init__(self, voc_size, emb_dim, adj, device=torch.device('cpu:0')): 18 | super(GCN, self).__init__() 19 | self.voc_size = voc_size 20 | self.emb_dim = emb_dim 21 | self.device = device 22 | 23 | adj = self.normalize(adj + np.eye(adj.shape[0])) 24 | 25 | self.adj = torch.FloatTensor(adj).to(device) 26 | self.x = torch.eye(voc_size).to(device) 27 | 28 | self.gcn1 = GraphConvolution(voc_size, emb_dim) 29 | self.dropout = nn.Dropout(p=0.3) 30 | self.gcn2 = GraphConvolution(emb_dim, emb_dim) 31 | 32 | def forward(self): 33 | node_embedding = self.gcn1(self.x, self.adj) 34 | node_embedding = F.relu(node_embedding) 35 | node_embedding = self.dropout(node_embedding) 36 | node_embedding = self.gcn2(node_embedding, self.adj) 37 | return node_embedding 38 | 39 | def normalize(self, mx): 40 | """Row-normalize sparse matrix""" 41 | rowsum = np.array(mx.sum(1)) 42 | r_inv = np.power(rowsum, -1).flatten() 43 | r_inv[np.isinf(r_inv)] = 0. 44 | r_mat_inv = np.diagflat(r_inv) 45 | mx = r_mat_inv.dot(mx) 46 | return mx 47 | 48 | 49 | class MaskLinear(nn.Module): 50 | def __init__(self, in_features, out_features, bias=True): 51 | super(MaskLinear, self).__init__() 52 | self.in_features = in_features 53 | self.out_features = out_features 54 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 55 | if bias: 56 | self.bias = Parameter(torch.FloatTensor(out_features)) 57 | else: 58 | self.register_parameter('bias', None) 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | stdv = 1. / math.sqrt(self.weight.size(1)) 63 | self.weight.data.uniform_(-stdv, stdv) 64 | if self.bias is not None: 65 | self.bias.data.uniform_(-stdv, stdv) 66 | 67 | def forward(self, input, mask): 68 | weight = torch.mul(self.weight, mask) 69 | output = torch.mm(input, weight) 70 | 71 | if self.bias is not None: 72 | return output + self.bias 73 | else: 74 | return output 75 | 76 | def __repr__(self): 77 | return self.__class__.__name__ + ' (' \ 78 | + str(self.in_features) + ' -> ' \ 79 | + str(self.out_features) + ')' 80 | 81 | 82 | class MolecularGraphNeuralNetwork(nn.Module): 83 | def __init__(self, N_fingerprint, dim, layer_hidden, device): 84 | super(MolecularGraphNeuralNetwork, self).__init__() 85 | self.device = device 86 | self.embed_fingerprint = nn.Embedding( 87 | N_fingerprint, dim).to(self.device) 88 | self.W_fingerprint = nn.ModuleList([nn.Linear(dim, dim).to(self.device) 89 | for _ in range(layer_hidden)]) 90 | self.layer_hidden = layer_hidden 91 | 92 | def pad(self, matrices, pad_value): 93 | """Pad the list of matrices 94 | with a pad_value (e.g., 0) for batch proc essing. 95 | For example, given a list of matrices [A, B, C], 96 | we obtain a new matrix [A00, 0B0, 00C], 97 | where 0 is the zero (i.e., pad value) matrix. 98 | """ 99 | shapes = [m.shape for m in matrices] 100 | M, N = sum([s[0] for s in shapes]), sum([s[1] for s in shapes]) 101 | zeros = torch.FloatTensor(np.zeros((M, N))).to(self.device) 102 | pad_matrices = pad_value + zeros 103 | i, j = 0, 0 104 | for k, matrix in enumerate(matrices): 105 | m, n = shapes[k] 106 | pad_matrices[i:i+m, j:j+n] = matrix 107 | i += m 108 | j += n 109 | return pad_matrices 110 | 111 | def update(self, matrix, vectors, layer): 112 | hidden_vectors = torch.relu(self.W_fingerprint[layer](vectors)) 113 | return hidden_vectors + torch.mm(matrix, hidden_vectors) 114 | 115 | def sum(self, vectors, axis): 116 | sum_vectors = [torch.sum(v, 0) for v in torch.split(vectors, axis)] 117 | return torch.stack(sum_vectors) 118 | 119 | def mean(self, vectors, axis): 120 | mean_vectors = [torch.mean(v, 0) for v in torch.split(vectors, axis)] 121 | return torch.stack(mean_vectors) 122 | 123 | def forward(self, inputs): 124 | """Cat or pad each input data for batch processing.""" 125 | fingerprints, adjacencies, molecular_sizes = inputs 126 | fingerprints = torch.cat(fingerprints) 127 | adjacencies = self.pad(adjacencies, 0) 128 | 129 | """MPNN layer (update the fingerprint vectors).""" 130 | fingerprint_vectors = self.embed_fingerprint(fingerprints) 131 | for l in range(self.layer_hidden): 132 | hs = self.update(adjacencies, fingerprint_vectors, l) 133 | # fingerprint_vectors = F.normalize(hs, 2, 1) # normalize. 134 | fingerprint_vectors = hs 135 | 136 | """Molecular vector by sum or mean of the fingerprint vectors.""" 137 | molecular_vectors = self.sum(fingerprint_vectors, molecular_sizes) 138 | # molecular_vectors = self.mean(fingerprint_vectors, molecular_sizes) 139 | 140 | return molecular_vectors 141 | 142 | 143 | class SafeDrugModel(nn.Module): 144 | def __init__(self, vocab_size, ddi_adj, ddi_mask_H, MPNNSet, N_fingerprints, average_projection, emb_dim=256, device=torch.device('cpu:0')): 145 | super(SafeDrugModel, self).__init__() 146 | 147 | self.device = device 148 | 149 | # pre-embedding 150 | self.embeddings = nn.ModuleList( 151 | [nn.Embedding(vocab_size[i], emb_dim) for i in range(2)]) 152 | self.dropout = nn.Dropout(p=0.5) 153 | self.encoders = nn.ModuleList( 154 | [nn.GRU(emb_dim, emb_dim, batch_first=True) for _ in range(2)]) 155 | self.query = nn.Sequential( 156 | nn.ReLU(), 157 | nn.Linear(2 * emb_dim, emb_dim) 158 | ) 159 | 160 | # bipartite local embedding 161 | self.bipartite_transform = nn.Sequential( 162 | nn.Linear(emb_dim, ddi_mask_H.shape[1]) 163 | ) 164 | self.bipartite_output = MaskLinear( 165 | ddi_mask_H.shape[1], vocab_size[2], False) 166 | 167 | # MPNN global embedding 168 | self.MPNN_molecule_Set = list(zip(*MPNNSet)) 169 | 170 | self.MPNN_emb = MolecularGraphNeuralNetwork( 171 | N_fingerprints, emb_dim, layer_hidden=2, device=device).forward(self.MPNN_molecule_Set) 172 | self.MPNN_emb = torch.mm(average_projection.to( 173 | device=self.device), self.MPNN_emb.to(device=self.device)) 174 | # self.MPNN_emb.to(device=self.device) 175 | self.MPNN_emb = torch.tensor(self.MPNN_emb, requires_grad=True) 176 | self.MPNN_output = nn.Linear(vocab_size[2], vocab_size[2]) 177 | self.MPNN_layernorm = nn.LayerNorm(vocab_size[2]) 178 | 179 | # graphs, bipartite matrix 180 | self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device) 181 | self.tensor_ddi_mask_H = torch.FloatTensor(ddi_mask_H).to(device) 182 | self.init_weights() 183 | 184 | def forward(self, input): 185 | 186 | # patient health representation 187 | i1_seq = [] 188 | i2_seq = [] 189 | 190 | def sum_embedding(embedding): 191 | return embedding.sum(dim=1).unsqueeze(dim=0) # (1,1,dim) 192 | for adm in input: 193 | i1 = sum_embedding(self.dropout(self.embeddings[0]( 194 | torch.LongTensor(adm[0]).unsqueeze(dim=0).to(self.device)))) # (1,1,dim) 195 | i2 = sum_embedding(self.dropout(self.embeddings[1]( 196 | torch.LongTensor(adm[1]).unsqueeze(dim=0).to(self.device)))) 197 | i1_seq.append(i1) 198 | i2_seq.append(i2) 199 | i1_seq = torch.cat(i1_seq, dim=1) # (1,seq,dim) 200 | i2_seq = torch.cat(i2_seq, dim=1) # (1,seq,dim) 201 | 202 | o1, h1 = self.encoders[0]( 203 | i1_seq 204 | ) 205 | o2, h2 = self.encoders[1]( 206 | i2_seq 207 | ) 208 | patient_representations = torch.cat( 209 | [o1, o2], dim=-1).squeeze(dim=0) # (seq, dim*2) 210 | query = self.query(patient_representations)[-1:, :] # (seq, dim) 211 | 212 | # MPNN embedding 213 | MPNN_match = F.sigmoid(torch.mm(query, self.MPNN_emb.t())) 214 | MPNN_att = self.MPNN_layernorm( 215 | MPNN_match + self.MPNN_output(MPNN_match)) 216 | 217 | # local embedding 218 | bipartite_emb = self.bipartite_output( 219 | F.sigmoid(self.bipartite_transform(query)), self.tensor_ddi_mask_H.t()) 220 | 221 | result = torch.mul(bipartite_emb, MPNN_att) 222 | 223 | neg_pred_prob = F.sigmoid(result) 224 | neg_pred_prob = neg_pred_prob.t() * neg_pred_prob # (voc_size, voc_size) 225 | 226 | batch_neg = 0.0005 * neg_pred_prob.mul(self.tensor_ddi_adj).sum() 227 | 228 | return result, batch_neg 229 | 230 | def init_weights(self): 231 | """Initialize weights.""" 232 | initrange = 0.1 233 | for item in self.embeddings: 234 | item.weight.data.uniform_(-initrange, initrange) 235 | 236 | 237 | class DMNC(nn.Module): 238 | def __init__(self, vocab_size, emb_dim=64, device=torch.device('cpu:0')): 239 | super(DMNC, self).__init__() 240 | K = len(vocab_size) 241 | self.K = K 242 | self.vocab_size = vocab_size 243 | self.device = device 244 | 245 | self.token_start = vocab_size[2] 246 | self.token_end = vocab_size[2] + 1 247 | 248 | self.embeddings = nn.ModuleList( 249 | [nn.Embedding(vocab_size[i] if i != 2 else vocab_size[2] + 2, emb_dim) for i in range(K)]) 250 | self.dropout = nn.Dropout(p=0.5) 251 | 252 | self.encoders = nn.ModuleList([DNC( 253 | input_size=emb_dim, 254 | hidden_size=emb_dim, 255 | rnn_type='gru', 256 | num_layers=1, 257 | num_hidden_layers=1, 258 | nr_cells=16, 259 | cell_size=emb_dim, 260 | read_heads=1, 261 | batch_first=True, 262 | gpu_id=0, 263 | independent_linears=False 264 | ) for _ in range(K - 1)]) 265 | 266 | self.decoder = nn.GRU(emb_dim + emb_dim * 2, emb_dim * 2, 267 | batch_first=True) # input: (y, r1, r2,) hidden: (hidden1, hidden2) 268 | self.interface_weighting = nn.Linear( 269 | emb_dim * 2, 2 * (emb_dim + 1 + 3)) # 2 read head (key, str, mode) 270 | self.decoder_r2o = nn.Linear(2 * emb_dim, emb_dim * 2) 271 | 272 | self.output = nn.Linear(emb_dim * 2, vocab_size[2] + 2) 273 | 274 | def forward(self, input, i1_state=None, i2_state=None, h_n=None, max_len=20): 275 | # input (3, code) 276 | i1_input_tensor = self.embeddings[0]( 277 | torch.LongTensor(input[0]).unsqueeze(dim=0).to(self.device)) # (1, seq, codes) 278 | i2_input_tensor = self.embeddings[1]( 279 | torch.LongTensor(input[1]).unsqueeze(dim=0).to(self.device)) # (1, seq, codes) 280 | 281 | o1, (ch1, m1, r1) = \ 282 | self.encoders[0](i1_input_tensor, (None, None, None) 283 | if i1_state is None else i1_state) 284 | o2, (ch2, m2, r2) = \ 285 | self.encoders[1](i2_input_tensor, (None, None, None) 286 | if i2_state is None else i2_state) 287 | 288 | # save memory state 289 | i1_state = (ch1, m1, r1) 290 | i2_state = (ch2, m2, r2) 291 | 292 | predict_sequence = [self.token_start] + input[2] 293 | if h_n is None: 294 | h_n = torch.cat([ch1[0], ch2[0]], dim=-1) 295 | 296 | output_logits = [] 297 | r1 = r1.unsqueeze(dim=0) 298 | r2 = r2.unsqueeze(dim=0) 299 | 300 | if self.training: 301 | for item in predict_sequence: 302 | # teacher force predict drug 303 | item_tensor = self.embeddings[2]( 304 | torch.LongTensor([item]).unsqueeze(dim=0).to(self.device)) # (1, seq, codes) 305 | 306 | o3, h_n = self.decoder( 307 | torch.cat([item_tensor, r1, r2], dim=-1), h_n) 308 | read_keys, read_strengths, read_modes = self.decode_read_variable( 309 | h_n.squeeze(0)) 310 | 311 | # read from i1_mem, i2_mem and i3_mem 312 | r1, _ = self.read_from_memory(self.encoders[0], 313 | read_keys[:, 0, :].unsqueeze( 314 | dim=1), 315 | read_strengths[:, 0].unsqueeze( 316 | dim=1), 317 | read_modes[:, 0, :].unsqueeze(dim=1), i1_state[1]) 318 | 319 | r2, _ = self.read_from_memory(self.encoders[1], 320 | read_keys[:, 1, :].unsqueeze( 321 | dim=1), 322 | read_strengths[:, 1].unsqueeze( 323 | dim=1), 324 | read_modes[:, 1, :].unsqueeze(dim=1), i2_state[1]) 325 | 326 | output = self.decoder_r2o(torch.cat([r1, r2], dim=-1)) 327 | output = self.output(output + o3).squeeze(dim=0) 328 | output_logits.append(output) 329 | else: 330 | item_tensor = self.embeddings[2]( 331 | torch.LongTensor([self.token_start]).unsqueeze(dim=0).to(self.device)) # (1, seq, codes) 332 | for idx in range(max_len): 333 | # predict 334 | # teacher force predict drug 335 | o3, h_n = self.decoder( 336 | torch.cat([item_tensor, r1, r2], dim=-1), h_n) 337 | read_keys, read_strengths, read_modes = self.decode_read_variable( 338 | h_n.squeeze(0)) 339 | 340 | # read from i1_mem, i2_mem and i3_mem 341 | r1, _ = self.read_from_memory(self.encoders[0], 342 | read_keys[:, 0, :].unsqueeze( 343 | dim=1), 344 | read_strengths[:, 0].unsqueeze( 345 | dim=1), 346 | read_modes[:, 0, :].unsqueeze(dim=1), i1_state[1]) 347 | 348 | r2, _ = self.read_from_memory(self.encoders[1], 349 | read_keys[:, 1, :].unsqueeze( 350 | dim=1), 351 | read_strengths[:, 1].unsqueeze( 352 | dim=1), 353 | read_modes[:, 1, :].unsqueeze(dim=1), i2_state[1]) 354 | 355 | output = self.decoder_r2o(torch.cat([r1, r2], dim=-1)) 356 | output = self.output(output + o3).squeeze(dim=0) 357 | output = F.softmax(output, dim=-1) 358 | output_logits.append(output) 359 | 360 | input_token = torch.argmax(output, dim=-1) 361 | input_token = input_token.item() 362 | item_tensor = self.embeddings[2]( 363 | torch.LongTensor([input_token]).unsqueeze(dim=0).to(self.device)) # (1, seq, codes) 364 | 365 | return torch.cat(output_logits, dim=0), i1_state, i2_state, h_n 366 | 367 | def read_from_memory(self, dnc, read_key, read_str, read_mode, m_hidden): 368 | read_vectors, hidden = dnc.memories[0].read( 369 | read_key, read_str, read_mode, m_hidden) 370 | return read_vectors, hidden 371 | 372 | def decode_read_variable(self, input): 373 | w = 64 374 | r = 2 375 | b = input.size(0) 376 | 377 | input = self.interface_weighting(input) 378 | # r read keys (b * w * r) 379 | read_keys = F.tanh(input[:, :r * w].contiguous().view(b, r, w)) 380 | # r read strengths (b * r) 381 | read_strengths = F.softplus( 382 | input[:, r * w:r * w + r].contiguous().view(b, r)) 383 | # read modes (b * 3*r) 384 | read_modes = F.softmax( 385 | input[:, (r * w + r):].contiguous().view(b, r, 3), -1) 386 | return read_keys, read_strengths, read_modes 387 | 388 | 389 | class GAMENet(nn.Module): 390 | def __init__(self, vocab_size, ehr_adj, ddi_adj, emb_dim=64, device=torch.device('cpu:0'), ddi_in_memory=True): 391 | super(GAMENet, self).__init__() 392 | K = len(vocab_size) 393 | self.K = K 394 | self.vocab_size = vocab_size 395 | self.device = device 396 | self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device) 397 | self.ddi_in_memory = ddi_in_memory 398 | self.embeddings = nn.ModuleList( 399 | [nn.Embedding(vocab_size[i], emb_dim) for i in range(K-1)]) 400 | self.dropout = nn.Dropout(p=0.5) 401 | 402 | self.encoders = nn.ModuleList( 403 | [nn.GRU(emb_dim, emb_dim * 2, batch_first=True) for _ in range(K-1)]) 404 | 405 | self.query = nn.Sequential( 406 | nn.ReLU(), 407 | nn.Linear(emb_dim * 4, emb_dim), 408 | ) 409 | 410 | self.ehr_gcn = GCN( 411 | voc_size=vocab_size[2], emb_dim=emb_dim, adj=ehr_adj, device=device) 412 | self.ddi_gcn = GCN( 413 | voc_size=vocab_size[2], emb_dim=emb_dim, adj=ddi_adj, device=device) 414 | self.inter = nn.Parameter(torch.FloatTensor(1)) 415 | 416 | self.output = nn.Sequential( 417 | nn.ReLU(), 418 | nn.Linear(emb_dim * 3, emb_dim * 2), 419 | nn.ReLU(), 420 | nn.Linear(emb_dim * 2, vocab_size[2]) 421 | ) 422 | 423 | self.init_weights() 424 | 425 | def forward(self, input): 426 | # input (adm, 3, codes) 427 | 428 | # generate medical embeddings and queries 429 | i1_seq = [] 430 | i2_seq = [] 431 | 432 | def mean_embedding(embedding): 433 | return embedding.mean(dim=1).unsqueeze(dim=0) # (1,1,dim) 434 | for adm in input: 435 | i1 = mean_embedding(self.dropout(self.embeddings[0]( 436 | torch.LongTensor(adm[0]).unsqueeze(dim=0).to(self.device)))) # (1,1,dim) 437 | i2 = mean_embedding(self.dropout(self.embeddings[1]( 438 | torch.LongTensor(adm[1]).unsqueeze(dim=0).to(self.device)))) 439 | i1_seq.append(i1) 440 | i2_seq.append(i2) 441 | i1_seq = torch.cat(i1_seq, dim=1) # (1,seq,dim) 442 | i2_seq = torch.cat(i2_seq, dim=1) # (1,seq,dim) 443 | 444 | o1, h1 = self.encoders[0]( 445 | i1_seq 446 | ) # o1:(1, seq, dim*2) hi:(1,1,dim*2) 447 | o2, h2 = self.encoders[1]( 448 | i2_seq 449 | ) 450 | patient_representations = torch.cat( 451 | [o1, o2], dim=-1).squeeze(dim=0) # (seq, dim*4) 452 | queries = self.query(patient_representations) # (seq, dim) 453 | 454 | # graph memory module 455 | '''I:generate current input''' 456 | query = queries[-1:] # (1,dim) 457 | 458 | '''G:generate graph memory bank and insert history information''' 459 | if self.ddi_in_memory: 460 | drug_memory = self.ehr_gcn() - self.ddi_gcn() * self.inter # (size, dim) 461 | else: 462 | drug_memory = self.ehr_gcn() 463 | 464 | if len(input) > 1: 465 | history_keys = queries[:(queries.size(0)-1)] # (seq-1, dim) 466 | 467 | history_values = np.zeros((len(input)-1, self.vocab_size[2])) 468 | for idx, adm in enumerate(input): 469 | if idx == len(input)-1: 470 | break 471 | history_values[idx, adm[2]] = 1 472 | history_values = torch.FloatTensor( 473 | history_values).to(self.device) # (seq-1, size) 474 | 475 | '''O:read from global memory bank and dynamic memory bank''' 476 | key_weights1 = F.softmax( 477 | torch.mm(query, drug_memory.t()), dim=-1) # (1, size) 478 | fact1 = torch.mm(key_weights1, drug_memory) # (1, dim) 479 | 480 | if len(input) > 1: 481 | visit_weight = F.softmax( 482 | torch.mm(query, history_keys.t())) # (1, seq-1) 483 | weighted_values = visit_weight.mm(history_values) # (1, size) 484 | fact2 = torch.mm(weighted_values, drug_memory) # (1, dim) 485 | else: 486 | fact2 = fact1 487 | '''R:convert O and predict''' 488 | output = self.output( 489 | torch.cat([query, fact1, fact2], dim=-1)) # (1, dim) 490 | 491 | if self.training: 492 | neg_pred_prob = F.sigmoid(output) 493 | neg_pred_prob = neg_pred_prob.t() * neg_pred_prob # (voc_size, voc_size) 494 | batch_neg = neg_pred_prob.mul(self.tensor_ddi_adj).mean() 495 | 496 | return output, batch_neg 497 | else: 498 | return output 499 | 500 | def init_weights(self): 501 | """Initialize weights.""" 502 | initrange = 0.1 503 | for item in self.embeddings: 504 | item.weight.data.uniform_(-initrange, initrange) 505 | 506 | self.inter.data.uniform_(-initrange, initrange) 507 | 508 | 509 | class Leap(nn.Module): 510 | def __init__(self, voc_size, emb_dim=64, device=torch.device('cpu:0')): 511 | super(Leap, self).__init__() 512 | self.voc_size = voc_size 513 | self.device = device 514 | self.SOS_TOKEN = voc_size[2] # start of sentence 515 | self.END_TOKEN = voc_size[2]+1 # end 新增的两个编码,两者均是针对于药物的embedding 516 | 517 | # dig_num * emb_dim 518 | self.enc_embedding = nn.Sequential( 519 | nn.Embedding(voc_size[0], emb_dim, ), 520 | nn.Dropout(0.3) 521 | ) 522 | 523 | # med_num * emb_dim 524 | self.dec_embedding = nn.Sequential( 525 | nn.Embedding(voc_size[2] + 2, emb_dim, ), 526 | nn.Dropout(0.3) 527 | ) 528 | 529 | self.dec_gru = nn.GRU(emb_dim*2, emb_dim, batch_first=True) 530 | 531 | self.attn = nn.Linear(emb_dim*2, 1) 532 | 533 | self.output = nn.Linear(emb_dim, voc_size[2]+2) 534 | 535 | def forward(self, input, max_len=20): 536 | device = self.device 537 | # input (3, codes) 538 | input_tensor = torch.LongTensor(input[0]).to(device) 539 | # (len, dim) 540 | # 对疾病进行编码 541 | input_embedding = self.enc_embedding( 542 | input_tensor.unsqueeze(dim=0)).squeeze(dim=0) 543 | 544 | output_logits = [] 545 | hidden_state = None 546 | if self.training: 547 | # training 过程 548 | # 对于每一个当前已知的药物分别进行计算,有点像teacher-forcing的思路 549 | for med_code in [self.SOS_TOKEN] + input[2]: 550 | dec_input = torch.LongTensor( 551 | [med_code]).unsqueeze(dim=0).to(device) 552 | dec_input = self.dec_embedding(dec_input).squeeze( 553 | dim=0) # (1,dim) 取对应药物的embedding 554 | 555 | if hidden_state is None: # 使用上一个adm的hidden_state 556 | hidden_state = dec_input 557 | hidden_state_repeat = hidden_state.repeat( 558 | input_embedding.size(0), 1) # (len, dim) 559 | 560 | combined_input = torch.cat( 561 | [hidden_state_repeat, input_embedding], dim=-1) # (len, dim*2) 562 | # (1, len) 计算该药物针对所有疾病的attention weight 563 | attn_weight = F.softmax(self.attn(combined_input).t(), dim=-1) 564 | input_embedding = attn_weight.mm( 565 | input_embedding) # (1, dim) # 加权求和 566 | 567 | _, hidden_state = self.dec_gru(torch.cat( 568 | [input_embedding, dec_input], dim=-1).unsqueeze(dim=0), hidden_state.unsqueeze(dim=0)) 569 | hidden_state = hidden_state.squeeze(dim=0) # (1,dim) 570 | 571 | # (1, med_num) 表示当前位置上每个med的logits 572 | output_logits.append(self.output(F.relu(hidden_state))) 573 | return torch.cat(output_logits, dim=0) 574 | 575 | else: 576 | # testing 过程,这里不能用input[2]也就是medication的信息 577 | # 控制最大的长度(可以根据数据的范围调整) 578 | for di in range(max_len): 579 | if di == 0: 580 | dec_input = torch.LongTensor([[self.SOS_TOKEN]]).to( 581 | device) # 第一个位置用SOS,后面的则用上一个位置的预测结果 582 | dec_input = self.dec_embedding( 583 | dec_input).squeeze(dim=0) # (1,dim) 584 | if hidden_state is None: 585 | hidden_state = dec_input 586 | hidden_state_repeat = hidden_state.repeat( 587 | input_embedding.size(0), 1) # (len, dim) 588 | combined_input = torch.cat( 589 | [hidden_state_repeat, input_embedding], dim=-1) # (len, dim*2) 590 | attn_weight = F.softmax( 591 | self.attn(combined_input).t(), dim=-1) # (1, len) 592 | input_embedding = attn_weight.mm(input_embedding) # (1, dim) 593 | _, hidden_state = self.dec_gru(torch.cat([input_embedding, dec_input], dim=-1).unsqueeze(dim=0), 594 | hidden_state.unsqueeze(dim=0)) 595 | hidden_state = hidden_state.squeeze(dim=0) # (1,dim) 596 | output = self.output(F.relu(hidden_state)) 597 | # data是直接取数据,这里直接获取当前位置上最有可能的logits 598 | topv, topi = output.data.topk(1) 599 | output_logits.append(F.softmax(output, dim=-1)) 600 | dec_input = topi.detach() 601 | return torch.cat(output_logits, dim=0) 602 | 603 | 604 | class Retain(nn.Module): 605 | def __init__(self, voc_size, emb_size=64, device=torch.device('cpu:0')): 606 | super(Retain, self).__init__() 607 | self.device = device 608 | self.voc_size = voc_size 609 | self.emb_size = emb_size 610 | self.input_len = voc_size[0] + voc_size[1] + voc_size[2] 611 | self.output_len = voc_size[2] 612 | 613 | self.embedding = nn.Sequential( 614 | nn.Embedding(self.input_len + 1, self.emb_size, 615 | padding_idx=self.input_len), 616 | nn.Dropout(0.5) 617 | ) 618 | 619 | self.alpha_gru = nn.GRU(emb_size, emb_size, batch_first=True) 620 | self.beta_gru = nn.GRU(emb_size, emb_size, batch_first=True) 621 | 622 | self.alpha_li = nn.Linear(emb_size, 1) 623 | self.beta_li = nn.Linear(emb_size, emb_size) 624 | 625 | self.output = nn.Linear(emb_size, self.output_len) 626 | 627 | def forward(self, input): 628 | device = self.device 629 | # input: (visit, 3, codes ) 630 | max_len = max([(len(v[0]) + len(v[1]) + len(v[2])) for v in input]) 631 | input_np = [] 632 | for visit in input: 633 | input_tmp = [] 634 | input_tmp.extend(visit[0]) 635 | input_tmp.extend(list(np.array(visit[1]) + self.voc_size[0])) 636 | input_tmp.extend( 637 | list(np.array(visit[2]) + self.voc_size[0] + self.voc_size[1])) 638 | if len(input_tmp) < max_len: 639 | input_tmp.extend([self.input_len]*(max_len - len(input_tmp))) 640 | 641 | input_np.append(input_tmp) 642 | 643 | visit_emb = self.embedding(torch.LongTensor( 644 | input_np).to(device)) # (visit, max_len, emb) 645 | visit_emb = torch.sum(visit_emb, dim=1) # (visit, emb) 646 | 647 | g, _ = self.alpha_gru(visit_emb.unsqueeze(dim=0)) # g: (1, visit, emb) 648 | h, _ = self.beta_gru(visit_emb.unsqueeze(dim=0)) # h: (1, visit, emb) 649 | 650 | g = g.squeeze(dim=0) # (visit, emb) 651 | h = h.squeeze(dim=0) # (visit, emb) 652 | attn_g = F.softmax(self.alpha_li(g), dim=-1) # (visit, 1) 653 | attn_h = F.tanh(self.beta_li(h)) # (visit, emb) 654 | 655 | c = attn_g * attn_h * visit_emb # (visit, emb) 656 | c = torch.sum(c, dim=0).unsqueeze(dim=0) # (1, emb) 657 | 658 | return self.output(c) 659 | 660 | 661 | class Leap_batch(nn.Module): 662 | def __init__(self, voc_size, emb_dim=64, device=torch.device('cpu:0')): 663 | super(Leap_batch, self).__init__() 664 | self.voc_size = voc_size 665 | self.emb_dim = emb_dim 666 | self.device = device 667 | self.SOS_TOKEN = voc_size[2] # start of sentence 668 | self.END_TOKEN = voc_size[2]+1 # end 新增的两个编码,两者均是针对于药物的embedding 669 | self.MED_PAD_TOKEN = voc_size[2]+2 # 用于embedding矩阵中的padding(全为0) 670 | self.DIAG_PAD_TOKEN = voc_size[0]+2 671 | self.PROC_PAD_TOKEN = voc_size[1]+2 672 | # dig_num * emb_dim 673 | self.diag_embedding = nn.Sequential( 674 | nn.Embedding(voc_size[0]+3, emb_dim, self.DIAG_PAD_TOKEN), 675 | nn.Dropout(0.3) 676 | ) 677 | 678 | # proc_num * emb_dim 679 | self.proc_embedding = nn.Sequential( 680 | nn.Embedding(voc_size[1]+3, emb_dim, self.PROC_PAD_TOKEN), 681 | nn.Dropout(0.3) 682 | ) 683 | 684 | # med_num * emb_dim 685 | self.med_embedding = nn.Sequential( 686 | # 添加padding_idx,表示取0向量 687 | nn.Embedding(voc_size[2] + 3, emb_dim, self.MED_PAD_TOKEN), 688 | nn.Dropout(0.3) 689 | ) 690 | 691 | # 用于对上一个visit进行编码 692 | self.enc_gru = nn.GRU(emb_dim, emb_dim, batch_first=True, bidirectional=True) 693 | 694 | # 用于生成药物序列 695 | self.dec_gru = nn.GRU(emb_dim*2, emb_dim, batch_first=True) 696 | 697 | self.attn = nn.Linear(emb_dim*2, 1) 698 | 699 | # self.output = nn.Linear(emb_dim, voc_size[2]+2) 700 | # self.output2 = nn.Linear(emb_dim, voc_size[2]+2) 701 | 702 | # weights 703 | self.Ws = nn.Linear(emb_dim*2, emb_dim) # only used at initial stage 704 | self.Wo = nn.Linear(emb_dim, voc_size[2]+2) # generate mode 705 | self.Wc = nn.Linear(emb_dim*2, emb_dim) # copy mode 706 | 707 | def encoder(self, x): 708 | # input: (med_num) 709 | embedded = self.dec_embedding(x) 710 | out, h = self.enc_gru(embedded) # out: [b x seq x hid*2] (biRNN) 711 | return out, h 712 | 713 | # def forward(self, input, last_input=None, max_len=20): 714 | def forward(self, diseases, procedures, medications, d_mask_matrix, m_mask_matrix, seq_length, max_len=20): 715 | device = self.device 716 | # batch维度以及seq维度上并行计算(现在不考虑时间序列信息),每一个medication序列仍然按顺序预测 717 | batch_size, max_seq_length, max_med_num = medications.size() 718 | max_diag_num = diseases.size()[2] 719 | hidden_state = None 720 | 721 | # print("diasease size", diseases.size()) 722 | # print("proc size", procedures.size()) 723 | # print("med size", medications.size()) 724 | input_disease_embdding = self.diag_embedding(diseases) # [batch, seq, max_d_num, emb] 725 | input_med_embedding = self.med_embedding(medications) # [batch, seq, max_med_num, emb] 726 | 727 | # 拼接一个last_seq_medication,表示对应seq对应的上一次的medication,第一次的由于没有上一次medication,用0填补(用啥填补都行,反正不会用到) 728 | last_seq_medication = torch.full((batch_size, 1, max_med_num), 0).to(device) 729 | last_seq_medication = torch.cat([last_seq_medication, medications[:, :-1, :]], dim=1) 730 | # m_mask_matrix矩阵同样也需要后移 731 | last_m_mask = torch.full((batch_size, 1, max_med_num), -1e9).to(device) # 这里用较大负值,避免softmax之后分走了概率 732 | last_m_mask = torch.cat([last_m_mask, m_mask_matrix[:, :-1, :]], dim=1) 733 | 734 | last_seq_medication_emb = self.med_embedding(last_seq_medication) 735 | # print(last_seq_medication.size(), input_med_embedding.size()) 736 | # 对last_visit进行编码,注意这里用的是last_seq_medication的编码结果 737 | # (batch*seq, max_med_num, emb_dim*2) 738 | encoded_disease, _ = self.enc_gru(last_seq_medication_emb.view(batch_size * max_seq_length, max_med_num, self.emb_dim)) 739 | 740 | # 同样拼接一个last_medication,用于进行序列生成,注意维度上增加了一维 741 | last_medication = torch.full((batch_size, max_seq_length, 1), self.SOS_TOKEN).to(device) # [batch_size, seq, 1] 742 | last_medication = torch.cat([last_medication, medications], dim=2) # [batch_size, seq, max_med_num + 1] 743 | # print(last_medication.size(), medications.size()) 744 | 745 | hidden_state = None 746 | # 预定义结果矩阵 747 | if self.training: 748 | output_logits = torch.zeros(batch_size, max_seq_length, max_med_num+1, self.voc_size[2]+2).to(device) 749 | loop_size=max_med_num+1 750 | else: 751 | output_logits = torch.zeros(batch_size, max_seq_length, max_len, self.voc_size[2]+2).to(device) 752 | loop_size=max_len 753 | 754 | # 开始遍历生成每一个位置的结果 755 | for i in range(loop_size): 756 | if self.training: 757 | dec_input = self.med_embedding(last_medication[:,:,i]) # (batch, seq, emb_dim) 取上一个药物的embedding 758 | else: 759 | if i==0: 760 | dec_input = self.med_embedding(last_medication[:,:,0]) 761 | elif i==max_len: 762 | break 763 | else: 764 | # 非训练时,只能取上一次的输出 765 | dec_input=self.med_embedding(dec_input) 766 | 767 | if hidden_state is None: 768 | hidden_state=dec_input 769 | 770 | # 根据当前的疾病做attention,计算hidden_state (batch, seq, emb_dim) 771 | # print(dec_input.size()) 772 | hidden_state_repeat = hidden_state.unsqueeze(dim=2).repeat(1,1,max_diag_num,1) # (batch, seq, max_diag_num, emb_dim) 773 | # print(hidden_state_repeat.size()) 774 | combined_input=torch.cat([hidden_state_repeat, input_disease_embdding], dim=-1) # (batch, seq, max_diag_num, emb_dim*2) 775 | """这里attn_score结果需要根据mask来加上一个较大的负值,使得对应softmax值接近0,来避免分散注意力""" 776 | attn_score = self.attn(combined_input).squeeze(dim=-1) # (batch, seq, max_diag_num, 1) -> (batch, seq, max_diag_num) 777 | attn_score = attn_score + d_mask_matrix 778 | 779 | attn_weight=F.softmax(attn_score, dim=-1).unsqueeze(dim=2) # (batch, seq, 1, max_diag_num) 注意力权重 780 | # print(attn_weight.size()) 781 | input_embedding=torch.matmul(attn_weight, input_disease_embdding).squeeze(dim=2) # (batch, seq, emb_dim) 782 | 783 | # 为了送到dec_gru中进行reshape 784 | input_embedding_buf = input_embedding.view(batch_size * max_seq_length, 1, -1) 785 | dec_input_buf = dec_input.view(batch_size * max_seq_length, 1, -1) 786 | # print(input_embedding_buf.size()) 787 | # print(dec_input_buf.size()) 788 | hidden_state_buf = hidden_state.view(1, batch_size * max_seq_length, -1) 789 | _, hidden_state_buf = self.dec_gru(torch.cat([input_embedding_buf, dec_input_buf], dim=-1), hidden_state_buf) 790 | # print(hidden_state_buf.size()) 791 | hidden_state = hidden_state_buf.view(batch_size, max_seq_length, -1) # (batch, seq, emb_dim) 792 | # print(hidden_state.size()) 793 | 794 | score_g = self.Wo(hidden_state) # (batch, seq, voc_size[2]+2) 795 | 796 | prob = torch.log_softmax(score_g, dim=-1) 797 | output_logits[:, :, i, :] = prob 798 | 799 | if not self.training: 800 | # data是直接取数据,这里直接获取当前位置上最有可能的logits 801 | _, topi = torch.topk(prob, 1, dim=-1) 802 | dec_input=topi.detach() 803 | 804 | return output_logits 805 | 806 | 807 | class MICRON(nn.Module): 808 | def __init__(self, vocab_size, ddi_adj, emb_dim=256, device=torch.device('cpu:0')): 809 | super(MICRON, self).__init__() 810 | 811 | self.device = device 812 | 813 | # pre-embedding 814 | self.embeddings = nn.ModuleList( 815 | [nn.Embedding(vocab_size[i], emb_dim) for i in range(2)]) 816 | self.dropout = nn.Dropout(p=0.5) 817 | 818 | self.health_net = nn.Sequential( 819 | nn.Linear(2 * emb_dim, emb_dim) 820 | ) 821 | 822 | # 823 | self.prescription_net = nn.Sequential( 824 | nn.Linear(emb_dim, emb_dim * 4), 825 | nn.ReLU(), 826 | nn.Linear(emb_dim * 4, vocab_size[2]) 827 | ) 828 | 829 | # graphs, bipartite matrix 830 | self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device) 831 | self.init_weights() 832 | 833 | def forward(self, input): 834 | 835 | # patient health representation 836 | def sum_embedding(embedding): 837 | return embedding.sum(dim=1).unsqueeze(dim=0) # (1,1,dim) 838 | 839 | diag_emb = sum_embedding(self.dropout(self.embeddings[0](torch.LongTensor(input[-1][0]).unsqueeze(dim=0).to(self.device)))) # (1,1,dim) 840 | prod_emb = sum_embedding(self.dropout(self.embeddings[1](torch.LongTensor(input[-1][1]).unsqueeze(dim=0).to(self.device)))) 841 | # diag_emb = torch.cat(diag_emb, dim=1) #(1,seq,dim) 842 | # prod_emb = torch.cat(prod_emb, dim=1) #(1,seq,dim) 843 | 844 | if len(input) < 2: 845 | diag_emb_last = diag_emb * torch.tensor(0.0) 846 | prod_emb_last = diag_emb * torch.tensor(0.0) 847 | else: 848 | diag_emb_last = sum_embedding(self.dropout(self.embeddings[0](torch.LongTensor(input[-2][0]).unsqueeze(dim=0).to(self.device)))) # (1,1,dim) 849 | prod_emb_last = sum_embedding(self.dropout(self.embeddings[1](torch.LongTensor(input[-2][1]).unsqueeze(dim=0).to(self.device)))) 850 | # diag_emb_last = torch.cat(diag_emb_last, dim=1) #(1,seq,dim) 851 | # prod_emb_last = torch.cat(prod_emb_last, dim=1) #(1,seq,dim) 852 | 853 | health_representation = torch.cat([diag_emb, prod_emb], dim=-1).squeeze(dim=0) # (seq, dim*2) 854 | health_representation_last = torch.cat([diag_emb_last, prod_emb_last], dim=-1).squeeze(dim=0) # (seq, dim*2) 855 | 856 | health_rep = self.health_net(health_representation)[-1:, :] # (seq, dim) 857 | health_rep_last = self.health_net(health_representation_last)[-1:, :] # (seq, dim) 858 | health_residual_rep = health_rep - health_rep_last 859 | 860 | # drug representation 861 | drug_rep = self.prescription_net(health_rep) 862 | drug_rep_last = self.prescription_net(health_rep_last) 863 | drug_residual_rep = self.prescription_net(health_residual_rep) 864 | 865 | # reconstructon loss 866 | rec_loss = 1 / self.tensor_ddi_adj.shape[0] * torch.sum(torch.pow((F.sigmoid(drug_rep) - F.sigmoid(drug_rep_last + drug_residual_rep)), 2)) 867 | 868 | # ddi_loss 869 | neg_pred_prob = F.sigmoid(drug_rep) 870 | neg_pred_prob = neg_pred_prob.t() * neg_pred_prob # (voc_size, voc_size) 871 | 872 | batch_neg = 1 / self.tensor_ddi_adj.shape[0] * neg_pred_prob.mul(self.tensor_ddi_adj).sum() 873 | return drug_rep, drug_rep_last, drug_residual_rep, batch_neg, rec_loss 874 | 875 | def init_weights(self): 876 | """Initialize weights.""" 877 | initrange = 0.1 878 | for item in self.embeddings: 879 | item.weight.data.uniform_(-initrange, initrange) 880 | --------------------------------------------------------------------------------