├── clinical_data ├── lung_cancer │ ├── lung_cancer.pdf │ ├── test3_pd.csv │ ├── visualization_breast.py │ ├── test0_pd.csv │ └── test4_pd.csv └── breast_cancer │ └── breast_test0_survival.csv ├── metrics.py ├── data ├── collinearity.py └── selection_bias.py ├── Logger.py ├── README.md ├── algorithm ├── DWR.py └── SRDO.py ├── model ├── linear.py ├── MLP.py └── STG.py ├── requirements.txt ├── utils.py ├── exp_svi.py ├── simulated_fs.py ├── simulated.py ├── clinical_breast_OS.py ├── clinical_breast_RFS.py ├── clinical_lung_OS.py ├── mRNA_breast_OS.py ├── mRNA_mela_OS.py ├── clinical_lung_DFS.py └── mRNA_HCC_OS.py /clinical_data/lung_cancer/lung_cancer.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/googlebaba/StableCox/HEAD/clinical_data/lung_cancer/lung_cancer.pdf -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_metric_class(metric_name): 4 | if metric_name not in globals(): 5 | raise NotImplementedError("Algorithm not found: {}".format(metric_name)) 6 | return globals()[metric_name] 7 | 8 | def L1_beta_error(beta_hat, beta): 9 | return np.sum(np.abs(beta_hat-beta)) 10 | 11 | def L2_beta_error(beta_hat, beta): 12 | return np.sum((beta_hat-beta)**2) -------------------------------------------------------------------------------- /data/collinearity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import ceil 3 | 4 | def _gen_Cov(rho = 0.8, p = 10, s = 2): 5 | Sigma = np.zeros((p, p)) 6 | for i in range(p//s): 7 | Sigma[i*s:(i+1)*s, i*s:(i+1)*s] = rho 8 | for i in range(p): 9 | Sigma[i, i] = 1 10 | return Sigma 11 | 12 | def gen_collinearity_data(rho = 0.8, p = 10, s = 2, n = 2000): # s是每个group的变量个数 13 | beta_base = [1/5, -2/5, 3/5, -4/5, 1, -1/5, 2/5, -3/5, 4/5, -1,] # hard-coded coefficients 14 | beta = beta_base * (ceil(p/len(beta_base))) 15 | beta = np.reshape(beta[:p], (-1, 1)) 16 | Sigma = _gen_Cov(rho, p, s) 17 | X = np.random.multivariate_normal([0]*p, Sigma, n) 18 | samplecov = np.cov(X.T) 19 | e_vals,e_vecs = np.linalg.eig(samplecov) 20 | v = e_vecs[:, np.argmin(e_vals)].reshape((p, 1)) 21 | bx = np.dot(X, v) 22 | fs = np.dot(X, beta) + np.reshape(np.prod(X[:,:3], axis=1), (-1, 1)) 23 | #fs = np.dot(X, beta) + bx 24 | Y = fs + np.random.randn(n, 1) 25 | return X, fs, Y 26 | -------------------------------------------------------------------------------- /Logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from utils import get_expname 5 | 6 | # set logger 7 | class Logger(): 8 | def __init__(self, args): 9 | self.txtname = time.strftime('%Y_%m%d_%H%M%S', time.localtime(time.time())) 10 | self.expname = get_expname(args) 11 | self.logpath = os.path.join(args.result_dir, self.expname, self.txtname) 12 | os.makedirs(os.path.join(args.result_dir, self.expname), exist_ok=True) 13 | 14 | self.logger = logging.getLogger() 15 | logging.getLogger('matplotlib.font_manager').disabled = True 16 | self.logger.setLevel(logging.DEBUG) # Here was logging.DEBUG... 17 | self.fh = logging.FileHandler(self.logpath, mode='a') 18 | self.fh.setLevel(logging.INFO) 19 | self.ch = logging.StreamHandler() 20 | self.ch.setLevel(logging.DEBUG) 21 | formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s") 22 | self.fh.setFormatter(formatter) 23 | self.ch.setFormatter(formatter) 24 | self.logger.addHandler(self.fh) 25 | self.logger.addHandler(self.ch) 26 | 27 | def log_args(self, args): 28 | for k, v in vars(args).items(): 29 | self.logger.info('%s: %s' % (k, str(v))) 30 | self.logger.info("") 31 | 32 | def debug(self, string): 33 | self.logger.debug(string) 34 | 35 | def info(self, string): 36 | self.logger.info(string) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/864352564.svg)](https://doi.org/10.5281/zenodo.13852489) 2 | 3 | # System Requirements 4 | ## Hardware requirements 5 | `Stable Cox' package requires only a standard computer with enough RAM to support the in-memory operations. 6 | 7 | ## Software requirements 8 | ### OS requirements 9 | This package is supported for *Linux*. The package has been tested on the following system: 10 | + Linux: Ubuntu 18.04 11 | ## directly install StableCox via pip 12 | pip install StableCox 13 | Tutorial on the package please see: https://pypi.org/project/StableCox/0.3/ 14 | 15 | ### Python Dependencies 16 | 'Stable Cox' mainly depends on the Python scientific stack. 17 | 18 | ``` 19 | lifelines=0.27.8 20 | numpy=1.20.3 21 | pandas=2.0.3 22 | scikit-learn=1.3.0 23 | ``` 24 | 25 | # Installation Guide 26 | conda create -n Stable_Cox python=3.8 27 | 28 | source activate Stable_Cox 29 | 30 | pip install -r requirements.txt 31 | 32 | - This takes several mins to build 33 | 34 | # Run demo 35 | 36 | ## omics data 37 | 38 | ### Stable Cox 39 | python3 mRNA_HCC_OS.py --reweighting SRDO --paradigm regr --topN 10 40 | ### Cox PH 41 | python3 mRNA_HCC_OS.py --reweighting None --paradigm regr --topN 10 42 | 43 | ## clinical data 44 | 45 | ### Stable Cox 46 | python3 clinical_lung_OS.py --reweighting SRDO --paradigm regr 47 | ### Cox PH 48 | python3 clinical_lung_OS.py --reweighting None --paradigm regr 49 | 50 | 51 | ## simulated data 52 | 53 | ### Stable Cox 54 | python3 simulated.py --reweighting SRDO --paradigm regr 55 | ### Cox PH 56 | python3 simulated.py --reweighting None --paradigm regr 57 | 58 | 59 | ## feature selection 60 | ### Stable Cox 61 | python3 simulated_fs.py --reweighting SRDO --paradigm regr --topN 5 62 | ### Cox PH 63 | python3 simulated_fs.py --reweighting None --paradigm regr --topN 5 64 | 65 | - The expected running time is from several seconds to mins depends on the number of samples. 66 | 67 | # License 68 | This project is licensed under the terms of the MIT license. 69 | -------------------------------------------------------------------------------- /algorithm/DWR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | import numpy as np 4 | from utils import weighted_cov_torch 5 | 6 | def decorr_loss(X, weight, cov_mask=None, order=1): 7 | n = X.shape[0] 8 | p = X.shape[1] 9 | balance_loss = 0.0 10 | for a in range(1, order+1): 11 | for b in range(a, order+1): 12 | if a != b: 13 | cov_mat = weighted_cov_torch(X**a, X**b, W=weight**2/n) 14 | else: 15 | cov_mat = weighted_cov_torch(X**a, W=weight**2/n) 16 | cov_mat = cov_mat**2 17 | cov_mat = cov_mat * cov_mask 18 | balance_loss += torch.sum(torch.sqrt(torch.sum(cov_mat, dim=1)-torch.diag(cov_mat) + 1e-10)) 19 | 20 | loss_weight_sum = (torch.sum(weight * weight) - n) ** 2 21 | loss_weight_l2 = torch.sum((weight * weight) ** 2) 22 | loss = 2000.0 / p * balance_loss + 100 * loss_weight_sum + 0.0005 * loss_weight_l2 # hard coding 23 | return loss, balance_loss, loss_weight_sum, loss_weight_l2 24 | 25 | def DWR(X, cov_mask=None, order=1, num_steps = 5000, lr = 0.01, tol=1e-8, loss_lb=0.001, iter_print=500, logger=None, device=None): 26 | X = torch.tensor(X, dtype=torch.float, device=device) 27 | n, p = X.shape 28 | 29 | if cov_mask is None: 30 | cov_mask = torch.ones((p, p), device=device) 31 | else: 32 | cov_mask = torch.tensor(cov_mask, dtype=torch.float, device=device) 33 | 34 | weight = torch.ones(n, 1, device=device) 35 | weight = weight.to(device) 36 | weight.requires_grad = True 37 | optimizer = optim.Adam([weight,], lr = lr) 38 | 39 | loss_prev = 0.0 40 | for i in range(num_steps): 41 | optimizer.zero_grad() 42 | loss, balance_loss, loss_s, loss_2 = decorr_loss(X, weight, cov_mask, order=order) 43 | loss.backward() 44 | optimizer.step() 45 | if abs(loss-loss_prev) <= tol or balance_loss <= loss_lb: 46 | break 47 | if (i+1) % iter_print == 0: 48 | logger.debug('iter %d: decorrelate loss %.6f balance loss %.6f loss_s %.6f loss_l2 %.6f' % (i+1, loss, balance_loss, loss_s, loss_2)) 49 | weight = (weight**2).cpu().detach().numpy() 50 | weight /= np.sum(weight) # normalize: weights sum up to 1 51 | return weight 52 | 53 | 54 | -------------------------------------------------------------------------------- /clinical_data/lung_cancer/test3_pd.csv: -------------------------------------------------------------------------------- 1 | ,Sex,Age,Location,"阻塞性肺炎/肺不张Obst pn 2 | or plugging","CT value 3 | Mean","CT value 4 | Std Dev","CT value 5 | P.major",CT V ratio,CT Kernel,Multiplicity多重性,Effusion胸膜积液,pos_LN,total_LN,LN ratio淋巴结比例,cT,cN,cM,pT病例T阶段,pN病理N阶段,pM,FVC用力肺活量 %PRED,FEV1 %PRED,FEV1一秒用力呼气量/FVC 用力肺活量(%) ,CEA癌胚抗原,post-op CTx术后CT,post-op RTx,Necrosis坏死_0,Necrosis坏死_1,Necrosis坏死_2,"Underlying 6 | lung_0","Underlying 7 | lung_1","Underlying 8 | lung_2","Underlying 9 | lung_3","Underlying 10 | lung_4",Bronchoscopy 支气管镜检_1,Bronchoscopy 支气管镜检_2,Bronchoscopy 支气管镜检_3,Bronchoscopy 支气管镜检_4,Differentiation分化_1,Differentiation分化_2,Differentiation分化_3,Differentiation分化_4,Smoking state_0,Smoking state_1,Smoking state_2,Smoking state_3,op type手术类型_1,op type手术类型_2,op type手术类型_3,op type手术类型_4,op type手术类型_5,Reccur.status,Reccur.months,Survival.status,Survival.months 11 | 0,0.0,0.3191489361702128,0.0,0.0,0.830812854442344,0.14370546318289787,0.6314814814814814,0.7128487090466595,1.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.0,0.3333333333333333,0.0,0.0,0.0,0.0,0.4823529411764706,0.41025641025641024,0.417910447761194,0.0057670126874279125,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0,56.7333333333333,1,94.0 12 | 1,0.0,0.40425531914893614,0.0,1.0,0.551039697542533,0.13776722090261284,0.5703703703703704,0.502959791624409,0.0,0.0,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.0,0.3333333333333333,0.0,0.8,0.5555555555555556,0.2835820895522388,0.003171856978085352,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,61.6333333333333,1,61.6333333333333 13 | 2,0.0,0.5106382978723404,0.0,1.0,0.6483931947069943,0.2042755344418053,0.5759259259259258,0.589261607698294,1.0,0.0,0.0,0.0,0.0,0.0,0.6666666666666666,0.3333333333333333,0.0,0.6666666666666666,0.3333333333333333,0.0,0.36470588235294116,0.5726495726495726,0.7313432835820896,0.000922722029988466,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0,6.4,0,24.0 14 | 3,0.0,0.574468085106383,0.0,1.0,0.6880907372400756,0.4631828978622328,0.312962962962963,0.876420621028281,1.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.6666666666666666,0.3333333333333333,0.0,0.6666666666666666,0.3333333333333333,0.0,0.32941176470588235,0.2222222222222222,0.34328358208955223,0.002595155709342561,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1,47.6333333333333,1,47.6333333333333 15 | -------------------------------------------------------------------------------- /model/linear.py: -------------------------------------------------------------------------------- 1 | from random import sample 2 | from sklearn import linear_model 3 | from lifelines import CoxPHFitter 4 | import pandas as pd 5 | import numpy as np 6 | from lifelines import CoxPHFitter 7 | import pandas as pd 8 | import numpy as np 9 | import seaborn as sns 10 | import matplotlib.pyplot as plt 11 | from lifelines.statistics import logrank_test 12 | from lifelines import KaplanMeierFitter 13 | from lifelines import LogLogisticAFTFitter, WeibullAFTFitter, LogNormalAFTFitter 14 | def get_algorithm_class(algorithm_name): 15 | """Return the algorithm class with the given name.""" 16 | if algorithm_name not in globals(): 17 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name)) 18 | return globals()[algorithm_name] 19 | 20 | def OLS(X, Y, W, **options): 21 | model = linear_model.LinearRegression(fit_intercept=False) 22 | model.fit(X, Y, sample_weight=W.reshape(-1)) 23 | return model 24 | 25 | def Lasso(X, Y, W, lam_backend=0.01, iters_train=1000, **options): 26 | model = linear_model.Lasso(alpha=lam_backend, fit_intercept=False, max_iter=iters_train) 27 | model.fit(X, Y, sample_weight=W.reshape(-1)) 28 | return model 29 | 30 | def Ridge(X, Y, W, lam_backend=0.01, iters_train=1000, **options): 31 | model = linear_model.Ridge(alpha=lam_backend, fit_intercept=False, max_iter=iters_train) 32 | model.fit(X, Y, sample_weight=W.reshape(-1)) 33 | return model 34 | 35 | 36 | def LogLogistic(X, duration_col, event_col, W, pen, **options): 37 | tmp = X[duration_col] 38 | tmp[tmp==0] = 0.0001 39 | llf = LogLogisticAFTFitter(penalizer=pen, fit_intercept=False).fit(X, duration_col=duration_col, event_col=event_col) 40 | return llf 41 | 42 | def Weibull(X, duration_col, event_col, W, pen, **options): 43 | tmp = X[duration_col] 44 | tmp[tmp==0] = 0.0001 45 | waf = WeibullAFTFitter(penalizer=pen, fit_intercept=False).fit(X, duration_col=duration_col, event_col=event_col) 46 | return waf 47 | def LogNormal(X, duration_col, event_col, W, pen, **options): 48 | tmp = X[duration_col] 49 | tmp[tmp==0] = 0.0001 50 | llf = LogNormalAFTFitter(penalizer=pen, fit_intercept=False).fit(X, duration_col=duration_col, event_col=event_col) 51 | return llf 52 | 53 | 54 | 55 | def Weighted_cox(X, duration_col, event_col, W, pen, **options): 56 | columns = X.columns 57 | all_X = np.concatenate((X, W), axis=1) 58 | 59 | all_X = pd.DataFrame(all_X, columns=list(columns)+["Weights"]) 60 | cph = CoxPHFitter(penalizer=pen) 61 | cph.fit(all_X, duration_col=duration_col, event_col=event_col, weights_col="Weights") 62 | 63 | return cph 64 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astor==0.8.1 2 | autograd==1.6.2 3 | autograd-gamma==0.5.0 4 | blessed==1.20.0 5 | Bottleneck==1.3.5 6 | Brotli==1.0.9 7 | certifi==2023.7.22 8 | cffi==1.15.1 9 | cgroups==0.1.0 10 | chardet==5.2.0 11 | charset-normalizer==2.0.4 12 | cloudpickle==3.0.0 13 | cmake==3.27.4.1 14 | colorama==0.4.6 15 | contextlib2==21.6.0 16 | contourpy==1.1.0 17 | cryptography==41.0.3 18 | cycler==0.11.0 19 | deepsurv==0.2.1 20 | dnspython==2.4.2 21 | ecos==2.0.12 22 | et-xmlfile==1.1.0 23 | exceptiongroup==1.2.0 24 | fairlearn==0.10.0 25 | filelock==3.11.0 26 | fonttools==4.42.1 27 | formulaic==0.6.4 28 | future==0.18.3 29 | gpustat==1.1.1 30 | graphlib-backport==1.0.3 31 | graphviz==0.8.4 32 | h5py==3.10.0 33 | hyperopt==0.1.2 34 | idna==3.4 35 | importlib-metadata==7.0.1 36 | importlib-resources==6.0.1 37 | iniconfig==2.0.0 38 | interface-meta==1.3.0 39 | Jinja2==3.1.2 40 | joblib==1.2.0 41 | json-tricks==3.17.3 42 | kiwisolver==1.4.5 43 | Lasagne==0.2.dev1 44 | lifelines==0.27.8 45 | lightgbm==4.3.0 46 | lit==16.0.6 47 | MarkupSafe==2.1.3 48 | matplotlib==3.7.2 49 | mkl-fft==1.3.8 50 | mkl-random==1.2.4 51 | mkl-service==2.4.0 52 | mpmath==1.3.0 53 | mxnet-mkl==1.6.0 54 | networkx==3.1 55 | nni==2.5 56 | numexpr==2.8.4 57 | numpy==1.20.3 58 | nvidia-cublas-cu11==11.10.3.66 59 | nvidia-cuda-cupti-cu11==11.7.101 60 | nvidia-cuda-nvrtc-cu11==11.7.99 61 | nvidia-cuda-runtime-cu11==11.7.99 62 | nvidia-cudnn-cu11==8.5.0.96 63 | nvidia-cufft-cu11==10.9.0.58 64 | nvidia-curand-cu11==10.2.10.91 65 | nvidia-cusolver-cu11==11.4.0.1 66 | nvidia-cusparse-cu11==11.7.4.91 67 | nvidia-ml-py==12.535.133 68 | nvidia-nccl-cu11==2.14.3 69 | nvidia-nvtx-cu11==11.7.91 70 | openpyxl==3.1.2 71 | osqp==0.6.3 72 | packaging==23.1 73 | pandas==2.0.3 74 | Pillow==10.0.0 75 | pip==23.2.1 76 | platformdirs==3.10.0 77 | pluggy==1.4.0 78 | pooch==1.7.0 79 | prettytable==3.9.0 80 | protobuf==5.26.0 81 | psutil==5.9.6 82 | pycparser==2.21 83 | pymongo==4.6.1 84 | pyOpenSSL==23.2.0 85 | pyparsing==3.0.9 86 | PySocks==1.7.1 87 | pytest==8.1.1 88 | python-dateutil==2.8.2 89 | PythonWebHDFS==0.2.3 90 | pytz==2023.3.post1 91 | PyYAML==6.0.1 92 | qdldl==0.1.7 93 | requests==2.31.0 94 | responses==0.24.1 95 | schema==0.7.5 96 | scikit-learn==1.3.0 97 | scikit-survival==0.22.1 98 | scipy==1.10.1 99 | seaborn==0.12.2 100 | setuptools==68.0.0 101 | simplejson==3.19.2 102 | six==1.16.0 103 | sklearn==0.0.post9 104 | sympy==1.12 105 | tensorboard-logger==0.1.0 106 | Theano==1.0.5 107 | threadpoolctl==2.2.0 108 | tomli==2.0.1 109 | torch==2.0.1 110 | torchaudio==2.0.2 111 | torchvision==0.15.2 112 | tqdm==4.66.1 113 | triton==2.0.0 114 | typeguard==4.1.2 115 | typing_extensions==4.7.1 116 | tzdata==2023.3 117 | ucimlrepo==0.0.3 118 | urllib3==1.26.18 119 | wcwidth==0.2.9 120 | websockets==12.0 121 | wheel==0.38.4 122 | whyshift==0.1.3 123 | wrapt==1.15.0 124 | xgboost==2.0.3 125 | zipp==3.11.0 126 | -------------------------------------------------------------------------------- /model/MLP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | from torch.nn.modules.loss import MSELoss 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, input_units=10, hidden_units=[100, 100], output_units=1, initialization="uniform", bias=False, mask=None, device=None): 8 | super().__init__() 9 | if mask is None: 10 | self.mask = torch.ones(input_units, device=device) 11 | else: 12 | self.mask = torch.tensor(mask, dtype=torch.float, device=device) 13 | self.device = device 14 | #assert len(hidden_units) > 0 15 | units = [input_units, ] + hidden_units 16 | layerdict = OrderedDict() 17 | for i in range(len(units)-1): 18 | layerdict["layer%d"%i] = nn.Linear(units[i], units[i+1], bias=bias) 19 | layerdict["relu%d"%i] = nn.ReLU() 20 | self.features = nn.Sequential(layerdict) 21 | self.classifier = nn.Linear(units[-1], output_units, bias=bias) 22 | if initialization == "uniform": 23 | def weights_init(m): 24 | if isinstance(m, nn.Linear): 25 | m.weight.data.uniform_(-1, 1) 26 | if bias: 27 | m.bias.data.uniform_(-1, 1) 28 | self.apply(weights_init) 29 | 30 | def set_mask(self, mask): 31 | assert not (mask is None) 32 | self.mask = torch.tensor(mask, dtype=torch.float, device=self.device) 33 | 34 | def forward(self, X): 35 | X = X*self.mask 36 | X = self.features(X) 37 | return self.classifier(X) 38 | 39 | def predict(self, X): 40 | X = torch.tensor(X, dtype=torch.float, device=self.device) 41 | X = X*self.mask 42 | self.eval() 43 | Y_pred = self.forward(X) 44 | self.train() 45 | return Y_pred.cpu().detach().numpy() 46 | 47 | 48 | def train(model, X, Y, W=None, lr=0.05, num_iters = 1000, tol=1e-13, logger=None, device=None): 49 | model.to(device) 50 | X = torch.tensor(X, dtype=torch.float, device=device) 51 | Y = torch.tensor(Y, dtype=torch.float, device=device) 52 | W = torch.tensor(W, dtype=torch.float, device=device) 53 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 54 | 55 | loss_prev = 0.0 56 | for ite in range(num_iters): 57 | optimizer.zero_grad() 58 | Y_pred = model(X) 59 | loss = torch.matmul(W.T, (Y_pred-Y)**2) 60 | loss.backward() 61 | optimizer.step() 62 | if ite == 0 or (ite+1) % 1000 == 0: 63 | logger.debug("%d/%d: loss %.4f" % (ite, num_iters, loss.item())) 64 | if abs(loss-loss_prev) <= tol: 65 | break 66 | loss_prev = loss.data 67 | 68 | 69 | 70 | 71 | if __name__ == "__main__": 72 | 73 | model = MLP(hidden_units=[]) 74 | print(next(model.parameters()).device) 75 | -------------------------------------------------------------------------------- /clinical_data/breast_cancer/breast_test0_survival.csv: -------------------------------------------------------------------------------- 1 | Patient ID,Survival.months,Survival.status,Recurr.months,Recurr.status,Age at Diagnosis,Cohort,Lymph nodes examined positive,Mutation Count,Nottingham prognostic index,Tumor Size,Tumor Stage,Type of Breast Surgery_Breast Conserving,Cellularity_High,Cellularity_Low,Cellularity_Moderate,Chemotherapy_No,ER status measured by IHC_Positve,ER Status_Positive,Neoplasm Histologic Grade_1.0,Neoplasm Histologic Grade_2.0,Neoplasm Histologic Grade_3.0,HER2 status measured by SNP6_Gain,HER2 status measured by SNP6_Loss,HER2 status measured by SNP6_Neutral,HER2 Status_Negative,Hormone Therapy_Yes,Inferred Menopausal State_Pre,Primary Tumor Laterality_Right,PR Status_Negative,Radio Therapy_No,3-Gene classifier subtype_ER+/HER2- High Prolif,3-Gene classifier subtype_ER+/HER2- Low Prolif,3-Gene classifier subtype_ER-/HER2-,3-Gene classifier subtype_HER2+ 2 | MB-6080,54.1,1,27.07,1,0.7533247533247532,5.0,0.12195121951219512,0.06666666666666667,0.9309074159373558,0.16201117318435757,0.6666666666666667,0,1,0,0,1,1,1,0,0,1,0,0,1,1,1,0,1,0,0,1,0,0,0 3 | MB-6108,193.9666667,0,191.41,0,0.1665951665951665,5.0,0.0,0.022222222222222223,0.23445416858590512,0.09497206703910616,0.0,0,1,0,0,1,1,1,0,1,0,0,1,0,1,1,1,1,0,1,0,1,0,0 4 | MB-6113,124.5666667,1,122.34,1,0.1866151866151865,5.0,0.04878048780487805,0.06666666666666667,0.4656840165822202,0.10614525139664806,0.3333333333333333,0,1,0,0,0,0,1,0,1,0,1,0,0,0,0,1,0,1,1,0,0,0,1 5 | MB-6118,264.6,0,93.09,1,0.21092521092521088,5.0,0.0,0.15555555555555556,0.2330723169046522,0.07821229050279331,0.0,1,0,1,0,1,1,1,0,1,0,0,0,1,1,0,1,1,0,0,0,1,0,0 6 | MB-6143,255.6,0,252.24,0,0.4896324896324896,5.0,0.0,0.1111111111111111,0.4776600644864118,0.25139664804469275,0.3333333333333333,1,0,0,1,1,0,0,0,0,1,0,0,1,1,0,0,1,1,0,0,0,1,0 7 | MB-6183,79.13333333,1,78.09,0,0.5702845702845701,5.0,0.0,0.15555555555555556,0.23445416858590512,0.09497206703910616,0.0,0,1,0,0,1,1,1,0,1,0,0,0,1,1,0,0,0,0,1,1,0,0,0 8 | MB-6201,222.7,1,219.77,0,0.7556127556127555,5.0,0.0,0.08888888888888889,0.23076923076923078,0.05027932960893855,0.0,1,1,0,0,1,1,1,0,1,0,0,0,1,1,0,0,0,0,0,1,0,0,0 9 | MB-6208,221.2,0,218.29,0,0.4371514371514371,5.0,0.04878048780487805,0.15555555555555556,0.4656840165822202,0.10614525139664806,0.3333333333333333,1,1,0,0,1,1,1,0,1,0,0,0,1,1,1,0,0,1,0,1,0,0,0 10 | MB-6218,147.7666667,1,101.15,1,0.33333333333333326,5.0,0.0,0.13333333333333333,0.46660525103638867,0.11731843575418996,0.3333333333333333,0,1,0,0,1,1,1,0,0,1,1,0,0,1,0,1,0,1,1,0,1,0,0 11 | MB-6223,221.9333333,0,219.01,0,0.46432146432146426,5.0,0.04878048780487805,0.08888888888888889,0.6959926301243666,0.10614525139664806,0.3333333333333333,1,1,0,0,0,0,0,0,0,1,0,0,1,1,0,0,0,1,0,0,0,1,0 12 | MB-6224,35.46666667,1,26.51,1,0.5802945802945801,5.0,0.7560975609756098,0.06666666666666667,0.92169507139567,0.05027932960893855,0.3333333333333333,0,1,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,1,0 13 | MB-6225,117.8666667,0,116.32,0,0.6016016016016016,5.0,0.0,0.35555555555555557,0.23353293413173654,0.08379888268156425,0.0,0,1,0,0,1,1,1,0,1,0,0,0,1,1,0,0,1,0,1,0,1,0,0 14 | MB-6229,0.1,1,0.1,1,0.7002717002717002,5.0,0.0,0.08888888888888889,0.002763703362505754,0.07821229050279331,0.0,0,1,0,0,1,1,1,1,0,0,0,0,1,1,0,0,1,0,1,0,1,0,0 15 | MB-6237,105.2,1,103.82,0,0.7044187044187044,5.0,0.0,0.06666666666666667,0.4656840165822202,0.10614525139664806,0.0,0,1,0,0,1,0,0,0,0,1,1,0,0,1,0,0,1,1,1,0,0,1,0 16 | -------------------------------------------------------------------------------- /algorithm/SRDO.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.neural_network import MLPClassifier 4 | from sklearn.linear_model import LogisticRegression 5 | def column_wise_resampling(x, p_s, decorrelation_type="global", random_state = 0, category_groups= None, **options): 6 | """ 7 | Perform column-wise random resampling to break the joint distribution of p(x). 8 | In practice, we can perform resampling without replacement (a.k.a. permutation) to retain all the data points of feature x_j. 9 | Moreover, if the practitioner has some priors on which features should be permuted, 10 | it can be passed through options by specifying 'sensitive_variables', by default it contains all the features 11 | """ 12 | rng = np.random.RandomState(random_state) 13 | n, p = x.shape 14 | if 'sensitive_variables' in options: 15 | sensitive_variables = options['sensitive_variables'] 16 | else: 17 | sensitive_variables = [i for i in range(p)] 18 | x_decorrelation = np.zeros([n, p]) 19 | if decorrelation_type == "global": 20 | for i in sensitive_variables: 21 | rand_idx = rng.permutation(n) 22 | x_decorrelation[:, i] = x[rand_idx, i] 23 | elif decorrelation_type == "group": 24 | rand_idx = rng.permutation(n) 25 | x_decorrelation[:, :p_s] = x[rand_idx, :p_s] 26 | for i in range(p_s, p): 27 | rand_idx = rng.permutation(n) 28 | x_decorrelation[:, i] = x[rand_idx, i] 29 | elif decorrelation_type == "category": 30 | print("len", len(category_groups)) 31 | pre_j = 1 32 | for i, j in enumerate(category_groups): 33 | if j ==pre_j: 34 | x_decorrelation[:, i] = x[:, i] 35 | pre_j = j 36 | continue 37 | count = category_groups.count(j) 38 | if count == 1: 39 | rand_idx = rng.permutation(n) 40 | x_decorrelation[:, i] = x[rand_idx, i] 41 | else: 42 | print(i, count) 43 | rand_idx = rng.permutation(n) 44 | x_decorrelation[:, i:i+count] = x[rand_idx, i:i+count] 45 | pre_j = j 46 | 47 | else: 48 | assert False 49 | return x_decorrelation 50 | 51 | def SRDO(x, p_s, decorrelation_type="global", solver = 'adam', hidden_layer_sizes = (100, 5), category_groups = None, max_iter = 500, random_state = 3): 52 | """ 53 | Calcualte new sample weights by density ratio estimation 54 | q(x) P(x belongs to q(x) | x) 55 | w(x) = ---- = ------------------------ 56 | p(x) P(x belongs to p(x) | x) 57 | """ 58 | n, p = x.shape 59 | x_decorrelation = column_wise_resampling(x, p_s, decorrelation_type, category_groups = category_groups, random_state = random_state) 60 | P = pd.DataFrame(x) 61 | Q = pd.DataFrame(x_decorrelation) 62 | corr_matrix = np.corrcoef(x, rowvar=False) 63 | abs_corr_matrix = np.abs(corr_matrix) 64 | sum_abs_corr = np.sum(abs_corr_matrix) - np.sum(np.diag(abs_corr_matrix)) 65 | P['src'] = 1 # 1 means source distribution 66 | Q['src'] = 0 # 0 means target distribution 67 | Z = pd.concat([P, Q], ignore_index=True, axis=0) 68 | labels = Z['src'].values 69 | Z = Z.drop('src', axis=1).values 70 | P, Q = P.values, Q.values 71 | # train a multi-layer perceptron to classify the source and target distribution 72 | #clf = LogisticRegression(random_state=0, C=0.1) 73 | clf = MLPClassifier(solver=solver, hidden_layer_sizes=hidden_layer_sizes, max_iter=max_iter, random_state=random_state) 74 | clf.fit(Z, labels) 75 | proba = clf.predict_proba(Z)[:len(P), 1] 76 | weights = (1./proba) - 1. # calculate sample weights by density ratio 77 | weights /= np.sum(weights) # normalize the weights to get average 1 78 | weights = np.reshape(weights, [n,1]) 79 | return weights 80 | -------------------------------------------------------------------------------- /model/STG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.optim as optim 5 | import numpy as np 6 | from utils import pretty 7 | from torch.autograd import grad 8 | 9 | class LinearRegression(nn.Module): 10 | def __init__(self, input_dim, output_dim=1): 11 | super(LinearRegression, self).__init__() 12 | self.linear = nn.Linear(input_dim, output_dim, bias=False) 13 | self.weight_init() 14 | 15 | def weight_init(self): 16 | torch.nn.init.xavier_uniform_(self.linear.weight) 17 | 18 | def forward(self, x): 19 | return self.linear(x) 20 | 21 | # Feature selection part 22 | class FeatureSelector(nn.Module): 23 | def __init__(self, input_dim, sigma): 24 | super(FeatureSelector, self).__init__() 25 | self.mu = torch.nn.Parameter(0.00 * torch.randn(input_dim, ), requires_grad=True) 26 | self.noise = torch.randn(self.mu.size()) 27 | self.sigma = sigma 28 | self.input_dim = input_dim 29 | 30 | def renew(self): 31 | self.mu = torch.nn.Parameter(0.00 * torch.randn(self.input_dim, ), requires_grad=True) 32 | self.noise = torch.randn(self.mu.size()) 33 | 34 | def forward(self, prev_x): 35 | z = self.mu + self.sigma * self.noise.normal_() * self.training 36 | stochastic_gate = self.hard_sigmoid(z) 37 | new_x = prev_x * stochastic_gate 38 | return new_x 39 | 40 | def hard_sigmoid(self, x): 41 | return torch.clamp(x + 0.5, 0.0, 1.0) 42 | 43 | def regularizer(self, x): 44 | return 0.5 * (1 + torch.erf(x / math.sqrt(2))) 45 | 46 | 47 | class STG: 48 | def __init__(self, input_dim, output_dim, sigma=1.0, lam=0.1, hard_sum = 1.0): 49 | self.backmodel = LinearRegression(input_dim, output_dim) 50 | self.loss = nn.MSELoss() 51 | self.featureSelector = FeatureSelector(input_dim, sigma) 52 | self.reg = self.featureSelector.regularizer 53 | self.lam = lam 54 | self.mu = self.featureSelector.mu 55 | self.sigma = self.featureSelector.sigma 56 | #self.alpha = alpha 57 | self.optimizer = optim.Adam([{'params': self.backmodel.parameters(), 'lr': 1e-3}, 58 | {'params': self.mu, 'lr': 3e-4}]) 59 | self.hard_sum = hard_sum 60 | self.input_dim = input_dim 61 | 62 | def renew(self): 63 | self.featureSelector.renew() 64 | self.mu = self.featureSelector.mu 65 | self.backmodel.weight_init() 66 | self.optimizer = optim.Adam([{'params': self.backmodel.parameters(), 'lr': 1e-3}, 67 | {'params': self.mu, 'lr': 3e-4}]) 68 | 69 | 70 | def pretrain(self, X, Y, pretrain_epoch=100): 71 | pre_optimizer = optim.Adam([{'params': self.backmodel.parameters(), 'lr': 1e-3}]) 72 | for i in range(pretrain_epoch): 73 | self.optimizer.zero_grad() 74 | pred = self.backmodel(X) 75 | loss = self.loss(pred, Y.reshape(pred.shape)) 76 | loss.backward() 77 | pre_optimizer.step() 78 | 79 | 80 | def get_gates(self): 81 | return self.mu+0.5 82 | 83 | def get_ratios(self): 84 | return self.reg((self.mu + 0.5) / self.sigma) 85 | 86 | def get_params(self): 87 | return self.backmodel.linear.weight 88 | 89 | def train(self, X, Y, W, epochs): 90 | X = torch.tensor(X, dtype=torch.float) 91 | Y = torch.tensor(Y, dtype=torch.float) 92 | W = torch.tensor(W, dtype=torch.float) 93 | self.renew() 94 | self.pretrain(X, Y, 3000) 95 | for epoch in range(1,epochs+1): 96 | self.optimizer.zero_grad() 97 | Y_pred = self.backmodel(self.featureSelector(X)) 98 | loss_erm = torch.matmul(W.T, (Y_pred-Y)**2) 99 | reg = torch.sum(self.reg((self.mu + 0.5) / self.sigma)) 100 | loss = loss_erm + self.lam * reg**2 101 | loss.backward() 102 | self.optimizer.step() 103 | if epoch % 1000 ==0: 104 | print("Epoch %d | Loss = %.4f | Ratio = %s | Theta = %s" % 105 | (epoch, loss, pretty(self.get_ratios()), pretty(self.get_params()))) 106 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from select import select 2 | import numpy as np 3 | import random 4 | from math import ceil 5 | import torch 6 | from model.linear import get_algorithm_class 7 | from sklearn.metrics import mean_squared_error 8 | beta_base = [1/3, -2/3, 1, -1/3, 2/3, -1,] # hard-coded coefficients 9 | 10 | def setup_seed(seed): 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | 18 | def weighted_cov(X, W): 19 | ''' 20 | X: numpy array, (n, p) 21 | W: numpy array, (n, 1), sum up to 1 22 | ''' 23 | X_bar = np.matmul(X.T, W) # shape: (p, 1) 24 | return np.matmul(X.T, W*X) - np.matmul(X_bar, X_bar.T) 25 | 26 | def weighted_cov_torch(X, Y=None, W=None): 27 | if Y is None: 28 | X_bar = torch.matmul(X.T, W) # shape: (p, 1) 29 | return torch.matmul(X.T, W*X) - torch.matmul(X_bar, X_bar.T) 30 | else: 31 | X_bar = torch.matmul(X.T, W) # shape: (p, 1) 32 | Y_bar = torch.matmul(Y.T, W) # shape: (p, 1) 33 | return torch.matmul(X.T, W*Y) - torch.matmul(X_bar, Y_bar.T) 34 | 35 | 36 | def weighted_corr(X, Y=None, W=None): 37 | ''' 38 | X: numpy array, (n, p) 39 | W: numpy array, (n, 1), sum up to 1 40 | ''' 41 | if Y is None: 42 | X_bar = np.matmul(X.T, W) # shape: (p, 1) 43 | X_2_bar = np.matmul((X**2).T, W) # shape: (p, 1) 44 | varX = X_2_bar - X_bar**2 45 | return (np.matmul(X.T, W*X) - np.matmul(X_bar, X_bar.T)) / np.sqrt(np.matmul(varX, varX.T)) 46 | else: 47 | X_bar = np.matmul(X.T, W) # shape: (p, 1) 48 | Y_bar = np.matmul(Y.T, W) 49 | X_2_bar = np.matmul((X**2).T, W) # shape: (p, 1) 50 | Y_2_bar = np.matmul((Y**2).T, W) 51 | varX = X_2_bar - X_bar**2 52 | varY = Y_2_bar - Y_bar**2 53 | return (np.matmul(X.T, W*Y) - np.matmul(X_bar, Y_bar.T)) / np.sqrt(np.matmul(varX, varY.T)) 54 | 55 | def pretty(vector): 56 | if type(vector) is list: 57 | vlist = vector 58 | elif type(vector) is np.ndarray: 59 | vlist = vector.reshape(-1).tolist() 60 | else: 61 | vlist = vector.view(-1).tolist() 62 | return "[" + ", ".join("{:+.4f}".format(vi) for vi in vlist) + "]" 63 | 64 | def get_cov_mask(select_ratio): 65 | select_ratio = np.reshape(select_ratio, (-1, 1)) 66 | cov_mask = 1-np.matmul(select_ratio, select_ratio.T) 67 | return cov_mask 68 | 69 | 70 | def get_beta_s(p_s): 71 | beta_s = beta_base * (ceil(p_s/len(beta_base))) 72 | return np.array(beta_s[:p_s]) 73 | 74 | def get_beta_collinearity(p): 75 | beta_base = [1/5, -2/5, 3/5, -4/5, 1, -1/5, 2/5, -3/5, 4/5, -1,] # hard-coded coefficients 76 | beta = beta_base * (ceil(p/len(beta_base))) 77 | return np.array(beta[:p]) 78 | 79 | def get_expname(args): 80 | order = "-order=%d"%args.order if args.reweighting == "DWR" else "" 81 | MLP_gen = "-MLP_gen=%s" % "_".join([str(j) for j in args.hidden_units_gen]) if args.true_func == "MLP" else "" 82 | MLP_backend = "-MLP_backend=%s" % "_".join([str(j) for j in args.hidden_units_backend]) if args.backend == "MLP" else "" 83 | misspe = "-misspe=%s"%args.misspe if args.true_func == "linear" else "" 84 | paradigm = "" if args.true_func == "linear" else "-paradigm=%s"%args.paradigm 85 | backend = "fs=%s"%args.fs_type if args.paradigm == "fs" else "regr=%s" % args.backend 86 | return "p=%d%s%s-n=%d-Vb_ratio=%.1f-%s%s-rtrain=%.1f-%s-%s-spurious=%s-%s%s%s"%(args.p, MLP_gen, MLP_backend, args.n, args.Vb_ratio, args.mode, misspe, args.r_train, args.reweighting, backend, args.spurious, args.decorrelation_type, order, paradigm) 87 | 88 | def calc_var(beta, X, fx): 89 | beta = np.reshape(beta, (-1, 1)) 90 | linear_term = np.matmul(X, beta) 91 | nonlinear_term = fx - linear_term 92 | return np.var(linear_term), np.var(nonlinear_term), np.var(fx) 93 | 94 | def gen_Cov(p, rho): 95 | cov = np.ones((p, p))*rho 96 | for i in range(p): 97 | cov[i, i] = 1 98 | return cov 99 | 100 | def gen_interaction_terms(X): 101 | X_mean = X.mean(axis=0) 102 | n, p = X.shape 103 | p_ia = (p*(p-1))//2 104 | X_ia = np.zeros((n, p_ia)) 105 | cnt = 0 106 | for i in range(p): 107 | for j in range(i+1, p): 108 | X_ia[:, cnt] = (X[:,i]-X_mean[i])*(X[:, j]-X_mean[j]) 109 | cnt += 1 110 | assert cnt == p_ia 111 | return X_ia 112 | 113 | def GridSearch(args, model_name, X_train, Y_train, W, X_val, Y_val): 114 | model_func = get_algorithm_class(model_name) 115 | best_MSE = 1e10 116 | best_lam = -1.0 117 | best_model = None 118 | if model_name != "OLS": 119 | for lam in args.lambda_grid: 120 | model = model_func(args, X_train, Y_train, W, lam) 121 | MSE = mean_squared_error(Y_val, model.predict(X_val)) 122 | if MSE < best_MSE: 123 | best_MSE = MSE 124 | best_lam = lam 125 | best_model = model 126 | else: 127 | best_model = model_func(args, X_train, Y_train, W) 128 | return best_model, best_lam 129 | 130 | def BV_analysis(beta_hat_array, beta): 131 | beta_hat_mean = np.mean(beta_hat_array, axis=0) 132 | bias = np.sum((beta_hat_mean-beta)**2) 133 | var = np.sum(np.diag(np.cov(beta_hat_array, rowvar=False))) 134 | return bias, var 135 | 136 | if __name__ == "__main__": 137 | X = torch.randn(2000, 10) 138 | print(weighted_cov(X.numpy(), W=np.ones((X.shape[0], 1))/X.shape[0])) 139 | print(weighted_cov_torch(X, W=torch.ones((X.shape[0], 1))/X.shape[0])) 140 | print(np.cov(X.numpy().T)) -------------------------------------------------------------------------------- /clinical_data/lung_cancer/visualization_breast.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | from sklearn.decomposition import PCA 4 | import numpy as np 5 | import seaborn as sns 6 | from sklearn.manifold import TSNE 7 | from random import sample 8 | from matplotlib.patches import Ellipse 9 | import matplotlib.transforms as transforms 10 | import matplotlib.lines as mlines 11 | from matplotlib.patches import FancyArrowPatch 12 | #training_pd_data = pd.read_csv('../nature_survival.csv', index_col=0) 13 | #test0_pd_data = pd.read_csv('../cell_survival.csv', index_col=0) 14 | #test1_pd_data = pd.read_csv('../NC_survival.csv', index_col=0) 15 | #training_pd_data = pd.read_csv('../breast_cancer/breast_train_survival.csv', index_col=0) 16 | #test0_pd_data = pd.read_csv('../breast_cancer/breast_test1_survival.csv', index_col=0) 17 | #test1_pd_data = pd.read_csv('../breast_cancer/breast_test2_survival.csv', index_col=0) 18 | 19 | #training_data = pd.read_csv('./train_pd.csv', index_col=0) 20 | test0_data = pd.read_csv('./test4_pd_visual.csv', index_col=0) 21 | test1_data = pd.read_csv('./test5_pd_visual.csv', index_col=0) 22 | #test3_data = pd.read_csv('../medical/test3.csv', index_col=0) 23 | 24 | #test4_data = pd.read_csv('../medical/test4.csv', index_col=0) 25 | #test5_data = pd.read_csv('../medical/test5.csv', index_col=0) 26 | #test6_data = pd.read_csv('../medical/test6.csv', index_col=0) 27 | #test7_data = pd.read_csv('../medical/test7.csv', index_col=0) 28 | 29 | #training_pd_data = training_pd_data.drop(['Survival.months', 'Survival.status', 'Cohort'], axis=1) 30 | 31 | #test0_pd_data = test0_pd_data.drop(['Survival.months', 'Survival.status', 'Cohort'], axis=1) 32 | #test1_pd_data = test1_pd_data.drop(['Survival.months', 'Survival.status', 'Cohort'], axis=1) 33 | 34 | 35 | #training_data = training_data.drop(['2', "Reccur.status", "Reccur.months"], axis=1) 36 | #test0_data = test0_data.drop(["Survival.status", "Survival.months"], axis=1) 37 | #test1_data = test1_data.drop(["Survival.status", "Survival.months"], axis=1) 38 | #test0_data.to_csv("./central.csv", index=False) 39 | #test1_data.to_csv("./peripheral.csv", index=False) 40 | #train_np = np.array(training_data[training_data.columns[:-2]]) 41 | 42 | def confidence_ellipse(x, y, ax, n_std=2.4477, facecolor='none', **kwargs): 43 | """ 44 | 创建一个表示给定数据的n标准差协方差椭圆的路径补丁。 45 | n_std: 对应于95%置信区间的标准差数。对于2D数据,2.4477标准差覆盖约95%的点。 46 | """ 47 | if x.size != y.size: 48 | raise ValueError("x和y必须大小相同") 49 | 50 | cov = np.cov(x, y) 51 | pearson = cov[0, 1]/np.sqrt(cov[0, 0] * cov[1, 1]) # 相关系数 52 | # 使用特征值计算椭圆的旋转角度和半轴长度 53 | rad_x = np.sqrt(1 + pearson) 54 | rad_y = np.sqrt(1 - pearson) 55 | ellipse_radius_x = np.sqrt(1 + pearson) * n_std 56 | ellipse_radius_y = np.sqrt(1 - pearson) * n_std 57 | ellipse = Ellipse((0, 0), width=ellipse_radius_x * 2, height=ellipse_radius_y * 2, 58 | facecolor=facecolor, **kwargs) 59 | 60 | # 计算椭圆的旋转角度 61 | scale_x = np.sqrt(cov[0, 0]) * n_std 62 | scale_y = np.sqrt(cov[1, 1]) * n_std 63 | mean_x = np.mean(x) 64 | mean_y = np.mean(y) 65 | 66 | transf = transforms.Affine2D()\ 67 | .rotate_deg(45)\ 68 | .scale(scale_x, scale_y)\ 69 | .translate(mean_x, mean_y) 70 | 71 | ellipse.set_transform(transf + ax.transData) 72 | return ax.add_patch(ellipse) 73 | 74 | test0_np = np.array(test0_data[test0_data.columns[:-4]]) 75 | test1_np = np.array(test1_data[test1_data.columns[:-4]]) 76 | 77 | 78 | pca = PCA(n_components=2) 79 | tsne = TSNE(n_components=2, random_state=3) 80 | all_data = np.concatenate((test0_np, test1_np)) 81 | all_data = np.nan_to_num(all_data) 82 | all_data = tsne.fit_transform(all_data) 83 | 84 | test1 = all_data[:test0_np.shape[0]] 85 | test2 = all_data[test0_np.shape[0]:test0_np.shape[0]+test1_np.shape[0]] 86 | 87 | 88 | print(test1.shape) 89 | print(test2.shape) 90 | 91 | camp2 = sns.color_palette("Set2") 92 | colors1 = sample(camp2, 4) 93 | fig, ax = plt.subplots() 94 | 95 | colors1 = ['#F39B7FCC', '#91D1c2cC', '#ADB6B6B2', "#925E9FB2"] 96 | colors2 = ['#F39B7F50', '#91D1c250', '#ADB6B650', "#925E9F50"] 97 | 98 | #plt.scatter(train[:, 0], train[:, 1], color=colors1[0], label=f'Cohort1') 99 | ax.scatter(test1[:, 0], test1[:, 1], color=colors1[1], s = 1, label=f'Cohort2') 100 | confidence_ellipse(test1[:, 0], test1[:, 1], ax, facecolor=colors2[1], n_std=1.5, edgecolor=colors1[1]) 101 | 102 | ax.scatter(test2[:, 0], test2[:, 1], color=colors1[2], s = 1, label=f'Cohort3') 103 | confidence_ellipse(test2[:, 0], test2[:, 1], ax, facecolor=colors2[2], n_std=1.5, edgecolor=colors1[2]) 104 | 105 | for spine in ax.spines.values(): 106 | spine.set_visible(False) 107 | 108 | x_start, x_end = np.min(all_data[:, 0]), np.max(all_data[:, 0]) 109 | y_start, y_end = np.min(all_data[:, 1]), np.max(all_data[:, 1]) 110 | x_mid = (x_end - x_start) / 2 111 | y_mid = (y_end - y_start) / 2 112 | 113 | ax.add_patch(FancyArrowPatch((x_start-8.3, y_start-8.7), (x_start-8.3, y_start+y_mid/2), 114 | arrowstyle='->', mutation_scale=10, color='k')) 115 | ax.add_patch(FancyArrowPatch((x_start-8.6, y_start-8.4), (x_start+x_mid/2-1, y_start-8.4), 116 | arrowstyle='->', mutation_scale=10, color='k')) 117 | #ax.add_patch(FancyArrowPatch((0, 0), (1, 0), transform=ax.transAxes, arrowstyle="->", color='k')) 118 | # 添加y轴 119 | #ax.add_patch(FancyArrowPatch((0, 0), (0, 1), transform=ax.transAxes, arrowstyle="->", color='k')) 120 | 121 | # 移除刻度 122 | ax.set_xticks([]) 123 | ax.set_yticks([]) 124 | # 设置图表限制,确保箭头不会超出图表范围 125 | legend_handle1 = mlines.Line2D([], [], color=colors1[1], marker='o', linestyle='None', 126 | markersize=10, label='Location (Central)') 127 | legend_handle2 = mlines.Line2D([], [], color=colors1[2], marker='o', linestyle='None', 128 | markersize=10, label='Location (Peripheral)') 129 | 130 | 131 | 132 | #ax.set_xlim(-40, 30) 133 | #ax.set_ylim(-40, 40) 134 | plt.text(x_start-10, y_start-4, 'tSNE2', ha='center', rotation=90, fontsize=12) 135 | plt.text(x_start, y_start-11, 'tSNE1', ha='center', fontsize=12) 136 | #ax.set_xlabel("tSNE1") 137 | #ax.set_ylabel("tSNE2") 138 | plt.legend(handles=[legend_handle1, legend_handle2], loc='lower right', bbox_to_anchor=(1.2, 0), frameon=False, handletextpad=1) 139 | plt.savefig('./lung_cancer.pdf', dpi=400, bbox_inches = 'tight') 140 | plt.show() 141 | -------------------------------------------------------------------------------- /data/selection_bias.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import ceil 3 | from sklearn import linear_model 4 | from sklearn.metrics import mean_squared_error 5 | import random 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | from model.MLP import MLP 9 | import torch 10 | import time 11 | from utils import gen_Cov, get_beta_s 12 | 13 | def setup_seed(seed): 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | 17 | def _gen_data(p=10, r=1.7, n_single=2000, V_ratio=0.5, Vb_ratio=0.1, gener_method='cox_exp', true_func="linear", mlp=None, mode="S_|_V", misspe="poly", mms_strength=1.0, corr_s=0.9, corr_v=0.1, spurious="nonlinear", noise_variance=0.3, device=None, **options): 18 | ''' 19 | p: dim of input X 20 | r: bias rate 21 | n: number of samples 22 | ''' 23 | # dim of S, V, V_b 24 | p_v = int(p*V_ratio) 25 | p_s = p-p_v 26 | p_b = int(p*Vb_ratio) 27 | 28 | # generate covariates 29 | Z = np.random.randn(n_single, p) 30 | S = np.zeros((n_single, p_s)) 31 | V = np.zeros((n_single, p_v)) 32 | if mode == "S_|_V": 33 | V = np.random.randn(n_single, p_v) 34 | for i in range(p_s): 35 | S[:, i] = 0.8*Z[:, i] + 0.2*Z[:, i+1] # hard-coding 36 | elif mode == "S->V": 37 | for i in range(p_s): 38 | S[:, i] = 0.8*Z[:, i] + 0.2*Z[:, i+1] 39 | for j in range(p_v): 40 | V[:, j] = 0.8*S[:, j] + 0.2*S[:, (j+1)%(p_s)] + np.random.randn(n_single) 41 | elif mode == "V->S": 42 | V = np.random.randn(n_single, p_v) 43 | for j in range(p_s): 44 | S[:, j] = 0.2*V[:, j] + 0.8*V[:, (j+1)%(p_v)] + np.random.randn(n_single) 45 | elif mode == "collinearity": 46 | Sigma_s = gen_Cov(p_s, corr_s) 47 | Sigma_v = gen_Cov(p_v, corr_v) 48 | S = np.random.multivariate_normal([0]*p_s, Sigma_s, n_single) 49 | V = np.random.multivariate_normal([0]*p_v, Sigma_v, n_single) 50 | else: 51 | raise NotImplementedError 52 | 53 | # generate f(S) 54 | if true_func == "linear": 55 | beta_s = get_beta_s(p_s) 56 | beta_s = np.reshape(beta_s, (-1, 1)) 57 | linear_term = np.matmul(S, beta_s) 58 | if misspe == "poly": # hard-coding: S1·S2·S3 59 | nonlinear_term = np.reshape(np.prod(S[:,:3], axis=1), (-1, 1)) 60 | elif misspe == "exp": # hard-coding: exp(S1·S2·S3) 61 | nonlinear_term = np.reshape(np.exp(np.prod(S[:,:3], axis=1)), (-1, 1)) 62 | elif misspe == "None": 63 | nonlinear_term = 0 64 | else: 65 | raise NotImplementedError 66 | fs = linear_term + mms_strength*nonlinear_term 67 | elif true_func == "MLP": 68 | fs = mlp(torch.tensor(S, dtype=torch.float, device=device)).detach().cpu().numpy() 69 | elif true_func == "poly": 70 | fs = np.reshape(np.prod(S, axis=1), (-1, 1)) 71 | elif true_func == "exp": 72 | fs = np.reshape(np.exp(np.prod(S, axis=1)), (-1, 1)) 73 | else: 74 | raise NotImplementedError 75 | 76 | # generate spurious correlation 77 | if spurious == "nonlinear": 78 | D = np.abs(fs-r/abs(r)*V[:,-p_b:]) # dim: (n, p_b), select the last p_b dim of V as V_b 79 | elif spurious == "linear": 80 | D = np.abs(linear_term-r/abs(r)*V[:,-p_b:]) 81 | else: 82 | raise NotImplementedError 83 | Pr = np.power(abs(r), -5*np.sum(D, axis=1)) # probability of being selected for certain samples 84 | select = np.random.uniform(size=Pr.shape[0]) < Pr 85 | # select 86 | S = S[select, :] 87 | V = V[select, :] 88 | X = np.concatenate((S, V), axis=1) 89 | fs = fs[select, :] 90 | 91 | lambda_param = fs #+ np.random.randn(*fs.shape)*np.sqrt(noise_variance) 92 | lambda_param = lambda_param.reshape(lambda_param.shape[0]) 93 | 94 | 95 | U = np.random.uniform(0, 1, size=lambda_param.shape[0]) 96 | #noise = (np.random.rand(*lambda_param.shape)*np.sqrt(noise_variance)).reshape(-1) 97 | if gener_method == "cox_exp": 98 | tmp = np.exp(-lambda_param).reshape(-1) 99 | noise = (np.random.rand(*lambda_param.shape)*np.sqrt(noise_variance)).reshape(-1) 100 | Y = 1/0.5*(-np.log(U)*tmp)# + noise 101 | elif gener_method == "cox_weibull": 102 | tmp = np.exp(-lambda_param).reshape(-1) 103 | Y = np.power(-1/0.5*np.log(U)*tmp, 4) 104 | elif gener_method == "cox_Gompertz": 105 | Y = 1/2 * np.log(1-np.log(U)/np.exp(lambda_param)) 106 | elif gener_method == "exp_T": 107 | noise = np.random.exponential(scale=1, size=lambda_param.shape[0]) 108 | Y = np.exp(lambda_param) + noise 109 | elif gener_method == "log_T": 110 | noise = np.random.normal(0, 0.5, size=lambda_param.shape[0]) 111 | Y = np.exp(lambda_param + noise) 112 | elif gener_method == "poly": 113 | #lambda_param = lambda_param + abs(min(lambda_param)) + 0.001 114 | select = lambda_param > 0 115 | lambda_param = lambda_param[select] 116 | U = np.random.uniform(0, 1, size=lambda_param.shape[0]) 117 | S = S[select, :] 118 | V = V[select, :] 119 | X = np.concatenate((S, V), axis=1) 120 | Y = -np.log(U)/lambda_param 121 | return X, S, V, fs, Y 122 | 123 | def gen_selection_bias_data(args): 124 | n_total = args["n"] 125 | n_cur = 0 126 | S_list = [] 127 | V_list = [] 128 | fs_list = [] 129 | Y_list = [] 130 | while n_cur < n_total: 131 | _, S, V, fs, Y = _gen_data(n_single=n_total, **args) 132 | S_list.append(S) 133 | V_list.append(V) 134 | fs_list.append(fs) 135 | Y_list.append(Y) 136 | n_cur += Y.shape[0] 137 | S = np.concatenate(S_list, axis=0)[:n_total] 138 | V = np.concatenate(V_list, axis=0)[:n_total] 139 | fs = np.concatenate(fs_list, axis=0)[:n_total] 140 | Y = np.concatenate(Y_list, axis=0)[:n_total] 141 | X = np.concatenate((S, V), axis=1) 142 | 143 | return X, S, V, fs, Y 144 | 145 | 146 | def data_split(X, split_ratio=0.8): 147 | p_split = int(len(X)*split_ratio) 148 | return X[:p_split], X[p_split:] 149 | 150 | 151 | 152 | if __name__ == "__main__": 153 | setup_seed(7) 154 | X, S, V, fs, Y = gen_selection_bias_data(p=10, r=1.7, n_total=1000, mode="S->V", misspe="poly") 155 | print(S.shape) 156 | print(V.shape) 157 | print(fs.shape) 158 | print(Y.shape) 159 | corr = np.corrcoef((np.concatenate((X, fs, Y), axis=1).T)) 160 | ax = sns.heatmap(corr, cmap="YlGnBu") 161 | plt.savefig("test.png") 162 | 163 | regr = linear_model.LinearRegression() 164 | X_train, X_test = data_split(X) 165 | Y_train, Y_test = data_split(Y) 166 | regr.fit(X_train, Y_train) 167 | Y_pred = regr.predict(X_test) 168 | print(Y_pred.shape) 169 | print("Coefficients: ", regr.coef_) 170 | print("Mean squared error: %.2f" % mean_squared_error(Y_test, Y_pred)) 171 | -------------------------------------------------------------------------------- /exp_svi.py: -------------------------------------------------------------------------------- 1 | from data.selection_bias import gen_selection_bias_data 2 | from algorithm.DWR import DWR 3 | from algorithm.SRDO import SRDO 4 | from model.linear import get_algorithm_class 5 | from metrics import get_metric_class 6 | from utils import setup_seed, get_beta_s, get_expname, calc_var, pretty, get_cov_mask, BV_analysis 7 | from Logger import Logger 8 | from model.STG import STG 9 | 10 | from sklearn.metrics import mean_squared_error 11 | import numpy as np 12 | import argparse 13 | import os 14 | import torch 15 | from collections import defaultdict as dd 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser(description="Script to launch sample reweighting experiments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | 20 | # data generation 21 | parser.add_argument("--p", type=int, default=10, help="Input dim") 22 | parser.add_argument("--n", type=int, default=2000, help="Sample size") 23 | parser.add_argument("--V_ratio", type=float, default=0.5) 24 | parser.add_argument("--Vb_ratio", type=float, default=0.1) 25 | parser.add_argument("--true_func", choices=["linear",], default="linear") 26 | parser.add_argument("--mode", choices=["S_|_V", "S->V", "V->S", "collinearity"], default="collinearity") 27 | parser.add_argument("--misspe", choices=["poly", "exp", "None"], default="poly") 28 | parser.add_argument("--corr_s", type=float, default=0.9) 29 | parser.add_argument("--corr_v", type=float, default=0.1) 30 | parser.add_argument("--mms_strength", type=float, default=1.0, help="model misspecifction strength") 31 | parser.add_argument("--spurious", choices=["nonlinear", "linear"], default="nonlinear") 32 | parser.add_argument("--r_train", type=float, default=2.5, help="Input dim") 33 | parser.add_argument("--r_list", type=float, nargs="+", default=[-3, -2, -1.7, -1.5, -1.3, 1.3, 1.5, 1.7, 2, 3]) 34 | parser.add_argument("--noise_variance", type=float, default=0.3) 35 | 36 | # frontend reweighting 37 | parser.add_argument("--reweighting", choices=["None", "DWR", "SRDO"], default="DWR") 38 | parser.add_argument("--decorrelation_type", choices=["global", "group"], default="global") 39 | parser.add_argument("--order", type=int, default=1) 40 | parser.add_argument("--iters_balance", type=int, default=20000) 41 | 42 | # backend model 43 | parser.add_argument("--backend", choices=["OLS"], default="OLS") 44 | parser.add_argument("--paradigm", choices=["fs",], default="fs") 45 | parser.add_argument("--iters_train", type=int, default=1000) 46 | parser.add_argument("--lam_backend", type=float, default=0.01) # regularizer coefficient 47 | parser.add_argument("--fs_type", choices=["SVI"], default="SVI") 48 | parser.add_argument("--mask_given", type=int, nargs="+", default=[1,1,1,1,1,0,0,0,0,0]) 49 | parser.add_argument("--mask_threshold", type=float, default=0.2) 50 | parser.add_argument("--lam_STG", type=float, default=3) 51 | parser.add_argument("--sigma_STG", type=float, default=0.1) 52 | parser.add_argument("--metrics", nargs="+", default=["L1_beta_error", "L2_beta_error"]) 53 | parser.add_argument("--bv_analysis", action="store_true") 54 | # SVI 55 | parser.add_argument("--epoch_algorithm", type=int, default=10) 56 | parser.add_argument("--period_MA", type=int, default=3) 57 | 58 | 59 | # others 60 | parser.add_argument("--seed", type=int, default=3) 61 | parser.add_argument("--times", type=int, default=10) 62 | parser.add_argument("--result_dir", default="results") 63 | 64 | return parser.parse_args() 65 | 66 | def main(args, round, logger): 67 | setup_seed(args.seed + round) 68 | p = args.p 69 | p_v = int(p*args.V_ratio) 70 | p_s = p-p_v 71 | n = args.n 72 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 73 | oracle_mask = [True,]*p_s + [False,]*p_v 74 | 75 | # generate train data 76 | X_train, S_train, V_train, fs_train, Y_train = gen_selection_bias_data({**vars(args),**{"r": args.r_train}}) 77 | 78 | beta_s = get_beta_s(p_s) 79 | beta_v = np.zeros(p_v) 80 | beta = np.concatenate([beta_s, beta_v]) 81 | 82 | linear_var, nonlinear_var, total_var = calc_var(beta_s, S_train, fs_train) 83 | logger.info("Linear term var: %.3f, Nonlinear term var: %.3f, total var: %.3f" % (linear_var, nonlinear_var, total_var)) 84 | 85 | # generate test data 86 | test_data = dict() 87 | for r_test in args.r_list: 88 | X_test, S_test, V_test, fs_test, Y_test = gen_selection_bias_data({**vars(args),**{"r": r_test}}) 89 | test_data[r_test] = (X_test, S_test, V_test, fs_test, Y_test) 90 | 91 | results = dict() 92 | 93 | cov_mask = get_cov_mask(np.zeros(p)) 94 | select_ratio_MA = [] 95 | for epoch in range(args.epoch_algorithm): 96 | logger.debug("Epoch %d" % epoch) 97 | # reweighting 98 | logger.debug("cov_mask:\n" + str(cov_mask)) 99 | W = DWR(X_train, cov_mask=cov_mask, order=args.order, num_steps=args.iters_balance, logger=logger, device=device) 100 | # feature selection 101 | stg = STG(p, 1, sigma=args.sigma_STG, lam=args.lam_STG) 102 | stg.train(X_train, Y_train, W=W, epochs=5000) 103 | select_ratio = stg.get_ratios().detach().numpy() 104 | if len(select_ratio_MA) >= args.period_MA: 105 | select_ratio_MA.pop(0) 106 | select_ratio_MA.append(select_ratio) 107 | select_ratio = sum(select_ratio_MA)/len(select_ratio_MA) 108 | logger.info("Select ratio: " + pretty(select_ratio)) 109 | logger.info("Current hard selection: " + str(np.array(select_ratio > args.mask_threshold, dtype=np.int64))) 110 | cov_mask = get_cov_mask(select_ratio) 111 | 112 | mask = select_ratio > args.mask_threshold 113 | if np.array(mask, dtype=np.int64).sum() == 0: 114 | logger.info("All variables are discarded!") 115 | assert False 116 | logger.info("Hard selection: " + str(np.array(mask, dtype=np.int64))) 117 | model_func = get_algorithm_class(args.backend) 118 | model = model_func(X_train, Y_train, np.ones((n, 1))/n, **vars(args)) 119 | model.fit(X_train[:, mask], Y_train) 120 | 121 | # test 122 | RMSE_dict = dict() 123 | for r_test in args.r_list: 124 | X_test, S_test, V_test, fs_test, Y_test = test_data[r_test] 125 | RMSE_dict[r_test] = mean_squared_error(Y_test, model.predict(X_test[:,mask])) 126 | logger.info("Average RMSE: %.3f" % np.mean(list(RMSE_dict.values()))) 127 | logger.info("Error STD: %.3f" % np.std(list(RMSE_dict.values()))) 128 | logger.info("Error max: %.3f" % np.max(list(RMSE_dict.values()))) 129 | results["RMSE"] = RMSE_dict 130 | 131 | return results 132 | 133 | if __name__ == "__main__": 134 | args = get_args() 135 | setup_seed(args.seed) 136 | expname = get_expname(args) 137 | os.makedirs(os.path.join(args.result_dir, expname), exist_ok=True) 138 | logger = Logger(args) 139 | logger.log_args(args) 140 | 141 | p = args.p 142 | p_v = int(p*args.V_ratio) 143 | p_s = p-p_v 144 | beta_s = get_beta_s(p_s) 145 | beta_v = np.zeros(p_v) 146 | beta = np.concatenate([beta_s, beta_v]) 147 | 148 | results_list = dd(list) 149 | for i in range(args.times): 150 | logger.info("Round %d" % i) 151 | results = main(args, i, logger) 152 | for k, v in results.items(): 153 | results_list[k].append(v) 154 | 155 | 156 | logger.info("Final Result:") 157 | for k, v in results_list.items(): 158 | if k == "RMSE": 159 | RMSE_dict = dict() 160 | for r_test in args.r_list: 161 | RMSE = [v[i][r_test] for i in range(args.times)] 162 | RMSE_dict[r_test] = sum(RMSE)/len(RMSE) 163 | logger.info("RMSE average: %.3f" % (np.mean(list(RMSE_dict.values())))) 164 | logger.info("RMSE std: %.3f" % ((np.std(list(RMSE_dict.values()))))) 165 | logger.info("RMSE max: %.3f" % ((np.max(list(RMSE_dict.values()))))) 166 | logger.info("Detailed RMSE:") 167 | for r_test in args.r_list: 168 | logger.info("%.1f: %.3f" % (r_test, RMSE_dict[r_test])) 169 | elif k == "beta_hat": 170 | beta_hat_array = np.array(v) 171 | beta_hat_mean = np.mean(beta_hat_array, axis=0) 172 | logger.info("%s: %s" % (k, beta_hat_mean)) 173 | if args.bv_analysis: 174 | bv_dict = dict() 175 | bv_dict["s"] = BV_analysis(beta_hat_array[:,:p_s], beta[:p_s]) 176 | bv_dict["v"] = BV_analysis(beta_hat_array[:,p_s:], beta[p_s:]) 177 | bv_dict["all"] = BV_analysis(beta_hat_array, beta) 178 | for covariates in ["s", "v", "all"]: 179 | logger.info("Bias for %s: %.4f, variance for %s: %.4f" % (covariates, bv_dict[covariates][0], covariates, bv_dict[covariates][1])) 180 | else: 181 | logger.info("%s: %.3f" % (k, sum(v)/len(v))) 182 | 183 | -------------------------------------------------------------------------------- /simulated_fs.py: -------------------------------------------------------------------------------- 1 | from data.selection_bias import gen_selection_bias_data 2 | from algorithm.DWR import DWR 3 | from algorithm.SRDO import SRDO 4 | from model.linear import get_algorithm_class 5 | from metrics import get_metric_class 6 | from utils import setup_seed, get_beta_s, get_expname, calc_var, pretty, get_cov_mask, BV_analysis 7 | from Logger import Logger 8 | from model.STG import STG 9 | from lifelines import LogLogisticAFTFitter, WeibullAFTFitter, LogNormalAFTFitter 10 | from sklearn.metrics import mean_squared_error 11 | import numpy as np 12 | import argparse 13 | import os 14 | import torch 15 | from collections import defaultdict as dd 16 | import pandas as pd 17 | from lifelines import CoxPHFitter 18 | import seaborn as sns 19 | import matplotlib.pyplot as plt 20 | from lifelines.statistics import logrank_test 21 | from lifelines import KaplanMeierFitter 22 | from sksurv.util import Surv 23 | from lifelines.utils import concordance_index 24 | 25 | def get_args(): 26 | parser = argparse.ArgumentParser(description="Script to launch sample reweighting experiments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 27 | 28 | # data generation 29 | parser.add_argument("--p", type=int, default=10, help="Input dim") 30 | parser.add_argument("--n", type=int, default=10000, help="Sample size") 31 | parser.add_argument("--V_ratio", type=float, default=0.5) 32 | parser.add_argument("--Vb_ratio", type=float, default=0.1) 33 | parser.add_argument("--true_func", choices=["linear",], default="linear") 34 | parser.add_argument("--mode", choices=["S_|_V", "S->V", "V->S", "collinearity"], default="S_|_V") 35 | parser.add_argument("--misspe", choices=["poly", "exp", "None"], default="poly") 36 | parser.add_argument("--corr_s", type=float, default=0.9) 37 | parser.add_argument("--corr_v", type=float, default=0.1) 38 | parser.add_argument("--mms_strength", type=float, default=1.0, help="model misspecifction strength") 39 | parser.add_argument("--spurious", choices=["nonlinear", "linear"], default="nonlinear") 40 | parser.add_argument("--r_train", type=float, default=2.5, help="Input dim") 41 | parser.add_argument("--r_list", type=float, nargs="+", default=[-3, -2, -1.7, -1.5, -1.3, 1.3, 1.5, 1.7, 2, 3]) 42 | parser.add_argument("--noise_variance", type=float, default=0.1) 43 | 44 | # frontend reweighting 45 | parser.add_argument("--reweighting", choices=["None", "DWR", "SRDO"], default="DWR") 46 | parser.add_argument("--decorrelation_type", choices=["global", "group"], default="global") 47 | parser.add_argument("--order", type=int, default=1) 48 | parser.add_argument("--iters_balance", type=int, default=20000) 49 | 50 | parser.add_argument("--topN", type=int, default=5) 51 | # backend model 52 | parser.add_argument("--backend", choices=["OLS", "Lasso", "Ridge", "Weighted_cox", "LogLogistic"], default="Weighted_cox") 53 | parser.add_argument("--paradigm", choices=["regr", "fs",], default="regr") 54 | parser.add_argument("--iters_train", type=int, default=1000) 55 | parser.add_argument("--lam_backend", type=float, default=0.01) # regularizer coefficient 56 | parser.add_argument("--fs_type", choices=["oracle", "None", "given", "STG"], default="STG") 57 | parser.add_argument("--mask_given", type=int, nargs="+", default=[1,1,1,1,1,0,0,0,0,0]) 58 | parser.add_argument("--mask_threshold", type=float, default=0.2) 59 | parser.add_argument("--lam_STG", type=float, default=3) 60 | parser.add_argument("--sigma_STG", type=float, default=0.1) 61 | parser.add_argument("--metrics", nargs="+", default=["L1_beta_error", "L2_beta_error"]) 62 | parser.add_argument("--bv_analysis", action="store_true") 63 | 64 | parser.add_argument("--gener_method", choices=["cox_exp", "cox_weibull", "poly", "cox_Gompertz", "exp_T", "log_T"], default="cox_exp") 65 | # others 66 | parser.add_argument("--seed", type=int, default=3) 67 | parser.add_argument("--times", type=int, default=10) 68 | parser.add_argument("--result_dir", default="results") 69 | 70 | return parser.parse_args() 71 | 72 | 73 | 74 | def generate_indicator(Y, cencored_rate = 0.1): 75 | n = Y.shape[0] 76 | num_elements = int(n * cencored_rate) 77 | indices = np.random.choice(n, size=num_elements, replace=False) 78 | random_values = np.random.uniform(0, Y[indices]) 79 | Y[indices] = random_values 80 | indicator = np.ones_like(Y) 81 | indicator[indices] = 0 82 | return Y, indicator 83 | 84 | 85 | 86 | def main(args, round, logger): 87 | setup_seed(args.seed + round) 88 | p = args.p 89 | p_v = int(p*args.V_ratio) 90 | p_s = p-p_v 91 | n = args.n 92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 93 | oracle_mask = [True,]*p_s + [False,]*p_v 94 | 95 | # generate train data 96 | X_train, S_train, V_train, fs_train, Y_train = gen_selection_bias_data({**vars(args),**{"r": args.r_train}}) 97 | 98 | Y_train, Y_censored = generate_indicator(Y_train, cencored_rate = 0.1) 99 | print("Y_train", Y_train) 100 | X_train_pd = pd.DataFrame(np.concatenate((Y_train.reshape((-1, 1)), Y_censored.reshape((-1,1)), X_train), axis=1), columns=["Survival.months", "Survival.status"]+list(range(0, X_train.shape[1]))) 101 | 102 | beta_s = get_beta_s(p_s) 103 | beta_v = np.zeros(p_v) 104 | beta = np.concatenate([beta_s, beta_v]) 105 | 106 | linear_var, nonlinear_var, total_var = calc_var(beta_s, S_train, fs_train) 107 | logger.info("Linear term var: %.3f, Nonlinear term var: %.3f, total var: %.3f" % (linear_var, nonlinear_var, total_var)) 108 | 109 | # generate test data 110 | test_data = dict() 111 | for r_test in args.r_list: 112 | X_test, S_test, V_test, fs_test, Y_test = gen_selection_bias_data({**vars(args),**{"r": r_test}}) 113 | Y_test, Y_censored = generate_indicator(Y_test, cencored_rate = 0.1) 114 | X_test_pd = pd.DataFrame(np.concatenate((Y_test.reshape((-1, 1)), Y_censored.reshape((-1,1)), X_test), axis=1), columns=["Survival.months", "Survival.status"]+list(range(0, X_test.shape[1]))) 115 | test_data[r_test] = (X_test, S_test, V_test, fs_test, Y_test, X_test_pd) 116 | 117 | p = X_train.shape[1] 118 | if args.reweighting == "DWR": 119 | if args.decorrelation_type == "global": 120 | cov_mask = get_cov_mask(np.zeros(p)) 121 | elif args.decorrelation_type == "group": 122 | cov_mask = get_cov_mask(np.array(oracle_mask, np.float)) 123 | else: 124 | raise NotImplementedError 125 | W = DWR(X_train, cov_mask=cov_mask, order=args.order, num_steps=args.iters_balance, logger=logger, device=device) 126 | elif args.reweighting == "SRDO": 127 | W = SRDO(X_train, p_s, decorrelation_type=args.decorrelation_type, max_iter=args.iters_balance) 128 | else: 129 | W = np.ones((n, 1)) 130 | 131 | #W = np.ones((n, 1))/n 132 | mean_value = np.mean(W) 133 | W = W * (1/mean_value) 134 | 135 | 136 | results = dict() 137 | if args.paradigm == "regr": 138 | mask = [True,]*p 139 | model_func = get_algorithm_class(args.backend) 140 | model = model_func(X_train_pd, "Survival.months", "Survival.status", W, 0.00001, **vars(args)) 141 | 142 | elif args.paradigm == "fs": 143 | if args.fs_type == "STG": 144 | stg = STG(p, 1, sigma=args.sigma_STG, lam=args.lam_STG) 145 | stg.train(X_train, Y_train, W=W, epochs=5000) 146 | select_ratio = stg.get_ratios().detach().numpy() 147 | logger.info("Select ratio: " + pretty(select_ratio)) 148 | mask = select_ratio > args.mask_threshold 149 | elif args.fs_type == "oracle": 150 | mask = oracle_mask 151 | elif args.fs_type == "None": 152 | mask = [True,]*p 153 | elif args.fs_type == "given": 154 | mask = np.array(args.mask_given, np.bool) 155 | else: 156 | raise NotImplementedError 157 | if np.array(mask, dtype=np.int64).sum() == 0: 158 | logger.info("All variables are discarded!") 159 | assert False 160 | logger.info("Hard selection: " + str(np.array(mask, dtype=np.int64))) 161 | model_func = get_algorithm_class(args.backend) 162 | model = model_func(X_train, Y_train, np.ones((n, 1))/n, **vars(args)) 163 | model.fit(X_train[:, mask], Y_train) 164 | else: 165 | raise NotImplementedError 166 | 167 | # test 168 | summary = model.summary 169 | fs_sorted_indices = summary['p'].sort_values().head(args.topN).index 170 | 171 | 172 | X_train_pd = X_train_pd[list(X_train_pd.columns[:2])+list(fs_sorted_indices)] 173 | #cph = LogLogisticAFTFitter(penalizer=0.00001, fit_intercept=False) 174 | cph = CoxPHFitter(penalizer=0.0001) 175 | cph.fit(X_train_pd, duration_col='Survival.months', event_col='Survival.status') 176 | 177 | summary = cph.summary 178 | 179 | coef = summary["coef"] 180 | 181 | 182 | train_score = cph.score(X_train_pd) 183 | optimal_p_value = 0 184 | 185 | c_index_dict = dict() 186 | 187 | for r_test, test in test_data.items(): 188 | print("test ratio:", r_test) 189 | X_test, S_test, V_test, fs_test, Y_test, X_test_pd = test 190 | X_test_pd = X_test_pd[list(X_test_pd.columns[:2])+list(fs_sorted_indices)] 191 | c_index = cph.score(X_test_pd, scoring_method='concordance_index') 192 | c_index_dict[r_test] = c_index 193 | 194 | results["c_index"] = c_index_dict 195 | 196 | return results 197 | 198 | if __name__ == "__main__": 199 | args = get_args() 200 | setup_seed(args.seed) 201 | expname = get_expname(args) 202 | os.makedirs(os.path.join(args.result_dir, expname), exist_ok=True) 203 | logger = Logger(args) 204 | logger.log_args(args) 205 | 206 | p = args.p 207 | p_v = int(p*args.V_ratio) 208 | p_s = p-p_v 209 | beta_s = get_beta_s(p_s) 210 | beta_v = np.zeros(p_v) 211 | beta = np.concatenate([beta_s, beta_v]) 212 | 213 | results_list = dd(list) 214 | for i in range(args.times): 215 | logger.info("Round %d" % i) 216 | results = main(args, i, logger) 217 | for k, v in results.items(): 218 | results_list[k].append(v) 219 | 220 | 221 | logger.info("Final Result:") 222 | for k, v in results_list.items(): 223 | if k == "c_index": 224 | RMSE_dict = dict() 225 | for r_test in args.r_list: 226 | RMSE = [v[i][r_test] for i in range(args.times)] 227 | RMSE_dict[r_test] = sum(RMSE)/len(RMSE) 228 | logger.info("c_index average: %.3f" % (np.mean(list(RMSE_dict.values())))) 229 | logger.info("c_index std: %.3f" % ((np.std(list(RMSE_dict.values()))))) 230 | logger.info("c_index max: %.3f" % ((np.max(list(RMSE_dict.values()))))) 231 | logger.info("Detailed score:") 232 | str1 = "" 233 | for r_test in args.r_list: 234 | logger.info("%.1f: %.8f" % (r_test, RMSE_dict[r_test])) 235 | str1 += str(RMSE_dict[r_test]) + "\n" 236 | print(str1) 237 | 238 | 239 | -------------------------------------------------------------------------------- /simulated.py: -------------------------------------------------------------------------------- 1 | from data.selection_bias import gen_selection_bias_data 2 | from algorithm.DWR import DWR 3 | from algorithm.SRDO import SRDO 4 | from model.linear import get_algorithm_class 5 | from metrics import get_metric_class 6 | from utils import setup_seed, get_beta_s, get_expname, calc_var, pretty, get_cov_mask, BV_analysis 7 | from Logger import Logger 8 | from model.STG import STG 9 | from lifelines import LogLogisticAFTFitter, WeibullAFTFitter, LogNormalAFTFitter 10 | from sklearn.metrics import mean_squared_error 11 | import numpy as np 12 | import argparse 13 | import os 14 | import torch 15 | from collections import defaultdict as dd 16 | import pandas as pd 17 | from lifelines import CoxPHFitter 18 | import seaborn as sns 19 | import matplotlib.pyplot as plt 20 | from lifelines.statistics import logrank_test 21 | from lifelines import KaplanMeierFitter 22 | from sksurv.util import Surv 23 | from lifelines.utils import concordance_index 24 | from scipy.stats import chi2 25 | def get_args(): 26 | parser = argparse.ArgumentParser(description="Script to launch sample reweighting experiments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 27 | 28 | # data generation 29 | parser.add_argument("--p", type=int, default=10, help="Input dim") 30 | parser.add_argument("--n", type=int, default=10000, help="Sample size") 31 | parser.add_argument("--V_ratio", type=float, default=0.5) 32 | parser.add_argument("--Vb_ratio", type=float, default=0.1) 33 | parser.add_argument("--true_func", choices=["linear",], default="linear") 34 | parser.add_argument("--mode", choices=["S_|_V", "S->V", "V->S", "collinearity"], default="S_|_V") 35 | parser.add_argument("--misspe", choices=["poly", "exp", "None"], default="poly") 36 | parser.add_argument("--corr_s", type=float, default=0.9) 37 | parser.add_argument("--corr_v", type=float, default=0.1) 38 | parser.add_argument("--mms_strength", type=float, default=1.0, help="model misspecifction strength") 39 | parser.add_argument("--spurious", choices=["nonlinear", "linear"], default="nonlinear") 40 | parser.add_argument("--r_train", type=float, default=2.5, help="Input dim") 41 | parser.add_argument("--r_list", type=float, nargs="+", default=[-3, -2, -1.7, -1.5, -1.3, 1.3, 1.5, 1.7, 2, 3]) 42 | parser.add_argument("--noise_variance", type=float, default=0.1) 43 | 44 | # frontend reweighting 45 | parser.add_argument("--reweighting", choices=["None", "DWR", "SRDO"], default="DWR") 46 | parser.add_argument("--decorrelation_type", choices=["global", "group"], default="global") 47 | parser.add_argument("--order", type=int, default=1) 48 | parser.add_argument("--iters_balance", type=int, default=2000) 49 | 50 | parser.add_argument("--topN", type=int, default=5) 51 | # backend model 52 | parser.add_argument("--backend", choices=["OLS", "Lasso", "Ridge", "Weighted_cox", "LogLogistic"], default="Weighted_cox") 53 | parser.add_argument("--paradigm", choices=["regr", "fs",], default="regr") 54 | parser.add_argument("--iters_train", type=int, default=1000) 55 | parser.add_argument("--lam_backend", type=float, default=0.01) # regularizer coefficient 56 | parser.add_argument("--fs_type", choices=["oracle", "None", "given", "STG"], default="STG") 57 | parser.add_argument("--mask_given", type=int, nargs="+", default=[1,1,1,1,1,0,0,0,0,0]) 58 | parser.add_argument("--mask_threshold", type=float, default=0.2) 59 | parser.add_argument("--lam_STG", type=float, default=3) 60 | parser.add_argument("--sigma_STG", type=float, default=0.1) 61 | parser.add_argument("--metrics", nargs="+", default=["L1_beta_error", "L2_beta_error"]) 62 | parser.add_argument("--bv_analysis", action="store_true") 63 | 64 | parser.add_argument("--gener_method", choices=["cox_exp", "cox_weibull", "poly", "cox_Gompertz", "exp_T", "log_T"], default="cox_Gompertz") 65 | # others 66 | parser.add_argument("--seed", type=int, default=3) 67 | parser.add_argument("--times", type=int, default=10) 68 | parser.add_argument("--result_dir", default="results") 69 | 70 | return parser.parse_args() 71 | 72 | 73 | def generate_indicator(Y, cencored_rate = 0.1): 74 | n = Y.shape[0] 75 | num_elements = int(n * cencored_rate) 76 | indices = np.random.choice(n, size=num_elements, replace=False) 77 | random_values = np.random.uniform(0, Y[indices]) 78 | Y[indices] = random_values 79 | indicator = np.ones_like(Y) 80 | indicator[indices] = 0 81 | return Y, indicator 82 | 83 | 84 | 85 | def main(args, round, logger): 86 | setup_seed(args.seed + round) 87 | p = args.p 88 | p_v = int(p*args.V_ratio) 89 | p_s = p-p_v 90 | n = args.n 91 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 92 | oracle_mask = [True,]*p_s + [False,]*p_v 93 | 94 | # generate train data 95 | X_train, S_train, V_train, fs_train, Y_train = gen_selection_bias_data({**vars(args),**{"r": args.r_train}}) 96 | 97 | Y_train, Y_censored = generate_indicator(Y_train, cencored_rate = 0.1) 98 | X_train_pd = pd.DataFrame(np.concatenate((Y_train.reshape((-1, 1)), Y_censored.reshape((-1,1)), X_train), axis=1), columns=["Survival.months", "Survival.status"]+list(range(0, X_train.shape[1]))) 99 | 100 | beta_s = get_beta_s(p_s) 101 | beta_v = np.zeros(p_v) 102 | beta = np.concatenate([beta_s, beta_v]) 103 | 104 | linear_var, nonlinear_var, total_var = calc_var(beta_s, S_train, fs_train) 105 | logger.info("Linear term var: %.3f, Nonlinear term var: %.3f, total var: %.3f" % (linear_var, nonlinear_var, total_var)) 106 | 107 | # generate test data 108 | test_data = dict() 109 | for r_test in args.r_list: 110 | X_test, S_test, V_test, fs_test, Y_test = gen_selection_bias_data({**vars(args),**{"r": r_test}}) 111 | Y_test, Y_censored = generate_indicator(Y_test, cencored_rate = 0.1) 112 | X_test_pd = pd.DataFrame(np.concatenate((Y_test.reshape((-1, 1)), Y_censored.reshape((-1,1)), X_test), axis=1), columns=["Survival.months", "Survival.status"]+list(range(0, X_test.shape[1]))) 113 | test_data[r_test] = (X_test, S_test, V_test, fs_test, Y_test, X_test_pd) 114 | 115 | p = X_train.shape[1] 116 | if args.reweighting == "DWR": 117 | if args.decorrelation_type == "global": 118 | cov_mask = get_cov_mask(np.zeros(p)) 119 | elif args.decorrelation_type == "group": 120 | cov_mask = get_cov_mask(np.array(oracle_mask, np.float)) 121 | else: 122 | raise NotImplementedError 123 | W = DWR(X_train, cov_mask=cov_mask, order=args.order, num_steps=args.iters_balance, logger=logger, device=device) 124 | elif args.reweighting == "SRDO": 125 | W = SRDO(X_train, p_s, decorrelation_type=args.decorrelation_type, max_iter=args.iters_balance) 126 | else: 127 | W = np.ones((n, 1)) 128 | 129 | #W = np.ones((n, 1))/n 130 | mean_value = np.mean(W) 131 | W = W * (1/mean_value) 132 | # W = np.clip(W, 0.1, 2) 133 | results = dict() 134 | if args.paradigm == "regr": 135 | mask = [True,]*p 136 | model_func = get_algorithm_class(args.backend) 137 | model = model_func(X_train_pd, "Survival.months", "Survival.status", W, 0.00001, **vars(args)) 138 | 139 | elif args.paradigm == "fs": 140 | if args.fs_type == "STG": 141 | stg = STG(p, 1, sigma=args.sigma_STG, lam=args.lam_STG) 142 | stg.train(X_train, Y_train, W=W, epochs=5000) 143 | select_ratio = stg.get_ratios().detach().numpy() 144 | logger.info("Select ratio: " + pretty(select_ratio)) 145 | mask = select_ratio > args.mask_threshold 146 | elif args.fs_type == "oracle": 147 | mask = oracle_mask 148 | elif args.fs_type == "None": 149 | mask = [True,]*p 150 | elif args.fs_type == "given": 151 | mask = np.array(args.mask_given, np.bool) 152 | else: 153 | raise NotImplementedError 154 | if np.array(mask, dtype=np.int64).sum() == 0: 155 | logger.info("All variables are discarded!") 156 | assert False 157 | logger.info("Hard selection: " + str(np.array(mask, dtype=np.int64))) 158 | model_func = get_algorithm_class(args.backend) 159 | model = model_func(X_train, Y_train, np.ones((n, 1))/n, **vars(args)) 160 | model.fit(X_train[:, mask], Y_train) 161 | else: 162 | raise NotImplementedError 163 | 164 | # test 165 | summary = model.summary 166 | fs_sorted_indices = summary['p'].sort_values().head(args.topN).index 167 | print("sorted_indices", fs_sorted_indices) 168 | p_value = summary['p'].tolist() 169 | 170 | #X_train_pd = X_train_pd[list(X_train_pd.columns[:2])+list(fs_sorted_indices)] 171 | #cph = LogLogisticAFTFitter(penalizer=0.00001, fit_intercept=False) 172 | #cph = CoxPHFitter(penalizer=0.0001) 173 | #cph.fit(new_train_pd, duration_col='Survival.months', event_col='Survival.status') 174 | 175 | summary = model.summary 176 | 177 | coef = summary["coef"] 178 | 179 | columns = X_train_pd.columns 180 | X_train_pd = np.concatenate((X_train_pd, W), axis=1) 181 | X_train_pd = pd.DataFrame(X_train_pd, columns=list(columns)+["Weights"]) 182 | 183 | 184 | train_score = model.score(X_train_pd) 185 | optimal_p_value = 0 186 | 187 | c_index_dict = dict() 188 | 189 | 190 | for r_test, test in test_data.items(): 191 | print("test ratio:", r_test) 192 | X_test, S_test, V_test, fs_test, Y_test, X_test_pd = test 193 | #X_test_pd = X_test_pd[list(X_test_pd.columns[:2])+list(fs_sorted_indices)] 194 | columns = X_test_pd.columns 195 | tmp = np.concatenate((X_test_pd, np.ones((n, 1))), axis=1) 196 | X_test_pd = pd.DataFrame(tmp, columns=list(columns)+["Weights"]) 197 | c_index = model.score(X_test_pd, scoring_method='concordance_index') 198 | c_index_dict[r_test] = c_index 199 | 200 | results["c_index"] = c_index_dict 201 | 202 | return results 203 | 204 | if __name__ == "__main__": 205 | args = get_args() 206 | setup_seed(args.seed) 207 | expname = get_expname(args) 208 | os.makedirs(os.path.join(args.result_dir, expname), exist_ok=True) 209 | logger = Logger(args) 210 | logger.log_args(args) 211 | 212 | p = args.p 213 | p_v = int(p*args.V_ratio) 214 | p_s = p-p_v 215 | beta_s = get_beta_s(p_s) 216 | beta_v = np.zeros(p_v) 217 | beta = np.concatenate([beta_s, beta_v]) 218 | 219 | results_list = dd(list) 220 | for i in range(args.times): 221 | logger.info("Round %d" % i) 222 | results = main(args, i, logger) 223 | for k, v in results.items(): 224 | results_list[k].append(v) 225 | 226 | 227 | logger.info("Final Result:") 228 | for k, v in results_list.items(): 229 | 230 | if k == "c_index": 231 | RMSE_dict = dict() 232 | for r_test in args.r_list: 233 | RMSE = [v[i][r_test] for i in range(args.times)] 234 | RMSE_dict[r_test] = sum(RMSE)/len(RMSE) 235 | logger.info("c_index average: %.3f" % (np.mean(list(RMSE_dict.values())))) 236 | logger.info("c_index std: %.3f" % ((np.std(list(RMSE_dict.values()))))) 237 | logger.info("c_index max: %.3f" % ((np.max(list(RMSE_dict.values()))))) 238 | logger.info("Detailed score:") 239 | str1 = "" 240 | for r_test in args.r_list: 241 | logger.info("%.1f: %.8f" % (r_test, RMSE_dict[r_test])) 242 | str1 += str(RMSE_dict[r_test]) + "\n" 243 | print(str1) 244 | 245 | 246 | -------------------------------------------------------------------------------- /clinical_breast_OS.py: -------------------------------------------------------------------------------- 1 | from data.selection_bias import gen_selection_bias_data 2 | from algorithm.DWR import DWR 3 | from algorithm.SRDO import SRDO 4 | from model.linear import get_algorithm_class 5 | from metrics import get_metric_class 6 | from utils import setup_seed, get_beta_s, get_expname, calc_var, pretty, get_cov_mask, BV_analysis 7 | from Logger import Logger 8 | from model.STG import STG 9 | from sksurv.metrics import brier_score, cumulative_dynamic_auc 10 | from sklearn.metrics import mean_squared_error 11 | import numpy as np 12 | import argparse 13 | import os 14 | import torch 15 | from collections import defaultdict as dd 16 | import pandas as pd 17 | 18 | from lifelines import CoxPHFitter 19 | import seaborn as sns 20 | import matplotlib.pyplot as plt 21 | from lifelines.statistics import logrank_test 22 | from lifelines import KaplanMeierFitter 23 | from sksurv.util import Surv 24 | from lifelines.utils import concordance_index 25 | from sklearn.metrics import roc_auc_score 26 | from sklearn.metrics import accuracy_score 27 | duration_col = 'Survival.months' 28 | event_col = 'Survival.status' 29 | def get_args(): 30 | parser = argparse.ArgumentParser(description="Script to launch sample reweighting experiments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | 32 | # data generation 33 | parser.add_argument("--p", type=int, default=10, help="Input dim") 34 | parser.add_argument("--n", type=int, default=2000, help="Sample size") 35 | parser.add_argument("--V_ratio", type=float, default=0.5) 36 | parser.add_argument("--Vb_ratio", type=float, default=0.1) 37 | parser.add_argument("--true_func", choices=["linear",], default="linear") 38 | parser.add_argument("--mode", choices=["S_|_V", "S->V", "V->S", "collinearity"], default="collinearity") 39 | parser.add_argument("--misspe", choices=["poly", "exp", "None"], default="poly") 40 | parser.add_argument("--corr_s", type=float, default=0.9) 41 | parser.add_argument("--corr_v", type=float, default=0.1) 42 | parser.add_argument("--mms_strength", type=float, default=1.0, help="model misspecifction strength") 43 | parser.add_argument("--spurious", choices=["nonlinear", "linear"], default="nonlinear") 44 | parser.add_argument("--r_train", type=float, default=2.5, help="Input dim") 45 | parser.add_argument("--r_list", type=float, nargs="+", default=[-3, -2, -1.7, -1.5, -1.3, 1.3, 1.5, 1.7, 2, 3]) 46 | parser.add_argument("--noise_variance", type=float, default=0.3) 47 | 48 | # frontend reweighting 49 | parser.add_argument("--reweighting", choices=["None", "DWR", "SRDO"], default="DWR") 50 | parser.add_argument("--decorrelation_type", choices=["global", "group"], default="global") 51 | parser.add_argument("--order", type=int, default=1) 52 | parser.add_argument("--iters_balance", type=int, default=2000) 53 | 54 | parser.add_argument("--topN", type=int, default=5) 55 | # backend model 56 | parser.add_argument("--backend", choices=["OLS", "Lasso", "Ridge", "Weighted_cox", "LogLogistic", "Weibull", "LogNormal"], default="Weighted_cox") 57 | parser.add_argument("--paradigm", choices=["regr", "fs",], default="regr") 58 | parser.add_argument("--iters_train", type=int, default=5000) 59 | parser.add_argument("--lam_backend", type=float, default=0.01) # regularizer coefficient 60 | parser.add_argument("--fs_type", choices=["oracle", "None", "given", "STG"], default="STG") 61 | parser.add_argument("--mask_given", type=int, nargs="+", default=[1,1,1,1,1,0,0,0,0,0]) 62 | parser.add_argument("--mask_threshold", type=float, default=0.2) 63 | parser.add_argument("--lam_STG", type=float, default=3) 64 | parser.add_argument("--sigma_STG", type=float, default=0.1) 65 | parser.add_argument("--metrics", nargs="+", default=["L1_beta_error", "L2_beta_error"]) 66 | parser.add_argument("--bv_analysis", action="store_true") 67 | 68 | # others 69 | parser.add_argument("--seed", type=int, default=3) 70 | parser.add_argument("--times", type=int, default=1) 71 | parser.add_argument("--result_dir", default="results") 72 | 73 | return parser.parse_args() 74 | 75 | 76 | 77 | def main(args, round, logger): 78 | setup_seed(args.seed + round) 79 | p = args.p 80 | p_v = int(p*args.V_ratio) 81 | p_s = p-p_v 82 | n = args.n 83 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 84 | 85 | # generate train data 86 | 87 | training_pd_data = pd.read_csv('./clinical_data/breast_cancer/breast_train_survival.csv', index_col=0) 88 | test0_pd_data = pd.read_csv('./clinical_data/breast_cancer/breast_test0_survival.csv', index_col=0) 89 | test1_pd_data = pd.read_csv('./clinical_data/breast_cancer/breast_test1_survival.csv', index_col=0) 90 | 91 | test2_pd_data = pd.read_csv('./clinical_data/breast_cancer/breast_test2_survival.csv', index_col=0) 92 | 93 | training_pd_data = training_pd_data.drop(['Recurr.months', 'Recurr.status', 'Cohort'], axis=1) 94 | 95 | test0_pd_data = test0_pd_data.drop(['Recurr.months', 'Recurr.status', 'Cohort'], axis=1) 96 | test1_pd_data = test1_pd_data.drop(['Recurr.months', 'Recurr.status', 'Cohort'], axis=1) 97 | 98 | test2_pd_data = test2_pd_data.drop(['Recurr.months', 'Recurr.status', 'Cohort'], axis=1) 99 | 100 | 101 | 102 | X_train_np = np.array(training_pd_data.iloc[:, 2:]) 103 | 104 | 105 | p = X_train_np.shape[1] 106 | n= X_train_np.shape[0] 107 | print("dim", p) 108 | if args.reweighting == "DWR": 109 | if args.decorrelation_type == "global": 110 | cov_mask = get_cov_mask(np.zeros(p)) 111 | elif args.decorrelation_type == "group": 112 | cov_mask = get_cov_mask(np.array(oracle_mask, np.float)) 113 | else: 114 | raise NotImplementedError 115 | W = DWR(X_train_np, cov_mask=cov_mask, order=args.order, num_steps=args.iters_balance, logger=logger, device=device) 116 | elif args.reweighting == "SRDO": 117 | W = SRDO(X_train_np, p_s, hidden_layer_sizes = (69, 15), decorrelation_type=args.decorrelation_type, max_iter=args.iters_balance) 118 | else: 119 | W = np.ones((n, 1))/n 120 | 121 | mean_value = np.mean(W) 122 | W = W * (1/mean_value) 123 | 124 | W = np.clip(W, 0.02, 2) 125 | columns = training_pd_data.columns 126 | X_train_pd = training_pd_data 127 | all_X = np.concatenate((X_train_pd, W), axis=1) 128 | 129 | X_train_test = pd.DataFrame(all_X, columns=list(columns)+["Weights"]) 130 | test0_W = np.ones((test0_pd_data.shape[0], 1)) 131 | 132 | X_test0_pd = np.concatenate((test0_pd_data, test0_W), axis=1) 133 | 134 | X_test0_pd = pd.DataFrame(X_test0_pd, columns=list(columns)+["Weights"]) 135 | 136 | test1_W = np.ones((test1_pd_data.shape[0], 1)) 137 | 138 | X_test1_pd = np.concatenate((test1_pd_data, test1_W), axis=1) 139 | 140 | X_test1_pd = pd.DataFrame(test1_pd_data, columns=list(columns)+["Weights"]) 141 | 142 | test2_W = np.ones((test2_pd_data.shape[0], 1)) 143 | 144 | X_test2_pd = np.concatenate((test2_pd_data, test2_W), axis=1) 145 | 146 | X_test2_pd = pd.DataFrame(test2_pd_data, columns=list(columns)+["Weights"]) 147 | 148 | 149 | 150 | results = dict() 151 | if args.paradigm == "regr": 152 | mask = [True,]*p 153 | model_func = get_algorithm_class(args.backend) 154 | model = model_func(X_train_pd, duration_col, event_col, W, 0.002, **vars(args)) 155 | print("train score", model.score(X_train_test)) 156 | 157 | elif args.paradigm == "fs": 158 | if args.fs_type == "STG": 159 | stg = STG(p, 1, sigma=args.sigma_STG, lam=args.lam_STG) 160 | stg.train(X_train, Y_train, W=W, epochs=5000) 161 | select_ratio = stg.get_ratios().detach().numpy() 162 | logger.info("Select ratio: " + pretty(select_ratio)) 163 | mask = select_ratio > args.mask_threshold 164 | elif args.fs_type == "oracle": 165 | mask = oracle_mask 166 | elif args.fs_type == "None": 167 | mask = [True,]*p 168 | elif args.fs_type == "given": 169 | mask = np.array(args.mask_given, np.bool) 170 | else: 171 | raise NotImplementedError 172 | if np.array(mask, dtype=np.int64).sum() == 0: 173 | logger.info("All variables are discarded!") 174 | assert False 175 | logger.info("Hard selection: " + str(np.array(mask, dtype=np.int64))) 176 | model_func = get_algorithm_class(args.backend) 177 | model = model_func(X_train, Y_train, np.ones((n, 1))/n, **vars(args)) 178 | model.fit(X_train[:, mask], Y_train) 179 | else: 180 | raise NotImplementedError 181 | 182 | # test 183 | summary = model.summary 184 | summary.to_csv("./breast_results_stable.csv") 185 | print("summary", summary) 186 | 187 | coef = summary["coef"] 188 | 189 | 190 | 191 | c_index_dict = [] 192 | 193 | 194 | c_index0 = model.score(X_test0_pd, scoring_method='concordance_index') 195 | print("test1") 196 | c_index1 = model.score(X_test1_pd, scoring_method='concordance_index') 197 | print(c_index1) 198 | c_index_dict.append(c_index1) 199 | 200 | print("test2") 201 | c_index2 = model.score(X_test2_pd, scoring_method='concordance_index') 202 | print(c_index2) 203 | c_index_dict.append(c_index2) 204 | 205 | 206 | 207 | print("c_index") 208 | mean_acc = np.mean(c_index_dict) 209 | std_acc = np.std(c_index_dict) 210 | worst_acc = min(c_index_dict) 211 | print(mean_acc) 212 | print(std_acc) 213 | print(worst_acc) 214 | 215 | 216 | 217 | return results 218 | 219 | if __name__ == "__main__": 220 | args = get_args() 221 | setup_seed(args.seed) 222 | expname = get_expname(args) 223 | os.makedirs(os.path.join(args.result_dir, expname), exist_ok=True) 224 | logger = Logger(args) 225 | logger.log_args(args) 226 | 227 | p = args.p 228 | p_v = int(p*args.V_ratio) 229 | p_s = p-p_v 230 | beta_s = get_beta_s(p_s) 231 | beta_v = np.zeros(p_v) 232 | beta = np.concatenate([beta_s, beta_v]) 233 | results_list = dd(list) 234 | for i in range(args.times): 235 | logger.info("Round %d" % i) 236 | results = main(args, i, logger) 237 | for k, v in results.items(): 238 | results_list[k].append(v) 239 | 240 | 241 | logger.info("Final Result:") 242 | for k, v in results_list.items(): 243 | if k == "RMSE": 244 | RMSE_dict = dict() 245 | for r_test in args.r_list: 246 | RMSE = [v[i][r_test] for i in range(args.times)] 247 | RMSE_dict[r_test] = sum(RMSE)/len(RMSE) 248 | logger.info("RMSE average: %.3f" % (np.mean(list(RMSE_dict.values())))) 249 | logger.info("RMSE std: %.3f" % ((np.std(list(RMSE_dict.values()))))) 250 | logger.info("RMSE max: %.3f" % ((np.max(list(RMSE_dict.values()))))) 251 | logger.info("Detailed RMSE:") 252 | for r_test in args.r_list: 253 | logger.info("%.1f: %.3f" % (r_test, RMSE_dict[r_test])) 254 | elif k == "beta_hat": 255 | beta_hat_array = np.array(v) 256 | beta_hat_mean = np.mean(beta_hat_array, axis=0) 257 | logger.info("%s: %s" % (k, beta_hat_mean)) 258 | if args.bv_analysis: 259 | bv_dict = dict() 260 | bv_dict["s"] = BV_analysis(beta_hat_array[:,:p_s], beta[:p_s]) 261 | bv_dict["v"] = BV_analysis(beta_hat_array[:,p_s:], beta[p_s:]) 262 | bv_dict["all"] = BV_analysis(beta_hat_array, beta) 263 | for covariates in ["s", "v", "all"]: 264 | logger.info("Bias for %s: %.4f, variance for %s: %.4f" % (covariates, bv_dict[covariates][0], covariates, bv_dict[covariates][1])) 265 | else: 266 | logger.info("%s: %.3f" % (k, sum(v)/len(v))) 267 | 268 | 269 | -------------------------------------------------------------------------------- /clinical_breast_RFS.py: -------------------------------------------------------------------------------- 1 | from data.selection_bias import gen_selection_bias_data 2 | from algorithm.DWR import DWR 3 | from algorithm.SRDO import SRDO 4 | from model.linear import get_algorithm_class 5 | from metrics import get_metric_class 6 | from utils import setup_seed, get_beta_s, get_expname, calc_var, pretty, get_cov_mask, BV_analysis 7 | from Logger import Logger 8 | from model.STG import STG 9 | from sksurv.metrics import brier_score, cumulative_dynamic_auc 10 | from sklearn.metrics import mean_squared_error 11 | import numpy as np 12 | import argparse 13 | import os 14 | import torch 15 | from collections import defaultdict as dd 16 | import pandas as pd 17 | 18 | from lifelines import CoxPHFitter 19 | import seaborn as sns 20 | import matplotlib.pyplot as plt 21 | from lifelines.statistics import logrank_test 22 | from lifelines import KaplanMeierFitter 23 | from sksurv.util import Surv 24 | from lifelines.utils import concordance_index 25 | from sklearn.metrics import roc_auc_score 26 | from sklearn.metrics import accuracy_score 27 | duration_col = 'Recurr.months' 28 | event_col = 'Recurr.status' 29 | def get_args(): 30 | parser = argparse.ArgumentParser(description="Script to launch sample reweighting experiments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | 32 | # data generation 33 | parser.add_argument("--p", type=int, default=10, help="Input dim") 34 | parser.add_argument("--n", type=int, default=2000, help="Sample size") 35 | parser.add_argument("--V_ratio", type=float, default=0.5) 36 | parser.add_argument("--Vb_ratio", type=float, default=0.1) 37 | parser.add_argument("--true_func", choices=["linear",], default="linear") 38 | parser.add_argument("--mode", choices=["S_|_V", "S->V", "V->S", "collinearity"], default="collinearity") 39 | parser.add_argument("--misspe", choices=["poly", "exp", "None"], default="poly") 40 | parser.add_argument("--corr_s", type=float, default=0.9) 41 | parser.add_argument("--corr_v", type=float, default=0.1) 42 | parser.add_argument("--mms_strength", type=float, default=1.0, help="model misspecifction strength") 43 | parser.add_argument("--spurious", choices=["nonlinear", "linear"], default="nonlinear") 44 | parser.add_argument("--r_train", type=float, default=2.5, help="Input dim") 45 | parser.add_argument("--r_list", type=float, nargs="+", default=[-3, -2, -1.7, -1.5, -1.3, 1.3, 1.5, 1.7, 2, 3]) 46 | parser.add_argument("--noise_variance", type=float, default=0.3) 47 | 48 | # frontend reweighting 49 | parser.add_argument("--reweighting", choices=["None", "DWR", "SRDO"], default="DWR") 50 | parser.add_argument("--decorrelation_type", choices=["global", "group"], default="global") 51 | parser.add_argument("--order", type=int, default=1) 52 | parser.add_argument("--iters_balance", type=int, default=2000) 53 | 54 | parser.add_argument("--topN", type=int, default=5) 55 | # backend model 56 | parser.add_argument("--backend", choices=["OLS", "Lasso", "Ridge", "Weighted_cox", "LogLogistic", "Weibull", "LogNormal"], default="Weighted_cox") 57 | parser.add_argument("--paradigm", choices=["regr", "fs",], default="regr") 58 | parser.add_argument("--iters_train", type=int, default=5000) 59 | parser.add_argument("--lam_backend", type=float, default=0.01) # regularizer coefficient 60 | parser.add_argument("--fs_type", choices=["oracle", "None", "given", "STG"], default="STG") 61 | parser.add_argument("--mask_given", type=int, nargs="+", default=[1,1,1,1,1,0,0,0,0,0]) 62 | parser.add_argument("--mask_threshold", type=float, default=0.2) 63 | parser.add_argument("--lam_STG", type=float, default=3) 64 | parser.add_argument("--sigma_STG", type=float, default=0.1) 65 | parser.add_argument("--metrics", nargs="+", default=["L1_beta_error", "L2_beta_error"]) 66 | parser.add_argument("--bv_analysis", action="store_true") 67 | 68 | # others 69 | parser.add_argument("--seed", type=int, default=3) 70 | parser.add_argument("--times", type=int, default=1) 71 | parser.add_argument("--result_dir", default="results") 72 | 73 | return parser.parse_args() 74 | 75 | 76 | def main(args, round, logger): 77 | setup_seed(args.seed + round) 78 | p = args.p 79 | p_v = int(p*args.V_ratio) 80 | p_s = p-p_v 81 | n = args.n 82 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 83 | oracle_mask = [True,]*p_s + [False,]*p_v 84 | 85 | # generate train data 86 | X_train, S_train, V_train, fs_train, Y_train = gen_selection_bias_data({**vars(args),**{"r": args.r_train}}) 87 | 88 | beta_s = get_beta_s(p_s) 89 | beta_v = np.zeros(p_v) 90 | beta = np.concatenate([beta_s, beta_v]) 91 | 92 | linear_var, nonlinear_var, total_var = calc_var(beta_s, S_train, fs_train) 93 | logger.info("Linear term var: %.3f, Nonlinear term var: %.3f, total var: %.3f" % (linear_var, nonlinear_var, total_var)) 94 | 95 | # generate test data 96 | test_data = dict() 97 | for r_test in args.r_list: 98 | X_test, S_test, V_test, fs_test, Y_test = gen_selection_bias_data({**vars(args),**{"r": r_test}}) 99 | test_data[r_test] = (X_test, S_test, V_test, fs_test, Y_test) 100 | 101 | 102 | training_pd_data = pd.read_csv('./clinical_data/breast_cancer/breast_train_survival.csv', index_col=0) 103 | test0_pd_data = pd.read_csv('./clinical_data/breast_cancer/breast_test0_survival.csv', index_col=0) 104 | test1_pd_data = pd.read_csv('./clinical_data/breast_cancer/breast_test1_survival.csv', index_col=0) 105 | test2_pd_data = pd.read_csv('./clinical_data/breast_cancer/breast_test2_survival.csv', index_col=0) 106 | 107 | 108 | 109 | 110 | training_pd_data = training_pd_data.drop(['Survival.months', 'Survival.status', 'Cohort'], axis=1) 111 | 112 | test0_pd_data = test0_pd_data.drop(['Survival.months', 'Survival.status', 'Cohort'], axis=1) 113 | test1_pd_data = test1_pd_data.drop(['Survival.months', 'Survival.status', 'Cohort'], axis=1) 114 | 115 | test2_pd_data = test2_pd_data.drop(['Survival.months', 'Survival.status', 'Cohort'], axis=1) 116 | 117 | 118 | X_train_np = np.array(training_pd_data.iloc[:, 2:]) 119 | 120 | 121 | p = X_train_np.shape[1] 122 | n= X_train_np.shape[0] 123 | if args.reweighting == "DWR": 124 | if args.decorrelation_type == "global": 125 | cov_mask = get_cov_mask(np.zeros(p)) 126 | elif args.decorrelation_type == "group": 127 | cov_mask = get_cov_mask(np.array(oracle_mask, np.float)) 128 | else: 129 | raise NotImplementedError 130 | W = DWR(X_train_np, cov_mask=cov_mask, order=args.order, num_steps=args.iters_balance, logger=logger, device=device) 131 | elif args.reweighting == "SRDO": 132 | W = SRDO(X_train_np, p_s, hidden_layer_sizes = (15, 13), decorrelation_type=args.decorrelation_type, max_iter=args.iters_balance) 133 | else: 134 | W = np.ones((n, 1))/n 135 | 136 | mean_value = np.mean(W) 137 | W = W * (1/mean_value) 138 | #print("max*********", np.max(W)) 139 | #print("min********", np.min(W)) 140 | 141 | 142 | #W = (W - min_value)/(max_value-min_value) 143 | #W = W * 2 144 | #print("W", W) 145 | 146 | W = np.clip(W, 0.001, 2) 147 | columns = training_pd_data.columns 148 | X_train_pd = training_pd_data 149 | all_X = np.concatenate((X_train_pd, W), axis=1) 150 | 151 | X_train_test = pd.DataFrame(all_X, columns=list(columns)+["Weights"]) 152 | test0_W = np.ones((test0_pd_data.shape[0], 1)) 153 | 154 | X_test0_pd = np.concatenate((test0_pd_data, test0_W), axis=1) 155 | 156 | X_test0_pd = pd.DataFrame(X_test0_pd, columns=list(columns)+["Weights"]) 157 | 158 | test1_W = np.ones((test1_pd_data.shape[0], 1)) 159 | 160 | X_test1_pd = np.concatenate((test1_pd_data, test1_W), axis=1) 161 | 162 | X_test1_pd = pd.DataFrame(test1_pd_data, columns=list(columns)+["Weights"]) 163 | 164 | test2_W = np.ones((test2_pd_data.shape[0], 1)) 165 | 166 | X_test2_pd = np.concatenate((test2_pd_data, test2_W), axis=1) 167 | 168 | X_test2_pd = pd.DataFrame(test2_pd_data, columns=list(columns)+["Weights"]) 169 | 170 | 171 | 172 | 173 | results = dict() 174 | if args.paradigm == "regr": 175 | mask = [True,]*p 176 | model_func = get_algorithm_class(args.backend) 177 | model = model_func(X_train_pd, duration_col, event_col, W, 0.03, **vars(args)) 178 | print("train score", model.score(X_train_test)) 179 | 180 | elif args.paradigm == "fs": 181 | if args.fs_type == "STG": 182 | stg = STG(p, 1, sigma=args.sigma_STG, lam=args.lam_STG) 183 | stg.train(X_train, Y_train, W=W, epochs=5000) 184 | select_ratio = stg.get_ratios().detach().numpy() 185 | logger.info("Select ratio: " + pretty(select_ratio)) 186 | mask = select_ratio > args.mask_threshold 187 | elif args.fs_type == "oracle": 188 | mask = oracle_mask 189 | elif args.fs_type == "None": 190 | mask = [True,]*p 191 | elif args.fs_type == "given": 192 | mask = np.array(args.mask_given, np.bool) 193 | else: 194 | raise NotImplementedError 195 | if np.array(mask, dtype=np.int64).sum() == 0: 196 | logger.info("All variables are discarded!") 197 | assert False 198 | logger.info("Hard selection: " + str(np.array(mask, dtype=np.int64))) 199 | model_func = get_algorithm_class(args.backend) 200 | model = model_func(X_train, Y_train, np.ones((n, 1))/n, **vars(args)) 201 | model.fit(X_train[:, mask], Y_train) 202 | else: 203 | raise NotImplementedError 204 | 205 | # test 206 | summary = model.summary 207 | #print("summary", summary) 208 | 209 | coef = summary["coef"] 210 | #print("training:") 211 | #plot_KM(cph2, sorted_indices, coef, X_train_pd, int(args.topN), "stable_cox_regression_nature.jpg") 212 | 213 | c_index_dict = [] 214 | 215 | c_index0 = model.score(X_test0_pd, scoring_method='concordance_index') 216 | print("test1") 217 | c_index1 = model.score(X_test1_pd, scoring_method='concordance_index') 218 | print(c_index1) 219 | c_index_dict.append(c_index1) 220 | 221 | print("test2") 222 | 223 | c_index2 = model.score(X_test2_pd, scoring_method='concordance_index') 224 | print(c_index2) 225 | c_index_dict.append(c_index2) 226 | 227 | 228 | 229 | print("c_index") 230 | mean_acc = np.mean(c_index_dict) 231 | std_acc = np.std(c_index_dict) 232 | worst_acc = min(c_index_dict) 233 | print(mean_acc) 234 | print(std_acc) 235 | print(worst_acc) 236 | 237 | return results 238 | 239 | if __name__ == "__main__": 240 | args = get_args() 241 | setup_seed(args.seed) 242 | expname = get_expname(args) 243 | os.makedirs(os.path.join(args.result_dir, expname), exist_ok=True) 244 | logger = Logger(args) 245 | logger.log_args(args) 246 | 247 | p = args.p 248 | p_v = int(p*args.V_ratio) 249 | p_s = p-p_v 250 | beta_s = get_beta_s(p_s) 251 | beta_v = np.zeros(p_v) 252 | beta = np.concatenate([beta_s, beta_v]) 253 | 254 | results_list = dd(list) 255 | for i in range(args.times): 256 | logger.info("Round %d" % i) 257 | results = main(args, i, logger) 258 | for k, v in results.items(): 259 | results_list[k].append(v) 260 | 261 | 262 | logger.info("Final Result:") 263 | for k, v in results_list.items(): 264 | if k == "RMSE": 265 | RMSE_dict = dict() 266 | for r_test in args.r_list: 267 | RMSE = [v[i][r_test] for i in range(args.times)] 268 | RMSE_dict[r_test] = sum(RMSE)/len(RMSE) 269 | logger.info("RMSE average: %.3f" % (np.mean(list(RMSE_dict.values())))) 270 | logger.info("RMSE std: %.3f" % ((np.std(list(RMSE_dict.values()))))) 271 | logger.info("RMSE max: %.3f" % ((np.max(list(RMSE_dict.values()))))) 272 | logger.info("Detailed RMSE:") 273 | for r_test in args.r_list: 274 | logger.info("%.1f: %.3f" % (r_test, RMSE_dict[r_test])) 275 | elif k == "beta_hat": 276 | beta_hat_array = np.array(v) 277 | beta_hat_mean = np.mean(beta_hat_array, axis=0) 278 | logger.info("%s: %s" % (k, beta_hat_mean)) 279 | if args.bv_analysis: 280 | bv_dict = dict() 281 | bv_dict["s"] = BV_analysis(beta_hat_array[:,:p_s], beta[:p_s]) 282 | bv_dict["v"] = BV_analysis(beta_hat_array[:,p_s:], beta[p_s:]) 283 | bv_dict["all"] = BV_analysis(beta_hat_array, beta) 284 | for covariates in ["s", "v", "all"]: 285 | logger.info("Bias for %s: %.4f, variance for %s: %.4f" % (covariates, bv_dict[covariates][0], covariates, bv_dict[covariates][1])) 286 | else: 287 | logger.info("%s: %.3f" % (k, sum(v)/len(v))) 288 | 289 | 290 | -------------------------------------------------------------------------------- /clinical_lung_OS.py: -------------------------------------------------------------------------------- 1 | from data.selection_bias import gen_selection_bias_data 2 | from algorithm.DWR import DWR 3 | from algorithm.SRDO import SRDO 4 | from model.linear import get_algorithm_class 5 | from metrics import get_metric_class 6 | from utils import setup_seed, get_beta_s, get_expname, calc_var, pretty, get_cov_mask, BV_analysis 7 | from Logger import Logger 8 | from model.STG import STG 9 | from sksurv.metrics import brier_score, cumulative_dynamic_auc 10 | from sklearn.metrics import mean_squared_error 11 | import numpy as np 12 | import argparse 13 | import os 14 | import torch 15 | from collections import defaultdict as dd 16 | import pandas as pd 17 | 18 | from lifelines import CoxPHFitter 19 | import seaborn as sns 20 | import matplotlib.pyplot as plt 21 | from lifelines.statistics import logrank_test 22 | from lifelines import KaplanMeierFitter 23 | from sksurv.util import Surv 24 | from lifelines.utils import concordance_index 25 | duration_col='Survival.months' 26 | event_col='Survival.status' 27 | def get_args(): 28 | parser = argparse.ArgumentParser(description="Script to launch sample reweighting experiments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 29 | 30 | # data generation 31 | parser.add_argument("--p", type=int, default=10, help="Input dim") 32 | parser.add_argument("--n", type=int, default=2000, help="Sample size") 33 | parser.add_argument("--V_ratio", type=float, default=0.5) 34 | parser.add_argument("--Vb_ratio", type=float, default=0.1) 35 | parser.add_argument("--true_func", choices=["linear",], default="linear") 36 | parser.add_argument("--mode", choices=["S_|_V", "S->V", "V->S", "collinearity"], default="collinearity") 37 | parser.add_argument("--misspe", choices=["poly", "exp", "None"], default="poly") 38 | parser.add_argument("--corr_s", type=float, default=0.9) 39 | parser.add_argument("--corr_v", type=float, default=0.1) 40 | parser.add_argument("--mms_strength", type=float, default=1.0, help="model misspecifction strength") 41 | parser.add_argument("--spurious", choices=["nonlinear", "linear"], default="nonlinear") 42 | parser.add_argument("--r_train", type=float, default=2.5, help="Input dim") 43 | parser.add_argument("--r_list", type=float, nargs="+", default=[-3, -2, -1.7, -1.5, -1.3, 1.3, 1.5, 1.7, 2, 3]) 44 | parser.add_argument("--noise_variance", type=float, default=0.3) 45 | 46 | # frontend reweighting 47 | parser.add_argument("--reweighting", choices=["None", "DWR", "SRDO"], default="DWR") 48 | parser.add_argument("--decorrelation_type", choices=["global", "group"], default="global") 49 | parser.add_argument("--order", type=int, default=1) 50 | parser.add_argument("--iters_balance", type=int, default=3000) 51 | 52 | parser.add_argument("--topN", type=int, default=5) 53 | # backend model 54 | parser.add_argument("--backend", choices=["OLS", "Lasso", "Ridge", "Weighted_cox"], default="Weighted_cox") 55 | parser.add_argument("--paradigm", choices=["regr", "fs",], default="regr") 56 | parser.add_argument("--iters_train", type=int, default=1000) 57 | parser.add_argument("--lam_backend", type=float, default=0.01) # regularizer coefficient 58 | parser.add_argument("--fs_type", choices=["oracle", "None", "given", "STG"], default="STG") 59 | parser.add_argument("--mask_given", type=int, nargs="+", default=[1,1,1,1,1,0,0,0,0,0]) 60 | parser.add_argument("--mask_threshold", type=float, default=0.2) 61 | parser.add_argument("--lam_STG", type=float, default=3) 62 | parser.add_argument("--sigma_STG", type=float, default=0.1) 63 | parser.add_argument("--metrics", nargs="+", default=["L1_beta_error", "L2_beta_error"]) 64 | parser.add_argument("--bv_analysis", action="store_true") 65 | 66 | # others 67 | parser.add_argument("--seed", type=int, default=3) 68 | parser.add_argument("--times", type=int, default=1) 69 | parser.add_argument("--result_dir", default="results") 70 | 71 | return parser.parse_args() 72 | 73 | 74 | def main(args, round, logger): 75 | setup_seed(args.seed + round) 76 | p = args.p 77 | p_v = int(p*args.V_ratio) 78 | p_s = p-p_v 79 | n = args.n 80 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 81 | oracle_mask = [True,]*p_s + [False,]*p_v 82 | 83 | training_data = pd.read_csv('./clinical_data/lung_cancer/train_pd.csv', index_col=0) 84 | test0_data = pd.read_csv('./clinical_data/lung_cancer/test0_pd.csv', index_col=0) 85 | test1_data = pd.read_csv('./clinical_data/lung_cancer/test1_pd.csv', index_col=0) 86 | test2_data = pd.read_csv('./clinical_data/lung_cancer/test2_pd.csv', index_col=0) 87 | test3_data = pd.read_csv('./clinical_data/lung_cancer/test3_pd.csv', index_col=0) 88 | 89 | test4_data = pd.read_csv('./clinical_data/lung_cancer/test4_pd.csv', index_col=0) 90 | test5_data = pd.read_csv('./clinical_data/lung_cancer/test5_pd.csv', index_col=0) 91 | test6_data = pd.read_csv('./clinical_data/lung_cancer/test6_pd.csv', index_col=0) 92 | test7_data = pd.read_csv('./clinical_data/lung_cancer/test7_pd.csv', index_col=0) 93 | 94 | training_data.drop(['Reccur.status', 'Reccur.months'], axis=1, inplace=True) 95 | test0_data.drop(['Reccur.status', 'Reccur.months'], axis=1, inplace=True) 96 | test1_data.drop(['Reccur.status', 'Reccur.months'], axis=1, inplace=True) 97 | test2_data.drop(['Reccur.status', 'Reccur.months'], axis=1, inplace=True) 98 | test3_data.drop(['Reccur.status', 'Reccur.months'], axis=1, inplace=True) 99 | 100 | test4_data.drop(['Reccur.status', 'Reccur.months'], axis=1, inplace=True) 101 | test5_data.drop(['Reccur.status', 'Reccur.months'], axis=1, inplace=True) 102 | test6_data.drop(['Reccur.status', 'Reccur.months'], axis=1, inplace=True) 103 | test7_data.drop(['Reccur.status', 'Reccur.months'], axis=1, inplace=True) 104 | 105 | 106 | X_train_pd = training_data 107 | tmp = training_data[training_data.columns[:-2]] 108 | 109 | X_train_np = np.array(training_data[training_data.columns[:-2]]) 110 | p = X_train_np.shape[1] 111 | n= X_train_np.shape[0] 112 | 113 | if args.reweighting == "DWR": 114 | if args.decorrelation_type == "global": 115 | cov_mask = get_cov_mask(np.zeros(p)) 116 | elif args.decorrelation_type == "group": 117 | cov_mask = get_cov_mask(np.array(oracle_mask, np.float)) 118 | else: 119 | raise NotImplementedError 120 | W = DWR(X_train_np, cov_mask=cov_mask, order=args.order, num_steps=args.iters_balance, logger=logger, device=device) 121 | elif args.reweighting == "SRDO": 122 | W = SRDO(X_train_np, p_s, hidden_layer_sizes = (37, 4), decorrelation_type="global", max_iter=args.iters_balance) 123 | else: 124 | W = np.ones((n, 1))/n 125 | 126 | mean_value = np.mean(W) 127 | W = W * (1/mean_value) 128 | 129 | W = np.clip(W, 0.3, 3) 130 | 131 | columns = X_train_pd.columns 132 | all_X = np.concatenate((X_train_pd, W), axis=1) 133 | X_train_test = pd.DataFrame(all_X, columns=list(columns)+["Weights"]) 134 | 135 | 136 | test_W = np.ones((test0_data.shape[0], 1)) 137 | test0_data = np.concatenate((test0_data, test_W), axis=1) 138 | test0_pd = pd.DataFrame(test0_data, columns=list(columns)+["Weights"]) 139 | 140 | 141 | test_W = np.ones((test1_data.shape[0], 1)) 142 | test1_data = np.concatenate((test1_data, test_W), axis=1) 143 | test1_pd = pd.DataFrame(test1_data, columns=list(columns)+["Weights"]) 144 | 145 | 146 | test_W = np.ones((test2_data.shape[0], 1)) 147 | test2_data = np.concatenate((test2_data, test_W), axis=1) 148 | test2_pd = pd.DataFrame(test2_data, columns=list(columns)+["Weights"]) 149 | 150 | 151 | test_W = np.ones((test3_data.shape[0], 1)) 152 | test3_data = np.concatenate((test3_data, test_W), axis=1) 153 | test3_pd = pd.DataFrame(test3_data, columns=list(columns)+["Weights"]) 154 | 155 | test_W = np.ones((test4_data.shape[0], 1)) 156 | test4_data = np.concatenate((test4_data, test_W), axis=1) 157 | test4_pd = pd.DataFrame(test4_data, columns=list(columns)+["Weights"]) 158 | 159 | test_W = np.ones((test5_data.shape[0], 1)) 160 | test5_data = np.concatenate((test5_data, test_W), axis=1) 161 | test5_pd = pd.DataFrame(test5_data, columns=list(columns)+["Weights"]) 162 | 163 | test_W = np.ones((test6_data.shape[0], 1)) 164 | test6_data = np.concatenate((test6_data, test_W), axis=1) 165 | test6_pd = pd.DataFrame(test6_data, columns=list(columns)+["Weights"]) 166 | 167 | test_W = np.ones((test7_data.shape[0], 1)) 168 | test7_data = np.concatenate((test7_data, test_W), axis=1) 169 | test7_pd = pd.DataFrame(test7_data, columns=list(columns)+["Weights"]) 170 | 171 | results = dict() 172 | if args.paradigm == "regr": 173 | mask = [True,]*p 174 | model_func = get_algorithm_class(args.backend) 175 | model = model_func(X_train_pd, duration_col, event_col, W, 0.002, **vars(args)) 176 | 177 | elif args.paradigm == "fs": 178 | if args.fs_type == "STG": 179 | stg = STG(p, 1, sigma=args.sigma_STG, lam=args.lam_STG) 180 | stg.train(X_train, Y_train, W=W, epochs=5000) 181 | select_ratio = stg.get_ratios().detach().numpy() 182 | logger.info("Select ratio: " + pretty(select_ratio)) 183 | mask = select_ratio > args.mask_threshold 184 | elif args.fs_type == "oracle": 185 | mask = oracle_mask 186 | elif args.fs_type == "None": 187 | mask = [True,]*p 188 | elif args.fs_type == "given": 189 | mask = np.array(args.mask_given, np.bool) 190 | else: 191 | raise NotImplementedError 192 | if np.array(mask, dtype=np.int64).sum() == 0: 193 | logger.info("All variables are discarded!") 194 | assert False 195 | logger.info("Hard selection: " + str(np.array(mask, dtype=np.int64))) 196 | model_func = get_algorithm_class(args.backend) 197 | model = model_func(X_train, Y_train, np.ones((n, 1))/n, **vars(args)) 198 | model.fit(X_train[:, mask], Y_train) 199 | else: 200 | raise NotImplementedError 201 | 202 | # test 203 | summary = model.summary 204 | coef = summary["coef"] 205 | 206 | c_index_dict = [] 207 | 208 | 209 | c_index0 = model.score(test0_pd, scoring_method='concordance_index') 210 | c_index_dict.append(c_index0) 211 | 212 | 213 | c_index1 = model.score(test1_pd, scoring_method='concordance_index') 214 | c_index_dict.append(c_index1) 215 | 216 | 217 | c_index4 = model.score(test4_pd, scoring_method='concordance_index') 218 | c_index_dict.append(c_index4) 219 | 220 | c_index5 = model.score(test5_pd, scoring_method='concordance_index') 221 | c_index_dict.append(c_index5) 222 | 223 | c_index6 = model.score(test6_pd, scoring_method='concordance_index') 224 | c_index_dict.append(c_index6) 225 | 226 | c_index7 = model.score(test7_pd, scoring_method='concordance_index') 227 | c_index_dict.append(c_index7) 228 | 229 | 230 | print("c_index") 231 | print("\n".join([str(s) for s in c_index_dict])) 232 | mean_acc = np.mean(c_index_dict) 233 | std_acc = np.std(c_index_dict) 234 | worst_acc = min(c_index_dict) 235 | print("mean_acc", mean_acc) 236 | print("std_acc", std_acc) 237 | print("worst_acc", worst_acc) 238 | 239 | return results 240 | 241 | if __name__ == "__main__": 242 | args = get_args() 243 | setup_seed(args.seed) 244 | expname = get_expname(args) 245 | os.makedirs(os.path.join(args.result_dir, expname), exist_ok=True) 246 | logger = Logger(args) 247 | logger.log_args(args) 248 | 249 | p = args.p 250 | p_v = int(p*args.V_ratio) 251 | p_s = p-p_v 252 | beta_s = get_beta_s(p_s) 253 | beta_v = np.zeros(p_v) 254 | beta = np.concatenate([beta_s, beta_v]) 255 | 256 | results_list = dd(list) 257 | for i in range(args.times): 258 | logger.info("Round %d" % i) 259 | results = main(args, i, logger) 260 | for k, v in results.items(): 261 | results_list[k].append(v) 262 | 263 | 264 | logger.info("Final Result:") 265 | for k, v in results_list.items(): 266 | if k == "RMSE": 267 | RMSE_dict = dict() 268 | for r_test in args.r_list: 269 | RMSE = [v[i][r_test] for i in range(args.times)] 270 | RMSE_dict[r_test] = sum(RMSE)/len(RMSE) 271 | logger.info("RMSE average: %.3f" % (np.mean(list(RMSE_dict.values())))) 272 | logger.info("RMSE std: %.3f" % ((np.std(list(RMSE_dict.values()))))) 273 | logger.info("RMSE max: %.3f" % ((np.max(list(RMSE_dict.values()))))) 274 | logger.info("Detailed RMSE:") 275 | for r_test in args.r_list: 276 | logger.info("%.1f: %.3f" % (r_test, RMSE_dict[r_test])) 277 | elif k == "beta_hat": 278 | beta_hat_array = np.array(v) 279 | beta_hat_mean = np.mean(beta_hat_array, axis=0) 280 | logger.info("%s: %s" % (k, beta_hat_mean)) 281 | if args.bv_analysis: 282 | bv_dict = dict() 283 | bv_dict["s"] = BV_analysis(beta_hat_array[:,:p_s], beta[:p_s]) 284 | bv_dict["v"] = BV_analysis(beta_hat_array[:,p_s:], beta[p_s:]) 285 | bv_dict["all"] = BV_analysis(beta_hat_array, beta) 286 | for covariates in ["s", "v", "all"]: 287 | logger.info("Bias for %s: %.4f, variance for %s: %.4f" % (covariates, bv_dict[covariates][0], covariates, bv_dict[covariates][1])) 288 | else: 289 | logger.info("%s: %.3f" % (k, sum(v)/len(v))) 290 | 291 | 292 | -------------------------------------------------------------------------------- /mRNA_breast_OS.py: -------------------------------------------------------------------------------- 1 | from data.selection_bias import gen_selection_bias_data 2 | from algorithm.DWR import DWR 3 | from algorithm.SRDO import SRDO 4 | from model.linear import get_algorithm_class 5 | from metrics import get_metric_class 6 | from utils import setup_seed, get_beta_s, get_expname, calc_var, pretty, get_cov_mask, BV_analysis 7 | from Logger import Logger 8 | from model.STG import STG 9 | from sksurv.metrics import brier_score, cumulative_dynamic_auc 10 | from sklearn.metrics import mean_squared_error 11 | import numpy as np 12 | import argparse 13 | import os 14 | import torch 15 | from collections import defaultdict as dd 16 | import pandas as pd 17 | 18 | from sklearn.model_selection import train_test_split 19 | from lifelines import CoxPHFitter 20 | import seaborn as sns 21 | import matplotlib.pyplot as plt 22 | from lifelines.statistics import logrank_test 23 | from lifelines import KaplanMeierFitter 24 | from sksurv.util import Surv 25 | from lifelines.utils import concordance_index 26 | from sklearn.metrics import roc_auc_score 27 | from sklearn.metrics import accuracy_score 28 | duration_col = 'Survival.months' 29 | event_col = 'Survival.status' 30 | def get_args(): 31 | parser = argparse.ArgumentParser(description="Script to launch sample reweighting experiments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 32 | 33 | # data generation 34 | parser.add_argument("--p", type=int, default=10, help="Input dim") 35 | parser.add_argument("--n", type=int, default=2000, help="Sample size") 36 | parser.add_argument("--V_ratio", type=float, default=0.5) 37 | parser.add_argument("--Vb_ratio", type=float, default=0.1) 38 | parser.add_argument("--true_func", choices=["linear",], default="linear") 39 | parser.add_argument("--mode", choices=["S_|_V", "S->V", "V->S", "collinearity"], default="collinearity") 40 | parser.add_argument("--misspe", choices=["poly", "exp", "None"], default="poly") 41 | parser.add_argument("--corr_s", type=float, default=0.9) 42 | parser.add_argument("--corr_v", type=float, default=0.1) 43 | parser.add_argument("--mms_strength", type=float, default=1.0, help="model misspecifction strength") 44 | parser.add_argument("--spurious", choices=["nonlinear", "linear"], default="nonlinear") 45 | parser.add_argument("--r_train", type=float, default=2.5, help="Input dim") 46 | parser.add_argument("--r_list", type=float, nargs="+", default=[-3, -2, -1.7, -1.5, -1.3, 1.3, 1.5, 1.7, 2, 3]) 47 | parser.add_argument("--noise_variance", type=float, default=0.3) 48 | 49 | # frontend reweighting 50 | parser.add_argument("--reweighting", choices=["None", "DWR", "SRDO"], default="DWR") 51 | parser.add_argument("--decorrelation_type", choices=["global", "group"], default="global") 52 | parser.add_argument("--order", type=int, default=1) 53 | parser.add_argument("--iters_balance", type=int, default=3000) 54 | 55 | parser.add_argument("--topN", type=int, default=5) 56 | # backend model 57 | parser.add_argument("--backend", choices=["OLS", "Lasso", "Ridge", "Weighted_cox"], default="Weighted_cox") 58 | parser.add_argument("--paradigm", choices=["regr", "fs",], default="regr") 59 | parser.add_argument("--iters_train", type=int, default=5000) 60 | parser.add_argument("--lam_backend", type=float, default=0.01) # regularizer coefficient 61 | parser.add_argument("--fs_type", choices=["oracle", "None", "given", "STG"], default="STG") 62 | parser.add_argument("--mask_given", type=int, nargs="+", default=[1,1,1,1,1,0,0,0,0,0]) 63 | parser.add_argument("--mask_threshold", type=float, default=0.2) 64 | parser.add_argument("--lam_STG", type=float, default=3) 65 | parser.add_argument("--sigma_STG", type=float, default=0.1) 66 | parser.add_argument("--metrics", nargs="+", default=["L1_beta_error", "L2_beta_error"]) 67 | parser.add_argument("--bv_analysis", action="store_true") 68 | 69 | # others 70 | parser.add_argument("--seed", type=int, default=3) 71 | parser.add_argument("--times", type=int, default=1) 72 | parser.add_argument("--result_dir", default="results") 73 | 74 | return parser.parse_args() 75 | 76 | 77 | def plot_KM(cph, sorted_indices, coef, selected_data): 78 | tmp_sel = selected_data[list(sorted_indices)] 79 | coef_value = np.dot(np.array(tmp_sel), np.array(coef)).reshape((-1, 1)) 80 | 81 | c_index = concordance_index(selected_data['Survival.months'], -cph.predict_partial_hazard(selected_data), selected_data['Survival.status']) 82 | 83 | return c_index 84 | 85 | 86 | def main(args, round, logger): 87 | setup_seed(args.seed + round) 88 | p = args.p 89 | p_v = int(p*args.V_ratio) 90 | p_s = p-p_v 91 | n = args.n 92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 93 | oracle_mask = [True,]*p_s + [False,]*p_v 94 | 95 | # generate train data 96 | X_train, S_train, V_train, fs_train, Y_train = gen_selection_bias_data({**vars(args),**{"r": args.r_train}}) 97 | 98 | beta_s = get_beta_s(p_s) 99 | beta_v = np.zeros(p_v) 100 | beta = np.concatenate([beta_s, beta_v]) 101 | 102 | linear_var, nonlinear_var, total_var = calc_var(beta_s, S_train, fs_train) 103 | logger.info("Linear term var: %.3f, Nonlinear term var: %.3f, total var: %.3f" % (linear_var, nonlinear_var, total_var)) 104 | 105 | # generate test data 106 | test_data = dict() 107 | for r_test in args.r_list: 108 | X_test, S_test, V_test, fs_test, Y_test = gen_selection_bias_data({**vars(args),**{"r": r_test}}) 109 | test_data[r_test] = (X_test, S_test, V_test, fs_test, Y_test) 110 | 111 | training_pd_data = pd.read_csv('./omics_data/breast_cancer/train_median.csv', index_col=0) 112 | test0_pd_data = pd.read_csv('./omics_data/breast_cancer/test0_median.csv', index_col=0) 113 | test1_pd_data = pd.read_csv('./omics_data/breast_cancer/test1_median.csv', index_col=0) 114 | test2_pd_data = pd.read_csv('./omics_data/breast_cancer/test2_median.csv', index_col=0) 115 | test3_pd_data = pd.read_csv('./omics_data/breast_cancer/test3_median.csv', index_col=0) 116 | 117 | 118 | tmp_data = training_pd_data[training_pd_data.columns[2:]] 119 | top_3_cols = tmp_data.columns 120 | # 选择方差最大的三列 121 | X_train_pd = pd.concat([training_pd_data[training_pd_data.columns[:2]], training_pd_data[top_3_cols]], axis=1) 122 | X_train_np = np.array(training_pd_data[top_3_cols]) 123 | 124 | X_test0_pd = pd.concat([test0_pd_data[test0_pd_data.columns[:2]], test0_pd_data[top_3_cols]], axis=1) 125 | X_test1_pd = pd.concat([test1_pd_data[test1_pd_data.columns[:2]], test1_pd_data[top_3_cols]], axis=1) 126 | 127 | X_test2_pd = pd.concat([test2_pd_data[test2_pd_data.columns[:2]], test2_pd_data[top_3_cols]], axis=1) 128 | X_test3_pd = pd.concat([test3_pd_data[test3_pd_data.columns[:2]], test3_pd_data[top_3_cols]], axis=1) 129 | 130 | 131 | p = X_train_np.shape[1] 132 | n= X_train_np.shape[0] 133 | if args.reweighting == "DWR": 134 | if args.decorrelation_type == "global": 135 | cov_mask = get_cov_mask(np.zeros(p)) 136 | elif args.decorrelation_type == "group": 137 | cov_mask = get_cov_mask(np.array(oracle_mask, np.float)) 138 | else: 139 | raise NotImplementedError 140 | W = DWR(X_train_np, cov_mask=cov_mask, order=args.order, num_steps=args.iters_balance, logger=logger, device=device) 141 | elif args.reweighting == "SRDO": 142 | W = SRDO(X_train_np, p_s, hidden_layer_sizes=(30, 2), decorrelation_type=args.decorrelation_type, max_iter=args.iters_balance) 143 | else: 144 | W = np.ones((n, 1))/n 145 | 146 | mean_value = np.mean(W) 147 | W = W * (1/mean_value) 148 | #print("max*********", np.max(W)) 149 | #print("min********", np.min(W)) 150 | 151 | 152 | #W = (W - min_value)/(max_value-min_value) 153 | #W = W * 2 154 | #print("W", W) 155 | 156 | W = np.clip(W, 0.03, 3) 157 | columns = X_train_pd.columns 158 | all_X = np.concatenate((X_train_pd, W), axis=1) 159 | 160 | X_train_test = pd.DataFrame(all_X, columns=list(columns)+["Weights"]) 161 | test0_W = np.ones((X_test0_pd.shape[0], 1)) 162 | 163 | X_test0_pd = np.concatenate((X_test0_pd, test0_W), axis=1) 164 | 165 | X_test0_pd = pd.DataFrame(X_test0_pd, columns=list(columns)+["Weights"]) 166 | 167 | test1_W = np.ones((X_test1_pd.shape[0], 1)) 168 | 169 | X_test1_pd = np.concatenate((X_test1_pd, test1_W), axis=1) 170 | 171 | X_test1_pd = pd.DataFrame(X_test1_pd, columns=list(columns)+["Weights"]) 172 | 173 | test2_W = np.ones((X_test2_pd.shape[0], 1)) 174 | 175 | X_test2_pd = np.concatenate((X_test2_pd, test2_W), axis=1) 176 | 177 | X_test2_pd = pd.DataFrame(X_test2_pd, columns=list(columns)+["Weights"]) 178 | 179 | test3_W = np.ones((X_test3_pd.shape[0], 1)) 180 | X_test3_pd = np.concatenate((X_test3_pd, test3_W), axis=1) 181 | X_test3_pd = pd.DataFrame(X_test3_pd, columns=list(columns)+["Weights"]) 182 | 183 | 184 | 185 | 186 | results = dict() 187 | if args.paradigm == "regr": 188 | mask = [True,]*p 189 | model_func = get_algorithm_class(args.backend) 190 | model = model_func(X_train_pd, duration_col, event_col, W, 0.03, **vars(args)) 191 | print("train score", model.score(X_train_test)) 192 | 193 | elif args.paradigm == "fs": 194 | if args.fs_type == "STG": 195 | stg = STG(p, 1, sigma=args.sigma_STG, lam=args.lam_STG) 196 | stg.train(X_train, Y_train, W=W, epochs=5000) 197 | select_ratio = stg.get_ratios().detach().numpy() 198 | logger.info("Select ratio: " + pretty(select_ratio)) 199 | mask = select_ratio > args.mask_threshold 200 | elif args.fs_type == "oracle": 201 | mask = oracle_mask 202 | elif args.fs_type == "None": 203 | mask = [True,]*p 204 | elif args.fs_type == "given": 205 | mask = np.array(args.mask_given, np.bool) 206 | else: 207 | raise NotImplementedError 208 | if np.array(mask, dtype=np.int64).sum() == 0: 209 | logger.info("All variables are discarded!") 210 | assert False 211 | logger.info("Hard selection: " + str(np.array(mask, dtype=np.int64))) 212 | model_func = get_algorithm_class(args.backend) 213 | model = model_func(X_train, Y_train, np.ones((n, 1))/n, **vars(args)) 214 | model.fit(X_train[:, mask], Y_train) 215 | else: 216 | raise NotImplementedError 217 | 218 | # test 219 | summary = model.summary 220 | sorted_indices = summary['p'].sort_values().head(args.topN).index 221 | 222 | print("\n".join([str(tmp) for tmp in list(sorted_indices)])) 223 | coef = summary.loc[list(sorted_indices)]["coef"] 224 | 225 | selected_X_train = pd.concat([X_train_pd[X_train_pd.columns[:2]], X_train_pd[sorted_indices]], axis=1) 226 | cph2 = CoxPHFitter(penalizer=0.0001) 227 | cph2.fit(selected_X_train, duration_col='Survival.months', event_col='Survival.status') 228 | 229 | summary = cph2.summary 230 | coef = summary["coef"] 231 | 232 | 233 | 234 | 235 | c_index_dict = [] 236 | print("test1:") 237 | c_index0 = plot_KM(cph2, sorted_indices, coef, X_test0_pd) 238 | print(c_index0) 239 | c_index_dict.append(c_index0) 240 | 241 | 242 | print("test2") 243 | c_index1 = plot_KM(cph2, sorted_indices, coef, X_test1_pd) 244 | print(c_index1) 245 | c_index_dict.append(c_index1) 246 | 247 | 248 | print("test3") 249 | c_index2 = plot_KM(cph2, sorted_indices, coef, X_test2_pd) 250 | print(c_index2) 251 | c_index_dict.append(c_index2) 252 | 253 | print("c_index") 254 | print(c_index_dict) 255 | mean_acc = np.mean(c_index_dict) 256 | std_acc = np.std(c_index_dict) 257 | worst_acc = min(c_index_dict) 258 | print(mean_acc) 259 | print(std_acc) 260 | print(worst_acc) 261 | 262 | 263 | 264 | 265 | return results 266 | 267 | if __name__ == "__main__": 268 | args = get_args() 269 | setup_seed(args.seed) 270 | expname = get_expname(args) 271 | os.makedirs(os.path.join(args.result_dir, expname), exist_ok=True) 272 | logger = Logger(args) 273 | logger.log_args(args) 274 | 275 | p = args.p 276 | p_v = int(p*args.V_ratio) 277 | p_s = p-p_v 278 | beta_s = get_beta_s(p_s) 279 | beta_v = np.zeros(p_v) 280 | beta = np.concatenate([beta_s, beta_v]) 281 | 282 | results_list = dd(list) 283 | for i in range(args.times): 284 | logger.info("Round %d" % i) 285 | results = main(args, i, logger) 286 | for k, v in results.items(): 287 | results_list[k].append(v) 288 | 289 | 290 | logger.info("Final Result:") 291 | for k, v in results_list.items(): 292 | if k == "RMSE": 293 | RMSE_dict = dict() 294 | for r_test in args.r_list: 295 | RMSE = [v[i][r_test] for i in range(args.times)] 296 | RMSE_dict[r_test] = sum(RMSE)/len(RMSE) 297 | logger.info("RMSE average: %.3f" % (np.mean(list(RMSE_dict.values())))) 298 | logger.info("RMSE std: %.3f" % ((np.std(list(RMSE_dict.values()))))) 299 | logger.info("RMSE max: %.3f" % ((np.max(list(RMSE_dict.values()))))) 300 | logger.info("Detailed RMSE:") 301 | for r_test in args.r_list: 302 | logger.info("%.1f: %.3f" % (r_test, RMSE_dict[r_test])) 303 | elif k == "beta_hat": 304 | beta_hat_array = np.array(v) 305 | beta_hat_mean = np.mean(beta_hat_array, axis=0) 306 | logger.info("%s: %s" % (k, beta_hat_mean)) 307 | if args.bv_analysis: 308 | bv_dict = dict() 309 | bv_dict["s"] = BV_analysis(beta_hat_array[:,:p_s], beta[:p_s]) 310 | bv_dict["v"] = BV_analysis(beta_hat_array[:,p_s:], beta[p_s:]) 311 | bv_dict["all"] = BV_analysis(beta_hat_array, beta) 312 | for covariates in ["s", "v", "all"]: 313 | logger.info("Bias for %s: %.4f, variance for %s: %.4f" % (covariates, bv_dict[covariates][0], covariates, bv_dict[covariates][1])) 314 | else: 315 | logger.info("%s: %.3f" % (k, sum(v)/len(v))) 316 | 317 | 318 | -------------------------------------------------------------------------------- /mRNA_mela_OS.py: -------------------------------------------------------------------------------- 1 | from data.selection_bias import gen_selection_bias_data 2 | from algorithm.DWR import DWR 3 | from algorithm.SRDO import SRDO 4 | from model.linear import get_algorithm_class 5 | from metrics import get_metric_class 6 | from utils import setup_seed, get_beta_s, get_expname, calc_var, pretty, get_cov_mask, BV_analysis 7 | from Logger import Logger 8 | from model.STG import STG 9 | from sksurv.metrics import brier_score, cumulative_dynamic_auc 10 | from sklearn.metrics import mean_squared_error 11 | import numpy as np 12 | import argparse 13 | import os 14 | import torch 15 | from collections import defaultdict as dd 16 | import pandas as pd 17 | 18 | from sklearn.model_selection import train_test_split 19 | from lifelines import CoxPHFitter 20 | import seaborn as sns 21 | import matplotlib.pyplot as plt 22 | from lifelines.statistics import logrank_test 23 | from lifelines import KaplanMeierFitter 24 | from sksurv.util import Surv 25 | from lifelines.utils import concordance_index 26 | from sklearn.metrics import roc_auc_score 27 | from sklearn.metrics import accuracy_score 28 | duration_col = 'Survival.months' 29 | event_col = 'Survival.status' 30 | def get_args(): 31 | parser = argparse.ArgumentParser(description="Script to launch sample reweighting experiments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 32 | 33 | # data generation 34 | parser.add_argument("--p", type=int, default=10, help="Input dim") 35 | parser.add_argument("--n", type=int, default=2000, help="Sample size") 36 | parser.add_argument("--V_ratio", type=float, default=0.5) 37 | parser.add_argument("--Vb_ratio", type=float, default=0.1) 38 | parser.add_argument("--true_func", choices=["linear",], default="linear") 39 | parser.add_argument("--mode", choices=["S_|_V", "S->V", "V->S", "collinearity"], default="collinearity") 40 | parser.add_argument("--misspe", choices=["poly", "exp", "None"], default="poly") 41 | parser.add_argument("--corr_s", type=float, default=0.9) 42 | parser.add_argument("--corr_v", type=float, default=0.1) 43 | parser.add_argument("--mms_strength", type=float, default=1.0, help="model misspecifction strength") 44 | parser.add_argument("--spurious", choices=["nonlinear", "linear"], default="nonlinear") 45 | parser.add_argument("--r_train", type=float, default=2.5, help="Input dim") 46 | parser.add_argument("--r_list", type=float, nargs="+", default=[-3, -2, -1.7, -1.5, -1.3, 1.3, 1.5, 1.7, 2, 3]) 47 | parser.add_argument("--noise_variance", type=float, default=0.3) 48 | 49 | # frontend reweighting 50 | parser.add_argument("--reweighting", choices=["None", "DWR", "SRDO"], default="DWR") 51 | parser.add_argument("--decorrelation_type", choices=["global", "group"], default="global") 52 | parser.add_argument("--order", type=int, default=1) 53 | parser.add_argument("--iters_balance", type=int, default=3000) 54 | 55 | parser.add_argument("--topN", type=int, default=5) 56 | # backend model 57 | parser.add_argument("--backend", choices=["OLS", "Lasso", "Ridge", "Weighted_cox"], default="Weighted_cox") 58 | parser.add_argument("--paradigm", choices=["regr", "fs",], default="regr") 59 | parser.add_argument("--iters_train", type=int, default=5000) 60 | parser.add_argument("--lam_backend", type=float, default=0.01) # regularizer coefficient 61 | parser.add_argument("--fs_type", choices=["oracle", "None", "given", "STG"], default="STG") 62 | parser.add_argument("--mask_given", type=int, nargs="+", default=[1,1,1,1,1,0,0,0,0,0]) 63 | parser.add_argument("--mask_threshold", type=float, default=0.2) 64 | parser.add_argument("--lam_STG", type=float, default=3) 65 | parser.add_argument("--sigma_STG", type=float, default=0.1) 66 | parser.add_argument("--metrics", nargs="+", default=["L1_beta_error", "L2_beta_error"]) 67 | parser.add_argument("--bv_analysis", action="store_true") 68 | 69 | # others 70 | parser.add_argument("--seed", type=int, default=3) 71 | parser.add_argument("--times", type=int, default=1) 72 | parser.add_argument("--result_dir", default="results") 73 | 74 | return parser.parse_args() 75 | 76 | 77 | def plot_KM(cph, sorted_indices, coef, selected_data): 78 | tmp_sel = selected_data[list(sorted_indices)] 79 | coef_value = np.dot(np.array(tmp_sel), np.array(coef)).reshape((-1, 1)) 80 | 81 | c_index = concordance_index(selected_data['Survival.months'], -cph.predict_partial_hazard(selected_data), selected_data['Survival.status']) 82 | 83 | return c_index 84 | 85 | def main(args, round, logger): 86 | setup_seed(args.seed + round) 87 | p = args.p 88 | p_v = int(p*args.V_ratio) 89 | p_s = p-p_v 90 | n = args.n 91 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 92 | oracle_mask = [True,]*p_s + [False,]*p_v 93 | 94 | # generate train data 95 | X_train, S_train, V_train, fs_train, Y_train = gen_selection_bias_data({**vars(args),**{"r": args.r_train}}) 96 | 97 | beta_s = get_beta_s(p_s) 98 | beta_v = np.zeros(p_v) 99 | beta = np.concatenate([beta_s, beta_v]) 100 | 101 | linear_var, nonlinear_var, total_var = calc_var(beta_s, S_train, fs_train) 102 | logger.info("Linear term var: %.3f, Nonlinear term var: %.3f, total var: %.3f" % (linear_var, nonlinear_var, total_var)) 103 | 104 | # generate test data 105 | test_data = dict() 106 | for r_test in args.r_list: 107 | X_test, S_test, V_test, fs_test, Y_test = gen_selection_bias_data({**vars(args),**{"r": r_test}}) 108 | test_data[r_test] = (X_test, S_test, V_test, fs_test, Y_test) 109 | 110 | training_pd_data = pd.read_csv('./omics_data/mela_cancer/train_median.csv', index_col=0) 111 | test0_pd_data = pd.read_csv('./omics_data/mela_cancer/test0_median.csv', index_col=0) 112 | test1_pd_data = pd.read_csv('./omics_data/mela_cancer/test1_median.csv', index_col=0) 113 | test2_pd_data = pd.read_csv('./omics_data/mela_cancer/test2_median.csv', index_col=0) 114 | test3_pd_data = pd.read_csv('./omics_data/mela_cancer/test3_median.csv', index_col=0) 115 | 116 | training_pd_data['Survival.months'] = training_pd_data['Survival.months']/30.0 117 | test0_pd_data['Survival.months'] = test0_pd_data['Survival.months']/30.0 118 | test1_pd_data['Survival.months'] = test1_pd_data['Survival.months']/30.0 119 | 120 | test2_pd_data['Survival.months'] = test2_pd_data['Survival.months']/30.0 121 | test3_pd_data['Survival.months'] = test3_pd_data['Survival.months']/30.0 122 | 123 | 124 | 125 | tmp_data = training_pd_data.columns[2:] 126 | # 选择方差最大的三列 127 | top_3_cols = tmp_data 128 | 129 | X_train_pd = pd.concat([training_pd_data[training_pd_data.columns[:2]], training_pd_data[top_3_cols]], axis=1) 130 | X_train_np = np.array(training_pd_data[top_3_cols]) 131 | X_test0_pd = pd.concat([test0_pd_data[test0_pd_data.columns[:2]], test0_pd_data[top_3_cols]], axis=1) 132 | X_test1_pd = pd.concat([test1_pd_data[test1_pd_data.columns[:2]], test1_pd_data[top_3_cols]], axis=1) 133 | 134 | X_test2_pd = pd.concat([test2_pd_data[test2_pd_data.columns[:2]], test2_pd_data[top_3_cols]], axis=1) 135 | X_test3_pd = pd.concat([test3_pd_data[test3_pd_data.columns[:2]], test3_pd_data[top_3_cols]], axis=1) 136 | 137 | 138 | p = X_train_np.shape[1] 139 | n= X_train_np.shape[0] 140 | if args.reweighting == "DWR": 141 | if args.decorrelation_type == "global": 142 | cov_mask = get_cov_mask(np.zeros(p)) 143 | elif args.decorrelation_type == "group": 144 | cov_mask = get_cov_mask(np.array(oracle_mask, np.float)) 145 | else: 146 | raise NotImplementedError 147 | W = DWR(X_train_np, cov_mask=cov_mask, order=args.order, num_steps=args.iters_balance, logger=logger, device=device) 148 | elif args.reweighting == "SRDO": 149 | W = SRDO(X_train_np, p_s, hidden_layer_sizes = (30, 9), decorrelation_type=args.decorrelation_type, max_iter=args.iters_balance) 150 | else: 151 | W = np.ones((n, 1))/n 152 | 153 | mean_value = np.mean(W) 154 | W = W * (1/mean_value) 155 | 156 | W = np.clip(W, 0.03, 3) 157 | columns = X_train_pd.columns 158 | all_X = np.concatenate((X_train_pd, W), axis=1) 159 | 160 | X_train_test = pd.DataFrame(all_X, columns=list(columns)+["Weights"]) 161 | test0_W = np.ones((X_test0_pd.shape[0], 1)) 162 | 163 | X_test0_pd = np.concatenate((X_test0_pd, test0_W), axis=1) 164 | 165 | X_test0_pd = pd.DataFrame(X_test0_pd, columns=list(columns)+["Weights"]) 166 | 167 | test1_W = np.ones((X_test1_pd.shape[0], 1)) 168 | 169 | X_test1_pd = np.concatenate((X_test1_pd, test1_W), axis=1) 170 | 171 | X_test1_pd = pd.DataFrame(X_test1_pd, columns=list(columns)+["Weights"]) 172 | 173 | test2_W = np.ones((X_test2_pd.shape[0], 1)) 174 | 175 | X_test2_pd = np.concatenate((X_test2_pd, test2_W), axis=1) 176 | 177 | X_test2_pd = pd.DataFrame(X_test2_pd, columns=list(columns)+["Weights"]) 178 | 179 | test3_W = np.ones((X_test3_pd.shape[0], 1)) 180 | X_test3_pd = np.concatenate((X_test3_pd, test3_W), axis=1) 181 | X_test3_pd = pd.DataFrame(X_test3_pd, columns=list(columns)+["Weights"]) 182 | 183 | 184 | 185 | results = dict() 186 | if args.paradigm == "regr": 187 | mask = [True,]*p 188 | model_func = get_algorithm_class(args.backend) 189 | model = model_func(X_train_pd, duration_col, event_col, W, 0.0003, **vars(args)) 190 | print("train score", model.score(X_train_test)) 191 | 192 | elif args.paradigm == "fs": 193 | if args.fs_type == "STG": 194 | stg = STG(p, 1, sigma=args.sigma_STG, lam=args.lam_STG) 195 | stg.train(X_train, Y_train, W=W, epochs=5000) 196 | select_ratio = stg.get_ratios().detach().numpy() 197 | logger.info("Select ratio: " + pretty(select_ratio)) 198 | mask = select_ratio > args.mask_threshold 199 | elif args.fs_type == "oracle": 200 | mask = oracle_mask 201 | elif args.fs_type == "None": 202 | mask = [True,]*p 203 | elif args.fs_type == "given": 204 | mask = np.array(args.mask_given, np.bool) 205 | else: 206 | raise NotImplementedError 207 | if np.array(mask, dtype=np.int64).sum() == 0: 208 | logger.info("All variables are discarded!") 209 | assert False 210 | logger.info("Hard selection: " + str(np.array(mask, dtype=np.int64))) 211 | model_func = get_algorithm_class(args.backend) 212 | model = model_func(X_train, Y_train, np.ones((n, 1))/n, **vars(args)) 213 | model.fit(X_train[:, mask], Y_train) 214 | else: 215 | raise NotImplementedError 216 | 217 | # test 218 | summary = model.summary 219 | sorted_indices = summary['p'].sort_values().head(args.topN).index 220 | 221 | print("\n".join([str(tmp) for tmp in list(sorted_indices)])) 222 | coef = summary.loc[list(sorted_indices)]["coef"] 223 | selected_X_train = pd.concat([X_train_pd[X_train_pd.columns[:2]], X_train_pd[sorted_indices]], axis=1) 224 | cph2 = CoxPHFitter(penalizer=0.0001) 225 | cph2.fit(selected_X_train, duration_col='Survival.months', event_col='Survival.status') 226 | 227 | summary = cph2.summary 228 | coef = summary["coef"] 229 | 230 | c_index_dict = [] 231 | print("test0:") 232 | c_index0 = plot_KM(cph2, sorted_indices, coef, X_test0_pd) 233 | print(c_index0) 234 | c_index_dict.append(c_index0) 235 | 236 | print("test1") 237 | c_index1 = plot_KM(cph2, sorted_indices, coef, X_test1_pd) 238 | print(c_index1) 239 | c_index_dict.append(c_index1) 240 | 241 | 242 | print("test2") 243 | c_index2 = plot_KM(cph2, sorted_indices, coef, X_test3_pd) 244 | print(c_index2) 245 | c_index_dict.append(c_index2) 246 | 247 | 248 | print("c_index") 249 | mean_acc = np.mean(c_index_dict) 250 | std_acc = np.std(c_index_dict) 251 | worst_acc = min(c_index_dict) 252 | print(c_index_dict) 253 | print(mean_acc) 254 | print(std_acc) 255 | print(worst_acc) 256 | 257 | 258 | return results 259 | 260 | if __name__ == "__main__": 261 | args = get_args() 262 | setup_seed(args.seed) 263 | expname = get_expname(args) 264 | os.makedirs(os.path.join(args.result_dir, expname), exist_ok=True) 265 | logger = Logger(args) 266 | logger.log_args(args) 267 | 268 | p = args.p 269 | p_v = int(p*args.V_ratio) 270 | p_s = p-p_v 271 | beta_s = get_beta_s(p_s) 272 | beta_v = np.zeros(p_v) 273 | beta = np.concatenate([beta_s, beta_v]) 274 | results_list = dd(list) 275 | for i in range(args.times): 276 | logger.info("Round %d" % i) 277 | results = main(args, i, logger) 278 | for k, v in results.items(): 279 | results_list[k].append(v) 280 | 281 | 282 | logger.info("Final Result:") 283 | for k, v in results_list.items(): 284 | if k == "RMSE": 285 | RMSE_dict = dict() 286 | for r_test in args.r_list: 287 | RMSE = [v[i][r_test] for i in range(args.times)] 288 | RMSE_dict[r_test] = sum(RMSE)/len(RMSE) 289 | logger.info("RMSE average: %.3f" % (np.mean(list(RMSE_dict.values())))) 290 | logger.info("RMSE std: %.3f" % ((np.std(list(RMSE_dict.values()))))) 291 | logger.info("RMSE max: %.3f" % ((np.max(list(RMSE_dict.values()))))) 292 | logger.info("Detailed RMSE:") 293 | for r_test in args.r_list: 294 | logger.info("%.1f: %.3f" % (r_test, RMSE_dict[r_test])) 295 | elif k == "beta_hat": 296 | beta_hat_array = np.array(v) 297 | beta_hat_mean = np.mean(beta_hat_array, axis=0) 298 | logger.info("%s: %s" % (k, beta_hat_mean)) 299 | if args.bv_analysis: 300 | bv_dict = dict() 301 | bv_dict["s"] = BV_analysis(beta_hat_array[:,:p_s], beta[:p_s]) 302 | bv_dict["v"] = BV_analysis(beta_hat_array[:,p_s:], beta[p_s:]) 303 | bv_dict["all"] = BV_analysis(beta_hat_array, beta) 304 | for covariates in ["s", "v", "all"]: 305 | logger.info("Bias for %s: %.4f, variance for %s: %.4f" % (covariates, bv_dict[covariates][0], covariates, bv_dict[covariates][1])) 306 | else: 307 | logger.info("%s: %.3f" % (k, sum(v)/len(v))) 308 | 309 | 310 | -------------------------------------------------------------------------------- /clinical_lung_DFS.py: -------------------------------------------------------------------------------- 1 | from data.selection_bias import gen_selection_bias_data 2 | from algorithm.DWR import DWR 3 | from algorithm.SRDO import SRDO 4 | from model.linear import get_algorithm_class 5 | from metrics import get_metric_class 6 | from utils import setup_seed, get_beta_s, get_expname, calc_var, pretty, get_cov_mask, BV_analysis 7 | from Logger import Logger 8 | from model.STG import STG 9 | from sksurv.metrics import brier_score, cumulative_dynamic_auc 10 | from sklearn.metrics import mean_squared_error 11 | import numpy as np 12 | import argparse 13 | import os 14 | import torch 15 | from collections import defaultdict as dd 16 | import pandas as pd 17 | 18 | from lifelines import CoxPHFitter 19 | import seaborn as sns 20 | import matplotlib.pyplot as plt 21 | from lifelines.statistics import logrank_test 22 | from lifelines import KaplanMeierFitter 23 | from sksurv.util import Surv 24 | from lifelines.utils import concordance_index 25 | duration_col='Reccur.months' 26 | event_col='Reccur.status' 27 | def get_args(): 28 | parser = argparse.ArgumentParser(description="Script to launch sample reweighting experiments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 29 | 30 | # data generation 31 | parser.add_argument("--p", type=int, default=10, help="Input dim") 32 | parser.add_argument("--n", type=int, default=2000, help="Sample size") 33 | parser.add_argument("--V_ratio", type=float, default=0.5) 34 | parser.add_argument("--Vb_ratio", type=float, default=0.1) 35 | parser.add_argument("--true_func", choices=["linear",], default="linear") 36 | parser.add_argument("--mode", choices=["S_|_V", "S->V", "V->S", "collinearity"], default="collinearity") 37 | parser.add_argument("--misspe", choices=["poly", "exp", "None"], default="poly") 38 | parser.add_argument("--corr_s", type=float, default=0.9) 39 | parser.add_argument("--corr_v", type=float, default=0.1) 40 | parser.add_argument("--mms_strength", type=float, default=1.0, help="model misspecifction strength") 41 | parser.add_argument("--spurious", choices=["nonlinear", "linear"], default="nonlinear") 42 | parser.add_argument("--r_train", type=float, default=2.5, help="Input dim") 43 | parser.add_argument("--r_list", type=float, nargs="+", default=[-3, -2, -1.7, -1.5, -1.3, 1.3, 1.5, 1.7, 2, 3]) 44 | parser.add_argument("--noise_variance", type=float, default=0.3) 45 | 46 | # frontend reweighting 47 | parser.add_argument("--reweighting", choices=["None", "DWR", "SRDO"], default="DWR") 48 | parser.add_argument("--decorrelation_type", choices=["global", "group"], default="global") 49 | parser.add_argument("--order", type=int, default=1) 50 | parser.add_argument("--iters_balance", type=int, default=3000) 51 | 52 | parser.add_argument("--topN", type=int, default=5) 53 | # backend model 54 | parser.add_argument("--backend", choices=["OLS", "Lasso", "Ridge", "Weighted_cox"], default="Weighted_cox") 55 | parser.add_argument("--paradigm", choices=["regr", "fs",], default="regr") 56 | parser.add_argument("--iters_train", type=int, default=1000) 57 | parser.add_argument("--lam_backend", type=float, default=0.01) # regularizer coefficient 58 | parser.add_argument("--fs_type", choices=["oracle", "None", "given", "STG"], default="STG") 59 | parser.add_argument("--mask_given", type=int, nargs="+", default=[1,1,1,1,1,0,0,0,0,0]) 60 | parser.add_argument("--mask_threshold", type=float, default=0.2) 61 | parser.add_argument("--lam_STG", type=float, default=3) 62 | parser.add_argument("--sigma_STG", type=float, default=0.1) 63 | parser.add_argument("--metrics", nargs="+", default=["L1_beta_error", "L2_beta_error"]) 64 | parser.add_argument("--bv_analysis", action="store_true") 65 | 66 | # others 67 | parser.add_argument("--seed", type=int, default=3) 68 | parser.add_argument("--times", type=int, default=1) 69 | parser.add_argument("--result_dir", default="results") 70 | 71 | return parser.parse_args() 72 | 73 | 74 | def main(args, round, logger): 75 | setup_seed(args.seed + round) 76 | p = args.p 77 | p_v = int(p*args.V_ratio) 78 | p_s = p-p_v 79 | n = args.n 80 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 81 | oracle_mask = [True,]*p_s + [False,]*p_v 82 | 83 | # generate train data 84 | X_train, S_train, V_train, fs_train, Y_train = gen_selection_bias_data({**vars(args),**{"r": args.r_train}}) 85 | 86 | beta_s = get_beta_s(p_s) 87 | beta_v = np.zeros(p_v) 88 | beta = np.concatenate([beta_s, beta_v]) 89 | 90 | linear_var, nonlinear_var, total_var = calc_var(beta_s, S_train, fs_train) 91 | logger.info("Linear term var: %.3f, Nonlinear term var: %.3f, total var: %.3f" % (linear_var, nonlinear_var, total_var)) 92 | 93 | # generate test data 94 | test_data = dict() 95 | for r_test in args.r_list: 96 | X_test, S_test, V_test, fs_test, Y_test = gen_selection_bias_data({**vars(args),**{"r": r_test}}) 97 | test_data[r_test] = (X_test, S_test, V_test, fs_test, Y_test) 98 | 99 | training_data = pd.read_csv('./clinical_data/lung_cancer/train_pd.csv', index_col=0) 100 | test0_data = pd.read_csv('./clinical_data/lung_cancer/test0_pd.csv', index_col=0) 101 | test1_data = pd.read_csv('./clinical_data/lung_cancer/test1_pd.csv', index_col=0) 102 | test2_data = pd.read_csv('./clinical_data/lung_cancer/test2_pd.csv', index_col=0) 103 | test3_data = pd.read_csv('./clinical_data/lung_cancer/test3_pd.csv', index_col=0) 104 | 105 | test4_data = pd.read_csv('./clinical_data/lung_cancer/test4_pd.csv', index_col=0) 106 | test5_data = pd.read_csv('./clinical_data/lung_cancer/test5_pd.csv', index_col=0) 107 | test6_data = pd.read_csv('./clinical_data/lung_cancer/test6_pd.csv', index_col=0) 108 | test7_data = pd.read_csv('./clinical_data/lung_cancer/test7_pd.csv', index_col=0) 109 | training_data.drop(['Survival.status', 'Survival.months'], axis=1, inplace=True) 110 | test0_data.drop(['Survival.status', 'Survival.months'], axis=1, inplace=True) 111 | test1_data.drop(['Survival.status', 'Survival.months'], axis=1, inplace=True) 112 | test2_data.drop(['Survival.status', 'Survival.months'], axis=1, inplace=True) 113 | test3_data.drop(['Survival.status', 'Survival.months'], axis=1, inplace=True) 114 | 115 | test4_data.drop(['Survival.status', 'Survival.months'], axis=1, inplace=True) 116 | test5_data.drop(['Survival.status', 'Survival.months'], axis=1, inplace=True) 117 | test6_data.drop(['Survival.status', 'Survival.months'], axis=1, inplace=True) 118 | test7_data.drop(['Survival.status', 'Survival.months'], axis=1, inplace=True) 119 | X_train_pd = training_data 120 | tmp = training_data[training_data.columns[:-2]] 121 | 122 | X_train_np = np.array(training_data[training_data.columns[:-2]]) 123 | p = X_train_np.shape[1] 124 | n= X_train_np.shape[0] 125 | 126 | if args.reweighting == "DWR": 127 | if args.decorrelation_type == "global": 128 | cov_mask = get_cov_mask(np.zeros(p)) 129 | elif args.decorrelation_type == "group": 130 | cov_mask = get_cov_mask(np.array(oracle_mask, np.float)) 131 | else: 132 | raise NotImplementedError 133 | W = DWR(X_train_np, cov_mask=cov_mask, order=args.order, num_steps=args.iters_balance, logger=logger, device=device) 134 | elif args.reweighting == "SRDO": 135 | W = SRDO(X_train_np, p_s, hidden_layer_sizes = (38, 9), decorrelation_type="global", max_iter=args.iters_balance) 136 | else: 137 | W = np.ones((n, 1))/n 138 | 139 | mean_value = np.mean(W) 140 | W = W * (1/mean_value) 141 | 142 | W = np.clip(W, 0.1, 4) 143 | 144 | columns = X_train_pd.columns 145 | all_X = np.concatenate((X_train_pd, W), axis=1) 146 | X_train_test = pd.DataFrame(all_X, columns=list(columns)+["Weights"]) 147 | 148 | 149 | test_W = np.ones((test0_data.shape[0], 1)) 150 | test0_data = np.concatenate((test0_data, test_W), axis=1) 151 | test0_pd = pd.DataFrame(test0_data, columns=list(columns)+["Weights"]) 152 | 153 | 154 | test_W = np.ones((test1_data.shape[0], 1)) 155 | test1_data = np.concatenate((test1_data, test_W), axis=1) 156 | test1_pd = pd.DataFrame(test1_data, columns=list(columns)+["Weights"]) 157 | 158 | 159 | test_W = np.ones((test2_data.shape[0], 1)) 160 | test2_data = np.concatenate((test2_data, test_W), axis=1) 161 | test2_pd = pd.DataFrame(test2_data, columns=list(columns)+["Weights"]) 162 | 163 | 164 | test_W = np.ones((test3_data.shape[0], 1)) 165 | test3_data = np.concatenate((test3_data, test_W), axis=1) 166 | test3_pd = pd.DataFrame(test3_data, columns=list(columns)+["Weights"]) 167 | 168 | test_W = np.ones((test4_data.shape[0], 1)) 169 | test4_data = np.concatenate((test4_data, test_W), axis=1) 170 | test4_pd = pd.DataFrame(test4_data, columns=list(columns)+["Weights"]) 171 | 172 | test_W = np.ones((test5_data.shape[0], 1)) 173 | test5_data = np.concatenate((test5_data, test_W), axis=1) 174 | test5_pd = pd.DataFrame(test5_data, columns=list(columns)+["Weights"]) 175 | 176 | test_W = np.ones((test6_data.shape[0], 1)) 177 | test6_data = np.concatenate((test6_data, test_W), axis=1) 178 | test6_pd = pd.DataFrame(test6_data, columns=list(columns)+["Weights"]) 179 | 180 | test_W = np.ones((test7_data.shape[0], 1)) 181 | test7_data = np.concatenate((test7_data, test_W), axis=1) 182 | test7_pd = pd.DataFrame(test7_data, columns=list(columns)+["Weights"]) 183 | 184 | results = dict() 185 | if args.paradigm == "regr": 186 | mask = [True,]*p 187 | model_func = get_algorithm_class(args.backend) 188 | model = model_func(X_train_pd, duration_col, event_col, W, 0.0004, **vars(args)) 189 | 190 | elif args.paradigm == "fs": 191 | if args.fs_type == "STG": 192 | stg = STG(p, 1, sigma=args.sigma_STG, lam=args.lam_STG) 193 | stg.train(X_train, Y_train, W=W, epochs=5000) 194 | select_ratio = stg.get_ratios().detach().numpy() 195 | logger.info("Select ratio: " + pretty(select_ratio)) 196 | mask = select_ratio > args.mask_threshold 197 | elif args.fs_type == "oracle": 198 | mask = oracle_mask 199 | elif args.fs_type == "None": 200 | mask = [True,]*p 201 | elif args.fs_type == "given": 202 | mask = np.array(args.mask_given, np.bool) 203 | else: 204 | raise NotImplementedError 205 | if np.array(mask, dtype=np.int64).sum() == 0: 206 | logger.info("All variables are discarded!") 207 | assert False 208 | logger.info("Hard selection: " + str(np.array(mask, dtype=np.int64))) 209 | model_func = get_algorithm_class(args.backend) 210 | model = model_func(X_train, Y_train, np.ones((n, 1))/n, **vars(args)) 211 | model.fit(X_train[:, mask], Y_train) 212 | else: 213 | raise NotImplementedError 214 | 215 | # test 216 | summary = model.summary 217 | coef = summary["coef"] 218 | 219 | c_index_dict = [] 220 | 221 | c_index0 = model.score(test0_pd, scoring_method='concordance_index') 222 | c_index_dict.append(c_index0) 223 | 224 | 225 | c_index1 = model.score(test1_pd, scoring_method='concordance_index') 226 | c_index_dict.append(c_index1) 227 | 228 | 229 | c_index4 = model.score(test4_pd, scoring_method='concordance_index') 230 | c_index_dict.append(c_index4) 231 | 232 | 233 | c_index5 = model.score(test5_pd, scoring_method='concordance_index') 234 | c_index_dict.append(c_index5) 235 | 236 | 237 | c_index6 = model.score(test6_pd, scoring_method='concordance_index') 238 | c_index_dict.append(c_index6) 239 | 240 | c_index7 = model.score(test7_pd, scoring_method='concordance_index') 241 | c_index_dict.append(c_index7) 242 | 243 | print("c_index") 244 | print("\n".join([str(s) for s in c_index_dict])) 245 | mean_acc = np.mean(c_index_dict) 246 | std_acc = np.std(c_index_dict) 247 | worst_acc = min(c_index_dict) 248 | print("mean_acc", mean_acc) 249 | print("std_acc", std_acc) 250 | print("worst_acc", worst_acc) 251 | 252 | return results 253 | 254 | if __name__ == "__main__": 255 | args = get_args() 256 | setup_seed(args.seed) 257 | expname = get_expname(args) 258 | os.makedirs(os.path.join(args.result_dir, expname), exist_ok=True) 259 | logger = Logger(args) 260 | logger.log_args(args) 261 | 262 | p = args.p 263 | p_v = int(p*args.V_ratio) 264 | p_s = p-p_v 265 | beta_s = get_beta_s(p_s) 266 | beta_v = np.zeros(p_v) 267 | beta = np.concatenate([beta_s, beta_v]) 268 | results_list = dd(list) 269 | for i in range(args.times): 270 | logger.info("Round %d" % i) 271 | results = main(args, i, logger) 272 | for k, v in results.items(): 273 | results_list[k].append(v) 274 | 275 | 276 | logger.info("Final Result:") 277 | for k, v in results_list.items(): 278 | if k == "RMSE": 279 | RMSE_dict = dict() 280 | for r_test in args.r_list: 281 | RMSE = [v[i][r_test] for i in range(args.times)] 282 | RMSE_dict[r_test] = sum(RMSE)/len(RMSE) 283 | logger.info("RMSE average: %.3f" % (np.mean(list(RMSE_dict.values())))) 284 | logger.info("RMSE std: %.3f" % ((np.std(list(RMSE_dict.values()))))) 285 | logger.info("RMSE max: %.3f" % ((np.max(list(RMSE_dict.values()))))) 286 | logger.info("Detailed RMSE:") 287 | for r_test in args.r_list: 288 | logger.info("%.1f: %.3f" % (r_test, RMSE_dict[r_test])) 289 | elif k == "beta_hat": 290 | beta_hat_array = np.array(v) 291 | beta_hat_mean = np.mean(beta_hat_array, axis=0) 292 | logger.info("%s: %s" % (k, beta_hat_mean)) 293 | if args.bv_analysis: 294 | bv_dict = dict() 295 | bv_dict["s"] = BV_analysis(beta_hat_array[:,:p_s], beta[:p_s]) 296 | bv_dict["v"] = BV_analysis(beta_hat_array[:,p_s:], beta[p_s:]) 297 | bv_dict["all"] = BV_analysis(beta_hat_array, beta) 298 | for covariates in ["s", "v", "all"]: 299 | logger.info("Bias for %s: %.4f, variance for %s: %.4f" % (covariates, bv_dict[covariates][0], covariates, bv_dict[covariates][1])) 300 | else: 301 | logger.info("%s: %.3f" % (k, sum(v)/len(v))) 302 | 303 | 304 | -------------------------------------------------------------------------------- /mRNA_HCC_OS.py: -------------------------------------------------------------------------------- 1 | from data.selection_bias import gen_selection_bias_data 2 | from algorithm.DWR import DWR 3 | from algorithm.SRDO import SRDO 4 | from model.linear import get_algorithm_class 5 | from metrics import get_metric_class 6 | from utils import setup_seed, get_beta_s, get_expname, calc_var, pretty, get_cov_mask, BV_analysis 7 | from Logger import Logger 8 | from model.STG import STG 9 | from sksurv.metrics import brier_score, cumulative_dynamic_auc 10 | from sklearn.metrics import mean_squared_error 11 | import numpy as np 12 | import argparse 13 | import os 14 | import torch 15 | from collections import defaultdict as dd 16 | import pandas as pd 17 | from sklearn.feature_selection import SelectKBest, f_classif 18 | from sklearn.model_selection import train_test_split 19 | from lifelines import CoxPHFitter 20 | import seaborn as sns 21 | import matplotlib.pyplot as plt 22 | from lifelines.statistics import logrank_test 23 | from lifelines import KaplanMeierFitter 24 | from sksurv.util import Surv 25 | from lifelines.utils import concordance_index 26 | from sklearn.metrics import roc_auc_score 27 | from sklearn.metrics import accuracy_score 28 | 29 | duration_col = 'Survival.months' 30 | event_col = 'Survival.status' 31 | def get_args(): 32 | parser = argparse.ArgumentParser(description="Script to launch sample reweighting experiments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 33 | 34 | # data generation 35 | parser.add_argument("--p", type=int, default=10, help="Input dim") 36 | parser.add_argument("--n", type=int, default=2000, help="Sample size") 37 | parser.add_argument("--V_ratio", type=float, default=0.5) 38 | parser.add_argument("--Vb_ratio", type=float, default=0.1) 39 | parser.add_argument("--true_func", choices=["linear",], default="linear") 40 | parser.add_argument("--mode", choices=["S_|_V", "S->V", "V->S", "collinearity"], default="collinearity") 41 | parser.add_argument("--misspe", choices=["poly", "exp", "None"], default="poly") 42 | parser.add_argument("--corr_s", type=float, default=0.9) 43 | parser.add_argument("--corr_v", type=float, default=0.1) 44 | parser.add_argument("--mms_strength", type=float, default=1.0, help="model misspecifction strength") 45 | parser.add_argument("--spurious", choices=["nonlinear", "linear"], default="nonlinear") 46 | parser.add_argument("--r_train", type=float, default=2.5, help="Input dim") 47 | parser.add_argument("--r_list", type=float, nargs="+", default=[-3, -2, -1.7, -1.5, -1.3, 1.3, 1.5, 1.7, 2, 3]) 48 | parser.add_argument("--noise_variance", type=float, default=0.3) 49 | 50 | # frontend reweighting 51 | parser.add_argument("--reweighting", choices=["None", "DWR", "SRDO"], default="DWR") 52 | parser.add_argument("--decorrelation_type", choices=["global", "group"], default="global") 53 | parser.add_argument("--order", type=int, default=1) 54 | parser.add_argument("--iters_balance", type=int, default=3000) 55 | 56 | parser.add_argument("--topN", type=int, default=10) 57 | # backend model 58 | parser.add_argument("--backend", choices=["OLS", "Lasso", "Ridge", "Weighted_cox"], default="Weighted_cox") 59 | parser.add_argument("--paradigm", choices=["regr", "fs",], default="regr") 60 | parser.add_argument("--iters_train", type=int, default=5000) 61 | parser.add_argument("--lam_backend", type=float, default=0.01) # regularizer coefficient 62 | parser.add_argument("--fs_type", choices=["oracle", "None", "given", "STG"], default="STG") 63 | parser.add_argument("--mask_given", type=int, nargs="+", default=[1,1,1,1,1,0,0,0,0,0]) 64 | parser.add_argument("--mask_threshold", type=float, default=0.2) 65 | parser.add_argument("--lam_STG", type=float, default=3) 66 | parser.add_argument("--sigma_STG", type=float, default=0.1) 67 | parser.add_argument("--metrics", nargs="+", default=["L1_beta_error", "L2_beta_error"]) 68 | parser.add_argument("--bv_analysis", action="store_true") 69 | 70 | # others 71 | parser.add_argument("--seed", type=int, default=6) 72 | parser.add_argument("--times", type=int, default=1) 73 | parser.add_argument("--result_dir", default="results") 74 | 75 | return parser.parse_args() 76 | 77 | def optimal(coef_value, selected_data, per): 78 | 79 | percentile = np.percentile(coef_value, per) 80 | result = np.where(coef_value >= percentile, 1, 0) 81 | all_X = np.concatenate((selected_data[selected_data.columns[:2]], result), axis=1) 82 | all_X = pd.DataFrame(all_X, columns=list(selected_data.columns[:2])+["groups"]) 83 | cph = CoxPHFitter(penalizer=0.0001) 84 | cph.fit(all_X, duration_col='Recurr.months', event_col='Recurr.status') 85 | summary = cph.summary 86 | 87 | HR = summary["exp(coef)"] 88 | group1 = selected_data[result==1][selected_data.columns[:2]] 89 | group2 = selected_data[result==0][selected_data.columns[:2]] 90 | results = logrank_test(group1["Recurr.months"], group2["Recurr.months"], event_observed_A=group1['Recurr.status'], event_observed_B=group2["Recurr.status"]) 91 | p_value = results.summary['p'].iloc[0] 92 | 93 | return HR, p_value 94 | 95 | def plot_KM(cph, sorted_indices, coef, selected_data): 96 | tmp_sel = selected_data[list(sorted_indices)] 97 | coef_value = np.dot(np.array(tmp_sel), np.array(coef)).reshape((-1, 1)) 98 | 99 | c_index = concordance_index(selected_data['Survival.months'], -cph.predict_partial_hazard(selected_data), selected_data['Survival.status']) 100 | 101 | return c_index 102 | 103 | def main(args, round, logger): 104 | setup_seed(args.seed + round) 105 | p = args.p 106 | p_v = int(p*args.V_ratio) 107 | p_s = p-p_v 108 | n = args.n 109 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 110 | oracle_mask = [True,]*p_s + [False,]*p_v 111 | 112 | # generate train data 113 | X_train, S_train, V_train, fs_train, Y_train = gen_selection_bias_data({**vars(args),**{"r": args.r_train}}) 114 | 115 | beta_s = get_beta_s(p_s) 116 | beta_v = np.zeros(p_v) 117 | beta = np.concatenate([beta_s, beta_v]) 118 | 119 | linear_var, nonlinear_var, total_var = calc_var(beta_s, S_train, fs_train) 120 | logger.info("Linear term var: %.3f, Nonlinear term var: %.3f, total var: %.3f" % (linear_var, nonlinear_var, total_var)) 121 | 122 | # generate test data 123 | test_data = dict() 124 | for r_test in args.r_list: 125 | X_test, S_test, V_test, fs_test, Y_test = gen_selection_bias_data({**vars(args),**{"r": r_test}}) 126 | test_data[r_test] = (X_test, S_test, V_test, fs_test, Y_test) 127 | 128 | 129 | 130 | training_pd_data = pd.read_csv('./omics_data/HCC_cancer/train_median.csv', index_col=0) 131 | test0_pd_data = pd.read_csv('./omics_data/HCC_cancer/test0_median.csv', index_col=0) 132 | test1_pd_data = pd.read_csv('./omics_data/HCC_cancer/test1_median.csv', index_col=0) 133 | test2_pd_data = pd.read_csv('./omics_data/HCC_cancer/test2_median.csv', index_col=0) 134 | test3_pd_data = pd.read_csv('./omics_data/HCC_cancer/test3_median.csv', index_col=0) 135 | 136 | 137 | 138 | top_3_cols = training_pd_data.columns[2:] 139 | 140 | #top_3_cols = top_3_cols[:args.topN] 141 | # 选择方差最大的三列 142 | 143 | X_train_pd = pd.concat([training_pd_data[training_pd_data.columns[:2]], training_pd_data[top_3_cols]], axis=1) 144 | X_train_np = np.array(training_pd_data[top_3_cols]) 145 | X_test0_pd = pd.concat([test0_pd_data[test0_pd_data.columns[:2]], test0_pd_data[top_3_cols]], axis=1) 146 | X_test1_pd = pd.concat([test1_pd_data[test1_pd_data.columns[:2]], test1_pd_data[top_3_cols]], axis=1) 147 | 148 | X_test2_pd = pd.concat([test2_pd_data[test2_pd_data.columns[:2]], test2_pd_data[top_3_cols]], axis=1) 149 | 150 | 151 | X_test3_pd = pd.concat([test3_pd_data[test3_pd_data.columns[:2]], test3_pd_data[top_3_cols]], axis=1) 152 | 153 | 154 | p = X_train_np.shape[1] 155 | n= X_train_np.shape[0] 156 | if args.reweighting == "DWR": 157 | if args.decorrelation_type == "global": 158 | cov_mask = get_cov_mask(np.zeros(p)) 159 | elif args.decorrelation_type == "group": 160 | cov_mask = get_cov_mask(np.array(oracle_mask, np.float)) 161 | else: 162 | raise NotImplementedError 163 | W = DWR(X_train_np, cov_mask=cov_mask, order=args.order, num_steps=args.iters_balance, logger=logger, device=device) 164 | elif args.reweighting == "SRDO": 165 | W = SRDO(X_train_np, p_s, hidden_layer_sizes = (98, 11), decorrelation_type=args.decorrelation_type, max_iter=args.iters_balance) 166 | else: 167 | W = np.ones((n, 1))/n 168 | 169 | mean_value = np.mean(W) 170 | W = W * (1/mean_value) 171 | #print("max*********", np.max(W)) 172 | #print("min********", np.min(W)) 173 | 174 | 175 | #W = (W - min_value)/(max_value-min_value) 176 | #W = W * 2 177 | #print("W", W) 178 | 179 | W = np.clip(W, 0.4, 4) 180 | columns = X_train_pd.columns 181 | all_X = np.concatenate((X_train_pd, W), axis=1) 182 | 183 | X_train_test = pd.DataFrame(all_X, columns=list(columns)+["Weights"]) 184 | test0_W = np.ones((X_test0_pd.shape[0], 1)) 185 | 186 | X_test0_pd = np.concatenate((X_test0_pd, test0_W), axis=1) 187 | 188 | X_test0_pd = pd.DataFrame(X_test0_pd, columns=list(columns)+["Weights"]) 189 | 190 | test1_W = np.ones((X_test1_pd.shape[0], 1)) 191 | 192 | X_test1_pd = np.concatenate((X_test1_pd, test1_W), axis=1) 193 | 194 | X_test1_pd = pd.DataFrame(X_test1_pd, columns=list(columns)+["Weights"]) 195 | 196 | test2_W = np.ones((X_test2_pd.shape[0], 1)) 197 | 198 | X_test2_pd = np.concatenate((X_test2_pd, test2_W), axis=1) 199 | 200 | X_test2_pd = pd.DataFrame(X_test2_pd, columns=list(columns)+["Weights"]) 201 | 202 | test3_W = np.ones((X_test3_pd.shape[0], 1)) 203 | X_test3_pd = np.concatenate((X_test3_pd, test3_W), axis=1) 204 | X_test3_pd = pd.DataFrame(X_test3_pd, columns=list(columns)+["Weights"]) 205 | 206 | 207 | results = dict() 208 | if args.paradigm == "regr": 209 | mask = [True,]*p 210 | model_func = get_algorithm_class(args.backend) 211 | model = model_func(X_train_pd, duration_col, event_col, W, 0.0005, **vars(args)) 212 | print("train score", model.score(X_train_test)) 213 | 214 | elif args.paradigm == "fs": 215 | if args.fs_type == "STG": 216 | stg = STG(p, 1, sigma=args.sigma_STG, lam=args.lam_STG) 217 | stg.train(X_train, Y_train, W=W, epochs=5000) 218 | select_ratio = stg.get_ratios().detach().numpy() 219 | logger.info("Select ratio: " + pretty(select_ratio)) 220 | mask = select_ratio > args.mask_threshold 221 | elif args.fs_type == "oracle": 222 | mask = oracle_mask 223 | elif args.fs_type == "None": 224 | mask = [True,]*p 225 | elif args.fs_type == "given": 226 | mask = np.array(args.mask_given, np.bool) 227 | else: 228 | raise NotImplementedError 229 | if np.array(mask, dtype=np.int64).sum() == 0: 230 | logger.info("All variables are discarded!") 231 | assert False 232 | logger.info("Hard selection: " + str(np.array(mask, dtype=np.int64))) 233 | model_func = get_algorithm_class(args.backend) 234 | model = model_func(X_train, Y_train, np.ones((n, 1))/n, **vars(args)) 235 | model.fit(X_train[:, mask], Y_train) 236 | else: 237 | raise NotImplementedError 238 | 239 | # test 240 | summary = model.summary 241 | sorted_indices = summary['p'].sort_values().head(args.topN).index 242 | 243 | print("\n".join([str(tmp) for tmp in list(sorted_indices)])) 244 | coef = summary.loc[list(sorted_indices)]["coef"] 245 | 246 | selected_X_train = pd.concat([X_train_pd[X_train_pd.columns[:2]], X_train_pd[sorted_indices]], axis=1) 247 | cph2 = CoxPHFitter(penalizer=0.1) 248 | cph2.fit(selected_X_train, duration_col='Survival.months', event_col='Survival.status') 249 | 250 | summary = cph2.summary 251 | coef = summary["coef"] 252 | 253 | 254 | c_index_dict = [] 255 | print("test1") 256 | c_index1 = plot_KM(cph2, sorted_indices, coef, X_test1_pd) 257 | print(c_index1) 258 | 259 | c_index_dict.append(c_index1) 260 | 261 | print("test2") 262 | c_index2 = plot_KM(cph2, sorted_indices, coef, X_test2_pd) 263 | c_index_dict.append(c_index2) 264 | 265 | print(c_index2) 266 | 267 | print("test3") 268 | c_index3 = plot_KM(cph2, sorted_indices, coef, X_test3_pd) 269 | 270 | print(c_index3) 271 | 272 | c_index_dict.append(c_index3) 273 | 274 | 275 | print("c_index") 276 | mean_acc = np.mean(c_index_dict) 277 | std_acc = np.std(c_index_dict) 278 | worst_acc = min(c_index_dict) 279 | print(mean_acc) 280 | print(std_acc) 281 | print(worst_acc) 282 | 283 | 284 | 285 | return results 286 | 287 | if __name__ == "__main__": 288 | args = get_args() 289 | setup_seed(args.seed) 290 | expname = get_expname(args) 291 | os.makedirs(os.path.join(args.result_dir, expname), exist_ok=True) 292 | logger = Logger(args) 293 | logger.log_args(args) 294 | 295 | p = args.p 296 | p_v = int(p*args.V_ratio) 297 | p_s = p-p_v 298 | beta_s = get_beta_s(p_s) 299 | beta_v = np.zeros(p_v) 300 | beta = np.concatenate([beta_s, beta_v]) 301 | results_list = dd(list) 302 | for i in range(args.times): 303 | logger.info("Round %d" % i) 304 | results = main(args, i, logger) 305 | for k, v in results.items(): 306 | results_list[k].append(v) 307 | 308 | 309 | logger.info("Final Result:") 310 | for k, v in results_list.items(): 311 | if k == "RMSE": 312 | RMSE_dict = dict() 313 | for r_test in args.r_list: 314 | RMSE = [v[i][r_test] for i in range(args.times)] 315 | RMSE_dict[r_test] = sum(RMSE)/len(RMSE) 316 | logger.info("RMSE average: %.3f" % (np.mean(list(RMSE_dict.values())))) 317 | logger.info("RMSE std: %.3f" % ((np.std(list(RMSE_dict.values()))))) 318 | logger.info("RMSE max: %.3f" % ((np.max(list(RMSE_dict.values()))))) 319 | logger.info("Detailed RMSE:") 320 | for r_test in args.r_list: 321 | logger.info("%.1f: %.3f" % (r_test, RMSE_dict[r_test])) 322 | elif k == "beta_hat": 323 | beta_hat_array = np.array(v) 324 | beta_hat_mean = np.mean(beta_hat_array, axis=0) 325 | logger.info("%s: %s" % (k, beta_hat_mean)) 326 | if args.bv_analysis: 327 | bv_dict = dict() 328 | bv_dict["s"] = BV_analysis(beta_hat_array[:,:p_s], beta[:p_s]) 329 | bv_dict["v"] = BV_analysis(beta_hat_array[:,p_s:], beta[p_s:]) 330 | bv_dict["all"] = BV_analysis(beta_hat_array, beta) 331 | for covariates in ["s", "v", "all"]: 332 | logger.info("Bias for %s: %.4f, variance for %s: %.4f" % (covariates, bv_dict[covariates][0], covariates, bv_dict[covariates][1])) 333 | else: 334 | logger.info("%s: %.3f" % (k, sum(v)/len(v))) 335 | 336 | 337 | -------------------------------------------------------------------------------- /clinical_data/lung_cancer/test0_pd.csv: -------------------------------------------------------------------------------- 1 | ,Sex,Age,Location,"阻塞性肺炎/肺不张Obst pn 2 | or plugging","CT value 3 | Mean","CT value 4 | Std Dev","CT value 5 | P.major",CT V ratio,CT Kernel,Multiplicity多重性,Effusion胸膜积液,pos_LN,total_LN,LN ratio淋巴结比例,cT,cN,cM,pT病例T阶段,pN病理N阶段,pM,FVC用力肺活量 %PRED,FEV1 %PRED,FEV1一秒用力呼气量/FVC 用力肺活量(%) ,CEA癌胚抗原,post-op CTx术后CT,post-op RTx,Necrosis坏死_0,Necrosis坏死_1,Necrosis坏死_2,"Underlying 6 | lung_0","Underlying 7 | lung_1","Underlying 8 | lung_2","Underlying 9 | lung_3","Underlying 10 | lung_4",Bronchoscopy 支气管镜检_1,Bronchoscopy 支气管镜检_2,Bronchoscopy 支气管镜检_3,Bronchoscopy 支气管镜检_4,Differentiation分化_1,Differentiation分化_2,Differentiation分化_3,Differentiation分化_4,Smoking state_0,Smoking state_1,Smoking state_2,Smoking state_3,op type手术类型_1,op type手术类型_2,op type手术类型_3,op type手术类型_4,op type手术类型_5,Reccur.status,Reccur.months,Survival.status,Survival.months 11 | 0,1.0,0.46808510638297873,0.0,1.0,0.3922495274102079,0.17814726840855108,0.4074074074074075,0.44558787865909766,1.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.3333333333333333,0.0,0.4,0.36752136752136755,0.6268656716417911,0.004181084198385236,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1,2.16666666666667,0,2.16666666666667 12 | 1,1.0,0.40425531914893614,0.0,1.0,0.3563327032136105,0.166270783847981,0.5685185185185185,0.32353830627417807,1.0,0.0,0.0,0.0,0.1794871794871795,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.3411764705882353,0.41025641025641024,0.7761194029850746,0.008362168396770472,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1,127.633333333333,1,127.633333333333 13 | 2,1.0,0.2553191489361702,0.0,1.0,0.3648393194706994,0.15676959619952496,0.562962962962963,0.3338973968009609,0.0,0.0,0.0,0.0,0.07692307692307693,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.4117647058823529,0.5128205128205128,0.8507462686567164,0.0008650519031141868,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,2.13333333333333,0,23.8 14 | 3,1.0,0.1276595744680851,0.0,1.0,0.7986767485822306,0.03919239904988124,0.6314814814814814,0.6847306192818725,0.0,0.0,0.0,0.0,0.15384615384615385,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.3764705882352941,0.2905982905982906,0.582089552238806,0.006920415224913494,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,84.6666666666667,0,84.6666666666667 15 | 4,1.0,0.46808510638297873,0.0,0.0,0.14839319470699433,0.14964370546318292,0.6074074074074074,0.12127409676251642,1.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.1411764705882353,0.18803418803418803,0.6119402985074627,0.002825836216839677,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,77.6666666666667,1,77.6666666666667 16 | 5,1.0,0.46808510638297873,0.0,1.0,0.6625708884688091,0.0641330166270784,0.4666666666666667,0.6869300601555564,0.0,0.0,0.0,0.0,0.0,0.0,0.6666666666666666,0.0,0.0,0.6666666666666666,0.3333333333333333,0.0,0.7294117647058823,0.48717948717948717,0.5074626865671642,0.0024221453287197234,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1,108.0,1,108.0 17 | 6,1.0,0.3404255319148936,0.0,0.0,0.2655954631379962,0.1959619952494062,0.48888888888888893,0.2699950842447464,1.0,0.0,0.0,0.0,0.1282051282051282,0.0,1.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.6823529411764706,0.6239316239316239,0.7313432835820896,0.00196078431372549,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,132.3,1,132.3 18 | 7,1.0,0.2765957446808511,0.0,1.0,0.6814744801512287,0.21852731591448932,0.6499999999999999,0.5704855335193963,1.0,0.0,0.0,0.0,0.07692307692307693,0.0,0.6666666666666666,0.3333333333333333,0.0,0.3333333333333333,0.3333333333333333,0.0,0.3058823529411765,0.41025641025641024,0.8059701492537313,0.009890426758938869,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,8.26666666666667,0,11.3 19 | 8,1.0,0.44680851063829785,0.0,1.0,0.604914933837429,0.3016627078384798,0.7685185185185185,0.4441437020395065,1.0,0.0,0.0,0.0,0.1282051282051282,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.5294117647058824,0.24786324786324787,0.16417910447761194,0.0057670126874279125,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,109.8,1,109.8 20 | 9,1.0,0.23404255319148937,0.0,1.0,0.33931947069943286,0.2125890736342043,0.5092592592592593,0.33413694137996847,1.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.5647058823529412,0.5470085470085471,0.5074626865671642,0.0063437139561707025,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,128.533333333333,1,128.533333333333 21 | 10,1.0,0.2765957446808511,0.0,1.0,0.5028355387523629,0.08313539192399051,0.5444444444444445,0.47327675358458726,1.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.3333333333333333,0.0,0.6352941176470588,0.47863247863247865,0.373134328358209,0.007900807381776238,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,105.466666666667,1,105.466666666667 22 | 11,1.0,0.14893617021276595,0.0,0.0,0.4480151228733459,0.26840855106888367,0.5611111111111111,0.41239791802277836,1.0,0.0,0.0,0.0,0.38461538461538464,0.0,0.3333333333333333,0.0,0.0,0.6666666666666666,0.0,0.0,0.18823529411764706,0.27350427350427353,0.47761194029850745,0.002566320645905421,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,109.266666666667,1,109.266666666667 23 | 12,1.0,0.3404255319148936,0.0,1.0,0.8128544423440452,0.3456057007125891,0.6981481481481482,0.6502644106596696,1.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.3333333333333333,0.0,0.12941176470588237,0.3162393162393162,0.6567164179104478,0.0038638985005767013,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,71.1,1,71.1 24 | 13,1.0,0.425531914893617,1.0,0.0,0.5689981096408316,0.5083135391923992,0.5481481481481482,0.5336559111813246,1.0,1.0,0.0,0.0,0.3076923076923077,0.0,1.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.47058823529411764,0.36752136752136755,0.373134328358209,0.006949250288350633,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,110.433333333333,1,110.433333333333 25 | 14,1.0,0.425531914893617,0.0,1.0,0.4725897920604914,0.37529691211401434,0.5592592592592592,0.4363570923264115,1.0,0.0,0.0,0.14285714285714285,0.0,1.0,0.6666666666666666,0.0,0.0,0.6666666666666666,0.6666666666666666,0.0,0.3411764705882353,0.3162393162393162,0.373134328358209,0.0053344867358708185,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,112.266666666667,1,112.266666666667 26 | 15,1.0,0.425531914893617,0.0,0.0,0.4905482041587901,0.1163895486935867,0.7574074074074072,0.35939443611139005,1.0,0.0,0.0,0.0,0.1282051282051282,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.4588235294117647,0.41025641025641024,0.40298507462686567,0.00464244521337947,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,32.2,1,106.0 27 | 16,1.0,0.3191489361702128,0.0,1.0,0.5349716446124764,0.30641330166270786,0.5833333333333334,0.4805297646137765,1.0,1.0,0.0,0.0,0.23076923076923078,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.5764705882352941,0.5982905982905983,0.582089552238806,0.010380622837370242,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,11.2,1,106.633333333333 28 | 17,1.0,0.1702127659574468,0.0,1.0,0.34877126654064267,0.43705463182897875,0.5462962962962964,0.32625095780815017,1.0,0.0,0.0,0.0,0.1794871794871795,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.3333333333333333,0.0,0.4235294117647059,0.47863247863247865,0.5522388059701493,0.01046712802768166,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,46.4,1,46.4 29 | 18,1.0,0.3617021276595745,0.0,1.0,0.329867674858223,0.07363420427553444,0.46851851851851856,0.3444872806546442,0.0,0.0,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.6666666666666666,0.3333333333333333,0.0,0.6470588235294118,0.49572649572649574,0.31343283582089554,0.002970011534025375,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1,95.5333333333333,1,95.5333333333333 30 | 19,1.0,0.3191489361702128,0.0,1.0,0.2948960302457466,0.8087885985748219,0.6129629629629628,0.25032551746938875,1.0,0.0,0.0,0.2857142857142857,0.05128205128205128,0.666666666666667,0.3333333333333333,0.6666666666666666,0.0,0.3333333333333333,0.6666666666666666,0.0,0.17647058823529413,0.3418803418803419,0.582089552238806,0.0014705882352941176,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0,0.0,0,32.7333333333333 31 | 20,1.0,0.425531914893617,1.0,0.0,0.5841209829867675,0.45486935866983386,0.6148148148148148,0.5065742020467052,1.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.4470588235294118,0.4017094017094017,0.40298507462686567,0.0057670126874279125,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,77.5333333333333,1,106.2 32 | 21,0.0,0.3191489361702128,0.0,0.0,0.830812854442344,0.14370546318289787,0.6314814814814814,0.7128487090466595,1.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.0,0.3333333333333333,0.0,0.0,0.0,0.0,0.4823529411764706,0.41025641025641024,0.417910447761194,0.0057670126874279125,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0,56.7333333333333,1,94.0 33 | 22,1.0,0.425531914893617,0.0,0.0,0.3969754253308128,0.8420427553444181,0.2555555555555556,0.5708068705892841,1.0,0.0,0.0,0.0,1.0,0.0,0.6666666666666666,0.0,0.0,0.6666666666666666,0.0,0.0,0.15294117647058825,0.05128205128205128,0.11940298507462686,0.0016147635524798155,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,6.73333333333333,0,6.73333333333333 34 | 23,1.0,0.2978723404255319,1.0,0.0,0.7844990548204158,0.09382422802850358,0.4092592592592592,0.8728608782579917,0.0,0.0,0.0,0.0,0.1282051282051282,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.5176470588235295,0.5897435897435898,0.5970149253731343,0.009890426758938869,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,21.1666666666667,1,21.1666666666667 35 | 24,1.0,0.44680851063829785,1.0,0.0,0.6200378071833649,0.7731591448931117,0.3833333333333333,0.7179288681218096,1.0,0.0,0.0,0.0,0.23076923076923078,0.0,0.3333333333333333,0.0,0.0,0.6666666666666666,0.0,0.0,0.47058823529411764,0.36752136752136755,0.3283582089552239,0.007497116493656285,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,96.3666666666667,1,96.3666666666667 36 | 25,1.0,0.44680851063829785,0.0,1.0,0.7561436672967863,0.4049881235154395,0.42777777777777776,0.8220751246805015,1.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.3333333333333333,0.6666666666666666,0.0,0.3333333333333333,0.6666666666666666,0.0,0.43529411764705883,0.5641025641025641,0.5671641791044776,0.0034313725490196074,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,98.2,1,98.2 37 | 26,1.0,0.0,0.0,1.0,0.388468809073724,0.12351543942992876,0.5425925925925926,0.3656506492928623,1.0,0.0,0.0,0.0,0.15384615384615385,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.611764705882353,0.5726495726495726,0.5223880597014925,0.001182237600922722,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,28.8333333333333,1,96.2 38 | 27,1.0,0.2765957446808511,1.0,0.0,0.14839319470699433,0.4049881235154395,0.2666666666666667,0.23734052353656607,1.0,0.0,0.0,0.0,0.20512820512820512,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.3411764705882353,0.5042735042735043,0.746268656716418,0.0030565167243367937,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,15.0666666666667,0,23.9333333333333 39 | 28,0.0,0.40425531914893614,0.0,1.0,0.551039697542533,0.13776722090261284,0.5703703703703704,0.502959791624409,0.0,0.0,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.0,0.3333333333333333,0.0,0.8,0.5555555555555556,0.2835820895522388,0.003171856978085352,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,61.6333333333333,1,61.6333333333333 40 | 29,1.0,0.2978723404255319,0.0,0.0,0.6011342155009451,0.3622327790973872,0.4074074074074075,0.6742369509268931,1.0,0.0,0.0,0.0,0.20512820512820512,0.0,0.3333333333333333,0.0,0.0,0.6666666666666666,0.0,0.0,0.3176470588235294,0.4358974358974359,0.582089552238806,0.0028835063437139563,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1,0.4,0,0.4 41 | 30,1.0,0.425531914893617,0.0,1.0,0.9319470699432892,0.21140142517814725,0.688888888888889,0.7555632649678744,0.0,0.0,0.0,0.0,0.10256410256410256,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.3333333333333333,0.0,0.4823529411764706,0.48717948717948717,0.4626865671641791,0.00461361014994233,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,77.0333333333333,1,77.0333333333333 42 | 31,1.0,0.3191489361702128,1.0,0.0,0.32703213610586007,0.31591448931116395,0.48518518518518516,0.33339969615887177,1.0,0.0,0.0,0.0,0.23076923076923078,0.0,0.6666666666666666,0.0,0.0,0.6666666666666666,0.0,0.0,0.47058823529411764,0.6581196581196581,0.6716417910447762,0.003748558246828143,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,14.4666666666667,0,30.7 43 | 32,1.0,0.2553191489361702,0.0,1.0,0.2826086956521739,0.03800475059382422,0.2666666666666667,0.4117566025136392,0.0,0.0,0.0,0.14285714285714285,0.05128205128205128,0.333333333333333,0.3333333333333333,0.0,0.0,0.3333333333333333,0.6666666666666666,0.0,0.36470588235294116,0.21367521367521367,0.208955223880597,0.003171856978085352,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,13.5333333333333,0,39.2 44 | 33,1.0,0.40425531914893614,0.0,0.0,0.833648393194707,0.11757719714964372,0.5759259259259258,0.7598406444050253,0.0,0.0,0.0,0.0,0.3076923076923077,0.0,0.6666666666666666,0.0,0.0,1.0,0.0,0.0,0.35294117647058826,0.4358974358974359,0.5522388059701493,0.00971741637831603,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,57.2666666666667,1,57.2666666666667 45 | 34,1.0,0.3829787234042553,0.0,0.0,0.5330812854442344,0.41686460807600956,0.8777777777777777,0.34623709029839295,1.0,0.0,0.0,0.0,0.15384615384615385,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.6666666666666666,0.0,0.5764705882352941,0.6837606837606838,0.6865671641791045,0.00720876585928489,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,21.4,0,37.4666666666667 46 | 35,1.0,0.3404255319148936,0.0,1.0,0.2930056710775047,0.2802850356294537,0.35000000000000003,0.36938222674767046,1.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.5647058823529412,0.5811965811965812,0.5074626865671642,0.006920415224913494,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,13.6333333333333,1,58.7 47 | 36,1.0,0.2553191489361702,1.0,0.0,0.5189035916824196,0.27197149643705465,0.19074074074074077,0.814717308226739,1.0,0.0,0.0,0.0,0.9487179487179487,0.0,0.6666666666666666,0.6666666666666666,0.0,0.6666666666666666,0.3333333333333333,0.0,0.4117647058823529,0.5128205128205128,0.6417910447761194,0.0008650519031141868,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1,55.7333333333333,1,55.7333333333333 48 | 37,1.0,0.44680851063829785,0.0,0.0,0.3213610586011342,0.31472684085510694,0.4759259259259259,0.3321454532190895,1.0,0.0,0.0,0.0,0.1282051282051282,0.0,0.6666666666666666,0.0,0.0,0.6666666666666666,0.0,0.0,0.6,0.41025641025641024,0.31343283582089554,0.002306805074971165,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,1.4,0,35.0666666666667 49 | 38,1.0,0.425531914893617,0.0,0.0,0.22117202268431,0.4049881235154395,0.40185185185185185,0.26078558528621165,1.0,0.0,0.0,0.0,0.10256410256410256,0.0,0.3333333333333333,0.6666666666666666,0.0,0.3333333333333333,0.6666666666666666,0.0,0.23529411764705882,0.3418803418803419,0.5373134328358209,0.0033448673587081895,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0.0,0,10.6 50 | 39,1.0,0.3617021276595745,0.0,1.0,0.1918714555765595,0.5902612826603326,0.44814814814814813,0.21048890887401536,1.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.3333333333333333,0.0,0.38823529411764707,0.4700854700854701,0.5373134328358209,0.004296424452133795,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1,60.2333333333333,1,60.2333333333333 51 | 40,1.0,0.3404255319148936,0.0,1.0,0.18998109640831756,0.4133016627078386,0.7370370370370369,0.1273571535473733,1.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.6666666666666666,0.3333333333333333,0.0,0.0,0.6666666666666666,0.0,0.23529411764705882,0.24786324786324787,0.417910447761194,0.0019031141868512107,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0,13.5333333333333,1,48.8666666666667 52 | 41,1.0,0.3617021276595745,1.0,0.0,0.8043478260869565,0.12589073634204276,0.6333333333333334,0.6883275330826187,0.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.3333333333333333,0.6666666666666666,0.0,0.3333333333333333,0.6666666666666666,0.0,0.5058823529411764,0.5128205128205128,0.6119402985074627,0.006055363321799307,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,0.933333333333333,0,0.933333333333333 53 | 42,1.0,0.44680851063829785,0.0,1.0,0.3695652173913043,0.30760095011876487,0.3981481481481482,0.426580964970766,1.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.3333333333333333,0.0,0.2823529411764706,0.24786324786324787,0.373134328358209,0.003892733564013841,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,36.8333333333333,1,50.2 54 | -------------------------------------------------------------------------------- /clinical_data/lung_cancer/test4_pd.csv: -------------------------------------------------------------------------------- 1 | ,Sex,Age,Location,"阻塞性肺炎/肺不张Obst pn 2 | or plugging","CT value 3 | Mean","CT value 4 | Std Dev","CT value 5 | P.major",CT V ratio,CT Kernel,Multiplicity多重性,Effusion胸膜积液,pos_LN,total_LN,LN ratio淋巴结比例,cT,cN,cM,pT病例T阶段,pN病理N阶段,pM,FVC用力肺活量 %PRED,FEV1 %PRED,FEV1一秒用力呼气量/FVC 用力肺活量(%) ,CEA癌胚抗原,post-op CTx术后CT,post-op RTx,Necrosis坏死_0,Necrosis坏死_1,Necrosis坏死_2,"Underlying 6 | lung_0","Underlying 7 | lung_1","Underlying 8 | lung_2","Underlying 9 | lung_3","Underlying 10 | lung_4",Bronchoscopy 支气管镜检_1,Bronchoscopy 支气管镜检_2,Bronchoscopy 支气管镜检_3,Bronchoscopy 支气管镜检_4,Differentiation分化_1,Differentiation分化_2,Differentiation分化_3,Differentiation分化_4,Smoking state_0,Smoking state_1,Smoking state_2,Smoking state_3,op type手术类型_1,op type手术类型_2,op type手术类型_3,op type手术类型_4,op type手术类型_5,Reccur.status,Reccur.months,Survival.status,Survival.months 11 | 0,1.0,0.5957446808510638,1.0,0.0,0.40548204158790163,0.25059382422802856,0.5407407407407407,0.3827567256511303,1.0,0.0,0.0,0.0,0.15384615384615385,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2,0.23931623931623933,0.6119402985074627,0.008910034602076124,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,85.4,1,93.7 12 | 1,1.0,0.5106382978723404,1.0,0.0,0.5793950850661626,0.2434679334916865,0.7203703703703702,0.4465843474979311,0.0,0.0,0.0,0.0,0.15384615384615385,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.3411764705882353,0.47863247863247865,0.8208955223880597,0.004786620530565167,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,116.6,1,116.6 13 | 2,1.0,0.5319148936170213,1.0,0.0,0.3062381852551984,0.2090261282660333,0.5537037037037037,0.28296212276386756,1.0,0.0,0.0,0.0,0.48717948717948717,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.7411764705882353,0.5555555555555556,0.5671641791044776,0.008679354094579006,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,103.3,0,103.3 14 | 3,1.0,0.6808510638297872,1.0,1.0,0.4253308128544423,0.24703087885985747,0.4666666666666667,0.4434099617251588,1.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.6666666666666666,0.3333333333333333,0.0,0.6666666666666666,0.3333333333333333,0.0,0.5058823529411764,0.4444444444444444,0.582089552238806,0.008160322952710495,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,1.13333333333333,0,46.4666666666667 15 | 4,1.0,0.6808510638297872,1.0,0.0,0.7627599243856332,0.18646080760095016,0.7648148148148147,0.5692369439107492,1.0,0.0,0.0,0.0,0.07692307692307693,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.6941176470588235,0.6923076923076923,0.7164179104477612,0.03794694348327566,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,51.4666666666667,1,73.1333333333333 16 | 5,1.0,0.7659574468085106,1.0,0.0,0.72117202268431,0.3016627078384798,0.8037037037037036,0.5164676640122008,1.0,0.0,0.0,0.0,0.07692307692307693,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.3058823529411765,0.3162393162393162,0.3582089552238806,0.009515570934256054,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,19.6333333333333,0,21.4333333333333 17 | 6,1.0,0.8297872340425532,1.0,0.0,0.5075614366729677,0.12351543942992876,0.7962962962962963,0.3573678386528621,1.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.32941176470588235,0.4017094017094017,0.4626865671641791,0.011534025374855825,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,10.9333333333333,0,15.6333333333333 18 | 7,1.0,0.48936170212765956,1.0,0.0,0.35349716446124757,0.19477434679334918,0.7037037037037037,0.2690883671430947,1.0,0.0,0.0,0.0,0.10256410256410256,0.0,0.3333333333333333,0.0,0.0,0.6666666666666666,0.3333333333333333,0.0,0.47058823529411764,0.4017094017094017,0.373134328358209,0.0022202998846597463,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,114.866666666667,1,114.866666666667 19 | 8,1.0,0.5531914893617021,1.0,0.0,0.6814744801512287,0.22090261282660337,0.8666666666666666,0.4574419717629692,1.0,0.0,0.0,0.0,0.07692307692307693,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.4117647058823529,0.5128205128205128,0.5522388059701493,0.0009515570934256056,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,32.5,1,118.033333333333 20 | 9,1.0,0.425531914893617,1.0,0.0,0.5689981096408316,0.5083135391923992,0.5481481481481482,0.5336559111813246,1.0,1.0,0.0,0.0,0.3076923076923077,0.0,1.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.47058823529411764,0.36752136752136755,0.373134328358209,0.006949250288350633,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,110.433333333333,1,110.433333333333 21 | 10,1.0,0.7872340425531915,1.0,0.0,0.7986767485822306,0.6140142517814728,1.0,0.48137355736227544,1.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.3333333333333333,0.0,0.0,0.6666666666666666,0.0,0.0,0.24705882352941178,0.49572649572649574,0.5970149253731343,0.026845444059976933,1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,39.4,0,39.4 22 | 11,1.0,0.7659574468085106,1.0,0.0,0.5680529300567108,0.3634204275534442,0.5851851851851851,0.509662172846405,1.0,0.0,0.0,0.0,0.07692307692307693,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.6235294117647059,0.49572649572649574,0.31343283582089554,0.008073817762399077,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,2.2,0,2.2 23 | 12,1.0,0.425531914893617,1.0,0.0,0.5841209829867675,0.45486935866983386,0.6148148148148148,0.5065742020467052,1.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.4470588235294118,0.4017094017094017,0.40298507462686567,0.0057670126874279125,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,77.5333333333333,1,106.2 24 | 13,1.0,0.574468085106383,1.0,0.0,0.48487712665406424,0.7529691211401426,0.3796296296296297,0.5688794143182604,1.0,0.0,0.0,0.0,0.1282051282051282,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.2823529411764706,0.46153846153846156,0.7014925373134329,0.01557093425605536,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,13.0333333333333,0,21.5333333333333 25 | 14,1.0,0.7446808510638298,1.0,0.0,0.8960302457466919,0.1959619952494062,0.9814814814814815,0.5540078155959879,1.0,1.0,0.0,0.0,0.1794871794871795,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.47058823529411764,0.6239316239316239,0.582089552238806,0.04325259515570934,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,11.4,0,11.4 26 | 15,1.0,0.8085106382978723,1.0,0.0,0.5179584120982986,0.15201900237529692,0.5148148148148148,0.5060704822568018,0.0,0.0,0.0,0.0,0.0,0.0,0.6666666666666666,0.6666666666666666,0.0,0.6666666666666666,0.6666666666666666,0.0,0.4,0.3162393162393162,0.26865671641791045,0.0839677047289504,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0.0,0,6.1 27 | 16,1.0,0.8297872340425532,1.0,0.0,0.37145557655954625,0.1163895486935867,0.22592592592592592,0.5651447453258936,0.0,0.0,0.0,0.0,0.23076923076923078,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.4117647058823529,0.2222222222222222,0.08955223880597014,0.015282583621683967,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,73.4333333333333,0,73.4333333333333 28 | 17,1.0,0.48936170212765956,1.0,0.0,0.43667296786389403,0.320665083135392,0.6185185185185186,0.3738941840080305,1.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.6666666666666666,0.0,0.0,0.3333333333333333,0.0,0.0,0.23529411764705882,0.48717948717948717,0.746268656716418,0.04382929642445214,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0,38.8666666666667,0,54.5333333333333 29 | 18,1.0,0.2978723404255319,1.0,0.0,0.7844990548204158,0.09382422802850358,0.4092592592592592,0.8728608782579917,0.0,0.0,0.0,0.0,0.1282051282051282,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.5176470588235295,0.5897435897435898,0.5970149253731343,0.009890426758938869,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,21.1666666666667,1,21.1666666666667 30 | 19,1.0,0.44680851063829785,1.0,0.0,0.6200378071833649,0.7731591448931117,0.3833333333333333,0.7179288681218096,1.0,0.0,0.0,0.0,0.23076923076923078,0.0,0.3333333333333333,0.0,0.0,0.6666666666666666,0.0,0.0,0.47058823529411764,0.36752136752136755,0.3283582089552239,0.007497116493656285,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,96.3666666666667,1,96.3666666666667 31 | 20,1.0,0.7446808510638298,1.0,0.0,0.3969754253308128,0.07125890736342044,0.9407407407407407,0.23173509001246279,1.0,0.0,0.0,0.0,0.38461538461538464,0.0,0.6666666666666666,0.3333333333333333,0.0,0.6666666666666666,0.0,0.0,0.32941176470588235,0.5299145299145299,0.582089552238806,0.002797001153402537,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,5.76666666666667,0,19.0666666666667 32 | 21,1.0,0.6170212765957447,1.0,0.0,0.3034026465028355,0.29928741092636585,0.5777777777777777,0.27089359953352976,1.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5058823529411764,0.47863247863247865,0.43283582089552236,0.004901960784313725,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,40.8333333333333,0,42.8 33 | 22,1.0,0.2765957446808511,1.0,0.0,0.14839319470699433,0.4049881235154395,0.2666666666666667,0.23734052353656607,1.0,0.0,0.0,0.0,0.20512820512820512,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.3411764705882353,0.5042735042735043,0.746268656716418,0.0030565167243367937,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,15.0666666666667,0,23.9333333333333 34 | 23,1.0,0.5957446808510638,1.0,1.0,0.030245746691871463,0.35510688836104515,0.40185185185185185,0.050485202019019966,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.5647058823529412,0.5299145299145299,0.417910447761194,0.006920415224913494,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,39.4333333333333,1,39.4333333333333 35 | 24,1.0,0.6595744680851063,1.0,0.0,0.26654064272211714,0.45961995249406185,0.8,0.17271588673248164,1.0,0.0,0.0,0.0,0.15384615384615385,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.35294117647058826,0.4700854700854701,0.5970149253731343,0.006257208765859284,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,1.1,0,1.1 36 | 25,1.0,0.7659574468085106,1.0,0.0,0.497164461247637,0.7339667458432305,0.5148148148148148,0.48575378406400593,1.0,0.0,0.0,0.0,0.02564102564102564,0.0,1.0,0.3333333333333333,0.0,1.0,0.3333333333333333,0.0,0.6588235294117647,0.6068376068376068,0.40298507462686567,0.004325259515570934,0.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,26.9666666666667,0,44.7333333333333 37 | 26,1.0,0.6595744680851063,1.0,0.0,0.6389413988657845,0.07482185273159145,0.4314814814814815,0.6932863375044205,0.0,0.0,0.0,0.0,0.38461538461538464,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.6705882352941176,0.5811965811965812,0.40298507462686567,0.1568627450980392,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1,24.4333333333333,0,24.4333333333333 38 | 27,1.0,0.6808510638297872,1.0,0.0,0.34120982986767484,0.3182897862232779,0.35000000000000003,0.42577427151984126,1.0,0.0,0.0,0.0,0.07692307692307693,0.0,0.6666666666666666,0.0,0.0,0.6666666666666666,0.0,0.0,0.5647058823529412,0.6153846153846154,0.43283582089552236,0.00720876585928489,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,10.1333333333333,0,37.3 39 | 28,1.0,0.7446808510638298,1.0,0.0,0.33648393194706994,0.15914489311163899,0.3111111111111111,0.4473539358941256,0.0,0.0,0.0,0.0,0.1794871794871795,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.27058823529411763,0.29914529914529914,0.3880597014925373,0.006055363321799307,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,49.2,1,49.2 40 | 29,1.0,0.5319148936170213,1.0,0.0,0.13988657844990546,0.13776722090261284,0.38888888888888895,0.17579109199016796,0.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.32941176470588235,0.4358974358974359,0.5970149253731343,0.002595155709342561,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,81.8,0,30.2666666666667 41 | 30,1.0,0.3191489361702128,1.0,0.0,0.32703213610586007,0.31591448931116395,0.48518518518518516,0.33339969615887177,1.0,0.0,0.0,0.0,0.23076923076923078,0.0,0.6666666666666666,0.0,0.0,0.6666666666666666,0.0,0.0,0.47058823529411764,0.6581196581196581,0.6716417910447762,0.003748558246828143,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,14.4666666666667,0,30.7 42 | 31,1.0,0.5957446808510638,1.0,0.0,0.37240075614366724,0.12826603325415678,0.5444444444444445,0.34950301672007034,0.0,0.0,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.6666666666666666,0.0,0.5529411764705883,0.3247863247863248,0.22388059701492538,0.11303344867358707,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,8.93333333333333,0,25.8666666666667 43 | 32,1.0,0.5106382978723404,1.0,0.0,0.4820415879017012,0.1959619952494062,0.5388888888888889,0.45673660933131544,1.0,0.0,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.47058823529411764,0.1623931623931624,0.11940298507462686,0.015945790080738178,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,70.2666666666667,1,70.2666666666667 44 | 33,1.0,0.6382978723404256,1.0,0.0,0.4328922495274101,0.0653206650831354,0.6574074074074074,0.3533465062588599,0.0,0.0,0.0,1.0,0.1794871794871795,0.875,0.3333333333333333,0.3333333333333333,0.0,0.3333333333333333,0.3333333333333333,0.0,0.7764705882352941,0.6410256410256411,0.3880597014925373,0.005017301038062283,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,11.6666666666667,0,32.0 45 | 34,1.0,0.6595744680851063,1.0,0.0,0.20888468809073724,0.3800475059382423,0.2074074074074074,0.3545199635061001,1.0,0.0,0.0,0.0,0.0,0.0,0.3333333333333333,0.3333333333333333,0.0,0.3333333333333333,0.3333333333333333,0.0,0.23529411764705882,0.3162393162393162,0.43283582089552236,0.008650519031141867,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0,7.2,1,49.2 46 | 35,1.0,0.2553191489361702,1.0,0.0,0.5189035916824196,0.27197149643705465,0.19074074074074077,0.814717308226739,1.0,0.0,0.0,0.0,0.9487179487179487,0.0,0.6666666666666666,0.6666666666666666,0.0,0.6666666666666666,0.3333333333333333,0.0,0.4117647058823529,0.5128205128205128,0.6417910447761194,0.0008650519031141868,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1,55.7333333333333,1,55.7333333333333 47 | 36,1.0,0.7659574468085106,1.0,0.0,0.32230623818525517,0.37767220902612836,0.5425925925925926,0.3027537457967539,1.0,0.0,0.0,0.0,0.1282051282051282,0.0,0.6666666666666666,0.0,0.0,0.6666666666666666,0.0,0.0,0.23529411764705882,0.39316239316239315,0.5970149253731343,0.003748558246828143,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,63.5666666666667,1,63.5666666666667 48 | 37,1.0,0.8723404255319149,1.0,0.0,0.6134215500945179,0.1330166270783848,0.6203703703703703,0.5292458113510906,0.0,0.0,0.0,0.0,0.48717948717948717,0.0,0.3333333333333333,0.3333333333333333,0.0,0.3333333333333333,0.0,0.0,0.4470588235294118,0.4444444444444444,0.3582089552238806,0.014705882352941175,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,52.9666666666667,1,52.9666666666667 49 | 38,1.0,0.574468085106383,1.0,0.0,0.6077504725897921,0.13657957244655583,0.7037037037037037,0.4780379267214601,0.0,0.0,0.0,0.0,0.46153846153846156,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.6235294117647059,0.5299145299145299,0.43283582089552236,0.004901960784313725,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,59.5666666666667,1,59.5666666666667 50 | 39,1.0,0.7021276595744681,1.0,0.0,0.7032136105860113,0.34204275534441814,0.8074074074074074,0.5010270299573113,1.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2,0.24786324786324787,0.44776119402985076,0.00461361014994233,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,14.5666666666667,1,24.7 51 | 40,1.0,0.3617021276595745,1.0,0.0,0.8043478260869565,0.12589073634204276,0.6333333333333334,0.6883275330826187,0.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.3333333333333333,0.6666666666666666,0.0,0.3333333333333333,0.6666666666666666,0.0,0.5058823529411764,0.5128205128205128,0.6119402985074627,0.006055363321799307,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,0.933333333333333,0,0.933333333333333 52 | 41,1.0,0.8085106382978723,1.0,0.0,0.5680529300567108,0.24109263657957244,0.4537037037037038,0.5999163603533563,0.0,0.0,0.0,0.0,0.1794871794871795,0.0,0.6666666666666666,0.0,0.0,0.6666666666666666,0.0,0.0,0.4117647058823529,0.49572649572649574,0.43283582089552236,0.0054498269896193765,1.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,13.8333333333333,0,22.5666666666667 53 | 42,1.0,0.7659574468085106,1.0,0.0,0.2684310018903591,0.25415676959619954,0.44814814814814813,0.2906343254462276,1.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.3333333333333333,0.0,0.0,0.0,0.6666666666666666,0.0,0.4117647058823529,0.4017094017094017,0.3880597014925373,0.011822376009227219,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,17.9333333333333,1,53.4666666666667 54 | 43,1.0,0.6808510638297872,1.0,0.0,0.5689981096408316,0.2482185273159145,0.9333333333333331,0.35213520501680706,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.3333333333333333,0.6666666666666666,0.0,0.4,0.37606837606837606,0.3880597014925373,0.003171856978085352,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,50.6,1,50.6 55 | 44,1.0,0.7659574468085106,1.0,0.0,0.717391304347826,0.08788598574821854,0.5000000000000001,0.7134938332866176,0.0,0.0,0.0,0.0,0.05128205128205128,0.0,0.0,0.0,0.0,0.3333333333333333,0.0,0.0,0.27058823529411763,0.4017094017094017,0.47761194029850745,0.005478662053056517,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0,12.8,1,51.1 56 | 45,1.0,0.7021276595744681,1.0,0.0,0.38185255198487705,0.17458432304038007,0.2592592592592593,0.5474481810185087,0.0,0.0,0.0,0.0,0.02564102564102564,0.0,0.3333333333333333,0.0,0.0,0.0,0.0,0.0,0.3176470588235294,0.358974358974359,0.40298507462686567,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1,44.1,0,44.1 57 | 46,1.0,0.5106382978723404,1.0,0.0,0.387523629489603,0.24228028503562946,0.5055555555555555,0.3834280517896962,1.0,0.0,0.0,0.0,0.20512820512820512,0.0,0.3333333333333333,0.0,0.0,0.3333333333333333,0.0,0.0,0.35294117647058826,0.358974358974359,0.4925373134328358,0.003460207612456747,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,11.7666666666667,1,28.1333333333333 58 | --------------------------------------------------------------------------------