├── .DS_Store ├── README.md ├── analysis_utils.py ├── data_utils.py ├── diversity.ipynb ├── figure_data.ipynb ├── figure_result.ipynb ├── generation_utils.py ├── graph_utils.py ├── model_utils.py ├── polymer_generation_playground.ipynb ├── result.ipynb ├── run_vae.py ├── saliency_utils.py ├── topo_analysis ├── .DS_Store ├── gen_branch_30_all.pickle ├── gen_branch_50_all.pickle ├── gen_comb_30_all.pickle ├── gen_comb_50_all.pickle ├── gen_cyclic_30_all.pickle ├── gen_dendrimer_7.5_all.pickle ├── gen_star_30_all.pickle ├── gen_star_7.5_all.pickle ├── latent_space.pickle ├── saliency.pickle ├── umap.pickle ├── umap_-3_12_6_12.pickle ├── umap_1_4_1_16.pickle ├── umap_gnn.pickle ├── umap_reg_false_cls_false.pickle ├── umap_reg_false_cls_true.pickle ├── umap_reg_true_cls_false.pickle └── umap_topo.pickle ├── topo_result ├── desc_dnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 0.01]_0.01_64.h5 ├── desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64.h5 └── gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 100]_0.01_64.h5 └── website ├── .DS_Store └── abstract.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # poly-topoGNN-vae 2 | 3 | ![alt text](./website/abstract.png "Abstract") 4 | 5 | ## Overview 6 | All content here is available under CC BY NC 4.0 License 7 | 8 | Please consider citing references, as pertinent. 9 | 10 | The repository is structured with Python files and Jupyter notebooks. All figures in the paper can be reproduced by running the `figure_data.ipynb` and `figure_result.ipynb` notebooks. Below are the files and their respective functions. 11 | 12 | - `analysis_utils.py` contains functions to compute accuracy metrics (e.g., reconstruction balanced accuracy, $\langle R_\mathrm{g}^2 \rangle$, regression $R^2$, and topology classification $F_1$) derived from predictions and hyperparameter tuning based on the validation dataset. 13 | - `data_utils.py` contains a function to extract topological descriptors from a `networkx.graph` object and another function to load and split the data for training. 14 | - `graph_utils.py` contains functions to clean up a generated graph from the VAE to make it more realistic. 15 | - `saliency_utils.py` computes the saliency map of a model, understanding the importance of topological descriptors on $\langle R_\mathrm{g}^2 \rangle$) regression. 16 | - `model_utils.py` contains all encoder and decoder models and auxiliary functions for VAE training. 17 | 18 | The Jupyter notebooks are 19 | - `result.ipynb` contains all the procedures to generate result pickle files. One thing to note is that the VAE sampling layer is inherently random, reproducing with different seeds may lead to slightly different results. 20 | - `diversity.ipynb` contains model diversity measurements using [Vendi Score](https://github.com/vertaix/Vendi-Score)). 21 | - `figure_data.ipynb` can be used to reproduce all figures related to the data (e.g. $\langle R_\mathrm{g}^2 \rangle$ distribution and topology visualization). 22 | - `figure_result.ipynb` can be used to reproduce all figures related to the results (e.g. reconstruction accuracy, property guided polymer generation and viscosity and analysis). 23 | 24 | In addition, if you are only interested in property guided polymer generation, there is a playground notebook `polymer_generation_playground.ipynb` where you can randomly set the target $\langle R_\mathrm{g}^2 \rangle$ and topology and visualize the generated polymer topologies. 25 | 26 | 27 | ## References 28 | 29 | ### Download Data 30 | The data is publicly available at ... 31 | 32 | Make sure to change the directory in each file to your own directory. 33 | 34 | ### Install Packages 35 | ```bash 36 | python = 3.8 37 | tensorflow = 2.5.0 38 | spektral = 1.3.0 39 | networkx = 2.8.4 40 | scikit-learn = 1.2.2 41 | proplot = 0.9.7 42 | vendi-score = 0.0.3 43 | ``` 44 | 45 | ### Code 46 | To train the VAE with hyperparameter tuning: 47 | ``` 48 | #SBATCH --job-name=job_name # create a short name for your job 49 | #SBATCH --output=slurm.%a.out # stdout file 50 | #SBATCH --error=slurm.%a.err # stderr file 51 | #SBATCH --nodes=1 # node count 52 | #SBATCH --ntasks=1 # total number of tasks across all nodes 53 | #SBATCH --cpus-per-task=1 # cpu-cores per task (>1 if multi-threaded tasks) 54 | #SBATCH --mem-per-cpu=8G # memory per cpu-core (4G is default) 55 | #SBATCH --gres=gpu:1 # number of gpus per node 56 | #SBATCH --time=48:00:00 # total run time limit (HH:MM:SS) 57 | #SBATCH --array=0-8 # job array with index values 0, 1, 2, 3, 4 58 | 59 | echo "My SLURM_ARRAY_TASK_ID is $SLURM_ARRAY_TASK_ID" 60 | echo "Executing on the machine:" $(hostname) 61 | 62 | export PATH="${PATH}:/usr/local/nvidia/bin:/usr/local/cuda/bin" 63 | 64 | module purge 65 | module load anaconda3/your_conda_module 66 | conda activate your_conda_env 67 | 68 | python run_vae.py > "python-$SLURM_ARRAY_TASK_ID.out" > "python-$SLURM_ARRAY_TASK_ID.err" 69 | ``` 70 | -------------------------------------------------------------------------------- /analysis_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import pickle 4 | 5 | import numpy as np 6 | import sklearn.metrics as skm 7 | 8 | from tqdm import tqdm 9 | from model_utils import get_spec, train_vae, latent_model 10 | from data_utils import load_data 11 | 12 | WEIGHT_DIR = '/scratch/gpfs/sj0161/topo_result/' # change to your directory 13 | DATA_DIR = '/scratch/gpfs/sj0161/topo_data/' # change to your directory 14 | 15 | def get_metrics(model, enc_type, if_reg, if_cls, 16 | x_train, y_train, c_train, l_train, 17 | x_valid, y_valid, c_valid, l_valid, 18 | x_test, y_test, c_test, l_test, 19 | n_repeat=5, if_bacc=False): 20 | 21 | rmse = [] 22 | r2 = [] 23 | f1 = [] 24 | acc = [] 25 | bacc = [] 26 | x_pred_trains = [] 27 | x_pred_valids = [] 28 | x_pred_tests = [] 29 | y_pred_trains = [] 30 | y_pred_valids = [] 31 | y_pred_tests = [] 32 | c_pred_trains = [] 33 | c_pred_valids = [] 34 | c_pred_tests = [] 35 | 36 | for i in range(n_repeat): 37 | if enc_type == 'cnn' or enc_type == 'dnn': 38 | in_train = x_train 39 | in_valid = x_valid 40 | in_test = x_test 41 | elif enc_type == 'gnn': 42 | in_train = [x_train, x_train] 43 | in_valid = [x_valid, x_valid] 44 | in_test = [x_test, x_test] 45 | elif enc_type == 'desc_dnn': 46 | in_train = l_train 47 | in_valid = l_valid 48 | in_test = l_test 49 | elif enc_type == 'desc_gnn': 50 | in_train = [[x_train, x_train], l_train] 51 | in_valid = [[x_valid, x_valid], l_valid] 52 | in_test = [[x_test, x_test], l_test] 53 | 54 | if not if_reg and not if_cls: 55 | x_pred_test = model.predict(in_test, verbose=0) 56 | x_pred_train = model.predict(in_train, verbose=0) 57 | x_pred_valid = model.predict(in_valid, verbose=0) 58 | x_pred_trains.append(x_pred_train) 59 | x_pred_valids.append(x_pred_valid) 60 | x_pred_tests.append(x_pred_test) 61 | elif if_reg and not if_cls: 62 | x_pred_test, y_pred_test = model.predict(in_test, verbose=0) 63 | x_pred_train, y_pred_train = model.predict(in_train, verbose=0) 64 | x_pred_valid, y_pred_valid = model.predict(in_valid, verbose=0) 65 | x_pred_trains.append(x_pred_train) 66 | x_pred_valids.append(x_pred_valid) 67 | x_pred_tests.append(x_pred_test) 68 | y_pred_tests.append(y_pred_test) 69 | y_pred_trains.append(y_pred_train) 70 | y_pred_valids.append(y_pred_valid) 71 | elif not if_reg and if_cls: 72 | x_pred_test, c_pred_test = model.predict(in_test, verbose=0) 73 | x_pred_train, c_pred_train = model.predict(in_train, verbose=0) 74 | x_pred_valid, c_pred_valid = model.predict(in_valid, verbose=0) 75 | x_pred_trains.append(x_pred_train) 76 | x_pred_valids.append(x_pred_valid) 77 | x_pred_tests.append(x_pred_test) 78 | c_pred_tests.append(c_pred_test) 79 | c_pred_trains.append(c_pred_train) 80 | c_pred_valids.append(c_pred_valid) 81 | else: 82 | x_pred_test, y_pred_test, c_pred_test = model.predict(in_test, verbose=0) 83 | x_pred_train, y_pred_train, c_pred_train = model.predict(in_train, verbose=0) 84 | x_pred_valid, y_pred_valid, c_pred_valid = model.predict(in_valid, verbose=0) 85 | x_pred_trains.append(x_pred_train) 86 | x_pred_valids.append(x_pred_valid) 87 | x_pred_tests.append(x_pred_test) 88 | c_pred_tests.append(c_pred_test) 89 | c_pred_trains.append(c_pred_train) 90 | c_pred_valids.append(c_pred_valid) 91 | y_pred_tests.append(y_pred_test) 92 | y_pred_trains.append(y_pred_train) 93 | y_pred_valids.append(y_pred_valid) 94 | 95 | if if_cls: 96 | c_pred_test = np.argmax(c_pred_test, axis=1) 97 | c_pred_train = np.argmax(c_pred_train, axis=1) 98 | c_pred_valid = np.argmax(c_pred_valid, axis=1) 99 | 100 | rmse.append([]) 101 | rmse[i].append(skm.mean_squared_error(y_train, y_pred_train) ** 0.5) 102 | rmse[i].append(skm.mean_squared_error(y_valid, y_pred_valid) ** 0.5) 103 | rmse[i].append(skm.mean_squared_error(y_test, y_pred_test) ** 0.5) 104 | 105 | r2.append([]) 106 | r2[i].append(skm.r2_score(y_train, y_pred_train)) 107 | r2[i].append(skm.r2_score(y_valid, y_pred_valid)) 108 | r2[i].append(skm.r2_score(y_test, y_pred_test)) 109 | 110 | f1.append([]) 111 | acc.append([]) 112 | 113 | if if_cls: 114 | f1[i].append(skm.f1_score(c_train, c_pred_train, average='weighted')) 115 | f1[i].append(skm.f1_score(c_valid, c_pred_valid, average='weighted')) 116 | f1[i].append(skm.f1_score(c_test, c_pred_test, average='weighted')) 117 | acc[i].append(skm.accuracy_score(c_train, c_pred_train)) 118 | acc[i].append(skm.accuracy_score(c_valid, c_pred_valid)) 119 | acc[i].append(skm.accuracy_score(c_test, c_pred_test)) 120 | 121 | if if_bacc: 122 | xt1 = x_train.ravel() 123 | xp1 = np.round(x_pred_train.ravel()) 124 | xt2 = x_valid.ravel() 125 | xp2 = np.round(x_pred_valid.ravel()) 126 | xt3 = x_test.ravel() 127 | xp3 = np.round(x_pred_test.ravel()) 128 | bacc.append([]) 129 | bacc[i].append(skm.balanced_accuracy_score(xt1, xp1)) 130 | bacc[i].append(skm.balanced_accuracy_score(xt2, xp2)) 131 | bacc[i].append(skm.balanced_accuracy_score(xt3, xp3)) 132 | 133 | rmse = np.array(rmse) 134 | r2 = np.array(r2) 135 | f1 = np.array(f1) 136 | bacc = np.array(bacc) 137 | acc = np.array(acc) 138 | rmse_m = None 139 | r2_m = None 140 | f1_m = None 141 | acc_m = None 142 | 143 | to_str = lambda x: f"{np.mean(x):0.2f}+/-{np.std(x):0.2f}" 144 | print(f"RMSE: Train {to_str(rmse[:,0])} Valid {to_str(rmse[:,1])} Test {to_str(rmse[:,2])}") 145 | rmse_m = rmse.mean(axis=0) 146 | if if_reg: 147 | print(f"R2: Train {to_str(r2[:,0])} Valid {to_str(r2[:,1])} Test {to_str(r2[:,2])}") 148 | r2_m = r2.mean(axis=0) 149 | if if_cls: 150 | print(f"F1: Train {to_str(f1[:,0])} Valid {to_str(f1[:,1])} Test {to_str(f1[:,2])}") 151 | f1_m = f1.mean(axis=0) 152 | if if_bacc: 153 | print(f"BACC: Train {to_str(bacc[:,0])} Valid {to_str(bacc[:,1])} Test {to_str(bacc[:,2])}") 154 | 155 | if if_reg and if_cls: 156 | train_out = (x_pred_trains, y_pred_trains, c_pred_trains) 157 | valid_out = (x_pred_valids, y_pred_valids, c_pred_valids) 158 | test_out = (x_pred_tests, y_pred_tests, c_pred_tests) 159 | elif if_reg and not if_cls: 160 | train_out = (x_pred_trains, y_pred_trains) 161 | valid_out = (x_pred_valids, y_pred_valids) 162 | test_out = (x_pred_tests, y_pred_tests) 163 | elif not if_reg and if_cls: 164 | train_out = (x_pred_trains, c_pred_trains) 165 | valid_out = (x_pred_valids, c_pred_valids) 166 | test_out = (x_pred_tests, c_pred_tests) 167 | else: 168 | train_out = x_pred_trains 169 | valid_out = x_pred_valids 170 | test_out = x_pred_tests 171 | 172 | 173 | return train_out, valid_out, test_out, bacc, rmse, r2, f1, acc 174 | 175 | 176 | def get_val_metrics(encoder): 177 | """ 178 | Get validation metrics for a specific encoder. 179 | 180 | Args: 181 | encoder (str): The encoder identifier. 182 | 183 | Returns: 184 | tuple: A tuple containing arrays of various metrics, including elbos, baccs, kls, cls, rls, rec, rmses, r2s, f1s, and selected files. 185 | """ 186 | ((x_train, y_train, c_train, l_train, graph_train), 187 | (x_valid, y_valid, c_valid, l_valid, graph_valid), 188 | (x_test, y_test, c_test, l_test, graph_test), 189 | NAMES, SCALER, LE) = load_data(os.path.join(DATA_DIR, 'rg2.pickle'), fold=0, if_validation=True) 190 | 191 | graph_all = np.concatenate((graph_train, graph_valid, graph_test)) 192 | 193 | files1 = sorted(glob.glob(WEIGHT_DIR + f"/{encoder}*True*True*.pickle")) 194 | files2 = sorted(glob.glob(WEIGHT_DIR + f"/{encoder}*True*True*metric.pickle")) 195 | files = list(set(files1) - set(files2)) 196 | 197 | elbos = [] 198 | kls = [] 199 | baccs = [] 200 | cls = [] 201 | rls = [] 202 | files_select = [] 203 | rmses = [] 204 | r2s = [] 205 | f1s = [] 206 | 207 | for file in tqdm(files, total=len(files)): 208 | with open(file, 'rb') as handle: 209 | hist = pickle.load(handle) 210 | if "val_decoder_acc" in file: 211 | idx = np.argmax(hist["val_decoder_acc"]) 212 | elif "val_decoder_loss" in file: 213 | idx = np.argmin(hist["val_decoder_loss"]) 214 | elif "val_loss" in file: 215 | idx = np.argmin(hist["val_loss"]) 216 | 217 | if os.path.exists(file.split(".pickle")[0]+"_metric.pickle"): 218 | with open(file.split(".pickle")[0]+"_metric.pickle", 'rb') as handle: 219 | rmse = pickle.load(handle) 220 | r2 = pickle.load(handle) 221 | f1 = pickle.load(handle) 222 | 223 | elbo = hist["val_decoder_loss"][idx] 224 | bacc = hist["val_decoder_acc"][idx] 225 | kl = hist["val_kl_loss"][idx] 226 | cl = hist["val_classifier_loss"][idx] 227 | rl = hist["val_regressor_loss"][idx] 228 | 229 | # BRUCE ADD 20240111 230 | h5_file = file.split(".pickle")[0]+".h5" 231 | ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(h5_file) 232 | 233 | model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, 234 | x_train, x_valid, y_train, y_valid, c_train, c_valid, 235 | l_train, l_valid, 1.0, weights, LR, BS, False) 236 | 237 | latent_valid = latent_model(model, data=[x_valid, l_valid], enc_type=ENCODER, mean_var=True) 238 | 239 | z_mean = latent_valid[0] 240 | z_log_var = latent_valid[1] 241 | 242 | kl_total = - .5 * np.sum(1 + z_log_var - 243 | np.square(z_mean) - 244 | np.exp(z_log_var), axis=-1) * 1 245 | 246 | kl = kl_total.mean() 247 | 248 | 249 | if np.isnan(elbo).any() or np.isnan(bacc).any(): 250 | continue 251 | else: 252 | elbos.append(elbo) 253 | baccs.append(bacc) 254 | kls.append(kl) 255 | cls.append(cl) 256 | rls.append(rl) 257 | rmses.append(rmse[1]) 258 | r2s.append(r2[1]) 259 | f1s.append(f1[1]) 260 | files_select.append(file) 261 | else: 262 | continue 263 | 264 | elbos = np.array(elbos) 265 | baccs = np.array(baccs) 266 | kls = np.array(kls) 267 | cls = np.array(cls) 268 | rls = np.array(rls) 269 | rec = elbos-kls 270 | rmses = np.array(rmses) 271 | r2s = np.array(r2s) 272 | f1s = np.array(f1s) 273 | 274 | return elbos, baccs, kls, cls, rls, rec, rmses, r2s, f1s, files_select 275 | 276 | 277 | def minmax(arr, mins=None, maxs=None): 278 | """ 279 | Perform min-max scaling on an array. 280 | 281 | Args: 282 | arr (numpy.ndarray): The input array to be scaled. 283 | mins (float, optional): The custom minimum value for scaling. Default is None. 284 | maxs (float, optional): The custom maximum value for scaling. Default is None. 285 | 286 | Returns: 287 | numpy.ndarray: The scaled array. 288 | """ 289 | # Determine minimum and maximum values for scaling 290 | if mins is not None: 291 | min_val = mins 292 | else: 293 | min_val = np.min(arr) 294 | 295 | if maxs is not None: 296 | max_val = maxs 297 | else: 298 | max_val = np.max(arr) 299 | 300 | # Perform min-max scaling 301 | scaled_arr = [(x - min_val) / (max_val - min_val) for x in arr] 302 | 303 | return scaled_arr 304 | 305 | 306 | def pareto_frontier(baccs, r2s, f1s, kls, limits): 307 | """ 308 | Find the Pareto frontier from a set of data points with multiple objectives. 309 | 310 | Args: 311 | baccs (list): List of balanced accuracy values. 312 | r2s (list): List of R2 values. 313 | f1s (list): List of F1 scores. 314 | kls (list): List of KL divergence values. 315 | limits (list): List of limit values for each objective. 316 | 317 | Returns: 318 | tuple: A tuple containing two elements: 319 | - pareto_indices (list): List of indices of Pareto points. 320 | - pareto_front (numpy.ndarray): Array of Pareto points satisfying the specified limits. 321 | """ 322 | combined = list(zip(baccs, r2s, f1s, kls, range(len(baccs)))) 323 | pareto_front = [] 324 | 325 | for point in combined: 326 | is_dominated = False 327 | 328 | for other_point in combined: 329 | if all(other <= point_dim for other, point_dim in zip(other_point[:4], point[:4])) and any(other < point_dim for other, point_dim in zip(other_point[:4], point[:4])): 330 | is_dominated = True 331 | break 332 | 333 | if not is_dominated: 334 | pareto_front.append(point) 335 | 336 | pareto_indices = [ 337 | point[4] 338 | for point in pareto_front 339 | if all(point_dim < limit for point_dim, limit in zip(point[:4], limits)) 340 | ] 341 | 342 | pareto_front = [ 343 | point[:5] 344 | for point in pareto_front 345 | if all(point_dim < limit for point_dim, limit in zip(point[:4], limits)) 346 | ] 347 | 348 | return pareto_indices, np.array(pareto_front) 349 | 350 | 351 | def closest_to_origin(pareto_front): 352 | """ 353 | Find the point in the Pareto front closest to the origin in multi-dimensional space. 354 | 355 | Args: 356 | pareto_front (list): List of points in the Pareto front, each represented as a tuple of objectives. 357 | 358 | Returns: 359 | tuple: A tuple containing two elements: 360 | - closest_point (tuple): The point in the Pareto front closest to the origin. 361 | - closest_idx (int): The index of the closest point in the Pareto front. 362 | """ 363 | baccs, r2s, f1s, kls, idx = zip(*pareto_front) 364 | 365 | baccs = minmax(baccs) 366 | r2s = minmax(r2s) 367 | f1s = minmax(f1s) 368 | kls = minmax(kls) 369 | 370 | scaled_pareto_front = list(zip(baccs, r2s, f1s, kls, idx)) 371 | closest_tuple = min(scaled_pareto_front, key=lambda point: np.linalg.norm(np.array(point[:-1]))) 372 | 373 | closest_point = closest_tuple[:-1] 374 | closest_idx = closest_tuple[-1] 375 | 376 | return closest_point, int(closest_idx) -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import networkx as nx 4 | 5 | from sklearn.preprocessing import StandardScaler, LabelEncoder 6 | from sklearn.model_selection import StratifiedKFold 7 | 8 | def get_desc(G): 9 | """ 10 | Compute various network measures for a given networkx graph. 11 | 12 | Args: 13 | G (networkx.Graph): The input graph for which network measures will be computed. 14 | 15 | Returns: 16 | numpy.ndarray: An array containing the following network measures in order: 17 | 1. Number of nodes 18 | 2. Number of edges 19 | 3. Algebraic connectivity 20 | 4. Diameter 21 | 5. Radius 22 | 6. Average degree 23 | 7. Average neighbor degree 24 | 8. Network density 25 | 9. Mean degree centrality 26 | 10. Mean betweenness centrality 27 | 11. Degree assortativity coefficient 28 | """ 29 | x1 = nx.number_of_nodes(G) 30 | x2 = nx.number_of_edges(G) 31 | x3 = nx.algebraic_connectivity(G) 32 | x4 = nx.diameter(G) 33 | x5 = nx.radius(G) 34 | degrees = [degree for _, degree in G.degree()] 35 | x6 = sum(degrees) / len(G.nodes()) 36 | x7 = np.mean(list(nx.average_neighbor_degree(G).values())) 37 | x8 = nx.density(G) 38 | x9 = np.mean(list(nx.degree_centrality(G).values())) 39 | x10 = np.mean(list(nx.betweenness_centrality(G).values())) 40 | x11 = nx.degree_assortativity_coefficient(G) 41 | 42 | return np.array([x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11]) 43 | 44 | def load_data(data_dir, fold, n_fold=5, if_validation=False, verbose=False): 45 | """ 46 | Load and preprocess data from the specified directory. 47 | 48 | Args: 49 | data_dir (str): Directory path where the data is stored. 50 | fold (int): Index of the fold to be used as the test set. 51 | n_fold (int, optional): Number of folds to split the data into. Default is 5. 52 | if_validation (bool, optional): Whether to include a validation set. Default is False. 53 | 54 | Returns: 55 | tuple: Tuple containing training, validation (optional), and test datasets, 56 | along with topo_class names, scaler, and label encoder. 57 | """ 58 | 59 | with open(data_dir, 'rb') as handle: 60 | x, y, topo_desc, topo_class, poly_param, graph = [pickle.load(handle) for _ in range(6)] 61 | 62 | # x: graph feature 63 | # y: rg2 value 64 | # topo_desc: topological descriptors 65 | # topo_class: topology classes 66 | # poly_param: polymer generation parameters 67 | # graph: networkx objects 68 | 69 | # preprocessing 70 | y = y[..., 0] 71 | 72 | SCALER = StandardScaler() 73 | topo_desc = SCALER.fit_transform(topo_desc) 74 | 75 | topo_class[topo_class == 'astar'] = 'star' 76 | topo_desc = np.where(np.isnan(topo_desc), -2, topo_desc) # only node assortativity has 0, should be [-1, 1] 77 | 78 | le = LabelEncoder() 79 | topo_class = le.fit_transform(topo_class) 80 | NAMES = le.classes_ 81 | 82 | # random shuffle 83 | x = np.random.RandomState(0).permutation(x) 84 | y = np.random.RandomState(0).permutation(y) 85 | topo_class = np.random.RandomState(0).permutation(topo_class) 86 | topo_desc = np.random.RandomState(0).permutation(topo_desc) 87 | poly_param = np.random.RandomState(0).permutation(poly_param) 88 | graph = np.random.RandomState(0).permutation(graph) 89 | 90 | # use one fold for testing 91 | skf = StratifiedKFold(n_splits=n_fold) 92 | count = -1 93 | for _, (train_idx, test_idx) in enumerate(skf.split(x, topo_class)): 94 | datasets = [x, y, topo_desc, topo_class, graph] 95 | train_data = [data[train_idx] for data in datasets] 96 | test_data = [data[test_idx] for data in datasets] 97 | 98 | x_train, y_train, l_train, c_train, graph_train = train_data 99 | x_test, y_test, l_test, c_test, graph_test = test_data 100 | 101 | if if_validation: 102 | skf2 = StratifiedKFold(n_splits=n_fold) 103 | train_idx2, valid_idx = next(iter(skf2.split(x_train, c_train))) 104 | datasets2 = [x_train, y_train, l_train, c_train, graph_train] 105 | 106 | x_valid, y_valid, l_valid, c_valid, graph_valid = ([data[valid_idx] for data in datasets2]) 107 | x_train, y_train, l_train, c_train, graph_train = ([data[train_idx2] for data in datasets2]) 108 | 109 | count += 1 110 | if count == fold: 111 | break 112 | 113 | if if_validation: 114 | if verbose: 115 | print(f'Train: {len(x_train)} Valid: {len(x_valid)} Test: {len(x_test)}') 116 | return ((x_train, y_train, c_train, l_train, graph_train), 117 | (x_valid, y_valid, c_valid, l_valid, graph_valid), 118 | (x_test, y_test, c_test, l_test, graph_test), 119 | NAMES, SCALER, le) 120 | 121 | else: 122 | if verbose: 123 | print(f'Train: {len(x_train)} Test: {len(x_test)}') 124 | return ((x_train, y_train, c_train, l_train, graph_train), 125 | (x_test, y_test, c_test, l_test, graph_test), 126 | NAMES, SCALER, le) 127 | -------------------------------------------------------------------------------- /diversity.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "05a46f9f", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/home/sj0161/.conda/envs/py38torch113/lib/python3.8/site-packages/umap/distances.py:1063: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", 14 | " @numba.jit()\n", 15 | "/home/sj0161/.conda/envs/py38torch113/lib/python3.8/site-packages/umap/distances.py:1071: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", 16 | " @numba.jit()\n", 17 | "/home/sj0161/.conda/envs/py38torch113/lib/python3.8/site-packages/umap/distances.py:1086: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", 18 | " @numba.jit()\n", 19 | "/home/sj0161/.conda/envs/py38torch113/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 20 | " from .autonotebook import tqdm as notebook_tqdm\n", 21 | "/home/sj0161/.conda/envs/py38torch113/lib/python3.8/site-packages/umap/umap_.py:660: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", 22 | " @numba.jit()\n", 23 | "2023-11-16 12:39:54.910654: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", 24 | "2023-11-16 12:39:55.015032: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", 25 | "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 26 | "2023-11-16 12:39:58.800579: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "# Import standard libraries\n", 32 | "import os\n", 33 | "import ast\n", 34 | "import glob\n", 35 | "import pickle\n", 36 | "import platform\n", 37 | "import copy\n", 38 | "from timeit import default_timer as timer\n", 39 | "\n", 40 | "# Import third-party libraries\n", 41 | "import numpy as np\n", 42 | "import pandas as pd\n", 43 | "import networkx as nx\n", 44 | "import matplotlib.pyplot as plt\n", 45 | "import proplot as pplt\n", 46 | "import umap\n", 47 | "import seaborn as sn\n", 48 | "import tensorflow as tf\n", 49 | "from tensorflow.keras import layers, Model, callbacks\n", 50 | "from tensorflow.keras import backend as K\n", 51 | "from tensorflow.keras.utils import to_categorical\n", 52 | "\n", 53 | "from vendi_score import vendi\n", 54 | "\n", 55 | "import sklearn.manifold as skma\n", 56 | "import sklearn.metrics as skm\n", 57 | "import sklearn.decomposition as skd\n", 58 | "\n", 59 | "from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder\n", 60 | "from sklearn.linear_model import LinearRegression\n", 61 | "from sklearn.model_selection import train_test_split, StratifiedKFold\n", 62 | "from sklearn.metrics import accuracy_score\n", 63 | "from scipy.spatial.distance import pdist, squareform\n", 64 | "from spektral.layers import GINConvBatch, GlobalAttentionPool, GlobalMaxPool, GlobalAttnSumPool\n", 65 | "\n", 66 | "# Import local modules\n", 67 | "from topo_sim.model import KLDivergenceLayer, Sampling\n", 68 | "\n", 69 | "# Configuration for file paths\n", 70 | "DATA_DIR = '/home/sj0161/complex_polymer/complex_polymer/temp/' # TODO: change this\n", 71 | "PLOT_DIR = '../fig/'\n", 72 | "WEIGHT_DIR = '/scratch/gpfs/sj0161/20230829/'\n", 73 | "\n", 74 | "# Set plot configurations\n", 75 | "pplt.rc['figure.facecolor'] = 'white'\n", 76 | "\n", 77 | "# Initialize color cycle\n", 78 | "COLORS = []\n", 79 | "colors1 = pplt.Cycle('default')\n", 80 | "colors2 = pplt.Cycle('538')\n", 81 | "\n", 82 | "for color in colors1:\n", 83 | " COLORS.append(color['color'])\n", 84 | "\n", 85 | "for color in colors2:\n", 86 | " COLORS.append(color['color'])\n", 87 | "\n", 88 | "# Handle warnings\n", 89 | "import warnings\n", 90 | "warnings.filterwarnings('ignore')\n", 91 | "\n", 92 | "# Some constants\n", 93 | "LATENT_DIM = 8" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 2, 99 | "id": "bcd65bdf", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "def load_data(data_dir, fold, n_fold=5, if_validation=False):\n", 104 | " \"\"\"\n", 105 | " Load and preprocess data from the specified directory.\n", 106 | "\n", 107 | " Args:\n", 108 | " data_dir (str): Directory path where the data is stored.\n", 109 | " fold (int): Index of the fold to be used as the test set.\n", 110 | " n_fold (int, optional): Number of folds to split the data into. Default is 5.\n", 111 | " if_validation (bool, optional): Whether to include a validation set. Default is False.\n", 112 | "\n", 113 | " Returns:\n", 114 | " tuple: Tuple containing training, validation (optional), and test datasets,\n", 115 | " along with topo_class names, scaler, and label encoder.\n", 116 | " \"\"\"\n", 117 | " \n", 118 | " with open(data_dir, 'rb') as handle:\n", 119 | " x, y, topo_desc, topo_class, poly_param, graph = [pickle.load(handle) for _ in range(6)]\n", 120 | " \n", 121 | " # x: graph feature\n", 122 | " # y: rg2 value\n", 123 | " # topo_desc: topological descriptors\n", 124 | " # topo_class: topology classes\n", 125 | " # poly_param: polymer generation parameters\n", 126 | " # graph: networkx objects\n", 127 | " \n", 128 | " # preprocessing\n", 129 | " y = y[..., 0]\n", 130 | " \n", 131 | " SCALER = StandardScaler()\n", 132 | " topo_desc = SCALER.fit_transform(topo_desc)\n", 133 | "\n", 134 | " topo_class[topo_class == 'astar'] = 'star'\n", 135 | " topo_desc = np.where(np.isnan(topo_desc), -2, topo_desc) # only node assortativity has 0, should be [-1, 1]\n", 136 | "\n", 137 | " le = LabelEncoder()\n", 138 | " topo_class = le.fit_transform(topo_class)\n", 139 | " NAMES = le.classes_\n", 140 | " \n", 141 | " # random shuffle\n", 142 | " x = np.random.RandomState(0).permutation(x)\n", 143 | " y = np.random.RandomState(0).permutation(y)\n", 144 | " topo_class = np.random.RandomState(0).permutation(topo_class)\n", 145 | " topo_desc = np.random.RandomState(0).permutation(topo_desc)\n", 146 | " poly_param = np.random.RandomState(0).permutation(poly_param)\n", 147 | " graph = np.random.RandomState(0).permutation(graph)\n", 148 | "\n", 149 | " # we just use one fold for testing\n", 150 | " skf = StratifiedKFold(n_splits=n_fold)\n", 151 | " count = -1\n", 152 | " for _, (train_idx, test_idx) in enumerate(skf.split(x, topo_class)):\n", 153 | " train_data = [data[train_idx] for data in [x, y, topo_desc, topo_class, graph]]\n", 154 | " test_data = [data[test_idx] for data in [x, y, topo_desc, topo_class, graph]]\n", 155 | " x_train, y_train, l_train, c_train, graph_train = train_data\n", 156 | " x_test, y_test, l_test, c_test, graph_test = test_data\n", 157 | "\n", 158 | " if if_validation:\n", 159 | " skf2 = StratifiedKFold(n_splits=n_fold)\n", 160 | " train_idx2, valid_idx = next(iter(skf2.split(x_train, c_train)))\n", 161 | " x_valid, y_valid, l_valid, c_valid, graph_valid = (\n", 162 | " [data[valid_idx] for data in [x_train, y_train, l_train, c_train, graph_train]])\n", 163 | " x_train, y_train, l_train, c_train, graph_train = (\n", 164 | " [data[train_idx2] for data in [x_train, y_train, l_train, c_train, graph_train]])\n", 165 | "\n", 166 | " \n", 167 | " count += 1\n", 168 | " if count == fold:\n", 169 | " break\n", 170 | "\n", 171 | " if if_validation:\n", 172 | " print(f'Train: {len(x_train)} Valid: {len(x_valid)} Test: {len(x_test)}')\n", 173 | " return ((x_train, y_train, c_train, l_train, graph_train),\n", 174 | " (x_valid, y_valid, c_valid, l_valid, graph_valid),\n", 175 | " (x_test, y_test, c_test, l_test, graph_test),\n", 176 | " NAMES, SCALER, le)\n", 177 | " \n", 178 | " else:\n", 179 | " print(f'Train: {len(x_train)} Test: {len(x_test)}')\n", 180 | " return ((x_train, y_train, c_train, l_train, graph_train),\n", 181 | " (x_test, y_test, c_test, l_test, graph_test),\n", 182 | " NAMES, SCALER, le)\n", 183 | "\n", 184 | " \n", 185 | "def graph_to_lap_spec(graphs):\n", 186 | " lap_spec_data = []\n", 187 | " for G in graphs:\n", 188 | " lap_spec = nx.laplacian_spectrum(G)\n", 189 | " lap_spec_zero_pad = np.zeros((100,))\n", 190 | " lap_spec_zero_pad[:len(lap_spec)] = lap_spec\n", 191 | " lap_spec_data.append(lap_spec_zero_pad)\n", 192 | " return np.array(lap_spec_data)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "id": "dc1ee376", 198 | "metadata": {}, 199 | "source": [ 200 | "### Vendi score evaluation for the whole dataset" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 3, 206 | "id": "2fd23fe5", 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stdout", 211 | "output_type": "stream", 212 | "text": [ 213 | "Train: 858 Valid: 215 Test: 269\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "((x_train, y_train, c_train, l_train, graph_train),\n", 219 | "(x_valid, y_valid, c_valid, l_valid, graph_valid),\n", 220 | "(x_test, y_test, c_test, l_test, graph_test),\n", 221 | "NAMES, SCALER, LE) = load_data(os.path.join(DATA_DIR, 'rg2.pickle'), fold=0, if_validation=True)\n", 222 | "\n", 223 | "graph_all = np.concatenate((graph_train, graph_valid, graph_test))" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 4, 229 | "id": "703e634e", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "# convert all graphs into graph eigen spectra\n", 234 | "graph_total = [graph_train, graph_valid, graph_test]\n", 235 | "\n", 236 | "lap_spec_data = []\n", 237 | "\n", 238 | "for graphs in graph_total:\n", 239 | " for G in graphs:\n", 240 | " lap_spec = nx.laplacian_spectrum(G)\n", 241 | " lap_spec_zero_pad = np.zeros((100,))\n", 242 | " lap_spec_zero_pad[:len(lap_spec)] = lap_spec\n", 243 | " lap_spec_data.append(lap_spec_zero_pad)\n", 244 | " \n", 245 | "lap_spec_data = np.array(lap_spec_data)\n", 246 | "\n", 247 | "with open(\"../result/lap_spec_data.pickle\", \"wb\") as handle:\n", 248 | " pickle.dump(lap_spec_data, handle)" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 5, 254 | "id": "a8d0666f", 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "with open(\"../result/lap_spec_data.pickle\", \"rb\") as handle:\n", 259 | " lap_spec_data = pickle.load(handle)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 6, 265 | "id": "c459e54b", 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "name": "stdout", 270 | "output_type": "stream", 271 | "text": [ 272 | "Dataset Vendi Score: 2.0968\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "print(f\"Dataset Vendi Score: {vendi.score_dual(lap_spec_data):0.4f}\")" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "id": "c0ff6b38", 283 | "metadata": {}, 284 | "source": [ 285 | "### Vendi score evaluation for the latent space" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 7, 291 | "id": "dc2c9743", 292 | "metadata": {}, 293 | "outputs": [ 294 | { 295 | "name": "stdout", 296 | "output_type": "stream", 297 | "text": [ 298 | "../result/latent_space_desc_gnn_cnn.pickle\n", 299 | "Dataset Vendi Score: 7.3225 \n", 300 | "\n", 301 | "../result/latent_space_gnn_cnn.pickle\n", 302 | "Dataset Vendi Score: 7.4370 \n", 303 | "\n", 304 | "../result/latent_space_desc_dnn_cnn.pickle\n", 305 | "Dataset Vendi Score: 7.0863 \n", 306 | "\n" 307 | ] 308 | } 309 | ], 310 | "source": [ 311 | "files = [\n", 312 | " \"../result/latent_space_desc_gnn_cnn.pickle\",\n", 313 | " \"../result/latent_space_gnn_cnn.pickle\",\n", 314 | " \"../result/latent_space_desc_dnn_cnn.pickle\"\n", 315 | "]\n", 316 | "\n", 317 | "for file in files:\n", 318 | " with open(file, \"rb\") as handle:\n", 319 | " latent_data = pickle.load(handle)\n", 320 | " print(file)\n", 321 | " print(f\"Dataset Vendi Score: {vendi.score_dual(latent_data):0.4f} \\n\")" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 8, 327 | "id": "88a1dbe9", 328 | "metadata": {}, 329 | "outputs": [ 330 | { 331 | "name": "stdout", 332 | "output_type": "stream", 333 | "text": [ 334 | "../result/latent_space_False_False.pickle\n", 335 | "Dataset Vendi Score: 5.8532 \n", 336 | "\n", 337 | "../result/latent_space_False_True.pickle\n", 338 | "Dataset Vendi Score: 6.3171 \n", 339 | "\n", 340 | "../result/latent_space_True_False.pickle\n", 341 | "Dataset Vendi Score: 5.3128 \n", 342 | "\n" 343 | ] 344 | } 345 | ], 346 | "source": [ 347 | "files = [\n", 348 | " \"../result/latent_space_False_False.pickle\",\n", 349 | " \"../result/latent_space_False_True.pickle\",\n", 350 | " \"../result/latent_space_True_False.pickle\"\n", 351 | "]\n", 352 | "\n", 353 | "for file in files:\n", 354 | " with open(file, \"rb\") as handle:\n", 355 | " latent_data = pickle.load(handle)\n", 356 | " print(file)\n", 357 | " print(f\"Dataset Vendi Score: {vendi.score_dual(latent_data):0.4f} \\n\")" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "id": "4e62f7bf", 363 | "metadata": {}, 364 | "source": [ 365 | "### Vendi score evaluation for the random generation based on different models" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 9, 371 | "id": "c40133e2", 372 | "metadata": {}, 373 | "outputs": [ 374 | { 375 | "name": "stdout", 376 | "output_type": "stream", 377 | "text": [ 378 | "Dataset Vendi Score: 5.0684\n" 379 | ] 380 | } 381 | ], 382 | "source": [ 383 | "with open(\"../result/no_valid_random_gen_desc_gnn_cnn.pickle\", \"rb\") as handle:\n", 384 | " gen_data = pickle.load(handle)\n", 385 | " \n", 386 | "gen_clean_graph = [gen_data[i][2] for i in range(len(gen_data))]\n", 387 | "\n", 388 | "lap_spec_data = graph_to_lap_spec(gen_clean_graph)\n", 389 | "\n", 390 | "print(f\"Dataset Vendi Score: {vendi.score_dual(lap_spec_data):0.4f}\")" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 10, 396 | "id": "10352c87", 397 | "metadata": {}, 398 | "outputs": [ 399 | { 400 | "name": "stdout", 401 | "output_type": "stream", 402 | "text": [ 403 | "Dataset Vendi Score: 4.9580\n" 404 | ] 405 | } 406 | ], 407 | "source": [ 408 | "with open(\"../result/no_valid_random_gen_gnn_cnn.pickle\", \"rb\") as handle:\n", 409 | " gen_data = pickle.load(handle)\n", 410 | " \n", 411 | "gen_clean_graph = [gen_data[i][2] for i in range(len(gen_data))]\n", 412 | "\n", 413 | "lap_spec_data = graph_to_lap_spec(gen_clean_graph)\n", 414 | "\n", 415 | "print(f\"Dataset Vendi Score: {vendi.score_dual(lap_spec_data):0.4f}\")" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 11, 421 | "id": "4312a29c", 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "name": "stdout", 426 | "output_type": "stream", 427 | "text": [ 428 | "Dataset Vendi Score: 4.3305\n" 429 | ] 430 | } 431 | ], 432 | "source": [ 433 | "with open(\"../result/no_valid_random_gen_desc_dnn_cnn.pickle\", \"rb\") as handle:\n", 434 | " gen_data = pickle.load(handle)\n", 435 | " \n", 436 | "gen_clean_graph = [gen_data[i][2] for i in range(len(gen_data))]\n", 437 | "\n", 438 | "lap_spec_data = graph_to_lap_spec(gen_clean_graph)\n", 439 | "\n", 440 | "print(f\"Dataset Vendi Score: {vendi.score_dual(lap_spec_data):0.4f}\")" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "id": "d7d89c19", 447 | "metadata": {}, 448 | "outputs": [], 449 | "source": [ 450 | "with open(\"../result/no_valid_random_gen_desc_gnn_cnn.pickle\", \"rb\") as handle:\n", 451 | " gen_data = pickle.load(handle)\n", 452 | " \n", 453 | "gen_clean_graph = [gen_data[i][2] for i in range(len(gen_data))]\n", 454 | "\n", 455 | "lap_spec_data = graph_to_lap_spec(gen_clean_graph)\n", 456 | "\n", 457 | "print(f\"Dataset Vendi Score: {vendi.score_dual(lap_spec_data):0.4f}\")" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "id": "0cb22a63", 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [] 467 | } 468 | ], 469 | "metadata": { 470 | "kernelspec": { 471 | "display_name": "py38torch113 [~/.conda/envs/py38torch113/]", 472 | "language": "python", 473 | "name": "conda_py38torch113" 474 | }, 475 | "language_info": { 476 | "codemirror_mode": { 477 | "name": "ipython", 478 | "version": 3 479 | }, 480 | "file_extension": ".py", 481 | "mimetype": "text/x-python", 482 | "name": "python", 483 | "nbconvert_exporter": "python", 484 | "pygments_lexer": "ipython3", 485 | "version": "3.8.16" 486 | } 487 | }, 488 | "nbformat": 4, 489 | "nbformat_minor": 5 490 | } 491 | -------------------------------------------------------------------------------- /generation_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import networkx as nx 5 | import tensorflow as tf 6 | 7 | from graph_utils import graph_anneal_break_largest_circle 8 | from data_utils import load_data, get_desc 9 | 10 | WEIGHT_DIR = '/scratch/gpfs/sj0161/topo_result/' # change to your directory 11 | DATA_DIR = '/scratch/gpfs/sj0161/topo_data/' # change to your directory 12 | 13 | LATENT_DIM = 8 14 | 15 | 16 | def reg_cls(model, z, data, enc_type='gnn'): 17 | """ 18 | Perform regression and classification using a given model. 19 | 20 | Args: 21 | model (tf.keras.Model): The neural network model. 22 | z (numpy.ndarray): Input data for regression and classification. 23 | data (numpy.ndarray or list): Input data for encoding (depends on 'enc_type' argument). 24 | enc_type (str, optional): The encoding type, either 'gnn', 'desc_gnn', or other. Defaults to 'gnn'. 25 | 26 | Returns: 27 | Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]: A tuple containing: 28 | - generated_reg (numpy.ndarray): Predictions from the regression model. 29 | - generated_cls (numpy.ndarray): Predictions from the classification model. 30 | - cleaned_reg (numpy.ndarray): Predictions after cleaning from the regression model. 31 | - cleaned_cls (numpy.ndarray): Predictions after cleaning from the classification model. 32 | """ 33 | # Regression Model 34 | regressor = model.get_layer('regressor') 35 | r_in = tf.keras.Input(shape=(LATENT_DIM,), name='reg_in') 36 | for i, l in enumerate(regressor.layers): 37 | if i == 0: 38 | x = l(r_in) 39 | else: 40 | x = l(x) 41 | reg_model = tf.keras.Model(inputs=r_in, outputs=x) 42 | generated_reg = reg_model.predict(z, verbose=0) 43 | 44 | # Classification Model 45 | classifier = model.get_layer('classifier') 46 | c_in = tf.keras.Input(shape=(LATENT_DIM,), name='cls_in') 47 | for i, l in enumerate(classifier.layers): 48 | if i == 0: 49 | x = l(c_in) 50 | else: 51 | x = l(x) 52 | cls_model = tf.keras.Model(inputs=c_in, outputs=x) 53 | generated_cls = cls_model.predict(z, verbose=0) 54 | 55 | # Encoding Type Handling 56 | if enc_type == 'gnn': 57 | _, cleaned_reg, cleaned_cls = model.predict([data, data], verbose=0) 58 | elif enc_type == 'desc_gnn': 59 | _, cleaned_reg, cleaned_cls = model.predict([[data[0], data[0]], data[1]], verbose=0) 60 | else: 61 | _, cleaned_reg, cleaned_cls = model.predict(data, verbose=0) 62 | 63 | return generated_reg, generated_cls, cleaned_reg, cleaned_cls 64 | 65 | def polymer_generation(model, latent_mean, latent_log_var, enc_type='gnn'): 66 | """ 67 | Generate and clean a polymer graph using a given model. 68 | 69 | Args: 70 | model (tf.keras.Model): The neural network model. 71 | latent_mean (numpy.ndarray): Mean of the latent space. 72 | latent_log_var (numpy.ndarray): Log variance of the latent space. 73 | enc_type (str, optional): Encoding type: 'gnn', 'desc_dnn', or 'desc_gnn'. Defaults to 'gnn'. 74 | 75 | Returns: 76 | Tuple[networkx.Graph, networkx.Graph, str, str, float, float, float]: 77 | - G_pre_clean (networkx.Graph): Generated polymer graph before cleaning. 78 | - G_post_clean (networkx.Graph): Generated polymer graph after cleaning. 79 | - cls_pre_clean (str): Predicted class before cleaning. 80 | - cls_post_clean (str): Predicted class after cleaning. 81 | - reg_pre_clean (float): Regression result before cleaning. 82 | - reg_post_clean_m (float): Mean of regression result after cleaning. 83 | - reg_post_clean_s (float): Standard deviation of regression result after cleaning. 84 | """ 85 | 86 | (_, _, _, NAMES, SCALER, _) = load_data(os.path.join(DATA_DIR, 'rg2.pickle'), fold=0, if_validation=True) 87 | 88 | # Get decoder weights 89 | decoder = model.get_layer('decoder') 90 | 91 | # Define input tensor for generation 92 | d_in = tf.keras.Input(shape=(LATENT_DIM,), name='gen_d_in') 93 | 94 | # Pass the input through the decoder layers 95 | x = d_in 96 | for layer in decoder.layers: 97 | x = layer(x) 98 | 99 | # Define generation model 100 | gen_model = tf.keras.Model(inputs=d_in, outputs=x) 101 | 102 | # Prepare latent space 103 | if len(latent_mean.shape) == 1: 104 | latent_mean = latent_mean[None, ...] 105 | elif len(latent_mean.shape) == 2 and latent_mean.shape[0] != 1: 106 | raise Exception("Only allow one sample at a time ...") 107 | 108 | epsilon = np.random.normal(size=(1, LATENT_DIM)) 109 | 110 | # Generate data 111 | if latent_log_var is None: 112 | z_sample = latent_mean 113 | else: 114 | z_sample = latent_mean + np.exp(0.5 * latent_log_var) * epsilon 115 | 116 | sampled_data = np.round(gen_model.predict(z_sample, verbose=0))[0] 117 | 118 | # Convert data to graphs 119 | G_pre_clean = nx.from_numpy_array(sampled_data) 120 | G_post_clean = graph_anneal_break_largest_circle(sampled_data) 121 | 122 | if enc_type == 'desc_dnn': 123 | data = get_desc(G_post_clean)[None, ] 124 | data = np.where(np.isnan(data), -2, data) 125 | data = SCALER.transform(data) 126 | 127 | elif enc_type == 'desc_gnn': 128 | adjs = np.zeros((1, 100, 100)) 129 | adj_ = nx.to_numpy_array(G_post_clean) 130 | adjs[0, :len(adj_), :len(adj_)] = adj_ 131 | data1 = adjs 132 | 133 | data2 = get_desc(G_post_clean)[None, ] 134 | data2 = np.where(np.isnan(data2), -2, data2) 135 | data2 = SCALER.transform(data2) 136 | 137 | data = [data1, data2] 138 | 139 | else: 140 | adjs = np.zeros((1, 100, 100)) 141 | adj_ = nx.to_numpy_array(G_post_clean) 142 | adjs[0, :len(adj_), :len(adj_)] = adj_ 143 | data = adjs 144 | 145 | (reg_pre_clean, cls_pre_clean, reg_post_clean, cls_post_clean) = reg_cls(model, z_sample, data, enc_type) 146 | 147 | cls_pre_clean = NAMES[np.argmax(cls_pre_clean, axis=1)] 148 | cls_post_clean = NAMES[np.argmax(cls_post_clean, axis=1)] 149 | 150 | unique_cls, counts = np.unique(cls_pre_clean, return_counts=True) 151 | cls_pre_clean = unique_cls[np.argmax(counts)] 152 | 153 | unique_cls, counts = np.unique(cls_post_clean, return_counts=True) 154 | cls_post_clean = unique_cls[np.argmax(counts)] 155 | 156 | reg_pre_clean = reg_pre_clean[0][0] 157 | 158 | reg_post_clean_m = np.mean(reg_post_clean) 159 | reg_post_clean_s = np.std(reg_post_clean) 160 | 161 | return G_pre_clean, G_post_clean, cls_pre_clean, cls_post_clean, reg_pre_clean, reg_post_clean_m, reg_post_clean_s 162 | 163 | 164 | def check_valid(gen_reg, cln_reg_m, gen_cls, cln_cls, threshold=2.0): 165 | """ 166 | Check the validity of generated and cleaned data based on regression and classification results. 167 | 168 | Args: 169 | gen_reg (float): Regression result from generated data. 170 | cln_reg_m (float): Mean of regression result from cleaned data. 171 | gen_cls (str): Classification result from generated data. 172 | cln_cls (str): Classification result from cleaned data. 173 | threshold (float, optional): Threshold for comparing regression results. Defaults to 2.0. 174 | 175 | Returns: 176 | bool: True if both regression and classification results are valid, False otherwise. 177 | """ 178 | flag1 = np.abs(gen_reg - cln_reg_m) < threshold 179 | flag2 = gen_cls == cln_cls 180 | return flag1 and flag2 -------------------------------------------------------------------------------- /graph_utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | 4 | def remove_component(graph, component): 5 | """ 6 | Remove a specified component from the graph. 7 | 8 | Args: 9 | graph (numpy.ndarray): A numpy array representing the adjacency matrix of the graph. 10 | component (list): A list of nodes representing the component to be removed. 11 | 12 | Returns: 13 | numpy.ndarray: The modified graph with the specified component removed. 14 | """ 15 | for node in component: 16 | graph[node, :] = 0 17 | graph[:, node] = 0 18 | return graph 19 | 20 | 21 | def dfs_component(graph, node, visited): 22 | """ 23 | Perform a depth-first search to find all nodes in the connected component of a given node. 24 | 25 | Args: 26 | graph (numpy.ndarray): A numpy array representing the adjacency matrix of the graph. 27 | node (int): The node from which the DFS starts. 28 | visited (list of bool): A list indicating whether each node has been visited. 29 | 30 | Returns: 31 | list: A list of nodes that are in the same connected component as the starting node. 32 | """ 33 | component = [] 34 | stack = [node] 35 | 36 | while stack: 37 | current_node = stack.pop() 38 | if not visited[current_node]: 39 | visited[current_node] = True 40 | component.append(current_node) 41 | for neighbor in range(len(graph[current_node])): 42 | if graph[current_node][neighbor] == 1 and not visited[neighbor]: 43 | stack.append(neighbor) 44 | return component 45 | 46 | 47 | def graph_anneal(graph): 48 | """ 49 | Modify the graph by removing small connected components and updating the adjacency matrix. 50 | 51 | Args: 52 | graph (numpy.ndarray): A numpy array representing the adjacency matrix of the graph. 53 | 54 | Returns: 55 | numpy.ndarray: The modified graph adjacency matrix with smaller components removed. 56 | """ 57 | num_nodes = len(graph) 58 | visited = [False] * num_nodes 59 | 60 | for i in range(num_nodes): 61 | if not visited[i]: 62 | component = dfs_component(graph, i, visited) 63 | if len(component) < 10: 64 | graph = remove_component(graph, component) 65 | 66 | # Update graph to include only nodes in larger components 67 | mask = np.any(graph, axis=1) 68 | graph = graph[mask][:, mask] 69 | 70 | return graph 71 | 72 | 73 | def remove_edge(graph, v, u): 74 | """ 75 | Remove an edge from the graph by setting its weight to zero. 76 | 77 | Args: 78 | graph (networkx.Graph): A NetworkX graph object. 79 | v (int): The first node in the edge. 80 | u (int): The second node in the edge. 81 | 82 | Returns: 83 | numpy.ndarray: The modified graph with the specified edge removed. 84 | """ 85 | graph[v][u] = 0 86 | graph[u][v] = 0 87 | return graph 88 | 89 | 90 | def break_edges_keep_largest_circle(graph, adj, max_cycle_length=80): 91 | """ 92 | Break edges to keep the largest circle in the graph under the specified maximum length. 93 | 94 | Args: 95 | graph (networkx.Graph): A NetworkX graph object. 96 | adj (numpy.ndarray): A numpy array representing the adjacency matrix of the graph. 97 | max_cycle_length (int, optional): The maximum length of the cycle to consider. 98 | Default is 80. 99 | 100 | Returns: 101 | networkx.Graph: A modified NetworkX graph with edges broken to ensure 102 | all cycles are below the specified maximum length. 103 | """ 104 | while True: 105 | cycles = sorted([cycle for cycle in nx.cycle_basis(graph) if len(cycle) < max_cycle_length], key=len) 106 | 107 | if len(cycles) < 1: 108 | break 109 | 110 | cycle = cycles[0] 111 | min_weight = float('inf') 112 | edge_to_remove = None 113 | 114 | for i in range(len(cycle)): 115 | node1 = cycle[i] 116 | node2 = cycle[(i+1) % len(cycle)] 117 | 118 | edge_weight = adj[node1][node2] 119 | 120 | if edge_weight < min_weight: 121 | min_weight = edge_weight 122 | edge_to_remove = (node1, node2) 123 | 124 | if edge_to_remove: 125 | graph.remove_edge(*edge_to_remove) 126 | 127 | return graph 128 | 129 | 130 | def keep_largest_connected_component(graph): 131 | """ 132 | Retain only the largest connected component of the given graph. 133 | 134 | Args: 135 | graph (networkx.Graph): A NetworkX graph object. 136 | 137 | Returns: 138 | networkx.Graph: A NetworkX graph object representing the largest connected 139 | component of the input graph. 140 | """ 141 | connected_components = list(nx.connected_components(graph)) 142 | 143 | if len(connected_components) == 0: 144 | return graph 145 | largest_component = max(connected_components, key=len) 146 | largest_component_graph = graph.subgraph(largest_component).copy() 147 | 148 | return largest_component_graph 149 | 150 | 151 | def graph_anneal_break_largest_circle(adj): 152 | """ 153 | Perform graph annealing and retain the largest circle in the graph. 154 | 155 | Args: 156 | adj (numpy.ndarray): A numpy array representing the adjacency matrix of the graph. 157 | 158 | Returns: 159 | networkx.Graph: A NetworkX graph object representing the largest circle 160 | retained after graph annealing and processing. 161 | """ 162 | a = graph_anneal(np.round(adj)) 163 | G = nx.from_numpy_array(a) 164 | G = break_edges_keep_largest_circle(G, adj) 165 | G = keep_largest_connected_component(G) 166 | return G 167 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' 3 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 4 | 5 | import ast 6 | import pickle 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from timeit import default_timer as timer 11 | from tensorflow.keras import layers, callbacks 12 | from tensorflow.keras import backend as K 13 | from tensorflow.keras.utils import to_categorical 14 | from spektral.layers import GINConvBatch 15 | 16 | LATENT_DIM = 8 17 | WEIGHT_DIR = '/scratch/gpfs/sj0161/topo_result/' 18 | 19 | class Sampling(layers.Layer): 20 | def __init__(self, seed=None, **kwargs): 21 | super(Sampling, self).__init__(**kwargs) 22 | self.seed = seed 23 | 24 | def call(self, inputs): 25 | z_mean, z_log_var = inputs 26 | batch = tf.shape(z_mean)[0] 27 | dim = tf.shape(z_mean)[1] 28 | epsilon = tf.keras.backend.random_normal(shape=(batch, dim), seed=self.seed) 29 | return z_mean + tf.exp(0.5 * z_log_var) * epsilon 30 | 31 | class KLDivergenceLayer(layers.Layer): 32 | def __init__(self, *args, **kwargs): 33 | self.is_placeholder = True 34 | super(KLDivergenceLayer, self).__init__(*args, **kwargs) 35 | 36 | def call(self, inputs, beta=1.0): 37 | z_mean, z_log_var = inputs 38 | kl_batch = - .5 * K.sum(1 + z_log_var - 39 | K.square(z_mean) - 40 | K.exp(z_log_var), axis=-1) * beta 41 | kl_loss = K.mean(kl_batch) 42 | self.add_loss(kl_loss, inputs=inputs) 43 | self.add_metric(kl_loss, aggregation='mean', name='kl_loss') 44 | return inputs 45 | 46 | def encoder_gnn(input_shape=(100, 100)): 47 | """ 48 | Constructs a GNN encoder with the specified input shape. 49 | 50 | Args: 51 | input_shape (tuple, optional): Shape of the input adjacency matrix and feature matrix. 52 | Defaults to (100, 100). 53 | 54 | Returns: 55 | tuple: A tuple containing the input tensors and the output tensor of the encoder. 56 | """ 57 | A = tf.keras.Input(shape=input_shape, sparse=False, name='a_1') 58 | F = tf.keras.Input(shape=input_shape, name='f_1') 59 | 60 | x = GINConvBatch(32, activation='relu')([F, A]) 61 | x = GINConvBatch(32, activation='relu')([x, A]) 62 | 63 | x = layers.Flatten(name='flatten')(x) 64 | x_out = layers.Dense(32, activation='relu')(x) 65 | 66 | return [A, F], x_out 67 | 68 | 69 | def encoder_desc_dnn(input_shape=(11,)): 70 | """ 71 | Constructs a DNN encoder for descriptor data with the specified input shape. 72 | 73 | Args: 74 | input_shape (tuple, optional): The shape of the input descriptor data. Defaults to (11,). 75 | 76 | Returns: 77 | tuple: A tuple containing the input tensor and the output tensor of the encoder. 78 | """ 79 | x_in = tf.keras.Input(shape=input_shape, name='x_in') 80 | 81 | x = layers.Dense(32, activation='relu')(x_in) 82 | x_out = layers.Dense(32, activation='relu')(x) 83 | 84 | return x_in, x_out 85 | 86 | 87 | def encoder_desc_gnn(input_shape=[(100, 100), (11,)]): 88 | """ 89 | Constructs a GNN-based encoder with an additional DNN branch for descriptor data. 90 | 91 | Args: 92 | input_shape (list of tuple, optional): List containing the shapes of the adjacency 93 | matrix, feature matrix, and descriptor data. 94 | Defaults to [(100, 100), (11,)]. 95 | 96 | Returns: 97 | tuple: A tuple containing the input tensors and the concatenated output tensor 98 | from the GNN and DNN branches of the encoder. 99 | """ 100 | A1 = tf.keras.Input(shape=input_shape[0], sparse=False, name='a_1') 101 | F1 = tf.keras.Input(shape=input_shape[0], name='f_1') 102 | x_in = tf.keras.Input(shape=input_shape[1], name='x_in') 103 | 104 | x1 = GINConvBatch(32, activation='relu')([F1, A1]) 105 | x1 = GINConvBatch(32, activation='relu')([x1, A1]) 106 | 107 | x2 = layers.Dense(32, activation='relu')(x_in) 108 | 109 | x_out2 = layers.Dense(32, activation='relu')(x2) 110 | x1 = layers.Flatten(name='flatten')(x1) 111 | x_out1 = layers.Dense(32, activation='relu')(x1) 112 | 113 | x_out = layers.Concatenate()([x_out1, x_out2]) 114 | 115 | return [[A1, F1], x_in], x_out 116 | 117 | def latent_space(x_out, beta=1.0): 118 | """ 119 | Constructs the latent space for a VAE, including the KL divergence layer. 120 | 121 | Args: 122 | x_out (tf.Tensor): Output tensor from the encoder. 123 | beta (float, optional): Weight for the KL divergence term in the loss function. Defaults to 1.0. 124 | 125 | Returns: 126 | tuple: Tensors representing the mean and log variance of the latent space. 127 | """ 128 | z_mean = layers.Dense(LATENT_DIM, name="z1")(x_out) 129 | z_log_var = layers.Dense(LATENT_DIM, name="z2")(x_out) 130 | z_mean, z_log_var = KLDivergenceLayer(name='kl')([z_mean, z_log_var], beta=beta) 131 | return z_mean, z_log_var 132 | 133 | def regressor_dnn(): 134 | """ 135 | Constructs a dense neural network (DNN) based regressor model. 136 | 137 | Returns: 138 | tf.keras.Model: A Keras model for regression, mapping from the latent space to a single output. 139 | """ 140 | r_in = layers.Input(shape=(LATENT_DIM,)) 141 | x = layers.Dense(32, activation='relu', name="r1")(r_in) 142 | x = layers.Dense(32, activation='relu', name="r2")(x) 143 | r_out = layers.Dense(1, activation='linear', name="r3")(x) 144 | regressor = tf.keras.Model(r_in, r_out, name='regressor') 145 | return regressor 146 | 147 | 148 | def classifier_dnn(): 149 | """ 150 | Constructs a dense neural network (DNN) based classifier model. 151 | 152 | Returns: 153 | tf.keras.Model: A Keras model for classification, mapping from the latent space to a softmax output. 154 | """ 155 | c_in = layers.Input(shape=(LATENT_DIM,)) 156 | x = layers.Dense(32, activation='relu', name="c1")(c_in) 157 | x = layers.Dense(32, activation='relu', name="c2")(x) 158 | c_out = layers.Dense(6, activation='softmax', name="c3")(x) 159 | classifier = tf.keras.Model(c_in, c_out, name='classifier') 160 | return classifier 161 | 162 | 163 | def decoder_cnn(): 164 | """ 165 | Constructs a convolutional neural network (CNN) based decoder model. 166 | 167 | Returns: 168 | tf.keras.Model: A Keras model for decoding, mapping from the latent space to the reconstructed output. 169 | """ 170 | d_in = layers.Input(shape=(LATENT_DIM,)) 171 | x = layers.Dense(32, activation='relu', name="d1")(d_in) 172 | x = layers.Dense(32, activation='relu', name="d2")(x) 173 | x = layers.Dense(25 * 25 * 64, activation='relu', name="d3")(x) 174 | x = layers.Reshape((25, 25, 64), name="d4")(x) 175 | x = layers.Conv2DTranspose(32, (3, 3), strides=2, padding='same', activation='relu', name=f"d5")(x) 176 | x = layers.Conv2DTranspose(32, (3, 3), strides=2, padding='same', activation='relu', name=f"d6")(x) 177 | x = layers.Conv2DTranspose(1, (3, 3), padding='same', activation='sigmoid', name="d7")(x) 178 | d_out = layers.Reshape((100, 100), name="d8")(x) 179 | decoder = tf.keras.Model(d_in, d_out, name='decoder') 180 | return decoder 181 | 182 | 183 | def get_callbacks(weight_path, monitor='val_decoder_acc', patience=200): 184 | """ 185 | Generates model training callbacks. 186 | 187 | Args: 188 | weight_path (str): Path to save model weights. 189 | monitor (str): Metric to monitor for performance (default 'val_decoder_acc'). 190 | patience (int): Number of epochs to wait for improvement before stopping (default 200). 191 | 192 | Returns: 193 | list: ModelCheckpoint and EarlyStopping callbacks. 194 | """ 195 | mode = 'max' if 'acc' in monitor else 'min' 196 | checkpoint = callbacks.ModelCheckpoint(weight_path, monitor=monitor, mode=mode, 197 | save_weights_only=True, save_best_only=True, verbose=0) 198 | early_stop = callbacks.EarlyStopping(monitor=monitor, mode=mode, patience=patience) 199 | 200 | return [checkpoint, early_stop] 201 | 202 | 203 | def get_model(beta=1.0, enc_type='gnn', dec_type='cnn', if_reg=True, if_cls=True, seed=42): 204 | """ 205 | Constructs and returns a VAE model based on the specified encoder 206 | and decoder types, with optional regression and classification tasks. 207 | 208 | Args: 209 | beta (float, optional): The beta parameter for the VAE KL term Default is 1.0. 210 | enc_type (str, optional): The type of encoder to use. Default is 'gnn'. 211 | dec_type (str, optional): The type of decoder to use. Default is 'cnn'. 212 | if_reg (bool, optional): If a regression tasks should be added. Default is True. 213 | if_cls (bool, optional): If a classification task should be added. Default is True. 214 | seed (int, optional): Seed for the random number generator in the sampling layer. Default is 42. 215 | 216 | Returns: 217 | tf.keras.Model: The constructed VAE model with the specified configuration. 218 | """ 219 | 220 | if enc_type == 'gnn': 221 | x_in, x_out = encoder_gnn() 222 | 223 | elif enc_type == 'desc_dnn': 224 | x_in, x_out = encoder_desc_dnn() 225 | 226 | elif enc_type == 'desc_gnn': 227 | x_in, x_out = encoder_desc_gnn() 228 | 229 | else: 230 | raise Exception("Option not available.") 231 | 232 | z_mean, z_log_var = latent_space(x_out, beta=beta) 233 | z = Sampling(seed=seed)([z_mean, z_log_var]) 234 | 235 | if dec_type == 'cnn': 236 | dec = decoder_cnn() 237 | else: 238 | raise Exception("Option not available.") 239 | 240 | d_out = dec(z) 241 | 242 | if if_reg: 243 | reg = regressor_dnn() 244 | r_out = reg(z) 245 | 246 | if if_cls: 247 | cls = classifier_dnn() 248 | c_out = cls(z) 249 | 250 | if if_reg and if_cls: 251 | model = tf.keras.Model(inputs=x_in, outputs=[d_out, r_out, c_out]) 252 | 253 | elif if_reg and not if_cls: 254 | model = tf.keras.Model(inputs=x_in, outputs=[d_out, r_out]) 255 | 256 | elif not if_reg and if_cls: 257 | model = tf.keras.Model(inputs=x_in, outputs=[d_out, c_out]) 258 | 259 | elif not if_reg and not if_cls: 260 | model = tf.keras.Model(inputs=x_in, outputs=[d_out]) 261 | 262 | return model 263 | 264 | def get_spec(file): 265 | """ 266 | Extract specifications from a filename. 267 | 268 | Args: 269 | file (str): The file name from which specifications are to be extracted. 270 | 271 | Returns: 272 | tuple: A tuple containing the following extracted specifications: 273 | - ENCODER (str): The encoder specification. 274 | - DECODER (str): The decoder specification. 275 | - MONITOR (str): The monitor specification. 276 | - IF_REG (bool): Flag indicating if regression is used. 277 | - IF_CLS (bool): Flag indicating if classification is used. 278 | - weights (dict): Weights for loss functions. 279 | - LR (float): Learning rate. 280 | - BS (int): Batch size. 281 | """ 282 | root = file.split("/")[-1].split("_") 283 | if "desc" in root: 284 | ENCODER = "_".join(root[:2]) 285 | root2 = root[2:] 286 | else: 287 | ENCODER = root[0] 288 | root2 = root[1:] 289 | 290 | DECODER = root2[0] 291 | 292 | if len(root2) == 12: 293 | MONITOR = "_".join(root2[3:6]) 294 | root3 = root2[6:] 295 | else: 296 | MONITOR = "_".join(root2[3:5]) 297 | root3 = root2[5:] 298 | 299 | IF_REG = root3[0] == "True" 300 | IF_CLS = root3[1] == "True" 301 | 302 | weights = ast.literal_eval(root3[3]) 303 | 304 | LR = float(root3[4]) 305 | BS = int(root3[5].split(".h5")[0]) 306 | return ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS 307 | 308 | 309 | 310 | def train_vae(enc_type, dec_type, monitor, if_reg, if_cls, x_train, x_valid, y_train, y_valid, 311 | c_train, c_valid, l_train, l_valid, beta=1.0, weights=[1.0, 1.0, 1.0], 312 | lr=0.001, bs=32, if_train=False, n_class=6, n_epoch=1000, date='20230828'): 313 | """ 314 | This function trains a VAE model. It handles different types of encoders and decoders, 315 | allows for regression and classification tasks, and accommodates various training configurations. 316 | The function can also load pre-trained weights instead of training from scratch. 317 | 318 | Args: 319 | enc_type (str): The type of encoder to use. 320 | dec_type (str): The type of decoder to use. 321 | monitor (str): The parameter to monitor during training. 322 | if_reg (bool): Flag indicating if regression is to be performed. 323 | if_cls (bool): Flag indicating if classification is to be performed. 324 | x_train (numpy.ndarray): Training graph features. 325 | x_valid (numpy.ndarray): Validation graph features. 326 | y_train (numpy.ndarray): Training rg2. 327 | y_valid (numpy.ndarray): Validation rg2. 328 | c_train (numpy.ndarray): Training topology class. 329 | c_valid (numpy.ndarray): Validation topology class. 330 | l_train (numpy.ndarray): Training topological descriptors 331 | l_valid (numpy.ndarray): Validation topological descriptors 332 | beta (float, optional): The beta parameter for VAE. Default is 1.0. 333 | weights (list, optional): List of weights for different components of the loss function. Default is [1.0, 1.0, 1.0]. 334 | lr (float, optional): Learning rate. Default is 0.001. 335 | bs (int, optional): Batch size. Default is 32. 336 | if_train (bool, optional): Flag indicating if the model should be trained or loaded. Default is False. 337 | n_class (int, optional): Number of classes for classification. Default is 6. 338 | n_epoch (int, optional): Number of epochs for training. Default is 1000. 339 | date (str, optional): Date string used for file naming. Default is '20230828'. 340 | 341 | Returns: 342 | tuple: A tuple containing the trained or loaded model and the path to the history file. 343 | """ 344 | 345 | K.clear_session() 346 | 347 | t0 = timer() 348 | 349 | if enc_type == 'desc_dnn': 350 | in_train = np.copy(l_train) 351 | in_valid = np.copy(l_valid) 352 | 353 | elif enc_type == 'desc_gnn': 354 | in_train = [[np.copy(x_train), np.copy(x_train)], np.copy(l_train)] 355 | in_valid = [[np.copy(x_valid), np.copy(x_valid)], np.copy(l_valid)] 356 | 357 | elif enc_type == 'gnn': 358 | in_train = [np.copy(x_train), np.copy(x_train)] 359 | in_valid = [np.copy(x_valid), np.copy(x_valid)] 360 | 361 | else: 362 | in_train = np.copy(x_train) 363 | in_valid = np.copy(x_valid) 364 | 365 | model = get_model(beta=beta, enc_type=enc_type, dec_type=dec_type, if_reg=if_reg, if_cls=if_cls) 366 | model, loss_weights = compile_model(model, lr=lr, if_reg=if_reg, if_cls=if_cls, weights=weights) 367 | 368 | weight_name, hist_name = get_file_names(enc_type, dec_type, date, LATENT_DIM, monitor, 369 | if_reg, if_cls, beta, loss_weights, lr, bs) 370 | 371 | 372 | if (not if_reg) and (not if_cls) and ('acc' in monitor): 373 | monitor = 'val_acc' 374 | 375 | if if_train: 376 | c1, c2 = get_callbacks(os.path.join(WEIGHT_DIR, weight_name), monitor=monitor) 377 | 378 | if not if_reg and not if_cls: 379 | train_label = x_train 380 | valid_label = x_valid 381 | 382 | else: 383 | train_label = [x_train] 384 | valid_label = [x_valid] 385 | 386 | if if_reg: 387 | train_label.append(y_train) 388 | valid_label.append(y_valid) 389 | 390 | if if_cls: 391 | train_label.append(to_categorical(c_train, n_class)) 392 | valid_label.append(to_categorical(c_valid, n_class)) 393 | 394 | hist = model.fit(in_train, train_label, validation_data=(in_valid, valid_label), 395 | callbacks=[c1, c2], epochs=n_epoch, verbose=0, batch_size=bs) 396 | 397 | with open(os.path.join(WEIGHT_DIR, hist_name), 'wb') as handle: 398 | pickle.dump(hist.history, handle) 399 | else: 400 | model.load_weights(os.path.join(WEIGHT_DIR, weight_name)) 401 | 402 | t1 = timer() 403 | 404 | print(weight_name + f' finished in {t1-t0:0.2f} sec ...') 405 | 406 | return model, os.path.join(WEIGHT_DIR, hist_name) 407 | 408 | 409 | def rec_loss(y_true, y_pred): 410 | """ 411 | Reconstruction loss for a binary classification problem. 412 | 413 | Args: 414 | y_true (tf.Tensor): True labels. 415 | y_pred (tf.Tensor): Predicted labels. 416 | 417 | Returns: 418 | tf.Tensor: Binary cross-entropy loss multiplied by 10000. 419 | """ 420 | y_true_ = K.reshape(y_true, (-1, 100 * 100)) 421 | y_pred_ = K.reshape(y_pred, (-1, 100 * 100)) 422 | loss = tf.keras.losses.binary_crossentropy(y_true_, y_pred_) * 10000 423 | return loss 424 | 425 | def reg_loss(y_true, y_pred): 426 | """ 427 | Regression loss using mean absolute error. 428 | 429 | Args: 430 | y_true (tf.Tensor): True values. 431 | y_pred (tf.Tensor): Predicted values. 432 | 433 | Returns: 434 | tf.Tensor: Mean absolute error loss. 435 | """ 436 | return tf.keras.losses.mean_absolute_error(y_true, y_pred) 437 | 438 | def cls_loss(y_true, y_pred): 439 | """ 440 | Classification loss using categorical cross-entropy. 441 | 442 | Args: 443 | y_true (tf.Tensor): True class labels (one-hot encoded). 444 | y_pred (tf.Tensor): Predicted class probabilities. 445 | 446 | Returns: 447 | tf.Tensor: Categorical cross-entropy loss. 448 | """ 449 | return tf.keras.losses.categorical_crossentropy(y_true, y_pred) 450 | 451 | def acc(y_true, y_pred): 452 | """ 453 | Compute balanced accuracy metric for binary classification. 454 | 455 | Balanced accuracy is a metric that takes into account both sensitivity (true positive rate) 456 | and specificity (true negative rate) to provide a balanced measure of classification performance. 457 | 458 | Args: 459 | y_true (tf.Tensor): True labels (ground truth). 460 | y_pred (tf.Tensor): Predicted labels (probabilities or binary predictions). 461 | 462 | Returns: 463 | tf.Tensor: Balanced accuracy score. 464 | """ 465 | y_true_flat = K.flatten(y_true) 466 | y_pred_flat = K.flatten(y_pred) 467 | y_pred_bin = K.round(y_pred_flat) 468 | TP = K.sum(y_true_flat * y_pred_bin) 469 | FP = K.sum((1-y_true_flat) * y_pred_bin) 470 | TN = K.sum((1-y_true_flat) * (1-y_pred_bin)) 471 | FN = K.sum(y_true_flat * (1-y_pred_bin)) 472 | sensitivity = TP / (TP + FN + K.epsilon()) 473 | specificity = TN / (TN + FP + K.epsilon()) 474 | balanced_accuracy = (sensitivity + specificity) / 2 475 | return balanced_accuracy 476 | 477 | 478 | def get_file_names(enc_type, dec_type, date, dim, monitor, if_reg, if_cls, beta, loss_weights, lr, bs): 479 | """ 480 | Generate file names based on various parameters. 481 | 482 | Args: 483 | enc_type (str): The encoding type. 484 | dec_type (str): The decoding type. 485 | date (str): The date or timestamp. 486 | dim (int): The dimensionality. 487 | monitor (str): The monitoring type. 488 | if_reg (bool): Whether regularization is used. 489 | if_cls (bool): Whether classification is used. 490 | beta (float): The beta value. 491 | loss_weights (list): List of loss weights. 492 | lr (float): Learning rate. 493 | bs (int): Batch size. 494 | 495 | Returns: 496 | tuple of str: A tuple containing two file names in the format (model_file_name, pickle_file_name). 497 | """ 498 | base_name = f"{enc_type}_{dec_type}_{date}_{dim}_{monitor}_{if_reg}_{if_cls}_{beta}_{loss_weights}_{lr}_{bs}" 499 | return base_name + ".h5", base_name + ".pickle" 500 | 501 | 502 | 503 | def compile_model(model, lr=0.001, if_reg=True, if_cls=True, weights=[1.0, 1.0, 1.0]): 504 | """ 505 | Compile a Keras model with specified configuration. 506 | 507 | Args: 508 | model (tf.keras.Model): The Keras model to compile. 509 | lr (float, optional): Learning rate for the optimizer. Default is 0.001. 510 | if_reg (bool, optional): Whether to include a regularization loss. Default is True. 511 | if_cls (bool, optional): Whether to include a classification loss. Default is True. 512 | weights (list of float, optional): Loss weights for different components. 513 | Default is [1.0, 1.0, 1.0], and it should have at least 2 elements. 514 | 515 | Returns: 516 | tf.keras.Model: Compiled Keras model. 517 | list of float: Loss weights used during compilation. 518 | """ 519 | loss = [rec_loss] 520 | loss_weights = [weights[0]] 521 | if if_reg: 522 | loss.append(reg_loss) 523 | loss_weights.append(weights[1]) 524 | if if_cls: 525 | loss.append(cls_loss) 526 | try: 527 | loss_weights.append(weights[2]) 528 | except: 529 | loss_weights.append(weights[1]) 530 | 531 | if not if_reg and not if_cls: 532 | model.compile( 533 | optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=lr), 534 | loss=loss, 535 | metrics={'decoder': acc} 536 | ) 537 | else: 538 | model.compile( 539 | optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=lr), 540 | loss=loss, 541 | loss_weights=loss_weights, 542 | metrics={'decoder': acc} 543 | ) 544 | return model, loss_weights 545 | 546 | 547 | def latent_model(model, data, enc_type='gnn', mean_var=False): 548 | """ 549 | Generates latent space representations from a given model and data. 550 | 551 | Args: 552 | model (tf.keras.Model): The model from which to generate the latent representations. 553 | data (list or numpy.ndarray): The input data for the model. This can vary based on encoder type. 554 | enc_type (str, optional): The type of encoder used in the model. Defaults to 'gnn'. 555 | mean_var (bool, optional): If True, returns both the mean and variance from the latent space. 556 | Defaults to False. 557 | 558 | Returns: 559 | numpy.ndarray or tuple: The latent space representation or a tuple of mean and variance representations. 560 | """ 561 | # Prepare input based on encoder type 562 | if enc_type == 'gnn': 563 | x_in = [data[0], data[0]] 564 | elif enc_type == 'desc_gnn': 565 | x_in = [[data[0], data[0]], data[1]] 566 | elif enc_type == 'desc_dnn': 567 | x_in = data[1] 568 | else: 569 | x_in = data[0] 570 | 571 | # Generate latent space representations 572 | if mean_var: 573 | l1_model = tf.keras.Model(inputs=model.input, outputs=model.get_layer('z1').output) 574 | l1 = l1_model.predict(x_in, verbose=0) 575 | l2_model = tf.keras.Model(inputs=model.input, outputs=model.get_layer('z2').output) 576 | l2 = l2_model.predict(x_in, verbose=0) 577 | return l1, l2 578 | else: 579 | l_model = tf.keras.Model(inputs=model.input, outputs=model.get_layer('sampling').output) 580 | l = l_model.predict(x_in, verbose=0) 581 | return l 582 | 583 | -------------------------------------------------------------------------------- /polymer_generation_playground.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "a5d7bdcc", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import warnings\n", 11 | "warnings.filterwarnings(\"ignore\")\n", 12 | "\n", 13 | "import os\n", 14 | "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n", 15 | "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n", 16 | "\n", 17 | "import pickle\n", 18 | "\n", 19 | "import proplot as pplt\n", 20 | "from graph_utils import *\n", 21 | "from data_utils import *\n", 22 | "from model_utils import *\n", 23 | "from analysis_utils import *\n", 24 | "from saliency_utils import *\n", 25 | "from generation_utils import *\n", 26 | "\n", 27 | "DATA_DIR = '/scratch/gpfs/sj0161/topo_data/'\n", 28 | "WEIGHT_DIR = '/scratch/gpfs/sj0161/topo_result/'\n", 29 | "ANALYSIS_DIR = '/scratch/gpfs/sj0161/topo_analysis/'\n", 30 | "\n", 31 | "\n", 32 | "pplt.rc['figure.facecolor'] = 'white'\n", 33 | "\n", 34 | "COLORS = []\n", 35 | "colors1 = pplt.Cycle('default')\n", 36 | "colors2 = pplt.Cycle('538')\n", 37 | "\n", 38 | "for color in colors1:\n", 39 | " COLORS.append(color['color'])\n", 40 | "\n", 41 | "for color in colors2:\n", 42 | " COLORS.append(color['color'])\n", 43 | "\n", 44 | "LATENT_DIM = 8" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 4, 50 | "id": "bc4b783d", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64.h5 finished in 1.05 sec ...\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "# Load latent space\n", 63 | "with open(os.path.join(ANALYSIS_DIR, \"latent_space.pickle\"), \"rb\") as handle:\n", 64 | " latent_train = pickle.load(handle)\n", 65 | " latent_valid = pickle.load(handle)\n", 66 | " latent_test = pickle.load(handle)\n", 67 | " \n", 68 | "latent_all = np.concatenate((latent_train, latent_valid, latent_test), axis=0)\n", 69 | "\n", 70 | "# Load data label\n", 71 | "((x_train, y_train, c_train, l_train, graph_train),\n", 72 | "(x_valid, y_valid, c_valid, l_valid, graph_valid),\n", 73 | "(x_test, y_test, c_test, l_test, graph_test),\n", 74 | "NAMES, SCALER, LE) = load_data(os.path.join(DATA_DIR, 'rg2.pickle'), fold=0, if_validation=True)\n", 75 | "\n", 76 | "graph_all = np.concatenate((graph_train, graph_valid, graph_test))\n", 77 | "y_all = np.concatenate((y_train, y_valid, y_test))\n", 78 | "c_all = np.concatenate((c_train, c_valid, c_test))\n", 79 | "\n", 80 | "# Load TopoGNN model\n", 81 | "file = \"desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64\"\n", 82 | "\n", 83 | "ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 84 | "\n", 85 | "model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 86 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 87 | " l_train, l_valid, 1.0, weights, LR, BS, False)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "b2ae9abb", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "def check_isomorphism(graph_list, new_graph):\n", 98 | " for graph in graph_list:\n", 99 | " if nx.is_isomorphic(graph, new_graph):\n", 100 | " return True\n", 101 | " return False\n", 102 | "\n", 103 | "\n", 104 | "def rg_latent_vector(l, y_train, c_train, poly_type='branch', target_rg=40):\n", 105 | " \n", 106 | " idx = np.where(NAMES == poly_type)[0][0]\n", 107 | "\n", 108 | " a = l[np.where(c_train == idx)[0]]\n", 109 | " y = y_train[np.where(c_train == idx)[0]]\n", 110 | " \n", 111 | " if np.abs(y - target_rg).min() < 1:\n", 112 | " \n", 113 | " idx2 = np.where(np.abs(y - target_rg) < 1)[0]\n", 114 | "\n", 115 | " return a[idx2], y[idx2]\n", 116 | " \n", 117 | " else:\n", 118 | " \n", 119 | " return None, None" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 19, 125 | "id": "351c7a23", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "def gen_prop_polymer(target_rg=30, target_top=\"branch\", max_iter=1000):\n", 130 | " outputs = []\n", 131 | " graphs = []\n", 132 | " \n", 133 | "\n", 134 | " latent_vector, rg_dataset = rg_latent_vector(latent_all, y_all, c_all, target_top, target_rg)\n", 135 | " \n", 136 | " if latent_vector is None:\n", 137 | " raise Exception(\"The target rg2 is too large/small \")\n", 138 | "\n", 139 | " for i in range(max_iter):\n", 140 | " noise = np.random.normal(0, 1, (1, 8)) * 0.1\n", 141 | " \n", 142 | " num = len(latent_vector)\n", 143 | "\n", 144 | " for j in range(num):\n", 145 | " K.clear_session()\n", 146 | " d_in = latent_vector[j, ...] + noise\n", 147 | " graph_raw, graph, gen_cls, cln_cls, gen_reg, cln_reg_m, cln_reg_s = polymer_generation(model, d_in, None, ENCODER)\n", 148 | "\n", 149 | " flag1 = np.abs(gen_reg - cln_reg_m) < 2\n", 150 | " flag2 = np.abs(gen_reg - target_rg) < 2\n", 151 | " flag3 = np.abs(cln_reg_m - target_rg) < 2\n", 152 | " flag4 = gen_cls == cln_cls\n", 153 | " flag5 = cln_cls == target_top\n", 154 | "\n", 155 | " x_clean_ = nx.to_numpy_array(graph)\n", 156 | " n_clean = len(x_clean_)\n", 157 | " x_clean = np.zeros((1, 100, 100))\n", 158 | " x_clean[0, :n_clean, :n_clean] = x_clean_\n", 159 | " x_clean = x_clean.astype(\"int\")\n", 160 | "\n", 161 | " l_clean = get_desc(graph)[None, ...]\n", 162 | " l_clean = SCALER.transform(l_clean)\n", 163 | "\n", 164 | " d_clean = latent_model(model, data=[x_clean, l_clean], enc_type=ENCODER, mean_var=False)\n", 165 | "\n", 166 | " if flag1 and flag2 and flag3 and flag4 and flag5: \n", 167 | " if len(graphs) > 0:\n", 168 | " if not check_isomorphism(graphs, graph) and not check_isomorphism(graph_all, graph):\n", 169 | " graphs.append(graph)\n", 170 | " outputs.append([d_in, graph_raw, graph, gen_cls, cln_cls, gen_reg, cln_reg_m, cln_reg_s, latent_vector[j, ...], d_in, d_clean])\n", 171 | " else:\n", 172 | " graphs.append(graph)\n", 173 | " outputs.append([d_in, graph_raw, graph, gen_cls, cln_cls, gen_reg, cln_reg_m, cln_reg_s, latent_vector[j, ...], d_in, d_clean])\n", 174 | "\n", 175 | " print(f\"{len(outputs)}/{num * max_iter} found ...\", end=\"\\r\")\n", 176 | " # check latent space distance\n", 177 | " z_cleans = []\n", 178 | " z_raws = []\n", 179 | " rmses = []\n", 180 | " new_outputs = []\n", 181 | " \n", 182 | " for i in range(len(outputs)):\n", 183 | " graph = outputs[i][2]\n", 184 | " x_clean_ = nx.to_numpy_array(graph)\n", 185 | " n_clean = len(x_clean_)\n", 186 | " x_clean = np.zeros((1, 100, 100))\n", 187 | " x_clean[0, :n_clean, :n_clean] = x_clean_\n", 188 | " x_clean = x_clean.astype(\"int\")\n", 189 | "\n", 190 | " l_clean = get_desc(graph)[None, ...]\n", 191 | " l_clean = SCALER.transform(l_clean)\n", 192 | "\n", 193 | " z_clean = latent_model(model, data=[x_clean, l_clean], enc_type=ENCODER, mean_var=False).squeeze()\n", 194 | "\n", 195 | " z_raw = outputs[i][0].squeeze()\n", 196 | " rmse = skm.mean_absolute_error(z_raw, z_clean)\n", 197 | "\n", 198 | " if rmse < 1:\n", 199 | " z_cleans.append(z_clean)\n", 200 | " z_raws.append(z_raw)\n", 201 | " rmses.append(rmse)\n", 202 | " new_outputs.append(outputs[i]+[z_clean])\n", 203 | " \n", 204 | " print(f\"{len(new_outputs)}/{len(outputs)} latent space check passed ...\")\n", 205 | " \n", 206 | " return new_outputs" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 27, 212 | "id": "baf8c1d0", 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "3/3 latent space check passed ...\n" 220 | ] 221 | } 222 | ], 223 | "source": [ 224 | "target_rg = 26.5\n", 225 | "target_top = \"star\"\n", 226 | "\n", 227 | "outputs = gen_prop_polymer(target_rg=target_rg, target_top=target_top, max_iter=10)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 30, 233 | "id": "5147185e", 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "data": { 238 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAqMAAAEECAYAAAAcbj2jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAB7CAAAewgFu0HU+AAAppklEQVR4nO3de3BV1dn48SfJCZcQyI0QLiYBJApYoQpKKwKRSlvbn/5sa20rvtZqO9Xx9VdbcKR0OlBvM9S2g7YzpbVYvKFD7cwrar3UigGxLyBqBRENcknQSAiJ3EJCLuv3Rzw7ycnZ5+xz9m3tfb6fmcwgZ++TRcxz9rPWetZaWUopJQAAAIAPsv1uAAAAADIXySgAAAB8QzIKAAAA35CMAgAAwDckowAAAPANySgAAAB8QzIKAAAA35CMAgAAwDckowAAAPANySgAAAB8QzIKAAAA35CMAgAAwDcko0jq5MmTUldXJ42NjdLd3e13c4DAOnr0qNTX18vp06f9bgqAFHR3d8snn3wizc3NfjcllEhGHfLMM8/IlVdeKePGjZPBgwdLUVGRzJ49W1auXCmnTp0yvW/GjBmSlZWV9OvJJ59Mq12bNm2Sa665RiorK2XIkCEyYsQImTFjhtx1113y6aefmt7X3t4uK1askClTpkh+fr5UVlZKWVmZFBYWyre+9S3Zvn17Wu0B3KJrDB4/flyWLFkiY8eOlcLCQqmoqJCCggL55je/Kbt37/atXYDTdI3BvhoaGiQ7O1uuvfZaS9fv379fFi5cKPn5+TJmzBgpKSmRsrIyueOOO+TkyZO224PPKNjS1dWlrr32WiUipl9TpkxRH374Ydz7CwoKEt4b/XriiSdSbtvPf/7zhO85btw4tW3btgH3nTp1Sl100UUJ741EIurhhx9OuU2A03SOwfr6ejVp0iTT9xw2bJjavHmz5+0CnKRzDMa67777lIiohQsXJr12+/btCdv2uc99TrW0tNhuE5QiGbXp17/+tfGLOW/ePPX000+rd999V9XU1Kibb75ZZWVlKRFRVVVV6vjx4/3ubWpqMu594YUXVG1trelX7L3JrFu3znjv6dOnqyeffFLt3LlT/fvf/1ZLlixRgwYNUiKiSktL1cGDB/vd+4tf/MK498orr1Tbtm1Tx44dUwcPHlSrV69WpaWlSkTUoEGD1J49e2z/DAE7dI3Bjo4ONWPGDCUiqqysTK1du1a1tLSopqYm9de//lUVFhYqEVEVFRWqtbXVs3YBTtM1BmMdPHjQeH4lS0ZbW1tVZWWlEhFVWFio1q5dq44dO6YaGhrUsmXLjH/TVVddZatN6EEyakN7e7vxi71gwQLV2dk54Jo1a9YYgXb33Xf3e23r1q3GaydOnHC0bdOmTTN6bidPnhzw+ssvv6xycnKUiKgf/vCHxt93d3cb/6Yrrrgi7nvX1NQYgbh06VJH2w2kQucY/POf/6xEROXl5aldu3YNeP3ZZ581vnfsLIOb7QKcpHMMKqVUXV2deu6559SiRYtUSUmJ8b2SJaP333+/ce1LL7004PVf/vKXSkRUVlaW+s9//uN4uzMNyagNGzduNH5ZX3vtNdPrvvjFLyoRUZ///Of7/f0TTzyhRESNGjXK0XbV1dUZ7XrsscdMr/ve976nREQVFRWp7u5upZRStbW1xr3PPPOM6b0TJkxQIqIuv/xyR9sOpELXGFRKqSlTpigRUYsXLza95qtf/aqaPn26Wr58uWftApykcwy2tLSYTrEnS0ZnzpypRERdfPHFcV8/duyYGjZsmBIRdccddzje9kzDAiYbduzYISIigwYNklmzZpled/HFF4uISG1tbb+/37t3r4iITJgwwZV2iYjMmTMnabtaWlrkyJEjIiLyySefGK9XVVWZ3jtixAgREcnKyrLVVsAOXWNwz5498t5774mIyMKFC02ve/755+Xtt9+WZcuWedIuwGm6xqCISCQSkXnz5vX7KigoSHpfc3OzsUj3G9/4Rtxrhg8fLl/60pdEROTFF190rtEZKuJ3A4IsusVDSUmJRCLmP8rjx4+LiIhSqt/ff/jhhyIiMnHiROP1Q4cOSXd3t5SWlkpubq6tdomIjB49Omm7+rbt4osvHtDOWA0NDcaD9vzzz0+rjYATdI3BTZs2iYhIXl6eTJ8+PeX73WoX4DRdY1BEJD8/X1599dV+f1ddXS01NTUJ73v33XeNds6cOdP0upkzZ8r69etl9+7dopRicMYGRkZtWLp0qXR0dEh9fb3pNV1dXfLSSy+JiMjkyZP7vRbtERYUFMiiRYukrKxMxowZI+PGjZMRI0bIZZddljRo4lm4cKF0dHRIR0eHDBo0yPS65557TkR6PkRKS0tNr1NKSVtbm9TV1cnf//53+fKXvyynT5+W8vJyueWWW1JuH+AUXWMw2lmbMGGCZGVlybp162T+/PlSUlIiQ4YMkaqqKrntttvk448/jnu/W+0CnKZrDNqxZ88e48/jx483va6yslJERNra2uTgwYNuNyvc/KkOyBxLliwxalRWrFjR77Xy8vKkW1lkZWWpe+65x/F2rVq1yvgeN998c8Jr+66uj35dccUVat++fY63C3CaHzH43e9+V4mI+sIXvqBuuOEG0/cuKipSmzZtGnC/n58NgNN0eg7Omzcvac3o7373O0uLqvouQmQRkz0koy45fvy4uvHGG41f1DPPPLPfqvb29naVnZ2tRETl5OSoRYsWqffff1+1t7erhoYG9dBDD6kxY8YY969bt86RdrW3t6ulS5ca37ukpER99NFHCe+Jl4xeeOGFauPGjY60CXCDnzH49a9/3XhfEVGTJ09Wa9asUTt27FDbtm1Ty5cvV0OHDjVisO/2an59NgBO0/E5aCUZvfvuu43vGW93gKh//etfxnWvv/667bZlMpJRFzz66KNq7Nixxi/p6NGjB2ztcvDgQTVr1iw1a9Ys9be//S3u++zbt08VFxcbQWzXP/7xD3XWWWcZ7crPz1c1NTVJ7zty5Iiqra1VGzduVCtXrjQ+HAYPHqz++c9/2m4X4DS/Y3D+/PnG954xY0bc/RFfeuklY4u0W265xZN2AV7xOwbNWElG77nnHqPdp0+fNr3uxRdfNK4zO7wC1pCMOmjHjh1q9uzZ/UYQq6urVX19fdrveddddxnv9c4776T1HgcOHFBXXHFFv3ZNnz5d7dy5M633a25uNrZ2Ovfcc9N6D8ANusTgZZddZtyTaAYhet2YMWM8aRfgNl1i0IyVZHTlypXG9zt27JjpdU8//bRx3dtvv22rXZmOBUwOUErJihUr5Pzzz5fNmzeLiEh5ebk89thjsmHDBjnjjDPSfu9LL73U+HPfomqrHn74YZk6daqsX79eRHoWKz3wwAPyxhtvyDnnnJNWm4qKiuQnP/mJiPRs60HhNvymWwwOHz5cRHpW08+ePdv0uurqahHp2aHi6NGjrrcLcItuMWhHSUmJ8edDhw6ZXtf3teLiYlfbFHZs7WRTZ2enfP/735e1a9eKiMiQIUNk6dKlsnjxYhk6dKjt9x85cqTx5/b29pTuXbJkiaxYsUJERLKzs+XWW2+VZcuWSVFRkek969evlzfffFPKy8vlxhtvNL1uypQpxp8/+ugjWx80gB06xmB0Be7w4cMlO9u8z983FltbWy3tgWinXYAbdIxBO/rusb1v3z6ZNGlS3Ov2798vIj2dTp6B9pCM2rR48WIjAM8991x54oknLI04vvjii9LQ0CDl5eXGxrnxNDU1GX8uKyuz3K6VK1caiWhFRYU8/vjjxqbDibzyyity//33y5QpUxImoydPnjT+7MSHDZAuHWMw+v2bm5ulra1NhgwZEve6xsZGEek5PCI6suL2ZwPgNB1j0I5zzjlHcnNzpaOjQ7Zs2SILFiyIe93WrVtFpOffzB6jNvlcJhBoO3bsMBYgzJw5U7W0tFi+97bbblMiogoLC1Vra6vpdb/61a+UiKjc3Ny4iyDiaWxsVHl5eUpE1IQJE1RdXZ3ldj3wwANKRFQkElFHjhwxve6nP/2psYip7+pIwEu6xuBHH31ktOupp54yvS56ROJ5553nSbsAp+kag2as1IwqpdQll1yiRERNmzYt7utNTU0qNzdXiYi68847bbUJLGCyZdGiRUpEVHZ29oBVgsm88cYbRuFz35W0fb333nuqoKBAiYi67rrrLL/373//e+O9X3jhhZTatWvXLuPen/3sZ3Gv2blzp3Em77e//e2U3h9wkq4xqJRSCxYsUCKixo8frw4dOjTg9b/85S/G97///vs9axfgJJ1jMB6ryejatWuNtj366KMDXv/xj3+sREQNGjRI7d+/33a7Mh3JqA2zZs1SIqKmTp2qamtrk37F/sJeffXVxi/71772NbV+/Xq1c+dO9frrr6s777zTCMCysrJ++xAqpVR9fb0aN26cGjdunLrqqqv6vfad73xHiYgqKChQH3zwgaW29dV35f3VV1+tnn/+ebVr1y61ZcsWde+996rCwkLj/WPvBbykawwq1fOgjY6cVFZWqtWrVxvvfeuttxr7K5533nmqra3NsXYBXtI5BuOxmox2dXWpmTNnGgnnvffeq9566y312muvqeuuu85o86JFi1L7gSEuklEbSkpK+m1fkeyrsrKy3/3Hjh0zAsPsa/z48XFPdti3b59xzbx58/q9NmPGjJTaFVut0dTUZASh2dfo0aPZVw2+0zUGox577DE1aNAg0/eePn163GTSTrsAL+keg7GsJqNKKVVXV6cmTZpk2q7LL79cdXR0WPq+SIytnWw4duyYrfuHDx8ur7zyijzyyCPyla98RUpLSyUSiUhhYaFcdNFF8pvf/EbeffddmTZtmqftKikpkc2bN8uf/vQn4zztaLtmzZol99xzj+zatUsuuugiW98HsEvXGIxauHChvPXWW3LjjTfK+PHjZfDgwVJQUCCzZ8+WP/zhD7J161YZN26c5+0CnKJ7DNpRXl4ub775pixbtkymTp0qeXl5UlhYKLNnz5bVq1fL008/LZEI68CdkKWUUn43AgAAAJmJkVEAAAD4hmQUAAAAviEZBQAAgG9IRgEAAOAbklEAAAD4hmQUAAAAviEZBQAAgG9IRgEAAOAbklEAAAD4hmQUAAAAviEZBQAAgG9IRgEAAOAbklEAAAD4hmQUAAAAviEZBQAAgG9IRgEAAOAbklEAAAD4hmQUAAAAviEZBQAAgG9IRgEAAOAbklEAAAD4hmQUAAAAviEZBQAAgG9IRgEAAOAbklEAAAD4hmQUAAAAvon43QBddXZ1y6baJjlw5KRUlgyTOVUjJZJD7g4AAOAkktE49jSekBvWbJO65lbj7yqK8+Sh6y+QSaPyfWwZAABAuGQppZTfjdBJZ1e3zP9tTb9ENKqiOE82LK6WnOwsH1oGhBuzEQCQmRgZjbGptiluIioiUtfcKhs/OCyXTB7lcauAcGM2AgAyF8MOMQ4cOWnrdQCp6ezqHpCIivR0/m5Ys026upm8AYAwIxmNUVkyLOHruUwbAo6yMhsBAAgvMqsYc6pGSkVxnunrq2o+ZKQGcFCy2YZ9TSc8agkAOK+zq1s27G6UNZv3yYbdjdLZ1e13k7RDzWiMSE623DR3oiz9n51xX69vOSUbdh+SS6eO9rhlQDglm43446t7Ze5Zo6gdBRA41MNbw8hoHKeT9Fpuf+od2dPIaA3ghGSzEYdPtFM7CiBwqIe3jmQ0jmQjNS2tHfwiAQ6J5GTLQ9dfIKXDB5teU9fcKht2H/KwVQBgD/Xw1pGMxpFspEaEXyTASZNG5ctNcycmvIYZCQBBkqzend15epGMxhEdqSnKy014HQsrAOdMLE1cP8WMBICg2NN4Qv746t6E1ySbhc0kJKMmJo3Kl/uump7wmj++upeRGsAhzEgACINorejhE+2m11QU58ncs0o9bJXeSEYTqD67lIUVgEeYkQAQBq++32haKyoiUjp8sDx0/QUcLd4HyWgCVhdWMFIDOIMZCQBBtqfxhNz+1DsJr7l53kS2dYpBMpqElYUVjNQAzmFGArCPjda9F52eb2ntSHjdhJEkorHY9N6CZAsr2JQbcE50RuJ7D/6vHD4ev+YqOiNxyeRRHrcO0B8brfsj0VZOUdSKxsfIqAVsyg14ixkJ9zBiFm5stO6Pzq5uefm9xHshF+XlUitqgpFRCxipAbzHjITzGDELv2Qbrf/+lVr570smSSSHsSinxIureH5z1XTizAS/jRZZGal540CLR60Bwo8ZCWcxYpYZkm2kvvLlWpn/2xoWATrELK5iVRTnSTWDVaZIRlOQbKRm3bZ6PtABh3BMqLM4mjAzWNlInQ6Ic6zWiTI9nxjJaArmVI1M+GA8fKKdD3TAQRwT6gwr9WwcTRgOVg6PEOmdsqdm2J5ktevXXFguGxZXMz2fBMloCiI52XL1jDMSXsOiivhYNIF0cUyoPXsaT8j839bI41vqEl7H0YThEJ1RsJKQMmVvj5UjPxdMHc2IqAUkoymaOb444etsyD1Q9GH4gzXbZPkzu+QHa7bxAQjLOCY0fanUs7HdTHhMGpUvryyaJ7ddWpX0Wqbs08ORn84iGU0RiypSw6IJ2MUxoelLdiyhCPVsYRXJyZb/vmSS5Sn7DbsbPWhVeHDkp7NIRlPEEaHWdXZ1yx827GHRBGzjmNDUWTmWkHq2cEtlyv72p/5D/FjEkZ/OIxlNAxtyJxedml/5cm3C61g0Aas4JtQ6q8cSUs8Wflan7Km9toYjP91BMpomKxtyZ2ov02qdmgiLJmAdsxLWJJuRiKKeLXNYnbJnhX1yHPnpDpLRNFE7as5KsIoQsEgdsxKJWZ2R4FjCzGO19poV9okl+3whttJDMpomRmnMWUkGWDSBdDErEV8qMxIcS5iZemqvpyW9jgWm8VnZyonYSg/JqA2M0gxkJVhvu7SKRRNIG7MS8aUyI8GxhJmr+uxRllfYZ+JgihmrWzkRW+khGbWJUZpeVoP11vlVjIgibRwTGh8zErAilRX2/9z1CfWjn0nW2WMrJ3tIRm1ilKaHlUUTBCucwjGh/TEjgVRYXWG/dms99aOfSdbZYysne0hGbaJ21PqiCYIVTuKY0B7MSCAdqaywz4Q4SsRKZ4+tnOwhGXVAJteOprJogmCFkzgmlBkJ2GN1yj6TT2jq7ey1mV7DzjD2kYw6JBNrR63uZyhCsMJ5mX5MKDMScEJ0yn7hrIqE12XqCU29taLxO3N09pxBMuqQTKsdtfogFGHRBNyTqceEMiMBJ0VysuXSKWUJr8mUspe+Oru65eX3Ei+EpLPnDJJRh2RS7WgqD0IWTcBtmXhMKAdLwGlWy14y5YSm6IDL41vqEl5HZ88ZJKMOypTa0VQehCyagNsyqSMoYm20RoQZCaSGE5p6WR1wobPnHJJRh2VC7Sj7GUI3mdIRtDpaw4wE0sEJTT1efb/RUiLKM845JKMOC3vtKPsZQldh7wimMlrDjATSleknNO1pPCG3P/VOwmuuubCcZ5zDSEYdFubTYdjPEDoLe0eQ0Rp4IZNPaIo+41paOxJet2DqaGLMYSSjLgjj6TDsZwjdhbl2lNEaeClTT2iysh6COlF3kIy6JEynw7CfIYIijLWjjNbAD5l4QlOyz4aivFwGXFxCMuqSsJwOw36GCJqw1Y4yWgO/pHJCU9C3fLKyHuI3V01nwMUlJKMuCcPpMJywhCAKU+2olW2cGK2Bm6ye0BTkLZ+sroeonjzKw1ZlFpJRFwX5dBhOWEJQWa0d1X0kx+o2TozWwG1WTmgSCe6UfbLZB9ZDuI9k1GVBPB2GE5YQdFZqR3UeyUllGydGa+AFK6VnIsEoP+uLIz/1QDLqsiCu8OWEJYRBstpREX1HctjGCbpJZcunvYf16+DFw5Gf+iAZ9UDQVvhywhLCIKgjOWzjBF1Z3fJp1UY9y8/64shPvZCMeiQoK3w5YQlhkcpIji6dQbZxgu6sbPl0+Lh+5WexrO5SwaCLN0hGPRKEFb6csISwsTqSo0tnkG2cEARG+Vl+cBcJJuuAMvvgLZJRj+heO8oJSwgrSyM5mnQG2cYJQTFpVL7cXB3MRYJWZgCZffAWyaiHdK0d5YQlhJ3unUG2cUIQWVnYo9siQaszgMw+eItk1GO61Y5ywhIyha6dQbZxQlAFcZEge4rqiWTUY7rVjqayjRM9RQSdjp1BK6ecsZACOkplkeA/d32iRf1osg4nM4D+IBn1mE7ThVZq1ER4ECI8dOoMWi2PYSEFdGZ1keDarfW+149aqRVlBtAfJKM+0GG60GqNGts4IUx06QymUh7DQgrozsoiQRF/60epFdUbyahP/JwuTKVGjW2cEDZWOoMvv3fI1SlFymMQNlan7P3a8olaUb2RjPrEz+lCjhpEpkvWGXx8S52rU4qccoYwik7ZL5xVkfA6r7d84vx5/ZGM+sSv6UKOGgSsrQJ2a0qRU84QZpGcbLl0SlnS67yasuf8+WAgGfWR17WjHDUI9EhlStHJDiGnnCET6LLlE+fPBwfJqM+8rB3lqEGgV3RK8ZoLyxNe52SHkLo1ZAJdtnzi/PngIBn1mVe1oxw1CAwUycmWBVNHJ7zGqQ4hdWvIJDps+cT588FBMuozL2pHOWoQMOdFh5C6NWQiP7d84vz5YCEZ1YCbW81w1CCQmNsdQurWkMn8qM9mT9HgIRnVhBtbzXDUIGCNWx1CYhCwvuWTU/Wj1GYHD8moJpzeaoajBoHUON0hJAaBXla2fHKifpTa7GAiGdWEk1MZHDUIpM7JDiExCAzk9v6+1GYHF8moRpzaaoajBoHUpdIh3LA78cgLMQgMlFqMNab03tRmBxvJqGbsbjVjZYpChBo1IB6rHcLbn3qHGATSYLV+9Pan/pPSdD17igYbyaiG0t1qxuoUBUcNAuasdAhbWjuIQSBNVupHzWLMDHuKBhvJqIbS2WomlSkKjhoEErNa20YMAulJJ8bMsKdo8JGMairVc+uZogCcE+0QFuXlJryu73ZPxCBgndUYS7bdE3uKhgPJqMasnltvpUaNKQogNZNG5ct9V01PeE3f7Z6YJgRS0xNj0xJek2y7J/YUDYeI3w2Aueg0hlmgHT7RLv+1eotEsrOkvuVUwvdiigJIXfXZpQljUKRnKvG/Vm+Rjs7Em3UTg8BA1WePshRjN6zZJhsWV/eLIfYUDQ9GRjVmpXa04Whb0kSUKQogPVa3omk42iZNJ0+bvk4MAvGlu8c2e4qGC8mo5qzUjiZCjRpgj9XtnswwTQgkZnW7pwNHTooIe4qGEcloACSrHTVDjRrgDCvbPZlhmhBIzsp2T5+e6pDOrm4WC4YQyWgAWNkCI575k8sIRMAhX5xYLKX55iUzZpgmBKxJ9qxb+XKtzP9tjWzbfyTh+zAQEzwkowFgtaYm1p3P7krpBAsA8e1pPCFfXrkp4fYx8TBNCKTmx3MnJtzuqa65VVbVJN5TdNSIIaJU6mfbwz9Ziv9jgdF2ulN+/Nh2qfmgyfI9pfmD5ebqM2XCyGEyp2qkRHLofwDJRKcCDxw5KWcUDZVfPbMr6ULBWNFpQkZngOT2NJ6wVAdqFfEXLCSjmur7MKwsGSZjC4fIjx7ZbitQK4rz5MHrZsrHn54y3pcEFegfb7k52bKq5sOUk8++/u/nx8q9V35Otu5vIdYAE9G429fUc4JSqjMPyTAYExwko5ro+zAclJMtqzbu7Zd4RrKzpNPiGb2JxL4PvUdkOqdHZKKINcCcW3FnhsEYvZGM+sCNUU876D0ikzgxBZ+u8qKhsuyKc+RgcysPQ2Qct0dCk4ntIJYXDZWb5p0pHV3dxKPPSEY9kGzUMydbpKtbiYj/K98JToSN01PwTiLeEFYDB12Gyo8eecO3QRcrmL3wD8moC5Iln25xaiq/L4ITQaNz8pkMU4kIg3hT8D3Pp25JddClojhPrjxvrDzwrz0OtzI+Zi/8QTJqk99T7kV5uXLr/EkyYWS+XHRmibz+4RHZe/iErNq4Vw4fd2YKhOBEULhZh1ZRnCfL/s9UebOuRdZtP+hYfMViKhFB5MYUfHQwZHxJnsz/bY0vo6rEnzdIRlPk1UIjq8YUDJHX7pg/YHN7856p/bYRnNCJF3VosTMEfT8HcrKyZPkz70qXR2FP/EE3TnYCS4cPlpvnTZQJI/Nl7lmlxrNtT+MJ+faq16WltcP03msuLJczS/MdHYyJRfy5g2Q0Cb+m3FPx1+svkEsmjxrw97GjtuMKh8oPXajZITjhJS+m4SuK8+SmuRPl9Ge/030firFt6RmxOSl+1XwztQ8/uDkSalYW9vKuT+SHj2w3vT/6LHRzMCZem4k/+0hGY+hS79k3wattPCGPb6kzvXf55VPl+tkTLH2f6L/P6an8vkhO4RSvFkFEp+DrW1oTJp+xNuxulB+s2Wb6+m2XVknh0Fw5oyhP7nx2l2ufJUztw01uxqHZSKhZO+b8eoM0fHpKJKv/dbGzhF4NxogQf07I+GRUh5HPeD2rvkGZ7IFnNjKajFf7vBGYSIdXoxt2Fumt2bxPlj+zy/T1vh1Fr/dV7IuFiEhXvN/bnOws6UpjMZLdvXdTSUbN7vfjeU/8JZdxyagOyWdscmalN2hWvB3JzpJ//L85ctbo4Wm1xY+fB4EJM17Vf1qZgrci1Y6inzXn7CcMq9yagn/wuhny8adtcQddrHB6YKb/nsPuzl4Qf4mFOhn1a6V7siH7dB5+739yXL7+wKa4D6qK4jzZsLg67QdqX14FJ4EJEb3qP9PRdrpTPrf8Jens6h4wUhPJzpJdd35VBkXMf7e9nErsi9kK9OVWHKYyBW9FKjMR6WC20D+hTUa9LGDuK9mUe7rcmqpPxovgJDAzh1d7gDr9EDTjRlz6MVtBDGYutz7j3ZgB8+I56NdsYaYvggpVMurF4pxYTox6WuF2jzAR94Oz/+lTPBjDyatRBy/LQLyISz+m9onBcHOrHMbtTqDdmtF0v6cX8Zfpi6AifjfADh1GENwcdemrsmSYrdftiORk9+ttXn1BucM/9/4/v/qWU/KL/9lp/HemBWWYuPHQi7cIwq1p+GS8iMtE8efW1D4xGC5uzEjYXYwUFF7FX2xCm2kxGKiR0eTJp/Pnu/uVfMbyo0eYStu87BSEPSiDzO1peCcWQThJl7gkBmHGjRkJv+LQr3K1ZFgIbJ/2I6PR/8lv7G+OcwRfbPJpPxB0ST6DxP2R0/4yrceoMy9qQONN/Z09eoSj3yPozGLQrZIlYlBvbsxI6BCHB46ctPW6W7yOPxGRuuZW+d6f/zc0C4G1Gxl1u4fhxkp3L+jaI7TC61peHoze8KIGVPfefxDi0o/9TYlBb3kxI6FDHAYh3mKxQt8a35NRL6fe3Vrp7gU/FzA5iQdjsLm9D6if9Z/pCEpc+r2/MjHorOTJp73npq5xqEtZTKpYhJic58mo13Wfffez1CWg0hHEHqEZHozB4cWIi44PPauCGpfEYHC51aH3ajs0O4xk9Ogpic0TdE5GY/mxv7DuMedpMur1qJguUwtOsLu5ts54MOrFqNM+0CLrttXHjH7a7ywG4aFnVVBHamIRg3rzYkYiCM/KoHb+rMj0RYieJaOJjrR0SlDqP9MR5iCM5feDMSgfzE7xaiN6kfD9bMOSjMYiBv3FjER8QSmLcUKmJaeerabfVNvk+A8yzMlnLF1XEbrB69X5seqaW+WGNdscO2JVZ27PVgT1oWfVptomaTjaNiARFRFpONomGz84HMhOIjHon0yehk/Gz/22veb3Dhledwg9S0adSpZK8wfL1ReUy8zKosAGVDoyKQhj+fFgrGtuDWwikYybU35hTz5jZUonkRh0F9Pw1nxxYvFni326JbZcKJKdJbMnjfSnYR6IxuAlk0fJvLNHuV7y6HWH0LNkNN1kKZNGPxMZWzjEdMVdRXGezD2r1IdW+cOrB2NYEom+3Bh1CcOIS7oytZNIDNrDNHx6/r23+bNn4MB/S2e3ks17mjKi8zJpVL68smheqDqEniWjc6pGSkVxXtIfFsnnQJ1d3fKjR7bHTUQj2Vny4HUzM/pn5NaDMWyJRGdXtysnsYRhxCVd0ZGajq5uyYqzsDDMIzV9EYPWxe8QhnMrJqdlykyEFWHrEHqWjEZysuWh6y8YEIQkn8klqrft7Fby8aen5OzRwz1ulb6cCNIwjjY7UbedKQ89q6IjNbGJqEhmjdTEIgbjM+8Qph5DmTgjkakzEVYEvUPo6XGgsUPLPMysoTdoT6pBGh3tC9vvZbq/J6XDB8vVM86QmeOLidcYxKY1xGAPpxbyZuqMxJyqkTKmYIjpPqNh67zYEbQOoedn08f+gJAcvUFnJQrSMHeQrP6eMPppHbGZnkyNwXQ7J8Qk7NK9Q+h5MorU9a4gjF8zmil1aW7JlA5SorrtTJzyc4IxUmOyzygjNdZkSgym0jkhJgcytlKLU9YQ5K3U/KBbh5BkNAB6VxAOlMl1aUiNWd12pk75AV6zupCXmIyPshj3+N0hJBkNAAIQTqFu21lh3fQe7kjUIWQaPjnKYsKLZDQACEA4ye8ecJjQUUSq6BCmjwVM4UUyGgC9Adg24DUCEPAPHUWkgw4h0F+23w0AgKCKLi4UxeJCwG1WFjAhmEhGA6A3AAciAAH/GIsLE2x6D8AZlMWEF8loABCAgJ6ITcA7lMWEF8loABCAgJ6ITcA7RlmMUBYTNiSjAdAbgAMRgIB/oosL49WMsrgQcFbvntuUxYQNyWgAWNn0HgCAMKMsJrxIRgOAAAT0ZGXTewDOoCwmvEhGA4AABPRERxHwjlEWE6dmlLKYYCMZDQACENATHUUAsI9kFADSxKb3gHfY9D68SEYDgAAE9MSm94B3KIsJL5LRACAAAT0Rm4B3KIsJL5LRACAAAT0Rm4B32PQ+vEhGA4C6NEBPxCbgHTa9Dy+S0QCgLg3QE7EJeIeymPAiGQ0AAhDQE7EJeIeymPAiGQ0AAhDQE7EJeIea0fAiGQ0A6tIAPXEgBeAdakbDi2Q0AKhLAwBkOspiwotkNAAIQEBPHEgBeIeymPAiGQ2AQTmJ/zcRgIA/6CgC3qEsJrxIRjXX2dUtqzbuNX29vGgoAQj4hJEaALCPZFRzm2qbpK651fT1m+adKTnZA6cIAbiP1b2AdyiLCS+SUc0lm+br6Or2qCUAYrG6F/AOZTHhRTKqOaYBAX3xcAS8w/MwvEhGNddbsD0QBduAv3g4At6hLCa8SEYBIE08HAHvUBYTXiSjmust2B6Igm3AXzwcAe9QFhNeJKOaI/gAfRGfgHcoiwkvklHNEXyAvohPwDuUxYQXyajmeoNvIIIP8BcLDAHvUBYTXiSjmusNvoEIPgBApqAsJrxIRjVH8AH6YoEh4B3KYsKLZFRzBB+gLzqLgHd6y2IGzhZSFhNsJKOao2Ab0BedRQCwj2RUcxRsA/pigSHgnd6ymIExR1lMsJGMao5pQEBfLDAEvMPzMLxIRjXHNCCgLx6OgHd4HoYXyajmjGlARc0ooJtkD7/cHD5iAaewhiK8+KTUnDENmEXNKKCbOVUjpaI4z/T1VTUfSpfJND6A1LCGIrxIRjXHNCCgr0hOttw0d6Lp6/Utp1hUATiE52F4kYxqjhoZQG+nu7oTvs4DEnAGz8PwIhnVHDWjgN54QALeoGY0vEhGNUfNKKA3OoyAN6gZDS+SUc1RIwPojQ4j4A2eh+FFMqo5pgABvfGABLzB8zC8SEY1N7ZwiOlxgxXFeTL3rFKPWwSgLx6QgDeoGQ0vklGNdXZ1y48e2R73uMFIdpY8eN1MyTFJVAF4g5pRwBvUjIYXyajGNtU2SV1za9zXOruVfPzpKY9bBCAWNaOANyiJCS+SUY0ReID+iFPAG5TEhBfJqMYIPEB/xCngjTlVI2VMwRCJVzM6pmAIaygCjGRUY73F2gNRiwbogZpRALCHZFRjvcXaA1GLBuiBmlHAG5tqm6ThaJvEW8DUcLRNNn5w2PtGwREkoxqjFg3QH3EKeINYCy+SUY1RiwbojzgFvEGshRfJqMbmVI2UiuK8uK+x4T2gB+IU8AaxFl4koxqL5GTLQ9dfMCD4Korz5KHrL2DDe0ADxCngDWItvLKUirMEFFrp7OqWTbVNcuDISaksGSZzzyol6ADNEKeAN4i18CEZBQAAgG+YpgcAAIBvSEYBAADgG5JRAAAA+IZkFAAAAL4hGQUAAIBvSEYBAADgG5JRAAAA+IZkFAAAAL4hGQUAAIBvSEYBAADgG5JRAAAA+IZkFAAAAL4hGQUAAIBvSEYBAADgG5JRAAAA+IZkFAAAAL4hGQUAAIBvSEYBAADgG5JRAAAA+IZkFAAAAL4hGQUAAIBvSEYBAADgG5JRAAAA+IZkFAAAAL75/zNc/uC2a/3ZAAAAAElFTkSuQmCC", 239 | "text/plain": [ 240 | "Figure(nrows=1, ncols=3, refwidth=1.0, refheight=1.0)" 241 | ] 242 | }, 243 | "metadata": { 244 | "image/png": { 245 | "height": 130, 246 | "width": 337 247 | } 248 | }, 249 | "output_type": "display_data" 250 | } 251 | ], 252 | "source": [ 253 | "fig, ax = pplt.subplots(ncols=len(outputs), nrows=1, refwidth=1, refheight=1)\n", 254 | "\n", 255 | "for i in range(len(outputs)):\n", 256 | " nx.draw(outputs[i][2], pos=nx.kamada_kawai_layout(outputs[i][2]), \n", 257 | " node_size=5, ax=ax[i], node_color=COLORS[0])\n", 258 | " ax[i].set_title(f\"{str(outputs[i][6])[:5]}\")" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 32, 264 | "id": "75467dd7", 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "name": "stdout", 269 | "output_type": "stream", 270 | "text": [ 271 | "6/6 latent space check passed ...\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "target_rg = 40\n", 277 | "target_top = \"star\"\n", 278 | "\n", 279 | "outputs = gen_prop_polymer(target_rg=target_rg, target_top=target_top, max_iter=10)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 33, 285 | "id": "35ba2bc2", 286 | "metadata": {}, 287 | "outputs": [ 288 | { 289 | "data": { 290 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABUYAAAEECAYAAADknu4VAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAB7CAAAewgFu0HU+AABKPElEQVR4nO3de3xU1b3//zfJ5Eq4JCHxRhIoIGhVbCGCCgjosdr+2nJOFautCt6qtTereCme1keteirUo205Ry0FrLevoNZWe9FSrtKiQS2IgiQVQkLRJCQKIRCSzPr9kTNjQuayd2Zmz94zr+fjwaPI3ntYpFlZa3/W+nzWAGOMEQAAAAAAAACkkYxkNwAAAAAAAAAAnEZgFAAAAAAAAEDaITAKAAAAAAAAIO0QGAUAAAAAAACQdgiMAgAAAAAAAEg7BEYBAAAAAAAApB0CowAAAAAAAADSDoFRAAAAAAAAAGmHwCgAAAAAAACAtENgFAAAAAAAAEDaITAKAAAAAAAAIO0QGAUAAAAAAACQdgiMxpExRk1NTaqtrdX+/fttP3/o0CHV1dX169lEOnjwoHbv3q2Ghgb5/f5kNwcIK1X74JEjR1RfX68PP/yQPghXS9U+2N7ervr6erW0tCS7KUBCdXZ2as+ePfr444+T3ZReGAfhFak6DgLpIpHjYHNzs2pra/s1n2QumlgERuOgurpaV155pYYNG6aSkhKNGDFCQ4YM0ciRIzV//vyIA1tXV5cWL16sz372sxo4cKDKy8s1ZMgQlZWV6Y477lBra2u/21VcXKwBAwZE/bVx48Y+z7a3t+unP/2pTjrpJBUUFKiiokLHHHOMhg4dqq985St64403+t0uIN7c2gclaf369brssstUUVGh3NxcDR48WBMmTNDdd9+tjz76KOKzL7/8smbOnKlBgwaprKxMxx57rIqKivS1r31N7733XkztAuLJzX0wlOuuu04DBgzQnXfeGfG+5557TlOmTFFBQYHKyspUVFSkkpISffOb31RjY2Pc2wXEmzFGn/vc5zRgwAAtXrw47H2bN2/Wl7/8ZeXn52v48OEaOnSoRowYofvuu0+dnZ1xb5fVPsg4CK9w8zjY37noV77yFUvvkv/1X/8VU/uAREr2OLh3715961vf0nHHHafi4mKNGDFCRUVFOv744/Xtb39bH3zwQcTnmYs6xCAmGzZsMIMGDTKSwv761Kc+Zerq6vo8e/DgQXP++edHfHbMmDHmgw8+sN2ulpaWiJ/b89ff//73Xs8eOnTInHXWWRGf8fl85rHHHuv31w2IF7f2QWOMueOOOyJ+9gknnGCqqqpCPnvnnXdGfDYvL8+8+OKL/WoXEE9u7oOhvPDCC8HPnj9/fsh7Ojs7zRVXXBGxXaWlpWbbtm1xaxeQCA8++GDwe/ZXv/pVyHv+8Ic/mJycnLDf6zNmzDDt7e1xa5OVPmgM4yC8w83jYCxz0fHjx1t6l7zvvvv61TbACckcB3fs2GGOOeaYiP1n2LBhZvPmzX2eZS7qLAKjMejo6DAjRowwkszAgQPNwoULTXV1tWlrazPvvPOOuf7664PftDNnzuzz/Fe/+tXg9a9//evm7bffNu3t7WbPnj1m4cKFJi8vz0gy06dPt922TZs2BT/77bffNtXV1WF/HT58uNez8+fPDz47a9YsU1VVZfbv32/q6+vNr3/9a1NSUmIkmezsbFNTU9Pvrx8QKzf3weXLlwc/e/z48eb//b//Z7Zu3Wr+/ve/m9tvv91kZ2cbSaakpMTU19f3eva5554LPjtp0iSzatUq8/HHH5u9e/eaJ554whx//PHBf3OoSTbgFDf3wVA++OCD4BgWKShz++23B++54IILzGuvvWYOHz5sGhoazOLFi01xcXHwZfXIkSNxaRsQb1u3bjW5ubkRXwj37t1rhgwZYiSZsrIy88c//tG0traaXbt2mRtuuCH47C233BKXNlntg4yD8Ao3j4OxzEWNMcFg7+LFiyO+S7a0tPTnSwckXLLHwbPPPttI3ZvK7rjjDvPuu++atrY2U11dbX7wgx8Yn89nJJmxY8eajo6OXs8yF3UWgdEYvPTSS8Fv1ueffz7kPVdffXXwnh07dgT//PXXXw/++Xe+852Qz65cudJkZGQYSebll1+21bbAQFhSUmLrOb/fH5ywfulLXwp5z9q1a82AAQOMJPODH/zA1ucD8eTmPnjaaacZSeaUU04xBw8eDPnZmZmZRpK55pprel2bMmWKkWTKy8vNxx9/3OfZTZs2Bdv1n//5n7baBcSTm/tgKBdeeGGvlfZQQZn6+vrgy+KsWbOM3+/vc8/WrVvNwIEDjSTzyCOPxNwuIN7a29v77PYK9UJ40003BRe733nnnT7XL7/8ciN178788MMPY26XlT5oDOMgvMPN42Asc9GGhoZg20L9bADcLtnj4NatW4N/7wMPPBDynrvvvjt4zyuvvBL8c+aizqPGaAz+/ve/S5JKSko0a9askPdcccUVwd+//fbbwd+vWLFCkpSbm6uf/OQnIZ8999xzg5/71FNP2WrbP//5T0nSyJEjbT8XqFVx7bXXhrxn2rRpGjFihKTe/ybAaW7tg3V1ddqyZYsk6fbbb1d+fn7Iz549e7ak7toxxpjgtbfeekuS9PnPf16DBw/u8+yECRM0evToXvcCyeDWPhjKokWL9Kc//Unl5eUqKSkJe9/vfvc7HTlyRJK0YMECDRgwoM89n/70p/WNb3wjLu0CEmH+/PnavHmzJk6cGPYeY4yefPJJSdLs2bN18skn97nnRz/6kQYMGKBDhw7p+eefj6lNVvugxDgI73DrOBjrXDTwLikp+N4HeEmyx8HAz4aMjAxdffXVIe8J97OBuajzCIzGIFAod9SoUSG/WSX1msz1vGf79u2SpFNPPVWDBg0K+3fMmDFDkvTXv/7VVtvef/99SfYDoz2L/44ZMybsfYF/V7h/N+AEt/bBngPb1KlTw943ZcoUSVJLS4v27dsX/PPMzExJ6jVBPVpGRkav/wWSwa198Gjbt2/XvHnzlJGRod/85jchXxCPbldJSUkw8BKpXa+++qra29v73TYg3lavXq0HHnhAeXl5euKJJ8Le949//EMNDQ2SpH//938Pec+oUaN06qmnSuo+CKm/7PRBiXEQ3uHWcTDWuWjgXfKYY46J2l8Bt3HDOBj42VBaWhpygU+K/rOBuahzmEnEYPHixTLGBFcDQlm1apWk7knb6aefHvzzwApAtMlcbm6uJKm+vj7iaYZHC6zyfepTn5Ik+f1+7dmzRx9++KG6urrCPjdlyhSZ7hILGjt2bMh79u7dq23btkmSPvvZz1puExBvbu2Dzc3Nwd8fe+yxYe87cOBA8Pc9X/4mTZokSfrjH/8Y8u986623tGPHDklSZWWlpTYBieDWPthTR0eHLrvsMh06dEg333yzzjnnnIj3221XV1dXsD8CyfbRRx/pyiuvlN/v18KFC8PO5SRp69atwd9H2lETuPbOO+/0q012+6DEOAjvcOs4GOtc9Oh3SUlqbGxUfX09ARi4mlvGwTvvvFPGGO3duzfsPYGfDVLvuApzUecRGI0zv9+v/fv36+2339b999+vO+64Q5J03XXX9dq9ecIJJ0iStm3bFvzGD2XTpk3B39fX11tuR2CVLysrS9dcc40KCws1fPhwHXvssRo6dKguvvhibd682dJnGWN0+PBh7d69W88995zOP/98HTlyRGVlZbrxxhsttwlwghv64Ne+9jV1dHSoo6ND2dnZYe/7wx/+IEkqLi7ulVZ43333qaCgQHV1dTr//PO1Zs0a7d+/Xw0NDXrqqaf0xS9+UX6/X2PGjNE3v/lNS20CnOKGPtjTf/7nf+qtt97SaaedFjZVsadAuz788MNeWRSR2lVXV2e7XUAi3HDDDaqrq9OFF14YdXyoqamRJPl8Pg0fPjzsfRUVFZKkXbt2ye/3226T3T4oMQ7C29wwDsY6F+25Y/THP/6xysvLVVpaqrKyMhUUFGjatGl64YUXLLUFcJIbx8EAY4wOHjyo7du36+GHHw6m2F944YW9Fg2ZiyaB41VNU9wJJ5zQq8DvoEGDzN133226urp63ffMM88E7/nv//7vkJ+1ZcuW4EmEksxrr71mqQ1HjhwJFtKO9CsrK8s89thjUT+v5yn1gV9f+tKXzM6dOy21B3CSG/qgFQ8//HDwc2+44YY+19944w1z4oknhu2/F154odm7d2/c2gPEi5v64Nq1a01GRobJyckxb7/9dvDPKyoqwh788tprrwX/vu9+97shP7e+vt6UlpYG73vmmWdstQtIhMcff9xIMsOGDes1PgS+T48+dOI73/lO8P5IfvnLXwY/I9RBSJH0pw8GMA7Cq9w0DkYSaS46derUqO+Sksw3vvGNkAfDAMngxnGwp8Ap9YFf2dnZ5qabbjJtbW297mMu6jx2jCbY4cOHVVdXp48++qjXn3/lK1/RKaecIkm65ZZbdOedd2rnzp3q7OzU3r179b//+78655xzdOjQoeAzkVYSe6qtrQ2my+fl5emee+7Rrl27dOTIEe3evVsPPvigBg8erI6ODl199dURUz/C+eCDD1iVgCckow9GcuTIEc2fPz+4gllcXKw777yzz31tbW1Ra/j2bBvgVsnqgx9//LGuuOIK+f1+3XfffcG/K5ozzjhDF1xwgSTpoYce0vXXX69t27apo6NDTU1NevLJJzV58uRgTSq77QISoba2Vt/61rckSY8++mjE1NmAtrY2SZ+k4oWTl5cX/P3Bgwctt6m/fbBn+xgHkQq8OBcN7BiVpLlz52rz5s06fPiwmpqatGLFCp144omSpEceeUQLFy6MuU1ArNw4DkbT0dGhPXv2BA+/DmAumgTJjsymmp07d5p33nnHvPTSS+a6664L7twcP3682b9/f697a2pqzIgRI8KuwA0cONDccccdwf9+8803LbWhqqrKTJo0yZx55plm3bp1Ye/Jzs42kszMmTMjft6+fftMdXW1WbdunXnwwQfNcccdZySZnJwc85e//MXaFwZwiBv6YDh//OMfe+1+KSgoMGvXru1z3+9//3vj8/mMJDNhwgTzxBNPmLffftu89dZb5tFHHzVjxowxkkxxcbHZunVrTG0C4s0tffBrX/tacIw7ejdLtN1qTU1N5vTTTw/bLp/PZ+66667gfz///PP2v1BAnHR1dZlp06YZSWbu3Ll9rge+T4/eKXPttdcaSeb444+P+PmPPPJI8DP27NljuV2x9EHGQXiZW8bBUKzMRTs7O82ZZ55pJk2aZH7+85+H/Jzm5mYzatQoI8kMGTLEHDhwIKZ2AbFw6zh4tPr6erN9+3bz8ssvm1tvvdXk5uYaSaa8vNzU19f3upe5qLMIjCbYb3/72+A364IFC/pc37dvn7n55puDwUZJJjc318yePdvU1NT0er6mpiaubbv66quNJJOZmWlaWlosP9fc3GxGjhxpJJlTTz01rm0C4s0NfbC2ttZ86Utf6jWYjR8/PuTLXEtLixk8eLCRutMEjxw50uee/fv3m8985jNGkjnppJP6pGYBbpKMPvj0008bSWbo0KGmrq6uz3UrabxtbW3m7rvvDo53gUnohRdeaN58803z1ltvBf985cqV1r8gQJzde++9RpIZOXJkn6CLMeFfCL/3ve8ZSaaoqCji5z/00EPBz/joo48stSmWPsg4iFTjtbmoVYG0ZUnm97//fb8/B4iVG8dBK6qqqoKLgDfeeGOf68xFnUNg1AGBSP+5554b8b7GxkZTV1fXawL4i1/8ItgBQk0MYxGYtPZn9fHBBx8MPhtqwgu4STL74LJly8zAgQOD/aW4uNj8/Oc/Nx0dHSHvX7RoUfDe6urqsJ/7yiuvBO975ZVXbLcLcJKTffCjjz4yQ4cONZLMU089FfIeK4HRnpqbm83u3bvN4cOHg3/24osvBvvg+++/b+lzgHirrq42WVlZJiMjw7z66qsh7wn3Qnj33XcbSWbAgAER+9YPfvADI3XXprcSgIy1DzIOIhV5aS5q1d69e4Of+cADD8T0WUB/uXEctGPWrFlGkhk1alTE+5iLJhY1RmPw4x//WHfddZfeeOONiPeddNJJkqQ9e/ZEvG/YsGEaPny4srKygn+2fft2SdKoUaN6/Xk8DBs2LPj79vZ2SdLvf/973XXXXfr1r38d8dnAv0mK/u8CEsXtffD222/XnDlzdPDgQWVkZOi73/2uqqur9e1vf1s+ny/kM4G/r6SkRKNHjw772WeffXbw9++8846tdgHx4sY+2NLSEqzjdtlll2nAgAF9ftXW1kqS7rnnnuCfPfjgg2E/s7CwUGVlZcrJyenTrry8vOBppYDT6uvr1dHRIb/frylTpoT8fg+49tprg3/2wgsvaMyYMZIkY4x27doV9u8IXBs1apQyMqK/OsTaBxkH4SVuHAd76s9c1KpQ75KA09w4DkrSz3/+c911113661//GvE+qz8bmIsmVmw/DdPcAw88oI8//lj5+fmaMGFC2PsCBXp7Fu3917/+pVWrVkmSvvjFL2rIkCEhnw10pJkzZ1pu13PPPacDBw7opJNO0qRJk8Le19TUFPz9McccI0latWqVHnroIZ100km6+uqrwz7bs+hwz38X4CS39kFJevDBB/XTn/5UklReXq4nn3xSU6ZMifqcMabX/4bj9/tD/h5wkpv7YH99/PHHevHFF4N/5/HHHx+xXdOnT7c8SQbc5PTTTw/+/rXXXgu+IB7t9ddflySNHz/eiWYxDsJT3DwO9ncuumHDBlVXV6uwsFBf/vKXw94X6l0S8JJEjoNLlizR5s2bdcMNN+jcc88Ne1+onw3MRZ1HYDQGo0aN0ptvvqmNGzeGvaejoyN4vedpnK2trbr88sslSb/61a90zTXX9Hn21VdfDa4C/Md//Ifldj3xxBN64YUXNHbs2ODzobzyyiuSpOOOO04jR44M/pskqbq6Ws3NzSoqKgr57Pr16yVJOTk5EVfzgURyax9sbGzU/PnzJUkjR47U2rVrVVZWZunZcePGSeqebO7YsSN46ufRNmzYEPz92LFjLbcNiCc39sHhw4eruro64j3Tp0/Xnj179M1vflM33XSTpO7daZKUlZWluXPnqrOzU/Pnz9dPfvKTPs/v2rVLf/nLX2y1C0iESZMmRf1+D7zo3Xvvvbr44osldc/9Bg4cqJEjR2rnzp1avny5vv71r/d5dsuWLaqpqZEkfeELX7DUplj7IOMgvMSN46AU21x006ZN+t73vqfMzEzt2rVLw4cPD3lf4F1S6r2DG3CSG8dBqftnw+bNmyP+bJA+iav0/NnAXDQJkpnH73W33nprsNbEG2+8EfKenieF/eEPf+h17eSTTzaSzOjRo01ra2uva62trWb8+PFGkpk4caKtdj377LMRC3wbY8zatWtNVlaWkWR++MMfBv/83XffDT77/e9/P+SzW7duDdapufjii221DYgnt/bBQC0oSebPf/6zrWfr6upMdna2kWQuuOCCkPWfDhw4EKxVVVJSYtra2mz9HUC8uLUPRhOtxujnP//5YDH+o08f7ejoMBdccIGRZE444QRz6NChuLYNiLdA/zu6tpoxnxxYMWDAALNu3bpe17q6usyFF14YrEkYz1OnI/VBxkF4iVvHwVjmov/617+C74lf/OIXQ9ZU3Lt3b7Afn3POObY+H3BaMsbB//mf/4l6ONmyZcuC9yxatKjXNeaiziIwGoPa2lpTUFBgJJkhQ4aYe+65x2zcuNFs27bN/PnPfzZf/epXg9/oX/jCF/o8v3z58uD1008/3fz2t781W7ZsMcuXLzennXaakWRycnLMxo0b+zxbV1dnTjjhBHPCCSeYiy66qNe1rq4uM3ny5GAH//rXv27+/Oc/m3feecesXbvWzJs3z+Tk5Bip+yTPo09u63li4ezZs82f/vQn8+6775rXXnvN3HvvvcGC+kOGDIlYFB9INLf2wUsuuSTYph07dpjq6uqov3r64Q9/GGzXxIkTzVNPPWW2bt1q/vGPf5hHHnnEjB49Onj9N7/5TXy/qIANbu2D0UQLjG7cuNFkZGQYqfuE0yeffNJs2bLFvPjii2batGnB8fX555+39fcCyRDphXD//v2mvLw82IcXLVpktmzZYlauXBl8KZNkfvGLX/R5NpF9kHEQXuHWcTDWuWgg4CvJnHnmmWbFihVm8+bNpqqqyjzwwAPmuOOOM5JMfn6+2bx5c3y/qECcJWMcPHDggDn++OODffi2224z69evN9u2bTOrVq0yN9xwQ3Cu+ZnPfMa0t7f3ep65qLMIjMbopZdeCg6G4X59/vOf7xN8DLjtttvCPjdw4ECzYsWKkM/t3LkzeF+oVbo9e/aYT3/60xHbNX78eFNbW9vn2aamJjNx4sSIzx577LFmw4YNMX3tgHhwYx+cMGFCxPaE+tWT3+83N910U8T7s7KyQg7QgNPc2AejsXIq/aJFi4IT0qN/+Xw+89BDD9n6O4FkifRCaIwxb731liktLQ3bD2+44YaQzyWyDzIOwkvcOA7GOhc9cuSImT17dsT7i4uLzapVq+LyNQQSKVnj4Ouvvx7xcyWZSZMmmX/9618hP5+5qHMIjMbBrl27zPe//31z6qmnmoKCApOVlWWOP/54M2vWLPPb3/7W+P3+iM+/+OKL5rzzzjNDhgwxubm5ZtSoUebGG2+MuBvTymT08OHD5pe//KU555xzTGFhofH5fKa4uNjMnDnTPProo+bIkSNhP7+9vd088sgjZubMmaa4uNj4fD4zdOhQM2nSJHPPPfeY5uZmS18bwAlu64NjxoyJaTIasHHjRnPllVeakSNHmpycHJOfn2/GjRtnbrzxRvPee+/Z/joBieK2PhiNlcCoMca8+uqr5stf/rIpLi422dnZpqKiwlx55ZXmzTfftPX3AckU7YXQGGM++OAD8/3vf9+MHj3a5ObmmuLiYnPeeedF3IniRB9kHIRXuG0cjNdc9He/+52ZNWuWOe6440xWVpYZNGiQ+exnP2t+9KMfmX379tn+OgHJkMxxsKGhwfzwhz80EyZMMIMHDzY+n8+UlJSYz33uc2bZsmUhy8X0xFzUGQOMiXLkIwAAAAAAAACkmIxkNwAAAAAAAAAAnEZgFAAAAAAAAEDaITAKAAAAAAAAIO0QGAUAAAAAAACQdgiMAgAAAAAAAEg7BEYBAAAAAAAApB0CowAAAAAAAADSDoFRAAAAAAAAAGmHwCgAAAAAAACAtENgFAAAAAAAAEDaITAKAAAAAAAAIO0QGAUAAAAAAACQdgiMAgAAAAAAAEg7BEYBAAAAAAAApB0CowAAAAAAAADSDoFRAAAAAAAAAGmHwCgAAAAAAACAtENgFAAAAAAAAEDaITAKAAAAAAAAIO0QGAUAAAAAAACQdgiMAgAAAAAAAEg7BEYBAAAAAAAApB0CowAAAAAAAADSDoFRAAAAAAAAAGnHl+wGAAAAAAAAAG7T2eXX+uom1e47qIrigZo6Zph8mewxTCUERgEAAAAAAIAeahpaddWyKu1ubgv+WXlRvpbMqdTo0oIktgzxNMAYY5LdCIBVGAAAAAAA4AadXX7N/NnaXkHRgPKifK2+ZboyMwYkoWWIN3aMIulYhQEApDMWBwEAANxlfXVTyKCoJO1ubtO6HY2aMa7U4VYhEQiMIqk6u/x9gqJS9w+aq5ZVsQoDAEhpLA4CAAC4T+2+gzFdh3ewHQFJZWUVBgCAVBRtcbDLT7UjAACAZBhemBf5elG+Qy1BohEYRVJtqm2JeJ1VGABAqnqmqo7FQQAAAFeKnLk6gPXrlEFgFEnT2eXX8qq6iPdUFA90qDUAADins8uvha+8F/EeFgcBAACSo74l9OJ1QF2U6/AOaowiaZ6pqlNja3vY6yWDcjTtxBIHWwQgHA6HAeJrfXWTWto6It7D4iAAAN2Yi8JppNKnDwKjSAorO2VmTxjOwUuAC3A4DBB/0UrJFOZnsTgIAICYiyJZSKVPFyyxICnWvNcQdafMxBFFDrUGQDgcDgPEn5VSMvPOH8viIAAg7TEXRbLsbo5yKn2YOvHwHgKjcFxNQ6vmPbsl4j3slAHcYX11E4fDAHEWrZTMsIJsXXJGuYMtAgDAnZiLIln2H+6Mcj3yRi94B4FROCqw4hdttyg7ZQB3iJbuy+EwgD1WSslcMrGMMRAAADEXRfIMyolceXJwLpUpUwWBUTgq0opfQFlhHjtlABc4fKRTj2/cFfEeDocB7LFy6BKlZAAAsFZ6hrkoEqWiOPLhSuV876UMAqNwlJXDJpbOPYOdMkCS1TS0aur9a7T/UPgUkpJBOZS8AGzi0CUAAKyJVnqGuSgSi8OX0gWBUTjGyorfwovGc7IgkGSBkheRJqKSNHvCcBYxABs4dAkAAGuslJ5hLopEqm+JnOlaF+U6vIPAKBxjZcVv+rhSB1sEIBQrJS8k0n0Buzh0CQAAayg9g2QbXpgX+XpR5FR7eAeBUTiCFT/AO6Kl+kpSeVE+qUuADRy6BACAdZSeQfKRSp8uCIzCEaz4Ad5gJdW3ZFCOlsypJIAD2MA4CACANZSegRtES6V/s+4jZxqChCMwCkew4gd4Q7RU38F5Pm24bSa1gAGbGAcBALCG0jNwg4oop84vr6pTl59to6mAwCgSjhU/wBuspPpePqlC2T6GDsAOxkEAAKyh9AzcYuqYYSoZlBP2emNru9btaHSwRUgU3m6RcOurm1jxAzyAVF8gMRgH3aWzy6/V2xu0bMNOrd7eoM4uf7KbBAD4P2vea2A+ClfwZWbo4gknRLxnZ9NBh1qDRPIluwFIfdHSB1nxA9yBVF8gMRgH3aOmoVVXLavS7uZP6oaVF+VryZxKSoQAQJLVNLRq3rNbIt7DfBROyvZlRry+/3DkID68gR2jSCgr6YOs+AHJR6ovkBiMg+7R2eXvExSVpN3Nbbr0VxvV3tGVpJYBAAI/o6PtFmU+CicNyom8l3BwLnsNUwGBUSRUtMLZJYNyWPEDXIAi90BiMA66x/rqpj5B0YDGA+2a8tPVqmlodbhVAACpe7wM9zM6oKwwj/koHFVRnB/xenmUA5rgDQRGkTBWCmfPnjCcFT8gyShyDyQG46C71O6LXAessbVdFz/8N/313Q+pOwoADrIyXhbmZ2np3DMYM+GwyN9vAziUPiUQGEXCcJAL4A0UuQcSg3HQXSos7OpoaevQ1b/ZpMp7VuqpjbUESAHAAVbGy4UXjacWNBxX3xJ5F3NdlOvwBgKjSBgOcgHcjyL3QOIwDrrL1DHDVF4UOSUuoKWtQz94YatmLFxDej0AJJiV8XL6uFKHWgN8YnhhXuTrFucVcDcCo0gIDnIB3I8i90DiMA66jy8zQ0vmVKpkUI7lZ+paDnEwEwAkEOMl3I1U+nRAYBQJsb66iYNcAJejyD2QOIyD7jS6tEDr501XSYH14CgHMwFA4nAAKNxsd3Pk+uS1Ud6l4A0ERpEQ0dIhOMgFSC6K3AOJxTjoXrnZPj193WTLafUSBzMBQCJwACjcbv/hzijXI2fewRsIjCLurKRDcNgEkFwUuQcSh3HQ/UaXFmjVzedo8RUTVJifZekZDmYCgPjiAFC43aAcX8Trg3MjX4c3EBhF3EVLhygZlMNhE0CSUeQeSBzGQW/wZWbovJOP1Yrrz7K1e5SDmQAgdhwACi+oKI48PygvHuhQS5BIBEYRV1bSIWZPGE46BJBEFLkHEodx0HsCu0fvmXWKCqLsDOmpruUQ6fUA0A8cAArv4PCldEBgFHFlJT2XdAgguShyDyQO46A3+TIz9LXJFdo0/1xbBzMF0utn/mxtyN2jnV1+rd7eoGUbdmr19gYCqAAgDgCFd9S3RP4+rYtyHd5AQQTElZX0XNIhgOShyD2QWIyD3hY4mOmqZVVRX9p72t3cpkt/tVGv3jpDOVmZkrrTRI/+nPKifC2ZU0n9ZgBpiwNA4SXDC/MiX7dRigfuxY5RxA3puYD7UeQeSBzGwdTQn4OZJKnxQLsq712pBX/erpXvfqC5S1/vE1zd3dymq5ZVqctP7h2A9MQBoPAWUunTAYFRxM366ibScwEXo8g9kFiUqUgd/T2Yaf+hTi1a809d85s3VNdyKOQ9u5vbtG5HY7yaCgCewgGg8BJS6dMDgVHETbRBjvRcIHkocg8kFmUqUlPPg5ns7B6Npnbfwbh9FgB4BZkV8BpS6dMDgVHEhZVBjvRcIHkocg8kFocupa7AwUxV88+znV4fTkXxwDi0DAC8hcwKeE/kIP3ej0Jnh8BbCIwiLqINciWDckjPBZKEIvdA4nHoUurrmV5fMsj6yfVHKy/K53sBQNohswJeFC2V/oG/7KBueAogMIqYWRnkZk8YziAHJAlF7oHEIjUwvYwuLdD6edNVUmA/ODo416clcyr5XgCQdsisgBdFy/BoaeugbngKIDCKmDHIAe5GkXsgsTh8MP3kZvv09HWTbR3MJEl+I+3ad1CdXf4EtQwA3InMCnjR1DHDopbQoW649xEYRcwY5AD3YicbkHgcPpieAgczLZ1TqRtnjLa0g7S1vVPXPLZJZ963Sgtefk+rtzcQJAWQ8g4f6dTjG3dFvIf5KNzIl5mh7//biRHv4QAm7/MluwHwNoIugLtR5B5ILA4fTG++zAzNGFeqGeNKddN5Y7S8qk4/fuldHe6MHOxsbG3XotU1krprji6ZU0k5EwApqaahVZc+ulH7D3WGvYf5KNzsuCGRT6YfQIlRz2PHKGJC+iDgXhS5BxKPwwfR08Pr3o8aFD3a7uY2XfroRrV3dCWoVQCQHJ1dfl21rCriOCkxH4W7RTuAqS7KdbgfgVHEhPRBwL2o/wskFocPoqf11U3a3dy/l6PG1nZNuX+1ahpa49wqAEgeqz8XmY/CzYYXRt4xSiq99xEYRb+RPgi4G/V/gcRi8QE9xXr4QuOBdnaOAkgp0eaiUnc5EeajcLfIC9yk0nsfgVH0G+mDgHtR5B5IPBYf0FNF8cCYP4OdowBShZVNNCWDcrRkTiXzUbgaqfSpj8Ao+oX0QcC9ahpaNfX+NRS5BxKIwwdxtKljhqk8TDqdz8b3QeOBdl388N+0ctuHnFgPwLOibaIZnOfThttmcvAcXI9U+tRHYBT9Qvog4E4UuQecweGDOJovM0NL5lT2CY6WF+XrD9+ZqqVzKnXj9FEanOeL+lktbR265rFNqrxnpZ7cWEuAFICnWNlEc/mkCmX7CEfAC0ilT3XRZ2ZACKQPAu5EkXvAGRw+iFBGlxZo1c3naH11k2r3HVRF8UBNO7FEmRkDNPbYQZoxrlTfnjlaUxesUeOByAtYUneAdP4LW/XIuve1ZE4lO6sAeAKbaJBKSKVPfSzRwDbSBwH3osg9kHhWavjywpe+fJkZmjGuVHPOHqkZ40r7zIdys316+trJKinIsfyZu5vbOJgJgGewiQaphFT61EdgFLaRPgi4E0XugcSzUsOXwwcRzejSAq2/dbpKBlkPjnIwEwAvYBMNUg+p9KmOwChsI30QcCeK3AOJZbWGL4cPworAztFwBzaFwsFMANxuzXsNbKJBStndfDDi9VoLZczgbgRGYYuVFUDSBwHnUeQeSDxq+CLeAjVJF18xUYX5WZaeCRzMNPNna9k9CsBVahpaNe/ZLRHvYRMNvGb/4fBZQt3XI9fThfvxhgxbou1II30QSA6K3AOJRw1fJIIvM0PnnXyMVlx/lsqi1DHribqjANwkkFXBfBSpZlBO5DPLB+dyprnXERiFZVZ2pJE+CCQHRe6BxKKGLxJtdGmBVt8yXffOOkUFOZmWnmlsbVflvSu14OX3tHp7A+n1AJLmmaq6qFkVLB7CiyqKI5e8KS8e6FBLkCiEtmEZO9IAd7JyQjZF7oHYRDt4MFDDl3IViIUvM0OXTa7Qf3z2BE1dsEaNByLXs5Wk/Yc6tWh1jaTuoMOSOZXUkgbgKCsbaArzs1g8hEdx+FKqY/YOy9iRBriPlROyKXIPxC7aGEgNX8RTfw5mkkivB5AcVjbQLLxoPIs28KT6lsg7oeuiXIf7MYOHJVZSCNmRBjjL6gnZFLkHYsPBg0iG/hzMJHWn10+5fzUHMwFwjJUNNNPHlTrUGiC+hkep/z3c5iIm3IfAKCyJlkLIjjTAeZyQDTgj2hjIwYNIlJ4HM5UU5Fh+rvFAOztHATiCDTRIfaTSpzoCo7Ak2iogO9IA53FCNuCMaH2NgweRaKNLC7T+1ukqGWQjOMrOUQAOeKaqjg00SGnRUunfrPvImYYgYQiMIipSCAH34YRswBmMgXCL/tQdbTzQrosf/ptWbvuQE+sBxJ2VQ5fYQAOvq4hy6vzyqjp1+dk26mUERhEVKYSA+1g9IZsi90Bsou2EYQyEkwJ1R5fOqdSN00dpcJ4v6jMtbR265rFNqrxnpZ7cWEuAFEDcrHmvIeqhSywewuumjhkWMWOjsbVd63Y0OtgixBuBUURFCiHgPpyQDSSelZ0wjIFwmi8zQzPGlWreBeP0+h3nWk6vb2nr0PwXtmrmz9aSXg8gZjUNrZr37JaI9xTmZ7F4CM/zZWZo9oThEe+p3XfQodYgEXhrRkSkEALuQ78EnLG+uomdMHC1QHq9nYOZdje3kV4PICadXX5dtawq6hjJoUtIFZ8pHxrxOifTexuBUURECiHgPvRLwBnRdmazEwZu0J+DmQLp9eweBdAf66ubtLs58oE0ZYV5HLqEFMLJ9KmMwCjCIoUQcB/6JeAMKzuz2QkDt+jPwUxS9+7RSx/dqPaOrgS1DEAqsrJwuHTuGYyRSBnRTqavi3Id7kZgFGFRTBtwH/ol4IxoB5wNK8hmJwxcJXAw0+IrJqowP8vyc42t7Zpy/2p2jgKwxMrC4cKLxnMAKFLK8MK8yNdJpfc0AqMIiWLagPvQLwHnRNsNc8nEMnbCwHV8mRk67+RjtOL6s1QW5SWup8YD7ewcBWCJlZJO08eVOtgiwAmk0qcyAqPog2LagPvQLwHncMAZvG50aYFW3zJd9846xfLu0cbWdlXeu1ILXn5Pq7c3cDATgD4o6YR0tbs58qnztVFq7sLdfMluANyHYtqA+zxTVUe/BBwSLY2eA87gBb7MDF02uUKzK8u05r1GzXt2c9TFtf2HOrVodY0kqbwoX0vmVJIOCyBofXUTJZ2QlvYf7oxyPXK/gLuxYxR9UEwbcBcrq/P0SyB+oo2D7IaBl/RMry8psH5yPQczATialfdEFg6RigblRN5TODiXPYdeRmAUvVBMG3AfK6vz9EsgPkijR6oaXVqg9bdOV8kg68FRDmYCEGBlfKSkE1JVRXHkw5XKiwc61BIkAoFR9GIlfZBi2oCzrKzO0y+B+LByqAS7YeBVudk+PX3tZFs7RxsPtOvih/+mlds+pO4okMbWvNcQcXwcVpBNSSekMA5fSmUERtEL6YOAu7A6DziHQyWQDvqzc7SlrUPXPLZJM3+2lt2jQBqqaWjVvGe3RLznkolljI9IWfUtkc96qItyHe5GYBRBpA8C7sPqPOAcDpVAugjsHC0vipwaeDTqjgLpp7PLr6uWVTE+Iq0NL8yLfN3meAp3oUIsgjiFF3AXVucBZ3GoBNLJ6NICrbr5HK2vbtKmXc16/LVa7T8U+dRdqbvuaOW9K3X55BGaWFGoqWOGyZfJXgsgVT1TVafdzZF3w5UX5TM+IsWRSp/KCIwiiDR6wD1YnQecRdkKpCNfZoZmjCvVjHGl+vbM0Zq6YI0aD4RfJA/Yf6hTi1bXSOoOiCyZU8kBgEAKslJipjA/S0vmVDI+IqWRSp/aWN6FJNLoAbdZX93E6jzgoGhZE5StQKojvR7A0ayUmFl40XgWRpDySKVPbQRGIYk0esBtrKT0sjoPxE+0PkfZCqSDQHr94ismqjA/y/Jzja3tmnL/ag5mAlKMlfno9HGlDrUGSCZS6VMZgVFIIo0ecBMrO7hZnQfih6wJ4BO+zAydd/IxWnH9WSopsH5yfeOBdl388N+0ctuH6uzyJ7CFAJxAiRngE6TSpzYCo+CFEHCZZ6rqou7gZnUeiB+yJoC+RpcWaP2t01UyyHpwtKWtQ9c8tkkzf7aW3aOAx1FiBvgEqfSpjcOXwAsh4AKdXX6tr27SzqZWPfCXHRHvZQc3ELtAn6vdd1BVZE0AIQXqjl61rCpq3eueAnVHX71thnKyMhPYQgDxZGdspMQM0gup9KmMwGia4oUQcI+ahlZbL53s4AZiQ58DrAvUHV1f3aRNu5r1+Gu12n+oM+pzja3tmvCTlbrjwnG6pLJMvkwS1QA3Y2wEwiOVPrURGE1DDHqAe3R2+W31x8L8LHZwAzGw2+fImgC6647OGFeqGeNK9e2ZozV1wRo1HgifbRTQ2t6p+S9s1SPr3teSOZXUxg6j54aFiuKBmjpmGIFkOIqxEYiMVPrURmA0zTDoAe6yvrrJVnoiRe6B2Njtc2RNAL31J71+d3ObLn74b1pw8XhNP7GEoJ8+CYZu2tWs5W/U9wo0lxflE0iGoxgbgWhIpU9lBEbTDIMe4C61+w5avresMI8i90CM7PQ5iawJIJRAev2a9xo179nNamnriPpM4GCmdA76RQqG9rS7uU1XLavS6lumMw+HIxgbgchIpU9tLNemGQY9IHk6u/xavb1Byzbs1OrtDTp8pFMfHYr+Mil1p9AvnXsGL0hAjCqKB1q+t7won6wJIAxfZobOO/kYrbj+LJUUWD+5PnAwU3tHVwJb5y6dXX49tbFWlfes1NxlVVq05p9RSxHsbm7Tuh2NDrUQ6Y6xEYhseGHkVPlNu1q0enuDOrv8DrUI8cSO0TRjZ9AjjR6wpmdtsO5B06i+5ZAqigfqzE8V6e/vN4fcHeLLGKBOv7W8i4UXjU/L3TVAPHV2+dXl96swPyvqDrfC/CwtmVPJYgQQxejSAq2/dbrluqNS98FMU+5fraevnZzyY1tNQ6vmLn1ddS2HbD9rd0MDEMrRNWwDc9Oe/93l92twrk/7D0c+WI2xEemopqFVP37p3Yj3vPT2Xr309t60zorwMgKjKe7ogbCyYqhKCnLU2Bp94koaPdJdfwOePUUKfnb6/YpWr0aShg3M1vRxpTH8S4D0Ea7fZmdm6OF171suJ8NiBGBdoO6onQBg44H2lK472tnlt1VqIBQ7GxqQfiIFPAPj31u7P4q6MM9CPRBa98/xBs17dovln+OBrIhXb5uhnKzMBLcQ8TLAGEOZWI+z8xJoZ+BbOqdSMwjGIEVYCXJGm0j2ZLkvGSMNiG2BYXCuT89/82wmosD/CfcyGG2hwqphA7P12vzzWBwEbOrs8mt5VZ0WvPKerWBgSUGOZleWaWJFoadPZLdaQ9SK8qJ8aoymqXgFPC2zOFctK8zTmnkz+J5EWqhpaLV1yODRBuf5dPnkEZ4f19IFgVGP6v/Ey8jKDjVfxgD98TtTdeKxg2JqJ+CEcEHPuAc5e4pDwNOOkoIcVh6RdkL17bi+DIbBYgQQm1h2S3o1DTHWl+ievPo1gD3JGuP667ghuXr86kl8XyKl9WeXaDT8THc/AqMuFq90wP5ipRrJFo9UdmusLRgkW8mgnLSox4b0EpexLgELFSxGALGraWjVpY9utFTCqScv9b94pMxLvXfNTjuxhPm3hzk3f3Wel/omYEdnl1/PVNVpoc2MB6sK87NStnRMKiAw6gJWVwuTYfEVE3XeyccktQ1ITUnZ5ZkCWLCAV7l5rAuHxQggdoePdNo6mCnAzf0vXinzBEO9h/mru/sm0B/x3PEfTaqUjkk1BEYdkuzdn/1VmJ+lFdefxcAHW2KdNFricCp7Ihw9OR5WkK2m1iNRn2PBAm6V6NqfycBiBBC7/r50unGHTTxeoAvzs3TL+WP11TPK+dniEo7t8kyB+avEzlGkhljS5ksG5ejfxpXqpa17tf9QZ7/+/sBYcEllmWvGuHRFYDTOvLgjJhoGPgREC3h6IdjvtKODnz1XCc8aVay//XNfMIh01qhi/dt/r4v6tWPBAsnmtbpo4VhdjPjeeWP0rRmjmbQCMYgl5dwNL4/xSJkPBHpnjC0lIOogdnn2Fu1U+sG5Pu0/HD3Qw85ReFFwx39ti5ZX1dku9VKYn6WFF43X9HHdP8f7mxXREzVIk4/AaD94eUdMnx1qA7N1pMsfdfBj4EtNVgKdqRLgj7dwAc/PlA2VBkj1zW3BYGfP4Ge0VDmr9dhYsIATYl7sc9nOmLLCPF1/zih1dPltLUZITFqBeIllx2Wy+mE8donyMyTxErZBxWVjWTiRArjRFub7s1AvkVUBb4n1Z3m4n+PxyiRwW4ZEOiEwalGkWkJuXkUM9RJ4dJBmZ9NBXfzw36KufjPwpYZ41cVKJVYmkrEGPO2wuvLIggXixevZDlYXKkL1UzuHw7AgAcRHz7nI46/V2kpDdPrlMZbdQNQQjY90z1jqT8BzeGF+1PEvEjtjIyWe4AWHj3Rq6v1rbO8QlfruEg0lEbWnqUHqHAKjR+nfy2HyT7S28xIYysp3P9A1v3kj6n2kE3qbk4Wl3chKkDPWiWS81DS0WlqwIFADq7yc7RBKpN0vdvutncAHCxJAfPU38OjEDkw7waGeSJm3zpG69B6UqICnHVb7JiWe4FY90+Yf37irX7VAywrztHTuGba+v+NRekVyRxmZdJG2gVGv7445eidorINiZ5dfM3+2lnTCFGbn/2MvsLJDzC1Bzv6yumBBoAY9pUr9T6n3WJfo/mxn4YgFCSC+3HY4UywvtcyTo0vn7KVYMhySgRJP8KpYNwQV5mdp3vljdUkMh+TFa1MS40ripXxg1OsBUCcHS9IJU9vq7Q2au6wq2c2wLRmp7G5hd8GCUhfpJZXqf7rlxbCzy69frq7Rgyuro97LggQQX24JRvbnRZaUeevSJXspleavlHiC1yQ6bd6OngtBz2yqs3ToZ7h2UYM0cVIiMBouBcMrtWac3BETDemEqWvZhp2668V3k92MXlJxl2e8UeMJXl7g89rOGMneggSLhED81TS0au7S11XXcsjWc/Hoj3bT+kmZt8er2UvpkKUUDSWe4BU1Da269Fcb+zVHTvTOzM4uv5ZX1WnBK+/1O8WeGqSJ4anAqJdfDiVvvBBKpBOmKid3jIYL9qf6pDFRqPGUulKxtlo8a38mi60MChYJgbjr78vj4DyfLp88ol8vjHbridL37XN79lIq7fJMBEo8wc26sw4aNO/ZLfbHjUkVmjiiyLG+TQ1S93FlYNRrAVAv7oiJhnTC1BOvVfp0XCV3A2o8pY5Uqa2WCmNdJHZ2jlHKAkgMp9Lr7c6RSgpytOH2mcr28SJqRzKzl5i/xo4ST3Cjzi6/nqmq08J+7MIsGZSjDbcl72c5NUjdwzWBUa+9KKbCjphoSCdMPaF++EabKDJpdA9qPHmfF2urpXoANBJKWQDukMjDmexsBpB4AY1FoneMuqk8WapiXIQb9DxtfnlVXb9qibrlZzk1SN3B8cCo13aD9pSOhdVJJ0w9PfsgE0XvocaTd7mttloqZjskAqUsAHdIxO5ROwFX6onGLpZxkB2f7sG4iGSKZZNBMtLm7YhHDVLS6/vH0cCoV3bK8HLYG+mEgLtQ48mb3FBbLR2yHRKBUhaAe9itBRpw9G4aW5lRjKdx05/sJcYo92FcRDLEctp8stPm7YhHDVIOabLHscCo23bKSARA7SBtAnAPajx5k5O11Rjf4o9SFoB72D09vqfA7tFdTa3WFhmpJxp3ZC+lBsZFOKFn2vzjG3dp/6FO25/hlrR5u+K1sZBdpNE5FhhN1k4Zas3ED2kTgHuwWOE9iRgHCYA6y2opC+o8AYkXUzplrk9dfr8OHvFHvM+rL9OAUyjxhESKNTBYmJ+lhReN1/Rx3i2BEq8apBJjWiSOBUYTvVOGl0NnWA3G8FIIJB6LFd4S79pqjHHJYbWUhcQEFEi0RB7e+r3zxujbM8fwMxaIghJPSIRY0ual7g1yS+eekVLfb/GqQUqcpi/P7Rjl5TD57NYc5aUQSBxqPHkLtdW8z26Am74HOCMeNdkCMgZI35j2KZ0xspjabEAUlHhCvMQjbb4wP0vzzh+rS84oT9nvM2qQxp9ra4wSAHU3q2kTEi+FQKJR48lbqK3mfXYPf6HvAc6J92GvLPID0VHiCbFK5dPmE6WmoVVzl76uupZDMX0ONUhddip9z6h1On1De5WddEJeCoHEosYT4Cy7h7+wSwZwTjx3j0qf1Kk7Zyyph0A4lHhCf6XLafOJEI/0+oB03kXqaGBU6r1ThlRBbyOdEHAXajwBzrK7u4FdMoCz4rWbJqBkUI5mTxiuiSOK0u6lEbCCEk+wo7PLr2eq6nTfn7artT19TptPhHge0iSl39fW8cAoUovdl0ICMkDi2FmsYEIKxIednWnskgGcZyfDyY50e2kErKLEE6zob+p8uqbN2xGvXaTp9L5IYBQxs5uulE4dDHCanRpPTEiB+GGXDOBOv17/vu7+w7aw13N9GTrc6e/XZ5NiD4RGiSeE0x07aNC8Z7fYDtqle9q8XfEoKzM4z6fLJ49I+fR6AqOIGwIygDvYqX1I3UMgftglA7hLTUOrLv3Vxoh98p5Zp+iRde/HdFgTKfZAX5R4Qk+BtPmF/dzFyC79/ovXoYSpfEgTgVHEFQEZwB04HRRIDnbJAO5gpbxMYC5qjOmuzVbbouVVdf06ACSAICnQjRJPCCBtPvniWYM0FYPUBEYRdwRkAHfgdFAgOdglAyTf6u0NmrusKuz1koIcPX1d3/4XS5pnn7+DICnSHBmF6Y20eXeKRw3SwvwsLbh4vKafmBqlZAiMIiEIyADuQN1DwHnskgGSb9mGnbrrxXfDXv/h/3eyrpoyMuz1eKUeBhAkRbqyk1HImJgaYk2bLyvM09K5ZxAjSLCeu0gff61W+w912v6MVEmvJzCKhCEgA7gDdQ8B57FLBkiule9+qGt+syns9V9fMVHnRslaCr40xiHFvqdUTEMEImFMTH3x+HlZmJ+leeeP1SVnlJM27zA7CxihlBTkaHZlmWcPaSIwioQiIAO4A3UPAeexSwZInmglLRZfOVHnnWS9nFM8U+wlau0j/XAWReqKdYd9YX6WFl40XtPHlfL/eRKl8yFN3mglPCs326enr52swvysiPc1HmjXpY9uVHtHl0MtA9LL6NICLbjotKj3Nba2a8r9q1XT0OpAq4DUFhgDSwpyot5L3wPi663dH0W8Xm/zxc+XmaHzTj5WK64/S+VF+TG0rNvu5jat29EY8+cAXmFnTNzd3KbV2xscaBVidfhIpy59dGO/g2llhXlacf1ZOvfkYwiKJtno0gKtuvkcLZ1TqRunj9Kwgux+fU5LW4fmv7BVM3+21jPzWnaMwhEcRAEkH3UPgeRg5yjgrM4uv878r1UR+9zSOZWaMa60358fjxT7u754suacHb7OKZCKOIsiNQTqiN73p+1qbe9fbUrS5t0tnQ5pIjAKRxCQAdyBGk9ActD3AOdYOZF+4w/OjcvLeCxB0liCs4CXcRaFt8WSck3avPd0l5Jp1LxnN8cUIHVzej2BUTiGl0LAHdi9BiQH9dUAZyz483YtWvPPsNdvnDFa8z43Nu5/r50gKX0c6Y6zKLwn1lrLnDbvbTUNrZq79HXVtRzq92e49ZAmAqNwlJ2XQq9suwa8iIUKIDns9L3FV0zUeVFOzQbQW6LT6O20I1yQlFPpgW4cDuoNgbT5hf1MqSZtPnXEI70+wE27SAmMwnF2XgolJo9AorBzFEgO6qsBieNkGr1VgSBp7b6DqigeqGknlhAcAP4PZ1G4U6y1lAfn+XT5pApNHFHEz7wUFPz+2NWsZzbVqan1SL8/yw3xHgKjSAo7ARmJdCMgUdg5CiQH9dWAxFi2YafuevHdsNcTlUYPoH84i8J9YqkhKnW/M2y4baayfWR9poNUOKSJ71QkRW62T09fO1klBTmW7t/d3KbV2xsS3Cog/YwuLdD6W6erZFD0vth4oF1XLatSl5/1NCBWVvteY2u7pty/WjUNrQ61DPC24YX5Ea9/tmyoMw0BYIkvM0NL5lRaei9kTEy8w0c6demjG/sdFC0rzNPT104mKJpGfJkZumxyharmn6fFV0xUYX6W7c9oaevQNY9tUuU9K/Xkxlp1dvkT0NLw+G5F0tgJyEjSvGc3MwgCCWBnoYJFCiB+An0v2gSy8UC7Ln10o9o7uhxqGeBlkRfvDMlHgOvYXahnTIy/zi6/ntxYq4n3/NV22rzUvePv3lmnaM28GWSXpSlfZobOO/kYrbj+LJUV5vXrM1raOjT/ha2a+bO1jsZ+CIwiqQIvheVFkVf3pe5OwiAIJIadCSmLFED8jC4t0IKLTot6H7tkAGve2v1RxOv1/dwFBSCx7CzUMybGV01Dq2b+bK3mv7BVre2dtp4tzM/Sr6+YqE13/psum1xB6TtodGmBVt8yXffOOqVfu0el7s04TmYqEhhF0o0uLdCqm8+xtO2aQRBIHKsT0pa2Ds1d+jop9UCcTB9bammBkF0yQGSdXX4tf6M+4j0VxQMdag0Auyjx5KzOLr9WvvuBLn74b/1KnS8rzNOK68/SuScfQ0AUvfRMr186p1I3Th+lYQXZtj5jd3Ob1u1oTFALeyMwClfoue2alEIgeaxOSOtaDumZ13c71CogtVFfDYiP9dVNEQ/2LCnI0bQTSxxsEQC7KPGUeIG0+cp7Vuqa37xh+8Ac0uZhlS8zQzPGlWreBeO08Y5zbe8ird13MIGt+wSBUbgKKYVA8lmte3jvn7axQAHECfXVgNhFe4GaXVnGribAAyjxlDg90+btBEQH5/l04/RRWjqnkrR59Et/DmlyKsuDwChch5RCIPmsLFK0tnexQAHEEfXVgNhEe4GaWFHoUEsAxMpOiSfeCa3p7PLrqmVVttPmSwblaNP8f9O8C8ZpxrhSAqKIidVDmsqL8h3L8iAwCtchpRBwh+ljSyltATiMnaNA/535qSL5wryw+zIG6OzRwxxuEYBYWB0TeSe0Zn11k+2gaFlhnp6+drKyfYSOEF+RDmkqL8rXkjmVjgXhBxhjqFYMVzp8pFNTF6yJWCsqoKQgR6/eNkM5WZkOtAxIH09urNX8F7ZGva9kUI6evnYydYaAOKlpaNWlj25UY6uFMZD+B0iSVm9v0NxlVWGvL51TqRnjSh1sEYB4qGlo1cUP/y1q6jfvhJEt27BTd734rqV7C/OzNO/8sbrkjHJ2iCLhOrv8Wl/dpNp9B1VRPFDTTixx9PuOsD9ci5RCIPkuqSyjtAWQBJzMC9gXrcaoU4c4AIgvzqGIDyv1Ggvzs/TrKyZSRxSOChzSNOfskUkp10BgFK5GSiGQXJS2AJKHk3kBe4YXRl7IK4tyHYB7cQ5F7KaOGRbxa1hWmKcV15+lc08+hoAo0gqBUbgeO0eB5GKBAkgeTuYF7Ii8a9rwng94Fov1sQt8DY8OjhbmZ+neWadozbwZlOVBWqLGKDyDmqNAclHzEEgeq/2P8Q/pLFr9vLu+eLLmnD3SwRYBiDfeCWOX7HqOgNuwYxSewc5RILmoeQgkDyfzAtGRSg+kPt4JY5fseo6A2xAYhaeQ0gskFzUPgeQJ9L/C/KyI9zH+IX2RSg+kA94JAcQTgVF4DquEQHJR8xBIHk7mBcKrbzkU+Xpzm0MtAZBovBMCiBcCo/AkUnqB5LI6GW1p62CVHogzTuYFQiOVHkgv7BwFEA8ERuFZpPQCyUXNQyA5OJkXCIdUeiDdsHMUQKwIjMLTSOkFkouah0BysEsG6ItUeiA9kU0IIBYERuF5pPQCyUXNQyA52CUD9EYqPZC+7GYTrtvR6ECrAHgBgVGkBFJ6geSi5iGQHOwcBXoilR5IZ3bGxE27mh1oEQAvIDCKlEFKL5A81DwEkoedo0A3UukBWB0TH3+tlvdBAJIIjCLFkNILJA8714Dkof8BpNID6BYYEwfn+cLes/9QJ++DACQRGEUKIqUXSB52rgHJQ/8DSKUH0C0326fLJ4+IeA/vgwAkAqNIQaT0Aslld+fa3KWvczIoECfsHEU6I5UeQE8TKwqj3sP7IAACo0hJvBgCyWVn51pdyyE98/puB1oFpAd2jiJdkUoPoKepY4aRSQggKgKjSFm8GALJZWeB4t4/bWMyCsQRC4RIT6TSA/gEmYQArCAwipTGiyGQXFYXKFrbu5iMAnHGAiHSTe2+yKnyu/cddKglANyC90EA0RAYRcrjxRBIrsCEtCAn/MmgUvdk9KplVdQbBeLI7gshfRBedqC9M+L1/YcjXweQmngfBBAJgVGkBV4MgeTKzfbpjgvHRb1vd3ObVm9vcKBFQPqw80JIH4SXDc6NvAA3ODfLoZYAcBveBwGEQ2AUacPui+G6HY0OtApIH5dUllkqgD/v2c2s0gNxZueFkD4IryovGhjxeoWFMQhA6uJ9EEAoBEaRVuy8GNZShwqIK6sF8FvaOlilBxLA6gshfRDexeFLACLjfRDA0QiMIu1YfTGsKI686wCAfVYno6zSA4lBH0Qqq285FPH6W7UtDrUEgJvxPgigJwKjSEvRXgzLi/I17cQSh1sFpIfAZLQwP3KtN1bpgcSgDyJVRQtiLH+jnp3QACTxPgjgEwRGkbYCL4ZH1zwsL8rXkjmVyswg3wpIlNGlBVpw0WkR72GVHkgc+iBS0dQxwyLuAGs80M5OaABBvA8CkKTIRzcCKW50aYFW3XyO1lc3qXbfQVUUD9S0E0sYBAEHTB9bqvKifO1ubutzjVV6IPHog0g1vswMza4s06LVNWHvYSc0gJ54HwRAYBRpz5eZoRnjSpPdDCDtBA5jumpZVa/ADKv0gDPog0hFnykbGvF6WSEn0wPojfdBIL0RGAUAJA2r9EBy0QeRejiZHgAAWEdgFACQVKzSA8lFH0QqiXYyfX2I0hEAACB9cfgSAAAAgJQwPEqqPKn0AACgJwKjAAAAAFIEqfQAAMA6AqMAAAAAUgKp9AAAwA4CowAAAABSQkXxwJiuAwCA9EJgFAAAAEBKmDpmmMqLQtcRLS/K17QTSxxuEQAAcDMCowAAAABSgi8zQ0vmVPYJjpYX5WvJnEplZlBkFAAAfGKAMSZyhXIAAAAA8JDOLr/WVzepdt9BVRQP1LQTSwiKAgCAPgiMAgAAAAAAAEg7pNIDAAAAAAAASDsERgEAAAAAAACkHQKjAAAAAAAAANIOgVEAAAAAAAAAaYfAKAAAAAAAAIC0Q2AUAAAAAAAAQNohMAoAAAAAAAAg7RAYBQAAAAAAAJB2CIwCAAAAAAAASDsERgEAAAAAAACkHQKjAAAAAAAAANIOgVEAAAAAAAAAaYfAKAAAAAAAAIC0Q2AUAAAAAAAAQNohMAoAAAAAAAAg7RAYBQAAAAAAAJB2CIwCAAAAAAAASDsERgEAAAAAAACkHQKjAAAAAAAAANIOgVEAAAAAAAAAaYfAKAAAAAAAAIC0Q2AUAAAAAAAAQNohMAoAAAAAAAAg7fz/ncKXPVMUZPQAAAAASUVORK5CYII=", 291 | "text/plain": [ 292 | "Figure(nrows=1, ncols=6, refwidth=1.0, refheight=1.0)" 293 | ] 294 | }, 295 | "metadata": { 296 | "image/png": { 297 | "height": 130, 298 | "width": 675 299 | } 300 | }, 301 | "output_type": "display_data" 302 | } 303 | ], 304 | "source": [ 305 | "fig, ax = pplt.subplots(ncols=len(outputs), nrows=1, refwidth=1, refheight=1)\n", 306 | "\n", 307 | "for i in range(len(outputs)):\n", 308 | " nx.draw(outputs[i][2], pos=nx.kamada_kawai_layout(outputs[i][2]), \n", 309 | " node_size=5, ax=ax[i], node_color=COLORS[0])\n", 310 | " ax[i].set_title(f\"{str(outputs[i][6])[:5]}\")" 311 | ] 312 | } 313 | ], 314 | "metadata": { 315 | "kernelspec": { 316 | "display_name": "base", 317 | "language": "python", 318 | "name": "python3" 319 | }, 320 | "language_info": { 321 | "codemirror_mode": { 322 | "name": "ipython", 323 | "version": 3 324 | }, 325 | "file_extension": ".py", 326 | "mimetype": "text/x-python", 327 | "name": "python", 328 | "nbconvert_exporter": "python", 329 | "pygments_lexer": "ipython3", 330 | "version": "3.8.17" 331 | } 332 | }, 333 | "nbformat": 4, 334 | "nbformat_minor": 5 335 | } 336 | -------------------------------------------------------------------------------- /result.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "id": "e0da5b19", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import warnings\n", 11 | "warnings.filterwarnings(\"ignore\")\n", 12 | "\n", 13 | "import os\n", 14 | "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n", 15 | "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n", 16 | "\n", 17 | "import glob\n", 18 | "import pickle\n", 19 | "import platform\n", 20 | "import copy\n", 21 | "import umap\n", 22 | "\n", 23 | "import proplot as pplt\n", 24 | "import pandas as pd\n", 25 | "\n", 26 | "from tqdm import tqdm\n", 27 | "from vendi_score import vendi\n", 28 | "from graph_utils import *\n", 29 | "from data_utils import *\n", 30 | "from model_utils import *\n", 31 | "from analysis_utils import *\n", 32 | "from saliency_utils import *\n", 33 | "from generation_utils import *\n", 34 | "\n", 35 | "DATA_DIR = '/scratch/gpfs/sj0161/topo_data/'\n", 36 | "WEIGHT_DIR = '/scratch/gpfs/sj0161/topo_result/'\n", 37 | "ANALYSIS_DIR = '/scratch/gpfs/sj0161/topo_analysis/'\n", 38 | "\n", 39 | "\n", 40 | "pplt.rc['figure.facecolor'] = 'white'\n", 41 | "\n", 42 | "COLORS = []\n", 43 | "colors1 = pplt.Cycle('default')\n", 44 | "colors2 = pplt.Cycle('538')\n", 45 | "\n", 46 | "for color in colors1:\n", 47 | " COLORS.append(color['color'])\n", 48 | "\n", 49 | "for color in colors2:\n", 50 | " COLORS.append(color['color'])\n", 51 | "\n", 52 | "LATENT_DIM = 8" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "4a0d4303", 58 | "metadata": {}, 59 | "source": [ 60 | "# Load Data" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 2, 66 | "id": "29724377", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "((x_train, y_train, c_train, l_train, graph_train),\n", 71 | "(x_valid, y_valid, c_valid, l_valid, graph_valid),\n", 72 | "(x_test, y_test, c_test, l_test, graph_test),\n", 73 | "NAMES, SCALER, LE) = load_data(os.path.join(DATA_DIR, 'rg2.pickle'), fold=0, if_validation=True)\n", 74 | "\n", 75 | "graph_all = np.concatenate((graph_train, graph_valid, graph_test))" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "id": "01d93500", 81 | "metadata": {}, 82 | "source": [ 83 | "# Hyperparameter Selection Based on Validation Dataset" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 6, 89 | "id": "445d4a7f", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Train: 858 Valid: 215 Test: 269\n" 97 | ] 98 | }, 99 | { 100 | "name": "stderr", 101 | "output_type": "stream", 102 | "text": [ 103 | "100%|██████████| 675/675 [07:42<00:00, 1.46it/s]\n" 104 | ] 105 | }, 106 | { 107 | "name": "stdout", 108 | "output_type": "stream", 109 | "text": [ 110 | "Train: 858 Valid: 215 Test: 269\n" 111 | ] 112 | }, 113 | { 114 | "name": "stderr", 115 | "output_type": "stream", 116 | "text": [ 117 | "100%|██████████| 675/675 [07:22<00:00, 1.52it/s]\n" 118 | ] 119 | }, 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "Train: 858 Valid: 215 Test: 269\n" 125 | ] 126 | }, 127 | { 128 | "name": "stderr", 129 | "output_type": "stream", 130 | "text": [ 131 | "100%|██████████| 675/675 [04:48<00:00, 2.34it/s]\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "# List of encoders to iterate through\n", 137 | "encoders = [\"desc_gnn\", \"gnn\", \"desc_dnn\"]\n", 138 | "\n", 139 | "results = []\n", 140 | "\n", 141 | "for encoder in encoders:\n", 142 | " elbos, baccs, kls, cls, rls, rec, rmses, r2s, f1s, files_select = get_val_metrics(encoder=encoder)\n", 143 | " \n", 144 | " results.append([elbos, baccs, kls, cls, rls, rec, rmses, r2s, f1s, files_select])\n", 145 | " \n", 146 | "with open(os.path.join(ANALYSIS_DIR, \"hyper_select_result_all.pickle\"), \"wb\") as handle:\n", 147 | " pickle.dump(results, handle)\n", 148 | " " 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 11, 154 | "id": "5e1503f1", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "with open(os.path.join(ANALYSIS_DIR, \"hyper_select_result_all.pickle\"), \"rb\") as handle:\n", 159 | " results = pickle.load(handle)\n", 160 | "\n", 161 | "# Concatenate arrays from all results\n", 162 | "baccs = np.concatenate([result[1] for result in results])\n", 163 | "r2s = np.concatenate([result[7] for result in results])\n", 164 | "f1s = np.concatenate([result[8] for result in results])\n", 165 | "kls = np.concatenate([result[2] for result in results])\n", 166 | "rmses = np.concatenate([result[6] for result in results])\n", 167 | "files_select = np.concatenate([result[9] for result in results])\n", 168 | "\n", 169 | "# Define indices for separating results of different encoders\n", 170 | "idx = [0, \n", 171 | " len(results[0][0]), \n", 172 | " len(results[0][0]) + len(results[0][0]), \n", 173 | " len(results[0][0]) + len(results[1][0]) + len(results[2][0])]\n", 174 | "\n", 175 | "# Calculate the Pareto front and find the best points\n", 176 | "pareto_indices, pareto_front = pareto_frontier(1 - baccs, 1 - r2s, 1 - f1s, kls, limits=[0.1, 0.1, 0.1, np.inf])\n", 177 | " " 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 18, 183 | "id": "6ae53cad", 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "data = pd.DataFrame(columns=[\n", 188 | " \"Encoder\", \"File Name\", \"BACC\", \"RMSE\", \"R2\", \"F1\", \"KL\", \"Distance to Origin\"\n", 189 | "])\n", 190 | "\n", 191 | "for i, encoder in enumerate(encoders):\n", 192 | " \n", 193 | " idx_temp = np.where((np.array(pareto_indices) >= idx[i]) & (np.array(pareto_indices) < idx[i+1]))[0]\n", 194 | " \n", 195 | " pareto_front_temp = pareto_front[idx_temp]\n", 196 | " best_point, best_idx = closest_to_origin(pareto_front_temp)#, vmin, vmax)\n", 197 | "\n", 198 | " # Extract information for the best point\n", 199 | " encoder_name = encoder\n", 200 | " file_name = files_select[best_idx].split('/')[-1].split('.pickle')[0]\n", 201 | " bacc = baccs[best_idx]\n", 202 | " rmse = rmses[best_idx]\n", 203 | " r2 = r2s[best_idx]\n", 204 | " f1 = f1s[best_idx]\n", 205 | " kl = kls[best_idx]\n", 206 | " distance_to_origin = np.linalg.norm(np.array(best_point[:]))\n", 207 | "\n", 208 | " # Create a dictionary with the data for the current encoder\n", 209 | " row_data = {\n", 210 | " \"Encoder\": encoder_name,\n", 211 | " \"File Name\": file_name,\n", 212 | " \"BACC\": bacc,\n", 213 | " \"RMSE\": rmse,\n", 214 | " \"R2\": r2,\n", 215 | " \"F1\": f1,\n", 216 | " \"KL\": kl,\n", 217 | " \"Distance to Origin\": distance_to_origin\n", 218 | " }\n", 219 | "\n", 220 | " # Append the data to the DataFrame\n", 221 | " data = data.append(row_data, ignore_index=True)\n", 222 | "\n", 223 | "data.to_csv(os.path.join(ANALYSIS_DIR, \"hyper_select_result.csv\"), index=False)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 19, 229 | "id": "22c9fcc4", 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "data": { 234 | "text/html": [ 235 | "
\n", 236 | "\n", 249 | "\n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | "
EncoderFile NameBACCRMSER2F1KLDistance to Origin
0desc_gnndesc_gnn_cnn_20230828_8_val_decoder_acc_True_T...0.9438521.4675480.9915040.99534818.7243710.382945
1gnngnn_cnn_20230828_8_val_decoder_acc_True_True_1...0.9447853.0444100.9634350.97683615.6018050.842681
2desc_dnndesc_dnn_cnn_20230828_8_val_decoder_acc_True_T...0.9281051.1407030.9948670.99534816.0417610.399231
\n", 299 | "
" 300 | ], 301 | "text/plain": [ 302 | " Encoder File Name BACC \\\n", 303 | "0 desc_gnn desc_gnn_cnn_20230828_8_val_decoder_acc_True_T... 0.943852 \n", 304 | "1 gnn gnn_cnn_20230828_8_val_decoder_acc_True_True_1... 0.944785 \n", 305 | "2 desc_dnn desc_dnn_cnn_20230828_8_val_decoder_acc_True_T... 0.928105 \n", 306 | "\n", 307 | " RMSE R2 F1 KL Distance to Origin \n", 308 | "0 1.467548 0.991504 0.995348 18.724371 0.382945 \n", 309 | "1 3.044410 0.963435 0.976836 15.601805 0.842681 \n", 310 | "2 1.140703 0.994867 0.995348 16.041761 0.399231 " 311 | ] 312 | }, 313 | "execution_count": 19, 314 | "metadata": {}, 315 | "output_type": "execute_result" 316 | } 317 | ], 318 | "source": [ 319 | "data" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "id": "63db1e0c", 325 | "metadata": {}, 326 | "source": [ 327 | "# Results of the Best Model of Each Encoder Type" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "id": "2e57e06c", 333 | "metadata": {}, 334 | "source": [ 335 | "### With 10 Repetitions Considering the Randomness of Sampling Layer" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 20, 341 | "id": "6737529a", 342 | "metadata": {}, 343 | "outputs": [ 344 | { 345 | "name": "stdout", 346 | "output_type": "stream", 347 | "text": [ 348 | "RMSE: Train 1.19+/-0.03 Valid 1.47+/-0.05 Test 1.49+/-0.06\n", 349 | "R2: Train 0.99+/-0.00 Valid 0.99+/-0.00 Test 0.99+/-0.00\n", 350 | "F1: Train 1.00+/-0.00 Valid 0.98+/-0.01 Test 0.96+/-0.01\n", 351 | "BACC: Train 1.00+/-0.00 Valid 0.94+/-0.00 Test 0.94+/-0.00\n", 352 | "RMSE: Train 2.38+/-0.05 Valid 3.16+/-0.06 Test 3.14+/-0.07\n", 353 | "R2: Train 0.98+/-0.00 Valid 0.96+/-0.00 Test 0.96+/-0.00\n", 354 | "F1: Train 1.00+/-0.00 Valid 0.97+/-0.00 Test 0.98+/-0.00\n", 355 | "BACC: Train 0.99+/-0.00 Valid 0.94+/-0.00 Test 0.94+/-0.00\n", 356 | "RMSE: Train 1.05+/-0.03 Valid 1.27+/-0.10 Test 1.99+/-0.06\n", 357 | "R2: Train 1.00+/-0.00 Valid 0.99+/-0.00 Test 0.98+/-0.00\n", 358 | "F1: Train 1.00+/-0.00 Valid 1.00+/-0.00 Test 0.97+/-0.00\n", 359 | "BACC: Train 0.97+/-0.00 Valid 0.92+/-0.00 Test 0.92+/-0.00\n" 360 | ] 361 | } 362 | ], 363 | "source": [ 364 | "files = [\n", 365 | " \"desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64\",\n", 366 | " \"gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 100]_0.01_64\",\n", 367 | " \"desc_dnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 0.01]_0.01_64\",\n", 368 | "]\n", 369 | "\n", 370 | "baccs = []\n", 371 | "rmses = []\n", 372 | "r2s = []\n", 373 | "f1s = []\n", 374 | "accs = []\n", 375 | "train_outs = []\n", 376 | "valid_outs = []\n", 377 | "test_outs = []\n", 378 | "\n", 379 | "for file in files:\n", 380 | " ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 381 | "\n", 382 | " model, pickle_file = train_vae(ENCODER, DECODER, MONITOR,\n", 383 | " IF_REG, IF_CLS,\n", 384 | " x_train, x_valid,\n", 385 | " y_train, y_valid,\n", 386 | " c_train, c_valid,\n", 387 | " l_train, l_valid,\n", 388 | " 1.0, weights,\n", 389 | " LR, BS,\n", 390 | " False, ) \n", 391 | "\n", 392 | " train_out, valid_out, test_out, bacc, rmse, r2, f1, acc = get_metrics(model, ENCODER, IF_REG, IF_CLS,\n", 393 | " x_train, y_train, c_train, l_train,\n", 394 | " x_valid, y_valid, c_valid, l_valid,\n", 395 | " x_test, y_test, c_test, l_test,\n", 396 | " n_repeat=10, if_bacc=True)\n", 397 | " baccs.append(bacc)\n", 398 | " rmses.append(rmse)\n", 399 | " r2s.append(r2)\n", 400 | " f1s.append(f1)\n", 401 | " accs.append(acc)\n", 402 | " train_outs.append(train_out)\n", 403 | " valid_outs.append(valid_out)\n", 404 | " test_outs.append(test_out)\n", 405 | " \n", 406 | "baccs = np.array(baccs)\n", 407 | "rmses = np.array(rmses)\n", 408 | "r2s = np.array(r2s)\n", 409 | "f1s = np.array(f1s)\n", 410 | "accs = np.array(accs)\n", 411 | "\n", 412 | "\n", 413 | "with open(os.path.join(ANALYSIS_DIR, \"accuracy_metric.pickle\"), \"wb\") as handle:\n", 414 | " pickle.dump(baccs, handle)\n", 415 | " pickle.dump(rmses, handle)\n", 416 | " pickle.dump(r2s, handle)\n", 417 | " pickle.dump(f1s, handle)\n", 418 | " pickle.dump(accs, handle)\n", 419 | "\n", 420 | "with open(os.path.join(ANALYSIS_DIR, \"all_outs.pickle\"), \"wb\") as handle:\n", 421 | " pickle.dump(train_outs, handle)\n", 422 | " pickle.dump(valid_outs, handle)\n", 423 | " pickle.dump(test_outs, handle)" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "id": "de2a9c45", 429 | "metadata": {}, 430 | "source": [ 431 | "### With 1 Repetition for Sample Test Set Prediction" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 22, 437 | "id": "110af222", 438 | "metadata": {}, 439 | "outputs": [ 440 | { 441 | "name": "stdout", 442 | "output_type": "stream", 443 | "text": [ 444 | "RMSE: Train 1.16+/-0.00 Valid 1.56+/-0.00 Test 1.50+/-0.00\n", 445 | "R2: Train 1.00+/-0.00 Valid 0.99+/-0.00 Test 0.99+/-0.00\n", 446 | "F1: Train 1.00+/-0.00 Valid 0.99+/-0.00 Test 0.97+/-0.00\n", 447 | "BACC: Train 1.00+/-0.00 Valid 0.94+/-0.00 Test 0.94+/-0.00\n", 448 | "RMSE: Train 2.42+/-0.00 Valid 3.21+/-0.00 Test 3.26+/-0.00\n", 449 | "R2: Train 0.98+/-0.00 Valid 0.96+/-0.00 Test 0.96+/-0.00\n", 450 | "F1: Train 1.00+/-0.00 Valid 0.97+/-0.00 Test 0.97+/-0.00\n", 451 | "BACC: Train 0.99+/-0.00 Valid 0.94+/-0.00 Test 0.94+/-0.00\n", 452 | "RMSE: Train 1.06+/-0.00 Valid 1.22+/-0.00 Test 1.91+/-0.00\n", 453 | "R2: Train 1.00+/-0.00 Valid 0.99+/-0.00 Test 0.99+/-0.00\n", 454 | "F1: Train 1.00+/-0.00 Valid 1.00+/-0.00 Test 0.97+/-0.00\n", 455 | "BACC: Train 0.97+/-0.00 Valid 0.92+/-0.00 Test 0.92+/-0.00\n" 456 | ] 457 | } 458 | ], 459 | "source": [ 460 | "files = [\n", 461 | " \"desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64\",\n", 462 | " \"gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 100]_0.01_64\",\n", 463 | " \"desc_dnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 0.01]_0.01_64\",\n", 464 | "]\n", 465 | "\n", 466 | "test_outs = []\n", 467 | "\n", 468 | "for file in files:\n", 469 | " ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 470 | "\n", 471 | " model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 472 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 473 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 474 | "\n", 475 | " train_out, valid_out, test_out, bacc, rmse, r2, f1, acc = get_metrics(model, ENCODER, IF_REG, IF_CLS,\n", 476 | " x_train, y_train, c_train, l_train,\n", 477 | " x_valid, y_valid, c_valid, l_valid,\n", 478 | " x_test, y_test, c_test, l_test,\n", 479 | " n_repeat=1,if_bacc=True)\n", 480 | " test_outs.append(test_out)\n", 481 | "\n", 482 | "with open(os.path.join(ANALYSIS_DIR, \"test_out.pickle\"), \"wb\") as handle:\n", 483 | " pickle.dump(test_outs, handle)" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "id": "a6e29647", 489 | "metadata": {}, 490 | "source": [ 491 | "# Saliency Map Calculation" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 24, 497 | "id": "48c7b4c6", 498 | "metadata": {}, 499 | "outputs": [ 500 | { 501 | "name": "stdout", 502 | "output_type": "stream", 503 | "text": [ 504 | "RMSE: Train 1.16+/-0.00 Valid 1.56+/-0.00 Test 1.50+/-0.00\n", 505 | "R2: Train 1.00+/-0.00 Valid 0.99+/-0.00 Test 0.99+/-0.00\n", 506 | "F1: Train 1.00+/-0.00 Valid 0.99+/-0.00 Test 0.97+/-0.00\n", 507 | "RMSE: Train 1.06+/-0.00 Valid 1.22+/-0.00 Test 1.91+/-0.00\n", 508 | "R2: Train 1.00+/-0.00 Valid 0.99+/-0.00 Test 0.99+/-0.00\n", 509 | "F1: Train 1.00+/-0.00 Valid 1.00+/-0.00 Test 0.97+/-0.00\n" 510 | ] 511 | } 512 | ], 513 | "source": [ 514 | "files = [\n", 515 | " \"desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64\",\n", 516 | " \"desc_dnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 0.01]_0.01_64\",\n", 517 | "]\n", 518 | "\n", 519 | "grads = []\n", 520 | "\n", 521 | "for file in files:\n", 522 | " ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 523 | "\n", 524 | " model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 525 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 526 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 527 | "\n", 528 | " train_out, valid_out, test_out, bacc, rmse, r2, f1, acc = get_metrics(model, ENCODER, IF_REG, IF_CLS,\n", 529 | " x_train, y_train, c_train, l_train,\n", 530 | " x_valid, y_valid, c_valid, l_valid,\n", 531 | " x_test, y_test, c_test, l_test,\n", 532 | " n_repeat=1, if_bacc=False)\n", 533 | "\n", 534 | " grad = compute_saliency(model, x_train, y_train, l_train, c_train, \n", 535 | " output_index=1, enc_type=ENCODER, if_reg=IF_REG, if_cls=IF_CLS)\n", 536 | " \n", 537 | " grads.append(grad)\n", 538 | "\n", 539 | "with open(os.path.join(ANALYSIS_DIR, \"saliency.pickle\"), \"wb\") as handle:\n", 540 | " pickle.dump(grads, handle)\n", 541 | " pickle.dump(files, handle)" 542 | ] 543 | }, 544 | { 545 | "cell_type": "markdown", 546 | "id": "72beb168", 547 | "metadata": {}, 548 | "source": [ 549 | "# Latent Space Calculation" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "id": "80228cd6", 555 | "metadata": {}, 556 | "source": [ 557 | "### Different Encoders" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 25, 563 | "id": "84cd4cd2", 564 | "metadata": {}, 565 | "outputs": [], 566 | "source": [ 567 | "files = [\n", 568 | " \"desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64\",\n", 569 | " \"gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 100]_0.01_64\",\n", 570 | " \"desc_dnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 0.01]_0.01_64\",\n", 571 | "]\n", 572 | "\n", 573 | "for file in files:\n", 574 | " ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 575 | "\n", 576 | " model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 577 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 578 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 579 | "\n", 580 | " latent_train = latent_model(model, data=[x_train, l_train], enc_type=ENCODER, mean_var=False)\n", 581 | " latent_valid = latent_model(model, data=[x_valid, l_valid], enc_type=ENCODER, mean_var=False)\n", 582 | " latent_test = latent_model(model, data=[x_test, l_test], enc_type=ENCODER, mean_var=False)\n", 583 | " \n", 584 | " short_name = file.split(\"_20230828\")[0]\n", 585 | " \n", 586 | " if short_name == \"desc_gnn_cnn\":\n", 587 | " short_name = \"\"\n", 588 | " elif short_name == \"gnn_cnn\":\n", 589 | " short_name = \"_gnn\"\n", 590 | " elif short_name == \"desc_dnn_cnn\":\n", 591 | " short_name = \"_topo\"\n", 592 | "\n", 593 | " with open(os.path.join(ANALYSIS_DIR, f\"latent_space{short_name}.pickle\"), \"wb\") as handle:\n", 594 | " pickle.dump(latent_train, handle)\n", 595 | " pickle.dump(latent_valid, handle)\n", 596 | " pickle.dump(latent_test, handle)" 597 | ] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "id": "3ff786dc", 602 | "metadata": {}, 603 | "source": [ 604 | "### Auxilary Tasks" 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": 26, 610 | "id": "c231f9d9", 611 | "metadata": {}, 612 | "outputs": [], 613 | "source": [ 614 | "files = [\"desc_gnn_cnn_20230828_8_val_loss_False_True_1.0_[1.0, 1]_0.001_64\",\n", 615 | " \"desc_gnn_cnn_20230828_8_val_loss_True_False_1.0_[1.0, 1]_0.001_64\",\n", 616 | " \"desc_gnn_cnn_20230828_8_val_loss_False_False_1.0_[1.0]_0.001_64\",\n", 617 | " \n", 618 | "]\n", 619 | "\n", 620 | "\n", 621 | "for file in files:\n", 622 | " ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 623 | "\n", 624 | " model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 625 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 626 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 627 | "\n", 628 | " latent_train = latent_model(model, data=[x_train, l_train], enc_type=ENCODER, mean_var=False)\n", 629 | " latent_valid = latent_model(model, data=[x_valid, l_valid], enc_type=ENCODER, mean_var=False)\n", 630 | " latent_test = latent_model(model, data=[x_test, l_test], enc_type=ENCODER, mean_var=False)\n", 631 | " \n", 632 | " short_name = \"_ref_\" + file.split(\"_\")[-6].lower() + \"_cls_\" + file.split(\"_\")[-5].lower()\n", 633 | "\n", 634 | " with open(os.path.join(ANALYSIS_DIR, f\"latent_space{short_name}.pickle\"), \"wb\") as handle:\n", 635 | " pickle.dump(latent_train, handle)\n", 636 | " pickle.dump(latent_valid, handle)\n", 637 | " pickle.dump(latent_test, handle)" 638 | ] 639 | }, 640 | { 641 | "cell_type": "markdown", 642 | "id": "5919cd26", 643 | "metadata": {}, 644 | "source": [ 645 | "# UMAP Generation" 646 | ] 647 | }, 648 | { 649 | "cell_type": "markdown", 650 | "id": "c0b15c5c", 651 | "metadata": {}, 652 | "source": [ 653 | "### UMAP TopoGNN" 654 | ] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": 12, 659 | "id": "cf400b39", 660 | "metadata": {}, 661 | "outputs": [], 662 | "source": [ 663 | "# Load model\n", 664 | "file = \"desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64\"\n", 665 | "\n", 666 | "ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 667 | "\n", 668 | "model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 669 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 670 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 671 | "\n", 672 | "# Load latent space\n", 673 | "with open(os.path.join(ANALYSIS_DIR, \"latent_space.pickle\"), \"rb\") as handle:\n", 674 | " latent_train = pickle.load(handle)\n", 675 | " latent_valid = pickle.load(handle)\n", 676 | " latent_test = pickle.load(handle)\n", 677 | " \n", 678 | "latent_all = np.concatenate((latent_train, latent_valid, latent_test), axis=0)\n", 679 | "\n", 680 | "# UMAP transformation\n", 681 | "u = umap.UMAP(n_components=2, random_state=0)\n", 682 | "z = u.fit_transform(latent_all)\n", 683 | " \n", 684 | "with open(os.path.join(ANALYSIS_DIR, \"umap.pickle\"), \"wb\") as handle:\n", 685 | " pickle.dump(z, handle)\n", 686 | " pickle.dump(u, handle)" 687 | ] 688 | }, 689 | { 690 | "cell_type": "markdown", 691 | "id": "053af9a5", 692 | "metadata": {}, 693 | "source": [ 694 | "### UMAP GNN" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": 13, 700 | "id": "c810f50c", 701 | "metadata": {}, 702 | "outputs": [], 703 | "source": [ 704 | "# Load model\n", 705 | "file = \"gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 100]_0.01_64\"\n", 706 | "\n", 707 | "ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 708 | "\n", 709 | "model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 710 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 711 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 712 | "\n", 713 | "# Load latent space\n", 714 | "with open(os.path.join(ANALYSIS_DIR, \"latent_space_gnn.pickle\"), \"rb\") as handle:\n", 715 | " latent_train = pickle.load(handle)\n", 716 | " latent_valid = pickle.load(handle)\n", 717 | " latent_test = pickle.load(handle)\n", 718 | " \n", 719 | "latent_all = np.concatenate((latent_train, latent_valid, latent_test), axis=0)\n", 720 | "\n", 721 | "# UMAP transformation\n", 722 | "u = umap.UMAP(n_components=2, random_state=0)\n", 723 | "z = u.fit_transform(latent_all)\n", 724 | " \n", 725 | "with open(os.path.join(ANALYSIS_DIR, \"umap_gnn.pickle\"), \"wb\") as handle:\n", 726 | " pickle.dump(z, handle)\n", 727 | " pickle.dump(u, handle)" 728 | ] 729 | }, 730 | { 731 | "cell_type": "markdown", 732 | "id": "d327f1bc", 733 | "metadata": {}, 734 | "source": [ 735 | "### UMAP Topo" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": null, 741 | "id": "881174b1", 742 | "metadata": {}, 743 | "outputs": [], 744 | "source": [ 745 | "# Load model\n", 746 | "file = \"desc_dnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 0.01]_0.01_64\"\n", 747 | "\n", 748 | "ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 749 | "\n", 750 | "model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 751 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 752 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 753 | "\n", 754 | "# Load latent space\n", 755 | "with open(os.path.join(ANALYSIS_DIR, \"latent_space_topo.pickle\"), \"rb\") as handle:\n", 756 | " latent_train = pickle.load(handle)\n", 757 | " latent_valid = pickle.load(handle)\n", 758 | " latent_test = pickle.load(handle)\n", 759 | " \n", 760 | "latent_all = np.concatenate((latent_train, latent_valid, latent_test), axis=0)\n", 761 | "\n", 762 | "# UMAP transformation\n", 763 | "u = umap.UMAP(n_components=2, random_state=0)\n", 764 | "z = u.fit_transform(latent_all)\n", 765 | " \n", 766 | "with open(os.path.join(ANALYSIS_DIR, \"umap_topo.pickle\"), \"wb\") as handle:\n", 767 | " pickle.dump(z, handle)\n", 768 | " pickle.dump(u, handle)" 769 | ] 770 | }, 771 | { 772 | "cell_type": "markdown", 773 | "id": "3ef9cc5e", 774 | "metadata": {}, 775 | "source": [ 776 | "### UMAP No Classsification No Regression" 777 | ] 778 | }, 779 | { 780 | "cell_type": "code", 781 | "execution_count": 15, 782 | "id": "58d8e2a7", 783 | "metadata": {}, 784 | "outputs": [], 785 | "source": [ 786 | "# Load model\n", 787 | "file = \"desc_gnn_cnn_20230828_8_val_loss_False_False_1.0_[1.0]_0.001_64\"\n", 788 | "\n", 789 | "ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 790 | "\n", 791 | "model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 792 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 793 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 794 | "\n", 795 | "# Load latent space\n", 796 | "with open(os.path.join(ANALYSIS_DIR, \"latent_space_reg_false_cls_false.pickle\"), \"rb\") as handle:\n", 797 | " latent_train = pickle.load(handle)\n", 798 | " latent_valid = pickle.load(handle)\n", 799 | " latent_test = pickle.load(handle)\n", 800 | " \n", 801 | "latent_all = np.concatenate((latent_train, latent_valid, latent_test), axis=0)\n", 802 | "\n", 803 | "# UMAP transformation\n", 804 | "u = umap.UMAP(n_components=2, random_state=0)\n", 805 | "z = u.fit_transform(latent_all)\n", 806 | " \n", 807 | "with open(os.path.join(ANALYSIS_DIR, \"umap_reg_false_cls_false.pickle\"), \"wb\") as handle:\n", 808 | " pickle.dump(z, handle)\n", 809 | " pickle.dump(u, handle)" 810 | ] 811 | }, 812 | { 813 | "cell_type": "markdown", 814 | "id": "875e7b10", 815 | "metadata": {}, 816 | "source": [ 817 | "### UMAP No Classification Only" 818 | ] 819 | }, 820 | { 821 | "cell_type": "code", 822 | "execution_count": 16, 823 | "id": "4f9bac07", 824 | "metadata": {}, 825 | "outputs": [], 826 | "source": [ 827 | "# Load model\n", 828 | "file = \"desc_gnn_cnn_20230828_8_val_loss_True_False_1.0_[1.0, 1]_0.001_64\"\n", 829 | "\n", 830 | "ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 831 | "\n", 832 | "model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 833 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 834 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 835 | "\n", 836 | "# Load latent space\n", 837 | "with open(os.path.join(ANALYSIS_DIR, \"latent_space_reg_true_cls_false.pickle\"), \"rb\") as handle:\n", 838 | " latent_train = pickle.load(handle)\n", 839 | " latent_valid = pickle.load(handle)\n", 840 | " latent_test = pickle.load(handle)\n", 841 | " \n", 842 | "latent_all = np.concatenate((latent_train, latent_valid, latent_test), axis=0)\n", 843 | "\n", 844 | "# UMAP transformation\n", 845 | "u = umap.UMAP(n_components=2, random_state=0)\n", 846 | "z = u.fit_transform(latent_all)\n", 847 | " \n", 848 | "with open(os.path.join(ANALYSIS_DIR, \"umap_reg_true_cls_false.pickle\"), \"wb\") as handle:\n", 849 | " pickle.dump(z, handle)\n", 850 | " pickle.dump(u, handle)" 851 | ] 852 | }, 853 | { 854 | "cell_type": "markdown", 855 | "id": "344b69fe", 856 | "metadata": {}, 857 | "source": [ 858 | "### UMAP No Regression Only" 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": 17, 864 | "id": "b85cf842", 865 | "metadata": {}, 866 | "outputs": [], 867 | "source": [ 868 | "# Load model\n", 869 | "file = \"desc_gnn_cnn_20230828_8_val_loss_False_False_1.0_[1.0]_0.001_64\"\n", 870 | "\n", 871 | "ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 872 | "\n", 873 | "model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 874 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 875 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 876 | "\n", 877 | "# Load latent space\n", 878 | "with open(os.path.join(ANALYSIS_DIR, \"latent_space_reg_false_cls_true.pickle\"), \"rb\") as handle:\n", 879 | " latent_train = pickle.load(handle)\n", 880 | " latent_valid = pickle.load(handle)\n", 881 | " latent_test = pickle.load(handle)\n", 882 | " \n", 883 | "latent_all = np.concatenate((latent_train, latent_valid, latent_test), axis=0)\n", 884 | "\n", 885 | "# UMAP transformation\n", 886 | "u = umap.UMAP(n_components=2, random_state=0)\n", 887 | "z = u.fit_transform(latent_all)\n", 888 | " \n", 889 | "with open(os.path.join(ANALYSIS_DIR, \"umap_reg_false_cls_true.pickle\"), \"wb\") as handle:\n", 890 | " pickle.dump(z, handle)\n", 891 | " pickle.dump(u, handle)" 892 | ] 893 | }, 894 | { 895 | "cell_type": "markdown", 896 | "id": "adfa609e", 897 | "metadata": {}, 898 | "source": [ 899 | "# Polymer Topology Generation Using TopoGNN Latent Space UMAP" 900 | ] 901 | }, 902 | { 903 | "cell_type": "markdown", 904 | "id": "cee0b9af", 905 | "metadata": {}, 906 | "source": [ 907 | "### TopoGNN Fix Z1, Increase Z2." 908 | ] 909 | }, 910 | { 911 | "cell_type": "code", 912 | "execution_count": 19, 913 | "id": "5c22fb78", 914 | "metadata": {}, 915 | "outputs": [], 916 | "source": [ 917 | "def interpolate(start, end, num_points):\n", 918 | " \"\"\"\n", 919 | " Interpolate between two points in n-dimensional space.\n", 920 | "\n", 921 | " Parameters:\n", 922 | " start (array-like): The starting point as an array-like object.\n", 923 | " end (array-like): The ending point as an array-like object.\n", 924 | " num_points (int): The number of points to interpolate between start and end.\n", 925 | "\n", 926 | " Returns:\n", 927 | " np.ndarray: An array containing the interpolated points between start and end.\n", 928 | " \"\"\"\n", 929 | " start = np.array(start)\n", 930 | " end = np.array(end)\n", 931 | " t_values = np.linspace(0, 1, num_points)\n", 932 | " points = [(1-t)*start + t*end for t in t_values]\n", 933 | " return np.array(points)\n", 934 | "\n", 935 | "\n", 936 | "\n", 937 | "file = \"desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64\"\n", 938 | "\n", 939 | "ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 940 | "\n", 941 | "model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 942 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 943 | " l_train, l_valid, 1.0, weights, LR, BS, False) \n", 944 | " \n", 945 | "with open(os.path.join(ANALYSIS_DIR, \"umap.pickle\"), \"rb\") as handle:\n", 946 | " _ = pickle.load(handle)\n", 947 | " u = pickle.load(handle)" 948 | ] 949 | }, 950 | { 951 | "cell_type": "code", 952 | "execution_count": 10, 953 | "id": "44a7a76e", 954 | "metadata": {}, 955 | "outputs": [], 956 | "source": [ 957 | "start = np.array([1, 4])\n", 958 | "end = np.array([1, 16])\n", 959 | "num_points = 20\n", 960 | "max_iter = 50\n", 961 | "\n", 962 | "points = interpolate(start, end, num_points)\n", 963 | "\n", 964 | "outputs = []\n", 965 | "\n", 966 | "for i in range(num_points):\n", 967 | " point = points[i]\n", 968 | " \n", 969 | " for _ in range(max_iter):\n", 970 | " point_ = point + np.random.normal(0, 0.2, size=(2,)) # random pertubation in a neighborhood\n", 971 | " data = u.inverse_transform(point_[None, ...])\n", 972 | " \n", 973 | " G0, G, gen_cls, cln_cls, gen_reg, cln_reg_m, cln_reg_s = polymer_generation(model, data, None, ENCODER)\n", 974 | " \n", 975 | " if check_valid(gen_reg, cln_reg_m, gen_cls, cln_cls):\n", 976 | " outputs.append([data, point_, G0, G, gen_cls, cln_cls, gen_reg, cln_reg_m, cln_reg_s])\n", 977 | " break\n", 978 | " \n", 979 | "with open(os.path.join(ANALYSIS_DIR, \"umap_1_4_1_16.pickle\"), \"wb\") as handle:\n", 980 | " pickle.dump(outputs, handle)" 981 | ] 982 | }, 983 | { 984 | "cell_type": "markdown", 985 | "id": "63eab4d1", 986 | "metadata": {}, 987 | "source": [ 988 | "### TopoGNN Fix Z2, Increase Z1." 989 | ] 990 | }, 991 | { 992 | "cell_type": "code", 993 | "execution_count": 11, 994 | "id": "ef305232", 995 | "metadata": {}, 996 | "outputs": [], 997 | "source": [ 998 | "start = np.array([-3, 12])\n", 999 | "end = np.array([6, 12])\n", 1000 | "num_points = 20\n", 1001 | "max_iter = 50\n", 1002 | "\n", 1003 | "points = interpolate(start, end, num_points)\n", 1004 | "\n", 1005 | "outputs = []\n", 1006 | "\n", 1007 | "for i in range(num_points):\n", 1008 | " point = points[i]\n", 1009 | " \n", 1010 | " for _ in range(max_iter):\n", 1011 | " point_ = point + np.random.normal(0, 0.2, size=(2,))\n", 1012 | " data = u.inverse_transform(point_[None, ...])\n", 1013 | " \n", 1014 | " G0, G, gen_cls, cln_cls, gen_reg, cln_reg_m, cln_reg_s = polymer_generation(model, data, None, ENCODER)\n", 1015 | " \n", 1016 | " if check_valid(gen_reg, cln_reg_m, gen_cls, cln_cls):\n", 1017 | " outputs.append([data, point_, G0, G, gen_cls, cln_cls, gen_reg, cln_reg_m, cln_reg_s])\n", 1018 | " break\n", 1019 | " \n", 1020 | "with open(os.path.join(ANALYSIS_DIR, \"umap_-3_12_6_12.pickle\"), \"wb\") as handle:\n", 1021 | " pickle.dump(outputs, handle)" 1022 | ] 1023 | }, 1024 | { 1025 | "cell_type": "markdown", 1026 | "id": "d2e34b56", 1027 | "metadata": {}, 1028 | "source": [ 1029 | "# Property Guided Topology Generation" 1030 | ] 1031 | }, 1032 | { 1033 | "cell_type": "code", 1034 | "execution_count": 23, 1035 | "id": "327952db", 1036 | "metadata": {}, 1037 | "outputs": [], 1038 | "source": [ 1039 | "def check_isomorphism(graph_list, new_graph):\n", 1040 | " for graph in graph_list:\n", 1041 | " if nx.is_isomorphic(graph, new_graph):\n", 1042 | " return True\n", 1043 | " return False\n", 1044 | "\n", 1045 | "def rg_latent_vector(l, y_train, c_train, poly_type='branch', target_rg=40):\n", 1046 | " \n", 1047 | " idx = np.where(NAMES == poly_type)[0][0]\n", 1048 | "\n", 1049 | " a = l[np.where(c_train == idx)[0]]\n", 1050 | " y = y_train[np.where(c_train == idx)[0]]\n", 1051 | " \n", 1052 | " if np.abs(y - target_rg).min() < 1:\n", 1053 | " \n", 1054 | " idx2 = np.where(np.abs(y - target_rg) < 1)[0]\n", 1055 | "\n", 1056 | " return a[idx2], y[idx2]\n", 1057 | " \n", 1058 | " else:\n", 1059 | " \n", 1060 | " return None, None" 1061 | ] 1062 | }, 1063 | { 1064 | "cell_type": "code", 1065 | "execution_count": 31, 1066 | "id": "085d2fa2", 1067 | "metadata": {}, 1068 | "outputs": [], 1069 | "source": [ 1070 | "# Load latent space\n", 1071 | "with open(os.path.join(ANALYSIS_DIR, \"latent_space.pickle\"), \"rb\") as handle:\n", 1072 | " latent_train = pickle.load(handle)\n", 1073 | " latent_valid = pickle.load(handle)\n", 1074 | " latent_test = pickle.load(handle)\n", 1075 | " \n", 1076 | "latent_all = np.concatenate((latent_train, latent_valid, latent_test), axis=0)\n", 1077 | "\n", 1078 | "# Load data label\n", 1079 | "((x_train, y_train, c_train, l_train, graph_train),\n", 1080 | "(x_valid, y_valid, c_valid, l_valid, graph_valid),\n", 1081 | "(x_test, y_test, c_test, l_test, graph_test),\n", 1082 | "NAMES, SCALER, LE) = load_data(os.path.join(DATA_DIR, 'rg2.pickle'), fold=0, if_validation=True)\n", 1083 | "\n", 1084 | "graph_all = np.concatenate((graph_train, graph_valid, graph_test))\n", 1085 | "y_all = np.concatenate((y_train, y_valid, y_test))\n", 1086 | "c_all = np.concatenate((c_train, c_valid, c_test))\n", 1087 | "\n", 1088 | "# Load TopoGNN model\n", 1089 | "file = \"desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64\"\n", 1090 | "\n", 1091 | "ENCODER, DECODER, MONITOR, IF_REG, IF_CLS, weights, LR, BS = get_spec(file)\n", 1092 | "\n", 1093 | "model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, IF_REG, IF_CLS,\n", 1094 | " x_train, x_valid, y_train, y_valid, c_train, c_valid,\n", 1095 | " l_train, l_valid, 1.0, weights, LR, BS, False)" 1096 | ] 1097 | }, 1098 | { 1099 | "cell_type": "code", 1100 | "execution_count": 38, 1101 | "id": "810ae1cd", 1102 | "metadata": {}, 1103 | "outputs": [], 1104 | "source": [ 1105 | "def gen_prop_polymer(target_rg=30, target_top=\"branch\", max_iter=1000):\n", 1106 | " outputs = []\n", 1107 | " graphs = []\n", 1108 | "\n", 1109 | " latent_vector, rg_dataset = rg_latent_vector(latent_all, y_all, c_all, target_top, target_rg)\n", 1110 | " \n", 1111 | " if latent_vector is None:\n", 1112 | " raise Exception(\"The target rg2 is too large/small \")\n", 1113 | "\n", 1114 | " for i in range(max_iter):\n", 1115 | " noise = np.random.normal(0, 1, (1, 8)) * 0.1\n", 1116 | " num = len(latent_vector)\n", 1117 | "\n", 1118 | " for j in range(num):\n", 1119 | " K.clear_session()\n", 1120 | " d_in = latent_vector[j, ...] + noise\n", 1121 | " graph_raw, graph, gen_cls, cln_cls, gen_reg, cln_reg_m, cln_reg_s = polymer_generation(model, d_in, None, ENCODER)\n", 1122 | "\n", 1123 | " flag1 = np.abs(gen_reg - cln_reg_m) < 2\n", 1124 | " flag2 = np.abs(gen_reg - target_rg) < 2\n", 1125 | " flag3 = np.abs(cln_reg_m - target_rg) < 2\n", 1126 | " flag4 = gen_cls == cln_cls\n", 1127 | " flag5 = cln_cls == target_top\n", 1128 | "\n", 1129 | " x_clean_ = nx.to_numpy_array(graph)\n", 1130 | " n_clean = len(x_clean_)\n", 1131 | " x_clean = np.zeros((1, 100, 100))\n", 1132 | " x_clean[0, :n_clean, :n_clean] = x_clean_\n", 1133 | " x_clean = x_clean.astype(\"int\")\n", 1134 | "\n", 1135 | " l_clean = get_desc(graph)[None, ...]\n", 1136 | " l_clean = SCALER.transform(l_clean)\n", 1137 | "\n", 1138 | " d_clean = latent_model(model, data=[x_clean, l_clean], enc_type=ENCODER, mean_var=False)\n", 1139 | "\n", 1140 | " if flag1 and flag2 and flag3 and flag4 and flag5: \n", 1141 | " if len(graphs) > 0:\n", 1142 | " if not check_isomorphism(graphs, graph) and not check_isomorphism(graph_all, graph):\n", 1143 | " graphs.append(graph)\n", 1144 | " outputs.append([d_in, graph_raw, graph, gen_cls, cln_cls, gen_reg, cln_reg_m, cln_reg_s, latent_vector[j, ...], d_in, d_clean])\n", 1145 | " else:\n", 1146 | " graphs.append(graph)\n", 1147 | " outputs.append([d_in, graph_raw, graph, gen_cls, cln_cls, gen_reg, cln_reg_m, cln_reg_s, latent_vector[j, ...], d_in, d_clean])\n", 1148 | "\n", 1149 | " \n", 1150 | " # check latent space distance\n", 1151 | " z_cleans = []\n", 1152 | " z_raws = []\n", 1153 | " rmses = []\n", 1154 | " new_outputs = []\n", 1155 | " \n", 1156 | " for i in range(len(outputs)):\n", 1157 | " graph = outputs[i][2]\n", 1158 | " x_clean_ = nx.to_numpy_array(graph)\n", 1159 | " n_clean = len(x_clean_)\n", 1160 | " x_clean = np.zeros((1, 100, 100))\n", 1161 | " x_clean[0, :n_clean, :n_clean] = x_clean_\n", 1162 | " x_clean = x_clean.astype(\"int\")\n", 1163 | "\n", 1164 | " l_clean = get_desc(graph)[None, ...]\n", 1165 | " l_clean = SCALER.transform(l_clean)\n", 1166 | "\n", 1167 | " z_clean = latent_model(model, data=[x_clean, l_clean], enc_type=ENCODER, mean_var=False).squeeze()\n", 1168 | "\n", 1169 | " z_raw = outputs[i][0].squeeze()\n", 1170 | " rmse = skm.mean_absolute_error(z_raw, z_clean)\n", 1171 | "\n", 1172 | " if rmse < 1:\n", 1173 | " z_cleans.append(z_clean)\n", 1174 | " z_raws.append(z_raw)\n", 1175 | " rmses.append(rmse)\n", 1176 | " new_outputs.append(outputs[i]+[z_clean])\n", 1177 | " \n", 1178 | " return new_outputs" 1179 | ] 1180 | }, 1181 | { 1182 | "cell_type": "code", 1183 | "execution_count": 45, 1184 | "id": "b79e6b3d", 1185 | "metadata": {}, 1186 | "outputs": [], 1187 | "source": [ 1188 | "target_rg = 30\n", 1189 | "\n", 1190 | "for target_top in [\"star\", \"branch\", \"comb\", \"cyclic\"]:\n", 1191 | " new_outputs = gen_prop_polymer(target_rg=target_rg, target_top=target_top, max_iter=1000)\n", 1192 | "\n", 1193 | " with open(os.path.join(ANALYSIS_DIR, f\"gen_{target_top}_{target_rg}_all.pickle\"), \"wb\") as handle:\n", 1194 | " pickle.dump(new_outputs, handle)" 1195 | ] 1196 | }, 1197 | { 1198 | "cell_type": "code", 1199 | "execution_count": 46, 1200 | "id": "b4c605dd", 1201 | "metadata": {}, 1202 | "outputs": [], 1203 | "source": [ 1204 | "target_rg = 50\n", 1205 | "\n", 1206 | "for target_top in [\"branch\", \"comb\"]:\n", 1207 | " new_outputs = gen_prop_polymer(target_rg=target_rg, target_top=target_top, max_iter=1000)\n", 1208 | "\n", 1209 | " with open(os.path.join(ANALYSIS_DIR, f\"gen_{target_top}_{target_rg}_all.pickle\"), \"wb\") as handle:\n", 1210 | " pickle.dump(new_outputs, handle)" 1211 | ] 1212 | }, 1213 | { 1214 | "cell_type": "code", 1215 | "execution_count": 47, 1216 | "id": "11effa11", 1217 | "metadata": {}, 1218 | "outputs": [], 1219 | "source": [ 1220 | "target_rg = 7.5\n", 1221 | "\n", 1222 | "for target_top in [\"star\", \"dendrimer\"]:\n", 1223 | " new_outputs = gen_prop_polymer(target_rg=target_rg, target_top=target_top, max_iter=1000)\n", 1224 | "\n", 1225 | " with open(os.path.join(ANALYSIS_DIR, f\"gen_{target_top}_{target_rg}_all.pickle\"), \"wb\") as handle:\n", 1226 | " pickle.dump(new_outputs, handle)" 1227 | ] 1228 | }, 1229 | { 1230 | "cell_type": "markdown", 1231 | "id": "01b2a819", 1232 | "metadata": {}, 1233 | "source": [ 1234 | "# Diversity Calculation Using Vendi Score" 1235 | ] 1236 | }, 1237 | { 1238 | "cell_type": "code", 1239 | "execution_count": 10, 1240 | "id": "f437e998", 1241 | "metadata": {}, 1242 | "outputs": [], 1243 | "source": [ 1244 | "# convert all graphs into graph eigen spectra\n", 1245 | "graph_total = [graph_train, graph_valid, graph_test]\n", 1246 | "\n", 1247 | "lap_spec_data = []\n", 1248 | "\n", 1249 | "for graphs in graph_total:\n", 1250 | " for G in graphs:\n", 1251 | " lap_spec = nx.laplacian_spectrum(G)\n", 1252 | " lap_spec_zero_pad = np.zeros((100,))\n", 1253 | " lap_spec_zero_pad[:len(lap_spec)] = lap_spec\n", 1254 | " lap_spec_data.append(lap_spec_zero_pad)\n", 1255 | " \n", 1256 | "lap_spec_data = np.array(lap_spec_data)\n", 1257 | "\n", 1258 | "with open(os.path.join(ANALYSIS_DIR, \"lap_spec_data.pickle\"), \"wb\") as handle:\n", 1259 | " pickle.dump(lap_spec_data, handle)" 1260 | ] 1261 | }, 1262 | { 1263 | "cell_type": "markdown", 1264 | "id": "b634b4c3", 1265 | "metadata": {}, 1266 | "source": [ 1267 | "### Dataset Vendi Score" 1268 | ] 1269 | }, 1270 | { 1271 | "cell_type": "code", 1272 | "execution_count": 4, 1273 | "id": "f95d0b61", 1274 | "metadata": {}, 1275 | "outputs": [ 1276 | { 1277 | "name": "stdout", 1278 | "output_type": "stream", 1279 | "text": [ 1280 | "Dataset Vendi Score: 2.0968\n" 1281 | ] 1282 | } 1283 | ], 1284 | "source": [ 1285 | "print(f\"Dataset Vendi Score: {vendi.score_dual(lap_spec_data):0.4f}\")" 1286 | ] 1287 | }, 1288 | { 1289 | "cell_type": "markdown", 1290 | "id": "81899ae6", 1291 | "metadata": {}, 1292 | "source": [ 1293 | "### Vendi Score for the Latent Space" 1294 | ] 1295 | }, 1296 | { 1297 | "cell_type": "code", 1298 | "execution_count": 7, 1299 | "id": "5a5bf305", 1300 | "metadata": {}, 1301 | "outputs": [ 1302 | { 1303 | "name": "stdout", 1304 | "output_type": "stream", 1305 | "text": [ 1306 | "latent_space_desc_gnn_cnn\n", 1307 | "Dataset Vendi Score: 7.3225 \n", 1308 | "\n", 1309 | "latent_space_gnn_cnn\n", 1310 | "Dataset Vendi Score: 7.4370 \n", 1311 | "\n", 1312 | "latent_space_desc_dnn_cnn\n", 1313 | "Dataset Vendi Score: 7.0863 \n", 1314 | "\n" 1315 | ] 1316 | } 1317 | ], 1318 | "source": [ 1319 | "files = [\n", 1320 | " \"latent_space_desc_gnn_cnn.pickle\",\n", 1321 | " \"latent_space_gnn_cnn.pickle\",\n", 1322 | " \"latent_space_desc_dnn_cnn.pickle\"\n", 1323 | "]\n", 1324 | "\n", 1325 | "for file in files:\n", 1326 | " with open(os.path.join(ANALYSIS_DIR, file), \"rb\") as handle:\n", 1327 | " latent_data = pickle.load(handle)\n", 1328 | " print(file.split(\".pickle\")[0])\n", 1329 | " print(f\"Dataset Vendi Score: {vendi.score_dual(latent_data):0.4f} \\n\")" 1330 | ] 1331 | }, 1332 | { 1333 | "cell_type": "code", 1334 | "execution_count": 8, 1335 | "id": "2ccb5698", 1336 | "metadata": {}, 1337 | "outputs": [ 1338 | { 1339 | "name": "stdout", 1340 | "output_type": "stream", 1341 | "text": [ 1342 | "latent_space_False_False\n", 1343 | "Dataset Vendi Score: 5.8532 \n", 1344 | "\n", 1345 | "latent_space_False_True\n", 1346 | "Dataset Vendi Score: 6.3171 \n", 1347 | "\n", 1348 | "latent_space_True_False\n", 1349 | "Dataset Vendi Score: 5.3128 \n", 1350 | "\n" 1351 | ] 1352 | } 1353 | ], 1354 | "source": [ 1355 | "files = [\n", 1356 | " \"latent_space_False_False.pickle\",\n", 1357 | " \"latent_space_False_True.pickle\",\n", 1358 | " \"latent_space_True_False.pickle\"\n", 1359 | "]\n", 1360 | "\n", 1361 | "for file in files:\n", 1362 | " with open(os.path.join(ANALYSIS_DIR, file), \"rb\") as handle:\n", 1363 | " latent_data = pickle.load(handle)\n", 1364 | " print(file.split(\".pickle\")[0])\n", 1365 | " print(f\"Dataset Vendi Score: {vendi.score_dual(latent_data):0.4f} \\n\")" 1366 | ] 1367 | }, 1368 | { 1369 | "cell_type": "markdown", 1370 | "id": "1b2f9469", 1371 | "metadata": {}, 1372 | "source": [ 1373 | "### Vendi Score for Random Generation Based on Different Models" 1374 | ] 1375 | }, 1376 | { 1377 | "cell_type": "code", 1378 | "execution_count": 13, 1379 | "id": "eef0825e", 1380 | "metadata": {}, 1381 | "outputs": [ 1382 | { 1383 | "name": "stdout", 1384 | "output_type": "stream", 1385 | "text": [ 1386 | "Dataset Vendi Score: 5.0684\n" 1387 | ] 1388 | } 1389 | ], 1390 | "source": [ 1391 | "with open(os.path.join(ANALYSIS_DIR, \"no_valid_random_gen_desc_gnn_cnn.pickle\"), \"rb\") as handle:\n", 1392 | " gen_data = pickle.load(handle)\n", 1393 | " \n", 1394 | "gen_clean_graph = [gen_data[i][2] for i in range(len(gen_data))]\n", 1395 | "\n", 1396 | "lap_spec_data = []\n", 1397 | "\n", 1398 | "for G in gen_clean_graph:\n", 1399 | " lap_spec = nx.laplacian_spectrum(G)\n", 1400 | " lap_spec_zero_pad = np.zeros((100,))\n", 1401 | " lap_spec_zero_pad[:len(lap_spec)] = lap_spec\n", 1402 | " lap_spec_data.append(lap_spec_zero_pad)\n", 1403 | "\n", 1404 | "print(f\"Dataset Vendi Score: {vendi.score_dual(lap_spec_data):0.4f}\")" 1405 | ] 1406 | }, 1407 | { 1408 | "cell_type": "code", 1409 | "execution_count": 14, 1410 | "id": "2972ddcc", 1411 | "metadata": {}, 1412 | "outputs": [ 1413 | { 1414 | "name": "stdout", 1415 | "output_type": "stream", 1416 | "text": [ 1417 | "Dataset Vendi Score: 4.9580\n" 1418 | ] 1419 | } 1420 | ], 1421 | "source": [ 1422 | "with open(os.path.join(ANALYSIS_DIR, \"no_valid_random_gen_gnn_cnn.pickle\"), \"rb\") as handle:\n", 1423 | " gen_data = pickle.load(handle)\n", 1424 | " \n", 1425 | "gen_clean_graph = [gen_data[i][2] for i in range(len(gen_data))]\n", 1426 | "\n", 1427 | "lap_spec_data = []\n", 1428 | "\n", 1429 | "for G in gen_clean_graph:\n", 1430 | " lap_spec = nx.laplacian_spectrum(G)\n", 1431 | " lap_spec_zero_pad = np.zeros((100,))\n", 1432 | " lap_spec_zero_pad[:len(lap_spec)] = lap_spec\n", 1433 | " lap_spec_data.append(lap_spec_zero_pad)\n", 1434 | "\n", 1435 | "print(f\"Dataset Vendi Score: {vendi.score_dual(lap_spec_data):0.4f}\")" 1436 | ] 1437 | }, 1438 | { 1439 | "cell_type": "code", 1440 | "execution_count": 16, 1441 | "id": "e3e6a06f", 1442 | "metadata": {}, 1443 | "outputs": [ 1444 | { 1445 | "name": "stdout", 1446 | "output_type": "stream", 1447 | "text": [ 1448 | "Dataset Vendi Score: 4.3305\n" 1449 | ] 1450 | } 1451 | ], 1452 | "source": [ 1453 | "with open(os.path.join(ANALYSIS_DIR, \"no_valid_random_gen_desc_dnn_cnn.pickle\"), \"rb\") as handle:\n", 1454 | " gen_data = pickle.load(handle)\n", 1455 | " \n", 1456 | "gen_clean_graph = [gen_data[i][2] for i in range(len(gen_data))]\n", 1457 | "\n", 1458 | "lap_spec_data = []\n", 1459 | "\n", 1460 | "for G in gen_clean_graph:\n", 1461 | " lap_spec = nx.laplacian_spectrum(G)\n", 1462 | " lap_spec_zero_pad = np.zeros((100,))\n", 1463 | " lap_spec_zero_pad[:len(lap_spec)] = lap_spec\n", 1464 | " lap_spec_data.append(lap_spec_zero_pad)\n", 1465 | "\n", 1466 | "print(f\"Dataset Vendi Score: {vendi.score_dual(lap_spec_data):0.4f}\")" 1467 | ] 1468 | } 1469 | ], 1470 | "metadata": { 1471 | "kernelspec": { 1472 | "display_name": "py38torch113 [~/.conda/envs/py38torch113/]", 1473 | "language": "python", 1474 | "name": "conda_py38torch113" 1475 | }, 1476 | "language_info": { 1477 | "codemirror_mode": { 1478 | "name": "ipython", 1479 | "version": 3 1480 | }, 1481 | "file_extension": ".py", 1482 | "mimetype": "text/x-python", 1483 | "name": "python", 1484 | "nbconvert_exporter": "python", 1485 | "pygments_lexer": "ipython3", 1486 | "version": "3.8.16" 1487 | } 1488 | }, 1489 | "nbformat": 4, 1490 | "nbformat_minor": 5 1491 | } 1492 | -------------------------------------------------------------------------------- /run_vae.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import os 5 | os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' 6 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 7 | 8 | import pickle 9 | import itertools 10 | import numpy as np 11 | 12 | from timeit import default_timer as timer 13 | 14 | from graph_utils import * 15 | from data_utils import * 16 | from model_utils import * 17 | from analysis_utils import * 18 | from saliency_utils import * 19 | from generation_utils import * 20 | 21 | DATA_DIR = '/scratch/gpfs/sj0161/topo_data/' 22 | WEIGHT_DIR = '/scratch/gpfs/sj0161/topo_result/' 23 | ANALYSIS_DIR = '/scratch/gpfs/sj0161/topo_analysis/' 24 | 25 | 26 | def train_vae(enc_type, dec_type, monitor, 27 | if_reg, if_cls, 28 | x_train, x_valid, 29 | y_train, y_valid, 30 | c_train, c_valid, 31 | l_train, l_valid, 32 | beta=1.0, weights=[1.0, 1.0, 1.0], 33 | lr=0.001, bs=32, 34 | if_train=False): 35 | K.clear_session() 36 | t0 = timer() 37 | 38 | if enc_type == 'desc_dnn': 39 | in_train = np.copy(l_train) 40 | in_valid = np.copy(l_valid) 41 | elif enc_type == 'desc_gnn': 42 | in_train = [[np.copy(x_train), np.copy(x_train)], np.copy(l_train)] 43 | in_valid = [[np.copy(x_valid), np.copy(x_valid)], np.copy(l_valid)] 44 | elif enc_type == 'gnn': 45 | in_train = [np.copy(x_train), np.copy(x_train)] 46 | in_valid = [np.copy(x_valid), np.copy(x_valid)] 47 | else: 48 | in_train = np.copy(x_train) 49 | in_valid = np.copy(x_valid) 50 | 51 | model = get_model(beta=beta, enc_type=enc_type, dec_type=dec_type, 52 | if_reg=if_reg, if_cls=if_cls) 53 | 54 | model, loss_weights = compile_model(model, lr=lr, if_reg=if_reg, if_cls=if_cls, weights=weights) 55 | 56 | weight_name, hist_name = get_file_names(enc_type, dec_type, "20230828", 57 | LATENT_DIM, monitor, 58 | if_reg, if_cls, beta, loss_weights, lr, bs) 59 | 60 | 61 | print(weight_name + ' started ...') 62 | 63 | if not if_reg and not if_cls and 'acc' in monitor: 64 | monitor = 'val_acc' 65 | 66 | if if_train: 67 | c1, c2 = get_callbacks(os.path.join(WEIGHT_DIR, weight_name), monitor=monitor) 68 | 69 | if not if_reg and not if_cls: 70 | train_label = x_train 71 | valid_label = x_valid 72 | else: 73 | train_label = [x_train] 74 | valid_label = [x_valid] 75 | 76 | if if_reg: 77 | train_label.append(y_train) 78 | valid_label.append(y_valid) 79 | 80 | if if_cls: 81 | train_label.append(to_categorical(c_train, 6)) 82 | valid_label.append(to_categorical(c_valid, 6)) 83 | 84 | hist = model.fit( 85 | in_train, train_label, validation_data=(in_valid, valid_label), 86 | callbacks=[c1, c2], epochs=1000, verbose=0, batch_size=bs) 87 | 88 | with open(os.path.join(WEIGHT_DIR, hist_name), 'wb') as handle: 89 | pickle.dump(hist.history, handle) 90 | else: 91 | model.load_weights(os.path.join(WEIGHT_DIR, weight_name)) 92 | 93 | t1 = timer() 94 | 95 | print(weight_name + f' finished in {t1-t0:0.2f} sec ...') 96 | 97 | return model, os.path.join(WEIGHT_DIR, hist_name) 98 | 99 | if __name__ == '__main__': 100 | (x_train, y_train, c_train, l_train, graph_train), \ 101 | (x_valid, y_valid, c_valid, l_valid, graph_valid), \ 102 | (x_test, y_test, c_test, l_test, graph_test), \ 103 | NAMES, SCALER = load_data(fold=0, if_validation=True) 104 | 105 | has_nan = np.isnan(l_train).any() 106 | has_inf = np.isinf(l_train).any() 107 | 108 | print("l_train contains NaN values:", has_nan) 109 | print("l_train contains inf values:", has_inf) 110 | 111 | K.clear_session() 112 | 113 | LATENT_DIM = 8 114 | DECODER = "cnn" 115 | 116 | idx_slurm = int(os.environ["SLURM_ARRAY_TASK_ID"]) 117 | 118 | encs = ["desc_gnn", "gnn", "desc_dnn"] 119 | mons = ["val_decoder_loss", "val_decoder_acc", "val_loss"] 120 | 121 | encmons = list(itertools.product(encs, mons)) # shape 9 122 | 123 | ENCODER = encmons[idx_slurm][0] 124 | MONITOR = encmons[idx_slurm][1] 125 | 126 | for IF_REG in [False, True]: 127 | for IF_CLS in [False, True]: 128 | if IF_REG == False or IF_CLS == False: 129 | for LR in [1e-4, 1e-3, 1e-2]: 130 | for BS in [32, 64, 128]: 131 | for rw in [0.01, 0.1, 1, 10, 100]: 132 | for cw in [0.01, 0.1, 1, 10, 100]: 133 | if IF_REG == False and IF_CLS == False: 134 | if MONITOR == "val_decoder_acc": 135 | MONITOR = "val_acc" 136 | elif MONITOR == "val_decoder_loss": 137 | MONITOR = "val_loss" 138 | 139 | weight_name, hist_name = get_file_names(ENCODER, DECODER, "20230829", 140 | LATENT_DIM, MONITOR, 141 | IF_REG, IF_CLS, 1.0, [1.0, rw, cw], LR, BS) 142 | 143 | 144 | 145 | if os.path.exists(os.path.join(WEIGHT_DIR, hist_name)): 146 | if_train = False 147 | else: 148 | if_train = True 149 | 150 | model, pickle_file = train_vae(ENCODER, DECODER, MONITOR, 151 | IF_REG, IF_CLS, 152 | x_train, x_valid, 153 | y_train, y_valid, 154 | c_train, c_valid, 155 | l_train, l_valid, 156 | 1.0, [1.0, rw, cw], 157 | LR, BS, 158 | if_train) 159 | 160 | with open(pickle_file, 'rb') as handle: 161 | hist = pickle.load(handle) -------------------------------------------------------------------------------- /saliency_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.keras.utils import to_categorical 4 | 5 | def compute_saliency(model, x_train, y_train, l_train, c_train, output_index, enc_type, if_reg, if_cls): 6 | """ 7 | Computes the saliency map for the given model and data. 8 | 9 | Args: 10 | model (tf.keras.Model): The model for which to compute the saliency. 11 | x_train (numpy.ndarray): Input feature data. 12 | y_train (numpy.ndarray): Regression target data. 13 | l_train (numpy.ndarray): Additional input data for descriptor models. 14 | c_train (numpy.ndarray): Classification target data. 15 | output_index (int): The index of the output to use for computing saliency. 16 | enc_type (str): The type of encoder used in the model ('desc_gnn' or 'desc_dnn'). 17 | if_reg (bool): Flag indicating if the model includes regression. 18 | if_cls (bool): Flag indicating if the model includes classification. 19 | 20 | Returns: 21 | numpy.ndarray or None: The computed saliency map as a NumPy array, or None if regression is not included. 22 | """ 23 | if if_reg: 24 | # Prepare input data based on encoder type 25 | if enc_type == 'desc_gnn': 26 | input_data = [[tf.convert_to_tensor(x_train), tf.convert_to_tensor(x_train)], 27 | tf.convert_to_tensor(l_train)] 28 | elif enc_type == 'desc_dnn': 29 | input_data = tf.convert_to_tensor(l_train) 30 | 31 | # Prepare target data based on classification flag 32 | if if_cls: 33 | target_data = [tf.convert_to_tensor(x_train), 34 | tf.convert_to_tensor(y_train), 35 | tf.convert_to_tensor(to_categorical(c_train, 6))] 36 | else: 37 | target_data = [tf.convert_to_tensor(x_train), 38 | tf.convert_to_tensor(y_train)] 39 | 40 | # Compute gradients for saliency 41 | with tf.GradientTape() as tape: 42 | tape.watch(input_data) 43 | output = model(input_data) 44 | loss = model.loss[output_index](target_data[output_index], output[output_index]) 45 | 46 | # Get gradients for the relevant input 47 | if enc_type == 'desc_gnn': 48 | grads = tape.gradient(loss, input_data[1]) 49 | elif enc_type == 'desc_dnn': 50 | grads = tape.gradient(loss, input_data) 51 | 52 | # Normalize gradients 53 | grads = tf.abs(grads) 54 | grads = (grads - tf.reduce_min(grads)) / (tf.reduce_max(grads) - tf.reduce_min(grads)) 55 | 56 | return grads.numpy() 57 | else: 58 | return None -------------------------------------------------------------------------------- /topo_analysis/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/.DS_Store -------------------------------------------------------------------------------- /topo_analysis/gen_branch_30_all.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/gen_branch_30_all.pickle -------------------------------------------------------------------------------- /topo_analysis/gen_branch_50_all.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/gen_branch_50_all.pickle -------------------------------------------------------------------------------- /topo_analysis/gen_comb_30_all.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/gen_comb_30_all.pickle -------------------------------------------------------------------------------- /topo_analysis/gen_comb_50_all.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/gen_comb_50_all.pickle -------------------------------------------------------------------------------- /topo_analysis/gen_cyclic_30_all.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/gen_cyclic_30_all.pickle -------------------------------------------------------------------------------- /topo_analysis/gen_dendrimer_7.5_all.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/gen_dendrimer_7.5_all.pickle -------------------------------------------------------------------------------- /topo_analysis/gen_star_30_all.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/gen_star_30_all.pickle -------------------------------------------------------------------------------- /topo_analysis/gen_star_7.5_all.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/gen_star_7.5_all.pickle -------------------------------------------------------------------------------- /topo_analysis/latent_space.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/latent_space.pickle -------------------------------------------------------------------------------- /topo_analysis/saliency.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/saliency.pickle -------------------------------------------------------------------------------- /topo_analysis/umap.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/umap.pickle -------------------------------------------------------------------------------- /topo_analysis/umap_-3_12_6_12.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/umap_-3_12_6_12.pickle -------------------------------------------------------------------------------- /topo_analysis/umap_1_4_1_16.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/umap_1_4_1_16.pickle -------------------------------------------------------------------------------- /topo_analysis/umap_gnn.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/umap_gnn.pickle -------------------------------------------------------------------------------- /topo_analysis/umap_reg_false_cls_false.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/umap_reg_false_cls_false.pickle -------------------------------------------------------------------------------- /topo_analysis/umap_reg_false_cls_true.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/umap_reg_false_cls_true.pickle -------------------------------------------------------------------------------- /topo_analysis/umap_reg_true_cls_false.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/umap_reg_true_cls_false.pickle -------------------------------------------------------------------------------- /topo_analysis/umap_topo.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_analysis/umap_topo.pickle -------------------------------------------------------------------------------- /topo_result/desc_dnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 0.01]_0.01_64.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_result/desc_dnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 0.01]_0.01_64.h5 -------------------------------------------------------------------------------- /topo_result/desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_result/desc_gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 1]_0.001_64.h5 -------------------------------------------------------------------------------- /topo_result/gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 100]_0.01_64.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/topo_result/gnn_cnn_20230828_8_val_decoder_acc_True_True_1.0_[1.0, 1, 100]_0.01_64.h5 -------------------------------------------------------------------------------- /website/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/website/.DS_Store -------------------------------------------------------------------------------- /website/abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webbtheosim/poly-topoGNN-vae/af84a08489db014786aa5d5f5227915d58ab7d57/website/abstract.png --------------------------------------------------------------------------------