├── data
├── idx2ndc.pkl
├── idx2drug.pkl
├── ndc2drug.pkl
├── ddi_A_final.pkl
├── ddi_mask_H.pkl
├── ehr_adj_final.pkl
├── ddi_mask_H.py
├── get_SMILES.py
├── processing.py
└── drug-atc.csv
├── src
├── __pycache__
│ ├── beam.cpython-39.pyc
│ ├── loss.cpython-39.pyc
│ ├── util.cpython-39.pyc
│ ├── layers.cpython-39.pyc
│ ├── models.cpython-39.pyc
│ ├── model_v2.cpython-39.pyc
│ ├── recommend.cpython-39.pyc
│ ├── data_loader.cpython-39.pyc
│ └── ablation_model.cpython-39.pyc
├── loss.py
├── layers.py
├── beam.py
├── COGNet.py
├── data_loader.py
├── MICRON.py
├── recommend.py
├── util.py
├── COGNet_model.py
└── models.py
├── mimic_env.yaml
└── README.md
/data/idx2ndc.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/idx2ndc.pkl
--------------------------------------------------------------------------------
/data/idx2drug.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/idx2drug.pkl
--------------------------------------------------------------------------------
/data/ndc2drug.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/ndc2drug.pkl
--------------------------------------------------------------------------------
/data/ddi_A_final.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/ddi_A_final.pkl
--------------------------------------------------------------------------------
/data/ddi_mask_H.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/ddi_mask_H.pkl
--------------------------------------------------------------------------------
/data/ehr_adj_final.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/data/ehr_adj_final.pkl
--------------------------------------------------------------------------------
/src/__pycache__/beam.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/beam.cpython-39.pyc
--------------------------------------------------------------------------------
/src/__pycache__/loss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/loss.cpython-39.pyc
--------------------------------------------------------------------------------
/src/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/src/__pycache__/layers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/layers.cpython-39.pyc
--------------------------------------------------------------------------------
/src/__pycache__/models.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/models.cpython-39.pyc
--------------------------------------------------------------------------------
/src/__pycache__/model_v2.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/model_v2.cpython-39.pyc
--------------------------------------------------------------------------------
/src/__pycache__/recommend.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/recommend.cpython-39.pyc
--------------------------------------------------------------------------------
/src/__pycache__/data_loader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/data_loader.cpython-39.pyc
--------------------------------------------------------------------------------
/src/__pycache__/ablation_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BarryRun/COGNet/HEAD/src/__pycache__/ablation_model.cpython-39.pyc
--------------------------------------------------------------------------------
/data/ddi_mask_H.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem
2 | from rdkit.Chem import BRICS
3 | import dill
4 | import numpy as np
5 | # from pinky.smiles import smilin
6 | # from pinky.fingerprints import ecfp
7 |
8 | NDCList = dill.load(open('./idx2drug.pkl', 'rb'))
9 | voc = dill.load(open('./voc_final.pkl', 'rb'))
10 | med_voc = voc['med_voc']
11 |
12 | """
13 | fraction = []
14 | for k, v in med_voc.idx2word.items():
15 | tempF = []
16 | for SMILES in NDCList[v]:
17 | try:
18 | m = ecfp(smilin(SMILES), radius=2)
19 | tempF.append(m)
20 | except:
21 | pass
22 | fraction.append(tempF)
23 |
24 |
25 | """
26 | NDCList[22] = {0}
27 | NDCList[25] = {0}
28 | NDCList[27] = {0}
29 | fraction = []
30 | for k, v in med_voc.idx2word.items():
31 | tempF = set()
32 | for SMILES in NDCList[v]:
33 | try:
34 | m = BRICS.BRICSDecompose(Chem.MolFromSmiles(SMILES))
35 | for frac in m:
36 | tempF.add(frac)
37 | except:
38 | pass
39 |
40 | fraction.append(tempF)
41 |
42 | fracSet = []
43 | for i in fraction:
44 | fracSet += i
45 | fracSet = list(set(fracSet))
46 |
47 | ddi_matrix = np.zeros((len(med_voc.idx2word), len(fracSet)))
48 |
49 | for i, fracList in enumerate(fraction):
50 | for frac in fracList:
51 | ddi_matrix[i, fracSet.index(frac)] = 1
52 |
53 | dill.dump(ddi_matrix, open('ddi_mask_H.pkl', 'wb'))
54 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class cross_entropy_loss(nn.Module):
7 | def __init__(self, device):
8 | super().__init__()
9 | self.device = device
10 |
11 | def forward(self, labels, logits, seq_length, m_length_matrix, med_num, END_TOKEN):
12 | # labels: [batch_size, max_seq_length, max_med_num]
13 | # logits: [batch_size, max_seq_length, max_med_num, med_num]
14 | # m_length_matrix: [batch_size, seq_length]
15 | # seq_length: [batch_size]
16 |
17 | batch_size, max_seq_length = labels.size()[:2]
18 | assert max_seq_length == max(seq_length)
19 | whole_seqs_num = seq_length.sum().item()
20 | whole_med_sum = sum([sum(buf) for buf in m_length_matrix]) + whole_seqs_num
21 |
22 | labels_flatten = torch.empty(whole_med_sum).to(self.device)
23 | logits_flatten = torch.empty(whole_med_sum, med_num).to(self.device)
24 |
25 | start_idx = 0
26 | for i in range(batch_size):
27 | for j in range(seq_length[i]):
28 | for k in range(m_length_matrix[i][j]+1):
29 | if k==m_length_matrix[i][j]:
30 | labels_flatten[start_idx] = END_TOKEN
31 | else:
32 | labels_flatten[start_idx] = labels[i, j, k]
33 | logits_flatten[start_idx, :] = logits[i, j, k, :]
34 | start_idx += 1
35 |
36 |
37 | loss = F.cross_entropy(logits_flatten, labels_flatten.long())
38 | return loss
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/src/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch.nn.parameter import Parameter
7 |
8 |
9 | class GraphConvolution(nn.Module):
10 | """
11 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
12 | """
13 |
14 | def __init__(self, in_features, out_features, bias=True):
15 | super(GraphConvolution, self).__init__()
16 | self.in_features = in_features
17 | self.out_features = out_features
18 | self.weight = Parameter(torch.FloatTensor(in_features, out_features))
19 | if bias:
20 | self.bias = Parameter(torch.FloatTensor(out_features))
21 | else:
22 | self.register_parameter('bias', None)
23 | self.reset_parameters()
24 |
25 | def reset_parameters(self):
26 | stdv = 1. / math.sqrt(self.weight.size(1))
27 | self.weight.data.uniform_(-stdv, stdv)
28 | if self.bias is not None:
29 | self.bias.data.uniform_(-stdv, stdv)
30 |
31 | def forward(self, input, adj):
32 | support = torch.mm(input, self.weight)
33 | output = torch.mm(adj, support)
34 | if self.bias is not None:
35 | return output + self.bias
36 | else:
37 | return output
38 |
39 | def __repr__(self):
40 | return self.__class__.__name__ + ' (' \
41 | + str(self.in_features) + ' -> ' \
42 | + str(self.out_features) + ')'
43 |
44 |
45 | class SelfAttend(nn.Module):
46 | def __init__(self, embedding_size: int) -> None:
47 | super(SelfAttend, self).__init__()
48 |
49 | self.h1 = nn.Sequential(
50 | nn.Linear(embedding_size, 32),
51 | nn.Tanh()
52 | )
53 |
54 | self.gate_layer = nn.Linear(32, 1)
55 |
56 | def forward(self, seqs, seq_masks=None):
57 | """
58 | :param seqs: shape [batch_size, seq_length, embedding_size]
59 | :param seq_lens: shape [batch_size, seq_length]
60 | :return: shape [batch_size, seq_length, embedding_size]
61 | """
62 | gates = self.gate_layer(self.h1(seqs)).squeeze(-1)
63 | if seq_masks is not None:
64 | gates = gates + seq_masks
65 | p_attn = F.softmax(gates, dim=-1)
66 | p_attn = p_attn.unsqueeze(-1)
67 | h = seqs * p_attn
68 | output = torch.sum(h, dim=1)
69 | return output
70 |
--------------------------------------------------------------------------------
/data/get_SMILES.py:
--------------------------------------------------------------------------------
1 | import dill
2 | import numpy as np
3 | import pandas as pd
4 | import requests
5 | import re
6 |
7 | # fix mismatch between two mappings
8 | def fix_mismatch(idx2atc, atc2ndc, ndc2atc_original_path):
9 | ndc2atc = pd.read_csv(open(ndc2atc_original_path, 'rb'))
10 | ndc2atc.ATC4 = ndc2atc.ATC4.apply(lambda x: x[:4])
11 |
12 | mismatch = []
13 | for k, v in idx2atc.items():
14 | if v in atc2ndc.NDC.tolist():
15 | pass
16 | else:
17 | mismatch.append(v)
18 |
19 | for i in mismatch:
20 | atc2ndc = atc2ndc.append({'NDC': i, 'NDC_orig': [s.replace('-', '') for s in ndc2atc[ndc2atc.ATC4 == i].NDC.tolist()]}, ignore_index=True)
21 |
22 | atc2ndc = atc2ndc.append({'NDC': 'seperator', 'NDC_orig': []}, ignore_index=True)
23 | atc2ndc = atc2ndc.append({'NDC': 'decoder_point', 'NDC_orig': []}, ignore_index=True)
24 |
25 | return atc2ndc
26 |
27 | def ndc2smiles(NDC):
28 | url3 = 'https://ndclist.com/?s=' + NDC
29 | r3 = requests.get(url3)
30 | name = re.findall('
(.+?) | ', r3.text)[0]
31 |
32 | url = 'https://dev.drugbankplus.com/guides/tutorials/api_request?request_path=us/product_concepts?q=' + name
33 | r = requests.get(url)
34 | drugbankID = re.findall('(DB\d+)', r.text)[0]
35 |
36 | # re matching might need to update (drugbank may change their html script)
37 | url2 = 'https://www.drugbank.ca/drugs/' + drugbankID
38 | r2 = requests.get(url2)
39 | SMILES = re.findall('SMILES(.+?)
', r2.text)[0]
40 | return SMILES
41 |
42 | def atc2smiles(atc2ndc):
43 | atc2SMILES = {}
44 | for k, ndc in atc2ndc.values:
45 | if k not in list(atc2SMILES.keys()):
46 | for index, code in enumerate(ndc):
47 | if index > 100: break
48 | try:
49 | SMILES = ndc2smiles(code)
50 | if 'href' in SMILES:
51 | continue
52 | print (k, index, len(ndc), SMILES)
53 | if k not in atc2SMILES:
54 | atc2SMILES[k] = set()
55 | atc2SMILES[k].add(SMILES)
56 | # if len(atc2SMILES[k]) >= 3:
57 | # break
58 | except:
59 | pass
60 | return atc2SMILES
61 |
62 |
63 | def idx2smiles(idx2atc, atc2SMILES):
64 | idx2drug = {}
65 | idx2drug['seperator'] = {}
66 | idx2drug['decoder_point'] = {}
67 |
68 | for idx, atc in idx2atc.items():
69 | try:
70 | idx2drug[idx] = atc2SMILES[atc]
71 | except:
72 | pass
73 | dill.dump(idx2drug, open('idx2drug.pkl', 'wb'))
74 |
75 |
76 | if __name__ == '__main__':
77 | # get idx2atc
78 | path = './voc_final.pkl'
79 | voc_final = dill.load(open(path, 'rb'))
80 | idx2atc = voc_final['med_voc'].idx2word
81 |
82 | # get atc2ndc
83 | path = './ndc2drug.pkl'
84 | atc2ndc = dill.load(open(path, 'rb'))
85 |
86 | # fix atc2ndc mismatch
87 | ndc2atc_original_path = './ndc2atc_level4.csv'
88 | atc2ndc = fix_mismatch(idx2atc, atc2ndc, ndc2atc_original_path)
89 |
90 | # atc2smiles
91 | atc2SMILES = atc2smiles(atc2ndc)
92 |
93 | # idx2smiles (dumpped)
94 | idx2smiles(idx2atc, atc2SMILES)
95 |
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/mimic_env.yaml:
--------------------------------------------------------------------------------
1 | name: mimic
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=conda_forge
8 | - _openmp_mutex=4.5=1_gnu
9 | - autopep8=1.5.7=pyhd3eb1b0_0
10 | - blas=1.0=mkl
11 | - boost=1.74.0=py39h5472131_3
12 | - boost-cpp=1.74.0=h312852a_4
13 | - bottleneck=1.3.2=py39hdd57654_1
14 | - bzip2=1.0.8=h7f98852_4
15 | - ca-certificates=2021.7.5=h06a4308_1
16 | - cairo=1.16.0=h6cf1ce9_1008
17 | - certifi=2021.5.30=py39h06a4308_0
18 | - cudatoolkit=11.0.221=h6bb024c_0
19 | - cycler=0.10.0=py_2
20 | - dill=0.3.4=pyhd3eb1b0_0
21 | - fontconfig=2.13.1=hba837de_1005
22 | - freetype=2.10.4=h0708190_1
23 | - gettext=0.19.8.1=h0b5b191_1005
24 | - greenlet=1.1.1=py39he80948d_0
25 | - icu=68.1=h58526e2_0
26 | - intel-openmp=2021.3.0=h06a4308_3350
27 | - jbig=2.1=h7f98852_2003
28 | - joblib=1.0.1=pyhd3eb1b0_0
29 | - jpeg=9d=h36c2ea0_0
30 | - kiwisolver=1.3.1=py39h1a9c180_1
31 | - lcms2=2.12=hddcbb42_0
32 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
33 | - lerc=2.2.1=h9c3ff4c_0
34 | - libdeflate=1.7=h7f98852_5
35 | - libffi=3.3=h58526e2_2
36 | - libgcc-ng=11.1.0=hc902ee8_8
37 | - libgfortran-ng=7.5.0=ha8ba4b0_17
38 | - libgfortran4=7.5.0=ha8ba4b0_17
39 | - libglib=2.68.3=h3e27bee_0
40 | - libgomp=11.1.0=hc902ee8_8
41 | - libiconv=1.16=h516909a_0
42 | - libpng=1.6.37=h21135ba_2
43 | - libstdcxx-ng=11.1.0=h56837e0_8
44 | - libtiff=4.3.0=hf544144_1
45 | - libuuid=2.32.1=h7f98852_1000
46 | - libwebp-base=1.2.0=h7f98852_2
47 | - libxcb=1.13=h7f98852_1003
48 | - libxml2=2.9.12=h72842e0_0
49 | - lz4-c=1.9.3=h9c3ff4c_1
50 | - matplotlib-base=3.4.2=py39h2fa2bec_0
51 | - mkl=2021.3.0=h06a4308_520
52 | - mkl-service=2.4.0=py39h7f8727e_0
53 | - mkl_fft=1.3.0=py39h42c9631_2
54 | - mkl_random=1.2.2=py39h51133e4_0
55 | - ncurses=6.2=h58526e2_4
56 | - ninja=1.10.2=hff7bd54_1
57 | - numexpr=2.7.3=py39h22e1b3c_1
58 | - numpy=1.20.3=py39hf144106_0
59 | - numpy-base=1.20.3=py39h74d4b33_0
60 | - olefile=0.46=pyh9f0ad1d_1
61 | - openjpeg=2.4.0=hb52868f_1
62 | - openssl=1.1.1k=h27cfd23_0
63 | - pandas=1.3.1=py39h8c16a72_0
64 | - pcre=8.45=h9c3ff4c_0
65 | - pillow=8.3.1=py39ha612740_0
66 | - pip=21.2.3=pyhd8ed1ab_0
67 | - pixman=0.40.0=h36c2ea0_0
68 | - pthread-stubs=0.4=h36c2ea0_1001
69 | - pycairo=1.20.1=py39hedcb9fc_0
70 | - pycodestyle=2.7.0=pyhd3eb1b0_0
71 | - pyparsing=2.4.7=pyh9f0ad1d_0
72 | - python=3.9.6=h49503c6_1_cpython
73 | - python-dateutil=2.8.2=pyhd8ed1ab_0
74 | - python_abi=3.9=2_cp39
75 | - pytorch=1.7.1=py3.9_cuda11.0.221_cudnn8.0.5_0
76 | - pytz=2021.1=pyhd8ed1ab_0
77 | - rdkit=2021.03.4=py39hccf6a74_0
78 | - readline=8.1=h46c0cb4_0
79 | - reportlab=3.5.68=py39he59360d_0
80 | - scikit-learn=0.24.2=py39ha9443f7_0
81 | - scipy=1.6.2=py39had2a1c9_1
82 | - setuptools=49.6.0=py39hf3d152e_3
83 | - six=1.16.0=pyh6c4a22f_0
84 | - sqlalchemy=1.4.22=py39h3811e60_0
85 | - sqlite=3.36.0=h9cd32fc_0
86 | - threadpoolctl=2.2.0=pyhb85f177_0
87 | - tk=8.6.10=h21135ba_1
88 | - toml=0.10.2=pyhd3eb1b0_0
89 | - torchaudio=0.7.2=py39
90 | - torchvision=0.2.2=py_3
91 | - tornado=6.1=py39h3811e60_1
92 | - typing_extensions=3.10.0.0=pyh06a4308_0
93 | - tzdata=2021a=he74cb21_1
94 | - wheel=0.36.2=pyhd3deb0d_0
95 | - xorg-kbproto=1.0.7=h7f98852_1002
96 | - xorg-libice=1.0.10=h7f98852_0
97 | - xorg-libsm=1.2.3=hd9c2040_1000
98 | - xorg-libx11=1.7.2=h7f98852_0
99 | - xorg-libxau=1.0.9=h7f98852_0
100 | - xorg-libxdmcp=1.1.3=h7f98852_0
101 | - xorg-libxext=1.3.4=h7f98852_1
102 | - xorg-libxrender=0.9.10=h7f98852_1003
103 | - xorg-renderproto=0.11.1=h7f98852_1002
104 | - xorg-xextproto=7.3.0=h7f98852_1002
105 | - xorg-xproto=7.0.31=h7f98852_1007
106 | - xz=5.2.5=h516909a_1
107 | - zlib=1.2.11=h516909a_1010
108 | - zstd=1.5.0=ha95c52a_0
109 | - pip:
110 | - blessings==1.7
111 | - dnc==1.1.0
112 | - flann==1.6.13
113 | - gpustat==0.6.0
114 | - nvidia-ml-py3==7.352.0
115 | - psutil==5.8.0
116 | prefix: /home/wurui/miniconda3/envs/mimic
117 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Implementation of WWW 2022 paper: Conditional Generation Net for Medication Recommendation
2 |
3 | ### Folder Specification
4 | - mimic_env.yaml
5 | - src/
6 | - COGNet.py: train/test COGNet
7 | - recommend.py: some test function used for COGNet
8 | - COGNet_modelt.py: full model of COGNet
9 | - COGNet_ablation.py: ablation models of COGNet
10 | - train/test baselines:
11 | - MICRON.py
12 | - Other code of train/test baselines can be find [here](https://github.com/ycq091044/SafeDrug).
13 | - models.py: baseline models
14 | - util.py
15 | - layer.py
16 | - data/ **(For a fair comparision, we use the same data and pre-processing scripts used in [Safedrug](https://github.com/ycq091044/SafeDrug))**
17 | - mapping files that collected from external sources
18 | - drug-atc.csv: drug to atc code mapping file
19 | - drug-DDI.csv: this a large file, could be downloaded from https://drive.google.com/file/d/1mnPc0O0ztz0fkv3HF-dpmBb8PLWsEoDz/view?usp=sharing
20 | - ndc2atc_level4.csv: NDC code to ATC-4 code mapping file
21 | - ndc2rxnorm_mapping.txt: NDC to xnorm mapping file
22 | - idx2drug.pkl: drug ID to drug SMILES string dict
23 | - other files that generated from mapping files and MIMIC dataset (we attach these files here, user could use our provided scripts to generate)
24 | - data_final.pkl: intermediate result
25 | - ddi_A_final.pkl: ddi matrix
26 | - ddi_matrix_H.pkl: H mask structure (This file is created by ddi_mask_H.py), used in Safedrug baseline
27 | - idx2ndc.pkl: idx2ndc mapping file
28 | - ndc2drug.pkl: ndc2drug mapping file
29 | - Under MIMIC Dataset policy, we are not allowed to distribute the datasets. Practioners could go to https://physionet.org/content/mimiciii/1.4/ and requrest the access to MIMIC-III dataset and then run our processing script to get the complete preprocessed dataset file.
30 | - voc_final.pkl: diag/prod/med dictionary
31 | - dataset processing scripts
32 | - processing.py: is used to process the MIMIC original dataset.
33 |
34 |
35 |
36 |
37 | ### Step 1: Data Processing
38 |
39 | - Go to https://physionet.org/content/mimiciii/1.4/ to download the MIMIC-III dataset (You may need to get the certificate)
40 |
41 | - go into the folder and unzip three main files (PROCEDURES_ICD.csv.gz, PRESCRIPTIONS.csv.gz, DIAGNOSES_ICD.csv.gz)
42 |
43 | - change the path in processing.py and processing the data to get a complete records_final.pkl
44 |
45 | ```python
46 | vim processing.py
47 |
48 | # line 310-312
49 | # med_file = '/data/mimic-iii/PRESCRIPTIONS.csv'
50 | # diag_file = '/data/mimic-iii/DIAGNOSES_ICD.csv'
51 | # procedure_file = '/data/mimic-iii/PROCEDURES_ICD.csv'
52 |
53 | python processing.py
54 | ```
55 |
56 | - run ddi_mask_H.py to get the ddi_mask_H.pkl
57 |
58 | ```python
59 | python ddi_mask_H.py
60 | ```
61 |
62 |
63 |
64 | ### Step 2: Package Dependency
65 |
66 | - First, install the [conda](https://www.anaconda.com/)
67 |
68 | - Then, create the conda environment through yaml file
69 | ```python
70 | conda env create -f mimic_env.yaml
71 | ```
72 | Note: maybe you need to upgrade the PyTorch to the 1.10.0 version. (Thank [Thomaswbt](https://github.com/Thomaswbt) a lot!)
73 |
74 |
75 | ### Step 3: run the code
76 |
77 | ```python
78 | python COGNet.py
79 | ```
80 |
81 | here is the argument:
82 |
83 | usage: COGNet.py [-h] [--Test] [--model_name MODEL_NAME]
84 | [--resume_path RESUME_PATH] [--lr LR]
85 | [--target_ddi TARGET_DDI] [--kp KP] [--dim DIM]
86 |
87 | optional arguments:
88 | -h, --help show this help message and exit
89 | --Test test mode
90 | --model_name MODEL_NAME
91 | model name
92 | --resume_path RESUME_PATH
93 | resume path
94 | --lr LR learning rate
95 | --batch_size batch size
96 | --emb_dim dimension size of embedding
97 | --max_len max number of recommended medications
98 | --beam_size number of ways in beam search
99 |
100 | If you cannot run the code on GPU, just change line 61, "cuda" to "cpu".
101 |
102 | ### Citation
103 | ```bibtex
104 | @inproceedings{wu2022cognet,
105 | title = {Conditional Generation Net for Medication Recommendation},
106 | author = {Rui Wu, Zhaopeng Qiu, Jiacheng Jiang, Guilin Qi, and Xian Wu.},
107 | booktitle = {{WWW} '22: The Web Conference 2022, Virtual Event, Lyon, France, April 25-29, 2022},
108 | year = {2022}
109 | }
110 | ```
111 |
112 | Please feel free to contact me for any question.
113 |
114 | Partial credit to previous reprostories:
115 | - https://github.com/sjy1203/GAMENet
116 | - https://github.com/ycq091044/SafeDrug
117 | - https://github.com/ycq091044/MICRON
118 |
119 | Thank [Chaoqi Yang](https://github.com/ycq091044) and [Junyuan Shang](https://github.com/sjy1203) for releasing their codes!
120 |
121 | Thank my mentor, [Zhaopeng Qiu](https://github.com/zpqiu), for helping me complete the code.
122 |
--------------------------------------------------------------------------------
/src/beam.py:
--------------------------------------------------------------------------------
1 | """ Manage beam search info structure.
2 | Heavily borrowed from OpenNMT-py.
3 | For code in OpenNMT-py, please check the following link:
4 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
5 | """
6 |
7 | from shutil import which
8 | from util import ddi_rate_score
9 | import torch
10 | import numpy as np
11 | import copy
12 | import random
13 | from torch.autograd.grad_mode import F
14 |
15 |
16 | class Beam(object):
17 | ''' Store the necessary info for beam search '''
18 | def __init__(self, size, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, ddi_adj, device):
19 | self.ddi_adj = ddi_adj
20 | self.PAD = PAD_TOKEN
21 | self.BOS = BOS_TOKEN
22 | self.EOS = EOS_TOKEN
23 | # print(PAD_TOKEN, EOS_TOKEN, BOS_TOKEN)
24 |
25 | self.device = device
26 | self.size = size
27 | self.done = False # 表示当前是否已经完成了beam search的过程
28 |
29 | self.beam_status = [False] * size # 用于记录每一个beam是否已经处于EOS状态
30 |
31 | self.tt = torch.cuda if device.type=='cuda' else torch
32 |
33 | # 每一个生成结果的分数,初始是beam_size个0
34 | self.scores = self.tt.FloatTensor(size).zero_()
35 | self.all_scores = []
36 |
37 | # The backpointers at each time-step.
38 | self.prev_ks = []
39 |
40 | # The outputs at each time-step.
41 | self.next_ys = [self.tt.LongTensor(size).fill_(self.BOS)]
42 | self.prob_list = []
43 |
44 |
45 | def get_current_state(self, sort=True):
46 | "Get the outputs for the current timestep."
47 | if sort:
48 | return self.get_tentative_hypothesis()
49 | else:
50 | return self.get_tentative_hypothesis_wo_sort()
51 |
52 | def get_current_origin(self):
53 | "Get the backpointers for the current timestep."
54 | return self.prev_ks[-1]
55 |
56 | def advance(self, word_lk):
57 | "Update the status and check for finished or not."
58 | num_words = word_lk.size(1)
59 | if self.done:
60 | self.prev_ks.append(torch.tensor(list(range(self.size)), device=self.device))
61 | self.next_ys.append(torch.tensor([self.EOS]*self.size, device=self.device))
62 | self.prob_list.append([[0]*num_words, [0]*num_words])
63 | return True
64 |
65 | active_beam_idx = torch.tensor([idx for idx in range(self.size) if self.beam_status[idx]==False]).long().to(self.device)
66 | end_beam_idx = torch.tensor([idx for idx in range(self.size) if self.beam_status[idx]==True]).long().to(self.device)
67 | active_word_lk = word_lk[active_beam_idx] # active_beam_num * num_words
68 |
69 | cur_output = self.get_current_state(sort=False)
70 |
71 | active_scores = self.scores[active_beam_idx]
72 | end_scores = self.scores[end_beam_idx]
73 |
74 | if len(self.prev_ks) > 0:
75 | beam_lk = active_word_lk + active_scores.unsqueeze(dim=1).expand_as(active_word_lk) # (active_beam_num, num_words)
76 | else:
77 | beam_lk = active_word_lk[0]
78 |
79 | flat_beam_lk = beam_lk.view(-1)
80 | active_max_idx = len(flat_beam_lk)
81 | flat_beam_lk = torch.cat([flat_beam_lk, end_scores], dim=-1)
82 |
83 | self.all_scores.append(self.scores)
84 |
85 |
86 | sorted_scores, sorted_score_ids = torch.sort(flat_beam_lk, descending=True)
87 | select_num, cur_idx = 0, 0
88 | selected_scores = []
89 | selected_words = []
90 | selected_beams = []
91 | new_active_status = []
92 |
93 | prob_buf = []
94 | while select_num < self.size:
95 | cur_score, cur_id = sorted_scores[cur_idx], sorted_score_ids[cur_idx]
96 | if cur_id >= active_max_idx:
97 | which_beam = end_beam_idx[cur_id-active_max_idx]
98 | which_word = torch.tensor(self.EOS).to(self.device)
99 | select_num += 1
100 | new_active_status.append(True)
101 | selected_scores.append(cur_score)
102 | selected_beams.append(which_beam)
103 | selected_words.append(which_word)
104 | prob_buf.append([0]*num_words)
105 | else:
106 | which_beam_idx = cur_id // num_words
107 | which_beam = active_beam_idx[which_beam_idx]
108 | which_word = cur_id - which_beam_idx*num_words
109 | if which_word not in cur_output[which_beam]:
110 | if which_word in [self.EOS, self.BOS]:
111 | new_active_status.append(True)
112 | else:
113 | new_active_status.append(False)
114 | select_num += 1
115 | selected_scores.append(cur_score)
116 | selected_beams.append(which_beam)
117 | selected_words.append(which_word)
118 | prob_buf.append(active_word_lk[which_beam_idx].detach().cpu().numpy().tolist())
119 | cur_idx += 1
120 | self.prob_list.append(prob_buf)
121 |
122 | self.beam_status = new_active_status
123 | self.scores = torch.stack(selected_scores)
124 | self.prev_ks.append(torch.stack(selected_beams))
125 | self.next_ys.append(torch.stack(selected_words))
126 |
127 | if_done = True
128 | for i in range(self.size):
129 | if not self.beam_status[i]:
130 | if_done = False
131 | break
132 | if if_done:
133 | self.done=True
134 |
135 | return self.done
136 |
137 | def sort_scores(self):
138 | "Sort the scores."
139 | return torch.sort(self.scores, 0, True)
140 |
141 |
142 | def get_tentative_hypothesis(self):
143 | "Get the decoded sequence for the current timestep."
144 |
145 | if len(self.next_ys) == 1:
146 | dec_seq = self.next_ys[0].unsqueeze(1)
147 | else:
148 | _, keys = self.sort_scores()
149 | hyps = [self.get_hypothesis(k) for k in keys]
150 | hyps = [[self.BOS] + h for h in hyps]
151 | dec_seq = torch.from_numpy(np.array(hyps)).long().to(self.device)
152 | return dec_seq
153 |
154 | def get_tentative_hypothesis_wo_sort(self):
155 | "Get the decoded sequence for the current timestep."
156 |
157 | if len(self.next_ys) == 1:
158 | dec_seq = self.next_ys[0].unsqueeze(1)
159 | else:
160 | keys = list(range(self.size))
161 | hyps = [self.get_hypothesis(k) for k in keys]
162 | hyps = [[self.BOS] + h for h in hyps]
163 | dec_seq = torch.from_numpy(np.array(hyps)).long().to(self.device)
164 | return dec_seq
165 |
166 | def get_hypothesis(self, k):
167 | """
168 | Walk back to construct the full hypothesis.
169 | Parameters.
170 | * `k` - the position in the beam to construct.
171 | Returns.
172 | 1. The hypothesis
173 | 2. The attention at each time step.
174 | """
175 | hyp = []
176 | for j in range(len(self.prev_ks)-1, -1, -1):
177 | hyp.append(self.next_ys[j + 1][k].item())
178 | k = self.prev_ks[j][k]
179 | return hyp[::-1]
180 |
181 | def get_prob_list(self, k):
182 | ret_prob_list = []
183 | for j in range(len(self.prev_ks)-1, -1, -1):
184 | ret_prob_list.append(self.prob_list[j][k])
185 | k = self.prev_ks[j][k]
186 | return ret_prob_list[::-1]
--------------------------------------------------------------------------------
/src/COGNet.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import torch
3 | import torch.nn as nn
4 | import argparse
5 | from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score
6 | import numpy as np
7 | import dill
8 | import time
9 | from torch.nn import CrossEntropyLoss
10 | from torch.optim import Adam
11 | from torch.utils import data
12 | from loss import cross_entropy_loss
13 | import os
14 | import torch.nn.functional as F
15 | import random
16 | from collections import defaultdict
17 |
18 | from torch.utils.data.dataloader import DataLoader
19 | from data_loader import mimic_data, pad_batch_v2_train, pad_batch_v2_eval, pad_num_replace
20 |
21 | import sys
22 | sys.path.append("..")
23 | from COGNet_ablation import COGNet_wo_copy, COGNet_wo_visit_score, COGNet_wo_graph, COGNet_wo_diag, COGNet_wo_proc
24 | from COGNet_model import COGNet, policy_network
25 | from util import llprint, sequence_metric, sequence_output_process, ddi_rate_score, get_n_params, output_flatten, print_result
26 | from recommend import eval, test
27 |
28 | torch.manual_seed(1203)
29 |
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
31 |
32 | model_name = 'COGNet'
33 | resume_path = ''
34 |
35 | if not os.path.exists(os.path.join("saved", model_name)):
36 | os.makedirs(os.path.join("saved", model_name))
37 |
38 | # Training settings
39 | parser = argparse.ArgumentParser()
40 | # parser.add_argument('--Test', action='store_true', default=True, help="test mode")
41 | parser.add_argument('--Test', action='store_true', default=False, help="test mode")
42 | parser.add_argument('--model_name', type=str, default=model_name, help="model name")
43 | parser.add_argument('--resume_path', type=str, default=resume_path, help='resume path')
44 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
45 | parser.add_argument('--batch_size', type=int, default=16, help='batch_size')
46 | parser.add_argument('--emb_dim', type=int, default=64, help='embedding dimension size')
47 | parser.add_argument('--max_len', type=int, default=45, help='maximum prediction medication sequence')
48 | parser.add_argument('--beam_size', type=int, default=4, help='max num of sentences in beam searching')
49 |
50 | args = parser.parse_args()
51 |
52 | def main(args):
53 | # load data
54 | data_path = '../data/records_final.pkl'
55 | voc_path = '../data/voc_final.pkl'
56 |
57 | # ehr_adj_path = '../data/weighted_ehr_adj_final.pkl'
58 | ehr_adj_path = '../data/ehr_adj_final.pkl'
59 | ddi_adj_path = '../data/ddi_A_final.pkl'
60 | ddi_mask_path = '../data/ddi_mask_H.pkl'
61 | device = torch.device('cuda')
62 |
63 | data = dill.load(open(data_path, 'rb'))
64 | voc = dill.load(open(voc_path, 'rb'))
65 | ehr_adj = dill.load(open(ehr_adj_path, 'rb'))
66 | ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
67 | ddi_mask_H = dill.load(open(ddi_mask_path, 'rb'))
68 |
69 | diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']
70 | print(f"Diag num:{len(diag_voc.idx2word)}")
71 | print(f"Proc num:{len(pro_voc.idx2word)}")
72 | print(f"Med num:{len(med_voc.idx2word)}")
73 |
74 | # frequency statistic
75 | med_count = defaultdict(int)
76 | for patient in data:
77 | for adm in patient:
78 | for med in adm[2]:
79 | med_count[med] += 1
80 |
81 | ## rare first
82 | for i in range(len(data)):
83 | for j in range(len(data[i])):
84 | cur_medications = sorted(data[i][j][2], key=lambda x:med_count[x])
85 | data[i][j][2] = cur_medications
86 |
87 |
88 | split_point = int(len(data) * 2 / 3)
89 | data_train = data[:split_point]
90 | eval_len = int(len(data[split_point:]) / 2)
91 | data_test = data[split_point:split_point + eval_len]
92 | data_eval = data[split_point+eval_len:]
93 |
94 | train_dataset = mimic_data(data_train)
95 | eval_dataset = mimic_data(data_eval)
96 | test_dataset = mimic_data(data_test)
97 |
98 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=pad_batch_v2_train, shuffle=True, pin_memory=True)
99 | eval_dataloader = DataLoader(eval_dataset, batch_size=1, collate_fn=pad_batch_v2_eval, shuffle=True, pin_memory=True)
100 | test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=pad_batch_v2_eval, shuffle=True, pin_memory=True)
101 |
102 | voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))
103 |
104 | END_TOKEN = voc_size[2] + 1
105 | DIAG_PAD_TOKEN = voc_size[0] + 2
106 | PROC_PAD_TOKEN = voc_size[1] + 2
107 | MED_PAD_TOKEN = voc_size[2] + 2
108 | SOS_TOKEN = voc_size[2]
109 | TOKENS = [END_TOKEN, DIAG_PAD_TOKEN, PROC_PAD_TOKEN, MED_PAD_TOKEN, SOS_TOKEN]
110 |
111 | model = COGNet(voc_size, ehr_adj, ddi_adj, ddi_mask_H, emb_dim=args.emb_dim, device=device)
112 |
113 | if args.Test:
114 | model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
115 | model.to(device=device)
116 | tic = time.time()
117 | smm_record, ja, prauc, precision, recall, f1, med_num = test(model, test_dataloader, diag_voc, pro_voc, med_voc, voc_size, 0, device, TOKENS, ddi_adj, args)
118 | result = []
119 | for _ in range(10):
120 | data_num = len(ja)
121 | final_length = int(0.8 * data_num)
122 | idx_list = list(range(data_num))
123 | random.shuffle(idx_list)
124 | idx_list = idx_list[:final_length]
125 | avg_ja = np.mean([ja[i] for i in idx_list])
126 | avg_prauc = np.mean([prauc[i] for i in idx_list])
127 | avg_precision = np.mean([precision[i] for i in idx_list])
128 | avg_recall = np.mean([recall[i] for i in idx_list])
129 | avg_f1 = np.mean([f1[i] for i in idx_list])
130 | avg_med = np.mean([med_num[i] for i in idx_list])
131 | cur_smm_record = [smm_record[i] for i in idx_list]
132 | ddi_rate = ddi_rate_score(cur_smm_record, path='../data/ddi_A_final.pkl')
133 | result.append([ddi_rate, avg_ja, avg_prauc, avg_precision, avg_recall, avg_f1, avg_med])
134 | llprint('\nDDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'.format(
135 | ddi_rate, avg_ja, avg_prauc, avg_precision, avg_recall, avg_f1, avg_med))
136 | result = np.array(result)
137 | mean = result.mean(axis=0)
138 | std = result.std(axis=0)
139 |
140 | outstring = ""
141 | for m, s in zip(mean, std):
142 | outstring += "{:.4f} $\pm$ {:.4f} & ".format(m, s)
143 |
144 | print (outstring)
145 | print ('test time: {}'.format(time.time() - tic))
146 | return
147 |
148 | model.to(device=device)
149 | print('parameters', get_n_params(model))
150 | optimizer = Adam(model.parameters(), lr=args.lr)
151 |
152 | history = defaultdict(list)
153 | best_epoch, best_ja = 0, 0
154 |
155 | EPOCH = 200
156 | for epoch in range(EPOCH):
157 | tic = time.time()
158 | print ('\nepoch {} --------------------------'.format(epoch))
159 |
160 | model.train()
161 | for idx, data in enumerate(train_dataloader):
162 | diseases, procedures, medications, seq_length, \
163 | d_length_matrix, p_length_matrix, m_length_matrix, \
164 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \
165 | dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, \
166 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask = data
167 |
168 | diseases = pad_num_replace(diseases, -1, DIAG_PAD_TOKEN).to(device)
169 | procedures = pad_num_replace(procedures, -1, PROC_PAD_TOKEN).to(device)
170 | dec_disease = pad_num_replace(dec_disease, -1, DIAG_PAD_TOKEN).to(device)
171 | stay_disease = pad_num_replace(stay_disease, -1, DIAG_PAD_TOKEN).to(device)
172 | dec_proc = pad_num_replace(dec_proc, -1, PROC_PAD_TOKEN).to(device)
173 | stay_proc = pad_num_replace(stay_proc, -1, PROC_PAD_TOKEN).to(device)
174 | medications = medications.to(device)
175 | m_mask_matrix = m_mask_matrix.to(device)
176 | d_mask_matrix = d_mask_matrix.to(device)
177 | p_mask_matrix = p_mask_matrix.to(device)
178 | dec_disease_mask = dec_disease_mask.to(device)
179 | stay_disease_mask = stay_disease_mask.to(device)
180 | dec_proc_mask = dec_proc_mask.to(device)
181 | stay_proc_mask = stay_proc_mask.to(device)
182 | output_logits = model(diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix, seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask,
183 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask)
184 | labels, predictions = output_flatten(medications, output_logits, seq_length, m_length_matrix, voc_size[2] + 2, END_TOKEN, device, max_len=args.max_len)
185 | loss = F.nll_loss(predictions, labels.long())
186 | optimizer.zero_grad()
187 | loss.backward()
188 | optimizer.step()
189 | llprint('\rtraining step: {} / {}'.format(idx, len(train_dataloader)))
190 |
191 | print ()
192 | tic2 = time.time()
193 | ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(model, eval_dataloader, voc_size, epoch, device, TOKENS, args)
194 | print ('training time: {}, test time: {}'.format(time.time() - tic, time.time() - tic2))
195 |
196 | history['ja'].append(ja)
197 | history['ddi_rate'].append(ddi_rate)
198 | history['avg_p'].append(avg_p)
199 | history['avg_r'].append(avg_r)
200 | history['avg_f1'].append(avg_f1)
201 | history['prauc'].append(prauc)
202 | history['med'].append(avg_med)
203 |
204 | if epoch >= 5:
205 | print ('ddi: {}, Med: {}, Ja: {}, F1: {}'.format(
206 | np.mean(history['ddi_rate'][-5:]),
207 | np.mean(history['med'][-5:]),
208 | np.mean(history['ja'][-5:]),
209 | np.mean(history['avg_f1'][-5:]),
210 | np.mean(history['prauc'][-5:])
211 | ))
212 |
213 | torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \
214 | 'Epoch_{}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, ja, ddi_rate)), 'wb'))
215 |
216 | if best_ja < ja:
217 | best_epoch = epoch
218 | best_ja = ja
219 |
220 | print ('best_epoch: {}'.format(best_epoch))
221 |
222 | dill.dump(history, open(os.path.join('saved', args.model_name, 'history_{}.pkl'.format(args.model_name)), 'wb'))
223 |
224 |
225 | if __name__ == '__main__':
226 | main(args)
227 |
--------------------------------------------------------------------------------
/data/processing.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import dill
3 | import numpy as np
4 | from collections import defaultdict
5 |
6 |
7 | ##### process medications #####
8 | # load med data
9 | def med_process(med_file):
10 | """读取MIMIC原数据文件,保留pid、adm_id、data以及NDC,以DF类型返回"""
11 | # 读取药物文件,NDC(National Drug Code)以类别类型存储
12 | med_pd = pd.read_csv(med_file, dtype={'NDC':'category'})
13 |
14 | # drop不用的数据
15 | med_pd.drop(columns=['ROW_ID','DRUG_TYPE','DRUG_NAME_POE','DRUG_NAME_GENERIC',
16 | 'FORMULARY_DRUG_CD','PROD_STRENGTH','DOSE_VAL_RX',
17 | 'DOSE_UNIT_RX','FORM_VAL_DISP','FORM_UNIT_DISP', 'GSN', 'FORM_UNIT_DISP',
18 | 'ROUTE','ENDDATE','DRUG'], axis=1, inplace=True)
19 | med_pd.drop(index = med_pd[med_pd['NDC'] == '0'].index, axis=0, inplace=True)
20 | med_pd.fillna(method='pad', inplace=True)
21 | med_pd.dropna(inplace=True)
22 | med_pd.drop_duplicates(inplace=True)
23 | med_pd['ICUSTAY_ID'] = med_pd['ICUSTAY_ID'].astype('int64')
24 | med_pd['STARTDATE'] = pd.to_datetime(med_pd['STARTDATE'], format='%Y-%m-%d %H:%M:%S')
25 | med_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'STARTDATE'], inplace=True)
26 | med_pd = med_pd.reset_index(drop=True) # 重置索引,同时drop原索引
27 |
28 | med_pd = med_pd.drop(columns=['ICUSTAY_ID'])
29 | med_pd = med_pd.drop_duplicates()
30 | med_pd = med_pd.reset_index(drop=True)
31 |
32 | return med_pd
33 |
34 | # medication mapping
35 | def ndc2atc4(med_pd):
36 | """将NDC映射到ACT4"""
37 | with open(ndc_rxnorm_file, 'r') as f:
38 | ndc2rxnorm = eval(f.read())
39 | # 根据ndc_rxnorm_file文件读取ndc到xnorm的映射(这个xnorm似乎等同于下面的RXCUI)
40 | med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm)
41 | med_pd.dropna(inplace=True) # 实际上啥也没删掉
42 |
43 | rxnorm2atc = pd.read_csv(ndc2atc_file)
44 | rxnorm2atc = rxnorm2atc.drop(columns=['YEAR','MONTH','NDC']) # NDC删了,直接从RXCUI映射到ATC
45 | # 根据RXCUI删除重复列
46 | rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True)
47 |
48 | med_pd.drop(index = med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True) # 删除特定的RXCUI
49 |
50 | med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64')
51 | med_pd = med_pd.reset_index(drop=True)
52 | med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI']) # 合并两个表
53 | med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True) # 干掉NDC\RXCUI,只剩ATC4了
54 | med_pd = med_pd.rename(columns={'ATC4':'NDC'}) # 重新命名为NDC
55 | med_pd['NDC'] = med_pd['NDC'].map(lambda x: x[:4]) # 只保留前四位
56 | med_pd = med_pd.drop_duplicates()
57 | med_pd = med_pd.reset_index(drop=True)
58 | return med_pd
59 |
60 | # visit >= 2
61 | def process_visit_lg2(med_pd):
62 | """筛除admission次数小于两次的患者数据"""
63 | a = med_pd[['SUBJECT_ID', 'HADM_ID']].groupby(by='SUBJECT_ID')['HADM_ID'].unique().reset_index()
64 | a['HADM_ID_Len'] = a['HADM_ID'].map(lambda x:len(x))
65 | a = a[a['HADM_ID_Len'] > 1]
66 | return a
67 |
68 | # most common medications
69 | def filter_300_most_med(med_pd):
70 | # 按照NDC出现的次数降序排列,取前300
71 | med_count = med_pd.groupby(by=['NDC']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
72 | med_pd = med_pd[med_pd['NDC'].isin(med_count.loc[:299, 'NDC'])]
73 |
74 | return med_pd.reset_index(drop=True)
75 |
76 | ##### process diagnosis #####
77 | def diag_process(diag_file):
78 | diag_pd = pd.read_csv(diag_file)
79 | diag_pd.dropna(inplace=True)
80 | diag_pd.drop(columns=['SEQ_NUM','ROW_ID'],inplace=True)
81 | diag_pd.drop_duplicates(inplace=True)
82 | diag_pd.sort_values(by=['SUBJECT_ID','HADM_ID'], inplace=True)
83 | diag_pd = diag_pd.reset_index(drop=True)
84 |
85 | def filter_2000_most_diag(diag_pd):
86 | diag_count = diag_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
87 | diag_pd = diag_pd[diag_pd['ICD9_CODE'].isin(diag_count.loc[:1999, 'ICD9_CODE'])]
88 |
89 | return diag_pd.reset_index(drop=True)
90 |
91 | diag_pd = filter_2000_most_diag(diag_pd)
92 |
93 | return diag_pd
94 |
95 | ##### process procedure #####
96 | def procedure_process(procedure_file):
97 | pro_pd = pd.read_csv(procedure_file, dtype={'ICD9_CODE':'category'})
98 | pro_pd.drop(columns=['ROW_ID'], inplace=True)
99 | pro_pd.drop_duplicates(inplace=True)
100 | pro_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'SEQ_NUM'], inplace=True)
101 | pro_pd.drop(columns=['SEQ_NUM'], inplace=True)
102 | pro_pd.drop_duplicates(inplace=True)
103 | pro_pd.reset_index(drop=True, inplace=True)
104 |
105 | return pro_pd
106 |
107 | def filter_1000_most_pro(pro_pd):
108 | pro_count = pro_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
109 | pro_pd = pro_pd[pro_pd['ICD9_CODE'].isin(pro_count.loc[:1000, 'ICD9_CODE'])]
110 |
111 | return pro_pd.reset_index(drop=True)
112 |
113 | ###### combine three tables #####
114 | def combine_process(med_pd, diag_pd, pro_pd):
115 | """药物、症状、proc的数据结合"""
116 |
117 | med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
118 | diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
119 | pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
120 |
121 | combined_key = med_pd_key.merge(diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
122 | combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
123 |
124 | diag_pd = diag_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
125 | med_pd = med_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
126 | pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
127 |
128 | # flatten and merge
129 | diag_pd = diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()
130 | med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()
131 | pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'})
132 | med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x))
133 | pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))
134 | data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
135 | data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
136 | # data['ICD9_CODE_Len'] = data['ICD9_CODE'].map(lambda x: len(x))
137 | data['NDC_Len'] = data['NDC'].map(lambda x: len(x))
138 |
139 | return data
140 |
141 | def statistics(data):
142 | print('#patients ', data['SUBJECT_ID'].unique().shape)
143 | print('#clinical events ', len(data))
144 |
145 | diag = data['ICD9_CODE'].values
146 | med = data['NDC'].values
147 | pro = data['PRO_CODE'].values
148 |
149 | unique_diag = set([j for i in diag for j in list(i)])
150 | unique_med = set([j for i in med for j in list(i)])
151 | unique_pro = set([j for i in pro for j in list(i)])
152 |
153 | print('#diagnosis ', len(unique_diag))
154 | print('#med ', len(unique_med))
155 | print('#procedure', len(unique_pro))
156 |
157 | avg_diag, avg_med, avg_pro, max_diag, max_med, max_pro, cnt, max_visit, avg_visit = [0 for i in range(9)]
158 |
159 | for subject_id in data['SUBJECT_ID'].unique():
160 | item_data = data[data['SUBJECT_ID'] == subject_id]
161 | x, y, z = [], [], []
162 | visit_cnt = 0
163 | for index, row in item_data.iterrows():
164 | visit_cnt += 1
165 | cnt += 1
166 | x.extend(list(row['ICD9_CODE']))
167 | y.extend(list(row['NDC']))
168 | z.extend(list(row['PRO_CODE']))
169 | x, y, z = set(x), set(y), set(z)
170 | avg_diag += len(x)
171 | avg_med += len(y)
172 | avg_pro += len(z)
173 | avg_visit += visit_cnt
174 | if len(x) > max_diag:
175 | max_diag = len(x)
176 | if len(y) > max_med:
177 | max_med = len(y)
178 | if len(z) > max_pro:
179 | max_pro = len(z)
180 | if visit_cnt > max_visit:
181 | max_visit = visit_cnt
182 |
183 | print('#avg of diagnoses ', avg_diag/ cnt)
184 | print('#avg of medicines ', avg_med/ cnt)
185 | print('#avg of procedures ', avg_pro/ cnt)
186 | print('#avg of vists ', avg_visit/ len(data['SUBJECT_ID'].unique()))
187 |
188 | print('#max of diagnoses ', max_diag)
189 | print('#max of medicines ', max_med)
190 | print('#max of procedures ', max_pro)
191 | print('#max of visit ', max_visit)
192 |
193 | ##### indexing file and final record
194 | class Voc(object):
195 | def __init__(self):
196 | self.idx2word = {}
197 | self.word2idx = {}
198 |
199 | def add_sentence(self, sentence):
200 | for word in sentence:
201 | if word not in self.word2idx:
202 | self.idx2word[len(self.word2idx)] = word
203 | self.word2idx[word] = len(self.word2idx)
204 |
205 | # create voc set
206 | def create_str_token_mapping(df):
207 | diag_voc = Voc()
208 | med_voc = Voc()
209 | pro_voc = Voc()
210 |
211 | for index, row in df.iterrows():
212 | diag_voc.add_sentence(row['ICD9_CODE'])
213 | med_voc.add_sentence(row['NDC'])
214 | pro_voc.add_sentence(row['PRO_CODE'])
215 |
216 | dill.dump(obj={'diag_voc':diag_voc, 'med_voc':med_voc ,'pro_voc':pro_voc}, file=open('voc_final.pkl','wb'))
217 | return diag_voc, med_voc, pro_voc
218 |
219 | # create final records
220 | def create_patient_record(df, diag_voc, med_voc, pro_voc):
221 | """
222 | 保存list类型的记录
223 | 每一项代表一个患者,患者中有多个visit,每个visit包含三者数组,按顺序分别表示诊断、proc与药物
224 | 存储的均为编号,可以通过voc_final.pkl来查看对应的具体word
225 | """
226 | records = [] # (patient, code_kind:3, codes) code_kind:diag, proc, med
227 | for subject_id in df['SUBJECT_ID'].unique():
228 | item_df = df[df['SUBJECT_ID'] == subject_id]
229 | patient = []
230 | for index, row in item_df.iterrows():
231 | admission = []
232 | admission.append([diag_voc.word2idx[i] for i in row['ICD9_CODE']])
233 | admission.append([pro_voc.word2idx[i] for i in row['PRO_CODE']])
234 | admission.append([med_voc.word2idx[i] for i in row['NDC']])
235 | patient.append(admission)
236 | records.append(patient)
237 | dill.dump(obj=records, file=open('records_final.pkl', 'wb'))
238 | return records
239 |
240 |
241 |
242 | # get ddi matrix
243 | def get_ddi_matrix(records, med_voc, ddi_file):
244 |
245 | TOPK = 40 # topk drug-drug interaction
246 | cid2atc_dic = defaultdict(set)
247 | med_voc_size = len(med_voc.idx2word)
248 | med_unique_word = [med_voc.idx2word[i] for i in range(med_voc_size)] # 所有的药物的ATC4
249 | atc3_atc4_dic = defaultdict(set)
250 | for item in med_unique_word:
251 | atc3_atc4_dic[item[:4]].add(item) #
252 |
253 | with open(cid_atc, 'r') as f:
254 | for line in f:
255 | line_ls = line[:-1].split(',')
256 | cid = line_ls[0]
257 | atcs = line_ls[1:]
258 | for atc in atcs:
259 | if len(atc3_atc4_dic[atc[:4]]) != 0:
260 | cid2atc_dic[cid].add(atc[:4])
261 |
262 | # 加载DDI数据
263 | ddi_df = pd.read_csv(ddi_file)
264 | # fliter sever side effect,也是采取topK的形式
265 | ddi_most_pd = ddi_df.groupby(by=['Polypharmacy Side Effect', 'Side Effect Name']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
266 | ddi_most_pd = ddi_most_pd.iloc[-TOPK:,:]
267 | # ddi_most_pd = pd.DataFrame(columns=['Side Effect Name'], data=['as','asd','as'])
268 | fliter_ddi_df = ddi_df.merge(ddi_most_pd[['Side Effect Name']], how='inner', on=['Side Effect Name'])
269 | ddi_df = fliter_ddi_df[['STITCH 1','STITCH 2']].drop_duplicates().reset_index(drop=True)
270 |
271 |
272 | # weighted ehr adj
273 | ehr_adj = np.zeros((med_voc_size, med_voc_size))
274 | for patient in records:
275 | for adm in patient:
276 | med_set = adm[2]
277 | for i, med_i in enumerate(med_set):
278 | for j, med_j in enumerate(med_set):
279 | if j<=i:
280 | continue
281 | # ehr_adj[med_i, med_j] = 1
282 | # ehr_adj[med_j, med_i] = 1
283 | ehr_adj[med_i, med_j] += 1
284 | ehr_adj[med_j, med_i] += 1
285 | dill.dump(ehr_adj, open('ehr_adj_final.pkl', 'wb'))
286 |
287 | # ddi adj,DDI表是CID编码的,因此需要将CID映射到ACT编码,才能记录数据集中药物之间的冲突信息
288 | ddi_adj = np.zeros((med_voc_size,med_voc_size))
289 | for index, row in ddi_df.iterrows():
290 | # ddi
291 | cid1 = row['STITCH 1']
292 | cid2 = row['STITCH 2']
293 |
294 | # cid -> atc_level3
295 | for atc_i in cid2atc_dic[cid1]:
296 | for atc_j in cid2atc_dic[cid2]:
297 |
298 | # atc_level3 -> atc_level4
299 | for i in atc3_atc4_dic[atc_i]:
300 | for j in atc3_atc4_dic[atc_j]:
301 | if med_voc.word2idx[i] != med_voc.word2idx[j]:
302 | ddi_adj[med_voc.word2idx[i], med_voc.word2idx[j]] = 1
303 | ddi_adj[med_voc.word2idx[j], med_voc.word2idx[i]] = 1
304 | dill.dump(ddi_adj, open('ddi_A_final.pkl', 'wb'))
305 |
306 | return ddi_adj
307 |
308 | if __name__ == '__main__':
309 | # MIMIC数据文件,分别包括药物、诊断和proc
310 | med_file = '/data/mimic-iii/PRESCRIPTIONS.csv'
311 | diag_file = '/data/mimic-iii/DIAGNOSES_ICD.csv'
312 | procedure_file = '/data/mimic-iii/PROCEDURES_ICD.csv'
313 |
314 | # 药物信息
315 | med_structure_file = './idx2drug.pkl' # 药物到分子式的映射
316 |
317 | # drug code mapping files
318 | ndc2atc_file = './ndc2atc_level4.csv' # NDC code to ATC-4 code mapping file,用于读取xnorm到ATC
319 | cid_atc = './drug-atc.csv' # drug(CID) to ATC code mapping file,用于处理DDI表
320 | ndc_rxnorm_file = './ndc2rxnorm_mapping.txt' # NDC to xnorm mapping file
321 |
322 | # ddi information
323 | # data example
324 | # STITCH 1,STITCH 2,Polypharmacy Side Effect,Side Effect Name
325 | # CID000002173,CID000003345,C0151714,hypermagnesemia
326 | # CID000002173,CID000003345,C0035344,retinopathy of prematurity
327 | # CID000002173,CID000003345,C0004144,atelectasis
328 | # CID000002173,CID000003345,C0002063,alkalosis
329 | # CID000002173,CID000003345,C0004604,Back Ache
330 | # CID000002173,CID000003345,C0034063,lung edema
331 | ddi_file = '/data/drug-DDI.csv'
332 |
333 | # 处理MIMIC中的药物数据
334 | med_pd = med_process(med_file)
335 | med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True) # 注意这里仅仅是针对med表中出现了两次以上admission的patient
336 | med_pd = med_pd.merge(med_pd_lg2[['SUBJECT_ID']], on='SUBJECT_ID', how='inner').reset_index(drop=True)
337 |
338 | med_pd = ndc2atc4(med_pd)
339 | NDCList = dill.load(open(med_structure_file, 'rb'))
340 | med_pd = med_pd[med_pd.NDC.isin(list(NDCList.keys()))]
341 | med_pd = filter_300_most_med(med_pd)
342 |
343 | print ('complete medication processing')
344 |
345 | # for diagnosis
346 | diag_pd = diag_process(diag_file)
347 |
348 | print ('complete diagnosis processing')
349 |
350 | # for procedure
351 | pro_pd = procedure_process(procedure_file)
352 | # pro_pd = filter_1000_most_pro(pro_pd)
353 |
354 | print ('complete procedure processing')
355 |
356 | # combine
357 | data = combine_process(med_pd, diag_pd, pro_pd)
358 | statistics(data)
359 | data.to_pickle('data_final.pkl')
360 |
361 | print ('complete combining')
362 |
363 |
364 | # ddi_matrix
365 | diag_voc, med_voc, pro_voc = create_str_token_mapping(data)
366 | records = create_patient_record(data, diag_voc, med_voc, pro_voc) # diag,proc,medication按顺序存储
367 | ddi_adj = get_ddi_matrix(records, med_voc, ddi_file)
368 |
--------------------------------------------------------------------------------
/src/data_loader.py:
--------------------------------------------------------------------------------
1 | from sklearn import preprocessing
2 | from torch.nn.functional import pad
3 | from torch.utils import data
4 | import torch
5 | import random
6 |
7 |
8 | class mimic_data(data.Dataset):
9 | def __init__(self, data) -> None:
10 | super().__init__()
11 | self.data = data
12 |
13 | def __getitem__(self, index):
14 | return self.data[index]
15 |
16 | def __len__(self):
17 | return len(self.data)
18 |
19 |
20 | def pad_batch(batch):
21 | seq_length = torch.tensor([len(data) for data in batch])
22 | batch_size = len(batch)
23 | max_seq = max(seq_length)
24 |
25 | # 统计每一个seq疾病、手术、药物的数量,以及相应的最值
26 | # 同时为每一个seq的disease计算与上一个seq的disease之间的交集和差集
27 | d_length_matrix = []
28 | p_length_matrix = []
29 | m_length_matrix = []
30 | d_max_num = 0
31 | p_max_num = 0
32 | m_max_num = 0
33 | d_dec_list = []
34 | d_stay_list = []
35 | for data in batch:
36 | d_buf, p_buf, m_buf = [], [], []
37 | d_dec_list_buf, d_stay_list_buf = [], []
38 | for idx, seq in enumerate(data):
39 | d_buf.append(len(seq[0]))
40 | p_buf.append(len(seq[1]))
41 | m_buf.append(len(seq[2]))
42 | d_max_num = max(d_max_num, len(seq[0]))
43 | p_max_num = max(p_max_num, len(seq[1]))
44 | m_max_num = max(m_max_num, len(seq[2]))
45 | if idx==0:
46 | # 第一个seq,则交集与差集为空
47 | d_dec_list_buf.append([])
48 | d_stay_list_buf.append([])
49 | else:
50 | # 计算差集与交集
51 | cur_d = set(seq[0])
52 | last_d = set(data[idx-1][0])
53 | stay_list = list(cur_d & last_d)
54 | dec_list = list(last_d - cur_d)
55 | d_dec_list_buf.append(dec_list)
56 | d_stay_list_buf.append(stay_list)
57 | d_length_matrix.append(d_buf)
58 | p_length_matrix.append(p_buf)
59 | m_length_matrix.append(m_buf)
60 | d_dec_list.append(d_dec_list_buf)
61 | d_stay_list.append(d_stay_list_buf)
62 |
63 | # 生成m_mask_matrix
64 | m_mask_matrix = torch.full((batch_size, max_seq, m_max_num), -1e9)
65 | for i in range(batch_size):
66 | for j in range(len(m_length_matrix[i])):
67 | m_mask_matrix[i, j, :m_length_matrix[i][j]] = 0.
68 |
69 | # 生成d_mask_matrix
70 | d_mask_matrix = torch.full((batch_size, max_seq, d_max_num), -1e9)
71 | for i in range(batch_size):
72 | for j in range(len(d_length_matrix[i])):
73 | d_mask_matrix[i, j, :d_length_matrix[i][j]] = 0.
74 |
75 | # 生成p_mask_matrix
76 | p_mask_matrix = torch.full((batch_size, max_seq, p_max_num), -1e9)
77 | for i in range(batch_size):
78 | for j in range(len(p_length_matrix[i])):
79 | p_mask_matrix[i, j, :p_length_matrix[i][j]] = 0.
80 |
81 | # 分别生成dec_disease_tensor和stay_disease_tensor
82 | dec_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1)
83 | stay_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1)
84 | dec_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9)
85 | stay_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9)
86 | for b_id, (dec_seqs, stay_seqs) in enumerate(zip(d_dec_list, d_stay_list)):
87 | for s_id, (dec_adm, stay_adm) in enumerate(zip(dec_seqs, stay_seqs)):
88 | dec_disease_tensor[b_id, s_id, :len(dec_adm)] = torch.tensor(dec_adm)
89 | stay_disease_tensor[b_id, s_id, :len(stay_adm)] = torch.tensor(stay_adm)
90 | dec_disease_mask[b_id, s_id, :len(dec_adm)] = 0.
91 | stay_disease_mask[b_id, s_id, :len(dec_adm)] = 0.
92 |
93 | # 分别生成disease、procedure、medication的数据
94 | disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1)
95 | procedure_tensor = torch.full((batch_size, max_seq, p_max_num), -1)
96 | medication_tensor = torch.full((batch_size, max_seq, m_max_num), 0)
97 |
98 | # 分别拼接成一个batch的数据
99 | for b_id, data in enumerate(batch):
100 | for s_id, adm in enumerate(data):
101 | # adm部分的数据按照disease、procedure、medication排序
102 | disease_tensor[b_id, s_id, :len(adm[0])] = torch.tensor(adm[0])
103 | procedure_tensor[b_id, s_id, :len(adm[1])] = torch.tensor(adm[1])
104 | medication_tensor[b_id, s_id, :len(adm[2])] = torch.tensor(adm[2])
105 | # print(disease_tensor[1])
106 | return disease_tensor, procedure_tensor, medication_tensor, seq_length, \
107 | d_length_matrix, p_length_matrix, m_length_matrix, \
108 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \
109 | dec_disease_tensor, stay_disease_tensor, dec_disease_mask, stay_disease_mask
110 |
111 |
112 | def pad_batch_v2_train(batch):
113 | seq_length = torch.tensor([len(data) for data in batch])
114 | batch_size = len(batch)
115 | max_seq = max(seq_length)
116 |
117 | # 统计每一个seq疾病、手术、药物的数量,以及相应的最值
118 | # 同时为每一个seq的disease计算与上一个seq的disease之间的交集和差集
119 | d_length_matrix = []
120 | p_length_matrix = []
121 | m_length_matrix = []
122 | d_max_num = 0
123 | p_max_num = 0
124 | m_max_num = 0
125 | d_dec_list = []
126 | d_stay_list = []
127 | p_dec_list = []
128 | p_stay_list = []
129 | for data in batch:
130 | d_buf, p_buf, m_buf = [], [], []
131 | d_dec_list_buf, d_stay_list_buf = [], []
132 | p_dec_list_buf, p_stay_list_buf = [], []
133 | for idx, seq in enumerate(data):
134 | d_buf.append(len(seq[0]))
135 | p_buf.append(len(seq[1]))
136 | m_buf.append(len(seq[2]))
137 | d_max_num = max(d_max_num, len(seq[0]))
138 | p_max_num = max(p_max_num, len(seq[1]))
139 | m_max_num = max(m_max_num, len(seq[2]))
140 | if idx==0:
141 | # 第一个seq,则交集与差集为空
142 | d_dec_list_buf.append([])
143 | d_stay_list_buf.append([])
144 | p_dec_list_buf.append([])
145 | p_stay_list_buf.append([])
146 | else:
147 | # 计算差集与交集
148 | cur_d = set(seq[0])
149 | last_d = set(data[idx-1][0])
150 | stay_list = list(cur_d & last_d)
151 | dec_list = list(last_d - cur_d)
152 | d_dec_list_buf.append(dec_list)
153 | d_stay_list_buf.append(stay_list)
154 |
155 | cur_p = set(seq[1])
156 | last_p = set(data[idx-1][1])
157 | proc_stay_list = list(cur_p & last_p)
158 | proc_dec_list = list(last_p - cur_p)
159 | p_dec_list_buf.append(proc_dec_list)
160 | p_stay_list_buf.append(proc_stay_list)
161 | d_length_matrix.append(d_buf)
162 | p_length_matrix.append(p_buf)
163 | m_length_matrix.append(m_buf)
164 | d_dec_list.append(d_dec_list_buf)
165 | d_stay_list.append(d_stay_list_buf)
166 | p_dec_list.append(p_dec_list_buf)
167 | p_stay_list.append(p_stay_list_buf)
168 |
169 | # 生成m_mask_matrix
170 | m_mask_matrix = torch.full((batch_size, max_seq, m_max_num), -1e9)
171 | for i in range(batch_size):
172 | for j in range(len(m_length_matrix[i])):
173 | m_mask_matrix[i, j, :m_length_matrix[i][j]] = 0.
174 |
175 | # 生成d_mask_matrix
176 | d_mask_matrix = torch.full((batch_size, max_seq, d_max_num), -1e9)
177 | for i in range(batch_size):
178 | for j in range(len(d_length_matrix[i])):
179 | d_mask_matrix[i, j, :d_length_matrix[i][j]] = 0.
180 |
181 | # 生成p_mask_matrix
182 | p_mask_matrix = torch.full((batch_size, max_seq, p_max_num), -1e9)
183 | for i in range(batch_size):
184 | for j in range(len(p_length_matrix[i])):
185 | p_mask_matrix[i, j, :p_length_matrix[i][j]] = 0.
186 |
187 | # 分别生成dec_disease_tensor和stay_disease_tensor
188 | dec_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1)
189 | stay_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1)
190 | dec_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9)
191 | stay_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9)
192 | for b_id, (dec_seqs, stay_seqs) in enumerate(zip(d_dec_list, d_stay_list)):
193 | for s_id, (dec_adm, stay_adm) in enumerate(zip(dec_seqs, stay_seqs)):
194 | dec_disease_tensor[b_id, s_id, :len(dec_adm)] = torch.tensor(dec_adm)
195 | stay_disease_tensor[b_id, s_id, :len(stay_adm)] = torch.tensor(stay_adm)
196 | dec_disease_mask[b_id, s_id, :len(dec_adm)] = 0.
197 | stay_disease_mask[b_id, s_id, :len(dec_adm)] = 0.
198 |
199 | # 分别生成dec_disease_tensor和stay_disease_tensor
200 | dec_proc_tensor = torch.full((batch_size, max_seq, p_max_num), -1)
201 | stay_proc_tensor = torch.full((batch_size, max_seq, p_max_num), -1)
202 | dec_proc_mask = torch.full((batch_size, max_seq, p_max_num), -1e9)
203 | stay_proc_mask = torch.full((batch_size, max_seq, p_max_num), -1e9)
204 | for b_id, (dec_seqs, stay_seqs) in enumerate(zip(p_dec_list, p_stay_list)):
205 | for s_id, (dec_adm, stay_adm) in enumerate(zip(dec_seqs, stay_seqs)):
206 | dec_proc_tensor[b_id, s_id, :len(dec_adm)] = torch.tensor(dec_adm)
207 | stay_proc_tensor[b_id, s_id, :len(stay_adm)] = torch.tensor(stay_adm)
208 | dec_proc_mask[b_id, s_id, :len(dec_adm)] = 0.
209 | stay_proc_mask[b_id, s_id, :len(dec_adm)] = 0.
210 |
211 | # 分别生成disease、procedure、medication的数据
212 | disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1)
213 | procedure_tensor = torch.full((batch_size, max_seq, p_max_num), -1)
214 | medication_tensor = torch.full((batch_size, max_seq, m_max_num), 0)
215 |
216 | # 分别拼接成一个batch的数据
217 | for b_id, data in enumerate(batch):
218 | for s_id, adm in enumerate(data):
219 | # adm部分的数据按照disease、procedure、medication排序
220 | disease_tensor[b_id, s_id, :len(adm[0])] = torch.tensor(adm[0])
221 | procedure_tensor[b_id, s_id, :len(adm[1])] = torch.tensor(adm[1])
222 | # dynamic shuffle
223 | # cur_medications = adm[2]
224 | # random.shuffle(cur_medications)
225 | # medication_tensor[b_id, s_id, :len(adm[2])] = torch.tensor(cur_medications)
226 | medication_tensor[b_id, s_id, :len(adm[2])] = torch.tensor(adm[2])
227 |
228 | # print(disease_tensor[1])
229 | return disease_tensor, procedure_tensor, medication_tensor, seq_length, \
230 | d_length_matrix, p_length_matrix, m_length_matrix, \
231 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \
232 | dec_disease_tensor, stay_disease_tensor, dec_disease_mask, stay_disease_mask, \
233 | dec_proc_tensor, stay_proc_tensor, dec_proc_mask, stay_proc_mask
234 |
235 |
236 |
237 | def pad_batch_v2_eval(batch):
238 | seq_length = torch.tensor([len(data) for data in batch])
239 | batch_size = len(batch)
240 | max_seq = max(seq_length)
241 |
242 | # 统计每一个seq疾病、手术、药物的数量,以及相应的最值
243 | # 同时为每一个seq的disease计算与上一个seq的disease之间的交集和差集
244 | d_length_matrix = []
245 | p_length_matrix = []
246 | m_length_matrix = []
247 | d_max_num = 0
248 | p_max_num = 0
249 | m_max_num = 0
250 | d_dec_list = []
251 | d_stay_list = []
252 | p_dec_list = []
253 | p_stay_list = []
254 | for data in batch:
255 | d_buf, p_buf, m_buf = [], [], []
256 | d_dec_list_buf, d_stay_list_buf = [], []
257 | p_dec_list_buf, p_stay_list_buf = [], []
258 | for idx, seq in enumerate(data):
259 | d_buf.append(len(seq[0]))
260 | p_buf.append(len(seq[1]))
261 | m_buf.append(len(seq[2]))
262 | d_max_num = max(d_max_num, len(seq[0]))
263 | p_max_num = max(p_max_num, len(seq[1]))
264 | m_max_num = max(m_max_num, len(seq[2]))
265 | if idx==0:
266 | # 第一个seq,则交集与差集为空
267 | d_dec_list_buf.append([])
268 | d_stay_list_buf.append([])
269 | p_dec_list_buf.append([])
270 | p_stay_list_buf.append([])
271 | else:
272 | # 计算差集与交集
273 | cur_d = set(seq[0])
274 | last_d = set(data[idx-1][0])
275 | stay_list = list(cur_d & last_d)
276 | dec_list = list(last_d - cur_d)
277 | d_dec_list_buf.append(dec_list)
278 | d_stay_list_buf.append(stay_list)
279 |
280 | cur_p = set(seq[1])
281 | last_p = set(data[idx-1][1])
282 | proc_stay_list = list(cur_p & last_p)
283 | proc_dec_list = list(last_p - cur_p)
284 | p_dec_list_buf.append(proc_dec_list)
285 | p_stay_list_buf.append(proc_stay_list)
286 | d_length_matrix.append(d_buf)
287 | p_length_matrix.append(p_buf)
288 | m_length_matrix.append(m_buf)
289 | d_dec_list.append(d_dec_list_buf)
290 | d_stay_list.append(d_stay_list_buf)
291 | p_dec_list.append(p_dec_list_buf)
292 | p_stay_list.append(p_stay_list_buf)
293 |
294 | # 生成m_mask_matrix
295 | m_mask_matrix = torch.full((batch_size, max_seq, m_max_num), -1e9)
296 | for i in range(batch_size):
297 | for j in range(len(m_length_matrix[i])):
298 | m_mask_matrix[i, j, :m_length_matrix[i][j]] = 0.
299 |
300 | # 生成d_mask_matrix
301 | d_mask_matrix = torch.full((batch_size, max_seq, d_max_num), -1e9)
302 | for i in range(batch_size):
303 | for j in range(len(d_length_matrix[i])):
304 | d_mask_matrix[i, j, :d_length_matrix[i][j]] = 0.
305 |
306 | # 生成p_mask_matrix
307 | p_mask_matrix = torch.full((batch_size, max_seq, p_max_num), -1e9)
308 | for i in range(batch_size):
309 | for j in range(len(p_length_matrix[i])):
310 | p_mask_matrix[i, j, :p_length_matrix[i][j]] = 0.
311 |
312 | # 分别生成dec_disease_tensor和stay_disease_tensor
313 | dec_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1)
314 | stay_disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1)
315 | dec_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9)
316 | stay_disease_mask = torch.full((batch_size, max_seq, d_max_num), -1e9)
317 | for b_id, (dec_seqs, stay_seqs) in enumerate(zip(d_dec_list, d_stay_list)):
318 | for s_id, (dec_adm, stay_adm) in enumerate(zip(dec_seqs, stay_seqs)):
319 | dec_disease_tensor[b_id, s_id, :len(dec_adm)] = torch.tensor(dec_adm)
320 | stay_disease_tensor[b_id, s_id, :len(stay_adm)] = torch.tensor(stay_adm)
321 | dec_disease_mask[b_id, s_id, :len(dec_adm)] = 0.
322 | stay_disease_mask[b_id, s_id, :len(dec_adm)] = 0.
323 |
324 | # 分别生成dec_disease_tensor和stay_disease_tensor
325 | dec_proc_tensor = torch.full((batch_size, max_seq, p_max_num), -1)
326 | stay_proc_tensor = torch.full((batch_size, max_seq, p_max_num), -1)
327 | dec_proc_mask = torch.full((batch_size, max_seq, p_max_num), -1e9)
328 | stay_proc_mask = torch.full((batch_size, max_seq, p_max_num), -1e9)
329 | for b_id, (dec_seqs, stay_seqs) in enumerate(zip(p_dec_list, p_stay_list)):
330 | for s_id, (dec_adm, stay_adm) in enumerate(zip(dec_seqs, stay_seqs)):
331 | dec_proc_tensor[b_id, s_id, :len(dec_adm)] = torch.tensor(dec_adm)
332 | stay_proc_tensor[b_id, s_id, :len(stay_adm)] = torch.tensor(stay_adm)
333 | dec_proc_mask[b_id, s_id, :len(dec_adm)] = 0.
334 | stay_proc_mask[b_id, s_id, :len(dec_adm)] = 0.
335 |
336 | # 分别生成disease、procedure、medication的数据
337 | disease_tensor = torch.full((batch_size, max_seq, d_max_num), -1)
338 | procedure_tensor = torch.full((batch_size, max_seq, p_max_num), -1)
339 | medication_tensor = torch.full((batch_size, max_seq, m_max_num), 0)
340 |
341 | # 分别拼接成一个batch的数据
342 | for b_id, data in enumerate(batch):
343 | for s_id, adm in enumerate(data):
344 | # adm部分的数据按照disease、procedure、medication排序
345 | disease_tensor[b_id, s_id, :len(adm[0])] = torch.tensor(adm[0])
346 | procedure_tensor[b_id, s_id, :len(adm[1])] = torch.tensor(adm[1])
347 | medication_tensor[b_id, s_id, :len(adm[2])] = torch.tensor(adm[2])
348 |
349 | # print(disease_tensor[1])
350 | return disease_tensor, procedure_tensor, medication_tensor, seq_length, \
351 | d_length_matrix, p_length_matrix, m_length_matrix, \
352 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \
353 | dec_disease_tensor, stay_disease_tensor, dec_disease_mask, stay_disease_mask, \
354 | dec_proc_tensor, stay_proc_tensor, dec_proc_mask, stay_proc_mask
355 |
356 |
357 | def pad_num_replace(tensor, src_num, target_num):
358 | # replace_tensor = torch.full_like(tensor, target_num)
359 | return torch.where(tensor==src_num, target_num, tensor)
360 |
361 |
362 |
--------------------------------------------------------------------------------
/src/MICRON.py:
--------------------------------------------------------------------------------
1 | from types import new_class
2 | import dill
3 | import numpy as np
4 | import argparse
5 | from collections import defaultdict
6 | from sklearn.metrics import jaccard_score, roc_curve
7 | from torch.optim import Adam, RMSprop
8 | import os
9 | import torch
10 | import time
11 | import math
12 | from models import MICRON
13 | from util import llprint, multi_label_metric, ddi_rate_score, get_n_params
14 | import torch.nn.functional as F
15 |
16 | # torch.set_num_threads(30)
17 | torch.manual_seed(1203)
18 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
19 |
20 | # setting
21 | model_name = 'MICRON_pad_lr1e-4'
22 | resume_path = './Epoch_39_JA_0.5209_DDI_0.06952.model'
23 | # resume_path = './{}_Epoch_39_JA_0.5209_DDI_0.06952.model'.format(model_name)
24 |
25 | if not os.path.exists(os.path.join("saved", model_name)):
26 | os.makedirs(os.path.join("saved", model_name))
27 |
28 | # Training settings
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--Test', action='store_true', default=False, help="test mode")
31 | parser.add_argument('--model_name', type=str, default=model_name, help="model name")
32 | parser.add_argument('--resume_path', type=str, default=resume_path, help='resume path')
33 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate') # original 0.0002
34 | parser.add_argument('--weight_decay', type=float, default=1e-5, help='learning rate')
35 | parser.add_argument('--dim', type=int, default=64, help='dimension')
36 |
37 | args = parser.parse_args()
38 |
39 | # evaluate
40 | def eval(model, data_eval, voc_size, epoch, val=0, threshold1=0.8, threshold2=0.2):
41 | model.eval()
42 |
43 | smm_record = []
44 | ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
45 | med_cnt, visit_cnt = 0, 0
46 | label_list, prob_list = [], []
47 | add_list, delete_list = [], []
48 | # 不同visit的指标统计
49 | ja_by_visit = [[] for _ in range(5)]
50 | prauc_by_visit = [[] for _ in range(5)]
51 | pre_by_visit = [[] for _ in range(5)]
52 | recall_by_visit = [[] for _ in range(5)]
53 | f1_by_visit = [[] for _ in range(5)]
54 | smm_record_by_visit = [[] for _ in range(5)]
55 |
56 | for step, input in enumerate(data_eval):
57 | y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
58 | add_temp_list, delete_temp_list = [], []
59 | if len(input) < 2: continue
60 | for adm_idx, adm in enumerate(input):
61 | # 第0个visit也要添加到结果中去
62 | y_gt_tmp = np.zeros(voc_size[2])
63 | y_gt_tmp[adm[2]] = 1
64 | y_gt.append(y_gt_tmp)
65 | label_list.append(y_gt_tmp)
66 |
67 | if adm_idx == 0:
68 | representation_base, _, _, _, _ = model(input[:adm_idx+1])
69 | # 第0个visit也添加
70 | y_pred_tmp = F.sigmoid(representation_base).detach().cpu().numpy()[0]
71 | y_pred_prob.append(y_pred_tmp)
72 | prob_list.append(y_pred_tmp)
73 |
74 | y_old = np.zeros(voc_size[2])
75 | y_old[y_pred_tmp>=threshold1] = 1
76 | y_old[y_pred_tmp=threshold1] = 1
108 | y_old[y_pred_tmp 1:
141 | add_list.append(np.mean(add_temp_list))
142 | delete_list.append(np.mean(delete_temp_list))
143 | elif len(add_temp_list) == 1:
144 | add_list.append(add_temp_list[0])
145 | delete_list.append(delete_temp_list[0])
146 |
147 | smm_record.append(y_pred_label)
148 | adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))
149 |
150 | ja.append(adm_ja)
151 | prauc.append(adm_prauc)
152 | avg_p.append(adm_avg_p)
153 | avg_r.append(adm_avg_r)
154 | avg_f1.append(adm_avg_f1)
155 | llprint('\rtest step: {} / {}'.format(step, len(data_eval)))
156 |
157 | # 分析各个visit的结果
158 | print('\tvisit1\tvisit2\tvisit3\tvisit4\tvisit5')
159 | print('count:', [len(buf) for buf in ja_by_visit])
160 | print('prauc:', [np.mean(buf) for buf in prauc_by_visit])
161 | print('jaccard:', [np.mean(buf) for buf in ja_by_visit])
162 | print('precision:', [np.mean(buf) for buf in pre_by_visit])
163 | print('recall:', [np.mean(buf) for buf in recall_by_visit])
164 | print('f1:', [np.mean(buf) for buf in f1_by_visit])
165 | print('DDI:', [ddi_rate_score(buf) for buf in smm_record_by_visit])
166 |
167 | # ddi rate
168 | ddi_rate = ddi_rate_score(smm_record, path='../data/ddi_A_final.pkl')
169 |
170 | llprint('\nDDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n'.format(
171 | ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt
172 | ))
173 |
174 | if val == 0:
175 | return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt
176 | else:
177 | return np.array(label_list), np.array(prob_list)
178 |
179 |
180 | def main():
181 |
182 | # load data
183 | data_path = '../data/records_final.pkl'
184 | voc_path = '../data/voc_final.pkl'
185 |
186 | ddi_adj_path = '../data/ddi_A_final.pkl'
187 | device = torch.device('cuda')
188 |
189 | ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
190 | data = dill.load(open(data_path, 'rb'))
191 |
192 | voc = dill.load(open(voc_path, 'rb'))
193 | diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']
194 |
195 | # np.random.seed(1203)
196 | # np.random.shuffle(data)
197 |
198 | # "添加第一个visit"
199 | # new_data = []
200 | # for patient in data:
201 | # patient.insert(0, [[],[],[]])
202 | # # patient.insert(0, patient[0])
203 | # new_data.append(patient)
204 | # data = new_data
205 |
206 | split_point = int(len(data) * 2 / 3)
207 | data_train = data[:split_point]
208 | eval_len = int(len(data[split_point:]) / 2)
209 | data_test = data[split_point:split_point + eval_len]
210 | data_eval = data[split_point+eval_len:]
211 |
212 | voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))
213 |
214 | model = MICRON(voc_size, ddi_adj, emb_dim=args.dim, device=device)
215 | # model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
216 |
217 | if args.Test:
218 | model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
219 | model.to(device=device)
220 | tic = time.time()
221 | label_list, prob_list = eval(model, data_eval, voc_size, 0, 1)
222 |
223 | threshold1, threshold2 = [], []
224 | for i in range(label_list.shape[1]):
225 | _, _, boundary = roc_curve(label_list[:, i], prob_list[:, i], pos_label=1)
226 | # boundary1 should be in [0.5, 0.9], boundary2 should be in [0.1, 0.5]
227 | threshold1.append(min(0.9, max(0.5, boundary[max(0, round(len(boundary) * 0.05) - 1)])))
228 | threshold2.append(max(0.1, min(0.5, boundary[min(round(len(boundary) * 0.95), len(boundary) - 1)])))
229 | print (np.mean(threshold1), np.mean(threshold2))
230 | threshold1 = np.ones(voc_size[2]) * np.mean(threshold1)
231 | threshold2 = np.ones(voc_size[2]) * np.mean(threshold2)
232 | eval(model, data_test, voc_size, 0, 0, threshold1, threshold2)
233 | print ('test time: {}'.format(time.time() - tic))
234 |
235 | result = []
236 | for _ in range(10):
237 | test_sample = np.random.choice(data_test, round(len(data_test) * 0.8), replace=True)
238 | ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_add, avg_del, avg_med = eval(model, test_sample, voc_size, 0)
239 | result.append([ddi_rate, ja, avg_f1, prauc, avg_med])
240 |
241 | result = np.array(result)
242 | mean = result.mean(axis=0)
243 | std = result.std(axis=0)
244 |
245 | outstring = ""
246 | for m, s in zip(mean, std):
247 | outstring += "{:.4f} $\pm$ {:.4f} & ".format(m, s)
248 |
249 | print (outstring)
250 |
251 | print ('test time: {}'.format(time.time() - tic))
252 | return
253 |
254 | model.to(device=device)
255 | print('parameters', get_n_params(model))
256 | # exit()
257 | optimizer = RMSprop(list(model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
258 |
259 | # start iterations
260 | history = defaultdict(list)
261 | best_epoch, best_ja = 0, 0
262 |
263 | weight_list = [[0.25, 0.25, 0.25, 0.25]]
264 |
265 | EPOCH = 40
266 | for epoch in range(EPOCH):
267 | t = 0
268 | tic = time.time()
269 | print ('\nepoch {} --------------------------'.format(epoch + 1))
270 |
271 | sample_counter = 0
272 | mean_loss = np.array([0, 0, 0, 0])
273 |
274 | model.train()
275 | for step, input in enumerate(data_train):
276 | loss = 0
277 | if len(input) < 2: continue
278 | for adm_idx, adm in enumerate(input):
279 | """第一个visit也参与训练"""
280 | # if adm_idx == 0: continue
281 | # sample_counter += 1
282 | seq_input = input[:adm_idx+1]
283 |
284 | loss_bce_target = np.zeros((1, voc_size[2]))
285 | loss_bce_target[:, adm[2]] = 1
286 |
287 | loss_bce_target_last = np.zeros((1, voc_size[2]))
288 | if adm_idx > 0:
289 | loss_bce_target_last[:, input[adm_idx-1][2]] = 1
290 |
291 | loss_multi_target = np.full((1, voc_size[2]), -1)
292 | for idx, item in enumerate(adm[2]):
293 | loss_multi_target[0][idx] = item
294 |
295 | loss_multi_target_last = np.full((1, voc_size[2]), -1)
296 | if adm_idx > 0:
297 | for idx, item in enumerate(input[adm_idx-1][2]):
298 | loss_multi_target_last[0][idx] = item
299 |
300 | result, result_last, _, loss_ddi, loss_rec = model(seq_input)
301 |
302 | loss_bce = 0.75 * F.binary_cross_entropy_with_logits(result, torch.FloatTensor(loss_bce_target).to(device)) + \
303 | (1 - 0.75) * F.binary_cross_entropy_with_logits(result_last, torch.FloatTensor(loss_bce_target_last).to(device))
304 | loss_multi = 5e-2 * (0.75 * F.multilabel_margin_loss(F.sigmoid(result), torch.LongTensor(loss_multi_target).to(device)) + \
305 | (1 - 0.75) * F.multilabel_margin_loss(F.sigmoid(result_last), torch.LongTensor(loss_multi_target_last).to(device)))
306 |
307 | y_pred_tmp = F.sigmoid(result).detach().cpu().numpy()[0]
308 | y_pred_tmp[y_pred_tmp >= 0.5] = 1
309 | y_pred_tmp[y_pred_tmp < 0.5] = 0
310 | y_label = np.where(y_pred_tmp == 1)[0]
311 | current_ddi_rate = ddi_rate_score([[y_label]], path='../data/ddi_A_final.pkl')
312 |
313 | # l2 = 0
314 | # for p in model.parameters():
315 | # l2 = l2 + (p ** 2).sum()
316 |
317 | if sample_counter == 0:
318 | lambda1, lambda2, lambda3, lambda4 = weight_list[-1]
319 | else:
320 | current_loss = np.array([loss_bce.detach().cpu().numpy(), loss_multi.detach().cpu().numpy(), loss_ddi.detach().cpu().numpy(), loss_rec.detach().cpu().numpy()])
321 | current_ratio = (current_loss - np.array(mean_loss)) / np.array(mean_loss)
322 | instant_weight = np.exp(current_ratio) / sum(np.exp(current_ratio))
323 | lambda1, lambda2, lambda3, lambda4 = instant_weight * 0.75 + np.array(weight_list[-1]) * 0.25
324 | # update weight_list
325 | weight_list.append([lambda1, lambda2, lambda3, lambda4])
326 | # update mean_loss
327 | mean_loss = (mean_loss * (sample_counter - 1) + np.array([loss_bce.detach().cpu().numpy(), \
328 | loss_multi.detach().cpu().numpy(), loss_ddi.detach().cpu().numpy(), loss_rec.detach().cpu().numpy()])) / sample_counter
329 | # lambda1, lambda2, lambda3, lambda4 = weight_list[-1]
330 | if current_ddi_rate > 0.08:
331 | loss += lambda1 * loss_bce + lambda2 * loss_multi + \
332 | lambda3 * loss_ddi + lambda4 * loss_rec
333 | else:
334 | loss += lambda1 * loss_bce + lambda2 * loss_multi + \
335 | lambda4 * loss_rec
336 |
337 | optimizer.zero_grad()
338 | loss.backward(retain_graph=True)
339 | optimizer.step()
340 |
341 | llprint('\rtraining step: {} / {}'.format(step, len(data_train)))
342 |
343 | print()
344 | tic2 = time.time()
345 | ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, add, delete, avg_med = eval(model, data_eval, voc_size, epoch)
346 | print ('training time: {}, test time: {}'.format(time.time() - tic, time.time() - tic2))
347 |
348 | history['ja'].append(ja)
349 | history['ddi_rate'].append(ddi_rate)
350 | history['avg_p'].append(avg_p)
351 | history['avg_r'].append(avg_r)
352 | history['avg_f1'].append(avg_f1)
353 | history['prauc'].append(prauc)
354 | history['add'].append(add)
355 | history['delete'].append(delete)
356 | history['med'].append(avg_med)
357 |
358 | if epoch >= 5:
359 | print ('ddi: {}, Med: {}, Ja: {}, F1: {}, Add: {}, Delete: {}'.format(
360 | np.mean(history['ddi_rate'][-5:]),
361 | np.mean(history['med'][-5:]),
362 | np.mean(history['ja'][-5:]),
363 | np.mean(history['avg_f1'][-5:]),
364 | np.mean(history['add'][-5:]),
365 | np.mean(history['delete'][-5:])
366 | ))
367 |
368 | torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \
369 | 'Epoch_{}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, ja, ddi_rate)), 'wb'))
370 |
371 | if epoch != 0 and best_ja < ja:
372 | best_epoch = epoch
373 | best_ja = ja
374 |
375 | print ('best_epoch: {}'.format(best_epoch))
376 |
377 | dill.dump(history, open(os.path.join('saved', args.model_name, 'history_{}.pkl'.format(args.model_name)), 'wb'))
378 |
379 | if __name__ == '__main__':
380 | main()
381 |
--------------------------------------------------------------------------------
/src/recommend.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import torch
3 | import torch.nn as nn
4 | import argparse
5 | from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score
6 | import numpy as np
7 | import dill
8 | import time
9 | from torch.nn import CrossEntropyLoss
10 | from torch.optim import Adam
11 | from torch.utils import data
12 | from loss import cross_entropy_loss
13 | import os
14 | import torch.nn.functional as F
15 | import random
16 | from collections import defaultdict
17 |
18 | from torch.utils.data.dataloader import DataLoader
19 | from data_loader import mimic_data, pad_num_replace
20 | from beam import Beam
21 |
22 | import sys
23 | sys.path.append("..")
24 | from models import Leap, CopyDrug_batch, CopyDrug_tranformer, CopyDrug_generate_prob, CopyDrug_diag_proc_encode
25 | from COGNet_model import COGNet
26 | from util import llprint, sequence_metric, sequence_metric_v2, sequence_output_process, ddi_rate_score, get_n_params, output_flatten, print_result
27 |
28 | torch.manual_seed(1203)
29 |
30 | # 读取disease跟proc的英文名
31 | icd_diag_path = '../data/D_ICD_DIAGNOSES.csv'
32 | icd_proc_path = '../data/D_ICD_PROCEDURES.csv'
33 | code2diag = {}
34 | code2proc = {}
35 | with open(icd_diag_path, 'r') as f:
36 | lines = f.readlines()[1:]
37 | for line in lines:
38 | line = line.strip().split(',"')
39 | if line[-1] == '': line = line[:-1]
40 | _, icd_code, _, title = line
41 | code2diag[icd_code[:-1]] = title
42 |
43 | with open(icd_proc_path, 'r') as f:
44 | lines = f.readlines()[1:]
45 | for line in lines:
46 | _, icd_code, _, title = line.strip().split(',"')
47 | code2proc[icd_code[:-1]] = title
48 |
49 |
50 |
51 | def eval_recommend_batch(model, batch_data, device, TOKENS, args):
52 | END_TOKEN, DIAG_PAD_TOKEN, PROC_PAD_TOKEN, MED_PAD_TOKEN, SOS_TOKEN = TOKENS
53 |
54 | diseases, procedures, medications, seq_length, \
55 | d_length_matrix, p_length_matrix, m_length_matrix, \
56 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \
57 | dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, \
58 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask = batch_data
59 | # continue
60 | # 根据vocab对padding数值进行替换
61 | diseases = pad_num_replace(diseases, -1, DIAG_PAD_TOKEN).to(device)
62 | procedures = pad_num_replace(procedures, -1, PROC_PAD_TOKEN).to(device)
63 | dec_disease = pad_num_replace(dec_disease, -1, DIAG_PAD_TOKEN).to(device)
64 | stay_disease = pad_num_replace(stay_disease, -1, DIAG_PAD_TOKEN).to(device)
65 | dec_proc = pad_num_replace(dec_proc, -1, PROC_PAD_TOKEN).to(device)
66 | stay_proc = pad_num_replace(stay_proc, -1, PROC_PAD_TOKEN).to(device)
67 | medications = medications.to(device)
68 | m_mask_matrix = m_mask_matrix.to(device)
69 | d_mask_matrix = d_mask_matrix.to(device)
70 | p_mask_matrix = p_mask_matrix.to(device)
71 | dec_disease_mask = dec_disease_mask.to(device)
72 | stay_disease_mask = stay_disease_mask.to(device)
73 | dec_proc_mask = dec_proc_mask.to(device)
74 | stay_proc_mask = stay_proc_mask.to(device)
75 |
76 | batch_size = medications.size(0)
77 | max_visit_num = medications.size(1)
78 |
79 | input_disease_embdding, input_proc_embedding, encoded_medication, cross_visit_scores, last_seq_medication, last_m_mask, drug_memory = model.encode(diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix,
80 | seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, dec_proc, stay_proc, dec_proc_mask, stay_proc_mask, max_len=20)
81 |
82 | partial_input_medication = torch.full((batch_size, max_visit_num, 1), SOS_TOKEN).to(device)
83 | parital_logits = None
84 |
85 |
86 | for i in range(args.max_len):
87 | partial_input_med_num = partial_input_medication.size(2)
88 | partial_m_mask_matrix = torch.zeros((batch_size, max_visit_num, partial_input_med_num), device=device).float()
89 | # print('val', i, partial_m_mask_matrix.size())
90 |
91 | parital_logits = model.decode(partial_input_medication, input_disease_embdding, input_proc_embedding, encoded_medication, last_seq_medication, cross_visit_scores,
92 | d_mask_matrix, p_mask_matrix, partial_m_mask_matrix, last_m_mask, drug_memory)
93 | _, next_medication = torch.topk(parital_logits[:, :, -1, :], 1, dim=-1)
94 | partial_input_medication = torch.cat([partial_input_medication, next_medication], dim=-1)
95 |
96 | return parital_logits
97 |
98 |
99 |
100 | def test_recommend_batch(model, batch_data, device, TOKENS, ddi_adj, args):
101 | END_TOKEN, DIAG_PAD_TOKEN, PROC_PAD_TOKEN, MED_PAD_TOKEN, SOS_TOKEN = TOKENS
102 |
103 | diseases, procedures, medications, seq_length, \
104 | d_length_matrix, p_length_matrix, m_length_matrix, \
105 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \
106 | dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, \
107 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask = batch_data
108 | # continue
109 | # 根据vocab对padding数值进行替换
110 | diseases = pad_num_replace(diseases, -1, DIAG_PAD_TOKEN).to(device)
111 | procedures = pad_num_replace(procedures, -1, PROC_PAD_TOKEN).to(device)
112 | dec_disease = pad_num_replace(dec_disease, -1, DIAG_PAD_TOKEN).to(device)
113 | stay_disease = pad_num_replace(stay_disease, -1, DIAG_PAD_TOKEN).to(device)
114 | dec_proc = pad_num_replace(dec_proc, -1, PROC_PAD_TOKEN).to(device)
115 | stay_proc = pad_num_replace(stay_proc, -1, PROC_PAD_TOKEN).to(device)
116 | medications = medications.to(device)
117 | m_mask_matrix = m_mask_matrix.to(device)
118 | d_mask_matrix = d_mask_matrix.to(device)
119 | p_mask_matrix = p_mask_matrix.to(device)
120 | dec_disease_mask = dec_disease_mask.to(device)
121 | stay_disease_mask = stay_disease_mask.to(device)
122 | dec_proc_mask = dec_proc_mask.to(device)
123 | stay_proc_mask = stay_proc_mask.to(device)
124 |
125 | batch_size = medications.size(0)
126 | visit_num = medications.size(1)
127 |
128 | input_disease_embdding, input_proc_embedding, encoded_medication, cross_visit_scores, last_seq_medication, last_m_mask, drug_memory = model.encode(diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix,
129 | seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, dec_proc, stay_proc, dec_proc_mask, stay_proc_mask, max_len=20)
130 |
131 | # partial_input_medication = torch.full((batch_size, visit_num, 1), SOS_TOKEN).to(device)
132 | # parital_logits = None
133 |
134 | # 为每一个样本声明一个beam
135 | # 这里为了方便实现,写死batch_size必须为1
136 | assert batch_size == 1
137 | # visit_num个batch
138 | beams = [Beam(args.beam_size, MED_PAD_TOKEN, SOS_TOKEN, END_TOKEN, ddi_adj, device) for _ in range(visit_num)]
139 |
140 | # 构建decode输入,每一个visit上需要重复beam_size次数据
141 | input_disease_embdding = input_disease_embdding.repeat_interleave(args.beam_size, dim=0)
142 | input_proc_embedding = input_proc_embedding.repeat_interleave(args.beam_size, dim=0)
143 | encoded_medication = encoded_medication.repeat_interleave(args.beam_size, dim=0)
144 | last_seq_medication = last_seq_medication.repeat_interleave(args.beam_size, dim=0)
145 | cross_visit_scores = cross_visit_scores.repeat_interleave(args.beam_size, dim=0)
146 | # cross_visit_scores = cross_visit_scores.repeat_interleave(args.beam_size, dim=2)
147 | d_mask_matrix = d_mask_matrix.repeat_interleave(args.beam_size, dim=0)
148 | p_mask_matrix = p_mask_matrix.repeat_interleave(args.beam_size, dim=0)
149 | last_m_mask = last_m_mask.repeat_interleave(args.beam_size, dim=0)
150 |
151 | for i in range(args.max_len):
152 | len_dec_seq = i + 1
153 | # b.get_current_state(): (beam_size, len_dec_seq) --> (beam_size, 1, len_dec_seq)
154 | # dec_partial_inputs: (beam_size, visit_num, len_dec_seq)
155 | dec_partial_inputs = torch.cat([b.get_current_state().unsqueeze(dim=1) for b in beams], dim=1)
156 | # dec_partial_inputs = dec_partial_inputs.view(args.beam_size, visit_num, len_dec_seq)
157 |
158 | partial_m_mask_matrix = torch.zeros((args.beam_size, visit_num, len_dec_seq), device=device).float()
159 | # print('val', i, partial_m_mask_matrix.size())
160 |
161 | # parital_logits: (beam_size, visit_sum, len_dec_seq, all_med_num)
162 | parital_logits = model.decode(dec_partial_inputs, input_disease_embdding, input_proc_embedding, encoded_medication, last_seq_medication, cross_visit_scores,
163 | d_mask_matrix, p_mask_matrix, partial_m_mask_matrix, last_m_mask, drug_memory)
164 |
165 | # word_lk: (beam_size, visit_sum, all_med_num)
166 | word_lk = parital_logits[:, :, -1, :]
167 |
168 | active_beam_idx_list = [] # 记录目前仍然active的beam
169 | for beam_idx in range(visit_num):
170 | # # 如果当前beam完成了,则跳过,这里beams的size应该是不变的
171 | # if beams[beam_idx].done: continue
172 | # inst_idx = beam_inst_idx_map[beam_idx] # 该beam所对应的adm下标
173 | # 更新beam,同时返回当前beam是否完成,如果未完成则表示active
174 | if not beams[beam_idx].advance(word_lk[:, beam_idx, :]):
175 | active_beam_idx_list.append(beam_idx)
176 |
177 | # 如果没有active的beam,则全部样本预测完毕
178 | if not active_beam_idx_list: break
179 |
180 | # Return useful information
181 | all_hyp = []
182 | all_prob = []
183 | for beam_idx in range(visit_num):
184 | scores, tail_idxs = beams[beam_idx].sort_scores() # 每个beam按照score排序,找出最优的生成
185 | hyps = beams[beam_idx].get_hypothesis(tail_idxs[0])
186 | probs = beams[beam_idx].get_prob_list(tail_idxs[0])
187 | all_hyp += [hyps] # 注意这里只关注最优解,否则写法上要修改
188 | all_prob += [probs]
189 |
190 | return all_hyp, all_prob
191 |
192 |
193 | # evaluate
194 | def eval(model, eval_dataloader, voc_size, epoch, device, TOKENS, args):
195 | model.eval()
196 | END_TOKEN, DIAG_PAD_TOKEN, PROC_PAD_TOKEN, MED_PAD_TOKEN, SOS_TOKEN = TOKENS
197 | ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
198 | smm_record = []
199 | med_cnt, visit_cnt = 0, 0
200 |
201 | # fw = open("prediction_results.txt", "w")
202 |
203 | for idx, data in enumerate(eval_dataloader):
204 | diseases, procedures, medications, seq_length, \
205 | d_length_matrix, p_length_matrix, m_length_matrix, \
206 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \
207 | dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, \
208 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask = data
209 | visit_cnt += seq_length.sum().item()
210 |
211 | output_logits = eval_recommend_batch(model, data, device, TOKENS, args)
212 |
213 | # 每一个med上的预测结果
214 | labels, predictions = output_flatten(medications, output_logits, seq_length, m_length_matrix, voc_size[2], END_TOKEN, device, training=False, testing=False, max_len=args.max_len)
215 |
216 | y_gt = [] # groud truth 表示正确的label 0-1序列
217 | y_pred = [] # 预测的结果 0-1序列
218 | y_pred_prob = [] # 预测的每一个药物的平均概率,非0-1序列
219 | y_pred_label = [] # 预测的结果,非0-1序列
220 | # 针对每一个admission的预测结果
221 | for label, prediction in zip(labels, predictions):
222 | y_gt_tmp = np.zeros(voc_size[2])
223 | y_gt_tmp[label] = 1 # 01序列,表示正确的label
224 | y_gt.append(y_gt_tmp)
225 |
226 | # label: med set
227 | # prediction: [med_num, probability]
228 | out_list, sorted_predict = sequence_output_process(prediction, [voc_size[2], voc_size[2]+1])
229 | y_pred_label.append(sorted(sorted_predict))
230 | y_pred_prob.append(np.mean(prediction[:, :-2], axis=0))
231 |
232 | # prediction label
233 | y_pred_tmp = np.zeros(voc_size[2])
234 | y_pred_tmp[out_list] = 1
235 | y_pred.append(y_pred_tmp)
236 | med_cnt += len(sorted_predict)
237 |
238 | # if idx < 100:
239 | # fw.write(print_result(label, sorted_predict))
240 |
241 | smm_record.append(y_pred_label)
242 |
243 | adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = \
244 | sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label))
245 | ja.append(adm_ja)
246 | prauc.append(adm_prauc)
247 | avg_p.append(adm_avg_p)
248 | avg_r.append(adm_avg_r)
249 | avg_f1.append(adm_avg_f1)
250 | llprint('\rtest step: {} / {}'.format(idx, len(eval_dataloader)))
251 |
252 | # fw.close()
253 |
254 | # ddi rate
255 | ddi_rate = ddi_rate_score(smm_record, path='../data/ddi_A_final.pkl')
256 |
257 | llprint('\nDDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'.format(
258 | ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt
259 | ))
260 |
261 | return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt
262 |
263 |
264 | # test
265 | def test(model, test_dataloader, diag_voc, pro_voc, med_voc, voc_size, epoch, device, TOKENS, ddi_adj, args):
266 | model.eval()
267 | END_TOKEN, DIAG_PAD_TOKEN, PROC_PAD_TOKEN, MED_PAD_TOKEN, SOS_TOKEN = TOKENS
268 | ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
269 | med_cnt_list = []
270 | smm_record = []
271 | med_cnt, visit_cnt = 0, 0
272 | all_pred_list = []
273 | all_label_list = []
274 |
275 | ja_by_visit = [[] for _ in range(5)]
276 | auc_by_visit = [[] for _ in range(5)]
277 | pre_by_visit = [[] for _ in range(5)]
278 | recall_by_visit = [[] for _ in range(5)]
279 | f1_by_visit = [[] for _ in range(5)]
280 | smm_record_by_visit = [[] for _ in range(5)]
281 |
282 | for idx, data in enumerate(test_dataloader):
283 | diseases, procedures, medications, seq_length, \
284 | d_length_matrix, p_length_matrix, m_length_matrix, \
285 | d_mask_matrix, p_mask_matrix, m_mask_matrix, \
286 | dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, \
287 | dec_proc, stay_proc, dec_proc_mask, stay_proc_mask = data
288 | visit_cnt += seq_length.sum().item()
289 |
290 | output_logits, output_probs = test_recommend_batch(model, data, device, TOKENS, ddi_adj, args)
291 |
292 | labels, predictions = output_flatten(medications, output_logits, seq_length, m_length_matrix, voc_size[2], END_TOKEN, device, training=False, testing=True, max_len=args.max_len)
293 | _, probs = output_flatten(medications, output_probs, seq_length, m_length_matrix, voc_size[2], END_TOKEN, device, training=False, testing=True, max_len=args.max_len)
294 | y_gt = []
295 | y_pred = []
296 | y_pred_label = []
297 | y_pred_prob = []
298 |
299 | label_hisory = []
300 | label_hisory_list = []
301 | pred_list = []
302 | jaccard_list = []
303 | def cal_jaccard(set1, set2):
304 | if not set1 or not set2:
305 | return 0
306 | set1 = set(set1)
307 | set2 = set(set2)
308 | a, b = len(set1 & set2), len(set1 | set2)
309 | return a/b
310 | def cal_overlap_num(set1, set2):
311 | count = 0
312 | for d in set1:
313 | if d in set2:
314 | count += 1
315 | return count
316 |
317 | # 针对每一个admission的预测结果
318 | for label, prediction, prob_list in zip(labels, predictions, probs):
319 | label_hisory += label.tolist() ### case study
320 |
321 | y_gt_tmp = np.zeros(voc_size[2])
322 | y_gt_tmp[label] = 1 # 01序列,表示正确的label
323 | y_gt.append(y_gt_tmp)
324 |
325 | out_list = []
326 | out_prob_list = []
327 | for med, prob in zip(prediction, prob_list):
328 | if med in [voc_size[2], voc_size[2]+1]:
329 | break
330 | out_list.append(med)
331 | out_prob_list.append(prob[:-2]) # 去掉SOS与EOS符号
332 |
333 | ## case study
334 | if label_hisory:
335 | jaccard_list.append(cal_jaccard(prediction, label_hisory))
336 | pred_list.append(out_list)
337 | label_hisory_list.append(label.tolist())
338 |
339 | # 对于没预测的药物,取每个位置上平均的概率,否则直接取对应的概率
340 | # pred_out_prob_list = np.mean(out_prob_list, axis=0)
341 | pred_out_prob_list = np.max(out_prob_list, axis=0)
342 | # pred_out_prob_list = np.min(out_prob_list, axis=0)
343 | for i in range(131):
344 | if i in out_list:
345 | pred_out_prob_list[i] = out_prob_list[out_list.index(i)][i]
346 |
347 | y_pred_prob.append(pred_out_prob_list)
348 | y_pred_label.append(out_list)
349 |
350 | # prediction label
351 | y_pred_tmp = np.zeros(voc_size[2])
352 | y_pred_tmp[out_list] = 1
353 | y_pred.append(y_pred_tmp)
354 | med_cnt += len(prediction)
355 | med_cnt_list.append(len(prediction))
356 |
357 |
358 | smm_record.append(y_pred_label)
359 | for i in range(min(len(labels), 5)):
360 | # single_ja, single_p, single_r, single_f1 = sequence_metric_v2(np.array(y_gt[i:i+1]), np.array(y_pred[i:i+1]), np.array(y_pred_label[i:i+1]))
361 | single_ja, single_auc, single_p, single_r, single_f1 = sequence_metric(np.array([y_gt[i]]), np.array([y_pred[i]]), np.array([y_pred_prob[i]]),np.array([y_pred_label[i]]))
362 | ja_by_visit[i].append(single_ja)
363 | auc_by_visit[i].append(single_auc)
364 | pre_by_visit[i].append(single_p)
365 | recall_by_visit[i].append(single_r)
366 | f1_by_visit[i].append(single_f1)
367 | smm_record_by_visit[i].append(y_pred_label[i:i+1])
368 |
369 | # 存储所有预测结果
370 | all_pred_list.append(pred_list)
371 | all_label_list.append(labels)
372 | adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = \
373 | sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label))
374 | ja.append(adm_ja)
375 | prauc.append(adm_prauc)
376 | avg_p.append(adm_avg_p)
377 | avg_r.append(adm_avg_r)
378 | avg_f1.append(adm_avg_f1)
379 | llprint('\rtest step: {} / {}'.format(idx, len(test_dataloader)))
380 |
381 | # 统计不同visit的指标
382 | if idx%100==0:
383 | print('\tvisit1\tvisit2\tvisit3\tvisit4\tvisit5')
384 | print('count:', [len(buf) for buf in ja_by_visit])
385 | print('jaccard:', [np.mean(buf) for buf in ja_by_visit])
386 | print('auc:', [np.mean(buf) for buf in auc_by_visit])
387 | print('precision:', [np.mean(buf) for buf in pre_by_visit])
388 | print('recall:', [np.mean(buf) for buf in recall_by_visit])
389 | print('f1:', [np.mean(buf) for buf in f1_by_visit])
390 | print('DDI:', [ddi_rate_score(buf) for buf in smm_record_by_visit])
391 |
392 | print('\tvisit1\tvisit2\tvisit3\tvisit4\tvisit5')
393 | print('count:', [len(buf) for buf in ja_by_visit])
394 | print('jaccard:', [np.mean(buf) for buf in ja_by_visit])
395 | print('auc:', [np.mean(buf) for buf in auc_by_visit])
396 | print('precision:', [np.mean(buf) for buf in pre_by_visit])
397 | print('recall:', [np.mean(buf) for buf in recall_by_visit])
398 | print('f1:', [np.mean(buf) for buf in f1_by_visit])
399 | print('DDI:', [ddi_rate_score(buf) for buf in smm_record_by_visit])
400 |
401 | pickle.dump(all_pred_list, open('out_list.pkl', 'wb'))
402 | pickle.dump(all_label_list, open('out_list_gt.pkl', 'wb'))
403 |
404 | return smm_record, ja, prauc, avg_p, avg_r, avg_f1, med_cnt_list
405 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score
2 | import numpy as np
3 | import pandas as pd
4 | from sklearn.model_selection import train_test_split
5 | import sys
6 | import warnings
7 | import dill
8 | from collections import Counter
9 | from rdkit import Chem
10 | from collections import defaultdict
11 | import torch
12 | warnings.filterwarnings('ignore')
13 |
14 | def get_n_params(model):
15 | pp=0
16 | for p in list(model.parameters()):
17 | nn=1
18 | for s in list(p.size()):
19 | nn = nn*s
20 | pp += nn
21 | return pp
22 |
23 | # use the same metric from DMNC
24 | def llprint(message):
25 | sys.stdout.write(message)
26 | sys.stdout.flush()
27 |
28 | def transform_split(X, Y):
29 | x_train, x_eval, y_train, y_eval = train_test_split(X, Y, train_size=2/3, random_state=1203)
30 | x_eval, x_test, y_eval, y_test = train_test_split(x_eval, y_eval, test_size=0.5, random_state=1203)
31 | return x_train, x_eval, x_test, y_train, y_eval, y_test
32 |
33 | def sequence_output_process(output_logits, filter_token):
34 | """生成最终正确的序列,output_logits表示每个位置的prob,filter_token代表SOS与END"""
35 | pind = np.argsort(output_logits, axis=-1)[:, ::-1] # 每个位置上按概率的降序排序
36 |
37 | out_list = [] # 生成的结果
38 | break_flag = False
39 | for i in range(len(pind)):
40 | # 顺序遍历pind上所有值
41 | # break_flag来判断是否退出sentence生成的循环
42 | if break_flag:
43 | break
44 | # 每个位置上是按降序排序好的结果
45 | for j in range(pind.shape[1]):
46 | label = pind[i][j]
47 | # 如果遇到了SOS或者END,就表示句子over了
48 | if label in filter_token:
49 | break_flag = True
50 | break
51 | # 如果遇到了未出现过的,就继续生成
52 | # 否则就继续看下一个概率较大的药
53 | if label not in out_list:
54 | out_list.append(label)
55 | break
56 | y_pred_prob_tmp = []
57 | for idx, item in enumerate(out_list):
58 | y_pred_prob_tmp.append(output_logits[idx, item])
59 | # 将out_list中按照概率的高低将所有药物排序?
60 | sorted_predict = [x for _, x in sorted(zip(y_pred_prob_tmp, out_list), reverse=True)]
61 | return out_list, sorted_predict
62 |
63 |
64 | def sequence_metric(y_gt, y_pred, y_prob, y_label):
65 | def average_prc(y_gt, y_label):
66 | score = []
67 | for b in range(y_gt.shape[0]):
68 | target = np.where(y_gt[b]==1)[0]
69 | out_list = y_label[b]
70 | inter = set(out_list) & set(target)
71 | prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
72 | score.append(prc_score)
73 | return score
74 |
75 |
76 | def average_recall(y_gt, y_label):
77 | score = []
78 | for b in range(y_gt.shape[0]):
79 | target = np.where(y_gt[b] == 1)[0]
80 | out_list = y_label[b]
81 | inter = set(out_list) & set(target)
82 | recall_score = 0 if len(target) == 0 else len(inter) / len(target)
83 | score.append(recall_score)
84 | return score
85 |
86 |
87 | def average_f1(average_prc, average_recall):
88 | score = []
89 | for idx in range(len(average_prc)):
90 | if (average_prc[idx] + average_recall[idx]) == 0:
91 | score.append(0)
92 | else:
93 | score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
94 | return score
95 |
96 |
97 | def jaccard(y_gt, y_label):
98 | score = []
99 | for b in range(y_gt.shape[0]):
100 | target = np.where(y_gt[b] == 1)[0]
101 | out_list = y_label[b]
102 | inter = set(out_list) & set(target)
103 | union = set(out_list) | set(target)
104 | jaccard_score = 0 if union == 0 else len(inter) / len(union)
105 | score.append(jaccard_score)
106 | return np.mean(score)
107 |
108 | def f1(y_gt, y_pred):
109 | all_micro = []
110 | for b in range(y_gt.shape[0]):
111 | all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
112 | return np.mean(all_micro)
113 |
114 | def roc_auc(y_gt, y_pred_prob):
115 | all_micro = []
116 | for b in range(len(y_gt)):
117 | all_micro.append(roc_auc_score(y_gt[b], y_pred_prob[b], average='macro'))
118 | return np.mean(all_micro)
119 |
120 | def precision_auc(y_gt, y_prob):
121 | all_micro = []
122 | for b in range(len(y_gt)):
123 | # try:
124 | # all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
125 | # except:
126 | # continue
127 | all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
128 | return np.mean(all_micro)
129 |
130 | def precision_at_k(y_gt, y_prob_label, k):
131 | precision = 0
132 | for i in range(len(y_gt)):
133 | TP = 0
134 | for j in y_prob_label[i][:k]:
135 | if y_gt[i, j] == 1:
136 | TP += 1
137 | precision += TP / k
138 | return precision / len(y_gt)
139 | try:
140 | auc = roc_auc(y_gt, y_prob)
141 | except ValueError:
142 | auc = 0
143 | p_1 = precision_at_k(y_gt, y_label, k=1)
144 | p_3 = precision_at_k(y_gt, y_label, k=3)
145 | p_5 = precision_at_k(y_gt, y_label, k=5)
146 | f1 = f1(y_gt, y_pred)
147 | prauc = precision_auc(y_gt, y_prob)
148 | ja = jaccard(y_gt, y_label)
149 | avg_prc = average_prc(y_gt, y_label)
150 | avg_recall = average_recall(y_gt, y_label)
151 | avg_f1 = average_f1(avg_prc, avg_recall)
152 |
153 | return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)
154 |
155 |
156 | def sequence_metric_v2(y_gt, y_pred, y_label):
157 | def average_prc(y_gt, y_label):
158 | score = []
159 | for b in range(y_gt.shape[0]):
160 | target = np.where(y_gt[b]==1)[0]
161 | out_list = y_label[b]
162 | inter = set(out_list) & set(target)
163 | prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
164 | score.append(prc_score)
165 | return score
166 |
167 |
168 | def average_recall(y_gt, y_label):
169 | score = []
170 | for b in range(y_gt.shape[0]):
171 | target = np.where(y_gt[b] == 1)[0]
172 | out_list = y_label[b]
173 | inter = set(out_list) & set(target)
174 | recall_score = 0 if len(target) == 0 else len(inter) / len(target)
175 | score.append(recall_score)
176 | return score
177 |
178 |
179 | def average_f1(average_prc, average_recall):
180 | score = []
181 | for idx in range(len(average_prc)):
182 | if (average_prc[idx] + average_recall[idx]) == 0:
183 | score.append(0)
184 | else:
185 | score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
186 | return score
187 |
188 |
189 | def jaccard(y_gt, y_label):
190 | score = []
191 | for b in range(y_gt.shape[0]):
192 | target = np.where(y_gt[b] == 1)[0]
193 | out_list = y_label[b]
194 | inter = set(out_list) & set(target)
195 | union = set(out_list) | set(target)
196 | jaccard_score = 0 if union == 0 else len(inter) / len(union)
197 | score.append(jaccard_score)
198 | return np.mean(score)
199 |
200 | def f1(y_gt, y_pred):
201 | all_micro = []
202 | for b in range(y_gt.shape[0]):
203 | all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
204 | return np.mean(all_micro)
205 |
206 | def roc_auc(y_gt, y_pred_prob):
207 | all_micro = []
208 | for b in range(len(y_gt)):
209 | all_micro.append(roc_auc_score(y_gt[b], y_pred_prob[b], average='macro'))
210 | return np.mean(all_micro)
211 |
212 | def precision_auc(y_gt, y_prob):
213 | all_micro = []
214 | for b in range(len(y_gt)):
215 | all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
216 | return np.mean(all_micro)
217 |
218 | def precision_at_k(y_gt, y_prob_label, k):
219 | precision = 0
220 | for i in range(len(y_gt)):
221 | TP = 0
222 | for j in y_prob_label[i][:k]:
223 | if y_gt[i, j] == 1:
224 | TP += 1
225 | precision += TP / k
226 | return precision / len(y_gt)
227 | # try:
228 | # auc = roc_auc(y_gt, y_prob)
229 | # except ValueError:
230 | # auc = 0
231 | # p_1 = precision_at_k(y_gt, y_label, k=1)
232 | # p_3 = precision_at_k(y_gt, y_label, k=3)
233 | # p_5 = precision_at_k(y_gt, y_label, k=5)
234 | f1 = f1(y_gt, y_pred)
235 | # prauc = precision_auc(y_gt, y_prob)
236 | ja = jaccard(y_gt, y_label)
237 | avg_prc = average_prc(y_gt, y_label)
238 | avg_recall = average_recall(y_gt, y_label)
239 | avg_f1 = average_f1(avg_prc, avg_recall)
240 |
241 | return ja, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)
242 |
243 | def multi_label_metric(y_gt, y_pred, y_prob):
244 |
245 | def jaccard(y_gt, y_pred):
246 | score = []
247 | for b in range(y_gt.shape[0]):
248 | target = np.where(y_gt[b] == 1)[0]
249 | out_list = np.where(y_pred[b] == 1)[0]
250 | inter = set(out_list) & set(target)
251 | union = set(out_list) | set(target)
252 | jaccard_score = 0 if union == 0 else len(inter) / len(union)
253 | score.append(jaccard_score)
254 | return np.mean(score)
255 |
256 | def average_prc(y_gt, y_pred):
257 | score = []
258 | for b in range(y_gt.shape[0]):
259 | target = np.where(y_gt[b] == 1)[0]
260 | out_list = np.where(y_pred[b] == 1)[0]
261 | inter = set(out_list) & set(target)
262 | prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
263 | score.append(prc_score)
264 | return score
265 |
266 | def average_recall(y_gt, y_pred):
267 | score = []
268 | for b in range(y_gt.shape[0]):
269 | target = np.where(y_gt[b] == 1)[0]
270 | out_list = np.where(y_pred[b] == 1)[0]
271 | inter = set(out_list) & set(target)
272 | recall_score = 0 if len(target) == 0 else len(inter) / len(target)
273 | score.append(recall_score)
274 | return score
275 |
276 | def average_f1(average_prc, average_recall):
277 | score = []
278 | for idx in range(len(average_prc)):
279 | if average_prc[idx] + average_recall[idx] == 0:
280 | score.append(0)
281 | else:
282 | score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
283 | return score
284 |
285 | def f1(y_gt, y_pred):
286 | all_micro = []
287 | for b in range(y_gt.shape[0]):
288 | all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
289 | return np.mean(all_micro)
290 |
291 | def roc_auc(y_gt, y_prob):
292 | all_micro = []
293 | for b in range(len(y_gt)):
294 | all_micro.append(roc_auc_score(y_gt[b], y_prob[b], average='macro'))
295 | return np.mean(all_micro)
296 |
297 | def precision_auc(y_gt, y_prob):
298 | all_micro = []
299 | for b in range(len(y_gt)):
300 | all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
301 | return np.mean(all_micro)
302 |
303 | def precision_at_k(y_gt, y_prob, k=3):
304 | precision = 0
305 | sort_index = np.argsort(y_prob, axis=-1)[:, ::-1][:, :k]
306 | for i in range(len(y_gt)):
307 | TP = 0
308 | for j in range(len(sort_index[i])):
309 | if y_gt[i, sort_index[i, j]] == 1:
310 | TP += 1
311 | precision += TP / len(sort_index[i])
312 | return precision / len(y_gt)
313 |
314 | # roc_auc
315 | try:
316 | auc = roc_auc(y_gt, y_prob)
317 | except:
318 | auc = 0
319 | # precision
320 | p_1 = precision_at_k(y_gt, y_prob, k=1)
321 | p_3 = precision_at_k(y_gt, y_prob, k=3)
322 | p_5 = precision_at_k(y_gt, y_prob, k=5)
323 | # macro f1
324 | f1 = f1(y_gt, y_pred)
325 | # precision
326 | prauc = precision_auc(y_gt, y_prob)
327 | # jaccard
328 | ja = jaccard(y_gt, y_pred)
329 | # pre, recall, f1
330 | avg_prc = average_prc(y_gt, y_pred)
331 | avg_recall = average_recall(y_gt, y_pred)
332 | avg_f1 = average_f1(avg_prc, avg_recall)
333 |
334 | return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)
335 |
336 | def ddi_rate_score(record, path='../data/ddi_A_final.pkl'):
337 | # ddi rate
338 | ddi_A = dill.load(open(path, 'rb'))
339 | all_cnt = 0
340 | dd_cnt = 0
341 | for patient in record:
342 | for adm in patient:
343 | med_code_set = adm
344 | for i, med_i in enumerate(med_code_set):
345 | for j, med_j in enumerate(med_code_set):
346 | if j <= i:
347 | continue
348 | all_cnt += 1
349 | if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
350 | dd_cnt += 1
351 | if all_cnt == 0:
352 | return 0
353 | return dd_cnt / all_cnt
354 |
355 |
356 | def create_atoms(mol, atom_dict):
357 | """Transform the atom types in a molecule (e.g., H, C, and O)
358 | into the indices (e.g., H=0, C=1, and O=2).
359 | Note that each atom index considers the aromaticity.
360 | """
361 | atoms = [a.GetSymbol() for a in mol.GetAtoms()]
362 | for a in mol.GetAromaticAtoms():
363 | i = a.GetIdx()
364 | atoms[i] = (atoms[i], 'aromatic')
365 | atoms = [atom_dict[a] for a in atoms]
366 | return np.array(atoms)
367 |
368 | def create_ijbonddict(mol, bond_dict):
369 | """Create a dictionary, in which each key is a node ID
370 | and each value is the tuples of its neighboring node
371 | and chemical bond (e.g., single and double) IDs.
372 | """
373 | i_jbond_dict = defaultdict(lambda: [])
374 | for b in mol.GetBonds():
375 | i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
376 | bond = bond_dict[str(b.GetBondType())]
377 | i_jbond_dict[i].append((j, bond))
378 | i_jbond_dict[j].append((i, bond))
379 | return i_jbond_dict
380 |
381 | def extract_fingerprints(radius, atoms, i_jbond_dict,
382 | fingerprint_dict, edge_dict):
383 | """Extract the fingerprints from a molecular graph
384 | based on Weisfeiler-Lehman algorithm.
385 | """
386 |
387 | if (len(atoms) == 1) or (radius == 0):
388 | nodes = [fingerprint_dict[a] for a in atoms]
389 |
390 | else:
391 | nodes = atoms
392 | i_jedge_dict = i_jbond_dict
393 |
394 | for _ in range(radius):
395 |
396 | """Update each node ID considering its neighboring nodes and edges.
397 | The updated node IDs are the fingerprint IDs.
398 | """
399 | nodes_ = []
400 | for i, j_edge in i_jedge_dict.items():
401 | neighbors = [(nodes[j], edge) for j, edge in j_edge]
402 | fingerprint = (nodes[i], tuple(sorted(neighbors)))
403 | nodes_.append(fingerprint_dict[fingerprint])
404 |
405 | """Also update each edge ID considering
406 | its two nodes on both sides.
407 | """
408 | i_jedge_dict_ = defaultdict(lambda: [])
409 | for i, j_edge in i_jedge_dict.items():
410 | for j, edge in j_edge:
411 | both_side = tuple(sorted((nodes[i], nodes[j])))
412 | edge = edge_dict[(both_side, edge)]
413 | i_jedge_dict_[i].append((j, edge))
414 |
415 | nodes = nodes_
416 | i_jedge_dict = i_jedge_dict_
417 |
418 | return np.array(nodes)
419 |
420 |
421 | def buildMPNN(molecule, med_voc, radius=1, device="cpu:0"):
422 |
423 | atom_dict = defaultdict(lambda: len(atom_dict))
424 | bond_dict = defaultdict(lambda: len(bond_dict))
425 | fingerprint_dict = defaultdict(lambda: len(fingerprint_dict))
426 | edge_dict = defaultdict(lambda: len(edge_dict))
427 | MPNNSet, average_index = [], []
428 |
429 | print (len(med_voc.items()))
430 | for index, ndc in med_voc.items():
431 |
432 | smilesList = list(molecule[ndc])
433 |
434 | """Create each data with the above defined functions."""
435 | counter = 0 # counter how many drugs are under that ATC-3
436 | for smiles in smilesList:
437 | try:
438 | mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
439 | atoms = create_atoms(mol, atom_dict)
440 | molecular_size = len(atoms)
441 | i_jbond_dict = create_ijbonddict(mol, bond_dict)
442 | fingerprints = extract_fingerprints(radius, atoms, i_jbond_dict,
443 | fingerprint_dict, edge_dict)
444 | adjacency = Chem.GetAdjacencyMatrix(mol)
445 | # if fingerprints.shape[0] == adjacency.shape[0]:
446 | for _ in range(adjacency.shape[0] - fingerprints.shape[0]):
447 | fingerprints = np.append(fingerprints, 1)
448 | fingerprints = torch.LongTensor(fingerprints).to(device)
449 | adjacency = torch.FloatTensor(adjacency).to(device)
450 | MPNNSet.append((fingerprints, adjacency, molecular_size))
451 | counter += 1
452 | except:
453 | continue
454 | average_index.append(counter)
455 |
456 | """Transform the above each data of numpy
457 | to pytorch tensor on a device (i.e., CPU or GPU).
458 | """
459 |
460 | N_fingerprint = len(fingerprint_dict)
461 |
462 | # transform into projection matrix
463 | n_col = sum(average_index)
464 | n_row = len(average_index)
465 |
466 | average_projection = np.zeros((n_row, n_col))
467 | col_counter = 0
468 | for i, item in enumerate(average_index):
469 | average_projection[i, col_counter : col_counter + item] = 1 / item
470 | col_counter += item
471 |
472 | return MPNNSet, N_fingerprint, torch.FloatTensor(average_projection)
473 |
474 |
475 |
476 | def output_flatten(labels, logits, seq_length, m_length_matrix, med_num, END_TOKEN, device, training=True, testing=False, max_len=20):
477 | '''
478 | labels: [batch_size, visit_num, medication_num]
479 | logits: [batch_size, visit_num, max_med_num, medication_vocab_size]
480 | '''
481 | # 将最终多个维度的结果展开
482 | batch_size, max_seq_length = labels.size()[:2]
483 | assert max_seq_length == max(seq_length)
484 | whole_seqs_num = seq_length.sum().item()
485 | if training:
486 | whole_med_sum = sum([sum(buf) for buf in m_length_matrix]) + whole_seqs_num # 因为每一个seq后面会多一个END_TOKEN
487 |
488 | # 将结果展开,然后用库函数进行计算
489 | labels_flatten = torch.empty(whole_med_sum).to(device)
490 | logits_flatten = torch.empty(whole_med_sum, med_num).to(device)
491 |
492 | start_idx = 0
493 | for i in range(batch_size): # 每个batch
494 | for j in range(seq_length[i]): # seq_length[i]指这个batch对应的seq数目
495 | for k in range(m_length_matrix[i][j]+1): # m_length_matrix[i][j]对应seq中med的数目
496 | if k==m_length_matrix[i][j]: # 最后一个label指定为END_TOKEN
497 | labels_flatten[start_idx] = END_TOKEN
498 | else:
499 | labels_flatten[start_idx] = labels[i, j, k]
500 | logits_flatten[start_idx, :] = logits[i, j, k, :]
501 | start_idx += 1
502 | return labels_flatten, logits_flatten
503 | else:
504 | # 将结果按照adm展开,然后用库函数进行计算
505 | labels_flatten = []
506 | logits_flatten = []
507 |
508 | start_idx = 0
509 | for i in range(batch_size): # 每个batch
510 | for j in range(seq_length[i]): # seq_length[i]指这个batch对应的seq数目
511 | labels_flatten.append(labels[i,j,:m_length_matrix[i][j]].detach().cpu().numpy())
512 |
513 | if testing:
514 | logits_flatten.append(logits[j]) # beam search目前直接给出了预测结果
515 | else:
516 | logits_flatten.append(logits[i,j,:max_len,:].detach().cpu().numpy()) # 注意这里手动定义了max_len
517 | # cur_label = []
518 | # cur_seq_length = []
519 | # for k in range(m_length_matrix[i][j]+1): # m_length_matrix[i][j]对应seq中med的数目
520 | # if k==m_length_matrix[i][j]: # 最后一个label指定为END_TOKEN
521 | # continue
522 | # else:
523 | # labels_flatten[start_idx] = labels[i, j, k]
524 | # logits_flatten[start_idx, :] = logits[i, j, k, :med_num]
525 | # start_idx += 1
526 | return labels_flatten, logits_flatten
527 |
528 |
529 | def print_result(label, prediction):
530 | '''
531 | label: [real_med_num, ]
532 | logits: [20, med_vocab_size]
533 | '''
534 | label_text = " ".join([str(x) for x in label])
535 | predict_text = " ".join([str(x) for x in prediction])
536 |
537 | return "[GT]\t{}\n[PR]\t{}\n\n".format(label_text, predict_text)
538 |
--------------------------------------------------------------------------------
/data/drug-atc.csv:
--------------------------------------------------------------------------------
1 | CID000003015,G03HA01
2 | CID000004011,N06AA21
3 | CID000071273,N01BB09
4 | CID000062816,C10AC02
5 | CID000052421,D07AC18,V01AA07,V01AA03
6 | CID000056339,N03AB05
7 | CID000077992,N02CC07
8 | CID000005300,L01AD04
9 | CID000222786,H02AB10,S01BA03
10 | CID000071301,C07AB12
11 | CID000004870,C03AA05
12 | CID000004873,A12BA01,B05XA01
13 | CID000002142,D06AX12,J01GB06,S01AA21
14 | CID000002141,V03AF05
15 | CID000131536,J05AE07
16 | CID000001775,N03AB02,N03AB04,N03AB05
17 | CID000003305,M05BA01,M05BB01
18 | CID000005454,N05AF04
19 | CID000003308,M01AB08
20 | CID000004909,D08AC04,S03AA05,S01AX08,R01AX07,R02AA18,N03AA03
21 | CID000004908,P01BA03
22 | CID000002099,A03AE01
23 | CID000054688,J01FA09
24 | CID000002092,G04CA01
25 | CID000005503,A10BB05
26 | CID000004536,G03AC01
27 | CID006918453,M01CB03
28 | CID000003706,J05AE02
29 | CID000005408,G03BA02,G03BA03,G03EK01
30 | CID000004497,C08CA06
31 | CID000004493,L02BB02
32 | CID000003219,S01GX06
33 | CID000002435,S01EA05
34 | CID000003121,N03AG01
35 | CID000040159,P03AC04
36 | CID000005878,A14AA08
37 | CID000002554,N03AF01
38 | CID000002550,C09AA01
39 | CID000002551,N07AB01,S01EB02
40 | CID000010100,N02AC04,N02AC54,N02AC74
41 | CID000002559,J01CA03
42 | CID000443871,A10BX03
43 | CID000003059,N02BA11
44 | CID000019090,G03AC05,G03DB02,L02AB01,G03DB04
45 | CID000059768,C07AB09
46 | CID004183806,A02BB01
47 | CID000176870,L01XE03
48 | CID000004473,C08CA04
49 | CID000003440,C03CA01
50 | CID000003449,N06DA04
51 | CID000003685,L01DB06
52 | CID000003937,C09AA03
53 | CID004630253,D07AB10,S01BA10
54 | CID000018140,L01XX11
55 | CID000000727,P03AB02,D08AE01
56 | CID004475485,C10AA07
57 | CID000077993,N02CC06
58 | CID000148211,A04AA05
59 | CID000072938,J01MA15
60 | CID000002153,R03DA04,R03DA05
61 | CID000002156,C01BD01
62 | CID000008953,B05XA02,B05CB04
63 | CID000005203,N06AB06
64 | CID000005206,N01AB08
65 | CID005281104,H05BX02
66 | CID000004917,N05AB04
67 | CID000004914,C05AD05,D04AB03,J01CE09,N01BA02,N01BA04,S01HA02,S01HA05,C01BA02
68 | CID000004913,C01BA02
69 | CID000003339,C10AB05
70 | CID000004911,M04AB01
71 | CID000213039,J05AE10
72 | CID006436173,A07AA11,D06AX11
73 | CID000002088,M05BA04
74 | CID000002082,P02CA03
75 | CID000005515,L01XX17
76 | CID000005516,L02BA02
77 | CID000002145,L02BG01
78 | CID000005512,G04BD07
79 | CID000005735,N05CF01
80 | CID000005734,N03AX15
81 | CID000005731,N02CC03
82 | CID000005732,N05CF02
83 | CID000003203,J05AG03
84 | CID000002140,V08AA01,V08AA04
85 | CID000004485,C08CA05
86 | CID000060707,J01DH02
87 | CID000003117,N07BB01,P03AA04
88 | CID000003114,C01BA03
89 | CID000004645,D06AA03,G01AA07,J01AA06,S01AA04
90 | CID000038904,L01XA02
91 | CID000060147,G04CA02
92 | CID000004674,M05BA03
93 | CID000002522,D05AX02
94 | CID000002524,A11CC04,D05AX03
95 | CID000104865,C02KX01
96 | CID004659568,N04BX02
97 | CID004659569,N04BX01
98 | CID000003062,C01AA02,C01AA05,C01AA08,V03AB24
99 | CID000003066,N02CA01
100 | CID000154059,G04BD08
101 | CID000004889,C10AA03
102 | CID000004264,D06AX09,R01AX06
103 | CID000000750,B03AA01,B05CX03
104 | CID000002637,J01DC01
105 | CID000002631,J01DD01
106 | CID000160051,C10AC04
107 | CID000001986,S01EC01
108 | CID000003454,J05AB06,J05AB14,S01AD09
109 | CID000003690,L01AA06
110 | CID000151165,A04AD12
111 | CID000003696,N06AA02,N06AA03,N06AA06
112 | CID000003698,C01CE01
113 | CID000003929,J01XX08
114 | CID000004195,C01CA17
115 | CID000002949,G03XA01
116 | CID000004078,N05AC03
117 | CID000004900,A07EA03,H02AB07,H02AB15
118 | CID000005593,S01FA06
119 | CID000005452,N05AC02
120 | CID000004107,M03BA03
121 | CID000071329,C01BD04
122 | CID000005212,G04BE03
123 | CID000005210,A08AA10
124 | CID000005215,D06BA01,J01EC02
125 | CID000004920,G03AC06,G03DA02,G03DA03,G03DA04,L02AB02
126 | CID000004927,D04AA10,R06AD02,R06AD05
127 | CID000008612,N01BA04
128 | CID000005430,D01AC06,P02CA02
129 | CID000003324,J05AB09,S01AD07
130 | CID000003325,A02BA03
131 | CID000124087,R06AX27
132 | CID000004856,M01AC01,M02AA07,S01BC06
133 | CID000005508,M01AB03,M02AA21
134 | CID000005504,C04AB02,M02AX02
135 | CID000005505,A10BB03,V04CA01
136 | CID000054841,N06BA09
137 | CID000005726,J05AF01
138 | CID000005721,J05AH01
139 | CID000002182,L01XX35
140 | CID000002187,L02BG03
141 | CID000004993,P01BD01
142 | CID000003100,D04AA32,D04AA33,R06AA02
143 | CID000000861,H03AA02
144 | CID000003105,S01EA02
145 | CID000003108,B01AC07,C01DX20
146 | CID000004675,M03AC01
147 | CID003468412,A07AA02,D01AA01,G01AA01
148 | CID000003075,C08DB01
149 | CID000068844,S01EC04
150 | CID000060714,V08CA04
151 | CID009571074,J01DD16
152 | CID000002583,C07AA15,S01ED05
153 | CID000004451,J05AE04
154 | CID000002512,G02CB03,N04BC06
155 | CID000002629,J01DD12
156 | CID000000137,L01XD04
157 | CID000041317,D05BB02
158 | CID000002622,J01DE01
159 | CID000062959,J01MA13
160 | CID000002995,N06AA01
161 | CID000002487,N02AF01
162 | CID000003467,D06AX07,J01GB03,S01AA11,S02AA14,S03AA06
163 | CID000003463,C10AB04
164 | CID000003461,L01BC05
165 | CID000003915,R01AC02,S01GX02
166 | CID000003914,S01ED03
167 | CID000003661,A03BA01,A03BB02,N04AC01,N04AC30,S01FA01,S01FA05,A03BA03
168 | CID000039042,C10AB02
169 | CID000004062,N01BB03
170 | CID000004060,N03AB04
171 | CID000004064,N05BC01
172 | CID000002249,C07AB03,C07AB11
173 | CID000152945,G04CB02
174 | CID000004112,L01BA01,L04AX03
175 | CID011947681,C02DD01
176 | CID000002658,J01DC02
177 | CID005329102,L01XE04
178 | CID000004934,A03AB05
179 | CID000004935,S01HA04
180 | CID000004932,C01BC03
181 | CID000002655,J01DD07
182 | CID000039860,A04AD11
183 | CID000005538,D10AD01,D10AD51,L01XX14,D10BA01,D10AD04,L01XX22
184 | CID000004845,R03AC08,R03CC07
185 | CID000005533,N06AX05
186 | CID000005717,R03DC01
187 | CID000005719,N05CF03
188 | CID000005718,J05AF03
189 | CID000004201,C02DC01,D11AX01
190 | CID000005155,J05AF04
191 | CID000005404,G01AG02
192 | CID000005403,R03AC03,R03CC03
193 | CID000005402,D01AE15,D01BA02
194 | CID000005401,G04CA03
195 | CID000005152,R03AC12
196 | CID000004666,L01CD01
197 | CID000002160,N06AA09
198 | CID000048041,C01BC08
199 | CID000000815,A07AA08,J01GB04,S01AA24
200 | CID000002441,N05BA08
201 | CID000060169,M03AC07
202 | CID000004595,A04AA01
203 | CID000060164,D10AD03
204 | CID000004599,A08AB01
205 | CID000002431,C01BD02
206 | CID000003000,D07AC03,D07XC02
207 | CID000003003,A01AC02,C05AA09,D07AB19,D07XB05,D10AA03,H02AB02,R01AD03,S01BA01,S01CB01,S02BA06,S03BA01
208 | CID000003007,N06BA01,N06BA02
209 | CID000438204,D06BA01
210 | CID000007029,A08AA03
211 | CID000123620,D07AC13,D07XC03,R01AD09,R03BA07
212 | CID003002190,J01FA15
213 | CID000004449,N06AX06
214 | CID000002617,J01DB04
215 | CID000104758,L01BA03
216 | CID000003478,A10BB07
217 | CID000023897,N05AE02
218 | CID000163742,B01AX05
219 | CID000003475,A10BB09
220 | CID000003476,A10BB12
221 | CID000002891,B03BB01,B03BA01
222 | CID000003676,C01BB01,C05AD01,D04AB01,N01BB02,R02AD02,S01HA07,S02DA01,A01AD11
223 | CID000003902,L02BG04
224 | CID000003675,N06AF03
225 | CID000003672,C01EB16,G02CC01,M01AE01,M01AE14,M02AA13
226 | CID000004058,N02AB02
227 | CID000004057,A03AB12
228 | CID000004054,N06DX01
229 | CID000004053,L01AA03
230 | CID000004052,M01AC06
231 | CID000003877,J05AF05
232 | CID000091270,C09AA13
233 | CID000003878,N03AX09
234 | CID000004121,C03AA08
235 | CID000065027,J05AE09
236 | CID000003494,A03AB02
237 | CID005311048,L01XA01
238 | CID000004409,M01AX01
239 | CID000016362,N05AG02
240 | CID000003869,C07AG01
241 | CID000004943,N01AX10
242 | CID000004946,C07AA05
243 | CID000005090,M01AH02
244 | CID000005525,C09AA10
245 | CID000005526,B02AA02
246 | CID000130881,C09CA08
247 | CID000009433,R03DA05
248 | CID000093860,L01XX32
249 | CID000003278,C03CC01
250 | CID000003279,J04AK02
251 | CID000072057,V08CA09
252 | CID000072054,G04BD10
253 | CID000005412,A01AB13,D06AA04,J01AA07,S01AA09,S02AA08,S03AA02,A01AB21,S01AA02,J01AA06,J01AA09,S01AA04,G01AA07,D06AA02,D06AA03,J01AA03
254 | CID000004614,M01AE12
255 | CID000004616,N05BA04
256 | CID000002477,N05BE01
257 | CID000002476,N02AE01,N07BC01
258 | CID000002471,C03CA02
259 | CID000002274,J01DF01
260 | CID000001065,C01BA01
261 | CID000002519,N06BC01,N02BE01
262 | CID000003928,J01FF02
263 | CID000004724,C07AA23
264 | CID000004723,N06BA05
265 | CID000003019,C02DA01,V03AH01
266 | CID000000937,C04AC01,C10AD02
267 | CID000003016,N05BA01,N05BA17
268 | CID000122316,N04BD02
269 | CID000002609,J01DC04
270 | CID000005071,J05AC02
271 | CID000003702,C03BA11
272 | CID000150310,C03DA04
273 | CID000150311,C10AX09
274 | CID000004436,R01AA08,S01GA01,R01AB02
275 | CID000004100,S01EC05
276 | CID000004547,A10BX02
277 | CID000003648,N02AA03
278 | CID000004543,N06AA10
279 | CID000004542,G03AC03,G03AD01
280 | CID000003647,C03AA02
281 | CID000060198,L02BG06
282 | CID000004043,S01BA08
283 | CID000004044,M01AG01
284 | CID000004046,P01BC02
285 | CID000003488,A10BB01
286 | CID000002720,C03AA03,C03AA04
287 | CID000002726,N05AA01
288 | CID000004138,C02AB01,C02AB02
289 | CID000002283,D06AX05,R02AB04,J01XX10
290 | CID000002284,M03BX01
291 | CID000005556,N05CD05
292 | CID000005396,L01CB02
293 | CID000005394,L01AX03
294 | CID000005391,N05CD07
295 | CID005311297,R03DC03
296 | CID000034633,C01BG01
297 | CID000005775,C04AB01,V03AB36
298 | CID000003241,R06AX24,S01GX10
299 | CID000667490,L01BB02
300 | CID000060835,N06AX21
301 | CID000060831,R03BB04
302 | CID000004603,J05AH02
303 | CID000004601,M03BC01,N04AB02
304 | CID000004607,J01CF01,J01CF02,J01CF04,J01CF05
305 | CID000004609,L01XA03
306 | CID000002462,A07EA06,D07AC09,R01AD05,R03BA02
307 | CID001349907,H03BB02
308 | CID000000838,A01AD01,B02BC09,C01CA03,C01CA24,R01AA14,R03AA01,S01EA01
309 | CID000030623,V03AF02
310 | CID000002266,D10AX03
311 | CID000002267,R01AC03,R06AX19,S01GX07
312 | CID000002265,L04AX01
313 | CID000003393,N05CD01
314 | CID000003392,D07AC07
315 | CID000003394,M01AE09,M02AA19,S01BC04
316 | CID000003397,L02BB01
317 | CID000002269,J01FA10,S01AA26
318 | CID000041781,C03CA01,C03CA04
319 | CID000004739,A03AB03
320 | CID000004736,N02AD01
321 | CID000004737,N05CA01
322 | CID000060696,M03AC09
323 | CID000004730,J01CE02,J01CE10
324 | CID000096312,L01BB07
325 | CID003085017,V03AE02
326 | CID000002673,J01DB09
327 | CID000002675,J01DD08
328 | CID000002676,C10AA06
329 | CID000002678,R06AE07,R06AE09
330 | CID000003730,V08AB02
331 | CID000000232,B05XB01
332 | CID000003657,L01XX05
333 | CID000003652,P01BA02
334 | CID000003658,N05BB01
335 | CID000002708,L01AA02
336 | CID000000581,R05CB01,S01XA08,V03AB23
337 | CID000028112,N05AN01,D11AX04
338 | CID000002818,N05AH02
339 | CID000068740,M05BA08
340 | CID000004091,A10BA02,A10BD11
341 | CID000004095,N02AC52,N07BC02,R05DA06
342 | CID000005546,C03DB02
343 | CID000005544,A01AC01,D07AB09,D07XB02,H02AB08,R01AD11,R03BA06,S01BA05,C05AA12
344 | CID000040976,G03AC08
345 | CID000004197,C01CE02
346 | CID000004893,C02CA01
347 | CID000004891,P02BA01
348 | CID000031477,S01ED04
349 | CID000016231,C03DB01
350 | CID000004894,A07EA01,C05AA04,D07AA01,D07AA03,D07AC14,D07XA02,D10AA02,H02AB04,H02AB06,R01AD02,S01BA04,S01CB02,S02BA03,S03BA02,H02AB07
351 | CID000002381,N04AA02
352 | CID000003261,N05CD04
353 | CID000003251,N02CA01,N02CA02
354 | CID000003255,D10AF02,J01FA01,S01AA17
355 | CID000050614,J01DC05
356 | CID000004635,N02AA05
357 | CID000004634,G04BD04
358 | CID000004428,N07BB04
359 | CID005282226,S01EE04
360 | CID000054454,C10AA01
361 | CID000057469,D06BB10
362 | CID000003382,C05AA11,D07AC08
363 | CID000003381,C05AA10,D07AC04,S01BA15,S02BA08
364 | CID000003387,G03BA01
365 | CID000003384,C05AA06,D07AB06,D07XB04,D10AA01,S01BA07,S01CB05
366 | CID000001046,J04AK01
367 | CID000003739,V08AC04
368 | CID000056959,C01EB18
369 | CID000000085,A16AA01
370 | CID000003032,M01AB05,M02AA15,S01BC03,D11AX18
371 | CID000060753,C01BD05
372 | CID000060754,V08CA03
373 | CID000002662,L01XX33,M01AH01
374 | CID000002666,J01DB01
375 | CID000004419,N02AF02
376 | CID006323497,J04AB05
377 | CID000003724,V08AB09
378 | CID000054547,J01DD13
379 | CID000002719,P01BA01,P01BA02
380 | CID000002712,N05BA02
381 | CID000000596,L01BC01
382 | CID000000772,B01AB01,C05BA01,C05BA03,S01XA09,S01XA14,B01AB05
383 | CID000003827,R06AX17,S01GX08
384 | CID000003826,M01AB15,S01BC05
385 | CID000003825,M01AE03,M01AE17,M02AA10
386 | CID000003823,D01AC08,G01AF11,J02AB02
387 | CID000003821,N01AX03,N01AX14
388 | CID000004140,G02AB01
389 | CID000002803,C02AC01,N02CX02,S01EA04,S01EA03
390 | CID000002806,B01AC04
391 | CID000062867,N05AC04
392 | CID000004158,N06BA04
393 | CID000004159,D07AA01,D07AC14,D10AA02,H02AB04
394 | CID000115237,N05AX08
395 | CID000013342,L01CA01
396 | CID000004259,J01MA14,S01AE07
397 | CID004479097,B03BA03,V03AB33
398 | CID000004086,R03CB03,R03AB03
399 | CID000005578,J01EA01
400 | CID000005572,N04AA01
401 | CID000005379,J01MA16,S01AE06
402 | CID000005376,L02BA01
403 | CID000005596,G04BD09
404 | CID000005591,A10BG01
405 | CID000060815,N01AH06
406 | CID000004976,N06AA11
407 | CID000002443,G02CB01,N04BC01
408 | CID005353894,V03AB04
409 | CID000005625,J05AG02
410 | CID000004623,D01AC11,G01AF17
411 | CID000002244,A01AD05,B01AC06,N02BA01
412 | CID000047319,M03AC04,M03AC11
413 | CID000170361,N07BA03
414 | CID000003375,C05AA10,D07AC04
415 | CID000087177,A06AD11
416 | CID000037393,P01BX01
417 | CID000041693,N01AH03
418 | CID000005978,L01CA02
419 | CID000060787,J05AE01
420 | CID000002656,J01DD04
421 | CID000002650,J01DD02
422 | CID000003759,N06AF01
423 | CID000004594,A02BC01,A02BC05
424 | CID000003750,L01XX19
425 | CID000065999,C09CA07
426 | CID000005193,N05CA06
427 | CID000157922,L01XD03
428 | CID000002951,M03CA01
429 | CID000003637,C02DB02,C02DB01
430 | CID000003639,C03AA03
431 | CID000054374,H01CB02
432 | CID000002727,A10BB02
433 | CID000002725,R06AB02,R06AB04
434 | CID000003767,J04AC01
435 | CID000004915,L01XB01
436 | CID000005582,P01AX07
437 | CID000004189,A01AB09,A07AC01,D01AC02,G01AF04,J02AB01,S02AA13
438 | CID000003958,N05BA06
439 | CID000000143,V03AF03
440 | CID000036339,N01AX07
441 | CID000004168,A03FA01
442 | CID000004163,N02CA04
443 | CID000004160,G03BA02,G03EK01
444 | CID000002812,A01AB18,D01AC01,G01AF02
445 | CID000002907,L01AA01,L01DB07
446 | CID000002905,S01FA04
447 | CID006398970,J01DD15
448 | CID000002019,L01DA01
449 | CID000005291,L01XE01
450 | CID000005297,A07AA04,J01GA01,S01AA15
451 | CID000002913,R06AX02
452 | CID000005566,N05AB06
453 | CID000002083,R03CC02,R03AC02
454 | CID000158440,V08CA11
455 | CID000005584,N06AA06
456 | CID000016850,S01JA01
457 | CID000002123,L01XX03
458 | CID000034312,N03AF02
459 | CID000002232,B01AE03
460 | CID000005035,G03XC01
461 | CID000039507,H02AB12,S01BA13
462 | CID000005746,L01DC03
463 | CID000005038,C09AA05
464 | CID000005039,A02BA02,A02BA07
465 | CID000004768,C04AX02
466 | CID000002344,N04AC01
467 | CID000002349,J01CE01,S01AA14,J01CE02,J01CE10
468 | CID000003291,N03AD01
469 | CID000003715,M01AB01,C01EB03
470 | CID000072467,L01DC01
471 | CID000050294,R01AC07,R03BC03,S01GX04
472 | CID000002646,J01DC10
473 | CID000003749,C09CA04
474 | CID000003746,R01AX03,R03BB01
475 | CID000002610,J01DB05
476 | CID000004440,N02CC02
477 | CID000002732,C03BA04
478 | CID000064147,J05AB14
479 | CID000004509,J01XE01
480 | CID000444033,R03BA08
481 | CID000004506,N05CD03,N05CD02
482 | CID005311039,J02AX04
483 | CID000003519,C02AC02
484 | CID000003518,C02CC02,S01EX01
485 | CID000004211,L01XX23
486 | CID000003512,D01AA08,D01BA01
487 | CID000003510,A04AA02
488 | CID000000159,B01AC09
489 | CID000000158,G02AD02
490 | CID000004178,C01BB02
491 | CID005481350,J05AF07
492 | CID000001935,N06AA18,N06DA01
493 | CID000004170,C03BA08
494 | CID000004171,C07AB02,C07AB52
495 | CID000004173,A01AB17,D06BX01,G01AF01,J01XD01,P01AB01
496 | CID000027686,R01AC01,A07EB01,S01GX01,R03BC01,D11AH03
497 | CID000003226,N01AB04
498 | CID000150610,J01DH03
499 | CID000071158,N07BB03
500 | CID000657298,H03BA02
501 | CID000006691,B01AA03
502 | CID005311181,B01AC11
503 | CID005381226,J04AB02
504 | CID000036811,C01CA07
505 | CID000004196,G03XB01
506 | CID000005359,M01AE07
507 | CID000004192,N05CD08
508 | CID000005352,M01AB02
509 | CID000060184,C09AA04
510 | CID000002133,D07AC11
511 | CID000002130,N04BB01
512 | CID000002131,N07AA30
513 | CID000060879,C09CA02
514 | CID000060871,J05AF08
515 | CID000004991,N07AA02
516 | CID000005645,A05AA02
517 | CID000005647,J05AB11
518 | CID000002171,J01CA04
519 | CID000003355,C01BC04
520 | CID000003354,G04BD02
521 | CID000051263,N01AH02
522 | CID000003350,D11AX10,G04CB01
523 | CID000005002,N05AH04
524 | CID000005005,C09AA06
525 | CID000004771,A08AA01
526 | CID005353980,A07EC01
527 | CID000000942,N07BA01,A11HA01,C04AC01,C10AD02
528 | CID006447131,D11AH02
529 | CID000003292,N03AB01
530 | CID000003779,C01CA02,D08AX05
531 | CID000110635,G04BE08
532 | CID000110634,G04BE09
533 | CID000041744,L01DB09
534 | CID000477468,J02AX05
535 | CID000002749,D01AE14,G01AX12
536 | CID000002585,C07AG02
537 | CID000145068,R07AX01
538 | CID000004510,C01DA02
539 | CID000048175,D07AC21
540 | CID000004205,N06AX11
541 | CID000004200,A01AB23,J01AA08
542 | CID000005381,D05AX05
543 | CID005361912,J04AB04
544 | CID000002520,C08DA01
545 | CID000039765,M03AC03
546 | CID000003403,C10AA04
547 | CID000042113,N01AB07
548 | CID000005344,J01EB05,S01AB02
549 | CID000005342,M04AB02
550 | CID000004834,J01CA12
551 | CID000123606,N02CC05
552 | CID000060865,S01GX09,R01AC08
553 | CID000005650,C09CA03
554 | CID000005651,A07AA09,J01XA01
555 | CID006435110,P02CF01
556 | CID000002216,S01EA03
557 | CID000002215,G04BE07,N04BC07
558 | CID000003348,R06AX26
559 | CID000003342,M01AE04
560 | CID000003340,C01CA19
561 | CID000006058,A16AA04
562 | CID000004748,N05AB03
563 | CID000004740,C04AD03
564 | CID000000951,C01CA03
565 | CID000002366,N07CA01
566 | CID000002369,C07AB05,S01ED02
567 | CID000025419,M05BA02
568 | CID000006476,N03AD03
569 | CID000003763,N01AB06
570 | CID000059708,N03AX14
571 | CID000003161,A01AB22,J01AA02
572 | CID000003168,N05AD08
573 | CID000000888,A12CC08,B05XA11,A06AD03,B05XA10,A06AD19,A02AA04,A12CC09,A02AA01,A06AD02,A12CC10,A12CC01,A12CC04,B05CB03,A12CC07,G04BX01,A12CC03,A12CC02,A02AA03,A06AD04,A02AA02,A12CC05,A02AA05,A06AD01,C10AX07,A12CC06,A12CC30,D11AX05,V04CC02,B05XA05
574 | CID000041774,A10BF01
575 | CID000002733,M03BB03
576 | CID000000401,J04AB01
577 | CID000002751,C09AA08
578 | CID000002756,A02BA01
579 | CID005281007,L01AX04
580 | CID000005358,N02CC01
581 | CID000444013,V08CA06
582 | CID000004463,J05AG01
583 | CID000004236,N06BA07
584 | CID000104741,L02BA03
585 | CID000003883,A02BC03
586 | CID000003405,B03BB01
587 | CID000003406,V03AB34
588 | CID000002978,A04AD10
589 | CID000002973,V03AC01
590 | CID000002250,C10AA05
591 | CID002761171,J04AD03
592 | CID000002021,J01XX04
593 | CID000002022,D06BB03,J05AB01,S01AD03,J05AB11
594 | CID003000502,J04AB30
595 | CID000003385,L01BC02
596 | CID000002118,N05BA12
597 | CID000004828,C07AA03,C07AA14,C07AA17
598 | CID000004829,A10BG03
599 | CID000001125,A16AX07
600 | CID000123631,L01XE02
601 | CID000004919,N04AA04
602 | CID000060852,M05BA06
603 | CID000005487,M03BX02
604 | CID000005486,B01AC17
605 | CID000047320,M03AC11
606 | CID000005245,M05BA07
607 | CID000005466,N03AG06
608 | CID000003379,R01AD04,R03BA03
609 | CID000003373,V03AB25
610 | CID000003372,N05AB02
611 | CID006398525,A02BX02
612 | CID000005064,J05AB04
613 | CID000005453,L01AC01
614 | CID000002375,L02BB03
615 | CID000004212,L01DB07
616 | CID000002370,N07AB02
617 | CID000002983,D06AA01,J01AA01
618 | CID000003793,J02AC02
619 | CID000003151,A03FA03
620 | CID000004689,A07AA06
621 | CID000003157,C02CA04
622 | CID000003156,R07AB01
623 | CID000003158,N06AA12
624 | CID000082146,L01XX25
625 | CID000002769,A03FA02
626 | CID000004539,J01MA06,S01AE02
627 | CID000002762,J01MB06
628 | CID000002564,R06AA08
629 | CID000003562,N01AB01
630 | CID000000187,S01EB09
631 | CID000060854,N05AE04
632 | CID000003899,L04AA13
633 | CID000003890,S01EE01
634 | CID000156419,H05BX01
635 | CID000003419,C09AA09
636 | CID000003417,J01XX01
637 | CID000003414,J05AD01
638 | CID000003410,R03AC13
639 | CID000002786,D10AF01,G01AA10,J01FF01
640 | CID000002781,R06AA04,D04AA14
641 | CID000004075,A07EC02
642 | CID000003964,N05AH01
643 | CID000003962,C10AA02
644 | CID000003961,C09CA01
645 | CID000027991,H01BA02
646 | CID000001546,L01BB04
647 | CID000004036,M01AG04,M02AA18
648 | CID000004030,P02CA01
649 | CID000004032,C02BB01
650 | CID000003333,C08CA02
651 | CID000002162,C08CA01
652 | CID000004411,C07AA12
653 | CID000004819,N07AX01,S01EB01
654 | CID005311027,S01EE03
655 | CID000060843,L01BA04
656 | CID000005496,J01GB01,S01AA12
657 | CID000009034,N01AF01,N05CA15
658 | CID000005253,C07AA07,C07AA57
659 | CID000003365,D01AC15,J02AC01
660 | CID000003366,D01AE21,J02AX01
661 | CID000003367,L01BB05
662 | CID000005472,B01AC05
663 | CID000005478,C07AA06,S01ED01
664 | CID000005479,J01XD02,P01AB02
665 | CID005282044,J01AA12
666 | CID000005078,N02CC04
667 | CID000060937,M05BA05
668 | CID000005076,J05AE03
669 | CID000002308,A07EA07,D07AC15,R01AD01,R03BA01
670 | CID000005672,L01CA04
671 | CID000047725,L02AE03
672 | CID000010548,S01EB03
673 | CID000003784,C08CA03
674 | CID000003780,C01DA08,D03AX08
675 | CID000003148,A04AA04
676 | CID000003143,L01CD02
677 | CID000006256,S01AD02
678 | CID000001206,N06BA03
679 | CID000014888,L01XX27
680 | CID000002578,L01AD01
681 | CID000003559,N05AD01
682 | CID000000191,C01EB10
683 | CID000004253,G04BE07,N02AA01,N02AA04,N02AA09,N04BC07,R05DA01,R05DA05,S01XA06,D10AX30
684 | CID000051577,A10BF02
685 | CID000001972,A01AB04,A07AA07,G01AA03,J02AA01
686 | CID000001971,J05AF06
687 | CID000001978,C07AB04
688 | CID000054786,B01AC21
689 | CID000032797,D07AD01
690 | CID000047472,G01AF15
691 | CID000002955,J04BA02,D10AX05
692 | CID000002958,L01DB02
693 | CID000002794,J04BA01
694 | CID000003950,L01AD02
695 | CID000003957,R06AX13,R06AX27,C09AA03
696 | CID000003954,A07DA03,A07DA05
697 | CID005362070,A07EC04
698 | CID000125017,N06AX23
699 | CID000002474,N01BB01,N01BB10
700 | CID000004513,A02BA04
701 | CID000002713,A01AB03,B05CA02,D08AC02,D09AA12,R02AA05,S01AX09,S02AA09,S03AA04,D08AE02
702 | CID000005314,M03AB01
703 | CID000000453,A06AD16,B05BC01,B05CX04,V08AC04
704 | CID000077999,A10BG02
705 | CID000002478,L01AB01
706 | CID000002173,J01CA01,J01CA02,J01CA06,J01CA14,J01CA15,S01AA19
707 | CID000002170,N06AA17
708 | CID000002177,J05AE05,J05AE07
709 | CID000002179,L01XX01
710 | CID000003222,C09AA02
711 | CID000001148,N06AX02
712 | CID000005267,C03DA01
713 | CID000031378,H02AA02
714 | CID000003310,L01CB01
715 | CID000005052,C02AA02
716 | CID000005040,L04AA10
717 | CID000074989,P01AX06
718 | CID000005665,N03AG04
719 | CID000002311,C09AA07
720 | CID000002315,C03AA01
721 | CID005473385,L01BB03
722 | CID000104803,M03AC10
723 | CID000002405,C07AB07
724 | CID000000853,C10AX01,H03AA01
725 | CID000060612,N05CM18
726 | CID000060613,J05AB12
727 | CID000504578,A01AB08,A07AA01,B05CA09,D06AX04,J01GB05,R02AB01,S01AA03,S02AA07,S03AA01,D09AA01,S01AA07
728 | CID005487301,A06AX06
729 | CID003062316,L01XE06
730 | CID000000298,D06AX02,D10AF03,G01AA05,J01BA01,S01AA01,S02AA01,S03AA08
731 | CID000003040,J01CF01
732 | CID000003043,J05AF02
733 | CID000003042,A03AA07
734 | CID000002541,C09CA06
735 | CID000083786,M04AA01
736 | CID000002895,M03BX08
737 | CID005493444,C09XA02
738 | CID000004727,M01CC01
739 | CID005362420,S01LA01
740 | CID000027661,C01DA14
741 | CID000000450,G03CA01,G03CA03,L02AA02,L02AA03
742 | CID000000681,C01CA04
743 | CID000002802,N03AE01
744 | CID000002800,G03GB02
745 | CID000003948,J01MA07,S01AE04
746 | CID000004425,V03AB15
747 |
--------------------------------------------------------------------------------
/src/COGNet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch._C import device
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 | import numpy as np
7 | from torch.nn.modules.linear import Linear
8 | from data.processing import process_visit_lg2
9 |
10 | from layers import SelfAttend
11 | from layers import GraphConvolution
12 |
13 |
14 | class COGNet(nn.Module):
15 | """在CopyDrug_batch基础上将medication的encode部分修改为transformer encoder"""
16 | def __init__(self, voc_size, ehr_adj, ddi_adj, ddi_mask_H, emb_dim=64, device=torch.device('cpu:0')):
17 | super(COGNet, self).__init__()
18 | self.voc_size = voc_size
19 | self.emb_dim = emb_dim
20 | self.device = device
21 | self.nhead = 2
22 | self.SOS_TOKEN = voc_size[2] # start of sentence
23 | self.END_TOKEN = voc_size[2]+1 # end 新增的两个编码,两者均是针对于药物的embedding
24 | self.MED_PAD_TOKEN = voc_size[2]+2 # 用于embedding矩阵中的padding(全为0)
25 | self.DIAG_PAD_TOKEN = voc_size[0]+2
26 | self.PROC_PAD_TOKEN = voc_size[1]+2
27 |
28 | self.tensor_ddi_mask_H = torch.FloatTensor(ddi_mask_H).to(device)
29 |
30 | # dig_num * emb_dim
31 | self.diag_embedding = nn.Sequential(
32 | nn.Embedding(voc_size[0]+3, emb_dim, self.DIAG_PAD_TOKEN),
33 | nn.Dropout(0.3)
34 | )
35 |
36 | # proc_num * emb_dim
37 | self.proc_embedding = nn.Sequential(
38 | nn.Embedding(voc_size[1]+3, emb_dim, self.PROC_PAD_TOKEN),
39 | nn.Dropout(0.3)
40 | )
41 |
42 | # med_num * emb_dim
43 | self.med_embedding = nn.Sequential(
44 | # 添加padding_idx,表示取0向量
45 | nn.Embedding(voc_size[2]+3, emb_dim, self.MED_PAD_TOKEN),
46 | nn.Dropout(0.3)
47 | )
48 |
49 | # 用于对上一个visit的medication进行编码
50 | # self.medication_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, dim_feedforward=emb_dim*8, batch_first=True, dropout=0.2)
51 | self.medication_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, batch_first=True, dropout=0.2)
52 | # 用于对当前visit的疾病与症状进行编码
53 | # self.diagnoses_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, dim_feedforward=emb_dim*8, batch_first=True, dropout=0.2)
54 | # self.procedure_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, dim_feedforward=emb_dim*8, batch_first=True, dropout=0.2)
55 | self.diagnoses_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, batch_first=True, dropout=0.2)
56 | self.procedure_encoder = nn.TransformerEncoderLayer(emb_dim, self.nhead, batch_first=True, dropout=0.2)
57 | # self.enc_gru = nn.GRU(emb_dim, emb_dim, batch_first=True, bidirectional=True)
58 |
59 | # self.ehr_gcn = GCN(
60 | # voc_size=voc_size[2], emb_dim=emb_dim, adj=ehr_adj, device=device)
61 | # self.ddi_gcn = GCN(
62 | # voc_size=voc_size[2], emb_dim=emb_dim, adj=ddi_adj, device=device)
63 | self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device)
64 |
65 | self.gcn = GCN(voc_size=voc_size[2], emb_dim=emb_dim, ehr_adj=ehr_adj, ddi_adj=ddi_adj, device=device)
66 | self.inter = nn.Parameter(torch.FloatTensor(1))
67 |
68 | # 聚合单个visit内的diag和proc得到visit-level的表达
69 | self.diag_self_attend = SelfAttend(emb_dim)
70 | self.proc_self_attend = SelfAttend(emb_dim)
71 |
72 | self.decoder = MedTransformerDecoder(emb_dim, self.nhead, dim_feedforward=emb_dim*2, dropout=0.2,
73 | layer_norm_eps=1e-5)
74 |
75 | # 用于对每一个visit的diagnoses进行编码
76 |
77 | # 用于生成药物序列
78 | self.dec_gru = nn.GRU(emb_dim*3, emb_dim, batch_first=True)
79 |
80 | self.diag_attn = nn.Linear(emb_dim*2, 1)
81 | self.proc_attn = nn.Linear(emb_dim*2, 1)
82 | self.W_diag_attn = nn.Linear(emb_dim, emb_dim)
83 | self.W_proc_attn = nn.Linear(emb_dim, emb_dim)
84 | self.W_diff_attn = nn.Linear(emb_dim, emb_dim)
85 | self.W_diff_proc_attn = nn.Linear(emb_dim, emb_dim)
86 |
87 | # weights
88 | self.Ws = nn.Linear(emb_dim*2, emb_dim) # only used at initial stage
89 | self.Wo = nn.Linear(emb_dim, voc_size[2]+2) # generate mode
90 | # self.Wc = nn.Linear(emb_dim*2, emb_dim) # copy mode
91 | self.Wc = nn.Linear(emb_dim, emb_dim) # copy mode
92 |
93 | self.W_dec = nn.Linear(emb_dim, emb_dim)
94 | self.W_stay = nn.Linear(emb_dim, emb_dim)
95 | self.W_proc_dec = nn.Linear(emb_dim, emb_dim)
96 | self.W_proc_stay = nn.Linear(emb_dim, emb_dim)
97 |
98 | # swtich network to calculate generate probablity
99 | self.W_z = nn.Linear(emb_dim, 1)
100 |
101 |
102 | self.weight = nn.Parameter(torch.tensor([0.3]), requires_grad=True)
103 | # bipartite local embedding
104 | self.bipartite_transform = nn.Sequential(
105 | nn.Linear(emb_dim, ddi_mask_H.shape[1])
106 | )
107 | self.bipartite_output = MaskLinear(
108 | ddi_mask_H.shape[1], voc_size[2], False)
109 |
110 | def encode(self, diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix, seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, dec_proc, stay_proc, dec_proc_mask, stay_proc_mask, max_len=20):
111 | device = self.device
112 | # batch维度以及seq维度上并行计算(现在不考虑时间序列信息),每一个medication序列仍然按顺序预测
113 | batch_size, max_visit_num, max_med_num = medications.size()
114 | max_diag_num = diseases.size()[2]
115 | max_proc_num = procedures.size()[2]
116 |
117 | ############################ 数据预处理 #########################
118 | # 1. 对当前的disease与procedure进行编码
119 | input_disease_embdding = self.diag_embedding(diseases).view(batch_size * max_visit_num, max_diag_num, self.emb_dim) # [batch, seq, max_diag_num, emb]
120 | input_proc_embedding = self.proc_embedding(procedures).view(batch_size * max_visit_num, max_proc_num, self.emb_dim) # [batch, seq, max_proc_num, emb]
121 | d_enc_mask_matrix = d_mask_matrix.view(batch_size * max_visit_num, max_diag_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1, self.nhead, max_diag_num,1) # [batch*seq, nhead, input_length, output_length]
122 | d_enc_mask_matrix = d_enc_mask_matrix.view(batch_size * max_visit_num * self.nhead, max_diag_num, max_diag_num)
123 | p_enc_mask_matrix = p_mask_matrix.view(batch_size * max_visit_num, max_proc_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1, self.nhead, max_proc_num,1)
124 | p_enc_mask_matrix = p_enc_mask_matrix.view(batch_size * max_visit_num * self.nhead, max_proc_num, max_proc_num)
125 | input_disease_embdding = self.diagnoses_encoder(input_disease_embdding, src_mask=d_enc_mask_matrix).view(batch_size, max_visit_num, max_diag_num, self.emb_dim)
126 | input_proc_embedding = self.procedure_encoder(input_proc_embedding, src_mask=p_enc_mask_matrix).view(batch_size, max_visit_num, max_proc_num, self.emb_dim)
127 |
128 | # 1.1 encode visit-level diag and proc representations
129 | visit_diag_embedding = self.diag_self_attend(input_disease_embdding.view(batch_size * max_visit_num, max_diag_num, -1), d_mask_matrix.view(batch_size * max_visit_num, -1))
130 | visit_proc_embedding = self.proc_self_attend(input_proc_embedding.view(batch_size * max_visit_num, max_proc_num, -1), p_mask_matrix.view(batch_size * max_visit_num, -1))
131 | visit_diag_embedding = visit_diag_embedding.view(batch_size, max_visit_num, -1)
132 | visit_proc_embedding = visit_proc_embedding.view(batch_size, max_visit_num, -1)
133 |
134 | # 1.3 计算 visit-level的attention score
135 | # [batch_size, max_visit_num, max_visit_num]
136 | cross_visit_scores = self.calc_cross_visit_scores(visit_diag_embedding, visit_proc_embedding)
137 |
138 |
139 | # 3. 构造一个last_seq_medication,表示上一次visit的medication,第一次的由于没有上一次medication,用0填补(用啥填补都行,反正不会用到)
140 | last_seq_medication = torch.full((batch_size, 1, max_med_num), 0).to(device)
141 | last_seq_medication = torch.cat([last_seq_medication, medications[:, :-1, :]], dim=1)
142 | # m_mask_matrix矩阵同样也需要后移
143 | last_m_mask = torch.full((batch_size, 1, max_med_num), -1e9).to(device) # 这里用较大负值,避免softmax之后分走了概率
144 | last_m_mask = torch.cat([last_m_mask, m_mask_matrix[:, :-1, :]], dim=1)
145 | # 对last_seq_medication进行编码
146 | last_seq_medication_emb = self.med_embedding(last_seq_medication)
147 | last_m_enc_mask = last_m_mask.view(batch_size * max_visit_num, max_med_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1,self.nhead,max_med_num,1)
148 | last_m_enc_mask = last_m_enc_mask.view(batch_size * max_visit_num * self.nhead, max_med_num, max_med_num)
149 | encoded_medication = self.medication_encoder(last_seq_medication_emb.view(batch_size * max_visit_num, max_med_num, self.emb_dim), src_mask=last_m_enc_mask) # (batch*seq, max_med_num, emb_dim)
150 | encoded_medication = encoded_medication.view(batch_size, max_visit_num, max_med_num, self.emb_dim)
151 |
152 | # vocab_size, emb_size
153 | ehr_embedding, ddi_embedding = self.gcn()
154 | drug_memory = ehr_embedding - ddi_embedding * self.inter
155 | drug_memory_padding = torch.zeros((3, self.emb_dim), device=self.device).float()
156 | drug_memory = torch.cat([drug_memory, drug_memory_padding], dim=0)
157 |
158 | return input_disease_embdding, input_proc_embedding, encoded_medication, cross_visit_scores, last_seq_medication, last_m_mask, drug_memory
159 |
160 | def decode(self, input_medications, input_disease_embedding, input_proc_embedding, last_medication_embedding, last_medications, cross_visit_scores,
161 | d_mask_matrix, p_mask_matrix, m_mask_matrix, last_m_mask, drug_memory):
162 | """
163 | input_medications: [batch_size, max_visit_num, max_med_num + 1], 开头包含了 SOS_TOKEN
164 | """
165 | batch_size = input_medications.size(0)
166 | max_visit_num = input_medications.size(1)
167 | max_med_num = input_medications.size(2)
168 | max_diag_num = input_disease_embedding.size(2)
169 | max_proc_num = input_proc_embedding.size(2)
170 |
171 | input_medication_embs = self.med_embedding(input_medications).view(batch_size * max_visit_num, max_med_num, -1)
172 | # input_medication_embs = self.dropout_emb(input_medication_embs)
173 | input_medication_memory = drug_memory[input_medications].view(batch_size * max_visit_num, max_med_num, -1)
174 |
175 | # m_sos_mask = torch.zeros((batch_size, max_visit_num, 1), device=self.device).float() # 这里用较大负值,避免softmax之后分走了概率
176 | m_self_mask = m_mask_matrix
177 |
178 | last_m_enc_mask = m_self_mask.view(batch_size * max_visit_num, max_med_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1, self.nhead, max_med_num, 1)
179 | medication_self_mask = last_m_enc_mask.view(batch_size * max_visit_num * self.nhead, max_med_num, max_med_num)
180 | m2d_mask_matrix = d_mask_matrix.view(batch_size * max_visit_num, max_diag_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1, self.nhead, max_med_num, 1)
181 | m2d_mask_matrix = m2d_mask_matrix.view(batch_size * max_visit_num * self.nhead, max_med_num, max_diag_num)
182 | m2p_mask_matrix = p_mask_matrix.view(batch_size * max_visit_num, max_proc_num).unsqueeze(dim=1).unsqueeze(dim=1).repeat(1, self.nhead, max_med_num,1)
183 | m2p_mask_matrix = m2p_mask_matrix.view(batch_size * max_visit_num * self.nhead, max_med_num, max_proc_num)
184 |
185 | dec_hidden = self.decoder(input_medication_embedding=input_medication_embs, input_medication_memory=input_medication_memory,
186 | input_disease_embdding=input_disease_embedding.view(batch_size * max_visit_num, max_diag_num, -1),
187 | input_proc_embedding=input_proc_embedding.view(batch_size * max_visit_num, max_proc_num, -1),
188 | input_medication_self_mask=medication_self_mask,
189 | d_mask=m2d_mask_matrix,
190 | p_mask=m2p_mask_matrix)
191 |
192 | score_g = self.Wo(dec_hidden) # (batch * max_visit_num, max_med_num, voc_size[2]+2)
193 | score_g = score_g.view(batch_size, max_visit_num, max_med_num, -1)
194 | prob_g = F.softmax(score_g, dim=-1)
195 | score_c = self.copy_med(dec_hidden.view(batch_size, max_visit_num, max_med_num, -1), last_medication_embedding, last_m_mask, cross_visit_scores)
196 | # (batch_size, max_visit_num * input_med_num, max_visit_num * max_med_num)
197 |
198 | ###### case study
199 | # 这里前提是batch_size等于1
200 | # 几个取值的说明:
201 | # 1.取最新生成的药物对于历史药物的attention值,所以第三维度为-1
202 | # 2.取第最后一个visit的copy值,所以第二维度为-1
203 | # 3.取最后一个visit对倒数第二个visit的药物的attention值,所以第四维度取最后max_med_num个
204 | # score_c_buf = score_c.view(batch_size, max_visit_num, max_med_num, -1)
205 | # score_c_buf = score_c_buf[0, -1, -1, :] # visit_num * (visit_num * max_med_num)
206 | # max_med_num_in_last = len(score_c_buf) // max_visit_num
207 | # print(score_c_buf[-max_med_num_in_last:])
208 | prob_c_to_g = torch.zeros_like(prob_g).to(self.device).view(batch_size, max_visit_num * max_med_num, -1) # (batch, max_visit_num * input_med_num, voc_size[2]+2)
209 |
210 | # 用scatter操作代替嵌套循环
211 | # 根据last_seq_medication中的indice,将score_c中的值加到score_c_to_g中去
212 | copy_source = last_medications.view(batch_size, 1, -1).repeat(1, max_visit_num * max_med_num, 1)
213 | prob_c_to_g.scatter_add_(2, copy_source, score_c)
214 | prob_c_to_g = prob_c_to_g.view(batch_size, max_visit_num, max_med_num, -1)
215 |
216 | generate_prob = F.sigmoid(self.W_z(dec_hidden)).view(batch_size, max_visit_num, max_med_num, 1)
217 | prob = prob_g * generate_prob + prob_c_to_g * (1. - generate_prob)
218 | prob[:, 0, :, :] = prob_g[:, 0, :, :] # 第一个seq由于没有last_medication信息,仅取prob_g的概率
219 |
220 | return torch.log(prob)
221 |
222 | # def forward(self, input, last_input=None, max_len=20):
223 | def forward(self, diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix, seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, dec_proc, stay_proc, dec_proc_mask, stay_proc_mask, max_len=20):
224 | device = self.device
225 | # batch维度以及seq维度上并行计算(现在不考虑时间序列信息),每一个medication序列仍然按顺序预测
226 | batch_size, max_seq_length, max_med_num = medications.size()
227 | max_diag_num = diseases.size()[2]
228 | max_proc_num = procedures.size()[2]
229 |
230 | input_disease_embdding, input_proc_embedding, encoded_medication, cross_visit_scores, last_seq_medication, last_m_mask, drug_memory = self.encode(diseases, procedures, medications, d_mask_matrix, p_mask_matrix, m_mask_matrix,
231 | seq_length, dec_disease, stay_disease, dec_disease_mask, stay_disease_mask, dec_proc, stay_proc, dec_proc_mask, stay_proc_mask, max_len=20)
232 |
233 | # 4. 构造给decoder的medications,用于decoding过程中的teacher forcing,注意维度上增加了一维,因为会多生成一个END_TOKEN
234 | input_medication = torch.full((batch_size, max_seq_length, 1), self.SOS_TOKEN).to(device) # [batch_size, seq, 1]
235 | input_medication = torch.cat([input_medication, medications], dim=2) # [batch_size, seq, max_med_num + 1]
236 |
237 | m_sos_mask = torch.zeros((batch_size, max_seq_length, 1), device=self.device).float() # 这里用较大负值,避免softmax之后分走了概率
238 | m_mask_matrix = torch.cat([m_sos_mask, m_mask_matrix], dim=-1)
239 |
240 | output_logits = self.decode(input_medication, input_disease_embdding, input_proc_embedding, encoded_medication, last_seq_medication, cross_visit_scores,
241 | d_mask_matrix, p_mask_matrix, m_mask_matrix, last_m_mask, drug_memory)
242 |
243 | # 5. 加入ddi loss
244 | # output_logits_part = torch.exp(output_logits[:, :, :, :-2] + m_mask_matrix.unsqueeze(-1)) # 去掉SOS与EOS
245 | # output_logits_part = torch.mean(output_logits_part, dim=2)
246 | # neg_pred_prob1 = output_logits_part.unsqueeze(-1)
247 | # neg_pred_prob2 = output_logits_part.unsqueeze(-2)
248 | # neg_pred_prob = neg_pred_prob1 * neg_pred_prob2 # bach * seq * max_med_num * all_med_num * all_med_num
249 | # batch_neg = 0.0005 * neg_pred_prob.mul(self.tensor_ddi_adj).sum()
250 | # return output_logits, batch_neg
251 | return output_logits
252 |
253 | def calc_cross_visit_scores(self, visit_diag_embedding, visit_proc_embedding):
254 | """
255 | visit_diag_embedding: (batch * visit_num * emb)
256 | visit_proc_embedding: (batch * visit_num * emb)
257 | """
258 | max_visit_num = visit_diag_embedding.size(1)
259 | batch_size = visit_diag_embedding.size(0)
260 |
261 | # mask表示每个visit只能看到自己之前的visit
262 | mask = (torch.triu(torch.ones((max_visit_num, max_visit_num), device=self.device)) == 1).transpose(0, 1) # 下三角矩阵
263 | mask = mask.float().masked_fill(mask == 0, -1e9).masked_fill(mask == 1, float(0.0))
264 | mask = mask.unsqueeze(0).repeat(batch_size, 1, 1) # batch * max_visit_num * max_visit_num
265 |
266 | # 每个visit后移一位
267 | padding = torch.zeros((batch_size, 1, self.emb_dim), device=self.device).float()
268 | diag_keys = torch.cat([padding, visit_diag_embedding[:, :-1, :]], dim=1) # batch * max_visit_num * emb
269 | proc_keys = torch.cat([padding, visit_proc_embedding[:, :-1, :]], dim=1)
270 |
271 | # 得到每个visit跟自己前面所有visit的score
272 | diag_scores = torch.matmul(visit_diag_embedding, diag_keys.transpose(-2, -1)) \
273 | / math.sqrt(visit_diag_embedding.size(-1))
274 | proc_scores = torch.matmul(visit_proc_embedding, proc_keys.transpose(-2, -1)) \
275 | / math.sqrt(visit_proc_embedding.size(-1))
276 | # 1st visit's scores is not zero!
277 | scores = F.softmax(diag_scores + proc_scores + mask, dim=-1)
278 |
279 | ###### case study
280 | # 将第0个val置0,然后重新归一化
281 | # scores_buf = scores
282 | # scores_buf[:, :, 0] = 0.
283 | # scores_buf = scores_buf / torch.sum(scores_buf, dim=2, keepdim=True)
284 |
285 | # print(scores_buf)
286 | return scores
287 |
288 | def copy_med(self, decode_input_hiddens, last_medications, last_m_mask, cross_visit_scores):
289 | """
290 | decode_input_hiddens: [batch_size, max_visit_num, input_med_num, emb_size]
291 | last_medications: [batch_size, max_visit_num, max_med_num, emb_size]
292 | last_m_mask: [batch_size, max_visit_num, max_med_num]
293 | cross_visit_scores: [batch_size, max_visit_num, max_visit_num]
294 | """
295 | max_visit_num = decode_input_hiddens.size(1)
296 | input_med_num = decode_input_hiddens.size(2)
297 | max_med_num = last_medications.size(2)
298 | copy_query = self.Wc(decode_input_hiddens).view(-1, max_visit_num*input_med_num, self.emb_dim)
299 | attn_scores = torch.matmul(copy_query, last_medications.view(-1, max_visit_num*max_med_num, self.emb_dim).transpose(-2, -1)) / math.sqrt(self.emb_dim)
300 | med_mask = last_m_mask.view(-1, 1, max_visit_num * max_med_num).repeat(1, max_visit_num * input_med_num, 1)
301 | # [batch_size, max_vist_num * input_med_num, max_visit_num * max_med_num]
302 | attn_scores = F.softmax(attn_scores + med_mask, dim=-1)
303 |
304 | # (batch_size, max_visit_num * input_med_num, max_visit_num)
305 | visit_scores = cross_visit_scores.repeat(1, 1, input_med_num).view(-1, max_visit_num * input_med_num, max_visit_num)
306 |
307 | # (batch_size, max_visit_num * input_med_num, max_visit_num * max_med_num)
308 | visit_scores = visit_scores.unsqueeze(-1).repeat(1, 1, 1, max_med_num).view(-1, max_visit_num * input_med_num, max_visit_num * max_med_num)
309 |
310 | scores = torch.mul(attn_scores, visit_scores).clamp(min=1e-9)
311 | row_scores = scores.sum(dim=-1, keepdim=True)
312 | scores = scores / row_scores # (batch_size, max_visit_num * input_med_num, max_visit_num * max_med_num)
313 |
314 | return scores
315 |
316 |
317 | class MedTransformerDecoder(nn.Module):
318 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
319 | layer_norm_eps=1e-5) -> None:
320 | super(MedTransformerDecoder, self).__init__()
321 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
322 | self.m2d_multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
323 | self.m2p_multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
324 | # Implementation of Feedforward model
325 | self.linear1 = nn.Linear(d_model, dim_feedforward)
326 | self.dropout = nn.Dropout(dropout)
327 | self.linear2 = nn.Linear(dim_feedforward, d_model)
328 |
329 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
330 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
331 | self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
332 | self.dropout1 = nn.Dropout(dropout)
333 | self.dropout2 = nn.Dropout(dropout)
334 | self.dropout3 = nn.Dropout(dropout)
335 |
336 | self.activation = nn.ReLU()
337 | self.nhead = nhead
338 |
339 | # self.align = nn.Linear(d_model, d_model)
340 |
341 | def forward(self, input_medication_embedding, input_medication_memory, input_disease_embdding, input_proc_embedding,
342 | input_medication_self_mask, d_mask, p_mask):
343 | r"""Pass the inputs (and mask) through the decoder layer.
344 | Args:
345 | input_medication_embedding: [*, max_med_num+1, embedding_size]
346 | Shape:
347 | see the docs in Transformer class.
348 | """
349 | input_len = input_medication_embedding.size(0)
350 | tgt_len = input_medication_embedding.size(1)
351 |
352 | # [batch_size*visit_num, max_med_num+1, max_med_num+1]
353 | subsequent_mask = self.generate_square_subsequent_mask(tgt_len, input_len * self.nhead, input_disease_embdding.device)
354 | self_attn_mask = subsequent_mask + input_medication_self_mask
355 |
356 | x = input_medication_embedding + input_medication_memory
357 |
358 | x = self.norm1(x + self._sa_block(x, self_attn_mask))
359 | # attentioned_disease_embedding = self._m2d_mha_block(x, input_disease_embdding, d_mask)
360 | # attentioned_proc_embedding = self._m2p_mha_block(x, input_proc_embedding, p_mask)
361 | # x = self.norm3(x + self._ff_block(torch.cat([attentioned_disease_embedding, self.align(attentioned_proc_embedding)], dim=-1)))
362 | x = self.norm2(x + self._m2d_mha_block(x, input_disease_embdding, d_mask) + self._m2p_mha_block(x, input_proc_embedding, p_mask))
363 | x = self.norm3(x + self._ff_block(x))
364 |
365 | return x
366 |
367 | # self-attention block
368 | def _sa_block(self, x, attn_mask):
369 | x = self.self_attn(x, x, x,
370 | attn_mask=attn_mask,
371 | need_weights=False)[0]
372 | return self.dropout1(x)
373 |
374 | # multihead attention block
375 | def _m2d_mha_block(self, x, mem, attn_mask):
376 | x = self.m2d_multihead_attn(x, mem, mem,
377 | attn_mask=attn_mask,
378 | need_weights=False)[0]
379 | return self.dropout2(x)
380 |
381 | def _m2p_mha_block(self, x, mem, attn_mask):
382 | x = self.m2p_multihead_attn(x, mem, mem,
383 | attn_mask=attn_mask,
384 | need_weights=False)[0]
385 | return self.dropout2(x)
386 |
387 | # feed forward block
388 | def _ff_block(self, x):
389 | x = self.linear2(self.dropout(self.activation(self.linear1(x))))
390 | return self.dropout3(x)
391 |
392 | def generate_square_subsequent_mask(self, sz: int, batch_size: int, device):
393 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
394 | Unmasked positions are filled with float(0.0).
395 | """
396 | mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
397 | mask = mask.float().masked_fill(mask == 0, -1e9).masked_fill(mask == 1, float(0.0))
398 | mask = mask.unsqueeze(0).repeat(batch_size, 1, 1)
399 | return mask
400 |
401 |
402 | class PositionEmbedding(nn.Module):
403 | """
404 | We assume that the sequence length is less than 512.
405 | """
406 | def __init__(self, emb_size, max_length=512):
407 | super(PositionEmbedding, self).__init__()
408 | self.max_length = max_length
409 | self.embedding_layer = nn.Embedding(max_length, emb_size)
410 |
411 | def forward(self, batch_size, seq_length, device):
412 | assert(seq_length <= self.max_length)
413 | ids = torch.arange(0, seq_length).long().to(torch.device(device))
414 | ids = ids.unsqueeze(0).repeat(batch_size, 1)
415 | emb = self.embedding_layer(ids)
416 | return emb
417 |
418 |
419 | class MaskLinear(nn.Module):
420 | def __init__(self, in_features, out_features, bias=True):
421 | super(MaskLinear, self).__init__()
422 | self.in_features = in_features
423 | self.out_features = out_features
424 | self.weight = nn.parameter.Parameter(torch.FloatTensor(in_features, out_features))
425 | if bias:
426 | self.bias = nn.parameter.Parameter(torch.FloatTensor(out_features))
427 | else:
428 | self.register_parameter('bias', None)
429 | self.reset_parameters()
430 |
431 | def reset_parameters(self):
432 | stdv = 1. / math.sqrt(self.weight.size(1))
433 | self.weight.data.uniform_(-stdv, stdv)
434 | if self.bias is not None:
435 | self.bias.data.uniform_(-stdv, stdv)
436 |
437 | def forward(self, input, mask):
438 | weight = torch.mul(self.weight, mask)
439 | output = torch.mm(input, weight)
440 |
441 | if self.bias is not None:
442 | return output + self.bias
443 | else:
444 | return output
445 |
446 | def __repr__(self):
447 | return self.__class__.__name__ + ' (' \
448 | + str(self.in_features) + ' -> ' \
449 | + str(self.out_features) + ')'
450 |
451 |
452 | class GCN(nn.Module):
453 | def __init__(self, voc_size, emb_dim, ehr_adj, ddi_adj, device=torch.device('cpu:0')):
454 | super(GCN, self).__init__()
455 | self.voc_size = voc_size
456 | self.emb_dim = emb_dim
457 | self.device = device
458 |
459 | ehr_adj = self.normalize(ehr_adj + np.eye(ehr_adj.shape[0]))
460 | ddi_adj = self.normalize(ddi_adj + np.eye(ddi_adj.shape[0]))
461 |
462 | self.ehr_adj = torch.FloatTensor(ehr_adj).to(device)
463 | self.ddi_adj = torch.FloatTensor(ddi_adj).to(device)
464 | self.x = torch.eye(voc_size).to(device)
465 |
466 | self.gcn1 = GraphConvolution(voc_size, emb_dim)
467 | self.dropout = nn.Dropout(p=0.3)
468 | self.gcn2 = GraphConvolution(emb_dim, emb_dim)
469 | self.gcn3 = GraphConvolution(emb_dim, emb_dim)
470 |
471 | def forward(self):
472 | ehr_node_embedding = self.gcn1(self.x, self.ehr_adj)
473 | ddi_node_embedding = self.gcn1(self.x, self.ddi_adj)
474 |
475 | ehr_node_embedding = F.relu(ehr_node_embedding)
476 | ddi_node_embedding = F.relu(ddi_node_embedding)
477 | ehr_node_embedding = self.dropout(ehr_node_embedding)
478 | ddi_node_embedding = self.dropout(ddi_node_embedding)
479 |
480 | ehr_node_embedding = self.gcn2(ehr_node_embedding, self.ehr_adj)
481 | ddi_node_embedding = self.gcn3(ddi_node_embedding, self.ddi_adj)
482 | return ehr_node_embedding, ddi_node_embedding
483 |
484 | def normalize(self, mx):
485 | """Row-normalize sparse matrix"""
486 | rowsum = np.array(mx.sum(1))
487 | r_inv = np.power(rowsum, -1).flatten()
488 | r_inv[np.isinf(r_inv)] = 0.
489 | r_mat_inv = np.diagflat(r_inv)
490 | mx = r_mat_inv.dot(mx)
491 | return mx
492 |
493 |
494 | class policy_network(nn.Module):
495 | def __init__(self, in_dim, out_dim, hidden_dim):
496 | super(policy_network, self).__init__()
497 | self.layers = nn.Sequential(
498 | nn.Linear(in_dim, hidden_dim),
499 | nn.ReLU(),
500 | nn.Linear(hidden_dim, out_dim)
501 | )
502 |
503 | def forward(self, x):
504 | return self.layers(x)
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # from torch._C import contiguous_format
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from dnc import DNC
7 | from layers import GraphConvolution
8 | import math
9 | from torch.nn.parameter import Parameter
10 |
11 | '''
12 | Our model
13 | '''
14 |
15 |
16 | class GCN(nn.Module):
17 | def __init__(self, voc_size, emb_dim, adj, device=torch.device('cpu:0')):
18 | super(GCN, self).__init__()
19 | self.voc_size = voc_size
20 | self.emb_dim = emb_dim
21 | self.device = device
22 |
23 | adj = self.normalize(adj + np.eye(adj.shape[0]))
24 |
25 | self.adj = torch.FloatTensor(adj).to(device)
26 | self.x = torch.eye(voc_size).to(device)
27 |
28 | self.gcn1 = GraphConvolution(voc_size, emb_dim)
29 | self.dropout = nn.Dropout(p=0.3)
30 | self.gcn2 = GraphConvolution(emb_dim, emb_dim)
31 |
32 | def forward(self):
33 | node_embedding = self.gcn1(self.x, self.adj)
34 | node_embedding = F.relu(node_embedding)
35 | node_embedding = self.dropout(node_embedding)
36 | node_embedding = self.gcn2(node_embedding, self.adj)
37 | return node_embedding
38 |
39 | def normalize(self, mx):
40 | """Row-normalize sparse matrix"""
41 | rowsum = np.array(mx.sum(1))
42 | r_inv = np.power(rowsum, -1).flatten()
43 | r_inv[np.isinf(r_inv)] = 0.
44 | r_mat_inv = np.diagflat(r_inv)
45 | mx = r_mat_inv.dot(mx)
46 | return mx
47 |
48 |
49 | class MaskLinear(nn.Module):
50 | def __init__(self, in_features, out_features, bias=True):
51 | super(MaskLinear, self).__init__()
52 | self.in_features = in_features
53 | self.out_features = out_features
54 | self.weight = Parameter(torch.FloatTensor(in_features, out_features))
55 | if bias:
56 | self.bias = Parameter(torch.FloatTensor(out_features))
57 | else:
58 | self.register_parameter('bias', None)
59 | self.reset_parameters()
60 |
61 | def reset_parameters(self):
62 | stdv = 1. / math.sqrt(self.weight.size(1))
63 | self.weight.data.uniform_(-stdv, stdv)
64 | if self.bias is not None:
65 | self.bias.data.uniform_(-stdv, stdv)
66 |
67 | def forward(self, input, mask):
68 | weight = torch.mul(self.weight, mask)
69 | output = torch.mm(input, weight)
70 |
71 | if self.bias is not None:
72 | return output + self.bias
73 | else:
74 | return output
75 |
76 | def __repr__(self):
77 | return self.__class__.__name__ + ' (' \
78 | + str(self.in_features) + ' -> ' \
79 | + str(self.out_features) + ')'
80 |
81 |
82 | class MolecularGraphNeuralNetwork(nn.Module):
83 | def __init__(self, N_fingerprint, dim, layer_hidden, device):
84 | super(MolecularGraphNeuralNetwork, self).__init__()
85 | self.device = device
86 | self.embed_fingerprint = nn.Embedding(
87 | N_fingerprint, dim).to(self.device)
88 | self.W_fingerprint = nn.ModuleList([nn.Linear(dim, dim).to(self.device)
89 | for _ in range(layer_hidden)])
90 | self.layer_hidden = layer_hidden
91 |
92 | def pad(self, matrices, pad_value):
93 | """Pad the list of matrices
94 | with a pad_value (e.g., 0) for batch proc essing.
95 | For example, given a list of matrices [A, B, C],
96 | we obtain a new matrix [A00, 0B0, 00C],
97 | where 0 is the zero (i.e., pad value) matrix.
98 | """
99 | shapes = [m.shape for m in matrices]
100 | M, N = sum([s[0] for s in shapes]), sum([s[1] for s in shapes])
101 | zeros = torch.FloatTensor(np.zeros((M, N))).to(self.device)
102 | pad_matrices = pad_value + zeros
103 | i, j = 0, 0
104 | for k, matrix in enumerate(matrices):
105 | m, n = shapes[k]
106 | pad_matrices[i:i+m, j:j+n] = matrix
107 | i += m
108 | j += n
109 | return pad_matrices
110 |
111 | def update(self, matrix, vectors, layer):
112 | hidden_vectors = torch.relu(self.W_fingerprint[layer](vectors))
113 | return hidden_vectors + torch.mm(matrix, hidden_vectors)
114 |
115 | def sum(self, vectors, axis):
116 | sum_vectors = [torch.sum(v, 0) for v in torch.split(vectors, axis)]
117 | return torch.stack(sum_vectors)
118 |
119 | def mean(self, vectors, axis):
120 | mean_vectors = [torch.mean(v, 0) for v in torch.split(vectors, axis)]
121 | return torch.stack(mean_vectors)
122 |
123 | def forward(self, inputs):
124 | """Cat or pad each input data for batch processing."""
125 | fingerprints, adjacencies, molecular_sizes = inputs
126 | fingerprints = torch.cat(fingerprints)
127 | adjacencies = self.pad(adjacencies, 0)
128 |
129 | """MPNN layer (update the fingerprint vectors)."""
130 | fingerprint_vectors = self.embed_fingerprint(fingerprints)
131 | for l in range(self.layer_hidden):
132 | hs = self.update(adjacencies, fingerprint_vectors, l)
133 | # fingerprint_vectors = F.normalize(hs, 2, 1) # normalize.
134 | fingerprint_vectors = hs
135 |
136 | """Molecular vector by sum or mean of the fingerprint vectors."""
137 | molecular_vectors = self.sum(fingerprint_vectors, molecular_sizes)
138 | # molecular_vectors = self.mean(fingerprint_vectors, molecular_sizes)
139 |
140 | return molecular_vectors
141 |
142 |
143 | class SafeDrugModel(nn.Module):
144 | def __init__(self, vocab_size, ddi_adj, ddi_mask_H, MPNNSet, N_fingerprints, average_projection, emb_dim=256, device=torch.device('cpu:0')):
145 | super(SafeDrugModel, self).__init__()
146 |
147 | self.device = device
148 |
149 | # pre-embedding
150 | self.embeddings = nn.ModuleList(
151 | [nn.Embedding(vocab_size[i], emb_dim) for i in range(2)])
152 | self.dropout = nn.Dropout(p=0.5)
153 | self.encoders = nn.ModuleList(
154 | [nn.GRU(emb_dim, emb_dim, batch_first=True) for _ in range(2)])
155 | self.query = nn.Sequential(
156 | nn.ReLU(),
157 | nn.Linear(2 * emb_dim, emb_dim)
158 | )
159 |
160 | # bipartite local embedding
161 | self.bipartite_transform = nn.Sequential(
162 | nn.Linear(emb_dim, ddi_mask_H.shape[1])
163 | )
164 | self.bipartite_output = MaskLinear(
165 | ddi_mask_H.shape[1], vocab_size[2], False)
166 |
167 | # MPNN global embedding
168 | self.MPNN_molecule_Set = list(zip(*MPNNSet))
169 |
170 | self.MPNN_emb = MolecularGraphNeuralNetwork(
171 | N_fingerprints, emb_dim, layer_hidden=2, device=device).forward(self.MPNN_molecule_Set)
172 | self.MPNN_emb = torch.mm(average_projection.to(
173 | device=self.device), self.MPNN_emb.to(device=self.device))
174 | # self.MPNN_emb.to(device=self.device)
175 | self.MPNN_emb = torch.tensor(self.MPNN_emb, requires_grad=True)
176 | self.MPNN_output = nn.Linear(vocab_size[2], vocab_size[2])
177 | self.MPNN_layernorm = nn.LayerNorm(vocab_size[2])
178 |
179 | # graphs, bipartite matrix
180 | self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device)
181 | self.tensor_ddi_mask_H = torch.FloatTensor(ddi_mask_H).to(device)
182 | self.init_weights()
183 |
184 | def forward(self, input):
185 |
186 | # patient health representation
187 | i1_seq = []
188 | i2_seq = []
189 |
190 | def sum_embedding(embedding):
191 | return embedding.sum(dim=1).unsqueeze(dim=0) # (1,1,dim)
192 | for adm in input:
193 | i1 = sum_embedding(self.dropout(self.embeddings[0](
194 | torch.LongTensor(adm[0]).unsqueeze(dim=0).to(self.device)))) # (1,1,dim)
195 | i2 = sum_embedding(self.dropout(self.embeddings[1](
196 | torch.LongTensor(adm[1]).unsqueeze(dim=0).to(self.device))))
197 | i1_seq.append(i1)
198 | i2_seq.append(i2)
199 | i1_seq = torch.cat(i1_seq, dim=1) # (1,seq,dim)
200 | i2_seq = torch.cat(i2_seq, dim=1) # (1,seq,dim)
201 |
202 | o1, h1 = self.encoders[0](
203 | i1_seq
204 | )
205 | o2, h2 = self.encoders[1](
206 | i2_seq
207 | )
208 | patient_representations = torch.cat(
209 | [o1, o2], dim=-1).squeeze(dim=0) # (seq, dim*2)
210 | query = self.query(patient_representations)[-1:, :] # (seq, dim)
211 |
212 | # MPNN embedding
213 | MPNN_match = F.sigmoid(torch.mm(query, self.MPNN_emb.t()))
214 | MPNN_att = self.MPNN_layernorm(
215 | MPNN_match + self.MPNN_output(MPNN_match))
216 |
217 | # local embedding
218 | bipartite_emb = self.bipartite_output(
219 | F.sigmoid(self.bipartite_transform(query)), self.tensor_ddi_mask_H.t())
220 |
221 | result = torch.mul(bipartite_emb, MPNN_att)
222 |
223 | neg_pred_prob = F.sigmoid(result)
224 | neg_pred_prob = neg_pred_prob.t() * neg_pred_prob # (voc_size, voc_size)
225 |
226 | batch_neg = 0.0005 * neg_pred_prob.mul(self.tensor_ddi_adj).sum()
227 |
228 | return result, batch_neg
229 |
230 | def init_weights(self):
231 | """Initialize weights."""
232 | initrange = 0.1
233 | for item in self.embeddings:
234 | item.weight.data.uniform_(-initrange, initrange)
235 |
236 |
237 | class DMNC(nn.Module):
238 | def __init__(self, vocab_size, emb_dim=64, device=torch.device('cpu:0')):
239 | super(DMNC, self).__init__()
240 | K = len(vocab_size)
241 | self.K = K
242 | self.vocab_size = vocab_size
243 | self.device = device
244 |
245 | self.token_start = vocab_size[2]
246 | self.token_end = vocab_size[2] + 1
247 |
248 | self.embeddings = nn.ModuleList(
249 | [nn.Embedding(vocab_size[i] if i != 2 else vocab_size[2] + 2, emb_dim) for i in range(K)])
250 | self.dropout = nn.Dropout(p=0.5)
251 |
252 | self.encoders = nn.ModuleList([DNC(
253 | input_size=emb_dim,
254 | hidden_size=emb_dim,
255 | rnn_type='gru',
256 | num_layers=1,
257 | num_hidden_layers=1,
258 | nr_cells=16,
259 | cell_size=emb_dim,
260 | read_heads=1,
261 | batch_first=True,
262 | gpu_id=0,
263 | independent_linears=False
264 | ) for _ in range(K - 1)])
265 |
266 | self.decoder = nn.GRU(emb_dim + emb_dim * 2, emb_dim * 2,
267 | batch_first=True) # input: (y, r1, r2,) hidden: (hidden1, hidden2)
268 | self.interface_weighting = nn.Linear(
269 | emb_dim * 2, 2 * (emb_dim + 1 + 3)) # 2 read head (key, str, mode)
270 | self.decoder_r2o = nn.Linear(2 * emb_dim, emb_dim * 2)
271 |
272 | self.output = nn.Linear(emb_dim * 2, vocab_size[2] + 2)
273 |
274 | def forward(self, input, i1_state=None, i2_state=None, h_n=None, max_len=20):
275 | # input (3, code)
276 | i1_input_tensor = self.embeddings[0](
277 | torch.LongTensor(input[0]).unsqueeze(dim=0).to(self.device)) # (1, seq, codes)
278 | i2_input_tensor = self.embeddings[1](
279 | torch.LongTensor(input[1]).unsqueeze(dim=0).to(self.device)) # (1, seq, codes)
280 |
281 | o1, (ch1, m1, r1) = \
282 | self.encoders[0](i1_input_tensor, (None, None, None)
283 | if i1_state is None else i1_state)
284 | o2, (ch2, m2, r2) = \
285 | self.encoders[1](i2_input_tensor, (None, None, None)
286 | if i2_state is None else i2_state)
287 |
288 | # save memory state
289 | i1_state = (ch1, m1, r1)
290 | i2_state = (ch2, m2, r2)
291 |
292 | predict_sequence = [self.token_start] + input[2]
293 | if h_n is None:
294 | h_n = torch.cat([ch1[0], ch2[0]], dim=-1)
295 |
296 | output_logits = []
297 | r1 = r1.unsqueeze(dim=0)
298 | r2 = r2.unsqueeze(dim=0)
299 |
300 | if self.training:
301 | for item in predict_sequence:
302 | # teacher force predict drug
303 | item_tensor = self.embeddings[2](
304 | torch.LongTensor([item]).unsqueeze(dim=0).to(self.device)) # (1, seq, codes)
305 |
306 | o3, h_n = self.decoder(
307 | torch.cat([item_tensor, r1, r2], dim=-1), h_n)
308 | read_keys, read_strengths, read_modes = self.decode_read_variable(
309 | h_n.squeeze(0))
310 |
311 | # read from i1_mem, i2_mem and i3_mem
312 | r1, _ = self.read_from_memory(self.encoders[0],
313 | read_keys[:, 0, :].unsqueeze(
314 | dim=1),
315 | read_strengths[:, 0].unsqueeze(
316 | dim=1),
317 | read_modes[:, 0, :].unsqueeze(dim=1), i1_state[1])
318 |
319 | r2, _ = self.read_from_memory(self.encoders[1],
320 | read_keys[:, 1, :].unsqueeze(
321 | dim=1),
322 | read_strengths[:, 1].unsqueeze(
323 | dim=1),
324 | read_modes[:, 1, :].unsqueeze(dim=1), i2_state[1])
325 |
326 | output = self.decoder_r2o(torch.cat([r1, r2], dim=-1))
327 | output = self.output(output + o3).squeeze(dim=0)
328 | output_logits.append(output)
329 | else:
330 | item_tensor = self.embeddings[2](
331 | torch.LongTensor([self.token_start]).unsqueeze(dim=0).to(self.device)) # (1, seq, codes)
332 | for idx in range(max_len):
333 | # predict
334 | # teacher force predict drug
335 | o3, h_n = self.decoder(
336 | torch.cat([item_tensor, r1, r2], dim=-1), h_n)
337 | read_keys, read_strengths, read_modes = self.decode_read_variable(
338 | h_n.squeeze(0))
339 |
340 | # read from i1_mem, i2_mem and i3_mem
341 | r1, _ = self.read_from_memory(self.encoders[0],
342 | read_keys[:, 0, :].unsqueeze(
343 | dim=1),
344 | read_strengths[:, 0].unsqueeze(
345 | dim=1),
346 | read_modes[:, 0, :].unsqueeze(dim=1), i1_state[1])
347 |
348 | r2, _ = self.read_from_memory(self.encoders[1],
349 | read_keys[:, 1, :].unsqueeze(
350 | dim=1),
351 | read_strengths[:, 1].unsqueeze(
352 | dim=1),
353 | read_modes[:, 1, :].unsqueeze(dim=1), i2_state[1])
354 |
355 | output = self.decoder_r2o(torch.cat([r1, r2], dim=-1))
356 | output = self.output(output + o3).squeeze(dim=0)
357 | output = F.softmax(output, dim=-1)
358 | output_logits.append(output)
359 |
360 | input_token = torch.argmax(output, dim=-1)
361 | input_token = input_token.item()
362 | item_tensor = self.embeddings[2](
363 | torch.LongTensor([input_token]).unsqueeze(dim=0).to(self.device)) # (1, seq, codes)
364 |
365 | return torch.cat(output_logits, dim=0), i1_state, i2_state, h_n
366 |
367 | def read_from_memory(self, dnc, read_key, read_str, read_mode, m_hidden):
368 | read_vectors, hidden = dnc.memories[0].read(
369 | read_key, read_str, read_mode, m_hidden)
370 | return read_vectors, hidden
371 |
372 | def decode_read_variable(self, input):
373 | w = 64
374 | r = 2
375 | b = input.size(0)
376 |
377 | input = self.interface_weighting(input)
378 | # r read keys (b * w * r)
379 | read_keys = F.tanh(input[:, :r * w].contiguous().view(b, r, w))
380 | # r read strengths (b * r)
381 | read_strengths = F.softplus(
382 | input[:, r * w:r * w + r].contiguous().view(b, r))
383 | # read modes (b * 3*r)
384 | read_modes = F.softmax(
385 | input[:, (r * w + r):].contiguous().view(b, r, 3), -1)
386 | return read_keys, read_strengths, read_modes
387 |
388 |
389 | class GAMENet(nn.Module):
390 | def __init__(self, vocab_size, ehr_adj, ddi_adj, emb_dim=64, device=torch.device('cpu:0'), ddi_in_memory=True):
391 | super(GAMENet, self).__init__()
392 | K = len(vocab_size)
393 | self.K = K
394 | self.vocab_size = vocab_size
395 | self.device = device
396 | self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device)
397 | self.ddi_in_memory = ddi_in_memory
398 | self.embeddings = nn.ModuleList(
399 | [nn.Embedding(vocab_size[i], emb_dim) for i in range(K-1)])
400 | self.dropout = nn.Dropout(p=0.5)
401 |
402 | self.encoders = nn.ModuleList(
403 | [nn.GRU(emb_dim, emb_dim * 2, batch_first=True) for _ in range(K-1)])
404 |
405 | self.query = nn.Sequential(
406 | nn.ReLU(),
407 | nn.Linear(emb_dim * 4, emb_dim),
408 | )
409 |
410 | self.ehr_gcn = GCN(
411 | voc_size=vocab_size[2], emb_dim=emb_dim, adj=ehr_adj, device=device)
412 | self.ddi_gcn = GCN(
413 | voc_size=vocab_size[2], emb_dim=emb_dim, adj=ddi_adj, device=device)
414 | self.inter = nn.Parameter(torch.FloatTensor(1))
415 |
416 | self.output = nn.Sequential(
417 | nn.ReLU(),
418 | nn.Linear(emb_dim * 3, emb_dim * 2),
419 | nn.ReLU(),
420 | nn.Linear(emb_dim * 2, vocab_size[2])
421 | )
422 |
423 | self.init_weights()
424 |
425 | def forward(self, input):
426 | # input (adm, 3, codes)
427 |
428 | # generate medical embeddings and queries
429 | i1_seq = []
430 | i2_seq = []
431 |
432 | def mean_embedding(embedding):
433 | return embedding.mean(dim=1).unsqueeze(dim=0) # (1,1,dim)
434 | for adm in input:
435 | i1 = mean_embedding(self.dropout(self.embeddings[0](
436 | torch.LongTensor(adm[0]).unsqueeze(dim=0).to(self.device)))) # (1,1,dim)
437 | i2 = mean_embedding(self.dropout(self.embeddings[1](
438 | torch.LongTensor(adm[1]).unsqueeze(dim=0).to(self.device))))
439 | i1_seq.append(i1)
440 | i2_seq.append(i2)
441 | i1_seq = torch.cat(i1_seq, dim=1) # (1,seq,dim)
442 | i2_seq = torch.cat(i2_seq, dim=1) # (1,seq,dim)
443 |
444 | o1, h1 = self.encoders[0](
445 | i1_seq
446 | ) # o1:(1, seq, dim*2) hi:(1,1,dim*2)
447 | o2, h2 = self.encoders[1](
448 | i2_seq
449 | )
450 | patient_representations = torch.cat(
451 | [o1, o2], dim=-1).squeeze(dim=0) # (seq, dim*4)
452 | queries = self.query(patient_representations) # (seq, dim)
453 |
454 | # graph memory module
455 | '''I:generate current input'''
456 | query = queries[-1:] # (1,dim)
457 |
458 | '''G:generate graph memory bank and insert history information'''
459 | if self.ddi_in_memory:
460 | drug_memory = self.ehr_gcn() - self.ddi_gcn() * self.inter # (size, dim)
461 | else:
462 | drug_memory = self.ehr_gcn()
463 |
464 | if len(input) > 1:
465 | history_keys = queries[:(queries.size(0)-1)] # (seq-1, dim)
466 |
467 | history_values = np.zeros((len(input)-1, self.vocab_size[2]))
468 | for idx, adm in enumerate(input):
469 | if idx == len(input)-1:
470 | break
471 | history_values[idx, adm[2]] = 1
472 | history_values = torch.FloatTensor(
473 | history_values).to(self.device) # (seq-1, size)
474 |
475 | '''O:read from global memory bank and dynamic memory bank'''
476 | key_weights1 = F.softmax(
477 | torch.mm(query, drug_memory.t()), dim=-1) # (1, size)
478 | fact1 = torch.mm(key_weights1, drug_memory) # (1, dim)
479 |
480 | if len(input) > 1:
481 | visit_weight = F.softmax(
482 | torch.mm(query, history_keys.t())) # (1, seq-1)
483 | weighted_values = visit_weight.mm(history_values) # (1, size)
484 | fact2 = torch.mm(weighted_values, drug_memory) # (1, dim)
485 | else:
486 | fact2 = fact1
487 | '''R:convert O and predict'''
488 | output = self.output(
489 | torch.cat([query, fact1, fact2], dim=-1)) # (1, dim)
490 |
491 | if self.training:
492 | neg_pred_prob = F.sigmoid(output)
493 | neg_pred_prob = neg_pred_prob.t() * neg_pred_prob # (voc_size, voc_size)
494 | batch_neg = neg_pred_prob.mul(self.tensor_ddi_adj).mean()
495 |
496 | return output, batch_neg
497 | else:
498 | return output
499 |
500 | def init_weights(self):
501 | """Initialize weights."""
502 | initrange = 0.1
503 | for item in self.embeddings:
504 | item.weight.data.uniform_(-initrange, initrange)
505 |
506 | self.inter.data.uniform_(-initrange, initrange)
507 |
508 |
509 | class Leap(nn.Module):
510 | def __init__(self, voc_size, emb_dim=64, device=torch.device('cpu:0')):
511 | super(Leap, self).__init__()
512 | self.voc_size = voc_size
513 | self.device = device
514 | self.SOS_TOKEN = voc_size[2] # start of sentence
515 | self.END_TOKEN = voc_size[2]+1 # end 新增的两个编码,两者均是针对于药物的embedding
516 |
517 | # dig_num * emb_dim
518 | self.enc_embedding = nn.Sequential(
519 | nn.Embedding(voc_size[0], emb_dim, ),
520 | nn.Dropout(0.3)
521 | )
522 |
523 | # med_num * emb_dim
524 | self.dec_embedding = nn.Sequential(
525 | nn.Embedding(voc_size[2] + 2, emb_dim, ),
526 | nn.Dropout(0.3)
527 | )
528 |
529 | self.dec_gru = nn.GRU(emb_dim*2, emb_dim, batch_first=True)
530 |
531 | self.attn = nn.Linear(emb_dim*2, 1)
532 |
533 | self.output = nn.Linear(emb_dim, voc_size[2]+2)
534 |
535 | def forward(self, input, max_len=20):
536 | device = self.device
537 | # input (3, codes)
538 | input_tensor = torch.LongTensor(input[0]).to(device)
539 | # (len, dim)
540 | # 对疾病进行编码
541 | input_embedding = self.enc_embedding(
542 | input_tensor.unsqueeze(dim=0)).squeeze(dim=0)
543 |
544 | output_logits = []
545 | hidden_state = None
546 | if self.training:
547 | # training 过程
548 | # 对于每一个当前已知的药物分别进行计算,有点像teacher-forcing的思路
549 | for med_code in [self.SOS_TOKEN] + input[2]:
550 | dec_input = torch.LongTensor(
551 | [med_code]).unsqueeze(dim=0).to(device)
552 | dec_input = self.dec_embedding(dec_input).squeeze(
553 | dim=0) # (1,dim) 取对应药物的embedding
554 |
555 | if hidden_state is None: # 使用上一个adm的hidden_state
556 | hidden_state = dec_input
557 | hidden_state_repeat = hidden_state.repeat(
558 | input_embedding.size(0), 1) # (len, dim)
559 |
560 | combined_input = torch.cat(
561 | [hidden_state_repeat, input_embedding], dim=-1) # (len, dim*2)
562 | # (1, len) 计算该药物针对所有疾病的attention weight
563 | attn_weight = F.softmax(self.attn(combined_input).t(), dim=-1)
564 | input_embedding = attn_weight.mm(
565 | input_embedding) # (1, dim) # 加权求和
566 |
567 | _, hidden_state = self.dec_gru(torch.cat(
568 | [input_embedding, dec_input], dim=-1).unsqueeze(dim=0), hidden_state.unsqueeze(dim=0))
569 | hidden_state = hidden_state.squeeze(dim=0) # (1,dim)
570 |
571 | # (1, med_num) 表示当前位置上每个med的logits
572 | output_logits.append(self.output(F.relu(hidden_state)))
573 | return torch.cat(output_logits, dim=0)
574 |
575 | else:
576 | # testing 过程,这里不能用input[2]也就是medication的信息
577 | # 控制最大的长度(可以根据数据的范围调整)
578 | for di in range(max_len):
579 | if di == 0:
580 | dec_input = torch.LongTensor([[self.SOS_TOKEN]]).to(
581 | device) # 第一个位置用SOS,后面的则用上一个位置的预测结果
582 | dec_input = self.dec_embedding(
583 | dec_input).squeeze(dim=0) # (1,dim)
584 | if hidden_state is None:
585 | hidden_state = dec_input
586 | hidden_state_repeat = hidden_state.repeat(
587 | input_embedding.size(0), 1) # (len, dim)
588 | combined_input = torch.cat(
589 | [hidden_state_repeat, input_embedding], dim=-1) # (len, dim*2)
590 | attn_weight = F.softmax(
591 | self.attn(combined_input).t(), dim=-1) # (1, len)
592 | input_embedding = attn_weight.mm(input_embedding) # (1, dim)
593 | _, hidden_state = self.dec_gru(torch.cat([input_embedding, dec_input], dim=-1).unsqueeze(dim=0),
594 | hidden_state.unsqueeze(dim=0))
595 | hidden_state = hidden_state.squeeze(dim=0) # (1,dim)
596 | output = self.output(F.relu(hidden_state))
597 | # data是直接取数据,这里直接获取当前位置上最有可能的logits
598 | topv, topi = output.data.topk(1)
599 | output_logits.append(F.softmax(output, dim=-1))
600 | dec_input = topi.detach()
601 | return torch.cat(output_logits, dim=0)
602 |
603 |
604 | class Retain(nn.Module):
605 | def __init__(self, voc_size, emb_size=64, device=torch.device('cpu:0')):
606 | super(Retain, self).__init__()
607 | self.device = device
608 | self.voc_size = voc_size
609 | self.emb_size = emb_size
610 | self.input_len = voc_size[0] + voc_size[1] + voc_size[2]
611 | self.output_len = voc_size[2]
612 |
613 | self.embedding = nn.Sequential(
614 | nn.Embedding(self.input_len + 1, self.emb_size,
615 | padding_idx=self.input_len),
616 | nn.Dropout(0.5)
617 | )
618 |
619 | self.alpha_gru = nn.GRU(emb_size, emb_size, batch_first=True)
620 | self.beta_gru = nn.GRU(emb_size, emb_size, batch_first=True)
621 |
622 | self.alpha_li = nn.Linear(emb_size, 1)
623 | self.beta_li = nn.Linear(emb_size, emb_size)
624 |
625 | self.output = nn.Linear(emb_size, self.output_len)
626 |
627 | def forward(self, input):
628 | device = self.device
629 | # input: (visit, 3, codes )
630 | max_len = max([(len(v[0]) + len(v[1]) + len(v[2])) for v in input])
631 | input_np = []
632 | for visit in input:
633 | input_tmp = []
634 | input_tmp.extend(visit[0])
635 | input_tmp.extend(list(np.array(visit[1]) + self.voc_size[0]))
636 | input_tmp.extend(
637 | list(np.array(visit[2]) + self.voc_size[0] + self.voc_size[1]))
638 | if len(input_tmp) < max_len:
639 | input_tmp.extend([self.input_len]*(max_len - len(input_tmp)))
640 |
641 | input_np.append(input_tmp)
642 |
643 | visit_emb = self.embedding(torch.LongTensor(
644 | input_np).to(device)) # (visit, max_len, emb)
645 | visit_emb = torch.sum(visit_emb, dim=1) # (visit, emb)
646 |
647 | g, _ = self.alpha_gru(visit_emb.unsqueeze(dim=0)) # g: (1, visit, emb)
648 | h, _ = self.beta_gru(visit_emb.unsqueeze(dim=0)) # h: (1, visit, emb)
649 |
650 | g = g.squeeze(dim=0) # (visit, emb)
651 | h = h.squeeze(dim=0) # (visit, emb)
652 | attn_g = F.softmax(self.alpha_li(g), dim=-1) # (visit, 1)
653 | attn_h = F.tanh(self.beta_li(h)) # (visit, emb)
654 |
655 | c = attn_g * attn_h * visit_emb # (visit, emb)
656 | c = torch.sum(c, dim=0).unsqueeze(dim=0) # (1, emb)
657 |
658 | return self.output(c)
659 |
660 |
661 | class Leap_batch(nn.Module):
662 | def __init__(self, voc_size, emb_dim=64, device=torch.device('cpu:0')):
663 | super(Leap_batch, self).__init__()
664 | self.voc_size = voc_size
665 | self.emb_dim = emb_dim
666 | self.device = device
667 | self.SOS_TOKEN = voc_size[2] # start of sentence
668 | self.END_TOKEN = voc_size[2]+1 # end 新增的两个编码,两者均是针对于药物的embedding
669 | self.MED_PAD_TOKEN = voc_size[2]+2 # 用于embedding矩阵中的padding(全为0)
670 | self.DIAG_PAD_TOKEN = voc_size[0]+2
671 | self.PROC_PAD_TOKEN = voc_size[1]+2
672 | # dig_num * emb_dim
673 | self.diag_embedding = nn.Sequential(
674 | nn.Embedding(voc_size[0]+3, emb_dim, self.DIAG_PAD_TOKEN),
675 | nn.Dropout(0.3)
676 | )
677 |
678 | # proc_num * emb_dim
679 | self.proc_embedding = nn.Sequential(
680 | nn.Embedding(voc_size[1]+3, emb_dim, self.PROC_PAD_TOKEN),
681 | nn.Dropout(0.3)
682 | )
683 |
684 | # med_num * emb_dim
685 | self.med_embedding = nn.Sequential(
686 | # 添加padding_idx,表示取0向量
687 | nn.Embedding(voc_size[2] + 3, emb_dim, self.MED_PAD_TOKEN),
688 | nn.Dropout(0.3)
689 | )
690 |
691 | # 用于对上一个visit进行编码
692 | self.enc_gru = nn.GRU(emb_dim, emb_dim, batch_first=True, bidirectional=True)
693 |
694 | # 用于生成药物序列
695 | self.dec_gru = nn.GRU(emb_dim*2, emb_dim, batch_first=True)
696 |
697 | self.attn = nn.Linear(emb_dim*2, 1)
698 |
699 | # self.output = nn.Linear(emb_dim, voc_size[2]+2)
700 | # self.output2 = nn.Linear(emb_dim, voc_size[2]+2)
701 |
702 | # weights
703 | self.Ws = nn.Linear(emb_dim*2, emb_dim) # only used at initial stage
704 | self.Wo = nn.Linear(emb_dim, voc_size[2]+2) # generate mode
705 | self.Wc = nn.Linear(emb_dim*2, emb_dim) # copy mode
706 |
707 | def encoder(self, x):
708 | # input: (med_num)
709 | embedded = self.dec_embedding(x)
710 | out, h = self.enc_gru(embedded) # out: [b x seq x hid*2] (biRNN)
711 | return out, h
712 |
713 | # def forward(self, input, last_input=None, max_len=20):
714 | def forward(self, diseases, procedures, medications, d_mask_matrix, m_mask_matrix, seq_length, max_len=20):
715 | device = self.device
716 | # batch维度以及seq维度上并行计算(现在不考虑时间序列信息),每一个medication序列仍然按顺序预测
717 | batch_size, max_seq_length, max_med_num = medications.size()
718 | max_diag_num = diseases.size()[2]
719 | hidden_state = None
720 |
721 | # print("diasease size", diseases.size())
722 | # print("proc size", procedures.size())
723 | # print("med size", medications.size())
724 | input_disease_embdding = self.diag_embedding(diseases) # [batch, seq, max_d_num, emb]
725 | input_med_embedding = self.med_embedding(medications) # [batch, seq, max_med_num, emb]
726 |
727 | # 拼接一个last_seq_medication,表示对应seq对应的上一次的medication,第一次的由于没有上一次medication,用0填补(用啥填补都行,反正不会用到)
728 | last_seq_medication = torch.full((batch_size, 1, max_med_num), 0).to(device)
729 | last_seq_medication = torch.cat([last_seq_medication, medications[:, :-1, :]], dim=1)
730 | # m_mask_matrix矩阵同样也需要后移
731 | last_m_mask = torch.full((batch_size, 1, max_med_num), -1e9).to(device) # 这里用较大负值,避免softmax之后分走了概率
732 | last_m_mask = torch.cat([last_m_mask, m_mask_matrix[:, :-1, :]], dim=1)
733 |
734 | last_seq_medication_emb = self.med_embedding(last_seq_medication)
735 | # print(last_seq_medication.size(), input_med_embedding.size())
736 | # 对last_visit进行编码,注意这里用的是last_seq_medication的编码结果
737 | # (batch*seq, max_med_num, emb_dim*2)
738 | encoded_disease, _ = self.enc_gru(last_seq_medication_emb.view(batch_size * max_seq_length, max_med_num, self.emb_dim))
739 |
740 | # 同样拼接一个last_medication,用于进行序列生成,注意维度上增加了一维
741 | last_medication = torch.full((batch_size, max_seq_length, 1), self.SOS_TOKEN).to(device) # [batch_size, seq, 1]
742 | last_medication = torch.cat([last_medication, medications], dim=2) # [batch_size, seq, max_med_num + 1]
743 | # print(last_medication.size(), medications.size())
744 |
745 | hidden_state = None
746 | # 预定义结果矩阵
747 | if self.training:
748 | output_logits = torch.zeros(batch_size, max_seq_length, max_med_num+1, self.voc_size[2]+2).to(device)
749 | loop_size=max_med_num+1
750 | else:
751 | output_logits = torch.zeros(batch_size, max_seq_length, max_len, self.voc_size[2]+2).to(device)
752 | loop_size=max_len
753 |
754 | # 开始遍历生成每一个位置的结果
755 | for i in range(loop_size):
756 | if self.training:
757 | dec_input = self.med_embedding(last_medication[:,:,i]) # (batch, seq, emb_dim) 取上一个药物的embedding
758 | else:
759 | if i==0:
760 | dec_input = self.med_embedding(last_medication[:,:,0])
761 | elif i==max_len:
762 | break
763 | else:
764 | # 非训练时,只能取上一次的输出
765 | dec_input=self.med_embedding(dec_input)
766 |
767 | if hidden_state is None:
768 | hidden_state=dec_input
769 |
770 | # 根据当前的疾病做attention,计算hidden_state (batch, seq, emb_dim)
771 | # print(dec_input.size())
772 | hidden_state_repeat = hidden_state.unsqueeze(dim=2).repeat(1,1,max_diag_num,1) # (batch, seq, max_diag_num, emb_dim)
773 | # print(hidden_state_repeat.size())
774 | combined_input=torch.cat([hidden_state_repeat, input_disease_embdding], dim=-1) # (batch, seq, max_diag_num, emb_dim*2)
775 | """这里attn_score结果需要根据mask来加上一个较大的负值,使得对应softmax值接近0,来避免分散注意力"""
776 | attn_score = self.attn(combined_input).squeeze(dim=-1) # (batch, seq, max_diag_num, 1) -> (batch, seq, max_diag_num)
777 | attn_score = attn_score + d_mask_matrix
778 |
779 | attn_weight=F.softmax(attn_score, dim=-1).unsqueeze(dim=2) # (batch, seq, 1, max_diag_num) 注意力权重
780 | # print(attn_weight.size())
781 | input_embedding=torch.matmul(attn_weight, input_disease_embdding).squeeze(dim=2) # (batch, seq, emb_dim)
782 |
783 | # 为了送到dec_gru中进行reshape
784 | input_embedding_buf = input_embedding.view(batch_size * max_seq_length, 1, -1)
785 | dec_input_buf = dec_input.view(batch_size * max_seq_length, 1, -1)
786 | # print(input_embedding_buf.size())
787 | # print(dec_input_buf.size())
788 | hidden_state_buf = hidden_state.view(1, batch_size * max_seq_length, -1)
789 | _, hidden_state_buf = self.dec_gru(torch.cat([input_embedding_buf, dec_input_buf], dim=-1), hidden_state_buf)
790 | # print(hidden_state_buf.size())
791 | hidden_state = hidden_state_buf.view(batch_size, max_seq_length, -1) # (batch, seq, emb_dim)
792 | # print(hidden_state.size())
793 |
794 | score_g = self.Wo(hidden_state) # (batch, seq, voc_size[2]+2)
795 |
796 | prob = torch.log_softmax(score_g, dim=-1)
797 | output_logits[:, :, i, :] = prob
798 |
799 | if not self.training:
800 | # data是直接取数据,这里直接获取当前位置上最有可能的logits
801 | _, topi = torch.topk(prob, 1, dim=-1)
802 | dec_input=topi.detach()
803 |
804 | return output_logits
805 |
806 |
807 | class MICRON(nn.Module):
808 | def __init__(self, vocab_size, ddi_adj, emb_dim=256, device=torch.device('cpu:0')):
809 | super(MICRON, self).__init__()
810 |
811 | self.device = device
812 |
813 | # pre-embedding
814 | self.embeddings = nn.ModuleList(
815 | [nn.Embedding(vocab_size[i], emb_dim) for i in range(2)])
816 | self.dropout = nn.Dropout(p=0.5)
817 |
818 | self.health_net = nn.Sequential(
819 | nn.Linear(2 * emb_dim, emb_dim)
820 | )
821 |
822 | #
823 | self.prescription_net = nn.Sequential(
824 | nn.Linear(emb_dim, emb_dim * 4),
825 | nn.ReLU(),
826 | nn.Linear(emb_dim * 4, vocab_size[2])
827 | )
828 |
829 | # graphs, bipartite matrix
830 | self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device)
831 | self.init_weights()
832 |
833 | def forward(self, input):
834 |
835 | # patient health representation
836 | def sum_embedding(embedding):
837 | return embedding.sum(dim=1).unsqueeze(dim=0) # (1,1,dim)
838 |
839 | diag_emb = sum_embedding(self.dropout(self.embeddings[0](torch.LongTensor(input[-1][0]).unsqueeze(dim=0).to(self.device)))) # (1,1,dim)
840 | prod_emb = sum_embedding(self.dropout(self.embeddings[1](torch.LongTensor(input[-1][1]).unsqueeze(dim=0).to(self.device))))
841 | # diag_emb = torch.cat(diag_emb, dim=1) #(1,seq,dim)
842 | # prod_emb = torch.cat(prod_emb, dim=1) #(1,seq,dim)
843 |
844 | if len(input) < 2:
845 | diag_emb_last = diag_emb * torch.tensor(0.0)
846 | prod_emb_last = diag_emb * torch.tensor(0.0)
847 | else:
848 | diag_emb_last = sum_embedding(self.dropout(self.embeddings[0](torch.LongTensor(input[-2][0]).unsqueeze(dim=0).to(self.device)))) # (1,1,dim)
849 | prod_emb_last = sum_embedding(self.dropout(self.embeddings[1](torch.LongTensor(input[-2][1]).unsqueeze(dim=0).to(self.device))))
850 | # diag_emb_last = torch.cat(diag_emb_last, dim=1) #(1,seq,dim)
851 | # prod_emb_last = torch.cat(prod_emb_last, dim=1) #(1,seq,dim)
852 |
853 | health_representation = torch.cat([diag_emb, prod_emb], dim=-1).squeeze(dim=0) # (seq, dim*2)
854 | health_representation_last = torch.cat([diag_emb_last, prod_emb_last], dim=-1).squeeze(dim=0) # (seq, dim*2)
855 |
856 | health_rep = self.health_net(health_representation)[-1:, :] # (seq, dim)
857 | health_rep_last = self.health_net(health_representation_last)[-1:, :] # (seq, dim)
858 | health_residual_rep = health_rep - health_rep_last
859 |
860 | # drug representation
861 | drug_rep = self.prescription_net(health_rep)
862 | drug_rep_last = self.prescription_net(health_rep_last)
863 | drug_residual_rep = self.prescription_net(health_residual_rep)
864 |
865 | # reconstructon loss
866 | rec_loss = 1 / self.tensor_ddi_adj.shape[0] * torch.sum(torch.pow((F.sigmoid(drug_rep) - F.sigmoid(drug_rep_last + drug_residual_rep)), 2))
867 |
868 | # ddi_loss
869 | neg_pred_prob = F.sigmoid(drug_rep)
870 | neg_pred_prob = neg_pred_prob.t() * neg_pred_prob # (voc_size, voc_size)
871 |
872 | batch_neg = 1 / self.tensor_ddi_adj.shape[0] * neg_pred_prob.mul(self.tensor_ddi_adj).sum()
873 | return drug_rep, drug_rep_last, drug_residual_rep, batch_neg, rec_loss
874 |
875 | def init_weights(self):
876 | """Initialize weights."""
877 | initrange = 0.1
878 | for item in self.embeddings:
879 | item.weight.data.uniform_(-initrange, initrange)
880 |
--------------------------------------------------------------------------------