├── .idea
├── Finite-expression-method.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── fex
├── Conservationlaw
│ ├── computational_tree.py
│ ├── controller_conservative.py
│ ├── function.py
│ ├── scripts.py
│ ├── tools.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── eval.py
│ │ ├── images
│ │ ├── cifar.png
│ │ └── imagenet.png
│ │ ├── logger.py
│ │ ├── misc.py
│ │ └── visualize.py
├── Poisson
│ ├── .idea
│ │ ├── .gitignore
│ │ ├── inspectionProfiles
│ │ │ └── profiles_settings.xml
│ │ ├── misc.xml
│ │ ├── modules.xml
│ │ └── poissoneqn.iml
│ ├── computational_tree.py
│ ├── controller_poisson.py
│ ├── function.py
│ ├── scripts.py
│ ├── tools.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── eval.py
│ │ ├── images
│ │ ├── cifar.png
│ │ └── imagenet.png
│ │ ├── logger.py
│ │ ├── misc.py
│ │ └── visualize.py
└── Schrodinger
│ ├── computational_tree.py
│ ├── controller_cubic_sh_firstdeflation_thenintegral.py
│ ├── function.py
│ ├── scripts.py
│ ├── tools.py
│ └── utils
│ ├── __init__.py
│ ├── eval.py
│ ├── images
│ ├── cifar.png
│ └── imagenet.png
│ ├── logger.py
│ ├── misc.py
│ └── visualize.py
├── fexrl.png
└── nn
├── Conservationlaw
├── scrpts.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── eval.py
│ ├── logger.py
│ ├── misc.py
│ └── visualize.py
├── Poisson
├── .idea
│ ├── .gitignore
│ ├── conservativelaw.iml
│ ├── inspectionProfiles
│ │ ├── Project_Default.xml
│ │ └── profiles_settings.xml
│ ├── misc.xml
│ └── modules.xml
├── scrpts.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── eval.py
│ ├── logger.py
│ ├── misc.py
│ └── visualize.py
└── Schrodinger
├── scripts.py
├── train_integral.py
└── utils
├── __init__.py
├── eval.py
├── logger.py
├── misc.py
└── visualize.py
/.idea/Finite-expression-method.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | 1660276797692
44 |
45 |
46 | 1660276797692
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Finite-expression-method
2 | 
3 |
4 | By [Senwei Liang](https://leungsamwai.github.io) and [Haizhao Yang](https://haizhaoyang.github.io/)
5 |
6 | This repo is the implementation of "Finite Expression Method for Solving High-Dimensional Partial Differential Equations" [[paper]](https://arxiv.org/abs/2206.10121).
7 |
8 | ## Introduction
9 |
10 | Finite expression method (FEX) is a new methodology that seeks an approximate PDE solution in the space of functions with finitely many analytic expressions. This repo provides a deep reinforcement learning method to implement FEX for various high-dimensional PDEs in different dimensions.
11 |
12 | 
13 |
14 | ## Environment
15 | * [PyTorch 1.0](http://pytorch.org/)
16 |
17 | ## Code structure
18 |
19 | ```
20 | Finite-expression-method
21 | │ README.md <-- You are here
22 | │
23 | └─── fex ----> three numerical examples with FEX
24 | │ │ Poisson
25 | │ │ Schrodinger
26 | │ │ Conservationlaw
27 | │
28 | └─── nn ----> three numerical examples with NN
29 | │ Poisson
30 | │ Schrodinger
31 | │ Conservationlaw
32 | ```
33 | ## Citing FEX
34 | If you find our code is helpful for your research, please kindly cite
35 | ```
36 | @article{liang2022finite,
37 | title={Finite Expression Method for Solving High-Dimensional Partial Differential Equations},
38 | author={Liang, Senwei and Yang, Haizhao},
39 | journal={arXiv preprint arXiv:2206.10121},
40 | year={2022}
41 | }
42 | ```
43 | ## Acknowledgments
44 |
45 | We appreciate [bearpaw](https://github.com/bearpaw) for his [DL framework](https://github.com/bearpaw/pytorch-classification) and Taehoon Kim for his [RL fromework](https://github.com/carpedm20/ENAS-pytorch).
--------------------------------------------------------------------------------
/fex/Conservationlaw/computational_tree.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import function as func
3 | unary = func.unary_functions
4 | binary = func.binary_functions
5 | unary_functions_str = func.unary_functions_str
6 | binary_functions_str = func.binary_functions_str
7 |
8 | class BinaryTree(object):
9 | def __init__(self,item,is_unary=True):
10 | self.key=item
11 | self.is_unary=is_unary
12 | self.leftChild=None
13 | self.rightChild=None
14 | def insertLeft(self,item, is_unary=True):
15 | if self.leftChild==None:
16 | self.leftChild=BinaryTree(item, is_unary)
17 | else:
18 | t=BinaryTree(item)
19 | t.leftChild=self.leftChild
20 | self.leftChild=t
21 | def insertRight(self,item, is_unary=True):
22 | if self.rightChild==None:
23 | self.rightChild=BinaryTree(item, is_unary)
24 | else:
25 | t=BinaryTree(item)
26 | t.rightChild=self.rightChild
27 | self.rightChild=t
28 |
29 | def compute_by_tree(tree, x):
30 | ''' judge whether a emtpy tree, if yes, that means the leaves and call the unary operation '''
31 | if tree.leftChild == None and tree.rightChild == None:
32 | return tree.key(x)
33 | elif tree.leftChild == None and tree.rightChild is not None:
34 | return tree.key(compute_by_tree(tree.rightChild, x))
35 | elif tree.leftChild is not None and tree.rightChild == None:
36 | return tree.key(compute_by_tree(tree.leftChild, x))
37 | else:
38 | return tree.key(compute_by_tree(tree.leftChild, x), compute_by_tree(tree.rightChild, x))
39 |
40 | def inorder(tree, actions):
41 | global count
42 | if tree:
43 | inorder(tree.leftChild, actions)
44 | action = actions[count].item()
45 | if tree.is_unary:
46 | action = action
47 | tree.key = unary[action]
48 | # print(count, action, func.unary_functions_str[action])
49 | else:
50 | action = action
51 | tree.key = binary[action]
52 | # print(count, action, func.binary_functions_str[action])
53 | count = count + 1
54 | inorder(tree.rightChild, actions)
55 |
56 | # def inorder(tree):
57 | # if tree:
58 | # inorder(tree.leftChild)
59 | # print(tree.key)
60 | # inorder(tree.rightChild)
61 |
62 | count = 0
63 | def inorder_w_idx(tree):
64 | global count
65 | if tree:
66 | inorder_w_idx(tree.leftChild)
67 | print(tree.key, count)
68 | count = count + 1
69 | inorder_w_idx(tree.rightChild)
70 |
71 | def basic_tree():
72 | tree = BinaryTree('', False)
73 | tree.insertLeft('', False)
74 | tree.leftChild.insertLeft('', True)
75 | tree.leftChild.insertRight('', True)
76 | tree.insertRight('', False)
77 | tree.rightChild.insertLeft('', True)
78 | tree.rightChild.insertRight('', True)
79 | return tree
80 |
81 | def get_function(actions):
82 | global count
83 | count = 0
84 | computation_tree = basic_tree()
85 | inorder(computation_tree, actions)
86 | count = 0 # 置零
87 | return computation_tree
88 |
89 | def inorder_test(tree, actions):
90 | global count
91 | if tree:
92 | inorder(tree.leftChild, actions)
93 | action = actions[count].item()
94 | print(action)
95 | if tree.is_unary:
96 | action = action
97 | tree.key = unary[action]
98 | # print(count, action, func.unary_functions_str[action])
99 | else:
100 | action = action
101 | tree.key = binary[action]
102 | # print(count, action, func.binary_functions_str[action])
103 | count = count + 1
104 | inorder(tree.rightChild, actions)
105 |
106 | if __name__ =='__main__':
107 | # tree = BinaryTree(np.add)
108 | # tree.insertLeft(np.multiply)
109 | # tree.leftChild.insertLeft(np.cos)
110 | # tree.leftChild.insertRight(np.sin)
111 | # tree.insertRight(np.sin)
112 | # print(compute_by_tree(tree, 30)) # np.sin(30)*np.cos(30)+np.sin(30)
113 | # inorder(tree)
114 | # inorder_w_idx(tree)
115 | import torch
116 | bs_action = [torch.LongTensor([10]), torch.LongTensor([2]),torch.LongTensor([9]),torch.LongTensor([1]),torch.LongTensor([0]),torch.LongTensor([2]),torch.LongTensor([6])]
117 |
118 | function = lambda x: compute_by_tree(get_function(bs_action), x)
119 | x = torch.FloatTensor([[-1], [1]])
120 |
121 | count = 0
122 | tr = basic_tree()
123 | inorder_test(tr, bs_action)
124 | count = 0
125 |
126 |
--------------------------------------------------------------------------------
/fex/Conservationlaw/function.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import sin, cos, exp
4 | import math
5 |
6 | scale = math.pi/4
7 |
8 | def LHS_pde(u, x, dim_set):
9 |
10 | v = torch.ones(u.shape).cuda()
11 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0]
12 | coefficient = torch.ones(1, dim_set).cuda()
13 | coefficient = -coefficient
14 | coefficient[:,0] = (dim_set-1)*scale
15 | ux_mul_coef = ux*coefficient
16 | LHS = torch.sum(ux_mul_coef, 1, keepdim=True)
17 | return LHS
18 |
19 | def RHS_pde(x):
20 | bs = x.size(0)
21 | return torch.zeros(bs, 1).cuda()
22 |
23 | def true_solution(x):
24 | dim = x.size(1)
25 | coefficient = torch.ones(1, dim).cuda()*scale
26 | coefficient[:, 0] = 1
27 | return torch.sin(torch.sum(x*coefficient, dim=1, keepdim=True))
28 |
29 |
30 | unary_functions = [lambda x: 0*x**2,
31 | lambda x: 1+0*x**2,
32 | lambda x: x+0*x**2,
33 | lambda x: x**2,
34 | lambda x: x**3,
35 | lambda x: x**4,
36 | torch.exp,
37 | torch.sin,
38 | torch.cos,]
39 |
40 | binary_functions = [lambda x,y: x+y,
41 | lambda x,y: x*y,
42 | lambda x,y: x-y]
43 |
44 |
45 | unary_functions_str = ['({}*(0)+{})',
46 | '({}*(1)+{})',
47 | # '5',
48 | '({}*{}+{})',
49 | # '-{}',
50 | '({}*({})**2+{})',
51 | '({}*({})**3+{})',
52 | '({}*({})**4+{})',
53 | # '({})**5',
54 | '({}*exp({})+{})',
55 | '({}*sin({})+{})',
56 | '({}*cos({})+{})',]
57 | # 'ref({})',
58 | # 'exp(-({})**2/2)']
59 |
60 | unary_functions_str_leaf= ['(0)',
61 | '(1)',
62 | # '5',
63 | '({})',
64 | # '-{}',
65 | '(({})**2)',
66 | '(({})**3)',
67 | '(({})**4)',
68 | # '({})**5',
69 | '(exp({}))',
70 | '(sin({}))',
71 | '(cos({}))',]
72 |
73 |
74 | binary_functions_str = ['(({})+({}))',
75 | '(({})*({}))',
76 | '(({})-({}))']
77 |
78 | if __name__ == '__main__':
79 | batch_size = 200
80 | left = -1
81 | right = 1
82 | points = (torch.rand(batch_size, 1)) * (right - left) + left
83 | x = torch.autograd.Variable(points.cuda(), requires_grad=True)
84 | function = true_solution
85 |
86 | '''
87 | PDE loss
88 | '''
89 | LHS = LHS_pde(function(x), x)
90 | RHS = RHS_pde(x)
91 | pde_loss = torch.nn.functional.mse_loss(LHS, RHS)
92 |
93 | '''
94 | boundary loss
95 | '''
96 | bc_points = torch.FloatTensor([[left], [right]]).cuda()
97 | bc_value = true_solution(bc_points)
98 | bd_loss = torch.nn.functional.mse_loss(function(bc_points), bc_value)
99 |
100 | print('pde loss: {} -- boundary loss: {}'.format(pde_loss.item(), bd_loss.item()))
--------------------------------------------------------------------------------
/fex/Conservationlaw/scripts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 |
5 | gpus = [1,7,8,9,2,3,4,5,6]*5
6 | idx = 0
7 |
8 | for i in range(5):
9 | for dim in [42, 48, 54]:
10 | gpu = gpus[idx]
11 | idx += 1
12 | os.system('screen python controller_conservative.py --epoch 2000 --bs 10 --greedy 0.1 --gpu '+str(gpu)+' --ckpt t_range_0_1_2ksearch_int20k_bd4kDim'+str(dim)+' --tree depth2_sub --random_step 3 --lr 0.002 --dim '+str(dim)+' --base 200000 --left -1 --right 1 --domainbs 20000 --bdbs 4000')
13 | time.sleep(500)
--------------------------------------------------------------------------------
/fex/Conservationlaw/tools.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from collections import defaultdict
4 | import collections
5 | from datetime import datetime
6 | import os
7 | import json
8 | import logging
9 |
10 | import numpy as np
11 | import torch
12 | from torch.autograd import Variable
13 |
14 | ##########################
15 | # Torch
16 | ##########################
17 |
18 | def detach(h):
19 | if type(h) == Variable:
20 | return Variable(h.data)
21 | else:
22 | return tuple(detach(v) for v in h)
23 |
24 | def get_variable(inputs, cuda=False, **kwargs):
25 | if type(inputs) in [list, np.ndarray]:
26 | inputs = torch.Tensor(inputs)
27 | if cuda:
28 | out = Variable(inputs.cuda(), **kwargs)
29 | else:
30 | out = Variable(inputs, **kwargs)
31 | return out
32 |
33 | def update_lr(optimizer, lr):
34 | for param_group in optimizer.param_groups:
35 | param_group['lr'] = lr
36 |
37 | def batchify(data, bsz, use_cuda):
38 | # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py
39 | nbatch = data.size(0) // bsz
40 | data = data.narrow(0, 0, nbatch * bsz)
41 | data = data.view(bsz, -1).t().contiguous()
42 | if use_cuda:
43 | data = data.cuda()
44 | return data
45 |
46 |
47 | ##########################
48 | # ETC
49 | ##########################
50 |
51 | Node = collections.namedtuple('Node', ['id', 'name'])
52 |
53 |
54 | class keydefaultdict(defaultdict):
55 | def __missing__(self, key):
56 | if self.default_factory is None:
57 | raise KeyError(key)
58 | else:
59 | ret = self[key] = self.default_factory(key)
60 | return ret
61 |
62 |
63 | def to_item(x):
64 | """Converts x, possibly scalar and possibly tensor, to a Python scalar."""
65 | if isinstance(x, (float, int)):
66 | return x
67 |
68 | if float(torch.__version__[0:3]) < 0.4:
69 | assert (x.dim() == 1) and (len(x) == 1)
70 | return x[0]
71 |
72 | return x.item()
73 |
74 |
75 | def get_logger(name=__file__, level=logging.INFO):
76 | logger = logging.getLogger(name)
77 |
78 | if getattr(logger, '_init_done__', None):
79 | logger.setLevel(level)
80 | return logger
81 |
82 | logger._init_done__ = True
83 | logger.propagate = False
84 | logger.setLevel(level)
85 |
86 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s")
87 | handler = logging.StreamHandler()
88 | handler.setFormatter(formatter)
89 | handler.setLevel(0)
90 |
91 | del logger.handlers[:]
92 | logger.addHandler(handler)
93 |
94 | return logger
95 |
96 |
97 | logger = get_logger()
98 |
99 |
100 | def prepare_dirs(args):
101 | """Sets the directories for the model, and creates those directories.
102 |
103 | Args:
104 | args: Parsed from `argparse` in the `config` module.
105 | """
106 | if args.load_path:
107 | if args.load_path.startswith(args.log_dir):
108 | args.model_dir = args.load_path
109 | else:
110 | if args.load_path.startswith(args.dataset):
111 | args.model_name = args.load_path
112 | else:
113 | args.model_name = "{}_{}".format(args.dataset, args.load_path)
114 | else:
115 | args.model_name = "{}_{}".format(args.dataset, get_time())
116 |
117 | if not hasattr(args, 'model_dir'):
118 | args.model_dir = os.path.join(args.log_dir, args.model_name)
119 | args.data_path = os.path.join(args.data_dir, args.dataset)
120 |
121 | for path in [args.log_dir, args.data_dir, args.model_dir]:
122 | if not os.path.exists(path):
123 | makedirs(path)
124 |
125 | def get_time():
126 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
127 |
128 | def save_args(args):
129 | param_path = os.path.join(args.model_dir, "params.json")
130 |
131 | logger.info("[*] MODEL dir: %s" % args.model_dir)
132 | logger.info("[*] PARAM path: %s" % param_path)
133 |
134 | with open(param_path, 'w') as fp:
135 | json.dump(args.__dict__, fp, indent=4, sort_keys=True)
136 |
137 | def save_dag(args, dag, name):
138 | save_path = os.path.join(args.model_dir, name)
139 | logger.info("[*] Save dag : {}".format(save_path))
140 | json.dump(dag, open(save_path, 'w'))
141 |
142 | def load_dag(args):
143 | load_path = os.path.join(args.dag_path)
144 | logger.info("[*] Load dag : {}".format(load_path))
145 | with open(load_path) as f:
146 | dag = json.load(f)
147 | dag = {int(k): [Node(el[0], el[1]) for el in v] for k, v in dag.items()}
148 | save_dag(args, dag, "dag.json")
149 | draw_network(dag, os.path.join(args.model_dir, "dag.png"))
150 | return dag
151 |
152 | def makedirs(path):
153 | if not os.path.exists(path):
154 | logger.info("[*] Make directories : {}".format(path))
155 | os.makedirs(path)
156 |
157 | def remove_file(path):
158 | if os.path.exists(path):
159 | logger.info("[*] Removed: {}".format(path))
160 | os.remove(path)
161 |
162 | def backup_file(path):
163 | root, ext = os.path.splitext(path)
164 | new_path = "{}.backup_{}{}".format(root, get_time(), ext)
165 |
166 | os.rename(path, new_path)
167 | logger.info("[*] {} has backup: {}".format(path, new_path))
168 |
--------------------------------------------------------------------------------
/fex/Conservationlaw/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from .misc import *
4 | from .logger import *
5 | from .visualize import *
6 | from .eval import *
7 |
8 | # progress bar
9 | import os, sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
11 | # from progress.bar import Bar as Bar
--------------------------------------------------------------------------------
/fex/Conservationlaw/utils/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | __all__ = ['accuracy']
4 |
5 | def accuracy(output, target, topk=(1,)):
6 | """Computes the precision@k for the specified values of k"""
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 |
10 | _, pred = output.topk(maxk, 1, True, True)
11 | pred = pred.t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 |
14 | res = []
15 | for k in topk:
16 | correct_k = correct[:k].view(-1).float().sum(0)
17 | res.append(correct_k.mul_(100.0 / batch_size))
18 | return res
--------------------------------------------------------------------------------
/fex/Conservationlaw/utils/images/cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Conservationlaw/utils/images/cifar.png
--------------------------------------------------------------------------------
/fex/Conservationlaw/utils/images/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Conservationlaw/utils/images/imagenet.png
--------------------------------------------------------------------------------
/fex/Conservationlaw/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 | import matplotlib.pyplot as plt
5 | import os
6 | import sys
7 | import numpy as np
8 |
9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
10 |
11 | def savefig(fname, dpi=None):
12 | dpi = 150 if dpi == None else dpi
13 | plt.savefig(fname, dpi=dpi)
14 |
15 | def plot_overlap(logger, names=None):
16 | names = logger.names if names == None else names
17 | numbers = logger.numbers
18 | for _, name in enumerate(names):
19 | x = np.arange(len(numbers[name]))
20 | plt.plot(x, np.asarray(numbers[name]))
21 | return [logger.title + '(' + name + ')' for name in names]
22 |
23 | class Logger(object):
24 | '''Save training process to log file with simple plot function.'''
25 | def __init__(self, fpath, title=None, resume=False):
26 | self.file = None
27 | self.resume = resume
28 | self.title = '' if title == None else title
29 | if fpath is not None:
30 | if resume:
31 | self.file = open(fpath, 'r')
32 | name = self.file.readline()
33 | self.names = name.rstrip().split('\t')
34 | self.numbers = {}
35 | for _, name in enumerate(self.names):
36 | self.numbers[name] = []
37 |
38 | for numbers in self.file:
39 | numbers = numbers.rstrip().split('\t')
40 | for i in range(0, len(numbers)):
41 | self.numbers[self.names[i]].append(numbers[i])
42 | self.file.close()
43 | self.file = open(fpath, 'a')
44 | else:
45 | self.file = open(fpath, 'w')
46 |
47 | def set_names(self, names):
48 | if self.resume:
49 | pass
50 | # initialize numbers as empty list
51 | self.numbers = {}
52 | self.names = names
53 | for _, name in enumerate(self.names):
54 | self.file.write(name)
55 | self.file.write('\t')
56 | self.numbers[name] = []
57 | self.file.write('\n')
58 | self.file.flush()
59 |
60 | def append(self, numbers):
61 | assert len(self.names) == len(numbers), 'Numbers do not match names'
62 | for index, num in enumerate(numbers):
63 | if isinstance(num, int):
64 | self.file.write("{}".format(num))
65 | elif isinstance(num, str):
66 | self.file.write("{}".format(num))
67 | else:
68 | self.file.write("{0:.8f}".format(num))
69 | self.file.write('\t')
70 | self.numbers[self.names[index]].append(num)
71 | self.file.write('\n')
72 | self.file.flush()
73 |
74 | def plot(self, names=None):
75 | names = self.names if names == None else names
76 | numbers = self.numbers
77 | for _, name in enumerate(names):
78 | x = np.arange(len(numbers[name]))
79 | plt.plot(x, np.asarray(numbers[name]))
80 | plt.legend([self.title + '(' + name + ')' for name in names])
81 | plt.grid(True)
82 |
83 | def close(self):
84 | if self.file is not None:
85 | self.file.close()
86 |
87 | class LoggerMonitor(object):
88 | '''Load and visualize multiple logs.'''
89 | def __init__ (self, paths):
90 | '''paths is a distionary with {name:filepath} pair'''
91 | self.loggers = []
92 | for title, path in paths.items():
93 | logger = Logger(path, title=title, resume=True)
94 | self.loggers.append(logger)
95 |
96 | def plot(self, names=None):
97 | plt.figure()
98 | plt.subplot(121)
99 | legend_text = []
100 | for logger in self.loggers:
101 | legend_text += plot_overlap(logger, names)
102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
103 | plt.grid(True)
104 |
105 | if __name__ == '__main__':
106 | # # Example
107 | # logger = Logger('test.txt')
108 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
109 |
110 | # length = 100
111 | # t = np.arange(length)
112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
115 |
116 | # for i in range(0, length):
117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
118 | # logger.plot()
119 |
120 | # Example: logger monitor
121 | paths = {
122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
125 | }
126 |
127 | field = ['Valid Acc.']
128 |
129 | monitor = LoggerMonitor(paths)
130 | monitor.plot(names=field)
131 | savefig('test.eps')
--------------------------------------------------------------------------------
/fex/Conservationlaw/utils/misc.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import errno
7 | import os
8 | import sys
9 | import time
10 | import math
11 |
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | from torch.autograd import Variable
15 |
16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']
17 |
18 |
19 | def get_mean_and_std(dataset):
20 | '''Compute the mean and std value of dataset.'''
21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
22 |
23 | mean = torch.zeros(3)
24 | std = torch.zeros(3)
25 | print('==> Computing mean and std..')
26 | for inputs, targets in dataloader:
27 | for i in range(3):
28 | mean[i] += inputs[:,i,:,:].mean()
29 | std[i] += inputs[:,i,:,:].std()
30 | mean.div_(len(dataset))
31 | std.div_(len(dataset))
32 | return mean, std
33 |
34 | def init_params(net):
35 | '''Init layer parameters.'''
36 | for m in net.modules():
37 | if isinstance(m, nn.Conv2d):
38 | init.kaiming_normal(m.weight, mode='fan_out')
39 | if m.bias:
40 | init.constant(m.bias, 0)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.constant(m.weight, 1)
43 | init.constant(m.bias, 0)
44 | elif isinstance(m, nn.Linear):
45 | init.normal(m.weight, std=1e-3)
46 | if m.bias:
47 | init.constant(m.bias, 0)
48 |
49 | def mkdir_p(path):
50 | '''make dir if not exist'''
51 | try:
52 | os.makedirs(path)
53 | except OSError as exc: # Python >2.5
54 | if exc.errno == errno.EEXIST and os.path.isdir(path):
55 | pass
56 | else:
57 | raise
58 |
59 | class AverageMeter(object):
60 | """Computes and stores the average and current value
61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
62 | """
63 | def __init__(self):
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = 0
68 | self.avg = 0
69 | self.sum = 0
70 | self.count = 0
71 |
72 | def update(self, val, n=1):
73 | self.val = val
74 | self.sum += val * n
75 | self.count += n
76 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/fex/Conservationlaw/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | from .misc import *
8 |
9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
10 |
11 | # functions to show an image
12 | def make_image(img, mean=(0,0,0), std=(1,1,1)):
13 | for i in range(0, 3):
14 | img[i] = img[i] * std[i] + mean[i] # unnormalize
15 | npimg = img.numpy()
16 | return np.transpose(npimg, (1, 2, 0))
17 |
18 | def gauss(x,a,b,c):
19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
20 |
21 | def colorize(x):
22 | ''' Converts a one-channel grayscale image to a color heatmap image '''
23 | if x.dim() == 2:
24 | torch.unsqueeze(x, 0, out=x)
25 | if x.dim() == 3:
26 | cl = torch.zeros([3, x.size(1), x.size(2)])
27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
28 | cl[1] = gauss(x,1,.5,.3)
29 | cl[2] = gauss(x,1,.2,.3)
30 | cl[cl.gt(1)] = 1
31 | elif x.dim() == 4:
32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
34 | cl[:,1,:,:] = gauss(x,1,.5,.3)
35 | cl[:,2,:,:] = gauss(x,1,.2,.3)
36 | return cl
37 |
38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
40 | plt.imshow(images)
41 | plt.show()
42 |
43 |
44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
45 | im_size = images.size(2)
46 |
47 | # save for adding mask
48 | im_data = images.clone()
49 | for i in range(0, 3):
50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
51 |
52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
53 | plt.subplot(2, 1, 1)
54 | plt.imshow(images)
55 | plt.axis('off')
56 |
57 | # for b in range(mask.size(0)):
58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
59 | mask_size = mask.size(2)
60 | # print('Max %f Min %f' % (mask.max(), mask.min()))
61 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
63 | # for c in range(3):
64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
65 |
66 | # print(mask.size())
67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
69 | plt.subplot(2, 1, 2)
70 | plt.imshow(mask)
71 | plt.axis('off')
72 |
73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
74 | im_size = images.size(2)
75 |
76 | # save for adding mask
77 | im_data = images.clone()
78 | for i in range(0, 3):
79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
80 |
81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
82 | plt.subplot(1+len(masklist), 1, 1)
83 | plt.imshow(images)
84 | plt.axis('off')
85 |
86 | for i in range(len(masklist)):
87 | mask = masklist[i].data.cpu()
88 | # for b in range(mask.size(0)):
89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
90 | mask_size = mask.size(2)
91 | # print('Max %f Min %f' % (mask.max(), mask.min()))
92 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
94 | # for c in range(3):
95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
96 |
97 | # print(mask.size())
98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
100 | plt.subplot(1+len(masklist), 1, i+2)
101 | plt.imshow(mask)
102 | plt.axis('off')
103 |
104 |
105 |
106 | # x = torch.zeros(1, 3, 3)
107 | # out = colorize(x)
108 | # out_im = make_image(out)
109 | # plt.imshow(out_im)
110 | # plt.show()
--------------------------------------------------------------------------------
/fex/Poisson/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/fex/Poisson/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/fex/Poisson/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/fex/Poisson/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/fex/Poisson/.idea/poissoneqn.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/fex/Poisson/computational_tree.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import function as func
3 | unary = func.unary_functions
4 | binary = func.binary_functions
5 | unary_functions_str = func.unary_functions_str
6 | binary_functions_str = func.binary_functions_str
7 |
8 | class BinaryTree(object):
9 | def __init__(self,item,is_unary=True):
10 | self.key=item
11 | self.is_unary=is_unary
12 | self.leftChild=None
13 | self.rightChild=None
14 | def insertLeft(self,item, is_unary=True):
15 | if self.leftChild==None:
16 | self.leftChild=BinaryTree(item, is_unary)
17 | else:
18 | t=BinaryTree(item)
19 | t.leftChild=self.leftChild
20 | self.leftChild=t
21 | def insertRight(self,item, is_unary=True):
22 | if self.rightChild==None:
23 | self.rightChild=BinaryTree(item, is_unary)
24 | else:
25 | t=BinaryTree(item)
26 | t.rightChild=self.rightChild
27 | self.rightChild=t
28 |
29 | def compute_by_tree(tree, x):
30 | ''' judge whether a emtpy tree, if yes, that means the leaves and call the unary operation '''
31 | if tree.leftChild == None and tree.rightChild == None:
32 | return tree.key(x)
33 | elif tree.leftChild == None and tree.rightChild is not None:
34 | return tree.key(compute_by_tree(tree.rightChild, x))
35 | elif tree.leftChild is not None and tree.rightChild == None:
36 | return tree.key(compute_by_tree(tree.leftChild, x))
37 | else:
38 | return tree.key(compute_by_tree(tree.leftChild, x), compute_by_tree(tree.rightChild, x))
39 |
40 | def inorder(tree, actions):
41 | global count
42 | if tree:
43 | inorder(tree.leftChild, actions)
44 | action = actions[count].item()
45 | if tree.is_unary:
46 | action = action
47 | tree.key = unary[action]
48 | # print(count, action, func.unary_functions_str[action])
49 | else:
50 | action = action
51 | tree.key = binary[action]
52 | # print(count, action, func.binary_functions_str[action])
53 | count = count + 1
54 | inorder(tree.rightChild, actions)
55 |
56 | # def inorder(tree):
57 | # if tree:
58 | # inorder(tree.leftChild)
59 | # print(tree.key)
60 | # inorder(tree.rightChild)
61 |
62 | count = 0
63 | def inorder_w_idx(tree):
64 | global count
65 | if tree:
66 | inorder_w_idx(tree.leftChild)
67 | print(tree.key, count)
68 | count = count + 1
69 | inorder_w_idx(tree.rightChild)
70 |
71 | def basic_tree():
72 | tree = BinaryTree('', False)
73 | tree.insertLeft('', False)
74 | tree.leftChild.insertLeft('', True)
75 | tree.leftChild.insertRight('', True)
76 | tree.insertRight('', False)
77 | tree.rightChild.insertLeft('', True)
78 | tree.rightChild.insertRight('', True)
79 | return tree
80 |
81 | def get_function(actions):
82 | global count
83 | count = 0
84 | computation_tree = basic_tree()
85 | inorder(computation_tree, actions)
86 | count = 0 # 置零
87 | return computation_tree
88 |
89 | def inorder_test(tree, actions):
90 | global count
91 | if tree:
92 | inorder(tree.leftChild, actions)
93 | action = actions[count].item()
94 | print(action)
95 | if tree.is_unary:
96 | action = action
97 | tree.key = unary[action]
98 | # print(count, action, func.unary_functions_str[action])
99 | else:
100 | action = action
101 | tree.key = binary[action]
102 | # print(count, action, func.binary_functions_str[action])
103 | count = count + 1
104 | inorder(tree.rightChild, actions)
105 |
106 | if __name__ =='__main__':
107 | # tree = BinaryTree(np.add)
108 | # tree.insertLeft(np.multiply)
109 | # tree.leftChild.insertLeft(np.cos)
110 | # tree.leftChild.insertRight(np.sin)
111 | # tree.insertRight(np.sin)
112 | # print(compute_by_tree(tree, 30)) # np.sin(30)*np.cos(30)+np.sin(30)
113 | # inorder(tree)
114 | # inorder_w_idx(tree)
115 | import torch
116 | bs_action = [torch.LongTensor([10]), torch.LongTensor([2]),torch.LongTensor([9]),torch.LongTensor([1]),torch.LongTensor([0]),torch.LongTensor([2]),torch.LongTensor([6])]
117 |
118 | function = lambda x: compute_by_tree(get_function(bs_action), x)
119 | x = torch.FloatTensor([[-1], [1]])
120 |
121 | count = 0
122 | tr = basic_tree()
123 | inorder_test(tr, bs_action)
124 | count = 0
125 |
126 |
--------------------------------------------------------------------------------
/fex/Poisson/controller_poisson.py:
--------------------------------------------------------------------------------
1 | """A module with NAS controller-related code."""
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import tools
6 | import scipy
7 | from utils import Logger, mkdir_p
8 | import os
9 | import torch.nn as nn
10 | from computational_tree import BinaryTree
11 | import function as func
12 | import argparse
13 | import random
14 | import math
15 |
16 | parser = argparse.ArgumentParser(description='NAS')
17 |
18 | parser.add_argument('--left', default=-1, type=float)
19 | parser.add_argument('--right', default=1, type=float)
20 | parser.add_argument('--epoch', default=2000, type=int)
21 | parser.add_argument('--bs', default=1, type=int)
22 | parser.add_argument('--greedy', default=0, type=float)
23 | parser.add_argument('--random_step', default=0, type=float)
24 | parser.add_argument('--ckpt', default='', type=str)
25 | parser.add_argument('--gpu', default=0, type=int)
26 | parser.add_argument('--dim', default=20, type=int)
27 | parser.add_argument('--tree', default='depth2', type=str)
28 | parser.add_argument('--lr', default=1e-2, type=float)
29 | parser.add_argument('--percentile', default=0.5, type=float)
30 | parser.add_argument('--base', default=100, type=int)
31 | parser.add_argument('--domainbs', default=1000, type=int)
32 | parser.add_argument('--bdbs', default=1000, type=int)
33 | args = parser.parse_args()
34 |
35 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
36 |
37 | unary = func.unary_functions
38 | binary = func.binary_functions
39 | unary_functions_str = func.unary_functions_str
40 | unary_functions_str_leaf = func.unary_functions_str_leaf
41 | binary_functions_str = func.binary_functions_str
42 |
43 | left = args.left
44 | right = args.right
45 | dim = args.dim
46 |
47 | def get_boundary(num_pts, dim):
48 |
49 | bd_pts = (torch.rand(num_pts, dim).cuda()) * (args.right - args.left) + args.left
50 |
51 | num_half = num_pts//2
52 | xlst = torch.arange(0, num_half)
53 | ylst = torch.randint(0, dim, (num_half,))
54 | bd_pts[xlst, ylst] = args.left
55 |
56 | xlst = torch.arange(num_half, num_pts)
57 | ylst = torch.randint(0, dim, (num_half,))
58 | bd_pts[xlst, ylst] = args.right
59 |
60 | return bd_pts
61 |
62 |
63 | class candidate(object):
64 | def __init__(self, action, expression, error):
65 | self.action = action
66 | self.expression = expression
67 | self.error = error
68 |
69 | class SaveBuffer(object):
70 | def __init__(self, max_size):
71 | self.max_size = max_size
72 | self.candidates = []
73 |
74 | def num_candidates(self):
75 | return len(self.candidates)
76 |
77 | def add_new(self, candidate):
78 | flag = 1
79 | action_idx = None
80 | for idx, old_candidate in enumerate(self.candidates):
81 | if candidate.action == old_candidate.action and candidate.error < old_candidate.error: # 如果判断出来和之前的action一样的话,就不去做
82 | flag = 1
83 | action_idx = idx
84 | break
85 | elif candidate.action == old_candidate.action:
86 | flag = 0
87 |
88 | if flag == 1:
89 | if action_idx is not None:
90 | print(action_idx)
91 | self.candidates.pop(action_idx)
92 | self.candidates.append(candidate)
93 | self.candidates = sorted(self.candidates, key=lambda x: x.error) # from small to large
94 |
95 | if len(self.candidates) > self.max_size:
96 | self.candidates.pop(-1) # remove the last one
97 |
98 | if args.tree == 'depth2':
99 | def basic_tree():
100 | tree = BinaryTree('', False)
101 |
102 | tree.insertLeft('', True)
103 | tree.leftChild.insertLeft('', False)
104 | tree.leftChild.leftChild.insertLeft('', True)
105 | tree.leftChild.leftChild.insertRight('', True)
106 |
107 | tree.insertRight('', True)
108 | tree.rightChild.insertLeft('', False)
109 | tree.rightChild.leftChild.insertLeft('', True)
110 | tree.rightChild.leftChild.insertRight('', True)
111 | return tree
112 |
113 | elif args.tree == 'depth1':
114 | def basic_tree():
115 |
116 | tree = BinaryTree('', False)
117 | tree.insertLeft('', True)
118 | tree.insertRight('', True)
119 |
120 | return tree
121 |
122 | elif args.tree == 'depth2_rml':
123 | def basic_tree():
124 | tree = BinaryTree('', False)
125 |
126 | tree.insertLeft('', True)
127 | tree.leftChild.insertLeft('', True)
128 |
129 | tree.insertRight('', True)
130 | tree.rightChild.insertLeft('', True)
131 |
132 | return tree
133 |
134 | elif args.tree == 'depth2_rmu':
135 | print('**************************rmu**************************')
136 | def basic_tree():
137 | tree = BinaryTree('', False)
138 |
139 | tree.insertLeft('', True)
140 | tree.leftChild.insertLeft('', False)
141 | tree.leftChild.leftChild.insertLeft('', True)
142 | tree.leftChild.leftChild.insertRight('', True)
143 |
144 | tree.insertRight('', False)
145 | tree.rightChild.insertLeft('', True)
146 | tree.rightChild.insertRight('', True)
147 |
148 | return tree
149 |
150 | elif args.tree == 'depth2_rmu2':
151 | print('**************************rmu2**************************')
152 | def basic_tree():
153 | tree = BinaryTree('', False)
154 |
155 | tree.insertLeft('', True)
156 | tree.leftChild.insertLeft('', False)
157 | tree.leftChild.leftChild.insertLeft('', True)
158 | tree.leftChild.leftChild.insertRight('', True)
159 |
160 | tree.insertRight('', True)
161 | # tree.rightChild.insertLeft('', True)
162 | # tree.rightChild.insertRight('', True)
163 |
164 | return tree
165 |
166 | elif args.tree == 'depth3':
167 | def basic_tree():
168 | tree = BinaryTree('', False)
169 |
170 | tree.insertLeft('', True)
171 | tree.leftChild.insertLeft('', False)
172 | tree.leftChild.leftChild.insertLeft('', True)
173 | tree.leftChild.leftChild.leftChild.insertLeft('', False)
174 | tree.leftChild.leftChild.leftChild.leftChild.insertLeft('', True)
175 | tree.leftChild.leftChild.leftChild.leftChild.insertRight('', True)
176 |
177 | tree.leftChild.leftChild.insertRight('', True)
178 | tree.leftChild.leftChild.rightChild.insertLeft('', False)
179 | tree.leftChild.leftChild.rightChild.leftChild.insertLeft('', True)
180 | tree.leftChild.leftChild.rightChild.leftChild.insertRight('', True)
181 |
182 | tree.insertRight('', True)
183 | tree.rightChild.insertLeft('', False)
184 | tree.rightChild.leftChild.insertLeft('', True)
185 | tree.rightChild.leftChild.leftChild.insertLeft('', False)
186 | tree.rightChild.leftChild.leftChild.leftChild.insertLeft('', True)
187 | tree.rightChild.leftChild.leftChild.leftChild.insertRight('', True)
188 |
189 | tree.rightChild.leftChild.insertRight('', True)
190 | tree.rightChild.leftChild.rightChild.insertLeft('', False)
191 | tree.rightChild.leftChild.rightChild.leftChild.insertLeft('', True)
192 | tree.rightChild.leftChild.rightChild.leftChild.insertRight('', True)
193 | return tree
194 |
195 | elif args.tree == 'depth2_sub':
196 | print('**************************sub**************************')
197 | def basic_tree():
198 | tree = BinaryTree('', True)
199 |
200 | tree.insertLeft('', False)
201 | tree.leftChild.insertLeft('', True)
202 | tree.leftChild.insertRight('', True)
203 |
204 | # tree.rightChild.insertLeft('', True)
205 | # tree.rightChild.insertRight('', True)
206 |
207 | return tree
208 |
209 | structure = []
210 |
211 | def inorder_structure(tree):
212 | global structure
213 | if tree:
214 | inorder_structure(tree.leftChild)
215 | structure.append(tree.is_unary)
216 | inorder_structure(tree.rightChild)
217 | inorder_structure(basic_tree())
218 | print('tree structure', structure)
219 |
220 | structure_choice = []
221 | for is_unary in structure:
222 | if is_unary == True:
223 | structure_choice.append(len(unary))
224 | else:
225 | structure_choice.append(len(binary))
226 | print('tree structure choices', structure_choice)
227 |
228 | if args.tree == 'depth1':
229 | def basic_tree():
230 | tree = BinaryTree('', False)
231 |
232 | tree.insertLeft('', True)
233 | tree.insertRight('', True)
234 |
235 | return tree
236 |
237 | elif args.tree == 'depth2':
238 | def basic_tree():
239 | tree = BinaryTree('', False)
240 |
241 | tree.insertLeft('', True)
242 | tree.leftChild.insertLeft('', False)
243 | tree.leftChild.leftChild.insertLeft('', True)
244 | tree.leftChild.leftChild.insertRight('', True)
245 |
246 | tree.insertRight('', True)
247 | tree.rightChild.insertLeft('', False)
248 | tree.rightChild.leftChild.insertLeft('', True)
249 | tree.rightChild.leftChild.insertRight('', True)
250 | return tree
251 |
252 | elif args.tree == 'depth3':
253 | def basic_tree():
254 | tree = BinaryTree('', False)
255 |
256 | tree.insertLeft('', True)
257 | tree.leftChild.insertLeft('', False)
258 | tree.leftChild.leftChild.insertLeft('', True)
259 | tree.leftChild.leftChild.leftChild.insertLeft('', False)
260 | tree.leftChild.leftChild.leftChild.leftChild.insertLeft('', True)
261 | tree.leftChild.leftChild.leftChild.leftChild.insertRight('', True)
262 |
263 | tree.leftChild.leftChild.insertRight('', True)
264 | tree.leftChild.leftChild.rightChild.insertLeft('', False)
265 | tree.leftChild.leftChild.rightChild.leftChild.insertLeft('', True)
266 | tree.leftChild.leftChild.rightChild.leftChild.insertRight('', True)
267 |
268 | tree.insertRight('', True)
269 | tree.rightChild.insertLeft('', False)
270 | tree.rightChild.leftChild.insertLeft('', True)
271 | tree.rightChild.leftChild.leftChild.insertLeft('', False)
272 | tree.rightChild.leftChild.leftChild.leftChild.insertLeft('', True)
273 | tree.rightChild.leftChild.leftChild.leftChild.insertRight('', True)
274 |
275 | tree.rightChild.leftChild.insertRight('', True)
276 | tree.rightChild.leftChild.rightChild.insertLeft('', False)
277 | tree.rightChild.leftChild.rightChild.leftChild.insertLeft('', True)
278 | tree.rightChild.leftChild.rightChild.leftChild.insertRight('', True)
279 | return tree
280 |
281 | structure = []
282 | leaves_index = []
283 | leaves = 0
284 | count = 0
285 |
286 | def inorder_structure(tree):
287 | global structure, leaves, count, leaves_index
288 | if tree:
289 | inorder_structure(tree.leftChild)
290 | structure.append(tree.is_unary)
291 | if tree.leftChild is None and tree.rightChild is None:
292 | leaves = leaves + 1
293 | leaves_index.append(count)
294 | count = count + 1
295 | inorder_structure(tree.rightChild)
296 |
297 |
298 | inorder_structure(basic_tree())
299 |
300 | print('leaves index:', leaves_index)
301 |
302 | print('tree structure:', structure, 'leaves num:', leaves)
303 |
304 | structure_choice = []
305 | for is_unary in structure:
306 | if is_unary == True:
307 | structure_choice.append(len(unary))
308 | else:
309 | structure_choice.append(len(binary))
310 | print('tree structure choices', structure_choice)
311 |
312 | def reset_params(tree_params):
313 | for v in tree_params:
314 | # v.data.fill_(0.01)
315 | v.data.normal_(0.0, 0.1)
316 |
317 | def inorder(tree, actions):
318 | global count
319 | if tree:
320 | inorder(tree.leftChild, actions)
321 | action = actions[count].item()
322 | if tree.is_unary:
323 | action = action
324 | tree.key = unary[action]
325 | # print(count, action, func.unary_functions_str[action])
326 | else:
327 | action = action
328 | tree.key = binary[action]
329 | # print(count, action, func.binary_functions_str[action])
330 | count = count + 1
331 | inorder(tree.rightChild, actions)
332 |
333 | def inorder_visualize(tree, actions, trainable_tree):
334 | global count, leaves_cnt
335 | if tree:
336 | leftfun = inorder_visualize(tree.leftChild, actions, trainable_tree)
337 | action = actions[count].item()
338 | # print('123', tree.key)
339 | if tree.is_unary:# and not tree.key.is_leave:
340 | if count not in leaves_index:
341 | midfun = unary_functions_str[action]
342 | a = trainable_tree.learnable_operator_set[count][action].a.item()
343 | b = trainable_tree.learnable_operator_set[count][action].b.item()
344 | else:
345 | midfun = unary_functions_str_leaf[action]
346 | else:
347 | midfun = binary_functions_str[action]
348 |
349 | count = count + 1
350 | rightfun = inorder_visualize(tree.rightChild, actions, trainable_tree)
351 | if leftfun is None and rightfun is None:
352 | w = []
353 | for i in range(dim):
354 | w.append(trainable_tree.linear[leaves_cnt].weight[0][i].item())
355 | bias = trainable_tree.linear[leaves_cnt].bias[0].item()
356 | leaves_cnt = leaves_cnt + 1
357 | ## -------------------------------------- input variable element wise ----------------------------
358 | expression = ''
359 | for i in range(0, dim):
360 | # print(midfun)
361 | x_expression = midfun.format('x'+str(i))
362 | expression = expression + ('{:.4f}*{}'+'+').format(w[i], x_expression)
363 | expression = expression+'{:.4f}'.format(bias)
364 | expression = '('+expression+')'
365 | # print('visualize', count, leaves_cnt, action)
366 | return expression
367 | elif leftfun is not None and rightfun is None:
368 | if '(0)' in midfun or '(1)' in midfun:
369 | return midfun.format('{:.4f}'.format(a), '{:.4f}'.format(b))
370 | else:
371 | return midfun.format('{:.4f}'.format(a), leftfun, '{:.4f}'.format(b))
372 | elif tree.leftChild is None and tree.rightChild is not None:
373 | if '(0)' in midfun or '(1)' in midfun:
374 | return midfun.format('{:.4f}'.format(a), '{:.4f}'.format(b))
375 | else:
376 | return midfun.format('{:.4f}'.format(a), rightfun, '{:.4f}'.format(b))
377 | else:
378 | return midfun.format(leftfun, rightfun)
379 | else:
380 | return None
381 |
382 | def get_function(actions):
383 | global count
384 | count = 0
385 | computation_tree = basic_tree()
386 | inorder(computation_tree, actions)
387 | count = 0
388 | return computation_tree
389 |
390 | def inorder_params(tree, actions, unary_choices):
391 | global count
392 | if tree:
393 | inorder_params(tree.leftChild, actions, unary_choices)
394 | action = actions[count].item()
395 | if tree.is_unary:
396 | action = action
397 | tree.key = unary_choices[count][action]
398 | else:
399 | action = action
400 | tree.key = unary_choices[count][len(unary)+action]
401 | count = count + 1
402 | inorder_params(tree.rightChild, actions, unary_choices)
403 |
404 | def get_function_trainable_params(actions, unary_choices):
405 | global count
406 | count = 0
407 | computation_tree = basic_tree()
408 | inorder_params(computation_tree, actions, unary_choices)
409 | count = 0
410 | return computation_tree
411 |
412 | class unary_operation(nn.Module):
413 | def __init__(self, operator, is_leave):
414 | super(unary_operation, self).__init__()
415 | self.unary = operator
416 | if not is_leave:
417 | self.a = nn.Parameter(torch.Tensor(1).cuda())
418 | self.a.data.fill_(1)
419 | self.b = nn.Parameter(torch.Tensor(1).cuda())
420 | self.b.data.fill_(0)
421 | self.is_leave = is_leave
422 |
423 | def forward(self, x):
424 | if self.is_leave:
425 | return self.unary(x)
426 | else:
427 | return self.a*self.unary(x)+self.b
428 |
429 | class binary_operation(nn.Module):
430 | def __init__(self, operator):
431 | super(binary_operation, self).__init__()
432 | self.binary = operator
433 | def forward(self, x, y):
434 | return self.binary(x, y)
435 |
436 | leaves_cnt = 0
437 |
438 | def compute_by_tree(tree, linear, x):
439 | if tree.leftChild == None and tree.rightChild == None: # leaf node
440 | global leaves_cnt
441 | transformation = linear[leaves_cnt]
442 | leaves_cnt = leaves_cnt + 1
443 | return transformation(tree.key(x))
444 | elif tree.leftChild is None and tree.rightChild is not None:
445 | return tree.key(compute_by_tree(tree.rightChild, linear, x))
446 | elif tree.leftChild is not None and tree.rightChild is None:
447 | return tree.key(compute_by_tree(tree.leftChild, linear, x))
448 | else:
449 | return tree.key(compute_by_tree(tree.leftChild, linear, x), compute_by_tree(tree.rightChild, linear, x))
450 |
451 | class learnable_compuatation_tree(nn.Module):
452 | def __init__(self):
453 | super(learnable_compuatation_tree, self).__init__()
454 | self.learnable_operator_set = {}
455 | for i in range(len(structure)):
456 | self.learnable_operator_set[i] = []
457 | is_leave = i in leaves_index
458 | for j in range(len(unary)):
459 | self.learnable_operator_set[i].append(unary_operation(unary[j], is_leave))
460 | for j in range(len(binary)):
461 | self.learnable_operator_set[i].append(binary_operation(binary[j]))
462 | self.linear = []
463 | for num, i in enumerate(range(leaves)):
464 | linear_module = torch.nn.Linear(dim, 1, bias=True).cuda() #set only one variable
465 | linear_module.weight.data.normal_(0, 1/math.sqrt(dim))
466 | linear_module.bias.data.fill_(0)
467 | self.linear.append(linear_module)
468 |
469 | def forward(self, x, bs_action):
470 | # print(len(bs_action))
471 | global leaves_cnt
472 | leaves_cnt = 0
473 | function = lambda y: compute_by_tree(get_function_trainable_params(bs_action, self.learnable_operator_set), self.linear, y)
474 | out = function(x)
475 | leaves_cnt = 0
476 | return out
477 |
478 | class Controller(torch.nn.Module):
479 | def __init__(self):
480 | torch.nn.Module.__init__(self)
481 |
482 | self.softmax_temperature = 5.0
483 | self.tanh_c = 2.5
484 | self.mode = True
485 |
486 | self.input_size = 20
487 | self.hidden_size = 50
488 | self.output_size = sum(structure_choice)
489 |
490 | self._fc_controller = nn.Sequential(
491 | nn.Linear(self.input_size,self.hidden_size),
492 | nn.ReLU(inplace=True),
493 | nn.Linear(self.hidden_size,self.output_size))
494 |
495 | def forward(self,x):
496 | logits = self._fc_controller(x)
497 |
498 | logits /= self.softmax_temperature
499 |
500 | # exploration # ??
501 | if self.mode == 'train':
502 | logits = (self.tanh_c*F.tanh(logits))
503 |
504 | return logits
505 |
506 | def sample(self, batch_size=1, step=0):
507 | """Samples a set of `args.num_blocks` many computational nodes from the
508 | controller, where each node is made up of an activation function, and
509 | each node except the last also includes a previous node.
510 | """
511 |
512 | # [B, L, H]
513 | inputs = torch.zeros(batch_size, self.input_size).cuda()
514 | log_probs = []
515 | actions = []
516 | total_logits = self.forward(inputs)
517 |
518 | cumsum = np.cumsum([0]+structure_choice)
519 | for idx in range(len(structure_choice)):
520 | logits = total_logits[:, cumsum[idx]:cumsum[idx+1]]
521 |
522 | probs = F.softmax(logits, dim=-1)
523 | log_prob = F.log_softmax(logits, dim=-1)
524 | # print(probs)
525 | if step>=args.random_step:
526 | action = probs.multinomial(num_samples=1).data
527 | else:
528 | action = torch.randint(0, structure_choice[idx], size=(batch_size, 1)).cuda()
529 | # print('old', action)
530 | if args.greedy is not 0:
531 | for k in range(args.bs):
532 | if np.random.rand(1) hyperparams['discount'] > 0:
681 | rewards = discount(rewards, hyperparams['discount'])
682 |
683 | base = args.base
684 | rewards[rewards > base] = base
685 | rewards[rewards != rewards] = 1e10
686 | error = rewards
687 | rewards = 1 / (1 + torch.sqrt(rewards))
688 |
689 | batch_smallest = error.min()
690 | batch_min_idx = torch.argmin(error)
691 | batch_min_action = [v[batch_min_idx] for v in actions]
692 |
693 | batch_best_formula = formulas[batch_min_idx]
694 |
695 | candidates.add_new(candidate(action=batch_min_action, expression=batch_best_formula, error=batch_smallest))
696 |
697 | for candidate_ in candidates.candidates:
698 | print('error:{} action:{} formula:{}'.format(candidate_.error.item(), [v.item() for v in candidate_.action], candidate_.expression))
699 |
700 | # moving average baseline
701 | if baseline is None:
702 | baseline = (rewards).mean()
703 | else:
704 | decay = hyperparams['ema_baseline_decay']
705 | baseline = decay * baseline + (1 - decay) * (rewards).mean()
706 |
707 | argsort = torch.argsort(rewards.squeeze(1), descending=True)
708 | # print(error, argsort)
709 | # print(rewards.size(), rewards.squeeze(1), torch.argsort(rewards.squeeze(1)), rewards[argsort])
710 | # policy loss
711 | num = int(args.bs * args.percentile)
712 | rewards_sort = rewards[argsort]
713 | adv = rewards_sort - rewards_sort[num:num + 1, 0:] # - baseline
714 | # print(error, argsort, rewards_sort, adv)
715 | log_probs_sort = log_probs[argsort]
716 | # print('adv', adv)
717 | loss = -log_probs_sort[:num] * tools.get_variable(adv[:num], True, requires_grad=False)
718 | loss = (loss.sum(1)).mean()
719 |
720 | # update
721 | controller_optim.zero_grad()
722 | loss.backward()
723 |
724 | if hyperparams['controller_grad_clip'] > 0:
725 | torch.nn.utils.clip_grad_norm(model.parameters(),
726 | hyperparams['controller_grad_clip'])
727 | Controller_optim.step()
728 |
729 | min_error = error.min().item()
730 | if smallest_error>min_error:
731 | smallest_error = min_error
732 |
733 | min_idx = torch.argmin(error)
734 | min_action = [v[min_idx] for v in actions]
735 | best_formula = formulas[min_idx]
736 |
737 |
738 | log = 'Step: {step}| Loss: {loss:.4f}| Action: {act} |Baseline: {base:.4f}| ' \
739 | 'Reward {re:.4f} | {error:.8f} {formula}'.format(loss=loss.item(), base=baseline, act=binary_code,
740 | re=(rewards).mean(), step=step, formula=best_formula,
741 | error=smallest_error)
742 | print('********************************************************************************************************')
743 | print(log)
744 | print('********************************************************************************************************')
745 | if (step + 1) % 1 == 0:
746 | logger.append([step + 1, loss.item(), baseline, rewards.mean(), smallest_error, best_formula])
747 |
748 | for candidate_ in candidates.candidates:
749 | print('error:{} action:{} formula:{}'.format(candidate_.error.item(), [v.item() for v in candidate_.action],
750 | candidate_.expression))
751 | logger.append([666, 0, 0, 0, candidate_.error.item(), candidate_.expression])
752 |
753 | finetune = 20000
754 | global count, leaves_cnt
755 | for candidate_ in candidates.candidates:
756 | trainable_tree = learnable_compuatation_tree()
757 | trainable_tree = trainable_tree.cuda()
758 |
759 | params = []
760 | for idx, v in enumerate(trainable_tree.learnable_operator_set):
761 | if idx not in leaves_index:
762 | for modules in trainable_tree.learnable_operator_set[v]:
763 | for param in modules.parameters():
764 | params.append(param)
765 | for module in trainable_tree.linear:
766 | for param in module.parameters():
767 | params.append(param)
768 |
769 | reset_params(params)
770 | tree_optim = torch.optim.Adam(params, lr=1e-2)
771 |
772 | for current_iter in range(finetune):
773 | error = best_error(candidate_.action, trainable_tree)
774 | tree_optim.zero_grad()
775 | error.backward()
776 |
777 | tree_optim.step()
778 |
779 | count = 0
780 | leaves_cnt = 0
781 | formula = inorder_visualize(basic_tree(), candidate_.action, trainable_tree)
782 | leaves_cnt = 0
783 | count = 0
784 | suffix = 'Finetune-- Iter {current_iter} Error {error:.5f} Formula {formula}'.format(current_iter=current_iter, error=error, formula=formula)
785 | if (current_iter + 1) % 100 == 0:
786 | logger.append([current_iter, 0, 0, 0, error.item(), formula])
787 |
788 | cosine_lr(tree_optim, 1e-2, current_iter, finetune)
789 | print(suffix)
790 |
791 | numerators = []
792 | denominators = []
793 |
794 | for i in range(1000):
795 | print(i)
796 | x = (torch.rand(100000, args.dim).cuda()) * (args.right - args.left) + args.left
797 | sq_de = torch.mean((func.true_solution(x))**2)
798 | sq_nu = torch.mean((func.true_solution(x)-trainable_tree(x, candidate_.action)) ** 2)
799 | numerators.append(sq_nu.item())
800 | denominators.append(sq_de.item())
801 |
802 | relative_l2 = math.sqrt(sum(numerators)) / math.sqrt(sum(denominators))
803 | print('relative l2 error: ', relative_l2)
804 | logger.append(['relative_l2', 0, 0, 0, relative_l2, 0])
805 |
806 | def cosine_lr(opt, base_lr, e, epochs):
807 | lr = 0.5 * base_lr * (math.cos(math.pi * e / epochs) + 1)
808 | for param_group in opt.param_groups:
809 | param_group["lr"] = lr
810 | return lr
811 |
812 | if __name__ == '__main__':
813 | controller = Controller().cuda()
814 | hyperparams = {}
815 |
816 | hyperparams['controller_max_step'] = args.epoch
817 | hyperparams['discount'] = 1.0
818 | hyperparams['ema_baseline_decay'] = 0.95
819 | hyperparams['controller_lr'] = args.lr
820 | hyperparams['entropy_mode'] = 'reward'
821 | hyperparams['controller_grad_clip'] = 0#10
822 | hyperparams['checkpoint'] = args.ckpt
823 | if not os.path.isdir(hyperparams['checkpoint']):
824 | mkdir_p(hyperparams['checkpoint'])
825 | controller_optim = torch.optim.Adam(controller.parameters(), lr= hyperparams['controller_lr'])
826 |
827 | trainable_tree = learnable_compuatation_tree()
828 | trainable_tree = trainable_tree.cuda()
829 |
830 | params = []
831 | for idx, v in enumerate(trainable_tree.learnable_operator_set):
832 | if idx not in leaves_index:
833 | for modules in trainable_tree.learnable_operator_set[v]:
834 | for param in modules.parameters():
835 | params.append(param)
836 | for module in trainable_tree.linear:
837 | for param in module.parameters():
838 | params.append(param)
839 |
840 | train_controller(controller, controller_optim, trainable_tree, params, hyperparams)
841 |
--------------------------------------------------------------------------------
/fex/Poisson/function.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import sin, cos, exp
4 | import math
5 |
6 | def LHS_pde(u, x, dim_set):
7 |
8 | v = torch.ones(u.shape).cuda()
9 | bs = x.size(0)
10 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0]
11 | uxx = torch.zeros(bs, dim_set).cuda()
12 | for i in range(dim_set):
13 | ux_tem = ux[:, i:i+1]
14 | uxx_tem = torch.autograd.grad(ux_tem, x, grad_outputs=v, create_graph=True)[0]
15 | uxx[:, i] = uxx_tem[:, i]
16 | LHS = -torch.sum(uxx, dim=1, keepdim=True)
17 | return LHS
18 |
19 | def RHS_pde(x):
20 | bs = x.size(0)
21 | dim = x.size(1)
22 | return -dim*torch.ones(bs, 1).cuda()
23 |
24 | def true_solution(x):
25 | return 0.5*torch.sum(x**2, dim=1, keepdim=True)#1 / (2 * x[:, 0:1] + x[:, 1:2]-5)
26 |
27 |
28 | unary_functions = [lambda x: 0*x**2,
29 | lambda x: 1+0*x**2,
30 | lambda x: x+0*x**2,
31 | lambda x: x**2,
32 | lambda x: x**3,
33 | lambda x: x**4,
34 | torch.exp,
35 | torch.sin,
36 | torch.cos,]
37 |
38 | binary_functions = [lambda x,y: x+y,
39 | lambda x,y: x*y,
40 | lambda x,y: x-y]
41 |
42 |
43 | unary_functions_str = ['({}*(0)+{})',
44 | '({}*(1)+{})',
45 | # '5',
46 | '({}*{}+{})',
47 | # '-{}',
48 | '({}*({})**2+{})',
49 | '({}*({})**3+{})',
50 | '({}*({})**4+{})',
51 | # '({})**5',
52 | '({}*exp({})+{})',
53 | '({}*sin({})+{})',
54 | '({}*cos({})+{})',]
55 | # 'ref({})',
56 | # 'exp(-({})**2/2)']
57 |
58 | unary_functions_str_leaf= ['(0)',
59 | '(1)',
60 | # '5',
61 | '({})',
62 | # '-{}',
63 | '(({})**2)',
64 | '(({})**3)',
65 | '(({})**4)',
66 | # '({})**5',
67 | '(exp({}))',
68 | '(sin({}))',
69 | '(cos({}))',]
70 |
71 |
72 | binary_functions_str = ['(({})+({}))',
73 | '(({})*({}))',
74 | '(({})-({}))']
75 |
76 | if __name__ == '__main__':
77 | batch_size = 200
78 | left = -1
79 | right = 1
80 | points = (torch.rand(batch_size, 1)) * (right - left) + left
81 | x = torch.autograd.Variable(points.cuda(), requires_grad=True)
82 | function = true_solution
83 |
84 | '''
85 | PDE loss
86 | '''
87 | LHS = LHS_pde(function(x), x)
88 | RHS = RHS_pde(x)
89 | pde_loss = torch.nn.functional.mse_loss(LHS, RHS)
90 |
91 | '''
92 | boundary loss
93 | '''
94 | bc_points = torch.FloatTensor([[left], [right]]).cuda()
95 | bc_value = true_solution(bc_points)
96 | bd_loss = torch.nn.functional.mse_loss(function(bc_points), bc_value)
97 |
98 | print('pde loss: {} -- boundary loss: {}'.format(pde_loss.item(), bd_loss.item()))
--------------------------------------------------------------------------------
/fex/Poisson/scripts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 |
5 | for i in range(6):
6 | for dim in [10, 20, 30, 40, 50]:
7 | gpu = random.randint(0,4)
8 | os.system('screen python controller_poisson.py --epoch 1000 --bs 10 --greedy 0.1 --gpu '+str(gpu)+' --ckpt Dim'+str(dim)+' --tree depth2_sub --random_step 3 --lr 0.002 --dim '+str(dim)+' --base 200000 --left -1 --right 1 --domainbs 5000 --bdbs 1000')
9 | time.sleep(100)
--------------------------------------------------------------------------------
/fex/Poisson/tools.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from collections import defaultdict
4 | import collections
5 | from datetime import datetime
6 | import os
7 | import json
8 | import logging
9 |
10 | import numpy as np
11 | import torch
12 | from torch.autograd import Variable
13 |
14 | ##########################
15 | # Torch
16 | ##########################
17 |
18 | def detach(h):
19 | if type(h) == Variable:
20 | return Variable(h.data)
21 | else:
22 | return tuple(detach(v) for v in h)
23 |
24 | def get_variable(inputs, cuda=False, **kwargs):
25 | if type(inputs) in [list, np.ndarray]:
26 | inputs = torch.Tensor(inputs)
27 | if cuda:
28 | out = Variable(inputs.cuda(), **kwargs)
29 | else:
30 | out = Variable(inputs, **kwargs)
31 | return out
32 |
33 | def update_lr(optimizer, lr):
34 | for param_group in optimizer.param_groups:
35 | param_group['lr'] = lr
36 |
37 | def batchify(data, bsz, use_cuda):
38 | # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py
39 | nbatch = data.size(0) // bsz
40 | data = data.narrow(0, 0, nbatch * bsz)
41 | data = data.view(bsz, -1).t().contiguous()
42 | if use_cuda:
43 | data = data.cuda()
44 | return data
45 |
46 |
47 | ##########################
48 | # ETC
49 | ##########################
50 |
51 | Node = collections.namedtuple('Node', ['id', 'name'])
52 |
53 |
54 | class keydefaultdict(defaultdict):
55 | def __missing__(self, key):
56 | if self.default_factory is None:
57 | raise KeyError(key)
58 | else:
59 | ret = self[key] = self.default_factory(key)
60 | return ret
61 |
62 |
63 | def to_item(x):
64 | """Converts x, possibly scalar and possibly tensor, to a Python scalar."""
65 | if isinstance(x, (float, int)):
66 | return x
67 |
68 | if float(torch.__version__[0:3]) < 0.4:
69 | assert (x.dim() == 1) and (len(x) == 1)
70 | return x[0]
71 |
72 | return x.item()
73 |
74 |
75 | def get_logger(name=__file__, level=logging.INFO):
76 | logger = logging.getLogger(name)
77 |
78 | if getattr(logger, '_init_done__', None):
79 | logger.setLevel(level)
80 | return logger
81 |
82 | logger._init_done__ = True
83 | logger.propagate = False
84 | logger.setLevel(level)
85 |
86 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s")
87 | handler = logging.StreamHandler()
88 | handler.setFormatter(formatter)
89 | handler.setLevel(0)
90 |
91 | del logger.handlers[:]
92 | logger.addHandler(handler)
93 |
94 | return logger
95 |
96 |
97 | logger = get_logger()
98 |
99 |
100 | def prepare_dirs(args):
101 | """Sets the directories for the model, and creates those directories.
102 |
103 | Args:
104 | args: Parsed from `argparse` in the `config` module.
105 | """
106 | if args.load_path:
107 | if args.load_path.startswith(args.log_dir):
108 | args.model_dir = args.load_path
109 | else:
110 | if args.load_path.startswith(args.dataset):
111 | args.model_name = args.load_path
112 | else:
113 | args.model_name = "{}_{}".format(args.dataset, args.load_path)
114 | else:
115 | args.model_name = "{}_{}".format(args.dataset, get_time())
116 |
117 | if not hasattr(args, 'model_dir'):
118 | args.model_dir = os.path.join(args.log_dir, args.model_name)
119 | args.data_path = os.path.join(args.data_dir, args.dataset)
120 |
121 | for path in [args.log_dir, args.data_dir, args.model_dir]:
122 | if not os.path.exists(path):
123 | makedirs(path)
124 |
125 | def get_time():
126 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
127 |
128 | def save_args(args):
129 | param_path = os.path.join(args.model_dir, "params.json")
130 |
131 | logger.info("[*] MODEL dir: %s" % args.model_dir)
132 | logger.info("[*] PARAM path: %s" % param_path)
133 |
134 | with open(param_path, 'w') as fp:
135 | json.dump(args.__dict__, fp, indent=4, sort_keys=True)
136 |
137 | def save_dag(args, dag, name):
138 | save_path = os.path.join(args.model_dir, name)
139 | logger.info("[*] Save dag : {}".format(save_path))
140 | json.dump(dag, open(save_path, 'w'))
141 |
142 | def load_dag(args):
143 | load_path = os.path.join(args.dag_path)
144 | logger.info("[*] Load dag : {}".format(load_path))
145 | with open(load_path) as f:
146 | dag = json.load(f)
147 | dag = {int(k): [Node(el[0], el[1]) for el in v] for k, v in dag.items()}
148 | save_dag(args, dag, "dag.json")
149 | draw_network(dag, os.path.join(args.model_dir, "dag.png"))
150 | return dag
151 |
152 | def makedirs(path):
153 | if not os.path.exists(path):
154 | logger.info("[*] Make directories : {}".format(path))
155 | os.makedirs(path)
156 |
157 | def remove_file(path):
158 | if os.path.exists(path):
159 | logger.info("[*] Removed: {}".format(path))
160 | os.remove(path)
161 |
162 | def backup_file(path):
163 | root, ext = os.path.splitext(path)
164 | new_path = "{}.backup_{}{}".format(root, get_time(), ext)
165 |
166 | os.rename(path, new_path)
167 | logger.info("[*] {} has backup: {}".format(path, new_path))
168 |
--------------------------------------------------------------------------------
/fex/Poisson/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from .misc import *
4 | from .logger import *
5 | from .visualize import *
6 | from .eval import *
7 |
8 | # progress bar
9 | import os, sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
11 | # from progress.bar import Bar as Bar
--------------------------------------------------------------------------------
/fex/Poisson/utils/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | __all__ = ['accuracy']
4 |
5 | def accuracy(output, target, topk=(1,)):
6 | """Computes the precision@k for the specified values of k"""
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 |
10 | _, pred = output.topk(maxk, 1, True, True)
11 | pred = pred.t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 |
14 | res = []
15 | for k in topk:
16 | correct_k = correct[:k].view(-1).float().sum(0)
17 | res.append(correct_k.mul_(100.0 / batch_size))
18 | return res
--------------------------------------------------------------------------------
/fex/Poisson/utils/images/cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Poisson/utils/images/cifar.png
--------------------------------------------------------------------------------
/fex/Poisson/utils/images/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Poisson/utils/images/imagenet.png
--------------------------------------------------------------------------------
/fex/Poisson/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 | import matplotlib.pyplot as plt
5 | import os
6 | import sys
7 | import numpy as np
8 |
9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
10 |
11 | def savefig(fname, dpi=None):
12 | dpi = 150 if dpi == None else dpi
13 | plt.savefig(fname, dpi=dpi)
14 |
15 | def plot_overlap(logger, names=None):
16 | names = logger.names if names == None else names
17 | numbers = logger.numbers
18 | for _, name in enumerate(names):
19 | x = np.arange(len(numbers[name]))
20 | plt.plot(x, np.asarray(numbers[name]))
21 | return [logger.title + '(' + name + ')' for name in names]
22 |
23 | class Logger(object):
24 | '''Save training process to log file with simple plot function.'''
25 | def __init__(self, fpath, title=None, resume=False):
26 | self.file = None
27 | self.resume = resume
28 | self.title = '' if title == None else title
29 | if fpath is not None:
30 | if resume:
31 | self.file = open(fpath, 'r')
32 | name = self.file.readline()
33 | self.names = name.rstrip().split('\t')
34 | self.numbers = {}
35 | for _, name in enumerate(self.names):
36 | self.numbers[name] = []
37 |
38 | for numbers in self.file:
39 | numbers = numbers.rstrip().split('\t')
40 | for i in range(0, len(numbers)):
41 | self.numbers[self.names[i]].append(numbers[i])
42 | self.file.close()
43 | self.file = open(fpath, 'a')
44 | else:
45 | self.file = open(fpath, 'w')
46 |
47 | def set_names(self, names):
48 | if self.resume:
49 | pass
50 | # initialize numbers as empty list
51 | self.numbers = {}
52 | self.names = names
53 | for _, name in enumerate(self.names):
54 | self.file.write(name)
55 | self.file.write('\t')
56 | self.numbers[name] = []
57 | self.file.write('\n')
58 | self.file.flush()
59 |
60 | def append(self, numbers):
61 | assert len(self.names) == len(numbers), 'Numbers do not match names'
62 | for index, num in enumerate(numbers):
63 | if isinstance(num, int):
64 | self.file.write("{}".format(num))
65 | elif isinstance(num, str):
66 | self.file.write("{}".format(num))
67 | else:
68 | self.file.write("{0:.8f}".format(num))
69 | self.file.write('\t')
70 | self.numbers[self.names[index]].append(num)
71 | self.file.write('\n')
72 | self.file.flush()
73 |
74 | def plot(self, names=None):
75 | names = self.names if names == None else names
76 | numbers = self.numbers
77 | for _, name in enumerate(names):
78 | x = np.arange(len(numbers[name]))
79 | plt.plot(x, np.asarray(numbers[name]))
80 | plt.legend([self.title + '(' + name + ')' for name in names])
81 | plt.grid(True)
82 |
83 | def close(self):
84 | if self.file is not None:
85 | self.file.close()
86 |
87 | class LoggerMonitor(object):
88 | '''Load and visualize multiple logs.'''
89 | def __init__ (self, paths):
90 | '''paths is a distionary with {name:filepath} pair'''
91 | self.loggers = []
92 | for title, path in paths.items():
93 | logger = Logger(path, title=title, resume=True)
94 | self.loggers.append(logger)
95 |
96 | def plot(self, names=None):
97 | plt.figure()
98 | plt.subplot(121)
99 | legend_text = []
100 | for logger in self.loggers:
101 | legend_text += plot_overlap(logger, names)
102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
103 | plt.grid(True)
104 |
105 | if __name__ == '__main__':
106 | # # Example
107 | # logger = Logger('test.txt')
108 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
109 |
110 | # length = 100
111 | # t = np.arange(length)
112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
115 |
116 | # for i in range(0, length):
117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
118 | # logger.plot()
119 |
120 | # Example: logger monitor
121 | paths = {
122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
125 | }
126 |
127 | field = ['Valid Acc.']
128 |
129 | monitor = LoggerMonitor(paths)
130 | monitor.plot(names=field)
131 | savefig('test.eps')
--------------------------------------------------------------------------------
/fex/Poisson/utils/misc.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import errno
7 | import os
8 | import sys
9 | import time
10 | import math
11 |
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | from torch.autograd import Variable
15 |
16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']
17 |
18 |
19 | def get_mean_and_std(dataset):
20 | '''Compute the mean and std value of dataset.'''
21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
22 |
23 | mean = torch.zeros(3)
24 | std = torch.zeros(3)
25 | print('==> Computing mean and std..')
26 | for inputs, targets in dataloader:
27 | for i in range(3):
28 | mean[i] += inputs[:,i,:,:].mean()
29 | std[i] += inputs[:,i,:,:].std()
30 | mean.div_(len(dataset))
31 | std.div_(len(dataset))
32 | return mean, std
33 |
34 | def init_params(net):
35 | '''Init layer parameters.'''
36 | for m in net.modules():
37 | if isinstance(m, nn.Conv2d):
38 | init.kaiming_normal(m.weight, mode='fan_out')
39 | if m.bias:
40 | init.constant(m.bias, 0)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.constant(m.weight, 1)
43 | init.constant(m.bias, 0)
44 | elif isinstance(m, nn.Linear):
45 | init.normal(m.weight, std=1e-3)
46 | if m.bias:
47 | init.constant(m.bias, 0)
48 |
49 | def mkdir_p(path):
50 | '''make dir if not exist'''
51 | try:
52 | os.makedirs(path)
53 | except OSError as exc: # Python >2.5
54 | if exc.errno == errno.EEXIST and os.path.isdir(path):
55 | pass
56 | else:
57 | raise
58 |
59 | class AverageMeter(object):
60 | """Computes and stores the average and current value
61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
62 | """
63 | def __init__(self):
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = 0
68 | self.avg = 0
69 | self.sum = 0
70 | self.count = 0
71 |
72 | def update(self, val, n=1):
73 | self.val = val
74 | self.sum += val * n
75 | self.count += n
76 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/fex/Poisson/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | from .misc import *
8 |
9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
10 |
11 | # functions to show an image
12 | def make_image(img, mean=(0,0,0), std=(1,1,1)):
13 | for i in range(0, 3):
14 | img[i] = img[i] * std[i] + mean[i] # unnormalize
15 | npimg = img.numpy()
16 | return np.transpose(npimg, (1, 2, 0))
17 |
18 | def gauss(x,a,b,c):
19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
20 |
21 | def colorize(x):
22 | ''' Converts a one-channel grayscale image to a color heatmap image '''
23 | if x.dim() == 2:
24 | torch.unsqueeze(x, 0, out=x)
25 | if x.dim() == 3:
26 | cl = torch.zeros([3, x.size(1), x.size(2)])
27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
28 | cl[1] = gauss(x,1,.5,.3)
29 | cl[2] = gauss(x,1,.2,.3)
30 | cl[cl.gt(1)] = 1
31 | elif x.dim() == 4:
32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
34 | cl[:,1,:,:] = gauss(x,1,.5,.3)
35 | cl[:,2,:,:] = gauss(x,1,.2,.3)
36 | return cl
37 |
38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
40 | plt.imshow(images)
41 | plt.show()
42 |
43 |
44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
45 | im_size = images.size(2)
46 |
47 | # save for adding mask
48 | im_data = images.clone()
49 | for i in range(0, 3):
50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
51 |
52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
53 | plt.subplot(2, 1, 1)
54 | plt.imshow(images)
55 | plt.axis('off')
56 |
57 | # for b in range(mask.size(0)):
58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
59 | mask_size = mask.size(2)
60 | # print('Max %f Min %f' % (mask.max(), mask.min()))
61 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
63 | # for c in range(3):
64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
65 |
66 | # print(mask.size())
67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
69 | plt.subplot(2, 1, 2)
70 | plt.imshow(mask)
71 | plt.axis('off')
72 |
73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
74 | im_size = images.size(2)
75 |
76 | # save for adding mask
77 | im_data = images.clone()
78 | for i in range(0, 3):
79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
80 |
81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
82 | plt.subplot(1+len(masklist), 1, 1)
83 | plt.imshow(images)
84 | plt.axis('off')
85 |
86 | for i in range(len(masklist)):
87 | mask = masklist[i].data.cpu()
88 | # for b in range(mask.size(0)):
89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
90 | mask_size = mask.size(2)
91 | # print('Max %f Min %f' % (mask.max(), mask.min()))
92 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
94 | # for c in range(3):
95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
96 |
97 | # print(mask.size())
98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
100 | plt.subplot(1+len(masklist), 1, i+2)
101 | plt.imshow(mask)
102 | plt.axis('off')
103 |
104 |
105 |
106 | # x = torch.zeros(1, 3, 3)
107 | # out = colorize(x)
108 | # out_im = make_image(out)
109 | # plt.imshow(out_im)
110 | # plt.show()
--------------------------------------------------------------------------------
/fex/Schrodinger/computational_tree.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import function as func
3 | unary = func.unary_functions
4 | binary = func.binary_functions
5 | unary_functions_str = func.unary_functions_str
6 | binary_functions_str = func.binary_functions_str
7 |
8 | class BinaryTree(object):
9 | def __init__(self,item,is_unary=True):
10 | self.key=item
11 | self.is_unary=is_unary
12 | self.leftChild=None
13 | self.rightChild=None
14 | def insertLeft(self,item, is_unary=True):
15 | if self.leftChild==None:
16 | self.leftChild=BinaryTree(item, is_unary)
17 | else:
18 | t=BinaryTree(item)
19 | t.leftChild=self.leftChild
20 | self.leftChild=t
21 | def insertRight(self,item, is_unary=True):
22 | if self.rightChild==None:
23 | self.rightChild=BinaryTree(item, is_unary)
24 | else:
25 | t=BinaryTree(item)
26 | t.rightChild=self.rightChild
27 | self.rightChild=t
28 |
29 | def compute_by_tree(tree, x):
30 | ''' judge whether a emtpy tree, if yes, that means the leaves and call the unary operation '''
31 | if tree.leftChild == None and tree.rightChild == None:
32 | return tree.key(x)
33 | elif tree.leftChild == None and tree.rightChild is not None:
34 | return tree.key(compute_by_tree(tree.rightChild, x))
35 | elif tree.leftChild is not None and tree.rightChild == None:
36 | return tree.key(compute_by_tree(tree.leftChild, x))
37 | else:
38 | return tree.key(compute_by_tree(tree.leftChild, x), compute_by_tree(tree.rightChild, x))
39 |
40 | def inorder(tree, actions):
41 | global count
42 | if tree:
43 | inorder(tree.leftChild, actions)
44 | action = actions[count].item()
45 | if tree.is_unary:
46 | action = action
47 | tree.key = unary[action]
48 | # print(count, action, func.unary_functions_str[action])
49 | else:
50 | action = action
51 | tree.key = binary[action]
52 | # print(count, action, func.binary_functions_str[action])
53 | count = count + 1
54 | inorder(tree.rightChild, actions)
55 |
56 | # def inorder(tree):
57 | # if tree:
58 | # inorder(tree.leftChild)
59 | # print(tree.key)
60 | # inorder(tree.rightChild)
61 |
62 | count = 0
63 | def inorder_w_idx(tree):
64 | global count
65 | if tree:
66 | inorder_w_idx(tree.leftChild)
67 | print(tree.key, count)
68 | count = count + 1
69 | inorder_w_idx(tree.rightChild)
70 |
71 | def basic_tree():
72 | tree = BinaryTree('', False)
73 | tree.insertLeft('', False)
74 | tree.leftChild.insertLeft('', True)
75 | tree.leftChild.insertRight('', True)
76 | tree.insertRight('', False)
77 | tree.rightChild.insertLeft('', True)
78 | tree.rightChild.insertRight('', True)
79 | return tree
80 |
81 | def get_function(actions):
82 | global count
83 | count = 0
84 | computation_tree = basic_tree()
85 | inorder(computation_tree, actions)
86 | count = 0 # 置零
87 | return computation_tree
88 |
89 | def inorder_test(tree, actions):
90 | global count
91 | if tree:
92 | inorder(tree.leftChild, actions)
93 | action = actions[count].item()
94 | print(action)
95 | if tree.is_unary:
96 | action = action
97 | tree.key = unary[action]
98 | # print(count, action, func.unary_functions_str[action])
99 | else:
100 | action = action
101 | tree.key = binary[action]
102 | # print(count, action, func.binary_functions_str[action])
103 | count = count + 1
104 | inorder(tree.rightChild, actions)
105 |
106 | if __name__ =='__main__':
107 | # tree = BinaryTree(np.add)
108 | # tree.insertLeft(np.multiply)
109 | # tree.leftChild.insertLeft(np.cos)
110 | # tree.leftChild.insertRight(np.sin)
111 | # tree.insertRight(np.sin)
112 | # print(compute_by_tree(tree, 30)) # np.sin(30)*np.cos(30)+np.sin(30)
113 | # inorder(tree)
114 | # inorder_w_idx(tree)
115 | import torch
116 | bs_action = [torch.LongTensor([10]), torch.LongTensor([2]),torch.LongTensor([9]),torch.LongTensor([1]),torch.LongTensor([0]),torch.LongTensor([2]),torch.LongTensor([6])]
117 |
118 | function = lambda x: compute_by_tree(get_function(bs_action), x)
119 | x = torch.FloatTensor([[-1], [1]])
120 |
121 | count = 0
122 | tr = basic_tree()
123 | inorder_test(tr, bs_action)
124 | count = 0
125 |
126 |
--------------------------------------------------------------------------------
/fex/Schrodinger/function.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import sin, cos, exp
4 | import math
5 |
6 | c = 3
7 |
8 | def LHS_pde(u, x, dim_set):
9 |
10 | v = torch.ones(u.shape).cuda()
11 | bs = x.size(0)
12 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0]
13 | uxx = torch.zeros(bs, dim_set).cuda()
14 | for i in range(dim_set):
15 | ux_tem = ux[:, i].reshape([x.size()[0], 1])
16 | uxx_tem = torch.autograd.grad(ux_tem, x, grad_outputs=v, create_graph=True)[0]
17 | uxx[:, i] = uxx_tem[:, i]
18 |
19 | LHS = -torch.sum(uxx, dim=1, keepdim=True)
20 | V = -exp(2/dim_set*torch.sum(cos(x), dim=1, keepdim=True))/(c**2)+torch.sum(sin(x)**2/(dim_set**2), dim=1, keepdim=True)-torch.sum(cos(x)/(dim_set), dim=1, keepdim=True)
21 | return LHS+u**3+V*u#ux+uy#uxx+uyy
22 |
23 | # def RHS_pde(x):
24 | # # return -torch.sin(x)
25 | # # return 2*torch.cos(x)-x*torch.sin(x)
26 | # return torch.sin(x)# 2*torch.cos(x**2)-4*x**2*torch.sin(x**2) #torch.sin(x)
27 | #
28 | # def true_solution(x):
29 | # return x-torch.sin(x)#torch.sin(x**2)#torch.sin(x**2)#torch.sin(x)+x # -1*x+0*x**2#
30 |
31 | def RHS_pde(x):
32 | # return 3 * torch.cos(2 * x[:, 0:1] + x[:, 1:2])
33 | # return 3 * torch.cos(2 * x[:, 0:1] + x[:, 1:2]) + 3*x[:, 0:1]
34 | # return 3 * torch.exp(2 * x[:, 0:1] + x[:, 1:2])
35 | # dim = x.size(1)
36 | bs = x.size(0)
37 | return torch.zeros(bs, 1).cuda()
38 | # dim = x.size(1)
39 | # coefficient = 2 * math.pi * torch.ones(1, dim).cuda()
40 | # coefficient[:, 0] = 1
41 | # print(x.size(), coefficient.size(), coefficient)
42 | # return -dim*math.pi*torch.cos(torch.sum(x * coefficient, dim=1, keepdim=True))
43 |
44 | def true_solution(x):
45 | # return torch.sin(2*x[:,0:1]+x[:,1:2])
46 | # return torch.sin(2 * x[:, 0:1] + x[:, 1:2]) + 1.5*x[:, 0:1]**2
47 | dim = x.size(1)
48 | # coefficient[:,0] = 1
49 | # print(x.size(), coefficient.size(), coefficient)
50 | return exp(1/dim*torch.sum(cos(x), dim=1, keepdim=True))/(c)
51 |
52 |
53 |
54 | # def RHS_pde(x):
55 | # # return -torch.sin(x)
56 | # # return 2*torch.cos(x)-x*torch.sin(x)
57 | # return torch.sin(x)+2# 2*torch.cos(x**2)-4*x**2*torch.sin(x**2) #torch.sin(x)
58 | #
59 | # def true_solution(x):
60 | # return x-torch.sin(x)+x**2
61 |
62 | # unary_functions = [lambda x: 0*x**2,
63 | # lambda x: 1+0*x**2,
64 | # # lambda x: 5+0*x**2,
65 | # lambda x: x+0*x**2,
66 | # # lambda x: -x+0*x**2,
67 | # lambda x: x**2,
68 | # lambda x: x**3,
69 | # lambda x: x**4,
70 | # # lambda x: x**5,
71 | # torch.exp,
72 | # torch.sin,
73 | # torch.cos,]
74 | # # torch.erf,
75 | # # lambda x: torch.exp(-x**2/2)]
76 |
77 | unary_functions = [lambda x: 0*x**2,
78 | lambda x: 1+0*x**2,
79 | lambda x: x+0*x**2,
80 | lambda x: x**2,
81 | lambda x: x**3,
82 | lambda x: x**4,
83 | torch.exp,
84 | torch.sin,
85 | torch.cos,]
86 |
87 | binary_functions = [lambda x,y: x+y,
88 | lambda x,y: x*y,
89 | lambda x,y: x-y]
90 |
91 |
92 | unary_functions_str = ['({}*(0)+{})',
93 | '({}*(1)+{})',
94 | # '5',
95 | '({}*{}+{})',
96 | # '-{}',
97 | '({}*({})**2+{})',
98 | '({}*({})**3+{})',
99 | '({}*({})**4+{})',
100 | # '({})**5',
101 | '({}*exp({})+{})',
102 | '({}*sin({})+{})',
103 | '({}*cos({})+{})',]
104 | # 'ref({})',
105 | # 'exp(-({})**2/2)']
106 |
107 | unary_functions_str_leaf= ['(0)',
108 | '(1)',
109 | # '5',
110 | '({})',
111 | # '-{}',
112 | '(({})**2)',
113 | '(({})**3)',
114 | '(({})**4)',
115 | # '({})**5',
116 | '(exp({}))',
117 | '(sin({}))',
118 | '(cos({}))',]
119 |
120 |
121 | binary_functions_str = ['(({})+({}))',
122 | '(({})*({}))',
123 | '(({})-({}))']
124 |
125 | if __name__ == '__main__':
126 | batch_size = 200
127 | left = -1
128 | right = 1
129 | points = (torch.rand(batch_size, 1)) * (right - left) + left
130 | x = torch.autograd.Variable(points.cuda(), requires_grad=True)
131 | function = true_solution
132 |
133 | '''
134 | PDE loss
135 | '''
136 | LHS = LHS_pde(function(x), x)
137 | RHS = RHS_pde(x)
138 | pde_loss = torch.nn.functional.mse_loss(LHS, RHS)
139 |
140 | '''
141 | boundary loss
142 | '''
143 | bc_points = torch.FloatTensor([[left], [right]]).cuda()
144 | bc_value = true_solution(bc_points)
145 | bd_loss = torch.nn.functional.mse_loss(function(bc_points), bc_value)
146 |
147 | print('pde loss: {} -- boundary loss: {}'.format(pde_loss.item(), bd_loss.item()))
--------------------------------------------------------------------------------
/fex/Schrodinger/scripts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 |
5 |
6 | for i in range(2):
7 | for dim in [6, 12, 18, 24, 30]:
8 | gpu = random.randint(1,9)
9 | os.system('screen python controller_cubic_sh_firstdeflation_thenintegral.py --epoch 1000 --bs 10 --greedy 0.1 --gpu '+str(gpu)+' --ckpt Firstdeflat_thenintegral_epoch1k_Dim'+str(dim)+' --tree depth2_sub --random_step 3 --lr 0.002 --dim '+str(dim)+' --base 200000 --left -1 --right 1 --domainbs 2000 --intbs 10000')
10 | time.sleep(100)
--------------------------------------------------------------------------------
/fex/Schrodinger/tools.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from collections import defaultdict
4 | import collections
5 | from datetime import datetime
6 | import os
7 | import json
8 | import logging
9 |
10 | import numpy as np
11 | import torch
12 | from torch.autograd import Variable
13 |
14 | ##########################
15 | # Torch
16 | ##########################
17 |
18 | def detach(h):
19 | if type(h) == Variable:
20 | return Variable(h.data)
21 | else:
22 | return tuple(detach(v) for v in h)
23 |
24 | def get_variable(inputs, cuda=False, **kwargs):
25 | if type(inputs) in [list, np.ndarray]:
26 | inputs = torch.Tensor(inputs)
27 | if cuda:
28 | out = Variable(inputs.cuda(), **kwargs)
29 | else:
30 | out = Variable(inputs, **kwargs)
31 | return out
32 |
33 | def update_lr(optimizer, lr):
34 | for param_group in optimizer.param_groups:
35 | param_group['lr'] = lr
36 |
37 | def batchify(data, bsz, use_cuda):
38 | # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py
39 | nbatch = data.size(0) // bsz
40 | data = data.narrow(0, 0, nbatch * bsz)
41 | data = data.view(bsz, -1).t().contiguous()
42 | if use_cuda:
43 | data = data.cuda()
44 | return data
45 |
46 |
47 | ##########################
48 | # ETC
49 | ##########################
50 |
51 | Node = collections.namedtuple('Node', ['id', 'name'])
52 |
53 |
54 | class keydefaultdict(defaultdict):
55 | def __missing__(self, key):
56 | if self.default_factory is None:
57 | raise KeyError(key)
58 | else:
59 | ret = self[key] = self.default_factory(key)
60 | return ret
61 |
62 |
63 | def to_item(x):
64 | """Converts x, possibly scalar and possibly tensor, to a Python scalar."""
65 | if isinstance(x, (float, int)):
66 | return x
67 |
68 | if float(torch.__version__[0:3]) < 0.4:
69 | assert (x.dim() == 1) and (len(x) == 1)
70 | return x[0]
71 |
72 | return x.item()
73 |
74 |
75 | def get_logger(name=__file__, level=logging.INFO):
76 | logger = logging.getLogger(name)
77 |
78 | if getattr(logger, '_init_done__', None):
79 | logger.setLevel(level)
80 | return logger
81 |
82 | logger._init_done__ = True
83 | logger.propagate = False
84 | logger.setLevel(level)
85 |
86 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s")
87 | handler = logging.StreamHandler()
88 | handler.setFormatter(formatter)
89 | handler.setLevel(0)
90 |
91 | del logger.handlers[:]
92 | logger.addHandler(handler)
93 |
94 | return logger
95 |
96 |
97 | logger = get_logger()
98 |
99 |
100 | def prepare_dirs(args):
101 | """Sets the directories for the model, and creates those directories.
102 |
103 | Args:
104 | args: Parsed from `argparse` in the `config` module.
105 | """
106 | if args.load_path:
107 | if args.load_path.startswith(args.log_dir):
108 | args.model_dir = args.load_path
109 | else:
110 | if args.load_path.startswith(args.dataset):
111 | args.model_name = args.load_path
112 | else:
113 | args.model_name = "{}_{}".format(args.dataset, args.load_path)
114 | else:
115 | args.model_name = "{}_{}".format(args.dataset, get_time())
116 |
117 | if not hasattr(args, 'model_dir'):
118 | args.model_dir = os.path.join(args.log_dir, args.model_name)
119 | args.data_path = os.path.join(args.data_dir, args.dataset)
120 |
121 | for path in [args.log_dir, args.data_dir, args.model_dir]:
122 | if not os.path.exists(path):
123 | makedirs(path)
124 |
125 | def get_time():
126 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
127 |
128 | def save_args(args):
129 | param_path = os.path.join(args.model_dir, "params.json")
130 |
131 | logger.info("[*] MODEL dir: %s" % args.model_dir)
132 | logger.info("[*] PARAM path: %s" % param_path)
133 |
134 | with open(param_path, 'w') as fp:
135 | json.dump(args.__dict__, fp, indent=4, sort_keys=True)
136 |
137 | def save_dag(args, dag, name):
138 | save_path = os.path.join(args.model_dir, name)
139 | logger.info("[*] Save dag : {}".format(save_path))
140 | json.dump(dag, open(save_path, 'w'))
141 |
142 | def load_dag(args):
143 | load_path = os.path.join(args.dag_path)
144 | logger.info("[*] Load dag : {}".format(load_path))
145 | with open(load_path) as f:
146 | dag = json.load(f)
147 | dag = {int(k): [Node(el[0], el[1]) for el in v] for k, v in dag.items()}
148 | save_dag(args, dag, "dag.json")
149 | draw_network(dag, os.path.join(args.model_dir, "dag.png"))
150 | return dag
151 |
152 | def makedirs(path):
153 | if not os.path.exists(path):
154 | logger.info("[*] Make directories : {}".format(path))
155 | os.makedirs(path)
156 |
157 | def remove_file(path):
158 | if os.path.exists(path):
159 | logger.info("[*] Removed: {}".format(path))
160 | os.remove(path)
161 |
162 | def backup_file(path):
163 | root, ext = os.path.splitext(path)
164 | new_path = "{}.backup_{}{}".format(root, get_time(), ext)
165 |
166 | os.rename(path, new_path)
167 | logger.info("[*] {} has backup: {}".format(path, new_path))
168 |
--------------------------------------------------------------------------------
/fex/Schrodinger/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from .misc import *
4 | from .logger import *
5 | from .visualize import *
6 | from .eval import *
7 |
8 | # progress bar
9 | import os, sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
11 | # from progress.bar import Bar as Bar
--------------------------------------------------------------------------------
/fex/Schrodinger/utils/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | __all__ = ['accuracy']
4 |
5 | def accuracy(output, target, topk=(1,)):
6 | """Computes the precision@k for the specified values of k"""
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 |
10 | _, pred = output.topk(maxk, 1, True, True)
11 | pred = pred.t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 |
14 | res = []
15 | for k in topk:
16 | correct_k = correct[:k].view(-1).float().sum(0)
17 | res.append(correct_k.mul_(100.0 / batch_size))
18 | return res
--------------------------------------------------------------------------------
/fex/Schrodinger/utils/images/cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Schrodinger/utils/images/cifar.png
--------------------------------------------------------------------------------
/fex/Schrodinger/utils/images/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fex/Schrodinger/utils/images/imagenet.png
--------------------------------------------------------------------------------
/fex/Schrodinger/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 | import matplotlib.pyplot as plt
5 | import os
6 | import sys
7 | import numpy as np
8 |
9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
10 |
11 | def savefig(fname, dpi=None):
12 | dpi = 150 if dpi == None else dpi
13 | plt.savefig(fname, dpi=dpi)
14 |
15 | def plot_overlap(logger, names=None):
16 | names = logger.names if names == None else names
17 | numbers = logger.numbers
18 | for _, name in enumerate(names):
19 | x = np.arange(len(numbers[name]))
20 | plt.plot(x, np.asarray(numbers[name]))
21 | return [logger.title + '(' + name + ')' for name in names]
22 |
23 | class Logger(object):
24 | '''Save training process to log file with simple plot function.'''
25 | def __init__(self, fpath, title=None, resume=False):
26 | self.file = None
27 | self.resume = resume
28 | self.title = '' if title == None else title
29 | if fpath is not None:
30 | if resume:
31 | self.file = open(fpath, 'r')
32 | name = self.file.readline()
33 | self.names = name.rstrip().split('\t')
34 | self.numbers = {}
35 | for _, name in enumerate(self.names):
36 | self.numbers[name] = []
37 |
38 | for numbers in self.file:
39 | numbers = numbers.rstrip().split('\t')
40 | for i in range(0, len(numbers)):
41 | self.numbers[self.names[i]].append(numbers[i])
42 | self.file.close()
43 | self.file = open(fpath, 'a')
44 | else:
45 | self.file = open(fpath, 'w')
46 |
47 | def set_names(self, names):
48 | if self.resume:
49 | pass
50 | # initialize numbers as empty list
51 | self.numbers = {}
52 | self.names = names
53 | for _, name in enumerate(self.names):
54 | self.file.write(name)
55 | self.file.write('\t')
56 | self.numbers[name] = []
57 | self.file.write('\n')
58 | self.file.flush()
59 |
60 | def append(self, numbers):
61 | assert len(self.names) == len(numbers), 'Numbers do not match names'
62 | for index, num in enumerate(numbers):
63 | if isinstance(num, int):
64 | self.file.write("{}".format(num))
65 | elif isinstance(num, str):
66 | self.file.write("{}".format(num))
67 | else:
68 | self.file.write("{0:.8f}".format(num))
69 | self.file.write('\t')
70 | self.numbers[self.names[index]].append(num)
71 | self.file.write('\n')
72 | self.file.flush()
73 |
74 | def plot(self, names=None):
75 | names = self.names if names == None else names
76 | numbers = self.numbers
77 | for _, name in enumerate(names):
78 | x = np.arange(len(numbers[name]))
79 | plt.plot(x, np.asarray(numbers[name]))
80 | plt.legend([self.title + '(' + name + ')' for name in names])
81 | plt.grid(True)
82 |
83 | def close(self):
84 | if self.file is not None:
85 | self.file.close()
86 |
87 | class LoggerMonitor(object):
88 | '''Load and visualize multiple logs.'''
89 | def __init__ (self, paths):
90 | '''paths is a distionary with {name:filepath} pair'''
91 | self.loggers = []
92 | for title, path in paths.items():
93 | logger = Logger(path, title=title, resume=True)
94 | self.loggers.append(logger)
95 |
96 | def plot(self, names=None):
97 | plt.figure()
98 | plt.subplot(121)
99 | legend_text = []
100 | for logger in self.loggers:
101 | legend_text += plot_overlap(logger, names)
102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
103 | plt.grid(True)
104 |
105 | if __name__ == '__main__':
106 | # # Example
107 | # logger = Logger('test.txt')
108 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
109 |
110 | # length = 100
111 | # t = np.arange(length)
112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
115 |
116 | # for i in range(0, length):
117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
118 | # logger.plot()
119 |
120 | # Example: logger monitor
121 | paths = {
122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
125 | }
126 |
127 | field = ['Valid Acc.']
128 |
129 | monitor = LoggerMonitor(paths)
130 | monitor.plot(names=field)
131 | savefig('test.eps')
--------------------------------------------------------------------------------
/fex/Schrodinger/utils/misc.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import errno
7 | import os
8 | import sys
9 | import time
10 | import math
11 |
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | from torch.autograd import Variable
15 |
16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']
17 |
18 |
19 | def get_mean_and_std(dataset):
20 | '''Compute the mean and std value of dataset.'''
21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
22 |
23 | mean = torch.zeros(3)
24 | std = torch.zeros(3)
25 | print('==> Computing mean and std..')
26 | for inputs, targets in dataloader:
27 | for i in range(3):
28 | mean[i] += inputs[:,i,:,:].mean()
29 | std[i] += inputs[:,i,:,:].std()
30 | mean.div_(len(dataset))
31 | std.div_(len(dataset))
32 | return mean, std
33 |
34 | def init_params(net):
35 | '''Init layer parameters.'''
36 | for m in net.modules():
37 | if isinstance(m, nn.Conv2d):
38 | init.kaiming_normal(m.weight, mode='fan_out')
39 | if m.bias:
40 | init.constant(m.bias, 0)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.constant(m.weight, 1)
43 | init.constant(m.bias, 0)
44 | elif isinstance(m, nn.Linear):
45 | init.normal(m.weight, std=1e-3)
46 | if m.bias:
47 | init.constant(m.bias, 0)
48 |
49 | def mkdir_p(path):
50 | '''make dir if not exist'''
51 | try:
52 | os.makedirs(path)
53 | except OSError as exc: # Python >2.5
54 | if exc.errno == errno.EEXIST and os.path.isdir(path):
55 | pass
56 | else:
57 | raise
58 |
59 | class AverageMeter(object):
60 | """Computes and stores the average and current value
61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
62 | """
63 | def __init__(self):
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = 0
68 | self.avg = 0
69 | self.sum = 0
70 | self.count = 0
71 |
72 | def update(self, val, n=1):
73 | self.val = val
74 | self.sum += val * n
75 | self.count += n
76 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/fex/Schrodinger/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | from .misc import *
8 |
9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
10 |
11 | # functions to show an image
12 | def make_image(img, mean=(0,0,0), std=(1,1,1)):
13 | for i in range(0, 3):
14 | img[i] = img[i] * std[i] + mean[i] # unnormalize
15 | npimg = img.numpy()
16 | return np.transpose(npimg, (1, 2, 0))
17 |
18 | def gauss(x,a,b,c):
19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
20 |
21 | def colorize(x):
22 | ''' Converts a one-channel grayscale image to a color heatmap image '''
23 | if x.dim() == 2:
24 | torch.unsqueeze(x, 0, out=x)
25 | if x.dim() == 3:
26 | cl = torch.zeros([3, x.size(1), x.size(2)])
27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
28 | cl[1] = gauss(x,1,.5,.3)
29 | cl[2] = gauss(x,1,.2,.3)
30 | cl[cl.gt(1)] = 1
31 | elif x.dim() == 4:
32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
34 | cl[:,1,:,:] = gauss(x,1,.5,.3)
35 | cl[:,2,:,:] = gauss(x,1,.2,.3)
36 | return cl
37 |
38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
40 | plt.imshow(images)
41 | plt.show()
42 |
43 |
44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
45 | im_size = images.size(2)
46 |
47 | # save for adding mask
48 | im_data = images.clone()
49 | for i in range(0, 3):
50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
51 |
52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
53 | plt.subplot(2, 1, 1)
54 | plt.imshow(images)
55 | plt.axis('off')
56 |
57 | # for b in range(mask.size(0)):
58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
59 | mask_size = mask.size(2)
60 | # print('Max %f Min %f' % (mask.max(), mask.min()))
61 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
63 | # for c in range(3):
64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
65 |
66 | # print(mask.size())
67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
69 | plt.subplot(2, 1, 2)
70 | plt.imshow(mask)
71 | plt.axis('off')
72 |
73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
74 | im_size = images.size(2)
75 |
76 | # save for adding mask
77 | im_data = images.clone()
78 | for i in range(0, 3):
79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
80 |
81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
82 | plt.subplot(1+len(masklist), 1, 1)
83 | plt.imshow(images)
84 | plt.axis('off')
85 |
86 | for i in range(len(masklist)):
87 | mask = masklist[i].data.cpu()
88 | # for b in range(mask.size(0)):
89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
90 | mask_size = mask.size(2)
91 | # print('Max %f Min %f' % (mask.max(), mask.min()))
92 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
94 | # for c in range(3):
95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
96 |
97 | # print(mask.size())
98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
100 | plt.subplot(1+len(masklist), 1, i+2)
101 | plt.imshow(mask)
102 | plt.axis('off')
103 |
104 |
105 |
106 | # x = torch.zeros(1, 3, 3)
107 | # out = colorize(x)
108 | # out_im = make_image(out)
109 | # plt.imshow(out_im)
110 | # plt.show()
--------------------------------------------------------------------------------
/fexrl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeungSamWai/Finite-expression-method/2206549622c053e6d6a4eafc58662aea138f3219/fexrl.png
--------------------------------------------------------------------------------
/nn/Conservationlaw/scrpts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 |
5 | for i in range(6):
6 | for dim in [6, 12, 18, 24, 30]:
7 | gpu = random.randint(1,9)
8 | os.system('screen python train.py --gpu-id '+str(gpu)+' --trainbs 5000 --bdbs 1000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts_ReLU_pidiv4/dim'+str(dim)+'_trial'+str(i)+' --weight 100 --left -1 --right 1')
9 | time.sleep(200)
--------------------------------------------------------------------------------
/nn/Conservationlaw/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import os
5 | import time
6 | import random
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.optim as optim
13 | import math
14 | import numpy as np
15 | from torch import sin, cos, exp
16 | import torch.nn.functional as F
17 | from utils import Logger, AverageMeter, mkdir_p
18 | from numpy.polynomial.legendre import leggauss
19 | # torch.set_default_tensor_type(torch.DoubleTensor)
20 |
21 | parser = argparse.ArgumentParser(description='PyTorch Density Function Training')
22 | # Datasets
23 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
24 | help='number of data loading workers (default: 4)')
25 | # Optimization options
26 | parser.add_argument('--iters', default=30000, type=int, metavar='N', help='number of total iterations to run')
27 | parser.add_argument('--dim', default=2, type=int)
28 | parser.add_argument('--trainbs', default=10, type=int, metavar='N', help='train batchsize')
29 | # parser.add_argument('--bdbs', default=10, type=int, metavar='N', help='train batchsize')
30 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate')
31 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
32 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
33 | parser.add_argument('--wd', default=0, type=float, metavar='W', help='weight decay')
34 | parser.add_argument('--function', default='resnet', type=str, help='function to approximate')
35 | parser.add_argument('--optim', default='adam', type=str, help='function to approximate')
36 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint')
37 | parser.add_argument('--bdbs', default=1000, type=int)
38 | parser.add_argument('--weight', default=100, type=float, help='weight')
39 |
40 | # Checkpoints
41 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='path to save checkpoint (default: checkpoint)')
42 |
43 | # Miscs
44 | parser.add_argument('--manualSeed', type=int, help='manual seed')
45 | # parser.add_argument('--exp', type=int, default='0', help='use exp last layer')
46 |
47 | #Device options
48 | parser.add_argument('--gpu-id', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES')
49 |
50 | # region
51 | parser.add_argument('--left', default=-3, type=float, help='left boundary of square region')
52 | parser.add_argument('--right', default=3, type=float, help='right boundary of square region')
53 |
54 | args = parser.parse_args()
55 | state = {k: v for k, v in args._get_kwargs()}
56 |
57 | print(state)
58 |
59 | # Use CUDA
60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
61 | use_cuda = torch.cuda.is_available()
62 |
63 | # Random seed
64 | if args.manualSeed is None:
65 | args.manualSeed = random.randint(1, 10000)
66 | random.seed(args.manualSeed)
67 | torch.manual_seed(args.manualSeed)
68 | if use_cuda:
69 | torch.cuda.manual_seed_all(args.manualSeed)
70 |
71 | def get_boundary(num_pts, dim):
72 | bd_pts = (torch.rand(num_pts, dim).cuda()) * (args.right - args.left) + args.left
73 | bd_pts[:, 0] = 0
74 | return bd_pts
75 |
76 | scale = math.pi/4#math.pi
77 |
78 | def LHS_pde(u, x, dim_set):
79 |
80 | v = torch.ones(u.shape).cuda()
81 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0]
82 | coefficient = torch.ones(1, dim_set).cuda()
83 | coefficient = -coefficient
84 | coefficient[:,0] = (dim_set-1)*scale
85 | ux_mul_coef = ux*coefficient
86 | LHS = torch.sum(ux_mul_coef, 1, keepdim=True)
87 | return LHS
88 |
89 | def RHS_pde(x):
90 | bs = x.size(0)
91 | return torch.zeros(bs, 1).cuda()
92 |
93 | def true_solution(x):
94 | dim = x.size(1)
95 | coefficient = torch.ones(1, dim).cuda()*scale
96 | coefficient[:, 0] = 1
97 | return torch.sin(torch.sum(x*coefficient, dim=1, keepdim=True))
98 |
99 |
100 | # the input dimension is modified to 2
101 | class ResNet(nn.Module):
102 | def __init__(self, m):
103 | super(ResNet, self).__init__()
104 | self.fc1 = nn.Linear(args.dim, m)
105 | self.fc2 = nn.Linear(m, m)
106 |
107 | self.fc3 = nn.Linear(m, m)
108 | self.fc4 = nn.Linear(m, m)
109 |
110 | self.fc5 = nn.Linear(m, m)
111 | self.fc6 = nn.Linear(m, m)
112 |
113 | self.outlayer = nn.Linear(m, 1, bias=True)
114 |
115 | # # initialize the network bias
116 | for idx, m in enumerate(self.modules()):
117 | if isinstance(m, nn.Linear):
118 | m.bias.data.zero_()
119 | m.weight.data.normal_(0.0, 0.1)
120 | # if args.use_bias != 0 and m.bias.data.shape[0] == 1:
121 | # m.bias.data.fill_(args.use_bias)
122 | # # print(m.weight.data.shape, m.bias.data.shape[0], m.bias.data)
123 |
124 | def forward(self, x):
125 |
126 | x1 = (x -args.left)*2/(args.right -args.left) + (-1)
127 | s = torch.nn.functional.pad(x1, (0, m - args.dim))
128 |
129 | y = self.fc1(x1)
130 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) # a RELU(X) + b RELU(X)^2
131 | y = self.fc2(y)
132 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)#
133 | y = y + s
134 |
135 | s = y
136 | y = self.fc3(y)
137 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
138 | y = self.fc4(y)
139 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
140 | y = y + s
141 |
142 | s = y
143 | y = self.fc5(y)
144 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
145 | y = self.fc6(y)
146 | y = F.relu(y)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
147 | y = y + s
148 |
149 | output = self.outlayer(y)
150 | # output = torch.exp(output)
151 | return output
152 |
153 | '''
154 | HyperParams Setting for Network
155 | '''
156 | m = 50 # number of hidden size
157 |
158 | # for ResNet
159 | Ix = torch.zeros([1,m]).cuda()
160 | Ix[0,0] = 1
161 |
162 |
163 | def main():
164 |
165 | if not os.path.isdir(args.checkpoint):
166 | mkdir_p(args.checkpoint)
167 |
168 | model = ResNet(m)
169 |
170 | model = model.cuda()
171 | cudnn.benchmark = True
172 |
173 | if not os.path.isdir(args.checkpoint):
174 | mkdir_p(args.checkpoint)
175 |
176 | with open(args.checkpoint + "/Config.txt", 'w+') as f:
177 | for (k, v) in args._get_kwargs():
178 | f.write(k + ' : ' + str(v) + '\n')
179 |
180 | print('Total params: %.2f' % (sum(p.numel() for p in model.parameters())))
181 |
182 | """
183 | Define Residual Methods and Optimizer
184 | """
185 | criterion = nn.MSELoss()
186 | if args.optim == 'SGD':
187 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
188 | else:
189 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.wd)
190 | # Resume
191 | title = ''
192 |
193 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
194 | logger.set_names(['Learning Rate', 'Losses', 'pdeloss', 'bdloss'])
195 |
196 | # Train and val
197 | for iter in range(0, args.iters):
198 |
199 | lr = cosine_lr(optimizer, args.lr, iter, args.iters)
200 |
201 | losses, pdeloss, intloss = train(model, criterion, optimizer, use_cuda, iter, lr)
202 | logger.append([lr, losses, pdeloss, intloss])
203 |
204 | # save model
205 | save_checkpoint({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, checkpoint=args.checkpoint)
206 |
207 | numerators = []
208 | denominators = []
209 |
210 | for i in range(1000):
211 | print(i)
212 | t = torch.rand(100000, 1).cuda()
213 | x1 = (torch.rand(100000, args.dim - 1).cuda()) * (args.right - args.left) + args.left
214 | x = torch.cat((t, x1), 1)
215 | # print(true_solution(x).size(), model(x).size())
216 | sq_de = torch.mean((true_solution(x)) ** 2)
217 | sq_nu = torch.mean((true_solution(x) - model(x)) ** 2)
218 | numerators.append(sq_nu.item())
219 | denominators.append(sq_de.item())
220 |
221 | relative_l2 = math.sqrt(sum(numerators)) / math.sqrt(sum(denominators))
222 | print('relative l2 error: ', relative_l2)
223 | logger.append(['relative_l2', relative_l2, 0, 0])
224 |
225 | # logger.append([0, 0, relative_l2, 0, 0])
226 |
227 | logger.close()
228 |
229 |
230 | def train(model, criterion, optimizer, use_cuda, iter, lr):
231 |
232 | # switch to train mode
233 | model.train()
234 | end = time.time()
235 | '''
236 | points sampling
237 | '''
238 | # the range is [0,1] --> [left, right]
239 | t = torch.rand(args.trainbs, 1).cuda()
240 | x1 = (torch.rand(args.trainbs, args.dim-1).cuda())*(args.right-args.left)+args.left
241 | x = torch.cat((t,x1), 1)
242 | # print(x)
243 | x.requires_grad = True
244 |
245 | bd_pts = get_boundary(args.bdbs, args.dim)
246 | bc_true = true_solution(bd_pts)
247 | bd_nn = model(bd_pts)
248 | bd_error = torch.nn.functional.mse_loss(bc_true, bd_nn)
249 | function_error = torch.nn.functional.mse_loss(LHS_pde(model(x), x, args.dim), RHS_pde(x))
250 | loss = function_error + args.weight * bd_error
251 |
252 | # compute gradient and do SGD step
253 | optimizer.zero_grad()
254 | loss.backward()
255 | optimizer.step()
256 |
257 | # measure elapsed time
258 | batch_time = time.time() - end
259 | suffix = '{iter:.1f} {lr:.8f}| Batch: {bt:.3f}s | Loss: {loss:.8f} | pdeloss: {pdeloss:.8f} | bd loss {bd: .8f} |'.format(
260 | bt=batch_time, loss=loss.item(), iter=iter, lr=lr, pdeloss=function_error.item(), bd=bd_error.item())
261 | print(suffix)
262 | return loss.item(), function_error.item(), bd_error.item()
263 |
264 | def save_checkpoint(state, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
265 | filepath = os.path.join(checkpoint, filename)
266 | torch.save(state, filepath)
267 |
268 |
269 | def cosine_lr(opt, base_lr, e, epochs):
270 | lr = 0.5 * base_lr * (math.cos(math.pi * e / epochs) + 1)
271 | for param_group in opt.param_groups:
272 | param_group["lr"] = lr
273 | return lr
274 |
275 | if __name__ == '__main__':
276 | main()
277 |
--------------------------------------------------------------------------------
/nn/Conservationlaw/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from .misc import *
4 | from .logger import *
5 | from .visualize import *
6 | from .eval import *
7 |
8 | # progress bar
9 | import os, sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
11 | # from progress.bar import Bar as Bar
--------------------------------------------------------------------------------
/nn/Conservationlaw/utils/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | __all__ = ['accuracy']
4 |
5 | def accuracy(output, target, topk=(1,)):
6 | """Computes the precision@k for the specified values of k"""
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 |
10 | _, pred = output.topk(maxk, 1, True, True)
11 | pred = pred.t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 |
14 | res = []
15 | for k in topk:
16 | correct_k = correct[:k].view(-1).float().sum(0)
17 | res.append(correct_k.mul_(100.0 / batch_size))
18 | return res
--------------------------------------------------------------------------------
/nn/Conservationlaw/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 | import matplotlib.pyplot as plt
5 | import os
6 | import sys
7 | import numpy as np
8 |
9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
10 |
11 | def savefig(fname, dpi=None):
12 | dpi = 150 if dpi == None else dpi
13 | plt.savefig(fname, dpi=dpi)
14 |
15 | def plot_overlap(logger, names=None):
16 | names = logger.names if names == None else names
17 | numbers = logger.numbers
18 | for _, name in enumerate(names):
19 | x = np.arange(len(numbers[name]))
20 | plt.plot(x, np.asarray(numbers[name]))
21 | return [logger.title + '(' + name + ')' for name in names]
22 |
23 | class Logger(object):
24 | '''Save training process to log file with simple plot function.'''
25 | def __init__(self, fpath, title=None, resume=False):
26 | self.file = None
27 | self.resume = resume
28 | self.title = '' if title == None else title
29 | if fpath is not None:
30 | if resume:
31 | self.file = open(fpath, 'r')
32 | name = self.file.readline()
33 | self.names = name.rstrip().split('\t')
34 | self.numbers = {}
35 | for _, name in enumerate(self.names):
36 | self.numbers[name] = []
37 |
38 | for numbers in self.file:
39 | numbers = numbers.rstrip().split('\t')
40 | for i in range(0, len(numbers)):
41 | self.numbers[self.names[i]].append(numbers[i])
42 | self.file.close()
43 | self.file = open(fpath, 'a')
44 | else:
45 | self.file = open(fpath, 'w')
46 |
47 | def set_names(self, names):
48 | if self.resume:
49 | pass
50 | # initialize numbers as empty list
51 | self.numbers = {}
52 | self.names = names
53 | for _, name in enumerate(self.names):
54 | self.file.write(name)
55 | self.file.write('\t')
56 | self.numbers[name] = []
57 | self.file.write('\n')
58 | self.file.flush()
59 |
60 | def append(self, numbers):
61 | assert len(self.names) == len(numbers), 'Numbers do not match names'
62 | for index, num in enumerate(numbers):
63 | if isinstance(num, int):
64 | self.file.write("{}".format(num))
65 | elif isinstance(num, str):
66 | self.file.write("{}".format(num))
67 | else:
68 | self.file.write("{0:.8f}".format(num))
69 | self.file.write('\t')
70 | self.numbers[self.names[index]].append(num)
71 | self.file.write('\n')
72 | self.file.flush()
73 |
74 | def plot(self, names=None):
75 | names = self.names if names == None else names
76 | numbers = self.numbers
77 | for _, name in enumerate(names):
78 | x = np.arange(len(numbers[name]))
79 | plt.plot(x, np.asarray(numbers[name]))
80 | plt.legend([self.title + '(' + name + ')' for name in names])
81 | plt.grid(True)
82 |
83 | def close(self):
84 | if self.file is not None:
85 | self.file.close()
86 |
87 | class LoggerMonitor(object):
88 | '''Load and visualize multiple logs.'''
89 | def __init__ (self, paths):
90 | '''paths is a distionary with {name:filepath} pair'''
91 | self.loggers = []
92 | for title, path in paths.items():
93 | logger = Logger(path, title=title, resume=True)
94 | self.loggers.append(logger)
95 |
96 | def plot(self, names=None):
97 | plt.figure()
98 | plt.subplot(121)
99 | legend_text = []
100 | for logger in self.loggers:
101 | legend_text += plot_overlap(logger, names)
102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
103 | plt.grid(True)
104 |
105 | if __name__ == '__main__':
106 | # # Example
107 | # logger = Logger('test.txt')
108 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
109 |
110 | # length = 100
111 | # t = np.arange(length)
112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
115 |
116 | # for i in range(0, length):
117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
118 | # logger.plot()
119 |
120 | # Example: logger monitor
121 | paths = {
122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
125 | }
126 |
127 | field = ['Valid Acc.']
128 |
129 | monitor = LoggerMonitor(paths)
130 | monitor.plot(names=field)
131 | savefig('test.eps')
--------------------------------------------------------------------------------
/nn/Conservationlaw/utils/misc.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import errno
7 | import os
8 | import sys
9 | import time
10 | import math
11 |
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | from torch.autograd import Variable
15 |
16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']
17 |
18 |
19 | def get_mean_and_std(dataset):
20 | '''Compute the mean and std value of dataset.'''
21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
22 |
23 | mean = torch.zeros(3)
24 | std = torch.zeros(3)
25 | print('==> Computing mean and std..')
26 | for inputs, targets in dataloader:
27 | for i in range(3):
28 | mean[i] += inputs[:,i,:,:].mean()
29 | std[i] += inputs[:,i,:,:].std()
30 | mean.div_(len(dataset))
31 | std.div_(len(dataset))
32 | return mean, std
33 |
34 | def init_params(net):
35 | '''Init layer parameters.'''
36 | for m in net.modules():
37 | if isinstance(m, nn.Conv2d):
38 | init.kaiming_normal(m.weight, mode='fan_out')
39 | if m.bias:
40 | init.constant(m.bias, 0)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.constant(m.weight, 1)
43 | init.constant(m.bias, 0)
44 | elif isinstance(m, nn.Linear):
45 | init.normal(m.weight, std=1e-3)
46 | if m.bias:
47 | init.constant(m.bias, 0)
48 |
49 | def mkdir_p(path):
50 | '''make dir if not exist'''
51 | try:
52 | os.makedirs(path)
53 | except OSError as exc: # Python >2.5
54 | if exc.errno == errno.EEXIST and os.path.isdir(path):
55 | pass
56 | else:
57 | raise
58 |
59 | class AverageMeter(object):
60 | """Computes and stores the average and current value
61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
62 | """
63 | def __init__(self):
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = 0
68 | self.avg = 0
69 | self.sum = 0
70 | self.count = 0
71 |
72 | def update(self, val, n=1):
73 | self.val = val
74 | self.sum += val * n
75 | self.count += n
76 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/nn/Conservationlaw/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | from .misc import *
8 |
9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
10 |
11 | # functions to show an image
12 | def make_image(img, mean=(0,0,0), std=(1,1,1)):
13 | for i in range(0, 3):
14 | img[i] = img[i] * std[i] + mean[i] # unnormalize
15 | npimg = img.numpy()
16 | return np.transpose(npimg, (1, 2, 0))
17 |
18 | def gauss(x,a,b,c):
19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
20 |
21 | def colorize(x):
22 | ''' Converts a one-channel grayscale image to a color heatmap image '''
23 | if x.dim() == 2:
24 | torch.unsqueeze(x, 0, out=x)
25 | if x.dim() == 3:
26 | cl = torch.zeros([3, x.size(1), x.size(2)])
27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
28 | cl[1] = gauss(x,1,.5,.3)
29 | cl[2] = gauss(x,1,.2,.3)
30 | cl[cl.gt(1)] = 1
31 | elif x.dim() == 4:
32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
34 | cl[:,1,:,:] = gauss(x,1,.5,.3)
35 | cl[:,2,:,:] = gauss(x,1,.2,.3)
36 | return cl
37 |
38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
40 | plt.imshow(images)
41 | plt.show()
42 |
43 |
44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
45 | im_size = images.size(2)
46 |
47 | # save for adding mask
48 | im_data = images.clone()
49 | for i in range(0, 3):
50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
51 |
52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
53 | plt.subplot(2, 1, 1)
54 | plt.imshow(images)
55 | plt.axis('off')
56 |
57 | # for b in range(mask.size(0)):
58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
59 | mask_size = mask.size(2)
60 | # print('Max %f Min %f' % (mask.max(), mask.min()))
61 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
63 | # for c in range(3):
64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
65 |
66 | # print(mask.size())
67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
69 | plt.subplot(2, 1, 2)
70 | plt.imshow(mask)
71 | plt.axis('off')
72 |
73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
74 | im_size = images.size(2)
75 |
76 | # save for adding mask
77 | im_data = images.clone()
78 | for i in range(0, 3):
79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
80 |
81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
82 | plt.subplot(1+len(masklist), 1, 1)
83 | plt.imshow(images)
84 | plt.axis('off')
85 |
86 | for i in range(len(masklist)):
87 | mask = masklist[i].data.cpu()
88 | # for b in range(mask.size(0)):
89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
90 | mask_size = mask.size(2)
91 | # print('Max %f Min %f' % (mask.max(), mask.min()))
92 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
94 | # for c in range(3):
95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
96 |
97 | # print(mask.size())
98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
100 | plt.subplot(1+len(masklist), 1, i+2)
101 | plt.imshow(mask)
102 | plt.axis('off')
103 |
104 |
105 |
106 | # x = torch.zeros(1, 3, 3)
107 | # out = colorize(x)
108 | # out_im = make_image(out)
109 | # plt.imshow(out_im)
110 | # plt.show()
--------------------------------------------------------------------------------
/nn/Poisson/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/nn/Poisson/.idea/conservativelaw.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/nn/Poisson/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/nn/Poisson/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/nn/Poisson/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/nn/Poisson/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/nn/Poisson/scrpts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 |
5 | # for i in range(6):
6 | # for dim in [10, 20, 30, 40, 50]:
7 | # gpu = random.randint(1,7)
8 | # os.system('screen python train.py --gpu-id '+str(gpu)+' --trainbs 5000 --bdbs 1000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts/dim'+str(dim)+'_trial'+str(i)+' --weight 100 --left -1 --right 1')
9 | # # print(name)
10 | # time.sleep(100)
11 |
12 |
13 | # for i in range(6):
14 | # for dim in [10, 20, 30, 40, 50]:
15 | # gpu = random.randint(0,9)
16 | # os.system('screen python train.py --gpu-id '+str(gpu)+' --trainbs 10000 --bdbs 5000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts_largerbs_relu2/dim'+str(dim)+'_trial'+str(i)+' --weight 100 --left -1 --right 1')
17 | # # print(name)
18 | # time.sleep(100)
19 |
20 | for i in range(6):
21 | for dim in [10, 20, 30, 40, 50]:
22 | gpu = random.randint(0,9)
23 | os.system('screen python train.py --gpu-id '+str(gpu)+' --trainbs 5000 --bdbs 1000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts_bs5k_1k_relu2/dim'+str(dim)+'_trial'+str(i)+' --weight 100 --left -1 --right 1')
24 | # print(name)
25 | time.sleep(100)
26 |
27 |
28 | # for i in range(1):
29 | # for dim in [30]:
30 | # gpu = random.randint(1,3)
31 | # os.system('python train.py --gpu-id '+str(gpu)+' --trainbs 10000 --bdbs 5000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts/test --weight 100 --left -1 --right 1')
32 | # print(name)
33 | # time.sleep(100)
34 |
--------------------------------------------------------------------------------
/nn/Poisson/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import os
5 | import time
6 | import random
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.optim as optim
13 | import math
14 | import numpy as np
15 | from torch import sin, cos, exp
16 | import torch.nn.functional as F
17 | from utils import Logger, AverageMeter, mkdir_p
18 | from numpy.polynomial.legendre import leggauss
19 | # torch.set_default_tensor_type(torch.DoubleTensor)
20 |
21 | parser = argparse.ArgumentParser(description='PyTorch Density Function Training')
22 | # Datasets
23 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
24 | help='number of data loading workers (default: 4)')
25 | # Optimization options
26 | parser.add_argument('--iters', default=30000, type=int, metavar='N', help='number of total iterations to run')
27 | parser.add_argument('--dim', default=2, type=int)
28 | parser.add_argument('--trainbs', default=10, type=int, metavar='N', help='train batchsize')
29 | # parser.add_argument('--bdbs', default=10, type=int, metavar='N', help='train batchsize')
30 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate')
31 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
32 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
33 | parser.add_argument('--wd', default=0, type=float, metavar='W', help='weight decay')
34 | parser.add_argument('--function', default='resnet', type=str, help='function to approximate')
35 | parser.add_argument('--optim', default='adam', type=str, help='function to approximate')
36 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint')
37 | parser.add_argument('--bdbs', default=1000, type=int)
38 | parser.add_argument('--weight', default=100, type=float, help='weight')
39 |
40 | # Checkpoints
41 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='path to save checkpoint (default: checkpoint)')
42 |
43 | # Miscs
44 | parser.add_argument('--manualSeed', type=int, help='manual seed')
45 | # parser.add_argument('--exp', type=int, default='0', help='use exp last layer')
46 |
47 | #Device options
48 | parser.add_argument('--gpu-id', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES')
49 |
50 | # region
51 | parser.add_argument('--left', default=-3, type=float, help='left boundary of square region')
52 | parser.add_argument('--right', default=3, type=float, help='right boundary of square region')
53 |
54 | args = parser.parse_args()
55 | state = {k: v for k, v in args._get_kwargs()}
56 |
57 | print(state)
58 |
59 | # Use CUDA
60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
61 | use_cuda = torch.cuda.is_available()
62 |
63 | # Random seed
64 | if args.manualSeed is None:
65 | args.manualSeed = random.randint(1, 10000)
66 | random.seed(args.manualSeed)
67 | torch.manual_seed(args.manualSeed)
68 | if use_cuda:
69 | torch.cuda.manual_seed_all(args.manualSeed)
70 |
71 | def get_boundary(num_pts, dim):
72 |
73 | bd_pts = (torch.rand(num_pts, dim).cuda()) * (args.right - args.left) + args.left
74 |
75 | num_half = num_pts//2
76 | xlst = torch.arange(0, num_half)
77 | ylst = torch.randint(0, dim, (num_half,))
78 | bd_pts[xlst, ylst] = args.left
79 |
80 | xlst = torch.arange(num_half, num_pts)
81 | ylst = torch.randint(0, dim, (num_half,))
82 | bd_pts[xlst, ylst] = args.right
83 |
84 | return bd_pts
85 |
86 | def LHS_pde(u, x, dim_set):
87 |
88 | v = torch.ones(u.shape).cuda()
89 | bs = x.size(0)
90 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0]
91 | uxx = torch.zeros(bs, dim_set).cuda()
92 | for i in range(dim_set):
93 | ux_tem = ux[:, i:i+1]
94 | uxx_tem = torch.autograd.grad(ux_tem, x, grad_outputs=v, create_graph=True)[0]
95 | uxx[:, i] = uxx_tem[:, i]
96 | LHS = -torch.sum(uxx, dim=1, keepdim=True)
97 | return LHS
98 |
99 | def RHS_pde(x):
100 | bs = x.size(0)
101 | dim = x.size(1)
102 | return -dim*torch.ones(bs, 1).cuda()
103 |
104 | def true_solution(x):
105 | return 0.5*torch.sum(x**2, dim=1, keepdim=True)
106 |
107 |
108 | # the input dimension is modified to 2
109 | class ResNet(nn.Module):
110 | def __init__(self, m):
111 | super(ResNet, self).__init__()
112 | self.fc1 = nn.Linear(args.dim, m)
113 | self.fc2 = nn.Linear(m, m)
114 |
115 | self.fc3 = nn.Linear(m, m)
116 | self.fc4 = nn.Linear(m, m)
117 |
118 | self.fc5 = nn.Linear(m, m)
119 | self.fc6 = nn.Linear(m, m)
120 |
121 | self.outlayer = nn.Linear(m, 1, bias=True)
122 |
123 | # # initialize the network bias
124 | for idx, m in enumerate(self.modules()):
125 | if isinstance(m, nn.Linear):
126 | m.bias.data.zero_()
127 | m.weight.data.normal_(0.0, 0.1)
128 | # if args.use_bias != 0 and m.bias.data.shape[0] == 1:
129 | # m.bias.data.fill_(args.use_bias)
130 | # # print(m.weight.data.shape, m.bias.data.shape[0], m.bias.data)
131 |
132 | def forward(self, x):
133 |
134 | x1 = (x -args.left)*2/(args.right -args.left) + (-1)
135 | s = torch.nn.functional.pad(x1, (0, m - args.dim))
136 |
137 | y = self.fc1(x1)
138 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) # a RELU(X) + b RELU(X)^2
139 | y = self.fc2(y)
140 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)#
141 | y = y + s
142 |
143 | s = y
144 | y = self.fc3(y)
145 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
146 | y = self.fc4(y)
147 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
148 | y = y + s
149 |
150 | s = y
151 | y = self.fc5(y)
152 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
153 | y = self.fc6(y)
154 | y = F.relu(y ** 3)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
155 | y = y + s
156 |
157 | output = self.outlayer(y)
158 | # output = torch.exp(output)
159 | return output
160 |
161 |
162 | # the input dimension is modified to 2
163 | class ResNet(nn.Module):
164 | def __init__(self, m):
165 | super(ResNet, self).__init__()
166 | self.fc1 = nn.Linear(args.dim, m)
167 | self.fc2 = nn.Linear(m, m)
168 |
169 | self.fc3 = nn.Linear(m, m)
170 | self.fc4 = nn.Linear(m, m)
171 |
172 | self.fc5 = nn.Linear(m, m)
173 | self.fc6 = nn.Linear(m, m)
174 |
175 | self.outlayer = nn.Linear(m, 1, bias=True)
176 |
177 | # # initialize the network bias
178 | for idx, m in enumerate(self.modules()):
179 | if isinstance(m, nn.Linear):
180 | m.bias.data.zero_()
181 | m.weight.data.normal_(0.0, 0.1)
182 | # if args.use_bias != 0 and m.bias.data.shape[0] == 1:
183 | # m.bias.data.fill_(args.use_bias)
184 | # # print(m.weight.data.shape, m.bias.data.shape[0], m.bias.data)
185 |
186 | def forward(self, x):
187 |
188 | x1 = (x -args.left)*2/(args.right -args.left) + (-1)
189 | s = torch.nn.functional.pad(x1, (0, m - args.dim))
190 |
191 | y = self.fc1(x1)
192 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) # a RELU(X) + b RELU(X)^2
193 | y = self.fc2(y)
194 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)#
195 | y = y + s
196 |
197 | s = y
198 | y = self.fc3(y)
199 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
200 | y = self.fc4(y)
201 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
202 | y = y + s
203 |
204 | s = y
205 | y = self.fc5(y)
206 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
207 | y = self.fc6(y)
208 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
209 | y = y + s
210 |
211 | output = self.outlayer(y)
212 | # output = torch.exp(output)
213 | return output
214 |
215 |
216 | '''
217 | HyperParams Setting for Network
218 | '''
219 | m = 50 # number of hidden size
220 |
221 | # for ResNet
222 | Ix = torch.zeros([1,m]).cuda()
223 | Ix[0,0] = 1
224 |
225 |
226 | def main():
227 |
228 | if not os.path.isdir(args.checkpoint):
229 | mkdir_p(args.checkpoint)
230 |
231 | model = ResNet(m)
232 |
233 | model = model.cuda()
234 | cudnn.benchmark = True
235 |
236 | if not os.path.isdir(args.checkpoint):
237 | mkdir_p(args.checkpoint)
238 |
239 | with open(args.checkpoint + "/Config.txt", 'w+') as f:
240 | for (k, v) in args._get_kwargs():
241 | f.write(k + ' : ' + str(v) + '\n')
242 |
243 | print('Total params: %.2f' % (sum(p.numel() for p in model.parameters())))
244 |
245 | """
246 | Define Residual Methods and Optimizer
247 | """
248 | criterion = nn.MSELoss()
249 | if args.optim == 'SGD':
250 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
251 | else:
252 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.wd)
253 | # Resume
254 | title = ''
255 |
256 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
257 | logger.set_names(['Learning Rate', 'Losses', 'pdeloss', 'bdloss'])
258 |
259 | # Train and val
260 | for iter in range(0, args.iters):
261 |
262 | lr = cosine_lr(optimizer, args.lr, iter, args.iters)
263 |
264 | losses, pdeloss, intloss = train(model, criterion, optimizer, use_cuda, iter, lr)
265 | logger.append([lr, losses, pdeloss, intloss])
266 |
267 | # save model
268 | save_checkpoint({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, checkpoint=args.checkpoint)
269 |
270 | numerators = []
271 | denominators = []
272 |
273 | for i in range(1000):
274 | print(i)
275 | x = (torch.rand(100000, args.dim).cuda()) * (args.right - args.left) + args.left
276 | # print(true_solution(x).size(), model(x).size())
277 | sq_de = torch.mean((true_solution(x)) ** 2)
278 | sq_nu = torch.mean((true_solution(x) - model(x)) ** 2)
279 | numerators.append(sq_nu.item())
280 | denominators.append(sq_de.item())
281 |
282 | relative_l2 = math.sqrt(sum(numerators)) / math.sqrt(sum(denominators))
283 | print('relative l2 error: ', relative_l2)
284 | logger.append(['relative_l2', relative_l2, 0, 0])
285 |
286 | # logger.append([0, 0, relative_l2, 0, 0])
287 |
288 | logger.close()
289 |
290 |
291 | def train(model, criterion, optimizer, use_cuda, iter, lr):
292 |
293 | # switch to train mode
294 | model.train()
295 | end = time.time()
296 | '''
297 | points sampling
298 | '''
299 | # the range is [0,1] --> [left, right]
300 | x = (torch.rand(args.trainbs, args.dim).cuda())*(args.right-args.left)+args.left
301 | x.requires_grad = True
302 |
303 | bd_pts = get_boundary(args.bdbs, args.dim)
304 | bc_true = true_solution(bd_pts)
305 | bd_nn = model(bd_pts)
306 | bd_error = torch.nn.functional.mse_loss(bc_true, bd_nn)
307 | function_error = torch.nn.functional.mse_loss(LHS_pde(model(x), x, args.dim), RHS_pde(x))
308 | loss = function_error + args.weight * bd_error
309 |
310 | # compute gradient and do SGD step
311 | optimizer.zero_grad()
312 | loss.backward()
313 | optimizer.step()
314 |
315 | # measure elapsed time
316 | batch_time = time.time() - end
317 | suffix = '{iter:.1f} {lr:.8f}| Batch: {bt:.3f}s | Loss: {loss:.8f} | pdeloss: {pdeloss:.8f} | bd loss {bd: .8f} |'.format(
318 | bt=batch_time, loss=loss.item(), iter=iter, lr=lr, pdeloss=function_error.item(), bd=bd_error.item())
319 | print(suffix)
320 | return loss.item(), function_error.item(), bd_error.item()
321 |
322 | def save_checkpoint(state, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
323 | filepath = os.path.join(checkpoint, filename)
324 | torch.save(state, filepath)
325 |
326 |
327 | def cosine_lr(opt, base_lr, e, epochs):
328 | lr = 0.5 * base_lr * (math.cos(math.pi * e / epochs) + 1)
329 | for param_group in opt.param_groups:
330 | param_group["lr"] = lr
331 | return lr
332 |
333 | if __name__ == '__main__':
334 | main()
335 |
--------------------------------------------------------------------------------
/nn/Poisson/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from .misc import *
4 | from .logger import *
5 | from .visualize import *
6 | from .eval import *
7 |
8 | # progress bar
9 | import os, sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
11 | # from progress.bar import Bar as Bar
--------------------------------------------------------------------------------
/nn/Poisson/utils/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | __all__ = ['accuracy']
4 |
5 | def accuracy(output, target, topk=(1,)):
6 | """Computes the precision@k for the specified values of k"""
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 |
10 | _, pred = output.topk(maxk, 1, True, True)
11 | pred = pred.t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 |
14 | res = []
15 | for k in topk:
16 | correct_k = correct[:k].view(-1).float().sum(0)
17 | res.append(correct_k.mul_(100.0 / batch_size))
18 | return res
--------------------------------------------------------------------------------
/nn/Poisson/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 | import matplotlib.pyplot as plt
5 | import os
6 | import sys
7 | import numpy as np
8 |
9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
10 |
11 | def savefig(fname, dpi=None):
12 | dpi = 150 if dpi == None else dpi
13 | plt.savefig(fname, dpi=dpi)
14 |
15 | def plot_overlap(logger, names=None):
16 | names = logger.names if names == None else names
17 | numbers = logger.numbers
18 | for _, name in enumerate(names):
19 | x = np.arange(len(numbers[name]))
20 | plt.plot(x, np.asarray(numbers[name]))
21 | return [logger.title + '(' + name + ')' for name in names]
22 |
23 | class Logger(object):
24 | '''Save training process to log file with simple plot function.'''
25 | def __init__(self, fpath, title=None, resume=False):
26 | self.file = None
27 | self.resume = resume
28 | self.title = '' if title == None else title
29 | if fpath is not None:
30 | if resume:
31 | self.file = open(fpath, 'r')
32 | name = self.file.readline()
33 | self.names = name.rstrip().split('\t')
34 | self.numbers = {}
35 | for _, name in enumerate(self.names):
36 | self.numbers[name] = []
37 |
38 | for numbers in self.file:
39 | numbers = numbers.rstrip().split('\t')
40 | for i in range(0, len(numbers)):
41 | self.numbers[self.names[i]].append(numbers[i])
42 | self.file.close()
43 | self.file = open(fpath, 'a')
44 | else:
45 | self.file = open(fpath, 'w')
46 |
47 | def set_names(self, names):
48 | if self.resume:
49 | pass
50 | # initialize numbers as empty list
51 | self.numbers = {}
52 | self.names = names
53 | for _, name in enumerate(self.names):
54 | self.file.write(name)
55 | self.file.write('\t')
56 | self.numbers[name] = []
57 | self.file.write('\n')
58 | self.file.flush()
59 |
60 | def append(self, numbers):
61 | assert len(self.names) == len(numbers), 'Numbers do not match names'
62 | for index, num in enumerate(numbers):
63 | if isinstance(num, int):
64 | self.file.write("{}".format(num))
65 | elif isinstance(num, str):
66 | self.file.write("{}".format(num))
67 | else:
68 | self.file.write("{0:.8f}".format(num))
69 | self.file.write('\t')
70 | self.numbers[self.names[index]].append(num)
71 | self.file.write('\n')
72 | self.file.flush()
73 |
74 | def plot(self, names=None):
75 | names = self.names if names == None else names
76 | numbers = self.numbers
77 | for _, name in enumerate(names):
78 | x = np.arange(len(numbers[name]))
79 | plt.plot(x, np.asarray(numbers[name]))
80 | plt.legend([self.title + '(' + name + ')' for name in names])
81 | plt.grid(True)
82 |
83 | def close(self):
84 | if self.file is not None:
85 | self.file.close()
86 |
87 | class LoggerMonitor(object):
88 | '''Load and visualize multiple logs.'''
89 | def __init__ (self, paths):
90 | '''paths is a distionary with {name:filepath} pair'''
91 | self.loggers = []
92 | for title, path in paths.items():
93 | logger = Logger(path, title=title, resume=True)
94 | self.loggers.append(logger)
95 |
96 | def plot(self, names=None):
97 | plt.figure()
98 | plt.subplot(121)
99 | legend_text = []
100 | for logger in self.loggers:
101 | legend_text += plot_overlap(logger, names)
102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
103 | plt.grid(True)
104 |
105 | if __name__ == '__main__':
106 | # # Example
107 | # logger = Logger('test.txt')
108 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
109 |
110 | # length = 100
111 | # t = np.arange(length)
112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
115 |
116 | # for i in range(0, length):
117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
118 | # logger.plot()
119 |
120 | # Example: logger monitor
121 | paths = {
122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
125 | }
126 |
127 | field = ['Valid Acc.']
128 |
129 | monitor = LoggerMonitor(paths)
130 | monitor.plot(names=field)
131 | savefig('test.eps')
--------------------------------------------------------------------------------
/nn/Poisson/utils/misc.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import errno
7 | import os
8 | import sys
9 | import time
10 | import math
11 |
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | from torch.autograd import Variable
15 |
16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']
17 |
18 |
19 | def get_mean_and_std(dataset):
20 | '''Compute the mean and std value of dataset.'''
21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
22 |
23 | mean = torch.zeros(3)
24 | std = torch.zeros(3)
25 | print('==> Computing mean and std..')
26 | for inputs, targets in dataloader:
27 | for i in range(3):
28 | mean[i] += inputs[:,i,:,:].mean()
29 | std[i] += inputs[:,i,:,:].std()
30 | mean.div_(len(dataset))
31 | std.div_(len(dataset))
32 | return mean, std
33 |
34 | def init_params(net):
35 | '''Init layer parameters.'''
36 | for m in net.modules():
37 | if isinstance(m, nn.Conv2d):
38 | init.kaiming_normal(m.weight, mode='fan_out')
39 | if m.bias:
40 | init.constant(m.bias, 0)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.constant(m.weight, 1)
43 | init.constant(m.bias, 0)
44 | elif isinstance(m, nn.Linear):
45 | init.normal(m.weight, std=1e-3)
46 | if m.bias:
47 | init.constant(m.bias, 0)
48 |
49 | def mkdir_p(path):
50 | '''make dir if not exist'''
51 | try:
52 | os.makedirs(path)
53 | except OSError as exc: # Python >2.5
54 | if exc.errno == errno.EEXIST and os.path.isdir(path):
55 | pass
56 | else:
57 | raise
58 |
59 | class AverageMeter(object):
60 | """Computes and stores the average and current value
61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
62 | """
63 | def __init__(self):
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = 0
68 | self.avg = 0
69 | self.sum = 0
70 | self.count = 0
71 |
72 | def update(self, val, n=1):
73 | self.val = val
74 | self.sum += val * n
75 | self.count += n
76 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/nn/Poisson/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | from .misc import *
8 |
9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
10 |
11 | # functions to show an image
12 | def make_image(img, mean=(0,0,0), std=(1,1,1)):
13 | for i in range(0, 3):
14 | img[i] = img[i] * std[i] + mean[i] # unnormalize
15 | npimg = img.numpy()
16 | return np.transpose(npimg, (1, 2, 0))
17 |
18 | def gauss(x,a,b,c):
19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
20 |
21 | def colorize(x):
22 | ''' Converts a one-channel grayscale image to a color heatmap image '''
23 | if x.dim() == 2:
24 | torch.unsqueeze(x, 0, out=x)
25 | if x.dim() == 3:
26 | cl = torch.zeros([3, x.size(1), x.size(2)])
27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
28 | cl[1] = gauss(x,1,.5,.3)
29 | cl[2] = gauss(x,1,.2,.3)
30 | cl[cl.gt(1)] = 1
31 | elif x.dim() == 4:
32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
34 | cl[:,1,:,:] = gauss(x,1,.5,.3)
35 | cl[:,2,:,:] = gauss(x,1,.2,.3)
36 | return cl
37 |
38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
40 | plt.imshow(images)
41 | plt.show()
42 |
43 |
44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
45 | im_size = images.size(2)
46 |
47 | # save for adding mask
48 | im_data = images.clone()
49 | for i in range(0, 3):
50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
51 |
52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
53 | plt.subplot(2, 1, 1)
54 | plt.imshow(images)
55 | plt.axis('off')
56 |
57 | # for b in range(mask.size(0)):
58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
59 | mask_size = mask.size(2)
60 | # print('Max %f Min %f' % (mask.max(), mask.min()))
61 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
63 | # for c in range(3):
64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
65 |
66 | # print(mask.size())
67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
69 | plt.subplot(2, 1, 2)
70 | plt.imshow(mask)
71 | plt.axis('off')
72 |
73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
74 | im_size = images.size(2)
75 |
76 | # save for adding mask
77 | im_data = images.clone()
78 | for i in range(0, 3):
79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
80 |
81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
82 | plt.subplot(1+len(masklist), 1, 1)
83 | plt.imshow(images)
84 | plt.axis('off')
85 |
86 | for i in range(len(masklist)):
87 | mask = masklist[i].data.cpu()
88 | # for b in range(mask.size(0)):
89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
90 | mask_size = mask.size(2)
91 | # print('Max %f Min %f' % (mask.max(), mask.min()))
92 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
94 | # for c in range(3):
95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
96 |
97 | # print(mask.size())
98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
100 | plt.subplot(1+len(masklist), 1, i+2)
101 | plt.imshow(mask)
102 | plt.axis('off')
103 |
104 |
105 |
106 | # x = torch.zeros(1, 3, 3)
107 | # out = colorize(x)
108 | # out_im = make_image(out)
109 | # plt.imshow(out_im)
110 | # plt.show()
--------------------------------------------------------------------------------
/nn/Schrodinger/scripts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 |
5 |
6 | for i in range(0, 6):
7 | for dim in [6, 12, 18, 24, 30]:
8 | gpu = random.choice([3,4,7])
9 | os.system('screen python train_integral.py --gpu-id '+str(gpu)+' --trainbs 2000 --intbs 10000 --optim adam --lr 0.001 --iters 15000 --dim '+str(dim)+' --checkpoint ckpts_ReLU2_intbs/Dim'+str(dim)+'_trial'+str(i)+' --weight 1 --left -1 --right 1')
10 | time.sleep(120)
--------------------------------------------------------------------------------
/nn/Schrodinger/train_integral.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import os
5 | import time
6 | import random
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.optim as optim
13 | import math
14 | import numpy as np
15 | from torch import sin, cos, exp
16 | import torch.nn.functional as F
17 | from utils import Logger, AverageMeter, mkdir_p
18 | from numpy.polynomial.legendre import leggauss
19 | # torch.set_default_tensor_type(torch.DoubleTensor)
20 |
21 | parser = argparse.ArgumentParser(description='PyTorch Density Function Training')
22 | # Datasets
23 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
24 | help='number of data loading workers (default: 4)')
25 | # Optimization options
26 | parser.add_argument('--iters', default=30000, type=int, metavar='N', help='number of total iterations to run')
27 | parser.add_argument('--dim', default=2, type=int)
28 | parser.add_argument('--trainbs', default=2000, type=int, metavar='N', help='train batchsize')
29 | parser.add_argument('--intbs', default=10000, type=int)
30 | # parser.add_argument('--bdbs', default=10, type=int, metavar='N', help='train batchsize')
31 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate')
32 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
33 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
34 | parser.add_argument('--wd', default=0, type=float, metavar='W', help='weight decay')
35 | parser.add_argument('--function', default='resnet', type=str, help='function to approximate')
36 | parser.add_argument('--optim', default='adam', type=str, help='function to approximate')
37 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint')
38 |
39 | parser.add_argument('--weight', type=float, help='weight')
40 |
41 | # Checkpoints
42 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='path to save checkpoint (default: checkpoint)')
43 |
44 | # Miscs
45 | parser.add_argument('--manualSeed', type=int, help='manual seed')
46 | # parser.add_argument('--exp', type=int, default='0', help='use exp last layer')
47 |
48 | #Device options
49 | parser.add_argument('--gpu-id', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES')
50 |
51 | # region
52 | parser.add_argument('--left', default=-3, type=float, help='left boundary of square region')
53 | parser.add_argument('--right', default=3, type=float, help='right boundary of square region')
54 |
55 | args = parser.parse_args()
56 | state = {k: v for k, v in args._get_kwargs()}
57 |
58 | print(state)
59 |
60 | # Use CUDA
61 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
62 | use_cuda = torch.cuda.is_available()
63 |
64 | # Random seed
65 | if args.manualSeed is None:
66 | args.manualSeed = random.randint(1, 10000)
67 | random.seed(args.manualSeed)
68 | torch.manual_seed(args.manualSeed)
69 | if use_cuda:
70 | torch.cuda.manual_seed_all(args.manualSeed)
71 |
72 | c = 3
73 | def LHS_pde(u, x, dim_set):
74 |
75 | v = torch.ones(u.shape).cuda()
76 | bs = x.size(0)
77 | ux = torch.autograd.grad(u, x, grad_outputs=v, create_graph=True)[0]
78 | uxx = torch.zeros(bs, dim_set).cuda()
79 | for i in range(dim_set):
80 | ux_tem = ux[:, i].reshape([x.size()[0], 1])
81 | uxx_tem = torch.autograd.grad(ux_tem, x, grad_outputs=v, create_graph=True)[0]
82 | uxx[:, i] = uxx_tem[:, i]
83 |
84 | LHS = -torch.sum(uxx, dim=1, keepdim=True)
85 | V = -exp(2/dim_set*torch.sum(cos(x), dim=1, keepdim=True))/(c**2)+torch.sum(sin(x)**2/(dim_set**2), dim=1, keepdim=True)-torch.sum(cos(x)/(dim_set), dim=1, keepdim=True)
86 | return LHS+u**3+V*u#ux+uy#uxx+uyy
87 |
88 | def RHS_pde(x):
89 | bs = x.size(0)
90 | return torch.zeros(bs, 1).cuda()
91 |
92 | def true_solution(x):
93 | dim = x.size(1)
94 | return exp(1/dim*torch.sum(cos(x), dim=1, keepdim=True))/(c)
95 |
96 | integral_value = []
97 |
98 | for i in range(500):
99 | print(i)
100 | x = (torch.rand(100000, args.dim).cuda()) * (args.right - args.left) + args.left
101 | x.requires_grad = True
102 | value = torch.mean(true_solution(x))
103 | integral_value.append(value.item())
104 |
105 | integral_true = sum(integral_value)/len(integral_value)
106 |
107 | # the input dimension is modified to 2
108 | class ResNet(nn.Module):
109 | def __init__(self, m):
110 | super(ResNet, self).__init__()
111 | self.fc1 = nn.Linear(args.dim, m)
112 | self.fc2 = nn.Linear(m, m)
113 |
114 | self.fc3 = nn.Linear(m, m)
115 | self.fc4 = nn.Linear(m, m)
116 |
117 | self.fc5 = nn.Linear(m, m)
118 | self.fc6 = nn.Linear(m, m)
119 |
120 | self.outlayer = nn.Linear(m, 1, bias=True)
121 |
122 | # # initialize the network bias
123 | for idx, m in enumerate(self.modules()):
124 | if isinstance(m, nn.Linear):
125 | m.bias.data.zero_()
126 | m.weight.data.normal_(0.0, 0.1)
127 | # if args.use_bias != 0 and m.bias.data.shape[0] == 1:
128 | # m.bias.data.fill_(args.use_bias)
129 | # # print(m.weight.data.shape, m.bias.data.shape[0], m.bias.data)
130 |
131 | def forward(self, x):
132 |
133 | x1 = (x -args.left)*2/(args.right -args.left) + (-1)
134 | s = torch.nn.functional.pad(x1, (0, m - args.dim))
135 |
136 | y = self.fc1(x1)
137 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2) # a RELU(X) + b RELU(X)^2
138 | y = self.fc2(y)
139 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)#
140 | y = y + s
141 |
142 | s = y
143 | y = self.fc3(y)
144 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
145 | y = self.fc4(y)
146 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
147 | y = y + s
148 |
149 | s = y
150 | y = self.fc5(y)
151 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
152 | y = self.fc6(y)
153 | y = F.relu(y ** 2)#F.relu(y)#F.tanh(y)#F.relu(y ** 2)
154 | y = y + s
155 |
156 | output = self.outlayer(y)
157 | # output = torch.exp(output)
158 | return output
159 |
160 | '''
161 | HyperParams Setting for Network
162 | '''
163 | m = 50 # number of hidden size
164 |
165 | # for ResNet
166 | Ix = torch.zeros([1,m]).cuda()
167 | Ix[0,0] = 1
168 |
169 |
170 | def main():
171 |
172 | if not os.path.isdir(args.checkpoint):
173 | mkdir_p(args.checkpoint)
174 |
175 | model = ResNet(m)
176 |
177 | model = model.cuda()
178 | cudnn.benchmark = True
179 |
180 | # if args.resume:
181 | # numerators = []
182 | # denominators = []
183 | #
184 | # checkpoint = torch.load(args.resume)
185 | # model.load_state_dict(checkpoint['state_dict'])
186 | #
187 | # for i in range(500):
188 | # print(i)
189 | # x = (torch.rand(100000, 5).cuda()) * (args.right - args.left) + args.left
190 | # x.requires_grad = True
191 | # start = time.time()
192 | # numerator, denominator = LHS_pde_2(model(x), x)
193 | # print(time.time() - start)
194 | # numerators.append(numerator.item())
195 | # denominators.append(denominator.item())
196 | #
197 | # print('eigen: {} '.format(sum(numerators) / sum(denominators)))
198 | # return
199 |
200 | if not os.path.isdir(args.checkpoint):
201 | mkdir_p(args.checkpoint)
202 |
203 | with open(args.checkpoint + "/Config.txt", 'w+') as f:
204 | for (k, v) in args._get_kwargs():
205 | f.write(k + ' : ' + str(v) + '\n')
206 |
207 | print('Total params: %.2f' % (sum(p.numel() for p in model.parameters())))
208 |
209 | """
210 | Define Residual Methods and Optimizer
211 | """
212 | criterion = nn.MSELoss()
213 | if args.optim == 'SGD':
214 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
215 | else:
216 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.wd)
217 | # Resume
218 | title = ''
219 |
220 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
221 | logger.set_names(['Learning Rate', 'Losses', 'pdeloss', 'intloss'])
222 |
223 | # Train and val
224 | for iter in range(0, args.iters):
225 |
226 | lr = cosine_lr(optimizer, args.lr, iter, args.iters)
227 |
228 | losses, pdeloss, intloss = train(model, criterion, optimizer, use_cuda, iter, lr)
229 | logger.append([lr, losses, pdeloss, intloss])
230 |
231 | # save model
232 | save_checkpoint({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, checkpoint=args.checkpoint)
233 |
234 | numerators = []
235 | denominators = []
236 |
237 | for i in range(1000):
238 | print(i)
239 | x = (torch.rand(100000, args.dim).cuda()) * (args.right - args.left) + args.left
240 | # print(true_solution(x).size(), model(x).size())
241 | sq_de = torch.mean((true_solution(x)) ** 2)
242 | sq_nu = torch.mean((true_solution(x) - model(x)) ** 2)
243 | numerators.append(sq_nu.item())
244 | denominators.append(sq_de.item())
245 |
246 | relative_l2 = math.sqrt(sum(numerators)) / math.sqrt(sum(denominators))
247 | print('relative l2 error: ', relative_l2)
248 | logger.append(['relative_l2', relative_l2, 0, 0])
249 |
250 | # logger.append([0, 0, relative_l2, 0, 0])
251 |
252 | logger.close()
253 |
254 |
255 | def train(model, criterion, optimizer, use_cuda, iter, lr):
256 |
257 | # switch to train mode
258 | model.train()
259 | end = time.time()
260 | '''
261 | points sampling
262 | '''
263 | # the range is [0,1] --> [left, right]
264 | x = (torch.rand(args.trainbs, args.dim).cuda())*(args.right-args.left)+args.left
265 | x.requires_grad = True
266 |
267 | x_int = (torch.rand(args.intbs, args.dim).cuda()) * (args.right - args.left) + args.left
268 | x_int.requires_grad = True
269 |
270 | integral = torch.mean(model(x_int))
271 | integration = (integral - integral_true) ** 2
272 | function_error = torch.nn.functional.mse_loss(LHS_pde(model(x), x, args.dim), RHS_pde(x))
273 | # print(function_error, args.weight, integration)
274 | loss = function_error + args.weight * integration
275 |
276 | # compute gradient and do SGD step
277 | optimizer.zero_grad()
278 | loss.backward()
279 | optimizer.step()
280 |
281 | # measure elapsed time
282 | batch_time = time.time() - end
283 | suffix = '{iter:.1f} {lr:.8f}| Batch: {bt:.3f}s | Loss: {loss:.8f} | pdeloss: {pdeloss:.8f} | Integral loss {integral: .8f} |'.format(
284 | bt=batch_time, loss=loss.item(), iter=iter, lr=lr, pdeloss=function_error.item(), integral=integration.item())
285 | print(suffix)
286 | return loss.item(), function_error.item(), integration.item()
287 |
288 | def save_checkpoint(state, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
289 | filepath = os.path.join(checkpoint, filename)
290 | torch.save(state, filepath)
291 |
292 |
293 | def cosine_lr(opt, base_lr, e, epochs):
294 | lr = 0.5 * base_lr * (math.cos(math.pi * e / epochs) + 1)
295 | for param_group in opt.param_groups:
296 | param_group["lr"] = lr
297 | return lr
298 |
299 | if __name__ == '__main__':
300 | main()
301 |
--------------------------------------------------------------------------------
/nn/Schrodinger/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from .misc import *
4 | from .logger import *
5 | from .visualize import *
6 | from .eval import *
7 |
8 | # progress bar
9 | import os, sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
11 | # from progress.bar import Bar as Bar
--------------------------------------------------------------------------------
/nn/Schrodinger/utils/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | __all__ = ['accuracy']
4 |
5 | def accuracy(output, target, topk=(1,)):
6 | """Computes the precision@k for the specified values of k"""
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 |
10 | _, pred = output.topk(maxk, 1, True, True)
11 | pred = pred.t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 |
14 | res = []
15 | for k in topk:
16 | correct_k = correct[:k].view(-1).float().sum(0)
17 | res.append(correct_k.mul_(100.0 / batch_size))
18 | return res
--------------------------------------------------------------------------------
/nn/Schrodinger/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 | import matplotlib.pyplot as plt
5 | import os
6 | import sys
7 | import numpy as np
8 |
9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
10 |
11 | def savefig(fname, dpi=None):
12 | dpi = 150 if dpi == None else dpi
13 | plt.savefig(fname, dpi=dpi)
14 |
15 | def plot_overlap(logger, names=None):
16 | names = logger.names if names == None else names
17 | numbers = logger.numbers
18 | for _, name in enumerate(names):
19 | x = np.arange(len(numbers[name]))
20 | plt.plot(x, np.asarray(numbers[name]))
21 | return [logger.title + '(' + name + ')' for name in names]
22 |
23 | class Logger(object):
24 | '''Save training process to log file with simple plot function.'''
25 | def __init__(self, fpath, title=None, resume=False):
26 | self.file = None
27 | self.resume = resume
28 | self.title = '' if title == None else title
29 | if fpath is not None:
30 | if resume:
31 | self.file = open(fpath, 'r')
32 | name = self.file.readline()
33 | self.names = name.rstrip().split('\t')
34 | self.numbers = {}
35 | for _, name in enumerate(self.names):
36 | self.numbers[name] = []
37 |
38 | for numbers in self.file:
39 | numbers = numbers.rstrip().split('\t')
40 | for i in range(0, len(numbers)):
41 | self.numbers[self.names[i]].append(numbers[i])
42 | self.file.close()
43 | self.file = open(fpath, 'a')
44 | else:
45 | self.file = open(fpath, 'w')
46 |
47 | def set_names(self, names):
48 | if self.resume:
49 | pass
50 | # initialize numbers as empty list
51 | self.numbers = {}
52 | self.names = names
53 | for _, name in enumerate(self.names):
54 | self.file.write(name)
55 | self.file.write('\t')
56 | self.numbers[name] = []
57 | self.file.write('\n')
58 | self.file.flush()
59 |
60 | def append(self, numbers):
61 | assert len(self.names) == len(numbers), 'Numbers do not match names'
62 | for index, num in enumerate(numbers):
63 | if isinstance(num, int):
64 | self.file.write("{}".format(num))
65 | elif isinstance(num, str):
66 | self.file.write("{}".format(num))
67 | else:
68 | self.file.write("{0:.8f}".format(num))
69 | self.file.write('\t')
70 | self.numbers[self.names[index]].append(num)
71 | self.file.write('\n')
72 | self.file.flush()
73 |
74 | def plot(self, names=None):
75 | names = self.names if names == None else names
76 | numbers = self.numbers
77 | for _, name in enumerate(names):
78 | x = np.arange(len(numbers[name]))
79 | plt.plot(x, np.asarray(numbers[name]))
80 | plt.legend([self.title + '(' + name + ')' for name in names])
81 | plt.grid(True)
82 |
83 | def close(self):
84 | if self.file is not None:
85 | self.file.close()
86 |
87 | class LoggerMonitor(object):
88 | '''Load and visualize multiple logs.'''
89 | def __init__ (self, paths):
90 | '''paths is a distionary with {name:filepath} pair'''
91 | self.loggers = []
92 | for title, path in paths.items():
93 | logger = Logger(path, title=title, resume=True)
94 | self.loggers.append(logger)
95 |
96 | def plot(self, names=None):
97 | plt.figure()
98 | plt.subplot(121)
99 | legend_text = []
100 | for logger in self.loggers:
101 | legend_text += plot_overlap(logger, names)
102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
103 | plt.grid(True)
104 |
105 | if __name__ == '__main__':
106 | # # Example
107 | # logger = Logger('test.txt')
108 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
109 |
110 | # length = 100
111 | # t = np.arange(length)
112 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
113 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
114 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
115 |
116 | # for i in range(0, length):
117 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
118 | # logger.plot()
119 |
120 | # Example: logger monitor
121 | paths = {
122 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
123 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
124 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
125 | }
126 |
127 | field = ['Valid Acc.']
128 |
129 | monitor = LoggerMonitor(paths)
130 | monitor.plot(names=field)
131 | savefig('test.eps')
--------------------------------------------------------------------------------
/nn/Schrodinger/utils/misc.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import errno
7 | import os
8 | import sys
9 | import time
10 | import math
11 |
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | from torch.autograd import Variable
15 |
16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']
17 |
18 |
19 | def get_mean_and_std(dataset):
20 | '''Compute the mean and std value of dataset.'''
21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
22 |
23 | mean = torch.zeros(3)
24 | std = torch.zeros(3)
25 | print('==> Computing mean and std..')
26 | for inputs, targets in dataloader:
27 | for i in range(3):
28 | mean[i] += inputs[:,i,:,:].mean()
29 | std[i] += inputs[:,i,:,:].std()
30 | mean.div_(len(dataset))
31 | std.div_(len(dataset))
32 | return mean, std
33 |
34 | def init_params(net):
35 | '''Init layer parameters.'''
36 | for m in net.modules():
37 | if isinstance(m, nn.Conv2d):
38 | init.kaiming_normal(m.weight, mode='fan_out')
39 | if m.bias:
40 | init.constant(m.bias, 0)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.constant(m.weight, 1)
43 | init.constant(m.bias, 0)
44 | elif isinstance(m, nn.Linear):
45 | init.normal(m.weight, std=1e-3)
46 | if m.bias:
47 | init.constant(m.bias, 0)
48 |
49 | def mkdir_p(path):
50 | '''make dir if not exist'''
51 | try:
52 | os.makedirs(path)
53 | except OSError as exc: # Python >2.5
54 | if exc.errno == errno.EEXIST and os.path.isdir(path):
55 | pass
56 | else:
57 | raise
58 |
59 | class AverageMeter(object):
60 | """Computes and stores the average and current value
61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
62 | """
63 | def __init__(self):
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = 0
68 | self.avg = 0
69 | self.sum = 0
70 | self.count = 0
71 |
72 | def update(self, val, n=1):
73 | self.val = val
74 | self.sum += val * n
75 | self.count += n
76 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/nn/Schrodinger/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | from .misc import *
8 |
9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
10 |
11 | # functions to show an image
12 | def make_image(img, mean=(0,0,0), std=(1,1,1)):
13 | for i in range(0, 3):
14 | img[i] = img[i] * std[i] + mean[i] # unnormalize
15 | npimg = img.numpy()
16 | return np.transpose(npimg, (1, 2, 0))
17 |
18 | def gauss(x,a,b,c):
19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
20 |
21 | def colorize(x):
22 | ''' Converts a one-channel grayscale image to a color heatmap image '''
23 | if x.dim() == 2:
24 | torch.unsqueeze(x, 0, out=x)
25 | if x.dim() == 3:
26 | cl = torch.zeros([3, x.size(1), x.size(2)])
27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
28 | cl[1] = gauss(x,1,.5,.3)
29 | cl[2] = gauss(x,1,.2,.3)
30 | cl[cl.gt(1)] = 1
31 | elif x.dim() == 4:
32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
34 | cl[:,1,:,:] = gauss(x,1,.5,.3)
35 | cl[:,2,:,:] = gauss(x,1,.2,.3)
36 | return cl
37 |
38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
40 | plt.imshow(images)
41 | plt.show()
42 |
43 |
44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
45 | im_size = images.size(2)
46 |
47 | # save for adding mask
48 | im_data = images.clone()
49 | for i in range(0, 3):
50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
51 |
52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
53 | plt.subplot(2, 1, 1)
54 | plt.imshow(images)
55 | plt.axis('off')
56 |
57 | # for b in range(mask.size(0)):
58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
59 | mask_size = mask.size(2)
60 | # print('Max %f Min %f' % (mask.max(), mask.min()))
61 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
63 | # for c in range(3):
64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
65 |
66 | # print(mask.size())
67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
69 | plt.subplot(2, 1, 2)
70 | plt.imshow(mask)
71 | plt.axis('off')
72 |
73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
74 | im_size = images.size(2)
75 |
76 | # save for adding mask
77 | im_data = images.clone()
78 | for i in range(0, 3):
79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
80 |
81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
82 | plt.subplot(1+len(masklist), 1, 1)
83 | plt.imshow(images)
84 | plt.axis('off')
85 |
86 | for i in range(len(masklist)):
87 | mask = masklist[i].data.cpu()
88 | # for b in range(mask.size(0)):
89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
90 | mask_size = mask.size(2)
91 | # print('Max %f Min %f' % (mask.max(), mask.min()))
92 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
94 | # for c in range(3):
95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
96 |
97 | # print(mask.size())
98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
100 | plt.subplot(1+len(masklist), 1, i+2)
101 | plt.imshow(mask)
102 | plt.axis('off')
103 |
104 |
105 |
106 | # x = torch.zeros(1, 3, 3)
107 | # out = colorize(x)
108 | # out_im = make_image(out)
109 | # plt.imshow(out_im)
110 | # plt.show()
--------------------------------------------------------------------------------