├── README.md ├── aJittor.sh ├── checkpoint └── .gitignore ├── config.py ├── data.py ├── data └── scape │ ├── cot_nb.mat │ ├── cotweight.mat │ ├── feature.mat │ ├── geodesic.mat │ ├── meshinfo.mat │ ├── neighbor.mat │ └── weight.mat ├── jittor_geometric ├── __init__.py ├── data │ ├── __init__.py │ ├── data.py │ ├── dataset.py │ ├── download.py │ ├── in_memory_dataset.py │ └── makedirs.py ├── datasets │ ├── __init__.py │ └── planetoid.py ├── io │ ├── __init__.py │ ├── planetoid.py │ └── txt_array.py ├── nn │ ├── __init__.py │ ├── conv │ │ ├── __init__.py │ │ ├── cheb_conv.py │ │ ├── gcn2_conv.py │ │ ├── gcn_conv.py │ │ ├── message_passing.py │ │ ├── sg_conv.py │ │ └── utils │ │ │ ├── __pycache__ │ │ │ ├── inspector.cpython-37.pyc │ │ │ └── typing.cpython-37.pyc │ │ │ ├── inspector.py │ │ │ └── typing.py │ └── inits.py ├── transforms │ ├── __init__.py │ └── normalize_features.py ├── typing.py └── utils │ ├── __init__.py │ ├── degree.py │ ├── get_laplacian.py │ ├── isolated.py │ ├── loop.py │ └── num_nodes.py ├── mainJittor.py ├── network.py ├── networkJittor.py ├── result └── .gitignore ├── utils.py └── utilsJittor.py /README.md: -------------------------------------------------------------------------------- 1 | # MeshVAE_neural_editing 2 | 3 | Run ./aJittor.sh to train the network. -------------------------------------------------------------------------------- /aJittor.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 mainJittor.py \ 2 | --dataset=scape \ 3 | --logfolder=scape \ 4 | --mode=train \ 5 | --net_type=VAE \ 6 | --lambda0=1 \ 7 | --lambda1=5000 \ 8 | --lambda2=1 \ 9 | --lambda3=0.001 \ 10 | --lambda4=1 \ 11 | --ac_type=tanh 12 | -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/checkpoint/.gitignore -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | class Config(): 5 | 6 | start_ind={'face':346,'swing':135,'jump':135,'scape':36,'humanoida':138,'swingalign':135,'jumpalign':135,'horse':44,'fat2':400, 'fat': 846} 7 | end_ind={'face':385,'swing':150,'jump':150,'scape':72,'humanoida':154,'swingalign':150,'jumpalign':150,'horse':49, 'fat2':469, 'fat': 941} 8 | 9 | th={'face':10,'swing':40,'jump':100,'scape':15,'humanoid':1.5,'swingalign':50,'jumpalign':70,'horse':50, 'fat2':100, 'fat': 10} 10 | 11 | def __init__(self, FLAGS): 12 | self.latent = FLAGS.latent 13 | self.finaldim = FLAGS.finaldim 14 | self.epoch = FLAGS.epoch 15 | self.weight_type = FLAGS.weight_type 16 | self.layer_num = FLAGS.layer_num 17 | self.mode = FLAGS.mode 18 | self.cp_name = FLAGS.cp_name 19 | # self.test_file = FLAGS.test_file 20 | self.d_type = FLAGS.d_type 21 | self.lambda0 = FLAGS.lambda0 22 | self.lambda1 = FLAGS.lambda1 23 | self.lambda2 = FLAGS.lambda2 24 | self.lambda3 = FLAGS.lambda3 25 | self.lambda4 = FLAGS.lambda4 26 | self.logfolder = FLAGS.logfolder 27 | self.net_type = FLAGS.net_type 28 | self.seed = FLAGS.seed 29 | self.stddev = FLAGS.std 30 | dataname=FLAGS.dataset 31 | 32 | self.dataname = dataname 33 | self.featurefile='./data/'+dataname+'/feature.mat' 34 | self.neighbourfile='./data/'+dataname+'/neighbor.mat' 35 | self.distancefile='./data/'+dataname+'/geodesic.mat' 36 | self.start_idx=self.start_ind[dataname] 37 | self.end_idx=self.end_ind[dataname] 38 | if self.th.__contains__(dataname): 39 | self.threshold=self.th[dataname] 40 | else: 41 | self.threshold = FLAGS.th 42 | self.K = FLAGS.K 43 | self.conv_type = FLAGS.conv_type 44 | self.ac_type=FLAGS.ac_type 45 | 46 | if self.weight_type=='normal': 47 | self.weightfile='./data/'+dataname+'/weight.mat' 48 | else: 49 | self.weightfile='./data/'+dataname+'/cotweight.mat' 50 | 51 | if not os.path.isdir('./checkpoint/'+self.logfolder) and self.mode=='train': 52 | os.mkdir('./checkpoint/'+self.logfolder) 53 | self.meshinfofile = './data/'+dataname+'/meshinfo.mat' 54 | 55 | syn_list = FLAGS.syn_list 56 | self.comp_idx = syn_list[0:int(len(syn_list)/3)] 57 | self.max_min = syn_list[int(len(syn_list)/3):int(len(syn_list)/3*2)] 58 | self.comp_weight = syn_list[int(len(syn_list)/3*2):] -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import scipy.io as sio 4 | import numpy as np 5 | import os 6 | import scipy 7 | 8 | 9 | 10 | 11 | class Data(): 12 | def __init__(self, config): 13 | 14 | self.feature, self.logrmin, self.logrmax, self.smin, self.smax, self.pointnum = self.load_data(config.featurefile) 15 | if os.path.exists('./data/idx/'+config.dataname+'.dat'): 16 | split_idx = pickle.load(open('./data/idx/'+config.dataname+'.dat','rb')) 17 | else: 18 | split_idx = range(len(self.feature)) 19 | 20 | start_idx = config.start_idx 21 | end_idx = config.end_idx 22 | 23 | 24 | train_data = self.feature 25 | valid_data = self.feature 26 | # if config.dataname == 'scape': 27 | # valid_data = [self.feature[i] for i in split_idx[start_idx:end_idx]] 28 | # train_data = [self.feature[i] for i in split_idx[0:start_idx]] 29 | # else: 30 | # train_data = [self.feature[i] for i in split_idx[start_idx:end_idx]] 31 | # valid_data = [self.feature[i] for i in split_idx[0:start_idx]] 32 | 33 | self.train_data = np.asarray(train_data,dtype='float32') 34 | self.valid_data = np.asarray(valid_data,dtype='float32') 35 | 36 | self.neighbour, self.degrees, self.maxdegree = self.load_neighbour(config.neighbourfile, self.pointnum) 37 | self.geodesic_weight = self.load_geodesic_weight(config.distancefile, self.pointnum) 38 | weight=self.load_weight(config.weightfile, self.pointnum) 39 | weight=weight.astype('float32') 40 | self.weight=scipy.sparse.csr_matrix(weight) 41 | # print(config.meshinfofile) 42 | # cc() 43 | self.load_meshinfo(config.meshinfofile) 44 | 45 | def load_meshinfo(self, path): 46 | meshinfo = sio.loadmat(path) 47 | self.face = meshinfo['f'].T 48 | self.vertices = meshinfo['v'] 49 | self.vdiff = meshinfo['vdiff'] 50 | self.recon = meshinfo['recon'] 51 | 52 | 53 | def load_weight(self, path, pointnum, name='weight'): 54 | data = sio.loadmat(path) 55 | data = data[name] 56 | 57 | weight = np.zeros((pointnum,pointnum)).astype('float32') 58 | weight = data 59 | 60 | return weight 61 | def load_geodesic_weight(self, path, pointnum, name='point_geodesic'): 62 | 63 | 64 | data = sio.loadmat(path) 65 | data = data[name] 66 | 67 | distance = np.zeros((pointnum, pointnum)).astype('float32') 68 | distance = data 69 | 70 | return distance 71 | def load_neighbour(self, path, pointnum, name='neighbour'): 72 | data = sio.loadmat(path) 73 | data = data[name] 74 | maxdegree = data.shape[1] 75 | neighbour = np.zeros((pointnum, maxdegree)).astype('float32') 76 | neighbour = data 77 | degree = np.zeros((neighbour.shape[0], 1)).astype('float32') 78 | for i in range(neighbour.shape[0]): 79 | degree[i] = np.count_nonzero(neighbour[i]) 80 | return neighbour,degree,maxdegree 81 | def load_data(self, path): 82 | resultmax = 0.95 83 | resultmin = -0.95 84 | 85 | data = sio.loadmat(path) 86 | logr = data['FLOGRNEW'] 87 | s = data['FS'] 88 | pointnum=logr.shape[1] 89 | logrmin = logr.min() 90 | logrmin = logrmin - 1e-6 91 | logrmax = logr.max() 92 | logrmax = logrmax + 1e-6 93 | smin = s.min() 94 | smin = smin- 1e-6 95 | smax = s.max() 96 | smax = smax + 1e-6 97 | 98 | rnew = (resultmax-resultmin)*(logr-logrmin)/(logrmax - logrmin) + resultmin 99 | snew = (resultmax-resultmin)*(s - smin)/(smax-smin) + resultmin 100 | 101 | feature = np.concatenate((rnew,snew),axis = 2) 102 | 103 | f = np.zeros_like(feature).astype('float32') 104 | f = feature 105 | 106 | return f, logrmin, logrmax, smin, smax,pointnum 107 | def load_test_data(self,path): 108 | resultmax = 0.95 109 | resultmin = -0.95 110 | 111 | data = sio.loadmat(path) 112 | logr = data['FLOGRNEW'] 113 | s = data['FS'] 114 | 115 | 116 | 117 | rnew = (resultmax-resultmin)*(logr-self.logrmin)/(self.logrmax - self.logrmin) + resultmin 118 | snew = (resultmax-resultmin)*(s - self.smin)/(self.smax-self.smin) + resultmin 119 | 120 | feature = np.concatenate((rnew,snew),axis = 2) 121 | 122 | f = np.zeros_like(feature).astype('float32') 123 | f = feature 124 | return f 125 | 126 | 127 | def recover_data(self, recover_feature, logrmin, logrmax, smin, smax, pointnum): 128 | logr = recover_feature[:,:,0:3] 129 | s = recover_feature[:,:,3:9] 130 | 131 | resultmax = 0.95 132 | resultmin = -0.95 133 | 134 | s = (smax - smin) * (s - resultmin) / (resultmax - resultmin) + smin 135 | logr = (logrmax - logrmin) * (logr - resultmin) / (resultmax - resultmin) + logrmin 136 | 137 | return s, logr -------------------------------------------------------------------------------- /data/scape/cot_nb.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/data/scape/cot_nb.mat -------------------------------------------------------------------------------- /data/scape/cotweight.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/data/scape/cotweight.mat -------------------------------------------------------------------------------- /data/scape/feature.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/data/scape/feature.mat -------------------------------------------------------------------------------- /data/scape/geodesic.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/data/scape/geodesic.mat -------------------------------------------------------------------------------- /data/scape/meshinfo.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/data/scape/meshinfo.mat -------------------------------------------------------------------------------- /data/scape/neighbor.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/data/scape/neighbor.mat -------------------------------------------------------------------------------- /data/scape/weight.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/data/scape/weight.mat -------------------------------------------------------------------------------- /jittor_geometric/__init__.py: -------------------------------------------------------------------------------- 1 | from types import ModuleType 2 | from importlib import import_module 3 | 4 | 5 | __version__ = '0.0.1' 6 | 7 | __all__ = [ 8 | 'jittor_geometric', 9 | '__version__', 10 | ] 11 | -------------------------------------------------------------------------------- /jittor_geometric/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import Data 2 | from .dataset import Dataset 3 | from .in_memory_dataset import InMemoryDataset 4 | from .download import download_url 5 | 6 | __all__ = [ 7 | 'Data', 8 | 'Dataset', 9 | 'InMemoryDataset', 10 | 'download_url', 11 | ] 12 | -------------------------------------------------------------------------------- /jittor_geometric/data/data.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import logging 4 | import collections 5 | 6 | import jittor as jt 7 | from jittor import Var 8 | from jittor_geometric.utils import (contains_isolated_nodes, 9 | contains_self_loops) 10 | 11 | from ..utils.num_nodes import maybe_num_nodes 12 | 13 | __num_nodes_warn_msg__ = ( 14 | 'The number of nodes in your data object can only be inferred by its {} ' 15 | 'indices, and hence may result in unexpected batch-wise behavior, e.g., ' 16 | 'in case there exists isolated nodes. Please consider explicitly setting ' 17 | 'the number of nodes for this data object by assigning it to ' 18 | 'data.num_nodes.') 19 | 20 | 21 | def size_repr(key, item, indent=0): 22 | indent_str = ' ' * indent 23 | if isinstance(item, Var) and item.ndim == 0: 24 | out = item.item() 25 | elif isinstance(item, Var): 26 | out = str(list(item.size())) 27 | # careful 28 | # elif isinstance(item, SparseTensor): 29 | # out = str(item.sizes())[:-1] + f', nnz={item.nnz()}]' 30 | elif isinstance(item, list) or isinstance(item, tuple): 31 | out = str([len(item)]) 32 | elif isinstance(item, dict): 33 | lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] 34 | out = '{\n' + ',\n'.join(lines) + '\n' + indent_str + '}' 35 | elif isinstance(item, str): 36 | out = f'"{item}"' 37 | else: 38 | out = str(item) 39 | 40 | return f'{indent_str}{key}={out}' 41 | 42 | 43 | class Data(object): 44 | r"""A plain old python object modeling a single graph with various 45 | (optional) attributes: 46 | 47 | Args: 48 | x (Var, optional): Node feature matrix with shape :obj:`[num_nodes, 49 | num_node_features]`. (default: :obj:`None`) 50 | edge_index (Var.int32, optional): Graph connectivity in COO format 51 | with shape :obj:`[2, num_edges]`. (default: :obj:`None`) 52 | edge_attr (Var, optional): Edge feature matrix with shape 53 | :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) 54 | y (Var, optional): Graph or node targets with arbitrary shape. 55 | (default: :obj:`None`) 56 | pos (Var, optional): Node position matrix with shape 57 | :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) 58 | normal (Var, optional): Normal vector matrix with shape 59 | :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) 60 | face (Var.int32, optional): Face adjacency matrix with shape 61 | :obj:`[3, num_faces]`. (default: :obj:`None`) 62 | 63 | The data object is not restricted to these attributes and can be extented 64 | by any other additional data. 65 | 66 | Example:: 67 | 68 | data = Data(x=x, edge_index=edge_index) 69 | data.train_idx = jt.array([...], dtype=Var.int32) 70 | data.test_mask = jt.array([...], dtype=Var.bool) 71 | """ 72 | 73 | def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, 74 | pos=None, normal=None, face=None, **kwargs): 75 | self.x = x 76 | self.edge_index = edge_index 77 | self.edge_attr = edge_attr 78 | self.y = y 79 | self.pos = pos 80 | self.normal = normal 81 | self.face = face 82 | for key, item in kwargs.items(): 83 | if key == 'num_nodes': 84 | self.__num_nodes__ = item 85 | else: 86 | self[key] = item 87 | 88 | if edge_index is not None and edge_index.dtype != Var.int32: 89 | raise ValueError( 90 | (f'Argument `edge_index` needs to be of type `Var.int32` but ' 91 | f'found type `{edge_index.dtype}`.')) 92 | 93 | if face is not None and face.dtype != Var.int32: 94 | raise ValueError( 95 | (f'Argument `face` needs to be of type `Var int32` but found ' 96 | f'type `{face.dtype}`.')) 97 | 98 | # if jittor_geometric.is_debug_enabled(): 99 | # self.debug() 100 | 101 | @classmethod 102 | def from_dict(cls, dictionary): 103 | r"""Creates a data object from a python dictionary.""" 104 | data = cls() 105 | 106 | for key, item in dictionary.items(): 107 | data[key] = item 108 | 109 | # if jittor_geometric.is_debug_enabled(): 110 | # data.debug() 111 | 112 | return data 113 | 114 | def to_dict(self): 115 | return {key: item for key, item in self} 116 | 117 | def to_namedtuple(self): 118 | keys = self.keys 119 | DataTuple = collections.namedtuple('DataTuple', keys) 120 | return DataTuple(*[self[key] for key in keys]) 121 | 122 | def __getitem__(self, key): 123 | r"""Gets the data of the attribute :obj:`key`.""" 124 | return getattr(self, key, None) 125 | 126 | def __setitem__(self, key, value): 127 | """Sets the attribute :obj:`key` to :obj:`value`.""" 128 | setattr(self, key, value) 129 | 130 | @property 131 | def keys(self): 132 | r"""Returns all names of graph attributes.""" 133 | keys = [key for key in self.__dict__.keys() if self[key] is not None] 134 | keys = [key for key in keys if key[:2] != '__' and key[-2:] != '__'] 135 | return keys 136 | 137 | def __len__(self): 138 | r"""Returns the number of all present attributes.""" 139 | return len(self.keys) 140 | 141 | def __contains__(self, key): 142 | r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the 143 | data.""" 144 | return key in self.keys 145 | 146 | def __iter__(self): 147 | r"""Iterates over all present attributes in the data, yielding their 148 | attribute names and content.""" 149 | for key in sorted(self.keys): 150 | yield key, self[key] 151 | 152 | def __call__(self, *keys): 153 | r"""Iterates over all attributes :obj:`*keys` in the data, yielding 154 | their attribute names and content. 155 | If :obj:`*keys` is not given this method will iterative over all 156 | present attributes.""" 157 | for key in sorted(self.keys) if not keys else keys: 158 | if key in self: 159 | yield key, self[key] 160 | 161 | def __cat_dim__(self, key, value): 162 | r"""Returns the dimension for which :obj:`value` of attribute 163 | :obj:`key` will get concatenated when creating batches. 164 | 165 | .. note:: 166 | 167 | This method is for internal use only, and should only be overridden 168 | if the batch concatenation process is corrupted for a specific data 169 | attribute. 170 | """ 171 | # Concatenate `*index*` and `*face*` attributes in the last dimension. 172 | if bool(re.search('(index|face)', key)): 173 | return -1 174 | # By default, concatenate sparse matrices diagonally. 175 | # careful 176 | # elif isinstance(value, SparseTensor): 177 | # return (0, 1) 178 | return 0 179 | 180 | def __inc__(self, key, value): 181 | r"""Returns the incremental count to cumulatively increase the value 182 | of the next attribute of :obj:`key` when creating batches. 183 | 184 | .. note:: 185 | 186 | This method is for internal use only, and should only be overridden 187 | if the batch concatenation process is corrupted for a specific data 188 | attribute. 189 | """ 190 | # Only `*index*` and `*face*` attributes should be cumulatively summed 191 | # up when creating batches. 192 | return self.num_nodes if bool(re.search('(index|face)', key)) else 0 193 | 194 | @property 195 | def num_nodes(self): 196 | r"""Returns or sets the number of nodes in the graph. 197 | 198 | .. note:: 199 | The number of nodes in your data object is typically automatically 200 | inferred, *e.g.*, when node features :obj:`x` are present. 201 | In some cases however, a graph may only be given by its edge 202 | indices :obj:`edge_index`. 203 | Jittor Geometric then *guesses* the number of nodes 204 | according to :obj:`edge_index.max().item() + 1`, but in case there 205 | exists isolated nodes, this number has not to be correct and can 206 | therefore result in unexpected batch-wise behavior. 207 | Thus, we recommend to set the number of nodes in your data object 208 | explicitly via :obj:`data.num_nodes = ...`. 209 | You will be given a warning that requests you to do so. 210 | """ 211 | if hasattr(self, '__num_nodes__'): 212 | return self.__num_nodes__ 213 | for key, item in self('x', 'pos', 'normal', 'batch'): 214 | # careful 215 | # if isinstance(item, SparseTensor): 216 | # return item.size(0) 217 | # else: 218 | return item.size(self.__cat_dim__(key, item)) 219 | if hasattr(self, 'adj'): 220 | return self.adj.size(0) 221 | if hasattr(self, 'adj_t'): 222 | return self.adj_t.size(1) 223 | if self.face is not None: 224 | logging.warning(__num_nodes_warn_msg__.format('face')) 225 | return maybe_num_nodes(self.face) 226 | if self.edge_index is not None: 227 | logging.warning(__num_nodes_warn_msg__.format('edge')) 228 | return maybe_num_nodes(self.edge_index) 229 | return None 230 | 231 | @num_nodes.setter 232 | def num_nodes(self, num_nodes): 233 | self.__num_nodes__ = num_nodes 234 | 235 | @property 236 | def num_edges(self): 237 | """ 238 | Returns the number of edges in the graph. 239 | For undirected graphs, this will return the number of bi-directional 240 | edges, which is double the amount of unique edges. 241 | """ 242 | for key, item in self('edge_index', 'edge_attr'): 243 | return item.size(self.__cat_dim__(key, item)) 244 | for key, item in self('adj', 'adj_t'): 245 | return item.nnz() 246 | return None 247 | 248 | @property 249 | def num_faces(self): 250 | r"""Returns the number of faces in the mesh.""" 251 | if self.face is not None: 252 | return self.face.size(self.__cat_dim__('face', self.face)) 253 | return None 254 | 255 | @property 256 | def num_node_features(self): 257 | r"""Returns the number of features per node in the graph.""" 258 | if self.x is None: 259 | return 0 260 | return 1 if self.x.ndim == 1 else self.x.size(1) 261 | 262 | @property 263 | def num_features(self): 264 | r"""Alias for :py:attr:`~num_node_features`.""" 265 | return self.num_node_features 266 | 267 | @property 268 | def num_edge_features(self): 269 | r"""Returns the number of features per edge in the graph.""" 270 | if self.edge_attr is None: 271 | return 0 272 | return 1 if self.edge_attr.ndim == 1 else self.edge_attr.size(1) 273 | 274 | # careful 275 | # def is_coalesced(self): 276 | # r"""Returns :obj:`True`, if edge indices are ordered and do not contain 277 | # duplicate entries.""" 278 | # edge_index, _ = coalesce(self.edge_index, None, self.num_nodes, 279 | # self.num_nodes) 280 | # return self.edge_index.numel() == edge_index.numel() and ( 281 | # self.edge_index != edge_index).sum().item() == 0 282 | 283 | # def coalesce(self): 284 | # r""""Orders and removes duplicated entries from edge indices.""" 285 | # self.edge_index, self.edge_attr = coalesce(self.edge_index, 286 | # self.edge_attr, 287 | # self.num_nodes, 288 | # self.num_nodes) 289 | # return self 290 | 291 | def contains_isolated_nodes(self): 292 | r"""Returns :obj:`True`, if the graph contains isolated nodes.""" 293 | return contains_isolated_nodes(self.edge_index, self.num_nodes) 294 | 295 | def contains_self_loops(self): 296 | """Returns :obj:`True`, if the graph contains self-loops.""" 297 | return contains_self_loops(self.edge_index) 298 | 299 | # def is_undirected(self): 300 | # r"""Returns :obj:`True`, if graph edges are undirected.""" 301 | # return is_undirected(self.edge_index, self.edge_attr, self.num_nodes) 302 | 303 | # def is_directed(self): 304 | # r"""Returns :obj:`True`, if graph edges are directed.""" 305 | # return not self.is_undirected() 306 | 307 | def __apply__(self, item, func): 308 | if isinstance(item, Var): 309 | return func(item) 310 | # careful 311 | # elif isinstance(item, SparseTensor): 312 | # # Not all apply methods are supported for `SparseTensor`, e.g., 313 | # # `contiguous()`. We can get around it by capturing the exception. 314 | # try: 315 | # return func(item) 316 | # except AttributeError: 317 | # return item 318 | elif isinstance(item, (tuple, list)): 319 | return [self.__apply__(v, func) for v in item] 320 | elif isinstance(item, dict): 321 | return {k: self.__apply__(v, func) for k, v in item.items()} 322 | else: 323 | return item 324 | 325 | def apply(self, func, *keys): 326 | r"""Applies the function :obj:`func` to all Var attributes 327 | :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to 328 | all present attributes. 329 | """ 330 | for key, item in self(*keys): 331 | self[key] = self.__apply__(item, func) 332 | return self 333 | 334 | def contiguous(self, *keys): 335 | r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. 336 | If :obj:`*keys` is not given, all present attributes are ensured to 337 | have a contiguous memory layout.""" 338 | return self.apply(lambda x: x.contiguous(), *keys) 339 | 340 | def cpu(self, *keys): 341 | r"""Copies all attributes :obj:`*keys` to CPU memory. 342 | If :obj:`*keys` is not given, the conversion is applied to all present 343 | attributes.""" 344 | return self.apply(lambda x: x.cpu(), *keys) 345 | 346 | def cuda(self, device=None, non_blocking=False, *keys): 347 | r"""Copies all attributes :obj:`*keys` to CUDA memory. 348 | If :obj:`*keys` is not given, the conversion is applied to all present 349 | attributes.""" 350 | return self.apply( 351 | lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys) 352 | 353 | def clone(self): 354 | r"""Performs a deep-copy of the data object.""" 355 | return self.__class__.from_dict({ 356 | k: v.clone() if isinstance(v, Var) else copy.deepcopy(v) 357 | for k, v in self.__dict__.items() 358 | }) 359 | 360 | def pin_memory(self, *keys): 361 | r"""Copies all attributes :obj:`*keys` to pinned memory. 362 | If :obj:`*keys` is not given, the conversion is applied to all present 363 | attributes.""" 364 | return self.apply(lambda x: x.pin_memory(), *keys) 365 | 366 | def debug(self): 367 | if self.edge_index is not None: 368 | if self.edge_index.dtype != Var.int32: 369 | raise RuntimeError( 370 | ('Expected edge indices of dtype {}, but found dtype ' 371 | ' {}').format(Var.int32, self.edge_index.dtype)) 372 | 373 | if self.face is not None: 374 | if self.face.dtype != Var.int32: 375 | raise RuntimeError( 376 | ('Expected face indices of dtype {}, but found dtype ' 377 | ' {}').format(Var.int32, self.face.dtype)) 378 | 379 | if self.edge_index is not None: 380 | if self.edge_index.ndim != 2 or self.edge_index.size(0) != 2: 381 | raise RuntimeError( 382 | ('Edge indices should have shape [2, num_edges] but found' 383 | ' shape {}').format(self.edge_index.size())) 384 | 385 | if self.edge_index is not None and self.num_nodes is not None: 386 | if self.edge_index.numel() > 0: 387 | min_index = self.edge_index.min() 388 | max_index = self.edge_index.max() 389 | else: 390 | min_index = max_index = 0 391 | if min_index < 0 or max_index > self.num_nodes - 1: 392 | raise RuntimeError( 393 | ('Edge indices must lay in the interval [0, {}]' 394 | ' but found them in the interval [{}, {}]').format( 395 | self.num_nodes - 1, min_index, max_index)) 396 | 397 | if self.face is not None: 398 | if self.face.ndim != 2 or self.face.size(0) != 3: 399 | raise RuntimeError( 400 | ('Face indices should have shape [3, num_faces] but found' 401 | ' shape {}').format(self.face.size())) 402 | 403 | if self.face is not None and self.num_nodes is not None: 404 | if self.face.numel() > 0: 405 | min_index = self.face.min() 406 | max_index = self.face.max() 407 | else: 408 | min_index = max_index = 0 409 | if min_index < 0 or max_index > self.num_nodes - 1: 410 | raise RuntimeError( 411 | ('Face indices must lay in the interval [0, {}]' 412 | ' but found them in the interval [{}, {}]').format( 413 | self.num_nodes - 1, min_index, max_index)) 414 | 415 | if self.edge_index is not None and self.edge_attr is not None: 416 | if self.edge_index.size(1) != self.edge_attr.size(0): 417 | raise RuntimeError( 418 | ('Edge indices and edge attributes hold a differing ' 419 | 'number of edges, found {} and {}').format( 420 | self.edge_index.size(), self.edge_attr.size())) 421 | 422 | if self.x is not None and self.num_nodes is not None: 423 | if self.x.size(0) != self.num_nodes: 424 | raise RuntimeError( 425 | ('Node features should hold {} elements in the first ' 426 | 'dimension but found {}').format(self.num_nodes, 427 | self.x.size(0))) 428 | 429 | if self.pos is not None and self.num_nodes is not None: 430 | if self.pos.size(0) != self.num_nodes: 431 | raise RuntimeError( 432 | ('Node positions should hold {} elements in the first ' 433 | 'dimension but found {}').format(self.num_nodes, 434 | self.pos.size(0))) 435 | 436 | if self.normal is not None and self.num_nodes is not None: 437 | if self.normal.size(0) != self.num_nodes: 438 | raise RuntimeError( 439 | ('Node normals should hold {} elements in the first ' 440 | 'dimension but found {}').format(self.num_nodes, 441 | self.normal.size(0))) 442 | 443 | def __repr__(self): 444 | cls = str(self.__class__.__name__) 445 | has_dict = any([isinstance(item, dict) for _, item in self]) 446 | 447 | if not has_dict: 448 | info = [size_repr(key, item) for key, item in self] 449 | return '{}({})'.format(cls, ', '.join(info)) 450 | else: 451 | info = [size_repr(key, item, indent=2) for key, item in self] 452 | return '{}(\n{}\n)'.format(cls, ',\n'.join(info)) 453 | -------------------------------------------------------------------------------- /jittor_geometric/data/dataset.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import logging 4 | import os.path as osp 5 | 6 | import jittor as jt 7 | from jittor import dataset, Var 8 | from .makedirs import makedirs 9 | 10 | 11 | def to_list(x): 12 | if not isinstance(x, (tuple, list)): 13 | x = [x] 14 | return x 15 | 16 | 17 | def files_exist(files): 18 | return len(files) != 0 and all(osp.exists(f) for f in files) 19 | 20 | 21 | def __repr__(obj): 22 | if obj is None: 23 | return 'None' 24 | return re.sub('(<.*?)\\s.*(>)', r'\1\2', obj.__repr__()) 25 | 26 | 27 | class Dataset(dataset.Dataset): 28 | r""" 29 | Args: 30 | root (string, optional): Root directory where the dataset should be 31 | saved. (optional: :obj:`None`) 32 | transform (callable, optional): A function/transform that takes in an 33 | :obj:`torch_geometric.data.Data` object and returns a transformed 34 | version. The data object will be transformed before every access. 35 | (default: :obj:`None`) 36 | pre_transform (callable, optional): A function/transform that takes in 37 | an :obj:`jittor_geometric.data.Data` object and returns a 38 | transformed version. The data object will be transformed before 39 | being saved to disk. (default: :obj:`None`) 40 | pre_filter (callable, optional): A function that takes in an 41 | :obj:`jittor_geometric.data.Data` object and returns a boolean 42 | value, indicating whether the data object should be included in the 43 | final dataset. (default: :obj:`None`) 44 | """ 45 | @property 46 | def raw_file_names(self): 47 | r"""The name of the files to find in the :obj:`self.raw_dir` folder in 48 | order to skip the download.""" 49 | raise NotImplementedError 50 | 51 | @property 52 | def processed_file_names(self): 53 | r"""The name of the files to find in the :obj:`self.processed_dir` 54 | folder in order to skip the processing.""" 55 | raise NotImplementedError 56 | 57 | def download(self): 58 | r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" 59 | raise NotImplementedError 60 | 61 | def process(self): 62 | r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" 63 | raise NotImplementedError 64 | 65 | def len(self): 66 | raise NotImplementedError 67 | 68 | def get(self, idx): 69 | r"""Gets the data object at index :obj:`idx`.""" 70 | raise NotImplementedError 71 | 72 | def __init__(self, root=None, transform=None, pre_transform=None, 73 | pre_filter=None): 74 | super(Dataset, self).__init__() 75 | 76 | if isinstance(root, str): 77 | root = osp.expanduser(osp.normpath(root)) 78 | 79 | self.root = root 80 | self.transform = transform 81 | self.pre_transform = pre_transform 82 | self.pre_filter = pre_filter 83 | self.__indices__ = None 84 | 85 | if 'download' in self.__class__.__dict__.keys(): 86 | self._download() 87 | 88 | if 'process' in self.__class__.__dict__.keys(): 89 | self._process() 90 | 91 | def indices(self): 92 | if self.__indices__ is not None: 93 | return self.__indices__ 94 | else: 95 | return range(len(self)) 96 | 97 | @property 98 | def raw_dir(self): 99 | return osp.join(self.root, 'raw') 100 | 101 | @property 102 | def processed_dir(self): 103 | return osp.join(self.root, 'processed') 104 | 105 | @property 106 | def num_node_features(self): 107 | r"""Returns the number of features per node in the dataset.""" 108 | return self[0].num_node_features 109 | 110 | @property 111 | def num_features(self): 112 | r"""Alias for :py:attr:`~num_node_features`.""" 113 | return self.num_node_features 114 | 115 | @property 116 | def num_edge_features(self): 117 | r"""Returns the number of features per edge in the dataset.""" 118 | return self[0].num_edge_features 119 | 120 | @property 121 | def raw_paths(self): 122 | r"""The filepaths to find in order to skip the download.""" 123 | files = to_list(self.raw_file_names) 124 | return [osp.join(self.raw_dir, f) for f in files] 125 | 126 | @property 127 | def processed_paths(self): 128 | r"""The filepaths to find in the :obj:`self.processed_dir` 129 | folder in order to skip the processing.""" 130 | files = to_list(self.processed_file_names) 131 | return [osp.join(self.processed_dir, f) for f in files] 132 | 133 | def _download(self): 134 | if files_exist(self.raw_paths): # pragma: no cover 135 | return 136 | 137 | makedirs(self.raw_dir) 138 | self.download() 139 | 140 | def _process(self): 141 | f = osp.join(self.processed_dir, 'pre_transform.pkl') 142 | if osp.exists(f) and jt.load(f) != __repr__(self.pre_transform): 143 | logging.warning( 144 | 'The `pre_transform` argument differs from the one used in ' 145 | 'the pre-processed version of this dataset. If you really ' 146 | 'want to make use of another pre-processing technique, make ' 147 | 'sure to delete `{}` first.'.format(self.processed_dir)) 148 | f = osp.join(self.processed_dir, 'pre_filter.pkl') 149 | if osp.exists(f) and jt.load(f) != __repr__(self.pre_filter): 150 | logging.warning( 151 | 'The `pre_filter` argument differs from the one used in the ' 152 | 'pre-processed version of this dataset. If you really want to ' 153 | 'make use of another pre-fitering technique, make sure to ' 154 | 'delete `{}` first.'.format(self.processed_dir)) 155 | 156 | if files_exist(self.processed_paths): # pragma: no cover 157 | return 158 | 159 | print('Processing...') 160 | 161 | makedirs(self.processed_dir) 162 | self.process() 163 | 164 | path = osp.join(self.processed_dir, 'pre_transform.pkl') 165 | jt.save(__repr__(self.pre_transform), path) 166 | path = osp.join(self.processed_dir, 'pre_filter.pkl') 167 | jt.save(__repr__(self.pre_filter), path) 168 | 169 | print('Done!') 170 | 171 | def __len__(self): 172 | r"""The number of examples in the dataset.""" 173 | if self.__indices__ is not None: 174 | return len(self.__indices__) 175 | return self.len() 176 | 177 | def __getitem__(self, idx): 178 | r"""Gets the data object at index :obj:`idx` and transforms it (in case 179 | a :obj:`self.transform` is given). 180 | In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a 181 | tuple, a Var int32 or a Var bool, will return a subset of the 182 | dataset at the specified indices.""" 183 | if isinstance(idx, int): 184 | data = self.get(self.indices()[idx]) 185 | data = data if self.transform is None else self.transform(data) 186 | return data 187 | else: 188 | return self.index_select(idx) 189 | 190 | def index_select(self, idx): 191 | indices = self.indices() 192 | 193 | if isinstance(idx, slice): 194 | indices = indices[idx] 195 | elif isinstance(idx, Var): 196 | if idx.dtype == Var.int32: 197 | if len(idx.shape) == 0: 198 | idx = idx.unsqueeze(0) 199 | return self.index_select(idx.tolist()) 200 | elif idx.dtype == Var.bool or idx.dtype == Var.uint8: 201 | return self.index_select( 202 | jt.flatten(idx.nonzero(idx)).tolist()) 203 | elif isinstance(idx, list) or isinstance(idx, tuple): 204 | indices = [indices[i] for i in idx] 205 | else: 206 | raise IndexError( 207 | 'Only integers, slices (`:`), list, tuples, and long or bool ' 208 | 'Vars are valid indices (got {}).'.format( 209 | type(idx).__name__)) 210 | 211 | dataset = copy.copy(self) 212 | dataset.__indices__ = indices 213 | return dataset 214 | 215 | def shuffle(self, return_perm=False): 216 | r"""Randomly shuffles the examples in the dataset. 217 | 218 | Args: 219 | return_perm (bool, optional): If set to :obj:`True`, will 220 | additionally return the random permutation used to shuffle the 221 | dataset. (default: :obj:`False`) 222 | """ 223 | perm = jt.randperm(len(self)) 224 | dataset = self.index_select(perm) 225 | return (dataset, perm) if return_perm is True else dataset 226 | 227 | def __repr__(self): # pragma: no cover 228 | return f'{self.__class__.__name__}({len(self)})' 229 | -------------------------------------------------------------------------------- /jittor_geometric/data/download.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import ssl 4 | import os.path as osp 5 | from six.moves import urllib 6 | 7 | from .makedirs import makedirs 8 | 9 | 10 | def download_url(url, folder, log=True): 11 | r"""Downloads the content of an URL to a specific folder. 12 | 13 | Args: 14 | url (string): The url. 15 | folder (string): The folder. 16 | log (bool, optional): If :obj:`False`, will not print anything to the 17 | console. (default: :obj:`True`) 18 | """ 19 | 20 | filename = url.rpartition('/')[2].split('?')[0] 21 | path = osp.join(folder, filename) 22 | 23 | if osp.exists(path): # pragma: no cover 24 | if log: 25 | print('Using exist file', filename) 26 | return path 27 | 28 | if log: 29 | print('Downloading', url) 30 | 31 | makedirs(folder) 32 | 33 | context = ssl._create_unverified_context() 34 | data = urllib.request.urlopen(url, context=context) 35 | 36 | with open(path, 'wb') as f: 37 | f.write(data.read()) 38 | 39 | return path 40 | -------------------------------------------------------------------------------- /jittor_geometric/data/in_memory_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from itertools import repeat, product 3 | 4 | import jittor as jt 5 | from jittor import Var 6 | from jittor_geometric.data import Dataset 7 | 8 | 9 | class InMemoryDataset(Dataset): 10 | r"""Dataset base class for creating graph datasets which fit completely 11 | into CPU memory. 12 | 13 | Args: 14 | root (string, optional): Root directory where the dataset should be 15 | saved. (default: :obj:`None`) 16 | transform (callable, optional): A function/transform that takes in an 17 | :obj:`jittor_geometric.data.Data` object and returns a transformed 18 | version. The data object will be transformed before every access. 19 | (default: :obj:`None`) 20 | pre_transform (callable, optional): A function/transform that takes in 21 | an :obj:`jittor_geometric.data.Data` object and returns a 22 | transformed version. The data object will be transformed before 23 | being saved to disk. (default: :obj:`None`) 24 | pre_filter (callable, optional): A function that takes in an 25 | :obj:`jittor_geometric.data.Data` object and returns a boolean 26 | value, indicating whether the data object should be included in the 27 | final dataset. (default: :obj:`None`) 28 | """ 29 | 30 | @property 31 | def raw_file_names(self): 32 | r"""The name of the files to find in the :obj:`self.raw_dir` folder in 33 | order to skip the download.""" 34 | raise NotImplementedError 35 | 36 | @property 37 | def processed_file_names(self): 38 | r"""The name of the files to find in the :obj:`self.processed_dir` 39 | folder in order to skip the processing.""" 40 | raise NotImplementedError 41 | 42 | def download(self): 43 | r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" 44 | raise NotImplementedError 45 | 46 | def process(self): 47 | r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" 48 | raise NotImplementedError 49 | 50 | def __init__(self, root=None, transform=None, pre_transform=None, 51 | pre_filter=None): 52 | super(InMemoryDataset, self).__init__(root, transform, pre_transform, 53 | pre_filter) 54 | self.data, self.slices = None, None 55 | self.__data_list__ = None 56 | 57 | @property 58 | def num_classes(self): 59 | r"""The number of classes in the dataset.""" 60 | if self.data.y is None: 61 | return 0 62 | elif self.data.y.ndim == 1: 63 | return int(self.data.y.max().item()) + 1 64 | else: 65 | return self.data.y.size(-1) 66 | 67 | def len(self): 68 | for item in self.slices.values(): 69 | return len(item) - 1 70 | return 0 71 | 72 | def get(self, idx): 73 | if hasattr(self, '__data_list__'): 74 | if self.__data_list__ is None: 75 | self.__data_list__ = self.len() * [None] 76 | else: 77 | data = self.__data_list__[idx] 78 | if data is not None: 79 | return copy.copy(data) 80 | 81 | data = self.data.__class__() 82 | if hasattr(self.data, '__num_nodes__'): 83 | data.num_nodes = self.data.__num_nodes__[idx] 84 | 85 | for key in self.data.keys: 86 | item, slices = self.data[key], self.slices[key] 87 | start, end = slices[idx].item(), slices[idx + 1].item() 88 | if isinstance(item, Var): 89 | s = list(repeat(slice(None), item.ndim)) 90 | cat_dim = self.data.__cat_dim__(key, item) 91 | if cat_dim is None: 92 | cat_dim = 0 93 | s[cat_dim] = slice(start, end) 94 | elif start + 1 == end: 95 | s = slices[start] 96 | else: 97 | s = slice(start, end) 98 | data[key] = item[tuple(s)] 99 | 100 | if hasattr(self, '__data_list__'): 101 | self.__data_list__[idx] = copy.copy(data) 102 | 103 | return data 104 | 105 | @staticmethod 106 | def collate(data_list): 107 | r"""Collates a python list of data objects to the internal storage 108 | format of :class:`torch_geometric.data.InMemoryDataset`.""" 109 | keys = data_list[0].keys 110 | data = data_list[0].__class__() 111 | 112 | for key in keys: 113 | data[key] = [] 114 | slices = {key: [0] for key in keys} 115 | 116 | for item, key in product(data_list, keys): 117 | data[key].append(item[key]) 118 | if isinstance(item[key], Var) and item[key].ndim > 0: 119 | cat_dim = item.__cat_dim__(key, item[key]) 120 | cat_dim = 0 if cat_dim is None else cat_dim 121 | s = slices[key][-1] + item[key].size(cat_dim) 122 | else: 123 | s = slices[key][-1] + 1 124 | slices[key].append(s) 125 | 126 | if hasattr(data_list[0], '__num_nodes__'): 127 | data.__num_nodes__ = [] 128 | for item in data_list: 129 | data.__num_nodes__.append(item.num_nodes) 130 | 131 | for key in keys: 132 | item = data_list[0][key] 133 | if isinstance(item, Var) and len(data_list) > 1: 134 | if item.ndim > 0: 135 | cat_dim = data.__cat_dim__(key, item) 136 | cat_dim = 0 if cat_dim is None else cat_dim 137 | data[key] = jt.concat(data[key], dim=cat_dim) 138 | else: 139 | data[key] = jt.stack(data[key]) 140 | elif isinstance(item, Var): # Don't duplicate attributes... 141 | data[key] = data[key][0] 142 | elif isinstance(item, int) or isinstance(item, float): 143 | data[key] = jt.array(data[key]) 144 | 145 | slices[key] = jt.array(slices[key], dtype=Var.int32) 146 | 147 | return data, slices 148 | 149 | def copy(self, idx=None): 150 | if idx is None: 151 | data_list = [self.get(i) for i in range(len(self))] 152 | else: 153 | data_list = [self.get(i) for i in idx] 154 | dataset = copy.copy(self) 155 | dataset.__indices__ = None 156 | dataset.__data_list__ = data_list 157 | dataset.data, dataset.slices = self.collate(data_list) 158 | return dataset 159 | -------------------------------------------------------------------------------- /jittor_geometric/data/makedirs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import errno 4 | 5 | 6 | def makedirs(path): 7 | try: 8 | os.makedirs(osp.expanduser(osp.normpath(path))) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST and osp.isdir(path): 11 | raise e 12 | -------------------------------------------------------------------------------- /jittor_geometric/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .planetoid import Planetoid 3 | 4 | __all__ = [ 5 | 'Planetoid', 6 | ] 7 | 8 | classes = __all__ 9 | -------------------------------------------------------------------------------- /jittor_geometric/datasets/planetoid.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from jittor_geometric.io import read_planetoid_data 4 | from jittor_geometric.data import InMemoryDataset, download_url 5 | import jittor as jt 6 | from jittor import init 7 | 8 | 9 | class Planetoid(InMemoryDataset): 10 | r"""The citation network datasets "Cora", "CiteSeer" and "PubMed" from the 11 | `"Revisiting Semi-Supervised Learning with Graph Embeddings" 12 | `_ paper. 13 | Nodes represent documents and edges represent citation links. 14 | Training, validation and test splits are given by binary masks. 15 | 16 | Args: 17 | root (string): Root directory where the dataset should be saved. 18 | name (string): The name of the dataset (:obj:`"Cora"`, 19 | :obj:`"CiteSeer"`, :obj:`"PubMed"`). 20 | split (string): The type of dataset split 21 | (:obj:`"public"`, :obj:`"full"`, :obj:`"random"`). 22 | If set to :obj:`"public"`, the split will be the public fixed split 23 | from the 24 | `"Revisiting Semi-Supervised Learning with Graph Embeddings" 25 | `_ paper. 26 | If set to :obj:`"full"`, all nodes except those in the validation 27 | and test sets will be used for training (as in the 28 | `"FastGCN: Fast Learning with Graph Convolutional Networks via 29 | Importance Sampling" `_ paper). 30 | If set to :obj:`"random"`, train, validation, and test sets will be 31 | randomly generated, according to :obj:`num_train_per_class`, 32 | :obj:`num_val` and :obj:`num_test`. (default: :obj:`"public"`) 33 | num_train_per_class (int, optional): The number of training samples 34 | per class in case of :obj:`"random"` split. (default: :obj:`20`) 35 | num_val (int, optional): The number of validation samples in case of 36 | :obj:`"random"` split. (default: :obj:`500`) 37 | num_test (int, optional): The number of test samples in case of 38 | :obj:`"random"` split. (default: :obj:`1000`) 39 | transform (callable, optional): A function/transform that takes in an 40 | :obj:`torch_geometric.data.Data` object and returns a transformed 41 | version. The data object will be transformed before every access. 42 | (default: :obj:`None`) 43 | pre_transform (callable, optional): A function/transform that takes in 44 | an :obj:`torch_geometric.data.Data` object and returns a 45 | transformed version. The data object will be transformed before 46 | being saved to disk. (default: :obj:`None`) 47 | """ 48 | 49 | url = 'https://github.com/kimiyoung/planetoid/raw/master/data' 50 | 51 | def __init__(self, root, name, split="public", num_train_per_class=20, 52 | num_val=500, num_test=1000, transform=None, 53 | pre_transform=None): 54 | self.name = name 55 | 56 | super(Planetoid, self).__init__(root, transform, pre_transform) 57 | self.data, self.slices = jt.load(self.processed_paths[0]) 58 | self.split = split 59 | assert self.split in ['public', 'full', 'random'] 60 | 61 | if split == 'full': 62 | data = self.get(0) 63 | init(data.train_mask, True) 64 | data.train_mask[jt.logical_or( 65 | data.val_mask, data.test_mask)] = False 66 | self.data, self.slices = self.collate([data]) 67 | 68 | elif split == 'random': 69 | data = self.get(0) 70 | init(data.train_mask, False) 71 | for c in range(self.num_classes): 72 | idx = (data.y == c).nonzero().view(-1) 73 | idx = idx[jt.randperm(idx.size(0))[:num_train_per_class]] 74 | data.train_mask[idx] = True 75 | 76 | remaining = jt.logical_not(data.train_mask).nonzero().view(-1) 77 | remaining = remaining[jt.randperm(remaining.size(0))] 78 | 79 | init(data.val_mask, False) 80 | data.val_mask[remaining[:num_val]] = True 81 | 82 | init(data.test_mask, False) 83 | data.test_mask[remaining[num_val:num_val + num_test]] = True 84 | 85 | self.data, self.slices = self.collate([data]) 86 | 87 | @property 88 | def raw_dir(self): 89 | return osp.join(self.root, self.name, 'raw') 90 | 91 | @property 92 | def processed_dir(self): 93 | return osp.join(self.root, self.name, 'processed') 94 | 95 | @property 96 | def raw_file_names(self): 97 | names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index'] 98 | return ['ind.{}.{}'.format(self.name.lower(), name) for name in names] 99 | 100 | @property 101 | def processed_file_names(self): 102 | return 'data.pkl' 103 | 104 | def download(self): 105 | for name in self.raw_file_names: 106 | download_url('{}/{}'.format(self.url, name), self.raw_dir) 107 | 108 | def process(self): 109 | data = read_planetoid_data(self.raw_dir, self.name) 110 | data = data if self.pre_transform is None else self.pre_transform(data) 111 | jt.save(self.collate([data]), self.processed_paths[0]) 112 | 113 | def __repr__(self): 114 | return '{}()'.format(self.name) 115 | -------------------------------------------------------------------------------- /jittor_geometric/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .txt_array import parse_txt_array, read_txt_array 2 | from .planetoid import read_planetoid_data 3 | 4 | __all__ = [ 5 | 'parse_txt_array', 6 | 'read_txt_array', 7 | 'read_planetoid_data', 8 | ] 9 | -------------------------------------------------------------------------------- /jittor_geometric/io/planetoid.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | from itertools import repeat 4 | import numpy as np 5 | import jittor as jt 6 | from jittor import Var 7 | # from torch_sparse import coalesce as coalesce_fn, SparseTensor 8 | from jittor_geometric.data import Data 9 | from jittor_geometric.io import read_txt_array 10 | from jittor_geometric.utils import remove_self_loops 11 | 12 | try: 13 | import cPickle as pickle 14 | except ImportError: 15 | import pickle 16 | 17 | 18 | def read_planetoid_data(folder, prefix): 19 | names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index'] 20 | items = [read_file(folder, prefix, name) for name in names] 21 | x, tx, allx, y, ty, ally, graph, test_index = items 22 | train_index = jt.arange(y.size(0), dtype=Var.int32) 23 | val_index = jt.arange(y.size(0), y.size(0) + 500, dtype=Var.int32) 24 | sorted_test_index = test_index.argsort()[1] 25 | 26 | if prefix.lower() == 'citeseer': 27 | # There are some isolated nodes in the Citeseer graph, resulting in 28 | # none consecutive test indices. We need to identify them and add them 29 | # as zero vectors to `tx` and `ty`. 30 | len_test_indices = (test_index.max() - test_index.min()).item() + 1 31 | 32 | tx_ext = jt.zeros((len_test_indices, tx.size(1))) 33 | tx_ext[sorted_test_index - test_index.min(), :] = tx 34 | ty_ext = jt.zeros((len_test_indices, ty.size(1))) 35 | ty_ext[sorted_test_index - test_index.min(), :] = ty 36 | 37 | tx, ty = tx_ext, ty_ext 38 | 39 | if prefix.lower() == 'nell.0.001': 40 | tx_ext = jt.zeros((len(graph) - allx.size(0), x.size(1))) 41 | tx_ext[sorted_test_index - allx.size(0)] = tx 42 | 43 | ty_ext = jt.zeros((len(graph) - ally.size(0), y.size(1))) 44 | ty_ext[sorted_test_index - ally.size(0)] = ty 45 | 46 | tx, ty = tx_ext, ty_ext 47 | 48 | x = jt.concat([allx, tx], dim=0) 49 | x[test_index] = x[sorted_test_index] 50 | 51 | # Creating feature vectors for relations. 52 | row, col, value = SparseTensor.from_dense(x).coo() 53 | rows, cols, values = [row], [col], [value] 54 | 55 | mask1 = index_to_mask(test_index, size=len(graph)) 56 | mask2 = index_to_mask(jt.arange(allx.size(0), len(graph)), 57 | size=len(graph)) 58 | mask = jt.logical_or(jt.logical_not(mask1), jt.logical_not(mask2)) 59 | isolated_index = mask.nonzero(as_tuple=False).view(-1)[allx.size(0):] 60 | 61 | rows += [isolated_index] 62 | cols += [jt.arange(isolated_index.size(0)) + x.size(1)] 63 | values += [jt.ones((isolated_index.size(0)))] 64 | 65 | x = SparseTensor(row=jt.concat(rows), col=jt.concat(cols), 66 | value=jt.concat(values)) 67 | else: 68 | x = jt.concat([allx, tx], dim=0) 69 | x[test_index] = x[sorted_test_index] 70 | y = jt.concat([ally, ty], dim=0).argmax(dim=1)[0] 71 | y[test_index] = y[sorted_test_index] 72 | 73 | train_mask = index_to_mask(train_index, size=y.size(0)) 74 | val_mask = index_to_mask(val_index, size=y.size(0)) 75 | test_mask = index_to_mask(test_index, size=y.size(0)) 76 | 77 | edge_index = edge_index_from_dict(graph, num_nodes=y.size(0)) 78 | 79 | data = Data(x=x, edge_index=edge_index, y=y) 80 | data.train_mask = train_mask 81 | data.val_mask = val_mask 82 | data.test_mask = test_mask 83 | return data 84 | 85 | 86 | def read_file(folder, prefix, name): 87 | path = osp.join(folder, 'ind.{}.{}'.format(prefix.lower(), name)) 88 | if name == 'test.index': 89 | return read_txt_array(path, dtype=Var.int32) 90 | 91 | with open(path, 'rb') as f: 92 | if sys.version_info > (3, 0): 93 | out = pickle.load(f, encoding='latin1') 94 | else: 95 | out = pickle.load(f) 96 | 97 | if name == 'graph': 98 | return out 99 | 100 | out = out.todense() if hasattr(out, 'todense') else out 101 | if isinstance(out, np.matrix): 102 | out = np.asarray(out) 103 | out = jt.array(out) 104 | return out 105 | 106 | # careful 107 | 108 | 109 | def edge_index_from_dict(graph_dict, num_nodes=None, coalesce=False): 110 | row, col = [], [] 111 | for key, value in graph_dict.items(): 112 | row += repeat(key, len(value)) 113 | col += value 114 | edge_index = jt.stack([jt.array(row), jt.array(col)], dim=0) 115 | if coalesce: 116 | # NOTE: There are some duplicated edges and self loops in the datasets. 117 | # Other implementations do not remove them! 118 | edge_index, _ = remove_self_loops(edge_index) 119 | edge_index, _ = coalesce_fn(edge_index, None, num_nodes, num_nodes) 120 | return edge_index 121 | 122 | 123 | def index_to_mask(index, size): 124 | mask = jt.zeros((size, ), dtype=Var.bool) 125 | mask[index] = 1 126 | return mask 127 | -------------------------------------------------------------------------------- /jittor_geometric/io/txt_array.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | 4 | def parse_txt_array(src, sep=None, start=0, end=None, dtype=None): 5 | src = [[float(x) for x in line.split(sep)[start:end]] for line in src] 6 | src = jt.array(src, dtype=dtype).squeeze(1) 7 | return src 8 | 9 | 10 | def read_txt_array(path, sep=None, start=0, end=None, dtype=None): 11 | with open(path, 'r') as f: 12 | src = f.read().split('\n')[:-1] 13 | return parse_txt_array(src, sep, start, end, dtype) 14 | -------------------------------------------------------------------------------- /jittor_geometric/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv import * # noqa 2 | 3 | __all__ = [ 4 | 'Sequential', 5 | 'MetaLayer', 6 | 'DataParallel', 7 | 'Reshape', 8 | ] 9 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .message_passing import MessagePassing 2 | from .gcn_conv import GCNConv 3 | from .cheb_conv import ChebConv 4 | from .sg_conv import SGConv 5 | from .gcn2_conv import GCN2Conv 6 | 7 | __all__ = [ 8 | 'MessagePassing', 9 | 'GCNConv', 10 | 'ChebConv', 11 | 'SGConv', 12 | 'GCN2Conv', 13 | ] 14 | 15 | classes = __all__ 16 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/cheb_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from jittor_geometric.typing import OptVar 3 | 4 | import jittor as jt 5 | from jittor import Var 6 | from jittor_geometric.nn.conv import MessagePassing 7 | from jittor_geometric.utils import remove_self_loops, add_self_loops 8 | from jittor_geometric.utils import get_laplacian 9 | 10 | from ..inits import glorot, zeros 11 | 12 | 13 | class ChebConv(MessagePassing): 14 | r"""The chebyshev spectral graph convolutional operator from the 15 | `"Convolutional Neural Networks on Graphs with Fast Localized Spectral 16 | Filtering" `_ paper 17 | 18 | .. math:: 19 | \mathbf{X}^{\prime} = \sum_{k=1}^{K} \mathbf{Z}^{(k)} \cdot 20 | \mathbf{\Theta}^{(k)} 21 | 22 | where :math:`\mathbf{Z}^{(k)}` is computed recursively by 23 | 24 | .. math:: 25 | \mathbf{Z}^{(1)} &= \mathbf{X} 26 | 27 | \mathbf{Z}^{(2)} &= \mathbf{\hat{L}} \cdot \mathbf{X} 28 | 29 | \mathbf{Z}^{(k)} &= 2 \cdot \mathbf{\hat{L}} \cdot 30 | \mathbf{Z}^{(k-1)} - \mathbf{Z}^{(k-2)} 31 | 32 | and :math:`\mathbf{\hat{L}}` denotes the scaled and normalized Laplacian 33 | :math:`\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}`. 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, K, normalization='sym', 37 | bias=True, **kwargs): 38 | kwargs.setdefault('aggr', 'add') 39 | super(ChebConv, self).__init__(**kwargs) 40 | 41 | assert K > 0 42 | assert normalization in [None, 'sym', 'rw'], 'Invalid normalization' 43 | 44 | self.in_channels = in_channels 45 | self.out_channels = out_channels 46 | self.normalization = normalization 47 | self.weight = jt.ones((K, in_channels, out_channels)) 48 | if bias: 49 | self.bias = jt.ones((out_channels,)) 50 | else: 51 | self.bias = None 52 | 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | glorot(self.weight) 57 | zeros(self.bias) 58 | 59 | def __norm__(self, edge_index, num_nodes: Optional[int], 60 | edge_weight: OptVar, normalization: Optional[str], 61 | lambda_max, dtype: Optional[int] = None, 62 | batch: OptVar = None): 63 | 64 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 65 | 66 | edge_index, edge_weight = get_laplacian(edge_index, edge_weight, 67 | normalization, dtype, 68 | num_nodes) 69 | 70 | if batch is not None and lambda_max.numel() > 1: 71 | lambda_max = lambda_max[batch[edge_index[0]]] 72 | 73 | edge_weight = (2.0 * edge_weight) / lambda_max 74 | # edge_weight.masked_fill((edge_weight == float('inf')).int32(), 0) 75 | # edge_weight.masked_fill((edge_weight == float('-inf')).int32(), 0) 76 | for i in range(edge_weight.shape[0]): 77 | if edge_weight[i] == float('inf'): 78 | edge_weight[i] = 0 79 | # print('edge_weight: ', edge_weight.shape, 80 | # edge_weight.min(), edge_weight.max()) 81 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 82 | fill_value=-1., 83 | num_nodes=num_nodes) 84 | assert edge_weight is not None 85 | 86 | return edge_index, edge_weight 87 | 88 | def execute(self, x, edge_index, edge_weight: OptVar = None, 89 | batch: OptVar = None, lambda_max: OptVar = None): 90 | """""" 91 | # # add batch operation 92 | # if (len(x.shape) == 3 and len(self.weight.shape) == 3): 93 | # bs = x.shape[0] 94 | # self.weight = self.weight.unsqueeze(1).repeat(1,3,1,1) 95 | if self.normalization != 'sym' and lambda_max is None: 96 | raise ValueError('You need to pass `lambda_max` to `execute() in`' 97 | 'case the normalization is non-symmetric.') 98 | 99 | if lambda_max is None: 100 | lambda_max = Var([2.0]) 101 | if not isinstance(lambda_max, Var): 102 | lambda_max = Var([lambda_max]) 103 | assert lambda_max is not None 104 | 105 | edge_index, norm = self.__norm__(edge_index, x.size(self.node_dim), 106 | edge_weight, self.normalization, 107 | lambda_max, dtype=x.dtype, 108 | batch=batch) 109 | 110 | Tx_0 = x 111 | # Tx_1 = x # Dummy. 112 | out = jt.matmul(Tx_0, self.weight[0]) 113 | # print('self weight:', self.weight) 114 | 115 | x = x.permute(1,0,2) 116 | x = x.reshape(-1, x.shape[-2]*x.shape[-1]) 117 | Tx_0 = x 118 | 119 | if self.weight.size(0) > 1: 120 | # print('norm: ', norm.shape, 121 | # norm.min(), norm.max()) 122 | Tx_1 = self.propagate(edge_index, x=x, norm=norm, size=None) 123 | Tx_1_reshaped = Tx_1.reshape(Tx_1.shape[0],-1,self.in_channels).permute(1,0,2) 124 | # print('Tx_1: ', Tx_1.shape, Tx_1.min(), Tx_1.max()) 125 | out = out + jt.matmul(Tx_1_reshaped, self.weight[1]) 126 | 127 | for k in range(2, self.weight.size(0)): 128 | Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm, size=None) 129 | Tx_2 = 2. * Tx_2 - Tx_0 130 | 131 | Tx_2_reshaped = Tx_2.reshape(Tx_2.shape[0],-1,self.in_channels).permute(1,0,2) 132 | out = out + jt.matmul(Tx_2_reshaped, self.weight[k]) 133 | Tx_0, Tx_1 = Tx_1, Tx_2 134 | 135 | if self.bias is not None: 136 | out += self.bias 137 | return out 138 | 139 | def message(self, x_j, norm): 140 | # res = norm.reshape(-1, 1) * x_j.permute(1,0,2) 141 | # return res.permute(1,0,2) 142 | return norm.reshape(-1, 1) * x_j 143 | 144 | def __repr__(self): 145 | return '{}({}, {}, K={}, normalization={})'.format( 146 | self.__class__.__name__, self.in_channels, self.out_channels, 147 | self.weight.size(0), self.normalization) 148 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/gcn2_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | from jittor_geometric.typing import Adj, OptVar 3 | 4 | from math import log 5 | 6 | import jittor as jt 7 | from jittor import Var 8 | from jittor_geometric.nn.conv import MessagePassing 9 | from jittor_geometric.nn.conv.gcn_conv import gcn_norm 10 | 11 | from ..inits import glorot 12 | 13 | 14 | def addmm(input_, mat1, mat2, *, beta=1, alpha=1, out=None): 15 | return beta*input_ + alpha*(mat1 @ mat2) 16 | 17 | 18 | class GCN2Conv(MessagePassing): 19 | r"""The graph convolutional operator with initial residual connections and 20 | identity mapping (GCNII) from the `"Simple and Deep Graph Convolutional 21 | Networks" `_ paper 22 | 23 | .. math:: 24 | \mathbf{X}^{\prime} = \left( (1 - \alpha) \mathbf{\hat{P}}\mathbf{X} + 25 | \alpha \mathbf{X^{(0)}}\right) \left( (1 - \beta) \mathbf{I} + \beta 26 | \mathbf{\Theta} \right) 27 | 28 | with :math:`\mathbf{\hat{P}} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 29 | \mathbf{\hat{D}}^{-1/2}`, where 30 | :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency 31 | matrix with inserted self-loops and 32 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix, 33 | and :math:`\mathbf{X}^{(0)}` being the initial feature representation. 34 | Here, :math:`\alpha` models the strength of the initial residual 35 | connection, while :math:`\beta` models the strength of the identity 36 | mapping. 37 | The adjacency matrix can include other values than :obj:`1` representing 38 | edge weights via the optional :obj:`edge_weight` Var. 39 | """ 40 | 41 | _cached_edge_index: Optional[Tuple[Var, Var]] 42 | 43 | def __init__(self, channels: int, alpha: float, theta: float = None, 44 | layer: int = None, shared_weights: bool = True, 45 | cached: bool = False, add_self_loops: bool = True, 46 | normalize: bool = True, **kwargs): 47 | 48 | kwargs.setdefault('aggr', 'add') 49 | super(GCN2Conv, self).__init__(**kwargs) 50 | 51 | self.channels = channels 52 | self.alpha = alpha 53 | self.beta = 1. 54 | if theta is not None or layer is not None: 55 | assert theta is not None and layer is not None 56 | self.beta = log(theta / layer + 1) 57 | self.cached = cached 58 | self.normalize = normalize 59 | self.add_self_loops = add_self_loops 60 | 61 | self._cached_edge_index = None 62 | 63 | self.weight1 = jt.random((channels, channels)) 64 | 65 | if shared_weights: 66 | self.weight2 = None 67 | else: 68 | self.weight2 = jt.random((channels, channels)) 69 | 70 | self.reset_parameters() 71 | 72 | def reset_parameters(self): 73 | glorot(self.weight1) 74 | glorot(self.weight2) 75 | self._cached_edge_index = None 76 | 77 | def execute(self, x: Var, x_0: Var, edge_index: Adj, 78 | edge_weight: OptVar = None) -> Var: 79 | """""" 80 | 81 | if self.normalize: 82 | if isinstance(edge_index, Var): 83 | cache = self._cached_edge_index 84 | if cache is None: 85 | edge_index, edge_weight = gcn_norm( # yapf: disable 86 | edge_index, edge_weight, x.size(self.node_dim), False, 87 | self.add_self_loops, dtype=x.dtype) 88 | if self.cached: 89 | self._cached_edge_index = (edge_index, edge_weight) 90 | else: 91 | edge_index, edge_weight = cache[0], cache[1] 92 | # propagate_type: (x: Var, edge_weight: OptVar) 93 | x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) 94 | x.multiply(1 - self.alpha) 95 | x_0 = self.alpha * x_0[:x.size(0)] 96 | if self.weight2 is None: 97 | out = x.add(x_0) 98 | out = addmm(out, out, self.weight1, beta=1. - self.beta, 99 | alpha=self.beta) 100 | else: 101 | out = addmm(x, x, self.weight1, beta=1. - self.beta, 102 | alpha=self.beta) 103 | out += addmm(x_0, x_0, self.weight2, beta=1. - self.beta, 104 | alpha=self.beta) 105 | return out 106 | 107 | def message(self, x_j: Var, edge_weight: Var) -> Var: 108 | # return edge_weight.view(-1, 1) * x_jd 109 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j 110 | 111 | def __repr__(self): 112 | return '{}({}, alpha={}, beta={})'.format(self.__class__.__name__, 113 | self.channels, self.alpha, 114 | self.beta) 115 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/gcn_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | from jittor_geometric.typing import Adj, OptVar 3 | 4 | import jittor as jt 5 | from jittor import Var 6 | from jittor_geometric.nn.conv import MessagePassing 7 | from jittor_geometric.utils import add_remaining_self_loops 8 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 9 | 10 | from ..inits import glorot, zeros 11 | 12 | 13 | def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, 14 | add_self_loops=True, dtype=None): 15 | 16 | fill_value = 2. if improved else 1. 17 | 18 | if isinstance(edge_index, Var): 19 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 20 | 21 | if edge_weight is None: 22 | edge_weight = jt.ones((edge_index.size(1), )) 23 | 24 | if add_self_loops: 25 | edge_index, tmp_edge_weight = add_remaining_self_loops( 26 | edge_index, edge_weight, fill_value, num_nodes) 27 | assert tmp_edge_weight is not None 28 | edge_weight = tmp_edge_weight 29 | 30 | row, col = edge_index[0], edge_index[1] 31 | shape = list(edge_weight.shape) 32 | shape[0] = num_nodes 33 | deg = jt.zeros(shape) 34 | deg = jt.scatter(deg, 0, col, src=edge_weight, reduce='add') 35 | deg_inv_sqrt = deg.pow(-0.5) 36 | deg_inv_sqrt.masked_fill(deg_inv_sqrt == float('inf'), 0) 37 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 38 | 39 | 40 | class GCNConv(MessagePassing): 41 | r"""The graph convolutional operator from the `"Semi-supervised 42 | Classification with Graph Convolutional Networks" 43 | `_ paper 44 | """ 45 | 46 | _cached_edge_index: Optional[Tuple[Var, Var]] 47 | 48 | def __init__(self, in_channels: int, out_channels: int, 49 | improved: bool = False, cached: bool = False, 50 | add_self_loops: bool = True, normalize: bool = True, 51 | bias: bool = True, **kwargs): 52 | 53 | kwargs.setdefault('aggr', 'add') 54 | super(GCNConv, self).__init__(**kwargs) 55 | 56 | self.in_channels = in_channels 57 | self.out_channels = out_channels 58 | self.improved = improved 59 | self.cached = cached 60 | self.add_self_loops = add_self_loops 61 | self.normalize = normalize 62 | 63 | self._cached_edge_index = None 64 | self._cached_adj_t = None 65 | 66 | self.weight = jt.random((in_channels, out_channels)) 67 | 68 | if bias: 69 | self.bias = jt.random((out_channels,)) 70 | 71 | else: 72 | self.bias = None 73 | 74 | self.reset_parameters() 75 | 76 | def reset_parameters(self): 77 | glorot(self.weight) 78 | zeros(self.bias) 79 | self._cached_edge_index = None 80 | self._cached_adj_t = None 81 | 82 | def execute(self, x: Var, edge_index: Adj, 83 | edge_weight: OptVar = None) -> Var: 84 | """""" 85 | 86 | if self.normalize: 87 | if isinstance(edge_index, Var): 88 | cache = self._cached_edge_index 89 | if cache is None: 90 | edge_index, edge_weight = gcn_norm( 91 | edge_index, edge_weight, x.size(self.node_dim), 92 | self.improved, self.add_self_loops) 93 | if self.cached: 94 | self._cached_edge_index = (edge_index, edge_weight) 95 | else: 96 | edge_index, edge_weight = cache[0], cache[1] 97 | x = x @ self.weight 98 | out = self.propagate(edge_index, x=x, edge_weight=edge_weight, 99 | size=None) 100 | if self.bias is not None: 101 | out += self.bias 102 | 103 | return out 104 | 105 | def message(self, x_j: Var, edge_weight: OptVar) -> Var: 106 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j 107 | 108 | def __repr__(self): 109 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 110 | self.out_channels) 111 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/message_passing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | from typing import List, Optional, Set 5 | from inspect import Parameter 6 | 7 | import jittor as jt 8 | from jittor import nn, Module 9 | from jittor import Var 10 | from jittor_geometric.typing import Adj, Size 11 | 12 | from .utils.inspector import Inspector 13 | 14 | 15 | class MessagePassing(Module): 16 | special_args: Set[str] = { 17 | 'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size', 18 | 'size_i', 'size_j', 'ptr', 'index', 'dim_size' 19 | } 20 | 21 | def __init__(self, aggr: Optional[str] = "add", 22 | flow: str = "source_to_target", node_dim: int = -2): 23 | 24 | super(MessagePassing, self).__init__() 25 | 26 | self.aggr = aggr 27 | assert self.aggr in ['add', 'mean', 'max', None] 28 | 29 | self.flow = flow 30 | assert self.flow in ['source_to_target', 'target_to_source'] 31 | 32 | self.node_dim = node_dim 33 | 34 | self.inspector = Inspector(self) 35 | self.inspector.inspect(self.message) 36 | self.inspector.inspect(self.aggregate, pop_first=True) 37 | self.inspector.inspect(self.update, pop_first=True) 38 | 39 | self.__user_args__ = self.inspector.keys( 40 | ['message', 'aggregate', 'update']).difference(self.special_args) 41 | 42 | def __check_input__(self, edge_index, size): 43 | the_size: List[Optional[int]] = [None, None] 44 | 45 | if isinstance(edge_index, Var): 46 | assert edge_index.dtype == Var.int32 47 | assert edge_index.ndim == 2 48 | assert edge_index.size(0) == 2 49 | if size is not None: 50 | the_size[0] = size[0] 51 | the_size[1] = size[1] 52 | return the_size 53 | 54 | raise ValueError( 55 | ('`MessagePassing.propagate` only supports `jittor Var int32` of ' 56 | 'shape `[2, num_messages]`')) 57 | 58 | def __set_size__(self, size: List[Optional[int]], dim: int, src: Var): 59 | the_size = size[dim] 60 | if the_size is None: 61 | size[dim] = src.size(self.node_dim) 62 | elif the_size != src.size(self.node_dim): 63 | raise ValueError( 64 | (f'Encountered Var with size {src.size(self.node_dim)} in ' 65 | f'dimension {self.node_dim}, but expected size {the_size}.')) 66 | 67 | def __lift__(self, src, edge_index, dim): 68 | if isinstance(edge_index, Var): 69 | index = edge_index[dim] 70 | return src[(slice(None),)*self.node_dim+(index,)] 71 | raise ValueError 72 | 73 | def __collect__(self, args, edge_index, size, kwargs): 74 | i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1) 75 | 76 | out = {} 77 | for arg in args: 78 | if arg[-2:] not in ['_i', '_j']: 79 | out[arg] = kwargs.get(arg, Parameter.empty) 80 | else: 81 | dim = 0 if arg[-2:] == '_j' else 1 82 | data = kwargs.get(arg[:-2], Parameter.empty) 83 | if isinstance(data, (tuple, list)): 84 | assert len(data) == 2 85 | if isinstance(data[1 - dim], Var): 86 | self.__set_size__(size, 1 - dim, data[1 - dim]) 87 | data = data[dim] 88 | 89 | if isinstance(data, Var): 90 | self.__set_size__(size, dim, data) 91 | data = self.__lift__(data, edge_index, 92 | j if arg[-2:] == '_j' else i) 93 | 94 | out[arg] = data 95 | 96 | if isinstance(edge_index, Var): 97 | out['adj_t'] = None 98 | out['edge_index'] = edge_index 99 | out['edge_index_i'] = edge_index[i] 100 | out['edge_index_j'] = edge_index[j] 101 | out['ptr'] = None 102 | 103 | out['index'] = out['edge_index_i'] 104 | out['size'] = size 105 | out['size_i'] = size[1] or size[0] 106 | out['size_j'] = size[0] or size[1] 107 | out['dim_size'] = out['size_i'] 108 | 109 | return out 110 | 111 | def propagate(self, edge_index: Adj, size: Size = None, **kwargs): 112 | 113 | size = self.__check_input__(edge_index, size) 114 | 115 | if isinstance(edge_index, Var): 116 | coll_dict = self.__collect__(self.__user_args__, edge_index, size, 117 | kwargs) 118 | 119 | msg_kwargs = self.inspector.distribute('message', coll_dict) 120 | out = self.message(**msg_kwargs) 121 | 122 | aggr_kwargs = self.inspector.distribute('aggregate', coll_dict) 123 | out = self.aggregate(out, **aggr_kwargs) 124 | 125 | update_kwargs = self.inspector.distribute('update', coll_dict) 126 | return self.update(out, **update_kwargs) 127 | 128 | def message(self, x_j: Var) -> Var: 129 | r"""Constructs messages from node :math:`j` to node :math:`i` 130 | in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in 131 | :obj:`edge_index`. 132 | This function can take any argument as input which was initially 133 | passed to :meth:`propagate`. 134 | Furthermore, Var passed to :meth:`propagate` can be mapped to the 135 | respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or 136 | :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. 137 | """ 138 | return x_j 139 | 140 | def aggregate(self, inputs: Var, index: Var, 141 | ptr: Optional[Var] = None, 142 | dim_size: Optional[int] = None) -> Var: 143 | # return input 144 | shape = list(inputs.shape) 145 | shape[self.node_dim] = dim_size 146 | out = jt.zeros(shape) 147 | 148 | return jt.scatter(out, 0, index, src=inputs, reduce=self.aggr) 149 | 150 | def update(self, inputs: Var) -> Var: 151 | r"""Updates node embeddings in analogy to 152 | :math:`\gamma_{\mathbf{\Theta}}` for each node 153 | :math:`i \in \mathcal{V}`. 154 | Takes in the output of aggregation as first argument and any argument 155 | which was initially passed to :meth:`propagate`. 156 | """ 157 | return inputs 158 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/sg_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from jittor_geometric.typing import Adj, OptVar 3 | 4 | import jittor as jt 5 | from jittor import Var 6 | from jittor.nn import Linear 7 | from jittor_geometric.nn.conv import MessagePassing 8 | from jittor_geometric.nn.conv.gcn_conv import gcn_norm 9 | 10 | 11 | class SGConv(MessagePassing): 12 | r"""The simple graph convolutional operator from the `"Simplifying Graph 13 | Convolutional Networks" `_ paper 14 | 15 | .. math:: 16 | \mathbf{X}^{\prime} = {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 17 | \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X} \mathbf{\Theta}, 18 | 19 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 20 | adjacency matrix with inserted self-loops and 21 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 22 | The adjacency matrix can include other values than :obj:`1` representing 23 | edge weights via the optional :obj:`edge_weight` Var. 24 | 25 | """ 26 | 27 | _cached_x: Optional[Var] 28 | 29 | def __init__(self, in_channels: int, out_channels: int, K: int = 1, 30 | cached: bool = False, add_self_loops: bool = True, 31 | bias: bool = True, **kwargs): 32 | kwargs.setdefault('aggr', 'add') 33 | super(SGConv, self).__init__(**kwargs) 34 | 35 | self.in_channels = in_channels 36 | self.out_channels = out_channels 37 | self.K = K 38 | self.cached = cached 39 | self.add_self_loops = add_self_loops 40 | 41 | self._cached_x = None 42 | 43 | self.lin = Linear(in_channels, out_channels, bias=bias) 44 | 45 | def execute(self, x: Var, edge_index: Adj, 46 | edge_weight: OptVar = None) -> Var: 47 | """""" 48 | cache = self._cached_x 49 | if cache is None: 50 | if isinstance(edge_index, Var): 51 | edge_index, edge_weight = gcn_norm( 52 | edge_index, edge_weight, x.size(self.node_dim), False, 53 | self.add_self_loops, dtype=x.dtype) 54 | for k in range(self.K): 55 | x = self.propagate(edge_index, x=x, edge_weight=edge_weight, 56 | size=None) 57 | if self.cached: 58 | self._cached_x = x 59 | else: 60 | x = cache 61 | return self.lin(x) 62 | 63 | def message(self, x_j: Var, edge_weight: Var) -> Var: 64 | return edge_weight.view(-1, 1) * x_j 65 | 66 | def __repr__(self): 67 | return '{}({}, {}, K={})'.format(self.__class__.__name__, 68 | self.in_channels, self.out_channels, 69 | self.K) 70 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/utils/__pycache__/inspector.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/jittor_geometric/nn/conv/utils/__pycache__/inspector.cpython-37.pyc -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/utils/__pycache__/typing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/jittor_geometric/nn/conv/utils/__pycache__/typing.cpython-37.pyc -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/utils/inspector.py: -------------------------------------------------------------------------------- 1 | import re 2 | import inspect 3 | from collections import OrderedDict 4 | from typing import Dict, List, Any, Optional, Callable, Set 5 | 6 | from .typing import parse_types 7 | 8 | 9 | class Inspector(object): 10 | def __init__(self, base_class: Any): 11 | self.base_class: Any = base_class 12 | self.params: Dict[str, Dict[str, Any]] = {} 13 | 14 | def inspect(self, func: Callable, 15 | pop_first: bool = False) -> Dict[str, Any]: 16 | params = inspect.signature(func).parameters 17 | params = OrderedDict(params) 18 | if pop_first: 19 | params.popitem(last=False) 20 | self.params[func.__name__] = params 21 | 22 | def keys(self, func_names: Optional[List[str]] = None) -> Set[str]: 23 | keys = [] 24 | for func in func_names or list(self.params.keys()): 25 | keys += self.params[func].keys() 26 | return set(keys) 27 | 28 | def __implements__(self, cls, func_name: str) -> bool: 29 | if cls.__name__ == 'MessagePassing': 30 | return False 31 | if func_name in cls.__dict__.keys(): 32 | return True 33 | return any(self.__implements__(c, func_name) for c in cls.__bases__) 34 | 35 | def implements(self, func_name: str) -> bool: 36 | return self.__implements__(self.base_class.__class__, func_name) 37 | 38 | def types(self, func_names: Optional[List[str]] = None) -> Dict[str, str]: 39 | out: Dict[str, str] = {} 40 | for func_name in func_names or list(self.params.keys()): 41 | func = getattr(self.base_class, func_name) 42 | arg_types = parse_types(func)[0][0] 43 | for key in self.params[func_name].keys(): 44 | if key in out and out[key] != arg_types[key]: 45 | raise ValueError( 46 | (f'Found inconsistent types for argument {key}. ' 47 | f'Expected type {out[key]} but found type ' 48 | f'{arg_types[key]}.')) 49 | out[key] = arg_types[key] 50 | return out 51 | 52 | def distribute(self, func_name, kwargs: Dict[str, Any]): 53 | out = {} 54 | for key, param in self.params[func_name].items(): 55 | data = kwargs.get(key, inspect.Parameter.empty) 56 | if data is inspect.Parameter.empty: 57 | if param.default is inspect.Parameter.empty: 58 | raise TypeError(f'Required parameter {key} is empty.') 59 | data = param.default 60 | out[key] = data 61 | return out 62 | 63 | 64 | def func_header_repr(func: Callable, keep_annotation: bool = True) -> str: 65 | source = inspect.getsource(func) 66 | signature = inspect.signature(func) 67 | 68 | if keep_annotation: 69 | return ''.join(re.split(r'(\).*?:.*?\n)', source, 70 | maxsplit=1)[:2]).strip() 71 | 72 | params_repr = ['self'] 73 | for param in signature.parameters.values(): 74 | params_repr.append(param.name) 75 | if param.default is not inspect.Parameter.empty: 76 | params_repr[-1] += f'={param.default}' 77 | 78 | return f'def {func.__name__}({", ".join(params_repr)}):' 79 | 80 | 81 | def func_body_repr(func: Callable, keep_annotation: bool = True) -> str: 82 | source = inspect.getsource(func) 83 | body_repr = re.split(r'\).*?:.*?\n', source, maxsplit=1)[1] 84 | if not keep_annotation: 85 | body_repr = re.sub(r'\s*# type:.*\n', '', body_repr) 86 | return body_repr 87 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/utils/typing.py: -------------------------------------------------------------------------------- 1 | import re 2 | import inspect 3 | import pyparsing as pp 4 | from itertools import product 5 | from collections import OrderedDict 6 | from typing import Callable, Tuple, Dict, List 7 | 8 | 9 | def split_types_repr(types_repr: str) -> List[str]: 10 | out = [] 11 | i = depth = 0 12 | for j, char in enumerate(types_repr): 13 | if char == '[': 14 | depth += 1 15 | elif char == ']': 16 | depth -= 1 17 | elif char == ',' and depth == 0: 18 | out.append(types_repr[i:j].strip()) 19 | i = j + 1 20 | out.append(types_repr[i:].strip()) 21 | return out 22 | 23 | 24 | def sanitize(type_repr: str): 25 | type_repr = re.sub(r'', r'\1', type_repr) 26 | type_repr = type_repr.replace('typing.', '') 27 | type_repr = type_repr.replace('torch_sparse.tensor.', '') 28 | type_repr = type_repr.replace('Adj', 'Union[Tensor, SparseTensor]') 29 | 30 | # Replace `Union[..., NoneType]` by `Optional[...]`. 31 | sexp = pp.nestedExpr(opener='[', closer=']') 32 | tree = sexp.parseString(f'[{type_repr.replace(",", " ")}]').asList()[0] 33 | 34 | def union_to_optional_(tree): 35 | for i in range(len(tree)): 36 | e, n = tree[i], tree[i + 1] if i + 1 < len(tree) else [] 37 | if e == 'Union' and n[-1] == 'NoneType': 38 | tree[i] = 'Optional' 39 | tree[i + 1] = tree[i + 1][:-1] 40 | elif e == 'Union' and 'NoneType' in n: 41 | idx = n.index('NoneType') 42 | n[idx] = [n[idx - 1]] 43 | n[idx - 1] = 'Optional' 44 | elif isinstance(e, list): 45 | tree[i] = union_to_optional_(e) 46 | return tree 47 | 48 | tree = union_to_optional_(tree) 49 | type_repr = re.sub(r'\'|\"', '', str(tree)[1:-1]).replace(', [', '[') 50 | 51 | return type_repr 52 | 53 | 54 | def param_type_repr(param) -> str: 55 | if param.annotation is inspect.Parameter.empty: 56 | return 'jittor.Var' 57 | return sanitize(re.split(r':|='.strip(), str(param))[1]) 58 | 59 | 60 | def return_type_repr(signature) -> str: 61 | return_type = signature.return_annotation 62 | if return_type is inspect.Parameter.empty: 63 | return 'jittor.Var' 64 | elif str(return_type)[:6] != ' List[Tuple[Dict[str, str], str]]: 73 | source = inspect.getsource(func) 74 | signature = inspect.signature(func) 75 | 76 | # Parse `# type: (...) -> ...` annotation. Note that it is allowed to pass 77 | # multiple `# type:` annotations in `forward()`. 78 | iterator = re.finditer(r'#\s*type:\s*\((.*)\)\s*->\s*(.*)\s*\n', source) 79 | matches = list(iterator) 80 | 81 | if len(matches) > 0: 82 | out = [] 83 | args = list(signature.parameters.keys()) 84 | for match in matches: 85 | arg_types_repr, return_type = match.groups() 86 | arg_types = split_types_repr(arg_types_repr) 87 | arg_types = OrderedDict((k, v) for k, v in zip(args, arg_types)) 88 | return_type = return_type.split('#')[0].strip() 89 | out.append((arg_types, return_type)) 90 | return out 91 | 92 | # Alternatively, parse annotations using the inspected signature. 93 | else: 94 | ps = signature.parameters 95 | arg_types = OrderedDict((k, param_type_repr(v)) for k, v in ps.items()) 96 | return [(arg_types, return_type_repr(signature))] 97 | 98 | 99 | def resolve_types(arg_types: Dict[str, str], 100 | return_type_repr: str) -> List[Tuple[List[str], str]]: 101 | out = [] 102 | for type_repr in arg_types.values(): 103 | if type_repr[:5] == 'Union': 104 | out.append(split_types_repr(type_repr[6:-1])) 105 | else: 106 | out.append([type_repr]) 107 | return [(x, return_type_repr) for x in product(*out)] 108 | -------------------------------------------------------------------------------- /jittor_geometric/nn/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import jittor as jt 4 | from jittor import init 5 | 6 | 7 | def uniform(size, var): 8 | if var is not None: 9 | bound = 1.0 / math.sqrt(size) 10 | init.uniform_(var, -bound, bound) 11 | 12 | 13 | def kaiming_uniform(var, fan, a): 14 | if var is not None: 15 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 16 | init.uniform_(var, -bound, bound) 17 | 18 | 19 | def glorot(var): 20 | if var is not None: 21 | stdv = math.sqrt(6.0 / (var.size(-2) + var.size(-1))) 22 | init.uniform_(var, -stdv, stdv) 23 | 24 | 25 | def zeros(var): 26 | if var is not None: 27 | init.constant_(var, 0) 28 | 29 | 30 | def ones(var): 31 | if var is not None: 32 | init.constant_(var, 1) 33 | 34 | 35 | def normal(var, mean, std): 36 | if var is not None: 37 | var.assign(jt.normal(mean, std, size=var.size)) 38 | 39 | 40 | def reset(nn): 41 | def _reset(item): 42 | if hasattr(item, 'reset_parameters'): 43 | item.reset_parameters() 44 | 45 | if nn is not None: 46 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 47 | for item in nn.children(): 48 | _reset(item) 49 | else: 50 | _reset(nn) 51 | -------------------------------------------------------------------------------- /jittor_geometric/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .normalize_features import NormalizeFeatures 2 | 3 | __all__ = [ 4 | 'NormalizeFeatures', 5 | ] 6 | 7 | classes = __all__ 8 | -------------------------------------------------------------------------------- /jittor_geometric/transforms/normalize_features.py: -------------------------------------------------------------------------------- 1 | class NormalizeFeatures(object): 2 | r"""Row-normalizes node features to sum-up to one.""" 3 | 4 | def __call__(self, data): 5 | data.x = data.x / data.x.sum(1, keepdims=True).clamp(min_v=1) 6 | return data 7 | 8 | def __repr__(self): 9 | return '{}()'.format(self.__class__.__name__) 10 | -------------------------------------------------------------------------------- /jittor_geometric/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Union 2 | 3 | from jittor import Var 4 | 5 | Adj = Optional[Var] 6 | OptVar = Optional[Var] 7 | PairVar = Tuple[Var, Var] 8 | OptPairVar = Tuple[Var, Optional[Var]] 9 | PairOptVar = Tuple[Optional[Var], Optional[Var]] 10 | Size = Optional[Tuple[int, int]] 11 | NoneType = Optional[Var] 12 | -------------------------------------------------------------------------------- /jittor_geometric/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .degree import degree 2 | from .loop import (contains_self_loops, remove_self_loops, 3 | segregate_self_loops, add_self_loops, 4 | add_remaining_self_loops) 5 | from .isolated import contains_isolated_nodes, remove_isolated_nodes 6 | from .get_laplacian import get_laplacian 7 | 8 | __all__ = [ 9 | 'degree', 10 | 'contains_self_loops', 11 | 'remove_self_loops', 12 | 'segregate_self_loops', 13 | 'add_self_loops', 14 | 'add_remaining_self_loops', 15 | 'contains_isolated_nodes', 16 | 'remove_isolated_nodes', 17 | 'get_laplacian', 18 | ] 19 | 20 | classes = __all__ 21 | -------------------------------------------------------------------------------- /jittor_geometric/utils/degree.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jittor as jt 4 | 5 | from .num_nodes import maybe_num_nodes 6 | 7 | 8 | def degree(index, num_nodes: Optional[int] = None, 9 | dtype: Optional[int] = None): 10 | r"""Computes the (unweighted) degree of a given one-dimensional index 11 | Var. 12 | """ 13 | N = maybe_num_nodes(index, num_nodes) 14 | out = jt.zeros((N, ), dtype=dtype) 15 | one = jt.ones((index.size(0), ), dtype=out.dtype) 16 | return out.scatter_add_(0, index, one) 17 | -------------------------------------------------------------------------------- /jittor_geometric/utils/get_laplacian.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jittor as jt 4 | from jittor import Var 5 | from jittor_geometric.utils import add_self_loops, remove_self_loops 6 | 7 | from .num_nodes import maybe_num_nodes 8 | 9 | 10 | def get_laplacian(edge_index, edge_weight: Optional[Var] = None, 11 | normalization: Optional[str] = None, 12 | dtype: Optional[int] = None, 13 | num_nodes: Optional[int] = None): 14 | r""" Computes the graph Laplacian of the graph given by :obj:`edge_index` 15 | and optional :obj:`edge_weight`. 16 | 17 | Args: 18 | edge_index (Var int32): The edge indices. 19 | edge_weight (Var, optional): One-dimensional edge weights. 20 | (default: :obj:`None`) 21 | normalization (str, optional): The normalization scheme for the graph 22 | Laplacian (default: :obj:`None`): 23 | 24 | 1. :obj:`None`: No normalization 25 | :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` 26 | 27 | 2. :obj:`"sym"`: Symmetric normalization 28 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} 29 | \mathbf{D}^{-1/2}` 30 | 31 | 3. :obj:`"rw"`: Random-walk normalization 32 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` 33 | dtype (Var.dtype, optional): The desired data type of returned Var 34 | in case :obj:`edge_weight=None`. (default: :obj:`None`) 35 | num_nodes (int, optional): The number of nodes, *i.e.* 36 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 37 | """ 38 | 39 | if normalization is not None: 40 | assert normalization in ['sym', 'rw'] # 'Invalid normalization' 41 | 42 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 43 | 44 | if edge_weight is None: 45 | edge_weight = jt.ones((edge_index.size(1)), dtype=dtype) 46 | 47 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 48 | 49 | row, col = edge_index[0], edge_index[1] 50 | shape = list(edge_weight.shape) 51 | shape[0] = num_nodes 52 | deg = jt.zeros(shape) 53 | deg = jt.scatter(deg, 0, row, src=edge_weight, reduce='add') 54 | if normalization is None: 55 | # L = D - A. 56 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) 57 | edge_weight = jt.concat([-edge_weight, deg], dim=0) 58 | elif normalization == 'sym': 59 | # Compute A_norm = -D^{-1/2} A D^{-1/2}. 60 | deg_inv_sqrt = deg.pow(-0.5) 61 | # deg_inv_sqrt.masked_fill(deg_inv_sqrt == float('inf'), 0) 62 | 63 | for i in range(deg_inv_sqrt.shape[0]): 64 | if deg_inv_sqrt[i] == float('inf'): 65 | deg_inv_sqrt[i] = 0 66 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 67 | 68 | # L = I - A_norm. 69 | edge_index, tmp = add_self_loops(edge_index, -edge_weight, 70 | fill_value=1., num_nodes=num_nodes) 71 | assert tmp is not None 72 | edge_weight = tmp 73 | else: 74 | # Compute A_norm = -D^{-1} A. 75 | deg_inv = 1.0 / deg 76 | deg_inv.masked_fill(deg_inv == float('inf'), 0) 77 | edge_weight = deg_inv[row] * edge_weight 78 | 79 | # L = I - A_norm. 80 | edge_index, tmp = add_self_loops(edge_index, -edge_weight, 81 | fill_value=1., num_nodes=num_nodes) 82 | assert tmp is not None 83 | edge_weight = tmp 84 | 85 | return edge_index, edge_weight 86 | -------------------------------------------------------------------------------- /jittor_geometric/utils/isolated.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import Var 3 | from jittor_geometric.utils import remove_self_loops, segregate_self_loops 4 | 5 | from .num_nodes import maybe_num_nodes 6 | 7 | 8 | def contains_isolated_nodes(edge_index, num_nodes=None): 9 | r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains 10 | isolated nodes. 11 | 12 | Args: 13 | edge_index (Var int32): The edge indices. 14 | num_nodes (int, optional): The number of nodes, *i.e.* 15 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 16 | 17 | :rtype: bool 18 | """ 19 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 20 | (row, col), _ = remove_self_loops(edge_index) 21 | 22 | return jt.unique(jt.concat((row, col))).size(0) < num_nodes 23 | 24 | 25 | def remove_isolated_nodes(edge_index, edge_attr=None, num_nodes=None): 26 | r"""Removes the isolated nodes from the graph given by :attr:`edge_index` 27 | with optional edge attributes :attr:`edge_attr`. 28 | In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter 29 | out isolated node features later on. 30 | Self-loops are preserved for non-isolated nodes. 31 | 32 | Args: 33 | edge_index (Var int32): The edge indices. 34 | edge_attr (Var, optional): Edge weights or multi-dimensional 35 | edge features. (default: :obj:`None`) 36 | num_nodes (int, optional): The number of nodes, *i.e.* 37 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 38 | 39 | :rtype: (Var int32, Var, Var bool) 40 | """ 41 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 42 | 43 | out = segregate_self_loops(edge_index, edge_attr) 44 | edge_index, edge_attr, loop_edge_index, loop_edge_attr = out 45 | 46 | mask = jt.zeros((num_nodes), dtype=Var.bool) 47 | mask[edge_index.view(-1)] = 1 48 | 49 | assoc = jt.full((num_nodes, ), -1, dtype=Var.int32) 50 | assoc[mask] = jt.arange(mask.sum()) 51 | edge_index = assoc[edge_index] 52 | 53 | loop_mask = jt.zeros_like(mask) 54 | loop_mask[loop_edge_index[0]] = 1 55 | loop_mask = loop_mask & mask 56 | loop_assoc = jt.full_like(assoc, -1) 57 | loop_assoc[loop_edge_index[0]] = jt.arange(loop_edge_index.size(1)) 58 | loop_idx = loop_assoc[loop_mask] 59 | loop_edge_index = assoc[loop_edge_index[:, loop_idx]] 60 | 61 | edge_index = jt.concat([edge_index, loop_edge_index], dim=1) 62 | 63 | if edge_attr is not None: 64 | loop_edge_attr = loop_edge_attr[loop_idx] 65 | edge_attr = jt.concat([edge_attr, loop_edge_attr], dim=0) 66 | 67 | return edge_index, edge_attr, mask 68 | -------------------------------------------------------------------------------- /jittor_geometric/utils/loop.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jittor as jt 4 | 5 | from .num_nodes import maybe_num_nodes 6 | from jittor import Var, init 7 | 8 | 9 | def contains_self_loops(edge_index): 10 | r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains 11 | self-loops. 12 | 13 | Args: 14 | edge_index (Var int32): The edge indices. 15 | 16 | :rtype: bool 17 | """ 18 | mask = edge_index[0] == edge_index[1] 19 | return mask.sum().item() > 0 20 | 21 | 22 | def remove_self_loops(edge_index, edge_attr: Optional[Var] = None): 23 | r"""Removes every self-loop in the graph given by :attr:`edge_index`, so 24 | that :math:`(i,i) \not\in \mathcal{E}` for every :math:`i \in \mathcal{V}`. 25 | 26 | Args: 27 | edge_index (Var int32): The edge indices. 28 | edge_attr (Var, optional): Edge weights or multi-dimensional 29 | edge features. (default: :obj:`None`) 30 | 31 | :rtype: (:class:`Var int32`, :class:`Var`) 32 | """ 33 | mask = edge_index[0] != edge_index[1] 34 | edge_index = edge_index[:, mask] 35 | if edge_attr is None: 36 | return edge_index, None 37 | else: 38 | return edge_index, edge_attr[mask] 39 | 40 | 41 | def segregate_self_loops(edge_index, edge_attr: Optional[Var] = None): 42 | r"""Segregates self-loops from the graph. 43 | 44 | Args: 45 | edge_index (Var int32): The edge indices. 46 | edge_attr (Var, optional): Edge weights or multi-dimensional 47 | edge features. (default: :obj:`None`) 48 | 49 | :rtype: (:class:`Var int32`, :class:`Var`, :class:`Var int32`, 50 | :class:`Var`) 51 | """ 52 | 53 | mask = edge_index[0] != edge_index[1] 54 | inv_mask = jt.logical_not(mask) 55 | 56 | loop_edge_index = edge_index[:, inv_mask] 57 | loop_edge_attr = None if edge_attr is None else edge_attr[inv_mask] 58 | edge_index = edge_index[:, mask] 59 | edge_attr = None if edge_attr is None else edge_attr[mask] 60 | 61 | return edge_index, edge_attr, loop_edge_index, loop_edge_attr 62 | 63 | 64 | def add_self_loops(edge_index, edge_weight: Optional[Var] = None, 65 | fill_value: float = 1., num_nodes: Optional[int] = None): 66 | r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node 67 | :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. 68 | In case the graph is weighted, self-loops will be added with edge weights 69 | denoted by :obj:`fill_value`. 70 | 71 | Args: 72 | edge_index (Var int32): The edge indices. 73 | edge_weight (Var, optional): One-dimensional edge weights. 74 | (default: :obj:`None`) 75 | fill_value (float, optional): If :obj:`edge_weight` is not :obj:`None`, 76 | will add self-loops with edge weights of :obj:`fill_value` to the 77 | graph. (default: :obj:`1.`) 78 | num_nodes (int, optional): The number of nodes, *i.e.* 79 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 80 | 81 | :rtype: (:class:`Var int32`, :class:`Var`) 82 | """ 83 | N = maybe_num_nodes(edge_index, num_nodes) 84 | 85 | loop_index = jt.arange(0, N, dtype=Var.int32) 86 | loop_index = loop_index.unsqueeze(0).repeat(2, 1) 87 | 88 | if edge_weight is not None: 89 | assert edge_weight.numel() == edge_index.size(1) 90 | loop_weight = init.constant((N, ), edge_weight.dtype, fill_value) 91 | edge_weight = jt.concat([edge_weight, loop_weight], dim=0) 92 | 93 | edge_index = jt.concat([edge_index, loop_index], dim=1) 94 | 95 | return edge_index, edge_weight 96 | 97 | 98 | def add_remaining_self_loops(edge_index, 99 | edge_weight: Optional[Var] = None, 100 | fill_value: float = 1., 101 | num_nodes: Optional[int] = None): 102 | r"""Adds remaining self-loop :math:`(i,i) \in \mathcal{E}` to every node 103 | :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. 104 | In case the graph is weighted and already contains a few self-loops, only 105 | non-existent self-loops will be added with edge weights denoted by 106 | :obj:`fill_value`. 107 | 108 | Args: 109 | edge_index (Var int32): The edge indices. 110 | edge_weight (Var, optional): One-dimensional edge weights. 111 | (default: :obj:`None`) 112 | fill_value (float, optional): If :obj:`edge_weight` is not :obj:`None`, 113 | will add self-loops with edge weights of :obj:`fill_value` to the 114 | graph. (default: :obj:`1.`) 115 | num_nodes (int, optional): The number of nodes, *i.e.* 116 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 117 | 118 | :rtype: (:class:`Var int32`, :class:`Var`) 119 | """ 120 | N = maybe_num_nodes(edge_index, num_nodes) 121 | row, col = edge_index[0], edge_index[1] 122 | mask = row != col 123 | 124 | loop_index = jt.arange(0, N, dtype=row.dtype) 125 | loop_index = loop_index.unsqueeze(0).repeat(2, 1) 126 | edge_index = jt.concat([edge_index[:, mask], loop_index], dim=1) 127 | 128 | if edge_weight is not None: 129 | inv_mask = jt.logical_not(mask) 130 | loop_weight = init.constant((N, ), edge_weight.dtype, fill_value) 131 | remaining_edge_weight = edge_weight[inv_mask] 132 | if remaining_edge_weight.numel() > 0: 133 | loop_weight[row[inv_mask]] = remaining_edge_weight 134 | edge_weight = jt.concat([edge_weight[mask], loop_weight], dim=0) 135 | 136 | return edge_index, edge_weight 137 | -------------------------------------------------------------------------------- /jittor_geometric/utils/num_nodes.py: -------------------------------------------------------------------------------- 1 | from jittor import Var 2 | 3 | 4 | def maybe_num_nodes(edge_index, num_nodes=None): 5 | if num_nodes is not None: 6 | return num_nodes 7 | elif isinstance(edge_index, Var): 8 | return int(edge_index.max()) + 1 9 | else: 10 | return max(edge_index.size(0), edge_index.size(1)) 11 | -------------------------------------------------------------------------------- /mainJittor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import time 5 | import cv2 6 | from data import Data 7 | from config import Config 8 | from networkJittor import Model 9 | 10 | import jittor as jt 11 | from jittor import init 12 | from jittor import nn 13 | jt.flags.use_cuda = 1 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--mode", type=str, default="train", choices=['train', 'recon']) 17 | parser.add_argument("--logfolder", type=str, default="temp", help="save checkpoint dir") 18 | parser.add_argument("--dataset", type=str, default="scape", help="the training dataset name") 19 | 20 | # ckp path 21 | parser.add_argument("--cp_name",type=str, default="best.model", help="the checkpoint name") 22 | 23 | # test input 24 | parser.add_argument("--test_file",type=str, default="test.mat", help="the test file") 25 | 26 | parser.add_argument("--lr",type=float, default=0.01,help="learning rate") 27 | parser.add_argument("--epoch",type=int, default=30000, help="the training epoch") 28 | 29 | # training loss weights 30 | parser.add_argument("--lambda0",type=float, default=1,help="reconstruction loss") 31 | parser.add_argument("--lambda1",type=float, default=1000,help="sparsity constraints") 32 | parser.add_argument("--lambda2",type=float, default=1,help="weights norm loss") 33 | parser.add_argument("--lambda3",type=float, default=1,help="trainable d loss") 34 | parser.add_argument("--lambda4",type=float, default=0.01,help="KL loss") 35 | parser.add_argument("--std",type=float, default=1,help="std") 36 | 37 | # other network setting 38 | parser.add_argument("--finaldim",type=int, default=9, help="the final layer dimension") 39 | parser.add_argument("--latent",type=int, default=50, help="the latent dimension") 40 | parser.add_argument("--K",type=int, default=3, help="the graph convolution parameter K") 41 | parser.add_argument("--layer_num",type=int, default=1, help="number of convolution layers") 42 | parser.add_argument("--th",type=int, default=10, help="the start valid threshold") 43 | 44 | parser.add_argument("--seed",type=int, default=1, help="random seed") 45 | 46 | #sparse constrain type 47 | parser.add_argument("--d_type", type=str, default='dynamic', choices=['fix', 'dynamic'], help='which sprase constrain to use') 48 | # adjacency matrix type 49 | parser.add_argument("--weight_type", type=str, default='normal', choices=['normal', 'cot'], help='normal or cotangent adjacency matrix') 50 | # convolution type 51 | parser.add_argument("--conv_type", type=str, default='spectral', choices=['spectral', 'spatial'], help='spectral or spatial convolution') 52 | # activation function type 53 | parser.add_argument("--ac_type", type=str, default='tanh', choices=['none', 'tanh', 'selu'], help='actiation function type') 54 | # network structure type 55 | parser.add_argument("--net_type", type=str, default='VAE', choices=['VAE', 'AE'], help='network structure type') 56 | # synthesis input [component_id, max or min, weight] 57 | parser.add_argument("--syn_list",nargs='+', type=int, default=[0,0,0], help='synthesis input') 58 | 59 | parser.add_argument("--deform_weight", type=int, default=10, help='weight of defrom') 60 | parser.add_argument("--deform_lr", type=float, default=0.01, help='weight of defrom') 61 | parser.add_argument("--deform_epoch", type=int, default=1000, help='weight of defrom') 62 | 63 | opt = parser.parse_args() 64 | print(opt) 65 | config = Config(opt) 66 | 67 | data = Data(config) 68 | 69 | model = Model(config, data) 70 | optimizer = model.optimizer 71 | 72 | from jittor.dataset import Dataset 73 | class MyDataset(Dataset): 74 | def __init__(self, mydata): 75 | super().__init__() 76 | self.data = mydata 77 | 78 | def __getitem__(self, k): 79 | return self.data[k] 80 | 81 | def __len__(self,): 82 | return len(self.data) 83 | 84 | # train_loader = MyDataset(data.train_data).set_attrs(batch_size=16, shuffle=True) 85 | # train_loader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True) 86 | 87 | # ---------- 88 | # Training 89 | # ---------- 90 | 91 | for epoch in range(config.epoch): 92 | # for i, (imgs, _) in enumerate(train_loader): 93 | # for i, train_input in enumerate(train_loader): 94 | sta = time.time() 95 | train_input = jt.array(data.train_data).stop_grad().astype(jt.float32) 96 | # train_input = jt.array(data.train_data).astype(jt.float32) 97 | # ----------------- 98 | # Train Generator 99 | # ----------------- 100 | KL_loss, Generation_loss, laplacian_norm, weights_norm, dloss = model(train_input) 101 | # sumLoss = KL_loss + Generation_loss + laplacian_norm 102 | # sumLoss = KL_loss + Generation_loss 103 | sumLoss = Generation_loss + laplacian_norm + weights_norm + dloss 104 | 105 | optimizer.step(sumLoss) 106 | 107 | # --------------------- 108 | # Train Discriminator 109 | # --------------------- 110 | 111 | print( 112 | "[Epoch %d/%d] [KL loss: %f] [Gen loss: %f] [Laplacian loss: %f] [weights_norm: %f] [dloss: %f] [Time: %f]" 113 | % (epoch, config.epoch, KL_loss.data, Generation_loss.data, laplacian_norm.data, weights_norm.data, dloss.data, time.time() - sta) 114 | ) 115 | # print("Epoch %s has done ..." % epoch) -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | from utils import * 4 | from data import * 5 | from f2v import Feature2Vertex 6 | import os 7 | 8 | 9 | 10 | class Model(): 11 | def __init__(self, config, data): 12 | self.global_step = tf.Variable( 13 | initial_value = 0, 14 | name = 'global_step', 15 | trainable = False) 16 | self.pointnum = data.pointnum 17 | 18 | self.maxdegree = data.maxdegree 19 | self.L = Laplacian(data.weight) 20 | self.L = rescale_L(self.L, lmax=2) 21 | self.config = config 22 | self.data = data 23 | self.f2v = Feature2Vertex(self.data) 24 | constant_b_min = 0.2 25 | constant_b_max = 0.4 26 | tf.set_random_seed(self.config.seed) 27 | self.train_input = tf.constant(data.train_data,dtype = 'float32',name='train_input') 28 | self.valid_input = tf.constant(data.valid_data,dtype = 'float32',name='valid_input') 29 | self.KL_weights = tf.placeholder(tf.float32, (), name = 'KL_weights') 30 | self.inputs = tf.placeholder(tf.float32, [None, self.pointnum, 9], name = 'input_mesh') 31 | self.random_input = tf.placeholder(tf.float32, [None, self.config.latent], name='random_input') 32 | self.nb = tf.constant(data.neighbour, dtype='int32', shape=[self.pointnum, self.maxdegree], name='nb_relation') 33 | self.degrees = tf.constant(data.degrees, dtype = 'float32', shape=[self.pointnum, 1], name = 'degrees') 34 | self.embedding_inputs = tf.placeholder(tf.float32, [None, self.config.latent], name = 'embedding_inputs') 35 | self.laplacian = tf.constant(data.geodesic_weight,dtype = 'float32', shape =(self.pointnum, self.pointnum), name = 'geodesic_weight') 36 | self.finaldim = config.finaldim 37 | self.layer_num = config.layer_num 38 | for i in range(self.layer_num): 39 | if self.config.conv_type=='spatial': 40 | n, e = self.get_conv_weights(9, self.finaldim, name = 'convw'+str(i)) 41 | setattr(self, 'n'+str(i+1), n) 42 | setattr(self, 'e'+str(i+1), e) 43 | else: 44 | setattr(self, 'W'+str(i*2+1), tf.get_variable("conv_weight"+str(i*2+1),[9*self.config.K, self.finaldim], tf.float32, tf.random_normal_initializer(stddev=0.02, seed = self.config.seed))) 45 | # setattr(self, 'W'+str(i*2+2), tf.get_variable("conv_weight"+str(i*2+2),[self.finaldim*self.config.K, 9], tf.float32, tf.random_normal_initializer(stddev=0.02, seed = self.config.seed))) 46 | 47 | if self.config.d_type=='dynamic': 48 | self.b0=tf.get_variable("b0",[self.config.latent,1],tf.float32,tf.constant_initializer(0.2)) 49 | self.fcparams = tf.get_variable("weights", [self.pointnum*self.finaldim, self.config.latent], tf.float32, tf.random_normal_initializer(stddev=0.02, seed = self.config.seed)) 50 | self.stdparams = tf.get_variable("stdweights", [self.pointnum*self.finaldim, self.config.latent], tf.float32, tf.random_normal_initializer(stddev=0.02, seed = self.config.seed)) 51 | self.fcparams_group = tf.transpose(tf.reshape(self.fcparams, [self.pointnum, self.finaldim, self.config.latent]), perm = [2, 0, 1]) 52 | 53 | self.selfdot = tf.reduce_sum(tf.pow(self.fcparams_group, 2.0), axis = 2) 54 | self.maxdimension = tf.argmax(self.selfdot, axis = 1) 55 | if self.config.dataname =='fat': 56 | self.maxdimension = [5459, 5712, 5467, 5717, 5016, 5048 ,5074, 5060, 4880, 4889, 4912 ,5365, 2000 ,2081, 2008, 2252 , 1691, 1602, 1721 ,1595 ,1505 ,1830 ,1389 ,1429 ,3511, 3050, 335 , 414 , 6765, 4589, 4393, 4985 , 3365, 1103, 907 , 1512 , 6786 ,4594 ,4931, 6513, 3387 ,1108 ,989 ,1246, 3380, 1115 ,973, 1454, 3146 ,3500] 57 | elif self.config.dataname =='scape': 58 | self.maxdimension = [920, 975, 937, 880, 1406, 1328, 1354, 1362, 1651, 1583, 1742, 1627, 1034, 1084, 1093, 1060, 1377, 1292 ,1309, 1337,1668, 1606, 1544, 1735, 1399, 1877, 2110, 2156 , 198, 380, 651, 1020, 203 , 334 , 603 , 946, 202 , 390 , 607, 911,219, 372 , 614, 855 , 190, 361 , 625 , 905, 762 , 1219] 59 | self.maxlaplacian = tf.gather(self.laplacian, self.maxdimension) 60 | if self.config.d_type=='fix': 61 | self.laplacian_new=tf.minimum(tf.maximum(tf.div(tf.add(self.maxlaplacian,-constant_b_min),constant_b_max-constant_b_min),0),1) 62 | else: 63 | self.laplacian_new=tf.maximum(binarize(self.maxlaplacian-self.b0),0) 64 | 65 | self.laplacian_norm = self.config.lambda1*tf.reduce_mean(tf.reduce_sum(tf.sqrt(self.selfdot) * self.laplacian_new, 1)) 66 | 67 | def log_normal_pdf(sample, mean, logvar, raxis=1): 68 | log2pi = tf.math.log(2. * np.pi) 69 | return tf.reduce_sum( 70 | -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi), 71 | axis=raxis) 72 | 73 | 74 | if self.config.net_type == 'VAE': 75 | self.mean, self.std, self.weights_norm, _ = self.encoder_vae(self.train_input, train = True) 76 | eps = tf.random.normal(self.mean.shape, stddev=self.config.stddev, seed = self.config.seed) 77 | self.decoder_input = self.mean + tf.exp(self.std*.5)*eps 78 | logpz = log_normal_pdf(self.decoder_input, 0., 0.) 79 | logqz_x = log_normal_pdf(self.decoder_input, self.mean, self.std) 80 | self.KL_loss = -tf.reduce_mean(logpz - logqz_x) 81 | # self.decoder_input = self.mean + self.std*eps 82 | # self.KL_loss = 0.5 * tf.reduce_mean(tf.reduce_sum(tf.square(self.mean) + tf.square(self.std) - tf.log(1e-8 + tf.square(self.std)) - 1, 1)) 83 | self.test_mean,self.test_std = self.encoder_vae(self.valid_input, train = False) 84 | test_eps = tf.random.normal(self.test_mean.shape, stddev=self.config.stddev, seed = self.config.seed) 85 | self.test_encode = self.test_mean + tf.exp(self.test_std*.5)*test_eps 86 | # self.test_encode = self.test_mean + self.test_std * test_eps 87 | self.feed_encode,self.feed_std=self.encoder_vae(self.inputs, train = False) 88 | 89 | #self.test_KL_loss = 0.5 * tf.reduce_mean(tf.reduce_sum(tf.square(self.test_mean) + tf.square(self.test_std) - tf.log(1e-8 + tf.square(self.test_std)) - 1, 1)) 90 | else: 91 | self.decoder_input, self.weights_norm, _ = self.encoder(self.train_input, train = True) 92 | self.KL_loss=tf.constant(0,dtype = 'float32',shape=[1],name='KL_loss') 93 | self.test_encode = self.encoder(self.valid_input, train = False) 94 | self.feed_encode=self.encoder(self.inputs, train = False) 95 | 96 | self.KL_loss = self.KL_loss * self.KL_weights 97 | # self.KL_loss = self.config.lambda4 * self.KL_loss 98 | self.weights_norm = self.config.lambda2 * self.weights_norm 99 | 100 | self.decode = self.decoder(self.decoder_input, train = True) 101 | self.test_decode = self.decoder(self.test_encode, train = False) 102 | 103 | # self.test_generation_loss = tf.reduce_sum(tf.pow(self.valid_input-self.test_decode, 2.0), [1,2]) 104 | 105 | self.test_generation_loss = tf.reduce_mean(tf.reduce_sum(tf.pow(self.valid_input-self.test_decode, 2.0), [1,2])) 106 | 107 | 108 | self.feed_decode = self.decoder(self.feed_encode,train = False) 109 | self.embedding_output=self.decoder(self.embedding_inputs,train = False) 110 | self.random_decoder = self.decoder(self.random_input, train=False) 111 | 112 | self.generation_loss = tf.reduce_mean(tf.reduce_sum(tf.pow(self.train_input-self.decode, 2.0), [1,2])) 113 | 114 | 115 | if self.config.d_type=='dynamic': 116 | self.dloss=self.config.lambda3 * (tf.reduce_sum(tf.maximum(self.b0,0))) 117 | else: 118 | self.dloss=tf.constant(0,dtype = 'float32',shape=[1],name='dloss') 119 | 120 | self.loss = self.config.lambda0 * self.generation_loss + self.laplacian_norm + self.weights_norm + self.dloss + self.KL_loss 121 | 122 | 123 | 124 | 125 | learning_rate = tf.train.exponential_decay( 126 | 0.001, 127 | self.global_step, 128 | decay_steps=4000, 129 | decay_rate=0.9, 130 | staircase=True) 131 | self.learning_rate = tf.maximum(learning_rate,1e-5) 132 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, global_step=self.global_step) 133 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = None) 134 | 135 | 136 | def spatial_conv(self, input_feature, input_dim, output_dim,nb_weights, edge_weights, name = 'meshconv', training = True, special_activation = False, no_activation = False, bn = True): 137 | with tf.variable_scope(name) as scope: 138 | 139 | padding_feature = tf.zeros([tf.shape(input_feature)[0], 1, input_dim], tf.float32) 140 | 141 | padded_input = tf.concat([padding_feature, input_feature], 1) 142 | 143 | def compute_nb_feature(input_f): 144 | return tf.gather(input_f, self.nb) 145 | 146 | total_nb_feature = tf.map_fn(compute_nb_feature, padded_input) 147 | mean_nb_feature = tf.reduce_sum(total_nb_feature, axis = 2)/self.degrees 148 | 149 | # nb_weights = tf.get_variable("nb_weights", [input_dim, output_dim], tf.float32, tf.random_normal_initializer(stddev=0.02)) 150 | nb_bias = tf.get_variable("nb_bias", [output_dim], tf.float32, initializer=tf.constant_initializer(0.0)) 151 | nb_feature = tf.tensordot(mean_nb_feature, nb_weights, [[2],[0]]) + nb_bias 152 | 153 | # edge_weights = tf.get_variable("edge_weights", [input_dim, output_dim], tf.float32, tf.random_normal_initializer(stddev=0.02)) 154 | edge_bias = tf.get_variable("edge_bias", [output_dim], tf.float32, initializer=tf.constant_initializer(0.0)) 155 | edge_feature = tf.tensordot(input_feature, edge_weights, [[2],[0]]) + edge_bias 156 | 157 | total_feature = edge_feature + nb_feature 158 | 159 | if bn == False: 160 | fb = total_feature 161 | else: 162 | fb = self.batch_norm_wrapper(total_feature, is_training = training) 163 | 164 | if no_activation == True: 165 | fa = fb 166 | elif special_activation == False: 167 | fa = self.leaky_relu(fb) 168 | else: 169 | fa = tf.tanh(fb) 170 | 171 | return fa 172 | 173 | def get_conv_weights(self, input_dim, output_dim, name = 'convweight'): 174 | with tf.variable_scope(name) as scope: 175 | n = tf.get_variable("nb_weights", [input_dim, output_dim], tf.float32, tf.random_normal_initializer(stddev=0.02, seed = self.config.seed)) 176 | e = tf.get_variable("edge_weights", [input_dim, output_dim], tf.float32, tf.random_normal_initializer(stddev=0.02, seed = self.config.seed)) 177 | 178 | return n, e 179 | 180 | def encoder_vae(self, input_feature, train = True): 181 | with tf.variable_scope("encoder_vae") as scope: 182 | if(train == False): 183 | scope.reuse_variables() 184 | conv = input_feature 185 | for i in range(self.layer_num): 186 | if self.config.conv_type=='spatial': 187 | conv = self.spatial_conv(conv, 9, self.finaldim,getattr(self, 'n'+str(i+1)), getattr(self, 'e'+str(i+1)), name = 'conv'+str(i+1), special_activation = True, no_activation= False if i==self.layer_num-1 else True, training = train, bn = False) 188 | 189 | else: 190 | 191 | conv=spectral_conv(conv,self.L,self.finaldim,self.config.K,getattr(self,'W'+str(i*2+1)), name='conv' + str(i+1), activation = self.config.ac_type) 192 | 193 | 194 | l0 = tf.reshape(conv, [tf.shape(conv)[0], self.pointnum*self.finaldim]) 195 | 196 | l1 = tf.matmul(l0, self.fcparams) 197 | std = tf.matmul(l0,self.stdparams) 198 | # std = 2*tf.sigmoid(tf.matmul(l0,self.stdparams)) 199 | 200 | if train == True: 201 | weights_maximum = tf.reduce_max(tf.abs(l1), 0) - 5 202 | zeros = tf.zeros_like(weights_maximum) 203 | weights_norm = tf.reduce_mean(tf.maximum(weights_maximum, zeros)) 204 | return l1,std, weights_norm,conv 205 | else: 206 | return l1,std 207 | def encoder(self, input_feature, train = True): 208 | with tf.variable_scope("encoder") as scope: 209 | if(train == False): 210 | scope.reuse_variables() 211 | conv = input_feature 212 | for i in range(self.layer_num): 213 | if self.config.conv_type=='spatial': 214 | conv = self.spatial_conv(conv, 9, self.finaldim,getattr(self, 'n'+str(i+1)), getattr(self, 'e'+str(i+1)), name = 'conv'+str(i+1), special_activation = True, no_activation= False if i==self.layer_num-1 else True, training = train, bn = False) 215 | 216 | else: 217 | conv=spectral_conv(conv,self.L,self.finaldim,self.config.K,getattr(self,'W'+str(i*2+1)), name='conv' + str(i+1), activation = self.config.ac_type if i==self.layer_num-1 else 'none') 218 | 219 | l0 = tf.reshape(conv, [tf.shape(conv)[0], self.pointnum*self.finaldim]) 220 | 221 | l1 = tf.matmul(l0, self.fcparams) 222 | 223 | 224 | if train == True: 225 | weights_maximum = tf.reduce_max(tf.abs(l1), 0) - 5 226 | zeros = tf.zeros_like(weights_maximum) 227 | weights_norm = tf.reduce_mean(tf.maximum(weights_maximum, zeros)) 228 | return l1, weights_norm,conv 229 | else: 230 | return l1 231 | 232 | def decoder(self, latent_tensor, train = True): 233 | with tf.variable_scope("decoder") as scope: 234 | if(train == False): 235 | scope.reuse_variables() 236 | 237 | l1 = tf.matmul(latent_tensor, tf.transpose(self.fcparams)) 238 | l2 = tf.reshape(l1, [tf.shape(l1)[0], self.pointnum, self.finaldim]) 239 | conv = l2 240 | for i in range(self.layer_num): 241 | if self.config.conv_type=='spatial': 242 | conv = self.spatial_conv(conv, self.finaldim, 9, tf.transpose(getattr(self, 'n'+str(i+1))), tf.transpose(getattr(self, 'e'+str(i+1))), name = 'conv'+str(i+1), training = train, special_activation = True, bn = False) 243 | else: 244 | conv=spectral_conv(conv,self.L,9,self.config.K,getattr(self,'W'+str(i*2+1)),name='conv'+str(i+1), activation=self.config.ac_type) 245 | return conv 246 | 247 | def train(self): 248 | config = tf.ConfigProto() 249 | config.gpu_options.allow_growth = True 250 | with tf.Session(config=config) as sess: 251 | file = open('./checkpoint/'+self.config.logfolder+'/script_result.txt', 'w') 252 | vfile = open('./checkpoint/'+self.config.logfolder+'/script_result_valid.txt', 'w') 253 | file2=open('./checkpoint/'+self.config.logfolder+'/b.txt', 'w') 254 | 255 | if os.path.exists('./checkpoint/'+self.config.logfolder+'/checkpoint'): 256 | self.saver.restore(sess, tf.train.latest_checkpoint('./checkpoint/'+self.config.logfolder)) 257 | print('restore!') 258 | else: 259 | tf.global_variables_initializer().run() 260 | # x=sess.run(self.train_input) 261 | # print(sess.run(self.valid_input)) 262 | # xx() 263 | 264 | global_step = sess.run(self.global_step) 265 | valid_best = float('inf') 266 | file.write("d_type:%s,lambda0:%f,lambda1:%f,lambda2:%f,lambda3:%f,lambda4:%f,k:%d,start_idx:%d,end_idx:%d" % (self.config.d_type, self.config.lambda0, self.config.lambda1, self.config.lambda2, self.config.lambda3, self.config.lambda4, self.config.K, self.config.start_idx, self.config.end_idx)) 267 | 268 | for epoch in range(global_step,self.config.epoch): 269 | for step in range(1): 270 | # KL_weights = 0 271 | # if epoch >= 15000: 272 | KL_weights = self.config.lambda4 273 | cost_generation, cost_kl, cost_norm, cost_weights, cost_d, _, lr =sess.run([self.generation_loss, self.KL_loss, self.laplacian_norm, self.weights_norm, self.dloss, self.optimizer, self.learning_rate], feed_dict={self.KL_weights:KL_weights}) 274 | 275 | # x=sess.run(getattr(self, 'n1')) 276 | # print(x) 277 | # xx() 278 | # if epoch > 15000 and epoch % 1000 == 0: 279 | # self.saver.save(sess, self.config.logfolder +'/'+str(epoch)+'.model') 280 | cost_valid = sess.run(self.test_generation_loss) 281 | # cost_valid = cost_valid[23] 282 | 283 | print("Epoch: [%5d|%5d] lr: %.5f generation_loss: %.8f validation: %.8f norm_loss: %.8f weight_loss: %.8f dloss: %.8f klloss: %.8f" % (epoch+1,step+1,lr, cost_generation, cost_valid, cost_norm, cost_weights,cost_d, cost_kl)) 284 | file.write("Epoch: [%5d|%5d] lr: %.5f generation_loss: %.8f validation: %.8f norm_loss: %.8f weight_loss: %.8f dloss: %.8f klloss: %.8f\n" % (epoch+1,step+1,lr, cost_generation, cost_valid, cost_norm, cost_weights,cost_d,cost_kl)) 285 | # print(cost_valid) 286 | # cc() 287 | if cost_generation + cost_kl < 50: 288 | if cost_generation + cost_kl < valid_best: 289 | valid_best = cost_generation + cost_kl - 0.1 290 | print('Save best!') 291 | self.saver.save(sess, './checkpoint/'+self.config.logfolder +'/'+'best.model') 292 | vfile.write("save best!\nEpoch: [%5d|%5d] generation_loss: %.8f, validation: %.8f, norm_loss: %.8f weight_loss: %.8f dloss: %.8f klloss: %.8f\n" % (epoch+1,step+1, cost_generation, cost_valid, cost_norm,cost_weights,cost_d, cost_kl)) 293 | if self.config.d_type=='dynamic': 294 | file2.write(str(self.b0.eval())) 295 | file2.write('\n') 296 | else: 297 | pass 298 | 299 | 300 | 301 | def export_obj(self, out, v, f): 302 | with open(out, 'w') as fout: 303 | for i in range(v.shape[0]): 304 | fout.write('v %f %f %f\n' % (v[i, 0], v[i, 1], v[i, 2])) 305 | for i in range(f.shape[0]): 306 | fout.write('f %d %d %d\n' % (f[i, 0], f[i, 1], f[i, 2])) 307 | 308 | def show_embedding(self, logfolder, cp): 309 | savedir='./result/'+logfolder+'/' 310 | if not os.path.exists(savedir): 311 | os.makedirs(savedir) 312 | config = tf.ConfigProto() 313 | config.gpu_options.allow_growth = True 314 | 315 | with tf.Session(config=config) as sess: 316 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 317 | embedding = sess.run(self.feed_encode, feed_dict = {self.inputs: self.data.feature}) 318 | sio.savemat(savedir+'/embedding.mat', {'embedding':embedding}) 319 | 320 | 321 | 322 | def deform(self, logfolder, cp, test_file, deform_weight=10, deform_lr = 0.01, total_epoch = 1000): 323 | # self.show_embedding(logfolder, cp) 324 | savedir='./result/'+logfolder+'/' 325 | if not os.path.exists(savedir): 326 | os.makedirs(savedir) 327 | if not os.path.exists(savedir+'/edit_'+test_file[:-4]): 328 | os.makedirs(savedir+'/edit_'+test_file[:-4]) 329 | config = tf.ConfigProto() 330 | config.gpu_options.allow_growth = True 331 | 332 | with tf.Session(config=config) as sess: 333 | # run_metadata = tf.RunMetadata() 334 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 335 | 336 | feature = self.data.load_test_data(test_file) 337 | data = sio.loadmat(test_file) 338 | control_point = tf.squeeze(tf.constant(data['cp'],dtype = 'int32',name='control_point')) 339 | other_point = np.setdiff1d(np.arange(self.data.pointnum),data['cp']) 340 | # other_point = tf.squeeze(tf.constant(np.setdiff1d(np.arange(self.data.pointnum),data['cp']),dtype = 'int32',name='other_point')) 341 | # print(other_point.shape) 342 | # cc() 343 | target_coor = tf.constant(data['target_coor'],dtype = 'float32',name='target_coor') 344 | embedding = sess.run(self.feed_encode, feed_dict = {self.inputs: feature}) 345 | 346 | global_step = tf.Variable( 347 | initial_value = 0, 348 | name = 'edit_global_step', 349 | trainable = False) 350 | learning_rate = tf.train.exponential_decay( 351 | deform_lr, 352 | global_step, 353 | decay_steps=4000, 354 | decay_rate=0.9, 355 | staircase=True) 356 | 357 | embedding_inputs = tf.Variable(initial_value=embedding, name='embedding', shape = [1, self.config.latent], dtype='float32',trainable=True) 358 | recon_feature = self.decoder(embedding_inputs, train = False) 359 | T = self.f2v.ftoT(recon_feature, use_identity=True) 360 | 361 | 362 | recon_vertex = self.f2v.Ttov(T) 363 | temp = tf.pow(tf.gather(recon_vertex, control_point, axis = 1) - target_coor, 2.0) 364 | if len(temp.shape) == 2: 365 | edit_loss = tf.reduce_mean(tf.reduce_sum(temp, [1])) #* 100000 366 | else: 367 | edit_loss = tf.reduce_mean(tf.reduce_sum(temp, [1,2])) #* 10000 368 | 369 | edit_loss = deform_weight * edit_loss 370 | # print(edit_loss) 371 | # preserve_loss = tf.reduce_mean(tf.reduce_sum(tf.pow(tf.gather(recon_vertex, other_point, axis = 1) - self.data.vertices[other_point,:], 2.0), [1,2])) 372 | # print(feature[:,other_point,:].shape) 373 | # print(tf.gather(recon_feature, other_point, axis = 1).shape) 374 | preserve_loss = tf.reduce_mean(tf.reduce_sum(tf.pow(tf.gather(recon_feature, other_point, axis = 1) - feature[:,other_point,:], 2.0), [1,2])) 375 | edit_optimizer = tf.train.GradientDescentOptimizer(learning_rate) 376 | # edit_optimizer = tf.train.AdamOptimizer(learning_rate) 377 | total_loss = edit_loss# + preserve_loss 378 | edit_optim = edit_optimizer.minimize(total_loss, var_list=[embedding_inputs], global_step = global_step) 379 | intial_op = tf.variables_initializer([global_step,embedding_inputs] + edit_optimizer.variables()) 380 | sess.run(intial_op) 381 | # T = sess.run(T) 382 | # print(rf.shape) 383 | 384 | # total_epoch = 1000 385 | best = np.inf 386 | embedding_seq = [] 387 | import time 388 | t1 = time.time() 389 | for epoch in range(total_epoch): 390 | edit_loss_print, preserve_loss_print,emb, lr, _ = sess.run([edit_loss,preserve_loss, embedding_inputs, learning_rate, edit_optim]) 391 | 392 | # embedding = sess.run(embedding_inputs) 393 | # embedding_seq.append(embedding) 394 | # if edit_loss_print+preserve_loss_print < best: 395 | # if edit_loss_print < best and preserve_loss_print < 30: 396 | # best = (edit_loss_print+preserve_loss_print)*0.9 397 | # best = (edit_loss_print)*0.9 398 | # vert = sess.run(recon_vertex).squeeze() 399 | # self.export_obj(savedir+'/'+str(epoch+1)+'_best.obj',vert, self.data.face) 400 | print("Epoch: [%5d] lr: %.5f edit_loss: %.8f preserve_loss: %.8f" % (epoch+1, lr, edit_loss_print, preserve_loss_print)) 401 | # print(emb) 402 | t2 = time.time() 403 | print('time:',t2-t1) 404 | vert = sess.run(recon_vertex).squeeze() 405 | 406 | # vert2 = sess.run(recon_vertex2, feed_dict = {self.inputs: self.data.feature}) 407 | # print(vert2[0]) 408 | 409 | 410 | self.export_obj(savedir+'/edit_'+test_file[:-4]+'/'+str(deform_weight)+'_'+str(deform_lr)+'_'+str(total_epoch)+'.obj',vert, self.data.face) 411 | # sio.savemat(savedir+'/edit_'+test_file[:-4]+'/'+'/seq_emb'+str(deform_weight)+'_'+str(deform_lr)+'_'+str(total_epoch)+'.mat', {'embedding':np.array(embedding),'RS':rs, 'RLOGR':rlogr,'T':T,'R':R,'S':S,'sum_T':sum_T,'vert':vert,'bdiff':bdiff}) 412 | 413 | # from tensorflow.python.client import timeline 414 | # tl = timeline.Timeline(run_metadata.step_stats) 415 | # ctf = tl.generate_chrome_trace_format() 416 | # with open(savedir+'/edit_'+test_file[:-4]+'/'+'timeline.json', 'w') as f: 417 | # f.write(ctf) 418 | 419 | 420 | def inter_from_embedding(self, logfolder, cp, test_file): 421 | savedir='./result/'+logfolder + '/inter/' 422 | if not os.path.exists(savedir): 423 | os.makedirs(savedir) 424 | config = tf.ConfigProto() 425 | config.gpu_options.allow_growth = True 426 | with tf.Session(config=config) as sess: 427 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 428 | embedding = sio.loadmat(test_file)['embedding'] 429 | inter_result = np.linspace(embedding[0], embedding[1], 10) 430 | recover = sess.run(self.feed_decode, feed_dict = {self.feed_encode:inter_result}) 431 | rs, rlogr = self.data.recover_data(recover, self.data.logrmin, self.data.logrmax, self.data.smin, self.data.smax, self.data.pointnum) 432 | sio.savemat(savedir+'/inter_from_embedding.mat', {'RS':rs, 'RLOGR':rlogr}) 433 | 434 | def inter_from_mat(self, logfolder, cp, test_file): 435 | savedir='./result/'+logfolder + '/inter/' 436 | if not os.path.exists(savedir): 437 | os.makedirs(savedir) 438 | config = tf.ConfigProto() 439 | config.gpu_options.allow_growth = True 440 | with tf.Session(config=config) as sess: 441 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 442 | embedding = sess.run(self.feed_encode, feed_dict = {self.inputs: self.data.load_test_data(test_file)}) 443 | inter_result = np.linspace(embedding[0], embedding[1], 10) 444 | recover = sess.run(self.feed_decode, feed_dict = {self.feed_encode:inter_result}) 445 | rs, rlogr = self.data.recover_data(recover, self.data.logrmin, self.data.logrmax, self.data.smin, self.data.smax, self.data.pointnum) 446 | sio.savemat(savedir+'/inter_from_mat.mat', {'RS':rs, 'RLOGR':rlogr}) 447 | 448 | def recover_test(self, logfolder, cp, test_file): 449 | savedir='./result/'+logfolder+'/' 450 | if not os.path.exists(savedir): 451 | os.makedirs(savedir) 452 | config = tf.ConfigProto() 453 | config.gpu_options.allow_growth = True 454 | with tf.Session(config=config) as sess: 455 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 456 | recover = sess.run(self.feed_decode, feed_dict = {self.inputs: self.data.load_test_data(test_file)}) 457 | rs, rlogr = self.data.recover_data(recover, self.data.logrmin, self.data.logrmax, self.data.smin, self.data.smax, self.data.pointnum) 458 | sio.savemat(savedir+'/test.mat', {'RS':rs, 'RLOGR':rlogr}) 459 | 460 | def embedd_inter(self, logfolder, cp): 461 | savedir='./result/'+logfolder+'/' 462 | if not os.path.exists(savedir): 463 | os.makedirs(savedir) 464 | config = tf.ConfigProto() 465 | config.gpu_options.allow_growth = True 466 | with tf.Session(config=config) as sess: 467 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 468 | embedding = sess.run(self.feed_encode, feed_dict = {self.inputs: self.data.feature}) 469 | down_id = [8, 9, 50, 52, 56] - 1 470 | up_id = [14, 24, 30, 43] - 1 471 | for i in range(len(down_id)): 472 | for j in range(len(up_id)): 473 | inter_result = np.linspace(embedding[down_id[i]], embedding[up_id[j]], 10) 474 | diff = inter_result[1] - inter_result[0] 475 | all_point = np.array([inter_result[0]-diff*k for k in range(5,0,-1)] + inter_result.tolist() + [inter_result[-1]+diff*k for k in range(1,6)]) 476 | # cc() 477 | recover = sess.run(self.feed_decode, feed_dict = {self.feed_encode:all_point}) 478 | rs, rlogr = self.data.recover_data(recover, self.data.logrmin, self.data.logrmax, self.data.smin, self.data.smax, self.data.pointnum) 479 | sio.savemat(savedir+'/'+str(down_id[i])+'_'+str(up_id[j])+'.mat', {'RS':rs, 'RLOGR':rlogr}) 480 | 481 | 482 | def random_generation(self, logfolder, cp): 483 | savedir='./result/'+logfolder+'/' 484 | if not os.path.exists(savedir): 485 | os.makedirs(savedir) 486 | config = tf.ConfigProto() 487 | config.gpu_options.allow_growth = True 488 | with tf.Session(config=config) as sess: 489 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 490 | embedding = sess.run(self.feed_encode, feed_dict = {self.inputs: self.data.feature}) 491 | # for i in range(5): 492 | i=4 493 | scalar = 0.1+0.1*i 494 | random_input = embedding[68] + gaussian(128, self.config.latent, var = scalar) 495 | recover = sess.run(self.random_decoder, feed_dict = {self.random_input:random_input}) 496 | rs, rlogr = self.data.recover_data(recover, self.data.logrmin, self.data.logrmax, self.data.smin, self.data.smax, self.data.pointnum) 497 | sio.savemat(savedir+'/random'+str(i)+'.mat', {'RS':rs, 'RLOGR':rlogr,'emb':random_input}) 498 | 499 | def recover_mesh(self, logfolder, cp): 500 | savedir='./result/'+logfolder+'/' 501 | if not os.path.exists(savedir): 502 | os.makedirs(savedir) 503 | config = tf.ConfigProto() 504 | config.gpu_options.allow_growth = True 505 | with tf.Session(config=config) as sess: 506 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 507 | if os.path.exists('./data/idx/'+self.config.dataname+'.dat'): 508 | split_idx = pickle.load(open('./data/idx/'+self.config.dataname+'.dat','rb')) 509 | else: 510 | split_idx=ss[self.config.dataname] 511 | # random_input = gaussian(len(feature), self.config.latent) 512 | recover, embedding = sess.run([self.feed_decode, self.feed_encode], feed_dict = {self.inputs: self.data.feature}) 513 | rs, rlogr = self.data.recover_data(recover, self.data.logrmin, self.data.logrmax, self.data.smin, self.data.smax, self.data.pointnum) 514 | sio.savemat(savedir+'/recover.mat', {'RS':rs, 'RLOGR':rlogr,'x':split_idx, 'embedding':embedding}) 515 | return 516 | 517 | def interpolate(self,logfolder, cp): 518 | config = tf.ConfigProto() 519 | config.gpu_options.allow_growth = True 520 | savedir='./result/'+logfolder + '/inter/' 521 | if not os.path.exists(savedir): 522 | os.makedirs(savedir) 523 | with tf.Session(config=config) as sess: 524 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 525 | data = sio.loadmat(savedir+'comp_latent.mat') 526 | ref = data['ref'][0,:] 527 | comp = data['comp'] 528 | new_latent=[] 529 | for k in range(0,50): 530 | for i in range(3,16,4): 531 | step = i/10.0 532 | new_latent.append(ref + (comp[k][0] - ref) * step) 533 | new_latent.append(ref + (comp[k][1] - ref) * step) 534 | recover = sess.run(self.embedding_output, feed_dict = {self.embedding_inputs: np.squeeze(new_latent)}) 535 | rs, rlogr = self.data.recover_data(recover, self.data.logrmin, self.data.logrmax, self.data.smin, self.data.smax, self.data.pointnum) 536 | sio.savemat(savedir+'inter.mat',{'RS':rs,'RLOGR':rlogr}) 537 | 538 | def individual_dimension(self, logfolder, cp, first = 0): 539 | savedir='./result/'+logfolder+'/dimension/' 540 | if not os.path.exists(savedir): 541 | os.makedirs(savedir) 542 | config = tf.ConfigProto() 543 | config.gpu_options.allow_growth = True 544 | with tf.Session(config=config) as sess: 545 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 546 | 547 | embedding = sess.run(self.feed_encode, feed_dict = {self.inputs: self.data.feature}) 548 | min_embedding = np.amin(embedding, axis = 0) 549 | 550 | max_embedding = np.amax(embedding, axis = 0) 551 | 552 | 553 | def generate_embedding_input(min, max, dimension, rest): 554 | x = np.zeros((25, self.config.latent)).astype('float32') 555 | 556 | for idx in range(self.config.latent): 557 | if idx == dimension: 558 | x[:, idx] = np.linspace(min[idx], max[idx], num = 25) 559 | else: 560 | x[:, idx] = rest[idx] 561 | 562 | return x 563 | comp=[] 564 | for idx in range(self.config.latent): 565 | embedding_data = generate_embedding_input(min_embedding, max_embedding, idx, embedding[first, :]) 566 | 567 | recover = sess.run(self.embedding_output, feed_dict = {self.embedding_inputs: embedding_data}) 568 | comp.append(embedding_data[[0,24],:,]) 569 | rs, rlogr = self.data.recover_data(recover, self.data.logrmin, self.data.logrmax, self.data.smin, self.data.smax, self.data.pointnum) 570 | sio.savemat(savedir+'dimension'+str(idx+1)+'.mat', {'RS':rs, 'RLOGR':rlogr,'embedding':embedding_data}) 571 | if not os.path.exists('./result/'+logfolder+'/inter/'): 572 | os.makedirs('./result/'+logfolder+'/inter/') 573 | sio.savemat('./result/'+logfolder+'/inter/comp_latent.mat', {'comp':comp,'ref':embedding}) 574 | 575 | def component_view(self, logfolder, cp): 576 | savedir='./result/'+logfolder+'/component/' 577 | if not os.path.exists(savedir): 578 | os.makedirs(savedir) 579 | config = tf.ConfigProto() 580 | config.gpu_options.allow_growth = True 581 | with tf.Session(config=config) as sess: 582 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 583 | 584 | component = sess.run([self.selfdot])[0] 585 | 586 | sio.savemat(savedir+'component.mat',{'component':component}) 587 | 588 | def synthesize(self,logfolder, cp, comp_idx, max_min, comp_weight): 589 | config = tf.ConfigProto() 590 | config.gpu_options.allow_growth = True 591 | savedir='./result/'+logfolder + '/synthesize/' 592 | if not os.path.exists(savedir): 593 | os.makedirs(savedir) 594 | with tf.Session(config=config) as sess: 595 | self.saver.restore(sess, './checkpoint/'+logfolder + '/' + cp) 596 | data = sio.loadmat('./result/'+logfolder+'/inter/'+'comp_latent.mat') 597 | ref = data['ref'] 598 | comp = data['comp'] 599 | latent = [] 600 | new_latent=ref[0,:] 601 | idx=0 602 | for i,j in zip(comp_idx,max_min): 603 | print(i,j) 604 | new_latent[int(i)] = comp[int(i)][int(j)][int(i)]*float(comp_weight[idx]) 605 | latent.append(comp[int(i)][int(j)]*float(comp_weight[idx])) 606 | idx = idx + 1 607 | latent.append(new_latent) 608 | recover = sess.run(self.embedding_output, feed_dict = {self.embedding_inputs: latent}) 609 | rs, rlogr = self.data.recover_data(recover, self.data.logrmin, self.data.logrmax, self.data.smin, self.data.smax, self.data.pointnum) 610 | sio.savemat(savedir+'synthesize.mat',{'RS':rs,'RLOGR':rlogr}) -------------------------------------------------------------------------------- /networkJittor.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn 3 | import os 4 | 5 | import numpy as np 6 | import scipy.sparse 7 | import scipy.sparse.linalg 8 | import scipy.spatial.distance 9 | import utilsJittor as utilsJT 10 | 11 | from jittor_geometric.nn import ChebConv 12 | 13 | def Laplacian(W, normalized=True): 14 | """Return the Laplacian of the weigth matrix.""" 15 | 16 | # Degree matrix. 17 | d = W.sum(axis=0) 18 | # Laplacian matrix. 19 | if not normalized: 20 | D = scipy.sparse.diags(d.A.squeeze(), 0) 21 | L = D - W 22 | else: 23 | #d += np.spacing(np.array(0.0, W.dtype)) 24 | d = 1 / np.sqrt(d) 25 | D = scipy.sparse.diags(d.A.squeeze(), 0) 26 | I = scipy.sparse.identity(d.size, dtype=W.dtype) 27 | L = I - D * W * D 28 | 29 | # assert np.abs(L - L.T).mean() < 1e-9 30 | assert type(L) is scipy.sparse.csr_matrix 31 | return L 32 | def rescale_L(L, lmax=2): 33 | """Rescale the Laplacian eigenvalues in [-1,1].""" 34 | M, M = L.shape 35 | I = scipy.sparse.identity(M, format='csr', dtype=L.dtype) 36 | L /= lmax / 2 37 | L -= I 38 | return L 39 | 40 | class Encoder(nn.Module): 41 | def __init__(self, input_dim, final_dim, layer_num): 42 | super(Encoder, self).__init__() 43 | # self.nb = nn.Linear(input_dim, final_dim) 44 | # self.edge = nn.Linear(input_dim, final_dim) 45 | self.fc1 = nn.Linear(input_dim, 100) 46 | self.fc2 = nn.Linear(100, final_dim) 47 | self.norm = nn.BatchNorm1d(final_dim) 48 | self.action = nn.Tanh() 49 | 50 | def execute(self, input_feature): 51 | x = self.fc1(input_feature) 52 | x = self.fc2(x) 53 | x = self.norm(x) 54 | x = self.action(x) 55 | return x 56 | 57 | class SpatialConv(nn.Module): 58 | def __init__(self, input_dim, final_dim, nb, degrees): 59 | super(SpatialConv, self).__init__() 60 | self.nb = nb # [numV, 10] 61 | self.degrees = degrees # [numV, 1] 62 | self.nb_fc = nn.Linear(input_dim, final_dim) 63 | self.edge_fc = nn.Linear(input_dim, final_dim) 64 | self.norm = nn.BatchNorm1d(final_dim) 65 | self.action = nn.Tanh() 66 | self.numC = final_dim 67 | self.gather = utilsJT.SYTgather() 68 | 69 | def execute(self, input_feature): # [b,Vxc] 70 | bs = input_feature.shape[0] 71 | numV = input_feature.shape[1] // self.numC 72 | numC = self.numC 73 | input_feature = input_feature.reshape(bs, numV, numC) 74 | # padding 75 | input_feature_ = jt.contrib.concat([jt.zeros([bs, 1, numC]), input_feature], dim=1) # # [b,V+1,c] 76 | 77 | # index_gather = self.nb.unsqueeze(0).unsqueeze(-1).expand([bs, numV, 10, numC]) 78 | # index_gather = jt.broadcast(self.nb, shape=(bs, numV, 10, numC), dims=[0,-1]) 79 | 80 | # nb_feature = [] 81 | # for idx in range(numV): 82 | # nb_feature.append(jt.gather(input_feature_, dim = 1, index=index_gather[:,idx]).unsqueeze(1)) 83 | # nb_feature = jt.contrib.concat(nb_feature, dim=1) 84 | 85 | nb_feature = self.gather(input_feature_, self.nb) 86 | 87 | # input_feature_gather = input_feature_.unsqueeze(1).expand([bs, numV, numV+1, numC]) 88 | # input_feature_gather = jt.broadcast(input_feature_, shape=(bs, numV, numV+1, numC), dims=[1]) 89 | # nb_feature = jt.gather(input_feature_gather, dim = 2, index=index_gather) # [bs,numV,10,numC] 90 | 91 | mean_nb_feature = nb_feature.sum(dim=2) / self.degrees # [bs,numV,numC] 92 | 93 | nb_feature = self.nb_fc(mean_nb_feature.reshape(bs*numV, -1)) 94 | edge_feature = self.edge_fc(input_feature.reshape(bs*numV, -1)) 95 | 96 | total_feature = edge_feature + nb_feature 97 | 98 | x = self.norm(total_feature) 99 | x = self.action(x) 100 | return x.reshape(bs, -1) 101 | 102 | class SpectralConv(nn.Module): 103 | def __init__(self, input_dim, final_dim, K, nb): 104 | super(SpectralConv, self).__init__() 105 | self.conv = ChebConv(input_dim, final_dim, K=K) 106 | edge_index = utilsJT.genEdge(nb.data) 107 | self.edge_index = jt.array(edge_index).astype(jt.int32).stop_grad() 108 | self.numC = final_dim 109 | 110 | def execute(self, input_feature): 111 | bs = input_feature.shape[0] 112 | numV = input_feature.shape[1] // self.numC 113 | numC = self.numC 114 | input_feature = input_feature.reshape(bs, numV, numC) # [b,V,c] 115 | # output = [self.conv(x, self.edge_index) for x in input_feature] # [b,V,c], [2, numE],不支持batch操作... 116 | # return jt.stack(output).reshape(bs, -1) 117 | output = self.conv(input_feature, self.edge_index) # [b,V,c], [2, numE],支持batch操作... 118 | return output.reshape(bs, -1) 119 | 120 | 121 | class Decoder(nn.Module): 122 | def __init__(self, input_dim, final_dim, layer_num): 123 | super(Decoder, self).__init__() 124 | self.fc1 = nn.Linear(final_dim, 100) 125 | self.fc2 = nn.Linear(100, input_dim) 126 | self.norm = nn.BatchNorm1d(input_dim) 127 | self.action = nn.Tanh() 128 | 129 | def execute(self, input_feature): 130 | x = self.fc1(input_feature) 131 | x = self.fc2(x) 132 | x = self.norm(x) 133 | x = self.action(x) 134 | return x 135 | 136 | 137 | class Model(nn.Module): 138 | def __init__(self, config, data): 139 | self.global_step = 0 140 | self.pointnum = data.pointnum 141 | self.maxdegree = data.maxdegree 142 | self.L = Laplacian(data.weight) 143 | self.L = rescale_L(self.L, lmax=2) 144 | self.config = config 145 | self.data = data 146 | self.constant_b_min = 0.2 147 | self.constant_b_max = 0.4 148 | jt.set_seed(self.config.seed) 149 | 150 | self.inputdim = 9 151 | self.finaldim = config.finaldim 152 | self.layer_num = config.layer_num 153 | 154 | self.nb = jt.array(data.neighbour).astype(jt.int32).stop_grad() # [V,10] 155 | self.degrees = jt.array(data.degrees).astype(jt.float32).stop_grad() # [V,1] 156 | self.laplacian = jt.array(data.geodesic_weight).astype(jt.float32).stop_grad() # [V,V] 157 | 158 | # self.C = jt.zeros([self.pointnum*self.finaldim, self.config.latent]).astype(jt.float32) 159 | # self.Cstd = jt.zeros([self.pointnum*self.finaldim, self.config.latent]).astype(jt.float32) 160 | self.C = 0.02 * jt.randn([self.pointnum*self.finaldim, self.config.latent]).astype(jt.float32) 161 | self.Cstd = 0.02 * jt.randn([self.pointnum*self.finaldim, self.config.latent]).astype(jt.float32) 162 | 163 | ############## fc 164 | # self.encoder = Encoder(self.pointnum*self.finaldim, self.pointnum*self.finaldim, self.layer_num) 165 | # self.decoder = Decoder(self.pointnum*self.finaldim, self.pointnum*self.finaldim, self.layer_num) 166 | 167 | ############# spatial conv 168 | # self.encoder = SpatialConv(self.finaldim, self.finaldim, self.nb, self.degrees) 169 | # self.decoder = SpatialConv(self.finaldim, self.finaldim, self.nb, self.degrees) 170 | 171 | ############# spectral conv 172 | self.encoder = SpectralConv(self.finaldim, self.finaldim, self.config.K, self.nb) 173 | self.decoder = SpectralConv(self.finaldim, self.finaldim, self.config.K, self.nb) 174 | 175 | self.b0 = jt.ones([self.config.latent,1]).astype(jt.float32) * 0.2 176 | self.binarize = utilsJT.binarize() 177 | 178 | self.printGrad = utilsJT.printGrad() 179 | 180 | self.criterionL2 = nn.MSELoss() 181 | 182 | ### optimizer 183 | parameters = self.encoder.parameters() 184 | parameters += [self.C, self.Cstd, self.b0] 185 | parameters += self.decoder.parameters() 186 | self.optimizer = nn.Adam(parameters, lr=0.001) 187 | 188 | def log_normal_pdf(self, sample, mean, logvar, raxis=1): 189 | log2pi = jt.log(2. * np.pi) 190 | return jt.sum( 191 | -.5 * ((sample - mean) ** 2. * jt.exp(-logvar) + logvar + log2pi), 192 | dim=raxis) 193 | 194 | def reparameterization(self, mu, logvar): 195 | std = jt.exp(logvar / 2) 196 | sampled_z = jt.array(np.random.normal(0, 1, (mu.shape[0], self.config.latent))).float32().stop_grad() 197 | z = sampled_z * std + mu 198 | return z 199 | 200 | def execute(self, input): 201 | bs = input.shape[0] 202 | input = input.reshape(bs,-1) # reshape for the fc encoder 203 | 204 | x = self.encoder(input) # [bs, Vx9] 205 | mu = jt.matmul(x, self.C) # [bs, V*feature] x [V*feature, latent] 206 | sigma = jt.matmul(x, self.Cstd) 207 | x = self.reparameterization(mu, sigma) 208 | logpz = self.log_normal_pdf(x, 0, 0) 209 | logqz_x = self.log_normal_pdf(x, mu, sigma) 210 | 211 | weights_norm = jt.max(jt.abs(mu), dim=0) - 5 # [latent] 212 | weights_norm = jt.mean(jt.maximum(weights_norm, 0)) # float 213 | weights_norm = self.config.lambda2 * weights_norm 214 | 215 | y = jt.matmul(x, self.C.transpose(1,0)) # [bs, V*feature] 216 | y = self.decoder(y) 217 | 218 | # KL_loss = -(logpz - logqz_x).mean() 219 | # copy from https://wiseodd.github.io/techblog/2017/01/24/vae-pytorch/ (kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var)) 220 | KL_loss = 0.5 * jt.sum(jt.exp(sigma) + mu**2 - 1 - sigma) * self.config.lambda4 221 | Generation_loss = self.criterionL2(input, y) * self.config.lambda0 222 | # loss = KL_loss + Generation_loss + self.laplacian_norm 223 | 224 | fcparams_group = self.C.reshape(self.pointnum, self.finaldim, self.config.latent).permute(2, 0, 1) # [latentNum, pointNum, featureDim] 225 | selfdot = jt.sum(jt.pow(fcparams_group, 2.0), dim = 2) # [latentNum, pointNum] 226 | maxdimension = jt.argmax(selfdot, dim = 1)[0] # [latentNum] 227 | maxlaplacian = jt.gather(self.laplacian, dim=0, index=jt.array(maxdimension)) # [latentNum, pointNum] 228 | # laplacian_new = jt.min(jt.max(jt.divide(jt.add(maxlaplacian,-self.constant_b_min),self.constant_b_max-self.constant_b_min),0, keepdims=True),1,keepdims=True) 229 | 230 | ### static 231 | # laplacian_new = jt.clamp((maxlaplacian-self.constant_b_min)/(self.constant_b_max-self.constant_b_min), 0, 1) 232 | ### dynamic 233 | # print("b0:", self.b0, self.b0.is_stop_grad()) 234 | 235 | # self.b00 = self.printGrad(self.b0) 236 | # maxlaplacian = self.printGrad(maxlaplacian) 237 | 238 | laplacian_new = jt.maximum(self.binarize(maxlaplacian-self.b0), 0) 239 | 240 | laplacian_norm = self.config.lambda1 * jt.mean(jt.sum(jt.sqrt(selfdot) * laplacian_new, 1)) 241 | 242 | dloss=self.config.lambda3 * (jt.sum(jt.maximum(self.b0,0))) 243 | # laplacian_norm = 0 244 | 245 | return KL_loss, Generation_loss, laplacian_norm, weights_norm, dloss 246 | -------------------------------------------------------------------------------- /result/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/MeshVAE_neural_editing/bbc54ad7cc2cad7f51bdec324bc0ff21e894bb8e/result/.gitignore -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import scipy.sparse 4 | import scipy.sparse.linalg 5 | import scipy.spatial.distance 6 | 7 | from tensorflow.python.framework import ops 8 | ''' 9 | spectral conv code from https://github.com/mdeff/cnn_graph 10 | ''' 11 | 12 | 13 | 14 | def spectral_conv(x, L, Fout, K,W,name='graph_conv',activation='tanh'): 15 | with tf.variable_scope(name) as scope: 16 | N, M, Fin = x.get_shape() 17 | L = L.tocoo() 18 | indices = np.column_stack((L.row, L.col)) 19 | L = tf.SparseTensor(indices, L.data, L.shape) 20 | L = tf.sparse_reorder(L) 21 | # Transform to Chebyshev basis 22 | x0 = tf.transpose(x, perm=[1, 2, 0]) # M x Fin x N 23 | x0 = tf.reshape(x0, [M, -1]) # M x Fin*N 24 | x = tf.expand_dims(x0, 0) # 1 x M x Fin*N 25 | def concat(x, x_): 26 | x_ = tf.expand_dims(x_, 0) # 1 x M x Fin*N 27 | return tf.concat([x, x_], axis=0) # K x M x Fin*N 28 | if K > 1: 29 | x1 = tf.sparse_tensor_dense_matmul(L, x0) 30 | x = concat(x, x1) 31 | for k in range(2, K): 32 | x2 = 2 * tf.sparse_tensor_dense_matmul(L, x1) - x0 # M x Fin*N 33 | x = concat(x, x2) 34 | x0, x1 = x1, x2 35 | x = tf.reshape(x, [K, M, Fin, -1]) # K x M x Fin x N 36 | x = tf.transpose(x, perm=[3,1,2,0]) # N x M x Fin x K 37 | x = tf.reshape(x, [-1, Fin*K]) # N*M x Fin*K 38 | # Filter: Fin*Fout filters of order K, i.e. one filterbank per feature pair. 39 | x = tf.matmul(x, W) # N*M x Fout 40 | x = tf.reshape(x, [-1, M, Fout]) 41 | if activation=='tanh': 42 | x = tf.tanh(x) 43 | elif activation == 'selu': 44 | x = tf.keras.activations.selu(x) 45 | elif activation =='spp': 46 | x = self.softplusplus(x) 47 | elif activation =='lrelu': 48 | x = self.leaky_relu(x) 49 | elif activation =='none': 50 | x = x 51 | return x # N x M x Fout 52 | 53 | def Laplacian(W, normalized=True): 54 | """Return the Laplacian of the weigth matrix.""" 55 | 56 | # Degree matrix. 57 | d = W.sum(axis=0) 58 | # Laplacian matrix. 59 | if not normalized: 60 | D = scipy.sparse.diags(d.A.squeeze(), 0) 61 | L = D - W 62 | else: 63 | #d += np.spacing(np.array(0.0, W.dtype)) 64 | d = 1 / np.sqrt(d) 65 | D = scipy.sparse.diags(d.A.squeeze(), 0) 66 | I = scipy.sparse.identity(d.size, dtype=W.dtype) 67 | L = I - D * W * D 68 | 69 | # assert np.abs(L - L.T).mean() < 1e-9 70 | assert type(L) is scipy.sparse.csr_matrix 71 | return L 72 | def rescale_L(L, lmax=2): 73 | """Rescale the Laplacian eigenvalues in [-1,1].""" 74 | M, M = L.shape 75 | I = scipy.sparse.identity(M, format='csr', dtype=L.dtype) 76 | L /= lmax / 2 77 | L -= I 78 | return L 79 | 80 | 81 | def leaky_relu(input_, alpha = 0.02): 82 | return tf.maximum(input_, alpha*input_) 83 | 84 | def softplusplus(input_, alpha=0.02): 85 | return tf.log(1.0+tf.exp(input_*(1.0-alpha)))+alpha*input_-tf.log(2.0) 86 | 87 | def linear(input_, input_size, output_size, name='Linear', stddev=0.02, bias_start=0.0): 88 | with tf.variable_scope(name) as scope: 89 | matrix = tf.get_variable("weights", [input_size, output_size], tf.float32, 90 | tf.random_normal_initializer(stddev=0.02)) 91 | bias = tf.get_variable("bias", [output_size], tf.float32, 92 | initializer=tf.constant_initializer(bias_start)) 93 | 94 | return tf.matmul(input_, matrix) + bias 95 | 96 | def batch_norm_wrapper(inputs, name = 'batch_norm',is_training = False, decay = 0.9, epsilon = 1e-5): 97 | with tf.variable_scope(name) as scope: 98 | if is_training == True: 99 | scale = tf.get_variable('scale', dtype=tf.float32, trainable=True, initializer=tf.ones([inputs.get_shape()[-1]],dtype=tf.float32)) 100 | beta = tf.get_variable('beta', dtype=tf.float32, trainable=True, initializer=tf.zeros([inputs.get_shape()[-1]],dtype=tf.float32)) 101 | pop_mean = tf.get_variable('overallmean', dtype=tf.float32,trainable=False, initializer=tf.zeros([inputs.get_shape()[-1]],dtype=tf.float32)) 102 | pop_var = tf.get_variable('overallvar', dtype=tf.float32, trainable=False, initializer=tf.ones([inputs.get_shape()[-1]],dtype=tf.float32)) 103 | else: 104 | scope.reuse_variables() 105 | scale = tf.get_variable('scale', dtype=tf.float32, trainable=True) 106 | beta = tf.get_variable('beta', dtype=tf.float32, trainable=True) 107 | pop_mean = tf.get_variable('overallmean', dtype=tf.float32, trainable=False) 108 | pop_var = tf.get_variable('overallvar', dtype=tf.float32, trainable=False) 109 | 110 | if is_training: 111 | axis = list(range(len(inputs.get_shape()) - 1)) 112 | batch_mean, batch_var = tf.nn.moments(inputs,axis) 113 | train_mean = tf.assign(pop_mean,pop_mean * decay + batch_mean * (1 - decay)) 114 | train_var = tf.assign(pop_var,pop_var * decay + batch_var * (1 - decay)) 115 | with tf.control_dependencies([train_mean, train_var]): 116 | return tf.nn.batch_normalization(inputs,batch_mean, batch_var, beta, scale, epsilon) 117 | else: 118 | return tf.nn.batch_normalization(inputs, pop_mean, pop_var, beta, scale, epsilon) 119 | 120 | 121 | 122 | 123 | 124 | 125 | def binarize(x): 126 | """ 127 | Clip and binarize tensor using the straight through estimator (STE) for the gradient. 128 | """ 129 | g = tf.get_default_graph() 130 | 131 | with ops.name_scope("Binarized") as name: 132 | with g.gradient_override_map({"Sign": "Identity"}): 133 | return tf.sign(x) 134 | 135 | def gaussian(batch_size, n_dim, mean=0, var=1, n_labels=10, use_label_info=False): 136 | if use_label_info: 137 | if n_dim != 2: 138 | raise Exception("n_dim must be 2.") 139 | 140 | def sample(n_labels): 141 | x, y = np.random.normal(mean, var, (2,)) 142 | angle = np.angle((x - mean) + 1j * (y - mean), deg=True) 143 | 144 | label = (int(n_labels * angle)) // 360 145 | 146 | if label < 0: 147 | label += n_labels 148 | 149 | return np.array([x, y]).reshape((2,)), label 150 | 151 | z = np.empty((batch_size, n_dim), dtype=np.float32) 152 | z_id = np.empty((batch_size, 1), dtype=np.int32) 153 | for batch in xrange(batch_size): 154 | for zi in xrange(int(n_dim / 2)): 155 | a_sample, a_label = sample(n_labels) 156 | z[batch, zi * 2:zi * 2 + 2] = a_sample 157 | z_id[batch] = a_label 158 | return z, z_id 159 | else: 160 | z = np.random.normal(mean, var, (batch_size, n_dim)).astype(np.float32) 161 | return z -------------------------------------------------------------------------------- /utilsJittor.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import Function 3 | import numpy as np 4 | 5 | class SYTgather(Function): 6 | def execute(self, feature, index): 7 | # feature [bs, numV+1, numC] 8 | # index [numV, 10] 9 | # output [bs, numV, 10, numC] 10 | self.save_vars = feature, index 11 | bs, numV, numC = feature.shape 12 | numV -= 1 13 | 14 | return jt.code([bs,numV,10,numC], feature.dtype, [feature,index], 15 | cpu_header=""" 16 | #include 17 | using namespace std; 18 | """, 19 | cpu_src=''' 20 | for (int batchID=0; batchID < out_shape0; batchID++) { 21 | for (int vID = 0; vID < out_shape1; vID++) { 22 | for (int nbID = 0; nbID < out_shape2; nbID++) { 23 | for (int cID = 0; cID < out_shape3; cID++) { 24 | @out(batchID, vID, nbID, cID) = @in0(batchID, @in1(vID,nbID), cID); 25 | } 26 | } 27 | } 28 | } 29 | ''') 30 | 31 | def grad(self, grad_x): 32 | feature, index = self.save_vars 33 | 34 | return jt.code(feature.shape, feature.dtype, [feature, index, grad_x], 35 | cpu_header=""" 36 | #include 37 | #include 38 | using namespace std; 39 | """, 40 | cpu_src=''' 41 | vector> degrees(out_shape0, vector(out_shape1,0)); 42 | 43 | for (int batchID=0; batchID < out_shape0; batchID++) { 44 | for (int vID = 0; vID < out_shape1; vID++) { 45 | for (int cID = 0; cID < out_shape2; cID++) { 46 | @out(batchID, vID, cID) = 0; 47 | } 48 | } 49 | } 50 | 51 | for (int batchID=0; batchID < in2_shape0; batchID++) { 52 | for (int vID = 0; vID < in2_shape1; vID++) { 53 | for (int nbID = 0; nbID < in2_shape2; nbID++) { 54 | for (int cID = 0; cID < in2_shape3; cID++) { 55 | @out(batchID, @in1(vID,nbID), cID) += @in2(batchID, vID, nbID, cID); 56 | degrees[batchID][@in1(vID,nbID)] += 1; 57 | } 58 | } 59 | } 60 | } 61 | 62 | for (int batchID=0; batchID < out_shape0; batchID++) { 63 | for (int vID = 0; vID < out_shape1; vID++) { 64 | if (degrees[batchID][vID] > 0) { 65 | for (int cID = 0; cID < out_shape2; cID++) { 66 | @out(batchID, vID, cID) = @out(batchID, vID, cID) / degrees[batchID][vID]; 67 | } 68 | } 69 | } 70 | } 71 | 72 | ''') 73 | 74 | 75 | class binarize(Function): 76 | def execute(self, x): 77 | self.x = x 78 | return jt.nn.sign(x) 79 | 80 | def grad(self, grad): 81 | # print("grad:", grad.shape, grad) 82 | return grad 83 | 84 | 85 | def genEdge(nb): 86 | ''' 87 | generate nb (numV, 10) to Edge index (2, numEdge) 88 | ''' 89 | edge_index = [] 90 | for startID, nbIDs in enumerate(nb): 91 | nbIDs = nbIDs[nbIDs > 0] - 1 92 | edge_index += [np.array([startID, nbID]) for nbID in nbIDs] 93 | 94 | return np.stack(edge_index).transpose() 95 | 96 | 97 | class printGrad(Function): 98 | def execute(self, x): 99 | return x 100 | 101 | def grad(self, grad): 102 | print("grad:", grad.shape, grad) 103 | return grad --------------------------------------------------------------------------------