├── .idea ├── Finite-expression-method.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── README.md ├── fex ├── Conservationlaw │ ├── computational_tree.py │ ├── controller_conservative.py │ ├── function.py │ ├── scripts.py │ ├── tools.py │ └── utils │ │ ├── __init__.py │ │ ├── eval.py │ │ ├── images │ │ ├── cifar.png │ │ └── imagenet.png │ │ ├── logger.py │ │ ├── misc.py │ │ └── visualize.py ├── Poisson │ ├── .idea │ │ ├── .gitignore │ │ ├── inspectionProfiles │ │ │ └── profiles_settings.xml │ │ ├── misc.xml │ │ ├── modules.xml │ │ └── poissoneqn.iml │ ├── computational_tree.py │ ├── controller_poisson.py │ ├── function.py │ ├── scripts.py │ ├── tools.py │ └── utils │ │ ├── __init__.py │ │ ├── eval.py │ │ ├── images │ │ ├── cifar.png │ │ └── imagenet.png │ │ ├── logger.py │ │ ├── misc.py │ │ └── visualize.py └── Schrodinger │ ├── computational_tree.py │ ├── controller_cubic_sh_firstdeflation_thenintegral.py │ ├── function.py │ ├── scripts.py │ ├── tools.py │ └── utils │ ├── __init__.py │ ├── eval.py │ ├── images │ ├── cifar.png │ └── imagenet.png │ ├── logger.py │ ├── misc.py │ └── visualize.py ├── fexrl.png └── nn ├── Conservationlaw ├── scrpts.py ├── train.py └── utils │ ├── __init__.py │ ├── eval.py │ ├── logger.py │ ├── misc.py │ └── visualize.py ├── Poisson ├── .idea │ ├── .gitignore │ ├── conservativelaw.iml │ ├── inspectionProfiles │ │ ├── Project_Default.xml │ │ └── profiles_settings.xml │ ├── misc.xml │ └── modules.xml ├── scrpts.py ├── train.py └── utils │ ├── __init__.py │ ├── eval.py │ ├── logger.py │ ├── misc.py │ └── visualize.py └── Schrodinger ├── scripts.py ├── train_integral.py └── utils ├── __init__.py ├── eval.py ├── logger.py ├── misc.py └── visualize.py /.idea/Finite-expression-method.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 13 | 14 | 16 | 17 | 19 | 20 | 21 | 24 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 1660276797692 44 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Finite-expression-method 2 | ![GitHub](https://img.shields.io/github/license/gbup-group/DIANet.svg) 3 | 4 | By [Senwei Liang](https://leungsamwai.github.io) and [Haizhao Yang](https://haizhaoyang.github.io/) 5 | 6 | This repo is the implementation of "Finite Expression Method for Solving High-Dimensional Partial Differential Equations" [[paper]](https://arxiv.org/abs/2206.10121). 7 | 8 | ## Introduction 9 | 10 | Finite expression method (FEX) is a new methodology that seeks an approximate PDE solution in the space of functions with finitely many analytic expressions. This repo provides a deep reinforcement learning method to implement FEX for various high-dimensional PDEs in different dimensions. 11 | 12 | ![image](fexrl.png) 13 | 14 | ## Environment 15 | * [PyTorch 1.0](http://pytorch.org/) 16 | 17 | ## Code structure 18 | 19 | ``` 20 | Finite-expression-method 21 | │ README.md <-- You are here 22 | │ 23 | └─── fex ----> three numerical examples with FEX 24 | │ │ Poisson 25 | │ │ Schrodinger 26 | │ │ Conservationlaw 27 | │ 28 | └─── nn ----> three numerical examples with NN 29 | │ Poisson 30 | │ Schrodinger 31 | │ Conservationlaw 32 | ``` 33 | ## Citing FEX 34 | If you find our code is helpful for your research, please kindly cite 35 | ``` 36 | @article{liang2022finite, 37 | title={Finite Expression Method for Solving High-Dimensional Partial Differential Equations}, 38 | author={Liang, Senwei and Yang, Haizhao}, 39 | journal={arXiv preprint arXiv:2206.10121}, 40 | year={2022} 41 | } 42 | ``` 43 | ## Acknowledgments 44 | 45 | We appreciate [bearpaw](https://github.com/bearpaw) for his [DL framework](https://github.com/bearpaw/pytorch-classification) and Taehoon Kim for his [RL fromework](https://github.com/carpedm20/ENAS-pytorch). -------------------------------------------------------------------------------- /fex/Conservationlaw/computational_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import function as func 3 | unary = func.unary_functions 4 | binary = func.binary_functions 5 | unary_functions_str = func.unary_functions_str 6 | binary_functions_str = func.binary_functions_str 7 | 8 | class BinaryTree(object): 9 | def __init__(self,item,is_unary=True): 10 | self.key=item 11 | self.is_unary=is_unary 12 | self.leftChild=None 13 | self.rightChild=None 14 | def insertLeft(self,item, is_unary=True): 15 | if self.leftChild==None: 16 | self.leftChild=BinaryTree(item, is_unary) 17 | else: 18 | t=BinaryTree(item) 19 | t.leftChild=self.leftChild 20 | self.leftChild=t 21 | def insertRight(self,item, is_unary=True): 22 | if self.rightChild==None: 23 | self.rightChild=BinaryTree(item, is_unary) 24 | else: 25 | t=BinaryTree(item) 26 | t.rightChild=self.rightChild 27 | self.rightChild=t 28 | 29 | def compute_by_tree(tree, x): 30 | ''' judge whether a emtpy tree, if yes, that means the leaves and call the unary operation ''' 31 | if tree.leftChild == None and tree.rightChild == None: 32 | return tree.key(x) 33 | elif tree.leftChild == None and tree.rightChild is not None: 34 | return tree.key(compute_by_tree(tree.rightChild, x)) 35 | elif tree.leftChild is not None and tree.rightChild == None: 36 | return tree.key(compute_by_tree(tree.leftChild, x)) 37 | else: 38 | return tree.key(compute_by_tree(tree.leftChild, x), compute_by_tree(tree.rightChild, x)) 39 | 40 | def inorder(tree, actions): 41 | global count 42 | if tree: 43 | inorder(tree.leftChild, actions) 44 | action = actions[count].item() 45 | if tree.is_unary: 46 | action = action 47 | tree.key = unary[action] 48 | # print(count, action, func.unary_functions_str[action]) 49 | else: 50 | action = action 51 | tree.key = binary[action] 52 | # print(count, action, func.binary_functions_str[action]) 53 | count = count + 1 54 | inorder(tree.rightChild, actions) 55 | 56 | # def inorder(tree): 57 | # if tree: 58 | # inorder(tree.leftChild) 59 | # print(tree.key) 60 | # inorder(tree.rightChild) 61 | 62 | count = 0 63 | def inorder_w_idx(tree): 64 | global count 65 | if tree: 66 | inorder_w_idx(tree.leftChild) 67 | print(tree.key, count) 68 | count = count + 1 69 | inorder_w_idx(tree.rightChild) 70 | 71 | def basic_tree(): 72 | tree = BinaryTree('', False) 73 | tree.insertLeft('', False) 74 | tree.leftChild.insertLeft('', True) 75 | tree.leftChild.insertRight('', True) 76 | tree.insertRight('', False) 77 | tree.rightChild.insertLeft('', True) 78 | tree.rightChild.insertRight('', True) 79 | return tree 80 | 81 | def get_function(actions): 82 | global count 83 | count = 0 84 | computation_tree = basic_tree() 85 | inorder(computation_tree, actions) 86 | count = 0 # 置零 87 | return computation_tree 88 | 89 | def inorder_test(tree, actions): 90 | global count 91 | if tree: 92 | inorder(tree.leftChild, actions) 93 | action = actions[count].item() 94 | print(action) 95 | if tree.is_unary: 96 | action = action 97 | tree.key = unary[action] 98 | # print(count, action, func.unary_functions_str[action]) 99 | else: 100 | action = action 101 | tree.key = binary[action] 102 | # print(count, action, func.binary_functions_str[action]) 103 | count = count + 1 104 | inorder(tree.rightChild, actions) 105 | 106 | if __name__ =='__main__': 107 | # tree = BinaryTree(np.add) 108 | # tree.insertLeft(np.multiply) 109 | # tree.leftChild.insertLeft(np.cos) 110 | # tree.leftChild.insertRight(np.sin) 111 | # tree.insertRight(np.sin) 112 | # print(compute_by_tree(tree, 30)) # np.sin(30)*np.cos(30)+np.sin(30) 113 | # inorder(tree) 114 | # inorder_w_idx(tree) 115 | import torch 116 | bs_action = [torch.LongTensor([10]), torch.LongTensor([2]),torch.LongTensor([9]),torch.LongTensor([1]),torch.LongTensor([0]),torch.LongTensor([2]),torch.LongTensor([6])] 117 | 118 | function = lambda x: compute_by_tree(get_function(bs_action), x) 119 | x = torch.FloatTensor([[-1], [1]]) 120 | 121 | count = 0 122 | tr = basic_tree() 123 | inorder_test(tr, bs_action) 124 | count = 0 125 | 126 | -------------------------------------------------------------------------------- /fex/Conservationlaw/function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import sin, cos, exp 4 | import math 5 | 6 | scale = math.pi/4 7 | 8 | def LHS_pde(u, x, dim_set): 9 | 10 | v = torch.ones(u.shape).cuda() 11 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0] 12 | coefficient = torch.ones(1, dim_set).cuda() 13 | coefficient = -coefficient 14 | coefficient[:,0] = (dim_set-1)*scale 15 | ux_mul_coef = ux*coefficient 16 | LHS = torch.sum(ux_mul_coef, 1, keepdim=True) 17 | return LHS 18 | 19 | def RHS_pde(x): 20 | bs = x.size(0) 21 | return torch.zeros(bs, 1).cuda() 22 | 23 | def true_solution(x): 24 | dim = x.size(1) 25 | coefficient = torch.ones(1, dim).cuda()*scale 26 | coefficient[:, 0] = 1 27 | return torch.sin(torch.sum(x*coefficient, dim=1, keepdim=True)) 28 | 29 | 30 | unary_functions = [lambda x: 0*x**2, 31 | lambda x: 1+0*x**2, 32 | lambda x: x+0*x**2, 33 | lambda x: x**2, 34 | lambda x: x**3, 35 | lambda x: x**4, 36 | torch.exp, 37 | torch.sin, 38 | torch.cos,] 39 | 40 | binary_functions = [lambda x,y: x+y, 41 | lambda x,y: x*y, 42 | lambda x,y: x-y] 43 | 44 | 45 | unary_functions_str = ['({}*(0)+{})', 46 | '({}*(1)+{})', 47 | # '5', 48 | '({}*{}+{})', 49 | # '-{}', 50 | '({}*({})**2+{})', 51 | '({}*({})**3+{})', 52 | '({}*({})**4+{})', 53 | # '({})**5', 54 | '({}*exp({})+{})', 55 | '({}*sin({})+{})', 56 | '({}*cos({})+{})',] 57 | # 'ref({})', 58 | # 'exp(-({})**2/2)'] 59 | 60 | unary_functions_str_leaf= ['(0)', 61 | '(1)', 62 | # '5', 63 | '({})', 64 | # '-{}', 65 | '(({})**2)', 66 | '(({})**3)', 67 | '(({})**4)', 68 | # '({})**5', 69 | '(exp({}))', 70 | '(sin({}))', 71 | '(cos({}))',] 72 | 73 | 74 | binary_functions_str = ['(({})+({}))', 75 | '(({})*({}))', 76 | '(({})-({}))'] 77 | 78 | if __name__ == '__main__': 79 | batch_size = 200 80 | left = -1 81 | right = 1 82 | points = (torch.rand(batch_size, 1)) * (right - left) + left 83 | x = torch.autograd.Variable(points.cuda(), requires_grad=True) 84 | function = true_solution 85 | 86 | ''' 87 | PDE loss 88 | ''' 89 | LHS = LHS_pde(function(x), x) 90 | RHS = RHS_pde(x) 91 | pde_loss = torch.nn.functional.mse_loss(LHS, RHS) 92 | 93 | ''' 94 | boundary loss 95 | ''' 96 | bc_points = torch.FloatTensor([[left], [right]]).cuda() 97 | bc_value = true_solution(bc_points) 98 | bd_loss = torch.nn.functional.mse_loss(function(bc_points), bc_value) 99 | 100 | print('pde loss: {} -- boundary loss: {}'.format(pde_loss.item(), bd_loss.item())) -------------------------------------------------------------------------------- /fex/Conservationlaw/scripts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | 5 | gpus = [1,7,8,9,2,3,4,5,6]*5 6 | idx = 0 7 | 8 | for i in range(5): 9 | for dim in [42, 48, 54]: 10 | gpu = gpus[idx] 11 | idx += 1 12 | os.system('screen python controller_conservative.py --epoch 2000 --bs 10 --greedy 0.1 --gpu '+str(gpu)+' --ckpt t_range_0_1_2ksearch_int20k_bd4kDim'+str(dim)+' --tree depth2_sub --random_step 3 --lr 0.002 --dim '+str(dim)+' --base 200000 --left -1 --right 1 --domainbs 20000 --bdbs 4000') 13 | time.sleep(500) -------------------------------------------------------------------------------- /fex/Conservationlaw/tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from collections import defaultdict 4 | import collections 5 | from datetime import datetime 6 | import os 7 | import json 8 | import logging 9 | 10 | import numpy as np 11 | import torch 12 | from torch.autograd import Variable 13 | 14 | ########################## 15 | # Torch 16 | ########################## 17 | 18 | def detach(h): 19 | if type(h) == Variable: 20 | return Variable(h.data) 21 | else: 22 | return tuple(detach(v) for v in h) 23 | 24 | def get_variable(inputs, cuda=False, **kwargs): 25 | if type(inputs) in [list, np.ndarray]: 26 | inputs = torch.Tensor(inputs) 27 | if cuda: 28 | out = Variable(inputs.cuda(), **kwargs) 29 | else: 30 | out = Variable(inputs, **kwargs) 31 | return out 32 | 33 | def update_lr(optimizer, lr): 34 | for param_group in optimizer.param_groups: 35 | param_group['lr'] = lr 36 | 37 | def batchify(data, bsz, use_cuda): 38 | # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py 39 | nbatch = data.size(0) // bsz 40 | data = data.narrow(0, 0, nbatch * bsz) 41 | data = data.view(bsz, -1).t().contiguous() 42 | if use_cuda: 43 | data = data.cuda() 44 | return data 45 | 46 | 47 | ########################## 48 | # ETC 49 | ########################## 50 | 51 | Node = collections.namedtuple('Node', ['id', 'name']) 52 | 53 | 54 | class keydefaultdict(defaultdict): 55 | def __missing__(self, key): 56 | if self.default_factory is None: 57 | raise KeyError(key) 58 | else: 59 | ret = self[key] = self.default_factory(key) 60 | return ret 61 | 62 | 63 | def to_item(x): 64 | """Converts x, possibly scalar and possibly tensor, to a Python scalar.""" 65 | if isinstance(x, (float, int)): 66 | return x 67 | 68 | if float(torch.__version__[0:3]) < 0.4: 69 | assert (x.dim() == 1) and (len(x) == 1) 70 | return x[0] 71 | 72 | return x.item() 73 | 74 | 75 | def get_logger(name=__file__, level=logging.INFO): 76 | logger = logging.getLogger(name) 77 | 78 | if getattr(logger, '_init_done__', None): 79 | logger.setLevel(level) 80 | return logger 81 | 82 | logger._init_done__ = True 83 | logger.propagate = False 84 | logger.setLevel(level) 85 | 86 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 87 | handler = logging.StreamHandler() 88 | handler.setFormatter(formatter) 89 | handler.setLevel(0) 90 | 91 | del logger.handlers[:] 92 | logger.addHandler(handler) 93 | 94 | return logger 95 | 96 | 97 | logger = get_logger() 98 | 99 | 100 | def prepare_dirs(args): 101 | """Sets the directories for the model, and creates those directories. 102 | 103 | Args: 104 | args: Parsed from `argparse` in the `config` module. 105 | """ 106 | if args.load_path: 107 | if args.load_path.startswith(args.log_dir): 108 | args.model_dir = args.load_path 109 | else: 110 | if args.load_path.startswith(args.dataset): 111 | args.model_name = args.load_path 112 | else: 113 | args.model_name = "{}_{}".format(args.dataset, args.load_path) 114 | else: 115 | args.model_name = "{}_{}".format(args.dataset, get_time()) 116 | 117 | if not hasattr(args, 'model_dir'): 118 | args.model_dir = os.path.join(args.log_dir, args.model_name) 119 | args.data_path = os.path.join(args.data_dir, args.dataset) 120 | 121 | for path in [args.log_dir, args.data_dir, args.model_dir]: 122 | if not os.path.exists(path): 123 | makedirs(path) 124 | 125 | def get_time(): 126 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 127 | 128 | def save_args(args): 129 | param_path = os.path.join(args.model_dir, "params.json") 130 | 131 | logger.info("[*] MODEL dir: %s" % args.model_dir) 132 | logger.info("[*] PARAM path: %s" % param_path) 133 | 134 | with open(param_path, 'w') as fp: 135 | json.dump(args.__dict__, fp, indent=4, sort_keys=True) 136 | 137 | def save_dag(args, dag, name): 138 | save_path = os.path.join(args.model_dir, name) 139 | logger.info("[*] Save dag : {}".format(save_path)) 140 | json.dump(dag, open(save_path, 'w')) 141 | 142 | def load_dag(args): 143 | load_path = os.path.join(args.dag_path) 144 | logger.info("[*] Load dag : {}".format(load_path)) 145 | with open(load_path) as f: 146 | dag = json.load(f) 147 | dag = {int(k): [Node(el[0], el[1]) for el in v] for k, v in dag.items()} 148 | save_dag(args, dag, "dag.json") 149 | draw_network(dag, os.path.join(args.model_dir, "dag.png")) 150 | return dag 151 | 152 | def makedirs(path): 153 | if not os.path.exists(path): 154 | logger.info("[*] Make directories : {}".format(path)) 155 | os.makedirs(path) 156 | 157 | def remove_file(path): 158 | if os.path.exists(path): 159 | logger.info("[*] Removed: {}".format(path)) 160 | os.remove(path) 161 | 162 | def backup_file(path): 163 | root, ext = os.path.splitext(path) 164 | new_path = "{}.backup_{}{}".format(root, get_time(), ext) 165 | 166 | os.rename(path, new_path) 167 | logger.info("[*] {} has backup: {}".format(path, new_path)) 168 | -------------------------------------------------------------------------------- /fex/Conservationlaw/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | # from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /fex/Conservationlaw/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /fex/Conservationlaw/utils/images/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Conservationlaw/utils/images/cifar.png -------------------------------------------------------------------------------- /fex/Conservationlaw/utils/images/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Conservationlaw/utils/images/imagenet.png -------------------------------------------------------------------------------- /fex/Conservationlaw/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | def append(self, numbers): 61 | assert len(self.names) == len(numbers), 'Numbers do not match names' 62 | for index, num in enumerate(numbers): 63 | if isinstance(num, int): 64 | self.file.write("{}".format(num)) 65 | elif isinstance(num, str): 66 | self.file.write("{}".format(num)) 67 | else: 68 | self.file.write("{0:.8f}".format(num)) 69 | self.file.write('\t') 70 | self.numbers[self.names[index]].append(num) 71 | self.file.write('\n') 72 | self.file.flush() 73 | 74 | def plot(self, names=None): 75 | names = self.names if names == None else names 76 | numbers = self.numbers 77 | for _, name in enumerate(names): 78 | x = np.arange(len(numbers[name])) 79 | plt.plot(x, np.asarray(numbers[name])) 80 | plt.legend([self.title + '(' + name + ')' for name in names]) 81 | plt.grid(True) 82 | 83 | def close(self): 84 | if self.file is not None: 85 | self.file.close() 86 | 87 | class LoggerMonitor(object): 88 | '''Load and visualize multiple logs.''' 89 | def __init__ (self, paths): 90 | '''paths is a distionary with {name:filepath} pair''' 91 | self.loggers = [] 92 | for title, path in paths.items(): 93 | logger = Logger(path, title=title, resume=True) 94 | self.loggers.append(logger) 95 | 96 | def plot(self, names=None): 97 | plt.figure() 98 | plt.subplot(121) 99 | legend_text = [] 100 | for logger in self.loggers: 101 | legend_text += plot_overlap(logger, names) 102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 103 | plt.grid(True) 104 | 105 | if __name__ == '__main__': 106 | # # Example 107 | # logger = Logger('test.txt') 108 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 109 | 110 | # length = 100 111 | # t = np.arange(length) 112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 115 | 116 | # for i in range(0, length): 117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 118 | # logger.plot() 119 | 120 | # Example: logger monitor 121 | paths = { 122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 125 | } 126 | 127 | field = ['Valid Acc.'] 128 | 129 | monitor = LoggerMonitor(paths) 130 | monitor.plot(names=field) 131 | savefig('test.eps') -------------------------------------------------------------------------------- /fex/Conservationlaw/utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /fex/Conservationlaw/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /fex/Poisson/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /fex/Poisson/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /fex/Poisson/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /fex/Poisson/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /fex/Poisson/.idea/poissoneqn.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /fex/Poisson/computational_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import function as func 3 | unary = func.unary_functions 4 | binary = func.binary_functions 5 | unary_functions_str = func.unary_functions_str 6 | binary_functions_str = func.binary_functions_str 7 | 8 | class BinaryTree(object): 9 | def __init__(self,item,is_unary=True): 10 | self.key=item 11 | self.is_unary=is_unary 12 | self.leftChild=None 13 | self.rightChild=None 14 | def insertLeft(self,item, is_unary=True): 15 | if self.leftChild==None: 16 | self.leftChild=BinaryTree(item, is_unary) 17 | else: 18 | t=BinaryTree(item) 19 | t.leftChild=self.leftChild 20 | self.leftChild=t 21 | def insertRight(self,item, is_unary=True): 22 | if self.rightChild==None: 23 | self.rightChild=BinaryTree(item, is_unary) 24 | else: 25 | t=BinaryTree(item) 26 | t.rightChild=self.rightChild 27 | self.rightChild=t 28 | 29 | def compute_by_tree(tree, x): 30 | ''' judge whether a emtpy tree, if yes, that means the leaves and call the unary operation ''' 31 | if tree.leftChild == None and tree.rightChild == None: 32 | return tree.key(x) 33 | elif tree.leftChild == None and tree.rightChild is not None: 34 | return tree.key(compute_by_tree(tree.rightChild, x)) 35 | elif tree.leftChild is not None and tree.rightChild == None: 36 | return tree.key(compute_by_tree(tree.leftChild, x)) 37 | else: 38 | return tree.key(compute_by_tree(tree.leftChild, x), compute_by_tree(tree.rightChild, x)) 39 | 40 | def inorder(tree, actions): 41 | global count 42 | if tree: 43 | inorder(tree.leftChild, actions) 44 | action = actions[count].item() 45 | if tree.is_unary: 46 | action = action 47 | tree.key = unary[action] 48 | # print(count, action, func.unary_functions_str[action]) 49 | else: 50 | action = action 51 | tree.key = binary[action] 52 | # print(count, action, func.binary_functions_str[action]) 53 | count = count + 1 54 | inorder(tree.rightChild, actions) 55 | 56 | # def inorder(tree): 57 | # if tree: 58 | # inorder(tree.leftChild) 59 | # print(tree.key) 60 | # inorder(tree.rightChild) 61 | 62 | count = 0 63 | def inorder_w_idx(tree): 64 | global count 65 | if tree: 66 | inorder_w_idx(tree.leftChild) 67 | print(tree.key, count) 68 | count = count + 1 69 | inorder_w_idx(tree.rightChild) 70 | 71 | def basic_tree(): 72 | tree = BinaryTree('', False) 73 | tree.insertLeft('', False) 74 | tree.leftChild.insertLeft('', True) 75 | tree.leftChild.insertRight('', True) 76 | tree.insertRight('', False) 77 | tree.rightChild.insertLeft('', True) 78 | tree.rightChild.insertRight('', True) 79 | return tree 80 | 81 | def get_function(actions): 82 | global count 83 | count = 0 84 | computation_tree = basic_tree() 85 | inorder(computation_tree, actions) 86 | count = 0 # 置零 87 | return computation_tree 88 | 89 | def inorder_test(tree, actions): 90 | global count 91 | if tree: 92 | inorder(tree.leftChild, actions) 93 | action = actions[count].item() 94 | print(action) 95 | if tree.is_unary: 96 | action = action 97 | tree.key = unary[action] 98 | # print(count, action, func.unary_functions_str[action]) 99 | else: 100 | action = action 101 | tree.key = binary[action] 102 | # print(count, action, func.binary_functions_str[action]) 103 | count = count + 1 104 | inorder(tree.rightChild, actions) 105 | 106 | if __name__ =='__main__': 107 | # tree = BinaryTree(np.add) 108 | # tree.insertLeft(np.multiply) 109 | # tree.leftChild.insertLeft(np.cos) 110 | # tree.leftChild.insertRight(np.sin) 111 | # tree.insertRight(np.sin) 112 | # print(compute_by_tree(tree, 30)) # np.sin(30)*np.cos(30)+np.sin(30) 113 | # inorder(tree) 114 | # inorder_w_idx(tree) 115 | import torch 116 | bs_action = [torch.LongTensor([10]), torch.LongTensor([2]),torch.LongTensor([9]),torch.LongTensor([1]),torch.LongTensor([0]),torch.LongTensor([2]),torch.LongTensor([6])] 117 | 118 | function = lambda x: compute_by_tree(get_function(bs_action), x) 119 | x = torch.FloatTensor([[-1], [1]]) 120 | 121 | count = 0 122 | tr = basic_tree() 123 | inorder_test(tr, bs_action) 124 | count = 0 125 | 126 | -------------------------------------------------------------------------------- /fex/Poisson/controller_poisson.py: -------------------------------------------------------------------------------- 1 | """A module with NAS controller-related code.""" 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import tools 6 | import scipy 7 | from utils import Logger, mkdir_p 8 | import os 9 | import torch.nn as nn 10 | from computational_tree import BinaryTree 11 | import function as func 12 | import argparse 13 | import random 14 | import math 15 | 16 | parser = argparse.ArgumentParser(description='NAS') 17 | 18 | parser.add_argument('--left', default=-1, type=float) 19 | parser.add_argument('--right', default=1, type=float) 20 | parser.add_argument('--epoch', default=2000, type=int) 21 | parser.add_argument('--bs', default=1, type=int) 22 | parser.add_argument('--greedy', default=0, type=float) 23 | parser.add_argument('--random_step', default=0, type=float) 24 | parser.add_argument('--ckpt', default='', type=str) 25 | parser.add_argument('--gpu', default=0, type=int) 26 | parser.add_argument('--dim', default=20, type=int) 27 | parser.add_argument('--tree', default='depth2', type=str) 28 | parser.add_argument('--lr', default=1e-2, type=float) 29 | parser.add_argument('--percentile', default=0.5, type=float) 30 | parser.add_argument('--base', default=100, type=int) 31 | parser.add_argument('--domainbs', default=1000, type=int) 32 | parser.add_argument('--bdbs', default=1000, type=int) 33 | args = parser.parse_args() 34 | 35 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 36 | 37 | unary = func.unary_functions 38 | binary = func.binary_functions 39 | unary_functions_str = func.unary_functions_str 40 | unary_functions_str_leaf = func.unary_functions_str_leaf 41 | binary_functions_str = func.binary_functions_str 42 | 43 | left = args.left 44 | right = args.right 45 | dim = args.dim 46 | 47 | def get_boundary(num_pts, dim): 48 | 49 | bd_pts = (torch.rand(num_pts, dim).cuda()) * (args.right - args.left) + args.left 50 | 51 | num_half = num_pts//2 52 | xlst = torch.arange(0, num_half) 53 | ylst = torch.randint(0, dim, (num_half,)) 54 | bd_pts[xlst, ylst] = args.left 55 | 56 | xlst = torch.arange(num_half, num_pts) 57 | ylst = torch.randint(0, dim, (num_half,)) 58 | bd_pts[xlst, ylst] = args.right 59 | 60 | return bd_pts 61 | 62 | 63 | class candidate(object): 64 | def __init__(self, action, expression, error): 65 | self.action = action 66 | self.expression = expression 67 | self.error = error 68 | 69 | class SaveBuffer(object): 70 | def __init__(self, max_size): 71 | self.max_size = max_size 72 | self.candidates = [] 73 | 74 | def num_candidates(self): 75 | return len(self.candidates) 76 | 77 | def add_new(self, candidate): 78 | flag = 1 79 | action_idx = None 80 | for idx, old_candidate in enumerate(self.candidates): 81 | if candidate.action == old_candidate.action and candidate.error < old_candidate.error: # 如果判断出来和之前的action一样的话,就不去做 82 | flag = 1 83 | action_idx = idx 84 | break 85 | elif candidate.action == old_candidate.action: 86 | flag = 0 87 | 88 | if flag == 1: 89 | if action_idx is not None: 90 | print(action_idx) 91 | self.candidates.pop(action_idx) 92 | self.candidates.append(candidate) 93 | self.candidates = sorted(self.candidates, key=lambda x: x.error) # from small to large 94 | 95 | if len(self.candidates) > self.max_size: 96 | self.candidates.pop(-1) # remove the last one 97 | 98 | if args.tree == 'depth2': 99 | def basic_tree(): 100 | tree = BinaryTree('', False) 101 | 102 | tree.insertLeft('', True) 103 | tree.leftChild.insertLeft('', False) 104 | tree.leftChild.leftChild.insertLeft('', True) 105 | tree.leftChild.leftChild.insertRight('', True) 106 | 107 | tree.insertRight('', True) 108 | tree.rightChild.insertLeft('', False) 109 | tree.rightChild.leftChild.insertLeft('', True) 110 | tree.rightChild.leftChild.insertRight('', True) 111 | return tree 112 | 113 | elif args.tree == 'depth1': 114 | def basic_tree(): 115 | 116 | tree = BinaryTree('', False) 117 | tree.insertLeft('', True) 118 | tree.insertRight('', True) 119 | 120 | return tree 121 | 122 | elif args.tree == 'depth2_rml': 123 | def basic_tree(): 124 | tree = BinaryTree('', False) 125 | 126 | tree.insertLeft('', True) 127 | tree.leftChild.insertLeft('', True) 128 | 129 | tree.insertRight('', True) 130 | tree.rightChild.insertLeft('', True) 131 | 132 | return tree 133 | 134 | elif args.tree == 'depth2_rmu': 135 | print('**************************rmu**************************') 136 | def basic_tree(): 137 | tree = BinaryTree('', False) 138 | 139 | tree.insertLeft('', True) 140 | tree.leftChild.insertLeft('', False) 141 | tree.leftChild.leftChild.insertLeft('', True) 142 | tree.leftChild.leftChild.insertRight('', True) 143 | 144 | tree.insertRight('', False) 145 | tree.rightChild.insertLeft('', True) 146 | tree.rightChild.insertRight('', True) 147 | 148 | return tree 149 | 150 | elif args.tree == 'depth2_rmu2': 151 | print('**************************rmu2**************************') 152 | def basic_tree(): 153 | tree = BinaryTree('', False) 154 | 155 | tree.insertLeft('', True) 156 | tree.leftChild.insertLeft('', False) 157 | tree.leftChild.leftChild.insertLeft('', True) 158 | tree.leftChild.leftChild.insertRight('', True) 159 | 160 | tree.insertRight('', True) 161 | # tree.rightChild.insertLeft('', True) 162 | # tree.rightChild.insertRight('', True) 163 | 164 | return tree 165 | 166 | elif args.tree == 'depth3': 167 | def basic_tree(): 168 | tree = BinaryTree('', False) 169 | 170 | tree.insertLeft('', True) 171 | tree.leftChild.insertLeft('', False) 172 | tree.leftChild.leftChild.insertLeft('', True) 173 | tree.leftChild.leftChild.leftChild.insertLeft('', False) 174 | tree.leftChild.leftChild.leftChild.leftChild.insertLeft('', True) 175 | tree.leftChild.leftChild.leftChild.leftChild.insertRight('', True) 176 | 177 | tree.leftChild.leftChild.insertRight('', True) 178 | tree.leftChild.leftChild.rightChild.insertLeft('', False) 179 | tree.leftChild.leftChild.rightChild.leftChild.insertLeft('', True) 180 | tree.leftChild.leftChild.rightChild.leftChild.insertRight('', True) 181 | 182 | tree.insertRight('', True) 183 | tree.rightChild.insertLeft('', False) 184 | tree.rightChild.leftChild.insertLeft('', True) 185 | tree.rightChild.leftChild.leftChild.insertLeft('', False) 186 | tree.rightChild.leftChild.leftChild.leftChild.insertLeft('', True) 187 | tree.rightChild.leftChild.leftChild.leftChild.insertRight('', True) 188 | 189 | tree.rightChild.leftChild.insertRight('', True) 190 | tree.rightChild.leftChild.rightChild.insertLeft('', False) 191 | tree.rightChild.leftChild.rightChild.leftChild.insertLeft('', True) 192 | tree.rightChild.leftChild.rightChild.leftChild.insertRight('', True) 193 | return tree 194 | 195 | elif args.tree == 'depth2_sub': 196 | print('**************************sub**************************') 197 | def basic_tree(): 198 | tree = BinaryTree('', True) 199 | 200 | tree.insertLeft('', False) 201 | tree.leftChild.insertLeft('', True) 202 | tree.leftChild.insertRight('', True) 203 | 204 | # tree.rightChild.insertLeft('', True) 205 | # tree.rightChild.insertRight('', True) 206 | 207 | return tree 208 | 209 | structure = [] 210 | 211 | def inorder_structure(tree): 212 | global structure 213 | if tree: 214 | inorder_structure(tree.leftChild) 215 | structure.append(tree.is_unary) 216 | inorder_structure(tree.rightChild) 217 | inorder_structure(basic_tree()) 218 | print('tree structure', structure) 219 | 220 | structure_choice = [] 221 | for is_unary in structure: 222 | if is_unary == True: 223 | structure_choice.append(len(unary)) 224 | else: 225 | structure_choice.append(len(binary)) 226 | print('tree structure choices', structure_choice) 227 | 228 | if args.tree == 'depth1': 229 | def basic_tree(): 230 | tree = BinaryTree('', False) 231 | 232 | tree.insertLeft('', True) 233 | tree.insertRight('', True) 234 | 235 | return tree 236 | 237 | elif args.tree == 'depth2': 238 | def basic_tree(): 239 | tree = BinaryTree('', False) 240 | 241 | tree.insertLeft('', True) 242 | tree.leftChild.insertLeft('', False) 243 | tree.leftChild.leftChild.insertLeft('', True) 244 | tree.leftChild.leftChild.insertRight('', True) 245 | 246 | tree.insertRight('', True) 247 | tree.rightChild.insertLeft('', False) 248 | tree.rightChild.leftChild.insertLeft('', True) 249 | tree.rightChild.leftChild.insertRight('', True) 250 | return tree 251 | 252 | elif args.tree == 'depth3': 253 | def basic_tree(): 254 | tree = BinaryTree('', False) 255 | 256 | tree.insertLeft('', True) 257 | tree.leftChild.insertLeft('', False) 258 | tree.leftChild.leftChild.insertLeft('', True) 259 | tree.leftChild.leftChild.leftChild.insertLeft('', False) 260 | tree.leftChild.leftChild.leftChild.leftChild.insertLeft('', True) 261 | tree.leftChild.leftChild.leftChild.leftChild.insertRight('', True) 262 | 263 | tree.leftChild.leftChild.insertRight('', True) 264 | tree.leftChild.leftChild.rightChild.insertLeft('', False) 265 | tree.leftChild.leftChild.rightChild.leftChild.insertLeft('', True) 266 | tree.leftChild.leftChild.rightChild.leftChild.insertRight('', True) 267 | 268 | tree.insertRight('', True) 269 | tree.rightChild.insertLeft('', False) 270 | tree.rightChild.leftChild.insertLeft('', True) 271 | tree.rightChild.leftChild.leftChild.insertLeft('', False) 272 | tree.rightChild.leftChild.leftChild.leftChild.insertLeft('', True) 273 | tree.rightChild.leftChild.leftChild.leftChild.insertRight('', True) 274 | 275 | tree.rightChild.leftChild.insertRight('', True) 276 | tree.rightChild.leftChild.rightChild.insertLeft('', False) 277 | tree.rightChild.leftChild.rightChild.leftChild.insertLeft('', True) 278 | tree.rightChild.leftChild.rightChild.leftChild.insertRight('', True) 279 | return tree 280 | 281 | structure = [] 282 | leaves_index = [] 283 | leaves = 0 284 | count = 0 285 | 286 | def inorder_structure(tree): 287 | global structure, leaves, count, leaves_index 288 | if tree: 289 | inorder_structure(tree.leftChild) 290 | structure.append(tree.is_unary) 291 | if tree.leftChild is None and tree.rightChild is None: 292 | leaves = leaves + 1 293 | leaves_index.append(count) 294 | count = count + 1 295 | inorder_structure(tree.rightChild) 296 | 297 | 298 | inorder_structure(basic_tree()) 299 | 300 | print('leaves index:', leaves_index) 301 | 302 | print('tree structure:', structure, 'leaves num:', leaves) 303 | 304 | structure_choice = [] 305 | for is_unary in structure: 306 | if is_unary == True: 307 | structure_choice.append(len(unary)) 308 | else: 309 | structure_choice.append(len(binary)) 310 | print('tree structure choices', structure_choice) 311 | 312 | def reset_params(tree_params): 313 | for v in tree_params: 314 | # v.data.fill_(0.01) 315 | v.data.normal_(0.0, 0.1) 316 | 317 | def inorder(tree, actions): 318 | global count 319 | if tree: 320 | inorder(tree.leftChild, actions) 321 | action = actions[count].item() 322 | if tree.is_unary: 323 | action = action 324 | tree.key = unary[action] 325 | # print(count, action, func.unary_functions_str[action]) 326 | else: 327 | action = action 328 | tree.key = binary[action] 329 | # print(count, action, func.binary_functions_str[action]) 330 | count = count + 1 331 | inorder(tree.rightChild, actions) 332 | 333 | def inorder_visualize(tree, actions, trainable_tree): 334 | global count, leaves_cnt 335 | if tree: 336 | leftfun = inorder_visualize(tree.leftChild, actions, trainable_tree) 337 | action = actions[count].item() 338 | # print('123', tree.key) 339 | if tree.is_unary:# and not tree.key.is_leave: 340 | if count not in leaves_index: 341 | midfun = unary_functions_str[action] 342 | a = trainable_tree.learnable_operator_set[count][action].a.item() 343 | b = trainable_tree.learnable_operator_set[count][action].b.item() 344 | else: 345 | midfun = unary_functions_str_leaf[action] 346 | else: 347 | midfun = binary_functions_str[action] 348 | 349 | count = count + 1 350 | rightfun = inorder_visualize(tree.rightChild, actions, trainable_tree) 351 | if leftfun is None and rightfun is None: 352 | w = [] 353 | for i in range(dim): 354 | w.append(trainable_tree.linear[leaves_cnt].weight[0][i].item()) 355 | bias = trainable_tree.linear[leaves_cnt].bias[0].item() 356 | leaves_cnt = leaves_cnt + 1 357 | ## -------------------------------------- input variable element wise ---------------------------- 358 | expression = '' 359 | for i in range(0, dim): 360 | # print(midfun) 361 | x_expression = midfun.format('x'+str(i)) 362 | expression = expression + ('{:.4f}*{}'+'+').format(w[i], x_expression) 363 | expression = expression+'{:.4f}'.format(bias) 364 | expression = '('+expression+')' 365 | # print('visualize', count, leaves_cnt, action) 366 | return expression 367 | elif leftfun is not None and rightfun is None: 368 | if '(0)' in midfun or '(1)' in midfun: 369 | return midfun.format('{:.4f}'.format(a), '{:.4f}'.format(b)) 370 | else: 371 | return midfun.format('{:.4f}'.format(a), leftfun, '{:.4f}'.format(b)) 372 | elif tree.leftChild is None and tree.rightChild is not None: 373 | if '(0)' in midfun or '(1)' in midfun: 374 | return midfun.format('{:.4f}'.format(a), '{:.4f}'.format(b)) 375 | else: 376 | return midfun.format('{:.4f}'.format(a), rightfun, '{:.4f}'.format(b)) 377 | else: 378 | return midfun.format(leftfun, rightfun) 379 | else: 380 | return None 381 | 382 | def get_function(actions): 383 | global count 384 | count = 0 385 | computation_tree = basic_tree() 386 | inorder(computation_tree, actions) 387 | count = 0 388 | return computation_tree 389 | 390 | def inorder_params(tree, actions, unary_choices): 391 | global count 392 | if tree: 393 | inorder_params(tree.leftChild, actions, unary_choices) 394 | action = actions[count].item() 395 | if tree.is_unary: 396 | action = action 397 | tree.key = unary_choices[count][action] 398 | else: 399 | action = action 400 | tree.key = unary_choices[count][len(unary)+action] 401 | count = count + 1 402 | inorder_params(tree.rightChild, actions, unary_choices) 403 | 404 | def get_function_trainable_params(actions, unary_choices): 405 | global count 406 | count = 0 407 | computation_tree = basic_tree() 408 | inorder_params(computation_tree, actions, unary_choices) 409 | count = 0 410 | return computation_tree 411 | 412 | class unary_operation(nn.Module): 413 | def __init__(self, operator, is_leave): 414 | super(unary_operation, self).__init__() 415 | self.unary = operator 416 | if not is_leave: 417 | self.a = nn.Parameter(torch.Tensor(1).cuda()) 418 | self.a.data.fill_(1) 419 | self.b = nn.Parameter(torch.Tensor(1).cuda()) 420 | self.b.data.fill_(0) 421 | self.is_leave = is_leave 422 | 423 | def forward(self, x): 424 | if self.is_leave: 425 | return self.unary(x) 426 | else: 427 | return self.a*self.unary(x)+self.b 428 | 429 | class binary_operation(nn.Module): 430 | def __init__(self, operator): 431 | super(binary_operation, self).__init__() 432 | self.binary = operator 433 | def forward(self, x, y): 434 | return self.binary(x, y) 435 | 436 | leaves_cnt = 0 437 | 438 | def compute_by_tree(tree, linear, x): 439 | if tree.leftChild == None and tree.rightChild == None: # leaf node 440 | global leaves_cnt 441 | transformation = linear[leaves_cnt] 442 | leaves_cnt = leaves_cnt + 1 443 | return transformation(tree.key(x)) 444 | elif tree.leftChild is None and tree.rightChild is not None: 445 | return tree.key(compute_by_tree(tree.rightChild, linear, x)) 446 | elif tree.leftChild is not None and tree.rightChild is None: 447 | return tree.key(compute_by_tree(tree.leftChild, linear, x)) 448 | else: 449 | return tree.key(compute_by_tree(tree.leftChild, linear, x), compute_by_tree(tree.rightChild, linear, x)) 450 | 451 | class learnable_compuatation_tree(nn.Module): 452 | def __init__(self): 453 | super(learnable_compuatation_tree, self).__init__() 454 | self.learnable_operator_set = {} 455 | for i in range(len(structure)): 456 | self.learnable_operator_set[i] = [] 457 | is_leave = i in leaves_index 458 | for j in range(len(unary)): 459 | self.learnable_operator_set[i].append(unary_operation(unary[j], is_leave)) 460 | for j in range(len(binary)): 461 | self.learnable_operator_set[i].append(binary_operation(binary[j])) 462 | self.linear = [] 463 | for num, i in enumerate(range(leaves)): 464 | linear_module = torch.nn.Linear(dim, 1, bias=True).cuda() #set only one variable 465 | linear_module.weight.data.normal_(0, 1/math.sqrt(dim)) 466 | linear_module.bias.data.fill_(0) 467 | self.linear.append(linear_module) 468 | 469 | def forward(self, x, bs_action): 470 | # print(len(bs_action)) 471 | global leaves_cnt 472 | leaves_cnt = 0 473 | function = lambda y: compute_by_tree(get_function_trainable_params(bs_action, self.learnable_operator_set), self.linear, y) 474 | out = function(x) 475 | leaves_cnt = 0 476 | return out 477 | 478 | class Controller(torch.nn.Module): 479 | def __init__(self): 480 | torch.nn.Module.__init__(self) 481 | 482 | self.softmax_temperature = 5.0 483 | self.tanh_c = 2.5 484 | self.mode = True 485 | 486 | self.input_size = 20 487 | self.hidden_size = 50 488 | self.output_size = sum(structure_choice) 489 | 490 | self._fc_controller = nn.Sequential( 491 | nn.Linear(self.input_size,self.hidden_size), 492 | nn.ReLU(inplace=True), 493 | nn.Linear(self.hidden_size,self.output_size)) 494 | 495 | def forward(self,x): 496 | logits = self._fc_controller(x) 497 | 498 | logits /= self.softmax_temperature 499 | 500 | # exploration # ?? 501 | if self.mode == 'train': 502 | logits = (self.tanh_c*F.tanh(logits)) 503 | 504 | return logits 505 | 506 | def sample(self, batch_size=1, step=0): 507 | """Samples a set of `args.num_blocks` many computational nodes from the 508 | controller, where each node is made up of an activation function, and 509 | each node except the last also includes a previous node. 510 | """ 511 | 512 | # [B, L, H] 513 | inputs = torch.zeros(batch_size, self.input_size).cuda() 514 | log_probs = [] 515 | actions = [] 516 | total_logits = self.forward(inputs) 517 | 518 | cumsum = np.cumsum([0]+structure_choice) 519 | for idx in range(len(structure_choice)): 520 | logits = total_logits[:, cumsum[idx]:cumsum[idx+1]] 521 | 522 | probs = F.softmax(logits, dim=-1) 523 | log_prob = F.log_softmax(logits, dim=-1) 524 | # print(probs) 525 | if step>=args.random_step: 526 | action = probs.multinomial(num_samples=1).data 527 | else: 528 | action = torch.randint(0, structure_choice[idx], size=(batch_size, 1)).cuda() 529 | # print('old', action) 530 | if args.greedy is not 0: 531 | for k in range(args.bs): 532 | if np.random.rand(1) hyperparams['discount'] > 0: 681 | rewards = discount(rewards, hyperparams['discount']) 682 | 683 | base = args.base 684 | rewards[rewards > base] = base 685 | rewards[rewards != rewards] = 1e10 686 | error = rewards 687 | rewards = 1 / (1 + torch.sqrt(rewards)) 688 | 689 | batch_smallest = error.min() 690 | batch_min_idx = torch.argmin(error) 691 | batch_min_action = [v[batch_min_idx] for v in actions] 692 | 693 | batch_best_formula = formulas[batch_min_idx] 694 | 695 | candidates.add_new(candidate(action=batch_min_action, expression=batch_best_formula, error=batch_smallest)) 696 | 697 | for candidate_ in candidates.candidates: 698 | print('error:{} action:{} formula:{}'.format(candidate_.error.item(), [v.item() for v in candidate_.action], candidate_.expression)) 699 | 700 | # moving average baseline 701 | if baseline is None: 702 | baseline = (rewards).mean() 703 | else: 704 | decay = hyperparams['ema_baseline_decay'] 705 | baseline = decay * baseline + (1 - decay) * (rewards).mean() 706 | 707 | argsort = torch.argsort(rewards.squeeze(1), descending=True) 708 | # print(error, argsort) 709 | # print(rewards.size(), rewards.squeeze(1), torch.argsort(rewards.squeeze(1)), rewards[argsort]) 710 | # policy loss 711 | num = int(args.bs * args.percentile) 712 | rewards_sort = rewards[argsort] 713 | adv = rewards_sort - rewards_sort[num:num + 1, 0:] # - baseline 714 | # print(error, argsort, rewards_sort, adv) 715 | log_probs_sort = log_probs[argsort] 716 | # print('adv', adv) 717 | loss = -log_probs_sort[:num] * tools.get_variable(adv[:num], True, requires_grad=False) 718 | loss = (loss.sum(1)).mean() 719 | 720 | # update 721 | controller_optim.zero_grad() 722 | loss.backward() 723 | 724 | if hyperparams['controller_grad_clip'] > 0: 725 | torch.nn.utils.clip_grad_norm(model.parameters(), 726 | hyperparams['controller_grad_clip']) 727 | Controller_optim.step() 728 | 729 | min_error = error.min().item() 730 | if smallest_error>min_error: 731 | smallest_error = min_error 732 | 733 | min_idx = torch.argmin(error) 734 | min_action = [v[min_idx] for v in actions] 735 | best_formula = formulas[min_idx] 736 | 737 | 738 | log = 'Step: {step}| Loss: {loss:.4f}| Action: {act} |Baseline: {base:.4f}| ' \ 739 | 'Reward {re:.4f} | {error:.8f} {formula}'.format(loss=loss.item(), base=baseline, act=binary_code, 740 | re=(rewards).mean(), step=step, formula=best_formula, 741 | error=smallest_error) 742 | print('********************************************************************************************************') 743 | print(log) 744 | print('********************************************************************************************************') 745 | if (step + 1) % 1 == 0: 746 | logger.append([step + 1, loss.item(), baseline, rewards.mean(), smallest_error, best_formula]) 747 | 748 | for candidate_ in candidates.candidates: 749 | print('error:{} action:{} formula:{}'.format(candidate_.error.item(), [v.item() for v in candidate_.action], 750 | candidate_.expression)) 751 | logger.append([666, 0, 0, 0, candidate_.error.item(), candidate_.expression]) 752 | 753 | finetune = 20000 754 | global count, leaves_cnt 755 | for candidate_ in candidates.candidates: 756 | trainable_tree = learnable_compuatation_tree() 757 | trainable_tree = trainable_tree.cuda() 758 | 759 | params = [] 760 | for idx, v in enumerate(trainable_tree.learnable_operator_set): 761 | if idx not in leaves_index: 762 | for modules in trainable_tree.learnable_operator_set[v]: 763 | for param in modules.parameters(): 764 | params.append(param) 765 | for module in trainable_tree.linear: 766 | for param in module.parameters(): 767 | params.append(param) 768 | 769 | reset_params(params) 770 | tree_optim = torch.optim.Adam(params, lr=1e-2) 771 | 772 | for current_iter in range(finetune): 773 | error = best_error(candidate_.action, trainable_tree) 774 | tree_optim.zero_grad() 775 | error.backward() 776 | 777 | tree_optim.step() 778 | 779 | count = 0 780 | leaves_cnt = 0 781 | formula = inorder_visualize(basic_tree(), candidate_.action, trainable_tree) 782 | leaves_cnt = 0 783 | count = 0 784 | suffix = 'Finetune-- Iter {current_iter} Error {error:.5f} Formula {formula}'.format(current_iter=current_iter, error=error, formula=formula) 785 | if (current_iter + 1) % 100 == 0: 786 | logger.append([current_iter, 0, 0, 0, error.item(), formula]) 787 | 788 | cosine_lr(tree_optim, 1e-2, current_iter, finetune) 789 | print(suffix) 790 | 791 | numerators = [] 792 | denominators = [] 793 | 794 | for i in range(1000): 795 | print(i) 796 | x = (torch.rand(100000, args.dim).cuda()) * (args.right - args.left) + args.left 797 | sq_de = torch.mean((func.true_solution(x))**2) 798 | sq_nu = torch.mean((func.true_solution(x)-trainable_tree(x, candidate_.action)) ** 2) 799 | numerators.append(sq_nu.item()) 800 | denominators.append(sq_de.item()) 801 | 802 | relative_l2 = math.sqrt(sum(numerators)) / math.sqrt(sum(denominators)) 803 | print('relative l2 error: ', relative_l2) 804 | logger.append(['relative_l2', 0, 0, 0, relative_l2, 0]) 805 | 806 | def cosine_lr(opt, base_lr, e, epochs): 807 | lr = 0.5 * base_lr * (math.cos(math.pi * e / epochs) + 1) 808 | for param_group in opt.param_groups: 809 | param_group["lr"] = lr 810 | return lr 811 | 812 | if __name__ == '__main__': 813 | controller = Controller().cuda() 814 | hyperparams = {} 815 | 816 | hyperparams['controller_max_step'] = args.epoch 817 | hyperparams['discount'] = 1.0 818 | hyperparams['ema_baseline_decay'] = 0.95 819 | hyperparams['controller_lr'] = args.lr 820 | hyperparams['entropy_mode'] = 'reward' 821 | hyperparams['controller_grad_clip'] = 0#10 822 | hyperparams['checkpoint'] = args.ckpt 823 | if not os.path.isdir(hyperparams['checkpoint']): 824 | mkdir_p(hyperparams['checkpoint']) 825 | controller_optim = torch.optim.Adam(controller.parameters(), lr= hyperparams['controller_lr']) 826 | 827 | trainable_tree = learnable_compuatation_tree() 828 | trainable_tree = trainable_tree.cuda() 829 | 830 | params = [] 831 | for idx, v in enumerate(trainable_tree.learnable_operator_set): 832 | if idx not in leaves_index: 833 | for modules in trainable_tree.learnable_operator_set[v]: 834 | for param in modules.parameters(): 835 | params.append(param) 836 | for module in trainable_tree.linear: 837 | for param in module.parameters(): 838 | params.append(param) 839 | 840 | train_controller(controller, controller_optim, trainable_tree, params, hyperparams) 841 | -------------------------------------------------------------------------------- /fex/Poisson/function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import sin, cos, exp 4 | import math 5 | 6 | def LHS_pde(u, x, dim_set): 7 | 8 | v = torch.ones(u.shape).cuda() 9 | bs = x.size(0) 10 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0] 11 | uxx = torch.zeros(bs, dim_set).cuda() 12 | for i in range(dim_set): 13 | ux_tem = ux[:, i:i+1] 14 | uxx_tem = torch.autograd.grad(ux_tem, x, grad_outputs=v, create_graph=True)[0] 15 | uxx[:, i] = uxx_tem[:, i] 16 | LHS = -torch.sum(uxx, dim=1, keepdim=True) 17 | return LHS 18 | 19 | def RHS_pde(x): 20 | bs = x.size(0) 21 | dim = x.size(1) 22 | return -dim*torch.ones(bs, 1).cuda() 23 | 24 | def true_solution(x): 25 | return 0.5*torch.sum(x**2, dim=1, keepdim=True)#1 / (2 * x[:, 0:1] + x[:, 1:2]-5) 26 | 27 | 28 | unary_functions = [lambda x: 0*x**2, 29 | lambda x: 1+0*x**2, 30 | lambda x: x+0*x**2, 31 | lambda x: x**2, 32 | lambda x: x**3, 33 | lambda x: x**4, 34 | torch.exp, 35 | torch.sin, 36 | torch.cos,] 37 | 38 | binary_functions = [lambda x,y: x+y, 39 | lambda x,y: x*y, 40 | lambda x,y: x-y] 41 | 42 | 43 | unary_functions_str = ['({}*(0)+{})', 44 | '({}*(1)+{})', 45 | # '5', 46 | '({}*{}+{})', 47 | # '-{}', 48 | '({}*({})**2+{})', 49 | '({}*({})**3+{})', 50 | '({}*({})**4+{})', 51 | # '({})**5', 52 | '({}*exp({})+{})', 53 | '({}*sin({})+{})', 54 | '({}*cos({})+{})',] 55 | # 'ref({})', 56 | # 'exp(-({})**2/2)'] 57 | 58 | unary_functions_str_leaf= ['(0)', 59 | '(1)', 60 | # '5', 61 | '({})', 62 | # '-{}', 63 | '(({})**2)', 64 | '(({})**3)', 65 | '(({})**4)', 66 | # '({})**5', 67 | '(exp({}))', 68 | '(sin({}))', 69 | '(cos({}))',] 70 | 71 | 72 | binary_functions_str = ['(({})+({}))', 73 | '(({})*({}))', 74 | '(({})-({}))'] 75 | 76 | if __name__ == '__main__': 77 | batch_size = 200 78 | left = -1 79 | right = 1 80 | points = (torch.rand(batch_size, 1)) * (right - left) + left 81 | x = torch.autograd.Variable(points.cuda(), requires_grad=True) 82 | function = true_solution 83 | 84 | ''' 85 | PDE loss 86 | ''' 87 | LHS = LHS_pde(function(x), x) 88 | RHS = RHS_pde(x) 89 | pde_loss = torch.nn.functional.mse_loss(LHS, RHS) 90 | 91 | ''' 92 | boundary loss 93 | ''' 94 | bc_points = torch.FloatTensor([[left], [right]]).cuda() 95 | bc_value = true_solution(bc_points) 96 | bd_loss = torch.nn.functional.mse_loss(function(bc_points), bc_value) 97 | 98 | print('pde loss: {} -- boundary loss: {}'.format(pde_loss.item(), bd_loss.item())) -------------------------------------------------------------------------------- /fex/Poisson/scripts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | 5 | for i in range(6): 6 | for dim in [10, 20, 30, 40, 50]: 7 | gpu = random.randint(0,4) 8 | os.system('screen python controller_poisson.py --epoch 1000 --bs 10 --greedy 0.1 --gpu '+str(gpu)+' --ckpt Dim'+str(dim)+' --tree depth2_sub --random_step 3 --lr 0.002 --dim '+str(dim)+' --base 200000 --left -1 --right 1 --domainbs 5000 --bdbs 1000') 9 | time.sleep(100) -------------------------------------------------------------------------------- /fex/Poisson/tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from collections import defaultdict 4 | import collections 5 | from datetime import datetime 6 | import os 7 | import json 8 | import logging 9 | 10 | import numpy as np 11 | import torch 12 | from torch.autograd import Variable 13 | 14 | ########################## 15 | # Torch 16 | ########################## 17 | 18 | def detach(h): 19 | if type(h) == Variable: 20 | return Variable(h.data) 21 | else: 22 | return tuple(detach(v) for v in h) 23 | 24 | def get_variable(inputs, cuda=False, **kwargs): 25 | if type(inputs) in [list, np.ndarray]: 26 | inputs = torch.Tensor(inputs) 27 | if cuda: 28 | out = Variable(inputs.cuda(), **kwargs) 29 | else: 30 | out = Variable(inputs, **kwargs) 31 | return out 32 | 33 | def update_lr(optimizer, lr): 34 | for param_group in optimizer.param_groups: 35 | param_group['lr'] = lr 36 | 37 | def batchify(data, bsz, use_cuda): 38 | # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py 39 | nbatch = data.size(0) // bsz 40 | data = data.narrow(0, 0, nbatch * bsz) 41 | data = data.view(bsz, -1).t().contiguous() 42 | if use_cuda: 43 | data = data.cuda() 44 | return data 45 | 46 | 47 | ########################## 48 | # ETC 49 | ########################## 50 | 51 | Node = collections.namedtuple('Node', ['id', 'name']) 52 | 53 | 54 | class keydefaultdict(defaultdict): 55 | def __missing__(self, key): 56 | if self.default_factory is None: 57 | raise KeyError(key) 58 | else: 59 | ret = self[key] = self.default_factory(key) 60 | return ret 61 | 62 | 63 | def to_item(x): 64 | """Converts x, possibly scalar and possibly tensor, to a Python scalar.""" 65 | if isinstance(x, (float, int)): 66 | return x 67 | 68 | if float(torch.__version__[0:3]) < 0.4: 69 | assert (x.dim() == 1) and (len(x) == 1) 70 | return x[0] 71 | 72 | return x.item() 73 | 74 | 75 | def get_logger(name=__file__, level=logging.INFO): 76 | logger = logging.getLogger(name) 77 | 78 | if getattr(logger, '_init_done__', None): 79 | logger.setLevel(level) 80 | return logger 81 | 82 | logger._init_done__ = True 83 | logger.propagate = False 84 | logger.setLevel(level) 85 | 86 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 87 | handler = logging.StreamHandler() 88 | handler.setFormatter(formatter) 89 | handler.setLevel(0) 90 | 91 | del logger.handlers[:] 92 | logger.addHandler(handler) 93 | 94 | return logger 95 | 96 | 97 | logger = get_logger() 98 | 99 | 100 | def prepare_dirs(args): 101 | """Sets the directories for the model, and creates those directories. 102 | 103 | Args: 104 | args: Parsed from `argparse` in the `config` module. 105 | """ 106 | if args.load_path: 107 | if args.load_path.startswith(args.log_dir): 108 | args.model_dir = args.load_path 109 | else: 110 | if args.load_path.startswith(args.dataset): 111 | args.model_name = args.load_path 112 | else: 113 | args.model_name = "{}_{}".format(args.dataset, args.load_path) 114 | else: 115 | args.model_name = "{}_{}".format(args.dataset, get_time()) 116 | 117 | if not hasattr(args, 'model_dir'): 118 | args.model_dir = os.path.join(args.log_dir, args.model_name) 119 | args.data_path = os.path.join(args.data_dir, args.dataset) 120 | 121 | for path in [args.log_dir, args.data_dir, args.model_dir]: 122 | if not os.path.exists(path): 123 | makedirs(path) 124 | 125 | def get_time(): 126 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 127 | 128 | def save_args(args): 129 | param_path = os.path.join(args.model_dir, "params.json") 130 | 131 | logger.info("[*] MODEL dir: %s" % args.model_dir) 132 | logger.info("[*] PARAM path: %s" % param_path) 133 | 134 | with open(param_path, 'w') as fp: 135 | json.dump(args.__dict__, fp, indent=4, sort_keys=True) 136 | 137 | def save_dag(args, dag, name): 138 | save_path = os.path.join(args.model_dir, name) 139 | logger.info("[*] Save dag : {}".format(save_path)) 140 | json.dump(dag, open(save_path, 'w')) 141 | 142 | def load_dag(args): 143 | load_path = os.path.join(args.dag_path) 144 | logger.info("[*] Load dag : {}".format(load_path)) 145 | with open(load_path) as f: 146 | dag = json.load(f) 147 | dag = {int(k): [Node(el[0], el[1]) for el in v] for k, v in dag.items()} 148 | save_dag(args, dag, "dag.json") 149 | draw_network(dag, os.path.join(args.model_dir, "dag.png")) 150 | return dag 151 | 152 | def makedirs(path): 153 | if not os.path.exists(path): 154 | logger.info("[*] Make directories : {}".format(path)) 155 | os.makedirs(path) 156 | 157 | def remove_file(path): 158 | if os.path.exists(path): 159 | logger.info("[*] Removed: {}".format(path)) 160 | os.remove(path) 161 | 162 | def backup_file(path): 163 | root, ext = os.path.splitext(path) 164 | new_path = "{}.backup_{}{}".format(root, get_time(), ext) 165 | 166 | os.rename(path, new_path) 167 | logger.info("[*] {} has backup: {}".format(path, new_path)) 168 | -------------------------------------------------------------------------------- /fex/Poisson/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | # from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /fex/Poisson/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /fex/Poisson/utils/images/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Poisson/utils/images/cifar.png -------------------------------------------------------------------------------- /fex/Poisson/utils/images/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Poisson/utils/images/imagenet.png -------------------------------------------------------------------------------- /fex/Poisson/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | def append(self, numbers): 61 | assert len(self.names) == len(numbers), 'Numbers do not match names' 62 | for index, num in enumerate(numbers): 63 | if isinstance(num, int): 64 | self.file.write("{}".format(num)) 65 | elif isinstance(num, str): 66 | self.file.write("{}".format(num)) 67 | else: 68 | self.file.write("{0:.8f}".format(num)) 69 | self.file.write('\t') 70 | self.numbers[self.names[index]].append(num) 71 | self.file.write('\n') 72 | self.file.flush() 73 | 74 | def plot(self, names=None): 75 | names = self.names if names == None else names 76 | numbers = self.numbers 77 | for _, name in enumerate(names): 78 | x = np.arange(len(numbers[name])) 79 | plt.plot(x, np.asarray(numbers[name])) 80 | plt.legend([self.title + '(' + name + ')' for name in names]) 81 | plt.grid(True) 82 | 83 | def close(self): 84 | if self.file is not None: 85 | self.file.close() 86 | 87 | class LoggerMonitor(object): 88 | '''Load and visualize multiple logs.''' 89 | def __init__ (self, paths): 90 | '''paths is a distionary with {name:filepath} pair''' 91 | self.loggers = [] 92 | for title, path in paths.items(): 93 | logger = Logger(path, title=title, resume=True) 94 | self.loggers.append(logger) 95 | 96 | def plot(self, names=None): 97 | plt.figure() 98 | plt.subplot(121) 99 | legend_text = [] 100 | for logger in self.loggers: 101 | legend_text += plot_overlap(logger, names) 102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 103 | plt.grid(True) 104 | 105 | if __name__ == '__main__': 106 | # # Example 107 | # logger = Logger('test.txt') 108 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 109 | 110 | # length = 100 111 | # t = np.arange(length) 112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 115 | 116 | # for i in range(0, length): 117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 118 | # logger.plot() 119 | 120 | # Example: logger monitor 121 | paths = { 122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 125 | } 126 | 127 | field = ['Valid Acc.'] 128 | 129 | monitor = LoggerMonitor(paths) 130 | monitor.plot(names=field) 131 | savefig('test.eps') -------------------------------------------------------------------------------- /fex/Poisson/utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /fex/Poisson/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /fex/Schrodinger/computational_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import function as func 3 | unary = func.unary_functions 4 | binary = func.binary_functions 5 | unary_functions_str = func.unary_functions_str 6 | binary_functions_str = func.binary_functions_str 7 | 8 | class BinaryTree(object): 9 | def __init__(self,item,is_unary=True): 10 | self.key=item 11 | self.is_unary=is_unary 12 | self.leftChild=None 13 | self.rightChild=None 14 | def insertLeft(self,item, is_unary=True): 15 | if self.leftChild==None: 16 | self.leftChild=BinaryTree(item, is_unary) 17 | else: 18 | t=BinaryTree(item) 19 | t.leftChild=self.leftChild 20 | self.leftChild=t 21 | def insertRight(self,item, is_unary=True): 22 | if self.rightChild==None: 23 | self.rightChild=BinaryTree(item, is_unary) 24 | else: 25 | t=BinaryTree(item) 26 | t.rightChild=self.rightChild 27 | self.rightChild=t 28 | 29 | def compute_by_tree(tree, x): 30 | ''' judge whether a emtpy tree, if yes, that means the leaves and call the unary operation ''' 31 | if tree.leftChild == None and tree.rightChild == None: 32 | return tree.key(x) 33 | elif tree.leftChild == None and tree.rightChild is not None: 34 | return tree.key(compute_by_tree(tree.rightChild, x)) 35 | elif tree.leftChild is not None and tree.rightChild == None: 36 | return tree.key(compute_by_tree(tree.leftChild, x)) 37 | else: 38 | return tree.key(compute_by_tree(tree.leftChild, x), compute_by_tree(tree.rightChild, x)) 39 | 40 | def inorder(tree, actions): 41 | global count 42 | if tree: 43 | inorder(tree.leftChild, actions) 44 | action = actions[count].item() 45 | if tree.is_unary: 46 | action = action 47 | tree.key = unary[action] 48 | # print(count, action, func.unary_functions_str[action]) 49 | else: 50 | action = action 51 | tree.key = binary[action] 52 | # print(count, action, func.binary_functions_str[action]) 53 | count = count + 1 54 | inorder(tree.rightChild, actions) 55 | 56 | # def inorder(tree): 57 | # if tree: 58 | # inorder(tree.leftChild) 59 | # print(tree.key) 60 | # inorder(tree.rightChild) 61 | 62 | count = 0 63 | def inorder_w_idx(tree): 64 | global count 65 | if tree: 66 | inorder_w_idx(tree.leftChild) 67 | print(tree.key, count) 68 | count = count + 1 69 | inorder_w_idx(tree.rightChild) 70 | 71 | def basic_tree(): 72 | tree = BinaryTree('', False) 73 | tree.insertLeft('', False) 74 | tree.leftChild.insertLeft('', True) 75 | tree.leftChild.insertRight('', True) 76 | tree.insertRight('', False) 77 | tree.rightChild.insertLeft('', True) 78 | tree.rightChild.insertRight('', True) 79 | return tree 80 | 81 | def get_function(actions): 82 | global count 83 | count = 0 84 | computation_tree = basic_tree() 85 | inorder(computation_tree, actions) 86 | count = 0 # 置零 87 | return computation_tree 88 | 89 | def inorder_test(tree, actions): 90 | global count 91 | if tree: 92 | inorder(tree.leftChild, actions) 93 | action = actions[count].item() 94 | print(action) 95 | if tree.is_unary: 96 | action = action 97 | tree.key = unary[action] 98 | # print(count, action, func.unary_functions_str[action]) 99 | else: 100 | action = action 101 | tree.key = binary[action] 102 | # print(count, action, func.binary_functions_str[action]) 103 | count = count + 1 104 | inorder(tree.rightChild, actions) 105 | 106 | if __name__ =='__main__': 107 | # tree = BinaryTree(np.add) 108 | # tree.insertLeft(np.multiply) 109 | # tree.leftChild.insertLeft(np.cos) 110 | # tree.leftChild.insertRight(np.sin) 111 | # tree.insertRight(np.sin) 112 | # print(compute_by_tree(tree, 30)) # np.sin(30)*np.cos(30)+np.sin(30) 113 | # inorder(tree) 114 | # inorder_w_idx(tree) 115 | import torch 116 | bs_action = [torch.LongTensor([10]), torch.LongTensor([2]),torch.LongTensor([9]),torch.LongTensor([1]),torch.LongTensor([0]),torch.LongTensor([2]),torch.LongTensor([6])] 117 | 118 | function = lambda x: compute_by_tree(get_function(bs_action), x) 119 | x = torch.FloatTensor([[-1], [1]]) 120 | 121 | count = 0 122 | tr = basic_tree() 123 | inorder_test(tr, bs_action) 124 | count = 0 125 | 126 | -------------------------------------------------------------------------------- /fex/Schrodinger/function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import sin, cos, exp 4 | import math 5 | 6 | c = 3 7 | 8 | def LHS_pde(u, x, dim_set): 9 | 10 | v = torch.ones(u.shape).cuda() 11 | bs = x.size(0) 12 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0] 13 | uxx = torch.zeros(bs, dim_set).cuda() 14 | for i in range(dim_set): 15 | ux_tem = ux[:, i].reshape([x.size()[0], 1]) 16 | uxx_tem = torch.autograd.grad(ux_tem, x, grad_outputs=v, create_graph=True)[0] 17 | uxx[:, i] = uxx_tem[:, i] 18 | 19 | LHS = -torch.sum(uxx, dim=1, keepdim=True) 20 | V = -exp(2/dim_set*torch.sum(cos(x), dim=1, keepdim=True))/(c**2)+torch.sum(sin(x)**2/(dim_set**2), dim=1, keepdim=True)-torch.sum(cos(x)/(dim_set), dim=1, keepdim=True) 21 | return LHS+u**3+V*u#ux+uy#uxx+uyy 22 | 23 | # def RHS_pde(x): 24 | # # return -torch.sin(x) 25 | # # return 2*torch.cos(x)-x*torch.sin(x) 26 | # return torch.sin(x)# 2*torch.cos(x**2)-4*x**2*torch.sin(x**2) #torch.sin(x) 27 | # 28 | # def true_solution(x): 29 | # return x-torch.sin(x)#torch.sin(x**2)#torch.sin(x**2)#torch.sin(x)+x # -1*x+0*x**2# 30 | 31 | def RHS_pde(x): 32 | # return 3 * torch.cos(2 * x[:, 0:1] + x[:, 1:2]) 33 | # return 3 * torch.cos(2 * x[:, 0:1] + x[:, 1:2]) + 3*x[:, 0:1] 34 | # return 3 * torch.exp(2 * x[:, 0:1] + x[:, 1:2]) 35 | # dim = x.size(1) 36 | bs = x.size(0) 37 | return torch.zeros(bs, 1).cuda() 38 | # dim = x.size(1) 39 | # coefficient = 2 * math.pi * torch.ones(1, dim).cuda() 40 | # coefficient[:, 0] = 1 41 | # print(x.size(), coefficient.size(), coefficient) 42 | # return -dim*math.pi*torch.cos(torch.sum(x * coefficient, dim=1, keepdim=True)) 43 | 44 | def true_solution(x): 45 | # return torch.sin(2*x[:,0:1]+x[:,1:2]) 46 | # return torch.sin(2 * x[:, 0:1] + x[:, 1:2]) + 1.5*x[:, 0:1]**2 47 | dim = x.size(1) 48 | # coefficient[:,0] = 1 49 | # print(x.size(), coefficient.size(), coefficient) 50 | return exp(1/dim*torch.sum(cos(x), dim=1, keepdim=True))/(c) 51 | 52 | 53 | 54 | # def RHS_pde(x): 55 | # # return -torch.sin(x) 56 | # # return 2*torch.cos(x)-x*torch.sin(x) 57 | # return torch.sin(x)+2# 2*torch.cos(x**2)-4*x**2*torch.sin(x**2) #torch.sin(x) 58 | # 59 | # def true_solution(x): 60 | # return x-torch.sin(x)+x**2 61 | 62 | # unary_functions = [lambda x: 0*x**2, 63 | # lambda x: 1+0*x**2, 64 | # # lambda x: 5+0*x**2, 65 | # lambda x: x+0*x**2, 66 | # # lambda x: -x+0*x**2, 67 | # lambda x: x**2, 68 | # lambda x: x**3, 69 | # lambda x: x**4, 70 | # # lambda x: x**5, 71 | # torch.exp, 72 | # torch.sin, 73 | # torch.cos,] 74 | # # torch.erf, 75 | # # lambda x: torch.exp(-x**2/2)] 76 | 77 | unary_functions = [lambda x: 0*x**2, 78 | lambda x: 1+0*x**2, 79 | lambda x: x+0*x**2, 80 | lambda x: x**2, 81 | lambda x: x**3, 82 | lambda x: x**4, 83 | torch.exp, 84 | torch.sin, 85 | torch.cos,] 86 | 87 | binary_functions = [lambda x,y: x+y, 88 | lambda x,y: x*y, 89 | lambda x,y: x-y] 90 | 91 | 92 | unary_functions_str = ['({}*(0)+{})', 93 | '({}*(1)+{})', 94 | # '5', 95 | '({}*{}+{})', 96 | # '-{}', 97 | '({}*({})**2+{})', 98 | '({}*({})**3+{})', 99 | '({}*({})**4+{})', 100 | # '({})**5', 101 | '({}*exp({})+{})', 102 | '({}*sin({})+{})', 103 | '({}*cos({})+{})',] 104 | # 'ref({})', 105 | # 'exp(-({})**2/2)'] 106 | 107 | unary_functions_str_leaf= ['(0)', 108 | '(1)', 109 | # '5', 110 | '({})', 111 | # '-{}', 112 | '(({})**2)', 113 | '(({})**3)', 114 | '(({})**4)', 115 | # '({})**5', 116 | '(exp({}))', 117 | '(sin({}))', 118 | '(cos({}))',] 119 | 120 | 121 | binary_functions_str = ['(({})+({}))', 122 | '(({})*({}))', 123 | '(({})-({}))'] 124 | 125 | if __name__ == '__main__': 126 | batch_size = 200 127 | left = -1 128 | right = 1 129 | points = (torch.rand(batch_size, 1)) * (right - left) + left 130 | x = torch.autograd.Variable(points.cuda(), requires_grad=True) 131 | function = true_solution 132 | 133 | ''' 134 | PDE loss 135 | ''' 136 | LHS = LHS_pde(function(x), x) 137 | RHS = RHS_pde(x) 138 | pde_loss = torch.nn.functional.mse_loss(LHS, RHS) 139 | 140 | ''' 141 | boundary loss 142 | ''' 143 | bc_points = torch.FloatTensor([[left], [right]]).cuda() 144 | bc_value = true_solution(bc_points) 145 | bd_loss = torch.nn.functional.mse_loss(function(bc_points), bc_value) 146 | 147 | print('pde loss: {} -- boundary loss: {}'.format(pde_loss.item(), bd_loss.item())) -------------------------------------------------------------------------------- /fex/Schrodinger/scripts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | 5 | 6 | for i in range(2): 7 | for dim in [6, 12, 18, 24, 30]: 8 | gpu = random.randint(1,9) 9 | os.system('screen python controller_cubic_sh_firstdeflation_thenintegral.py --epoch 1000 --bs 10 --greedy 0.1 --gpu '+str(gpu)+' --ckpt Firstdeflat_thenintegral_epoch1k_Dim'+str(dim)+' --tree depth2_sub --random_step 3 --lr 0.002 --dim '+str(dim)+' --base 200000 --left -1 --right 1 --domainbs 2000 --intbs 10000') 10 | time.sleep(100) -------------------------------------------------------------------------------- /fex/Schrodinger/tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from collections import defaultdict 4 | import collections 5 | from datetime import datetime 6 | import os 7 | import json 8 | import logging 9 | 10 | import numpy as np 11 | import torch 12 | from torch.autograd import Variable 13 | 14 | ########################## 15 | # Torch 16 | ########################## 17 | 18 | def detach(h): 19 | if type(h) == Variable: 20 | return Variable(h.data) 21 | else: 22 | return tuple(detach(v) for v in h) 23 | 24 | def get_variable(inputs, cuda=False, **kwargs): 25 | if type(inputs) in [list, np.ndarray]: 26 | inputs = torch.Tensor(inputs) 27 | if cuda: 28 | out = Variable(inputs.cuda(), **kwargs) 29 | else: 30 | out = Variable(inputs, **kwargs) 31 | return out 32 | 33 | def update_lr(optimizer, lr): 34 | for param_group in optimizer.param_groups: 35 | param_group['lr'] = lr 36 | 37 | def batchify(data, bsz, use_cuda): 38 | # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py 39 | nbatch = data.size(0) // bsz 40 | data = data.narrow(0, 0, nbatch * bsz) 41 | data = data.view(bsz, -1).t().contiguous() 42 | if use_cuda: 43 | data = data.cuda() 44 | return data 45 | 46 | 47 | ########################## 48 | # ETC 49 | ########################## 50 | 51 | Node = collections.namedtuple('Node', ['id', 'name']) 52 | 53 | 54 | class keydefaultdict(defaultdict): 55 | def __missing__(self, key): 56 | if self.default_factory is None: 57 | raise KeyError(key) 58 | else: 59 | ret = self[key] = self.default_factory(key) 60 | return ret 61 | 62 | 63 | def to_item(x): 64 | """Converts x, possibly scalar and possibly tensor, to a Python scalar.""" 65 | if isinstance(x, (float, int)): 66 | return x 67 | 68 | if float(torch.__version__[0:3]) < 0.4: 69 | assert (x.dim() == 1) and (len(x) == 1) 70 | return x[0] 71 | 72 | return x.item() 73 | 74 | 75 | def get_logger(name=__file__, level=logging.INFO): 76 | logger = logging.getLogger(name) 77 | 78 | if getattr(logger, '_init_done__', None): 79 | logger.setLevel(level) 80 | return logger 81 | 82 | logger._init_done__ = True 83 | logger.propagate = False 84 | logger.setLevel(level) 85 | 86 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 87 | handler = logging.StreamHandler() 88 | handler.setFormatter(formatter) 89 | handler.setLevel(0) 90 | 91 | del logger.handlers[:] 92 | logger.addHandler(handler) 93 | 94 | return logger 95 | 96 | 97 | logger = get_logger() 98 | 99 | 100 | def prepare_dirs(args): 101 | """Sets the directories for the model, and creates those directories. 102 | 103 | Args: 104 | args: Parsed from `argparse` in the `config` module. 105 | """ 106 | if args.load_path: 107 | if args.load_path.startswith(args.log_dir): 108 | args.model_dir = args.load_path 109 | else: 110 | if args.load_path.startswith(args.dataset): 111 | args.model_name = args.load_path 112 | else: 113 | args.model_name = "{}_{}".format(args.dataset, args.load_path) 114 | else: 115 | args.model_name = "{}_{}".format(args.dataset, get_time()) 116 | 117 | if not hasattr(args, 'model_dir'): 118 | args.model_dir = os.path.join(args.log_dir, args.model_name) 119 | args.data_path = os.path.join(args.data_dir, args.dataset) 120 | 121 | for path in [args.log_dir, args.data_dir, args.model_dir]: 122 | if not os.path.exists(path): 123 | makedirs(path) 124 | 125 | def get_time(): 126 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 127 | 128 | def save_args(args): 129 | param_path = os.path.join(args.model_dir, "params.json") 130 | 131 | logger.info("[*] MODEL dir: %s" % args.model_dir) 132 | logger.info("[*] PARAM path: %s" % param_path) 133 | 134 | with open(param_path, 'w') as fp: 135 | json.dump(args.__dict__, fp, indent=4, sort_keys=True) 136 | 137 | def save_dag(args, dag, name): 138 | save_path = os.path.join(args.model_dir, name) 139 | logger.info("[*] Save dag : {}".format(save_path)) 140 | json.dump(dag, open(save_path, 'w')) 141 | 142 | def load_dag(args): 143 | load_path = os.path.join(args.dag_path) 144 | logger.info("[*] Load dag : {}".format(load_path)) 145 | with open(load_path) as f: 146 | dag = json.load(f) 147 | dag = {int(k): [Node(el[0], el[1]) for el in v] for k, v in dag.items()} 148 | save_dag(args, dag, "dag.json") 149 | draw_network(dag, os.path.join(args.model_dir, "dag.png")) 150 | return dag 151 | 152 | def makedirs(path): 153 | if not os.path.exists(path): 154 | logger.info("[*] Make directories : {}".format(path)) 155 | os.makedirs(path) 156 | 157 | def remove_file(path): 158 | if os.path.exists(path): 159 | logger.info("[*] Removed: {}".format(path)) 160 | os.remove(path) 161 | 162 | def backup_file(path): 163 | root, ext = os.path.splitext(path) 164 | new_path = "{}.backup_{}{}".format(root, get_time(), ext) 165 | 166 | os.rename(path, new_path) 167 | logger.info("[*] {} has backup: {}".format(path, new_path)) 168 | -------------------------------------------------------------------------------- /fex/Schrodinger/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | # from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /fex/Schrodinger/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /fex/Schrodinger/utils/images/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Schrodinger/utils/images/cifar.png -------------------------------------------------------------------------------- /fex/Schrodinger/utils/images/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Schrodinger/utils/images/imagenet.png -------------------------------------------------------------------------------- /fex/Schrodinger/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | def append(self, numbers): 61 | assert len(self.names) == len(numbers), 'Numbers do not match names' 62 | for index, num in enumerate(numbers): 63 | if isinstance(num, int): 64 | self.file.write("{}".format(num)) 65 | elif isinstance(num, str): 66 | self.file.write("{}".format(num)) 67 | else: 68 | self.file.write("{0:.8f}".format(num)) 69 | self.file.write('\t') 70 | self.numbers[self.names[index]].append(num) 71 | self.file.write('\n') 72 | self.file.flush() 73 | 74 | def plot(self, names=None): 75 | names = self.names if names == None else names 76 | numbers = self.numbers 77 | for _, name in enumerate(names): 78 | x = np.arange(len(numbers[name])) 79 | plt.plot(x, np.asarray(numbers[name])) 80 | plt.legend([self.title + '(' + name + ')' for name in names]) 81 | plt.grid(True) 82 | 83 | def close(self): 84 | if self.file is not None: 85 | self.file.close() 86 | 87 | class LoggerMonitor(object): 88 | '''Load and visualize multiple logs.''' 89 | def __init__ (self, paths): 90 | '''paths is a distionary with {name:filepath} pair''' 91 | self.loggers = [] 92 | for title, path in paths.items(): 93 | logger = Logger(path, title=title, resume=True) 94 | self.loggers.append(logger) 95 | 96 | def plot(self, names=None): 97 | plt.figure() 98 | plt.subplot(121) 99 | legend_text = [] 100 | for logger in self.loggers: 101 | legend_text += plot_overlap(logger, names) 102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 103 | plt.grid(True) 104 | 105 | if __name__ == '__main__': 106 | # # Example 107 | # logger = Logger('test.txt') 108 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 109 | 110 | # length = 100 111 | # t = np.arange(length) 112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 115 | 116 | # for i in range(0, length): 117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 118 | # logger.plot() 119 | 120 | # Example: logger monitor 121 | paths = { 122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 125 | } 126 | 127 | field = ['Valid Acc.'] 128 | 129 | monitor = LoggerMonitor(paths) 130 | monitor.plot(names=field) 131 | savefig('test.eps') -------------------------------------------------------------------------------- /fex/Schrodinger/utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /fex/Schrodinger/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /fexrl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fexrl.png -------------------------------------------------------------------------------- /nn/Conservationlaw/scrpts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | 5 | for i in range(6): 6 | for dim in [6, 12, 18, 24, 30]: 7 | gpu = random.randint(1,9) 8 | os.system('screen python train.py --gpu-id '+str(gpu)+' --trainbs 5000 --bdbs 1000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts_ReLU_pidiv4/dim'+str(dim)+'_trial'+str(i)+' --weight 100 --left -1 --right 1') 9 | time.sleep(200) -------------------------------------------------------------------------------- /nn/Conservationlaw/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim as optim 13 | import math 14 | import numpy as np 15 | from torch import sin, cos, exp 16 | import torch.nn.functional as F 17 | from utils import Logger, AverageMeter, mkdir_p 18 | from numpy.polynomial.legendre import leggauss 19 | # torch.set_default_tensor_type(torch.DoubleTensor) 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch Density Function Training') 22 | # Datasets 23 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 24 | help='number of data loading workers (default: 4)') 25 | # Optimization options 26 | parser.add_argument('--iters', default=30000, type=int, metavar='N', help='number of total iterations to run') 27 | parser.add_argument('--dim', default=2, type=int) 28 | parser.add_argument('--trainbs', default=10, type=int, metavar='N', help='train batchsize') 29 | # parser.add_argument('--bdbs', default=10, type=int, metavar='N', help='train batchsize') 30 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate') 31 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 32 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 33 | parser.add_argument('--wd', default=0, type=float, metavar='W', help='weight decay') 34 | parser.add_argument('--function', default='resnet', type=str, help='function to approximate') 35 | parser.add_argument('--optim', default='adam', type=str, help='function to approximate') 36 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint') 37 | parser.add_argument('--bdbs', default=1000, type=int) 38 | parser.add_argument('--weight', default=100, type=float, help='weight') 39 | 40 | # Checkpoints 41 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='path to save checkpoint (default: checkpoint)') 42 | 43 | # Miscs 44 | parser.add_argument('--manualSeed', type=int, help='manual seed') 45 | # parser.add_argument('--exp', type=int, default='0', help='use exp last layer') 46 | 47 | #Device options 48 | parser.add_argument('--gpu-id', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES') 49 | 50 | # region 51 | parser.add_argument('--left', default=-3, type=float, help='left boundary of square region') 52 | parser.add_argument('--right', default=3, type=float, help='right boundary of square region') 53 | 54 | args = parser.parse_args() 55 | state = {k: v for k, v in args._get_kwargs()} 56 | 57 | print(state) 58 | 59 | # Use CUDA 60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 61 | use_cuda = torch.cuda.is_available() 62 | 63 | # Random seed 64 | if args.manualSeed is None: 65 | args.manualSeed = random.randint(1, 10000) 66 | random.seed(args.manualSeed) 67 | torch.manual_seed(args.manualSeed) 68 | if use_cuda: 69 | torch.cuda.manual_seed_all(args.manualSeed) 70 | 71 | def get_boundary(num_pts, dim): 72 | bd_pts = (torch.rand(num_pts, dim).cuda()) * (args.right - args.left) + args.left 73 | bd_pts[:, 0] = 0 74 | return bd_pts 75 | 76 | scale = math.pi/4#math.pi 77 | 78 | def LHS_pde(u, x, dim_set): 79 | 80 | v = torch.ones(u.shape).cuda() 81 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0] 82 | coefficient = torch.ones(1, dim_set).cuda() 83 | coefficient = -coefficient 84 | coefficient[:,0] = (dim_set-1)*scale 85 | ux_mul_coef = ux*coefficient 86 | LHS = torch.sum(ux_mul_coef, 1, keepdim=True) 87 | return LHS 88 | 89 | def RHS_pde(x): 90 | bs = x.size(0) 91 | return torch.zeros(bs, 1).cuda() 92 | 93 | def true_solution(x): 94 | dim = x.size(1) 95 | coefficient = torch.ones(1, dim).cuda()*scale 96 | coefficient[:, 0] = 1 97 | return torch.sin(torch.sum(x*coefficient, dim=1, keepdim=True)) 98 | 99 | 100 | # the input dimension is modified to 2 101 | class ResNet(nn.Module): 102 | def __init__(self, m): 103 | super(ResNet, self).__init__() 104 | self.fc1 = nn.Linear(args.dim, m) 105 | self.fc2 = nn.Linear(m, m) 106 | 107 | self.fc3 = nn.Linear(m, m) 108 | self.fc4 = nn.Linear(m, m) 109 | 110 | self.fc5 = nn.Linear(m, m) 111 | self.fc6 = nn.Linear(m, m) 112 | 113 | self.outlayer = nn.Linear(m, 1, bias=True) 114 | 115 | # # initialize the network bias 116 | for idx, m in enumerate(self.modules()): 117 | if isinstance(m, nn.Linear): 118 | m.bias.data.zero_() 119 | m.weight.data.normal_(0.0, 0.1) 120 | # if args.use_bias != 0 and m.bias.data.shape[0] == 1: 121 | # m.bias.data.fill_(args.use_bias) 122 | # # print(m.weight.data.shape, m.bias.data.shape[0], m.bias.data) 123 | 124 | def forward(self, x): 125 | 126 | x1 = (x -args.left)*2/(args.right -args.left) + (-1) 127 | s = torch.nn.functional.pad(x1, (0, m - args.dim)) 128 | 129 | y = self.fc1(x1) 130 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) # a RELU(X) + b RELU(X)^2 131 | y = self.fc2(y) 132 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)# 133 | y = y + s 134 | 135 | s = y 136 | y = self.fc3(y) 137 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 138 | y = self.fc4(y) 139 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 140 | y = y + s 141 | 142 | s = y 143 | y = self.fc5(y) 144 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 145 | y = self.fc6(y) 146 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 147 | y = y + s 148 | 149 | output = self.outlayer(y) 150 | # output = torch.exp(output) 151 | return output 152 | 153 | ''' 154 | HyperParams Setting for Network 155 | ''' 156 | m = 50 # number of hidden size 157 | 158 | # for ResNet 159 | Ix = torch.zeros([1,m]).cuda() 160 | Ix[0,0] = 1 161 | 162 | 163 | def main(): 164 | 165 | if not os.path.isdir(args.checkpoint): 166 | mkdir_p(args.checkpoint) 167 | 168 | model = ResNet(m) 169 | 170 | model = model.cuda() 171 | cudnn.benchmark = True 172 | 173 | if not os.path.isdir(args.checkpoint): 174 | mkdir_p(args.checkpoint) 175 | 176 | with open(args.checkpoint + "/Config.txt", 'w+') as f: 177 | for (k, v) in args._get_kwargs(): 178 | f.write(k + ' : ' + str(v) + '\n') 179 | 180 | print('Total params: %.2f' % (sum(p.numel() for p in model.parameters()))) 181 | 182 | """ 183 | Define Residual Methods and Optimizer 184 | """ 185 | criterion = nn.MSELoss() 186 | if args.optim == 'SGD': 187 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 188 | else: 189 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.wd) 190 | # Resume 191 | title = '' 192 | 193 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 194 | logger.set_names(['Learning Rate', 'Losses', 'pdeloss', 'bdloss']) 195 | 196 | # Train and val 197 | for iter in range(0, args.iters): 198 | 199 | lr = cosine_lr(optimizer, args.lr, iter, args.iters) 200 | 201 | losses, pdeloss, intloss = train(model, criterion, optimizer, use_cuda, iter, lr) 202 | logger.append([lr, losses, pdeloss, intloss]) 203 | 204 | # save model 205 | save_checkpoint({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, checkpoint=args.checkpoint) 206 | 207 | numerators = [] 208 | denominators = [] 209 | 210 | for i in range(1000): 211 | print(i) 212 | t = torch.rand(100000, 1).cuda() 213 | x1 = (torch.rand(100000, args.dim - 1).cuda()) * (args.right - args.left) + args.left 214 | x = torch.cat((t, x1), 1) 215 | # print(true_solution(x).size(), model(x).size()) 216 | sq_de = torch.mean((true_solution(x)) ** 2) 217 | sq_nu = torch.mean((true_solution(x) - model(x)) ** 2) 218 | numerators.append(sq_nu.item()) 219 | denominators.append(sq_de.item()) 220 | 221 | relative_l2 = math.sqrt(sum(numerators)) / math.sqrt(sum(denominators)) 222 | print('relative l2 error: ', relative_l2) 223 | logger.append(['relative_l2', relative_l2, 0, 0]) 224 | 225 | # logger.append([0, 0, relative_l2, 0, 0]) 226 | 227 | logger.close() 228 | 229 | 230 | def train(model, criterion, optimizer, use_cuda, iter, lr): 231 | 232 | # switch to train mode 233 | model.train() 234 | end = time.time() 235 | ''' 236 | points sampling 237 | ''' 238 | # the range is [0,1] --> [left, right] 239 | t = torch.rand(args.trainbs, 1).cuda() 240 | x1 = (torch.rand(args.trainbs, args.dim-1).cuda())*(args.right-args.left)+args.left 241 | x = torch.cat((t,x1), 1) 242 | # print(x) 243 | x.requires_grad = True 244 | 245 | bd_pts = get_boundary(args.bdbs, args.dim) 246 | bc_true = true_solution(bd_pts) 247 | bd_nn = model(bd_pts) 248 | bd_error = torch.nn.functional.mse_loss(bc_true, bd_nn) 249 | function_error = torch.nn.functional.mse_loss(LHS_pde(model(x), x, args.dim), RHS_pde(x)) 250 | loss = function_error + args.weight * bd_error 251 | 252 | # compute gradient and do SGD step 253 | optimizer.zero_grad() 254 | loss.backward() 255 | optimizer.step() 256 | 257 | # measure elapsed time 258 | batch_time = time.time() - end 259 | suffix = '{iter:.1f} {lr:.8f}| Batch: {bt:.3f}s | Loss: {loss:.8f} | pdeloss: {pdeloss:.8f} | bd loss {bd: .8f} |'.format( 260 | bt=batch_time, loss=loss.item(), iter=iter, lr=lr, pdeloss=function_error.item(), bd=bd_error.item()) 261 | print(suffix) 262 | return loss.item(), function_error.item(), bd_error.item() 263 | 264 | def save_checkpoint(state, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 265 | filepath = os.path.join(checkpoint, filename) 266 | torch.save(state, filepath) 267 | 268 | 269 | def cosine_lr(opt, base_lr, e, epochs): 270 | lr = 0.5 * base_lr * (math.cos(math.pi * e / epochs) + 1) 271 | for param_group in opt.param_groups: 272 | param_group["lr"] = lr 273 | return lr 274 | 275 | if __name__ == '__main__': 276 | main() 277 | -------------------------------------------------------------------------------- /nn/Conservationlaw/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | # from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /nn/Conservationlaw/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /nn/Conservationlaw/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | def append(self, numbers): 61 | assert len(self.names) == len(numbers), 'Numbers do not match names' 62 | for index, num in enumerate(numbers): 63 | if isinstance(num, int): 64 | self.file.write("{}".format(num)) 65 | elif isinstance(num, str): 66 | self.file.write("{}".format(num)) 67 | else: 68 | self.file.write("{0:.8f}".format(num)) 69 | self.file.write('\t') 70 | self.numbers[self.names[index]].append(num) 71 | self.file.write('\n') 72 | self.file.flush() 73 | 74 | def plot(self, names=None): 75 | names = self.names if names == None else names 76 | numbers = self.numbers 77 | for _, name in enumerate(names): 78 | x = np.arange(len(numbers[name])) 79 | plt.plot(x, np.asarray(numbers[name])) 80 | plt.legend([self.title + '(' + name + ')' for name in names]) 81 | plt.grid(True) 82 | 83 | def close(self): 84 | if self.file is not None: 85 | self.file.close() 86 | 87 | class LoggerMonitor(object): 88 | '''Load and visualize multiple logs.''' 89 | def __init__ (self, paths): 90 | '''paths is a distionary with {name:filepath} pair''' 91 | self.loggers = [] 92 | for title, path in paths.items(): 93 | logger = Logger(path, title=title, resume=True) 94 | self.loggers.append(logger) 95 | 96 | def plot(self, names=None): 97 | plt.figure() 98 | plt.subplot(121) 99 | legend_text = [] 100 | for logger in self.loggers: 101 | legend_text += plot_overlap(logger, names) 102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 103 | plt.grid(True) 104 | 105 | if __name__ == '__main__': 106 | # # Example 107 | # logger = Logger('test.txt') 108 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 109 | 110 | # length = 100 111 | # t = np.arange(length) 112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 115 | 116 | # for i in range(0, length): 117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 118 | # logger.plot() 119 | 120 | # Example: logger monitor 121 | paths = { 122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 125 | } 126 | 127 | field = ['Valid Acc.'] 128 | 129 | monitor = LoggerMonitor(paths) 130 | monitor.plot(names=field) 131 | savefig('test.eps') -------------------------------------------------------------------------------- /nn/Conservationlaw/utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /nn/Conservationlaw/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /nn/Poisson/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /nn/Poisson/.idea/conservativelaw.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /nn/Poisson/.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 26 | -------------------------------------------------------------------------------- /nn/Poisson/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /nn/Poisson/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /nn/Poisson/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /nn/Poisson/scrpts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | 5 | # for i in range(6): 6 | # for dim in [10, 20, 30, 40, 50]: 7 | # gpu = random.randint(1,7) 8 | # os.system('screen python train.py --gpu-id '+str(gpu)+' --trainbs 5000 --bdbs 1000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts/dim'+str(dim)+'_trial'+str(i)+' --weight 100 --left -1 --right 1') 9 | # # print(name) 10 | # time.sleep(100) 11 | 12 | 13 | # for i in range(6): 14 | # for dim in [10, 20, 30, 40, 50]: 15 | # gpu = random.randint(0,9) 16 | # os.system('screen python train.py --gpu-id '+str(gpu)+' --trainbs 10000 --bdbs 5000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts_largerbs_relu2/dim'+str(dim)+'_trial'+str(i)+' --weight 100 --left -1 --right 1') 17 | # # print(name) 18 | # time.sleep(100) 19 | 20 | for i in range(6): 21 | for dim in [10, 20, 30, 40, 50]: 22 | gpu = random.randint(0,9) 23 | os.system('screen python train.py --gpu-id '+str(gpu)+' --trainbs 5000 --bdbs 1000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts_bs5k_1k_relu2/dim'+str(dim)+'_trial'+str(i)+' --weight 100 --left -1 --right 1') 24 | # print(name) 25 | time.sleep(100) 26 | 27 | 28 | # for i in range(1): 29 | # for dim in [30]: 30 | # gpu = random.randint(1,3) 31 | # os.system('python train.py --gpu-id '+str(gpu)+' --trainbs 10000 --bdbs 5000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts/test --weight 100 --left -1 --right 1') 32 | # print(name) 33 | # time.sleep(100) 34 | -------------------------------------------------------------------------------- /nn/Poisson/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim as optim 13 | import math 14 | import numpy as np 15 | from torch import sin, cos, exp 16 | import torch.nn.functional as F 17 | from utils import Logger, AverageMeter, mkdir_p 18 | from numpy.polynomial.legendre import leggauss 19 | # torch.set_default_tensor_type(torch.DoubleTensor) 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch Density Function Training') 22 | # Datasets 23 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 24 | help='number of data loading workers (default: 4)') 25 | # Optimization options 26 | parser.add_argument('--iters', default=30000, type=int, metavar='N', help='number of total iterations to run') 27 | parser.add_argument('--dim', default=2, type=int) 28 | parser.add_argument('--trainbs', default=10, type=int, metavar='N', help='train batchsize') 29 | # parser.add_argument('--bdbs', default=10, type=int, metavar='N', help='train batchsize') 30 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate') 31 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 32 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 33 | parser.add_argument('--wd', default=0, type=float, metavar='W', help='weight decay') 34 | parser.add_argument('--function', default='resnet', type=str, help='function to approximate') 35 | parser.add_argument('--optim', default='adam', type=str, help='function to approximate') 36 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint') 37 | parser.add_argument('--bdbs', default=1000, type=int) 38 | parser.add_argument('--weight', default=100, type=float, help='weight') 39 | 40 | # Checkpoints 41 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='path to save checkpoint (default: checkpoint)') 42 | 43 | # Miscs 44 | parser.add_argument('--manualSeed', type=int, help='manual seed') 45 | # parser.add_argument('--exp', type=int, default='0', help='use exp last layer') 46 | 47 | #Device options 48 | parser.add_argument('--gpu-id', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES') 49 | 50 | # region 51 | parser.add_argument('--left', default=-3, type=float, help='left boundary of square region') 52 | parser.add_argument('--right', default=3, type=float, help='right boundary of square region') 53 | 54 | args = parser.parse_args() 55 | state = {k: v for k, v in args._get_kwargs()} 56 | 57 | print(state) 58 | 59 | # Use CUDA 60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 61 | use_cuda = torch.cuda.is_available() 62 | 63 | # Random seed 64 | if args.manualSeed is None: 65 | args.manualSeed = random.randint(1, 10000) 66 | random.seed(args.manualSeed) 67 | torch.manual_seed(args.manualSeed) 68 | if use_cuda: 69 | torch.cuda.manual_seed_all(args.manualSeed) 70 | 71 | def get_boundary(num_pts, dim): 72 | 73 | bd_pts = (torch.rand(num_pts, dim).cuda()) * (args.right - args.left) + args.left 74 | 75 | num_half = num_pts//2 76 | xlst = torch.arange(0, num_half) 77 | ylst = torch.randint(0, dim, (num_half,)) 78 | bd_pts[xlst, ylst] = args.left 79 | 80 | xlst = torch.arange(num_half, num_pts) 81 | ylst = torch.randint(0, dim, (num_half,)) 82 | bd_pts[xlst, ylst] = args.right 83 | 84 | return bd_pts 85 | 86 | def LHS_pde(u, x, dim_set): 87 | 88 | v = torch.ones(u.shape).cuda() 89 | bs = x.size(0) 90 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0] 91 | uxx = torch.zeros(bs, dim_set).cuda() 92 | for i in range(dim_set): 93 | ux_tem = ux[:, i:i+1] 94 | uxx_tem = torch.autograd.grad(ux_tem, x, grad_outputs=v, create_graph=True)[0] 95 | uxx[:, i] = uxx_tem[:, i] 96 | LHS = -torch.sum(uxx, dim=1, keepdim=True) 97 | return LHS 98 | 99 | def RHS_pde(x): 100 | bs = x.size(0) 101 | dim = x.size(1) 102 | return -dim*torch.ones(bs, 1).cuda() 103 | 104 | def true_solution(x): 105 | return 0.5*torch.sum(x**2, dim=1, keepdim=True) 106 | 107 | 108 | # the input dimension is modified to 2 109 | class ResNet(nn.Module): 110 | def __init__(self, m): 111 | super(ResNet, self).__init__() 112 | self.fc1 = nn.Linear(args.dim, m) 113 | self.fc2 = nn.Linear(m, m) 114 | 115 | self.fc3 = nn.Linear(m, m) 116 | self.fc4 = nn.Linear(m, m) 117 | 118 | self.fc5 = nn.Linear(m, m) 119 | self.fc6 = nn.Linear(m, m) 120 | 121 | self.outlayer = nn.Linear(m, 1, bias=True) 122 | 123 | # # initialize the network bias 124 | for idx, m in enumerate(self.modules()): 125 | if isinstance(m, nn.Linear): 126 | m.bias.data.zero_() 127 | m.weight.data.normal_(0.0, 0.1) 128 | # if args.use_bias != 0 and m.bias.data.shape[0] == 1: 129 | # m.bias.data.fill_(args.use_bias) 130 | # # print(m.weight.data.shape, m.bias.data.shape[0], m.bias.data) 131 | 132 | def forward(self, x): 133 | 134 | x1 = (x -args.left)*2/(args.right -args.left) + (-1) 135 | s = torch.nn.functional.pad(x1, (0, m - args.dim)) 136 | 137 | y = self.fc1(x1) 138 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) # a RELU(X) + b RELU(X)^2 139 | y = self.fc2(y) 140 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)# 141 | y = y + s 142 | 143 | s = y 144 | y = self.fc3(y) 145 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 146 | y = self.fc4(y) 147 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 148 | y = y + s 149 | 150 | s = y 151 | y = self.fc5(y) 152 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 153 | y = self.fc6(y) 154 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 155 | y = y + s 156 | 157 | output = self.outlayer(y) 158 | # output = torch.exp(output) 159 | return output 160 | 161 | 162 | # the input dimension is modified to 2 163 | class ResNet(nn.Module): 164 | def __init__(self, m): 165 | super(ResNet, self).__init__() 166 | self.fc1 = nn.Linear(args.dim, m) 167 | self.fc2 = nn.Linear(m, m) 168 | 169 | self.fc3 = nn.Linear(m, m) 170 | self.fc4 = nn.Linear(m, m) 171 | 172 | self.fc5 = nn.Linear(m, m) 173 | self.fc6 = nn.Linear(m, m) 174 | 175 | self.outlayer = nn.Linear(m, 1, bias=True) 176 | 177 | # # initialize the network bias 178 | for idx, m in enumerate(self.modules()): 179 | if isinstance(m, nn.Linear): 180 | m.bias.data.zero_() 181 | m.weight.data.normal_(0.0, 0.1) 182 | # if args.use_bias != 0 and m.bias.data.shape[0] == 1: 183 | # m.bias.data.fill_(args.use_bias) 184 | # # print(m.weight.data.shape, m.bias.data.shape[0], m.bias.data) 185 | 186 | def forward(self, x): 187 | 188 | x1 = (x -args.left)*2/(args.right -args.left) + (-1) 189 | s = torch.nn.functional.pad(x1, (0, m - args.dim)) 190 | 191 | y = self.fc1(x1) 192 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) # a RELU(X) + b RELU(X)^2 193 | y = self.fc2(y) 194 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)# 195 | y = y + s 196 | 197 | s = y 198 | y = self.fc3(y) 199 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 200 | y = self.fc4(y) 201 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 202 | y = y + s 203 | 204 | s = y 205 | y = self.fc5(y) 206 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 207 | y = self.fc6(y) 208 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 209 | y = y + s 210 | 211 | output = self.outlayer(y) 212 | # output = torch.exp(output) 213 | return output 214 | 215 | 216 | ''' 217 | HyperParams Setting for Network 218 | ''' 219 | m = 50 # number of hidden size 220 | 221 | # for ResNet 222 | Ix = torch.zeros([1,m]).cuda() 223 | Ix[0,0] = 1 224 | 225 | 226 | def main(): 227 | 228 | if not os.path.isdir(args.checkpoint): 229 | mkdir_p(args.checkpoint) 230 | 231 | model = ResNet(m) 232 | 233 | model = model.cuda() 234 | cudnn.benchmark = True 235 | 236 | if not os.path.isdir(args.checkpoint): 237 | mkdir_p(args.checkpoint) 238 | 239 | with open(args.checkpoint + "/Config.txt", 'w+') as f: 240 | for (k, v) in args._get_kwargs(): 241 | f.write(k + ' : ' + str(v) + '\n') 242 | 243 | print('Total params: %.2f' % (sum(p.numel() for p in model.parameters()))) 244 | 245 | """ 246 | Define Residual Methods and Optimizer 247 | """ 248 | criterion = nn.MSELoss() 249 | if args.optim == 'SGD': 250 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 251 | else: 252 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.wd) 253 | # Resume 254 | title = '' 255 | 256 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 257 | logger.set_names(['Learning Rate', 'Losses', 'pdeloss', 'bdloss']) 258 | 259 | # Train and val 260 | for iter in range(0, args.iters): 261 | 262 | lr = cosine_lr(optimizer, args.lr, iter, args.iters) 263 | 264 | losses, pdeloss, intloss = train(model, criterion, optimizer, use_cuda, iter, lr) 265 | logger.append([lr, losses, pdeloss, intloss]) 266 | 267 | # save model 268 | save_checkpoint({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, checkpoint=args.checkpoint) 269 | 270 | numerators = [] 271 | denominators = [] 272 | 273 | for i in range(1000): 274 | print(i) 275 | x = (torch.rand(100000, args.dim).cuda()) * (args.right - args.left) + args.left 276 | # print(true_solution(x).size(), model(x).size()) 277 | sq_de = torch.mean((true_solution(x)) ** 2) 278 | sq_nu = torch.mean((true_solution(x) - model(x)) ** 2) 279 | numerators.append(sq_nu.item()) 280 | denominators.append(sq_de.item()) 281 | 282 | relative_l2 = math.sqrt(sum(numerators)) / math.sqrt(sum(denominators)) 283 | print('relative l2 error: ', relative_l2) 284 | logger.append(['relative_l2', relative_l2, 0, 0]) 285 | 286 | # logger.append([0, 0, relative_l2, 0, 0]) 287 | 288 | logger.close() 289 | 290 | 291 | def train(model, criterion, optimizer, use_cuda, iter, lr): 292 | 293 | # switch to train mode 294 | model.train() 295 | end = time.time() 296 | ''' 297 | points sampling 298 | ''' 299 | # the range is [0,1] --> [left, right] 300 | x = (torch.rand(args.trainbs, args.dim).cuda())*(args.right-args.left)+args.left 301 | x.requires_grad = True 302 | 303 | bd_pts = get_boundary(args.bdbs, args.dim) 304 | bc_true = true_solution(bd_pts) 305 | bd_nn = model(bd_pts) 306 | bd_error = torch.nn.functional.mse_loss(bc_true, bd_nn) 307 | function_error = torch.nn.functional.mse_loss(LHS_pde(model(x), x, args.dim), RHS_pde(x)) 308 | loss = function_error + args.weight * bd_error 309 | 310 | # compute gradient and do SGD step 311 | optimizer.zero_grad() 312 | loss.backward() 313 | optimizer.step() 314 | 315 | # measure elapsed time 316 | batch_time = time.time() - end 317 | suffix = '{iter:.1f} {lr:.8f}| Batch: {bt:.3f}s | Loss: {loss:.8f} | pdeloss: {pdeloss:.8f} | bd loss {bd: .8f} |'.format( 318 | bt=batch_time, loss=loss.item(), iter=iter, lr=lr, pdeloss=function_error.item(), bd=bd_error.item()) 319 | print(suffix) 320 | return loss.item(), function_error.item(), bd_error.item() 321 | 322 | def save_checkpoint(state, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 323 | filepath = os.path.join(checkpoint, filename) 324 | torch.save(state, filepath) 325 | 326 | 327 | def cosine_lr(opt, base_lr, e, epochs): 328 | lr = 0.5 * base_lr * (math.cos(math.pi * e / epochs) + 1) 329 | for param_group in opt.param_groups: 330 | param_group["lr"] = lr 331 | return lr 332 | 333 | if __name__ == '__main__': 334 | main() 335 | -------------------------------------------------------------------------------- /nn/Poisson/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | # from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /nn/Poisson/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /nn/Poisson/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | def append(self, numbers): 61 | assert len(self.names) == len(numbers), 'Numbers do not match names' 62 | for index, num in enumerate(numbers): 63 | if isinstance(num, int): 64 | self.file.write("{}".format(num)) 65 | elif isinstance(num, str): 66 | self.file.write("{}".format(num)) 67 | else: 68 | self.file.write("{0:.8f}".format(num)) 69 | self.file.write('\t') 70 | self.numbers[self.names[index]].append(num) 71 | self.file.write('\n') 72 | self.file.flush() 73 | 74 | def plot(self, names=None): 75 | names = self.names if names == None else names 76 | numbers = self.numbers 77 | for _, name in enumerate(names): 78 | x = np.arange(len(numbers[name])) 79 | plt.plot(x, np.asarray(numbers[name])) 80 | plt.legend([self.title + '(' + name + ')' for name in names]) 81 | plt.grid(True) 82 | 83 | def close(self): 84 | if self.file is not None: 85 | self.file.close() 86 | 87 | class LoggerMonitor(object): 88 | '''Load and visualize multiple logs.''' 89 | def __init__ (self, paths): 90 | '''paths is a distionary with {name:filepath} pair''' 91 | self.loggers = [] 92 | for title, path in paths.items(): 93 | logger = Logger(path, title=title, resume=True) 94 | self.loggers.append(logger) 95 | 96 | def plot(self, names=None): 97 | plt.figure() 98 | plt.subplot(121) 99 | legend_text = [] 100 | for logger in self.loggers: 101 | legend_text += plot_overlap(logger, names) 102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 103 | plt.grid(True) 104 | 105 | if __name__ == '__main__': 106 | # # Example 107 | # logger = Logger('test.txt') 108 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 109 | 110 | # length = 100 111 | # t = np.arange(length) 112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 115 | 116 | # for i in range(0, length): 117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 118 | # logger.plot() 119 | 120 | # Example: logger monitor 121 | paths = { 122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 125 | } 126 | 127 | field = ['Valid Acc.'] 128 | 129 | monitor = LoggerMonitor(paths) 130 | monitor.plot(names=field) 131 | savefig('test.eps') -------------------------------------------------------------------------------- /nn/Poisson/utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /nn/Poisson/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /nn/Schrodinger/scripts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | 5 | 6 | for i in range(0, 6): 7 | for dim in [6, 12, 18, 24, 30]: 8 | gpu = random.choice([3,4,7]) 9 | os.system('screen python train_integral.py --gpu-id '+str(gpu)+' --trainbs 2000 --intbs 10000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts_ReLU2_intbs/Dim'+str(dim)+'_trial'+str(i)+' --weight 1 --left -1 --right 1') 10 | time.sleep(120) -------------------------------------------------------------------------------- /nn/Schrodinger/train_integral.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim as optim 13 | import math 14 | import numpy as np 15 | from torch import sin, cos, exp 16 | import torch.nn.functional as F 17 | from utils import Logger, AverageMeter, mkdir_p 18 | from numpy.polynomial.legendre import leggauss 19 | # torch.set_default_tensor_type(torch.DoubleTensor) 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch Density Function Training') 22 | # Datasets 23 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 24 | help='number of data loading workers (default: 4)') 25 | # Optimization options 26 | parser.add_argument('--iters', default=30000, type=int, metavar='N', help='number of total iterations to run') 27 | parser.add_argument('--dim', default=2, type=int) 28 | parser.add_argument('--trainbs', default=2000, type=int, metavar='N', help='train batchsize') 29 | parser.add_argument('--intbs', default=10000, type=int) 30 | # parser.add_argument('--bdbs', default=10, type=int, metavar='N', help='train batchsize') 31 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate') 32 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 33 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 34 | parser.add_argument('--wd', default=0, type=float, metavar='W', help='weight decay') 35 | parser.add_argument('--function', default='resnet', type=str, help='function to approximate') 36 | parser.add_argument('--optim', default='adam', type=str, help='function to approximate') 37 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint') 38 | 39 | parser.add_argument('--weight', type=float, help='weight') 40 | 41 | # Checkpoints 42 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='path to save checkpoint (default: checkpoint)') 43 | 44 | # Miscs 45 | parser.add_argument('--manualSeed', type=int, help='manual seed') 46 | # parser.add_argument('--exp', type=int, default='0', help='use exp last layer') 47 | 48 | #Device options 49 | parser.add_argument('--gpu-id', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES') 50 | 51 | # region 52 | parser.add_argument('--left', default=-3, type=float, help='left boundary of square region') 53 | parser.add_argument('--right', default=3, type=float, help='right boundary of square region') 54 | 55 | args = parser.parse_args() 56 | state = {k: v for k, v in args._get_kwargs()} 57 | 58 | print(state) 59 | 60 | # Use CUDA 61 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 62 | use_cuda = torch.cuda.is_available() 63 | 64 | # Random seed 65 | if args.manualSeed is None: 66 | args.manualSeed = random.randint(1, 10000) 67 | random.seed(args.manualSeed) 68 | torch.manual_seed(args.manualSeed) 69 | if use_cuda: 70 | torch.cuda.manual_seed_all(args.manualSeed) 71 | 72 | c = 3 73 | def LHS_pde(u, x, dim_set): 74 | 75 | v = torch.ones(u.shape).cuda() 76 | bs = x.size(0) 77 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0] 78 | uxx = torch.zeros(bs, dim_set).cuda() 79 | for i in range(dim_set): 80 | ux_tem = ux[:, i].reshape([x.size()[0], 1]) 81 | uxx_tem = torch.autograd.grad(ux_tem, x, grad_outputs=v, create_graph=True)[0] 82 | uxx[:, i] = uxx_tem[:, i] 83 | 84 | LHS = -torch.sum(uxx, dim=1, keepdim=True) 85 | V = -exp(2/dim_set*torch.sum(cos(x), dim=1, keepdim=True))/(c**2)+torch.sum(sin(x)**2/(dim_set**2), dim=1, keepdim=True)-torch.sum(cos(x)/(dim_set), dim=1, keepdim=True) 86 | return LHS+u**3+V*u#ux+uy#uxx+uyy 87 | 88 | def RHS_pde(x): 89 | bs = x.size(0) 90 | return torch.zeros(bs, 1).cuda() 91 | 92 | def true_solution(x): 93 | dim = x.size(1) 94 | return exp(1/dim*torch.sum(cos(x), dim=1, keepdim=True))/(c) 95 | 96 | integral_value = [] 97 | 98 | for i in range(500): 99 | print(i) 100 | x = (torch.rand(100000, args.dim).cuda()) * (args.right - args.left) + args.left 101 | x.requires_grad = True 102 | value = torch.mean(true_solution(x)) 103 | integral_value.append(value.item()) 104 | 105 | integral_true = sum(integral_value)/len(integral_value) 106 | 107 | # the input dimension is modified to 2 108 | class ResNet(nn.Module): 109 | def __init__(self, m): 110 | super(ResNet, self).__init__() 111 | self.fc1 = nn.Linear(args.dim, m) 112 | self.fc2 = nn.Linear(m, m) 113 | 114 | self.fc3 = nn.Linear(m, m) 115 | self.fc4 = nn.Linear(m, m) 116 | 117 | self.fc5 = nn.Linear(m, m) 118 | self.fc6 = nn.Linear(m, m) 119 | 120 | self.outlayer = nn.Linear(m, 1, bias=True) 121 | 122 | # # initialize the network bias 123 | for idx, m in enumerate(self.modules()): 124 | if isinstance(m, nn.Linear): 125 | m.bias.data.zero_() 126 | m.weight.data.normal_(0.0, 0.1) 127 | # if args.use_bias != 0 and m.bias.data.shape[0] == 1: 128 | # m.bias.data.fill_(args.use_bias) 129 | # # print(m.weight.data.shape, m.bias.data.shape[0], m.bias.data) 130 | 131 | def forward(self, x): 132 | 133 | x1 = (x -args.left)*2/(args.right -args.left) + (-1) 134 | s = torch.nn.functional.pad(x1, (0, m - args.dim)) 135 | 136 | y = self.fc1(x1) 137 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) # a RELU(X) + b RELU(X)^2 138 | y = self.fc2(y) 139 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)# 140 | y = y + s 141 | 142 | s = y 143 | y = self.fc3(y) 144 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 145 | y = self.fc4(y) 146 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 147 | y = y + s 148 | 149 | s = y 150 | y = self.fc5(y) 151 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 152 | y = self.fc6(y) 153 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) 154 | y = y + s 155 | 156 | output = self.outlayer(y) 157 | # output = torch.exp(output) 158 | return output 159 | 160 | ''' 161 | HyperParams Setting for Network 162 | ''' 163 | m = 50 # number of hidden size 164 | 165 | # for ResNet 166 | Ix = torch.zeros([1,m]).cuda() 167 | Ix[0,0] = 1 168 | 169 | 170 | def main(): 171 | 172 | if not os.path.isdir(args.checkpoint): 173 | mkdir_p(args.checkpoint) 174 | 175 | model = ResNet(m) 176 | 177 | model = model.cuda() 178 | cudnn.benchmark = True 179 | 180 | # if args.resume: 181 | # numerators = [] 182 | # denominators = [] 183 | # 184 | # checkpoint = torch.load(args.resume) 185 | # model.load_state_dict(checkpoint['state_dict']) 186 | # 187 | # for i in range(500): 188 | # print(i) 189 | # x = (torch.rand(100000, 5).cuda()) * (args.right - args.left) + args.left 190 | # x.requires_grad = True 191 | # start = time.time() 192 | # numerator, denominator = LHS_pde_2(model(x), x) 193 | # print(time.time() - start) 194 | # numerators.append(numerator.item()) 195 | # denominators.append(denominator.item()) 196 | # 197 | # print('eigen: {} '.format(sum(numerators) / sum(denominators))) 198 | # return 199 | 200 | if not os.path.isdir(args.checkpoint): 201 | mkdir_p(args.checkpoint) 202 | 203 | with open(args.checkpoint + "/Config.txt", 'w+') as f: 204 | for (k, v) in args._get_kwargs(): 205 | f.write(k + ' : ' + str(v) + '\n') 206 | 207 | print('Total params: %.2f' % (sum(p.numel() for p in model.parameters()))) 208 | 209 | """ 210 | Define Residual Methods and Optimizer 211 | """ 212 | criterion = nn.MSELoss() 213 | if args.optim == 'SGD': 214 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 215 | else: 216 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.wd) 217 | # Resume 218 | title = '' 219 | 220 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 221 | logger.set_names(['Learning Rate', 'Losses', 'pdeloss', 'intloss']) 222 | 223 | # Train and val 224 | for iter in range(0, args.iters): 225 | 226 | lr = cosine_lr(optimizer, args.lr, iter, args.iters) 227 | 228 | losses, pdeloss, intloss = train(model, criterion, optimizer, use_cuda, iter, lr) 229 | logger.append([lr, losses, pdeloss, intloss]) 230 | 231 | # save model 232 | save_checkpoint({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, checkpoint=args.checkpoint) 233 | 234 | numerators = [] 235 | denominators = [] 236 | 237 | for i in range(1000): 238 | print(i) 239 | x = (torch.rand(100000, args.dim).cuda()) * (args.right - args.left) + args.left 240 | # print(true_solution(x).size(), model(x).size()) 241 | sq_de = torch.mean((true_solution(x)) ** 2) 242 | sq_nu = torch.mean((true_solution(x) - model(x)) ** 2) 243 | numerators.append(sq_nu.item()) 244 | denominators.append(sq_de.item()) 245 | 246 | relative_l2 = math.sqrt(sum(numerators)) / math.sqrt(sum(denominators)) 247 | print('relative l2 error: ', relative_l2) 248 | logger.append(['relative_l2', relative_l2, 0, 0]) 249 | 250 | # logger.append([0, 0, relative_l2, 0, 0]) 251 | 252 | logger.close() 253 | 254 | 255 | def train(model, criterion, optimizer, use_cuda, iter, lr): 256 | 257 | # switch to train mode 258 | model.train() 259 | end = time.time() 260 | ''' 261 | points sampling 262 | ''' 263 | # the range is [0,1] --> [left, right] 264 | x = (torch.rand(args.trainbs, args.dim).cuda())*(args.right-args.left)+args.left 265 | x.requires_grad = True 266 | 267 | x_int = (torch.rand(args.intbs, args.dim).cuda()) * (args.right - args.left) + args.left 268 | x_int.requires_grad = True 269 | 270 | integral = torch.mean(model(x_int)) 271 | integration = (integral - integral_true) ** 2 272 | function_error = torch.nn.functional.mse_loss(LHS_pde(model(x), x, args.dim), RHS_pde(x)) 273 | # print(function_error, args.weight, integration) 274 | loss = function_error + args.weight * integration 275 | 276 | # compute gradient and do SGD step 277 | optimizer.zero_grad() 278 | loss.backward() 279 | optimizer.step() 280 | 281 | # measure elapsed time 282 | batch_time = time.time() - end 283 | suffix = '{iter:.1f} {lr:.8f}| Batch: {bt:.3f}s | Loss: {loss:.8f} | pdeloss: {pdeloss:.8f} | Integral loss {integral: .8f} |'.format( 284 | bt=batch_time, loss=loss.item(), iter=iter, lr=lr, pdeloss=function_error.item(), integral=integration.item()) 285 | print(suffix) 286 | return loss.item(), function_error.item(), integration.item() 287 | 288 | def save_checkpoint(state, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 289 | filepath = os.path.join(checkpoint, filename) 290 | torch.save(state, filepath) 291 | 292 | 293 | def cosine_lr(opt, base_lr, e, epochs): 294 | lr = 0.5 * base_lr * (math.cos(math.pi * e / epochs) + 1) 295 | for param_group in opt.param_groups: 296 | param_group["lr"] = lr 297 | return lr 298 | 299 | if __name__ == '__main__': 300 | main() 301 | -------------------------------------------------------------------------------- /nn/Schrodinger/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | # from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /nn/Schrodinger/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /nn/Schrodinger/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | def append(self, numbers): 61 | assert len(self.names) == len(numbers), 'Numbers do not match names' 62 | for index, num in enumerate(numbers): 63 | if isinstance(num, int): 64 | self.file.write("{}".format(num)) 65 | elif isinstance(num, str): 66 | self.file.write("{}".format(num)) 67 | else: 68 | self.file.write("{0:.8f}".format(num)) 69 | self.file.write('\t') 70 | self.numbers[self.names[index]].append(num) 71 | self.file.write('\n') 72 | self.file.flush() 73 | 74 | def plot(self, names=None): 75 | names = self.names if names == None else names 76 | numbers = self.numbers 77 | for _, name in enumerate(names): 78 | x = np.arange(len(numbers[name])) 79 | plt.plot(x, np.asarray(numbers[name])) 80 | plt.legend([self.title + '(' + name + ')' for name in names]) 81 | plt.grid(True) 82 | 83 | def close(self): 84 | if self.file is not None: 85 | self.file.close() 86 | 87 | class LoggerMonitor(object): 88 | '''Load and visualize multiple logs.''' 89 | def __init__ (self, paths): 90 | '''paths is a distionary with {name:filepath} pair''' 91 | self.loggers = [] 92 | for title, path in paths.items(): 93 | logger = Logger(path, title=title, resume=True) 94 | self.loggers.append(logger) 95 | 96 | def plot(self, names=None): 97 | plt.figure() 98 | plt.subplot(121) 99 | legend_text = [] 100 | for logger in self.loggers: 101 | legend_text += plot_overlap(logger, names) 102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 103 | plt.grid(True) 104 | 105 | if __name__ == '__main__': 106 | # # Example 107 | # logger = Logger('test.txt') 108 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 109 | 110 | # length = 100 111 | # t = np.arange(length) 112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 115 | 116 | # for i in range(0, length): 117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 118 | # logger.plot() 119 | 120 | # Example: logger monitor 121 | paths = { 122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 125 | } 126 | 127 | field = ['Valid Acc.'] 128 | 129 | monitor = LoggerMonitor(paths) 130 | monitor.plot(names=field) 131 | savefig('test.eps') -------------------------------------------------------------------------------- /nn/Schrodinger/utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /nn/Schrodinger/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() --------------------------------------------------------------------------------