├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── baselines ├── codi │ ├── diffusion_continuous.py │ ├── diffusion_discrete.py │ ├── evaluation.py │ ├── main.py │ ├── models │ │ ├── __init__.py │ │ ├── layers.py │ │ └── tabular_unet.py │ ├── sample.py │ ├── tabular_dataload.py │ ├── tabular_transformer.py │ └── utils.py ├── goggle │ ├── GoggleModel.py │ ├── data_utils.py │ ├── main.py │ ├── model │ │ ├── Encoder.py │ │ ├── Goggle.py │ │ ├── GoggleLoss.py │ │ ├── GraphDecoder.py │ │ ├── GraphInputProcessor.py │ │ ├── LearnedGraph.py │ │ ├── RGCNConv.py │ │ └── __pycache__ │ │ │ ├── Encoder.cpython-310.pyc │ │ │ ├── Goggle.cpython-310.pyc │ │ │ ├── GoggleLoss.cpython-310.pyc │ │ │ ├── GraphDecoder.cpython-310.pyc │ │ │ ├── GraphInputProcessor.cpython-310.pyc │ │ │ └── LearnedGraph.cpython-310.pyc │ └── sample.py ├── great │ ├── main.py │ ├── models │ │ ├── __pycache__ │ │ │ ├── great.cpython-310.pyc │ │ │ ├── great_dataset.cpython-310.pyc │ │ │ ├── great_start.cpython-310.pyc │ │ │ ├── great_trainer.cpython-310.pyc │ │ │ └── great_utils.cpython-310.pyc │ │ ├── great.py │ │ ├── great_dataset.py │ │ ├── great_start.py │ │ ├── great_trainer.py │ │ └── great_utils.py │ ├── post_process.py │ ├── sample.py │ └── utils.py ├── smote │ └── main.py ├── stasy │ ├── configs │ │ ├── __pycache__ │ │ │ ├── config.cpython-310.pyc │ │ │ ├── config.cpython-39.pyc │ │ │ ├── default_tabular_configs.cpython-310.pyc │ │ │ └── default_tabular_configs.cpython-39.pyc │ │ ├── config.py │ │ └── default_tabular_configs.py │ ├── datasets.py │ ├── likelihood.py │ ├── losses.py │ ├── main.py │ ├── models │ │ ├── __pycache__ │ │ │ ├── ema.cpython-310.pyc │ │ │ ├── ema.cpython-39.pyc │ │ │ ├── layers.cpython-310.pyc │ │ │ ├── layers.cpython-39.pyc │ │ │ ├── layerspp.cpython-310.pyc │ │ │ ├── layerspp.cpython-39.pyc │ │ │ ├── ncsnpp_tabular.cpython-310.pyc │ │ │ ├── ncsnpp_tabular.cpython-39.pyc │ │ │ ├── utils.cpython-310.pyc │ │ │ └── utils.cpython-39.pyc │ │ ├── ema.py │ │ ├── layers.py │ │ ├── layerspp.py │ │ ├── ncsnpp_tabular.py │ │ ├── tabular_utils.py │ │ └── utils.py │ ├── sample.py │ ├── sample_steps.py │ ├── sampling.py │ ├── sde_lib.py │ └── utils.py └── tabddpm │ ├── configs │ ├── adult.toml │ ├── beijing.toml │ ├── default.toml │ ├── magic.toml │ ├── news.toml │ └── shoppers.toml │ ├── main_sample.py │ ├── main_train.py │ ├── models │ ├── __pycache__ │ │ ├── gaussian_multinomial_distribution.cpython-310.pyc │ │ ├── modules.cpython-310.pyc │ │ └── utils.cpython-310.pyc │ ├── gaussian_multinomial_distribution.py │ ├── modules.py │ └── utils.py │ ├── sample.py │ └── train.py ├── data └── Info │ ├── adult.json │ ├── beijing.json │ ├── default.json │ ├── magic.json │ ├── news.json │ └── shoppers.json ├── download_dataset.py ├── eval ├── eval_dcr.py ├── eval_density.py ├── eval_detection.py ├── eval_mle.py ├── eval_quality.py └── mle │ ├── mle.py │ ├── tabular_dataload.py │ └── tabular_transformer.py ├── eval_impute.py ├── images ├── density.jpg ├── heat_map.jpg ├── nfe1.jpg ├── radar.jpg └── tabsyn_model.jpg ├── impute.py ├── main.py ├── process_dataset.py ├── requirements.txt ├── src ├── __init__.py ├── data.py ├── deep.py ├── env.py ├── metrics.py └── util.py ├── tabsyn ├── diffusion_utils.py ├── latent_utils.py ├── main.py ├── model.py ├── sample.py └── vae │ ├── main.py │ └── model.py ├── utils.py └── utils_train.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /baselines/codi/diffusion_continuous.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | def extract(v, t, x_shape): 7 | out = torch.gather(v, index=t, dim=0).float() 8 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) 9 | 10 | class GaussianDiffusionTrainer(nn.Module): 11 | def __init__(self, model, beta_1, beta_T, T): 12 | super().__init__() 13 | 14 | self.model = model 15 | self.T = T 16 | betas = torch.linspace(beta_1, beta_T, T, dtype=torch.float64).double() 17 | alphas = 1. - betas 18 | self.register_buffer('betas', betas) 19 | alphas_bar = torch.cumprod(alphas, dim=0) 20 | 21 | self.register_buffer( 22 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar)) 23 | self.register_buffer( 24 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar)) 25 | self.register_buffer( 26 | 'sqrt_recip_alphas_bar', torch.sqrt(1. / alphas_bar)) 27 | self.register_buffer( 28 | 'sqrt_recipm1_alphas_bar', torch.sqrt(1. / alphas_bar - 1)) 29 | 30 | def make_x_t(self, x_0_con, t, noise): 31 | x_t_con = ( 32 | extract(self.sqrt_alphas_bar, t, x_0_con.shape) * x_0_con + 33 | extract(self.sqrt_one_minus_alphas_bar, t, x_0_con.shape) * noise) 34 | return x_t_con 35 | 36 | def predict_xstart_from_eps(self, x_t, t, eps): 37 | assert x_t.shape == eps.shape 38 | return ( 39 | extract(self.sqrt_recip_alphas_bar, t, x_t.shape) * x_t - 40 | extract(self.sqrt_recipm1_alphas_bar, t, x_t.shape) * eps 41 | ) 42 | 43 | 44 | class GaussianDiffusionSampler(nn.Module): 45 | def __init__(self, model, beta_1, beta_T, T, 46 | mean_type='eps', var_type='fixedlarge'): 47 | assert mean_type in ['xprev' 'xstart', 'epsilon'] 48 | assert var_type in ['fixedlarge', 'fixedsmall'] 49 | super().__init__() 50 | 51 | self.model = model 52 | self.T = T 53 | self.mean_type = mean_type 54 | self.var_type = var_type 55 | 56 | betas = torch.linspace(beta_1, beta_T, T, dtype=torch.float64).double() 57 | 58 | alphas = 1. - betas 59 | self.register_buffer( 60 | 'betas', betas) 61 | alphas_bar = torch.cumprod(alphas, dim=0) 62 | alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T] 63 | 64 | self.register_buffer( 65 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar)) 66 | self.register_buffer( 67 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar)) 68 | 69 | # calculations for diffusion q(x_t | x_{t-1}) and others 70 | self.register_buffer( 71 | 'sqrt_recip_alphas_bar', torch.sqrt(1. / alphas_bar)) 72 | self.register_buffer( 73 | 'sqrt_recipm1_alphas_bar', torch.sqrt(1. / alphas_bar - 1)) 74 | 75 | # calculations for posterior q(x_{t-1} | x_t, x_0) 76 | self.register_buffer( 77 | 'posterior_var', 78 | self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar)) 79 | # below: log calculation clipped because the posterior variance is 0 at 80 | # the beginning of the diffusion chain 81 | self.register_buffer( 82 | 'posterior_log_var_clipped', 83 | torch.log( 84 | torch.cat([self.posterior_var[1:2], self.posterior_var[1:]]))) 85 | self.register_buffer( 86 | 'posterior_mean_coef1', 87 | torch.sqrt(alphas_bar_prev) * self.betas / (1. - alphas_bar)) 88 | self.register_buffer( 89 | 'posterior_mean_coef2', 90 | torch.sqrt(alphas) * (1. - alphas_bar_prev) / (1. - alphas_bar)) 91 | 92 | def q_mean_variance(self, x_0, x_t, t): 93 | """ 94 | Compute the mean and variance of the diffusion posterior 95 | q(x_{t-1} | x_t, x_0) 96 | """ 97 | assert x_0.shape == x_t.shape 98 | posterior_mean = ( 99 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 + 100 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 101 | ) 102 | posterior_log_var_clipped = extract( 103 | self.posterior_log_var_clipped, t, x_t.shape) 104 | return posterior_mean, posterior_log_var_clipped 105 | 106 | def predict_xstart_from_eps(self, x_t, t, eps): 107 | assert x_t.shape == eps.shape 108 | return ( 109 | extract(self.sqrt_recip_alphas_bar, t, x_t.shape) * x_t - 110 | extract(self.sqrt_recipm1_alphas_bar, t, x_t.shape) * eps 111 | ) 112 | 113 | 114 | def p_mean_variance(self, x_t, t, cond, trans): 115 | # below: only log_variance is used in the KL computations 116 | model_log_var = { 117 | # for fixedlarge, we set the initial (log-)variance like so to 118 | # get a better decoder log likelihood 119 | 'fixedlarge': torch.log(torch.cat([self.posterior_var[1:2], 120 | self.betas[1:]])), 121 | 'fixedsmall': self.posterior_log_var_clipped, 122 | }[self.var_type] 123 | model_log_var = extract(model_log_var, t, x_t.shape) 124 | 125 | # Mean parameterization 126 | if self.mean_type == 'epsilon': # the model predicts epsilon 127 | eps = self.model(x_t, t, cond) 128 | x_0 = self.predict_xstart_from_eps(x_t, t, eps=eps) 129 | model_mean, _ = self.q_mean_variance(x_0, x_t, t) 130 | else: 131 | raise NotImplementedError(self.mean_type) 132 | 133 | return model_mean, model_log_var -------------------------------------------------------------------------------- /baselines/codi/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /baselines/codi/models/layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Common layers for defining score networks. 18 | """ 19 | import math 20 | import torch.nn as nn 21 | import torch 22 | import torch.nn.functional as F 23 | import numpy as np 24 | 25 | def get_act(FLAGS): 26 | if FLAGS.activation.lower() == 'elu': 27 | return nn.ELU() 28 | elif FLAGS.activation.lower() == 'relu': 29 | return nn.ReLU() 30 | elif FLAGS.activation.lower() == 'lrelu': 31 | return nn.LeakyReLU(negative_slope=0.2) 32 | elif FLAGS.activation.lower() == 'swish': 33 | return nn.SiLU() 34 | elif FLAGS.activation.lower() == 'tanh': 35 | return nn.Tanh() 36 | elif FLAGS.activation.lower() == 'softplus': 37 | return nn.Softplus() 38 | else: 39 | raise NotImplementedError('activation function does not exist!') 40 | 41 | def variance_scaling(scale, mode, distribution, 42 | in_axis=1, out_axis=0, 43 | dtype=torch.float32, 44 | device='cpu'): 45 | def _compute_fans(shape, in_axis=1, out_axis=0): 46 | receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] 47 | fan_in = shape[in_axis] * receptive_field_size 48 | fan_out = shape[out_axis] * receptive_field_size 49 | return fan_in, fan_out 50 | 51 | def init(shape, dtype=dtype, device=device): 52 | fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) 53 | if mode == "fan_in": 54 | denominator = fan_in 55 | elif mode == "fan_out": 56 | denominator = fan_out 57 | elif mode == "fan_avg": 58 | denominator = (fan_in + fan_out) / 2 59 | else: 60 | raise ValueError( 61 | "invalid mode for variance scaling initializer: {}".format(mode)) 62 | variance = scale / denominator 63 | if distribution == "normal": 64 | return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) 65 | elif distribution == "uniform": 66 | return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) 67 | else: 68 | raise ValueError("invalid distribution for variance scaling initializer") 69 | 70 | return init 71 | 72 | def default_init(scale=1.): 73 | """The same initialization used in DDPM.""" 74 | scale = 1e-10 if scale == 0 else scale 75 | return variance_scaling(scale, 'fan_avg', 'uniform') 76 | 77 | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): 78 | assert len(timesteps.shape) == 1 79 | half_dim = embedding_dim // 2 80 | emb = math.log(max_positions) / (half_dim - 1) 81 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) 82 | emb = timesteps.float()[:, None] * emb[None, :] 83 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 84 | if embedding_dim % 2 == 1: 85 | emb = F.pad(emb, (0, 1), mode='constant') 86 | assert emb.shape == (timesteps.shape[0], embedding_dim) 87 | return emb 88 | 89 | class Encoder(nn.Module): 90 | def __init__(self, encoder_dim, tdim, FLAGS): 91 | super(Encoder, self).__init__() 92 | self.encoding_blocks = nn.ModuleList() 93 | for i in range(len(encoder_dim)): 94 | if (i+1)==len(encoder_dim): break 95 | encoding_block = EncodingBlock(encoder_dim[i], encoder_dim[i+1], tdim, FLAGS) 96 | self.encoding_blocks.append(encoding_block) 97 | 98 | def forward(self, x, t): 99 | skip_connections = [] 100 | for encoding_block in self.encoding_blocks: 101 | x, skip_connection = encoding_block(x, t) 102 | skip_connections.append(skip_connection) 103 | return skip_connections, x 104 | 105 | class EncodingBlock(nn.Module): 106 | def __init__(self, dim_in, dim_out, tdim, FLAGS): 107 | super(EncodingBlock, self).__init__() 108 | self.layer1 = nn.Sequential( 109 | nn.Linear(dim_in, dim_out), 110 | get_act(FLAGS) 111 | ) 112 | self.temb_proj = nn.Sequential( 113 | nn.Linear(tdim, dim_out), 114 | get_act(FLAGS) 115 | ) 116 | self.layer2 = nn.Sequential( 117 | nn.Linear(dim_out, dim_out), 118 | get_act(FLAGS) 119 | ) 120 | 121 | def forward(self, x, t): 122 | x = self.layer1(x).clone() 123 | x += self.temb_proj(t) 124 | x = self.layer2(x) 125 | skip_connection = x 126 | return x, skip_connection 127 | 128 | class Decoder(nn.Module): 129 | def __init__(self, decoder_dim, tdim, FLAGS): 130 | super(Decoder, self).__init__() 131 | self.decoding_blocks = nn.ModuleList() 132 | for i in range(len(decoder_dim)): 133 | if (i+1)==len(decoder_dim): break 134 | decoding_block = DecodingBlock(decoder_dim[i], decoder_dim[i+1], tdim, FLAGS) 135 | self.decoding_blocks.append(decoding_block) 136 | 137 | def forward(self, skip_connections, x, t): 138 | zipped = zip(reversed(skip_connections), self.decoding_blocks) 139 | for skip_connection, decoding_block in zipped: 140 | x = decoding_block(skip_connection, x, t) 141 | return x 142 | 143 | class DecodingBlock(nn.Module): 144 | def __init__(self, dim_in, dim_out, tdim, FLAGS): 145 | super(DecodingBlock, self).__init__() 146 | self.layer1 = nn.Sequential( 147 | nn.Linear(dim_in*2, dim_in), 148 | get_act(FLAGS) 149 | ) 150 | self.temb_proj = nn.Sequential( 151 | nn.Linear(tdim, dim_in), 152 | get_act(FLAGS) 153 | ) 154 | self.layer2 = nn.Sequential( 155 | nn.Linear(dim_in, dim_out), 156 | get_act(FLAGS) 157 | ) 158 | 159 | def forward(self, skip_connection, x, t): 160 | 161 | x = torch.cat((skip_connection, x), dim=1) 162 | x = self.layer1(x).clone() 163 | x += self.temb_proj(t) 164 | x = self.layer2(x) 165 | 166 | return x -------------------------------------------------------------------------------- /baselines/codi/models/tabular_unet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | 18 | from baselines.codi.models import layers 19 | import torch.nn as nn 20 | import torch 21 | 22 | get_act = layers.get_act 23 | default_initializer = layers.default_init 24 | 25 | class tabularUnet(nn.Module): 26 | def __init__(self, FLAGS): 27 | super().__init__() 28 | 29 | self.embed_dim = FLAGS.nf 30 | tdim = self.embed_dim*4 31 | self.act = get_act(FLAGS) 32 | 33 | modules = [] 34 | modules.append(nn.Linear(self.embed_dim, tdim)) 35 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 36 | nn.init.zeros_(modules[-1].bias) 37 | modules.append(nn.Linear(tdim, tdim)) 38 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 39 | nn.init.zeros_(modules[-1].bias) 40 | 41 | cond = FLAGS.cond_size 42 | cond_out = (FLAGS.input_size)//2 43 | if cond_out < 2: 44 | cond_out = FLAGS.input_size 45 | modules.append(nn.Linear(cond, cond_out)) 46 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 47 | nn.init.zeros_(modules[-1].bias) 48 | 49 | self.all_modules = nn.ModuleList(modules) 50 | 51 | dim_in = FLAGS.input_size + cond_out 52 | dim_out = list(FLAGS.encoder_dim)[0] 53 | self.inputs = nn.Linear(dim_in, dim_out) # input layer 54 | 55 | self.encoder = layers.Encoder(list(FLAGS.encoder_dim), tdim, FLAGS) # encoder 56 | 57 | dim_in = list(FLAGS.encoder_dim)[-1] 58 | dim_out = list(FLAGS.encoder_dim)[-1] 59 | self.bottom_block = nn.Linear(dim_in, dim_out) #bottom_layer 60 | 61 | self.decoder = layers.Decoder(list(reversed(FLAGS.encoder_dim)), tdim, FLAGS) #decoder 62 | 63 | dim_in = list(FLAGS.encoder_dim)[0] 64 | dim_out = FLAGS.output_size 65 | self.outputs = nn.Linear(dim_in, dim_out) #output layer 66 | 67 | 68 | def forward(self, x, time_cond, cond): 69 | 70 | modules = self.all_modules 71 | m_idx = 0 72 | 73 | #time embedding 74 | temb = layers.get_timestep_embedding(time_cond, self.embed_dim) 75 | temb = modules[m_idx](temb) 76 | m_idx += 1 77 | temb= self.act(temb) 78 | temb = modules[m_idx](temb) 79 | m_idx += 1 80 | 81 | #condition layer 82 | cond = modules[m_idx](cond) 83 | m_idx += 1 84 | 85 | x = torch.cat([x, cond], dim=1).float() 86 | inputs = self.inputs(x) #input layer 87 | skip_connections, encoding = self.encoder(inputs, temb) 88 | encoding = self.bottom_block(encoding) 89 | encoding = self.act(encoding) 90 | x = self.decoder(skip_connections, encoding, temb) 91 | outputs = self.outputs(x) 92 | 93 | return outputs 94 | -------------------------------------------------------------------------------- /baselines/codi/tabular_dataload.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Return training and evaluation/test datasets from config files.""" 18 | import torch 19 | import numpy as np 20 | import pandas as pd 21 | import json 22 | import logging 23 | import os 24 | 25 | from baselines.codi.tabular_transformer import GeneralTransformer 26 | 27 | CATEGORICAL = "categorical" 28 | CONTINUOUS = "continuous" 29 | 30 | LOGGER = logging.getLogger(__name__) 31 | 32 | DATA_PATH = os.path.join(os.path.dirname(__file__), 'tabular_datasets') 33 | 34 | def _load_json(path): 35 | with open(path) as json_file: 36 | return json.load(json_file) 37 | 38 | 39 | def _load_file(filename, loader): 40 | local_path = os.path.join(DATA_PATH, filename) 41 | 42 | if loader == np.load: 43 | return loader(local_path, allow_pickle=True) 44 | return loader(local_path) 45 | 46 | 47 | def _get_columns(metadata): 48 | categorical_columns = list() 49 | 50 | for column_idx, column in enumerate(metadata['columns']): 51 | if column['type'] == CATEGORICAL: 52 | categorical_columns.append(column_idx) 53 | 54 | return categorical_columns 55 | 56 | 57 | # def load_data(name, benchmark=False): 58 | # data = _load_file(name + '.npz', np.load) 59 | # meta = _load_file(name + '.json', _load_json) 60 | 61 | # categorical_columns = _get_columns(meta) 62 | # train = data['train'] 63 | # test = data['test'] 64 | 65 | 66 | # return train, test, (categorical_columns, meta) 67 | 68 | def load_data(name): 69 | data_dir = f'data/{name}' 70 | info_path = f'{data_dir}/info.json' 71 | 72 | train = pd.read_csv(f'{data_dir}/train.csv').to_numpy() 73 | test = pd.read_csv(f'{data_dir}/test.csv').to_numpy() 74 | 75 | with open(f'{data_dir}/info.json', 'r') as f: 76 | info = json.load(f) 77 | 78 | task_type = info['task_type'] 79 | 80 | num_cols = info['num_col_idx'] 81 | cat_cols = info['cat_col_idx'] 82 | target_cols = info['target_col_idx'] 83 | 84 | if task_type != 'regression': 85 | cat_cols = cat_cols + target_cols 86 | 87 | return train, test, (cat_cols, info) 88 | 89 | 90 | def get_dataset(FLAGS, evaluation=False): 91 | 92 | batch_size = FLAGS.training_batch_size if not evaluation else FLAGS.eval_batch_size 93 | 94 | if batch_size % torch.cuda.device_count() != 0: 95 | raise ValueError(f'Batch sizes ({batch_size} must be divided by' 96 | f'the number of devices ({torch.cuda.device_count()})') 97 | 98 | 99 | # Create dataset builders for tabular data. 100 | train, test, cols = load_data(FLAGS.dataname) 101 | cols_idx = list(np.arange(train.shape[1])) 102 | dis_idx = cols[0] 103 | con_idx = [x for x in cols_idx if x not in dis_idx] 104 | 105 | #split continuous and categorical 106 | train_con = train[:,con_idx] 107 | train_dis = train[:,dis_idx] 108 | 109 | #new index 110 | cat_idx_ = list(np.arange(train_dis.shape[1]))[:len(cols[0])] 111 | 112 | transformer_con = GeneralTransformer() 113 | transformer_dis = GeneralTransformer() 114 | 115 | transformer_con.fit(train_con, []) 116 | transformer_dis.fit(train_dis, cat_idx_) 117 | 118 | train_con_data = transformer_con.transform(train_con) 119 | train_dis_data = transformer_dis.transform(train_dis) 120 | 121 | 122 | return train, train_con_data, train_dis_data, test, (transformer_con, transformer_dis, cols[1]), con_idx, dis_idx 123 | -------------------------------------------------------------------------------- /baselines/codi/tabular_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | CATEGORICAL = "categorical" 5 | CONTINUOUS = "continuous" 6 | 7 | class Transformer: 8 | 9 | @staticmethod 10 | def get_metadata(data, categorical_columns=tuple()): 11 | meta = [] 12 | 13 | df = pd.DataFrame(data) 14 | for index in df: 15 | column = df[index] 16 | 17 | if index in categorical_columns: 18 | mapper = column.value_counts().index.tolist() 19 | meta.append({ 20 | "name": index, 21 | "type": CATEGORICAL, 22 | "size": len(mapper), 23 | "i2s": mapper 24 | }) 25 | else: 26 | meta.append({ 27 | "name": index, 28 | "type": CONTINUOUS, 29 | "min": column.min(), 30 | "max": column.max(), 31 | }) 32 | 33 | return meta 34 | 35 | def fit(self, data, categorical_columns=tuple()): 36 | raise NotImplementedError 37 | 38 | def transform(self, data): 39 | raise NotImplementedError 40 | 41 | def inverse_transform(self, data): 42 | raise NotImplementedError 43 | 44 | 45 | class GeneralTransformer(Transformer): 46 | 47 | def __init__(self, act='tanh'): 48 | self.act = act 49 | self.meta = None 50 | self.output_dim = None 51 | 52 | def fit(self, data, categorical_columns=tuple()): 53 | self.meta = self.get_metadata(data, categorical_columns) 54 | self.output_dim = 0 55 | for info in self.meta: 56 | if info['type'] in [CONTINUOUS]: 57 | self.output_dim += 1 58 | else: 59 | self.output_dim += info['size'] 60 | 61 | def transform(self, data): 62 | data_t = [] 63 | self.output_info = [] 64 | for id_, info in enumerate(self.meta): 65 | col = data[:, id_] 66 | if info['type'] == CONTINUOUS: 67 | col = (col - (info['min'])) / (info['max'] - info['min']) 68 | if self.act == 'tanh': 69 | col = col * 2 - 1 70 | data_t.append(col.reshape([-1, 1])) 71 | self.output_info.append((1, self.act)) 72 | 73 | else: 74 | col_t = np.zeros([len(data), info['size']]) 75 | idx = list(map(info['i2s'].index, col)) 76 | col_t[np.arange(len(data)), idx] = 1 77 | data_t.append(col_t) 78 | self.output_info.append((info['size'], 'softmax')) 79 | 80 | return np.concatenate(data_t, axis=1) 81 | 82 | def inverse_transform(self, data): 83 | 84 | if self.meta[0]['type'] == CONTINUOUS: 85 | data_t = np.zeros([len(data), len(self.meta)]) 86 | else: 87 | dtype = np.dtype('U50') 88 | data_t = np.empty([len(data), len(self.meta)], dtype=dtype) 89 | 90 | 91 | data = data.copy() 92 | for id_, info in enumerate(self.meta): 93 | 94 | if info['type'] == CONTINUOUS: 95 | current = data[:, 0] 96 | data = data[:, 1:] 97 | 98 | if self.act == 'tanh': 99 | current = (current + 1) / 2 100 | 101 | current = np.clip(current, 0, 1) 102 | data_t[:, id_] = current * (info['max'] - info['min']) + info['min'] 103 | 104 | else: 105 | current = data[:, :info['size']] 106 | data = data[:, info['size']:] 107 | idx = np.argmax(current, axis=1) 108 | recovered = list(map(info['i2s'].__getitem__, idx)) 109 | 110 | data_t[:, id_] = recovered 111 | return data_t 112 | -------------------------------------------------------------------------------- /baselines/codi/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import pandas as pd 5 | 6 | def warmup_lr(step): 7 | return min(step, 5000) / 5000 8 | 9 | def infiniteloop(dataloader): 10 | while True: 11 | for _, y in enumerate(dataloader): 12 | yield y 13 | 14 | def apply_activate(data, output_info): 15 | data_t = [] 16 | st = 0 17 | for item in output_info: 18 | if item[1] == 'tanh': 19 | ed = st + item[0] 20 | data_t.append(torch.tanh(data[:, st:ed])) 21 | st = ed 22 | elif item[1] == 'sigmoid': 23 | ed = st + item[0] 24 | data_t.append(data[:,st:ed]) 25 | st = ed 26 | elif item[1] == 'softmax': 27 | ed = st + item[0] 28 | data_t.append(F.softmax(data[:, st:ed])) 29 | st = ed 30 | else: 31 | assert 0 32 | return torch.cat(data_t, dim=1) 33 | 34 | def log_sample_categorical(logits, num_classes): 35 | full_sample = [] 36 | k=0 37 | for i in range(len(num_classes)): 38 | logits_column = logits[:,k:num_classes[i]+k] 39 | k+=num_classes[i] 40 | uniform = torch.rand_like(logits_column) 41 | gumbel_noise = -torch.log(-torch.log(uniform+1e-30)+1e-30) 42 | sample = (gumbel_noise + logits_column).argmax(dim=1) 43 | col_t =np.zeros(logits_column.shape) 44 | col_t[np.arange(logits_column.shape[0]), sample.detach().cpu()] = 1 45 | full_sample.append(col_t) 46 | full_sample = torch.tensor(np.concatenate(full_sample, axis=1)) 47 | log_sample = torch.log(full_sample.float().clamp(min=1e-30)) 48 | return log_sample 49 | 50 | 51 | def sampling_with(x_T_con, log_x_T_dis, net_sampler, trainer_dis, trans, FLAGS): 52 | x_t_con = x_T_con 53 | x_t_dis = log_x_T_dis 54 | 55 | for time_step in reversed(range(FLAGS.T)): 56 | t = x_t_con.new_ones([x_t_con.shape[0], ], dtype=torch.long) * time_step 57 | mean, log_var = net_sampler.p_mean_variance(x_t=x_t_con, t=t, cond = x_t_dis.to(x_t_con.device), trans=trans) 58 | if time_step > 0: 59 | noise = torch.randn_like(x_t_con) 60 | elif time_step == 0: 61 | noise = 0 62 | x_t_minus_1_con = mean + torch.exp(0.5 * log_var) * noise 63 | x_t_minus_1_con = torch.clip(x_t_minus_1_con, -1., 1.) 64 | x_t_minus_1_dis = trainer_dis.p_sample(x_t_dis, t, x_t_con) 65 | x_t_con = x_t_minus_1_con 66 | x_t_dis = x_t_minus_1_dis 67 | 68 | return x_t_con, x_t_dis 69 | 70 | def training_with(x_0_con, x_0_dis, trainer, trainer_dis, ns_con, ns_dis, categories, FLAGS): 71 | 72 | t = torch.randint(FLAGS.T, size=(x_0_con.shape[0], ), device=x_0_con.device) 73 | pt = torch.ones_like(t).float() / FLAGS.T 74 | 75 | #co-evolving training and predict positive samples 76 | noise = torch.randn_like(x_0_con) 77 | x_t_con = trainer.make_x_t(x_0_con, t, noise) 78 | log_x_start = torch.log(x_0_dis.float().clamp(min=1e-30)) 79 | x_t_dis = trainer_dis.q_sample(log_x_start=log_x_start, t=t) 80 | eps = trainer.model(x_t_con, t, x_t_dis.to(x_t_con.device)) 81 | ps_0_con = trainer.predict_xstart_from_eps(x_t_con, t, eps=eps) 82 | con_loss = F.mse_loss(eps, noise, reduction='none') 83 | con_loss = con_loss.mean() 84 | kl, ps_0_dis = trainer_dis.compute_Lt(log_x_start, x_t_dis, t, x_t_con) 85 | ps_0_dis = torch.exp(ps_0_dis) 86 | kl_prior = trainer_dis.kl_prior(log_x_start) 87 | dis_loss = (kl / pt + kl_prior).mean() 88 | 89 | # negative condition -> predict negative samples 90 | noise_ns = torch.randn_like(ns_con) 91 | ns_t_con = trainer.make_x_t(ns_con, t, noise_ns) 92 | log_ns_start = torch.log(ns_dis.float().clamp(min=1e-30)) 93 | ns_t_dis = trainer_dis.q_sample(log_x_start=log_ns_start, t=t) 94 | eps_ns = trainer.model(x_t_con, t, ns_t_dis.to(ns_t_dis.device)) 95 | ns_0_con = trainer.predict_xstart_from_eps(x_t_con, t, eps=eps_ns) 96 | _, ns_0_dis = trainer_dis.compute_Lt(log_x_start, x_t_dis, t, ns_t_con) 97 | ns_0_dis = torch.exp(ns_0_dis) 98 | 99 | # contrastive learning loss 100 | triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2) 101 | triplet_con = triplet_loss(x_0_con, ps_0_con, ns_0_con) 102 | st=0 103 | triplet_dis = [] 104 | for item in categories: 105 | ed = st + item 106 | ps_dis = F.cross_entropy(ps_0_dis[:, st:ed], torch.argmax(x_0_dis[:, st:ed], dim=-1).long(), reduction='none') 107 | ns_dis = F.cross_entropy(ns_0_dis[:, st:ed], torch.argmax(x_0_dis[:, st:ed], dim=-1).long(), reduction='none') 108 | 109 | triplet_dis.append(max((ps_dis-ns_dis).mean()+1,0)) 110 | st = ed 111 | triplet_dis = sum(triplet_dis)/len(triplet_dis) 112 | return con_loss, triplet_con, dis_loss, triplet_dis 113 | 114 | def make_negative_condition(x_0_con, x_0_dis): 115 | 116 | device = x_0_con.device 117 | x_0_con = x_0_con.detach().cpu().numpy() 118 | x_0_dis = x_0_dis.detach().cpu().numpy() 119 | 120 | nsc_raw = pd.DataFrame(x_0_con) 121 | nsd_raw = pd.DataFrame(x_0_dis) 122 | nsc = np.array(nsc_raw.sample(frac=1, replace = False).reset_index(drop=True)) 123 | nsd = np.array(nsd_raw.sample(frac=1, replace = False).reset_index(drop=True)) 124 | ns_con = nsc 125 | ns_dis = nsd 126 | 127 | return torch.tensor(ns_con).to(device), torch.tensor(ns_dis).to(device) -------------------------------------------------------------------------------- /baselines/goggle/data_utils.py: -------------------------------------------------------------------------------- 1 | # Standard imports 2 | import random 3 | 4 | # 3rd party 5 | import numpy as np 6 | import torch 7 | from sklearn.model_selection import train_test_split 8 | from torch.utils.data import DataLoader, TensorDataset 9 | 10 | 11 | def seed_worker(worker_id): 12 | worker_seed = torch.initial_seed() % 2**32 13 | np.random.seed(worker_seed) 14 | random.seed(worker_seed) 15 | 16 | 17 | def get_dataloader(X, batch_size, seed): 18 | X_train, X_val = train_test_split(X, test_size=0.2, random_state=seed) 19 | 20 | train_dataset = TensorDataset(torch.Tensor(X_train.values)) 21 | val_dataset = TensorDataset(torch.Tensor(X_val.values)) 22 | 23 | g = torch.Generator() 24 | g.manual_seed(seed) 25 | torch.manual_seed(seed) 26 | 27 | dataloader = { 28 | "train": DataLoader( 29 | train_dataset, 30 | batch_size=batch_size, 31 | shuffle=True, 32 | worker_init_fn=seed_worker, 33 | generator=g, 34 | ), 35 | "val": DataLoader( 36 | val_dataset, 37 | batch_size=batch_size, 38 | shuffle=True, 39 | worker_init_fn=seed_worker, 40 | generator=g, 41 | ), 42 | } 43 | return dataloader -------------------------------------------------------------------------------- /baselines/goggle/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | from torch.utils.data import DataLoader 5 | 6 | import argparse 7 | import warnings 8 | import json 9 | 10 | from utils_train import preprocess 11 | 12 | from baselines.goggle.GoggleModel import GoggleModel 13 | 14 | warnings.filterwarnings('ignore') 15 | 16 | 17 | def main(args): 18 | 19 | dataname = args.dataname 20 | device = args.device 21 | 22 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 23 | dataset_dir = f'data/{dataname}' 24 | ckpt_dir = f'{curr_dir}/ckpt/{dataname}' 25 | 26 | if not os.path.exists(ckpt_dir): 27 | os.makedirs(ckpt_dir) 28 | 29 | with open(f'{dataset_dir}/info.json', 'r') as f: 30 | info = json.load(f) 31 | 32 | task_type = info['task_type'] 33 | 34 | dataset = preprocess(dataset_dir, task_type = task_type, cat_encoding = 'one-hot') 35 | X_train = torch.tensor(dataset.X_num['train']) 36 | 37 | gen = GoggleModel( 38 | ds_name=dataname, 39 | input_dim=X_train.shape[1], 40 | encoder_dim=2048, 41 | encoder_l=4, 42 | het_encoding=True, 43 | decoder_dim=2048, 44 | decoder_l=4, 45 | threshold=0.1, 46 | decoder_arch="gcn", 47 | graph_prior=None, 48 | prior_mask=None, 49 | device=device, 50 | beta=1, 51 | learning_rate=0.01, 52 | seed=42, 53 | ) 54 | print(gen.model) 55 | print(gen.model.learned_graph.graph.shape) 56 | 57 | num_params = sum(p.numel() for p in gen.model.encoder.parameters() if p.requires_grad) 58 | print(f'Number of parameters in encoder: {num_params}') 59 | 60 | start_time = time.time() 61 | train_loader = DataLoader(X_train, batch_size=gen.batch_size, shuffle=True) 62 | gen.fit(train_loader, f'{ckpt_dir}/model.pt') 63 | end_time = time.time() 64 | 65 | print(f'Training time: {end_time - start_time}') 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser(description='GOGGLE') 69 | 70 | parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.') 71 | parser.add_argument('--gpu', type=int, default=0, help='GPU index.') 72 | 73 | args = parser.parse_args() 74 | 75 | # check cuda 76 | if args.gpu != -1 and torch.cuda.is_available(): 77 | args.device = f'cuda:{args.gpu}' 78 | else: 79 | args.device = 'cpu' -------------------------------------------------------------------------------- /baselines/goggle/model/Encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, input_dim, encoder_dim, encoder_l, device): 7 | super(Encoder, self).__init__() 8 | encoder = nn.ModuleList([nn.Linear(input_dim, encoder_dim), nn.ReLU()]) 9 | for _ in range(encoder_l - 2): 10 | encoder_dim_ = int(encoder_dim / 2) 11 | encoder.append(nn.Linear(encoder_dim, encoder_dim_)) 12 | encoder.append(nn.ReLU()) 13 | encoder_dim = encoder_dim_ 14 | self.encoder = nn.Sequential(*encoder) 15 | self.encode_mu = nn.Linear(encoder_dim, input_dim) 16 | self.encode_logvar = nn.Linear(encoder_dim, input_dim) 17 | 18 | def reparameterize(self, mu, logvar): 19 | std = torch.exp(0.5 * logvar) 20 | eps = torch.randn_like(std) 21 | return mu + eps * std 22 | 23 | def forward(self, x): 24 | 25 | h = self.encoder(x) 26 | mu_z, logvar_z = self.encode_mu(h), self.encode_logvar(h) 27 | z = self.reparameterize(mu_z, logvar_z) 28 | return z, (mu_z, logvar_z) -------------------------------------------------------------------------------- /baselines/goggle/model/Goggle.py: -------------------------------------------------------------------------------- 1 | # 3rd party 2 | import torch 3 | from torch import nn 4 | 5 | # Goggle 6 | from baselines.goggle.model.Encoder import Encoder 7 | from baselines.goggle.model.GraphDecoder import GraphDecoderHet, GraphDecoderHomo 8 | from baselines.goggle.model.GraphInputProcessor import ( 9 | GraphInputProcessorHet, 10 | GraphInputProcessorHomo, 11 | ) 12 | from baselines.goggle.model.LearnedGraph import LearnedGraph 13 | 14 | 15 | class Goggle(nn.Module): 16 | def __init__( 17 | self, 18 | input_dim, 19 | encoder_dim=64, 20 | encoder_l=2, 21 | het_encoding=True, 22 | decoder_dim=64, 23 | decoder_l=2, 24 | threshold=0.1, 25 | decoder_arch="gcn", 26 | graph_prior=None, 27 | prior_mask=None, 28 | device="cpu", 29 | ): 30 | super(Goggle, self).__init__() 31 | self.input_dim = input_dim 32 | self.device = device 33 | self.learned_graph = LearnedGraph( 34 | input_dim, graph_prior, prior_mask, threshold, device 35 | ) 36 | self.encoder = Encoder(input_dim, encoder_dim, encoder_l, device) 37 | if decoder_arch == "het": 38 | n_edge_types = input_dim * input_dim 39 | self.graph_processor = GraphInputProcessorHet( 40 | input_dim, decoder_dim, n_edge_types, het_encoding, device 41 | ) 42 | self.decoder = GraphDecoderHet(decoder_dim, decoder_l, n_edge_types, device) 43 | else: 44 | self.graph_processor = GraphInputProcessorHomo( 45 | input_dim, decoder_dim, het_encoding, device 46 | ) 47 | self.decoder = GraphDecoderHomo( 48 | decoder_dim, decoder_l, decoder_arch, device 49 | ) 50 | 51 | def forward(self, x, iter): 52 | z, (mu_z, logvar_z) = self.encoder(x) 53 | b_size, _ = z.shape 54 | adj = self.learned_graph(iter) 55 | graph_input = self.graph_processor(z, adj) 56 | x_hat = self.decoder(graph_input, b_size) 57 | 58 | return x_hat, adj, mu_z, logvar_z 59 | 60 | def sample(self, count): 61 | with torch.no_grad(): 62 | mu = torch.zeros(self.input_dim) 63 | sigma = torch.ones(self.input_dim) 64 | q = torch.distributions.Normal(mu, sigma) 65 | z = q.rsample(sample_shape=torch.Size([count])).squeeze().to(self.device) 66 | 67 | self.learned_graph.eval() 68 | self.graph_processor.eval() 69 | self.decoder.eval() 70 | 71 | adj = self.learned_graph(100) 72 | graph_input = self.graph_processor(z, adj) 73 | synth_x = self.decoder(graph_input, count) 74 | 75 | return synth_x -------------------------------------------------------------------------------- /baselines/goggle/model/GoggleLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class GoggleLoss(nn.Module): 6 | def __init__(self, alpha=1, beta=0, graph_prior=None, device="cpu"): 7 | super(GoggleLoss, self).__init__() 8 | self.mse_loss = nn.MSELoss(reduction="sum") 9 | self.device = device 10 | self.alpha = alpha 11 | self.beta = beta 12 | if graph_prior is not None: 13 | self.use_prior = True 14 | self.graph_prior = ( 15 | torch.Tensor(graph_prior).requires_grad_(False).to(device) 16 | ) 17 | else: 18 | self.use_prior = False 19 | 20 | def forward(self, x_recon, x, mu, logvar, graph): 21 | loss_mse = self.mse_loss(x_recon, x) 22 | loss_kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 23 | if self.use_prior: 24 | loss_graph = (graph - self.graph_prior).norm(p=1) / torch.numel(graph) 25 | else: 26 | loss_graph = graph.norm(p=1) / torch.numel(graph) 27 | 28 | loss = loss_mse + self.alpha * loss_kld + self.beta * loss_graph 29 | 30 | return loss, loss_mse, loss_kld, loss_graph -------------------------------------------------------------------------------- /baselines/goggle/model/GraphDecoder.py: -------------------------------------------------------------------------------- 1 | # 3rd Party 2 | from dgl.nn import GraphConv 3 | from torch import nn 4 | 5 | # Goggle 6 | # from .RGCNConv import RGCNConv 7 | 8 | 9 | class GraphDecoderHomo(nn.Module): 10 | def __init__(self, decoder_dim, decoder_l, decoder_arch, device): 11 | super(GraphDecoderHomo, self).__init__() 12 | decoder = nn.ModuleList([]) 13 | 14 | if decoder_arch == "gcn": 15 | for i in range(decoder_l): 16 | if i == decoder_l - 1: 17 | decoder.append( 18 | GraphConv(decoder_dim, 1, norm="both", weight=True, bias=True) 19 | ) 20 | else: 21 | decoder_dim_ = int(decoder_dim / 2) 22 | decoder.append( 23 | GraphConv( 24 | decoder_dim, 25 | decoder_dim_, 26 | norm="both", 27 | weight=True, 28 | bias=True, 29 | activation=nn.Tanh(), 30 | ) 31 | ) 32 | decoder_dim = decoder_dim_ 33 | elif decoder_arch == "sage": 34 | for i in range(decoder_l): 35 | if i == decoder_l - 1: 36 | decoder.append( 37 | SAGEConv(decoder_dim, 1, aggregator_type="mean", bias=True) 38 | ) 39 | else: 40 | decoder_dim_ = int(decoder_dim / 2) 41 | decoder.append( 42 | SAGEConv( 43 | decoder_dim, 44 | decoder_dim_, 45 | aggregator_type="mean", 46 | bias=True, 47 | activation=nn.Tanh(), 48 | ) 49 | ) 50 | decoder_dim = decoder_dim_ 51 | else: 52 | raise Exception("decoder can only be {het|gcn|sage}") 53 | 54 | self.decoder = nn.Sequential(*decoder) 55 | 56 | def forward(self, graph_input, b_size): 57 | b_z, b_adj, b_edge_weight = graph_input 58 | 59 | for layer in self.decoder: 60 | b_z = layer(b_adj, feat=b_z, edge_weight=b_edge_weight) 61 | 62 | x_hat = b_z.reshape(b_size, -1) 63 | 64 | return x_hat 65 | 66 | 67 | class GraphDecoderHet(nn.Module): 68 | def __init__(self, decoder_dim, decoder_l, n_edge_types, device): 69 | super(GraphDecoderHet, self).__init__() 70 | decoder = nn.ModuleList([]) 71 | 72 | for i in range(decoder_l): 73 | if i == decoder_l - 1: 74 | decoder.append( 75 | RGCNConv( 76 | decoder_dim, 77 | 1, 78 | num_relations=n_edge_types + 1, 79 | root_weight=False, 80 | ) 81 | ) 82 | else: 83 | decoder_dim_ = int(decoder_dim / 2) 84 | decoder.append( 85 | RGCNConv( 86 | decoder_dim, 87 | decoder_dim_, 88 | num_relations=n_edge_types + 1, 89 | root_weight=False, 90 | ) 91 | ) 92 | decoder.append(nn.ReLU()) 93 | decoder_dim = decoder_dim_ 94 | 95 | self.decoder = nn.Sequential(*decoder) 96 | 97 | def forward(self, graph_input, b_size): 98 | b_z, b_edge_index, b_edge_weights, b_edge_types = graph_input 99 | 100 | h = b_z 101 | for layer in self.decoder: 102 | if not isinstance(layer, nn.ReLU): 103 | h = layer(h, b_edge_index, b_edge_types, b_edge_weights) 104 | else: 105 | h = layer(h) 106 | 107 | x_hat = h.reshape(b_size, -1) 108 | 109 | return x_hat -------------------------------------------------------------------------------- /baselines/goggle/model/GraphInputProcessor.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | from torch import nn 4 | from torch_geometric.utils import dense_to_sparse 5 | 6 | 7 | class GraphInputProcessorHomo(nn.Module): 8 | def __init__(self, input_dim, decoder_dim, het_encoding, device): 9 | super(GraphInputProcessorHomo, self).__init__() 10 | self.device = device 11 | self.het_encoding = het_encoding 12 | 13 | if het_encoding: 14 | feat_dim = input_dim + 1 15 | else: 16 | feat_dim = 1 17 | 18 | self.embedding_functions = [] 19 | for _ in range(input_dim): 20 | self.embedding_functions.append( 21 | nn.Sequential(nn.Linear(feat_dim, decoder_dim), nn.Tanh()).to(device) 22 | ) 23 | 24 | def forward(self, z, adj): 25 | """ 26 | Prepares embeddings for graph decoding 27 | Parameters: 28 | z (Tensor): feature embeddings 29 | adj (Tensor): adjacency matrix 30 | iter (int): training iteration 31 | Returns: 32 | b_z (Tensor): dense feature matrix, shape = (b_size*n_nodes, n_feats) 33 | b_adj (Tensor): batched adjacency matrix 34 | b_edge_weight (Sparse Tensor): sparse edge weights, shape = (n_edges) 35 | """ 36 | b_z = z.unsqueeze(-1) 37 | b_size, n_nodes, _ = b_z.shape 38 | 39 | if self.het_encoding: 40 | one_hot_encoding = torch.eye(n_nodes).to(self.device) 41 | b_encoding = torch.stack([one_hot_encoding for _ in range(b_size)], dim=0) 42 | b_z = torch.cat([b_z, b_encoding], dim=-1) 43 | 44 | b_z = [f(b_z[:, i]) for i, f in enumerate(self.embedding_functions)] 45 | b_z = torch.stack(b_z, dim=1) 46 | b_z = torch.flatten(b_z, start_dim=0, end_dim=1) 47 | 48 | edge_index = adj.nonzero().t() 49 | row, col = edge_index 50 | edge_weight = adj[row, col] 51 | 52 | g = dgl.graph((edge_index[0], edge_index[1])) 53 | b_adj = dgl.batch([g] * b_size) 54 | b_edge_weight = edge_weight.repeat(b_size) 55 | 56 | return (b_z, b_adj, b_edge_weight) 57 | 58 | 59 | class GraphInputProcessorHet(nn.Module): 60 | def __init__(self, input_dim, decoder_dim, n_edge_types, het_encoding, device): 61 | super(GraphInputProcessorHet, self).__init__() 62 | self.n_edge_types = n_edge_types 63 | self.device = device 64 | self.het_encoding = het_encoding 65 | 66 | if het_encoding: 67 | feat_dim = input_dim + 1 68 | else: 69 | feat_dim = 1 70 | 71 | self.embedding_functions = [] 72 | for _ in range(input_dim): 73 | self.embedding_functions.append( 74 | nn.Sequential(nn.Linear(feat_dim, decoder_dim), nn.Tanh()).to(device) 75 | ) 76 | 77 | def forward(self, z, adj): 78 | """ 79 | Prepares embeddings for graph decoding 80 | Parameters: 81 | z (Tensor): feature embeddings 82 | adj (Tensor): adjacency matrix 83 | het_encoding (bool): use of heterogeneous encoding 84 | Returns: 85 | b_z (Tensor): dense feature matrix, shape = (b_size*n_nodes, n_feats) 86 | b_adj (Tensor): batched adjacency matrix, shape = (b_size, n_nodes, n_nodes) 87 | b_edge_index (Sparse Tensor): sparse edge index, shape = (2, n_edges) 88 | b_edge_weights (Sparse Tensor): sparse edge weights, shape = (n_edges) 89 | b_edge_types (Sparse Tensor): sparse edge type, shape = (n_edges) 90 | """ 91 | b_size, n_nodes = z.shape 92 | 93 | b_z = z.unsqueeze(-1) 94 | 95 | if self.het_encoding: 96 | one_hot_encoding = torch.eye(n_nodes).to(self.device) 97 | b_encoding = torch.stack([one_hot_encoding for _ in range(b_size)], dim=0) 98 | b_z = torch.cat([b_z, b_encoding], dim=-1) 99 | 100 | b_z = [f(b_z[:, i]) for i, f in enumerate(self.embedding_functions)] 101 | b_z = torch.stack(b_z, dim=1) 102 | b_size, n_nodes, n_feats = b_z.shape 103 | 104 | n_edge_types = self.n_edge_types 105 | edge_types = torch.arange(1, n_edge_types + 1, 1).reshape(n_nodes, n_nodes) 106 | 107 | b_adj = torch.stack([adj for _ in range(b_size)], dim=0) 108 | 109 | b_edge_index, b_edge_weights = dense_to_sparse(b_adj) 110 | r, c = b_edge_index 111 | b_edge_types = edge_types[r % n_nodes, c % n_nodes] 112 | b_z = b_z.reshape(b_size * n_nodes, n_feats) 113 | 114 | return (b_z, b_edge_index, b_edge_weights, b_edge_types) -------------------------------------------------------------------------------- /baselines/goggle/model/LearnedGraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LearnedGraph(nn.Module): 6 | def __init__(self, input_dim, graph_prior, prior_mask, threshold, device): 7 | super(LearnedGraph, self).__init__() 8 | 9 | self.graph = nn.Parameter( 10 | torch.zeros(input_dim, input_dim, requires_grad=True, device=device) 11 | ) 12 | 13 | if all(i is not None for i in [graph_prior, prior_mask]): 14 | self.graph_prior = ( 15 | graph_prior.detach().clone().requires_grad_(False).to(device) 16 | ) 17 | self.prior_mask = ( 18 | prior_mask.detach().clone().requires_grad_(False).to(device) 19 | ) 20 | self.use_prior = True 21 | else: 22 | self.use_prior = False 23 | 24 | self.act = nn.Sigmoid() 25 | self.threshold = nn.Threshold(threshold, 0) 26 | self.device = device 27 | 28 | def forward(self, iter): 29 | if self.use_prior: 30 | graph = ( 31 | self.prior_mask * self.graph_prior + (1 - self.prior_mask) * self.graph 32 | ) 33 | else: 34 | graph = self.graph 35 | 36 | graph = self.act(graph) 37 | graph = graph.clone() 38 | graph = graph * ( 39 | torch.ones(graph.shape[0]).to(self.device) 40 | - torch.eye(graph.shape[0]).to(self.device) 41 | ) + torch.eye(graph.shape[0]).to(self.device) 42 | 43 | if iter > 50: 44 | graph = self.threshold(graph) 45 | else: 46 | graph = graph 47 | 48 | return graph -------------------------------------------------------------------------------- /baselines/goggle/model/__pycache__/Encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/goggle/model/__pycache__/Encoder.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/goggle/model/__pycache__/Goggle.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/goggle/model/__pycache__/Goggle.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/goggle/model/__pycache__/GoggleLoss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/goggle/model/__pycache__/GoggleLoss.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/goggle/model/__pycache__/GraphDecoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/goggle/model/__pycache__/GraphDecoder.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/goggle/model/__pycache__/GraphInputProcessor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/goggle/model/__pycache__/GraphInputProcessor.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/goggle/model/__pycache__/LearnedGraph.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/goggle/model/__pycache__/LearnedGraph.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/goggle/sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import torch 4 | 5 | import argparse 6 | import warnings 7 | import json 8 | import time 9 | from utils_train import preprocess 10 | 11 | from baselines.goggle.GoggleModel import GoggleModel 12 | import json 13 | 14 | warnings.filterwarnings('ignore') 15 | 16 | 17 | def recover_data(syn_num, syn_cat, info): 18 | 19 | target_col_idx = info['target_col_idx'] 20 | if info['task_type'] == 'regression': 21 | syn_target = syn_num[:, :len(target_col_idx)] 22 | syn_num = syn_num[:, len(target_col_idx):] 23 | 24 | else: 25 | print(syn_cat.shape) 26 | syn_target = syn_cat[:, :len(target_col_idx)] 27 | syn_cat = syn_cat[:, len(target_col_idx):] 28 | 29 | 30 | num_col_idx = info['num_col_idx'] 31 | cat_col_idx = info['cat_col_idx'] 32 | target_col_idx = info['target_col_idx'] 33 | 34 | 35 | idx_mapping = info['idx_mapping'] 36 | idx_mapping = {int(key): value for key, value in idx_mapping.items()} 37 | 38 | syn_df = pd.DataFrame() 39 | 40 | if info['task_type'] == 'regression': 41 | for i in range(len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)): 42 | if i in set(num_col_idx): 43 | syn_df[i] = syn_num[:, idx_mapping[i]] 44 | elif i in set(cat_col_idx): 45 | syn_df[i] = syn_cat[:, idx_mapping[i] - len(num_col_idx)] 46 | else: 47 | syn_df[i] = syn_target[:, idx_mapping[i] - len(num_col_idx) - len(cat_col_idx)] 48 | 49 | 50 | else: 51 | for i in range(len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)): 52 | if i in set(num_col_idx): 53 | syn_df[i] = syn_num[:, idx_mapping[i]] 54 | elif i in set(cat_col_idx): 55 | syn_df[i] = syn_cat[:, idx_mapping[i] - len(num_col_idx)] 56 | else: 57 | syn_df[i] = syn_target[:, idx_mapping[i] - len(num_col_idx) - len(cat_col_idx)] 58 | 59 | return syn_df 60 | 61 | 62 | def main(args): 63 | dataname = args.dataname 64 | device = args.device 65 | save_path = args.save_path 66 | 67 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 68 | dataset_dir = f'data/{dataname}' 69 | ckpt_dir = f'{curr_dir}/ckpt/{dataname}' 70 | 71 | if not os.path.exists(ckpt_dir): 72 | os.makedirs(ckpt_dir) 73 | 74 | with open(f'{dataset_dir}/info.json', 'r') as f: 75 | info = json.load(f) 76 | 77 | task_type = info['task_type'] 78 | 79 | dataset = preprocess(dataset_dir, task_type = task_type, cat_encoding = 'one-hot') 80 | X_train = torch.tensor(dataset.X_num['train']) 81 | 82 | num_inverse = dataset.num_transform.inverse_transform 83 | cat_inverse = dataset.cat_transform.inverse_transform 84 | 85 | 86 | gen = GoggleModel( 87 | ds_name=dataname, 88 | input_dim=X_train.shape[1], 89 | encoder_dim=2048, 90 | encoder_l=4, 91 | het_encoding=True, 92 | decoder_dim=2048, 93 | decoder_l=4, 94 | threshold=0.1, 95 | decoder_arch="gcn", 96 | graph_prior=None, 97 | prior_mask=None, 98 | device=device, 99 | beta=1, 100 | learning_rate=0.01, 101 | seed=42, 102 | ) 103 | 104 | gen.model.load_state_dict(torch.load(f'{ckpt_dir}/model.pt')) 105 | 106 | start_time = time.time() 107 | 108 | samples = gen.sample(X_train) 109 | 110 | task_type = info['task_type'] 111 | num_col_idx = info['num_col_idx'] 112 | cat_col_idx = info['cat_col_idx'] 113 | target_col_idx = info['target_col_idx'] 114 | 115 | n_num_feat = len(num_col_idx) 116 | n_cat_feat = len(cat_col_idx) 117 | 118 | if task_type == 'regression': 119 | n_num_feat += len(target_col_idx) 120 | else: 121 | n_cat_feat += len(target_col_idx) 122 | 123 | syn_data_num = samples[:, :n_num_feat] 124 | cat_sample = samples[:, n_num_feat:] 125 | 126 | syn_num = num_inverse(syn_data_num) 127 | syn_cat = cat_inverse(cat_sample) 128 | 129 | syn_df = recover_data(syn_num, syn_cat, info) 130 | 131 | idx_name_mapping = info['idx_name_mapping'] 132 | idx_name_mapping = {int(key): value for key, value in idx_name_mapping.items()} 133 | 134 | syn_df.rename(columns = idx_name_mapping, inplace=True) 135 | syn_df.to_csv(save_path, index = False) 136 | 137 | end_time = time.time() 138 | print(f'Sampling time = {end_time - start_time}') 139 | print('Saving sampled data to {}'.format(save_path)) 140 | 141 | 142 | if __name__ == '__main__': 143 | parser = argparse.ArgumentParser(description='Training ') 144 | 145 | parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.') 146 | parser.add_argument('--gpu', type=int, default=0, help='GPU index.') 147 | 148 | args = parser.parse_args() 149 | 150 | # check cuda 151 | if args.gpu != -1 and torch.cuda.is_available(): 152 | args.device = f'cuda:{args.gpu}' 153 | else: 154 | args.device = 'cpu' -------------------------------------------------------------------------------- /baselines/great/main.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import os 4 | import argparse 5 | 6 | from baselines.great.models.great import GReaT 7 | 8 | def main(args): 9 | 10 | dataname = args.dataname 11 | batch_size = args.bs 12 | dataset_path = f'data/{dataname}/train.csv' 13 | train_df = pd.read_csv(dataset_path) 14 | 15 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 16 | ckpt_dir = f'{curr_dir}/ckpt/{dataname}' 17 | 18 | if not os.path.exists(ckpt_dir): 19 | os.makedirs(ckpt_dir) 20 | 21 | great = GReaT("distilgpt2", 22 | epochs=100, 23 | save_steps=2000, 24 | logging_steps=50, 25 | experiment_dir=f"{curr_dir}/ckpt/{dataname}", 26 | batch_size=batch_size, 27 | #lr_scheduler_type="constant", # Specify the learning rate scheduler 28 | #learning_rate=5e-5 # Set the inital learning rate 29 | ) 30 | 31 | trainer = great.fit(train_df) 32 | great.save(ckpt_dir) 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser(description='GReaT') 37 | 38 | parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.') 39 | parser.add_argument('--bs', type=int, default=16, help='(Maximum) batch size') 40 | args = parser.parse_args() -------------------------------------------------------------------------------- /baselines/great/models/__pycache__/great.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/great/models/__pycache__/great.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/great/models/__pycache__/great_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/great/models/__pycache__/great_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/great/models/__pycache__/great_start.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/great/models/__pycache__/great_start.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/great/models/__pycache__/great_trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/great/models/__pycache__/great_trainer.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/great/models/__pycache__/great_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/great/models/__pycache__/great_utils.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/great/models/great_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import typing as tp 3 | 4 | from datasets import Dataset 5 | from dataclasses import dataclass 6 | from transformers import DataCollatorWithPadding 7 | 8 | 9 | class GReaTDataset(Dataset): 10 | """ GReaT Dataset 11 | 12 | The GReaTDataset overwrites the _getitem function of the HuggingFace Dataset Class to include the permutation step. 13 | 14 | Attributes: 15 | tokenizer (AutoTokenizer): Tokenizer from HuggingFace 16 | """ 17 | def set_tokenizer(self, tokenizer): 18 | """ Set the Tokenizer 19 | 20 | Args: 21 | tokenizer: Tokenizer from HuggingFace 22 | """ 23 | self.tokenizer = tokenizer 24 | 25 | def _getitem(self, key: tp.Union[int, slice, str], decoded: bool = True, **kwargs) -> tp.Union[tp.Dict, tp.List]: 26 | """ Get Item from Tabular Data 27 | 28 | Get one instance of the tabular data, permuted, converted to text and tokenized. 29 | """ 30 | # If int, what else? 31 | row = self._data.fast_slice(key, 1) 32 | 33 | shuffle_idx = list(range(row.num_columns)) 34 | random.shuffle(shuffle_idx) 35 | 36 | shuffled_text = ", ".join( 37 | ["%s is %s" % (row.column_names[i], str(row.columns[i].to_pylist()[0]).strip()) for i in shuffle_idx] 38 | ) 39 | 40 | tokenized_text = self.tokenizer(shuffled_text) 41 | return tokenized_text 42 | 43 | def __getitems__(self, keys: tp.Union[int, slice, str, list]): 44 | if isinstance(keys, list): 45 | return [self._getitem(key) for key in keys] 46 | else: 47 | return self._getitem(keys) 48 | 49 | @dataclass 50 | class GReaTDataCollator(DataCollatorWithPadding): 51 | """ GReaT Data Collator 52 | 53 | Overwrites the DataCollatorWithPadding to also pad the labels and not only the input_ids 54 | """ 55 | def __call__(self, features: tp.List[tp.Dict[str, tp.Any]]): 56 | batch = self.tokenizer.pad( 57 | features, 58 | padding=self.padding, 59 | max_length=self.max_length, 60 | pad_to_multiple_of=self.pad_to_multiple_of, 61 | return_tensors=self.return_tensors, 62 | ) 63 | batch["labels"] = batch["input_ids"].clone() 64 | return batch -------------------------------------------------------------------------------- /baselines/great/models/great_start.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import typing as tp 4 | 5 | 6 | def _pad(x, length: int, pad_value=50256): 7 | """ 8 | Prepend the pad value until the array reaches the specific length 9 | """ 10 | return [pad_value] * (length - len(x)) + x 11 | 12 | 13 | # 14 | def _pad_tokens(tokens): 15 | """ 16 | Checks that all tensors in the list have the same length, pads them if necessary to the max length 17 | 18 | Args: 19 | tokens: List of Tensors 20 | 21 | Returns: 22 | List of Tensors, where each Tensor has the same length 23 | """ 24 | max_length = len(max(tokens, key=len)) 25 | tokens = [_pad(t, max_length) for t in tokens] 26 | return tokens 27 | 28 | 29 | class GReaTStart: 30 | """ Abstract super class GReaT Start 31 | 32 | GReaT Start creates tokens to start the generation process. 33 | 34 | Attributes: 35 | tokenizer (AutoTokenizer): Tokenizer, automatically downloaded from llm-checkpoint 36 | """ 37 | def __init__(self, tokenizer): 38 | """ 39 | Initializes the super class. 40 | 41 | Args: 42 | tokenizer: Tokenizer from the HuggingFace library 43 | """ 44 | self.tokenizer = tokenizer 45 | 46 | def get_start_tokens(self, n_samples: int) -> tp.List[tp.List[int]]: 47 | """ Get Start Tokens 48 | 49 | Creates starting points for the generation process 50 | 51 | Args: 52 | n_samples: Number of start prompts to create 53 | 54 | Returns: 55 | List of n_sample lists with tokens 56 | """ 57 | raise NotImplementedError("This has to be overwritten but the subclasses") 58 | 59 | 60 | class CategoricalStart(GReaTStart): 61 | """ Categorical Starting Feature 62 | 63 | A categorical column with its categories is used as starting point. 64 | 65 | Attributes: 66 | start_col (str): Name of the categorical column 67 | population (list[str]): Possible values the column can take 68 | weights (list[float]): Probabilities for the individual categories 69 | 70 | """ 71 | def __init__(self, tokenizer, start_col: str, start_col_dist: dict): 72 | """ Initializes the Categorical Start 73 | 74 | Args: 75 | tokenizer: Tokenizer from the HuggingFace library 76 | start_col: Name of the categorical column 77 | start_col_dist: Distribution of the categorical column (dict of form {"Cat A": 0.8, "Cat B": 0.2}) 78 | """ 79 | super().__init__(tokenizer) 80 | 81 | assert isinstance(start_col, str), "" 82 | assert isinstance(start_col_dist, dict), "" 83 | 84 | self.start_col = start_col 85 | self.population = list(start_col_dist.keys()) 86 | self.weights = list(start_col_dist.values()) 87 | 88 | def get_start_tokens(self, n_samples): 89 | start_words = random.choices(self.population, self.weights, k=n_samples) 90 | start_text = [self.start_col + " is " + str(s) + "," for s in start_words] 91 | start_tokens = _pad_tokens(self.tokenizer(start_text)["input_ids"]) 92 | return start_tokens 93 | 94 | 95 | class ContinuousStart(GReaTStart): 96 | """ Continuous Starting Feature 97 | 98 | A continuous column with some noise is used as starting point. 99 | 100 | Attributes: 101 | start_col (str): Name of the continuous column 102 | start_col_dist (list[float]): The continuous column from the train data set 103 | noise (float): Size of noise that is added to each value 104 | decimal_places (int): Number of decimal places the continuous values have 105 | """ 106 | def __init__(self, tokenizer, start_col: str, start_col_dist: tp.List[float], 107 | noise: float = .01, decimal_places: int = 5): 108 | """ Initializes the Continuous Start 109 | 110 | Args: 111 | tokenizer: Tokenizer from the HuggingFace library 112 | start_col: Name of the continuous column 113 | start_col_dist: The continuous column from the train data set 114 | noise: Size of noise that is added to each value 115 | decimal_places: Number of decimal places the continuous values have 116 | """ 117 | super().__init__(tokenizer) 118 | 119 | assert isinstance(start_col, str), "" 120 | assert isinstance(start_col_dist, list), "" 121 | 122 | self.start_col = start_col 123 | self.start_col_dist = start_col_dist 124 | self.noise = noise 125 | self.decimal_places = decimal_places 126 | 127 | def get_start_tokens(self, n_samples): 128 | start_words = random.choices(self.start_col_dist, k=n_samples) 129 | # start_words += np.random.normal(size=n_samples) * self.noise # add noise to start words 130 | start_text = [self.start_col + " is " + format(s, f".{self.decimal_places}f") + "," for s in start_words] 131 | start_tokens = _pad_tokens(self.tokenizer(start_text)["input_ids"]) 132 | return start_tokens 133 | 134 | 135 | class RandomStart(GReaTStart): 136 | """ Random Starting Features 137 | 138 | Random column names are used as start point. Can be used if no distribution of any column is known. 139 | 140 | Attributes: 141 | all_columns (List[str]): Names of all columns 142 | """ 143 | def __init__(self, tokenizer, all_columns: tp.List[str]): 144 | """ Initializes the Random Start 145 | 146 | Args: 147 | tokenizer: Tokenizer from the HuggingFace library 148 | all_columns: Names of all columns 149 | """ 150 | super().__init__(tokenizer) 151 | self.all_columns = all_columns 152 | 153 | def get_start_tokens(self, n_samples): 154 | start_words = random.choices(self.all_columns, k=n_samples) 155 | start_text = [s + " is " for s in start_words] 156 | start_tokens = _pad_tokens(self.tokenizer(start_text)["input_ids"]) 157 | return start_tokens -------------------------------------------------------------------------------- /baselines/great/models/great_trainer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from transformers import Trainer 8 | 9 | 10 | def _seed_worker(_): 11 | """ 12 | Helper function to set worker seed during Dataloader initialization. 13 | """ 14 | worker_seed = torch.initial_seed() % 2**32 15 | random.seed(worker_seed) 16 | np.random.seed(worker_seed) 17 | torch.manual_seed(worker_seed) 18 | torch.cuda.manual_seed_all(worker_seed) 19 | 20 | 21 | class GReaTTrainer(Trainer): 22 | """ GReaT Trainer 23 | 24 | Overwrites the get_train_dataloader methode of the HuggingFace Trainer to not remove the "unused" columns - 25 | they are needed later! 26 | """ 27 | def get_train_dataloader(self) -> DataLoader: 28 | if self.train_dataset is None: 29 | raise ValueError("Trainer: training requires a train_dataset.") 30 | 31 | data_collator = self.data_collator 32 | train_dataset = self.train_dataset # self._remove_unused_columns(self.train_dataset, description="training") 33 | train_sampler = self._get_train_sampler() 34 | 35 | return DataLoader( 36 | train_dataset, 37 | batch_size=self._train_batch_size, 38 | sampler=train_sampler, 39 | collate_fn=data_collator, 40 | drop_last=self.args.dataloader_drop_last, 41 | num_workers=self.args.dataloader_num_workers, 42 | pin_memory=self.args.dataloader_pin_memory, 43 | worker_init_fn=_seed_worker, 44 | ) -------------------------------------------------------------------------------- /baselines/great/models/great_utils.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | 7 | from transformers import AutoTokenizer 8 | 9 | 10 | def _array_to_dataframe(data: tp.Union[pd.DataFrame, np.ndarray], columns=None) -> pd.DataFrame: 11 | """ Converts a Numpy Array to a Pandas DataFrame 12 | 13 | Args: 14 | data: Pandas DataFrame or Numpy NDArray 15 | columns: If data is a Numpy Array, columns needs to be a list of all column names 16 | 17 | Returns: 18 | Pandas DataFrame with the given data 19 | """ 20 | if isinstance(data, pd.DataFrame): 21 | return data 22 | 23 | assert isinstance(data, np.ndarray), "Input needs to be a Pandas DataFrame or a Numpy NDArray" 24 | assert columns, "To convert the data into a Pandas DataFrame, a list of column names has to be given!" 25 | assert len(columns) == len(data[0]), \ 26 | "%d column names are given, but array has %d columns!" % (len(columns), len(data[0])) 27 | 28 | return pd.DataFrame(data=data, columns=columns) 29 | 30 | 31 | def _get_column_distribution(df: pd.DataFrame, col: str) -> tp.Union[list, dict]: 32 | """ Returns the distribution of a given column. If continuous, returns a list of all values. 33 | If categorical, returns a dictionary in form {"A": 0.6, "B": 0.4} 34 | 35 | Args: 36 | df: pandas DataFrame 37 | col: name of the column 38 | 39 | Returns: 40 | Distribution of the column 41 | """ 42 | if df[col].dtype == "float": 43 | col_dist = df[col].to_list() 44 | else: 45 | col_dist = df[col].value_counts(1).to_dict() 46 | return col_dist 47 | 48 | 49 | def _convert_tokens_to_text(tokens: tp.List[torch.Tensor], tokenizer: AutoTokenizer) -> tp.List[str]: 50 | """ Decodes the tokens back to strings 51 | 52 | Args: 53 | tokens: List of tokens to decode 54 | tokenizer: Tokenizer used for decoding 55 | 56 | Returns: 57 | List of decoded strings 58 | """ 59 | # Convert tokens to text 60 | text_data = [tokenizer.decode(t) for t in tokens] 61 | 62 | # Clean text 63 | text_data = [d.replace("<|endoftext|>", "") for d in text_data] 64 | text_data = [d.replace("\n", " ") for d in text_data] 65 | text_data = [d.replace("\r", "") for d in text_data] 66 | 67 | return text_data 68 | 69 | 70 | def _convert_text_to_tabular_data(text: tp.List[str], df_gen: pd.DataFrame) -> pd.DataFrame: 71 | """ Converts the sentences back to tabular data 72 | 73 | Args: 74 | text: List of the tabular data in text form 75 | df_gen: Pandas DataFrame where the tabular data is appended 76 | 77 | Returns: 78 | Pandas DataFrame with the tabular data from the text appended 79 | """ 80 | columns = df_gen.columns.to_list() 81 | 82 | # Convert text to tabular data 83 | for t in text: 84 | features = t.split(",") 85 | td = dict.fromkeys(columns) 86 | 87 | # Transform all features back to tabular data 88 | for f in features: 89 | values = f.strip().split(" is ") 90 | if values[0] in columns and not td[values[0]]: 91 | try: 92 | td[values[0]] = [values[1]] 93 | except IndexError: 94 | #print("An Index Error occurred - if this happends a lot, consider fine-tuning your model further.") 95 | pass 96 | df_gen = pd.concat([df_gen, pd.DataFrame(td)], ignore_index=True, axis=0) 97 | return df_gen 98 | 99 | class bcolors: 100 | HEADER = '\033[95m' 101 | OKBLUE = '\033[94m' 102 | OKGREEN = '\033[92m' 103 | WARNING = '\033[93m' 104 | FAIL = '\033[91m' 105 | ENDC = '\033[0m' -------------------------------------------------------------------------------- /baselines/great/post_process.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import json 4 | 5 | def add_space_before_string(s): 6 | for _ in range(len(s)): 7 | s = s.strip(' ') 8 | 9 | return ' ' + s 10 | 11 | def post_process_adult(syn_path): 12 | dataname = 'adult' 13 | 14 | syn_path = f'synthetic/{dataname}/great_{i}.csv' 15 | 16 | data_dir = f'data/{dataname}' 17 | info_path = f'{data_dir}/info.json' 18 | 19 | with open(info_path, 'r') as f: 20 | info = json.load(f) 21 | 22 | cat_col_idx = info['cat_col_idx'] 23 | 24 | syn_data = pd.read_csv(syn_path) 25 | columns = syn_data.columns 26 | 27 | for i, name in enumerate(columns): 28 | if i in cat_col_idx: 29 | syn_data[name] = syn_data[name].apply(add_space_before_string) 30 | 31 | syn_data.to_csv(syn_path, index=False) 32 | 33 | 34 | -------------------------------------------------------------------------------- /baselines/great/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | 4 | import os 5 | 6 | import argparse 7 | import json 8 | 9 | from baselines.great.models.great import GReaT 10 | from baselines.great.models.great_utils import _array_to_dataframe 11 | 12 | 13 | def main(args): 14 | 15 | dataname = args.dataname 16 | 17 | dataset_path = f'data/{dataname}/train.csv' 18 | info_path = f'data/{dataname}/info.json' 19 | 20 | with open(info_path, 'r') as f: 21 | info = json.load(f) 22 | train_df = pd.read_csv(dataset_path) 23 | 24 | 25 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 26 | 27 | great = GReaT("distilgpt2", 28 | epochs=200, 29 | save_steps=2000, 30 | logging_steps=50, 31 | experiment_dir="ckpt/adult", 32 | batch_size=24, 33 | #lr_scheduler_type="constant", # Specify the learning rate scheduler 34 | #learning_rate=5e-5 # Set the inital learning rate 35 | ) 36 | 37 | model_save_path = f'{curr_dir}/ckpt/{dataname}/model.pt' 38 | great.model.load_state_dict(torch.load(model_save_path)) 39 | 40 | great.load_finetuned_model(f"{curr_dir}/ckpt/{dataname}/model.pt") 41 | 42 | df = _array_to_dataframe(train_df, columns=None) 43 | great._update_column_information(df) 44 | great._update_conditional_information(df, conditional_col=None) 45 | 46 | 47 | n_samples = info['train_num'] 48 | 49 | samples = great.sample(n_samples, k=100, device=args.device) 50 | samples.head() 51 | save_path = args.save_path 52 | samples.to_csv(save_path, index = False) 53 | 54 | 55 | print('Saving sampled data to {}'.format(save_path)) 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser(description='GReaT') 59 | 60 | parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.') 61 | parser.add_argument('--bs', type=int, default=16, help='(Maximum) batch size') 62 | args = parser.parse_args() -------------------------------------------------------------------------------- /baselines/great/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class CustomFormatter(logging.Formatter): 5 | 6 | grey = "\x1b[39;20m" 7 | yellow = "\x1b[33;20m" 8 | red = "\x1b[31;20m" 9 | bold_red = "\x1b[31;1m" 10 | reset = "\x1b[0m" 11 | format = "%(asctime)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)" 12 | 13 | FORMATS = { 14 | logging.DEBUG: grey + format + reset, 15 | logging.INFO: grey + format + reset, 16 | logging.WARNING: yellow + format + reset, 17 | logging.ERROR: red + format + reset, 18 | logging.CRITICAL: bold_red + format + reset 19 | } 20 | 21 | def format(self, record): 22 | log_fmt = self.FORMATS.get(record.levelno) 23 | formatter = logging.Formatter(log_fmt) 24 | return formatter.format(record) 25 | 26 | 27 | def set_logging_level(level=logging.INFO): 28 | logger = logging.getLogger() 29 | logger.setLevel(level) 30 | 31 | ch = logging.StreamHandler() 32 | ch.setLevel(level) 33 | ch.setFormatter(CustomFormatter()) 34 | 35 | logger.addHandler(ch) 36 | 37 | return logger -------------------------------------------------------------------------------- /baselines/stasy/configs/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/configs/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/stasy/configs/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/configs/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /baselines/stasy/configs/__pycache__/default_tabular_configs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/configs/__pycache__/default_tabular_configs.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/stasy/configs/__pycache__/default_tabular_configs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/configs/__pycache__/default_tabular_configs.cpython-39.pyc -------------------------------------------------------------------------------- /baselines/stasy/configs/config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | 18 | from baselines.stasy.configs.default_tabular_configs import get_default_configs 19 | 20 | 21 | def get_config(name): 22 | config = get_default_configs() 23 | 24 | config.data.dataset = name 25 | config.training.batch_size = 1000 26 | config.eval.batch_size = 1000 27 | config.data.image_size = 77 28 | 29 | # training 30 | training = config.training 31 | training.sde = 'vesde' 32 | training.continuous = True 33 | training.reduce_mean = True 34 | training.n_iters = 100000 35 | training.tolerance = 1e-03 36 | training.hutchinson_type = "Rademacher" 37 | training.retrain_type = "median" 38 | 39 | # sampling 40 | sampling = config.sampling 41 | sampling.method = 'ode' 42 | sampling.predictor = 'euler_maruyama' 43 | sampling.corrector = 'none' 44 | 45 | # model 46 | model = config.model 47 | model.layer_type = 'concatsquash' 48 | model.name = 'ncsnpp_tabular' 49 | model.scale_by_sigma = False 50 | model.ema_rate = 0.9999 51 | model.activation = 'elu' 52 | 53 | model.nf = 64 54 | model.hidden_dims = (1024, 2048, 1024, 1024) 55 | # model.hidden_dims = (256, 512, 1024, 1024, 512, 256) 56 | model.conditional = True 57 | model.embedding_type = 'fourier' 58 | model.fourier_scale = 16 59 | model.conv_size = 3 60 | 61 | model.sigma_min = 0.01 62 | model.sigma_max = 10. 63 | 64 | # test 65 | test = config.test 66 | test.n_iter = 1 67 | 68 | # optim 69 | optim = config.optim 70 | optim.lr = 2e-3 71 | 72 | 73 | return config -------------------------------------------------------------------------------- /baselines/stasy/configs/default_tabular_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | 4 | 5 | def get_default_configs(): 6 | config = ml_collections.ConfigDict() 7 | 8 | config.seed = 42 9 | config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 10 | config.baseline = False 11 | 12 | # training 13 | config.training = training = ml_collections.ConfigDict() 14 | config.training.batch_size = 1000 15 | training.epoch = 10000 16 | training.snapshot_freq = 300 17 | training.eval_freq = 100 18 | training.snapshot_freq_for_preemption = 100 19 | training.snapshot_sampling = True 20 | training.likelihood_weighting = False 21 | training.continuous = True 22 | training.reduce_mean = False 23 | training.eps = 1e-05 24 | training.loss_weighting = False 25 | training.spl = True 26 | training.lambda_ = 0.5 27 | 28 | #fine_tune 29 | training.eps_iters = 50 30 | training.fine_tune_epochs = 50 31 | training.retrain_type = 'median' 32 | training.hutchinson_type = 'Rademacher' 33 | training.tolerance = 1e-03 34 | 35 | # sampling 36 | config.sampling = sampling = ml_collections.ConfigDict() 37 | sampling.n_steps_each = 1 38 | sampling.noise_removal = True 39 | sampling.probability_flow = False 40 | sampling.snr = 0.16 41 | 42 | # evaluation 43 | config.eval = evaluate = ml_collections.ConfigDict() 44 | evaluate.num_samples = 22560 45 | 46 | # data 47 | config.data = data = ml_collections.ConfigDict() 48 | data.centered = False 49 | data.uniform_dequantization = False 50 | 51 | # model 52 | config.model = model = ml_collections.ConfigDict() 53 | model.sigma_min = 0.01 54 | model.sigma_max = 10. 55 | model.num_scales = 50 56 | model.alpha0 = 0.3 57 | model.beta0 = 0.95 58 | 59 | # optimization 60 | config.optim = optim = ml_collections.ConfigDict() 61 | optim.weight_decay = 0 62 | optim.optimizer = 'Adam' 63 | optim.lr = 2e-3 64 | optim.beta1 = 0.9 65 | optim.eps = 1e-8 66 | optim.warmup = 5000 67 | optim.grad_clip = 1. 68 | 69 | # test 70 | config.test = test = ml_collections.ConfigDict() 71 | 72 | return config -------------------------------------------------------------------------------- /baselines/stasy/datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Return training and evaluation/test datasets from config files.""" 18 | import torch 19 | import numpy as np 20 | 21 | def get_data_scaler(config): 22 | """Data normalizer. Assume data are always in [0, 1].""" 23 | if config.data.centered: 24 | return lambda x: x * 2. - 1. 25 | else: 26 | return lambda x: x 27 | 28 | 29 | def get_data_inverse_scaler(config): 30 | """Inverse data normalizer.""" 31 | if config.data.centered: 32 | return lambda x: (x + 1.) / 2. 33 | else: 34 | return lambda x: x 35 | 36 | -------------------------------------------------------------------------------- /baselines/stasy/likelihood.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | # pytype: skip-file 18 | """Various sampling methods.""" 19 | 20 | import torch 21 | import numpy as np 22 | from scipy import integrate 23 | from models import utils as mutils 24 | 25 | 26 | def get_div_fn(fn): 27 | """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator.""" 28 | 29 | def div_fn(x, t, eps): 30 | 31 | grad_fn_eps_list = [] 32 | for epsilon in eps: 33 | with torch.enable_grad(): 34 | x.requires_grad_(True) 35 | fn_eps = torch.sum(fn(x, t) * epsilon) 36 | grad_fn_eps = torch.autograd.grad(fn_eps, x)[0] 37 | 38 | x.requires_grad_(False) 39 | grad_fn_eps_list.append(torch.sum(grad_fn_eps * epsilon, dim=tuple(range(1, len(x.shape))))) 40 | 41 | return torch.mean(torch.stack(grad_fn_eps_list), 0) 42 | 43 | 44 | return div_fn 45 | 46 | 47 | def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher', 48 | rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5): 49 | 50 | def drift_fn(model, x, t): 51 | """The drift function of the reverse-time SDE.""" 52 | score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True) 53 | # Probability flow ODE is a special case of Reverse SDE 54 | rsde = sde.reverse(score_fn, probability_flow=True) 55 | return rsde.sde(x, t)[0] 56 | 57 | def div_fn(model, x, t, noise): 58 | return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise) 59 | 60 | def likelihood_fn(model, data, eps_iters=1): 61 | with torch.no_grad(): 62 | shape = data.shape 63 | if hutchinson_type == 'Gaussian': 64 | epsilon = [torch.randn_like(data) for k in range(eps_iters)] 65 | elif hutchinson_type == 'Rademacher': 66 | epsilon = [torch.randint_like(data, low=0, high=2).float() * 2 - 1. for k in range(eps_iters)] 67 | else: 68 | raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") 69 | 70 | def ode_func(t, x): 71 | sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32) 72 | vec_t = torch.ones(sample.shape[0], device=sample.device) * t 73 | drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t)) 74 | logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon)) 75 | return np.concatenate([drift, logp_grad], axis=0) 76 | 77 | init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0) 78 | solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method) 79 | 80 | nfe = solution.nfev 81 | zp = solution.y[:, -1] 82 | 83 | z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float64) 84 | delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float64) 85 | prior_logp = sde.prior_logp(z).view(shape[0], -1).sum(1, keepdim=False) 86 | 87 | ll = prior_logp + delta_logp 88 | return ll, z, nfe 89 | 90 | return likelihood_fn -------------------------------------------------------------------------------- /baselines/stasy/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | import baselines.stasy.datasets as datasets 7 | from baselines.stasy.utils import save_checkpoint, restore_checkpoint, apply_activate 8 | import baselines.stasy.losses as losses 9 | from baselines.stasy.models import ncsnpp_tabular 10 | from baselines.stasy.models import utils as mutils 11 | from baselines.stasy.models.ema import ExponentialMovingAverage 12 | import baselines.stasy.sde_lib as sde_lib 13 | from baselines.stasy.configs.config import get_config 14 | 15 | import os 16 | import json 17 | import argparse 18 | import warnings 19 | import time 20 | 21 | from utils_train import preprocess 22 | 23 | warnings.filterwarnings("ignore") 24 | 25 | 26 | def main(args): 27 | dataname = args.dataname 28 | 29 | config = get_config(dataname) 30 | 31 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 32 | ckpt_dir = f'{curr_dir}/ckpt/{dataname}' 33 | if not os.path.exists(ckpt_dir): 34 | os.makedirs(ckpt_dir) 35 | 36 | dataset_dir = f'data/{dataname}' 37 | 38 | with open(f'{dataset_dir}/info.json', 'r') as f: 39 | info = json.load(f) 40 | 41 | task_type = info['task_type'] 42 | 43 | dataset = preprocess(dataset_dir, task_type = task_type, cat_encoding = 'one-hot') 44 | train_z = torch.tensor(dataset.X_num['train']) 45 | 46 | config.data.image_size = train_z.shape[1] 47 | print(config.data.image_size) 48 | # Initialize model. 49 | config.device = torch.device(f'cuda:{args.gpu}') 50 | score_model = mutils.create_model(config) 51 | print(score_model) 52 | num_params = sum(p.numel() for p in score_model.parameters()) 53 | print("the number of parameters", num_params) 54 | 55 | 56 | ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) 57 | 58 | # optimizer 59 | optimizer = losses.get_optimizer(config, score_model.parameters()) 60 | state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0, epoch=0) 61 | 62 | initial_step = int(state['epoch']) 63 | 64 | batch_size = config.training.batch_size 65 | 66 | shuffle_buffer_size = 10000 67 | num_epochs = None 68 | 69 | 70 | train_data = train_z 71 | train_iter = DataLoader(train_data, 72 | batch_size=config.training.batch_size, 73 | shuffle=True, 74 | num_workers=4) 75 | 76 | 77 | scaler = datasets.get_data_scaler(config) 78 | inverse_scaler = datasets.get_data_inverse_scaler(config) 79 | 80 | # Setup SDEs 81 | if config.training.sde.lower() == 'vpsde': 82 | sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) 83 | sampling_eps = 1e-3 84 | elif config.training.sde.lower() == 'subvpsde': 85 | sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) 86 | sampling_eps = 1e-3 87 | elif config.training.sde.lower() == 'vesde': 88 | sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) 89 | sampling_eps = 1e-5 90 | else: 91 | raise NotImplementedError(f"SDE {config.training.sde} unknown.") 92 | logging.info(score_model) 93 | 94 | 95 | optimize_fn = losses.optimization_manager(config) 96 | continuous = config.training.continuous 97 | reduce_mean = config.training.reduce_mean 98 | likelihood_weighting = config.training.likelihood_weighting 99 | 100 | train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn, 101 | reduce_mean=reduce_mean, continuous=continuous, 102 | likelihood_weighting=likelihood_weighting, workdir=ckpt_dir, spl=config.training.spl, 103 | alpha0=config.model.alpha0, beta0=config.model.beta0) 104 | 105 | best_loss = np.inf 106 | 107 | 108 | for epoch in range(initial_step, config.training.epoch+1): 109 | start_time = time.time() 110 | state['epoch'] += 1 111 | 112 | batch_loss = 0 113 | batch_num = 0 114 | for iteration, batch in enumerate(train_iter): 115 | batch = batch.to(config.device).float() 116 | 117 | num_sample = batch.shape[0] 118 | batch_num += num_sample 119 | loss = train_step_fn(state, batch) 120 | 121 | batch_loss += loss.item() * num_sample 122 | 123 | batch_loss = batch_loss / batch_num 124 | print("epoch: %d, iter: %d, training_loss: %.5e" % (epoch, iteration, batch_loss)) 125 | 126 | if batch_loss < best_loss: 127 | best_loss = batch_loss 128 | save_checkpoint(os.path.join(ckpt_dir, 'model.pth'), state) 129 | 130 | if epoch % 1000 == 0: 131 | save_checkpoint(os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth'), state) 132 | 133 | end_time = time.time() 134 | # print("training time: %.5f" % (end_time - start_time)) 135 | 136 | if __name__ == '__main__': 137 | 138 | parser = argparse.ArgumentParser(description='STASY') 139 | 140 | parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.') 141 | parser.add_argument('--gpu', type=int, default=0, help='GPU device number.') 142 | 143 | args = parser.parse_args() -------------------------------------------------------------------------------- /baselines/stasy/models/__pycache__/ema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/models/__pycache__/ema.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/stasy/models/__pycache__/ema.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/models/__pycache__/ema.cpython-39.pyc -------------------------------------------------------------------------------- /baselines/stasy/models/__pycache__/layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/models/__pycache__/layers.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/stasy/models/__pycache__/layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/models/__pycache__/layers.cpython-39.pyc -------------------------------------------------------------------------------- /baselines/stasy/models/__pycache__/layerspp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/models/__pycache__/layerspp.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/stasy/models/__pycache__/layerspp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/models/__pycache__/layerspp.cpython-39.pyc -------------------------------------------------------------------------------- /baselines/stasy/models/__pycache__/ncsnpp_tabular.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/models/__pycache__/ncsnpp_tabular.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/stasy/models/__pycache__/ncsnpp_tabular.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/models/__pycache__/ncsnpp_tabular.cpython-39.pyc -------------------------------------------------------------------------------- /baselines/stasy/models/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/models/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/stasy/models/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/stasy/models/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /baselines/stasy/models/ema.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import unicode_literals 3 | 4 | import torch 5 | 6 | 7 | class ExponentialMovingAverage: 8 | """ 9 | Maintains (exponential) moving average of a set of parameters. 10 | """ 11 | 12 | def __init__(self, parameters, decay, use_num_updates=True): 13 | """ 14 | Args: 15 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 16 | `model.parameters()`. 17 | decay: The exponential decay. 18 | use_num_updates: Whether to use number of updates when computing 19 | averages. 20 | """ 21 | if decay < 0.0 or decay > 1.0: 22 | raise ValueError('Decay must be between 0 and 1') 23 | self.decay = decay 24 | self.num_updates = 0 if use_num_updates else None 25 | self.shadow_params = [p.clone().detach() 26 | for p in parameters if p.requires_grad] 27 | self.collected_params = [] 28 | 29 | def update(self, parameters): 30 | """ 31 | Update currently maintained parameters. 32 | 33 | Call this every time the parameters are updated, such as the result of 34 | the `optimizer.step()` call. 35 | 36 | Args: 37 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 38 | parameters used to initialize this object. 39 | """ 40 | decay = self.decay 41 | if self.num_updates is not None: 42 | self.num_updates += 1 43 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 44 | one_minus_decay = 1.0 - decay 45 | with torch.no_grad(): 46 | parameters = [p for p in parameters if p.requires_grad] 47 | for s_param, param in zip(self.shadow_params, parameters): 48 | s_param.sub_(one_minus_decay * (s_param - param)) 49 | 50 | def copy_to(self, parameters): 51 | """ 52 | Copy current parameters into given collection of parameters. 53 | 54 | Args: 55 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 56 | updated with the stored moving averages. 57 | """ 58 | parameters = [p for p in parameters if p.requires_grad] 59 | for s_param, param in zip(self.shadow_params, parameters): 60 | if param.requires_grad: 61 | param.data.copy_(s_param.data) 62 | 63 | def store(self, parameters): 64 | """ 65 | Save the current parameters for restoring later. 66 | 67 | Args: 68 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 69 | temporarily stored. 70 | """ 71 | self.collected_params = [param.clone() for param in parameters] 72 | 73 | def restore(self, parameters): 74 | """ 75 | Restore the parameters stored with the `store` method. 76 | Useful to validate the model with EMA parameters without affecting the 77 | original optimization process. Store the parameters before the 78 | `copy_to` method. After validation (or model saving), use this to 79 | restore the former parameters. 80 | 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | 88 | def state_dict(self): 89 | return dict(decay=self.decay, num_updates=self.num_updates, 90 | shadow_params=self.shadow_params) 91 | 92 | def load_state_dict(self, state_dict): 93 | self.decay = state_dict['decay'] 94 | self.num_updates = state_dict['num_updates'] 95 | self.shadow_params = state_dict['shadow_params'] -------------------------------------------------------------------------------- /baselines/stasy/models/layerspp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Layers for defining NCSN++. 18 | """ 19 | import torch.nn as nn 20 | import torch 21 | import torch.nn.functional as F 22 | import numpy as np 23 | 24 | 25 | class GaussianFourierProjection(nn.Module): 26 | """Gaussian Fourier embeddings for noise levels.""" 27 | 28 | def __init__(self, embedding_size=256, scale=1.0): 29 | super().__init__() 30 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 31 | 32 | def forward(self, x): 33 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 34 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 35 | -------------------------------------------------------------------------------- /baselines/stasy/models/ncsnpp_tabular.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | 18 | from torch.nn.functional import embedding 19 | from . import utils, layers, layerspp 20 | import torch.nn as nn 21 | import torch 22 | 23 | get_act = layers.get_act 24 | default_initializer = layers.default_init 25 | 26 | 27 | NONLINEARITIES = { 28 | "elu": nn.ELU(), 29 | "relu": nn.ReLU(), 30 | "lrelu": nn.LeakyReLU(negative_slope=0.2), 31 | "swish": nn.SiLU(), 32 | "tanh": nn.Tanh(), 33 | "softplus": nn.Softplus(), 34 | } 35 | 36 | 37 | @utils.register_model(name='ncsnpp_tabular') 38 | class NCSNpp(nn.Module): 39 | """NCSN++ model""" 40 | 41 | def __init__(self, config): 42 | super().__init__() 43 | base_layer = { 44 | "ignore": layers.IgnoreLinear, 45 | "squash": layers.SquashLinear, 46 | "concat": layers.ConcatLinear, 47 | "concat_v2": layers.ConcatLinear_v2, 48 | "concatsquash": layers.ConcatSquashLinear, 49 | "blend": layers.BlendLinear, 50 | "concatcoord": layers.ConcatLinear, 51 | } 52 | 53 | self.config = config 54 | self.act = get_act(config) 55 | self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) 56 | self.hidden_dims = config.model.hidden_dims 57 | 58 | self.nf = nf = config.model.nf 59 | 60 | self.conditional = conditional = config.model.conditional 61 | self.embedding_type = embedding_type = config.model.embedding_type.lower() 62 | 63 | modules = [] 64 | if embedding_type == 'fourier': 65 | assert config.training.continuous, "Fourier features are only used for continuous training." 66 | 67 | modules.append(layerspp.GaussianFourierProjection( 68 | embedding_size=nf, scale=config.model.fourier_scale 69 | )) 70 | embed_dim = 2 * nf 71 | 72 | elif embedding_type == 'positional': 73 | embed_dim = nf 74 | 75 | else: 76 | pass 77 | 78 | if conditional: 79 | modules.append(nn.Linear(embed_dim, nf * 4)) 80 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 81 | nn.init.zeros_(modules[-1].bias) 82 | modules.append(nn.Linear(nf * 4, nf * 4)) 83 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 84 | nn.init.zeros_(modules[-1].bias) 85 | 86 | dim = config.data.image_size 87 | for item in list(config.model.hidden_dims): 88 | modules += [ 89 | base_layer[config.model.layer_type](dim, item) 90 | ] 91 | dim += item 92 | modules.append(NONLINEARITIES[config.model.activation]) 93 | 94 | modules.append(nn.Linear(dim, config.data.image_size)) 95 | self.all_modules = nn.ModuleList(modules) 96 | 97 | def forward(self, x, time_cond): 98 | modules = self.all_modules 99 | m_idx = 0 100 | if self.embedding_type == 'fourier': 101 | used_sigmas = time_cond 102 | temb = modules[m_idx](torch.log(used_sigmas)) 103 | m_idx += 1 104 | 105 | elif self.embedding_type == 'positional': 106 | timesteps = time_cond 107 | used_sigmas = self.sigmas[time_cond.long()] 108 | temb = layers.get_timestep_embedding(time_cond, self.nf) 109 | 110 | else: 111 | pass 112 | 113 | if self.conditional: 114 | temb = modules[m_idx](temb) 115 | m_idx += 1 116 | temb = modules[m_idx](self.act(temb)) 117 | m_idx += 1 118 | else: 119 | temb = None 120 | 121 | temb = x 122 | for _ in range(len(self.hidden_dims)): 123 | temb1 = modules[m_idx](t=time_cond, x=temb) 124 | temb = torch.cat([temb1, temb], dim=1) 125 | m_idx += 1 126 | temb = modules[m_idx](temb) 127 | m_idx += 1 128 | 129 | h = modules[m_idx](temb) 130 | 131 | if self.config.model.scale_by_sigma: 132 | used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) 133 | h = h / used_sigmas 134 | 135 | return h -------------------------------------------------------------------------------- /baselines/stasy/models/tabular_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.mixture import BayesianGaussianMixture 4 | 5 | CATEGORICAL = "categorical" 6 | CONTINUOUS = "continuous" 7 | ORDINAL = "ordinal" 8 | 9 | class Transformer: 10 | 11 | @staticmethod 12 | def get_metadata(data, categorical_columns=tuple(), ordinal_columns=tuple()): 13 | meta = [] 14 | 15 | df = pd.DataFrame(data) 16 | for index in df: 17 | column = df[index] 18 | 19 | if index in categorical_columns: 20 | mapper = column.value_counts().index.tolist() 21 | meta.append({ 22 | "name": index, 23 | "type": CATEGORICAL, 24 | "size": len(mapper), 25 | "i2s": mapper 26 | }) 27 | elif index in ordinal_columns: 28 | value_count = list(dict(column.value_counts()).items()) 29 | value_count = sorted(value_count, key=lambda x: -x[1]) 30 | mapper = list(map(lambda x: x[0], value_count)) 31 | meta.append({ 32 | "name": index, 33 | "type": ORDINAL, 34 | "size": len(mapper), 35 | "i2s": mapper 36 | }) 37 | else: 38 | meta.append({ 39 | "name": index, 40 | "type": CONTINUOUS, 41 | "min": column.min(), 42 | "max": column.max(), 43 | }) 44 | 45 | return meta 46 | 47 | def fit(self, data, categorical_columns=tuple(), ordinal_columns=tuple()): 48 | raise NotImplementedError 49 | 50 | def transform(self, data): 51 | raise NotImplementedError 52 | 53 | def inverse_transform(self, data): 54 | raise NotImplementedError 55 | 56 | 57 | class GeneralTransformer(Transformer): 58 | """Continuous and ordinal columns are normalized to [0, 1]. 59 | Discrete columns are converted to a one-hot vector. 60 | """ 61 | 62 | def __init__(self, act='sigmoid'): 63 | self.act = act 64 | self.meta = None 65 | self.output_dim = None 66 | 67 | def fit(self, data, categorical_columns=tuple(), ordinal_columns=tuple()): 68 | self.meta = self.get_metadata(data, categorical_columns, ordinal_columns) 69 | self.output_dim = 0 70 | for info in self.meta: 71 | if info['type'] in [CONTINUOUS, ORDINAL]: 72 | self.output_dim += 1 73 | else: 74 | self.output_dim += info['size'] 75 | 76 | def transform(self, data): 77 | data_t = [] 78 | self.output_info = [] 79 | for id_, info in enumerate(self.meta): 80 | col = data[:, id_] 81 | if info['type'] == CONTINUOUS: 82 | col = (col - (info['min'])) / (info['max'] - info['min']) 83 | if self.act == 'tanh': 84 | col = col * 2 - 1 85 | data_t.append(col.reshape([-1, 1])) 86 | self.output_info.append((1, self.act)) 87 | 88 | elif info['type'] == ORDINAL: 89 | col = col / info['size'] 90 | if self.act == 'tanh': 91 | col = col * 2 - 1 92 | data_t.append(col.reshape([-1, 1])) 93 | self.output_info.append((1, self.act)) 94 | 95 | else: 96 | col_t = np.zeros([len(data), info['size']]) 97 | idx = list(map(info['i2s'].index, col)) 98 | col_t[np.arange(len(data)), idx] = 1 99 | data_t.append(col_t) 100 | self.output_info.append((info['size'], 'softmax')) 101 | 102 | return np.concatenate(data_t, axis=1) 103 | 104 | def inverse_transform(self, data): 105 | data_t = np.zeros([len(data), len(self.meta)]) 106 | 107 | data = data.copy() 108 | for id_, info in enumerate(self.meta): 109 | if info['type'] == CONTINUOUS: 110 | current = data[:, 0] 111 | data = data[:, 1:] 112 | 113 | if self.act == 'tanh': 114 | current = (current + 1) / 2 115 | 116 | current = np.clip(current, 0, 1) 117 | data_t[:, id_] = current * (info['max'] - info['min']) + info['min'] 118 | 119 | elif info['type'] == ORDINAL: 120 | current = data[:, 0] 121 | data = data[:, 1:] 122 | 123 | if self.act == 'tanh': 124 | current = (current + 1) / 2 125 | 126 | current = current * info['size'] 127 | current = np.round(current).clip(0, info['size'] - 1) 128 | data_t[:, id_] = current 129 | else: 130 | current = data[:, :info['size']] 131 | data = data[:, info['size']:] 132 | idx = np.argmax(current, axis=1) 133 | data_t[:, id_] = list(map(info['i2s'].__getitem__, idx)) 134 | 135 | return data_t -------------------------------------------------------------------------------- /baselines/stasy/models/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | 19 | import torch 20 | 21 | import baselines.stasy.sde_lib as sde_lib 22 | import numpy as np 23 | 24 | 25 | _MODELS = {} 26 | 27 | 28 | def register_model(cls=None, *, name=None): 29 | """A decorator for registering model classes.""" 30 | 31 | def _register(cls): 32 | if name is None: 33 | local_name = cls.__name__ 34 | else: 35 | local_name = name 36 | if local_name in _MODELS: 37 | raise ValueError(f'Already registered model with name: {local_name}') 38 | _MODELS[local_name] = cls 39 | return cls 40 | 41 | if cls is None: 42 | return _register 43 | else: 44 | return _register(cls) 45 | 46 | 47 | def get_model(name): 48 | return _MODELS[name] 49 | 50 | 51 | def get_sigmas(config): 52 | """Get sigmas --- the set of noise levels for SMLD from config files. 53 | Args: 54 | config: A ConfigDict object parsed from the config file 55 | Returns: 56 | sigmas: a jax numpy arrary of noise levels 57 | """ 58 | sigmas = np.exp( 59 | np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) 60 | 61 | return sigmas 62 | 63 | 64 | def get_ddpm_params(config): 65 | num_diffusion_timesteps = 1000 66 | beta_start = config.model.beta_min / config.model.num_scales 67 | beta_end = config.model.beta_max / config.model.num_scales 68 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 69 | 70 | alphas = 1. - betas 71 | alphas_cumprod = np.cumprod(alphas, axis=0) 72 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 73 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 74 | 75 | return { 76 | 'betas': betas, 77 | 'alphas': alphas, 78 | 'alphas_cumprod': alphas_cumprod, 79 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 80 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 81 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 82 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 83 | 'num_diffusion_timesteps': num_diffusion_timesteps 84 | } 85 | 86 | 87 | def create_model(config): 88 | """Create the score model.""" 89 | model_name = config.model.name 90 | score_model = get_model(model_name)(config) 91 | score_model = score_model.to(config.device) 92 | # score_model = torch.nn.DataParallel(score_model) 93 | return score_model 94 | 95 | 96 | def get_model_fn(model, train=False): 97 | """Create a function to give the output of the score-based model. 98 | 99 | Args: 100 | model: The score model. 101 | train: `True` for training and `False` for evaluation. 102 | 103 | Returns: 104 | A model function. 105 | """ 106 | 107 | def model_fn(x, labels): 108 | """Compute the output of the score-based model. 109 | 110 | Args: 111 | x: A mini-batch of input data. 112 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 113 | for different models. 114 | 115 | Returns: 116 | A tuple of (model output, new mutable states) 117 | """ 118 | if not train: 119 | model.eval() 120 | return model(x, labels) 121 | else: 122 | model.train() 123 | return model(x, labels) 124 | 125 | return model_fn 126 | 127 | 128 | def get_score_fn(sde, model, train=False, continuous=False): 129 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 130 | 131 | Args: 132 | sde: An `sde_lib.SDE` object that represents the forward SDE. 133 | model: A score model. 134 | train: `True` for training and `False` for evaluation. 135 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 136 | 137 | Returns: 138 | A score function. 139 | """ 140 | model_fn = get_model_fn(model, train=train) 141 | 142 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 143 | def score_fn(x, t): 144 | if continuous or isinstance(sde, sde_lib.subVPSDE): 145 | labels = t * (sde.N - 1) 146 | score = model_fn(x, labels) 147 | std = sde.marginal_prob(torch.zeros_like(x), t)[1] 148 | else: 149 | labels = t * (sde.N - 1) 150 | score = model_fn(x, labels) 151 | std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] 152 | score = -score / std[:, None] 153 | 154 | return score 155 | 156 | elif isinstance(sde, sde_lib.VESDE): 157 | def score_fn(x, t): 158 | if continuous: 159 | labels = sde.marginal_prob(torch.zeros_like(x), t)[1] 160 | else: 161 | labels = sde.T - t 162 | labels *= sde.N - 1 163 | labels = torch.round(labels).long() 164 | 165 | score = model_fn(x, labels) 166 | return score 167 | 168 | else: 169 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 170 | 171 | return score_fn 172 | 173 | 174 | def to_flattened_numpy(x): 175 | """Flatten a torch tensor `x` and convert it to numpy.""" 176 | return x.detach().cpu().numpy().reshape((-1,)) 177 | 178 | 179 | def from_flattened_numpy(x, shape): 180 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 181 | return torch.from_numpy(x.reshape(shape)) -------------------------------------------------------------------------------- /baselines/stasy/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torch.nn.functional as F 5 | 6 | 7 | 8 | def restore_checkpoint(ckpt_dir, state, device): 9 | 10 | loaded_state = torch.load(ckpt_dir, map_location=device) 11 | state['optimizer'].load_state_dict(loaded_state['optimizer']) 12 | state['model'].load_state_dict(loaded_state['model'], strict=False) 13 | state['ema'].load_state_dict(loaded_state['ema']) 14 | state['step'] = loaded_state['step'] 15 | try: 16 | state['epoch'] = loaded_state['epoch'] 17 | except: pass 18 | return state 19 | 20 | 21 | def save_checkpoint(ckpt_dir, state): 22 | saved_state = { 23 | 'optimizer': state['optimizer'].state_dict(), 24 | 'model': state['model'].state_dict(), 25 | 'ema': state['ema'].state_dict(), 26 | 'step': state['step'], 27 | 'epoch': state['epoch'], 28 | } 29 | torch.save(saved_state, ckpt_dir) 30 | 31 | 32 | def apply_activate(data, output_info): 33 | data_t = [] 34 | st = 0 35 | for item in output_info: 36 | if item[1] == 'tanh': 37 | ed = st + item[0] 38 | data_t.append(torch.tanh(data[:, st:ed])) 39 | st = ed 40 | elif item[1] == 'sigmoid': 41 | ed = st + item[0] 42 | data_t.append(data[:,st:ed]) 43 | st = ed 44 | elif item[1] == 'softmax': 45 | ed = st + item[0] 46 | data_t.append(F.softmax(data[:, st:ed])) 47 | 48 | st = ed 49 | else: 50 | assert 0 51 | return torch.cat(data_t, dim=1) -------------------------------------------------------------------------------- /baselines/tabddpm/configs/adult.toml: -------------------------------------------------------------------------------- 1 | parent_dir = "configs/adult" 2 | model_save_path = "ckpt/adult" 3 | sample_save_path = "sampled/adult" 4 | real_data_path = "Data/adult" 5 | 6 | num_numerical_features = 6 7 | model_type = "mlp" 8 | task_type = "binclass" 9 | 10 | [model_params] 11 | num_classes = 2 12 | is_y_cond = false 13 | 14 | [model_params.rtdl_params] 15 | d_layers = [ 16 | 1024, 17 | 2048, 18 | 2048, 19 | 1024, 20 | ] 21 | dropout = 0.0 22 | 23 | [diffusion_params] 24 | num_timesteps = 1000 25 | gaussian_loss_type = "mse" 26 | 27 | [train.main] 28 | steps = 100000 29 | lr = 0.001809824563637657 30 | weight_decay = 5e-4 31 | batch_size = 4096 32 | 33 | [train.T] 34 | seed = 0 35 | normalization = "quantile" 36 | num_nan_policy = "mean" 37 | cat_nan_policy = "__none__" 38 | cat_min_frequency = "__none__" 39 | cat_encoding = "__none__" 40 | y_policy = "default" 41 | 42 | [sample] 43 | num_samples = 32561 44 | batch_size = 10000 45 | seed = 0 46 | 47 | [eval.type] 48 | eval_model = "mlp" 49 | eval_type = "synthetic" 50 | 51 | [eval.T] 52 | seed = 0 53 | normalization = "quantile" 54 | num_nan_policy = "__none__" 55 | cat_nan_policy = "__none__" 56 | cat_min_frequency = "__none__" 57 | cat_encoding = "one-hot" 58 | y_policy = "default" -------------------------------------------------------------------------------- /baselines/tabddpm/configs/beijing.toml: -------------------------------------------------------------------------------- 1 | parent_dir = "configs/beijing" 2 | model_save_path = "ckpt/beijing" 3 | sample_save_path = "sampled/beijing" 4 | real_data_path = "Data/beijing" 5 | num_numerical_features = 6 6 | task_type = "regression" 7 | model_type = "mlp" 8 | seed = 0 9 | 10 | [model_params] 11 | num_classes = 2 12 | is_y_cond = false 13 | 14 | [model_params.rtdl_params] 15 | d_layers = [ 16 | 1024, 17 | 2048, 18 | 2048, 19 | 1024, 20 | ] 21 | dropout = 0.0 22 | 23 | [diffusion_params] 24 | num_timesteps = 1000 25 | gaussian_loss_type = "mse" 26 | 27 | [train.main] 28 | steps = 100000 29 | lr = 0.001809824563637657 30 | weight_decay = 0.0 31 | batch_size = 4096 32 | 33 | [train.T] 34 | seed = 0 35 | normalization = "quantile" 36 | num_nan_policy = "mean" 37 | cat_nan_policy = "__none__" 38 | cat_min_frequency = "__none__" 39 | cat_encoding = "__none__" 40 | y_policy = "default" 41 | 42 | [sample] 43 | num_samples = 37581 44 | batch_size = 10000 45 | seed = 0 46 | 47 | [eval.type] 48 | eval_model = "mlp" 49 | eval_type = "synthetic" 50 | 51 | [eval.T] 52 | seed = 0 53 | normalization = "quantile" 54 | num_nan_policy = "__none__" 55 | cat_nan_policy = "__none__" 56 | cat_min_frequency = "__none__" 57 | cat_encoding = "one-hot" 58 | y_policy = "default" -------------------------------------------------------------------------------- /baselines/tabddpm/configs/default.toml: -------------------------------------------------------------------------------- 1 | parent_dir = "configs/default" 2 | model_save_path = "ckpt/default" 3 | sample_save_path = "sampled/default" 4 | real_data_path = "Data/default" 5 | num_numerical_features = 14 6 | task_type = "binclass" 7 | model_type = "mlp" 8 | seed = 0 9 | device = "cuda:0" 10 | 11 | [model_params] 12 | num_classes = 2 13 | is_y_cond = false 14 | 15 | [model_params.rtdl_params] 16 | d_layers = [ 17 | 1024, 18 | 2048, 19 | 2048, 20 | 1024, 21 | ] 22 | dropout = 0.0 23 | 24 | [diffusion_params] 25 | num_timesteps = 1000 26 | gaussian_loss_type = "mse" 27 | 28 | [train.main] 29 | steps = 100000 30 | lr = 0.001809824563637657 31 | weight_decay = 0.0 32 | batch_size = 4096 33 | 34 | [train.T] 35 | seed = 0 36 | normalization = "quantile" 37 | num_nan_policy = "mean" 38 | cat_nan_policy = "__none__" 39 | cat_min_frequency = "__none__" 40 | cat_encoding = "__none__" 41 | y_policy = "default" 42 | 43 | [sample] 44 | num_samples = 27000 45 | batch_size = 10000 46 | seed = 0 47 | 48 | [eval.type] 49 | eval_model = "mlp" 50 | eval_type = "synthetic" 51 | 52 | [eval.T] 53 | seed = 0 54 | normalization = "quantile" 55 | num_nan_policy = "__none__" 56 | cat_nan_policy = "__none__" 57 | cat_min_frequency = "__none__" 58 | cat_encoding = "one-hot" 59 | y_policy = "default" -------------------------------------------------------------------------------- /baselines/tabddpm/configs/magic.toml: -------------------------------------------------------------------------------- 1 | parent_dir = "configs/magic" 2 | model_save_path = "ckpt/magic" 3 | sample_save_path = "sampled/magic" 4 | real_data_path = "Data/magic" 5 | num_numerical_features = 10 6 | task_type = "binclass" 7 | model_type = "mlp" 8 | seed = 0 9 | 10 | [model_params] 11 | num_classes = 2 12 | is_y_cond = false 13 | 14 | [model_params.rtdl_params] 15 | d_layers = [ 16 | 1024, 17 | 2048, 18 | 2048, 19 | 1024, 20 | ] 21 | dropout = 0.0 22 | 23 | [diffusion_params] 24 | num_timesteps = 1000 25 | gaussian_loss_type = "mse" 26 | 27 | [train.main] 28 | steps = 100000 29 | lr = 0.001809824563637657 30 | weight_decay = 0.0 31 | batch_size = 4096 32 | 33 | [train.T] 34 | seed = 0 35 | normalization = "quantile" 36 | num_nan_policy = "mean" 37 | cat_nan_policy = "__none__" 38 | cat_min_frequency = "__none__" 39 | cat_encoding = "__none__" 40 | y_policy = "default" 41 | 42 | [sample] 43 | num_samples = 17117 44 | batch_size = 10000 45 | seed = 0 46 | 47 | [eval.type] 48 | eval_model = "mlp" 49 | eval_type = "synthetic" 50 | 51 | [eval.T] 52 | seed = 0 53 | normalization = "quantile" 54 | num_nan_policy = "__none__" 55 | cat_nan_policy = "__none__" 56 | cat_min_frequency = "__none__" 57 | cat_encoding = "one-hot" 58 | y_policy = "default" -------------------------------------------------------------------------------- /baselines/tabddpm/configs/news.toml: -------------------------------------------------------------------------------- 1 | parent_dir = "configs/news" 2 | model_save_path = "ckpt/news" 3 | sample_save_path = "sampled/news" 4 | real_data_path = "Data/news" 5 | num_numerical_features = 45 6 | task_type = "regression" 7 | model_type = "mlp" 8 | seed = 0 9 | device = "cuda:0" 10 | 11 | [model_params] 12 | num_classes = 2 13 | is_y_cond = false 14 | 15 | [model_params.rtdl_params] 16 | d_layers = [ 17 | 1024, 18 | 2048, 19 | 2048, 20 | 1024, 21 | ] 22 | dropout = 0.0 23 | 24 | [diffusion_params] 25 | num_timesteps = 1000 26 | gaussian_loss_type = "mse" 27 | 28 | [train.main] 29 | steps = 100000 30 | lr = 0.001809824563637657 31 | weight_decay = 0.0 32 | batch_size = 4096 33 | 34 | [train.T] 35 | seed = 0 36 | normalization = "quantile" 37 | num_nan_policy = "mean" 38 | cat_nan_policy = "__none__" 39 | cat_min_frequency = "__none__" 40 | cat_encoding = "__none__" 41 | y_policy = "default" 42 | 43 | [sample] 44 | num_samples = 35679 45 | batch_size = 10000 46 | seed = 0 47 | 48 | [eval.type] 49 | eval_model = "mlp" 50 | eval_type = "synthetic" 51 | 52 | [eval.T] 53 | seed = 0 54 | normalization = "quantile" 55 | num_nan_policy = "__none__" 56 | cat_nan_policy = "__none__" 57 | cat_min_frequency = "__none__" 58 | cat_encoding = "one-hot" 59 | y_policy = "default" -------------------------------------------------------------------------------- /baselines/tabddpm/configs/shoppers.toml: -------------------------------------------------------------------------------- 1 | parent_dir = "configs/shoppers" 2 | model_save_path = "ckpt/shoppers" 3 | sample_save_path = "sampled/shoppers" 4 | real_data_path = "Data/shoppers" 5 | num_numerical_features = 10 6 | task_type = "binclass" 7 | model_type = "mlp" 8 | seed = 0 9 | device = "cuda:0" 10 | 11 | [model_params] 12 | num_classes = 2 13 | is_y_cond = false 14 | 15 | [model_params.rtdl_params] 16 | d_layers = [ 17 | 1024, 18 | 2048, 19 | 2048, 20 | 1024, 21 | ] 22 | dropout = 0.0 23 | 24 | [diffusion_params] 25 | num_timesteps = 1000 26 | gaussian_loss_type = "mse" 27 | 28 | [train.main] 29 | steps = 100000 30 | lr = 0.001809824563637657 31 | weight_decay = 0.0 32 | batch_size = 4096 33 | 34 | [train.T] 35 | seed = 0 36 | normalization = "quantile" 37 | num_nan_policy = "mean" 38 | cat_nan_policy = "__none__" 39 | cat_min_frequency = "__none__" 40 | cat_encoding = "__none__" 41 | y_policy = "default" 42 | 43 | [sample] 44 | num_samples = 11097 45 | batch_size = 10000 46 | seed = 0 47 | 48 | [eval.type] 49 | eval_model = "mlp" 50 | eval_type = "synthetic" 51 | 52 | [eval.T] 53 | seed = 0 54 | normalization = "quantile" 55 | num_nan_policy = "__none__" 56 | cat_nan_policy = "__none__" 57 | cat_min_frequency = "__none__" 58 | cat_encoding = "one-hot" 59 | y_policy = "default" -------------------------------------------------------------------------------- /baselines/tabddpm/main_sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from baselines.tabddpm.sample import sample 4 | 5 | import src 6 | 7 | 8 | def main(args): 9 | dataname = args.dataname 10 | device = f'cuda:{args.gpu}' 11 | 12 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 13 | config_path = f'{curr_dir}/configs/{dataname}.toml' 14 | model_save_path = f'{curr_dir}/ckpt/{dataname}' 15 | real_data_path = f'data/{dataname}' 16 | sample_save_path = args.save_path 17 | 18 | args.train = True 19 | 20 | raw_config = src.load_config(config_path) 21 | 22 | ''' 23 | Modification of configs 24 | ''' 25 | print('START SAMPLING') 26 | 27 | sample( 28 | num_samples=raw_config['sample']['num_samples'], 29 | batch_size=raw_config['sample']['batch_size'], 30 | disbalance=raw_config['sample'].get('disbalance', None), 31 | **raw_config['diffusion_params'], 32 | model_save_path=model_save_path, 33 | sample_save_path=sample_save_path, 34 | real_data_path=real_data_path, 35 | task_type=raw_config['task_type'], 36 | model_type=raw_config['model_type'], 37 | model_params=raw_config['model_params'], 38 | T_dict=raw_config['train']['T'], 39 | num_numerical_features=raw_config['num_numerical_features'], 40 | device=device, 41 | ddim=args.ddim, 42 | steps=args.steps 43 | ) 44 | 45 | if __name__ == '__main__': 46 | 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--dataname', type = str, default = 'adult') 49 | parser.add_argument('--gpu', type = int, default=0) 50 | parser.add_argument('--ddim', action = 'store_true', default = False, help='Whether to use ddim sampling.') 51 | parser.add_argument('--steps', type=int, default = 1000) 52 | 53 | args = parser.parse_args() -------------------------------------------------------------------------------- /baselines/tabddpm/main_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from baselines.tabddpm.train import train 5 | 6 | import src 7 | 8 | 9 | def main(args): 10 | 11 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 12 | dataname = args.dataname 13 | device = f'cuda:{args.gpu}' 14 | 15 | config_path = f'{curr_dir}/configs/{dataname}.toml' 16 | model_save_path = f'{curr_dir}/ckpt/{dataname}' 17 | real_data_path = f'data/{dataname}' 18 | 19 | if not os.path.exists(model_save_path): 20 | os.makedirs(model_save_path) 21 | 22 | args.train = True 23 | raw_config = src.load_config(config_path) 24 | 25 | ''' 26 | Modification of configs 27 | ''' 28 | print('START TRAINING') 29 | 30 | train( 31 | **raw_config['train']['main'], 32 | **raw_config['diffusion_params'], 33 | model_save_path=model_save_path, 34 | real_data_path=real_data_path, 35 | task_type=raw_config['task_type'], 36 | model_type=raw_config['model_type'], 37 | model_params=raw_config['model_params'], 38 | T_dict=raw_config['train']['T'], 39 | num_numerical_features=raw_config['num_numerical_features'], 40 | device=device 41 | ) 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--config', metavar='FILE') 46 | parser.add_argument('--dataname', type = str, default = 'adult') 47 | parser.add_argument('--gpu', type = int, default=0) 48 | 49 | args = parser.parse_args() -------------------------------------------------------------------------------- /baselines/tabddpm/models/__pycache__/gaussian_multinomial_distribution.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/tabddpm/models/__pycache__/gaussian_multinomial_distribution.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/tabddpm/models/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/tabddpm/models/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/tabddpm/models/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/baselines/tabddpm/models/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /baselines/tabddpm/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torch.profiler import record_function 5 | from inspect import isfunction 6 | 7 | def normal_kl(mean1, logvar1, mean2, logvar2): 8 | """ 9 | Compute the KL divergence between two gaussians. 10 | 11 | Shapes are automatically broadcasted, so batches can be compared to 12 | scalars, among other use cases. 13 | """ 14 | tensor = None 15 | for obj in (mean1, logvar1, mean2, logvar2): 16 | if isinstance(obj, torch.Tensor): 17 | tensor = obj 18 | break 19 | assert tensor is not None, "at least one argument must be a Tensor" 20 | 21 | # Force variances to be Tensors. Broadcasting helps convert scalars to 22 | # Tensors, but it does not work for torch.exp(). 23 | logvar1, logvar2 = [ 24 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 25 | for x in (logvar1, logvar2) 26 | ] 27 | 28 | return 0.5 * ( 29 | -1.0 30 | + logvar2 31 | - logvar1 32 | + torch.exp(logvar1 - logvar2) 33 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 34 | ) 35 | 36 | def approx_standard_normal_cdf(x): 37 | """ 38 | A fast approximation of the cumulative distribution function of the 39 | standard normal. 40 | """ 41 | return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) 42 | 43 | 44 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 45 | """ 46 | Compute the log-likelihood of a Gaussian distribution discretizing to a 47 | given image. 48 | 49 | :param x: the target images. It is assumed that this was uint8 values, 50 | rescaled to the range [-1, 1]. 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | assert x.shape == means.shape == log_scales.shape 56 | centered_x = x - means 57 | inv_stdv = torch.exp(-log_scales) 58 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 59 | cdf_plus = approx_standard_normal_cdf(plus_in) 60 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 61 | cdf_min = approx_standard_normal_cdf(min_in) 62 | log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) 63 | log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) 64 | cdf_delta = cdf_plus - cdf_min 65 | log_probs = torch.where( 66 | x < -0.999, 67 | log_cdf_plus, 68 | torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), 69 | ) 70 | assert log_probs.shape == x.shape 71 | return log_probs 72 | 73 | def sum_except_batch(x, num_dims=1): 74 | ''' 75 | Sums all dimensions except the first. 76 | 77 | Args: 78 | x: Tensor, shape (batch_size, ...) 79 | num_dims: int, number of batch dims (default=1) 80 | 81 | Returns: 82 | x_sum: Tensor, shape (batch_size,) 83 | ''' 84 | return x.reshape(*x.shape[:num_dims], -1).sum(-1) 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | def ohe_to_categories(ohe, K): 93 | K = torch.from_numpy(K) 94 | indices = torch.cat([torch.zeros((1,)), K.cumsum(dim=0)], dim=0).int().tolist() 95 | res = [] 96 | for i in range(len(indices) - 1): 97 | res.append(ohe[:, indices[i]:indices[i+1]].argmax(dim=1)) 98 | return torch.stack(res, dim=1) 99 | 100 | def log_1_min_a(a): 101 | return torch.log(1 - a.exp() + 1e-40) 102 | 103 | 104 | def log_add_exp(a, b): 105 | maximum = torch.max(a, b) 106 | return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) 107 | 108 | def exists(x): 109 | return x is not None 110 | 111 | def extract(a, t, x_shape): 112 | b, *_ = t.shape 113 | t = t.to(a.device) 114 | out = a.gather(-1, t) 115 | while len(out.shape) < len(x_shape): 116 | out = out[..., None] 117 | return out.expand(x_shape) 118 | 119 | def default(val, d): 120 | if exists(val): 121 | return val 122 | return d() if isfunction(d) else d 123 | 124 | def log_categorical(log_x_start, log_prob): 125 | return (log_x_start.exp() * log_prob).sum(dim=1) 126 | 127 | def index_to_log_onehot(x, num_classes): 128 | onehots = [] 129 | for i in range(len(num_classes)): 130 | onehots.append(F.one_hot(x[:, i], num_classes[i])) 131 | 132 | x_onehot = torch.cat(onehots, dim=1) 133 | log_onehot = torch.log(x_onehot.float().clamp(min=1e-30)) 134 | return log_onehot 135 | 136 | def log_sum_exp_by_classes(x, slices): 137 | device = x.device 138 | res = torch.zeros_like(x) 139 | for ixs in slices: 140 | res[:, ixs] = torch.logsumexp(x[:, ixs], dim=1, keepdim=True) 141 | 142 | assert x.size() == res.size() 143 | 144 | return res 145 | 146 | @torch.jit.script 147 | def log_sub_exp(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 148 | m = torch.maximum(a, b) 149 | return torch.log(torch.exp(a - m) - torch.exp(b - m) + 1e-10) + m 150 | 151 | @torch.jit.script 152 | def sliced_logsumexp(x, slices): 153 | 154 | 155 | lse = torch.logcumsumexp( 156 | torch.nn.functional.pad(x, [1, 0, 0, 0], value=-float('inf')), 157 | dim=-1) 158 | 159 | slice_starts = slices[:-1] 160 | slice_ends = slices[1:] 161 | 162 | slice_lse = log_sub_exp(lse[:, slice_ends], lse[:, slice_starts]) 163 | 164 | slice_lse_repeated = torch.repeat_interleave( 165 | slice_lse, 166 | slice_ends - slice_starts, 167 | dim=-1 168 | ) 169 | 170 | return slice_lse_repeated 171 | 172 | def log_onehot_to_index(log_x): 173 | return log_x.argmax(1) 174 | 175 | class FoundNANsError(BaseException): 176 | """Found NANs during sampling""" 177 | def __init__(self, message='Found NANs during sampling.'): 178 | super(FoundNANsError, self).__init__(message) -------------------------------------------------------------------------------- /baselines/tabddpm/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | import json 6 | import time 7 | 8 | from baselines.tabddpm.models.gaussian_multinomial_distribution import GaussianMultinomialDiffusion 9 | from baselines.tabddpm.models.modules import MLPDiffusion 10 | 11 | import src 12 | from utils_train import make_dataset 13 | 14 | @torch.no_grad() 15 | def split_num_cat_target(syn_data, info, num_inverse, cat_inverse): 16 | task_type = info['task_type'] 17 | 18 | num_col_idx = info['num_col_idx'] 19 | cat_col_idx = info['cat_col_idx'] 20 | target_col_idx = info['target_col_idx'] 21 | 22 | n_num_feat = len(num_col_idx) 23 | n_cat_feat = len(cat_col_idx) 24 | 25 | if task_type == 'regression': 26 | n_num_feat += len(target_col_idx) 27 | else: 28 | n_cat_feat += len(target_col_idx) 29 | 30 | syn_num = syn_data[:, :n_num_feat] 31 | syn_cat = syn_data[:, n_num_feat:] 32 | 33 | syn_num = num_inverse(syn_num).astype(np.float32) 34 | syn_cat = cat_inverse(syn_cat) 35 | 36 | 37 | if info['task_type'] == 'regression': 38 | syn_target = syn_num[:, :len(target_col_idx)] 39 | syn_num = syn_num[:, len(target_col_idx):] 40 | 41 | else: 42 | print(syn_cat.shape) 43 | syn_target = syn_cat[:, :len(target_col_idx)] 44 | syn_cat = syn_cat[:, len(target_col_idx):] 45 | 46 | return syn_num, syn_cat, syn_target 47 | 48 | def recover_data(syn_num, syn_cat, syn_target, info): 49 | 50 | num_col_idx = info['num_col_idx'] 51 | cat_col_idx = info['cat_col_idx'] 52 | target_col_idx = info['target_col_idx'] 53 | 54 | 55 | idx_mapping = info['idx_mapping'] 56 | idx_mapping = {int(key): value for key, value in idx_mapping.items()} 57 | 58 | syn_df = pd.DataFrame() 59 | 60 | if info['task_type'] == 'regression': 61 | for i in range(len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)): 62 | if i in set(num_col_idx): 63 | syn_df[i] = syn_num[:, idx_mapping[i]] 64 | elif i in set(cat_col_idx): 65 | syn_df[i] = syn_cat[:, idx_mapping[i] - len(num_col_idx)] 66 | else: 67 | syn_df[i] = syn_target[:, idx_mapping[i] - len(num_col_idx) - len(cat_col_idx)] 68 | 69 | 70 | else: 71 | for i in range(len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)): 72 | if i in set(num_col_idx): 73 | syn_df[i] = syn_num[:, idx_mapping[i]] 74 | elif i in set(cat_col_idx): 75 | syn_df[i] = syn_cat[:, idx_mapping[i] - len(num_col_idx)] 76 | else: 77 | syn_df[i] = syn_target[:, idx_mapping[i] - len(num_col_idx) - len(cat_col_idx)] 78 | 79 | return syn_df 80 | 81 | def get_model( 82 | model_name, 83 | model_params, 84 | n_num_features, 85 | category_sizes 86 | ): 87 | print(model_name) 88 | if model_name == 'mlp': 89 | model = MLPDiffusion(**model_params) 90 | else: 91 | raise "Unknown model!" 92 | return model 93 | 94 | def to_good_ohe(ohe, X): 95 | indices = np.cumsum([0] + ohe._n_features_outs) 96 | Xres = [] 97 | for i in range(1, len(indices)): 98 | x_ = np.max(X[:, indices[i - 1]:indices[i]], axis=1) 99 | t = X[:, indices[i - 1]:indices[i]] - x_.reshape(-1, 1) 100 | Xres.append(np.where(t >= 0, 1, 0)) 101 | return np.hstack(Xres) 102 | 103 | def sample( 104 | model_save_path, 105 | sample_save_path, 106 | real_data_path, 107 | batch_size = 2000, 108 | num_samples = 0, 109 | task_type = 'binclass', 110 | model_type = 'mlp', 111 | model_params = None, 112 | num_timesteps = 1000, 113 | gaussian_loss_type = 'mse', 114 | scheduler = 'cosine', 115 | T_dict = None, 116 | num_numerical_features = 0, 117 | disbalance = None, 118 | device = torch.device('cuda:0'), 119 | change_val = False, 120 | ddim = False, 121 | steps = 1000, 122 | ): 123 | 124 | T = src.Transformations(**T_dict) 125 | 126 | D = make_dataset( 127 | real_data_path, 128 | T, 129 | task_type = task_type, 130 | change_val = False, 131 | ) 132 | 133 | K = np.array(D.get_category_sizes('train')) 134 | if len(K) == 0 or T_dict['cat_encoding'] == 'one-hot': 135 | K = np.array([0]) 136 | 137 | num_numerical_features_ = D.X_num['train'].shape[1] if D.X_num is not None else 0 138 | d_in = np.sum(K) + num_numerical_features_ 139 | model_params['d_in'] = int(d_in) 140 | model = get_model( 141 | model_type, 142 | model_params, 143 | num_numerical_features_, 144 | category_sizes=D.get_category_sizes('train') 145 | ) 146 | 147 | model_path =f'{model_save_path}/model.pt' 148 | 149 | model.load_state_dict( 150 | torch.load(model_path, map_location="cpu") 151 | ) 152 | 153 | 154 | diffusion = GaussianMultinomialDiffusion( 155 | K, 156 | num_numerical_features=num_numerical_features_, 157 | denoise_fn=model, num_timesteps=num_timesteps, 158 | gaussian_loss_type=gaussian_loss_type, scheduler=scheduler, device=device 159 | ) 160 | 161 | diffusion.to(device) 162 | diffusion.eval() 163 | 164 | start_time = time.time() 165 | if not ddim: 166 | x_gen = diffusion.sample_all(num_samples, batch_size, ddim=False) 167 | else: 168 | x_gen = diffusion.sample_all(num_samples, batch_size, ddim=True, steps = steps) 169 | 170 | 171 | print('Shape', x_gen.shape) 172 | 173 | syn_data = x_gen 174 | num_inverse = D.num_transform.inverse_transform 175 | cat_inverse = D.cat_transform.inverse_transform 176 | 177 | info_path = f'{real_data_path}/info.json' 178 | 179 | with open(info_path, 'r') as f: 180 | info = json.load(f) 181 | 182 | syn_num, syn_cat, syn_target = split_num_cat_target(syn_data, info, num_inverse, cat_inverse) 183 | syn_df = recover_data(syn_num, syn_cat, syn_target, info) 184 | 185 | idx_name_mapping = info['idx_name_mapping'] 186 | idx_name_mapping = {int(key): value for key, value in idx_name_mapping.items()} 187 | 188 | syn_df.rename(columns = idx_name_mapping, inplace=True) 189 | end_time = time.time() 190 | 191 | print('Sampling time:', end_time - start_time) 192 | 193 | save_path = sample_save_path 194 | syn_df.to_csv(save_path, index = False) -------------------------------------------------------------------------------- /data/Info/adult.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "adult", 3 | "task_type": "binclass", 4 | "header": null, 5 | "column_names": [ 6 | "age", 7 | "workclass", 8 | "fnlwgt", 9 | "education", 10 | "education.num", 11 | "marital.status", 12 | "occupation", 13 | "relationship", 14 | "race", 15 | "sex", 16 | "capital.gain", 17 | "capital.loss", 18 | "hours.per.week", 19 | "native.country", 20 | "income" 21 | ], 22 | "num_col_idx": [ 23 | 0, 24 | 2, 25 | 4, 26 | 10, 27 | 11, 28 | 12 29 | ], 30 | "cat_col_idx": [ 31 | 1, 32 | 3, 33 | 5, 34 | 6, 35 | 7, 36 | 8, 37 | 9, 38 | 13 39 | ], 40 | "target_col_idx": [ 41 | 14 42 | ], 43 | "file_type": "csv", 44 | "data_path": "data/adult/adult.data", 45 | "test_path": "data/adult/adult.test", 46 | "column_info": { 47 | "age": "float", 48 | "workclass": "str", 49 | "fnlwgt": "float", 50 | "education": "str", 51 | "education.num": "float", 52 | "marital.status": "str", 53 | "occupation": "str", 54 | "relationship": "str", 55 | "race": "str", 56 | "sex": "str", 57 | "capital.gain": "float", 58 | "capital.loss": "float", 59 | "hours.per.week": "float", 60 | "native.country": "str", 61 | "income": "str" 62 | }, 63 | "train_num": 32561, 64 | "test_num": 16281 65 | } -------------------------------------------------------------------------------- /data/Info/beijing.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "beijing", 3 | "task_type": "regression", 4 | "header": "infer", 5 | "column_names": null, 6 | "num_col_idx": [5,6,7,9,10,11], 7 | "cat_col_idx": [0,1,2,3,8], 8 | "target_col_idx": [4], 9 | "file_type": "csv", 10 | "raw_data_path": "data/beijing/PRSA_data_2010.1.1-2014.12.31.csv", 11 | "test_path": null, 12 | "data_path": "data/beijing/beijing.csv" 13 | } 14 | 15 | -------------------------------------------------------------------------------- /data/Info/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "default", 3 | "task_type": "binclass", 4 | "header": "infer", 5 | "column_names": null, 6 | "num_col_idx": [0,4,11,12,13,14,15,16,17,18,19,20,21,22], 7 | "cat_col_idx": [1,2,3,5,6,7,8,9,10], 8 | "target_col_idx": [23], 9 | 10 | "file_type": "xls", 11 | "data_path": "data/default/default of credit card clients.xls", 12 | "test_path": null 13 | } 14 | 15 | -------------------------------------------------------------------------------- /data/Info/magic.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "magic", 3 | "task_type": "binclass", 4 | "header": "infer", 5 | "column_names": ["Length", "Width", "Size", "Conc", "Conc1", "Asym", "M3Long", "M3Trans", "Alpha", "Dist", "class"], 6 | "num_col_idx": [0,1,2,3,4,5,6,7,8,9], 7 | "cat_col_idx": [], 8 | "target_col_idx": [10], 9 | 10 | "file_type": "csv", 11 | "data_path": "data/magic/magic04.data", 12 | "test_path": null 13 | } 14 | 15 | -------------------------------------------------------------------------------- /data/Info/news.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "news", 3 | "task_type": "regression", 4 | "header": "infer", 5 | "column_names": null, 6 | "num_col_idx": [ 7 | 0, 8 | 1, 9 | 2, 10 | 3, 11 | 4, 12 | 5, 13 | 6, 14 | 7, 15 | 8, 16 | 9, 17 | 10, 18 | 11, 19 | 12, 20 | 13, 21 | 14, 22 | 15, 23 | 16, 24 | 17, 25 | 18, 26 | 19, 27 | 20, 28 | 21, 29 | 22, 30 | 23, 31 | 24, 32 | 25, 33 | 26, 34 | 27, 35 | 28, 36 | 29, 37 | 30, 38 | 31, 39 | 32, 40 | 33, 41 | 34, 42 | 35, 43 | 36, 44 | 37, 45 | 38, 46 | 39, 47 | 40, 48 | 41, 49 | 42, 50 | 43, 51 | 44 52 | ], 53 | "cat_col_idx": [ 54 | 46, 55 | 47 56 | ], 57 | "target_col_idx": [ 58 | 45 59 | ], 60 | "file_type": "csv", 61 | "raw_data_path": "data/news/OnlineNewsPopularity/OnlineNewsPopularity.csv", 62 | "test_path": null, 63 | "data_path": "data/news/news.csv" 64 | } -------------------------------------------------------------------------------- /data/Info/shoppers.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "shoppers", 3 | "task_type": "binclass", 4 | "header": "infer", 5 | "column_names": null, 6 | "num_col_idx": [0,1,2,3,4,5,6,7,8,9], 7 | "cat_col_idx": [10,11,12,13,14,15,16], 8 | "target_col_idx": [17], 9 | 10 | "file_type": "csv", 11 | "data_path": "data/shoppers/online_shoppers_intention.csv", 12 | "test_path": null 13 | } 14 | 15 | -------------------------------------------------------------------------------- /download_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy 3 | import pandas as pd 4 | from urllib import request 5 | import shutil 6 | import zipfile 7 | 8 | DATA_DIR = 'data' 9 | 10 | 11 | NAME_URL_DICT_UCI = { 12 | 'adult': 'https://archive.ics.uci.edu/static/public/2/adult.zip', 13 | 'default': 'https://archive.ics.uci.edu/static/public/350/default+of+credit+card+clients.zip', 14 | 'magic': 'https://archive.ics.uci.edu/static/public/159/magic+gamma+telescope.zip', 15 | 'shoppers': 'https://archive.ics.uci.edu/static/public/468/online+shoppers+purchasing+intention+dataset.zip', 16 | 'beijing': 'https://archive.ics.uci.edu/static/public/381/beijing+pm2+5+data.zip', 17 | 'news': 'https://archive.ics.uci.edu/static/public/332/online+news+popularity.zip' 18 | } 19 | 20 | def unzip_file(zip_filepath, dest_path): 21 | with zipfile.ZipFile(zip_filepath, 'r') as zip_ref: 22 | zip_ref.extractall(dest_path) 23 | 24 | 25 | def download_from_uci(name): 26 | 27 | print(f'Start processing dataset {name} from UCI.') 28 | save_dir = f'{DATA_DIR}/{name}' 29 | if not os.path.exists(save_dir): 30 | os.makedirs(save_dir) 31 | 32 | url = NAME_URL_DICT_UCI[name] 33 | request.urlretrieve(url, f'{save_dir}/{name}.zip') 34 | print(f'Finish downloading dataset from {url}, data has been saved to {save_dir}.') 35 | 36 | unzip_file(f'{save_dir}/{name}.zip', save_dir) 37 | print(f'Finish unzipping {name}.') 38 | 39 | else: 40 | print('Aready downloaded.') 41 | 42 | if __name__ == '__main__': 43 | for name in NAME_URL_DICT_UCI.keys(): 44 | download_from_uci(name) 45 | -------------------------------------------------------------------------------- /eval/eval_dcr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pandas as pd 4 | import json 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | from utils_train import preprocess, TabularDataset 10 | from sklearn.preprocessing import OneHotEncoder 11 | 12 | pd.options.mode.chained_assignment = None 13 | 14 | import argparse 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--dataname', type=str, default='adult') 18 | parser.add_argument('--model', type=str, default='model') 19 | parser.add_argument('--path', type=str, default = None, help='The file path of the synthetic data') 20 | 21 | args = parser.parse_args() 22 | 23 | 24 | if __name__ == '__main__': 25 | 26 | dataname = args.dataname 27 | model = args.model 28 | 29 | if not args.path: 30 | syn_path = f'synthetic/{dataname}/{model}.csv' 31 | else: 32 | syn_path = args.path 33 | 34 | real_path = f'synthetic/{dataname}/real.csv' 35 | test_path = f'synthetic/{dataname}/test.csv' 36 | 37 | data_dir = f'data/{dataname}' 38 | 39 | with open(f'{data_dir}/info.json', 'r') as f: 40 | info = json.load(f) 41 | 42 | syn_data = pd.read_csv(syn_path) 43 | real_data = pd.read_csv(real_path) 44 | test_data = pd.read_csv(test_path) 45 | 46 | num_col_idx = info['num_col_idx'] 47 | cat_col_idx = info['cat_col_idx'] 48 | target_col_idx = info['target_col_idx'] 49 | 50 | task_type = info['task_type'] 51 | if task_type == 'regression': 52 | num_col_idx += target_col_idx 53 | else: 54 | cat_col_idx += target_col_idx 55 | 56 | num_ranges = [] 57 | 58 | real_data.columns = list(np.arange(len(real_data.columns))) 59 | syn_data.columns = list(np.arange(len(real_data.columns))) 60 | test_data.columns = list(np.arange(len(real_data.columns))) 61 | for i in num_col_idx: 62 | num_ranges.append(real_data[i].max() - real_data[i].min()) 63 | 64 | num_ranges = np.array(num_ranges) 65 | 66 | 67 | num_real_data = real_data[num_col_idx] 68 | cat_real_data = real_data[cat_col_idx] 69 | num_syn_data = syn_data[num_col_idx] 70 | cat_syn_data = syn_data[cat_col_idx] 71 | num_test_data = test_data[num_col_idx] 72 | cat_test_data = test_data[cat_col_idx] 73 | 74 | num_real_data_np = num_real_data.to_numpy() 75 | cat_real_data_np = cat_real_data.to_numpy().astype('str') 76 | num_syn_data_np = num_syn_data.to_numpy() 77 | cat_syn_data_np = cat_syn_data.to_numpy().astype('str') 78 | num_test_data_np = num_test_data.to_numpy() 79 | cat_test_data_np = cat_test_data.to_numpy().astype('str') 80 | 81 | encoder = OneHotEncoder() 82 | encoder.fit(cat_real_data_np) 83 | 84 | 85 | cat_real_data_oh = encoder.transform(cat_real_data_np).toarray() 86 | cat_syn_data_oh = encoder.transform(cat_syn_data_np).toarray() 87 | cat_test_data_oh = encoder.transform(cat_test_data_np).toarray() 88 | 89 | num_real_data_np = num_real_data_np / num_ranges 90 | num_syn_data_np = num_syn_data_np / num_ranges 91 | num_test_data_np = num_test_data_np / num_ranges 92 | 93 | real_data_np = np.concatenate([num_real_data_np, cat_real_data_oh], axis=1) 94 | syn_data_np = np.concatenate([num_syn_data_np, cat_syn_data_oh], axis=1) 95 | test_data_np = np.concatenate([num_test_data_np, cat_test_data_oh], axis=1) 96 | 97 | if torch.cuda.is_available(): 98 | device = 'cuda' 99 | else: 100 | device = 'cpu' 101 | 102 | real_data_th = torch.tensor(real_data_np).to(device) 103 | syn_data_th = torch.tensor(syn_data_np).to(device) 104 | test_data_th = torch.tensor(test_data_np).to(device) 105 | 106 | dcrs_real = [] 107 | dcrs_test = [] 108 | batch_size = 100 109 | 110 | batch_syn_data_np = syn_data_np[i*batch_size: (i+1) * batch_size] 111 | 112 | for i in range((syn_data_th.shape[0] // batch_size) + 1): 113 | if i != (syn_data_th.shape[0] // batch_size): 114 | batch_syn_data_th = syn_data_th[i*batch_size: (i+1) * batch_size] 115 | else: 116 | batch_syn_data_th = syn_data_th[i*batch_size:] 117 | 118 | dcr_real = (batch_syn_data_th[:, None] - real_data_th).abs().sum(dim = 2).min(dim = 1).values 119 | dcr_test = (batch_syn_data_th[:, None] - test_data_th).abs().sum(dim = 2).min(dim = 1).values 120 | dcrs_real.append(dcr_real) 121 | dcrs_test.append(dcr_test) 122 | 123 | dcrs_real = torch.cat(dcrs_real) 124 | dcrs_test = torch.cat(dcrs_test) 125 | 126 | 127 | score = (dcrs_real < dcrs_test).nonzero().shape[0] / dcrs_real.shape[0] 128 | 129 | print('DCR Score, a value closer to 0.5 is better') 130 | print(f'{dataname}-{model}, DCR Score = {score}') -------------------------------------------------------------------------------- /eval/eval_density.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | 5 | import json 6 | 7 | # Metrics 8 | from sdmetrics.reports.single_table import QualityReport, DiagnosticReport 9 | 10 | 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--dataname', type=str, default='adult') 15 | parser.add_argument('--model', type=str, default='tabsyn') 16 | parser.add_argument('--path', type=str, default = None, help='The file path of the synthetic data') 17 | 18 | args = parser.parse_args() 19 | 20 | 21 | def reorder(real_data, syn_data, info): 22 | num_col_idx = info['num_col_idx'] 23 | cat_col_idx = info['cat_col_idx'] 24 | target_col_idx = info['target_col_idx'] 25 | 26 | task_type = info['task_type'] 27 | if task_type == 'regression': 28 | num_col_idx += target_col_idx 29 | else: 30 | cat_col_idx += target_col_idx 31 | 32 | real_num_data = real_data[num_col_idx] 33 | real_cat_data = real_data[cat_col_idx] 34 | 35 | new_real_data = pd.concat([real_num_data, real_cat_data], axis=1) 36 | new_real_data.columns = range(len(new_real_data.columns)) 37 | 38 | syn_num_data = syn_data[num_col_idx] 39 | syn_cat_data = syn_data[cat_col_idx] 40 | 41 | new_syn_data = pd.concat([syn_num_data, syn_cat_data], axis=1) 42 | new_syn_data.columns = range(len(new_syn_data.columns)) 43 | 44 | 45 | metadata = info['metadata'] 46 | 47 | columns = metadata['columns'] 48 | metadata['columns'] = {} 49 | 50 | inverse_idx_mapping = info['inverse_idx_mapping'] 51 | 52 | 53 | for i in range(len(new_real_data.columns)): 54 | if i < len(num_col_idx): 55 | metadata['columns'][i] = columns[num_col_idx[i]] 56 | else: 57 | metadata['columns'][i] = columns[cat_col_idx[i-len(num_col_idx)]] 58 | 59 | 60 | return new_real_data, new_syn_data, metadata 61 | 62 | if __name__ == '__main__': 63 | 64 | dataname = args.dataname 65 | model = args.model 66 | 67 | if not args.path: 68 | syn_path = f'synthetic/{dataname}/{model}.csv' 69 | else: 70 | syn_path = args.path 71 | 72 | real_path = f'synthetic/{dataname}/real.csv' 73 | 74 | data_dir = f'data/{dataname}' 75 | print(syn_path) 76 | 77 | with open(f'{data_dir}/info.json', 'r') as f: 78 | info = json.load(f) 79 | 80 | syn_data = pd.read_csv(syn_path) 81 | real_data = pd.read_csv(real_path) 82 | 83 | save_dir = f'eval/density/{dataname}/{model}' 84 | if not os.path.exists(save_dir): 85 | os.makedirs(save_dir) 86 | 87 | real_data.columns = range(len(real_data.columns)) 88 | syn_data.columns = range(len(syn_data.columns)) 89 | 90 | metadata = info['metadata'] 91 | metadata['columns'] = {int(key): value for key, value in metadata['columns'].items()} 92 | 93 | new_real_data, new_syn_data, metadata = reorder(real_data, syn_data, info) 94 | 95 | qual_report = QualityReport() 96 | qual_report.generate(new_real_data, new_syn_data, metadata) 97 | 98 | diag_report = DiagnosticReport() 99 | diag_report.generate(new_real_data, new_syn_data, metadata) 100 | 101 | quality = qual_report.get_properties() 102 | diag = diag_report.get_properties() 103 | 104 | Shape = quality['Score'][0] 105 | Trend = quality['Score'][1] 106 | 107 | with open(f'{save_dir}/quality.txt', 'w') as f: 108 | f.write(f'{Shape}\n') 109 | f.write(f'{Trend}\n') 110 | 111 | Quality = (Shape + Trend) / 2 112 | 113 | shapes = qual_report.get_details(property_name='Column Shapes') 114 | trends = qual_report.get_details(property_name='Column Pair Trends') 115 | coverages = diag_report.get_details('Coverage') 116 | 117 | 118 | shapes.to_csv(f'{save_dir}/shape.csv') 119 | trends.to_csv(f'{save_dir}/trend.csv') 120 | coverages.to_csv(f'{save_dir}/coverage.csv') 121 | -------------------------------------------------------------------------------- /eval/eval_detection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pandas as pd 4 | import os 5 | import sys 6 | 7 | import json 8 | import pickle 9 | 10 | # Metrics 11 | from sdmetrics import load_demo 12 | from sdmetrics.single_table import LogisticDetection 13 | 14 | from matplotlib import pyplot as plt 15 | 16 | import argparse 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataname', type=str, default='adult') 22 | parser.add_argument('--model', type=str, default='real') 23 | 24 | args = parser.parse_args() 25 | 26 | def reorder(real_data, syn_data, info): 27 | num_col_idx = info['num_col_idx'] 28 | cat_col_idx = info['cat_col_idx'] 29 | target_col_idx = info['target_col_idx'] 30 | 31 | task_type = info['task_type'] 32 | if task_type == 'regression': 33 | num_col_idx += target_col_idx 34 | else: 35 | cat_col_idx += target_col_idx 36 | 37 | real_num_data = real_data[num_col_idx] 38 | real_cat_data = real_data[cat_col_idx] 39 | 40 | new_real_data = pd.concat([real_num_data, real_cat_data], axis=1) 41 | new_real_data.columns = range(len(new_real_data.columns)) 42 | 43 | syn_num_data = syn_data[num_col_idx] 44 | syn_cat_data = syn_data[cat_col_idx] 45 | 46 | new_syn_data = pd.concat([syn_num_data, syn_cat_data], axis=1) 47 | new_syn_data.columns = range(len(new_syn_data.columns)) 48 | 49 | 50 | metadata = info['metadata'] 51 | 52 | columns = metadata['columns'] 53 | metadata['columns'] = {} 54 | 55 | inverse_idx_mapping = info['inverse_idx_mapping'] 56 | 57 | 58 | for i in range(len(new_real_data.columns)): 59 | if i < len(num_col_idx): 60 | metadata['columns'][i] = columns[num_col_idx[i]] 61 | else: 62 | metadata['columns'][i] = columns[cat_col_idx[i-len(num_col_idx)]] 63 | 64 | 65 | return new_real_data, new_syn_data, metadata 66 | 67 | if __name__ == '__main__': 68 | 69 | dataname = args.dataname 70 | model = args.model 71 | 72 | syn_path = f'synthetic/{dataname}/{model}.csv' 73 | real_path = f'synthetic/{dataname}/real.csv' 74 | 75 | data_dir = f'data/{dataname}' 76 | print(syn_path) 77 | 78 | with open(f'{data_dir}/info.json', 'r') as f: 79 | info = json.load(f) 80 | 81 | syn_data = pd.read_csv(syn_path) 82 | real_data = pd.read_csv(real_path) 83 | 84 | save_dir = f'eval/density/{dataname}/{model}' 85 | if not os.path.exists(save_dir): 86 | os.makedirs(save_dir) 87 | 88 | real_data.columns = range(len(real_data.columns)) 89 | syn_data.columns = range(len(syn_data.columns)) 90 | 91 | metadata = info['metadata'] 92 | metadata['columns'] = {int(key): value for key, value in metadata['columns'].items()} 93 | 94 | new_real_data, new_syn_data, metadata = reorder(real_data, syn_data, info) 95 | 96 | # qual_report.generate(new_real_data, new_syn_data, metadata) 97 | 98 | score = LogisticDetection.compute( 99 | real_data=new_real_data, 100 | synthetic_data=new_syn_data, 101 | metadata=metadata 102 | ) 103 | 104 | print(f'{dataname}, {model}: {score}') -------------------------------------------------------------------------------- /eval/eval_mle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pandas as pd 4 | import os 5 | import sys 6 | 7 | import json 8 | from mle.mle import get_evaluator 9 | 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | 14 | import argparse 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--dataname', type=str, default='adult') 18 | parser.add_argument('--model', type=str, default='real') 19 | parser.add_argument('--path', type=str, default = None, help='The file path of the synthetic data') 20 | 21 | args = parser.parse_args() 22 | 23 | # def preprocess(train, test, info) 24 | 25 | # def norm_data(data, ) 26 | 27 | if __name__ == '__main__': 28 | 29 | dataname = args.dataname 30 | model = args.model 31 | 32 | if not args.path: 33 | train_path = f'synthetic/{dataname}/{model}.csv' 34 | else: 35 | train_path = args.path 36 | test_path = f'synthetic/{dataname}/test.csv' 37 | 38 | train = pd.read_csv(train_path).to_numpy() 39 | test = pd.read_csv(test_path).to_numpy() 40 | 41 | with open(f'data/{dataname}/info.json', 'r') as f: 42 | info = json.load(f) 43 | 44 | task_type = info['task_type'] 45 | 46 | evaluator = get_evaluator(task_type) 47 | 48 | if task_type == 'regression': 49 | best_r2_scores, best_rmse_scores = evaluator(train, test, info) 50 | 51 | overall_scores = {} 52 | for score_name in ['best_r2_scores', 'best_rmse_scores']: 53 | overall_scores[score_name] = {} 54 | 55 | scores = eval(score_name) 56 | for method in scores: 57 | name = method['name'] 58 | method.pop('name') 59 | overall_scores[score_name][name] = method 60 | 61 | else: 62 | best_f1_scores, best_weighted_scores, best_auroc_scores, best_acc_scores, best_avg_scores = evaluator(train, test, info) 63 | 64 | overall_scores = {} 65 | for score_name in ['best_f1_scores', 'best_weighted_scores', 'best_auroc_scores', 'best_acc_scores', 'best_avg_scores']: 66 | overall_scores[score_name] = {} 67 | 68 | scores = eval(score_name) 69 | for method in scores: 70 | name = method['name'] 71 | method.pop('name') 72 | overall_scores[score_name][name] = method 73 | 74 | if not os.path.exists(f'eval/mle/{dataname}'): 75 | os.makedirs(f'eval/mle/{dataname}') 76 | 77 | save_path = f'eval/mle/{dataname}/{model}.json' 78 | print('Saving scores to ', save_path) 79 | with open(save_path, "w") as json_file: 80 | json.dump(overall_scores, json_file, indent=4, separators=(", ", ": ")) 81 | 82 | -------------------------------------------------------------------------------- /eval/eval_quality.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import sys 5 | import json 6 | 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | from utils_train import preprocess, TabularDataset 9 | from sklearn.preprocessing import OneHotEncoder 10 | from synthcity.metrics import eval_detection, eval_performance, eval_statistical 11 | from synthcity.plugins.core.dataloader import GenericDataLoader 12 | 13 | pd.options.mode.chained_assignment = None 14 | 15 | import argparse 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataname', type=str, default='adult') 19 | parser.add_argument('--model', type=str, default='model') 20 | parser.add_argument('--path', type=str, default = None, help='The file path of the synthetic data') 21 | 22 | 23 | args = parser.parse_args() 24 | 25 | 26 | if __name__ == '__main__': 27 | 28 | dataname = args.dataname 29 | model = args.model 30 | 31 | if not args.path: 32 | syn_path = f'synthetic/{dataname}/{model}.csv' 33 | else: 34 | syn_path = args.path 35 | real_path = f'synthetic/{dataname}/real.csv' 36 | 37 | data_dir = f'data/{dataname}' 38 | 39 | print(syn_path) 40 | 41 | 42 | with open(f'{data_dir}/info.json', 'r') as f: 43 | info = json.load(f) 44 | 45 | syn_data = pd.read_csv(syn_path) 46 | real_data = pd.read_csv(real_path) 47 | 48 | 49 | ''' Special treatment for default dataset and CoDi model ''' 50 | 51 | real_data.columns = range(len(real_data.columns)) 52 | syn_data.columns = range(len(syn_data.columns)) 53 | 54 | num_col_idx = info['num_col_idx'] 55 | cat_col_idx = info['cat_col_idx'] 56 | target_col_idx = info['target_col_idx'] 57 | if info['task_type'] == 'regression': 58 | num_col_idx += target_col_idx 59 | else: 60 | cat_col_idx += target_col_idx 61 | 62 | num_real_data = real_data[num_col_idx] 63 | cat_real_data = real_data[cat_col_idx] 64 | 65 | num_real_data_np = num_real_data.to_numpy() 66 | cat_real_data_np = cat_real_data.to_numpy().astype('str') 67 | 68 | 69 | num_syn_data = syn_data[num_col_idx] 70 | cat_syn_data = syn_data[cat_col_idx] 71 | 72 | num_syn_data_np = num_syn_data.to_numpy() 73 | 74 | # cat_syn_data_np = np.array 75 | cat_syn_data_np = cat_syn_data.to_numpy().astype('str') 76 | if (dataname == 'default' or dataname == 'news') and model[:4] == 'codi': 77 | cat_syn_data_np = cat_syn_data.astype('int').to_numpy().astype('str') 78 | 79 | elif model[:5] == 'great': 80 | if dataname == 'shoppers': 81 | cat_syn_data_np[:, 1] = cat_syn_data[11].astype('int').to_numpy().astype('str') 82 | cat_syn_data_np[:, 2] = cat_syn_data[12].astype('int').to_numpy().astype('str') 83 | cat_syn_data_np[:, 3] = cat_syn_data[13].astype('int').to_numpy().astype('str') 84 | 85 | max_data = cat_real_data[14].max() 86 | 87 | cat_syn_data.loc[cat_syn_data[14] > max_data, 14] = max_data 88 | # cat_syn_data[14] = cat_syn_data[14].apply(lambda x: threshold if x > max_data else x) 89 | 90 | cat_syn_data_np[:, 4] = cat_syn_data[14].astype('int').to_numpy().astype('str') 91 | cat_syn_data_np[:, 4] = cat_syn_data[14].astype('int').to_numpy().astype('str') 92 | 93 | elif dataname in ['default', 'faults', 'beijing']: 94 | 95 | columns = cat_real_data.columns 96 | for i, col in enumerate(columns): 97 | if (cat_real_data[col].dtype == 'int'): 98 | 99 | max_data = cat_real_data[col].max() 100 | min_data = cat_real_data[col].min() 101 | 102 | cat_syn_data.loc[cat_syn_data[col] > max_data, col] = max_data 103 | cat_syn_data.loc[cat_syn_data[col] < min_data, col] = min_data 104 | 105 | cat_syn_data_np[:, i] = cat_syn_data[col].astype('int').to_numpy().astype('str') 106 | 107 | else: 108 | cat_syn_data_np = cat_syn_data.to_numpy().astype('str') 109 | 110 | else: 111 | cat_syn_data_np = cat_syn_data.to_numpy().astype('str') 112 | 113 | encoder = OneHotEncoder() 114 | encoder.fit(cat_real_data_np) 115 | 116 | 117 | cat_real_data_oh = encoder.transform(cat_real_data_np).toarray() 118 | cat_syn_data_oh = encoder.transform(cat_syn_data_np).toarray() 119 | 120 | le_real_data = pd.DataFrame(np.concatenate((num_real_data_np, cat_real_data_oh), axis = 1)).astype(float) 121 | le_real_num = pd.DataFrame(num_real_data_np).astype(float) 122 | le_real_cat = pd.DataFrame(cat_real_data_oh).astype(float) 123 | 124 | 125 | le_syn_data = pd.DataFrame(np.concatenate((num_syn_data_np, cat_syn_data_oh), axis = 1)).astype(float) 126 | le_syn_num = pd.DataFrame(num_syn_data_np).astype(float) 127 | le_syn_cat = pd.DataFrame(cat_syn_data_oh).astype(float) 128 | 129 | np.set_printoptions(precision=4) 130 | 131 | result = [] 132 | 133 | print('=========== All Features ===========') 134 | print('Data shape: ', le_syn_data.shape) 135 | 136 | X_syn_loader = GenericDataLoader(le_syn_data) 137 | X_real_loader = GenericDataLoader(le_real_data) 138 | 139 | quality_evaluator = eval_statistical.AlphaPrecision() 140 | qual_res = quality_evaluator.evaluate(X_real_loader, X_syn_loader) 141 | qual_res = { 142 | k: v for (k, v) in qual_res.items() if "naive" in k 143 | } # use the naive implementation of AlphaPrecision 144 | qual_score = np.mean(list(qual_res.values())) 145 | 146 | print('alpha precision: {:.6f}, beta recall: {:.6f}'.format(qual_res['delta_precision_alpha_naive'], qual_res['delta_coverage_beta_naive'] )) 147 | 148 | Alpha_Precision_all = qual_res['delta_precision_alpha_naive'] 149 | Beta_Recall_all = qual_res['delta_coverage_beta_naive'] 150 | 151 | save_dir = f'eval/quality/{dataname}' 152 | if not os.path.exists(save_dir): 153 | os.makedirs(save_dir) 154 | 155 | with open(f'{save_dir}/{model}.txt', 'w') as f: 156 | f.write(f'{Alpha_Precision_all}\n') 157 | f.write(f'{Beta_Recall_all}\n') 158 | -------------------------------------------------------------------------------- /eval/mle/tabular_dataload.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Return training and evaluation/test datasets from config files.""" 18 | import torch 19 | import numpy as np 20 | import pandas as pd 21 | from tabular_transformer import GeneralTransformer 22 | import json 23 | import logging 24 | import os 25 | 26 | CATEGORICAL = "categorical" 27 | CONTINUOUS = "continuous" 28 | 29 | LOGGER = logging.getLogger(__name__) 30 | 31 | DATA_PATH = os.path.join(os.path.dirname(__file__), 'tabular_datasets') 32 | 33 | def _load_json(path): 34 | with open(path) as json_file: 35 | return json.load(json_file) 36 | 37 | 38 | def _load_file(filename, loader): 39 | local_path = os.path.join(DATA_PATH, filename) 40 | 41 | if loader == np.load: 42 | return loader(local_path, allow_pickle=True) 43 | return loader(local_path) 44 | 45 | 46 | def _get_columns(metadata): 47 | categorical_columns = list() 48 | 49 | for column_idx, column in enumerate(metadata['columns']): 50 | if column['type'] == CATEGORICAL: 51 | categorical_columns.append(column_idx) 52 | 53 | return categorical_columns 54 | 55 | 56 | def load_data(name): 57 | data_dir = f'data/{name}' 58 | info_path = f'{data_dir}/info.json' 59 | 60 | train = pd.read_csv(f'{data_dir}/train.csv').to_numpy() 61 | test = pd.read_csv(f'{data_dir}/test.csv').to_numpy() 62 | 63 | with open(f'{data_dir}/info.json', 'r') as f: 64 | info = json.load(f) 65 | 66 | task_type = info['task_type'] 67 | 68 | num_cols = info['num_col_idx'] 69 | cat_cols = info['cat_col_idx'] 70 | target_cols = info['target_col_idx'] 71 | 72 | if task_type != 'regression': 73 | cat_cols = cat_cols + target_cols 74 | 75 | return train, test, (cat_cols, info) 76 | 77 | 78 | def get_dataset(FLAGS, evaluation=False): 79 | 80 | batch_size = FLAGS.training_batch_size if not evaluation else FLAGS.eval_batch_size 81 | 82 | if batch_size % torch.cuda.device_count() != 0: 83 | raise ValueError(f'Batch sizes ({batch_size} must be divided by' 84 | f'the number of devices ({torch.cuda.device_count()})') 85 | 86 | 87 | # Create dataset builders for tabular data. 88 | train, test, cols = load_data(FLAGS.dataname) 89 | cols_idx = list(np.arange(train.shape[1])) 90 | dis_idx = cols[0] 91 | con_idx = [x for x in cols_idx if x not in dis_idx] 92 | 93 | #split continuous and categorical 94 | train_con = train[:,con_idx] 95 | train_dis = train[:,dis_idx] 96 | 97 | #new index 98 | cat_idx_ = list(np.arange(train_dis.shape[1]))[:len(cols[0])] 99 | 100 | transformer_con = GeneralTransformer() 101 | transformer_dis = GeneralTransformer() 102 | 103 | transformer_con.fit(train_con, []) 104 | transformer_dis.fit(train_dis, cat_idx_) 105 | 106 | train_con_data = transformer_con.transform(train_con) 107 | train_dis_data = transformer_dis.transform(train_dis) 108 | 109 | 110 | return train, train_con_data, train_dis_data, test, (transformer_con, transformer_dis, cols[1]), con_idx, dis_idx 111 | -------------------------------------------------------------------------------- /eval/mle/tabular_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | CATEGORICAL = "categorical" 5 | CONTINUOUS = "continuous" 6 | 7 | class Transformer: 8 | 9 | @staticmethod 10 | def get_metadata(data, categorical_columns=tuple()): 11 | meta = [] 12 | 13 | df = pd.DataFrame(data) 14 | for index in df: 15 | column = df[index] 16 | 17 | if index in categorical_columns: 18 | mapper = column.value_counts().index.tolist() 19 | meta.append({ 20 | "name": index, 21 | "type": CATEGORICAL, 22 | "size": len(mapper), 23 | "i2s": mapper 24 | }) 25 | else: 26 | meta.append({ 27 | "name": index, 28 | "type": CONTINUOUS, 29 | "min": column.min(), 30 | "max": column.max(), 31 | }) 32 | 33 | return meta 34 | 35 | def fit(self, data, categorical_columns=tuple()): 36 | raise NotImplementedError 37 | 38 | def transform(self, data): 39 | raise NotImplementedError 40 | 41 | def inverse_transform(self, data): 42 | raise NotImplementedError 43 | 44 | 45 | class GeneralTransformer(Transformer): 46 | 47 | def __init__(self, act='tanh'): 48 | self.act = act 49 | self.meta = None 50 | self.output_dim = None 51 | 52 | def fit(self, data, categorical_columns=tuple()): 53 | self.meta = self.get_metadata(data, categorical_columns) 54 | self.output_dim = 0 55 | for info in self.meta: 56 | if info['type'] in [CONTINUOUS]: 57 | self.output_dim += 1 58 | else: 59 | self.output_dim += info['size'] 60 | 61 | def transform(self, data): 62 | data_t = [] 63 | self.output_info = [] 64 | for id_, info in enumerate(self.meta): 65 | col = data[:, id_] 66 | if info['type'] == CONTINUOUS: 67 | col = (col - (info['min'])) / (info['max'] - info['min']) 68 | if self.act == 'tanh': 69 | col = col * 2 - 1 70 | data_t.append(col.reshape([-1, 1])) 71 | self.output_info.append((1, self.act)) 72 | 73 | else: 74 | col_t = np.zeros([len(data), info['size']]) 75 | idx = list(map(info['i2s'].index, col)) 76 | col_t[np.arange(len(data)), idx] = 1 77 | data_t.append(col_t) 78 | self.output_info.append((info['size'], 'softmax')) 79 | 80 | return np.concatenate(data_t, axis=1) 81 | 82 | def inverse_transform(self, data): 83 | if self.meta[1]['type'] == CONTINUOUS: 84 | data_t = np.zeros([len(data), len(self.meta)]) 85 | else: 86 | dtype = np.dtype('U50') 87 | data_t = np.empty([len(data), len(self.meta)], dtype=dtype) 88 | 89 | 90 | data = data.copy() 91 | for id_, info in enumerate(self.meta): 92 | 93 | if info['type'] == CONTINUOUS: 94 | current = data[:, 0] 95 | data = data[:, 1:] 96 | 97 | if self.act == 'tanh': 98 | current = (current + 1) / 2 99 | 100 | current = np.clip(current, 0, 1) 101 | data_t[:, id_] = current * (info['max'] - info['min']) + info['min'] 102 | 103 | else: 104 | current = data[:, :info['size']] 105 | data = data[:, info['size']:] 106 | idx = np.argmax(current, axis=1) 107 | recovered = list(map(info['i2s'].__getitem__, idx)) 108 | 109 | data_t[:, id_] = recovered 110 | return data_t 111 | -------------------------------------------------------------------------------- /eval_impute.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.preprocessing import OneHotEncoder 4 | from sklearn.metrics import f1_score, roc_auc_score 5 | import argparse 6 | 7 | 8 | parser = argparse.ArgumentParser(description='Missing Value Imputation') 9 | 10 | parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.') 11 | parser.add_argument('--col', type=int, default=0, help='Numerical Column to Impute') 12 | 13 | args = parser.parse_args() 14 | 15 | dataname = args.dataname 16 | col = args.col 17 | 18 | dataname = args.dataname 19 | 20 | data_dir = f'data/{dataname}' 21 | 22 | real_path = f'{data_dir}/test.csv' 23 | 24 | 25 | encoder = OneHotEncoder() 26 | 27 | real_data = pd.read_csv(real_path) 28 | target_col = real_data.columns[-1] 29 | real_target = real_data[target_col].to_numpy().reshape(-1,1) 30 | real_y = encoder.fit_transform(real_target).toarray() 31 | 32 | 33 | syn_y = [] 34 | for i in range(50): 35 | syn_path = f'impute/{i}.csv' 36 | syn_data = pd.read_csv(syn_path) 37 | target = syn_data[target_col].to_numpy().reshape(-1, 1) 38 | syn_y.append(encoder.transform(target).toarray()) 39 | 40 | syn_y = np.stack(syn_y).mean(0) 41 | 42 | 43 | micro_f1 = f1_score(real_y.argmax(axis=1), syn_y.argmax(axis=1), average='micro') 44 | auc = roc_auc_score(real_y, syn_y, average='micro') 45 | 46 | print("Micro-F1:", micro_f1) 47 | print("AUC:", auc) 48 | -------------------------------------------------------------------------------- /images/density.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/images/density.jpg -------------------------------------------------------------------------------- /images/heat_map.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/images/heat_map.jpg -------------------------------------------------------------------------------- /images/nfe1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/images/nfe1.jpg -------------------------------------------------------------------------------- /images/radar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/images/radar.jpg -------------------------------------------------------------------------------- /images/tabsyn_model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/tabsyn/cb5ac0f74ec36ee88e7a974a393dfbef50d42da7/images/tabsyn_model.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import execute_function, get_args 3 | 4 | if __name__ == '__main__': 5 | args = get_args() 6 | if torch.cuda.is_available(): 7 | args.device = f'cuda:{args.gpu}' 8 | else: 9 | args.device = 'cpu' 10 | 11 | if not args.save_path: 12 | args.save_path = f'synthetic/{args.dataname}/{args.method}.csv' 13 | main_fn = execute_function(args.method, args.mode) 14 | 15 | main_fn(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | scikit-learn 4 | scipy 5 | icecream 6 | xlrd 7 | tomli-w 8 | zero 9 | category_encoders 10 | imbalanced-learn 11 | kaggle 12 | transformers 13 | datasets 14 | peft==0.9.0 15 | ml_collections 16 | sdmetrics==0.11.1 17 | prdc 18 | rdt 19 | openpyxl 20 | xgboost -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from icecream import install 3 | 4 | torch.set_num_threads(1) 5 | install() 6 | 7 | from . import env # noqa 8 | from .data import * # noqa 9 | from .deep import * # noqa 10 | from .env import * # noqa 11 | from .metrics import * # noqa 12 | from .util import * # noqa -------------------------------------------------------------------------------- /src/deep.py: -------------------------------------------------------------------------------- 1 | import statistics 2 | from dataclasses import dataclass 3 | from typing import Any, Callable, Literal, cast 4 | 5 | # import rtdl 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import zero 11 | from torch import Tensor 12 | 13 | from .util import TaskType 14 | 15 | 16 | def cos_sin(x: Tensor) -> Tensor: 17 | return torch.cat([torch.cos(x), torch.sin(x)], -1) 18 | 19 | 20 | @dataclass 21 | class PeriodicOptions: 22 | n: int # the output size is 2 * n 23 | sigma: float 24 | trainable: bool 25 | initialization: Literal['log-linear', 'normal'] 26 | 27 | 28 | class Periodic(nn.Module): 29 | def __init__(self, n_features: int, options: PeriodicOptions) -> None: 30 | super().__init__() 31 | if options.initialization == 'log-linear': 32 | coefficients = options.sigma ** (torch.arange(options.n) / options.n) 33 | coefficients = coefficients[None].repeat(n_features, 1) 34 | else: 35 | assert options.initialization == 'normal' 36 | coefficients = torch.normal(0.0, options.sigma, (n_features, options.n)) 37 | if options.trainable: 38 | self.coefficients = nn.Parameter(coefficients) # type: ignore[code] 39 | else: 40 | self.register_buffer('coefficients', coefficients) 41 | 42 | def forward(self, x: Tensor) -> Tensor: 43 | assert x.ndim == 2 44 | return cos_sin(2 * torch.pi * self.coefficients[None] * x[..., None]) 45 | 46 | 47 | def get_n_parameters(m: nn.Module): 48 | return sum(x.numel() for x in m.parameters() if x.requires_grad) 49 | 50 | 51 | def get_loss_fn(task_type: TaskType) -> Callable[..., Tensor]: 52 | return ( 53 | F.binary_cross_entropy_with_logits 54 | if task_type == TaskType.BINCLASS 55 | else F.cross_entropy 56 | if task_type == TaskType.MULTICLASS 57 | else F.mse_loss 58 | ) 59 | 60 | 61 | def default_zero_weight_decay_condition(module_name, module, parameter_name, parameter): 62 | del module_name, parameter 63 | return parameter_name.endswith('bias') or isinstance( 64 | module, 65 | ( 66 | nn.BatchNorm1d, 67 | nn.LayerNorm, 68 | nn.InstanceNorm1d, 69 | rtdl.CLSToken, 70 | rtdl.NumericalFeatureTokenizer, 71 | rtdl.CategoricalFeatureTokenizer, 72 | Periodic, 73 | ), 74 | ) 75 | 76 | 77 | def split_parameters_by_weight_decay( 78 | model: nn.Module, zero_weight_decay_condition=default_zero_weight_decay_condition 79 | ) -> list[dict[str, Any]]: 80 | parameters_info = {} 81 | for module_name, module in model.named_modules(): 82 | for parameter_name, parameter in module.named_parameters(): 83 | full_parameter_name = ( 84 | f'{module_name}.{parameter_name}' if module_name else parameter_name 85 | ) 86 | parameters_info.setdefault(full_parameter_name, ([], parameter))[0].append( 87 | zero_weight_decay_condition( 88 | module_name, module, parameter_name, parameter 89 | ) 90 | ) 91 | params_with_wd = {'params': []} 92 | params_without_wd = {'params': [], 'weight_decay': 0.0} 93 | for full_parameter_name, (results, parameter) in parameters_info.items(): 94 | (params_without_wd if any(results) else params_with_wd)['params'].append( 95 | parameter 96 | ) 97 | return [params_with_wd, params_without_wd] 98 | 99 | 100 | def make_optimizer( 101 | config: dict[str, Any], 102 | parameter_groups, 103 | ) -> optim.Optimizer: 104 | if config['optimizer'] == 'FT-Transformer-default': 105 | return optim.AdamW(parameter_groups, lr=1e-4, weight_decay=1e-5) 106 | return getattr(optim, config['optimizer'])( 107 | parameter_groups, 108 | **{x: config[x] for x in ['lr', 'weight_decay', 'momentum'] if x in config}, 109 | ) 110 | 111 | 112 | def get_lr(optimizer: optim.Optimizer) -> float: 113 | return next(iter(optimizer.param_groups))['lr'] 114 | 115 | 116 | def is_oom_exception(err: RuntimeError) -> bool: 117 | return any( 118 | x in str(err) 119 | for x in [ 120 | 'CUDA out of memory', 121 | 'CUBLAS_STATUS_ALLOC_FAILED', 122 | 'CUDA error: out of memory', 123 | ] 124 | ) 125 | 126 | 127 | def train_with_auto_virtual_batch( 128 | optimizer, 129 | loss_fn, 130 | step, 131 | batch, 132 | chunk_size: int, 133 | ) -> tuple[Tensor, int]: 134 | batch_size = len(batch) 135 | random_state = zero.random.get_state() 136 | loss = None 137 | while chunk_size != 0: 138 | try: 139 | zero.random.set_state(random_state) 140 | optimizer.zero_grad() 141 | if batch_size <= chunk_size: 142 | loss = loss_fn(*step(batch)) 143 | loss.backward() 144 | else: 145 | loss = None 146 | for chunk in zero.iter_batches(batch, chunk_size): 147 | chunk_loss = loss_fn(*step(chunk)) 148 | chunk_loss = chunk_loss * (len(chunk) / batch_size) 149 | chunk_loss.backward() 150 | if loss is None: 151 | loss = chunk_loss.detach() 152 | else: 153 | loss += chunk_loss.detach() 154 | except RuntimeError as err: 155 | if not is_oom_exception(err): 156 | raise 157 | chunk_size //= 2 158 | else: 159 | break 160 | if not chunk_size: 161 | raise RuntimeError('Not enough memory even for batch_size=1') 162 | optimizer.step() 163 | return cast(Tensor, loss), chunk_size 164 | 165 | 166 | def process_epoch_losses(losses: list[Tensor]) -> tuple[list[float], float]: 167 | losses_ = torch.stack(losses).tolist() 168 | return losses_, statistics.mean(losses_) -------------------------------------------------------------------------------- /src/env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Have not used in TabDDPM project. 3 | """ 4 | 5 | import datetime 6 | import os 7 | import shutil 8 | import typing as ty 9 | from pathlib import Path 10 | 11 | PROJ = Path('tab-ddpm/').absolute().resolve() 12 | EXP = PROJ / 'exp' 13 | DATA = PROJ / 'data' 14 | 15 | 16 | def get_path(path: ty.Union[str, Path]) -> Path: 17 | if isinstance(path, str): 18 | path = Path(path) 19 | if not path.is_absolute(): 20 | path = PROJ / path 21 | return path.resolve() 22 | 23 | 24 | def get_relative_path(path: ty.Union[str, Path]) -> Path: 25 | return get_path(path).relative_to(PROJ) 26 | 27 | 28 | def duplicate_path( 29 | src: ty.Union[str, Path], alternative_project_dir: ty.Union[str, Path] 30 | ) -> None: 31 | src = get_path(src) 32 | alternative_project_dir = get_path(alternative_project_dir) 33 | dst = alternative_project_dir / src.relative_to(PROJ) 34 | dst.parent.mkdir(parents=True, exist_ok=True) 35 | if dst.exists(): 36 | dst = dst.with_name( 37 | dst.name + '_' + datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 38 | ) 39 | (shutil.copytree if src.is_dir() else shutil.copyfile)(src, dst) -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from typing import Any, Optional, Tuple, Dict, Union, cast 3 | from functools import partial 4 | 5 | import numpy as np 6 | import scipy.special 7 | import sklearn.metrics as skm 8 | 9 | from . import util 10 | from .util import TaskType 11 | 12 | 13 | class PredictionType(enum.Enum): 14 | LOGITS = 'logits' 15 | PROBS = 'probs' 16 | 17 | class MetricsReport: 18 | def __init__(self, report: dict, task_type: TaskType): 19 | self._res = {k: {} for k in report.keys()} 20 | if task_type in (TaskType.BINCLASS, TaskType.MULTICLASS): 21 | self._metrics_names = ["acc", "f1"] 22 | for k in report.keys(): 23 | self._res[k]["acc"] = report[k]["accuracy"] 24 | self._res[k]["f1"] = report[k]["macro avg"]["f1-score"] 25 | if task_type == TaskType.BINCLASS: 26 | self._res[k]["roc_auc"] = report[k]["roc_auc"] 27 | self._metrics_names.append("roc_auc") 28 | 29 | elif task_type == TaskType.REGRESSION: 30 | self._metrics_names = ["r2", "rmse"] 31 | for k in report.keys(): 32 | self._res[k]["r2"] = report[k]["r2"] 33 | self._res[k]["rmse"] = report[k]["rmse"] 34 | else: 35 | raise "Unknown TaskType!" 36 | 37 | def get_splits_names(self) -> list[str]: 38 | return self._res.keys() 39 | 40 | def get_metrics_names(self) -> list[str]: 41 | return self._metrics_names 42 | 43 | def get_metric(self, split: str, metric: str) -> float: 44 | return self._res[split][metric] 45 | 46 | def get_val_score(self) -> float: 47 | return self._res["val"]["r2"] if "r2" in self._res["val"] else self._res["val"]["f1"] 48 | 49 | def get_test_score(self) -> float: 50 | return self._res["test"]["r2"] if "r2" in self._res["test"] else self._res["test"]["f1"] 51 | 52 | def print_metrics(self) -> None: 53 | res = { 54 | "val": {k: np.around(self._res["val"][k], 4) for k in self._res["val"]}, 55 | "test": {k: np.around(self._res["test"][k], 4) for k in self._res["test"]} 56 | } 57 | 58 | print("*"*100) 59 | print("[val]") 60 | print(res["val"]) 61 | print("[test]") 62 | print(res["test"]) 63 | 64 | return res 65 | 66 | class SeedsMetricsReport: 67 | def __init__(self): 68 | self._reports = [] 69 | 70 | def add_report(self, report: MetricsReport) -> None: 71 | self._reports.append(report) 72 | 73 | def get_mean_std(self) -> dict: 74 | res = {k: {} for k in ["train", "val", "test"]} 75 | for split in self._reports[0].get_splits_names(): 76 | for metric in self._reports[0].get_metrics_names(): 77 | res[split][metric] = [x.get_metric(split, metric) for x in self._reports] 78 | 79 | agg_res = {k: {} for k in ["train", "val", "test"]} 80 | for split in self._reports[0].get_splits_names(): 81 | for metric in self._reports[0].get_metrics_names(): 82 | for k, f in [("count", len), ("mean", np.mean), ("std", np.std)]: 83 | agg_res[split][f"{metric}-{k}"] = f(res[split][metric]) 84 | self._res = res 85 | self._agg_res = agg_res 86 | 87 | return agg_res 88 | 89 | def print_result(self) -> dict: 90 | res = {split: {k: float(np.around(self._agg_res[split][k], 4)) for k in self._agg_res[split]} for split in ["val", "test"]} 91 | print("="*100) 92 | print("EVAL RESULTS:") 93 | print("[val]") 94 | print(res["val"]) 95 | print("[test]") 96 | print(res["test"]) 97 | print("="*100) 98 | return res 99 | 100 | def calculate_rmse( 101 | y_true: np.ndarray, y_pred: np.ndarray, std = None) -> float: 102 | rmse = skm.mean_squared_error(y_true, y_pred) ** 0.5 103 | if std is not None: 104 | rmse *= std 105 | return rmse 106 | 107 | 108 | def _get_labels_and_probs( 109 | y_pred: np.ndarray, task_type: TaskType, prediction_type: Optional[PredictionType] 110 | ) -> Tuple[np.ndarray, Optional[np.ndarray]]: 111 | assert task_type in (TaskType.BINCLASS, TaskType.MULTICLASS) 112 | 113 | if prediction_type is None: 114 | return y_pred, None 115 | 116 | if prediction_type == PredictionType.LOGITS: 117 | probs = ( 118 | scipy.special.expit(y_pred) 119 | if task_type == TaskType.BINCLASS 120 | else scipy.special.softmax(y_pred, axis=1) 121 | ) 122 | elif prediction_type == PredictionType.PROBS: 123 | probs = y_pred 124 | else: 125 | util.raise_unknown('prediction_type', prediction_type) 126 | 127 | assert probs is not None 128 | labels = np.round(probs) if task_type == TaskType.BINCLASS else probs.argmax(axis=1) 129 | return labels.astype('int64'), probs 130 | 131 | 132 | def calculate_metrics( 133 | y_true: np.ndarray, 134 | y_pred: np.ndarray, 135 | task_type: Union[str, TaskType], 136 | prediction_type: Optional[Union[str, PredictionType]], 137 | y_info: Dict[str, Any], 138 | ) -> Dict[str, Any]: 139 | # Example: calculate_metrics(y_true, y_pred, 'binclass', 'logits', {}) 140 | task_type = TaskType(task_type) 141 | if prediction_type is not None: 142 | prediction_type = PredictionType(prediction_type) 143 | 144 | if task_type == TaskType.REGRESSION: 145 | assert prediction_type is None 146 | assert 'std' in y_info 147 | rmse = calculate_rmse(y_true, y_pred, y_info['std']) 148 | r2 = skm.r2_score(y_true, y_pred) 149 | result = {'rmse': rmse, 'r2': r2} 150 | else: 151 | labels, probs = _get_labels_and_probs(y_pred, task_type, prediction_type) 152 | result = cast( 153 | Dict[str, Any], skm.classification_report(y_true, labels, output_dict=True) 154 | ) 155 | if task_type == TaskType.BINCLASS: 156 | result['roc_auc'] = skm.roc_auc_score(y_true, probs) 157 | return result -------------------------------------------------------------------------------- /tabsyn/latent_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from utils_train import preprocess 7 | from tabsyn.vae.model import Decoder_model 8 | 9 | def get_input_train(args): 10 | dataname = args.dataname 11 | 12 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 13 | dataset_dir = f'data/{dataname}' 14 | 15 | with open(f'{dataset_dir}/info.json', 'r') as f: 16 | info = json.load(f) 17 | 18 | ckpt_dir = f'{curr_dir}/ckpt/{dataname}/' 19 | embedding_save_path = f'{curr_dir}/vae/ckpt/{dataname}/train_z.npy' 20 | train_z = torch.tensor(np.load(embedding_save_path)).float() 21 | 22 | train_z = train_z[:, 1:, :] 23 | B, num_tokens, token_dim = train_z.size() 24 | in_dim = num_tokens * token_dim 25 | 26 | train_z = train_z.view(B, in_dim) 27 | 28 | return train_z, curr_dir, dataset_dir, ckpt_dir, info 29 | 30 | 31 | def get_input_generate(args): 32 | dataname = args.dataname 33 | 34 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 35 | dataset_dir = f'data/{dataname}' 36 | ckpt_dir = f'{curr_dir}/ckpt/{dataname}' 37 | 38 | with open(f'{dataset_dir}/info.json', 'r') as f: 39 | info = json.load(f) 40 | 41 | task_type = info['task_type'] 42 | 43 | 44 | ckpt_dir = f'{curr_dir}/ckpt/{dataname}' 45 | 46 | _, _, categories, d_numerical, num_inverse, cat_inverse = preprocess(dataset_dir, task_type = task_type, inverse = True) 47 | 48 | embedding_save_path = f'{curr_dir}/vae/ckpt/{dataname}/train_z.npy' 49 | train_z = torch.tensor(np.load(embedding_save_path)).float() 50 | 51 | train_z = train_z[:, 1:, :] 52 | 53 | B, num_tokens, token_dim = train_z.size() 54 | in_dim = num_tokens * token_dim 55 | 56 | train_z = train_z.view(B, in_dim) 57 | pre_decoder = Decoder_model(2, d_numerical, categories, 4, n_head = 1, factor = 32) 58 | 59 | decoder_save_path = f'{curr_dir}/vae/ckpt/{dataname}/decoder.pt' 60 | pre_decoder.load_state_dict(torch.load(decoder_save_path)) 61 | 62 | info['pre_decoder'] = pre_decoder 63 | info['token_dim'] = token_dim 64 | 65 | return train_z, curr_dir, dataset_dir, ckpt_dir, info, num_inverse, cat_inverse 66 | 67 | 68 | 69 | @torch.no_grad() 70 | def split_num_cat_target(syn_data, info, num_inverse, cat_inverse, device): 71 | task_type = info['task_type'] 72 | 73 | num_col_idx = info['num_col_idx'] 74 | cat_col_idx = info['cat_col_idx'] 75 | target_col_idx = info['target_col_idx'] 76 | 77 | n_num_feat = len(num_col_idx) 78 | n_cat_feat = len(cat_col_idx) 79 | 80 | if task_type == 'regression': 81 | n_num_feat += len(target_col_idx) 82 | else: 83 | n_cat_feat += len(target_col_idx) 84 | 85 | 86 | pre_decoder = info['pre_decoder'] 87 | token_dim = info['token_dim'] 88 | 89 | syn_data = syn_data.reshape(syn_data.shape[0], -1, token_dim) 90 | 91 | norm_input = pre_decoder(torch.tensor(syn_data)) 92 | x_hat_num, x_hat_cat = norm_input 93 | 94 | syn_cat = [] 95 | for pred in x_hat_cat: 96 | syn_cat.append(pred.argmax(dim = -1)) 97 | 98 | syn_num = x_hat_num.cpu().numpy() 99 | syn_cat = torch.stack(syn_cat).t().cpu().numpy() 100 | 101 | syn_num = num_inverse(syn_num) 102 | syn_cat = cat_inverse(syn_cat) 103 | 104 | if info['task_type'] == 'regression': 105 | syn_target = syn_num[:, :len(target_col_idx)] 106 | syn_num = syn_num[:, len(target_col_idx):] 107 | 108 | else: 109 | print(syn_cat.shape) 110 | syn_target = syn_cat[:, :len(target_col_idx)] 111 | syn_cat = syn_cat[:, len(target_col_idx):] 112 | 113 | return syn_num, syn_cat, syn_target 114 | 115 | def recover_data(syn_num, syn_cat, syn_target, info): 116 | 117 | num_col_idx = info['num_col_idx'] 118 | cat_col_idx = info['cat_col_idx'] 119 | target_col_idx = info['target_col_idx'] 120 | 121 | 122 | idx_mapping = info['idx_mapping'] 123 | idx_mapping = {int(key): value for key, value in idx_mapping.items()} 124 | 125 | syn_df = pd.DataFrame() 126 | 127 | if info['task_type'] == 'regression': 128 | for i in range(len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)): 129 | if i in set(num_col_idx): 130 | syn_df[i] = syn_num[:, idx_mapping[i]] 131 | elif i in set(cat_col_idx): 132 | syn_df[i] = syn_cat[:, idx_mapping[i] - len(num_col_idx)] 133 | else: 134 | syn_df[i] = syn_target[:, idx_mapping[i] - len(num_col_idx) - len(cat_col_idx)] 135 | 136 | 137 | else: 138 | for i in range(len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)): 139 | if i in set(num_col_idx): 140 | syn_df[i] = syn_num[:, idx_mapping[i]] 141 | elif i in set(cat_col_idx): 142 | syn_df[i] = syn_cat[:, idx_mapping[i] - len(num_col_idx)] 143 | else: 144 | syn_df[i] = syn_target[:, idx_mapping[i] - len(num_col_idx) - len(cat_col_idx)] 145 | 146 | return syn_df 147 | 148 | 149 | def process_invalid_id(syn_cat, min_cat, max_cat): 150 | syn_cat = np.clip(syn_cat, min_cat, max_cat) 151 | 152 | return syn_cat 153 | 154 | -------------------------------------------------------------------------------- /tabsyn/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torch.utils.data import DataLoader 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau 6 | import argparse 7 | import warnings 8 | import time 9 | 10 | from tqdm import tqdm 11 | from tabsyn.model import MLPDiffusion, Model 12 | from tabsyn.latent_utils import get_input_train 13 | 14 | warnings.filterwarnings('ignore') 15 | 16 | 17 | def main(args): 18 | device = args.device 19 | 20 | train_z, _, _, ckpt_path, _ = get_input_train(args) 21 | 22 | print(ckpt_path) 23 | 24 | if not os.path.exists(ckpt_path): 25 | os.makedirs(ckpt_path) 26 | 27 | in_dim = train_z.shape[1] 28 | 29 | mean, std = train_z.mean(0), train_z.std(0) 30 | 31 | train_z = (train_z - mean) / 2 32 | train_data = train_z 33 | 34 | 35 | batch_size = 4096 36 | train_loader = DataLoader( 37 | train_data, 38 | batch_size = batch_size, 39 | shuffle = True, 40 | num_workers = 4, 41 | ) 42 | 43 | num_epochs = 10000 + 1 44 | 45 | denoise_fn = MLPDiffusion(in_dim, 1024).to(device) 46 | print(denoise_fn) 47 | 48 | num_params = sum(p.numel() for p in denoise_fn.parameters()) 49 | print("the number of parameters", num_params) 50 | 51 | model = Model(denoise_fn = denoise_fn, hid_dim = train_z.shape[1]).to(device) 52 | 53 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0) 54 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20, verbose=True) 55 | 56 | model.train() 57 | 58 | best_loss = float('inf') 59 | patience = 0 60 | start_time = time.time() 61 | for epoch in range(num_epochs): 62 | 63 | pbar = tqdm(train_loader, total=len(train_loader)) 64 | pbar.set_description(f"Epoch {epoch+1}/{num_epochs}") 65 | 66 | batch_loss = 0.0 67 | len_input = 0 68 | for batch in pbar: 69 | inputs = batch.float().to(device) 70 | loss = model(inputs) 71 | 72 | loss = loss.mean() 73 | 74 | batch_loss += loss.item() * len(inputs) 75 | len_input += len(inputs) 76 | 77 | optimizer.zero_grad() 78 | loss.backward() 79 | optimizer.step() 80 | 81 | pbar.set_postfix({"Loss": loss.item()}) 82 | 83 | curr_loss = batch_loss/len_input 84 | scheduler.step(curr_loss) 85 | 86 | if curr_loss < best_loss: 87 | best_loss = curr_loss 88 | patience = 0 89 | torch.save(model.state_dict(), f'{ckpt_path}/model.pt') 90 | else: 91 | patience += 1 92 | if patience == 500: 93 | print('Early stopping') 94 | break 95 | 96 | if epoch % 1000 == 0: 97 | torch.save(model.state_dict(), f'{ckpt_path}/model_{epoch}.pt') 98 | 99 | end_time = time.time() 100 | print('Time: ', end_time - start_time) 101 | 102 | if __name__ == '__main__': 103 | 104 | parser = argparse.ArgumentParser(description='Training of TabSyn') 105 | 106 | parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.') 107 | parser.add_argument('--gpu', type=int, default=0, help='GPU index.') 108 | 109 | args = parser.parse_args() 110 | 111 | # check cuda 112 | if args.gpu != -1 and torch.cuda.is_available(): 113 | args.device = f'cuda:{args.gpu}' 114 | else: 115 | args.device = 'cpu' -------------------------------------------------------------------------------- /tabsyn/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim 8 | from torch import Tensor 9 | from tabsyn.diffusion_utils import EDMLoss 10 | 11 | ModuleType = Union[str, Callable[..., nn.Module]] 12 | 13 | class SiLU(nn.Module): 14 | def forward(self, x): 15 | return x * torch.sigmoid(x) 16 | 17 | class PositionalEmbedding(torch.nn.Module): 18 | def __init__(self, num_channels, max_positions=10000, endpoint=False): 19 | super().__init__() 20 | self.num_channels = num_channels 21 | self.max_positions = max_positions 22 | self.endpoint = endpoint 23 | 24 | def forward(self, x): 25 | freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) 26 | freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) 27 | freqs = (1 / self.max_positions) ** freqs 28 | x = x.ger(freqs.to(x.dtype)) 29 | x = torch.cat([x.cos(), x.sin()], dim=1) 30 | return x 31 | 32 | def reglu(x: Tensor) -> Tensor: 33 | """The ReGLU activation function from [1]. 34 | References: 35 | [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 36 | """ 37 | assert x.shape[-1] % 2 == 0 38 | a, b = x.chunk(2, dim=-1) 39 | return a * F.relu(b) 40 | 41 | 42 | def geglu(x: Tensor) -> Tensor: 43 | """The GEGLU activation function from [1]. 44 | References: 45 | [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 46 | """ 47 | assert x.shape[-1] % 2 == 0 48 | a, b = x.chunk(2, dim=-1) 49 | return a * F.gelu(b) 50 | 51 | class ReGLU(nn.Module): 52 | """The ReGLU activation function from [shazeer2020glu]. 53 | 54 | Examples: 55 | .. testcode:: 56 | 57 | module = ReGLU() 58 | x = torch.randn(3, 4) 59 | assert module(x).shape == (3, 2) 60 | 61 | References: 62 | * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020 63 | """ 64 | 65 | def forward(self, x: Tensor) -> Tensor: 66 | return reglu(x) 67 | 68 | 69 | class GEGLU(nn.Module): 70 | """The GEGLU activation function from [shazeer2020glu]. 71 | 72 | Examples: 73 | .. testcode:: 74 | 75 | module = GEGLU() 76 | x = torch.randn(3, 4) 77 | assert module(x).shape == (3, 2) 78 | 79 | References: 80 | * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020 81 | """ 82 | 83 | def forward(self, x: Tensor) -> Tensor: 84 | return geglu(x) 85 | 86 | 87 | class FourierEmbedding(torch.nn.Module): 88 | def __init__(self, num_channels, scale=16): 89 | super().__init__() 90 | self.register_buffer('freqs', torch.randn(num_channels // 2) * scale) 91 | 92 | def forward(self, x): 93 | x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) 94 | x = torch.cat([x.cos(), x.sin()], dim=1) 95 | return x 96 | 97 | class MLPDiffusion(nn.Module): 98 | def __init__(self, d_in, dim_t = 512): 99 | super().__init__() 100 | self.dim_t = dim_t 101 | 102 | self.proj = nn.Linear(d_in, dim_t) 103 | 104 | self.mlp = nn.Sequential( 105 | nn.Linear(dim_t, dim_t * 2), 106 | nn.SiLU(), 107 | nn.Linear(dim_t * 2, dim_t * 2), 108 | nn.SiLU(), 109 | nn.Linear(dim_t * 2, dim_t), 110 | nn.SiLU(), 111 | nn.Linear(dim_t, d_in), 112 | ) 113 | 114 | self.map_noise = PositionalEmbedding(num_channels=dim_t) 115 | self.time_embed = nn.Sequential( 116 | nn.Linear(dim_t, dim_t), 117 | nn.SiLU(), 118 | nn.Linear(dim_t, dim_t) 119 | ) 120 | 121 | def forward(self, x, noise_labels, class_labels=None): 122 | emb = self.map_noise(noise_labels) 123 | emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos 124 | emb = self.time_embed(emb) 125 | 126 | x = self.proj(x) + emb 127 | return self.mlp(x) 128 | 129 | 130 | class Precond(nn.Module): 131 | def __init__(self, 132 | denoise_fn, 133 | hid_dim, 134 | sigma_min = 0, # Minimum supported noise level. 135 | sigma_max = float('inf'), # Maximum supported noise level. 136 | sigma_data = 0.5, # Expected standard deviation of the training data. 137 | ): 138 | super().__init__() 139 | 140 | self.hid_dim = hid_dim 141 | self.sigma_min = sigma_min 142 | self.sigma_max = sigma_max 143 | self.sigma_data = sigma_data 144 | ########### 145 | self.denoise_fn_F = denoise_fn 146 | 147 | def forward(self, x, sigma): 148 | 149 | x = x.to(torch.float32) 150 | 151 | sigma = sigma.to(torch.float32).reshape(-1, 1) 152 | dtype = torch.float32 153 | 154 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 155 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() 156 | c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() 157 | c_noise = sigma.log() / 4 158 | 159 | x_in = c_in * x 160 | F_x = self.denoise_fn_F((x_in).to(dtype), c_noise.flatten()) 161 | 162 | assert F_x.dtype == dtype 163 | D_x = c_skip * x + c_out * F_x.to(torch.float32) 164 | return D_x 165 | 166 | def round_sigma(self, sigma): 167 | return torch.as_tensor(sigma) 168 | 169 | 170 | class Model(nn.Module): 171 | def __init__(self, denoise_fn, hid_dim, P_mean=-1.2, P_std=1.2, sigma_data=0.5, gamma=5, opts=None, pfgmpp = False): 172 | super().__init__() 173 | 174 | self.denoise_fn_D = Precond(denoise_fn, hid_dim) 175 | self.loss_fn = EDMLoss(P_mean, P_std, sigma_data, hid_dim=hid_dim, gamma=5, opts=None) 176 | 177 | def forward(self, x): 178 | 179 | loss = self.loss_fn(self.denoise_fn_D, x) 180 | return loss.mean(-1).mean() 181 | -------------------------------------------------------------------------------- /tabsyn/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import argparse 4 | import warnings 5 | import time 6 | 7 | from tabsyn.model import MLPDiffusion, Model 8 | from tabsyn.latent_utils import get_input_generate, recover_data, split_num_cat_target 9 | from tabsyn.diffusion_utils import sample 10 | 11 | warnings.filterwarnings('ignore') 12 | 13 | 14 | def main(args): 15 | dataname = args.dataname 16 | device = args.device 17 | steps = args.steps 18 | save_path = args.save_path 19 | 20 | train_z, _, _, ckpt_path, info, num_inverse, cat_inverse = get_input_generate(args) 21 | in_dim = train_z.shape[1] 22 | 23 | mean = train_z.mean(0) 24 | 25 | denoise_fn = MLPDiffusion(in_dim, 1024).to(device) 26 | 27 | model = Model(denoise_fn = denoise_fn, hid_dim = train_z.shape[1]).to(device) 28 | 29 | model.load_state_dict(torch.load(f'{ckpt_path}/model.pt')) 30 | 31 | ''' 32 | Generating samples 33 | ''' 34 | start_time = time.time() 35 | 36 | num_samples = train_z.shape[0] 37 | sample_dim = in_dim 38 | 39 | x_next = sample(model.denoise_fn_D, num_samples, sample_dim) 40 | x_next = x_next * 2 + mean.to(device) 41 | 42 | syn_data = x_next.float().cpu().numpy() 43 | syn_num, syn_cat, syn_target = split_num_cat_target(syn_data, info, num_inverse, cat_inverse, args.device) 44 | 45 | syn_df = recover_data(syn_num, syn_cat, syn_target, info) 46 | 47 | idx_name_mapping = info['idx_name_mapping'] 48 | idx_name_mapping = {int(key): value for key, value in idx_name_mapping.items()} 49 | 50 | syn_df.rename(columns = idx_name_mapping, inplace=True) 51 | syn_df.to_csv(save_path, index = False) 52 | 53 | end_time = time.time() 54 | print('Time:', end_time - start_time) 55 | 56 | print('Saving sampled data to {}'.format(save_path)) 57 | 58 | if __name__ == '__main__': 59 | 60 | parser = argparse.ArgumentParser(description='Generation') 61 | 62 | parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.') 63 | parser.add_argument('--gpu', type=int, default=0, help='GPU index.') 64 | parser.add_argument('--epoch', type=int, default=None, help='Epoch.') 65 | parser.add_argument('--steps', type=int, default=None, help='Number of function evaluations.') 66 | 67 | args = parser.parse_args() 68 | 69 | # check cuda 70 | if args.gpu != -1 and torch.cuda.is_available(): 71 | args.device = f'cuda:{args.gpu}' 72 | else: 73 | args.device = 'cpu' -------------------------------------------------------------------------------- /utils_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import src 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class TabularDataset(Dataset): 9 | def __init__(self, X_num, X_cat): 10 | self.X_num = X_num 11 | self.X_cat = X_cat 12 | 13 | def __getitem__(self, index): 14 | this_num = self.X_num[index] 15 | this_cat = self.X_cat[index] 16 | 17 | sample = (this_num, this_cat) 18 | 19 | return sample 20 | 21 | def __len__(self): 22 | return self.X_num.shape[0] 23 | 24 | def preprocess(dataset_path, task_type = 'binclass', inverse = False, cat_encoding = None, concat = True): 25 | 26 | T_dict = {} 27 | 28 | T_dict['normalization'] = "quantile" 29 | T_dict['num_nan_policy'] = 'mean' 30 | T_dict['cat_nan_policy'] = None 31 | T_dict['cat_min_frequency'] = None 32 | T_dict['cat_encoding'] = cat_encoding 33 | T_dict['y_policy'] = "default" 34 | 35 | T = src.Transformations(**T_dict) 36 | 37 | dataset = make_dataset( 38 | data_path = dataset_path, 39 | T = T, 40 | task_type = task_type, 41 | change_val = False, 42 | concat = concat 43 | ) 44 | 45 | if cat_encoding is None: 46 | X_num = dataset.X_num 47 | X_cat = dataset.X_cat 48 | 49 | X_train_num, X_test_num = X_num['train'], X_num['test'] 50 | X_train_cat, X_test_cat = X_cat['train'], X_cat['test'] 51 | 52 | categories = src.get_categories(X_train_cat) 53 | d_numerical = X_train_num.shape[1] 54 | 55 | X_num = (X_train_num, X_test_num) 56 | X_cat = (X_train_cat, X_test_cat) 57 | 58 | 59 | if inverse: 60 | num_inverse = dataset.num_transform.inverse_transform 61 | cat_inverse = dataset.cat_transform.inverse_transform 62 | 63 | return X_num, X_cat, categories, d_numerical, num_inverse, cat_inverse 64 | else: 65 | return X_num, X_cat, categories, d_numerical 66 | else: 67 | return dataset 68 | 69 | 70 | def update_ema(target_params, source_params, rate=0.999): 71 | """ 72 | Update target parameters to be closer to those of source parameters using 73 | an exponential moving average. 74 | :param target_params: the target parameter sequence. 75 | :param source_params: the source parameter sequence. 76 | :param rate: the EMA rate (closer to 1 means slower). 77 | """ 78 | for target, source in zip(target_params, source_params): 79 | target.detach().mul_(rate).add_(source.detach(), alpha=1 - rate) 80 | 81 | 82 | 83 | def concat_y_to_X(X, y): 84 | if X is None: 85 | return y.reshape(-1, 1) 86 | return np.concatenate([y.reshape(-1, 1), X], axis=1) 87 | 88 | 89 | def make_dataset( 90 | data_path: str, 91 | T: src.Transformations, 92 | task_type, 93 | change_val: bool, 94 | concat = True, 95 | ): 96 | 97 | # classification 98 | if task_type == 'binclass' or task_type == 'multiclass': 99 | X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) else None 100 | X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) else None 101 | y = {} if os.path.exists(os.path.join(data_path, 'y_train.npy')) else None 102 | 103 | for split in ['train', 'test']: 104 | X_num_t, X_cat_t, y_t = src.read_pure_data(data_path, split) 105 | if X_num is not None: 106 | X_num[split] = X_num_t 107 | if X_cat is not None: 108 | if concat: 109 | X_cat_t = concat_y_to_X(X_cat_t, y_t) 110 | X_cat[split] = X_cat_t 111 | if y is not None: 112 | y[split] = y_t 113 | else: 114 | # regression 115 | X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) else None 116 | X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) else None 117 | y = {} if os.path.exists(os.path.join(data_path, 'y_train.npy')) else None 118 | 119 | for split in ['train', 'test']: 120 | X_num_t, X_cat_t, y_t = src.read_pure_data(data_path, split) 121 | 122 | if X_num is not None: 123 | if concat: 124 | X_num_t = concat_y_to_X(X_num_t, y_t) 125 | X_num[split] = X_num_t 126 | if X_cat is not None: 127 | X_cat[split] = X_cat_t 128 | if y is not None: 129 | y[split] = y_t 130 | 131 | info = src.load_json(os.path.join(data_path, 'info.json')) 132 | 133 | D = src.Dataset( 134 | X_num, 135 | X_cat, 136 | y, 137 | y_info={}, 138 | task_type=src.TaskType(info['task_type']), 139 | n_classes=info.get('n_classes') 140 | ) 141 | 142 | if change_val: 143 | D = src.change_val(D) 144 | 145 | # def categorical_to_idx(feature): 146 | # unique_categories = np.unique(feature) 147 | # idx_mapping = {category: index for index, category in enumerate(unique_categories)} 148 | # idx_feature = np.array([idx_mapping[category] for category in feature]) 149 | # return idx_feature 150 | 151 | # for split in ['train', 'val', 'test']: 152 | # D.y[split] = categorical_to_idx(D.y[split].squeeze(1)) 153 | 154 | return src.transform_dataset(D, T, None) --------------------------------------------------------------------------------