├── README.md ├── geometric_solver.py ├── test.py ├── training.py ├── model_periodicBC.py ├── model_dirichletBC.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Learning to optimize multigrid solvers 2 | Code for the paper Learning to optimize multigrid PDE solvers, which can be found on https://arxiv.org/abs/1902.10248. 3 | 4 | The notation of the grid points is as follows: in any grid-related tensor, the (0,0) cell corresponds to the leftmost bottommost grid cell. The (I,J) cell then corresponds to the grid cell located in the I'th column and J'th row of the grid. 5 | -------------------------------------------------------------------------------- /geometric_solver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from warnings import warn 3 | from scipy.sparse import csr_matrix, isspmatrix_csr, SparseEfficiencyWarning 4 | 5 | from pyamg.multilevel import multilevel_solver 6 | from pyamg.relaxation.smoothing import change_smoothers 7 | 8 | 9 | # similar to "ruge_stuben_solver" in pyamg 10 | def geometric_solver(A, prolongation_function, prolongation_args, 11 | presmoother=('gauss_seidel', {'sweep': 'forward'}), 12 | postsmoother=('gauss_seidel', {'sweep': 'forward'}), 13 | max_levels=10, max_coarse=10, **kwargs): 14 | """Create a multilevel solver using geometric AMG. 15 | 16 | Parameters 17 | ---------- 18 | A : csr_matrix 19 | Square matrix in CSR format 20 | presmoother : string or dict 21 | Method used for presmoothing at each level. Method-specific parameters 22 | may be passed in using a tuple, e.g. 23 | presmoother=('gauss_seidel',{'sweep':'symmetric}), the default. 24 | postsmoother : string or dict 25 | Postsmoothing method with the same usage as presmoother 26 | max_levels: integer 27 | Maximum number of levels to be used in the multilevel solver. 28 | max_coarse: integer 29 | Maximum number of variables permitted on the coarse grid. 30 | 31 | Returns 32 | ------- 33 | ml : multilevel_solver 34 | Multigrid hierarchy of matrices and prolongation operators 35 | 36 | Notes 37 | ----- 38 | "coarse_solver" is an optional argument and is the solver used at the 39 | coarsest grid. The default is a pseudo-inverse. Most simply, 40 | coarse_solver can be one of ['splu', 'lu', 'cholesky, 'pinv', 41 | 'gauss_seidel', ... ]. Additionally, coarse_solver may be a tuple 42 | (fn, args), where fn is a string such as ['splu', 'lu', ...] or a callable 43 | function, and args is a dictionary of arguments to be passed to fn. 44 | See [2001TrOoSc]_ for additional details. 45 | 46 | 47 | References 48 | ---------- 49 | .. [2001TrOoSc] Trottenberg, U., Oosterlee, C. W., and Schuller, A., 50 | "Multigrid" San Diego: Academic Press, 2001. Appendix A 51 | 52 | See Also 53 | -------- 54 | aggregation.smoothed_aggregation_solver, multilevel_solver, 55 | aggregation.rootnode_solver 56 | 57 | """ 58 | levels = [multilevel_solver.level()] 59 | 60 | # convert A to csr 61 | if not isspmatrix_csr(A): 62 | try: 63 | A = csr_matrix(A) 64 | warn("Implicit conversion of A to CSR", 65 | SparseEfficiencyWarning) 66 | except BaseException: 67 | raise TypeError('Argument A must have type csr_matrix, \ 68 | or be convertible to csr_matrix') 69 | # preprocess A 70 | A = A.asfptype() 71 | if A.shape[0] != A.shape[1]: 72 | raise ValueError('expected square matrix') 73 | 74 | levels[-1].A = A 75 | 76 | while len(levels) < max_levels and levels[-1].A.shape[0] > max_coarse: 77 | extend_hierarchy(levels, prolongation_function, prolongation_args) 78 | 79 | ml = multilevel_solver(levels, **kwargs) 80 | change_smoothers(ml, presmoother, postsmoother) 81 | return ml 82 | 83 | 84 | # internal function 85 | def extend_hierarchy(levels, prolongation_fn, prolongation_args): 86 | """Extend the multigrid hierarchy.""" 87 | 88 | A = levels[-1].A 89 | 90 | # Generate the interpolation matrix that maps from the coarse-grid to the 91 | # fine-grid 92 | P = prolongation_fn(A, prolongation_args) 93 | 94 | # Generate the restriction matrix that maps from the fine-grid to the 95 | # coarse-grid 96 | R = P.T.tocsr() 97 | 98 | levels[-1].P = P # prolongation operator 99 | levels[-1].R = R # restriction operator 100 | 101 | levels.append(multilevel_solver.level()) 102 | 103 | # Form next level through Galerkin product 104 | A = R * A * P 105 | A = A.astype(np.float64) # convert from complex numbers, should have A.imag==0 106 | levels[-1].A = A 107 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from utils import Utils 4 | import matplotlib.pyplot as plt 5 | import argparse 6 | import scipy 7 | from tqdm import tqdm 8 | 9 | tf.enable_eager_execution() 10 | DEVICE = '/gpu:0' 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | # the grid should be 2^n-1 14 | parser.add_argument('--grid-size', default=255, type=int, help="") 15 | parser.add_argument('--num-test-samples', default=100, type=int, help="") 16 | parser.add_argument('--boundary', default='dirichlet', type=str, help="") 17 | parser.add_argument('--compute-spectral-radius', default=False, type=bool, help="") 18 | parser.add_argument('--bb-row-normalize', default=False, type=bool, help="") 19 | 20 | args = parser.parse_args() 21 | 22 | num_cycles = 41 23 | utils = Utils(grid_size=args.grid_size, device=DEVICE, bc=args.boundary) 24 | if args.boundary == 'dirichlet': 25 | from model_dirichletBC import Pnetwork 26 | else: 27 | from model_periodicBC import Pnetwork 28 | m = Pnetwork(grid_size=args.grid_size, device=DEVICE) 29 | 30 | checkpoint_dir = './training_dir' 31 | 32 | with tf.device(DEVICE): 33 | lr = tf.Variable([3.4965356e-05]) 34 | optimizer = tf.train.AdamOptimizer(lr) 35 | root = tf.train.Checkpoint(optimizer=optimizer, model=m, optimizer_step=tf.train.get_or_create_global_step()) 36 | root.restore(tf.train.latest_checkpoint(checkpoint_dir)) 37 | 38 | black_box_residual_norms = [] 39 | black_box_errors = [] 40 | black_box_frob_norms = [] 41 | black_box_spectral_radii = [] 42 | net_residual_norms = [] 43 | net_errors = [] 44 | network_spectral_radii = [] 45 | network_frob_norms = [] 46 | A_stencils_test = utils.two_d_stencil(num=args.num_test_samples, grid_size=args.grid_size, epsilon=0.0) 47 | for A_stencil in tqdm(A_stencils_test): 48 | A_matrix = utils.compute_csr_matrices(stencils=A_stencil, grid_size=args.grid_size) 49 | 50 | A_stencil_tf = tf.convert_to_tensor(value=[A_stencil], dtype=tf.double) 51 | b = np.zeros(shape=(args.grid_size ** 2, 1)) 52 | 53 | initial = np.random.normal(loc=0.0, scale=1.0, size=args.grid_size ** 2) 54 | initial = initial[:, np.newaxis] 55 | 56 | _, residual_norms, error_norms, solver = utils.solve_with_model(model=m, 57 | A_matrices=A_matrix, b=b, 58 | initial_guess=initial, 59 | max_iterations=num_cycles, 60 | max_depth=int(np.log2(args.grid_size)) - 1, 61 | blackbox=True, 62 | w_cycle=True) 63 | black_box_errors.append(error_norms) 64 | black_box_residual_norms.append(residual_norms) 65 | if args.compute_spectral_radius: 66 | I = np.eye(args.grid_size ** 2, dtype=np.double) 67 | P = solver.levels[0].P 68 | R = solver.levels[0].R 69 | A = solver.levels[0].A 70 | C = I - P @ scipy.sparse.linalg.inv(R @ A @ P) @ R @ A 71 | 72 | L = scipy.sparse.tril(A) 73 | S = I - scipy.sparse.linalg.inv(L) @ A 74 | M = S @ C @ S 75 | black_box_frob_norms.append(scipy.linalg.norm(M)) 76 | eigs, _ = scipy.sparse.linalg.eigs(M) 77 | black_box_spectral_radius = eigs.max() 78 | black_box_spectral_radii.append(black_box_spectral_radius) 79 | 80 | x, residual_norms, error_norms, solver = utils.solve_with_model(model=m, 81 | A_matrices=A_matrix, b=b, initial_guess=initial, 82 | max_iterations=num_cycles, 83 | max_depth=int(np.log2(args.grid_size)) - 1, 84 | blackbox=False, 85 | w_cycle=True) 86 | 87 | net_errors.append(error_norms) 88 | net_residual_norms.append(residual_norms) 89 | if args.compute_spectral_radius: 90 | I = np.eye(args.grid_size ** 2, dtype=np.double) 91 | P = solver.levels[0].P 92 | R = solver.levels[0].R 93 | A = solver.levels[0].A 94 | C = I - P @ scipy.sparse.linalg.inv(R @ A @ P) @ R @ A 95 | 96 | L = scipy.sparse.tril(A) 97 | S = I - scipy.sparse.linalg.inv(L) @ A 98 | M = S @ C @ S 99 | network_frob_norms.append(scipy.linalg.norm(M)) 100 | eigs, _ = scipy.sparse.linalg.eigs(M) 101 | network_spectral_radius = eigs.max() 102 | network_spectral_radii.append(network_spectral_radius) 103 | 104 | net_errors = np.array(net_errors) 105 | net_errors_log = np.log2(net_errors) 106 | net_errors_div_diff = 2 ** np.diff(net_errors_log) 107 | net_errors_mean = np.mean(net_errors, axis=0) 108 | net_errors_std = np.std(net_errors, axis=0) 109 | net_errors_div_diff_mean = np.mean(net_errors_div_diff, axis=0) 110 | net_errors_div_diff_std = np.std(net_errors_div_diff, axis=0) 111 | 112 | black_box_errors = np.array(black_box_errors) 113 | black_box_errors_log = np.log2(black_box_errors) 114 | black_box_errors_div_diff = 2 ** np.diff(black_box_errors_log) 115 | black_box_errors_mean = np.mean(black_box_errors, axis=0) 116 | black_box_errors_std = np.std(black_box_errors, axis=0) 117 | black_box_errors_div_diff_mean = np.mean(black_box_errors_div_diff, axis=0) 118 | black_box_errors_div_diff_std = np.std(black_box_errors_div_diff, axis=0) 119 | 120 | plt.figure() 121 | plt.plot(np.arange(len(net_errors_mean), dtype=np.int), net_errors_mean, label='nn') 122 | plt.plot(np.arange(len(black_box_errors_mean), dtype=np.int), black_box_errors_mean, label='black box') 123 | plt.xticks(np.arange(len(black_box_errors_mean), step=10)) 124 | plt.xlabel('iteration number') 125 | plt.ylabel('error l2 norm') 126 | plt.yscale('log') 127 | plt.legend() 128 | plt.savefig('results/test_p.png') 129 | 130 | plt.figure() 131 | plt.plot(np.arange(len(net_errors_div_diff_mean), dtype=np.int), net_errors_div_diff_mean, label='nn') 132 | plt.plot(np.arange(len(black_box_errors_div_diff_mean), dtype=np.int), black_box_errors_div_diff_mean, 133 | label='black box') 134 | plt.xticks(np.arange(len(black_box_errors_mean), step=10)) 135 | plt.xlabel('iteration number') 136 | plt.ylabel('error l2 norm') 137 | plt.legend() 138 | plt.savefig('results/test_div_diff.png') 139 | 140 | results_file = open("results/results.txt", 'w') 141 | print(f"network asymptotic error factor: {net_errors_div_diff_mean[-1]:.4f} ± {net_errors_div_diff_std[-1]:.5f}", 142 | file=results_file) 143 | print( 144 | f"black box asymptotic error factor: {black_box_errors_div_diff_mean[-1]:.4f} ± {black_box_errors_div_diff_std[-1]:.5f}", 145 | file=results_file) 146 | net_success_rate = sum(net_errors_div_diff[:, -1] < black_box_errors_div_diff[:, -1]) / args.num_test_samples 147 | print(f"network success rate: {100 * net_success_rate}%", 148 | file=results_file) 149 | if args.compute_spectral_radius: 150 | network_spectral_radii = np.array(network_spectral_radii) 151 | network_spectral_mean = network_spectral_radii.mean() 152 | network_spectral_std = network_spectral_radii.std() 153 | network_frob_norms = np.array(network_frob_norms) 154 | network_frob_mean = network_frob_norms.mean() 155 | network_frob_std = network_frob_norms.std() 156 | black_box_spectral_radii = np.array(black_box_spectral_radii) 157 | black_box_spectral_mean = black_box_spectral_radii.mean() 158 | black_box_spectral_std = black_box_spectral_radii.std() 159 | black_box_frob_norms = np.array(black_box_frob_norms) 160 | black_box_frob_mean = black_box_frob_norms.mean() 161 | black_box_frob_std = black_box_frob_norms.std() 162 | print(f"network spectral radius: {network_spectral_mean:.4f} ± {network_spectral_std:.5f}", 163 | file=results_file) 164 | print(f"network frobenius norm: {network_frob_mean:.4f} ± {network_frob_std:.5f}", 165 | file=results_file) 166 | print(f"black box spectral radius: {black_box_spectral_mean:.4f} ± {black_box_spectral_std:.5f}", 167 | file=results_file) 168 | print(f"black box frobenius norm: {black_box_frob_mean:.4f} ± {black_box_frob_std:.5f}", 169 | file=results_file) 170 | results_file.close() 171 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import argparse 5 | import random 6 | import string 7 | from utils import Utils 8 | from tqdm import tqdm 9 | from tensorboardX import SummaryWriter 10 | 11 | tf.enable_eager_execution() 12 | 13 | DEVICE = "/cpu:0" 14 | 15 | num_training_samples = 10 * 16384 16 | num_test_samples = 128 17 | grid_size = 8 18 | n_test, n_train = 32, 8 19 | checkpoint_dir = './training_dir' 20 | checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') 21 | 22 | with tf.device(DEVICE): 23 | lr_ = 1.2e-5 24 | lr = tf.Variable(lr_) 25 | optimizer = tf.train.AdamOptimizer(lr) 26 | 27 | 28 | def loss(model, n, A_stencil, A_matrices, S_matrices, index=None, pos=-1., phase="Training", epoch=-1, grid_size=8, 29 | remove=True): 30 | with tf.device(DEVICE): 31 | A_matrices = tf.conj(A_matrices) 32 | S_matrices = tf.conj(S_matrices) 33 | pi = tf.constant(np.pi) 34 | theta_x = np.array(([i * 2 * pi / n for i in range(-n // (grid_size * 2) + 1, n // (grid_size * 2) + 1)])) 35 | with tf.device(DEVICE): 36 | if phase == "Test" and epoch == 0: 37 | P_stencil = model(A_stencil, True) 38 | P_matrix = utils.compute_p2LFA(P_stencil, n, grid_size) 39 | P_matrix = tf.transpose(P_matrix, [2, 0, 1, 3, 4]) 40 | P_matrix_t = tf.transpose(P_matrix, [0, 1, 2, 4, 3], conjugate=True) 41 | A_c = tf.matmul(tf.matmul(P_matrix_t, A_matrices), P_matrix) 42 | 43 | index_to_remove = len(theta_x) * (-1 + n // (2 * grid_size)) + n // (2 * grid_size) - 1 44 | A_c = tf.reshape(A_c, (-1, int(theta_x.shape[0]) ** 2, (grid_size // 2) ** 2, (grid_size // 2) ** 2)) 45 | A_c_removed = tf.concat([A_c[:, :index_to_remove], A_c[:, index_to_remove + 1:]], 1) 46 | P_matrix_t_reshape = tf.reshape(P_matrix_t, 47 | (-1, int(theta_x.shape[0]) ** 2, (grid_size // 2) ** 2, grid_size ** 2)) 48 | P_matrix_reshape = tf.reshape(P_matrix, 49 | (-1, int(theta_x.shape[0]) ** 2, grid_size ** 2, (grid_size // 2) ** 2)) 50 | A_matrices_reshaped = tf.reshape(A_matrices, 51 | (-1, int(theta_x.shape[0]) ** 2, grid_size ** 2, grid_size ** 2)) 52 | A_matrices_removed = tf.concat( 53 | [A_matrices_reshaped[:, :index_to_remove], A_matrices_reshaped[:, index_to_remove + 1:]], 1) 54 | 55 | P_matrix_removed = tf.concat( 56 | [P_matrix_reshape[:, :index_to_remove], P_matrix_reshape[:, index_to_remove + 1:]], 1) 57 | P_matrix_t_removed = tf.concat( 58 | [P_matrix_t_reshape[:, :index_to_remove], P_matrix_t_reshape[:, index_to_remove + 1:]], 1) 59 | 60 | A_coarse_inv_removed = tf.matrix_solve(A_c_removed, P_matrix_t_removed) 61 | 62 | CGC_removed = tf.eye(grid_size ** 2, dtype=tf.complex128) \ 63 | - tf.matmul(tf.matmul(P_matrix_removed, A_coarse_inv_removed), A_matrices_removed) 64 | S_matrices_reshaped = tf.reshape(S_matrices, 65 | (-1, int(theta_x.shape[0]) ** 2, grid_size ** 2, grid_size ** 2)) 66 | S_removed = tf.concat( 67 | [S_matrices_reshaped[:, :index_to_remove], S_matrices_reshaped[:, index_to_remove + 1:]], 1) 68 | iteration_matrix = tf.matmul(tf.matmul(CGC_removed, S_removed), S_removed) 69 | loss_test = tf.reduce_mean(tf.reduce_mean(tf.reduce_sum(tf.square(tf.abs(iteration_matrix)), [2, 3]), 1)) 70 | return tf.constant([0.]), loss_test.numpy() 71 | if index is not None: 72 | P_stencil = model(A_stencil, index=index, pos=pos, phase=phase) 73 | else: 74 | P_stencil = model(A_stencil, phase=phase) 75 | 76 | if not (phase == "Test" and epoch == 0): 77 | P_matrix = utils.compute_p2LFA(P_stencil, n, grid_size) 78 | 79 | P_matrix = tf.transpose(P_matrix, [2, 0, 1, 3, 4]) 80 | P_matrix_t = tf.transpose(P_matrix, [0, 1, 2, 4, 3], conjugate=True) 81 | 82 | A_c = tf.matmul(tf.matmul(P_matrix_t, A_matrices), P_matrix) 83 | index_to_remove = len(theta_x) * (-1 + n // (2 * grid_size)) + n // (2 * grid_size) - 1 84 | A_c = tf.reshape(A_c, (-1, int(theta_x.shape[0]) ** 2, (grid_size // 2) ** 2, (grid_size // 2) ** 2)) 85 | A_c_removed = tf.concat([A_c[:, :index_to_remove], A_c[:, index_to_remove + 1:]], 1) 86 | P_matrix_t_reshape = tf.reshape(P_matrix_t, 87 | (-1, int(theta_x.shape[0]) ** 2, (grid_size // 2) ** 2, grid_size ** 2)) 88 | P_matrix_reshape = tf.reshape(P_matrix, 89 | (-1, int(theta_x.shape[0]) ** 2, grid_size ** 2, (grid_size // 2) ** 2)) 90 | A_matrices_reshaped = tf.reshape(A_matrices, 91 | (-1, int(theta_x.shape[0]) ** 2, grid_size ** 2, grid_size ** 2)) 92 | A_matrices_removed = tf.concat( 93 | [A_matrices_reshaped[:, :index_to_remove], A_matrices_reshaped[:, index_to_remove + 1:]], 1) 94 | 95 | P_matrix_removed = tf.concat( 96 | [P_matrix_reshape[:, :index_to_remove], P_matrix_reshape[:, index_to_remove + 1:]], 1) 97 | P_matrix_t_removed = tf.concat( 98 | [P_matrix_t_reshape[:, :index_to_remove], P_matrix_t_reshape[:, index_to_remove + 1:]], 1) 99 | A_coarse_inv_removed = tf.matrix_solve(A_c_removed, P_matrix_t_removed) 100 | 101 | CGC_removed = tf.eye(grid_size ** 2, dtype=tf.complex128) \ 102 | - tf.matmul(tf.matmul(P_matrix_removed, A_coarse_inv_removed), A_matrices_removed) 103 | S_matrices_reshaped = tf.reshape(S_matrices, 104 | (-1, int(theta_x.shape[0]) ** 2, grid_size ** 2, grid_size ** 2)) 105 | S_removed = tf.concat( 106 | [S_matrices_reshaped[:, :index_to_remove], S_matrices_reshaped[:, index_to_remove + 1:]], 1) 107 | iteration_matrix_all = tf.matmul(tf.matmul(CGC_removed, S_removed), S_removed) 108 | 109 | if remove: 110 | if phase != 'Test': 111 | iteration_matrix = iteration_matrix_all 112 | for _ in range(0): 113 | iteration_matrix = tf.matmul(iteration_matrix_all, iteration_matrix_all) 114 | else: 115 | iteration_matrix = iteration_matrix_all 116 | loss = tf.reduce_mean( 117 | tf.reduce_max(tf.pow(tf.reduce_sum(tf.square(tf.abs(iteration_matrix)), [2, 3]), 1), 1)) 118 | else: 119 | loss = tf.reduce_mean( 120 | tf.reduce_mean(tf.reduce_sum(tf.square(tf.abs(iteration_matrix_all)), [2, 3]), 1)) 121 | 122 | print("Real loss: ", loss.numpy()) 123 | real_loss = loss.numpy() 124 | return loss, real_loss 125 | 126 | 127 | def grad(model, n, A_stencil, A_matrices, S_matrices, phase="Training", epoch=-1, grid_size=8, remove=True): 128 | with tf.GradientTape() as tape: 129 | loss_value, real_loss = loss(model, n, A_stencil, A_matrices, S_matrices, 130 | phase=phase, epoch=epoch, grid_size=grid_size, remove=remove) 131 | return tape.gradient(loss_value, m.variables), real_loss 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('--verbose', action='store_true', help="") 137 | parser.add_argument('--use-gpu', action='store_true', default=True, help="") 138 | parser.add_argument('--grid-size', default=8, type=int, help="") 139 | parser.add_argument('--batch-size', default=32, type=int, help="") 140 | parser.add_argument('--n-epochs', default=2, type=int, help="") 141 | parser.add_argument('--bc', default='periodic') 142 | 143 | args = parser.parse_args() 144 | 145 | if args.use_gpu: 146 | DEVICE = "/gpu:0" 147 | 148 | utils = Utils(grid_size=args.grid_size, device=DEVICE, bc=args.bc) 149 | 150 | random_string = ''.join(random.choices(string.digits, k=5)) # to make the run_name string unique 151 | run_name = f"regularization_grid_size={args.grid_size}_batch_size={args.batch_size}_{random_string}" 152 | writer = SummaryWriter(log_dir='runs/' + run_name) 153 | 154 | if args.bc == 'periodic': 155 | from model_periodicBC import Pnetwork 156 | else: 157 | from model_dirichletBC import Pnetwork 158 | 159 | # define network 160 | m = Pnetwork(grid_size=grid_size, device=DEVICE) 161 | 162 | root = tf.train.Checkpoint(optimizer=optimizer, model=m, optimizer_step=tf.train.get_or_create_global_step()) 163 | 164 | with tf.device(DEVICE): 165 | pi = tf.constant(np.pi) 166 | ci = tf.to_complex128(1j) 167 | 168 | A_stencils_test, A_matrices_test, S_matrices_test, num_of_modes = utils.get_A_S_matrices(num_test_samples, np.pi, 169 | grid_size, n_test) 170 | 171 | with tf.device(DEVICE): 172 | A_stencils_test = tf.convert_to_tensor(A_stencils_test, dtype=tf.double) 173 | A_matrices_test = tf.convert_to_tensor(A_matrices_test, dtype=tf.complex128) 174 | S_matrices_test = tf.reshape(tf.convert_to_tensor(S_matrices_test, dtype=tf.complex128), 175 | (-1, num_of_modes, num_of_modes, grid_size ** 2, grid_size ** 2)) 176 | 177 | A_stencils_train = np.array(utils.two_d_stencil(num_training_samples)) 178 | n_train_list = [16, 16, 32] 179 | initial_epsi = 1e-0 180 | 181 | numiter = -1 182 | for j in range(len(n_train_list)): 183 | A_stencils = A_stencils_train.copy() 184 | n_train = n_train_list[j] 185 | 186 | theta_x = np.array( 187 | [i * 2 * pi / n_train for i in range(-n_train // (2 * grid_size) + 1, n_train // (2 * grid_size) + 1)]) 188 | theta_y = np.array( 189 | [i * 2 * pi / n_train for i in range(-n_train // (2 * grid_size) + 1, n_train // (2 * grid_size) + 1)]) 190 | 191 | for epoch in range(args.n_epochs): 192 | print("epoch: {}".format(epoch)) 193 | order = np.random.permutation(num_training_samples) 194 | 195 | _, blackbox_test_loss = grad(model=m, n=n_test, A_stencil=A_stencils_test, 196 | A_matrices=A_matrices_test, S_matrices=S_matrices_test, 197 | phase="Test", epoch=0, grid_size=grid_size) 198 | 199 | if epoch % 1 == 0: # change to save once every X epochs 200 | root.save(file_prefix=checkpoint_prefix) 201 | 202 | for iter in tqdm(range(num_training_samples // args.batch_size)): 203 | numiter += 1 204 | idx = np.random.choice(A_stencils.shape[0], args.batch_size, replace=False) 205 | A_matrices = np.stack( 206 | [[utils.compute_A(A_stencils[idx], tx, ty, 1j, grid_size=grid_size) for tx in theta_x] for ty in 207 | theta_y]) 208 | A_matrices = A_matrices.transpose((2, 0, 1, 3, 4)) 209 | 210 | S_matrices = np.reshape(utils.compute_S(A_matrices.reshape((-1, grid_size ** 2, grid_size ** 2))), 211 | (-1, theta_x.shape[0], theta_x.shape[0], grid_size ** 2, grid_size ** 2)) 212 | with tf.device(DEVICE): 213 | A_stencils_tensor = tf.convert_to_tensor(A_stencils[idx], dtype=tf.double) 214 | A_matrices_tensor = tf.convert_to_tensor(A_matrices, dtype=tf.complex128) 215 | S_matrices_tensor = tf.convert_to_tensor(S_matrices, dtype=tf.complex128) 216 | 217 | _, blackbox_train_loss = grad(m, n_train, A_stencils_tensor, A_matrices_tensor, S_matrices_tensor, 218 | epoch=0, 219 | grid_size=grid_size, remove=True, phase="Test") 220 | grads, real_loss = grad(m, n_train, A_stencils_tensor, A_matrices_tensor, 221 | S_matrices_tensor, grid_size=grid_size, remove=True, phase="p") 222 | writer.add_scalar('loss', real_loss, numiter) 223 | writer.add_scalar('blackbox_train_loss', blackbox_train_loss, numiter) 224 | writer.add_scalar('blackbox_test_loss', blackbox_test_loss, numiter) 225 | optimizer.apply_gradients(zip(grads, m.variables), tf.train.get_or_create_global_step()) 226 | 227 | # add coarse grid problems: 228 | if j > 0: 229 | num_training_samples = num_training_samples // 2 230 | temp = utils.create_coarse_training_set(m, pi, num_training_samples) 231 | A_stencils_train = np.concatenate( 232 | [np.array(utils.two_d_stencil(num_training_samples)), temp], axis=0) 233 | num_training_samples = A_stencils_train.shape[0] 234 | -------------------------------------------------------------------------------- /model_periodicBC.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | class Pnetwork(tf.keras.Model): 6 | def __init__(self, grid_size=8, device="/cpu:0"): 7 | super(Pnetwork, self).__init__() 8 | self.grid_size = grid_size 9 | self.device = device 10 | with tf.device(device): 11 | width = 100 12 | self.linear0 = tf.keras.layers.Dense(width,kernel_regularizer=tf.keras.regularizers.l2(1e-7), 13 | use_bias=False) 14 | self.num_layers = 100 15 | for i in range(1, self.num_layers): 16 | setattr(self, "linear%i" % i, tf.keras.layers.Dense(width, use_bias=False, 17 | kernel_regularizer=tf.keras.regularizers.l2(1e-7), 18 | kernel_initializer=tf.initializers.truncated_normal( 19 | stddev=i ** (-1 / 2) * np.sqrt(2. / width)))) 20 | setattr(self, "bias_1%i" % i, tf.Variable([0.], dtype=tf.float64)) 21 | setattr(self, "linear%i" % (i + 1), tf.keras.layers.Dense(width, use_bias=False, 22 | kernel_regularizer=tf.keras.regularizers.l2( 23 | 1e-7), 24 | kernel_initializer=tf.zeros_initializer())) 25 | setattr(self, "bias_2%i" % i, tf.Variable([0.], dtype=tf.float64)) 26 | setattr(self, "bias_3%i" % i, tf.Variable([0.], dtype=tf.float64)) 27 | setattr(self, "bias_4%i" % i, tf.Variable([0.], dtype=tf.float64)) 28 | setattr(self, "multiplier_%i" % i, tf.Variable([1.], dtype=tf.float64)) 29 | 30 | self.output_layer = tf.keras.layers.Dense(4, use_bias=True, kernel_regularizer=tf.keras.regularizers.l2(1e-5)) 31 | self.new_output = tf.Variable(0.5*tf.random_normal(shape=[2*2*2*8], dtype=tf.float64), dtype=tf.float64) 32 | 33 | def call(self, inputs, black_box=False, index=None, pos=-1., phase='Training'): 34 | 35 | with tf.device(self.device): 36 | batch_size = inputs.shape[0] 37 | right_contributions_input = tf.gather(params=inputs, 38 | indices=[i for i in range(1, self.grid_size, 2)], axis=1) 39 | right_contributions_input = tf.gather(params=right_contributions_input, 40 | indices=[i for i in range(0, self.grid_size, 2)], axis=2) 41 | idx = [(i-1) % self.grid_size for i in range(0, self.grid_size, 2)] 42 | left_contributions_input = tf.gather(params=inputs, indices=idx, axis=1) 43 | left_contributions_input = tf.gather(params=left_contributions_input, 44 | indices=[i for i in range(0, self.grid_size, 2)], axis=2) 45 | left_contributions_input = tf.reshape(tensor=left_contributions_input, 46 | shape=(-1, self.grid_size//2, self.grid_size//2, 3, 3)) 47 | 48 | up_contributions_input = tf.gather(params=inputs, indices=[i for i in range(0, self.grid_size, 2)], axis=1) 49 | up_contributions_input = tf.gather(params=up_contributions_input, 50 | indices=[i for i in range(1, self.grid_size, 2)], axis=2) 51 | up_contributions_input = tf.reshape(tensor=up_contributions_input, 52 | shape=(-1, self.grid_size//2, self.grid_size//2, 3, 3)) 53 | 54 | down_contributions_input = tf.gather(params=inputs, 55 | indices=[i for i in range(0,self.grid_size,2)], axis=1) 56 | down_contributions_input = tf.gather(params=down_contributions_input, indices=idx, axis=2) 57 | down_contributions_input = tf.reshape(tensor=down_contributions_input, 58 | shape=(-1,self.grid_size//2,self.grid_size//2,3,3)) 59 | # 60 | center_contributions_input = tf.gather(params=inputs, 61 | indices=[i for i in range(0, self.grid_size, 2)], axis=1) 62 | center_contributions_input = tf.gather(params=center_contributions_input, indices=[i for i in range(0, self.grid_size, 2)], axis=2) 63 | center_contributions_input = tf.reshape(tensor=center_contributions_input, 64 | shape=(-1, self.grid_size // 2, self.grid_size // 2, 3, 3)) 65 | 66 | inputs_combined = tf.concat([right_contributions_input, left_contributions_input, 67 | up_contributions_input, down_contributions_input, 68 | center_contributions_input], 0) 69 | 70 | flattended = tf.reshape(inputs_combined, (-1, 9)) 71 | 72 | temp = (self.grid_size//2)**2 73 | 74 | # bug then augmented with doubled grid size 75 | flattended = tf.concat([flattended[:batch_size*temp], 76 | flattended[temp*batch_size:temp*2*batch_size], 77 | flattended[temp*2*batch_size:temp*3*batch_size], 78 | flattended[temp*3*batch_size:temp*4*batch_size], 79 | flattended[temp*4*batch_size:]],-1) 80 | 81 | x = self.linear0(flattended) 82 | x = tf.nn.relu(x) 83 | for i in range(1, self.num_layers, 2): 84 | x1 = getattr(self, "bias_1%i" % i) + x 85 | x1 = getattr(self, "linear%i" % i)(x1) 86 | x1 = x1 + getattr(self, "bias_2%i" % i) + x1 87 | x1 = tf.nn.relu(x1) 88 | x1 = x1 + getattr(self, "bias_3%i" % i) + x1 89 | x1 = getattr(self, "linear%i" % (i + 1))(x1) 90 | x1 = tf.multiply(x1, getattr(self, "multiplier_%i" % i)) 91 | x = x + x1 92 | x = x + getattr(self, "bias_4%i" % i) 93 | x = tf.nn.relu(x) 94 | 95 | x = self.output_layer(x) 96 | 97 | if index is not None: 98 | indices = tf.constant([[index]]) 99 | updates = [tf.to_double(pos)] 100 | shape = tf.constant([2*2*2*8]) 101 | scatter = tf.scatter_nd(indices,updates,shape) 102 | x = self.new_output+tf.reshape(scatter,(-1,2,2,8)) 103 | ld_contribution = x[:, :, :, 0] 104 | left_contributions_output = x[:, :, :, 1] 105 | lu_contribution = x[:, :, :, 2] 106 | down_contributions_output = x[:, :, :, 3] 107 | up_contributions_output = x[:, :, :, 4] 108 | ones = tf.ones_like(up_contributions_output) 109 | right_contributions_output = x[:, :, :, 6] 110 | rd_contribution = x[:, :, :, 5] 111 | ru_contribution = x[:, :, :, 7] 112 | first_row = tf.concat( 113 | [tf.expand_dims(ld_contribution, -1), tf.expand_dims(left_contributions_output, -1), 114 | tf.expand_dims(lu_contribution, -1)], -1) 115 | second_row = tf.concat([tf.expand_dims(down_contributions_output, -1), 116 | tf.expand_dims(ones, -1), tf.expand_dims(up_contributions_output, -1)], -1) 117 | third_row = tf.concat( 118 | [tf.expand_dims(rd_contribution, -1), tf.expand_dims(right_contributions_output, -1), 119 | tf.expand_dims(ru_contribution, -1)], -1) 120 | 121 | output = tf.stack([first_row, second_row, third_row], 0) 122 | output = tf.transpose(output, (1, 2, 3, 0, 4)) 123 | if not black_box: 124 | return tf.to_complex128(output) 125 | else: 126 | x = tf.reshape(x, (-1, self.grid_size//2, self.grid_size//2,4)) 127 | if black_box: 128 | up_contributions_output = tf.gather(inputs,[i for i in range(0,self.grid_size,2)],axis=1) 129 | up_contributions_output = tf.gather(up_contributions_output, 130 | [i for i in range(1,self.grid_size,2)], axis=2) 131 | up_contributions_output = -tf.reduce_sum(up_contributions_output[:,:,:,:,0],axis=-1)/tf.reduce_sum(up_contributions_output[:,:,:,:,1],axis=-1) 132 | 133 | left_contributions_output = tf.gather(inputs, idx, axis=1) 134 | left_contributions_output = tf.gather(left_contributions_output, 135 | [i for i in range(0,self.grid_size,2)], axis=2) 136 | left_contributions_output = -tf.reduce_sum(left_contributions_output[:, :, :, 2, :], 137 | axis=-1) / tf.reduce_sum( 138 | left_contributions_output[:, :, :, 1, :], axis=-1) 139 | 140 | right_contributions_output = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1) 141 | right_contributions_output = tf.gather(right_contributions_output, [i for i in range(0,self.grid_size,2)], axis=2) 142 | right_contributions_output = -tf.reduce_sum(right_contributions_output[:, :, :, 0, :], 143 | axis=-1) / tf.reduce_sum( 144 | right_contributions_output[:, :, :, 1, :], axis=-1) 145 | down_contributions_output = tf.gather(inputs, [i for i in range(0,self.grid_size,2)], axis=1) 146 | down_contributions_output = tf.gather(down_contributions_output, idx, axis=2) 147 | down_contributions_output = -tf.reduce_sum(down_contributions_output[:, :, :, :, 2], 148 | axis=-1) / tf.reduce_sum( 149 | down_contributions_output[:, :, :, :, 1], axis=-1) 150 | else: 151 | jm1 = [(i - 1) % (self.grid_size // 2) for i in range(self.grid_size // 2)] 152 | jp1 = [(i + 1) % (self.grid_size // 2) for i in range(self.grid_size // 2)] 153 | right_contributions_output = x[:,:,:,0]/(tf.gather(x[:, :, :, 1],jp1,axis=1)+x[:,:,:,0]) 154 | left_contributions_output = x[:,:,:,1]/(x[:,:,:,1]+tf.gather(x[:, :, :, 0],jm1,axis=1)) 155 | up_contributions_output = x[:,:,:,2]/(x[:,:,:,2]+tf.gather(x[:, :, :, 3],jp1,axis=2)) 156 | down_contributions_output = x[:,:,:,3]/(tf.gather(x[:, :, :, 2],jm1,axis=2)+x[:,:,:,3]) 157 | ones = tf.ones_like(down_contributions_output) 158 | 159 | #based on rule 2 given rule 1: 160 | up_right_contribution = tf.gather(inputs,[i for i in range(1,self.grid_size,2)],axis=1) 161 | up_right_contribution = tf.gather(up_right_contribution, [i for i in range(1,self.grid_size,2)], axis=2) 162 | up_right_contribution = up_right_contribution [:,:,:,0,1] 163 | right_up_contirbution = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1) 164 | right_up_contirbution = tf.gather(right_up_contirbution, [i for i in range(1,self.grid_size,2)], axis=2) 165 | right_up_contirbution_additional_term = right_up_contirbution[:, :, :, 0, 0] 166 | right_up_contirbution = right_up_contirbution[:,:,:,1,0] 167 | ru_center_ = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1) 168 | ru_center_ = tf.gather(ru_center_, [i for i in range(1,self.grid_size,2)], axis=2) 169 | ru_center_ = ru_center_[:,:,:,1,1] 170 | ru_contribution = -tf.expand_dims((right_up_contirbution_additional_term+ 171 | tf.multiply(right_up_contirbution,right_contributions_output) +\ 172 | tf.multiply(up_right_contribution,up_contributions_output))/ru_center_, -1) 173 | 174 | up_left_contribution = tf.gather(inputs, idx, axis=1) 175 | up_left_contribution = tf.gather(up_left_contribution, [i for i in range(1,self.grid_size,2)], axis=2) 176 | up_left_contribution = up_left_contribution[:, :, :, 2, 1] 177 | left_up_contirbution = tf.gather(inputs, idx, axis=1) 178 | left_up_contirbution = tf.gather(left_up_contirbution, [i for i in range(1,self.grid_size,2)], axis=2) 179 | left_up_contirbution_addtional_term = left_up_contirbution[:, :, :, 2, 0] 180 | left_up_contirbution = left_up_contirbution[:, :, :, 1, 0] 181 | lu_center_ = tf.gather(inputs, idx, axis=1) 182 | lu_center_ = tf.gather(lu_center_, [i for i in range(1,self.grid_size,2)], axis=2) 183 | lu_center_ = lu_center_[:, :, :, 1, 1] 184 | lu_contribution = -tf.expand_dims((left_up_contirbution_addtional_term+ 185 | tf.multiply(up_left_contribution , up_contributions_output) + \ 186 | tf.multiply(left_up_contirbution , left_contributions_output)) / lu_center_, -1) 187 | 188 | down_left_contribution = tf.gather(inputs, idx, axis=1) 189 | down_left_contribution = tf.gather(down_left_contribution, idx, axis=2) 190 | down_left_contribution = down_left_contribution[:, :, :, 2, 1] 191 | left_down_contirbution = tf.gather(inputs, idx, axis=1) 192 | left_down_contirbution = tf.gather(left_down_contirbution, idx, axis=2) 193 | left_down_contirbution_additional_term = left_down_contirbution[:, :, :, 2, 2] 194 | left_down_contirbution = left_down_contirbution[:, :, :, 1, 2] 195 | ld_center_ = tf.gather(inputs, idx, axis=1) 196 | ld_center_ = tf.gather(ld_center_, idx, axis=2) 197 | ld_center_ = ld_center_[:, :, :, 1, 1] 198 | ld_contribution = -tf.expand_dims((left_down_contirbution_additional_term+ 199 | tf.multiply(down_left_contribution , down_contributions_output) + \ 200 | tf.multiply(left_down_contirbution , left_contributions_output)) / ld_center_,-1) 201 | 202 | down_right_contribution = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1) 203 | down_right_contribution = tf.gather(down_right_contribution, idx, axis=2) 204 | down_right_contribution = down_right_contribution[:, :, :, 0, 1] 205 | right_down_contirbution = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1) 206 | right_down_contirbution = tf.gather(right_down_contirbution, idx, axis=2) 207 | right_down_contirbution_addtional_term = right_down_contirbution[:, :, :, 0, 2] 208 | right_down_contirbution = right_down_contirbution[:, :, :, 1, 2] 209 | rd_center_ = tf.gather(inputs, [i for i in range(1,self.grid_size,2)], axis=1) 210 | rd_center_ = tf.gather(rd_center_, idx, axis=2) 211 | rd_center_ = rd_center_[:, :, :, 1, 1] 212 | rd_contribution = -tf.expand_dims((right_down_contirbution_addtional_term+tf.multiply(down_right_contribution , down_contributions_output) + \ 213 | tf.multiply(right_down_contirbution , right_contributions_output)) / rd_center_,-1) 214 | 215 | first_row = tf.concat([ld_contribution, tf.expand_dims(left_contributions_output,-1), 216 | lu_contribution], -1) 217 | second_row = tf.concat([tf.expand_dims(down_contributions_output,-1), 218 | tf.expand_dims(ones, -1), tf.expand_dims(up_contributions_output, -1)], -1) 219 | third_row = tf.concat([rd_contribution, tf.expand_dims(right_contributions_output, -1), 220 | ru_contribution], -1) 221 | 222 | output = tf.stack([first_row, second_row, third_row], 0) 223 | output = tf.transpose(output, (1, 2, 3, 0, 4)) 224 | 225 | return tf.to_complex128(output) -------------------------------------------------------------------------------- /model_dirichletBC.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | class Pnetwork(tf.keras.Model): 8 | def __init__(self, grid_size=8, device="/cpu:0"): 9 | super(Pnetwork, self).__init__() 10 | self.grid_size = grid_size 11 | self.device = device 12 | with tf.device(self.device): 13 | self.num_layers = 100 14 | self.linear0 = tf.keras.layers.Dense(100, kernel_regularizer=tf.keras.regularizers.l2(1e-7), 15 | use_bias=False) 16 | for i in range(1, self.num_layers): 17 | setattr(self, "linear%i" % i, tf.keras.layers.Dense(100, use_bias=False, 18 | kernel_regularizer=tf.keras.regularizers.l2(1e-7), 19 | kernel_initializer=tf.initializers.truncated_normal( 20 | stddev=i ** (-1 / 2) * np.sqrt(2. / 100)))) 21 | setattr(self, "bias_1%i" % i, tf.Variable([0.], dtype=tf.float64)) 22 | setattr(self, "linear%i" % (i + 1), tf.keras.layers.Dense(100, use_bias=False, 23 | kernel_regularizer=tf.keras.regularizers.l2( 24 | 1e-7), 25 | kernel_initializer=tf.zeros_initializer())) 26 | setattr(self, "bias_2%i" % i, tf.Variable([0.], dtype=tf.float64)) 27 | setattr(self, "bias_3%i" % i, tf.Variable([0.], dtype=tf.float64)) 28 | setattr(self, "bias_4%i" % i, tf.Variable([0.], dtype=tf.float64)) 29 | setattr(self, "multiplier_%i" % i, tf.Variable([1.], dtype=tf.float64)) 30 | 31 | self.output_layer = tf.keras.layers.Dense(4, use_bias=True, 32 | kernel_regularizer=tf.keras.regularizers.l2(1e-5)) 33 | self.new_output = tf.Variable(0.5 * tf.random_normal(shape=[2 * 2 * 2 * 8], dtype=tf.float64), 34 | dtype=tf.float64) 35 | 36 | # print(len(self.layers)) 37 | 38 | def call(self, inputs, black_box=False, index=None, pos=-1., phase='Training'): 39 | with tf.device(self.device): 40 | batch_size = inputs.shape[0] 41 | right_contributions_input = tf.gather(params=inputs, 42 | indices=[i for i in range(2, self.grid_size, 2)], axis=1) 43 | right_contributions_input = tf.gather(params=right_contributions_input, 44 | indices=[i for i in range(1, self.grid_size, 2)], axis=2) 45 | idx = [i for i in range(0, self.grid_size - 1, 2)] 46 | left_contributions_input = tf.gather(params=inputs, indices=idx, axis=1) 47 | left_contributions_input = tf.gather(params=left_contributions_input, 48 | indices=[i for i in range(1, self.grid_size, 2)], axis=2) 49 | left_contributions_input = tf.reshape(tensor=left_contributions_input, 50 | shape=(-1, self.grid_size // 2, self.grid_size // 2, 3, 3)) 51 | 52 | up_contributions_input = tf.gather(params=inputs, indices=[i for i in range(1, self.grid_size, 2)], axis=1) 53 | up_contributions_input = tf.gather(params=up_contributions_input, 54 | indices=[i for i in range(2, self.grid_size, 2)], axis=2) 55 | up_contributions_input = tf.reshape(tensor=up_contributions_input, 56 | shape=(-1, self.grid_size // 2, self.grid_size // 2, 3, 3)) 57 | 58 | down_contributions_input = tf.gather(params=inputs, 59 | indices=[i for i in range(1, self.grid_size, 2)], axis=1) 60 | down_contributions_input = tf.gather(params=down_contributions_input, indices=idx, axis=2) 61 | down_contributions_input = tf.reshape(tensor=down_contributions_input, 62 | shape=(-1, self.grid_size // 2, self.grid_size // 2, 3, 3)) 63 | # 64 | center_contributions_input = tf.gather(params=inputs, 65 | indices=[i for i in range(1, self.grid_size, 2)], axis=1) 66 | center_contributions_input = tf.gather(params=center_contributions_input, 67 | indices=[i for i in range(1, self.grid_size, 2)], 68 | axis=2) 69 | center_contributions_input = tf.reshape(tensor=center_contributions_input, 70 | shape=(-1, self.grid_size // 2, self.grid_size // 2, 3, 3)) 71 | 72 | inputs_combined = tf.concat([right_contributions_input, left_contributions_input, 73 | up_contributions_input, down_contributions_input, 74 | center_contributions_input], 0) 75 | 76 | flattended = tf.reshape(inputs_combined, (-1, 9)) 77 | temp = (self.grid_size // 2) ** 2 78 | 79 | flattended = tf.concat([flattended[:batch_size * temp], 80 | flattended[temp * batch_size:temp * 2 * batch_size], 81 | flattended[temp * 2 * batch_size:temp * 3 * batch_size], 82 | flattended[temp * 3 * batch_size:temp * 4 * batch_size], 83 | flattended[temp * 4 * batch_size:]], -1) 84 | 85 | if not black_box: 86 | x = self.linear0(flattended) 87 | x = tf.nn.relu(x) 88 | for i in range(1, self.num_layers, 2): 89 | x1 = getattr(self, "bias_1%i" % i) + x 90 | x1 = getattr(self, "linear%i" % i)(x1) 91 | x1 = x1 + getattr(self, "bias_2%i" % i) + x1 92 | x1 = tf.nn.relu(x1) 93 | x1 = x1 + getattr(self, "bias_3%i" % i) + x1 94 | x1 = getattr(self, "linear%i" % (i + 1))(x1) 95 | x1 = tf.multiply(x1, getattr(self, "multiplier_%i" % i)) 96 | x = x + x1 97 | x = x + getattr(self, "bias_4%i" % i) 98 | x = tf.nn.relu(x) 99 | 100 | x = self.output_layer(x) 101 | 102 | if index is not None: 103 | indices = tf.constant([[index]]) 104 | updates = [tf.to_double(pos)] 105 | shape = tf.constant([2 * 2 * 2 * 8]) 106 | scatter = tf.scatter_nd(indices, updates, shape) 107 | x = self.new_output + tf.reshape(scatter, (-1, 2, 2, 8)) 108 | ld_contribution = x[:, :, :, 0] 109 | left_contributions_output = x[:, :, :, 1] 110 | lu_contribution = x[:, :, :, 2] 111 | down_contributions_output = x[:, :, :, 3] 112 | up_contributions_output = x[:, :, :, 4] 113 | ones = tf.ones_like(up_contributions_output) 114 | right_contributions_output = x[:, :, :, 6] 115 | rd_contribution = x[:, :, :, 5] 116 | ru_contribution = x[:, :, :, 7] 117 | first_row = tf.concat( 118 | [tf.expand_dims(ld_contribution, -1), tf.expand_dims(left_contributions_output, -1), 119 | tf.expand_dims(lu_contribution, -1)], -1) 120 | second_row = tf.concat([tf.expand_dims(down_contributions_output, -1), 121 | tf.expand_dims(ones, -1), tf.expand_dims(up_contributions_output, -1)], -1) 122 | third_row = tf.concat( 123 | [tf.expand_dims(rd_contribution, -1), tf.expand_dims(right_contributions_output, -1), 124 | tf.expand_dims(ru_contribution, -1)], -1) 125 | 126 | output = tf.stack([first_row, second_row, third_row], 0) 127 | output = tf.transpose(output, (1, 2, 3, 0, 4)) 128 | if not black_box: 129 | return tf.to_complex128(output) 130 | else: 131 | if not black_box: 132 | x = tf.reshape(x, (-1, self.grid_size // 2, self.grid_size // 2, 4)) 133 | if black_box: 134 | up_contributions_output = tf.gather(inputs, [i for i in range(1, self.grid_size, 2)], axis=1) 135 | up_contributions_output = tf.gather(up_contributions_output, 136 | [i for i in range(2, self.grid_size, 2)], axis=2) 137 | up_contributions_output = -tf.reduce_sum(up_contributions_output[:, :, :, :, 0], 138 | axis=-1) / tf.reduce_sum( 139 | up_contributions_output[:, :, :, :, 1], axis=-1) 140 | left_contributions_output = tf.gather(inputs, idx, axis=1) 141 | left_contributions_output = tf.gather(left_contributions_output, 142 | [i for i in range(1, self.grid_size, 2)], axis=2) 143 | left_contributions_output = -tf.reduce_sum(left_contributions_output[:, :, :, 2, :], 144 | axis=-1) / tf.reduce_sum( 145 | left_contributions_output[:, :, :, 1, :], axis=-1) 146 | right_contributions_output = tf.gather(inputs, [i for i in range(2, self.grid_size, 2)], axis=1) 147 | right_contributions_output = tf.gather(right_contributions_output, 148 | [i for i in range(1, self.grid_size, 2)], axis=2) 149 | right_contributions_output = -tf.reduce_sum(right_contributions_output[:, :, :, 0, :], 150 | axis=-1) / tf.reduce_sum( 151 | right_contributions_output[:, :, :, 1, :], axis=-1) 152 | 153 | down_contributions_output = tf.gather(inputs, [i for i in range(1, self.grid_size, 2)], axis=1) 154 | down_contributions_output = tf.gather(down_contributions_output, idx, axis=2) 155 | down_contributions_output = -tf.reduce_sum(down_contributions_output[:, :, :, :, 2], 156 | axis=-1) / tf.reduce_sum( 157 | down_contributions_output[:, :, :, :, 1], axis=-1) 158 | else: 159 | jm1 = [(i - 0) % (self.grid_size // 2) for i in range(self.grid_size // 2 - 1)] 160 | jp1 = [(i + 1) % (self.grid_size // 2) for i in range(self.grid_size // 2 - 1)] 161 | right_contributions_output = x[:, :-1, :, 0] / (tf.gather(x[:, :, :, 1], jp1, axis=1) + x[:, :-1, :, 0]) 162 | left_contributions_output = x[:, 1:, :, 1] / (x[:, 1:, :, 1] + tf.gather(x[:, :, :, 0], jm1, axis=1)) 163 | up_contributions_output = x[:, :, :-1, 2] / (x[:, :, :-1, 2] + tf.gather(x[:, :, :, 3], jp1, axis=2)) 164 | down_contributions_output = x[:, :, 1:, 3] / (tf.gather(x[:, :, :, 2], jm1, axis=2) + x[:, :, 1:, 3]) 165 | 166 | # complete right with black box: 167 | right_contributions_output_bb = tf.gather(inputs, [i for i in range(2, self.grid_size, 2)], axis=1) 168 | right_contributions_output_bb = tf.gather(right_contributions_output_bb, 169 | [i for i in range(1, self.grid_size, 2)], axis=2) 170 | right_contributions_output_bb = -tf.reduce_sum(right_contributions_output_bb[:, :, :, 0, :], 171 | axis=-1) / tf.reduce_sum( 172 | right_contributions_output_bb[:, :, :, 1, :], axis=-1) 173 | right_contributions_output_bb = tf.reshape(right_contributions_output_bb[:, -1, :], (1, 1, -1)) 174 | right_contributions_output = tf.concat([right_contributions_output, right_contributions_output_bb], 175 | axis=1) 176 | left_contributions_output_bb = tf.gather(inputs, idx, axis=1) 177 | left_contributions_output_bb = tf.gather(left_contributions_output_bb, 178 | [i for i in range(1, self.grid_size, 2)], axis=2) 179 | left_contributions_output_bb = -tf.reduce_sum(left_contributions_output_bb[:, :, :, 2, :], 180 | axis=-1) / tf.reduce_sum( 181 | left_contributions_output_bb[:, :, :, 1, :], axis=-1) 182 | left_contributions_output_bb = tf.reshape(left_contributions_output_bb[:, 0, :], (1, 1, -1)) 183 | left_contributions_output = tf.concat([left_contributions_output_bb, left_contributions_output], axis=1) 184 | up_contributions_output_bb = tf.gather(inputs, [i for i in range(1, self.grid_size, 2)], axis=1) 185 | up_contributions_output_bb = tf.gather(up_contributions_output_bb, 186 | [i for i in range(2, self.grid_size, 2)], axis=2) 187 | up_contributions_output_bb = -tf.reduce_sum(up_contributions_output_bb[:, :, :, :, 0], 188 | axis=-1) / tf.reduce_sum( 189 | up_contributions_output_bb[:, :, :, :, 1], axis=-1) 190 | up_contributions_output_bb = tf.reshape(up_contributions_output_bb[:, :, -1], (1, -1, 1)) 191 | up_contributions_output = tf.concat([up_contributions_output, up_contributions_output_bb], axis=-1) 192 | down_contributions_output_bb = tf.gather(inputs, [i for i in range(1, self.grid_size, 2)], axis=1) 193 | down_contributions_output_bb = tf.gather(down_contributions_output_bb, idx, axis=2) 194 | down_contributions_output_bb = -tf.reduce_sum(down_contributions_output_bb[:, :, :, :, 2], 195 | axis=-1) / tf.reduce_sum( 196 | down_contributions_output_bb[:, :, :, :, 1], axis=-1) 197 | down_contributions_output_bb = tf.reshape(down_contributions_output_bb[:, :, 0], (1, -1, 1)) 198 | down_contributions_output = tf.concat([down_contributions_output_bb, down_contributions_output], 199 | axis=-1) 200 | ones = tf.ones_like(down_contributions_output) 201 | idx = [i for i in range(0, self.grid_size - 1, 2)] 202 | 203 | # based on rule 2 given rule 1: 204 | # x,y = np.ix_([3, 1], [1, 3]) 205 | up_right_contribution = tf.gather(inputs, [i for i in range(2, self.grid_size, 2)], axis=1) 206 | up_right_contribution = tf.gather(up_right_contribution, [i for i in range(2, self.grid_size, 2)], axis=2) 207 | up_right_contribution = up_right_contribution[:, :, :, 0, 1] 208 | right_up_contirbution = tf.gather(inputs, [i for i in range(2, self.grid_size, 2)], axis=1) 209 | right_up_contirbution = tf.gather(right_up_contirbution, [i for i in range(2, self.grid_size, 2)], axis=2) 210 | right_up_contirbution_additional_term = right_up_contirbution[:, :, :, 0, 0] 211 | right_up_contirbution = right_up_contirbution[:, :, :, 1, 0] 212 | ru_center_ = tf.gather(inputs, [i for i in range(2, self.grid_size, 2)], axis=1) 213 | ru_center_ = tf.gather(ru_center_, [i for i in range(2, self.grid_size, 2)], axis=2) 214 | ru_center_ = ru_center_[:, :, :, 1, 1] 215 | ru_contribution = -tf.expand_dims((right_up_contirbution_additional_term + 216 | tf.multiply(right_up_contirbution, right_contributions_output) + \ 217 | tf.multiply(up_right_contribution, 218 | up_contributions_output)) / ru_center_, -1) 219 | 220 | # x,y = np.ix_([3, 1], [3, 1]) 221 | up_left_contribution = tf.gather(inputs, idx, axis=1) 222 | up_left_contribution = tf.gather(up_left_contribution, [i for i in range(2, self.grid_size, 2)], axis=2) 223 | up_left_contribution = up_left_contribution[:, :, :, 2, 1] 224 | left_up_contirbution = tf.gather(inputs, idx, axis=1) 225 | left_up_contirbution = tf.gather(left_up_contirbution, [i for i in range(2, self.grid_size, 2)], axis=2) 226 | left_up_contirbution_addtional_term = left_up_contirbution[:, :, :, 2, 0] 227 | left_up_contirbution = left_up_contirbution[:, :, :, 1, 0] 228 | lu_center_ = tf.gather(inputs, idx, axis=1) 229 | lu_center_ = tf.gather(lu_center_, [i for i in range(2, self.grid_size, 2)], axis=2) 230 | lu_center_ = lu_center_[:, :, :, 1, 1] 231 | lu_contribution = -tf.expand_dims((left_up_contirbution_addtional_term + 232 | tf.multiply(up_left_contribution, up_contributions_output) + \ 233 | tf.multiply(left_up_contirbution, 234 | left_contributions_output)) / lu_center_, -1) 235 | 236 | # x,y = np.ix_([1, 3], [3, 1]) 237 | down_left_contribution = tf.gather(inputs, idx, axis=1) 238 | down_left_contribution = tf.gather(down_left_contribution, idx, axis=2) 239 | down_left_contribution = down_left_contribution[:, :, :, 2, 1] 240 | left_down_contirbution = tf.gather(inputs, idx, axis=1) 241 | left_down_contirbution = tf.gather(left_down_contirbution, idx, axis=2) 242 | left_down_contirbution_additional_term = left_down_contirbution[:, :, :, 2, 2] 243 | left_down_contirbution = left_down_contirbution[:, :, :, 1, 2] 244 | ld_center_ = tf.gather(inputs, idx, axis=1) 245 | ld_center_ = tf.gather(ld_center_, idx, axis=2) 246 | ld_center_ = ld_center_[:, :, :, 1, 1] 247 | ld_contribution = -tf.expand_dims((left_down_contirbution_additional_term + 248 | tf.multiply(down_left_contribution, down_contributions_output) + \ 249 | tf.multiply(left_down_contirbution, 250 | left_contributions_output)) / ld_center_, -1) 251 | 252 | # x,y = np.ix_([1, 3], [1, 3]) 253 | down_right_contribution = tf.gather(inputs, [i for i in range(2, self.grid_size, 2)], axis=1) 254 | down_right_contribution = tf.gather(down_right_contribution, idx, axis=2) 255 | down_right_contribution = down_right_contribution[:, :, :, 0, 1] 256 | right_down_contirbution = tf.gather(inputs, [i for i in range(2, self.grid_size, 2)], axis=1) 257 | right_down_contirbution = tf.gather(right_down_contirbution, idx, axis=2) 258 | right_down_contirbution_addtional_term = right_down_contirbution[:, :, :, 0, 2] 259 | right_down_contirbution = right_down_contirbution[:, :, :, 1, 2] 260 | rd_center_ = tf.gather(inputs, [i for i in range(2, self.grid_size, 2)], axis=1) 261 | rd_center_ = tf.gather(rd_center_, idx, axis=2) 262 | rd_center_ = rd_center_[:, :, :, 1, 1] 263 | rd_contribution = -tf.expand_dims((right_down_contirbution_addtional_term + tf.multiply( 264 | down_right_contribution, down_contributions_output) + \ 265 | tf.multiply(right_down_contirbution, 266 | right_contributions_output)) / rd_center_, -1) 267 | 268 | first_row = tf.concat([ld_contribution, tf.expand_dims(left_contributions_output, -1), 269 | lu_contribution], -1) 270 | second_row = tf.concat([tf.expand_dims(down_contributions_output, -1), 271 | tf.expand_dims(ones, -1), tf.expand_dims(up_contributions_output, -1)], -1) 272 | third_row = tf.concat([rd_contribution, tf.expand_dims(right_contributions_output, -1), 273 | ru_contribution], -1) 274 | 275 | output = tf.stack([first_row, second_row, third_row], 0) 276 | output = tf.transpose(output, (1, 2, 3, 0, 4)) 277 | return tf.to_complex128(output) 278 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import csr_matrix 3 | import scipy.sparse.linalg 4 | import tensorflow as tf 5 | import scipy 6 | import math 7 | from functools import partial 8 | from tqdm import tqdm 9 | import pyamg 10 | from geometric_solver import geometric_solver 11 | 12 | 13 | class memoize(object): 14 | """cache the return value of a method 15 | 16 | This class is meant to be used as a decorator of methods. The return value 17 | from a given method invocation will be cached on the instance whose method 18 | was invoked. All arguments passed to a method decorated with memoize must 19 | be hashable. 20 | 21 | If a memoized method is invoked directly on its class the result will not 22 | be cached. Instead the method will be invoked like a static method: 23 | class Obj(object): 24 | @memoize 25 | def add_to(self, arg): 26 | return self + arg 27 | Obj.add_to(1) # not enough arguments 28 | Obj.add_to(1, 2) # returns 3, result is not cached 29 | """ 30 | 31 | def __init__(self, func): 32 | self.func = func 33 | 34 | def __get__(self, obj, objtype=None): 35 | if obj is None: 36 | return self.func 37 | return partial(self, obj) 38 | 39 | def __call__(self, *args, **kw): 40 | obj = args[0] 41 | try: 42 | cache = obj.__cache 43 | except AttributeError: 44 | cache = obj.__cache = {} 45 | key = (self.func, args[1:], frozenset(kw.items())) 46 | try: 47 | res = cache[key] 48 | except KeyError: 49 | res = cache[key] = self.func(*args, **kw) 50 | return res 51 | 52 | 53 | class Utils(object): 54 | def __init__(self, grid_size=8, device="/cpu:0", bc='dirichlet'): 55 | self.grid_size = grid_size 56 | self.device = device 57 | self.bc = bc 58 | 59 | @staticmethod 60 | def two_d_stencil_dirichletBC(num, epsilon, grid_size=8): 61 | epsi = epsilon * np.ones((grid_size, grid_size)) 62 | stencil = np.zeros((num, grid_size, grid_size, 3, 3)) 63 | 64 | diffusion_coeff = np.exp(np.random.normal(size=[num, grid_size, grid_size])) 65 | 66 | jm1 = [(i - 1) % grid_size for i in range(grid_size)] 67 | stencil[:, :, :, 1, 2] = -1. / 6 * (diffusion_coeff[:, jm1] + diffusion_coeff) 68 | stencil[:, :, :, 2, 1] = -1. / 6 * (diffusion_coeff + diffusion_coeff[:, :, jm1]) 69 | stencil[:, :, :, 2, 0] = -1. / 3 * diffusion_coeff[:, :, jm1] 70 | stencil[:, :, :, 2, 2] = -1. / 3 * diffusion_coeff 71 | 72 | jp1 = [(i + 1) % grid_size for i in range(grid_size)] 73 | 74 | stencil[:, :, :, 1, 0] = stencil[:, :, jm1, 1, 2] 75 | stencil[:, :, :, 0, 0] = stencil[:, jm1][:, :, jm1][:, :, :, 2, 2] 76 | stencil[:, :, :, 0, 1] = stencil[:, jm1][:, :, :, 2, 1] 77 | stencil[:, :, :, 0, 2] = stencil[:, jm1][:, :, jp1][:, :, :, 2, 0] 78 | stencil[:, :, :, 1, 1] = -np.sum(np.sum(stencil, axis=4), axis=3) + epsi 79 | 80 | stencil[:, :, 0, :, 0] = 0. 81 | stencil[:, :, -1, :, -1] = 0. 82 | stencil[:, 0, :, 0, :] = 0. 83 | stencil[:, -1, :, -1, :] = 0. 84 | return stencil 85 | 86 | @staticmethod 87 | def two_d_stencil_periodicBC(num, epsilon=0.0, epsilon_sparse=False, grid_size=8): 88 | # creates the discretization stencil of 2D diffusion 89 | # problems where the coefficients are drawn from a log-normal distribution. 90 | 91 | if not epsilon_sparse: 92 | epsi = epsilon * np.ones((grid_size, grid_size)) 93 | else: # a single epsilon value for each grid, to simulate boundaries 94 | epsilon_coord = np.random.randint(grid_size, size=num) 95 | epsi = np.zeros((num, grid_size, grid_size)) 96 | for i in range(num): 97 | epsi[i, epsilon_coord[i], epsilon_coord[i]] = np.exp(np.random.normal(loc=0.0, scale=1.0)) 98 | stencil = np.zeros((num, grid_size, grid_size, 3, 3)) 99 | diffusion_coeff = np.exp(np.random.normal(loc=0.0, scale=1.0, size=[num, grid_size, grid_size])) 100 | 101 | # lists of plus minus 1 coordinates, modulu gird size 102 | jm1 = [(i - 1) % grid_size for i in range(grid_size)] 103 | jp1 = [(i + 1) % grid_size for i in range(grid_size)] 104 | 105 | stencil[:, :, :, 1, 2] = -1. / 6 * (diffusion_coeff[:, jm1] + diffusion_coeff) 106 | stencil[:, :, :, 2, 1] = -1. / 6 * (diffusion_coeff + diffusion_coeff[:, :, jm1]) 107 | stencil[:, :, :, 2, 0] = -1. / 3 * diffusion_coeff[:, :, jm1] 108 | stencil[:, :, :, 2, 2] = -1. / 3 * diffusion_coeff 109 | stencil[:, :, :, 1, 0] = stencil[:, :, jm1, 1, 2] 110 | stencil[:, :, :, 0, 0] = stencil[:, jm1][:, :, jm1][:, :, :, 2, 2] 111 | stencil[:, :, :, 0, 1] = stencil[:, jm1][:, :, :, 2, 1] 112 | stencil[:, :, :, 0, 2] = stencil[:, jm1][:, :, jp1][:, :, :, 2, 0] 113 | stencil[:, :, :, 1, 1] = -np.sum(np.sum(stencil, axis=4), axis=3) + epsi 114 | return stencil 115 | 116 | def two_d_stencil(self, num, epsilon=0.0, epsilon_sparse=False, grid_size=8): 117 | if self.bc == 'dirichlet': 118 | return self.two_d_stencil_dirichletBC(num=num, epsilon=epsilon, grid_size=grid_size) 119 | elif self.bc == 'periodic': 120 | return self.two_d_stencil_periodicBC(num=num, epsilon=epsilon, epsilon_sparse=epsilon_sparse, grid_size=grid_size) 121 | 122 | def map_2_to_1(self, grid_size=8): 123 | # maps 2D coordinates to the corresponding 1D coordinate in the matrix. 124 | k = np.zeros((grid_size, grid_size, 3, 3)) 125 | M = np.reshape(np.arange(grid_size ** 2), (grid_size, grid_size)).T 126 | M = np.concatenate([M, M], 0) 127 | M = np.concatenate([M, M], 1) 128 | for i in range(3): 129 | I = (i - 1) % grid_size 130 | for j in range(3): 131 | J = (j - 1) % grid_size 132 | k[:, :, i, j] = M[I:I + grid_size, J:J + grid_size] 133 | return k 134 | 135 | def compute_S(self, A): 136 | # computes the iteration matrix of the relaxation, here Gauss-Seidel is used. 137 | # This function is called on each block seperately. 138 | n = A.shape[-1] 139 | B = np.copy(A) 140 | B[:, np.tril_indices(n, 0)[0], np.tril_indices(n, 0)[1]] = 0. # B is the upper part of A 141 | res = [] 142 | for i in range(A.shape[0]): # range(A.shape[0] // batch_size): 143 | res.append(scipy.linalg.solve_triangular(a=A[i], 144 | b=-B[i], 145 | lower=True, unit_diagonal=False, 146 | overwrite_b=False, debug=None, check_finite=True)) 147 | return np.stack(res, 0) 148 | 149 | def compute_A(self, stencils, tx, ty, ci, grid_size=8): 150 | # compute the diagonal block of the discretization matrix that corresponds 151 | # to the (tx,ty) Fourier mode, using Theorem 1. 152 | K = self.map_2_to_1(grid_size=grid_size) 153 | batch_size = stencils.shape[0] 154 | A = np.zeros((batch_size, grid_size ** 2, grid_size ** 2), dtype=np.complex128) 155 | X, Y = np.meshgrid(np.arange(-1, 2), np.arange(-1, 2)) 156 | fourier_component = np.exp(-ci * (tx * X + ty * Y)) 157 | fourier_component = np.reshape(fourier_component, (1, 3, 3)) 158 | 159 | for i in range(grid_size): 160 | for j in range(grid_size): 161 | I = int(K[i, j, 1, 1]) 162 | for k in range(3): 163 | for m in range(3): 164 | J = int(K[i, j, k, m]) 165 | A[:, I, J] = stencils[:, i, j, k, m] * fourier_component[:, k, m] 166 | return A 167 | 168 | @memoize 169 | def compute_A_indices(self, grid_size): 170 | K = self.map_2_to_1(grid_size=grid_size) 171 | A_idx = [] 172 | stencil_idx = [] 173 | for i in range(grid_size): 174 | for j in range(grid_size): 175 | I = int(K[i, j, 1, 1]) 176 | for k in range(3): 177 | for m in range(3): 178 | J = int(K[i, j, k, m]) 179 | A_idx.append([I, J]) 180 | stencil_idx.append([i, j, k, m]) 181 | return np.array(A_idx), stencil_idx 182 | 183 | def compute_csr_matrices(self, stencils, grid_size=8): 184 | A_idx, stencil_idx = self.compute_A_indices(grid_size) 185 | if len(stencils.shape) == 5: 186 | matrices = [] 187 | for stencil in stencils: 188 | matrices.append(csr_matrix(arg1=(stencil.reshape((-1)), (A_idx[:, 0], A_idx[:, 1])), 189 | shape=(grid_size ** 2, grid_size ** 2))) 190 | return np.asarray(matrices) 191 | else: 192 | return csr_matrix(arg1=(stencils.reshape((-1)), (A_idx[:, 0], A_idx[:, 1])), 193 | shape=(grid_size ** 2, grid_size ** 2)) 194 | 195 | def compute_p2(self, P_stencil, grid_size): 196 | indexes = self.get_p_matrix_indices_one(grid_size) 197 | P = csr_matrix(arg1=(P_stencil.numpy().reshape(-1), (indexes[:, 0], indexes[:, 1])), 198 | shape=(grid_size ** 2, (grid_size // 2) ** 2)) 199 | 200 | return P 201 | 202 | @memoize 203 | def get_p_matrix_indices_one(self, grid_size): 204 | K = self.map_2_to_1(grid_size=grid_size) 205 | indices = [] 206 | for ic in range(grid_size // 2): 207 | i = 2 * ic + 1 208 | for jc in range(grid_size // 2): 209 | j = 2 * jc + 1 210 | J = int(grid_size // 2 * jc + ic) 211 | for k in range(3): 212 | for m in range(3): 213 | I = int(K[i, j, k, m]) 214 | indices.append([I, J]) 215 | 216 | return np.array(indices) 217 | 218 | def get_A_S_matrices(self, num: int, pi: float, grid_size: int, n_size: int): 219 | """ 220 | :param num: number of samples to test 221 | :param pi: 222 | :param grid_size: 223 | :param n_size: 224 | :return: 225 | """ 226 | theta_x = np.array( 227 | [i * 2 * pi / n_size for i in range(-n_size // (2 * grid_size) + 1, n_size // (2 * grid_size) + 1)]) 228 | theta_y = np.array( 229 | [i * 2 * pi / n_size for i in range(-n_size // (2 * grid_size) + 1, n_size // (2 * grid_size) + 1)]) 230 | A_stencils_test = np.array(self.two_d_stencil(num)) 231 | A_matrices_test, S_matrices_test = self.create_matrices(A_stencils_test, grid_size, theta_x, theta_y) 232 | return A_stencils_test, A_matrices_test, S_matrices_test, len(theta_x) 233 | 234 | def create_matrices(self, A_stencil, grid_size, theta_x, theta_y): 235 | A_matrices = np.stack( 236 | [[self.compute_A(A_stencil, tx, ty, 1j, grid_size=grid_size) for tx in theta_x] for ty in theta_y]) 237 | A_matrices = A_matrices.transpose((2, 0, 1, 3, 4)) 238 | S_matrices = self.compute_S(A_matrices.reshape((-1, grid_size ** 2, grid_size ** 2))) 239 | return A_matrices, S_matrices 240 | 241 | @memoize 242 | def idx_array(self, x): 243 | I, J, batch_size, num_modes = x 244 | return np.array([[[[i1, i2, ell, I, J] for ell in range(batch_size)] for i2 245 | in range(num_modes)] for i1 246 | in range(num_modes)]).reshape(-1, 5).astype(np.int32) 247 | 248 | @memoize 249 | def get_p_matrix_indices(self, x): 250 | batch_size, grid_size = x 251 | K = self.map_2_to_1(grid_size=grid_size) 252 | value_indices = [] 253 | indices = [] 254 | for n in range(batch_size): 255 | for ic in range(grid_size // 2): 256 | i = 2 * ic 257 | for jc in range(grid_size // 2): 258 | j = 2 * jc 259 | J = int(grid_size // 2 * jc + ic) 260 | for k in range(3): 261 | for m in range(3): 262 | I = int(K[i, j, k, m]) 263 | value_indices.append([n, ic, jc, k, m]) 264 | indices.append([n, I, J]) 265 | return indices, value_indices 266 | 267 | def compute_p2_sparse(self, P_stencil, n, grid_size): 268 | with tf.device(self.device): 269 | indexes, values_indices = self.get_p_matrix_indices((n, grid_size)) 270 | P = tf.SparseTensor(indices=indexes, values=tf.gather_nd(P_stencil, values_indices), 271 | dense_shape=(n, grid_size ** 2, (grid_size // 2) ** 2)) 272 | return P 273 | 274 | def compute_sparse_matrix(self, stencils, batch_size, grid_size): 275 | with tf.device(self.device): 276 | indexes, values_indices = self.get_indices_compute_A((batch_size, grid_size)) 277 | tau = tf.SparseTensor(indices=indexes, 278 | values=tf.gather_nd(params=stencils, indices=values_indices), 279 | dense_shape=(batch_size, grid_size ** 2, grid_size ** 2)) 280 | return tau 281 | 282 | def compute_dense_matrix(self, stencils, batch_size, grid_size): 283 | with tf.device(self.device): 284 | indexes, values_indices = self.get_indices_compute_A((batch_size, grid_size)) 285 | tau = tf.scatter_nd(indices=indexes, 286 | updates=tf.gather_nd(params=stencils, indices=values_indices), 287 | shape=(batch_size, grid_size ** 2, grid_size ** 2)) 288 | return tau 289 | 290 | def compute_p2LFA(self, P_stencil, n, grid_size): 291 | batch_size = P_stencil.get_shape().as_list()[0] 292 | K = self.map_2_to_1(grid_size=grid_size) 293 | pi = np.pi 294 | theta_x = np.array(([i * 2 * pi / n for i in range(-n // (grid_size * 2) + 1, n // (grid_size * 2) + 1)])) 295 | theta_y = np.array([i * 2 * pi / n for i in range(-n // (grid_size * 2) + 1, n // (grid_size * 2) + 1)]) 296 | num_modes = theta_x.shape[0] 297 | 298 | X, Y = np.meshgrid(np.arange(-1, 2), np.arange(-1, 2)) 299 | with tf.device(self.device): 300 | P = tf.zeros((len(theta_y), len(theta_x), batch_size, grid_size ** 2, (grid_size // 2) ** 2), 301 | dtype=tf.complex128) 302 | modes = np.array([[np.exp(-1j * (tx * X + ty * Y)) for tx in theta_x] for ty in theta_y]) 303 | fourier_component = tf.to_complex128(np.tile(modes, (batch_size, 1, 1, 1, 1))) 304 | for ic in range(grid_size // 2): 305 | i = 2 * ic # ic is the index on the coarse grid, and i is the index on the fine grid 306 | for jc in range(grid_size // 2): 307 | j = 2 * jc # jc is the index on the coarse grid, and j is the index on the fine grid 308 | J = int(grid_size // 2 * jc + ic) 309 | for k in range(3): 310 | for m in range(3): 311 | I = int(K[i, j, k, m]) 312 | a = fourier_component[:, :, :, k, m] * tf.reshape(P_stencil[:, ic, jc, k, m], (-1, 1, 1)) 313 | a = tf.transpose(a, (1, 2, 0)) 314 | 315 | P = P + tf.to_complex128( 316 | tf.scatter_nd(indices=tf.constant(self.idx_array((I, J, int(batch_size), num_modes))), 317 | updates=tf.ones(batch_size * (num_modes ** 2)), 318 | shape=tf.constant([num_modes, num_modes, batch_size, grid_size ** 2, 319 | (grid_size // 2) ** 2]))) \ 320 | * tf.reshape(a, (theta_x.shape[0], theta_y.shape[0], batch_size, 1, 1)) 321 | return P 322 | 323 | def compute_stencil(self, A, grid_size): 324 | if isinstance(A, (tf.Tensor, tf.SparseTensor, tf.Variable)): 325 | indices, _ = self.get_indices_compute_A((A.shape.as_list()[0], grid_size)) 326 | stencil = tf.reshape(tf.gather_nd(A, indices), (A.shape[0], grid_size, grid_size, 3, 3)) 327 | return stencil 328 | else: 329 | indices = self.get_indices_compute_A_one(grid_size) 330 | stencil = np.array(A[indices[:, 0], indices[:, 1]]).reshape((grid_size, grid_size, 3, 3)) 331 | return tf.to_double(stencil) 332 | 333 | @memoize 334 | def get_indices_compute_A_one(self, grid_size): 335 | indices = [] 336 | K = self.map_2_to_1(grid_size=grid_size) 337 | for i in range(grid_size): 338 | for j in range(grid_size): 339 | I = int(K[i, j, 1, 1]) 340 | for k in range(3): 341 | for m in range(3): 342 | J = int(K[i, j, k, m]) 343 | indices.append([I, J]) 344 | 345 | return np.array(indices) 346 | 347 | @memoize 348 | def get_indices_compute_A(self, x): 349 | batch_size, grid_size = x 350 | indices = [] 351 | value_indices = [] 352 | K = self.map_2_to_1(grid_size=grid_size) 353 | for n in range(batch_size): 354 | for i in range(grid_size): 355 | for j in range(grid_size): 356 | I = int(K[i, j, 1, 1]) 357 | for k in range(3): 358 | for m in range(3): 359 | J = int(K[i, j, k, m]) 360 | indices.append([n, I, J]) 361 | value_indices.append([n, i, j, k, m]) 362 | return indices, value_indices 363 | 364 | def compute_coarse_matrix(self, model, A_stencil, A_matrices, grid_size, bb=True): 365 | with tf.device(self.device): 366 | if bb == True: 367 | P_stencil = model(inputs=A_stencil, black_box=True) 368 | else: 369 | P_stencil = model(inputs=A_stencil, black_box=False, phase="Test") 370 | P_matrix = self.compute_p2(P_stencil, grid_size) 371 | P_matrix_t = P_matrix.transpose() 372 | A_c = P_matrix_t @ A_matrices @ P_matrix 373 | return A_c, self.compute_stencil(A_c, (grid_size // 2)), P_matrix, P_matrix_t 374 | 375 | def compute_coarse_matrixLFA(self, model, n, A_stencil, A_matrices, grid_size, bb=True): 376 | with tf.device(self.device): 377 | if bb == True: 378 | P_stencil = model(inputs=A_stencil, black_box=True) 379 | else: 380 | P_stencil = model(inputs=A_stencil, black_box=False, phase="Test") 381 | P_matrix = (self.compute_p2LFA(P_stencil, n, grid_size)).numpy() 382 | P_matrix_t = (tf.transpose(P_matrix)).numpy() 383 | A_c = P_matrix_t @ A_matrices @ P_matrix 384 | return A_c, self.compute_stencil(A_c, (grid_size // 2)), P_matrix, P_matrix_t 385 | 386 | def compute_coarse_matrix_sparse(self, model, A_stencil, A_matrices, grid_size, bb=True): 387 | if bb == True: 388 | P_stencil = model(inputs=A_stencil, black_box=True) 389 | else: 390 | P_stencil = model(inputs=A_stencil, black_box=False, phase="Test") 391 | P_matrix = tf.to_double(self.compute_p2_sparse(P_stencil, P_stencil.shape.as_list()[0], grid_size)) 392 | P_matrix_t = tf.sparse_transpose(P_matrix, [0, 2, 1]) 393 | A_matrices = tf.squeeze(A_matrices) 394 | temp = tf.sparse_tensor_to_dense(P_matrix_t) 395 | q = tf.matmul(temp, tf.to_double(A_matrices)) 396 | A_c = tf.transpose(tf.matmul(temp, tf.transpose(q, [0, 2, 1])), [0, 2, 1]) 397 | return A_c, self.compute_stencil(tf.squeeze(A_c), (grid_size // 2)), P_matrix, P_matrix_t 398 | 399 | def create_coarse_training_set(self, m, pi, num_training_samples, bb=False, epsilon_sparse=False): 400 | m.grid_size = 16 # instead of 8 401 | stencils = [] 402 | additional_num_training_samples = num_training_samples 403 | theta_x = np.array( 404 | [i * 2 * pi / 8 for i in range(-8 // (2 * 8) + 1, 8 // (2 * 8) + 1)]) 405 | theta_y = np.array( 406 | [i * 2 * pi / 8 for i in range(-8 // (2 * 8) + 1, 8 // (2 * 8) + 1)]) 407 | 408 | A_stencils_ = self.two_d_stencil(additional_num_training_samples, grid_size=16, epsilon_sparse=epsilon_sparse) 409 | 410 | batch_size = 128 411 | for i in tqdm(range(additional_num_training_samples // batch_size)): 412 | A_matrices_ = np.stack( 413 | [[self.compute_A(A_stencils_[i * batch_size:(i + 1) * batch_size], tx, ty, 1j, grid_size=16) for tx in 414 | theta_x] for ty in theta_y]) 415 | A_matrices_ = A_matrices_.transpose((2, 0, 1, 3, 4)) 416 | A_stencils_temp = tf.convert_to_tensor(A_stencils_[i * batch_size:(i + 1) * batch_size], dtype=tf.double) 417 | A_matrices__temp = tf.convert_to_tensor(A_matrices_, dtype=tf.complex128) 418 | A_c, A_c_stencil, _, _ = self.compute_coarse_matrix_sparse(m, A_stencils_temp, A_matrices__temp, 16, 419 | bb=bb) 420 | A_c_stencil = A_c_stencil.numpy() 421 | stencils.append(A_c_stencil) 422 | m.grid_size = 8 423 | 424 | return np.concatenate(stencils) 425 | 426 | def mg_levels(self, model, n, A_stencil, A_matrices, grid_size, max_depth=3, bb=False): 427 | res = {'A0': A_matrices} 428 | for i in range(max_depth): 429 | A_matrices, A_stencil, P, _ = self.compute_coarse_matrix(model, n // (2 ** i), A_stencil, 430 | A_matrices, grid_size // (2 ** i), bb=bb) 431 | model.grid_size = model.grid_size // 2 432 | A_stencil = tf.convert_to_tensor([A_stencil]) 433 | res['A' + str(i + 1)] = A_matrices 434 | res['P' + str(i)] = P 435 | return res 436 | 437 | def solve_with_model(self, model, A_matrices, b, initial_guess, max_iterations, max_depth=3, blackbox=False, 438 | w_cycle=False): 439 | def prolongation_fn(A, args): 440 | is_blackbox = args["is_blackbox"] 441 | grid_size = int(math.sqrt(A.shape[0])) 442 | indices = self.get_indices_compute_A_one(grid_size) 443 | A_stencil = np.array(A[indices[:, 0], indices[:, 1]]).reshape((grid_size, grid_size, 3, 3)) 444 | model.grid_size = grid_size # TODO: infer grid_size automatically 445 | 446 | tf_A_stencil = tf.convert_to_tensor([A_stencil]) 447 | with tf.device(self.device): 448 | if is_blackbox: 449 | P_stencil = model(inputs=tf_A_stencil, black_box=True) 450 | else: 451 | P_stencil = model(inputs=tf_A_stencil, black_box=False, phase="Test") 452 | return self.compute_p2(P_stencil, grid_size).astype(np.double) # imaginary part should be zero 453 | 454 | prolongation_args = {"is_blackbox": blackbox} 455 | 456 | error_norms = [] 457 | 458 | # solver calls this function after each iteration 459 | def error_callback(x_k): 460 | error_norms.append(pyamg.util.linalg.norm(x_k)) 461 | 462 | solver = geometric_solver(A_matrices, prolongation_fn, prolongation_args, 463 | max_levels=max_depth) 464 | 465 | if w_cycle: 466 | cycle = 'W' 467 | else: 468 | cycle = 'V' 469 | residual_norms = [] 470 | x = solver.solve(b, x0=initial_guess, maxiter=max_iterations, cycle=cycle, residuals=residual_norms, tol=0, 471 | callback=error_callback) 472 | return x, residual_norms, error_norms, solver 473 | --------------------------------------------------------------------------------