├── drop_row_col.m ├── drop_zero_row_col.m ├── square_P.m ├── cr_demo.py ├── dataset.py ├── relaxation.py ├── prolongation_functions.py ├── utils.py ├── README.md ├── tf_sparse_utils.py ├── spec_cluster.py ├── block_periodic_delaunay.m ├── test_baseline.py ├── graph_net_model.py ├── test_model.py ├── ruge_stuben_custom_solver.py ├── configs.py ├── cr_solver.py ├── model.py ├── multigrid_utils.py ├── data.py └── train.py /drop_row_col.m: -------------------------------------------------------------------------------- 1 | function [rows, cols, values] = drop_row_col(A_rows, A_cols, A_values, total_size, indices) 2 | total_size = double(total_size); 3 | 4 | A = sparse(A_rows, A_cols, A_values, total_size, total_size); 5 | A(indices, :) = []; 6 | A(:, indices) = []; 7 | [rows, cols, values] = find(A); 8 | end -------------------------------------------------------------------------------- /drop_zero_row_col.m: -------------------------------------------------------------------------------- 1 | function [rows, cols, values] = drop_zero_row_col(A_rows, A_cols, A_values, total_size) 2 | total_size = double(total_size); 3 | 4 | A = sparse(A_rows, A_cols, A_values, total_size, total_size); 5 | A(~any(A,2), :) = []; 6 | A(:, ~any(A,1)) = []; 7 | [rows, cols, values] = find(A); 8 | end -------------------------------------------------------------------------------- /square_P.m: -------------------------------------------------------------------------------- 1 | function [rows, cols] = square_P(P_rows, P_cols, P_values, total_size, coarse_nodes) 2 | total_size = double(total_size); 3 | P_num_rows = total_size; 4 | [~, P_num_cols] = size(coarse_nodes); 5 | 6 | P = sparse(P_rows, P_cols, P_values, P_num_rows, P_num_cols); 7 | P_square = sparse(total_size, total_size); 8 | P_square(:, coarse_nodes) = P; 9 | [rows, cols] = find(P_square); 10 | end 11 | 12 | -------------------------------------------------------------------------------- /cr_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyamg 3 | import matplotlib.pyplot as plt 4 | 5 | from configs import CR_TEST 6 | from cr_solver import cr_solver 7 | 8 | size = 33**2 9 | grid_size = int(np.sqrt(size)) 10 | A = pyamg.gallery.poisson((grid_size, grid_size), type='FE', format='csr') 11 | # A = pyamg.gallery.stencil_grid(pyamg.gallery.diffusion_stencil_2d(epsilon=0.01, type='FE'), 12 | # (grid_size, grid_size), format='csr') 13 | 14 | xx = np.arange(0, grid_size, dtype=float) 15 | x, y = np.meshgrid(xx, xx) 16 | V = np.concatenate([[x.ravel()], [y.ravel()]], axis=0).T 17 | 18 | solver = cr_solver(A, 19 | keep=True, max_levels=2, 20 | CF=CR_TEST.data_config.splitting) 21 | print(solver) 22 | splitting = solver.levels[0].splitting 23 | 24 | C = np.where(splitting == 0)[0] 25 | F = np.where(splitting == 1)[0] 26 | plt.scatter(V[C, 0], V[C, 1], marker='s', s=12, 27 | color=[232.0 / 255, 74.0 / 255, 39.0 / 255]) 28 | plt.scatter(V[F, 0], V[F, 1], marker='s', s=12, 29 | color=[19.0 / 255, 41.0 / 255, 75.0 / 255]) 30 | plt.show() 31 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from sklearn.utils import shuffle 2 | 3 | 4 | class DataSet: 5 | def __init__(self, As, Ss, coarse_nodes_list, baseline_P_list): 6 | self.As = As 7 | self.Ss = Ss 8 | self.coarse_nodes_list = coarse_nodes_list 9 | self.baseline_P_list = baseline_P_list 10 | 11 | def shuffle(self): 12 | As, Ss, coarse_nodes_list, baseline_P_list = shuffle(self.As, 13 | self.Ss, 14 | self.coarse_nodes_list, 15 | self.baseline_P_list) 16 | return DataSet(As, Ss, coarse_nodes_list, baseline_P_list) 17 | 18 | def __getitem__(self, item): 19 | return DataSet( 20 | self.As[item], 21 | self.Ss[item], 22 | self.coarse_nodes_list[item], 23 | self.baseline_P_list[item] 24 | ) 25 | 26 | def __add__(self, other): 27 | return DataSet( 28 | self.As + other.As, 29 | self.Ss + other.Ss, 30 | self.coarse_nodes_list + other.coarse_nodes_list, 31 | self.baseline_P_list + other.baseline_P_list 32 | ) 33 | -------------------------------------------------------------------------------- /relaxation.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import tensorflow as tf 3 | 4 | import utils 5 | 6 | 7 | def relaxation_matrices(As, tensor=False): 8 | # computes the iteration matrix of the relaxation, here Gauss-Seidel is used. 9 | # This function is called on each block separately. 10 | num_As = len(As) 11 | grid_sizes = [A.shape[0] for A in As] 12 | Bs = [A.toarray() for A in As] 13 | for B, grid_size in zip(Bs, grid_sizes): 14 | B[utils.tril_indices(grid_size)[0], utils.tril_indices(grid_size)[1]] = 0. # B is the upper part of A 15 | res = [] 16 | if tensor: 17 | for i in range(num_As): 18 | res.append(tf.linalg.triangular_solve(As[i].toarray(), 19 | -Bs[i], 20 | lower=True)) 21 | else: 22 | for i in range(num_As): 23 | res.append(scipy.linalg.solve_triangular(a=As[i].toarray(), 24 | b=-Bs[i], 25 | lower=True, unit_diagonal=False, 26 | overwrite_b=True, debug=None, check_finite=False).astype( 27 | As[i].dtype)) 28 | return res 29 | -------------------------------------------------------------------------------- /prolongation_functions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from scipy.sparse import csr_matrix 3 | 4 | from model import csrs_to_graphs_tuple, graphs_tuple_to_sparse_tensor, to_prolongation_matrix_csr 5 | from tf_sparse_utils import sparse_tensor_to_csr 6 | 7 | 8 | def graphs_tuple_to_csr(graphs_tuple): 9 | row_indices = graphs_tuple.senders.numpy() 10 | col_indices = graphs_tuple.receivers.numpy() 11 | data = tf.squeeze(graphs_tuple.edges).numpy() 12 | num_nodes = graphs_tuple.n_node.numpy()[0] 13 | shape = (num_nodes, num_nodes) 14 | return csr_matrix((data, (row_indices, col_indices)), shape=shape) 15 | 16 | 17 | def model(A, coarse_nodes, baseline_P, C, graph_model, matlab_engine=None, normalize_rows_by_node=False, 18 | edge_indicators=True, node_indicators=True): 19 | with tf.device(":/gpu:0"): 20 | graphs_tuple = csrs_to_graphs_tuple([A], matlab_engine, coarse_nodes_list=[coarse_nodes], 21 | baseline_P_list=[baseline_P], 22 | edge_indicators=edge_indicators, 23 | node_indicators=node_indicators) 24 | output_graph = graph_model(graphs_tuple) 25 | P_square_tensor = graphs_tuple_to_sparse_tensor(output_graph) 26 | nodes_tensor = tf.squeeze(output_graph.nodes) 27 | nodes = nodes_tensor.numpy() 28 | 29 | P_square_csr = sparse_tensor_to_csr(P_square_tensor) 30 | P_csr = to_prolongation_matrix_csr(P_square_csr, coarse_nodes, baseline_P, nodes, 31 | normalize_rows_by_node=normalize_rows_by_node) 32 | return P_csr 33 | 34 | 35 | def baseline(A, splitting, baseline_P, C): 36 | return baseline_P 37 | 38 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | import os 4 | import glob 5 | from functools import lru_cache 6 | 7 | import numpy as np 8 | from collections import Counter 9 | 10 | 11 | def chunks(l, n): 12 | """Yield successive n-sized chunks from l.""" 13 | for i in range(0, len(l), n): 14 | yield l[i:i + n] 15 | 16 | 17 | def most_frequent_splitting(splittings): 18 | """Given a list of numpy array, returns the most frequent one""" 19 | list_of_tuples = [tuple(splitting) for splitting in splittings] # we need a list of immutable types 20 | counter = Counter(list_of_tuples) 21 | most_frequent_tuple = counter.most_common(1)[0][0] 22 | return np.array(most_frequent_tuple) 23 | 24 | 25 | def create_results_dir(run_name): 26 | results_dir = 'results/' + run_name 27 | os.makedirs(results_dir) 28 | 29 | # make a copy of all Python files, for reproducibility 30 | local_dir = os.path.dirname(__file__) 31 | for py_file in glob.glob(local_dir + '/*.py'): 32 | shutil.copy(py_file, results_dir) 33 | 34 | 35 | def write_config_file(run_name, config, seed): 36 | results_dir = 'results/' + run_name 37 | config_dict = {'train_config': config.train_config.__dict__, 38 | 'data_config': config.data_config.__dict__, 39 | 'model_config': config.model_config.__dict__, 40 | 'run_config': config.run_config.__dict__, 41 | 'seed': seed} 42 | with open(f'{results_dir}/config.json', 'w') as outfile: 43 | json.dump(config_dict, outfile) 44 | 45 | 46 | @lru_cache(maxsize=None) 47 | def tril_indices(grid_size): 48 | """Cached version of np.tril_indices used for creating relaxation matrices""" 49 | return np.tril_indices(grid_size) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Algebraic Multigrid using Graph Neural Networks 2 | Code for reproducing the experimental results in our paper: 3 | https://arxiv.org/abs/2003.05744 4 | 5 | ## Requirements 6 | * Python >= 3.6 7 | * Tensorflow >= 1.14 8 | * NumPy 9 | * PyAMG 10 | * Graph Nets: https://github.com/deepmind/graph_nets 11 | * MATLAB >= R2019a 12 | * MATLAB engine for Python: https://www.mathworks.com/help/matlab/matlab_external/install-the-matlab-engine-for-python.html 13 | * Requires modifying internals of Python library for efficient passing of NumPy arrays, as described here: https://stackoverflow.com/a/45290997 14 | * tqdm 15 | * Fire: https://github.com/google/python-fire 16 | * scikit-learn 17 | * Meshpy: https://documen.tician.de/meshpy/index.html 18 | 19 | 20 | ## Training 21 | ### Graph Laplacian 22 | ``` 23 | python train.py 24 | ``` 25 | Model checkpoint is saved at 'training_dir/*model_id*', where *model_id* is a randomly generated 5 digit string. 26 | 27 | Tensorboard log files are outputted to 'tb_dir/*model_id*'. 28 | 29 | A copy of the .py files and a JSON file that describes the configuration are saved to 'results/*model_id*'. 30 | 31 | A random seed can be specified by setting a `-seed` argument. 32 | ### Spectral clustering 33 | ``` 34 | python train.py -config SPEC_CLUSTERING_TRAIN -eval-config SPEC_CLUSTERING_EVAL 35 | ``` 36 | 37 | ### Ablation study 38 | ``` 39 | python train.py -config GRAPH_LAPLACIAN_ABLATION_MLP2 40 | python train.py -config GRAPH_LAPLACIAN_ABLATION_MP2 41 | python train.py -config GRAPH_LAPLACIAN_ABLATION_NO_CONCAT 42 | python train.py -config GRAPH_LAPLACIAN_ABLATION_NO_INDICATORS 43 | ``` 44 | Other model configurations and hyper-parameters can be trained by creating `Config` objects in `configs.py`, and setting the appropriate `-config` argument. 45 | 46 | ## Evaluation 47 | ### Graph Laplacian lognormal distribution 48 | ``` 49 | python test_model.py -model-name 12345 50 | ``` 51 | Replace `12345` by the *model_id* of a previously trained model. 52 | 53 | Results are saved at 'results/*model_id*'. 54 | 55 | ### Graph Laplacian uniform distribution 56 | ``` 57 | python test_model.py -model-name 12345 -config GRAPH_LAPLACIAN_UNIFORM_TEST 58 | ``` 59 | 60 | ### Finite element 61 | ``` 62 | python test_model.py -model-name 12345 -config FINITE_ELEMENT_TEST 63 | ``` 64 | 65 | ### Spectral clustering 66 | ``` 67 | python spec_cluster.py -model-name 12345 68 | ``` -------------------------------------------------------------------------------- /tf_sparse_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from scipy.sparse import coo_matrix 4 | 5 | 6 | def to_sparse(a): 7 | # tf.contrib.layers.dense_to_sparse has a memory leak issue, see: 8 | # https://github.com/tensorflow/tensorflow/issues/17590 9 | zero = tf.constant(0, a.dtype) 10 | indices = tf.where(tf.not_equal(a, zero)) 11 | values = tf.gather_nd(a, indices) 12 | shape = a.shape 13 | return tf.SparseTensor(indices, values, shape) 14 | 15 | 16 | def to_dense(a): 17 | # the tf.sparse.to_dense function does not support back-propagation, see: 18 | # https://github.com/tensorflow/tensorflow/issues/6391 19 | # bug is fixed in Tensorflow 1.14: https://github.com/tensorflow/tensorflow/releases/tag/v1.14.0 20 | # for older versions, workaround is performing addition with dense zero tensor 21 | # return tf.sparse.add(a, tf.zeros(a.shape, dtype=a.dtype)) 22 | return tf.sparse.to_dense(a) 23 | 24 | 25 | def csr_to_sparse_tensor(a): 26 | # https://stackoverflow.com/a/43130239 27 | a_coo = a.tocoo() 28 | a_coo.eliminate_zeros() 29 | indices = np.mat([a_coo.row, a_coo.col]).transpose() 30 | tensor = tf.SparseTensor(indices, a_coo.data, a_coo.shape) 31 | return tf.sparse.reorder(tensor) 32 | # return tensor 33 | 34 | 35 | def sparse_tensor_to_csr(a): 36 | indices = a.indices.numpy() 37 | rows = indices[:, 0] 38 | cols = indices[:, 1] 39 | data = a.values.numpy() 40 | shape = (a.shape[0].value, a.shape[1].value) 41 | a_coo = coo_matrix((data, (rows, cols)), shape=shape) 42 | return a_coo.tocsr() 43 | 44 | 45 | def sparse_multiply(a, b): 46 | # there are multiple ways to do it in Tensorflow, see: 47 | # https://www.tensorflow.org/api_docs/python/tf/sparse/sparse_dense_matmul 48 | # https://stackoverflow.com/questions/34030140/is-sparse-tensor-multiplication-implemented-in-tensorflow 49 | dense_a = to_dense(a) 50 | dense_b = to_dense(b) 51 | return tf.matmul(dense_a, dense_b, a_is_sparse=True, b_is_sparse=True) 52 | 53 | 54 | def pad_diagonal(a, padded_length): 55 | # given square matrix "a", pad with 1's on diagonal until matrix is of size "padded_length" 56 | a_length = a.shape[0].value 57 | if a_length > padded_length: 58 | raise RuntimeError(f"padded length {padded_length} is larger than matrix length {a_length}") 59 | if a_length == padded_length: 60 | return a 61 | 62 | a_indices = a.indices 63 | new_range = tf.range(a_length, padded_length, dtype=tf.int64) 64 | new_indices = tf.tile(tf.expand_dims(new_range, axis=1), [1, 2]) 65 | padded_indices = tf.concat([a_indices, new_indices], axis=0) 66 | 67 | a_values = a.values 68 | new_values = tf.ones(padded_length - a_length, dtype=tf.float64) 69 | padded_values = tf.concat([a_values, new_values], axis=0) 70 | 71 | padded_shape = [padded_length, padded_length] 72 | return tf.SparseTensor(indices=padded_indices, values=padded_values, dense_shape=padded_shape) 73 | 74 | -------------------------------------------------------------------------------- /spec_cluster.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | 4 | import fire 5 | import numpy as np 6 | import tensorflow as tf 7 | from scipy.sparse.linalg import lobpcg 8 | from tqdm import tqdm 9 | import matlab.engine 10 | 11 | import configs 12 | from data import generate_A_spec_cluster 13 | from model import get_model 14 | from prolongation_functions import model, baseline 15 | from ruge_stuben_custom_solver import ruge_stuben_custom_solver 16 | 17 | 18 | def precond_test(model_name=None, test_config='GRAPH_LAPLACIAN_TEST', seed=1): 19 | if model_name is None: 20 | raise RuntimeError("model name required") 21 | model_name = str(model_name) 22 | matlab_engine = matlab.engine.start_matlab() 23 | 24 | # fix random seeds for reproducibility 25 | np.random.seed(seed) 26 | tf.random.set_random_seed(seed) 27 | matlab_engine.eval(f'rng({seed})') 28 | 29 | test_config = getattr(configs, test_config).test_config 30 | config_file = f"results/{model_name}/config.json" 31 | with open(config_file) as f: 32 | data = json.load(f) 33 | model_config = configs.ModelConfig(**data['model_config']) 34 | run_config = configs.RunConfig(**data['run_config']) 35 | 36 | graph_model = get_model(model_name, model_config, run_config, matlab_engine) 37 | 38 | max_levels = 2 39 | cycle = 'V' 40 | size = 1000 41 | num_samples = 100 42 | num_clusters = 2 43 | dimension = 2 44 | gamma = None 45 | distribution = 'moons' 46 | 47 | model_prolongation = partial(model, graph_model=graph_model, normalize_rows_by_node=False, 48 | matlab_engine=matlab_engine) 49 | baseline_prolongation = baseline 50 | 51 | model_wins = 0 52 | base_wins = 0 53 | ties = 0 54 | total_model_iters = 0 55 | total_base_iters = 0 56 | for _ in tqdm(range(num_samples)): 57 | A = generate_A_spec_cluster(size, num_clusters, unit_std=True, dim=dimension, 58 | dist=distribution, gamma=gamma, distance=False, n_neighbors=10) 59 | 60 | model_solver = ruge_stuben_custom_solver(A, model_prolongation, 61 | strength=test_config.strength, 62 | presmoother=test_config.presmoother, 63 | postsmoother=test_config.postsmoother, 64 | CF=test_config.splitting, 65 | max_levels=max_levels) 66 | 67 | model_precond = model_solver.aspreconditioner(cycle=cycle) 68 | 69 | base_solver = ruge_stuben_custom_solver(A, baseline_prolongation, 70 | strength=test_config.strength, 71 | presmoother=test_config.presmoother, 72 | postsmoother=test_config.postsmoother, 73 | CF=test_config.splitting, 74 | max_levels=max_levels) 75 | base_precond = base_solver.aspreconditioner(cycle=cycle) 76 | 77 | x0 = np.random.uniform(size=[A.shape[0], num_clusters]) 78 | try: 79 | _, _, model_res_norms = lobpcg(A, x0, M=model_precond, tol=1.e-12, maxiter=100, 80 | largest=False, retResidualNormsHistory=True) 81 | 82 | _, _, base_res_norms = lobpcg(A, x0, M=base_precond, tol=1.e-12, maxiter=100, 83 | largest=False, retResidualNormsHistory=True) 84 | except np.linalg.LinAlgError: 85 | print("error") 86 | continue 87 | 88 | model_iters = len(model_res_norms) 89 | base_iters = len(base_res_norms) 90 | if model_iters < base_iters: 91 | model_wins += 1 92 | elif model_iters > base_iters: 93 | base_wins += 1 94 | else: 95 | ties += 1 96 | 97 | total_model_iters += model_iters 98 | total_base_iters += base_iters 99 | 100 | print(f"model wins: {model_wins}") 101 | print(f"base wins: {base_wins}") 102 | print(f"win ratio: {model_wins / (model_wins + base_wins)}") 103 | print(f"avg model iters: {total_model_iters / num_samples}") 104 | print(f"avg base iters: {total_base_iters / num_samples}") 105 | print(f"iters ratio: {total_model_iters / total_base_iters}") 106 | 107 | 108 | if __name__ == '__main__': 109 | config = tf.ConfigProto() 110 | config.gpu_options.allow_growth = True 111 | tf.enable_eager_execution(config=config) 112 | 113 | fire.Fire(precond_test) 114 | -------------------------------------------------------------------------------- /block_periodic_delaunay.m: -------------------------------------------------------------------------------- 1 | function [A, points] = block_periodic_delaunay(k,b) 2 | % Create b by b block - doubly periodic triangulation with k vertices per 3 | % block. Put random lognormal coefficients on the edges and create graph 4 | % Laplacian. 5 | if (b < 3) 6 | disp('WARNING: b cannot be less than 3, resetting b = 3') 7 | b = 3; 8 | end 9 | 10 | % Python Matlab engine convert integers to int64, we need to convert 11 | % back to double 12 | k = double(k); 13 | b = double(b); 14 | 15 | tri = Shilush(k); % 3 by 3 identical blocks, each with the same k randomly 16 | % distributed vertices. Return the Delaunay 17 | % triangulation 18 | A = Compute_Matrix(tri.ConnectivityList,k,b); % Compute a block-doubly periodic graph Laplacian 19 | % with random coefficients on the edges, with 20 | % the edges defined by tri. 21 | points = tri.Points; 22 | end 23 | 24 | 25 | function A = Compute_Matrix(tri,k,b) 26 | 27 | A = sparse(b*k,b*k); 28 | 29 | % For convenience, first create the Laplacian matrix of the 3 by 3 block triangulation. 30 | % Then use only the middle block to construct A. 31 | B = sparse(9*k,9*k); 32 | 33 | % Make a standard log-normal random matrix of size k by 9*k. 34 | % Only part of it will be used 35 | R = -lognrnd(0,1,k,9*k); 36 | 37 | tri = sort(tri,2); 38 | for i = 1:length(tri) % Run over all the triangles 39 | t = tri(i,:); % For convenience, store triangle i in array t 40 | B(t(1),t(2)) = R(mod(t(1)-1,k)+1,t(2)-t(1)); % Maintains periodicity in the middle block 41 | B(t(1),t(3)) = R(mod(t(1)-1,k)+1,t(3)-t(1)); % Maintains periodicity in the middle block 42 | B(t(2),t(3)) = R(mod(t(2)-1,k)+1,t(3)-t(2)); % Maintains periodicity in the middle block 43 | % Symmetrize 44 | B(t(2),t(1)) = B(t(1),t(2)); 45 | B(t(3),t(1)) = B(t(1),t(3)); 46 | B(t(3),t(2)) = B(t(2),t(3)); 47 | end 48 | 49 | % Plug the rows of B corresponding to the (2,2) block into A in the proper places, 50 | % given that B has 3 by 3 blocks, while A has b by b blocks. This will 51 | % define the (2,2) block of A 52 | 53 | A((b+1)*k+1:(b+2)*k,1:3*k) = B((3+1)*k+1:(3+2)*k,1:3*k); 54 | A((b+1)*k+1:(b+2)*k,b*k+1:(b+3)*k) = B((3+1)*k+1:(3+2)*k,3*k+1:(3+3)*k); 55 | A((b+1)*k+1:(b+2)*k,2*b*k+1:(2*b+3)*k) = B((3+1)*k+1:(3+2)*k,2*3*k+1:(2*3+3)*k); 56 | 57 | % Now create the rest of the doubly periodic A from its (2,2) block. 58 | for ib = 0:b^2-1 % Run over the blocks, starting from 0 for convenience. 59 | % ib = b+1 corresponds to block (2,2). 60 | A(ib*k+1:ib*k+k,ib*k+1:ib*k+k) = A((b+1)*k+1:(b+1)*k+k,(b+1)*k+1:(b+1)*k+k); 61 | if (mod(ib,b) == 0) % North block 62 | A(ib*k+1:ib*k+k,(ib-1+b)*k+1:(ib-1+b)*k+k) = A((b+1)*k+1:(b+1)*k+k,b*k+1:b*k+k); 63 | else 64 | A(ib*k+1:ib*k+k,(ib-1)*k+1:(ib-1)*k+k) = A((b+1)*k+1:(b+1)*k+k,b*k+1:b*k+k); 65 | end 66 | if (mod(ib,b) == b-1) % South block 67 | A(ib*k+1:ib*k+k,(ib+1-b)*k+1:(ib+1-b)*k+k) = A((b+1)*k+1:(b+1)*k+k,(b+2)*k+1:(b+2)*k+k); 68 | else 69 | A(ib*k+1:ib*k+k,(ib+1)*k+1:(ib+1)*k+k) = A((b+1)*k+1:(b+1)*k+k,(b+2)*k+1:(b+2)*k+k); 70 | end 71 | if (ib < b) % West block 72 | A(ib*k+1:ib*k+k,(ib-b+b^2)*k+1:(ib-b+b^2)*k+k) = A((b+1)*k+1:(b+1)*k+k,k+1:k+k); 73 | else 74 | A(ib*k+1:ib*k+k,(ib-b)*k+1:(ib-b)*k+k) = A((b+1)*k+1:(b+1)*k+k,k+1:k+k); 75 | end 76 | if (ib >= b^2-b) % East block 77 | A(ib*k+1:ib*k+k,(ib+b-b^2)*k+1:(ib+b-b^2)*k+k) = A((b+1)*k+1:(b+1)*k+k,(2*b+1)*k+1:(2*b+1)*k+k); 78 | else 79 | A(ib*k+1:ib*k+k,(ib+b)*k+1:(ib+b)*k+k) = A((b+1)*k+1:(b+1)*k+k,(2*b+1)*k+1:(2*b+1)*k+k); 80 | end 81 | if (ib == 0) % NorthWest block 82 | A(ib*k+1:ib*k+k,(b^2-1)*k+1:(b^2-1)*k+k) = A((b+1)*k+1:(b+1)*k+k,1:k); 83 | elseif (mod(ib,b) == 0) 84 | A(ib*k+1:ib*k+k,(ib-b-1+b)*k+1:(ib-b-1+b)*k+k) = A((b+1)*k+1:(b+1)*k+k,1:k); 85 | elseif (ib < b) 86 | A(ib*k+1:ib*k+k,(ib-b-1+b^2)*k+1:(ib-b-1+b^2)*k+k) = A((b+1)*k+1:(b+1)*k+k,1:k); 87 | else 88 | A(ib*k+1:ib*k+k,(ib-b-1)*k+1:(ib-b-1)*k+k) = A((b+1)*k+1:(b+1)*k+k,1:k); 89 | end 90 | if (ib == b-1) % SouthWest block 91 | A(ib*k+1:ib*k+k,(b^2-b)*k+1:(b^2-b)*k+k) = A((b+1)*k+1:(b+1)*k+k,2*k+1:2*k+k); 92 | elseif (mod(ib,b) == b-1) 93 | A(ib*k+1:ib*k+k,(ib-b+1-b)*k+1:(ib-b+1-b)*k+k) = A((b+1)*k+1:(b+1)*k+k,2*k+1:2*k+k); 94 | elseif (ib < b) 95 | A(ib*k+1:ib*k+k,(ib-b+1+b^2)*k+1:(ib-b+1+b^2)*k+k) = A((b+1)*k+1:(b+1)*k+k,2*k+1:2*k+k); 96 | else 97 | A(ib*k+1:ib*k+k,(ib-b+1)*k+1:(ib-b+1)*k+k) = A((b+1)*k+1:(b+1)*k+k,2*k+1:2*k+k); 98 | end 99 | if (ib == b^2-b) % NorthEast block 100 | A(ib*k+1:ib*k+k,(b-1)*k+1:(b-1)*k+k) = A((b+1)*k+1:(b+1)*k+k,2*b*k+1:2*b*k+k); 101 | elseif (mod(ib,b) == 0) 102 | A(ib*k+1:ib*k+k,(ib+b-1+b)*k+1:(ib+b-1+b)*k+k) = A((b+1)*k+1:(b+1)*k+k,2*b*k+1:2*b*k+k); 103 | elseif (ib >= b^2-b) 104 | A(ib*k+1:ib*k+k,(ib+b-1-b^2)*k+1:(ib+b-1-b^2)*k+k) = A((b+1)*k+1:(b+1)*k+k,2*b*k+1:2*b*k+k); 105 | else 106 | A(ib*k+1:ib*k+k,(ib+b-1)*k+1:(ib+b-1)*k+k) = A((b+1)*k+1:(b+1)*k+k,2*b*k+1:2*b*k+k); 107 | end 108 | if (ib == b^2-1) % SouthEast block 109 | A(ib*k+1:ib*k+k,1:k) = A((b+1)*k+1:(b+1)*k+k,(2*b+2)*k+1:(2*b+2)*k+k); 110 | elseif (mod(ib,b) == b-1) 111 | A(ib*k+1:ib*k+k,(ib+b+1-b)*k+1:(ib+b+1-b)*k+k) = A((b+1)*k+1:(b+1)*k+k,(2*b+2)*k+1:(2*b+2)*k+k); 112 | elseif (ib >= b^2-b) 113 | A(ib*k+1:ib*k+k,(ib+b+1-b^2)*k+1:(ib+b+1-b^2)*k+k) = A((b+1)*k+1:(b+1)*k+k,(2*b+2)*k+1:(2*b+2)*k+k); 114 | else 115 | A(ib*k+1:ib*k+k,(ib+b+1)*k+1:(ib+b+1)*k+k) = A((b+1)*k+1:(b+1)*k+k,(2*b+2)*k+1:(2*b+2)*k+k); 116 | end 117 | 118 | end 119 | 120 | % Zerosum 121 | for i = 1:length(A) 122 | A(i,i) = -sum(A(i,:)); 123 | end 124 | 125 | % Python Matlab engine does not support sparse arrays 126 | A = full(A); 127 | end 128 | 129 | function tri = Shilush(k) 130 | 131 | x = rand(k,1); 132 | y = rand(k,1); 133 | X = [x-1;x-1;x-1;x;x;x;x+1;x+1;x+1]; 134 | Y = [y-1;y;y+1;y-1;y;y+1;y-1;y;y+1]; 135 | tri = delaunayTriangulation(X,Y); 136 | 137 | end 138 | -------------------------------------------------------------------------------- /test_baseline.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import numpy as np 3 | from pyamg import ruge_stuben_solver, smoothed_aggregation_solver, rootnode_solver 4 | from tqdm import tqdm 5 | 6 | import configs 7 | from cr_solver import cr_solver 8 | from data import generate_A 9 | 10 | 11 | def test_size(size, test_config): 12 | baseline_errors_div_diff = [] 13 | operator_complexities = [] 14 | 15 | fp_threshold = test_config.fp_threshold 16 | strength = test_config.strength 17 | presmoother = test_config.presmoother 18 | postsmoother = test_config.postsmoother 19 | coarse_solver = test_config.coarse_solver 20 | 21 | cycle = test_config.cycle 22 | splitting = test_config.splitting 23 | num_runs = test_config.num_runs 24 | dist = test_config.dist 25 | max_levels = test_config.max_levels 26 | iterations = test_config.iterations 27 | load_data = test_config.load_data 28 | 29 | block_periodic = False 30 | root_num_blocks = 1 31 | 32 | if load_data: 33 | if dist == 'lognormal_laplacian_periodic': 34 | As = np.load(f"test_data_dir/delaunay_periodic_logn_num_As_{100}_num_points_{size}.npy") 35 | elif dist == 'lognormal_complex_fem': 36 | As = np.load(f"test_data_dir/fe_hole_logn_num_As_{100}_num_points_{size}.npy") 37 | else: 38 | raise NotImplementedError() 39 | 40 | for i in tqdm(range(num_runs)): 41 | if load_data: 42 | A = As[i] 43 | else: 44 | A = generate_A(size, dist, block_periodic, root_num_blocks) 45 | 46 | num_unknowns = A.shape[0] 47 | x0 = np.random.normal(loc=0.0, scale=1.0, size=num_unknowns) 48 | b = np.zeros((A.shape[0])) 49 | 50 | baseline_residuals = [] 51 | 52 | if splitting is 'CR' or splitting[0] is 'CR': 53 | baseline_solver = cr_solver(A, 54 | presmoother=presmoother, 55 | postsmoother=postsmoother, 56 | keep=True, max_levels=max_levels, 57 | CF=splitting, 58 | coarse_solver=coarse_solver) 59 | elif splitting is 'SA': 60 | baseline_solver = smoothed_aggregation_solver(A, 61 | strength=strength, 62 | presmoother=presmoother, 63 | postsmoother=postsmoother, 64 | max_levels=max_levels, 65 | keep=True, 66 | coarse_solver=coarse_solver) 67 | elif splitting is 'rootnode': 68 | baseline_solver = rootnode_solver(A, 69 | strength=strength, 70 | presmoother=presmoother, 71 | postsmoother=postsmoother, 72 | max_levels=max_levels, 73 | keep=True, 74 | coarse_solver=coarse_solver) 75 | else: 76 | baseline_solver = ruge_stuben_solver(A, 77 | strength=strength, 78 | interpolation='direct', 79 | presmoother=presmoother, 80 | postsmoother=postsmoother, 81 | keep=True, max_levels=max_levels, 82 | CF=splitting, 83 | coarse_solver=coarse_solver) 84 | 85 | operator_complexities.append(baseline_solver.operator_complexity()) 86 | 87 | _ = baseline_solver.solve(b, x0=x0, tol=0.0, maxiter=iterations, cycle=cycle, 88 | residuals=baseline_residuals) 89 | baseline_residuals = np.array(baseline_residuals) 90 | baseline_residuals = baseline_residuals[baseline_residuals > fp_threshold] 91 | baseline_factor = baseline_residuals[-1] / baseline_residuals[-2] 92 | baseline_errors_div_diff.append(baseline_factor) 93 | 94 | baseline_errors_div_diff = np.array(baseline_errors_div_diff) 95 | baseline_errors_div_diff_mean = np.mean(baseline_errors_div_diff) 96 | baseline_errors_div_diff_std = np.std(baseline_errors_div_diff) 97 | 98 | operator_complexity_mean = np.mean(operator_complexities) 99 | operator_complexity_std = np.std(operator_complexities) 100 | 101 | if type(splitting) == tuple: 102 | splitting_str = splitting[0] + '_' + '_'.join([f'{key}_{value}' for key, value in splitting[1].items()]) 103 | else: 104 | splitting_str = splitting 105 | results_file = open( 106 | f"results/baseline/{dist}_{num_unknowns}_cycle_{cycle}_max_levels_{max_levels}_split_{splitting_str}_results.txt", 'w') 107 | print(f"cycle: {cycle}, max levels: {max_levels}", file=results_file) 108 | 109 | print(f"asymptotic error factor baseline: {baseline_errors_div_diff_mean:.4f} ± {baseline_errors_div_diff_std:.5f}", 110 | file=results_file) 111 | 112 | print(f"num unknowns: {num_unknowns}") 113 | print(f"asymptotic error factor baseline: {baseline_errors_div_diff_mean:.4f} ± {baseline_errors_div_diff_std:.5f}") 114 | 115 | print(f"operator complexity: {operator_complexity_mean:.4f} ± {operator_complexity_std:.5f}") 116 | print(f"operator complexity: {operator_complexity_mean:.4f} ± {operator_complexity_std:.5f}", 117 | file=results_file) 118 | 119 | results_file.close() 120 | 121 | 122 | def test_baseline(config='GRAPH_LAPLACIAN_TEST', seed=1): 123 | # fix random seeds for reproducibility 124 | np.random.seed(seed) 125 | 126 | test_config = getattr(configs, config).test_config 127 | 128 | for size in test_config.test_sizes: 129 | test_size(size, test_config) 130 | 131 | 132 | if __name__ == '__main__': 133 | fire.Fire(test_baseline) 134 | -------------------------------------------------------------------------------- /graph_net_model.py: -------------------------------------------------------------------------------- 1 | import sonnet as snt 2 | import graph_nets as gn 3 | import tensorflow as tf 4 | from graph_nets import modules 5 | from functools import partial 6 | 7 | 8 | class EncodeProcessDecodeNonRecurrent(snt.AbstractModule): 9 | """ 10 | similar to EncodeProcessDecode, but with non-recurrent core 11 | see docs for EncodeProcessDecode 12 | """ 13 | 14 | def __init__(self, 15 | num_cores=3, 16 | edge_output_size=None, 17 | node_output_size=None, 18 | global_output_size=None, 19 | global_block=True, 20 | latent_size=16, 21 | num_layers=2, 22 | concat_encoder=True, 23 | name="EncodeProcessDecodeNonRecurrent"): 24 | super(EncodeProcessDecodeNonRecurrent, self).__init__(name=name) 25 | self._encoder = MLPGraphIndependent(latent_size=latent_size, num_layers=num_layers) 26 | self._cores = [MLPGraphNetwork(latent_size=latent_size, num_layers=num_layers, 27 | global_block=global_block) for _ in range(num_cores)] 28 | self._decoder = MLPGraphIndependent(latent_size=latent_size, num_layers=num_layers) 29 | self.concat_encoder = concat_encoder 30 | # Transforms the outputs into the appropriate shapes. 31 | if edge_output_size is None: 32 | edge_fn = None 33 | else: 34 | edge_fn = lambda: snt.Linear(edge_output_size, name="edge_output") 35 | if node_output_size is None: 36 | node_fn = None 37 | else: 38 | node_fn = lambda: snt.Linear(node_output_size, name="node_output") 39 | if global_output_size is None: 40 | global_fn = None 41 | else: 42 | global_fn = lambda: snt.Linear(global_output_size, name="global_output") 43 | with self._enter_variable_scope(): 44 | self._output_transform = modules.GraphIndependent(edge_fn, node_fn, 45 | global_fn) 46 | 47 | def _build(self, input_op): 48 | latent = self._encoder(input_op) 49 | latent0 = latent 50 | for i in range(len(self._cores)): 51 | if self.concat_encoder: 52 | core_input = gn.utils_tf.concat([latent0, latent], axis=1) 53 | else: 54 | core_input = latent 55 | latent = self._cores[i](core_input) 56 | return self._output_transform(self._decoder(latent)) 57 | 58 | 59 | class MLPGraphNetwork(snt.AbstractModule): 60 | """GraphNetwork with MLP edge, node, and global models.""" 61 | 62 | def __init__(self, latent_size=16, num_layers=2, global_block=True, last_round=False, 63 | name="MLPGraphNetwork"): 64 | super(MLPGraphNetwork, self).__init__(name=name) 65 | partial_make_mlp_model = partial(make_mlp_model, latent_size=latent_size, num_layers=num_layers, 66 | last_round_edges=False) 67 | if last_round: 68 | partial_make_mlp_model_edges = partial(make_mlp_model, latent_size=latent_size, num_layers=num_layers, 69 | last_round_edges=True) 70 | else: 71 | partial_make_mlp_model_edges = partial_make_mlp_model 72 | 73 | with self._enter_variable_scope(): 74 | if global_block: 75 | self._network = modules.GraphNetwork(partial_make_mlp_model_edges, partial_make_mlp_model, 76 | partial_make_mlp_model, 77 | edge_block_opt={ 78 | "use_globals": True 79 | }, 80 | node_block_opt={ 81 | "use_globals": True 82 | }, 83 | global_block_opt={ 84 | "use_globals": True, 85 | "edges_reducer": tf.unsorted_segment_mean, 86 | "nodes_reducer": tf.unsorted_segment_mean 87 | }) 88 | else: 89 | self._network = modules.GraphNetwork(partial_make_mlp_model_edges, partial_make_mlp_model, 90 | make_identity_model, 91 | edge_block_opt={ 92 | "use_globals": False 93 | }, 94 | node_block_opt={ 95 | "use_globals": False 96 | }, 97 | global_block_opt={ 98 | "use_globals": False, 99 | }) 100 | 101 | def _build(self, inputs): 102 | return self._network(inputs) 103 | 104 | 105 | class MLPGraphIndependent(snt.AbstractModule): 106 | """GraphIndependent with MLP edge, node, and global models.""" 107 | 108 | def __init__(self, latent_size=16, num_layers=2, name="MLPGraphIndependent"): 109 | super(MLPGraphIndependent, self).__init__(name=name) 110 | 111 | partial_make_mlp_model = partial(make_mlp_model, latent_size=latent_size, num_layers=num_layers, 112 | last_round_edges=False) 113 | 114 | with self._enter_variable_scope(): 115 | self._network = modules.GraphIndependent( 116 | edge_model_fn=partial_make_mlp_model, 117 | node_model_fn=partial_make_mlp_model, 118 | global_model_fn=partial_make_mlp_model) 119 | 120 | def _build(self, inputs): 121 | return self._network(inputs) 122 | 123 | 124 | def make_mlp_model(latent_size=16, num_layers=2, last_round_edges=False): 125 | """Instantiates a new MLP, followed by LayerNorm. 126 | 127 | The parameters of each new MLP are not shared with others generated by 128 | this function. 129 | 130 | Returns: 131 | A Sonnet module which contains the MLP and LayerNorm. 132 | """ 133 | if last_round_edges: 134 | return snt.nets.MLP([latent_size] * num_layers + [1], activate_final=False) 135 | else: 136 | return snt.Sequential([ 137 | snt.nets.MLP([latent_size] * num_layers, activate_final=False) 138 | ]) 139 | 140 | 141 | class IdentityModule(snt.AbstractModule): 142 | def _build(self, inputs): 143 | return tf.identity(inputs) 144 | 145 | 146 | def make_identity_model(): 147 | return IdentityModule() 148 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | 4 | import fire 5 | import matlab.engine 6 | import numpy as np 7 | import tensorflow as tf 8 | from tqdm import tqdm 9 | 10 | import configs 11 | from data import generate_A 12 | from model import get_model 13 | from prolongation_functions import model, baseline 14 | from ruge_stuben_custom_solver import ruge_stuben_custom_solver 15 | 16 | 17 | def test_size(model_name, graph_model, size, test_config, run_config, matlab_engine): 18 | model_prolongation = partial(model, graph_model=graph_model, 19 | normalize_rows_by_node=run_config.normalize_rows_by_node, 20 | edge_indicators=run_config.edge_indicators, 21 | node_indicators=run_config.node_indicators, 22 | matlab_engine=matlab_engine) 23 | baseline_prolongation = baseline 24 | 25 | model_errors_div_diff = [] 26 | baseline_errors_div_diff = [] 27 | 28 | fp_threshold = test_config.fp_threshold 29 | strength = test_config.strength 30 | presmoother = test_config.presmoother 31 | postsmoother = test_config.postsmoother 32 | coarse_solver = test_config.coarse_solver 33 | 34 | cycle = test_config.cycle 35 | splitting = test_config.splitting 36 | dist = test_config.dist 37 | num_runs = test_config.num_runs 38 | max_levels = test_config.max_levels 39 | iterations = test_config.iterations 40 | load_data = test_config.load_data 41 | 42 | block_periodic = False 43 | root_num_blocks = 1 44 | 45 | if load_data: 46 | if dist == 'lognormal_laplacian_periodic': 47 | As = np.load(f"test_data_dir/delaunay_periodic_logn_num_As_{100}_num_points_{size}.npy") 48 | elif dist == 'lognormal_complex_fem': 49 | As = np.load(f"test_data_dir/fe_hole_logn_num_As_{100}_num_points_{size}.npy") 50 | else: 51 | raise NotImplementedError() 52 | 53 | for i in tqdm(range(num_runs)): 54 | if load_data: 55 | A = As[i] 56 | else: 57 | A = generate_A(size, dist, block_periodic, root_num_blocks) 58 | 59 | num_unknowns = A.shape[0] 60 | x0 = np.random.normal(loc=0.0, scale=1.0, size=num_unknowns) 61 | b = np.zeros((A.shape[0])) 62 | 63 | model_residuals = [] 64 | baseline_residuals = [] 65 | 66 | model_solver = ruge_stuben_custom_solver(A, model_prolongation, 67 | strength=strength, 68 | presmoother=presmoother, 69 | postsmoother=postsmoother, 70 | keep=True, max_levels=max_levels, 71 | CF=splitting, 72 | coarse_solver=coarse_solver) 73 | 74 | _ = model_solver.solve(b, x0=x0, tol=0.0, maxiter=iterations, cycle=cycle, 75 | residuals=model_residuals) 76 | model_residuals = np.array(model_residuals) 77 | model_residuals = model_residuals[model_residuals > fp_threshold] 78 | model_factor = model_residuals[-1] / model_residuals[-2] 79 | model_errors_div_diff.append(model_factor) 80 | 81 | baseline_solver = ruge_stuben_custom_solver(A, baseline_prolongation, 82 | strength=strength, 83 | presmoother=presmoother, 84 | postsmoother=postsmoother, 85 | keep=True, max_levels=max_levels, 86 | CF=splitting, 87 | coarse_solver=coarse_solver) 88 | 89 | _ = baseline_solver.solve(b, x0=x0, tol=0.0, maxiter=iterations, cycle=cycle, 90 | residuals=baseline_residuals) 91 | baseline_residuals = np.array(baseline_residuals) 92 | baseline_residuals = baseline_residuals[baseline_residuals > fp_threshold] 93 | baseline_factor = baseline_residuals[-1] / baseline_residuals[-2] 94 | baseline_errors_div_diff.append(baseline_factor) 95 | 96 | model_errors_div_diff = np.array(model_errors_div_diff) 97 | baseline_errors_div_diff = np.array(baseline_errors_div_diff) 98 | model_errors_div_diff_mean = np.mean(model_errors_div_diff) 99 | model_errors_div_diff_std = np.std(model_errors_div_diff) 100 | baseline_errors_div_diff_mean = np.mean(baseline_errors_div_diff) 101 | baseline_errors_div_diff_std = np.std(baseline_errors_div_diff) 102 | 103 | if type(splitting) == tuple: 104 | splitting_str = splitting[0] + '_'+ '_'.join([f'{key}_{value}' for key, value in splitting[1].items()]) 105 | else: 106 | splitting_str = splitting 107 | results_file = open( 108 | f"results/{model_name}/{dist}_{num_unknowns}_cycle_{cycle}_max_levels_{max_levels}_split_{splitting_str}_results.txt", 109 | 'w') 110 | print(f"cycle: {cycle}, max levels: {max_levels}", file=results_file) 111 | print(f"asymptotic error factor model: {model_errors_div_diff_mean:.4f} ± {model_errors_div_diff_std:.5f}", 112 | file=results_file) 113 | 114 | print(f"asymptotic error factor baseline: {baseline_errors_div_diff_mean:.4f} ± {baseline_errors_div_diff_std:.5f}", 115 | file=results_file) 116 | model_success_rate = sum(model_errors_div_diff < baseline_errors_div_diff) / num_runs 117 | print(f"model success rate: {model_success_rate}", 118 | file=results_file) 119 | 120 | print(f"num unknowns: {num_unknowns}") 121 | print(f"asymptotic error factor model: {model_errors_div_diff_mean:.4f} ± {model_errors_div_diff_std:.5f}") 122 | print(f"asymptotic error factor baseline: {baseline_errors_div_diff_mean:.4f} ± {baseline_errors_div_diff_std:.5f}") 123 | print(f"model success rate: {model_success_rate}") 124 | 125 | results_file.close() 126 | 127 | 128 | def test_model(model_name=None, test_config='GRAPH_LAPLACIAN_TEST', seed=1): 129 | if model_name is None: 130 | raise RuntimeError("model name required") 131 | model_name = str(model_name) 132 | matlab_engine = matlab.engine.start_matlab() 133 | 134 | # fix random seeds for reproducibility 135 | np.random.seed(seed) 136 | tf.random.set_random_seed(seed) 137 | matlab_engine.eval(f'rng({seed})') 138 | 139 | test_config = getattr(configs, test_config).test_config 140 | config_file = f"results/{model_name}/config.json" 141 | with open(config_file) as f: 142 | data = json.load(f) 143 | model_config = configs.ModelConfig(**data['model_config']) 144 | run_config = configs.RunConfig(**data['run_config']) 145 | 146 | model = get_model(model_name, model_config, run_config, matlab_engine) 147 | 148 | for size in test_config.test_sizes: 149 | test_size(model_name, model, size, test_config, run_config, 150 | matlab_engine) 151 | 152 | 153 | if __name__ == '__main__': 154 | config = tf.ConfigProto() 155 | config.gpu_options.allow_growth = True 156 | tf.enable_eager_execution(config=config) 157 | 158 | fire.Fire(test_model) 159 | -------------------------------------------------------------------------------- /ruge_stuben_custom_solver.py: -------------------------------------------------------------------------------- 1 | """Classical AMG (Ruge-Stuben AMG).""" 2 | from __future__ import absolute_import 3 | 4 | import math 5 | from warnings import warn 6 | from scipy.sparse import csr_matrix, isspmatrix_csr, SparseEfficiencyWarning 7 | 8 | from pyamg.multilevel import multilevel_solver 9 | from pyamg.relaxation.smoothing import change_smoothers 10 | from pyamg.strength import classical_strength_of_connection, \ 11 | symmetric_strength_of_connection, evolution_strength_of_connection, \ 12 | distance_strength_of_connection, energy_based_strength_of_connection, \ 13 | algebraic_distance, affinity_distance 14 | 15 | from pyamg.classical.interpolate import direct_interpolation 16 | from pyamg.classical import split 17 | from pyamg.classical.cr import CR 18 | 19 | import numpy as np 20 | 21 | __all__ = ['ruge_stuben_custom_solver'] 22 | 23 | 24 | # similar to "ruge_stuben_solver" in pyamg, but with additional prolongation function parameter 25 | def ruge_stuben_custom_solver(A, prolongation_function, 26 | strength=('classical', {'theta': 0.25}), 27 | CF='RS', 28 | presmoother=('gauss_seidel', {'sweep': 'forward'}), 29 | postsmoother=('gauss_seidel', {'sweep': 'forward'}), 30 | max_levels=10, max_coarse=10, keep=False, **kwargs): 31 | """Create a multilevel solver using Classical AMG (Ruge-Stuben AMG). 32 | 33 | Parameters 34 | ---------- 35 | A : csr_matrix 36 | Square matrix in CSR format 37 | prolongation_function : function 38 | receives matrix A, splitting, and the baseline prolongation matrix P 39 | outputs prolongation matrix P 40 | strength : ['symmetric', 'classical', 'evolution', 'distance', 'algebraic_distance','affinity', 'energy_based', None] 41 | Method used to determine the strength of connection between unknowns 42 | of the linear system. Method-specific parameters may be passed in 43 | using a tuple, e.g. strength=('symmetric',{'theta' : 0.25 }). If 44 | strength=None, all nonzero entries of the matrix are considered strong. 45 | CF : string 46 | Method used for coarse grid selection (C/F splitting) 47 | Supported methods are RS, PMIS, PMISc, CLJP, CLJPc, and CR. 48 | presmoother : string or dict 49 | Method used for presmoothing at each level. Method-specific parameters 50 | may be passed in using a tuple, e.g. 51 | presmoother=('gauss_seidel',{'sweep':'symmetric}), the default. 52 | postsmoother : string or dict 53 | Postsmoothing method with the same usage as presmoother 54 | max_levels: integer 55 | Maximum number of levels to be used in the multilevel solver. 56 | max_coarse: integer 57 | Maximum number of variables permitted on the coarse grid. 58 | keep: bool 59 | Flag to indicate keeping extra operators in the hierarchy for 60 | diagnostics. For example, if True, then strength of connection (C) and 61 | tentative prolongation (T) are kept. 62 | 63 | Returns 64 | ------- 65 | ml : multilevel_solver 66 | Multigrid hierarchy of matrices and prolongation operators 67 | 68 | Examples 69 | -------- 70 | >>> from pyamg.gallery import poisson 71 | >>> from pyamg import ruge_stuben_solver 72 | >>> A = poisson((10,),format='csr') 73 | >>> ml = ruge_stuben_solver(A,max_coarse=3) 74 | 75 | Notes 76 | ----- 77 | "coarse_solver" is an optional argument and is the solver used at the 78 | coarsest grid. The default is a pseudo-inverse. Most simply, 79 | coarse_solver can be one of ['splu', 'lu', 'cholesky, 'pinv', 80 | 'gauss_seidel', ... ]. Additionally, coarse_solver may be a tuple 81 | (fn, args), where fn is a string such as ['splu', 'lu', ...] or a callable 82 | function, and args is a dictionary of arguments to be passed to fn. 83 | See [2001TrOoSc]_ for additional details. 84 | 85 | 86 | References 87 | ---------- 88 | .. [2001TrOoSc] Trottenberg, U., Oosterlee, C. W., and Schuller, A., 89 | "Multigrid" San Diego: Academic Press, 2001. Appendix A 90 | 91 | See Also 92 | -------- 93 | aggregation.smoothed_aggregation_solver, multilevel_solver, 94 | aggregation.rootnode_solver 95 | 96 | """ 97 | levels = [multilevel_solver.level()] 98 | 99 | # convert A to csr 100 | if not isspmatrix_csr(A): 101 | try: 102 | A = csr_matrix(A) 103 | warn("Implicit conversion of A to CSR", 104 | SparseEfficiencyWarning) 105 | except BaseException: 106 | raise TypeError('Argument A must have type csr_matrix, \ 107 | or be convertible to csr_matrix') 108 | # preprocess A 109 | A = A.asfptype() 110 | if A.shape[0] != A.shape[1]: 111 | raise ValueError('expected square matrix') 112 | 113 | levels[-1].A = A 114 | 115 | while len(levels) < max_levels and levels[-1].A.shape[0] > max_coarse: 116 | extend_hierarchy(levels, strength, CF, keep, prolongation_function) 117 | 118 | ml = multilevel_solver(levels, **kwargs) 119 | change_smoothers(ml, presmoother, postsmoother) 120 | return ml 121 | 122 | 123 | # internal function 124 | def extend_hierarchy(levels, strength, CF, keep, prolongation_function): 125 | """Extend the multigrid hierarchy.""" 126 | 127 | def unpack_arg(v): 128 | if isinstance(v, tuple): 129 | return v[0], v[1] 130 | else: 131 | return v, {} 132 | 133 | A = levels[-1].A 134 | 135 | # Compute the strength-of-connection matrix C, where larger 136 | # C[i,j] denote stronger couplings between i and j. 137 | fn, kwargs = unpack_arg(strength) 138 | if fn == 'symmetric': 139 | C = symmetric_strength_of_connection(A, **kwargs) 140 | elif fn == 'classical': 141 | C = classical_strength_of_connection(A, **kwargs) 142 | elif fn == 'distance': 143 | C = distance_strength_of_connection(A, **kwargs) 144 | elif (fn == 'ode') or (fn == 'evolution'): 145 | C = evolution_strength_of_connection(A, **kwargs) 146 | elif fn == 'energy_based': 147 | C = energy_based_strength_of_connection(A, **kwargs) 148 | elif fn == 'algebraic_distance': 149 | C = algebraic_distance(A, **kwargs) 150 | elif fn == 'affinity': 151 | C = affinity_distance(A, **kwargs) 152 | elif fn is None: 153 | C = A 154 | else: 155 | raise ValueError('unrecognized strength of connection method: %s' % 156 | str(fn)) 157 | 158 | # Generate the C/F splitting 159 | fn, kwargs = unpack_arg(CF) 160 | if fn == 'RS': 161 | splitting = split.RS(C, **kwargs) 162 | elif fn == 'PMIS': 163 | splitting = split.PMIS(C, **kwargs) 164 | elif fn == 'PMISc': 165 | splitting = split.PMISc(C, **kwargs) 166 | elif fn == 'CLJP': 167 | splitting = split.CLJP(C, **kwargs) 168 | elif fn == 'CLJPc': 169 | splitting = split.CLJPc(C, **kwargs) 170 | elif fn == 'CR': 171 | splitting = CR(C, **kwargs) 172 | else: 173 | raise ValueError('unknown C/F splitting method (%s)' % CF) 174 | 175 | # Generate the interpolation matrix that maps from the coarse-grid to the 176 | # fine-grid 177 | baseline_P = direct_interpolation(A, C, splitting) 178 | coarse_nodes = np.nonzero(splitting)[0] 179 | 180 | # Create a prolongation matrix with the same coarse-fine splitting and sparsity pattern 181 | # as the baseline 182 | P = prolongation_function(A, coarse_nodes, baseline_P, C) 183 | 184 | # Generate the restriction matrix that maps from the fine-grid to the 185 | # coarse-grid 186 | R = P.T.tocsr() 187 | 188 | # Store relevant information for this level 189 | if keep: 190 | levels[-1].C = C # strength of connection matrix 191 | levels[-1].splitting = splitting # C/F splitting 192 | 193 | levels[-1].P = P # prolongation operator 194 | levels[-1].R = R # restriction operator 195 | 196 | levels.append(multilevel_solver.level()) 197 | 198 | # Form next level through Galerkin product 199 | A = R * A * P 200 | levels[-1].A = A 201 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | class DataConfig: 2 | def __init__(self, dist='lognormal_laplacian_periodic', block_periodic=True, 3 | num_unknowns=8 ** 2, root_num_blocks=4, splitting='CLJP', add_diag=False, 4 | load_data=True, save_data=False): 5 | self.dist = dist # see function 'generate_A()' for possible distributions 6 | self.block_periodic = block_periodic 7 | self.num_unknowns = num_unknowns 8 | self.root_num_blocks = root_num_blocks 9 | self.splitting = splitting 10 | self.add_diag = add_diag 11 | self.load_data = load_data 12 | self.save_data = save_data 13 | 14 | 15 | class ModelConfig: 16 | def __init__(self, mp_rounds=3, global_block=False, latent_size=64, mlp_layers=4, concat_encoder=True): 17 | self.mp_rounds = mp_rounds 18 | self.global_block = global_block 19 | self.latent_size = latent_size 20 | self.mlp_layers = mlp_layers 21 | self.concat_encoder = concat_encoder 22 | 23 | 24 | class RunConfig: 25 | def __init__(self, node_indicators=True, edge_indicators=True, normalize_rows=True, normalize_rows_by_node=False): 26 | self.node_indicators = node_indicators 27 | self.edge_indicators = edge_indicators 28 | self.normalize_rows = normalize_rows 29 | self.normalize_rows_by_node = normalize_rows_by_node 30 | 31 | 32 | class TestConfig: 33 | def __init__(self, dist='lognormal_laplacian_periodic', splitting='CLJP', 34 | test_sizes=(1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072), 35 | load_data=True, num_runs=100, cycle='W', 36 | max_levels=12, iterations=81, fp_threshold=1e-10, strength=('classical', {'theta': 0.25}), 37 | presmoother=('gauss_seidel', {'sweep': 'forward', 'iterations': 1}), 38 | postsmoother=('gauss_seidel', {'sweep': 'forward', 'iterations': 1}), 39 | coarse_solver='pinv2'): 40 | self.dist = dist 41 | self.splitting = splitting 42 | self.test_sizes = test_sizes 43 | self.load_data = load_data 44 | self.num_runs = num_runs 45 | self.cycle = cycle 46 | self.max_levels = max_levels 47 | self.iterations = iterations 48 | self.fp_threshold = fp_threshold 49 | self.strength = strength 50 | self.presmoother = presmoother 51 | self.postsmoother = postsmoother 52 | self.coarse_solver = coarse_solver 53 | # self.coarse_solver = ('gauss_seidel', {'iterations': 20}) 54 | 55 | 56 | class TrainConfig: 57 | def __init__(self, samples_per_run=256, num_runs=1000, batch_size=32, learning_rate=3e-3, fourier=True, 58 | coarsen=False, checkpoint_dir='./training_dir', tensorboard_dir='./tb_dir', load_model=False): 59 | self.samples_per_run = samples_per_run 60 | self.num_runs = num_runs 61 | self.batch_size = batch_size 62 | self.learning_rate = learning_rate 63 | self.fourier = fourier 64 | self.coarsen = coarsen 65 | self.checkpoint_dir = checkpoint_dir 66 | self.tensorboard_dir = tensorboard_dir 67 | self.load_model = load_model 68 | 69 | 70 | class Config: 71 | def __init__(self): 72 | self.data_config = DataConfig() 73 | self.model_config = ModelConfig() 74 | self.run_config = RunConfig() 75 | self.test_config = TestConfig() 76 | self.train_config = TrainConfig() 77 | 78 | 79 | GRAPH_LAPLACIAN_TEST = Config() 80 | 81 | COMPLEX_FEM_TEST = Config() 82 | COMPLEX_FEM_TEST.test_config.dist = 'lognormal_complex_fem' 83 | COMPLEX_FEM_TEST.test_config.fp_threshold = 0 84 | 85 | GRAPH_LAPLACIAN_RS_TEST = Config() 86 | GRAPH_LAPLACIAN_RS_TEST.test_config.splitting = 'RS' 87 | 88 | GRAPH_LAPLACIAN_RS_SECOND_PASS_TEST = Config() 89 | GRAPH_LAPLACIAN_RS_SECOND_PASS_TEST.test_config.splitting = ('RS', {'second_pass': True}) 90 | 91 | GRAPH_LAPLACIAN_PMIS_TEST = Config() 92 | GRAPH_LAPLACIAN_PMIS_TEST.test_config.splitting = 'PMIS' 93 | 94 | GRAPH_LAPLACIAN_PMISc_TEST = Config() 95 | GRAPH_LAPLACIAN_PMISc_TEST.test_config.splitting = 'PMISc' 96 | 97 | GRAPH_LAPLACIAN_CLJPc_TEST = Config() 98 | GRAPH_LAPLACIAN_CLJPc_TEST.test_config.splitting = 'CLJPc' 99 | 100 | GRAPH_LAPLACIAN_SA_TEST = Config() 101 | GRAPH_LAPLACIAN_SA_TEST.test_config.splitting = 'SA' 102 | 103 | GRAPH_LAPLACIAN_ROOTNODE_TEST = Config() 104 | GRAPH_LAPLACIAN_ROOTNODE_TEST.test_config.splitting = 'rootnode' 105 | 106 | 107 | GRAPH_LAPLACIAN_TRAIN = Config() 108 | GRAPH_LAPLACIAN_TRAIN.data_config.dist = 'lognormal_laplacian_periodic' 109 | 110 | GRAPH_LAPLACIAN_ABLATION_MP2 = Config() 111 | GRAPH_LAPLACIAN_ABLATION_MP2.data_config.dist = 'lognormal_laplacian_periodic' 112 | GRAPH_LAPLACIAN_ABLATION_MP2.model_config.mp_rounds = 2 113 | 114 | GRAPH_LAPLACIAN_ABLATION_MLP2 = Config() 115 | GRAPH_LAPLACIAN_ABLATION_MLP2.data_config.dist = 'lognormal_laplacian_periodic' 116 | GRAPH_LAPLACIAN_ABLATION_MLP2.model_config.mlp_layers = 2 117 | 118 | GRAPH_LAPLACIAN_ABLATION_NO_CONCAT = Config() 119 | GRAPH_LAPLACIAN_ABLATION_NO_CONCAT.data_config.dist = 'lognormal_laplacian_periodic' 120 | GRAPH_LAPLACIAN_ABLATION_NO_CONCAT.model_config.concat_encoder = False 121 | 122 | GRAPH_LAPLACIAN_ABLATION_NO_INDICATORS = Config() 123 | GRAPH_LAPLACIAN_ABLATION_NO_INDICATORS.data_config.dist = 'lognormal_laplacian_periodic' 124 | GRAPH_LAPLACIAN_ABLATION_NO_INDICATORS.run_config.node_indicators = False 125 | GRAPH_LAPLACIAN_ABLATION_NO_INDICATORS.run_config.edge_indicators = False 126 | 127 | GRAPH_LAPLACIAN_EVAL = Config() 128 | GRAPH_LAPLACIAN_EVAL.data_config.block_periodic = False 129 | GRAPH_LAPLACIAN_EVAL.data_config.num_unknowns = 4096 130 | GRAPH_LAPLACIAN_EVAL.data_config.dist = 'lognormal_laplacian' 131 | GRAPH_LAPLACIAN_EVAL.data_config.load_data = False 132 | GRAPH_LAPLACIAN_EVAL.train_config.fourier = False 133 | 134 | SPEC_CLUSTERING_TRAIN = Config() 135 | SPEC_CLUSTERING_TRAIN.data_config.dist = 'spectral_clustering' 136 | SPEC_CLUSTERING_TRAIN.data_config.num_unknowns = 1024 137 | SPEC_CLUSTERING_TRAIN.data_config.block_periodic = False 138 | SPEC_CLUSTERING_TRAIN.data_config.add_diag = True 139 | SPEC_CLUSTERING_TRAIN.train_config.coarsen = False 140 | SPEC_CLUSTERING_TRAIN.train_config.fourier = False 141 | 142 | SPEC_CLUSTERING_EVAL = Config() 143 | SPEC_CLUSTERING_EVAL.data_config.dist = 'spectral_clustering' 144 | SPEC_CLUSTERING_EVAL.data_config.num_unknowns = 4096 145 | SPEC_CLUSTERING_EVAL.data_config.block_periodic = False 146 | SPEC_CLUSTERING_EVAL.data_config.add_diag = True 147 | SPEC_CLUSTERING_EVAL.train_config.coarsen = False 148 | SPEC_CLUSTERING_EVAL.train_config.fourier = False 149 | 150 | 151 | 152 | GRAPH_LAPLACIAN_UNIFORM_TEST = Config() 153 | GRAPH_LAPLACIAN_UNIFORM_TEST.data_config.block_periodic = False 154 | GRAPH_LAPLACIAN_UNIFORM_TEST.data_config.dist = 'uniform_laplacian' 155 | 156 | FINITE_ELEMENT_TEST = Config() 157 | FINITE_ELEMENT_TEST.data_config.block_periodic = False 158 | FINITE_ELEMENT_TEST.data_config.dist = 'finite_element' 159 | 160 | # should replicate results from "Compatible Relaxation and Coarsening in Algebraic Multigrid" (2009) 161 | CR_TEST = Config() 162 | CR_TEST.data_config.splitting = ('CR', {'verbose': True, 163 | 'method': 'habituated', 164 | 'nu': 2, 165 | 'thetacr': 0.5, 166 | 'thetacs': [0.3 ** 2, 0.5], 167 | 'maxiter': 20}) 168 | # CR_TEST.data_config.dist = 'poisson' 169 | # CR_TEST.data_config.dist = 'aniso' 170 | CR_TEST.data_config.dist = 'lognormal_laplacian' 171 | # CR_TEST.data_config.dist = 'example' 172 | CR_TEST.test_config.num_runs = 10 173 | CR_TEST.test_config.test_sizes = (1024, 2048, 4096, 8192,) 174 | # CR_TEST.test_config.test_sizes = ('airfoil', 'bar', 'knot', 'local_disc_galerkin_diffusion', 175 | # 'recirc_flow', 'unit_cube', 'unit_square') 176 | # CR_TEST.test_config.fp_threshold = 0 177 | # CR_TEST.test_config.coarse_solver = ('gauss_seidel', {'iterations': 200}) 178 | 179 | CR_TEST.test_config.presmoother = ('gauss_seidel', {'sweep': 'forward', 'iterations': 1}) 180 | # CR_TEST.test_config.postsmoother = ('gauss_seidel', {'sweep': 'backward', 'iterations': 1}) 181 | CR_TEST.test_config.coarse_solver = 'pinv2' 182 | CR_TEST.test_config.cycle = 'V' 183 | CR_TEST.test_config.iterations = 40 184 | CR_TEST.test_config.max_levels = 2 185 | -------------------------------------------------------------------------------- /cr_solver.py: -------------------------------------------------------------------------------- 1 | from warnings import warn 2 | 3 | import numpy as np 4 | import scipy.sparse as sparse 5 | from pyamg.classical import split 6 | from pyamg.classical.cr import CR 7 | from pyamg.multilevel import multilevel_solver 8 | from pyamg.relaxation.smoothing import change_smoothers 9 | from pyamg.classical.interpolate import distance_two_interpolation, direct_interpolation, standard_interpolation 10 | from pyamg.strength import classical_strength_of_connection 11 | from scipy.sparse import csr_matrix, isspmatrix_csr, SparseEfficiencyWarning 12 | 13 | 14 | def cr_solver(A, 15 | CF='CR', l=40, maxp=40000, theta_a=0.25*0, 16 | presmoother=('gauss_seidel', {'sweep': 'symmetric'}), 17 | postsmoother=('gauss_seidel', {'sweep': 'symmetric'}), 18 | max_levels=10, max_coarse=10, keep=False, **kwargs): 19 | """Create a multilevel solver using Classical AMG (Ruge-Stuben AMG). 20 | 21 | Parameters 22 | ---------- 23 | A : csr_matrix 24 | Square matrix in CSR format 25 | CF : string 26 | Method used for coarse grid selection (C/F splitting) 27 | Supported methods are RS, PMIS, PMISc, CLJP, CLJPc, and CR. 28 | presmoother : string or dict 29 | Method used for presmoothing at each level. Method-specific parameters 30 | may be passed in using a tuple, e.g. 31 | presmoother=('gauss_seidel',{'sweep':'symmetric}), the default. 32 | postsmoother : string or dict 33 | Postsmoothing method with the same usage as presmoother 34 | max_levels: integer 35 | Maximum number of levels to be used in the multilevel solver. 36 | max_coarse: integer 37 | Maximum number of variables permitted on the coarse grid. 38 | keep: bool 39 | Flag to indicate keeping extra operators in the hierarchy for 40 | diagnostics. For example, if True, then strength of connection (C) and 41 | tentative prolongation (T) are kept. 42 | 43 | Returns 44 | ------- 45 | ml : multilevel_solver 46 | Multigrid hierarchy of matrices and prolongation operators 47 | 48 | Examples 49 | -------- 50 | >>> from pyamg.gallery import poisson 51 | >>> from pyamg import ruge_stuben_solver 52 | >>> A = poisson((10,),format='csr') 53 | >>> ml = ruge_stuben_solver(A,max_coarse=3) 54 | 55 | Notes 56 | ----- 57 | "coarse_solver" is an optional argument and is the solver used at the 58 | coarsest grid. The default is a pseudo-inverse. Most simply, 59 | coarse_solver can be one of ['splu', 'lu', 'cholesky, 'pinv', 60 | 'gauss_seidel', ... ]. Additionally, coarse_solver may be a tuple 61 | (fn, args), where fn is a string such as ['splu', 'lu', ...] or a callable 62 | function, and args is a dictionary of arguments to be passed to fn. 63 | See [2001TrOoSc]_ for additional details. 64 | 65 | 66 | References 67 | ---------- 68 | .. [2001TrOoSc] Trottenberg, U., Oosterlee, C. W., and Schuller, A., 69 | "Multigrid" San Diego: Academic Press, 2001. Appendix A 70 | 71 | See Also 72 | -------- 73 | aggregation.smoothed_aggregation_solver, multilevel_solver, 74 | aggregation.rootnode_solver 75 | 76 | """ 77 | levels = [multilevel_solver.level()] 78 | 79 | # convert A to csr 80 | if not isspmatrix_csr(A): 81 | try: 82 | A = csr_matrix(A) 83 | warn("Implicit conversion of A to CSR", 84 | SparseEfficiencyWarning) 85 | except BaseException: 86 | raise TypeError('Argument A must have type csr_matrix, \ 87 | or be convertible to csr_matrix') 88 | # preprocess A 89 | A = A.asfptype() 90 | if A.shape[0] != A.shape[1]: 91 | raise ValueError('expected square matrix') 92 | 93 | levels[-1].A = A 94 | 95 | while len(levels) < max_levels and levels[-1].A.shape[0] > max_coarse: 96 | extend_hierarchy(levels, CF, l, maxp, theta_a, keep) 97 | 98 | ml = multilevel_solver(levels, **kwargs) 99 | change_smoothers(ml, presmoother, postsmoother) 100 | return ml 101 | 102 | 103 | def extend_hierarchy(levels, CF, l, maxp, theta_a, keep): 104 | """Extend the multigrid hierarchy.""" 105 | 106 | def unpack_arg(v): 107 | if isinstance(v, tuple): 108 | return v[0], v[1] 109 | else: 110 | return v, {} 111 | 112 | A = levels[-1].A 113 | 114 | # Generate the C/F splitting 115 | fn, kwargs = unpack_arg(CF) 116 | if fn == 'CR': 117 | splitting = CR(A, **kwargs) 118 | else: 119 | raise ValueError('unknown C/F splitting method (%s)' % CF) 120 | 121 | # rs_C = classical_strength_of_connection(A, theta=0.25) 122 | # rs_splitting = split.RS(rs_C) 123 | # rs_P = direct_interpolation(A.copy(), rs_C.copy(), rs_splitting.copy()) 124 | # 125 | # rs_P_sparsity = rs_P.copy() 126 | # rs_P_sparsity.data[:] = 1 127 | # 128 | # rs_fine = np.where(rs_splitting == 0)[0] 129 | # rs_coarse = np.where(rs_splitting == 1)[0] 130 | # rs_A_fc = A[rs_fine][:, rs_coarse] 131 | # rs_W = rs_P[rs_fine] 132 | # my_rs_P, my_rs_W = my_direct_interpolation(rs_A_fc, A, rs_W, rs_coarse, rs_fine) 133 | # 134 | # my_rs_P_sparsity = my_rs_P.copy() 135 | # my_rs_P_sparsity.data[:] = 1 136 | # 137 | # rs_A_sparsity = A[:, rs_coarse].copy() 138 | # rs_A_sparsity.data[:] = 1 139 | 140 | # Generate the interpolation matrix that maps from the coarse-grid to the 141 | # fine-grid 142 | P = truncation_interpolation(A, splitting, l, maxp, theta_a) 143 | # P = optimal_interpolation(A, splitting) 144 | # P = rs_P 145 | 146 | # Generate the restriction matrix that maps from the fine-grid to the 147 | # coarse-grid 148 | R = P.T.tocsr() 149 | 150 | # Store relevant information for this level 151 | if keep: 152 | levels[-1].splitting = splitting # C/F splitting 153 | 154 | levels[-1].P = P # prolongation operator 155 | levels[-1].R = R # restriction operator 156 | 157 | levels.append(multilevel_solver.level()) 158 | 159 | # Form next level through Galerkin product 160 | A = R * A * P 161 | levels[-1].A = A 162 | 163 | 164 | def optimal_interpolation(A, splitting): 165 | fine = np.where(splitting == 0)[0] 166 | coarse = np.where(splitting == 1)[0] 167 | 168 | A_ff = A[fine][:, fine] 169 | A_fc = A[fine][:, coarse] 170 | 171 | W = -sparse.linalg.inv(A_ff) @ A_fc 172 | 173 | np_W = W.toarray() 174 | P = np.zeros(A.shape) 175 | for i in range(A_fc.shape[1]): 176 | P[fine, coarse[i]] = np_W[:, i] 177 | 178 | np.fill_diagonal(P, 1) 179 | P = P[:, coarse] 180 | P = csr_matrix(P) 181 | return P 182 | 183 | 184 | def my_direct_interpolation(A_fc, A, sparsity, coarse, fine): 185 | sparsity = sparsity.copy() 186 | sparsity.data[:] = 1.0 187 | sparsity = sparsity.multiply(A_fc) 188 | 189 | A_zerodiag = A - sparse.diags(A.diagonal()) 190 | # A_zerodiag = A 191 | A_rowsums = np.array(A_zerodiag.sum(axis=1))[:, 0] 192 | sparsity_rowsums = np.array(sparsity.sum(axis=1))[:, 0] 193 | 194 | W = -A_fc.multiply(A_rowsums[fine, None]) / A.diagonal()[fine, None] / sparsity_rowsums[:, None] 195 | 196 | np_W = np.array(W) 197 | n = A_fc.shape[0] + A_fc.shape[1] 198 | P_square = np.zeros((n, n)) 199 | for i in range(W.shape[1]): 200 | P_square[fine, coarse[i]] = np_W[:, i] 201 | 202 | np.fill_diagonal(P_square, 1) 203 | P_square = csr_matrix(P_square) 204 | 205 | P = P_square[:, coarse] 206 | return P, csr_matrix(W) 207 | 208 | 209 | # from "Compatible Relaxation and Coarsening in Algebraic Multigrid" (2009) 210 | def truncation_interpolation(A, splitting, l, maxp, theta_a): 211 | # eq. 3.2 212 | fine = np.where(splitting == 0)[0] 213 | coarse = np.where(splitting == 1)[0] 214 | 215 | A_ff = A[fine][:, fine] 216 | A_fc = A[fine][:, coarse] 217 | 218 | D_ff = sparse.diags(A_ff.diagonal()).tocsr() 219 | D_ffinv = D_ff.power(-1) 220 | 221 | # eq. 4.8 222 | omega = 1 / gershgorin_bound(D_ffinv @ A_ff) 223 | 224 | # eq. 4.7 225 | W = -weighted_jacobi_cr(omega, D_ffinv, A_fc, A_ff, l) 226 | 227 | # W_star = -sparse.linalg.inv(A_ff) @ A_fc 228 | 229 | W = keep_largest_per_row(W, maxp) 230 | # W = keep_largest_per_row(W_star, maxp) 231 | 232 | # eq. 4.9 233 | sparsity = keep_thres_per_row(W, theta_a) 234 | # sparsity = W_star 235 | sparsity.data = np.ones_like(sparsity.data) 236 | 237 | # my_P = my_direct_interpolation(A_fc, A, sparsity, coarse, fine) 238 | 239 | # eq. 4.10 240 | # TODO: implement more efficient indexing by passing to matlab 241 | np_sparsity = sparsity.toarray() 242 | P_sparsity = np.zeros(A.shape) 243 | for i in range(sparsity.shape[1]): 244 | P_sparsity[fine, coarse[i]] = np_sparsity[:, i] 245 | 246 | np.fill_diagonal(P_sparsity, 0) 247 | P_sparsity = csr_matrix(P_sparsity) 248 | 249 | # P = distance_two_interpolation(A.copy(), P_sparsity.copy(), splitting.copy()) 250 | P = direct_interpolation(A.copy(), P_sparsity.copy(), splitting.copy()) 251 | # P = standard_interpolation(A.copy(), P_sparsity.copy(), splitting.copy()) 252 | 253 | return P 254 | # return my_P 255 | 256 | 257 | def gershgorin_bound(M): 258 | return abs(M).sum(axis=1).max() 259 | 260 | 261 | def weighted_jacobi_cr(omega, D_ffinv, A_fc, A_ff, l): 262 | W = csr_matrix((A_fc.shape[0], A_fc.shape[1])) 263 | for _ in range(l): 264 | W = W + omega * D_ffinv @ (A_fc - A_ff @ W) 265 | W.eliminate_zeros() 266 | return W 267 | 268 | 269 | def keep_largest_per_row(M, maxp): 270 | nrows = M.shape[0] 271 | for i in range(nrows): 272 | # Get the row slice, not a copy, only the non zero elements 273 | row_array = M.data[M.indptr[i]: M.indptr[i + 1]] 274 | if row_array.shape[0] <= maxp: 275 | # Not more than maxp elements 276 | continue 277 | 278 | # only take the maxp last elements in the sorted indices 279 | row_array[np.argsort(row_array)[:-maxp]] = 0 280 | M.eliminate_zeros() 281 | return M 282 | 283 | 284 | def keep_thres_per_row(M, theta_a): 285 | nrows = M.shape[0] 286 | for i in range(nrows): 287 | # Get the row slice, not a copy, only the non zero elements 288 | row_array = M.data[M.indptr[i]: M.indptr[i + 1]] 289 | 290 | threshold = theta_a * max(abs(row_array)) 291 | row_array[np.where(abs(row_array) <= threshold)] = 0 292 | M.eliminate_zeros() 293 | return M 294 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import graph_nets as gn 2 | import matlab 3 | import numpy as np 4 | import tensorflow as tf 5 | from scipy.sparse import csr_matrix 6 | 7 | from data import As_poisson_grid 8 | from graph_net_model import EncodeProcessDecodeNonRecurrent 9 | 10 | 11 | def get_model(model_name, model_config, run_config, matlab_engine, train=False, train_config=None): 12 | dummy_input = As_poisson_grid(1, 7 ** 2)[0] 13 | checkpoint_dir = './training_dir/' + model_name 14 | graph_model, optimizer, global_step = load_model(checkpoint_dir, dummy_input, model_config, 15 | run_config, 16 | matlab_engine, get_optimizer=train, 17 | train_config=train_config) 18 | if train: 19 | return graph_model, optimizer, global_step 20 | else: 21 | return graph_model 22 | 23 | 24 | def load_model(checkpoint_dir, dummy_input, model_config, run_config, matlab_engine, get_optimizer=True, 25 | train_config=None): 26 | tf.enable_eager_execution() 27 | model = create_model(model_config) 28 | 29 | # we have to use the model at least once to get the list of variables 30 | model(csrs_to_graphs_tuple([dummy_input], matlab_engine, coarse_nodes_list=np.array([[0, 1]]), 31 | baseline_P_list=[tf.convert_to_tensor(dummy_input.toarray()[:, [0, 1]])], 32 | node_indicators=run_config.node_indicators, 33 | edge_indicators=run_config.edge_indicators)) 34 | 35 | variables = model.get_all_variables() 36 | variables_dict = {variable.name: variable for variable in variables} 37 | if get_optimizer: 38 | global_step = tf.train.get_or_create_global_step() 39 | decay_steps = 100 40 | decay_rate = 1.0 41 | learning_rate = tf.train.exponential_decay(train_config.learning_rate, global_step, decay_steps, decay_rate) 42 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 43 | 44 | checkpoint = tf.train.Checkpoint(**variables_dict, optimizer=optimizer, global_step=global_step) 45 | else: 46 | optimizer = None 47 | global_step = None 48 | checkpoint = tf.train.Checkpoint(**variables_dict) 49 | latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) 50 | if latest_checkpoint is None: 51 | raise RuntimeError(f'training_dir {checkpoint_dir} does not exist') 52 | checkpoint.restore(latest_checkpoint) 53 | return model, optimizer, global_step 54 | 55 | 56 | def create_model(model_config): 57 | with tf.device('/gpu:0'): 58 | return EncodeProcessDecodeNonRecurrent(num_cores=model_config.mp_rounds, edge_output_size=1, 59 | node_output_size=1, global_block=model_config.global_block, 60 | latent_size=model_config.latent_size, 61 | num_layers=model_config.mlp_layers, 62 | concat_encoder=model_config.concat_encoder) 63 | 64 | 65 | def csrs_to_graphs_tuple(csrs, matlab_engine, node_feature_size=128, coarse_nodes_list=None, baseline_P_list=None, 66 | node_indicators=True, edge_indicators=True): 67 | dtype = tf.float64 68 | 69 | # build up the arguments for the GraphsTuple constructor 70 | n_node = tf.convert_to_tensor([csr.shape[0] for csr in csrs]) 71 | n_edge = tf.convert_to_tensor([csr.nnz for csr in csrs]) 72 | 73 | if not edge_indicators: 74 | numpy_edges = np.concatenate([csr.data for csr in csrs]) 75 | edges = tf.expand_dims(tf.convert_to_tensor(numpy_edges, dtype=dtype), axis=1) 76 | else: 77 | edge_encodings_list = [] 78 | for csr, coarse_nodes, baseline_P in zip(csrs, coarse_nodes_list, baseline_P_list): 79 | if tf.is_tensor(baseline_P): 80 | baseline_P = csr_matrix(baseline_P.numpy()) 81 | 82 | baseline_P_rows, baseline_P_cols = P_square_sparsity_pattern(baseline_P, baseline_P.shape[0], 83 | coarse_nodes, matlab_engine) 84 | coo = csr.tocoo() 85 | 86 | # construct numpy structured arrays, where each element is a tuple (row,col), so that we can later use 87 | # the numpy set function in1d() 88 | baseline_P_indices = np.core.records.fromarrays([baseline_P_rows, baseline_P_cols], dtype='i,i') 89 | coo_indices = np.core.records.fromarrays([coo.row, coo.col], dtype='i,i') 90 | 91 | same_indices = np.in1d(coo_indices, baseline_P_indices, assume_unique=True) 92 | baseline_edges = same_indices.astype(np.float64) 93 | non_baseline_edges = (~same_indices).astype(np.float64) 94 | 95 | edge_encodings = np.stack([coo.data, baseline_edges, non_baseline_edges]).T 96 | edge_encodings_list.append(edge_encodings) 97 | numpy_edges = np.concatenate(edge_encodings_list) 98 | edges = tf.convert_to_tensor(numpy_edges, dtype=dtype) 99 | 100 | # COO format for sparse matrices contains a list of row indices and a list of column indices 101 | coos = [csr.tocoo() for csr in csrs] 102 | senders_numpy = np.concatenate([coo.row for coo in coos]) 103 | senders = tf.convert_to_tensor(senders_numpy) 104 | receivers_numpy = np.concatenate([coo.col for coo in coos]) 105 | receivers = tf.convert_to_tensor(receivers_numpy) 106 | 107 | # see the source of _concatenate_data_dicts for explanation 108 | offsets = gn.utils_tf._compute_stacked_offsets(n_node, n_edge) 109 | senders += offsets 110 | receivers += offsets 111 | 112 | if not node_indicators: 113 | nodes = None 114 | else: 115 | node_encodings_list = [] 116 | for csr, coarse_nodes in zip(csrs, coarse_nodes_list): 117 | coarse_indices = np.in1d(range(csr.shape[0]), coarse_nodes, assume_unique=True) 118 | 119 | coarse_node_encodings = coarse_indices.astype(np.float64) 120 | fine_node_encodings = (~coarse_indices).astype(np.float64) 121 | node_encodings = np.stack([coarse_node_encodings, fine_node_encodings]).T 122 | 123 | node_encodings_list.append(node_encodings) 124 | 125 | numpy_nodes = np.concatenate(node_encodings_list) 126 | nodes = tf.convert_to_tensor(numpy_nodes, dtype=dtype) 127 | 128 | graphs_tuple = gn.graphs.GraphsTuple( 129 | nodes=nodes, 130 | edges=edges, 131 | globals=None, 132 | receivers=receivers, 133 | senders=senders, 134 | n_node=n_node, 135 | n_edge=n_edge 136 | ) 137 | if not node_indicators: 138 | graphs_tuple = gn.utils_tf.set_zero_node_features(graphs_tuple, 1, dtype=dtype) 139 | 140 | graphs_tuple = gn.utils_tf.set_zero_global_features(graphs_tuple, node_feature_size, dtype=dtype) 141 | 142 | return graphs_tuple 143 | 144 | 145 | def P_square_sparsity_pattern(P, size, coarse_nodes, matlab_engine): 146 | P_coo = P.tocoo() 147 | P_rows = matlab.double((P_coo.row + 1)) 148 | P_cols = matlab.double((P_coo.col + 1)) 149 | P_values = matlab.double(P_coo.data) 150 | coarse_nodes = matlab.double((coarse_nodes + 1)) 151 | rows, cols = matlab_engine.square_P(P_rows, P_cols, P_values, size, coarse_nodes, nargout=2) 152 | rows = np.array(rows._data).reshape(rows.size, order='F') - 1 153 | cols = np.array(cols._data).reshape(cols.size, order='F') - 1 154 | rows, cols = rows.T[0], cols.T[0] 155 | return rows, cols 156 | 157 | 158 | def graphs_tuple_to_sparse_tensor(graphs_tuple): 159 | senders = graphs_tuple.senders 160 | receivers = graphs_tuple.receivers 161 | indices = tf.cast(tf.stack([senders, receivers], axis=1), tf.int64) 162 | 163 | # first element in the edge feature is the value, the other elements are metadata 164 | values = tf.squeeze(graphs_tuple.edges[:, 0]) 165 | 166 | shape = tf.concat([graphs_tuple.n_node, graphs_tuple.n_node], axis=0) 167 | shape = tf.cast(shape, tf.int64) 168 | 169 | matrix = tf.sparse.SparseTensor(indices, values, shape) 170 | # reordering is required because the pyAMG coarsening step does not preserve indices order 171 | matrix = tf.sparse.reorder(matrix) 172 | 173 | return matrix 174 | 175 | 176 | def to_prolongation_matrix_csr(matrix, coarse_nodes, baseline_P, nodes, normalize_rows=True, 177 | normalize_rows_by_node=False): 178 | """ 179 | sparse version of the above function, for when the dense matrix is too large to fit in GPU memory 180 | used only for inference, so no need for backpropagation, inputs are csr matrices 181 | """ 182 | # prolongation from coarse point to itself should be identity. This corresponds to 1's on the diagonal 183 | matrix.setdiag(np.ones(matrix.shape[0])) 184 | 185 | # select only columns corresponding to coarse nodes 186 | matrix = matrix[:, coarse_nodes] 187 | 188 | # set sparsity pattern (interpolatory sets) to be of baseline prolongation 189 | baseline_P_mask = (baseline_P != 0).astype(np.float64) 190 | matrix = matrix.multiply(baseline_P_mask) 191 | matrix.eliminate_zeros() 192 | 193 | if normalize_rows: 194 | if normalize_rows_by_node: 195 | baseline_row_sum = nodes 196 | else: 197 | baseline_row_sum = baseline_P.sum(axis=1) 198 | baseline_row_sum = np.array(baseline_row_sum)[:, 0] 199 | 200 | matrix_row_sum = np.array(matrix.sum(axis=1))[:, 0] 201 | # https://stackoverflow.com/a/12238133 202 | matrix_copy = matrix.copy() 203 | matrix_copy.data /= matrix_row_sum.repeat(np.diff(matrix_copy.indptr)) 204 | matrix_copy.data *= baseline_row_sum.repeat(np.diff(matrix_copy.indptr)) 205 | matrix = matrix_copy 206 | return matrix 207 | 208 | 209 | def to_prolongation_matrix_tensor(matrix, coarse_nodes, baseline_P, nodes, 210 | normalize_rows=True, 211 | normalize_rows_by_node=False): 212 | dtype = tf.float64 213 | matrix = tf.cast(matrix, dtype) 214 | matrix = tf.sparse.to_dense(matrix) 215 | 216 | # prolongation from coarse point to itself should be identity. This corresponds to 1's on the diagonal 217 | matrix = tf.linalg.set_diag(matrix, tf.ones(matrix.shape[0], dtype=dtype)) 218 | 219 | # select only columns corresponding to coarse nodes 220 | matrix = tf.gather(matrix, coarse_nodes, axis=1) 221 | 222 | # set sparsity pattern (interpolatory sets) to be of baseline prolongation 223 | baseline_zero_mask = tf.cast(tf.not_equal(baseline_P, tf.zeros_like(baseline_P)), dtype) 224 | matrix = matrix * baseline_zero_mask 225 | 226 | if normalize_rows: 227 | if normalize_rows_by_node: 228 | baseline_row_sum = nodes 229 | else: 230 | baseline_row_sum = tf.reduce_sum(baseline_P, axis=1) 231 | baseline_row_sum = tf.cast(baseline_row_sum, dtype) 232 | 233 | matrix_row_sum = tf.reduce_sum(matrix, axis=1) 234 | matrix_row_sum = tf.cast(matrix_row_sum, dtype) 235 | 236 | # there might be a few rows that are all 0's - corresponding to fine points that are not connected to any 237 | # coarse point. We use "divide_no_nan" to let these rows remain 0's 238 | matrix = tf.math.divide_no_nan(matrix, tf.reshape(matrix_row_sum, (-1, 1))) 239 | 240 | matrix = matrix * tf.reshape(baseline_row_sum, (-1, 1)) 241 | return matrix 242 | 243 | 244 | def graphs_tuple_to_sparse_matrices(graphs_tuple, return_nodes=False): 245 | num_graphs = int(graphs_tuple.n_node.shape[0]) 246 | graphs = [gn.utils_tf.get_graph(graphs_tuple, i) 247 | for i in range(num_graphs)] 248 | 249 | matrices = [graphs_tuple_to_sparse_tensor(graph) for graph in graphs] 250 | 251 | if return_nodes: 252 | nodes_list = [tf.squeeze(graph.nodes) for graph in graphs] 253 | return matrices, nodes_list 254 | else: 255 | return matrices 256 | 257 | 258 | -------------------------------------------------------------------------------- /multigrid_utils.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import matlab.engine 4 | import numpy as np 5 | import pyamg 6 | import scipy.linalg 7 | import tensorflow as tf 8 | from pyamg.classical import direct_interpolation 9 | from scipy.sparse import csr_matrix 10 | 11 | from utils import chunks, most_frequent_splitting 12 | 13 | 14 | def frob_norm(a, power=1): 15 | if power == 1: 16 | return tf.norm(a, axis=[-2, -1]) 17 | else: 18 | curr_power = a 19 | for i in range(power - 1): 20 | curr_power = a @ curr_power 21 | return tf.norm(curr_power, axis=[-2, -1]) ** (1 / power) 22 | 23 | 24 | def compute_coarse_A(R, A, P): 25 | return R @ A @ P 26 | 27 | 28 | def compute_coarse_As(padded_Rs, padded_As, padded_Ps): 29 | RAs = padded_Rs @ padded_As 30 | RAPs = RAs @ padded_Ps 31 | return RAPs 32 | 33 | 34 | def compute_Cs(padded_As, Is, padded_Ps, padded_Rs, coarse_As_inv): 35 | RAs = padded_Rs @ padded_As 36 | coarse_A_inv_RAs = coarse_As_inv @ RAs 37 | P_coarse_A_inv_RAs = padded_Ps @ coarse_A_inv_RAs 38 | Cs = Is - P_coarse_A_inv_RAs 39 | return Cs 40 | 41 | 42 | def two_grid_error_matrices(padded_As, padded_Ps, padded_Rs, padded_Ss): 43 | batch_size = padded_As.shape[0].value 44 | padded_length = padded_As.shape[1].value 45 | Is = tf.eye(padded_length, batch_shape=[batch_size], dtype=padded_As.dtype) 46 | coarse_As = compute_coarse_As(padded_Rs, padded_As, padded_Ps) 47 | coarse_As_inv = tf.linalg.inv(coarse_As) 48 | Cs = compute_Cs(padded_As, Is, padded_Ps, padded_Rs, coarse_As_inv) 49 | Ms = padded_Ss @ Cs @ padded_Ss 50 | return Ms 51 | 52 | 53 | def two_grid_error_matrix(A, P, R, S): 54 | I = tf.eye(A.shape[0].value, dtype=A.dtype) 55 | coarse_A = compute_coarse_A(R, A, P) 56 | coarse_A_inv = tf.linalg.inv(coarse_A) 57 | C = compute_C(A, I, P, R, coarse_A_inv) 58 | M = S @ C @ S 59 | return M 60 | 61 | 62 | def compute_C(A, I, P, R, coarse_A_inv): 63 | RA = R @ A 64 | coarse_A_inv_RA = coarse_A_inv @ RA 65 | P_coarse_A_inv_RA = P @ coarse_A_inv_RA 66 | C = I - P_coarse_A_inv_RA 67 | return C 68 | 69 | 70 | def block_diag_multiply(W_conj_t, As, W): 71 | return W_conj_t @ As @ W 72 | 73 | 74 | def extract_diag_blocks(block_diag_As, block_size, root_num_blocks, single_matrix=False): 75 | """extracts the block matrices on the diagonal""" 76 | if single_matrix: 77 | return [ 78 | block_diag_As[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size] 79 | for i in range(root_num_blocks ** 2)] 80 | else: 81 | return [ 82 | block_diag_As[:, i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size] 83 | for i in range(root_num_blocks ** 2)] 84 | 85 | 86 | def block_diagonalize_A_fast(As, root_num_blocks, tensor=False): 87 | """Returns root_num_blocks**2 matrices that represent the block diagonalization of A""" 88 | if tensor: 89 | total_size = As.shape[1].value 90 | else: 91 | total_size = As.shape[1] 92 | block_size = total_size // root_num_blocks 93 | 94 | double_W, double_W_conj_t = create_double_W(block_size, root_num_blocks, tensor) 95 | block_diag_A = block_diag_multiply(double_W_conj_t, As, double_W) 96 | 97 | small_block_size = block_size // root_num_blocks 98 | blocks = extract_diag_blocks(block_diag_A, small_block_size, root_num_blocks) 99 | 100 | if tensor: 101 | return tf.stack(blocks, axis=1) 102 | else: 103 | return [csr_matrix(block) for block_list in blocks for block in block_list] 104 | 105 | 106 | def block_diagonalize_A_single(A, root_num_blocks, tensor=False): 107 | """Returns root_num_blocks**2 matrices that represent the block diagonalization of A""" 108 | if tensor: 109 | total_size = A.shape[0].value 110 | else: 111 | total_size = A.shape[0] 112 | block_size = total_size // root_num_blocks 113 | 114 | double_W, double_W_conj_t = create_double_W(block_size, root_num_blocks, tensor) 115 | block_diag_A = block_diag_multiply(double_W_conj_t, A, double_W) 116 | 117 | small_block_size = block_size // root_num_blocks 118 | blocks = extract_diag_blocks(block_diag_A, small_block_size, root_num_blocks, single_matrix=True) 119 | blocks = blocks[1:] # ignore zero mode block 120 | 121 | if tensor: 122 | return tf.stack(blocks, axis=0) 123 | else: 124 | return [csr_matrix(block) for block_list in blocks for block in block_list] 125 | 126 | 127 | def block_diagonalize_A(A, root_num_blocks): 128 | """Returns root_num_blocks**2 matrices that represent the block diagonalization of A""" 129 | block_size = A.shape[0] // root_num_blocks 130 | 131 | # block-diagonalize each of the blocks in the first row of blocks (no need to block-diagonalize all blocks, 132 | # because A is block-circulant) 133 | small_W, small_W_conj_t = create_W_matrix(block_size // root_num_blocks, root_num_blocks) 134 | small_diagonalized_blocks = [] 135 | for i in range(root_num_blocks): 136 | small_block = A[:block_size, i * block_size:(i + 1) * block_size] 137 | small_diagonalized_block = small_W_conj_t @ small_block @ small_W 138 | small_diagonalized_blocks.append(small_diagonalized_block) 139 | 140 | # arrange the block-diagonalized blocks into a block-circulant matrix 141 | block_list = [] 142 | for shift in range(root_num_blocks): 143 | shifted_list = np.roll(small_diagonalized_blocks, shift, axis=0) 144 | block_list.append(list(shifted_list)) 145 | small_block_diagonalized_A = np.block(block_list) 146 | 147 | # block-diagonalize the block-circulant matrix, extract the resulting blocks and stack them 148 | double_block_diagonalized_A = block_diagonalize_1d_circulant(small_block_diagonalized_A, root_num_blocks) 149 | small_blocks = [ 150 | double_block_diagonalized_A[b, i:i + block_size // root_num_blocks, i:i + block_size // root_num_blocks] 151 | for i in range(0, A.shape[0] // root_num_blocks, block_size // root_num_blocks) 152 | for b in range(root_num_blocks)] 153 | return np.stack(small_blocks) 154 | 155 | 156 | def pad_P(P, coarse_nodes): 157 | total_size = P.shape[0].value 158 | zero_column = tf.zeros([total_size], dtype=tf.float64) 159 | P_cols = tf.unstack(P, axis=1) 160 | full_P_cols = [] 161 | curr_P_col = 0 162 | is_coarse = np.in1d(range(total_size), coarse_nodes, assume_unique=True) 163 | for col_index in range(total_size): 164 | if is_coarse[col_index]: 165 | column = P_cols[curr_P_col] 166 | curr_P_col += 1 167 | else: 168 | column = zero_column 169 | full_P_cols.append(column) 170 | 171 | full_P = tf.transpose(tf.stack(full_P_cols)) 172 | full_P = tf.cast(full_P, tf.complex128) 173 | return full_P 174 | 175 | 176 | def block_diagonalize_P(P, root_num_blocks, coarse_nodes): 177 | """ 178 | Returns root_num_blocks**2 matrices that represent the block diagonalization of P 179 | Only works on block-periodic prolongation matrices 180 | """ 181 | total_size = P.shape[0].value 182 | block_size = total_size // root_num_blocks 183 | 184 | # we build the padded P matrix column by column, I couldn't find a more efficient way 185 | full_P = pad_P(P, coarse_nodes) 186 | 187 | double_W, double_W_conj_t = create_double_W(block_size, root_num_blocks, True) 188 | block_diag_full_P = block_diag_multiply(double_W_conj_t, full_P, double_W) 189 | 190 | small_block_size = block_size // root_num_blocks 191 | blocks = extract_diag_blocks(block_diag_full_P, small_block_size, root_num_blocks, single_matrix=True) 192 | blocks = blocks[1:] # ignore zero mode block 193 | 194 | block_coarse_nodes = coarse_nodes[:len(coarse_nodes) // root_num_blocks**2] 195 | blocks = [tf.gather(block, block_coarse_nodes, axis=1) for block in blocks] 196 | 197 | return blocks 198 | 199 | 200 | def block_diagonalize_1d_circulant(A, root_num_blocks): 201 | """ 202 | Returns root_num_blocks matrices that represent the block diagonalization of A 203 | We apply this function recursively to block-diagonalize 2d-block-circulant matrices 204 | Refer to docs/block_fourier_analysis.pdf for notation and details 205 | """ 206 | total_size = A.shape[0] 207 | block_size = total_size // root_num_blocks 208 | W, W_conj_t = create_W_matrix(block_size, root_num_blocks) 209 | block_diagonal_matrix = W_conj_t @ A @ W 210 | 211 | # extract the block matrices on the diagonal 212 | blocks = [block_diagonal_matrix[i:i + block_size, i:i + block_size] for i in range(0, total_size, block_size)] 213 | return np.stack(blocks) 214 | 215 | 216 | @lru_cache(maxsize=None) 217 | def create_W_matrix(block_size, root_num_blocks, tensor=False): 218 | """ 219 | Returns a matrix that block-diagonalizes a block-circulant matrix 220 | Refer to docs/block_fourier_analysis.pdf for notation and details 221 | """ 222 | total_size = block_size * root_num_blocks 223 | dft_matrix = scipy.linalg.dft(total_size) 224 | dft_matrix_first_b_columns = dft_matrix[:, :root_num_blocks] 225 | 226 | columns = [] 227 | for i in range(root_num_blocks): 228 | for k in range(block_size): 229 | col_mask = np.ones(total_size, np.bool) 230 | col_mask[k:total_size:block_size] = 0 231 | column = np.copy(dft_matrix_first_b_columns[:, i]) 232 | column[col_mask] = 0 233 | columns.append(column) 234 | W = np.stack(columns, axis=1) 235 | W /= np.sqrt(root_num_blocks) 236 | W_conj_t = W.conj().T 237 | 238 | if tensor: 239 | W, W_conj_t = tf.convert_to_tensor(W), tf.convert_to_tensor(W_conj_t) 240 | return W, W_conj_t 241 | 242 | 243 | @lru_cache(maxsize=None) 244 | def create_double_W(block_size, root_num_blocks, tensor=False): 245 | big_W, _ = create_W_matrix(block_size, root_num_blocks) 246 | small_W, _ = create_W_matrix(block_size // root_num_blocks, root_num_blocks) 247 | small_W_block = scipy.linalg.block_diag(*[small_W] * root_num_blocks) 248 | double_W = small_W_block @ big_W 249 | double_W_conj_t = double_W.conj().T 250 | 251 | if tensor: 252 | double_W, double_W_conj_t = tf.convert_to_tensor(double_W), tf.convert_to_tensor(double_W_conj_t) 253 | return double_W, double_W_conj_t 254 | 255 | 256 | def test_create_W_matrix(): 257 | """Check if W matrix is unitary""" 258 | W, W_conj_T = create_W_matrix(3, 4) 259 | I = W_conj_T @ W 260 | print(np.all(np.isclose(I, np.eye(3 * 4)))) 261 | 262 | 263 | def test_block_diagonalize_1d_circulant(): 264 | """Check if eigenvalues of block matrices are the same as eigenvalues of original block-circulant matrix""" 265 | matlab_engine = matlab.engine.start_matlab() 266 | matlab_engine.eval('rng(1)') # fix random seed for reproducibility 267 | 268 | def generate_A_delaunay_block_periodic_lognormal(num_unknowns_per_block, root_num_blocks, matlab_engine): 269 | """Poisson equation on triangular mesh, with lognormal coefficients, and block periodic boundary conditions""" 270 | # points are correct only for 3x3 number of blocks 271 | A_matlab, points_matlab = matlab_engine.block_periodic_delaunay(num_unknowns_per_block, root_num_blocks, 272 | nargout=2) 273 | A_numpy = np.array(A_matlab._data).reshape(A_matlab.size, order='F') 274 | points_numpy = np.array(points_matlab._data).reshape(points_matlab.size, order='F') 275 | return csr_matrix(A_numpy), points_numpy 276 | 277 | A, _ = generate_A_delaunay_block_periodic_lognormal(3, 4, matlab_engine) 278 | A = A.toarray() 279 | blocks = block_diagonalize_1d_circulant(A, 4) 280 | A_eigs = np.sort(np.linalg.eigvals(A)) 281 | block_eigs = np.sort(np.linalg.eigvals(blocks).flatten()) 282 | print(np.all(np.isclose(A_eigs, block_eigs))) 283 | 284 | 285 | def test_block_diagonalize_A(): 286 | """Check if eigenvalues of block matrices are the same as eigenvalues of original block-circulant matrix""" 287 | matlab_engine = matlab.engine.start_matlab() 288 | matlab_engine.eval('rng(1)') # fix random seed for reproducibility 289 | 290 | def generate_A_delaunay_block_periodic_lognormal(num_unknowns_per_block, root_num_blocks, matlab_engine): 291 | """Poisson equation on triangular mesh, with lognormal coefficients, and block periodic boundary conditions""" 292 | # points are correct only for 3x3 number of blocks 293 | A_matlab, points_matlab = matlab_engine.block_periodic_delaunay(num_unknowns_per_block, root_num_blocks, 294 | nargout=2) 295 | A_numpy = np.array(A_matlab._data).reshape(A_matlab.size, order='F') 296 | points_numpy = np.array(points_matlab._data).reshape(points_matlab.size, order='F') 297 | return csr_matrix(A_numpy), points_numpy 298 | 299 | A, _ = generate_A_delaunay_block_periodic_lognormal(15, 4, matlab_engine) 300 | A = A.toarray() 301 | blocks = block_diagonalize_A(A, 4) 302 | 303 | # check if eigenvalues are identical 304 | A_eigs = np.sort(np.linalg.eigvals(A)) 305 | block_eigs = np.sort(np.linalg.eigvals(blocks).flatten()) 306 | print(np.all(np.isclose(A_eigs, block_eigs))) 307 | 308 | 309 | def test_block_diagonalize_A_fast(): 310 | """Check if eigenvalues of block matrices are the same as eigenvalues of original block-circulant matrix""" 311 | matlab_engine = matlab.engine.start_matlab() 312 | matlab_engine.eval('rng(1)') # fix random seed for reproducibility 313 | 314 | def generate_A_delaunay_block_periodic_lognormal(num_unknowns_per_block, root_num_blocks, matlab_engine): 315 | """Poisson equation on triangular mesh, with lognormal coefficients, and block periodic boundary conditions""" 316 | # points are correct only for 3x3 number of blocks 317 | A_matlab = matlab_engine.block_periodic_delaunay(num_unknowns_per_block, root_num_blocks, 318 | nargout=1) 319 | A_numpy = np.array(A_matlab._data).reshape(A_matlab.size, order='F') 320 | return csr_matrix(A_numpy) 321 | 322 | batch_size = 32 323 | As = [generate_A_delaunay_block_periodic_lognormal(5, 3, matlab_engine) for i in range(batch_size)] 324 | As = [A.toarray() for A in As] 325 | As = tf.stack(As) 326 | As = tf.cast(As, dtype=tf.complex128) 327 | blocks = block_diagonalize_A_fast(As, 3, tensor=True).numpy() 328 | 329 | # check if eigenvalues are identical 330 | A_eigs = np.sort(np.linalg.eigvals(As.numpy()).flatten()) 331 | block_eigs = np.sort(np.linalg.eigvals(blocks).flatten()) 332 | print(np.all(np.isclose(A_eigs, block_eigs))) 333 | 334 | 335 | def test_block_diagonalize_P(): 336 | """Check if eigenvalues of block matrices are the same as eigenvalues of original block-circulant matrix""" 337 | matlab_engine = matlab.engine.start_matlab() 338 | matlab_engine.eval('rng(1)') # fix random seed for reproducibility 339 | 340 | def generate_A_delaunay_block_periodic_lognormal(num_unknowns_per_block, root_num_blocks, matlab_engine): 341 | """Poisson equation on triangular mesh, with lognormal coefficients, and block periodic boundary conditions""" 342 | # points are correct only for 3x3 number of blocks 343 | A_matlab = matlab_engine.block_periodic_delaunay(num_unknowns_per_block, root_num_blocks, 344 | nargout=1) 345 | A_numpy = np.array(A_matlab._data).reshape(A_matlab.size, order='F') 346 | return csr_matrix(A_numpy) 347 | 348 | num_unknowns_per_block = 64 349 | root_num_blocks = 3 350 | A = generate_A_delaunay_block_periodic_lognormal(num_unknowns_per_block, root_num_blocks, matlab_engine) 351 | # A = A + 0.1 * scipy.sparse.diags(np.ones(num_unknowns_per_block * root_num_blocks**2)) 352 | 353 | orig_solver = pyamg.ruge_stuben_solver(A, max_levels=2, max_coarse=1, CF='CLJP', keep=True) 354 | orig_splitting = orig_solver.levels[0].splitting 355 | block_splitting = list(chunks(orig_splitting, num_unknowns_per_block)) 356 | common_block_splitting = most_frequent_splitting(block_splitting) 357 | repeated_splitting = np.tile(common_block_splitting, root_num_blocks ** 2) 358 | 359 | # we recompute the Ruge-Stuben prolongation matrix with the modified splitting, and the original strength 360 | # matrix. We assume the strength matrix is block-circulant (because A is block-circulant) 361 | C = orig_solver.levels[0].C 362 | P = direct_interpolation(A, C, repeated_splitting) 363 | P = tf.convert_to_tensor(P.toarray(), dtype=tf.float64) 364 | 365 | P_blocks = block_diagonalize_P(P, root_num_blocks, repeated_splitting.nonzero()) 366 | P_blocks = P_blocks.numpy() 367 | 368 | A_c = P.numpy().T @ A.toarray() @ P.numpy() 369 | # double_W, double_W_conj_t = create_double_W(num_unknowns_per_block * root_num_blocks, root_num_blocks) 370 | # A_c_full_block_diag = double_W_conj_t @ A_c_full @ double_W 371 | 372 | # P_full_block_diag = double_W_conj_t @ full_P @ double_W 373 | # A_block_diag = double_W_conj_t @ A.toarray() @ double_W 374 | # A_c_full_block_diag_2 = P_full_block_diag.conj().T @ A_block_diag @ P_full_block_diag 375 | 376 | tf_A = tf.cast(tf.stack([A.toarray()]), tf.complex128) 377 | A_blocks = block_diagonalize_A_fast(tf_A, root_num_blocks, True).numpy()[0][1:] # ignore the first zero mode block 378 | 379 | def relaxation_matrices(As, w=0.8): 380 | I = np.eye(As[0].shape[0]) 381 | res = [I - w * np.diag(1 / (np.diag(A))) @ A for A in As] 382 | 383 | # computes the iteration matrix of the relaxation, here Gauss-Seidel is used. 384 | # This function is called on each block seperately. 385 | # num_As = len(As) 386 | # grid_sizes = [A.shape[0] for A in As] 387 | # Bs = [A.copy() for A in As] 388 | # for B, grid_size in zip(Bs, grid_sizes): 389 | # B[np.tril_indices(grid_size, 0)[0], np.tril_indices(grid_size, 0)[1]] = 0. # B is the upper part of A 390 | # res = [] 391 | # for i in tqdm(range(num_As)): # range(A.shape[0] // batch_size): 392 | # res.append(scipy.linalg.solve_triangular(a=As[i], 393 | # b=-Bs[i], 394 | # lower=True, unit_diagonal=False, 395 | # overwrite_b=False, debug=None, check_finite=True).astype( 396 | # np.float64)) 397 | return res 398 | 399 | S = relaxation_matrices([A.toarray()])[0] 400 | S_blocks = relaxation_matrices(A_blocks) 401 | 402 | # A_c = P.numpy().T @ A.toarray() @ P.numpy() 403 | A_c_blocks = P_blocks.transpose([0, 2, 1]).conj() @ A_blocks @ P_blocks 404 | 405 | A = A.toarray() 406 | C = np.eye(A.shape[0]) - P.numpy() @ np.linalg.inv(A_c) @ P.numpy().T @ A 407 | M = S @ C @ S 408 | 409 | I = np.eye(A_blocks[0].shape[0]) 410 | C_blocks = [I - P_block @ np.linalg.inv(A_c_block) @ P_block.conj().T @ A_block 411 | for (P_block, A_c_block, A_block) in zip(P_blocks, A_c_blocks, A_blocks)] 412 | M_blocks = [S_block @ C_block @ S_block for (S_block, C_block) in zip(S_blocks, C_blocks)] 413 | 414 | # # extract only elements that correspond to coarse nodes 415 | # A_c_blocks = A_c_blocks[:, common_block_splitting.nonzero()[0][:, None], common_block_splitting.nonzero()[0]] 416 | 417 | 418 | A_c_block_eigs = np.sort(np.linalg.eigvals(A_c_blocks).flatten()) 419 | A_c_eigs = np.sort(np.linalg.eigvals(A_c)) 420 | 421 | C_block_eigs = np.sort(np.linalg.eigvals(C_blocks).flatten()) 422 | C_eigs = np.sort(np.linalg.eigvals(C)) 423 | 424 | 425 | M_block_eigs = np.sort(np.linalg.eigvals(M_blocks).flatten()) 426 | M_eigs = np.sort(np.linalg.eigvals(M)) 427 | 428 | pass 429 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict 3 | from functools import lru_cache 4 | 5 | import matlab 6 | import meshpy.triangle as triangle 7 | import numpy as np 8 | import pyamg 9 | import scipy 10 | from scipy.sparse import csr_matrix, lil_matrix 11 | from scipy.sparse.csgraph import laplacian as csgraph_laplacian 12 | from scipy.spatial.qhull import Delaunay 13 | from sklearn import datasets 14 | from sklearn.neighbors import kneighbors_graph 15 | from sklearn.preprocessing import StandardScaler 16 | 17 | 18 | def generate_A(size, dist, block_periodic, root_num_blocks, add_diag=False, matlab_engine=None): 19 | if dist is 'lognormal_laplacian': 20 | A = generate_A_delaunay_dirichlet_lognormal(size, matlab_engine=matlab_engine) 21 | elif dist is 'lognormal_laplacian_periodic': 22 | if block_periodic: 23 | A, _ = generate_A_delaunay_block_periodic_lognormal(size, root_num_blocks, matlab_engine) 24 | else: 25 | A = generate_A_delaunay_periodic_lognormal(size, uniform=False) 26 | elif dist is 'lognormal_complex_fem': 27 | A = A_dirichlet_finite_element_quality(size, matlab_engine, hole=True) 28 | elif dist is 'spectral_clustering': 29 | A = generate_A_spec_cluster(size, add_diag=add_diag) 30 | elif dist is 'poisson': 31 | grid_size = int(np.sqrt(size)) 32 | A = pyamg.gallery.poisson((grid_size, grid_size), type='FE') 33 | elif dist is 'aniso': 34 | grid_size = int(np.sqrt(size)) 35 | # stencil = pyamg.gallery.diffusion_stencil_2d(epsilon=0.01, type='FE') 36 | stencil = pyamg.gallery.diffusion_stencil_2d(epsilon=0.01, theta=np.pi / 3, type='FE') 37 | A = pyamg.gallery.stencil_grid(stencil, (grid_size, grid_size), format='csr') 38 | elif dist is 'example': 39 | A = pyamg.gallery.load_example(size)['A'] 40 | return A 41 | 42 | 43 | def drop_zero_row_col_matlab(A, matlab_engine): 44 | size = A.shape[0] 45 | A_coo = A.tocoo() 46 | A_rows = matlab.double((A_coo.row + 1)) 47 | A_cols = matlab.double((A_coo.col + 1)) 48 | A_values = matlab.double(A_coo.data) 49 | rows, cols, values = matlab_engine.drop_zero_row_col(A_rows, A_cols, A_values, size, nargout=3) 50 | rows = np.array(rows._data).reshape(rows.size, order='F') - 1 51 | cols = np.array(cols._data).reshape(cols.size, order='F') - 1 52 | values = np.array(values._data).reshape(values.size, order='F') 53 | rows, cols, values = rows.T[0], cols.T[0], values.T[0] 54 | rows, cols = rows.astype(np.int), cols.astype(np.int) 55 | return csr_matrix((values, (rows, cols))) 56 | 57 | 58 | def drop_zero_row_col(A): 59 | # https://stackoverflow.com/a/35905815 60 | return A[A.getnnz(1) > 0][:, A.getnnz(0) > 0] 61 | 62 | 63 | def generate_A_delaunay_dirichlet_lognormal(num_points, constant_coefficients=False, uniform=False, 64 | matlab_engine=None): 65 | """ 66 | Poisson equation on triangular mesh, with lognormal coefficients 67 | We create a triangulation of random points on the square from (-1,-1) to (2,2), 68 | and we look only at points that lie inside the unit square. Each point that has an edge to a point 69 | outside the unit square, we designate as a boundary 70 | the total number of points in the grid, including boundaries, is num_points 71 | the number of unknowns is the number of points minus the number of boundaries, which is variable 72 | """ 73 | rand_points = np.random.uniform([-1, -1], [2, 2], [num_points * 3 ** 2, 2]) 74 | 75 | # remove points that lie inside the unit square 76 | unit_square_indices = np.where( 77 | (rand_points[:, 0] >= 0) & 78 | (rand_points[:, 0] <= 1) & 79 | (rand_points[:, 1] >= 0) & 80 | (rand_points[:, 1] <= 1) 81 | )[0] 82 | 83 | rand_points = np.delete(rand_points, unit_square_indices, axis=0) 84 | 85 | # add back exactly num_unknowns points to the unit square 86 | rand_points_unit_square = np.random.uniform([0, 0], [1, 1], [num_points, 2]) 87 | rand_points = np.concatenate([rand_points_unit_square, rand_points]) 88 | 89 | tri = Delaunay(rand_points) 90 | # vertex_neighbor_vertices is used to get neighbors: 91 | # https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.spatial.Delaunay.vertex_neighbor_vertices.html 92 | index_pointers = tri.vertex_neighbor_vertices[0] 93 | indices = tri.vertex_neighbor_vertices[1] 94 | 95 | # random coefficients must be negative numbers, size must be at least number of edges 96 | if constant_coefficients: 97 | random_values = -np.ones(shape=[tri.nsimplex * 3]) 98 | else: 99 | if uniform: 100 | random_values = -np.random.uniform(size=[tri.nsimplex * 3]) 101 | else: 102 | random_values = -np.exp(np.random.normal(size=[tri.nsimplex * 3])) 103 | A = lil_matrix((num_points, num_points), dtype=np.float64) 104 | 105 | # go through every point in the unit square, and add a random coefficient to it's inside neighbors. 106 | # if there is one or more outside neighbors, record the number so later we'll add this number of coefficient 107 | # to the diagonal 108 | boundary_indices = [] 109 | for vertex_id in range(num_points): 110 | neighbors = indices[index_pointers[vertex_id]:index_pointers[vertex_id + 1]] 111 | outside_neighbors = neighbors[np.where(neighbors >= num_points)] 112 | if len(outside_neighbors) > 0: 113 | boundary_indices.append(vertex_id) 114 | 115 | num_boundary_neighbors_dict = {} 116 | edge_counter = 0 117 | is_boundary = np.in1d(range(num_points), boundary_indices, assume_unique=True) 118 | for vertex_id in range(num_points): 119 | # if vertex is boundary, do not include in A 120 | # if np.isin(vertex_id, boundary_indices): 121 | # continue 122 | if is_boundary[vertex_id]: 123 | continue 124 | 125 | neighbors = indices[index_pointers[vertex_id]:index_pointers[vertex_id + 1]] 126 | internal_neighbors = np.setdiff1d(neighbors, boundary_indices, assume_unique=True) 127 | num_boundary_neighbors = len(neighbors) - len(internal_neighbors) 128 | if num_boundary_neighbors > 0: 129 | num_boundary_neighbors_dict[vertex_id] = num_boundary_neighbors 130 | 131 | for neighbor_id in np.sort(internal_neighbors): 132 | A[vertex_id, neighbor_id] = random_values[edge_counter] 133 | A[neighbor_id, vertex_id] = random_values[edge_counter] 134 | edge_counter += 1 135 | 136 | # set row sums to be zero 137 | row_sums = A.sum(axis=0) 138 | for i in range(num_points): 139 | A[i, i] = -row_sums[0, i] 140 | 141 | for boundary_id, num_boundary_neighbors in num_boundary_neighbors_dict.items(): 142 | diagonal_value = -random_values[edge_counter:edge_counter + num_boundary_neighbors].sum() 143 | edge_counter += num_boundary_neighbors 144 | A[boundary_id, boundary_id] += diagonal_value 145 | 146 | # drop zero rows and columns 147 | if matlab_engine: 148 | A = drop_zero_row_col_matlab(A, matlab_engine) 149 | else: 150 | A = drop_zero_row_col(A).tocsr() 151 | return A 152 | 153 | 154 | def generate_A_spec_cluster(num_unknowns, add_diag=False, num_clusters=2, unit_std=False, dim=2, dist='gauss', gamma=None, 155 | distance=False, return_x=False, n_neighbors=10): 156 | """ 157 | Similar params to https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html 158 | With spectral clustering 159 | """ 160 | centers = num_clusters 161 | if num_clusters == 2 and not unit_std: 162 | cluster_std = [1.0, 2.5] # looks good, sometimes graph is connected, sometimes not 163 | size_factor = 1 164 | else: 165 | cluster_std = 1.0 166 | size_factor = num_unknowns / 1000 167 | center_box = [-10 * size_factor, 10 * size_factor] 168 | norm_laplacian = True 169 | 170 | if dist == 'gauss': 171 | X, y = datasets.make_blobs(n_samples=num_unknowns, n_features=dim, centers=centers, 172 | cluster_std=cluster_std, center_box=center_box) 173 | elif dist == 'moons': 174 | X, y = datasets.make_moons(n_samples=num_unknowns, noise=.05) 175 | elif dist == 'circles': 176 | X, y = datasets.make_circles(n_samples=num_unknowns, noise=.05, factor=.5) 177 | elif dist == 'random': 178 | X = np.random.rand(num_unknowns, dim) 179 | X = StandardScaler().fit_transform(X) 180 | 181 | if distance: 182 | mode = 'distance' 183 | else: 184 | mode = 'connectivity' 185 | connectivity = kneighbors_graph(X, n_neighbors=n_neighbors, mode=mode, 186 | include_self=True) 187 | if gamma is not None: 188 | np.exp(-(gamma * connectivity.data) ** 2, out=connectivity.data) 189 | affinity_matrix = 0.5 * (connectivity + connectivity.T) 190 | 191 | laplacian, dd = csgraph_laplacian(affinity_matrix, normed=norm_laplacian, 192 | return_diag=True) 193 | # set diagonal to 1 if normed 194 | if norm_laplacian: 195 | diag_idx = (laplacian.row == laplacian.col) 196 | laplacian.data[diag_idx] = 1 197 | 198 | if add_diag: 199 | small_diag = scipy.sparse.diags(np.random.uniform(0, 0.02, num_unknowns)) 200 | laplacian += small_diag 201 | 202 | if return_x: 203 | return X, laplacian 204 | else: 205 | return laplacian 206 | 207 | 208 | def generate_A_delaunay_block_periodic_lognormal(num_unknowns_per_block, root_num_blocks, matlab_engine): 209 | """Poisson equation on triangular mesh, with lognormal coefficients, and block periodic boundary conditions""" 210 | # points are correct only for 3x3 number of blocks 211 | A_matlab, points_matlab = matlab_engine.block_periodic_delaunay(num_unknowns_per_block, root_num_blocks, nargout=2) 212 | A_numpy = np.array(A_matlab._data).reshape(A_matlab.size, order='F') 213 | points_numpy = np.array(points_matlab._data).reshape(points_matlab.size, order='F') 214 | return csr_matrix(A_numpy), points_numpy 215 | 216 | 217 | def As_poisson_grid(num_As, num_unknowns, constant_coefficients=False): 218 | grid_size = int(math.sqrt(num_unknowns)) 219 | if grid_size ** 2 != num_unknowns: 220 | raise RuntimeError("num_unknowns must be a square number") 221 | stencils = poisson_dirichlet_stencils(num_As, grid_size, constant_coefficients=constant_coefficients) 222 | A_idx, stencil_idx = compute_A_indices(grid_size) 223 | matrices = [] 224 | for stencil in stencils: 225 | matrix = csr_matrix(arg1=(stencil.reshape((-1)), (A_idx[:, 0], A_idx[:, 1])), 226 | shape=(grid_size ** 2, grid_size ** 2)) 227 | matrix.eliminate_zeros() 228 | matrices.append(matrix) 229 | return matrices 230 | 231 | 232 | @lru_cache(maxsize=None) 233 | def create_hole_mesh(num_unknowns): 234 | def round_trip_connect(start, end): 235 | return [(i, i + 1) for i in range(start, end)] + [(end, start)] 236 | 237 | points = [(1 / 4, 0), (1 / 4, 1 / 4), (-1 / 4, 1 / 4), (-1 / 4, -1 / 4), (1 / 4, -1 / 4), (1 / 4, 0)] 238 | facets = round_trip_connect(0, len(points) - 1) 239 | 240 | circ_start = len(points) 241 | points.extend( 242 | (1.5 * np.cos(angle), 1.5 * np.sin(angle)) 243 | for angle in np.linspace(0, 2 * np.pi, 30, endpoint=False)) 244 | 245 | facets.extend(round_trip_connect(circ_start, len(points) - 1)) 246 | maximum_area = 0.5 / num_unknowns 247 | 248 | def needs_refinement(vertices, area): 249 | bary = np.sum(np.array(vertices), axis=0) / 3 250 | max_area = maximum_area + (np.linalg.norm(bary, np.inf) - 1 / 4) * maximum_area * 10 251 | return bool(area > max_area) 252 | 253 | info = triangle.MeshInfo() 254 | info.set_points(points) 255 | info.set_holes([(0, 0)]) 256 | info.set_facets(facets) 257 | 258 | mesh = triangle.build(info, refinement_func=needs_refinement) 259 | return mesh 260 | 261 | 262 | @lru_cache(maxsize=None) 263 | def create_quality_mesh(num_unknowns): 264 | mesh_info = triangle.MeshInfo() 265 | points = [(1.5, 1.5), (-1.5, 1.5), (-1.5, -1.5), (1.5, -1.5)] 266 | segments = [(i, i + 1) for i in range(3)] + [(3, 0)] 267 | mesh_info.set_points(points) 268 | mesh_info.set_facets(segments) 269 | mesh = triangle.build(mesh_info, max_volume=3 / (num_unknowns * 2 ** 2), min_angle=30) 270 | return mesh 271 | 272 | 273 | def vertex_to_tris_map(simplices): 274 | M = defaultdict(set) 275 | for i, tri in enumerate(simplices): 276 | for point in tri: 277 | M[point].add(i) 278 | return M 279 | 280 | 281 | def A_dirichlet_finite_element_quality(num_unknowns, matlab_engine, constant_coeffs=False, hole=False, uniform=False): 282 | if hole: 283 | mesh = create_hole_mesh(num_unknowns) 284 | else: 285 | mesh = create_quality_mesh(num_unknowns) 286 | mesh_points = np.array(mesh.points) 287 | num_points = mesh_points.shape[0] 288 | 289 | tris = np.array(mesh.elements) 290 | vertex_to_tris = vertex_to_tris_map(tris) 291 | 292 | if constant_coeffs: 293 | coeffs = np.ones(tris.shape[0]) 294 | else: 295 | if uniform: 296 | coeffs = np.random.uniform(size=[tris.shape[0]]) 297 | else: 298 | coeffs = np.exp(np.random.normal(scale=0.5, size=[tris.shape[0]])) 299 | # np.clip(coeffs, 0.1, 5, out=coeffs) 300 | # coeffs = np.random.uniform(size=tris.shape[0]) 301 | 302 | x0 = mesh_points[tris[:, 0], 0] 303 | x1 = mesh_points[tris[:, 1], 0] 304 | x2 = mesh_points[tris[:, 2], 0] 305 | y0 = mesh_points[tris[:, 0], 1] 306 | y1 = mesh_points[tris[:, 1], 1] 307 | y2 = mesh_points[tris[:, 2], 1] 308 | 309 | a = x1 - x0 310 | b = y1 - y0 311 | c = x2 - x0 312 | d = y2 - y0 313 | 314 | s = np.abs((a * d - b * c) / 2) 315 | det_A = x0 * y1 - x0 * y2 - x1 * y0 + x1 * y2 + x2 * y0 - x2 * y1 316 | 317 | coeffs_per_tri = coeffs * s / det_A ** 2 318 | 319 | x = mesh_points[:, 0] 320 | y = mesh_points[:, 1] 321 | 322 | # lil_matrix is faster than dok_matrix in this case 323 | A = lil_matrix((num_points, num_points), dtype=np.float64) 324 | for p in range(num_points): 325 | point_tris = vertex_to_tris[p] 326 | for point_tri in point_tris: 327 | tri_points = tris[point_tri, :] 328 | tri_coeff = coeffs_per_tri[point_tri] 329 | ind_others = tri_points[tri_points != p] 330 | p1 = ind_others[0] 331 | p2 = ind_others[1] 332 | 333 | A[p, p] = A[p, p] + tri_coeff * ((y[p2] - y[p1]) ** 2 + (x[p2] - x[p1]) ** 2) 334 | A[p, p1] = A[p, p1] + tri_coeff * ((y[p1] - y[p2]) * (y[p2] - y[p]) + (x[p1] - x[p2]) * (x[p2] - x[p])) 335 | A[p, p2] = A[p, p2] + tri_coeff * ((y[p2] - y[p1]) * (y[p1] - y[p]) + (x[p2] - x[p1]) * (x[p1] - x[p])) 336 | 337 | if hole: 338 | outside_indices = np.where( 339 | (np.linalg.norm(mesh_points, axis=1) >= 1) | 340 | (np.linalg.norm(mesh_points, axis=1) <= 1 / 3) 341 | )[0] 342 | else: 343 | outside_indices = np.where( 344 | (mesh_points[:, 0] <= 0) | 345 | (mesh_points[:, 0] >= 1) | 346 | (mesh_points[:, 1] <= 0) | 347 | (mesh_points[:, 1] >= 1) 348 | )[0] 349 | A = drop_row_col_matlab(A, outside_indices, matlab_engine) 350 | return A.tocsr() 351 | 352 | 353 | def poisson_dirichlet_stencils(num_stencils, grid_size, constant_coefficients=False): 354 | stencil = np.zeros((num_stencils, grid_size, grid_size, 3, 3)) 355 | 356 | if constant_coefficients: 357 | diffusion_coeff = np.ones(shape=[num_stencils, grid_size, grid_size]) 358 | else: 359 | diffusion_coeff = np.exp(np.random.normal(size=[num_stencils, grid_size, grid_size])) 360 | 361 | jm1 = [(i - 1) % grid_size for i in range(grid_size)] 362 | stencil[:, :, :, 1, 2] = -1. / 6 * (diffusion_coeff[:, jm1] + diffusion_coeff) 363 | stencil[:, :, :, 2, 1] = -1. / 6 * (diffusion_coeff + diffusion_coeff[:, :, jm1]) 364 | stencil[:, :, :, 2, 0] = -1. / 3 * diffusion_coeff[:, :, jm1] 365 | stencil[:, :, :, 2, 2] = -1. / 3 * diffusion_coeff 366 | 367 | jp1 = [(i + 1) % grid_size for i in range(grid_size)] 368 | 369 | stencil[:, :, :, 1, 0] = stencil[:, :, jm1, 1, 2] 370 | stencil[:, :, :, 0, 0] = stencil[:, jm1][:, :, jm1][:, :, :, 2, 2] 371 | stencil[:, :, :, 0, 1] = stencil[:, jm1][:, :, :, 2, 1] 372 | stencil[:, :, :, 0, 2] = stencil[:, jm1][:, :, jp1][:, :, :, 2, 0] 373 | stencil[:, :, :, 1, 1] = -np.sum(np.sum(stencil, axis=4), axis=3) 374 | 375 | stencil[:, :, 0, :, 0] = 0. 376 | stencil[:, :, -1, :, -1] = 0. 377 | stencil[:, 0, :, 0, :] = 0. 378 | stencil[:, -1, :, -1, :] = 0. 379 | return stencil 380 | 381 | 382 | @lru_cache(maxsize=None) 383 | def compute_A_indices(grid_size): 384 | K = map_2_to_1(grid_size=grid_size) 385 | A_idx = [] 386 | stencil_idx = [] 387 | for i in range(grid_size): 388 | for j in range(grid_size): 389 | I = int(K[i, j, 1, 1]) 390 | for k in range(3): 391 | for m in range(3): 392 | J = int(K[i, j, k, m]) 393 | A_idx.append([I, J]) 394 | stencil_idx.append([i, j, k, m]) 395 | return np.array(A_idx), stencil_idx 396 | 397 | 398 | def map_2_to_1(grid_size): 399 | # maps 2D coordinates to the corresponding 1D coordinate in the matrix. 400 | k = np.zeros((grid_size, grid_size, 3, 3)) 401 | M = np.reshape(np.arange(grid_size ** 2), (grid_size, grid_size)).T 402 | M = np.concatenate([M, M], 0) 403 | M = np.concatenate([M, M], 1) 404 | for i in range(3): 405 | I = (i - 1) % grid_size 406 | for j in range(3): 407 | J = (j - 1) % grid_size 408 | k[:, :, i, j] = M[I:I + grid_size, J:J + grid_size] 409 | return k 410 | 411 | 412 | def create_mesh(num_nodes): 413 | # create uniformly random points on the unit square 414 | rand_points = np.random.uniform([0, 0], [1, 1], [num_nodes, 2]) 415 | return Delaunay(rand_points) 416 | 417 | 418 | def generate_A_delaunay_periodic_lognormal(num_unknowns, uniform=False): 419 | tri = create_mesh(num_unknowns) 420 | verts = tri.vertices 421 | if uniform: 422 | random_values = -np.random.uniform(size=[tri.nsimplex, 3]) 423 | else: 424 | random_values = -np.exp(np.random.normal(size=[tri.nsimplex, 3])) # must be negative numbers 425 | A = lil_matrix((num_unknowns, num_unknowns), dtype=np.float64) 426 | 427 | for i in range(tri.nsimplex): 428 | vert = verts[i] 429 | A[vert[0], vert[1]] = random_values[i, 0] 430 | A[vert[0], vert[2]] = random_values[i, 1] 431 | A[vert[1], vert[2]] = random_values[i, 2] 432 | 433 | # symmetrize 434 | A[vert[1], vert[0]] = random_values[i, 0] 435 | A[vert[2], vert[0]] = random_values[i, 1] 436 | A[vert[2], vert[1]] = random_values[i, 2] 437 | 438 | # set row sums to be zero 439 | row_sums = A.sum(axis=0) 440 | for i in range(num_unknowns): 441 | A[i, i] = -row_sums[0, i] 442 | return A.tocsr() 443 | 444 | 445 | def drop_row_col_matlab(A, indices, matlab_engine): 446 | size = A.shape[0] 447 | A_coo = A.tocoo() 448 | A_rows = matlab.double((A_coo.row + 1)) 449 | A_cols = matlab.double((A_coo.col + 1)) 450 | indices = matlab.double((indices + 1)) 451 | A_values = matlab.double(A_coo.data) 452 | rows, cols, values = matlab_engine.drop_row_col(A_rows, A_cols, A_values, size, indices, nargout=3) 453 | rows = np.array(rows._data).reshape(rows.size, order='F') - 1 454 | cols = np.array(cols._data).reshape(cols.size, order='F') - 1 455 | values = np.array(values._data).reshape(values.size, order='F') 456 | rows, cols, values = rows.T[0], cols.T[0], values.T[0] 457 | rows, cols = rows.astype(np.int), cols.astype(np.int) 458 | return csr_matrix((values, (rows, cols))) 459 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | import string 5 | 6 | import fire 7 | import matlab.engine 8 | import numpy as np 9 | import pyamg 10 | import tensorflow as tf 11 | from pyamg.classical import direct_interpolation 12 | from scipy.sparse import csr_matrix 13 | from tqdm import tqdm 14 | 15 | import configs 16 | from data import generate_A 17 | from dataset import DataSet 18 | from model import csrs_to_graphs_tuple, create_model, graphs_tuple_to_sparse_matrices, to_prolongation_matrix_tensor 19 | from multigrid_utils import block_diagonalize_A_single, block_diagonalize_P, two_grid_error_matrices, frob_norm, \ 20 | two_grid_error_matrix, compute_coarse_A 21 | from relaxation import relaxation_matrices 22 | from utils import create_results_dir, write_config_file, most_frequent_splitting, chunks 23 | 24 | 25 | def create_dataset(num_As, data_config, run=0, matlab_engine=None): 26 | if data_config.load_data: 27 | As_filename = f"data_dir/periodic_delaunay_num_As_{num_As}_num_points_{data_config.num_unknowns}" \ 28 | f"_rnb_{data_config.root_num_blocks}_epoch_{run}.npy" 29 | if not os.path.isfile(As_filename): 30 | raise RuntimeError(f"file {As_filename} not found") 31 | As = np.load(As_filename) 32 | 33 | # workaround for data generated with both matrices and point coordinates 34 | if len(As.shape) == 1: 35 | As = list(As) 36 | elif len(As.shape) == 2: 37 | As = list(As[0]) 38 | else: 39 | As = [generate_A(data_config.num_unknowns, 40 | data_config.dist, 41 | data_config.block_periodic, 42 | data_config.root_num_blocks, 43 | add_diag=data_config.add_diag, 44 | matlab_engine=matlab_engine) for _ in range(num_As)] 45 | 46 | if data_config.save_data: 47 | As_filename = f"data_dir/periodic_delaunay_num_As_{num_As}_num_points_{data_config.num_unknowns}" \ 48 | f"_rnb_{data_config.root_num_blocks}_epoch_{run}.npy" 49 | np.save(As_filename, As) 50 | return create_dataset_from_As(As, data_config) 51 | 52 | 53 | def create_dataset_from_As(As, data_config): 54 | if data_config.block_periodic: 55 | Ss = [None] * len(As) # relaxation matrices are only created per block when calling loss() 56 | else: 57 | Ss = relaxation_matrices(As) 58 | if data_config.block_periodic: 59 | orig_solvers = [pyamg.ruge_stuben_solver(A, max_levels=2, keep=True, CF=data_config.splitting) 60 | for A in As] 61 | # for efficient Fourier analysis, we require that each block contains the same sparsity pattern - set of 62 | # coarse nodes, and interpolatory set for each node. The AMG C/F splitting algorithms do not output the same 63 | # splitting for each block, but the blocks are relatively similar to each other. Taking the most common set 64 | # of coarse nodes and repeating it for each block might be a good strategy 65 | splittings = [] 66 | baseline_P_list = [] 67 | for i in range(len(As)): 68 | # visualize_cf_splitting(As[i], Vs[i], orig_splittings[i]) 69 | 70 | orig_splitting = orig_solvers[i].levels[0].splitting 71 | block_splittings = list(chunks(orig_splitting, data_config.num_unknowns)) 72 | common_block_splitting = most_frequent_splitting(block_splittings) 73 | repeated_splitting = np.tile(common_block_splitting, data_config.root_num_blocks ** 2) 74 | splittings.append(repeated_splitting) 75 | 76 | # we recompute the Ruge-Stuben prolongation matrix with the modified splitting, and the original strength 77 | # matrix. We assume the strength matrix is block-circulant (because A is block-circulant) 78 | A = As[i] 79 | C = orig_solvers[i].levels[0].C 80 | P = direct_interpolation(A, C, repeated_splitting) 81 | baseline_P_list.append(tf.convert_to_tensor(P.toarray(), dtype=tf.float64)) 82 | 83 | coarse_nodes_list = [np.nonzero(splitting)[0] for splitting in splittings] 84 | 85 | else: 86 | solvers = [pyamg.ruge_stuben_solver(A, max_levels=2, keep=True, CF=data_config.splitting) 87 | for A in As] 88 | baseline_P_list = [solver.levels[0].P for solver in solvers] 89 | baseline_P_list = [tf.convert_to_tensor(P.toarray(), dtype=tf.float64) for P in baseline_P_list] 90 | splittings = [solver.levels[0].splitting for solver in solvers] 91 | coarse_nodes_list = [np.nonzero(splitting)[0] for splitting in splittings] 92 | return DataSet(As, Ss, coarse_nodes_list, baseline_P_list) 93 | 94 | 95 | def loss(dataset, A_graphs_tuple, P_graphs_tuple, 96 | run_config, train_config, data_config): 97 | As = graphs_tuple_to_sparse_matrices(A_graphs_tuple) 98 | Ps_square, nodes_list = graphs_tuple_to_sparse_matrices(P_graphs_tuple, True) 99 | 100 | if train_config.fourier: 101 | As = [tf.cast(tf.sparse.to_dense(A), tf.complex128) for A in As] 102 | block_As = [block_diagonalize_A_single(A, data_config.root_num_blocks, tensor=True) for A in As] 103 | block_Ss = relaxation_matrices([csr_matrix(A.numpy()) for block_A in block_As for A in block_A]) 104 | 105 | batch_size = len(dataset.coarse_nodes_list) 106 | total_norm = tf.Variable(0.0, dtype=tf.float64) 107 | for i in range(batch_size): 108 | if train_config.fourier: 109 | num_blocks = data_config.root_num_blocks ** 2 - 1 110 | 111 | P_square = Ps_square[i] 112 | coarse_nodes = dataset.coarse_nodes_list[i] 113 | baseline_P = dataset.baseline_P_list[i] 114 | nodes = nodes_list[i] 115 | P = to_prolongation_matrix_tensor(P_square, coarse_nodes, baseline_P, nodes, 116 | normalize_rows=run_config.normalize_rows, 117 | normalize_rows_by_node=run_config.normalize_rows_by_node) 118 | block_P = block_diagonalize_P(P, data_config.root_num_blocks, coarse_nodes) 119 | 120 | As = tf.stack(block_As[i]) 121 | Ps = tf.stack(block_P) 122 | Rs = tf.transpose(Ps, perm=[0, 2, 1], conjugate=True) 123 | Ss = tf.convert_to_tensor(block_Ss[num_blocks * i:num_blocks * (i + 1)]) 124 | 125 | Ms = two_grid_error_matrices(As, Ps, Rs, Ss) 126 | M = Ms[-1] # for logging 127 | block_norms = tf.abs(frob_norm(Ms, power=1)) 128 | 129 | block_max_norm = tf.reduce_max(block_norms) 130 | total_norm = total_norm + block_max_norm 131 | 132 | else: 133 | A = tf.sparse.to_dense(As[i]) 134 | P_square = Ps_square[i] 135 | coarse_nodes = dataset.coarse_nodes_list[i] 136 | baseline_P = dataset.baseline_P_list[i] 137 | nodes = nodes_list[i] 138 | P = to_prolongation_matrix_tensor(P_square, coarse_nodes, baseline_P, nodes, 139 | normalize_rows=run_config.normalize_rows, 140 | normalize_rows_by_node=run_config.normalize_rows_by_node) 141 | R = tf.transpose(P) 142 | S = tf.convert_to_tensor(dataset.Ss[i]) 143 | 144 | M = two_grid_error_matrix(A, P, R, S) 145 | 146 | norm = frob_norm(M, power=1) 147 | total_norm = total_norm + norm 148 | 149 | return total_norm / batch_size, M # M is chosen randomly - the last in the batch 150 | 151 | 152 | def save_model_and_optimizer(checkpoint_prefix, model, optimizer, global_step): 153 | variables = model.get_all_variables() 154 | variables_dict = {variable.name: variable for variable in variables} 155 | checkpoint = tf.train.Checkpoint(**variables_dict, optimizer=optimizer, global_step=global_step) 156 | checkpoint.save(file_prefix=checkpoint_prefix) 157 | return checkpoint 158 | 159 | 160 | def train_run(run_dataset, run, batch_size, config, 161 | model, optimizer, global_step, checkpoint_prefix, 162 | eval_dataset, eval_A_graphs_tuple, eval_config, matlab_engine): 163 | num_As = len(run_dataset.As) 164 | if num_As % batch_size != 0: 165 | raise RuntimeError("batch size must divide training data size") 166 | 167 | run_dataset = run_dataset.shuffle() 168 | num_batches = num_As // batch_size 169 | loop = tqdm(range(num_batches)) 170 | for batch in loop: 171 | start_index = batch * batch_size 172 | end_index = start_index + batch_size 173 | batch_dataset = run_dataset[start_index:end_index] 174 | 175 | batch_A_graphs_tuple = csrs_to_graphs_tuple(batch_dataset.As, matlab_engine, 176 | coarse_nodes_list=batch_dataset.coarse_nodes_list, 177 | baseline_P_list=batch_dataset.baseline_P_list, 178 | node_indicators=config.run_config.node_indicators, 179 | edge_indicators=config.run_config.edge_indicators) 180 | 181 | with tf.GradientTape() as tape: 182 | with tf.device('/gpu:0'): 183 | batch_P_graphs_tuple = model(batch_A_graphs_tuple) 184 | frob_loss, M = loss(batch_dataset, batch_A_graphs_tuple, batch_P_graphs_tuple, 185 | config.run_config, config.train_config, config.data_config) 186 | 187 | print(f"frob loss: {frob_loss.numpy()}") 188 | save_every = max(1000 // batch_size, 1) 189 | if batch % save_every == 0: 190 | checkpoint = save_model_and_optimizer(checkpoint_prefix, model, optimizer, global_step) 191 | 192 | # we don't call .get_variables() because the model is Sequential/custom, 193 | # see docs for Sequential.get_variables() 194 | variables = model.get_all_variables() 195 | grads = tape.gradient(frob_loss, variables) 196 | 197 | global_step.assign_add(batch_size - 1) # apply_gradients increments global_step by 1 198 | optimizer.apply_gradients(zip(grads, variables), 199 | global_step=global_step) 200 | 201 | record_tb(M, run, num_As, batch, batch_size, frob_loss, grads, loop, model, 202 | variables, eval_dataset, eval_A_graphs_tuple, eval_config) 203 | return checkpoint 204 | 205 | 206 | def record_tb_loss(frob_loss): 207 | with tf.contrib.summary.record_summaries_every_n_global_steps(1): 208 | tf.contrib.summary.scalar('loss', frob_loss) 209 | 210 | 211 | def record_tb_params(batch_size, grads, loop, variables): 212 | with tf.contrib.summary.record_summaries_every_n_global_steps(1): 213 | if loop.avg_time is not None: 214 | tf.contrib.summary.scalar('seconds_per_batch', tf.convert_to_tensor(loop.avg_time)) 215 | 216 | for i in range(len(variables)): 217 | variable = variables[i] 218 | variable_name = variable.name 219 | grad = grads[i] 220 | if grad is not None: 221 | tf.contrib.summary.scalar(variable_name + '_grad', tf.norm(grad) / batch_size) 222 | tf.contrib.summary.histogram(variable_name + '_grad_histogram', grad / batch_size) 223 | tf.contrib.summary.scalar(variable_name + '_grad_fraction_dead', tf.nn.zero_fraction(grad)) 224 | tf.contrib.summary.scalar(variable_name + '_value', tf.norm(variable)) 225 | tf.contrib.summary.histogram(variable_name + '_value_histogram', variable) 226 | 227 | 228 | def record_tb_spectral_radius(M, model, eval_dataset, eval_A_graphs_tuple, eval_config): 229 | with tf.contrib.summary.record_summaries_every_n_global_steps(1): 230 | spectral_radius = np.abs(np.linalg.eigvals(M.numpy())).max() 231 | tf.contrib.summary.scalar('spectral_radius', spectral_radius) 232 | 233 | with tf.device('/gpu:0'): 234 | eval_P_graphs_tuple = model(eval_A_graphs_tuple) 235 | eval_loss, eval_M = loss(eval_dataset, eval_A_graphs_tuple, eval_P_graphs_tuple, 236 | eval_config.run_config, 237 | eval_config.train_config, 238 | eval_config.data_config) 239 | 240 | eval_spectral_radius = np.abs(np.linalg.eigvals(eval_M.numpy())).max() 241 | tf.contrib.summary.scalar('eval_loss', eval_loss) 242 | tf.contrib.summary.scalar('eval_spectral_radius', eval_spectral_radius) 243 | 244 | 245 | def record_tb(M, run, num_As, batch, batch_size, frob_loss, grads, loop, model, 246 | variables, eval_dataset, eval_A_graphs_tuple, eval_config): 247 | batch = run * num_As + batch 248 | 249 | record_loss_every = max(1 // batch_size, 1) 250 | if batch % record_loss_every == 0: 251 | record_tb_loss(frob_loss) 252 | 253 | record_params_every = max(300 // batch_size, 1) 254 | if batch % record_params_every == 0: 255 | record_tb_params(batch_size, grads, loop, variables) 256 | 257 | record_spectral_every = max(300 // batch_size, 1) 258 | if batch % record_spectral_every == 0: 259 | record_tb_spectral_radius(M, model, eval_dataset, eval_A_graphs_tuple, eval_config) 260 | 261 | 262 | def clone_model(model, model_config, run_config, matlab_engine): 263 | clone = create_model(model_config) 264 | 265 | dummy_A = pyamg.gallery.poisson((7, 7), type='FE', format='csr') 266 | dummy_input = csrs_to_graphs_tuple([dummy_A], matlab_engine, coarse_nodes_list=np.array([[0, 1]]), 267 | baseline_P_list=[tf.convert_to_tensor(dummy_A.toarray()[:, [0, 1]])], 268 | node_indicators=run_config.node_indicators, 269 | edge_indicators=run_config.edge_indicators) 270 | clone(dummy_input) 271 | [var_clone.assign(var_orig) for var_clone, var_orig in zip(clone.get_all_variables(), model.get_all_variables())] 272 | return clone 273 | 274 | 275 | def coarsen_As(fine_dataset, model, run_config, matlab_engine, batch_size=64): 276 | # computes the Galerkin operator P^(T)AP on each of the A matrices in a batch, using the Prolongation 277 | # outputted from the model 278 | As = fine_dataset.As 279 | coarse_nodes_list = fine_dataset.coarse_nodes_list 280 | baseline_P_list = fine_dataset.baseline_P_list 281 | 282 | batch_size = min(batch_size, len(As)) 283 | num_batches = len(As) // batch_size 284 | 285 | batched_As = list(chunks(As, batch_size)) 286 | batched_coarse_nodes_list = list(chunks(coarse_nodes_list, batch_size)) 287 | batched_baseline_P_list = list(chunks(baseline_P_list, batch_size)) 288 | A_graphs_tuple_batches = [csrs_to_graphs_tuple(batch_As, matlab_engine, coarse_nodes_list=batch_coarse_nodes_list, 289 | baseline_P_list=batch_baseline_P_list, 290 | node_indicators=run_config.node_indicators, 291 | edge_indicators=run_config.edge_indicators 292 | ) 293 | for batch_As, batch_coarse_nodes_list, batch_baseline_P_list 294 | in zip(batched_As, batched_coarse_nodes_list, batched_baseline_P_list)] 295 | 296 | Ps_square = [] 297 | nodes_list = [] 298 | for batch in tqdm(range(num_batches)): 299 | A_graphs_tuple = A_graphs_tuple_batches[batch] 300 | P_graphs_tuple = model(A_graphs_tuple) 301 | P_square_batch, nodes_batch = graphs_tuple_to_sparse_matrices(P_graphs_tuple, return_nodes=True) 302 | Ps_square.extend(P_square_batch) 303 | nodes_list.extend(nodes_batch) 304 | 305 | coarse_As = [] 306 | for i in tqdm(range(len(As))): 307 | P_square = Ps_square[i] 308 | nodes = nodes_list[i] 309 | coarse_nodes = coarse_nodes_list[i] 310 | baseline_P = baseline_P_list[i] 311 | P = to_prolongation_matrix_tensor(P_square, coarse_nodes, baseline_P, nodes) 312 | R = tf.transpose(P) 313 | A_csr = As[i] 314 | A = tf.convert_to_tensor(A_csr.toarray(), dtype=tf.float64) 315 | tensor_coarse_A = compute_coarse_A(R, A, P) 316 | coarse_A = csr_matrix(tensor_coarse_A.numpy()) 317 | coarse_As.append(coarse_A) 318 | return coarse_As 319 | 320 | 321 | def create_coarse_dataset(fine_dataset, model, data_config, run_config, matlab_engine): 322 | As = coarsen_As(fine_dataset, model, run_config, matlab_engine) 323 | return create_dataset_from_As(As, data_config) 324 | 325 | 326 | def train(config='GRAPH_LAPLACIAN_TRAIN', eval_config='GRAPH_LAPLACIAN_EVAL', seed=1): 327 | config = getattr(configs, config) 328 | eval_config = getattr(configs, eval_config) 329 | eval_config.run_config = config.run_config 330 | 331 | matlab_engine = matlab.engine.start_matlab() 332 | 333 | # fix random seeds for reproducibility 334 | np.random.seed(seed) 335 | tf.random.set_random_seed(seed) 336 | matlab_engine.eval(f'rng({seed})') 337 | 338 | batch_size = min(config.train_config.samples_per_run, config.train_config.batch_size) 339 | 340 | # we measure the performance of the model over time on one larger instance that is not optimized for 341 | eval_dataset = create_dataset(1, eval_config.data_config) 342 | eval_A_graphs_tuple = csrs_to_graphs_tuple(eval_dataset.As, matlab_engine, 343 | coarse_nodes_list=eval_dataset.coarse_nodes_list, 344 | baseline_P_list=eval_dataset.baseline_P_list, 345 | node_indicators=eval_config.run_config.node_indicators, 346 | edge_indicators=eval_config.run_config.edge_indicators 347 | ) 348 | 349 | if config.train_config.load_model: 350 | raise NotImplementedError() 351 | else: 352 | model = create_model(config.model_config) 353 | global_step = tf.train.get_or_create_global_step() 354 | optimizer = tf.train.AdamOptimizer(learning_rate=config.train_config.learning_rate) 355 | 356 | run_name = ''.join(random.choices(string.digits, k=5)) # to make the run_name string unique 357 | create_results_dir(run_name) 358 | write_config_file(run_name, config, seed) 359 | 360 | checkpoint_prefix = os.path.join(config.train_config.checkpoint_dir + '/' + run_name, 'ckpt') 361 | log_dir = config.train_config.tensorboard_dir + '/' + run_name 362 | writer = tf.contrib.summary.create_file_writer(log_dir) 363 | writer.set_as_default() 364 | 365 | for run in range(config.train_config.num_runs): 366 | # we create the data before the training loop starts for efficiency, 367 | # at the loop we only slice batches and convert to tensors 368 | run_dataset = create_dataset(config.train_config.samples_per_run, config.data_config, 369 | run=run, matlab_engine=matlab_engine) 370 | 371 | checkpoint = train_run(run_dataset, run, batch_size, config, 372 | model, optimizer, global_step, 373 | checkpoint_prefix, 374 | eval_dataset, eval_A_graphs_tuple, eval_config, 375 | matlab_engine) 376 | checkpoint.save(file_prefix=checkpoint_prefix) 377 | 378 | if config.train_config.coarsen: 379 | old_model = clone_model(model, config.model_config, config.run_config, matlab_engine) 380 | 381 | for run in range(config.train_config.num_runs): 382 | run_dataset = create_dataset(config.train_config.samples_per_run, config.data_config, 383 | run=run, matlab_engine=matlab_engine) 384 | 385 | fine_data_config = copy.deepcopy(config.data_config) 386 | # RS coarsens to roughly 1/3 of the size of the grid, CLJP to roughly 1/2 387 | fine_data_config.num_unknowns = config.data_config.num_unknowns * 2 388 | fine_run_dataset = create_dataset(config.train_config.samples_per_run, 389 | fine_data_config, 390 | run=run, 391 | matlab_engine=matlab_engine) 392 | coarse_run_dataset = create_coarse_dataset(fine_run_dataset, old_model, 393 | config.data_config, 394 | config.run_config, 395 | matlab_engine=matlab_engine) 396 | 397 | combined_run_dataset = run_dataset + coarse_run_dataset 398 | combined_run_dataset = combined_run_dataset.shuffle() 399 | 400 | checkpoint = train_run(combined_run_dataset, run, batch_size, config, 401 | model, optimizer, global_step, 402 | checkpoint_prefix, 403 | eval_dataset, eval_A_graphs_tuple, eval_config, 404 | matlab_engine) 405 | checkpoint.save(file_prefix=checkpoint_prefix) 406 | 407 | 408 | if __name__ == '__main__': 409 | tf_config = tf.ConfigProto() 410 | tf_config.gpu_options.allow_growth = True 411 | tf.enable_eager_execution(config=tf_config) 412 | 413 | fire.Fire(train) 414 | --------------------------------------------------------------------------------