├── 3_6_requirements.txt ├── 3_8_requirements.txt ├── npz_to_obj.py ├── recon_config.yaml ├── .gitignore ├── README.md ├── walks.py ├── walks-standalone.py ├── visualization.py ├── attack_mesh.py ├── attack_single_mesh.py ├── params_setting.py ├── imitating_network_train.py ├── dataset.py ├── utils.py ├── evaluate_clustering.py ├── dataset_prepare.py └── rnn_model.py /3_6_requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==2.3.1 2 | tensorflow-addons==0.8.3 3 | open3d==0.8.0.0 4 | easydict 5 | h5py 6 | ipython 7 | matplotlib==3.* 8 | networkx==2.* 9 | psutil 10 | scikit-learn==0.22.* 11 | scipy==1.* 12 | tqdm 13 | trimesh==3.* 14 | pyvista==0.24.* 15 | pydot 16 | graphviz 17 | opencv-python==4.* 18 | pandas 19 | -------------------------------------------------------------------------------- /3_8_requirements.txt: -------------------------------------------------------------------------------- 1 | #tensorflow-gpu==2.* 2 | tf-nightly-gpu==2.5.0.dev20201028 3 | #tensorflow-addons==0.8.3 4 | tensorflow-addons 5 | #open3d==0.8.0.0 6 | open3d 7 | easydict 8 | h5py 9 | ipython 10 | matplotlib==3.* 11 | networkx==2.* 12 | psutil 13 | scikit-learn==0.22.* 14 | scipy==1.* 15 | tqdm 16 | trimesh==3.* 17 | pyvista==0.24.* 18 | pydot 19 | graphviz 20 | opencv-python==4.* 21 | pandas 22 | -------------------------------------------------------------------------------- /npz_to_obj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | attacked_models_dir = "./" 5 | 6 | for filename in os.listdir(attacked_models_dir): 7 | if not filename.endswith(".npz"): 8 | continue 9 | if len(filename.split('_T')[1].split('_')) > 1: 10 | curr_obj_name = "T" + filename.split('_T')[1].split('_')[0] 11 | else: 12 | curr_obj_name = "T" + filename.split('_T')[1].split('.')[0] 13 | mesh_data = np.load(filename, encoding='latin1', allow_pickle="True") 14 | #model_id = curr_obj_name[1:] 15 | #model_name = filename.split('_T')[0].split('_')[-1] 16 | obj_file = open(curr_obj_name + ".obj", "w") 17 | 18 | num_vertices = len(mesh_data['vertices']) 19 | num_faces = len(mesh_data['faces']) 20 | num_coords_texture = 0 21 | obj_file.write("####\n#\n# OBJ File Generated by Meshlab\n#\n####\n# Object " + curr_obj_name + ".obj\n#\n# Vertices: " + str(num_vertices) + "\n# Faces: " + str(num_faces) + "\n#\n####\n") 22 | for vertex in mesh_data['vertices']: 23 | obj_file.write("v "+ str(vertex[0]) + " " + str(vertex[1]) + " " + str(vertex[2]) + " \n") 24 | obj_file.write("# " + str(num_vertices) +" vertices, 0 vertices normals\n\n") 25 | for face in mesh_data['faces']: 26 | obj_file.write("f "+ str(face[0] + 1) + " " + str(face[1] + 1) + " " + str(face[2] + 1) + " \n") 27 | 28 | obj_file.write("# " + str(num_faces) + " faces, " + str(num_coords_texture) + " coords texture\n\n# End of File") 29 | obj_file.close() 30 | 31 | -------------------------------------------------------------------------------- /recon_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | #General Info 3 | gpu_to_use: 1 # -1 if we want to use all the GPUS 4 | use_prev_model: False 5 | description: "Mesh-net copycat network training" 6 | trained_model: 'trained_models/meshCNN_imitating_network' 7 | #trained_model: 'trained_models/pd_meshnet_ditto_model' 8 | 9 | # 'WALKER', 'MESHCNN', 'PDMESHNET', 'MESHNET' 10 | arch: 'MESHNET' 11 | # 'SHREC11', 'MODELNET40' 12 | dataset: 'MODELNET40' 13 | dataset_path: '' 14 | 15 | dump_timings: False 16 | x_server_exists: True 17 | 18 | 19 | job: 'shrec11' #'shrec11' #'mesh_net' 20 | job_part: '16-04_a' #'16_4' 21 | # choose network task from: 'features_extraction', 'unsupervised_classification', 'semantic_segmentation', 'classification', 'manifold_classification'. 22 | network_task: 'manifold_classification' 23 | trained_only_2_classes: False 24 | train_several_classes: False 25 | 26 | # Manifold params 27 | #'sparse_only' #'manifold_only' #'both' 28 | sparse_or_manifold: 'manifold_only' 29 | non_zero_ratio: 2 30 | 31 | # Deform Training params 32 | attacking_weight: 0.01 33 | max_label_diff: 0.001 34 | pred_close_enough_to_target: 0.9 35 | max_iter: 20_000 36 | iter_2_change_weight: 1_000 37 | show_model_every: 100_001 38 | 39 | walk_len: 800 40 | num_walks_per_iter: 1 41 | use_last: True 42 | 43 | # logger options 44 | image_save_iter: 100 # How often do you want to save output images during training 45 | plot_iter: 10 46 | image_display_iter: 100 # How often do you want to display output images during training 47 | display_size: 15 # How many images do you want to display each time 48 | snapshot_save_iter: 10_000 # How often do you want to save trained models 49 | log_iter: 1 # How often do you want to log the training stats 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MeshWalker: Deep Mesh Understanding by Random Walks 2 | Created by [Amir Belder](mailto:amirbelder5@gmail.com). 3 | 4 | Based on: https://arxiv.org/abs/2202.07453 5 | 6 | ## Installation 7 | In order to run the network you need to create and environment, we recommend a conda environment. 8 | You have 2 choices, a 3.6 python env, and a 3.8 python env. 9 | In both cases, go to the desired folder and: 10 | For 3.6: 11 | 'pip install -r 3_6_requirements.txt' 12 | For 3.8: 13 | 'pip install -r 3_8_requirements.txt' 14 | 15 | ## Source codes and Raw data 16 | The source code of each attacked network, along with its raw data can be found at: 17 | - [MeshCNN](https://github.com/ranahanocka/MeshCNN) 18 | - [MeshWalker](https://github.com/alonlahav/meshWalker) 19 | - [MeshNet](https://github.com/iMoonLab/MeshNet) 20 | - [Pd-MeshNet](https://github.com/MIT-SPARK/PD-MeshNet) 21 | 22 | ## Data 23 | We have 2 kinds of datasets, the raw of each network and thr adjusted walker datasets that were used to train the imitating networks. 24 | The raw datsets can be found at the links above. 25 | For each attacked network: 26 | - Our adjusted dataset is made out the raw data of each network. 27 | - On both the train set and test set, we took the vertices and faces of each network after its simplification. 28 | - We took the predicted labels of each of the train set meshes. 29 | - We created a dataset in the MeshWalker format using these above collected data. 30 | 31 | These datasets can be found here: 32 | Add link 33 | 34 | The datasets could also be created by using 'datasets_prepare.py', 35 | where you will find a function for each dataset. 36 | You may need to adjust the raw datsets paths according to where you saved them on your computer. 37 | Processing will rearrang dataset in `npz` files, labels included, vertex niebours added. 38 | 39 | Some of our results can be found [here]( https://cgm.technion.ac.il/Computer-Graphics-Multimedia/Software/MeshAdversarial/attacked_models_of_all_networks.zip). 40 | 41 | You can also download it from our [raw_datasets]() folder. 42 | Please run `bash ./scripts/download_raw_datasets.sh`. 43 | 44 | 45 | ### Processed 46 | To prepare the data, run `python dataset_prepare.py ` 47 | 48 | ## Training Imitating Networks 49 | ``` 50 | python imitating_network_train.py 51 | ``` 52 | In order to train the imitating networks you will have to change 3 values in the configuration YAML file: 53 | 54 | - arch: set to one of: 'WALKER', 'MESHCNN', 'PDMESHNET', 'MESHNET' 55 | - dataset: set to one of: 'SHREC11', 'MODELNET40' 56 | - dataset_path: According to where you chose to put the data 57 | 58 | Use tensorboard to show training results: `tensorboard ` 59 | 60 | 61 | ## Attacking 62 | To attack the data, run `python attack_mesh.py` 63 | In order to attack the different mesh models, you again need to set 2 values in the configuration file: 64 | - arch: set to one of: 'WALKER', 'MESHCNN', 'PDMESHNET', 'MESHNET' 65 | - dataset: set to one of: 'SHREC11', 'MODELNET40' 66 | 67 | Please notice that you have got to have the data you wish to attack and a trained imitating network. 68 | You may need to change the directories of these two inside the attack_mesh file, according to where you saved them on your computer. 69 | 70 | The attacked meshes can be found in the folder above the current working directory. 71 | They will be saved according to the different networks, i.e.: '../attacks/imitating_network_name' 72 | 73 | ## Pretrained 74 | All five pretrained imitating networks can be found here: [pretrained](https://technionmail-my.sharepoint.com/personal/alon_lahav_campus_technion_ac_il/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Falon%5Flahav%5Fcampus%5Ftechnion%5Fac%5Fil%2FDocuments%2Fmesh%5Fwalker%2Fpretrained) models to run evaluation only. 75 | 76 | ## Results 77 | In order to check how the attacked affected the original SOTA system we advise to save the changed vertices and faces in an obj file. 78 | Each of the SOTA systems uses these obj files while testing. 79 | And so, by changing them, we assure that the networks perform all its needed precprocessing. 80 | An example of such saving for SHREC11 can be found in the npz_to_obj.py file. 81 | This is not necessary at the Walker attacked files, as they are already in format. 82 | 83 | -------------------------------------------------------------------------------- /walks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from easydict import EasyDict 4 | import numpy as np 5 | 6 | import utils 7 | 8 | 9 | def jump_to_closest_unviseted(model_kdtree_query, model_n_vertices, walk, enable_super_jump=True): 10 | for nbr in model_kdtree_query[walk[-1]]: 11 | if nbr not in walk: 12 | return nbr 13 | 14 | if not enable_super_jump: 15 | return None 16 | 17 | # If not fouind, jump to random node 18 | node = np.random.randint(model_n_vertices) 19 | 20 | return node 21 | 22 | 23 | def get_seq_random_walk_no_jumps(mesh_extra, f0, seq_len): 24 | nbrs = mesh_extra['edges'] 25 | n_vertices = mesh_extra['n_vertices'] 26 | seq = np.zeros((seq_len + 1,), dtype=np.int32) 27 | jumps = np.zeros((seq_len + 1,), dtype=np.bool) 28 | visited = np.zeros((n_vertices + 1,), dtype=np.bool) 29 | visited[-1] = True 30 | visited[f0] = True 31 | seq[0] = f0 32 | jumps[0] = [True] 33 | backward_steps = 1 34 | for i in range(1, seq_len + 1): 35 | this_nbrs = nbrs[seq[i - 1]] 36 | nodes_to_consider = [n for n in this_nbrs if not visited[n]] 37 | if len(nodes_to_consider): 38 | to_add = np.random.choice(nodes_to_consider) 39 | jump = False 40 | else: 41 | if i > backward_steps: 42 | to_add = seq[i - backward_steps - 1] 43 | backward_steps += 2 44 | else: 45 | to_add = np.random.randint(n_vertices) 46 | jump = True 47 | seq[i] = to_add 48 | jumps[i] = jump 49 | visited[to_add] = 1 50 | 51 | return seq, jumps 52 | 53 | 54 | def get_seq_random_walk_random_global_jumps(mesh_extra, f0, seq_len): 55 | MAX_BACKWARD_ALLOWED = np.inf # 25 * 2 56 | nbrs = mesh_extra['edges'] 57 | n_vertices = mesh_extra['n_vertices'] 58 | seq = np.zeros((seq_len + 1,), dtype=np.int32) 59 | jumps = np.zeros((seq_len + 1,), dtype=np.bool) 60 | visited = np.zeros((n_vertices + 1,), dtype=np.bool) 61 | visited[-1] = True 62 | visited[f0] = True 63 | seq[0] = f0 64 | jumps[0] = [True] 65 | backward_steps = 1 66 | jump_prob = 1 / 100 67 | for i in range(1, seq_len + 1): 68 | this_nbrs = nbrs[seq[i - 1]] 69 | nodes_to_consider = [n for n in this_nbrs if not visited[n]] 70 | jump_now = np.random.binomial(1, jump_prob) or (backward_steps > MAX_BACKWARD_ALLOWED) 71 | if len(nodes_to_consider) and not jump_now: 72 | to_add = np.random.choice(nodes_to_consider) 73 | jump = False 74 | backward_steps = 1 75 | else: 76 | if i > backward_steps and not jump_now: 77 | to_add = seq[i - backward_steps - 1] 78 | backward_steps += 2 79 | else: 80 | backward_steps = 1 81 | to_add = np.random.randint(n_vertices) 82 | jump = True 83 | visited[...] = 0 84 | visited[-1] = True 85 | visited[to_add] = 1 86 | seq[i] = to_add 87 | jumps[i] = jump 88 | 89 | return seq, jumps 90 | 91 | 92 | def get_seq_random_walk_local_jumps(mesh_extra, f0, seq_len): 93 | n_vertices = mesh_extra['n_vertices'] 94 | kdtr = mesh_extra['kdtree_query'] 95 | seq = np.zeros((seq_len + 1, ), dtype=np.int32) 96 | jumps = np.zeros((seq_len + 1,), dtype=np.bool) 97 | seq[0] = f0 98 | visited = np.zeros((n_vertices + 1,), dtype=np.bool) 99 | visited[-1] = True 100 | visited[f0] = True 101 | for i in range(1, seq_len + 1): 102 | b = min(0, i - 20) 103 | to_consider = [n for n in kdtr[seq[i - 1]] if not visited[n]] 104 | if len(to_consider): 105 | seq[i] = np.random.choice(to_consider) 106 | jumps[i] = False 107 | else: 108 | seq[i] = np.random.randint(n_vertices) 109 | jumps[i] = True 110 | visited = np.zeros((n_vertices + 1,), dtype=np.bool) 111 | visited[-1] = True 112 | visited[seq[i]] = True 113 | 114 | return seq, jumps 115 | 116 | 117 | def get_mesh(): 118 | from dataset_prepare import prepare_edges_and_kdtree, load_mesh, remesh 119 | 120 | model_fn = os.path.expanduser('~') + '/datasets_processed/human_benchmark_sig_17/sig17_seg_benchmark/meshes/test/shrec/10.off' 121 | mesh = load_mesh(model_fn) 122 | mesh, _, _ = remesh(mesh, 4000) 123 | mesh = EasyDict({'vertices': np.asarray(mesh.vertices), 'faces': np.asarray(mesh.triangles), 'n_faces_orig': np.asarray(mesh.triangles).shape[0]}) 124 | prepare_edges_and_kdtree(mesh) 125 | mesh['n_vertices'] = mesh['vertices'].shape[0] 126 | return mesh 127 | 128 | def show_walk_on_mesh(): 129 | walk, jumps = get_seq_random_walk_no_jumps(mesh, f0=0, seq_len=400) 130 | vertices = mesh['vertices'] 131 | if 0: 132 | dxdydz = np.diff(vertices[walk], axis=0) 133 | for i, title in enumerate(['dx', 'dy', 'dz']): 134 | plt.subplot(3, 1, i + 1) 135 | plt.plot(dxdydz[:, i]) 136 | plt.ylabel(title) 137 | plt.suptitle('Walk features on Human Body') 138 | utils.visualize_model(mesh['vertices'], mesh['faces'], 139 | line_width=1, show_edges=1, edge_color_a='gray', 140 | show_vertices=False, opacity=0.8, 141 | point_size=4, all_colors='white', 142 | walk=walk, edge_colors='red') 143 | 144 | 145 | if __name__ == '__main__': 146 | utils.config_gpu(False) 147 | mesh = get_mesh() 148 | np.random.seed(1) 149 | show_walk_on_mesh() -------------------------------------------------------------------------------- /walks-standalone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyvista as pv 3 | import trimesh 4 | 5 | ###### ---------------------------------------------------------------- ###### 6 | 7 | def visualize_model(vertices, faces_, title=' ', walk=None, opacity=1.0, 8 | all_colors='white', face_colors=None, cmap=None, edge_colors=None, edge_color_a='white', 9 | line_width=1, show_edges=True): 10 | p = pv.Plotter() 11 | faces = np.hstack([[3] + f.tolist() for f in faces_]) 12 | surf = pv.PolyData(vertices, faces) 13 | p.add_mesh(surf, show_edges=show_edges, edge_color=edge_color_a, color=all_colors, opacity=opacity, smooth_shading=True, 14 | scalars=face_colors, cmap=cmap, line_width=line_width) 15 | all_edges = [[2, walk[i], walk[i + 1]] for i in range(len(walk) - 1)] 16 | walk_edges = np.hstack([edge for edge in all_edges]) 17 | walk_mesh = pv.PolyData(vertices, walk_edges) 18 | p.add_mesh(walk_mesh, show_edges=True, line_width=line_width * 4, edge_color=edge_colors) 19 | cpos = p.show(title=title) 20 | 21 | 22 | def get_mesh(): 23 | mesh = trimesh.load_mesh('/home/alonla/mesh_walker/datasets_raw/sig17_seg_benchmark/meshes/test/shrec/2.off') 24 | mesh_data = {'vertices': mesh.vertices, 'faces': mesh.faces, 'n_vertices': mesh.vertices.shape[0]} 25 | prepare_edges_and_kdtree(mesh_data) 26 | return mesh_data 27 | 28 | 29 | def prepare_edges_and_kdtree(mesh): 30 | vertices = mesh['vertices'] 31 | faces = mesh['faces'] 32 | mesh['edges'] = [set() for _ in range(vertices.shape[0])] 33 | for i in range(faces.shape[0]): 34 | for v in faces[i]: 35 | mesh['edges'][v] |= set(faces[i]) 36 | for i in range(vertices.shape[0]): 37 | if i in mesh['edges'][i]: 38 | mesh['edges'][i].remove(i) 39 | mesh['edges'][i] = list(mesh['edges'][i]) 40 | max_vertex_degree = np.max([len(e) for e in mesh['edges']]) 41 | for i in range(vertices.shape[0]): 42 | if len(mesh['edges'][i]) < max_vertex_degree: 43 | mesh['edges'][i] += [-1] * (max_vertex_degree - len(mesh['edges'][i])) 44 | mesh['edges'] = np.array(mesh['edges'], dtype=np.int32) 45 | 46 | mesh['kdtree_query'] = [] 47 | t_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) 48 | n_nbrs = min(10, vertices.shape[0] - 2) 49 | for n in range(vertices.shape[0]): 50 | d, i_nbrs = t_mesh.kdtree.query(vertices[n], n_nbrs) 51 | i_nbrs_cleared = [inbr for inbr in i_nbrs if inbr != n and inbr < vertices.shape[0]] 52 | if len(i_nbrs_cleared) > n_nbrs - 1: 53 | i_nbrs_cleared = i_nbrs_cleared[:n_nbrs - 1] 54 | mesh['kdtree_query'].append(np.array(i_nbrs_cleared, dtype=np.int32)) 55 | mesh['kdtree_query'] = np.array(mesh['kdtree_query']) 56 | assert mesh['kdtree_query'].shape[1] == (n_nbrs - 1), 'Number of kdtree_query is wrong: ' + str(mesh['kdtree_query'].shape[1]) 57 | 58 | ###### ---------------------------------------------------------------- ###### 59 | 60 | def jump_to_closest_unviseted(model_kdtree_query, model_n_vertices, walk, enable_super_jump=True): 61 | for nbr in model_kdtree_query[walk[-1]]: 62 | if nbr not in walk: 63 | return nbr 64 | 65 | if not enable_super_jump: 66 | return None 67 | 68 | # If not fouind, jump to random node 69 | node = np.random.randint(model_n_vertices) 70 | 71 | return node 72 | 73 | 74 | def get_seq_random_walk_no_jumps(mesh_extra, f0, seq_len): 75 | nbrs = mesh_extra['edges'] 76 | n_vertices = mesh_extra['n_vertices'] 77 | seq = np.zeros((seq_len + 1,), dtype=np.int32) 78 | jumps = np.zeros((seq_len + 1,), dtype=np.bool) 79 | visited = np.zeros((n_vertices + 1,), dtype=np.bool) 80 | visited[-1] = True 81 | visited[f0] = True 82 | seq[0] = f0 83 | jumps[0] = [True] 84 | backward_steps = 1 85 | for i in range(1, seq_len + 1): 86 | this_nbrs = nbrs[seq[i - 1]] 87 | nodes_to_consider = [n for n in this_nbrs if not visited[n]] 88 | if len(nodes_to_consider): 89 | to_add = np.random.choice(nodes_to_consider) 90 | jump = False 91 | else: 92 | if i > backward_steps: 93 | to_add = seq[i - backward_steps - 1] 94 | backward_steps += 2 95 | else: 96 | to_add = np.random.randint(n_vertices) 97 | jump = True 98 | seq[i] = to_add 99 | jumps[i] = jump 100 | visited[to_add] = 1 101 | 102 | return seq, jumps 103 | 104 | 105 | def get_seq_random_walk_random_global_jumps(mesh_extra, f0, seq_len): 106 | MEIR_WALK = 0 107 | nbrs = mesh_extra['edges'] 108 | n_vertices = mesh_extra['n_vertices'] 109 | seq = np.zeros((seq_len + 1,), dtype=np.int32) 110 | jumps = np.zeros((seq_len + 1,), dtype=np.bool) 111 | visited = np.zeros((n_vertices + 1,), dtype=np.bool) 112 | visited[-1] = True 113 | visited[f0] = True 114 | seq[0] = f0 115 | jumps[0] = [True] 116 | backward_steps = 1 117 | jump_prob = 1 / 100 118 | dont_check_visited_prob = 5 / 100 119 | for i in range(1, seq_len + 1): 120 | this_nbrs = nbrs[seq[i - 1]] 121 | if MEIR_WALK and np.random.binomial(1, dont_check_visited_prob): 122 | nodes_to_consider = this_nbrs 123 | else: 124 | nodes_to_consider = [n for n in this_nbrs if not visited[n]] 125 | jump_now = np.random.binomial(1, jump_prob) 126 | if len(nodes_to_consider) and not jump_now: 127 | to_add = np.random.choice(nodes_to_consider) 128 | jump = False 129 | backward_steps = 1 130 | else: 131 | if i > backward_steps and not jump_now: 132 | to_add = seq[i - backward_steps - 1] 133 | backward_steps += 2 134 | else: 135 | to_add = np.random.randint(n_vertices) 136 | jump = True 137 | visited[...] = 0 138 | visited[-1] = True 139 | visited[to_add] = 1 140 | seq[i] = to_add 141 | jumps[i] = jump 142 | 143 | return seq, jumps 144 | 145 | 146 | def get_seq_random_walk_local_jumps(mesh_extra, f0, seq_len): 147 | n_vertices = mesh_extra['n_vertices'] 148 | kdtr = mesh_extra['kdtree_query'] 149 | seq = np.zeros((seq_len + 1, ), dtype=np.int32) 150 | jumps = np.zeros((seq_len + 1,), dtype=np.bool) 151 | seq[0] = f0 152 | visited = np.zeros((n_vertices + 1,), dtype=np.bool) 153 | visited[-1] = True 154 | visited[f0] = True 155 | for i in range(1, seq_len + 1): 156 | b = min(0, i - 20) 157 | to_consider = [n for n in kdtr[seq[i - 1]] if not visited[n]] 158 | if len(to_consider): 159 | seq[i] = np.random.choice(to_consider) 160 | jumps[i] = False 161 | else: 162 | seq[i] = np.random.randint(n_vertices) 163 | jumps[i] = True 164 | visited = np.zeros((n_vertices + 1,), dtype=np.bool) 165 | visited[-1] = True 166 | visited[seq[i]] = True 167 | 168 | return seq, jumps 169 | 170 | #### -------------------------------------------------- ######### 171 | 172 | def show_walk_on_mesh(): 173 | mesh = get_mesh() 174 | walk, jumps = get_seq_random_walk_no_jumps(mesh, f0=0, seq_len=400) 175 | visualize_model(mesh['vertices'], mesh['faces'], line_width=1, show_edges=1, 176 | walk=walk, edge_colors='red') 177 | 178 | 179 | if __name__ == '__main__': 180 | show_walk_on_mesh() -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | #from utils import visualize_npz 2 | import csv, glob, os, json 3 | from easydict import EasyDict 4 | from dataset import load_model_from_npz 5 | import rnn_model, dataset 6 | import tensorflow as tf 7 | import numpy as np 8 | import utils 9 | from copy import deepcopy 10 | import dataset_prepare 11 | 12 | 13 | def show_walk(model, features, one_walk=False, weights=False, pred_cats=None, pred_val=None, labels=None, save_name=''): 14 | if weights is not False: 15 | walks = features[:,:,-1] 16 | for i, walk in enumerate(walks): 17 | name = '_rank_{}_weight_{:02d}%'.format(i+1, int(weights[i]*100)) 18 | if labels: 19 | pred_label=labels[pred_cats[i]] 20 | pred_score=pred_val[i] 21 | title='{}: {:2.3f}\n weight: {:2.3f}'.format(pred_label, pred_score, weights[i]) 22 | cur_color= 'cyan' if i < len(walks) //2 else 'magenta' #'cadetblue' #label2color[gt] 23 | rendered = utils.visualize_model(dataset.norm_model(model['vertices'], return_val=True), 24 | model['faces'], walk=[list(walk.astype(np.int32))], 25 | jump_indicator=features[i,:,-2], 26 | show_edges=True, 27 | opacity=0.5, 28 | all_colors=cur_color, 29 | edge_color_a='black', 30 | off_screen=True, save_fn=os.path.join(save_name, name), title=title) 31 | # TODO: save rendered to file 32 | else: 33 | for wi in range(features.shape[0]): 34 | walk = features[wi, :, -1].astype(np.int) 35 | jumps = features[wi, :, -2].astype(np.bool) 36 | utils.visualize_model_walk(model['vertices'], model['faces'], walk, jumps) 37 | if one_walk: 38 | break 39 | 40 | 41 | def load_params(logdir): 42 | 43 | # ================ Loading parameters ============== # 44 | if not os.path.exists(logdir): 45 | raise(ValueError, '{} is not a folder'.format(logdir)) 46 | try: 47 | with open(logdir + '/params.txt') as fp: 48 | params = EasyDict(json.load(fp)) 49 | params.net_input += ['vertex_indices'] 50 | params.batch_size = 1 51 | except: 52 | raise(ValueError, 'Could not load params.txt from logdir') 53 | # ================================================== # 54 | return params 55 | 56 | def load_model(params, model_fn=None): 57 | # ================ Loading architecture ============== # 58 | if not model_fn: 59 | model_fn = glob.glob(params.logdir + '/learned_model2keep__*.keras') 60 | model_fn.sort() 61 | model_fn = model_fn[-1] 62 | if params.net == 'HierTransformer': 63 | import attention_model 64 | dnn_model = attention_model.WalkHierTransformer(**params.net_params, params=params, 65 | model_fn=model_fn, model_must_be_load=True) 66 | else: 67 | dnn_model = rnn_model.RnnWalkNet(params, params.n_classes, params.net_input_dim - 1, 68 | model_fn, 69 | model_must_be_load=True, dump_model_visualization=False) 70 | return dnn_model 71 | 72 | 73 | def predict_and_plot(models, logdir, logdir2=None): 74 | models.sort() 75 | params = load_params(logdir) 76 | # load all npzs in folder of filelist 77 | test_folder = os.path.dirname(params.datasets2use['test'][0]) 78 | list_per_model = [glob.glob(test_folder + '/*{}*'.format(x)) for x in models] 79 | npzs = [item for sublist in list_per_model for item in sublist] 80 | test_dataset, n_models_to_test = dataset.tf_mesh_dataset(params, None, mode=params.network_task, 81 | shuffle_size=0, permute_file_names=False, must_run_on_all=True, 82 | filenames=npzs) 83 | dnn_model = load_model(params) 84 | if logdir2 is not None: 85 | params_2 = load_params(logdir2) 86 | dnn_model_2 = load_model(params_2) 87 | 88 | for i, data in enumerate(test_dataset): 89 | name, ftrs, gt = data 90 | ftrs = tf.reshape(ftrs, ftrs.shape[1:]) 91 | ftr2use = ftrs[:, :, :-1].numpy() 92 | gt = gt.numpy()[0] 93 | model_fn = name.numpy()[0].decode() 94 | # forward pass through the model 95 | if params.cross_walk_attn: 96 | predictions_, weights, per_walk_predictions_ = [x.numpy() for x in dnn_model(ftr2use, classify='visualize', training=False)] 97 | else: 98 | predictions_ = dnn_model(ftr2use, classify=True, training=False).numpy() 99 | if logdir2 is not None: 100 | if params_2.cross_walk_attn: 101 | predictions_2, weights2, per_walk_predictions_2 = [x.numpy() for x in 102 | dnn_model_2(ftr2use, classify='visualize', training=False)] 103 | else: 104 | predictions_2 = dnn_model_2(ftr2use, classify=True, training=False).numpy() 105 | if params.cross_walk_attn: 106 | # show only weights of walks where Alon's model failed 107 | # Showing walks with weighted attention - which walks recieved higher weights 108 | weights = weights.squeeze() 109 | if len(weights.shape) > 1: 110 | weights = np.sum(weights,axis=1) 111 | weights /= np.sum(weights) 112 | sorted_weights = np.argsort(weights)[::-1] 113 | sorted_features = ftrs.numpy()[sorted_weights] 114 | model = dataset.load_model_from_npz(model_fn) 115 | print(model_fn) 116 | print('nv: ', model['vertices'].shape[0]) 117 | per_walk_pred = np.argmax(per_walk_predictions_[sorted_weights], axis=1) 118 | per_walk_scores = [per_walk_predictions_[i, j] for i,j in zip(sorted_weights, per_walk_pred)] 119 | # if 'modelnet40' in any(params.datasets2use.values()): 120 | labels = dataset_prepare.model_net_labels 121 | save_dir=os.path.join(params.logdir, 'plots', model_fn.split('/')[-1].split('.')[0]) 122 | show_walk(model, sorted_features, weights=weights[sorted_weights], 123 | pred_cats=per_walk_pred, pred_val=per_walk_scores, labels=labels, 124 | save_name=save_dir) 125 | create_gif_from_preds(save_dir, title=model_fn.split('/')[-1].split('.')[0]) 126 | with open(save_dir + '/pred.txt', 'w') as f: 127 | f.write('Predicted: {}'.format(labels[np.argmax(predictions_)])) 128 | # TODO: write prediction_2 scores to see the difference in prediction 129 | 130 | 131 | 132 | 133 | def create_gif_from_preds(path, title=''): 134 | files = [os.path.join(path, x) for x in os.listdir(path) if x.endswith('.png')] 135 | if not len(files): 136 | print('Did not find any .png images in {}'.format(path)) 137 | sorted_indices = np.argsort([int(x.split('_')[-3]) for x in files]) 138 | files = [files[i] for i in sorted_indices] 139 | from PIL import Image 140 | ims = [Image.open(x) for x in files] 141 | ims[0].save(os.path.join(path, '{}_animated.gif'.format(title)), save_all=True, append_images=ims[1:], duration=1000) 142 | 143 | 144 | def compare_attention(): 145 | attn_csv = '/home/ran/mesh_walker/runs_compare/0168-03.12.2020..12.37__modelnet_multiwalk/False_preds_9250.csv' 146 | orig_csv = '/home/ran/mesh_walker/runs_compare/0095-23.11.2020..15.31__modelnet/False_preds_9222.csv' 147 | attn_models = [] 148 | orig_models = [] 149 | with open(attn_csv, 'r') as f: 150 | for row in f: 151 | attn_models.append(row.split(',')[0]) 152 | with open(orig_csv, 'r') as f: 153 | for row in f: 154 | orig_models.append(row.split(',')[0]) 155 | 156 | fixed = [x for x in orig_models if x not in attn_models] 157 | ruined = [x for x in attn_models if x not in orig_models] 158 | 159 | 160 | first10_each_class_fp = [glob.glob('/home/ran/mesh_walker/datasets/modelnet40_walker/test_{}*'.format(c)) for c in dataset_prepare.model_net_labels] 161 | first10_each_class = ['_'.join(x.split('_')[3:5]) for y in first10_each_class_fp for x in y[:10] if len(y[0].split('_')) ==8] 162 | first10_each_class += ['_'.join(x.split('_')[3:6]) for y in first10_each_class_fp for x in y[:10] if len(y[0].split('_')) == 9] 163 | 164 | # predict_and_plot(fixed, '/home/ran/mesh_walker/runs_compare/0168-03.12.2020..12.37__modelnet_multiwalk/') 165 | # predict_and_plot(orig_models, '/home/ran/mesh_walker/runs_compare/0168-03.12.2020..12.37__modelnet_multiwalk/') 166 | predict_and_plot(first10_each_class, '/home/ran/mesh_walker/runs_compare/0168-03.12.2020..12.37__modelnet_multiwalk/') 167 | # predict_and_plot(fixed, '/home/ran/mesh_walker/runs_compare/0095-23.11.2020..15.31__modelnet') 168 | # attn_corrected = ['bed_0558', 'bookshelf_0633', 'bottle_0416', ] 169 | 170 | 171 | if __name__ == '__main__': 172 | np.random.seed(4) 173 | compare_attention() 174 | # plot_attention() 175 | # npz_path = '/home/ran/mesh_walker/datasets/modelnet40_retrieval_split_0/train_desk_0018_000_simplified_to_4000.npz' 176 | # visualize_npz(npz_path) 177 | -------------------------------------------------------------------------------- /attack_mesh.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import utils 3 | 4 | #get hyper params from yaml 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--config', type=str, default='recon_config.yaml', help='Path to the config file.') 7 | opts = parser.parse_args() 8 | config = utils.get_config(opts.config) 9 | 10 | import numpy as np 11 | import os 12 | import re 13 | import attack_single_mesh 14 | 15 | #All data sets paths 16 | meshCnn_and_Pd_meshNet_shrec_path ='datasets_processed/meshCNN_and_PD_meshNet_source_data/' 17 | meshCnn_shrec_vertices_and_faces = 'datasets_processed/meshCNN_faces_vertices_labels/' 18 | pd_MeshNet_shrec_vertices_and_faces = 'datasets_processed/Pd_meshNet_faces_vertices_labels/' 19 | meshWalker_shrec_path = 'datasets_processed/walker_copycat_shrec11/' 20 | meshWalker_model_net_path = 'datasets_processed/walker_copycat_modelnet40/' 21 | mesh_net_path = 'datasets_processed/mesh_net_modelnet40/' 22 | 23 | if config['gpu_to_use'] >= 0: 24 | utils.set_single_gpu(config['gpu_to_use']) 25 | 26 | mesh_net_labels = ['night_stand', 'range_hood', 'plant', 'chair', 'tent', 27 | 'curtain', 'piano', 'dresser', 'desk', 'bed', 28 | 'sink', 'laptop', 'flower_pot', 'car', 'stool', 29 | 'vase', 'monitor', 'airplane', 'stairs', 'glass_box', 30 | 'bottle', 'guitar', 'cone', 'toilet', 'bathtub', 31 | 'wardrobe', 'radio', 'person', 'xbox', 'bowl', 32 | 'cup', 'door', 'tv_stand', 'mantel', 'sofa', 33 | 'keyboard', 'bookshelf', 'bench', 'table', 'lamp'] 34 | 35 | 36 | walker_shrec11_labels = [ 37 | 'armadillo', 'man', 'centaur', 'dinosaur', 'dog2', 38 | 'ants', 'rabbit', 'dog1', 'snake', 'bird2', 39 | 'shark', 'dino_ske', 'laptop', 'santa', 'flamingo', 40 | 'horse', 'hand', 'lamp', 'two_balls', 'gorilla', 41 | 'alien', 'octopus', 'cat', 'woman', 'spiders', 42 | 'camel', 'pliers', 'myScissor', 'glasses', 'bird1' 43 | ] 44 | 45 | walker_model_net_labels = [ 46 | 'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet', 47 | 'wardrobe', 'bookshelf', 'laptop', 'door', 'lamp', 'person', 'curtain', 'piano', 'airplane', 'cup', 48 | 'cone', 'tent', 'radio', 'stool', 'range_hood', 'car', 'sink', 'guitar', 'tv_stand', 'stairs', 49 | 'mantel', 'bench', 'plant', 'bottle', 'bowl', 'flower_pot', 'keyboard', 'vase', 'xbox', 'glass_box' 50 | ] 51 | 52 | meshCNN_and_Pd_meshNet_shrec11_labels = [ 53 | 'armadillo', 'man', 'centaur', 'dinosaur', 'dog2', 54 | 'ants', 'rabbit', 'dog1', 'snake', 'bird2', 55 | 'shark', 'dino_ske', 'laptop', 'santa', 'flamingo', 56 | 'horse', 'hand', 'lamp', 'two_balls', 'gorilla', 57 | 'alien', 'octopus', 'cat', 'woman', 'spiders', 58 | 'camel', 'pliers', 'myScissor', 'glasses', 'bird1' 59 | ] 60 | meshCNN_and_Pd_meshNet_shrec11_labels.sort() 61 | 62 | 63 | def get_dataset_path(config = None): 64 | if config is None: 65 | exit("Your configuration file is None... Exiting") 66 | if config['arch'] == 'WALKER' and config['dataset'] == 'MODELNET40': 67 | config['trained_model'] = 'trained_models/walker_modelnet_imitating_network' 68 | return meshWalker_model_net_path 69 | elif config['arch'] == 'MESHNET' and config['dataset'] == 'MODELNET40': 70 | config['trained_model'] = 'trained_models/mesh_net_imitating_network' 71 | return mesh_net_path 72 | if config['arch'] == 'WALKER' and config['dataset'] == 'SHREC11': 73 | config['trained_model'] = 'trained_models/walker_shrec11_imitating_network' 74 | return meshWalker_shrec_path 75 | elif config['arch'] == 'MESHCNN' and config['dataset'] == 'SHREC11': 76 | config['trained_model'] = 'trained_models/meshCNN_imitating_network' 77 | return meshCnn_shrec_vertices_and_faces 78 | elif config['arch'] == 'PDMESHNET' and config['dataset'] == 'SHREC11': 79 | config['trained_model'] = 'trained_models/pd_meshnet_imitating_network' 80 | return pd_MeshNet_shrec_vertices_and_faces 81 | else: 82 | exit("Please provide a valid dataset name in recon file.") 83 | 84 | 85 | def attack_mesh_net_models(config=None): 86 | if config is None: 87 | return 88 | dataset_path = get_dataset_path(config=config) 89 | 90 | for i in range(2, 40): 91 | config['source_label'] = i 92 | name_of_class = mesh_net_labels[config['source_label']] 93 | model_net_files_to_attack = [file for file in os.listdir(path=dataset_path) if file.__contains__('test') and file.__contains__(name_of_class)] 94 | 95 | for model_name in model_net_files_to_attack: 96 | if str(model_name).__contains__('attacked'): 97 | continue 98 | num_of_models = [name for name in model_net_files_to_attack if name.__contains__(model_name[0:-4])] 99 | if len(num_of_models) > 1: 100 | continue 101 | name_parts = re.split(pattern='_', string=model_name) 102 | name_parts = [name for name in name_parts if name.isnumeric()] 103 | id = name_parts[0] #name_parts[1] + '_' + name_parts[2] + '_' + name_parts[-1][:-4] 104 | _ = attack_single_mesh.attack_single_mesh(config=config, source_mesh=dataset_path+model_name, id=id, labels=mesh_net_labels) 105 | 106 | return 107 | 108 | 109 | def attack_walker_model_net_models(config=None): 110 | if config is None: 111 | return 112 | dataset_path = get_dataset_path(config=config) 113 | 114 | for i in range(0, 40): 115 | config['source_label'] = i 116 | name_of_class = walker_model_net_labels[config['source_label']] 117 | model_net_files_to_attack = [file for file in os.listdir(path=dataset_path) if file.__contains__('test') and file.__contains__(name_of_class)] 118 | 119 | for model_name in model_net_files_to_attack: 120 | if str(model_name).__contains__('attacked'): 121 | continue 122 | num_of_models = [name for name in model_net_files_to_attack if name.__contains__(model_name[0:-4])] 123 | if len(num_of_models) > 1: 124 | continue 125 | name_parts = re.split(pattern='_', string=model_name) 126 | name_parts[-1] = name_parts[-1][:-4] 127 | id = '' 128 | for i in range(len(name_parts)): 129 | if name_parts[i].isdigit(): 130 | id = id +'_' + name_parts[i] 131 | 132 | _ = attack_single_mesh.attack_single_mesh(config=config, source_mesh=dataset_path+model_name, id=id, labels=walker_model_net_labels) 133 | 134 | return 135 | 136 | 137 | def attack_meshCNN_shrec11_models(config = None): 138 | if config is None: 139 | return 140 | dataset_path = get_dataset_path(config=config) 141 | 142 | for i in range(0, 30): 143 | config['source_label'] = i 144 | name_of_class = meshCNN_and_Pd_meshNet_shrec11_labels[config['source_label']] 145 | files_to_attack = [file for file in os.listdir(path=dataset_path) if file.__contains__('test') and file.__contains__(name_of_class)] 146 | 147 | for model_name in files_to_attack: 148 | if str(model_name).__contains__('attacked'): 149 | continue 150 | num_of_models = [name for name in files_to_attack if name.__contains__(model_name[0:-4])] 151 | if len(num_of_models) > 1: 152 | continue 153 | name_parts = re.split(pattern='_', string=model_name) 154 | id = name_parts[2] 155 | _ = attack_single_mesh.attack_single_mesh(config=config, source_mesh=dataset_path+model_name, id=id+'cnn', labels=meshCNN_and_Pd_meshNet_shrec11_labels) 156 | 157 | return 158 | 159 | 160 | def attack_Pd_meshNet_shrec11_models(config = None): 161 | if config is None: 162 | return 163 | dataset_path = get_dataset_path(config=config) 164 | 165 | for i in range(0, 30): 166 | config['source_label'] = i 167 | name_of_class = meshCNN_and_Pd_meshNet_shrec11_labels[config['source_label']] 168 | files_to_attack = [file for file in os.listdir(path=dataset_path) if file.__contains__('test') and file.__contains__(name_of_class)] 169 | 170 | for model_name in files_to_attack: 171 | if str(model_name).__contains__('attacked'): 172 | continue 173 | num_of_models = [name for name in files_to_attack if name.__contains__(model_name[0:-4])] 174 | if len(num_of_models) > 1: 175 | continue 176 | name_parts = re.split(pattern='_', string=model_name) 177 | id = name_parts[2] 178 | _ = attack_single_mesh.attack_single_mesh(config=config, source_mesh=dataset_path+model_name, id=id+'_pd', labels=meshCNN_and_Pd_meshNet_shrec11_labels) 179 | 180 | return 181 | 182 | 183 | def attack_walker_shrec11_models(config = None): 184 | if config is None: 185 | return 186 | dataset_path = get_dataset_path(config=config) 187 | 188 | for i in range(0, 30): 189 | config['source_label'] = i 190 | name_of_class = walker_shrec11_labels[config['source_label']] 191 | files_to_attack = [file for file in os.listdir(path=dataset_path) if file.__contains__('test') and file.__contains__(name_of_class)] 192 | 193 | for model_name in files_to_attack: 194 | if str(model_name).__contains__('attacked'): 195 | continue 196 | num_of_models = [name for name in files_to_attack if name.__contains__(model_name[0:-4])] 197 | if len(num_of_models) > 1: 198 | continue 199 | name_parts = re.split(pattern='_', string=model_name) 200 | id = name_parts[2] 201 | _ = attack_single_mesh.attack_single_mesh(config=config, source_mesh=dataset_path+model_name, id=id, labels=meshCNN_and_Pd_meshNet_shrec11_labels) 202 | 203 | return 204 | 205 | 206 | def main(): 207 | np.random.seed(0) 208 | utils.config_gpu(1, -1) 209 | #attack_meshCNN_shrec11_models(config=config) 210 | #attack_Pd_meshNet_shrec11_models(config=config) 211 | #attack_walker_shrec11_models(config=config) 212 | #attack_walker_model_net_models(config=config) 213 | attack_mesh_net_models(config=config) 214 | 215 | return 0 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /attack_single_mesh.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import utils 3 | 4 | #get hyper params from yaml 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--config', type=str, default='recon_config.yaml', help='Path to the config file.') 7 | opts = parser.parse_args() 8 | config = utils.get_config(opts.config) 9 | 10 | if config['gpu_to_use'] >= 0: 11 | utils.set_single_gpu(config['gpu_to_use']) 12 | 13 | import os, shutil, time 14 | from easydict import EasyDict 15 | import json 16 | import cv2 17 | import numpy as np 18 | import tensorflow as tf 19 | import pyvista as pv 20 | import pylab as plt 21 | import rnn_model 22 | import utils 23 | import dataset 24 | import dataset_prepare 25 | 26 | 27 | def dump_mesh(mesh_data, path, cpos, iter, x_server_exists): 28 | """ 29 | Saves a picture of the mesh 30 | """ 31 | if not os.path.isdir(path): 32 | os.makedirs(path) 33 | if x_server_exists: 34 | window_size = [512, 512] 35 | p = pv.Plotter(off_screen=1, window_size=(int(window_size[0]), int(window_size[1]))) 36 | faces = np.hstack([[3] + f.tolist() for f in mesh_data['faces']]) 37 | surf = pv.PolyData(mesh_data['vertices'], faces) 38 | p.add_mesh(surf, show_edges=False, color=None) 39 | p.camera_position = cpos 40 | p.set_background("#AAAAAA", top="White") 41 | rendered = p.screenshot() 42 | p.close() 43 | img = rendered.copy() 44 | my_text = str(iter) 45 | cv2.putText(img, my_text, (img.shape[1] - 100, img.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, 46 | color=(0, 255, 255), thickness=2) 47 | cv2.imwrite(path + '/img_' + str(dump_mesh.i).zfill(5) + '.jpg', img) 48 | dump_mesh.i += 1 49 | dump_mesh.i = 0 50 | 51 | 52 | def deform_add_fields_and_dump_model(mesh_data, fileds_needed, out_fn, dump_model=True): 53 | """ 54 | Saves the new attacked mesh model 55 | """ 56 | m = {} 57 | for k, v in mesh_data.items(): 58 | if k in fileds_needed: 59 | m[k] = v 60 | for field in fileds_needed: 61 | if field not in m.keys(): 62 | if field == 'labels_fuzzy': 63 | m[field] = np.zeros((0,)) 64 | if field == 'walk_cache': 65 | m[field] = np.zeros((0,)) 66 | if field == 'kdtree_query': 67 | dataset_prepare.prepare_edges_and_kdtree(m) 68 | 69 | if dump_model: 70 | np.savez(out_fn, **m) 71 | 72 | 73 | def get_res_path(config, id = -1, labels = None): 74 | """ 75 | Sets the result path in which the mesh and its pictures are saved 76 | """ 77 | if labels is None: 78 | exit("Error, no shrec labels") 79 | 80 | net_name = config['trained_model'].split('/')[-1] 81 | if len(net_name) > 0: 82 | res_path = '../attacks/' + net_name +'/' + labels[config['source_label']] 83 | else: 84 | res_path = '../attacks/' + labels[config['source_label']] 85 | if id != -1: 86 | res_path+= '_'+str(id) 87 | return res_path, net_name 88 | 89 | 90 | def plot_preditions(params, dnn_model, config, mesh_data, result_path, num_iter, x_axis, source_pred_list): 91 | """ 92 | Saves as image a graph of the prediction of the network on the changed mesh 93 | """ 94 | params.n_walks_per_model = 16 95 | features, labels = dataset.mesh_data_to_walk_features(mesh_data, params) 96 | ftrs = tf.cast(features[:, :, :3], tf.float32) 97 | eight_pred = dnn_model(ftrs, classify=True, training=False) 98 | sum_pred = tf.reduce_sum(eight_pred, 0) 99 | print("source_label number ", config['source_label'], " over " + str(params.n_walks_per_model) + " runs is: ", (sum_pred.numpy())[config['source_label']] / params.n_walks_per_model) 100 | source_pred_list.append((sum_pred.numpy())[config['source_label']] / params.n_walks_per_model) 101 | params.n_walks_per_model = 8 102 | 103 | if not os.path.isdir(result_path + '/plots/'): 104 | os.makedirs(result_path + '/plots/') 105 | # plot the predictions 106 | x_axis.append(num_iter) 107 | 108 | plt.plot(x_axis, source_pred_list) 109 | plt.title(str(config['source_label']) + ": source pred") 110 | plt.savefig(result_path + '/plots/' + 'source_pred.png') 111 | plt.close() 112 | return 113 | 114 | 115 | def define_network_and_its_params(config=None): 116 | """ 117 | Defining the parameters of the network, called params, and loads the trained model, called dnn_model 118 | """ 119 | with open(config['trained_model'] + '/params.txt') as fp: 120 | params = EasyDict(json.load(fp)) 121 | model_fn = tf.train.latest_checkpoint(config['trained_model']) 122 | # Define network parameters 123 | params.batch_size = 1 124 | params.seq_len = config['walk_len'] 125 | params.n_walks_per_model = 8 126 | params.set_seq_len_by_n_faces = False 127 | params.data_augmentaion_vertices_functions = [] 128 | params.label_per_step = False 129 | params.n_target_vrt_to_norm_walk = 0 130 | params.net_input += ['vertex_indices'] 131 | dataset.setup_features_params(params, params) 132 | dataset.mesh_data_to_walk_features.SET_SEED_WALK = False 133 | 134 | dnn_model = rnn_model.RnnManifoldWalkNet(params, params.n_classes, 3, model_fn, 135 | model_must_be_load=True, dump_model_visualization=False) 136 | 137 | return params, dnn_model 138 | 139 | 140 | def attack_single_mesh(config = None, source_mesh = None, id = -1, labels = None): 141 | if labels is None or config is None: 142 | exit(-1) 143 | 144 | # Defining network's parameters and model 145 | network_params, network_dnn_model = define_network_and_its_params(config=config) 146 | 147 | # Defining output path 148 | result_path, net_name = get_res_path(config=config, id=id, labels=labels) 149 | print("source label: ", config['source_label'], " output dir: ", result_path) 150 | if os.path.isdir(result_path) and config['use_last'] is False: 151 | shutil.rmtree(result_path) 152 | 153 | # Defining original mesh data - Either use the last saved in the folder or the original one 154 | orig_mesh_data_path = source_mesh 155 | if config['use_last'] is True: 156 | if os.path.exists(result_path+'/'+'last_model.npz'): # A previous model exists 157 | orig_mesh_data_path = result_path + '/last_model.npz' 158 | elif os.path.exists(source_mesh[0:-4] + '_attacked.npz'): # A previous attacked model exists 159 | orig_mesh_data_path = source_mesh[0:-4] + '_attacked.npz' 160 | 161 | orig_mesh_data = np.load(orig_mesh_data_path, encoding='latin1', allow_pickle=True) 162 | mesh_data = {k: v for k, v in orig_mesh_data.items()} 163 | 164 | # Defining parameters that keep track of the changes 165 | loss = [] 166 | cpos = None 167 | last_dev_res = 0 168 | last_plt_res = 0 169 | fields_needed = ['vertices', 'faces', 'edges', 'kdtree_query', 'label', 'labels', 'dataset_name', 'labels_fuzzy'] 170 | source_pred_list = [] 171 | x_axis = [] 172 | vertices_counter = np.ones(mesh_data['vertices'].shape) 173 | vertices_gradient_change_sum = np.zeros(mesh_data['vertices'].shape) 174 | num_times_wrong_classification = 0 175 | 176 | # Defining the attack 177 | kl_divergence_loss = tf.keras.losses.KLDivergence() 178 | w = config['attacking_weight'] 179 | if config['dataset'] == 'SHREC11': 180 | one_hot_original_label_vetor = tf.one_hot(config['source_label'], 30) 181 | elif config['dataset'] == 'MODELNET40': 182 | one_hot_original_label_vetor = tf.one_hot(config['source_label'], 40) 183 | else: 184 | one_hot_original_label_vetor = config['source_label'] 185 | 186 | 187 | # Time measurment parameter 188 | start_time_100_iters = time.time() 189 | 190 | for num_iter in range(config['max_iter']): 191 | # Extract features and labels 192 | features, labels = dataset.mesh_data_to_walk_features(mesh_data, network_params) 193 | ftrs = tf.cast(features[:, :, :3], tf.float32) # The walks features 194 | v_indices = features[0, :, 3].astype(np.int) # the vertices indices of the walk 195 | 196 | with tf.GradientTape() as tape: 197 | tape.watch(ftrs) 198 | pred = network_dnn_model(ftrs, classify=True, training=False) 199 | 200 | # Produce the attack 201 | attack = -1 * w * kl_divergence_loss(one_hot_original_label_vetor, pred) 202 | 203 | # Check the prediction of the network 204 | pred = tf.reduce_sum(pred, 0) 205 | pred /= network_params.n_walks_per_model 206 | source_pred_brfore_attack = (pred.numpy())[config['source_label']] 207 | 208 | gradients = tape.gradient(attack, ftrs) 209 | ftrs_after_attack_update = ftrs + gradients 210 | 211 | new_pred = network_dnn_model(ftrs_after_attack_update, classify=True, training=False) 212 | new_pred = tf.reduce_sum(new_pred, 0) 213 | new_pred /= network_params.n_walks_per_model 214 | 215 | # Check to see that we didn't update too much 216 | # We don't want the change to be too big, as it may result in intersections. 217 | # And so, we check to see if the change caused us to get closer to the target by more than 0.01. 218 | # If so, we will divide the change so it won't change more than 0.01 219 | source_pred_after_attack = (new_pred.numpy())[config['source_label']] 220 | source_pred_abs_diff = abs(source_pred_brfore_attack - source_pred_after_attack) 221 | 222 | if source_pred_abs_diff > config['max_label_diff'] : 223 | # We update the gradients accordingly 224 | ratio = config['max_label_diff'] / source_pred_abs_diff 225 | gradients = gradients * ratio 226 | 227 | print("iter:", num_iter, " attack:", attack.numpy(), " w:", w, " source prec:", (pred.numpy())[config['source_label']], 228 | " max label:", np.argmax(pred)) 229 | 230 | if np.argmax(pred) != config['source_label']: 231 | num_times_wrong_classification += 1 232 | else: 233 | num_times_wrong_classification = 0 234 | 235 | loss.append(attack.numpy()) 236 | vertices_counter[v_indices] += 1 237 | vertices_gradient_change_sum[v_indices] += gradients[0].numpy() 238 | 239 | # Updating the mesh itself 240 | change = vertices_gradient_change_sum/vertices_counter 241 | mesh_data['vertices'] += change 242 | 243 | # If we got the wrong classification 10 times straight 244 | if num_times_wrong_classification > 10 * config['num_walks_per_iter']: 245 | if num_iter < 15: 246 | print("\n\nExiting.. Wrong model was loaded / Wrong labels were compared\n\n\n") 247 | return num_iter 248 | path = source_mesh if source_mesh is not None else None 249 | if result_path.__contains__('_meshCNN'): 250 | deform_add_fields_and_dump_model(mesh_data=mesh_data, fileds_needed=fields_needed, 251 | out_fn=path[0:-4] + 'meshCNN_attacked.npz') 252 | else: 253 | deform_add_fields_and_dump_model(mesh_data=mesh_data, fileds_needed=fields_needed, 254 | out_fn=path[0:-4] + '_attacked.npz') 255 | return num_iter 256 | 257 | # Saving pictures of the models 258 | if num_iter % 100 == 0: 259 | total_time_100_iters = time.time() - start_time_100_iters 260 | start_time_100_iters = time.time() 261 | 262 | preds_to_print_str = '' 263 | print('\n' + str(net_name) + '\n' + preds_to_print_str +'\n' 264 | + 'Time took for 100 iters: '+ str(total_time_100_iters) +'\n') 265 | 266 | curr_save_image_iter = num_iter - (num_iter % config['image_save_iter']) 267 | if curr_save_image_iter / config['image_save_iter'] >= last_dev_res + 1 or num_iter == 0: 268 | print(result_path) 269 | cpos = dump_mesh(mesh_data, result_path, cpos, num_iter, config['x_server_exists']) 270 | last_dev_res = num_iter / config['image_save_iter'] 271 | deform_add_fields_and_dump_model(mesh_data=mesh_data, fileds_needed=fields_needed, out_fn=result_path + '/last_model.npz')#"+ str(num_iter)) 272 | 273 | 274 | curr_plot_iter = num_iter - (num_iter % config['plot_iter']) 275 | if curr_plot_iter / config['plot_iter'] >= last_plt_res + 1 or num_iter == 0: 276 | plot_preditions(network_params, network_dnn_model, config, mesh_data, result_path, num_iter, x_axis, source_pred_list) 277 | last_plt_res = num_iter / config['plot_iter'] 278 | 279 | 280 | if config['show_model_every'] > 0 and num_iter % config['show_model_every'] == 0 and num_iter > 0: 281 | plt.plot(loss) 282 | plt.show() 283 | utils.visualize_model(mesh_data['vertices'], mesh_data['faces']) 284 | 285 | #res_path = config['result_path'] 286 | cmd = f'ffmpeg -framerate 24 -i {result_path}img_%05d.jpg {result_path}mesh_reconstruction.mp4' 287 | os.system(cmd) 288 | return 289 | 290 | 291 | def main(): 292 | np.random.seed(0) 293 | utils.config_gpu(1, config['gpu_to_use']) 294 | attack_single_mesh(config=config, labels=dataset_prepare.model_net_labels) 295 | 296 | return 0 297 | 298 | if __name__ == '__main__': 299 | main() 300 | -------------------------------------------------------------------------------- /params_setting.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from easydict import EasyDict 4 | import numpy as np 5 | 6 | import utils 7 | import dataset_prepare 8 | 9 | if 0: 10 | MAX_AUGMENTATION = 90 11 | run_folder = 'runs_test' 12 | elif 0: 13 | MAX_AUGMENTATION = 45 14 | run_folder = 'runs_aug_45' 15 | else: 16 | MAX_AUGMENTATION = 360 17 | run_folder = 'runs_aug_360_must' 18 | 19 | 20 | def use_pretrained_model(config, run_name): 21 | import json 22 | import tensorflow as tf 23 | with open(config['trained_model'] + '/params.txt') as fp: 24 | params = EasyDict(json.load(fp)) 25 | params.net_start_from_prev_net = tf.train.latest_checkpoint(config['trained_model']) 26 | params.logdir = utils.get_run_folder(params.run_root_path + '/', '__' + run_name, params.cont_run_number) 27 | params.model_fn = params.logdir + '/learned_model.keras' 28 | 29 | return params 30 | 31 | 32 | def set_up_default_params(network_task, run_name, cont_run_number=0, config = None): 33 | ''' 34 | Define dafault parameters, commonly for many test case 35 | ''' 36 | if config is not None: 37 | if config['use_prev_model'] is True: 38 | return use_pretrained_model(config, run_name) 39 | 40 | params = EasyDict() 41 | params.dataset = run_name 42 | params.cont_run_number = cont_run_number 43 | params.run_root_path = os.path.expanduser('~') + '/mesh_walker/' + run_folder 44 | params.logdir = utils.get_run_folder(params.run_root_path + '/', '__' + run_name, params.cont_run_number) 45 | params.model_fn = params.logdir + '/learned_model.keras' 46 | 47 | # Optimizer params 48 | params.optimizer_type = 'cycle' # sgd / adam / cycle 49 | params.learning_rate_dynamics = 'cycle' 50 | params.cycle_opt_prms = EasyDict({'initial_learning_rate': 1e-6, 51 | 'maximal_learning_rate': 1e-4, 52 | 'step_size': 10000}) 53 | params.n_models_per_test_epoch = 300 54 | params.gradient_clip_th = 1 55 | 56 | # Dataset params 57 | params.classes_indices_to_use = None 58 | params.train_dataset_size_limit = np.inf 59 | params.test_dataset_size_limit = np.inf 60 | params.network_task = network_task 61 | params.normalize_model = True 62 | params.sub_mean_for_data_augmentation = True 63 | params.datasets2use = {} 64 | params.test_data_augmentation = {} 65 | params.train_data_augmentation = {} 66 | params.aditional_network_params = [] 67 | params.cut_walk_at_deadend = False 68 | 69 | params.network_tasks = [params.network_task] 70 | params.features_extraction = False 71 | if params.network_task == 'classification': 72 | params.n_walks_per_model = 1 73 | # Amir - changed to False to see what happens 74 | params.one_label_per_model = True 75 | params.train_loss = ['cros_entr'] 76 | params.net = 'RnnWalkNet' 77 | elif params.network_task == 'manifold_classification': 78 | params.n_walks_per_model = 1 79 | params.one_label_per_model = True 80 | params.train_loss = ['manifold_cros_entr'] 81 | params.net = 'Manifold_RnnWalkNet' 82 | elif params.network_task == 'semantic_segmentation': 83 | params.n_walks_per_model = 4 84 | params.one_label_per_model = False 85 | params.train_loss = ['cros_entr'] 86 | elif params.network_task == 'unsupervised_classification': 87 | params.n_walks_per_model = 2 88 | params.one_label_per_model = True 89 | params.train_loss = ['triplet'] 90 | params.net = 'Unsupervised_RnnWalkNet' 91 | elif params.network_task == 'features_extraction': 92 | params.n_walks_per_model = 2 93 | params.one_label_per_model = True 94 | params.train_loss = ['triplet'] 95 | else: 96 | raise Exception('Unsupported params.network_task: ' + params.network_task) 97 | params.batch_size = int(32 / params.n_walks_per_model) 98 | 99 | # Other params 100 | params.log_freq = 100 101 | params.walk_alg = 'random_global_jumps' # no_repeat / no_jumps / fast / fastest / only_jumps / local_jumps / no_local_jumps 102 | params.net_input = ['xyz'] # 'xyz', 'dxdydz', 'jump_indication' 103 | params.reverse_walk = False 104 | params.train_min_max_faces2use = [0, np.inf] 105 | params.test_min_max_faces2use = [0, np.inf] 106 | params.last_layer_actication = 'softmax' 107 | params.use_norm_layer = 'InstanceNorm' # BatchNorm / InstanceNorm / None 108 | params.layer_sizes = None 109 | 110 | params.initializers = 'orthogonal' 111 | params.adjust_vertical_model = False 112 | if config is not None and config['use_prev_model'] is True: 113 | import tensorflow as tf 114 | params.net_start_from_prev_net = tf.train.latest_checkpoint(config['trained_model']) 115 | else: 116 | params.net_start_from_prev_net = None 117 | 118 | 119 | params.net_gru_dropout = 0 120 | params.uniform_starting_point = False 121 | params.train_max_size_per_class = None # None / 'uniform_as_max_class' / 122 | 123 | params.full_accuracy_test = None 124 | 125 | params.iters_to_train = 60e3 126 | 127 | return params 128 | 129 | # Classifications 130 | # --------------- 131 | def modelnet_params(network_task, config=None): 132 | params = set_up_default_params(network_task, 'modelnet', 0, config) 133 | params.n_classes = 40 134 | 135 | params.cycle_opt_prms = EasyDict({'initial_learning_rate': 1e-6, 136 | 'maximal_learning_rate': 0.0005, 137 | 'step_size': 10000}) 138 | 139 | p = 'modelnet40' 140 | params.train_min_max_faces2use = [0000, 4000] 141 | params.test_min_max_faces2use = [0000, 4000] 142 | 143 | ds_path = config['dataset_path'] 144 | if len(ds_path) < 2: 145 | ds_path = 'datasets_processed/walker_copycat_modelnet40' 146 | params.datasets2use['train'] = [ds_path + '/*train*.npz'] 147 | params.datasets2use['test'] = [ds_path + '/*test*.npz'] 148 | 149 | params.seq_len = 800 150 | params.min_seq_len = int(params.seq_len / 2) 151 | 152 | params.full_accuracy_test = {'dataset_folder': params.datasets2use['test'][0], 153 | 'labels': dataset_prepare.model_net_labels, 154 | 'min_max_faces2use': params.test_min_max_faces2use, 155 | 'n_walks_per_model': 16 * 4, 156 | } 157 | 158 | 159 | # Parameters to recheck: 160 | params.iters_to_train = 500e3 161 | params.net_input = ['xyz'] 162 | params.walk_alg = 'random_global_jumps' # no_jumps / global_jumps 163 | 164 | if 1: 165 | params.iters_to_train = 2000e3 166 | params.net_input = ['dxdydz'] 167 | 168 | params.last_layer_actication = None 169 | params.last_layer_activation = None 170 | params.cycle_opt_prms = EasyDict({'initial_learning_rate': 1e-6 / 2, 171 | 'maximal_learning_rate': 0.0005 / 2, 172 | 'step_size': 10000}) 173 | params.net_start_from_prev_net = None 174 | params.batch_size = 16 175 | 176 | return params 177 | 178 | 179 | def mesh_net_params(network_task, config=None): 180 | params = set_up_default_params(network_task, 'mesh_net', 0, config) 181 | params.n_classes = 40 182 | 183 | params.cycle_opt_prms = EasyDict({'initial_learning_rate': 1e-6, 184 | 'maximal_learning_rate': 0.0005, 185 | 'step_size': 10000}) 186 | 187 | p = 'mesh_net' 188 | params.train_min_max_faces2use = [0000, 4000] 189 | params.test_min_max_faces2use = [0000, 4000] 190 | 191 | 192 | ds_path = config['dataset_path'] 193 | if len(ds_path) < 2: 194 | ds_path = 'datasets_processed/mesh_net_modelnet40' 195 | params.datasets2use['train'] = [ds_path + '/*train*.npz'] 196 | params.datasets2use['test'] = [ds_path + '/*test*.npz'] 197 | 198 | params.seq_len = 400 199 | params.min_seq_len = int(params.seq_len / 2) 200 | 201 | params.full_accuracy_test = {'dataset_folder': params.datasets2use['test'][0], 202 | 'labels': dataset_prepare.model_net_labels, 203 | 'min_max_faces2use': params.test_min_max_faces2use, 204 | 'n_walks_per_model': 16 * 4, 205 | } 206 | 207 | 208 | # Parameters to recheck: 209 | params.iters_to_train = 500e3 210 | params.net_input = ['xyz'] 211 | params.walk_alg = 'random_global_jumps' # no_jumps / global_jumps 212 | 213 | if 1: 214 | params.iters_to_train = 2000e3 215 | params.net_input = ['xyz'] 216 | 217 | params.last_layer_actication = None 218 | params.last_layer_activation = None 219 | params.cycle_opt_prms = EasyDict({'initial_learning_rate': 1e-6 / 2, 220 | 'maximal_learning_rate': 0.0005 / 2, 221 | 'step_size': 10000}) 222 | 223 | params.net_start_from_prev_net = None 224 | params.batch_size = 16 225 | 226 | return params 227 | 228 | def cubes_params(network_task, config = None): 229 | # |V| = 250 , |F| = 500 => seq_len = |V| / 2.5 = 100 230 | params = set_up_default_params(network_task, 'cubes', 0, config) 231 | params.n_classes = 22 232 | params.seq_len = 100 233 | params.min_seq_len = int(params.seq_len / 2) 234 | 235 | p = 'cubes' 236 | params.datasets2use['train'] = [os.path.expanduser('~') + '/mesh_walker/datasets_processed/' + p + '/*train*.npz'] 237 | params.datasets2use['test'] = [os.path.expanduser('~') + '/mesh_walker/datasets_processed/' + p + '/*test*.npz'] 238 | 239 | params.full_accuracy_test = {'dataset_folder': params.datasets2use['test'][0], 240 | 'labels': dataset_prepare.cubes_labels, 241 | } 242 | 243 | params.iters_to_train = 460e3 244 | 245 | return params 246 | 247 | def shrec11_params(split_part, network_task, config = None): 248 | # split_part is one of the following: 249 | # 10-10_A / 10-10_B / 10-10_C 250 | # 16-04_a / 16-04_b / 16-04_C 251 | 252 | # |V| = 250 , |F| = 500 => seq_len = |V| / 2.5 = 100 253 | params = set_up_default_params(network_task, 'shrec11_' + split_part, 0, config) 254 | params.n_classes = 30 255 | params.seq_len = 200 256 | params.min_seq_len = int(params.seq_len / 2) 257 | 258 | #500 259 | 260 | ds_path = config['dataset_path'] 261 | if len(ds_path) < 2: 262 | ds_path = 'datasets_processed/walker_copycat_shrec11/' 263 | params.datasets2use['train'] = [ds_path + '/*train*.npz'] 264 | params.datasets2use['test'] = [ds_path + '/*test*.npz'] 265 | 266 | params.train_data_augmentation = {'rotation': MAX_AUGMENTATION} 267 | params.last_layer_activation = None 268 | 269 | params.full_accuracy_test = {'dataset_folder': params.datasets2use['test'][0], 270 | 'labels': dataset_prepare.shrec11_labels} 271 | 272 | params.net_using_from_prev_net = 'trained_models/meshCNN_imitating_network/learned_model2keep__00200010.keras' 273 | params.iters_to_train = 32e3 274 | 275 | return params 276 | 277 | 278 | # Semantic Segmentation 279 | # --------------------- 280 | def human_seg_params(network_task, config = None): 281 | # |V| = 750 , |F| = 1500 => seq_len = |V| / 2.5 = 300 282 | params = set_up_default_params(network_task, 'human_seg', 0, config) 283 | params.n_classes = 9 284 | params.seq_len = 300 285 | 286 | if 1: # MeshCNN data 287 | sub_dir = 'human_seg_from_meshcnn' 288 | if 0: # Simplification to 1.5k faces 289 | sub_dir = 'sig17_seg_benchmark-1.5k' 290 | if 0: # Simplification to 4k faces 4000 / 2 / 2.5 = 800 291 | sub_dir = 'sig17_seg_benchmark-4k' 292 | params.seq_len = 1200 293 | if 0: # Simplification to 6k faces 6000 / 2 / 2.5 = 1200 294 | sub_dir = 'sig17_seg_benchmark-6k' 295 | params.seq_len = 2000 296 | if 0: # Simplification to 8k faces 297 | sub_dir = 'sig17_seg_benchmark-8k' 298 | params.seq_len = 1600 299 | params.batch_size = int(16 / params.n_walks_per_model) 300 | if 0: 301 | params.n_target_vrt_to_norm_walk = 3000 302 | sub_dir = 'sig17_seg_benchmark-no_simplification' 303 | params.seq_len = 2000 304 | p = os.path.expanduser('~') + '/mesh_walker/datasets_processed/' + sub_dir + '/' 305 | params.datasets2use['train'] = [p + '*train*.npz'] 306 | params.datasets2use['test'] = [p + '*test*.npz'] 307 | 308 | params.min_seq_len = int(params.seq_len / 2) 309 | params.train_data_augmentation = {'rotation': MAX_AUGMENTATION} 310 | 311 | params.full_accuracy_test = {'dataset_folder': params.datasets2use['test'][0], 312 | 'n_iters': 32} 313 | 314 | 315 | # Parameters to recheck: 316 | params.iters_to_train = 100e3 317 | 318 | params.cycle_opt_prms = EasyDict({'initial_learning_rate': 1e-6, 319 | 'maximal_learning_rate': 2e-5, 320 | 'step_size': 10000}) 321 | 322 | return params 323 | 324 | 325 | def coseg_params(type, network_task, config = None): # aliens / chairs / vases 326 | # |V| = 750 , |F| = 1500 => seq_len = |V| / 2.5 = 300 327 | sub_folder = 'coseg_' + type 328 | p = os.path.expanduser('~') + '/mesh_walker/datasets_processed/coseg_from_meshcnn/' + sub_folder + '/' 329 | params = set_up_default_params(network_task, 'coseg_' + type, 0, config) 330 | params.n_classes = 10 331 | params.seq_len = 300 332 | params.min_seq_len = int(params.seq_len / 2) 333 | 334 | params.datasets2use['train'] = [p + '*train*.npz'] 335 | params.datasets2use['test'] = [p + '*test*.npz'] 336 | 337 | params.iters_to_train = 200e3 338 | params.train_data_augmentation = {'rotation': MAX_AUGMENTATION} 339 | 340 | params.full_accuracy_test = {'dataset_folder': params.datasets2use['test'][0], 341 | 'n_iters': 32} 342 | 343 | 344 | return params 345 | 346 | 347 | -------------------------------------------------------------------------------- /imitating_network_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow_addons as tfa 6 | 7 | import rnn_model 8 | import dataset 9 | import utils 10 | import params_setting 11 | import argparse 12 | 13 | def label_to_one_hot(labels: tf.Tensor, params): 14 | return tf.one_hot(indices=labels, depth=params.n_classes) 15 | 16 | def print_enters(to_print): 17 | print("\n\n\n\n") 18 | print(to_print) 19 | print("\n\n\n\n") 20 | 21 | def train_val(params): 22 | utils.next_iter_to_keep = 10000 23 | print(utils.color.BOLD + utils.color.RED + 'params.logdir :::: ', params.logdir, utils.color.END) 24 | print(utils.color.BOLD + utils.color.RED, os.getpid(), utils.color.END) 25 | utils.backup_python_files_and_params(params) 26 | 27 | # Set up datasets_processed for training and for test 28 | # ----------------------------------------- 29 | train_datasets = [] 30 | train_ds_iters = [] 31 | max_train_size = 0 32 | for i in range(len(params.datasets2use['train'])): 33 | this_train_dataset, n_trn_items = dataset.tf_mesh_dataset(params, params.datasets2use['train'][i], 34 | mode=params.network_tasks[i], 35 | size_limit=params.train_dataset_size_limit, 36 | shuffle_size=100, 37 | min_max_faces2use=params.train_min_max_faces2use, 38 | max_size_per_class=params.train_max_size_per_class, 39 | min_dataset_size=128, 40 | data_augmentation=params.train_data_augmentation) 41 | print('Train Dataset size:', n_trn_items) 42 | train_ds_iters.append(iter(this_train_dataset.repeat())) 43 | train_datasets.append(this_train_dataset) 44 | 45 | max_train_size = max(max_train_size, n_trn_items) 46 | train_epoch_size = max(8, int(max_train_size / params.n_walks_per_model / params.batch_size)) 47 | print('train_epoch_size:', train_epoch_size) 48 | if params.datasets2use['test'] is None: 49 | test_dataset = None 50 | n_tst_items = 0 51 | else: 52 | test_dataset, n_tst_items = dataset.tf_mesh_dataset(params, params.datasets2use['test'][0], 53 | mode=params.network_tasks[0], 54 | size_limit=params.test_dataset_size_limit, 55 | shuffle_size=100, 56 | min_max_faces2use=params.test_min_max_faces2use) 57 | test_ds_iter = iter(test_dataset.repeat()) 58 | print(' Test Dataset size:', n_tst_items) 59 | 60 | # Set up RNN model and optimizer 61 | # ------------------------------ 62 | if params.net_start_from_prev_net is not None: 63 | init_net_using = params.net_start_from_prev_net 64 | else: 65 | init_net_using = None 66 | 67 | if params.optimizer_type == 'adam': 68 | optimizer = tf.keras.optimizers.Adam(lr=params.learning_rate[0], clipnorm=params.gradient_clip_th) 69 | elif params.optimizer_type == 'cycle': 70 | @tf.function 71 | def _scale_fn(x): 72 | x_th = 500e3 / params.cycle_opt_prms.step_size 73 | if x < x_th: 74 | return 1.0 75 | else: 76 | return 0.5 77 | lr_schedule = tfa.optimizers.CyclicalLearningRate(initial_learning_rate=params.cycle_opt_prms.initial_learning_rate, 78 | maximal_learning_rate=params.cycle_opt_prms.maximal_learning_rate, 79 | step_size=params.cycle_opt_prms.step_size, 80 | scale_fn=_scale_fn, scale_mode="cycle", name="MyCyclicScheduler") 81 | optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, clipnorm=params.gradient_clip_th) 82 | elif params.optimizer_type == 'sgd': 83 | optimizer = tf.keras.optimizers.SGD(lr=params.learning_rate[0], decay=0, momentum=0.9, nesterov=True, 84 | clipnorm=params.gradient_clip_th) 85 | else: 86 | raise Exception('optimizer_type not supported: ' + params.optimizer_type) 87 | 88 | if params.net == 'RnnWalkNet': 89 | dnn_model = rnn_model.RnnWalkNet(params, params.n_classes, params.net_input_dim, init_net_using, 90 | optimizer=optimizer) 91 | elif params.net == "Manifold_RnnWalkNet": 92 | dnn_model = rnn_model.RnnManifoldWalkNet(params, params.n_classes, params.net_input_dim, init_net_using, 93 | optimizer=optimizer) 94 | 95 | # Other initializations 96 | # --------------------- 97 | time_msrs = {} 98 | time_msrs_names = ['train_step', 'get_train_data', 'test'] 99 | for name in time_msrs_names: 100 | time_msrs[name] = 0 101 | manifold_seg_train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='seg_train_accuracy') 102 | seg_train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='seg_train_accuracy') 103 | 104 | train_log_names = ['seg_loss'] 105 | train_logs = {name: tf.keras.metrics.Mean(name=name) for name in train_log_names} 106 | train_logs['seg_train_accuracy'] = seg_train_accuracy 107 | 108 | # Train / test functions 109 | # ---------------------- 110 | if params.last_layer_actication is None: 111 | seg_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 112 | manifold_seg_loss = tf.keras.losses.KLDivergence(rom_logits=True) 113 | else: 114 | seg_loss = tf.keras.losses.SparseCategoricalCrossentropy() 115 | manifold_seg_loss = tf.keras.losses.KLDivergence() 116 | 117 | #@tf.function 118 | def train_step(model_ftrs_, labels_, one_label_per_model): 119 | sp = model_ftrs_.shape 120 | model_ftrs = tf.reshape(model_ftrs_, (-1, sp[-2], sp[-1])) 121 | with tf.GradientTape() as tape: 122 | if one_label_per_model: 123 | labels = tf.reshape(tf.transpose(tf.stack((labels_,) * params.n_walks_per_model)), (-1,)) 124 | predictions = dnn_model(model_ftrs) 125 | else: 126 | labels = tf.reshape(labels_, (-1, sp[-2])) 127 | skip = params.min_seq_len 128 | predictions = dnn_model(model_ftrs)[:, skip:] 129 | labels = labels[:, skip + 1:] 130 | 131 | if params.train_loss == ['manifold_cros_entr']: 132 | labels = label_to_one_hot(labels=labels, params=params) 133 | manifold_seg_train_accuracy(labels, predictions) 134 | loss = manifold_seg_loss(labels, predictions) 135 | else: 136 | seg_train_accuracy(labels, predictions) 137 | loss = seg_loss(labels, predictions) 138 | loss += tf.reduce_sum(dnn_model.losses) 139 | 140 | gradients = tape.gradient(loss, dnn_model.trainable_variables) 141 | optimizer.apply_gradients(zip(gradients, dnn_model.trainable_variables)) 142 | 143 | train_logs['seg_loss'](loss) 144 | 145 | return loss 146 | 147 | if params.train_loss == ['manifold_cros_entr']: 148 | test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy') 149 | else: 150 | test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') 151 | 152 | #@tf.function 153 | def test_step(model_ftrs_, labels_, one_label_per_model): 154 | sp = model_ftrs_.shape 155 | model_ftrs = tf.reshape(model_ftrs_, (-1, sp[-2], sp[-1])) 156 | if one_label_per_model: 157 | labels = tf.reshape(tf.transpose(tf.stack((labels_,) * params.n_walks_per_model)), (-1,)) 158 | predictions = dnn_model(model_ftrs, training=False) 159 | else: 160 | labels = tf.reshape(labels_, (-1, sp[-2])) 161 | skip = params.min_seq_len 162 | if params.train_loss == ['manifold_cros_entr']: 163 | predictions = dnn_model(model_ftrs, training=False)[:, skip:] 164 | else: 165 | predictions = dnn_model(model_ftrs, training=False)[:, skip:] 166 | labels = labels[:, skip + 1:] 167 | 168 | if params.train_loss == ['manifold_cros_entr']: 169 | labels = label_to_one_hot(labels=labels, params=params) 170 | test_accuracy(labels, predictions) 171 | confusion = None 172 | return confusion 173 | # ------------------------------------- 174 | 175 | # Loop over training EPOCHs 176 | # ------------------------- 177 | one_label_per_model = params.one_label_per_model 178 | next_iter_to_log = 0 179 | e_time = 0 180 | accrcy_smoothed = tb_epoch = last_loss = None 181 | all_confusion = {} 182 | 183 | 184 | with tf.summary.create_file_writer(params.logdir).as_default(): 185 | epoch = 0 186 | while optimizer.iterations.numpy() < params.iters_to_train + train_epoch_size * 2: 187 | epoch += 1 188 | if epoch % 10 == 0: 189 | print(params.logdir) 190 | print(config['description']) 191 | str_to_print = str(os.getpid()) + ') Epoch' + str(epoch) + ', iter ' + str(optimizer.iterations.numpy()) 192 | 193 | # Save some logs & infos 194 | utils.save_model_if_needed(optimizer.iterations, dnn_model, params) 195 | if tb_epoch is not None: 196 | e_time = time.time() - tb_epoch 197 | tf.summary.scalar('time/one_epoch', e_time, step=optimizer.iterations) 198 | tf.summary.scalar('time/av_one_trn_itr', e_time / n_iters, step=optimizer.iterations) 199 | for name in time_msrs_names: 200 | if time_msrs[name]: # if there is something to save 201 | tf.summary.scalar('time/' + name, time_msrs[name], step=optimizer.iterations) 202 | time_msrs[name] = 0 203 | tb_epoch = time.time() 204 | n_iters = 0 205 | tf.summary.scalar(name="train/learning_rate", data=optimizer._decayed_lr(tf.float32), step=optimizer.iterations) 206 | tf.summary.scalar(name="mem/free", data=utils.check_mem_and_exit_if_full(), step=optimizer.iterations) 207 | 208 | str_to_print += '; LR: ' + str(optimizer._decayed_lr(tf.float32)) 209 | train_logs['seg_loss'].reset_states() 210 | tb = time.time() 211 | for iter_db in range(train_epoch_size): 212 | for dataset_id in range(len(train_datasets)): 213 | name, model_ftrs, labels = train_ds_iters[dataset_id].next() 214 | dataset_type = utils.get_dataset_type_from_name(name) 215 | if params.learning_rate_dynamics != 'stable': 216 | utils.update_lerning_rate_in_optimizer(0, params.learning_rate_dynamics, optimizer, params) 217 | time_msrs['get_train_data'] += time.time() - tb 218 | n_iters += 1 219 | tb = time.time() 220 | if params.train_loss[dataset_id] == 'cros_entr': 221 | train_step(model_ftrs, labels, one_label_per_model=one_label_per_model) 222 | loss2show = 'seg_loss' 223 | elif params.train_loss[dataset_id] == 'manifold_cros_entr': 224 | train_step(model_ftrs, labels, one_label_per_model=one_label_per_model) 225 | loss2show = 'seg_loss' 226 | else: 227 | raise Exception('Unsupported loss_type: ' + params.train_loss[dataset_id]) 228 | time_msrs['train_step'] += time.time() - tb 229 | tb = time.time() 230 | if iter_db == train_epoch_size - 1: 231 | str_to_print += ', TrnLoss: ' + str(round(train_logs[loss2show].result().numpy(), 2)) 232 | 233 | # Dump training info to tensorboard 234 | if optimizer.iterations >= next_iter_to_log: 235 | for k, v in train_logs.items(): 236 | if v.count.numpy() > 0: 237 | tf.summary.scalar('train/' + k, v.result(), step=optimizer.iterations) 238 | v.reset_states() 239 | next_iter_to_log += params.log_freq 240 | 241 | # Run test on part of the test set 242 | if test_dataset is not None: 243 | n_test_iters = 0 244 | tb = time.time() 245 | #for name, model_ftrs, labels in test_dataset: 246 | for i in range(n_tst_items): 247 | name, model_ftrs, labels = test_ds_iter.next() 248 | 249 | n_test_iters += model_ftrs.shape[0] 250 | if n_test_iters > params.n_models_per_test_epoch: 251 | break 252 | confusion = test_step(model_ftrs, labels, one_label_per_model=one_label_per_model) 253 | # Amir - added the case that confusion is none as a result of recon training 254 | if confusion is not None: 255 | dataset_type = utils.get_dataset_type_from_name(name) 256 | if dataset_type in all_confusion.keys(): 257 | all_confusion[dataset_type] += confusion 258 | else: 259 | all_confusion[dataset_type] = confusion 260 | # Dump test info to tensorboard 261 | if accrcy_smoothed is None: 262 | accrcy_smoothed = test_accuracy.result() 263 | accrcy_smoothed = accrcy_smoothed * .9 + test_accuracy.result() * 0.1 264 | tf.summary.scalar('test/accuracy_' + dataset_type, test_accuracy.result(), step=optimizer.iterations) 265 | tf.summary.scalar('test/accuracy_smoothed', accrcy_smoothed, step=optimizer.iterations) 266 | str_to_print += ', test/accuracy_' + dataset_type + ': ' + str(round(test_accuracy.result().numpy(), 2)) 267 | test_accuracy.reset_states() 268 | time_msrs['test'] += time.time() - tb 269 | 270 | str_to_print += ', time: ' + str(round(e_time, 1)) 271 | print(str_to_print) 272 | 273 | 274 | return last_loss 275 | 276 | def run_one_job(job, job_part, network_task): 277 | # Classifications 278 | job = job.lower() 279 | if job == 'modelnet40' or job == 'modelnet': 280 | params = params_setting.modelnet_params(network_task, config) 281 | 282 | if job == 'shrec11': 283 | params = params_setting.shrec11_params(job_part, network_task, config) 284 | 285 | if job == 'cubes': 286 | params = params_setting.cubes_params(network_task, config) 287 | 288 | # Semantic Segmentations 289 | if job == 'human_seg': 290 | params = params_setting.human_seg_params(network_task, config) 291 | 292 | if job == 'coseg': 293 | params = params_setting.coseg_params(job_part, network_task, config) # job_part can be : 'aliens' or 'chairs' or 'vases' 294 | train_val(params) 295 | 296 | 297 | def get_all_jobs(): 298 | jobs = [ 299 | 'shrec11', 'shrec11', 'shrec11', 300 | 'shrec11', 'shrec11', 'shrec11', 301 | 'coseg', 'coseg', 'coseg', 302 | 'human_seg', 303 | 'cubes', 304 | 'modelnet40', 305 | ] 306 | 307 | job_parts = [ 308 | '10-10_A', '10-10_B', '10-10_C', 309 | '16-04_a', '16-4_B', '16-4_C', 310 | 'aliens', 'vases', 'chairs', 311 | None, 312 | None, 313 | None, 314 | ] 315 | 316 | return jobs, job_parts 317 | 318 | if __name__ == '__main__': 319 | np.random.seed(0) 320 | utils.config_gpu() 321 | 322 | # get hyper params from yaml 323 | parser = argparse.ArgumentParser() 324 | parser.add_argument('--config', type=str, default='recon_config.yaml', help='Path to the config file.') 325 | opts = parser.parse_args() 326 | config = utils.get_config(opts.config) 327 | 328 | job = config['job'] 329 | job_part = config['job_part'] 330 | 331 | # choose network task from: 'features_extraction', 'unsupervised_classification', 'semantic_segmentation', 'classification'. 'manifold_classification' 332 | network_task = config['network_task'] 333 | 334 | if job.lower() == 'all': 335 | jobs, job_parts = get_all_jobs() 336 | for job_, job_part in zip(jobs, job_parts): 337 | run_one_job(job_, job_part, network_task) 338 | else: 339 | run_one_job(job, job_part, network_task) 340 | 341 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import glob, os, copy 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | import utils 7 | import walks 8 | import dataset_prepare 9 | 10 | def print_enters(to_print): 11 | print("\n\n\n\n") 12 | print(to_print) 13 | print("\n\n\n\n") 14 | 15 | # Glabal list of dataset parameters 16 | dataset_params_list = [] 17 | 18 | def load_model_from_npz(npz_fn): 19 | if npz_fn.find(':') != -1: 20 | npz_fn = npz_fn.split(':')[1] 21 | mesh_data = np.load(npz_fn, encoding='latin1', allow_pickle=True) 22 | return mesh_data 23 | 24 | 25 | def norm_model(vertices): 26 | # Move the model so the bbox center will be at (0, 0, 0) 27 | mean = np.mean((np.min(vertices, axis=0), np.max(vertices, axis=0)), axis=0) 28 | vertices -= mean 29 | 30 | # Scale model to fit into the unit ball 31 | if 1: # Model Norm -->> !!! 32 | norm_with = np.max(vertices) 33 | else: 34 | norm_with = np.max(np.linalg.norm(vertices, axis=1)) 35 | vertices /= norm_with 36 | 37 | if norm_model.sub_mean_for_data_augmentation: 38 | vertices -= np.nanmean(vertices, axis=0) 39 | 40 | 41 | def data_augmentation_axes_rot(vertices): 42 | if np.random.randint(2): # 50% chance to switch the two hirisontal axes 43 | vertices[:] = vertices[:, data_augmentation_axes_rot.flip_axes] 44 | if np.random.randint(2): # 50% chance to neg one random hirisontal axis 45 | i = np.random.choice(data_augmentation_axes_rot.hori_axes) 46 | vertices[:, i] = -vertices[:, i] 47 | 48 | 49 | def rotate_to_check_weak_points(max_rot_ang_deg): 50 | if np.random.randint(2): 51 | x = max_rot_ang_deg 52 | else: 53 | x = -max_rot_ang_deg 54 | if np.random.randint(2): 55 | y = max_rot_ang_deg 56 | else: 57 | y = -max_rot_ang_deg 58 | if np.random.randint(2): 59 | z = max_rot_ang_deg 60 | else: 61 | z = -max_rot_ang_deg 62 | 63 | return x, y, z 64 | 65 | def data_augmentation_rotation(vertices): 66 | if 1:#np.random.randint(2): # 50% chance 67 | max_rot_ang_deg = data_augmentation_rotation.max_rot_ang_deg 68 | if 0: 69 | x = y = z = 0 70 | if data_augmentation_rotation.test_rotation_axis == 0: 71 | x = max_rot_ang_deg 72 | if data_augmentation_rotation.test_rotation_axis == 1: 73 | y = max_rot_ang_deg 74 | if data_augmentation_rotation.test_rotation_axis == 2: 75 | z = max_rot_ang_deg 76 | else: 77 | x = np.random.uniform(-max_rot_ang_deg, max_rot_ang_deg) * np.pi / 180 78 | y = np.random.uniform(-max_rot_ang_deg, max_rot_ang_deg) * np.pi / 180 79 | z = np.random.uniform(-max_rot_ang_deg, max_rot_ang_deg) * np.pi / 180 80 | A = np.array(((np.cos(x), -np.sin(x), 0), 81 | (np.sin(x), np.cos(x), 0), 82 | (0, 0, 1)), 83 | dtype=vertices.dtype) 84 | B = np.array(((np.cos(y), 0, -np.sin(y)), 85 | (0, 1, 0), 86 | (np.sin(y), 0, np.cos(y))), 87 | dtype=vertices.dtype) 88 | C = np.array(((1, 0, 0), 89 | (0, np.cos(z), -np.sin(z)), 90 | (0, np.sin(z), np.cos(z))), 91 | dtype=vertices.dtype) 92 | np.dot(vertices, A, out=vertices) 93 | np.dot(vertices, B, out=vertices) 94 | np.dot(vertices, C, out=vertices) 95 | 96 | 97 | def data_augmentation_aspect_ratio(vertices): 98 | if np.random.randint(2): # 50% chance 99 | for i in range(3): 100 | r = np.random.uniform(1 - data_augmentation_aspect_ratio.max_ratio, 1 + data_augmentation_aspect_ratio.max_ratio) 101 | vertices[i] *= r 102 | 103 | 104 | def fill_xyz_features(features, f_idx, vertices, mesh_extra, seq, jumps, seq_len): 105 | walk = vertices[seq[1:seq_len + 1]] 106 | features[:, f_idx:f_idx + walk.shape[1]] = walk 107 | f_idx += 3 108 | return f_idx 109 | 110 | 111 | def fill_dxdydz_features(features, f_idx, vertices, mesh_extra, seq, jumps, seq_len): 112 | walk = np.diff(vertices[seq[:seq_len + 1]], axis=0) * 100 113 | features[:, f_idx:f_idx + walk.shape[1]] = walk 114 | f_idx += 3 115 | return f_idx 116 | 117 | 118 | def fill_vertex_indices(features, f_idx, vertices, mesh_extra, seq, jumps, seq_len): 119 | walk = seq[1:seq_len + 1][:, None] 120 | features[:, f_idx:f_idx + walk.shape[1]] = walk 121 | f_idx += 1 122 | return f_idx 123 | 124 | 125 | def fill_jumps(features, f_idx, vertices, mesh_extra, seq, jumps, seq_len): 126 | walk = jumps[1:seq_len + 1][:, None] 127 | features[:, f_idx:f_idx + walk.shape[1]] = walk 128 | f_idx += 1 129 | return f_idx 130 | 131 | 132 | def setup_data_augmentation(dataset_params, data_augmentation): 133 | dataset_params.data_augmentaion_vertices_functions = [] 134 | if 'horisontal_90deg' in data_augmentation.keys() and data_augmentation['horisontal_90deg']: 135 | dataset_params.data_augmentaion_vertices_functions.append(data_augmentation_axes_rot) 136 | data_augmentation_axes_rot.hori_axes = data_augmentation['horisontal_90deg'] 137 | flip_axes_ = [0, 1, 2] 138 | data_augmentation_axes_rot.flip_axes = [0, 1, 2] 139 | data_augmentation_axes_rot.flip_axes[data_augmentation_axes_rot.hori_axes[0]] = flip_axes_[data_augmentation_axes_rot.hori_axes[1]] 140 | data_augmentation_axes_rot.flip_axes[data_augmentation_axes_rot.hori_axes[1]] = flip_axes_[data_augmentation_axes_rot.hori_axes[0]] 141 | if 'rotation' in data_augmentation.keys() and data_augmentation['rotation']: 142 | data_augmentation_rotation.max_rot_ang_deg = data_augmentation['rotation'] 143 | dataset_params.data_augmentaion_vertices_functions.append(data_augmentation_rotation) 144 | if 'aspect_ratio' in data_augmentation.keys() and data_augmentation['aspect_ratio']: 145 | data_augmentation_aspect_ratio.max_ratio = data_augmentation['aspect_ratio'] 146 | dataset_params.data_augmentaion_vertices_functions.append(data_augmentation_aspect_ratio) 147 | 148 | 149 | def setup_features_params(dataset_params, params): 150 | if params.uniform_starting_point: 151 | dataset_params.area = 'all' 152 | else: 153 | dataset_params.area = -1 154 | norm_model.sub_mean_for_data_augmentation = params.sub_mean_for_data_augmentation 155 | dataset_params.support_mesh_cnn_ftrs = False 156 | dataset_params.fill_features_functions = [] 157 | dataset_params.number_of_features = 0 158 | net_input = params.net_input 159 | if 'xyz' in net_input: 160 | dataset_params.fill_features_functions.append(fill_xyz_features) 161 | dataset_params.number_of_features += 3 162 | if 'dxdydz' in net_input: 163 | dataset_params.fill_features_functions.append(fill_dxdydz_features) 164 | dataset_params.number_of_features += 3 165 | if 'edge_meshcnn' in net_input: 166 | dataset_params.support_mesh_cnn_ftrs = True 167 | dataset_params.fill_features_functions.append(fill_edge_meshcnn_features) 168 | dataset_params.number_of_features += 5 169 | if 'normals' in net_input: 170 | dataset_params.fill_features_functions.append(fill_normals_features) 171 | dataset_params.number_of_features += 3 172 | if 'jump_indication' in net_input: 173 | dataset_params.fill_features_functions.append(fill_jumps) 174 | dataset_params.number_of_features += 1 175 | if 'vertex_indices' in net_input: 176 | dataset_params.fill_features_functions.append(fill_vertex_indices) 177 | dataset_params.number_of_features += 1 178 | 179 | dataset_params.edges_needed = True 180 | if params.walk_alg == 'no_jumps': 181 | dataset_params.walk_function = walks.get_seq_random_walk_no_jumps 182 | dataset_params.kdtree_query_needed = False 183 | elif params.walk_alg == 'random_global_jumps': 184 | dataset_params.walk_function = walks.get_seq_random_walk_random_global_jumps 185 | dataset_params.kdtree_query_needed = False 186 | elif params.walk_alg == 'local_jumps': 187 | dataset_params.walk_function = walks.get_seq_random_walk_local_jumps 188 | dataset_params.kdtree_query_needed = True 189 | dataset_params.edges_needed = False 190 | else: 191 | raise Exception('Walk alg not recognized: ' + params.walk_alg) 192 | 193 | return dataset_params.number_of_features 194 | 195 | 196 | def get_starting_point(area, area_vertices_list, n_vertices, walk_id): 197 | if area is None or area_vertices_list is None: 198 | return np.random.randint(n_vertices) 199 | elif area == -1: 200 | candidates = np.zeros((0,)) 201 | while candidates.size == 0: 202 | b = np.random.randint(9) 203 | candidates = area_vertices_list[b] 204 | return np.random.choice(candidates) 205 | else: 206 | candidates = area_vertices_list[walk_id % len(area_vertices_list)] 207 | while candidates.size == 0: 208 | b = np.random.randint(9) 209 | candidates = area_vertices_list[b] 210 | return np.random.choice(candidates) 211 | 212 | 213 | def generate_walk_py_fun(fn, vertices, faces, edges, kdtree_query, labels, params_idx): 214 | return tf.py_function( 215 | generate_walk, 216 | inp=(fn, vertices, faces, edges, kdtree_query, labels, params_idx), 217 | Tout=(fn.dtype, vertices.dtype, tf.int32) 218 | ) 219 | 220 | 221 | def generate_walk(fn, vertices, faces, edges, kdtree_query, labels_from_npz, params_idx): 222 | mesh_data = {'vertices': vertices.numpy(), 223 | 'faces': faces.numpy(), 224 | 'edges': edges.numpy(), 225 | 'kdtree_query': kdtree_query.numpy(), 226 | } 227 | if dataset_params_list[params_idx[0]].label_per_step: 228 | mesh_data['labels'] = labels_from_npz.numpy() 229 | 230 | dataset_params = dataset_params_list[params_idx[0].numpy()] 231 | features, labels = mesh_data_to_walk_features(mesh_data, dataset_params) 232 | 233 | if dataset_params_list[params_idx[0]].label_per_step: 234 | labels_return = labels 235 | else: 236 | labels_return = labels_from_npz 237 | 238 | return fn[0], features, labels_return 239 | 240 | 241 | def mesh_data_to_walk_features(mesh_data, dataset_params): 242 | vertices = mesh_data['vertices'] 243 | seq_len = dataset_params.seq_len 244 | if dataset_params.set_seq_len_by_n_faces: 245 | seq_len = int(mesh_data['vertices'].shape[0]) 246 | seq_len = min(seq_len, dataset_params.seq_len) 247 | 248 | # Preprocessing 249 | if dataset_params.adjust_vertical_model: 250 | vertices[:, 1] -= vertices[:, 1].min() 251 | if dataset_params.normalize_model: 252 | norm_model(vertices) 253 | 254 | # Vertices pertubation, for Tessellation Robustness test (like MeshCNN): 255 | if 0: 256 | vertices = dataset_prepare.vertex_pertubation(mesh_data['faces'], vertices) 257 | 258 | # Data augmentation 259 | for data_augmentaion_function in dataset_params.data_augmentaion_vertices_functions: 260 | data_augmentaion_function(vertices) 261 | 262 | # Get essential data from file 263 | if dataset_params.label_per_step: 264 | print("mesh_data['labels']", mesh_data['labels']) 265 | mesh_labels = mesh_data['labels'] 266 | else: 267 | mesh_labels = -1 * np.ones((vertices.shape[0],)) 268 | 269 | mesh_extra = {} 270 | mesh_extra['n_vertices'] = vertices.shape[0] 271 | if dataset_params.edges_needed: 272 | mesh_extra['edges'] = mesh_data['edges'] 273 | if dataset_params.kdtree_query_needed: 274 | mesh_extra['kdtree_query'] = mesh_data['kdtree_query'] 275 | 276 | features = np.zeros((dataset_params.n_walks_per_model, seq_len, dataset_params.number_of_features), dtype=np.float32) 277 | labels = np.zeros((dataset_params.n_walks_per_model, seq_len), dtype=np.int32) 278 | 279 | if mesh_data_to_walk_features.SET_SEED_WALK: 280 | np.random.seed(mesh_data_to_walk_features.SET_SEED_WALK) 281 | if dataset_params.network_task == 'self:triplets': 282 | neg_walk_f0 = np.random.randint(vertices.shape[0]) 283 | if 1: 284 | pos_walk_f0 = np.random.choice(mesh_data['far_vertices'][neg_walk_f0]) 285 | else: 286 | pos_walk_f0 = np.random.choice(mesh_data['mid_vertices'][neg_walk_f0]) 287 | for walk_id in range(dataset_params.n_walks_per_model): 288 | if dataset_params.network_task == 'self:triplets': 289 | if walk_id < dataset_params.n_walks_per_model / 2: 290 | f0 = neg_walk_f0 291 | else: 292 | f0 = pos_walk_f0 293 | else: 294 | f0 = np.random.randint(vertices.shape[0]) # TODO: to verify it works well! 295 | if mesh_data_to_walk_features.SET_SEED_WALK: 296 | f0 = mesh_data_to_walk_features.SET_SEED_WALK 297 | 298 | if dataset_params.n_target_vrt_to_norm_walk and dataset_params.n_target_vrt_to_norm_walk < vertices.shape[0]: 299 | j = int(round(vertices.shape[0] / dataset_params.n_target_vrt_to_norm_walk)) 300 | else: 301 | j = 1 302 | seq, jumps = dataset_params.walk_function(mesh_extra, f0, seq_len * j) 303 | seq = seq[::j] 304 | if dataset_params.reverse_walk: 305 | seq = seq[::-1] 306 | jumps = jumps[::-1] 307 | 308 | f_idx = 0 309 | for fill_ftr_fun in dataset_params.fill_features_functions: 310 | f_idx = fill_ftr_fun(features[walk_id], f_idx, vertices, mesh_extra, seq, jumps, seq_len) 311 | if dataset_params.label_per_step: 312 | print("mesh labels shape: ", mesh_labels.shape) 313 | if dataset_params.network_task == 'self:triplets': 314 | labels[walk_id] = seq[1:seq_len + 1] 315 | else: 316 | labels[walk_id] = mesh_labels[seq[1:seq_len + 1]] 317 | 318 | return features, labels 319 | 320 | 321 | def get_file_names(pathname_expansion, min_max_faces2use): 322 | filenames_ = glob.glob(pathname_expansion) 323 | filenames = [] 324 | for fn in filenames_: 325 | try: 326 | n_faces = int(fn.split('.')[-2].split('_')[-1]) 327 | if n_faces > min_max_faces2use[1] or n_faces < min_max_faces2use[0]: 328 | continue 329 | except: 330 | pass 331 | filenames.append(fn) 332 | assert len(filenames) > 0, 'DATASET error: no files in directory to be used! \nDataset directory: ' + pathname_expansion 333 | 334 | return filenames 335 | 336 | 337 | def adjust_fn_list_by_size(filenames_, max_size_per_class): 338 | lmap = dataset_prepare.map_fns_to_label(filenames=filenames_) 339 | filenames = [] 340 | if type(max_size_per_class) is int: 341 | models_already_used = {k: set() for k in lmap.keys()} 342 | for k, v in lmap.items(): 343 | for i, f in enumerate(v): 344 | model_name = f.split('/')[-1].split('simplified')[0].split('not_changed')[0] 345 | if len(models_already_used[k]) < max_size_per_class or model_name in models_already_used[k]: 346 | filenames.append(f) 347 | models_already_used[k].add(model_name) 348 | elif max_size_per_class == 'uniform_as_max_class': 349 | max_size = 0 350 | for k, v in lmap.items(): 351 | if len(v) > max_size: 352 | max_size = len(v) 353 | for k, v in lmap.items(): 354 | f = int(np.ceil(max_size / len(v))) 355 | fnms = v * f 356 | filenames += fnms[:max_size] 357 | else: 358 | raise Exception('max_size_per_class not recognized') 359 | 360 | return filenames 361 | 362 | 363 | def filter_fn_by_class(filenames_, classes_indices_to_use): 364 | filenames = [] 365 | for fn in filenames_: 366 | mesh_data = np.load(fn, encoding='latin1', allow_pickle=True) 367 | if classes_indices_to_use is not None and mesh_data['label'] not in classes_indices_to_use: 368 | continue 369 | filenames.append(fn) 370 | return filenames 371 | 372 | 373 | def setup_dataset_params(params, data_augmentation): 374 | p_idx = len(dataset_params_list) 375 | ds_params = copy.deepcopy(params) 376 | ds_params.set_seq_len_by_n_faces = False 377 | if 'n_target_vrt_to_norm_walk' not in ds_params.keys(): 378 | ds_params.n_target_vrt_to_norm_walk = 0 379 | 380 | setup_data_augmentation(ds_params, data_augmentation) 381 | setup_features_params(ds_params, params) 382 | 383 | dataset_params_list.append(ds_params) 384 | 385 | return p_idx 386 | 387 | 388 | class OpenMeshDataset(tf.data.Dataset): 389 | # OUTPUT: (fn, vertices, faces, edges, kdtree_query, labels, params_idx) 390 | OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int16, tf.dtypes.int16, tf.dtypes.int16, tf.dtypes.int32, tf.dtypes.int16) 391 | 392 | def _generator(fn_, params_idx): 393 | fn = fn_[0] 394 | with np.load(fn, encoding='latin1', allow_pickle=True) as mesh_data: 395 | vertices = mesh_data['vertices'] 396 | faces = mesh_data['faces'] 397 | edges = mesh_data['edges'] 398 | if dataset_params_list[params_idx].label_per_step: 399 | labels = mesh_data['labels'] 400 | else: 401 | labels = mesh_data['label'] 402 | if dataset_params_list[params_idx].kdtree_query_needed: 403 | kdtree_query = mesh_data['kdtree_query'] 404 | else: 405 | kdtree_query = [-1] 406 | 407 | name = mesh_data['dataset_name'].tolist() + ':' + fn.decode() 408 | 409 | yield ([name], vertices, faces, edges, kdtree_query, labels, [params_idx]) 410 | 411 | def __new__(cls, filenames, params_idx): 412 | return tf.data.Dataset.from_generator( 413 | cls._generator, 414 | output_types=cls.OUTPUT_TYPES, 415 | args=(filenames, params_idx) 416 | ) 417 | 418 | 419 | def dump_all_fns_to_file(filenames, params): 420 | if 'logdir' in params.keys(): 421 | for n in range(10): 422 | log_fn = params.logdir + '/dataset_files_' + str(n).zfill(2) + '.txt' 423 | if not os.path.isfile(log_fn): 424 | try: 425 | with open(log_fn, 'w') as f: 426 | for fn in filenames: 427 | f.write(fn + '\n') 428 | except: 429 | pass 430 | break 431 | 432 | 433 | def tf_mesh_dataset(params, pathname_expansion, mode=None, size_limit=np.inf, shuffle_size=1000, 434 | permute_file_names=True, min_max_faces2use=[0, np.inf], data_augmentation={}, 435 | must_run_on_all=False, max_size_per_class=None, min_dataset_size=16): 436 | params_idx = setup_dataset_params(params, data_augmentation) 437 | number_of_features = dataset_params_list[params_idx].number_of_features 438 | params.net_input_dim = number_of_features 439 | mesh_data_to_walk_features.SET_SEED_WALK = 0 440 | 441 | filenames = get_file_names(pathname_expansion, min_max_faces2use) 442 | 443 | if params.classes_indices_to_use is not None: 444 | filenames = filter_fn_by_class(filenames, params.classes_indices_to_use) 445 | if max_size_per_class is not None: 446 | filenames = adjust_fn_list_by_size(filenames, max_size_per_class) 447 | 448 | if permute_file_names: 449 | filenames = np.random.permutation(filenames) 450 | else: 451 | filenames.sort() 452 | filenames = np.array(filenames) 453 | if size_limit < len(filenames): 454 | filenames = filenames[:size_limit] 455 | n_items = len(filenames) 456 | if len(filenames) < min_dataset_size: 457 | filenames = filenames.tolist() * (int(min_dataset_size / len(filenames)) + 1) 458 | 459 | if mode == 'classification': 460 | dataset_params_list[params_idx].label_per_step = False 461 | elif mode == 'manifold_classification': 462 | dataset_params_list[params_idx].label_per_step = False 463 | elif mode == 'semantic_segmentation': 464 | dataset_params_list[params_idx].label_per_step = True 465 | elif mode == 'unsupervised_classification': 466 | dataset_params_list[params_idx].label_per_step = False 467 | elif mode == 'features_extraction': 468 | dataset_params_list[params_idx].label_per_step = False 469 | elif mode == 'self:triplets': 470 | dataset_params_list[params_idx].label_per_step = True 471 | else: 472 | raise Exception('DS mode ?') 473 | 474 | dump_all_fns_to_file(filenames, params) 475 | 476 | def _open_npz_fn(*args): 477 | return OpenMeshDataset(args, params_idx) 478 | 479 | ds = tf.data.Dataset.from_tensor_slices(filenames) 480 | if shuffle_size: 481 | ds = ds.shuffle(shuffle_size) 482 | ds = ds.interleave(_open_npz_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) 483 | ds = ds.cache() 484 | ds = ds.map(generate_walk_py_fun, num_parallel_calls=tf.data.experimental.AUTOTUNE) 485 | ds = ds.batch(params.batch_size, drop_remainder=False) 486 | ds = ds.prefetch(tf.data.experimental.AUTOTUNE) 487 | 488 | return ds, n_items 489 | 490 | if __name__ == '__main__': 491 | utils.config_gpu(False) 492 | np.random.seed(1) 493 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, shutil, psutil, json, copy 2 | import time, datetime 3 | import itertools 4 | from easydict import EasyDict 5 | 6 | import pylab as plt 7 | import numpy as np 8 | import tensorflow as tf 9 | import pyvista as pv 10 | import trimesh 11 | import yaml 12 | 13 | from dataset_prepare import shrec11_labels, model_net_labels 14 | 15 | #import dnn_cad_seq 16 | import evaluate_clustering 17 | 18 | SEGMENTATION_COLORMAP = np.array( 19 | ((165, 242, 12), (89, 12, 89), (165, 89, 165), (242, 242, 165), 20 | (242, 165, 12), (89, 12, 12), (165, 12, 12), (165, 89, 242), (12, 12, 165), 21 | (165, 12, 89), (12, 89, 89), (165, 165, 89), (89, 242, 12), (12, 89, 165), 22 | (242, 242, 89), (165, 165, 165)), 23 | dtype=np.float32) / 255.0 24 | 25 | 26 | class color: 27 | PURPLE = '\033[95m' 28 | CYAN = '\033[96m' 29 | DARKCYAN = '\033[36m' 30 | BLUE = '\033[94m' 31 | GREEN = '\033[92m' 32 | YELLOW = '\033[93m' 33 | RED = '\033[91m' 34 | BOLD = '\033[1m' 35 | UNDERLINE = '\033[4m' 36 | END = '\033[0m' 37 | 38 | colors_list = ['black', 'red', 'green', 'blue', 'orange', 'magenta', 'yellow', 'cyan', 39 | 'gray', 'brown', 'lightgreen', 'steelblue', 'lightcoral', 'pink', 'gold', 40 | 'olive', 'darkblue', 'salmon', 'deeppink', 41 | 'lime', 'tomato', 42 | 'sienna', 43 | ] 44 | 45 | 46 | def get_config(config): 47 | with open(config, 'r') as stream: 48 | return yaml.safe_load(stream) 49 | 50 | def set_single_gpu(gpu_num_to_use = -1): 51 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 52 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_num_to_use) 53 | 54 | def config_gpu(use_gpu=True, gpu_num_to_use = -1): 55 | print('tf.__version__', tf.__version__) 56 | np.set_printoptions(suppress=True) 57 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 58 | try: 59 | if use_gpu: 60 | gpus = tf.config.experimental.list_physical_devices('GPU') 61 | #if gpu_num_to_use <= len(gpus) and gpu_num_to_use >= 0: 62 | # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 63 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_num_to_use) 64 | 65 | # os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num_to_use) 66 | #else: 67 | # gpus = [gpus[0]] 68 | for idx, gpu in enumerate(gpus): 69 | if idx == gpu_num_to_use or gpu_num_to_use < 0: 70 | tf.config.experimental.set_memory_growth(gpu, True) 71 | else: 72 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 73 | except: 74 | pass 75 | 76 | 77 | def get_gpu_temprature(): 78 | # Gals cahnges 79 | output = os.popen("nvidia-smi -q | grep 'GPU Current Temp' | cut -d' ' -f 24").read() 80 | #output = os.popen("nvidia-smi -q | grep 'GPU Current Temp' | cut -d':' -f 2 | cut -d' ' -f 2").read() 81 | # End 82 | output = ''.join(filter(str.isdigit, output)) 83 | try: 84 | temp = int(output) 85 | except: 86 | temp = -1 87 | return temp 88 | 89 | 90 | def backup_python_files_and_params(params): 91 | save_id = 0 92 | while 1: 93 | code_log_folder = params.logdir + '/.' + str(save_id) 94 | if not os.path.isdir(code_log_folder): 95 | os.makedirs(code_log_folder) 96 | for file in os.listdir(): 97 | if file.endswith('py'): 98 | shutil.copyfile(file, code_log_folder + '/' + file) 99 | break 100 | else: 101 | save_id += 1 102 | 103 | # Dump params to text file 104 | try: 105 | prm2dump = copy.deepcopy(params) 106 | if 'hyper_params' in prm2dump.keys(): 107 | prm2dump.hyper_params = str(prm2dump.hyper_params) 108 | prm2dump.hparams_metrics = prm2dump.hparams_metrics[0]._display_name 109 | for l in prm2dump.net: 110 | l['layer_function'] = 'layer_function' 111 | with open(params.logdir + '/params.txt', 'w') as fp: 112 | json.dump(prm2dump, fp, indent=2, sort_keys=True) 113 | except: 114 | pass 115 | 116 | 117 | def get_run_folder(root_dir, str2add='', cont_run_number=False): 118 | try: 119 | all_runs = os.listdir(root_dir) 120 | run_ids = [int(d.split('-')[0]) for d in all_runs if '-' in d] 121 | if cont_run_number: 122 | n = [i for i, m in enumerate(run_ids) if m == cont_run_number][0] 123 | run_dir = root_dir + all_runs[n] 124 | print('Continue to run at:', run_dir) 125 | return run_dir 126 | n = np.sort(run_ids)[-1] 127 | except: 128 | n = 0 129 | now = datetime.datetime.now() 130 | return root_dir + str(n + 1).zfill(4) + '-' + now.strftime("%d.%m.%Y..%H.%M") + str2add 131 | 132 | 133 | def index2color(idx): 134 | return SEGMENTATION_COLORMAP[np.array(idx).astype(np.int)] 135 | 136 | 137 | def plot_confusion_matrix(cm, classes, 138 | normalize=False, 139 | title='Confusion matrix', 140 | cmap=plt.cm.Blues, 141 | show_txt=True): 142 | """ 143 | This function prints and plots the confusion matrix. 144 | Normalization can be applied by setting `normalize=True`. 145 | """ 146 | if normalize: 147 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 148 | print("Normalized confusion matrix") 149 | else: 150 | print('Confusion matrix, without normalization') 151 | 152 | print(cm) 153 | 154 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 155 | plt.title(title) 156 | plt.colorbar() 157 | tick_marks = np.arange(len(classes)) 158 | plt.xticks(tick_marks, classes, rotation=45) 159 | plt.yticks(tick_marks, classes) 160 | 161 | if show_txt: 162 | fmt = '.2f' if normalize else 'd' 163 | thresh = cm.max() / 2. 164 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 165 | plt.text(j, i, format(cm[i, j], fmt), 166 | horizontalalignment="center", 167 | color="white" if cm[i, j] > thresh else "black") 168 | 169 | plt.ylabel('True label') 170 | plt.xlabel('Predicted label') 171 | plt.tight_layout() 172 | 173 | 174 | def update_lerning_rate_in_optimizer(n_times_loss_stable, method, optimizer, params): 175 | iter = optimizer.iterations.numpy() 176 | if method == 'triangle' and n_times_loss_stable >= 1: 177 | iter = iter % params.cyclic_lr_period 178 | far_from_mid = np.abs(iter - params.cyclic_lr_period / 2) 179 | fraction_from_mid = np.abs(params.cyclic_lr_period / 2 - far_from_mid) / (params.cyclic_lr_period / 2) 180 | factor = fraction_from_mid + (1 - fraction_from_mid) * params.min_lr_factor 181 | optimizer.learning_rate.assign(params.learning_rate * factor) 182 | 183 | if method == 'steps': 184 | for i in range(len(params.learning_rate_steps) - 1): 185 | if iter >= params.learning_rate_steps[i] and iter < params.learning_rate_steps[i + 1]: 186 | lr = params.learning_rate[i] 187 | optimizer.learning_rate.assign(lr) 188 | 189 | last_free_mem = np.inf 190 | def check_mem_and_exit_if_full(): 191 | global last_free_mem 192 | free_mem = psutil.virtual_memory().available + psutil.swap_memory().free 193 | free_mem_gb = round(free_mem / 1024 / 1024 / 1024, 2) 194 | if last_free_mem > free_mem_gb + 0.25: 195 | last_free_mem = free_mem_gb 196 | print('free_mem', free_mem_gb, 'GB') 197 | if free_mem_gb < 1: 198 | print('!!! Exiting due to memory full !!!') 199 | exit(111) 200 | return free_mem_gb 201 | 202 | 203 | def visualize_model_loop_on_each_color(vertices, faces, title=' ', vertex_colors_idx=None, cpos=None): 204 | clrs = np.unique(vertex_colors_idx) 205 | for c in clrs: 206 | if c == -1: 207 | continue 208 | v_colors = -1 * np.ones((vertices.shape[0])).astype(np.int) 209 | i = np.where(vertex_colors_idx==c) 210 | v_colors[i] = c 211 | print(c, i[0].size) 212 | visualize_model(vertices, faces, title=' ', vertex_colors_idx=v_colors, cpos=cpos) 213 | 214 | 215 | def visualize_model_walk(vertices, faces_, walk, jumps, title='', cpos=None): 216 | faces = np.hstack([[3] + f.tolist() for f in faces_]) 217 | surf = pv.PolyData(vertices, faces) 218 | p = pv.Plotter() 219 | p.add_mesh(surf, show_edges=True, color='white', opacity=0.6) 220 | p.add_mesh(pv.PolyData(surf.points), point_size=2, render_points_as_spheres=True) 221 | 222 | cm = np.array(plt.get_cmap('plasma').colors) 223 | a = (np.arange(walk.size) * cm.shape[0] / walk.size).astype(np.int) 224 | colors2use = cm[a] 225 | all_edges = [[2, walk[i], walk[i + 1]] for i in range(len(walk) - 1)] 226 | if np.any(1 - jumps): 227 | walk_edges = np.hstack([edge for edge, jump in zip(all_edges, jumps[1:]) if not jump]) 228 | walk_mesh = pv.PolyData(vertices, walk_edges) 229 | p.add_mesh(walk_mesh, show_edges=True, edge_color='blue', line_width=4) 230 | if np.any(jumps[1:]): 231 | jump_edges = np.hstack([edge for edge, jump in zip(all_edges, jumps[1:]) if jump]) 232 | walk_mesh = pv.PolyData(vertices, jump_edges) 233 | p.add_mesh(walk_mesh, show_edges=True, edge_color='red', line_width=4) 234 | for i, c in zip(walk, colors2use): 235 | if i == walk[0]: 236 | point_size = 20 237 | elif i == walk[-1]: 238 | point_size = 30 239 | else: 240 | point_size = 10 241 | p.add_mesh(pv.PolyData(surf.points[i]), color=c, point_size=point_size, render_points_as_spheres=True) 242 | p.camera_position = cpos 243 | cpos = p.show(title=title) 244 | 245 | 246 | def merge_some_models_for_visualization(models): 247 | def _get_faces(mesh): 248 | if type(mesh) is dict: 249 | return mesh['faces'] 250 | else: 251 | return np.asarray(mesh.triangles) 252 | def _get_vertices(mesh): 253 | if type(mesh) is dict: 254 | return mesh['faces'] 255 | else: 256 | return np.asarray(mesh.vertices) 257 | all_faces = _get_faces(models[0]) 258 | all_vertices = _get_vertices(models[0]) 259 | x_shift = all_vertices.max(axis=0)[0] 260 | for model in models[1:]: 261 | this_faces = _get_faces(model) + all_vertices.shape[0] 262 | all_faces = np.vstack((all_faces, this_faces)) 263 | this_vertices = _get_vertices(model).copy() 264 | this_vertices[:, 0] += x_shift - this_vertices[:, 0].min() + (this_vertices[:, 0].max() - this_vertices[:, 0].min()) * 0.1 265 | all_vertices = np.vstack((all_vertices, this_vertices)) 266 | x_shift = all_vertices.max(axis=0)[0] 267 | return all_vertices, all_faces 268 | 269 | 270 | def visualize_model(vertices_, faces_, title=' ', vertex_colors_idx=None, cpos=None, v_size=None, off_screen=False, walk=None, opacity=1.0, 271 | all_colors='white', face_colors=None, show_vertices=True, cmap=None, edge_colors=None, edge_color_a='white', dual_object=None, 272 | dual_object_shift=0.7,window_size=[1024, 768], line_width=1, show_edges=True, edge_colors_list=None, point_size=15, show_surface=True): 273 | if 0:#face_colors is not None: 274 | if face_colors.shape[0] == vertices_.shape[0]: 275 | vertices_ = vertices_.copy() 276 | vertices_ = np.vstack((vertices_, [[0, 0, 0], [0, 0, 0]])) 277 | else: 278 | faces_ = faces_.copy() 279 | faces_ = np.vstack((faces_, [[0, 0, 0], [0, 0, 0]])) 280 | face_colors = np.hstack((face_colors, [0, len(cmap) - 1])) 281 | p = pv.Plotter(off_screen=off_screen, window_size=(int(window_size[0]), int(window_size[1]))) 282 | if dual_object is None: 283 | n_obj2show = 1 284 | else: 285 | n_obj2show = 2 286 | for pos in range(n_obj2show): 287 | faces = np.hstack([[3] + f.tolist() for f in faces_]) 288 | vertices = vertices_.copy() 289 | if dual_object is not None: 290 | v_shift = [0, 0, 0] 291 | if pos == 0: 292 | v_shift[dual_object] -= dual_object_shift 293 | else: 294 | vertices[:, 2 - dual_object] *= -1 295 | vertices[:, dual_object] *= -1 296 | v_shift[dual_object] += dual_object_shift 297 | vertices += v_shift 298 | surf = pv.PolyData(vertices, faces) 299 | if show_surface: 300 | p.add_mesh(surf, show_edges=show_edges, edge_color=edge_color_a, color=all_colors, opacity=opacity, smooth_shading=True, 301 | scalars=face_colors, cmap=cmap, line_width=line_width) 302 | if show_vertices: 303 | p.add_mesh(pv.PolyData(surf.points), point_size=point_size, render_points_as_spheres=True) 304 | if show_vertices and vertex_colors_idx is not None: 305 | if type(cmap) is list: 306 | colors = cmap 307 | else: 308 | colors = 'ygbkmcywr' 309 | for c in np.unique(vertex_colors_idx): 310 | if c != -1: 311 | idxs = np.where(vertex_colors_idx == c)[0] 312 | if v_size is None: 313 | p.add_mesh(pv.PolyData(surf.points[idxs]), color=colors[c], point_size=point_size, render_points_as_spheres=True) 314 | else: 315 | for i in idxs: 316 | p.add_mesh(pv.PolyData(surf.points[i]), color=colors[c % len(colors)], point_size=v_size[i], 317 | render_points_as_spheres=True) 318 | colors = ['blue', 'red', 'lime', 'orange', 'black', 'pink', 'yellow', 'lightblue', 'lightgreen'] 319 | if type(walk) is list: 320 | for i, w in enumerate(walk): 321 | all_edges = [[2, w[i], w[i + 1]] for i in range(len(w) - 1)] 322 | walk_edges = np.hstack([edge for edge in all_edges]) 323 | walk_mesh = pv.PolyData(vertices, walk_edges) 324 | p.add_mesh(walk_mesh, show_edges=True, line_width=line_width * 4, edge_color=colors[i]) 325 | elif walk is not None and walk.size > 1: 326 | all_edges = [[2, walk[i], walk[i + 1]] for i in range(len(walk) - 1)] 327 | if edge_colors is None: 328 | edge_colors = np.zeros_like(walk) 329 | elif edge_colors == 'use_cmap': 330 | if 0: 331 | walk_edges = np.hstack([edge for edge in all_edges]) 332 | walk_mesh = pv.PolyData(vertices, walk_edges) 333 | scalars = (np.arange(len(all_edges)) / len(all_edges) * 255).astype(np.int) 334 | p.add_mesh(walk_mesh, show_edges=True, line_width=line_width, cmap='Blues', scalars=scalars) 335 | else: 336 | for i, edge in enumerate(all_edges): 337 | walk_edges = np.array(edge) 338 | walk_mesh = pv.PolyData(vertices, walk_edges) 339 | edge_color = (1.0 - i / len(all_edges), 0.0, i / len(all_edges)) 340 | p.add_mesh(walk_mesh, show_edges=True, line_width=line_width, edge_color=edge_color) 341 | elif type(edge_colors) is str: 342 | walk_edges = np.hstack([edge for edge in all_edges]) 343 | walk_mesh = pv.PolyData(vertices, walk_edges) 344 | p.add_mesh(walk_mesh, show_edges=True, line_width=line_width * 40, edge_color=edge_colors) 345 | else: 346 | for this_e_color in range(np.max(edge_colors) + 1): 347 | this_edges = (edge_colors == this_e_color) 348 | if np.any(this_edges): 349 | walk_edges = np.hstack([edge for edge, clr in zip(all_edges, this_edges[1:]) if clr]) 350 | walk_mesh = pv.PolyData(vertices, walk_edges) 351 | p.add_mesh(walk_mesh, show_edges=True, edge_color=colors[this_e_color], line_width=line_width*10) 352 | if edge_colors_list is not None: 353 | t_mesh = trimesh.Trimesh(vertices=vertices, faces=faces_, process=False) 354 | for clr, edges in edge_colors_list.items(): 355 | if clr.find(':') == -1: 356 | vertices2use = vertices 357 | this_edges_ = [[2, e[0], e[1]] for e in edges] 358 | this_edges = np.hstack([edge for edge in this_edges_]) 359 | walk_mesh = pv.PolyData(vertices2use, this_edges) 360 | p.add_mesh(walk_mesh, show_edges=True, edge_color=clr, line_width=line_width*10) 361 | else: 362 | clr_1st, clr_2nd = clr.split(':') 363 | vertices2use = [] 364 | this_edges_1st_clr = [] 365 | this_edges_2nd_clr = [] 366 | for e in edges: 367 | mean_normal = (t_mesh.vertex_normals[e[0]] + t_mesh.vertex_normals[e[1]]) / 2 368 | v0 = vertices[e[0]] 369 | v1 = vertices[e[1]] 370 | v_ = (v0 + v1) / 2 + mean_normal * np.linalg.norm(v0 - v1) * 0.001 371 | vertices2use.append(v0) 372 | vertices2use.append(v1) 373 | vertices2use.append(v_) 374 | this_edges_1st_clr.append([2, len(vertices2use) - 3, len(vertices2use) - 1]) 375 | this_edges_2nd_clr.append([2, len(vertices2use) - 2, len(vertices2use) - 1]) 376 | vertices2use = np.array(vertices2use) 377 | this_edges = np.hstack([edge for edge in this_edges_1st_clr]) 378 | walk_mesh = pv.PolyData(vertices2use, this_edges) 379 | p.add_mesh(walk_mesh, show_edges=True, edge_color=clr_1st, line_width=line_width) 380 | this_edges = np.hstack([edge for edge in this_edges_2nd_clr]) 381 | walk_mesh = pv.PolyData(vertices2use, this_edges) 382 | p.add_mesh(walk_mesh, show_edges=True, edge_color=clr_2nd, line_width=line_width) 383 | 384 | p.camera_position = cpos 385 | #p.show_bounds(grid='front', location='outer', all_edges=True) 386 | #min_v = np.min(vertices, axis=0) 387 | #p.add_mesh(pv.Plane(center=(0, -min_v[1], 0), direction=(0, 1, 0), i_size=3, j_size=3)) 388 | p.set_background("#AAAAAA", top="White") 389 | if off_screen: 390 | rendered = p.screenshot() 391 | p.close() 392 | return rendered 393 | else: 394 | cpos = p.show(title=title) 395 | 396 | return cpos 397 | 398 | 399 | def print_cpos(cpos): 400 | s = '[' 401 | for i, c in enumerate(cpos): 402 | s += '[' 403 | for j, n in enumerate(c): 404 | s += str(round(n, 2)) 405 | if j != 2: 406 | s += ' , ' 407 | s += ']' 408 | if i != 2: 409 | s += ' , ' 410 | s += ']' 411 | print(s) 412 | 413 | 414 | next_iter_to_keep = 0 # Should be set by -train_val- function, each time job starts 415 | def save_model_if_needed(iterations, dnn_model, params): 416 | global next_iter_to_keep 417 | iter_th = 2000 418 | keep = iterations.numpy() >= next_iter_to_keep 419 | dnn_model.save_weights(params.logdir, iterations.numpy(), keep=keep) 420 | if keep: 421 | if iterations < iter_th: 422 | next_iter_to_keep = iterations * 2 423 | else: 424 | next_iter_to_keep = int(iterations / iter_th) * iter_th + iter_th 425 | if params.full_accuracy_test is not None: 426 | if params.network_task == 'semantic_segmentation': 427 | accuracy, _ = evaluate_segmentation.calc_accuracy_test(params=params, dnn_model=dnn_model, verbose_level=0, 428 | **params.full_accuracy_test) 429 | elif params.network_task == 'classification': 430 | accuracy, _ = evaluate_clustering.calc_accuracy_test(params=params, dnn_model=dnn_model, verbose_level=0, 431 | **params.full_accuracy_test) 432 | elif params.network_task == 'manifold_classification': 433 | accuracy, _ = evaluate_clustering.calc_accuracy_test(params=params, dnn_model=dnn_model, verbose_level=0, 434 | **params.full_accuracy_test) 435 | elif params.network_task == 'unsupervised_classification': 436 | accuracy, _ = evaluate_clustering.calc_accuracy_test(params=params, dnn_model=dnn_model, verbose_level=0, 437 | **params.full_accuracy_test) 438 | elif params.network_task == 'features_extraction': 439 | # no need to calc accuracy... model already saved. 440 | return 441 | 442 | with open(params.logdir + '/log.txt', 'at') as f: 443 | f.write('Accuracy: ' + str(np.round(np.array(accuracy) * 100, 2)) + '%, Iter: ' + str(iterations.numpy()) + '\n') 444 | tf.summary.scalar('full_accuracy_test/overall', accuracy[0], step=iterations) 445 | tf.summary.scalar('full_accuracy_test/mean', accuracy[1], step=iterations) 446 | 447 | 448 | def get_dataset_type_from_name(tf_names): 449 | name_str = tf_names[0].numpy().decode() 450 | return name_str[:name_str.find(':')] 451 | 452 | 453 | def get_model_name_from_npz_fn(npz_fn): 454 | fn = npz_fn.split('/')[-1].split('.')[-2] 455 | sp_fn = fn.split('_') 456 | # Gal changes 457 | 458 | # before changes 459 | 460 | if npz_fn.find('/shrec11') == -1: 461 | sp_fn = sp_fn[1:] 462 | 463 | # END 464 | 465 | i = np.where([s.isdigit() for s in sp_fn])[0][0] 466 | model_name = '_'.join(sp_fn[:i + 1]) 467 | n_faces = int(sp_fn[-1]) 468 | 469 | return model_name, n_faces 470 | 471 | def colorize_and_dump_model(model, idxs2color, out_fn, clrs=None, norm_clrs=True, show=False, vertex_colors=None, 472 | idxs2color_faces=True, verbose=True): 473 | mesh = trimesh.Trimesh(vertices=model.vertices, faces=model.faces, process=False) 474 | 475 | some_colors = False 476 | if len(idxs2color) and type(idxs2color[0]) is list: 477 | some_colors = True 478 | faces_colored = 0 479 | if idxs2color_faces: 480 | face_colors = np.zeros((model.faces.shape[0], 3)) 481 | else: 482 | face_colors = np.zeros((model.vertices.shape[0], 3)) 483 | alpha = np.zeros((model.faces.shape[0], 1)) 484 | f = [] 485 | for idx in range(model.faces.shape[0]): 486 | r = 0 487 | g = 0 488 | b = 0 489 | if clrs is not None: 490 | (r, g, b) = clrs[idx] 491 | elif some_colors: 492 | for c_idx_, l in enumerate(idxs2color): 493 | if idx in l: 494 | c_idx = c_idx_ % len(idx2color) 495 | r, g, b = idx2color[c_idx] 496 | faces_colored += 1 497 | break 498 | else: 499 | if idx in idxs2color: 500 | b = int(255 * np.where(np.array(idxs2color) == idx)[0][0] / len(idxs2color)) 501 | r = 100 502 | if r or b or g: 503 | face_colors[idx] = [r, g, b] 504 | alpha[idx] = 255 505 | 506 | if vertex_colors is not None: 507 | vertex_colors *= 255 508 | mesh.visual.vertex_colors = vertex_colors 509 | elif not idxs2color_faces: 510 | mesh.visual.vertex_colors = face_colors.astype('uint8') 511 | else: 512 | if norm_clrs: 513 | face_colors -= face_colors.min() 514 | face_colors /= face_colors.max() / 255 515 | else: 516 | face_colors *= 255 517 | 518 | face_colors = np.hstack((face_colors, alpha)) 519 | 520 | mesh.visual.face_colors = face_colors.astype('uint8') 521 | 522 | trimesh.repair.fix_normals(mesh) 523 | try: 524 | mesh.export(out_fn) 525 | except: 526 | print('Mesh could not be dumped.') 527 | 528 | if show: 529 | mesh.show() 530 | 531 | if verbose: 532 | print(out_fn, ' was written') 533 | 534 | 535 | def print_labels_names_and_indices(model_name): 536 | if model_name == "shrec11": 537 | for i in range(30): 538 | print(i,": ",shrec11_labels[i]) 539 | elif model_name == "modelnet40": 540 | for i in range(40): 541 | print(i, ": ",model_net_labels[i]) 542 | else: 543 | print("utils Error - Unknown model name !!") 544 | return 545 | 546 | 547 | -------------------------------------------------------------------------------- /evaluate_clustering.py: -------------------------------------------------------------------------------- 1 | import os, shutil, time, copy, glob 2 | 3 | import yaml 4 | from easydict import EasyDict 5 | import json 6 | import platform 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | import trimesh, open3d 11 | import pyvista as pv 12 | import scipy 13 | import pylab as plt 14 | from sklearn.manifold import TSNE 15 | from sklearn.decomposition import KernelPCA 16 | from tqdm import tqdm 17 | import argparse 18 | 19 | import rnn_model 20 | import utils 21 | import yaml 22 | import dataset 23 | import dataset_prepare 24 | 25 | recon_training = True 26 | timelog = {} 27 | timelog['prep_model'] = [] 28 | timelog['fill_features'] = [] 29 | 30 | import yaml 31 | def get_config(config): 32 | with open(config, 'r') as stream: 33 | return yaml.safe_load(stream) 34 | 35 | 36 | """ 37 | #get hyper params from yaml 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--config', type=str, default='recon_config.yaml', help='Path to the config file.') 40 | opts = parser.parse_args() 41 | config = get_config(opts.config)""" 42 | 43 | def print_enters(to_print): 44 | print("\n\n\n\n") 45 | print(to_print) 46 | print("\n\n\n\n") 47 | 48 | def get_model_names(): 49 | part = 'test' 50 | model_fns = [] 51 | for i, name in enumerate(dataset_prepare.model_net_labels): 52 | pathname_expansion = os.path.expanduser('~') + '/datasets_processed/ModelNet40/' + name + '/' + part + '/*.off' 53 | filenames = glob.glob(pathname_expansion) 54 | model_fns += filenames 55 | return model_fns 56 | 57 | def show_walk(model, features, one_walk=False): 58 | for wi in range(features.shape[0]): 59 | walk = features[wi, :, -1].astype(np.int) 60 | jumps = features[wi, :, -2].astype(np.bool) 61 | utils.visualize_model_walk(model['vertices'], model['faces'], walk, jumps) 62 | if one_walk: 63 | break 64 | 65 | 66 | def calc_accuracy_test(dataset_folder=False, logdir=None, labels=None, iter2use='last', classes_indices_to_use=None, 67 | dnn_model=None, params=None, verbose_level=2, min_max_faces2use=[0, 5000], model_fn=None, 68 | target_n_faces=['according_to_dataset'], n_walks_per_model=16, seq_len=None, data_augmentation={}): 69 | verbose_level = 2 70 | SHOW_WALK = 1 71 | WALK_LEN_PROP_TO_NUM_OF_TRIANLES = 0 72 | COMPONENT_ANALYSIS = False 73 | PRINT_CONFUSION_MATRIX = False 74 | np.random.seed(1) 75 | tf.random.set_seed(0) 76 | #classes2use = None #['desk', 'dresser', 'table', 'laptop', 'lamp', 'stool', 'wardrobe'] # or "None" for all 77 | #params.classes_indices_to_use = None #[15, 25] 78 | if params is None: 79 | classes2use = classes_indices_to_use 80 | else: 81 | classes2use = params.classes_indices_to_use 82 | 83 | 84 | print_details = verbose_level >= 2 85 | if params is None: 86 | with open(logdir + '/params.txt') as fp: 87 | params = EasyDict(json.load(fp)) 88 | if model_fn is not None: 89 | pass 90 | elif iter2use != 'last': 91 | model_fn = logdir + '/learned_model2keep--' + iter2use 92 | model_fn = model_fn.replace('//', '/') 93 | else: 94 | model_fn = tf.train.latest_checkpoint(logdir) 95 | if verbose_level and model_fn is not None: 96 | print(utils.color.BOLD + utils.color.BLUE + 'logdir : ', model_fn + utils.color.END) 97 | else: 98 | params = copy.deepcopy(params) 99 | params.batch_size = 1 100 | params.n_walks_per_model = n_walks_per_model 101 | 102 | if "modelnet" in params.logdir: 103 | model_name = "modelnet40" 104 | elif "mesh_net" in params.logdir: 105 | model_name = "mesh_net" 106 | elif "shrec" in params.logdir: 107 | model_name = "shrec11" 108 | else: 109 | print("Unknown model ! exiting...") 110 | exit(0) 111 | 112 | 113 | if 0: 114 | params.net_input.append('jump_indication') 115 | if 0: 116 | params.layer_sizes = None 117 | params.aditional_network_params = [] 118 | 119 | if seq_len: 120 | params.seq_len = seq_len 121 | if verbose_level: 122 | print('params.seq_len:', params.seq_len, ' ; n_walks_per_model:', n_walks_per_model) 123 | 124 | #Amir - check 800 after training over 200 125 | if params is None: 126 | params.seq_len = 200 127 | 128 | if SHOW_WALK: 129 | params.net_input += ['vertex_indices'] 130 | 131 | params.set_seq_len_by_n_faces = 1 132 | if dataset_folder: 133 | size_limit = np.inf # 200 134 | 135 | params.classes_indices_to_use = classes2use 136 | pathname_expansion = dataset_folder 137 | if 1: 138 | test_dataset, n_models_to_test = dataset.tf_mesh_dataset(params, pathname_expansion, mode=params.network_task, 139 | shuffle_size=0, size_limit=size_limit, permute_file_names=True, 140 | min_max_faces2use=min_max_faces2use, must_run_on_all=True, 141 | data_augmentation=data_augmentation) 142 | else: 143 | test_dataset = dataset.mesh_dataset_iterator(params, pathname_expansion, mode=params.network_task, 144 | shuffle_size=0, size_limit=size_limit, permute_file_names=True, 145 | min_max_faces2use=min_max_faces2use) 146 | iter(test_dataset).__next__() 147 | n_models_to_test = 1000 148 | else: 149 | test_dataset = get_model_names() 150 | test_dataset = np.random.permutation(test_dataset) 151 | n_models_to_test = len(test_dataset) 152 | 153 | if dnn_model is None: 154 | if params.net == "RnnWalkNet": 155 | dnn_model = rnn_model.RnnWalkNet(params, params.n_classes, params.net_input_dim - SHOW_WALK, model_fn, 156 | model_must_be_load=True, dump_model_visualization=False) 157 | elif params.net == "Manifold_RnnWalkNet": 158 | dnn_model = rnn_model.RnnManifoldWalkNet(params, params.n_classes, params.net_input_dim - SHOW_WALK, model_fn, 159 | model_must_be_load=True, dump_model_visualization=False) 160 | elif params.net == "Unsupervised_RnnWalkNet": 161 | dnn_model = rnn_model.Unsupervised_RnnWalkNet(params, params.n_classes, params.net_input_dim - SHOW_WALK, model_fn, 162 | model_must_be_load=True, dump_model_visualization=False) 163 | else: 164 | print("Net type is not familiar ! exiting..") 165 | exit(0) 166 | 167 | n_pos_all = 0 168 | n_classes = 40 169 | all_confusion = np.zeros((n_classes, n_classes), dtype=np.int) 170 | size_accuracy = [] 171 | ii = 0 172 | tb_all = time.time() 173 | res_per_n_faces = {} 174 | pred_per_model_name = {} 175 | dnn_inference_time = [] # 150mSec for 64 walks of 200 steps 176 | bad_pred = EasyDict({'n_comp': [], 'biggest_comp_area_ratio': []}) 177 | good_pred = EasyDict({'n_comp': [], 'biggest_comp_area_ratio': []}) 178 | 179 | utils.print_labels_names_and_indices(model_name) 180 | for i, data in tqdm(enumerate(test_dataset), disable=print_details, total=n_models_to_test): 181 | name, ftrs, gt = data 182 | model_fn = name.numpy()[0].decode() 183 | model_name, n_faces = utils.get_model_name_from_npz_fn(model_fn) 184 | print(model_name) 185 | assert ftrs.shape[0] == 1, 'Must have one model per batch for test' 186 | if WALK_LEN_PROP_TO_NUM_OF_TRIANLES: 187 | n2keep = int(n_faces / 2.5) 188 | ftrs = ftrs[:, :, :n2keep, :] 189 | ftrs = tf.reshape(ftrs, ftrs.shape[1:]) 190 | gt = gt.numpy()[0] 191 | predictions = None 192 | for i_f, this_target_n_faces in enumerate(target_n_faces): 193 | model = None 194 | if SHOW_WALK: 195 | if model is None: 196 | model = dataset.load_model_from_npz(model_fn) 197 | if model['vertices'].shape[0] < 1000: 198 | print(model_fn) 199 | print('nv: ', model['vertices'].shape[0]) 200 | #Amir - removed for modelnet40 201 | #show_walk(model, ftrs.numpy(), one_walk=1) 202 | ftrs = ftrs[:, :, :-1] 203 | ftr2use = ftrs.numpy() 204 | for k in [1, -1][:1]: # test augmentation: flip X axis (did not help) 205 | ftr2use[:, :, 0] *= k 206 | tb = time.time() 207 | if 0: 208 | jumps = ftr2use[:,:,3] 209 | jumps = np.hstack((jumps, np.ones((jumps.shape[0], 1)))) # To promise that one jump is found 210 | ftr2use = ftr2use[:,:,:3] 211 | first_jumps = [np.where(j)[0][0] for j in jumps] 212 | last_jumps = [0,0,0] # [np.where(j)[0][-2] - np.where(j)[0][-3] for j in jumps] 213 | plt.hist(first_jumps) 214 | plt.hist(last_jumps) 215 | model = dataset.load_model_from_npz(model_fn) 216 | plt.title('#vertices / faces : ' + str(model['vertices'].shape[0]) + ' / ' + str(model['faces'].shape[0])) 217 | plt.show() 218 | predictions_ = dnn_model(ftr2use, classify=True, training=False).numpy() 219 | te = time.time() - tb 220 | dnn_inference_time.append(te / n_walks_per_model * 1000) 221 | if 0:#len(dnn_inference_time) == 10: 222 | print(dnn_inference_time) 223 | plt.hist(dnn_inference_time[1:]) 224 | plt.xlabel('[mSec]') 225 | plt.show() 226 | if predictions is None: 227 | predictions = predictions_ 228 | else: 229 | predictions = np.vstack((predictions, predictions_)) 230 | 231 | mean_pred = np.mean(predictions, axis=0) 232 | max_hit = np.argmax(mean_pred) 233 | # Gals changes 234 | #model_name = labels[int(gt)] 235 | # End 236 | 237 | if model_name not in pred_per_model_name.keys(): 238 | pred_per_model_name[model_name] = [gt, np.zeros_like(mean_pred)] 239 | pred_per_model_name[model_name][1] += mean_pred 240 | str2add = '; n.unique models: ' + str(len(pred_per_model_name.keys())) 241 | ''' 242 | print("\n\n") 243 | print(pred_per_model_name) 244 | print("\n\n") 245 | ''' 246 | 247 | if n_faces not in res_per_n_faces.keys(): 248 | res_per_n_faces[n_faces] = [0, 0] 249 | res_per_n_faces[n_faces][0] += 1 250 | 251 | if COMPONENT_ANALYSIS: 252 | model = dataset.load_model_from_npz(model_fn) 253 | comp_summary = dataset_prepare.component_analysis(model['faces'], model['vertices']) 254 | comp_area = [a['area'] for a in comp_summary] 255 | n_components = len(comp_summary) 256 | biggest_comp_area_ratio = np.sort(comp_area)[-1] / np.sum(comp_area) 257 | 258 | if max_hit != gt: 259 | false_str = ' , predicted: ' + labels[int(max_hit)] + ' ; ' + model_fn 260 | if COMPONENT_ANALYSIS: 261 | bad_pred.n_comp.append(n_components) 262 | bad_pred.biggest_comp_area_ratio.append(biggest_comp_area_ratio) 263 | else: 264 | res_per_n_faces[n_faces][1] += 1 265 | false_str = '' 266 | if COMPONENT_ANALYSIS: 267 | good_pred.n_comp.append(n_components) 268 | good_pred.biggest_comp_area_ratio.append(biggest_comp_area_ratio) 269 | if print_details: 270 | print(' ', max_hit == gt, labels[int(gt)], false_str, 'n_vertices: ')#, model['vertices'].shape[0]) 271 | if 0:#max_hit != gt: 272 | model = dataset.load_model_from_npz(model_fn) 273 | # Amir - only for model net 274 | #utils.visualize_model(model['vertices'], model['faces'], line_width=1, opacity=1) 275 | 276 | all_confusion[int(gt), max_hit] += 1 277 | n_pos_all += (max_hit == gt) 278 | ii += 1 279 | if print_details: 280 | print(i, '/', n_models_to_test, ') Total accuracy: ', round(n_pos_all / ii * 100, 1), 'n_pos_all:', n_pos_all, str2add) 281 | 282 | if print_details: 283 | print(utils.color.BLUE + 'Total time, all:', time.time() - tb_all, utils.color.END) 284 | 285 | n_models = 0 286 | n_sucesses = 0 287 | all_confusion_all_faces = np.zeros((n_classes, n_classes), dtype=np.int) 288 | for k, v in pred_per_model_name.items(): 289 | gt = v[0] 290 | pred = v[1] 291 | max_hit = np.argmax(pred) 292 | all_confusion_all_faces[gt, max_hit] += 1 293 | n_models += 1 294 | n_sucesses += max_hit == gt 295 | mean_accuracy_all_faces = n_sucesses / n_models 296 | if print_details: 297 | print('\n\n ---------------\nOn avarage, for all faces:') 298 | print(' Accuracy: ', np.round(mean_accuracy_all_faces * 100, 2), '% ; n models checkd: ', n_models) 299 | print('Results per number of faces:') 300 | print(' ', res_per_n_faces, '\n\n--------------\n\n') 301 | 302 | if 0: 303 | bins = [0, 700, 1500, 3000, 5000] 304 | accuracy_per_n_faces = [] 305 | for b in range(len(bins) - 1): 306 | ks = [k for k in res_per_n_faces.keys() if k >= bins[b] and k < bins[b + 1]] 307 | attempts = 0 308 | successes = 0 309 | for k in ks: 310 | attempts_, successes_ = res_per_n_faces[k] 311 | attempts += attempts_ 312 | successes += successes_ 313 | if attempts: 314 | accuracy_per_n_faces.append(successes / attempts) 315 | else: 316 | accuracy_per_n_faces.append(np.nan) 317 | x = (np.array(bins[1:]) + np.array(bins[:-1])) / 2 318 | plt.figure() 319 | plt.plot(x, accuracy_per_n_faces) 320 | plt.xlabel('Number of faces') 321 | plt.ylabel('Accuracy') 322 | plt.show() 323 | 324 | if PRINT_CONFUSION_MATRIX: 325 | b = 0; 326 | e = 40 327 | utils.plot_confusion_matrix(all_confusion[b:e, b:e], labels[b:e], normalize=1, show_txt=0) 328 | 329 | # Print list of accuracy per model 330 | for confusion in [all_confusion, all_confusion_all_faces]: 331 | if print_details: 332 | print('------') 333 | acc_per_class = [] 334 | for i, name in enumerate(labels): 335 | this_type = confusion[i] 336 | n_this_type = this_type.sum() 337 | accuracy_this_type = this_type[i] / n_this_type 338 | if n_this_type: 339 | acc_per_class.append(accuracy_this_type) 340 | this_type_ = this_type.copy() 341 | this_type_[i] = -1 342 | scnd_best = np.argmax(this_type_) 343 | scnd_best_name = labels[scnd_best] 344 | accuracy_2nd_best = this_type[scnd_best] / n_this_type 345 | if print_details: 346 | print(str(i).ljust(3), name.ljust(12), n_this_type, ',', str(round(accuracy_this_type * 100, 1)).ljust(5), ' ; 2nd best:', scnd_best_name.ljust(12), round(accuracy_2nd_best * 100, 1)) 347 | mean_acc_per_class = np.mean(acc_per_class) 348 | 349 | if 0: 350 | print('Time Log:') 351 | for k, v in timelog.items(): 352 | print(' ' , k, ':', np.mean(v)) 353 | 354 | return [mean_accuracy_all_faces, mean_acc_per_class], dnn_model 355 | 356 | def show_features_tsne(dataset_files_path, logdir, dnn_model=None, cls2show=None, n_iters=None, dataset_labels=None, 357 | model_fn='', max_size_per_class=5): 358 | with open(logdir + '/params.txt') as fp: 359 | params = EasyDict(json.load(fp)) 360 | params.network_task = 'classification' 361 | params.batch_size = 1 362 | params.one_label_per_model = True 363 | params.n_walks_per_model = 8 364 | params.logdir = logdir 365 | params.seq_len = 200 366 | params.new_run = 0 367 | if n_iters is None: 368 | n_iters = 1 369 | 370 | if "modelnet" in params.logdir: 371 | model_name = "modelnet40" 372 | elif "shrec" in params.logdir: 373 | model_name = "shrec11" 374 | else: 375 | print("Unknown model ! exiting...") 376 | exit(0) 377 | 378 | # choose between "last" and "features" 379 | layer_to_show = "last" 380 | layer_to_show = "features" 381 | 382 | utils.print_labels_names_and_indices(model_name) 383 | 384 | params.classes_indices_to_use = cls2show 385 | 386 | pathname_expansion = dataset_files_path 387 | 388 | test_dataset, n_items = dataset.tf_mesh_dataset(params, pathname_expansion, mode=params.network_task, 389 | max_size_per_class=max_size_per_class) 390 | 391 | if params.net == 'RnnWalkNet': 392 | print('RnnWalkNet') 393 | # Gals changes 394 | dnn_model = rnn_model.RnnWalkNet(params, params.n_classes, params.net_input_dim, model_fn, 395 | model_must_be_load=True, dump_model_visualization=False) 396 | # BEFORE 397 | ''' 398 | dnn_model = dnn_cad_seq.RnnWalkNet(params, params.n_classes, params.net_input_dim, model_fn, 399 | model_must_be_load=True, dump_model_visualization=False) 400 | ''' 401 | # END 402 | elif params.net == 'Unsupervised_RnnWalkNet': 403 | dnn_model = rnn_model.Unsupervised_RnnWalkNet(params, params.n_classes, params.net_input_dim, model_fn, 404 | model_must_be_load=True, dump_model_visualization=False) 405 | 406 | test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') 407 | n_walks = 0 408 | model_fns = pred_all = lbl_all = None 409 | print('Calculating embeddings.') 410 | tb = time.time() 411 | for iter in range(n_iters): 412 | for name_, model_ftrs, labels in test_dataset: 413 | labels = labels.numpy() 414 | name = name_.numpy()[0].decode() 415 | print(' - Got data', name, labels) 416 | sp = model_ftrs.shape 417 | ftrs = tf.reshape(model_ftrs, (-1, sp[-2], sp[-1])) 418 | #print(' - Start Run Pred') 419 | if layer_to_show == "features": 420 | predictions_features = dnn_model(ftrs, training=False, classify=False).numpy() 421 | if 0: 422 | predictions_features = np.mean(predictions_features, axis=0)[None, :] 423 | name = [name] 424 | else: 425 | labels = np.repeat(labels, predictions_features.shape[0]) 426 | name = [name] * predictions_features.shape[0] 427 | predictions_labels = dnn_model(ftrs, training=False, classify=True).numpy() 428 | elif layer_to_show == "last": 429 | predictions_labels = dnn_model(ftrs, training=False, classify=True).numpy() 430 | predictions_features = predictions_labels 431 | if 0: 432 | predictions_features = np.mean(predictions_features, axis=0)[None, :] 433 | name = [name] 434 | else: 435 | labels = np.repeat(labels, predictions_features.shape[0]) 436 | name = [name] * predictions_features.shape[0] 437 | else: 438 | print("layer_to_show is unvalid ! exiting..") 439 | exit(-1) 440 | #print(' - End Run Pred') 441 | pred_best = predictions_labels.argmax(axis=1) 442 | acc = np.mean(labels == pred_best) 443 | print('This batch accuracy:', round(100 * acc, 2)) 444 | if pred_all is None: 445 | pred_all = predictions_features 446 | lbl_all = labels 447 | model_fns = name 448 | else: 449 | pred_all = np.vstack((pred_all, predictions_features)) 450 | lbl_all = np.concatenate((lbl_all, labels)) 451 | model_fns += name 452 | #if pred_all.shape[0] > 1200: 453 | # break 454 | #break 455 | print('Feature calc time: ', round(time.time() - tb, 2)) 456 | shape_fn_2_id = {} 457 | shape_fns = np.array(model_fns) 458 | for cls in np.unique(lbl_all): 459 | this_cls_idxs = np.where(lbl_all == cls)[0] 460 | shape_fn_this_cls = shape_fns[this_cls_idxs] 461 | shape_fn_2_id[cls] = {n: i for i, n in enumerate(list(set(shape_fn_this_cls)))} 462 | if 0: 463 | pred_all = pred_all[:1200, :20] 464 | lbl_all = lbl_all[:1200] 465 | print('Embedding shape:', pred_all.shape) 466 | print('t-SNE calculation') 467 | transformer = TSNE(n_components=2) 468 | 469 | ftrs_tsne = transformer.fit_transform(pred_all) 470 | print(' t-SNE calc finished') 471 | shps = '.1234+X|_' 472 | shps = '.<*>^vspPDd' 473 | colors = utils.colors_list 474 | plt.figure() 475 | i_cls = -1 476 | for cls, this_cls_shape_fns in shape_fn_2_id.items(): 477 | i_cls += 1 478 | for i_shape, this_shape_fn in enumerate(this_cls_shape_fns): 479 | idxs = (shape_fns == this_shape_fn) * (lbl_all == cls) 480 | if idxs.size: 481 | clr = colors[i_cls % len(colors)] 482 | edgecolor = colors[(i_shape + 1) % len(colors)] 483 | mrkr = shps[i_shape % len(shps)] 484 | if i_shape == 0: 485 | label=dataset_labels[cls] 486 | else: 487 | label = None 488 | plt.scatter(ftrs_tsne[idxs, 0], ftrs_tsne[idxs, 1], color=clr, marker=mrkr, #edgecolor=edgecolor, linewidth=3, 489 | s=100, label=label) 490 | plt.legend(fontsize=15) 491 | plt.axis('off') 492 | plt.show() 493 | 494 | def check_rotation_weak_points(): 495 | if 0: 496 | logdir = '/home/alonla/mesh_walker/runs_aug_45/0001-06.08.2020..17.40__shrec11_10-10_A/' 497 | logdir = '/home/alonla/mesh_walker/runs_aug_45/0002-06.08.2020..21.48__shrec11_10-10_B/' 498 | else: 499 | logdir = '/home/alonla/mesh_walker/runs_aug_360/0001-06.08.2020..17.40__shrec11_10-10_A/' 500 | logdir = '/home/alonla/mesh_walker/runs_aug_360_must/0001-23.08.2020..18.03__shrec11_10-10_A/' 501 | #logdir = '/home/alonla/mesh_walker/runs_aug_360/0002-06.08.2020..21.49__shrec11_10-10_B/' 502 | print(logdir) 503 | dataset_folder = '/home/alonla/mesh_walker/datasets_processed/shrec11/10-10_A/test/*.*' 504 | #dataset_folder = '/home/alonla/mesh_walker/datasets_processed/shrec11/10-10_B/test/*.*' 505 | 506 | # Gals changes 507 | #model_fn = logdir + 'learned_model2keep__00060003.keras' 508 | model_fn = glob.glob(logdir + "learned_model2keep__*.keras")[-1] 509 | # End 510 | 511 | dnn_model = Noe 512 | rot_angles = range(0, 360, 10) 513 | for axis in [0, 1, 2]: 514 | accs = [] 515 | stds = [] 516 | dataset.data_augmentation_rotation.test_rotation_axis = axis 517 | for rot in rot_angles: 518 | accs_this_rot = [] 519 | for _ in range(5): 520 | acc, dnn_model = calc_accuracy_test(logdir=logdir, dataset_folder=dataset_folder, n_walks_per_model=16, 521 | dnn_model=dnn_model, labels=dataset_prepare.shrec11_labels, 522 | model_fn=model_fn, verbose_level=0, data_augmentation={'rotation': rot}) 523 | accs_this_rot.append(acc[0]) 524 | accs.append(np.mean(accs_this_rot)) 525 | stds.append(np.std(accs_this_rot)) 526 | print(rot, accs, stds) 527 | plt.errorbar(rot_angles, accs, yerr=stds) 528 | plt.xlabel('Rotation [degrees]') 529 | plt.ylabel('Accuracy') 530 | plt.title('Accuracy VS rotation angles, axis = ' + str(axis)) 531 | plt.legend(['axis=0', 'axis=1', 'axis=2']) 532 | plt.suptitle('/'.join(logdir.split('/')[-3:-1])) 533 | plt.show() 534 | 535 | import sys 536 | 537 | if __name__ == '__main__': 538 | np.random.seed(0) 539 | utils.config_gpu(1) 540 | 541 | data_path = sys.argv[1] 542 | trained_model_path = sys.argv[2] 543 | dataset_name = sys.argv[3] 544 | print(sys.argv[0], sys.argv[1], sys.argv[2], sys.argv[3]) 545 | accuracy_or_tsne = sys.argv[4] 546 | t_SNE = True if accuracy_or_tsne == "tsne" else False 547 | modelnet40 = True if dataset_name == "modelnet40" else False 548 | shrec11 = True if dataset_name == "shrec11" else False 549 | 550 | if len(sys.argv) == 6: 551 | config_path = sys.argv[6] 552 | config = get_config(config_path) 553 | else: 554 | config = None 555 | 556 | 557 | if trained_model_path == 'latest': 558 | trained_models_names = glob.glob("/home/galye/mesh_walker/runs_aug_360_must/*"+dataset_name+"*") 559 | trained_models_names.sort(key=os.path.getctime) 560 | trained_model_path = trained_models_names[-1] 561 | 562 | # for SHREC11 563 | if shrec11: 564 | #data_to_use = "10-10_A/test/*.npz" 565 | data_to_use = "16-04_a/test/*.npz" 566 | #data_to_use = "16-4_B/test/*.npz" 567 | #data_to_use = "16-4_C/test/*.npz" 568 | #data_to_use = "*test*.npz" 569 | 570 | # for modelnet40 571 | if modelnet40: 572 | data_to_use = "test*.npz" 573 | 574 | dataset_files_path = data_path + data_to_use 575 | logdir = trained_model_path 576 | print(logdir) 577 | print(glob.glob(logdir + "learned_model2keep__*.keras")) 578 | model_fn = glob.glob(logdir + "learned_model2keep__*.keras")[-1] 579 | 580 | if 0: 581 | check_rotation_weak_points() 582 | exit(0) 583 | 584 | #test_dataset() 585 | iter2use = 'last' 586 | classes_indices_to_use = None 587 | model_fn = None 588 | 589 | 590 | if t_SNE: # t-SNE 591 | if shrec11: # Use shrec model 592 | dataset_labels = dataset_prepare.shrec11_labels 593 | 594 | cls2show = range(30)[0:6] 595 | #cls2show = range(30)[-5:-1] 596 | elif modelnet40: # Use ModelNet 597 | dataset_labels = dataset_prepare.model_net_labels 598 | 599 | cls2show = range(30)[-9:-1] 600 | else: 601 | print("No module specified !! exiting..") 602 | exit(0) 603 | 604 | show_features_tsne(dataset_files_path=dataset_files_path, logdir=logdir, n_iters=1, cls2show=cls2show, dataset_labels=dataset_labels, 605 | model_fn=model_fn, max_size_per_class=3) 606 | elif modelnet40 : # ModelNet 607 | dataset_folder = dataset_files_path 608 | 609 | min_max_faces2use = [000, 4000] 610 | if 1: 611 | print(utils.color.BOLD + utils.color.BLUE + '6 classes used for fast run', utils.color.END) 612 | classes_indices_to_use = [39, 38, 37, 31, 28, 19] 613 | #classes_indices_to_use = range(30) 614 | accs, _ = calc_accuracy_test(logdir=logdir, dataset_folder=dataset_folder, 615 | labels=dataset_prepare.model_net_labels, iter2use=iter2use, 616 | classes_indices_to_use=classes_indices_to_use, 617 | min_max_faces2use=min_max_faces2use, model_fn=model_fn, n_walks_per_model=16 * 4) 618 | print('Overall Accuracy / Mean Accuracy:', np.round(np.array(accs) * 100, 2)) 619 | elif shrec11: # SHREC11 620 | # Gals changes 621 | dataset_path_a = '16-04_a/test/*.npz' 622 | dataset_path_b = '16-4_B/test/*.npz' 623 | dataset_path_c = '16-4_C/test/*.npz' 624 | # END 625 | 626 | acc_all = [] 627 | for curr_dataset_path in [dataset_path_a, dataset_path_b, dataset_path_c]: 628 | print("\n\nNew Iteration !") 629 | dataset_path = data_path + curr_dataset_path 630 | print("dataset_path=", dataset_path, "\n\n") 631 | 632 | if config is not None: 633 | if config['trained_only_2_classes'] == True: 634 | # params.classes_indices_to_use = (params.classes_indices_to_use)[0:2] 635 | first_label = min(config['source_label'], config['target_label']) 636 | sec_label = max(config['source_label'], config['target_label']) 637 | cls2show = [first_label, sec_label] 638 | else: 639 | cls2show = None 640 | cls2show = None #[15, 25] 641 | acc, _ = calc_accuracy_test(logdir=logdir, 642 | dataset_folder=dataset_path, classes_indices_to_use=cls2show, labels=dataset_prepare.shrec11_labels, iter2use=iter2use, 643 | model_fn=model_fn, n_walks_per_model=8) 644 | acc_all.append(acc) 645 | continue 646 | print(acc_all) 647 | print(np.mean(acc_all)) 648 | elif 1: # Look for Rotation weekpoints 649 | if 0: 650 | logdir = '/home/alonla/mesh_walker/runs_aug_360/0004-07.08.2020..06.05__shrec11_16-04_A' 651 | dataset_folder = '/home/alonla/mesh_walker/datasets_processed/shrec11/16-04_a/test/*.*' 652 | else: 653 | if 0: 654 | logdir = '/home/alonla/mesh_walker/runs_aug_45/0001-06.08.2020..17.40__shrec11_10-10_A/' 655 | else: 656 | logdir = '/home/alonla/mesh_walker/runs_aug_360/0001-06.08.2020..17.40__shrec11_10-10_A/' 657 | dataset_folder = '/home/alonla/mesh_walker/datasets_processed/shrec11/10-10_A/test/*.*' 658 | model_fn = None # logdir + 'learned_model2keep__00200010.keras' 659 | tb = time.time() 660 | dnn_model = None 661 | accs = [] 662 | rot_angles = range(0, 180, 10) 663 | for rot in rot_angles: 664 | acc, dnn_model = calc_accuracy_test(logdir=logdir, dataset_folder=dataset_folder, n_walks_per_model=8, 665 | dnn_model=dnn_model, labels=dataset_prepare.shrec11_labels, iter2use=str(iter2use), 666 | model_fn=model_fn, verbose_level=0, data_augmentation={'rotation': rot}) 667 | accs.append(acc[0]) 668 | print(rot, accs) 669 | plt.plot(rot_angles, accs) 670 | plt.xlabel('Rotation [degrees]') 671 | plt.ylabel('Accuracy') 672 | plt.title('Accuracy VS rotation angles') 673 | plt.show() 674 | elif 1: # Check STD vs Number of walks 675 | if 1: 676 | logdir = '/home/alonla/mesh_walker/runs_aug_360/0004-07.08.2020..06.05__shrec11_16-04_A' 677 | dataset_folder = '/home/alonla/mesh_walker/datasets_processed/shrec11/16-04_a/test/*.*' 678 | else: 679 | logdir = '/home/alonla/mesh_walker/runs_aug_360/0001-06.08.2020..17.40__shrec11_10-10_A' 680 | dataset_folder = '/home/alonla/mesh_walker/datasets_processed/shrec11/10-10_A/test/*.*' 681 | model_fn = None # logdir + 'learned_model2keep__00200010.keras' 682 | tb = time.time() 683 | dnn_model = None 684 | for n_walks in [1, 2, 4, 8, 16, 32]: 685 | accs = [] 686 | for _ in range(6): 687 | acc, dnn_model = calc_accuracy_test(logdir=logdir, dataset_folder=dataset_folder, n_walks_per_model=n_walks, 688 | dnn_model=dnn_model, labels=dataset_prepare.shrec11_labels, iter2use=str(iter2use), 689 | model_fn=model_fn, verbose_level=0) 690 | accs.append(acc[0]) 691 | #print('Run Time: ', time.time() - tb, ' ; Accuracy:', acc) 692 | print(n_walks, accs, 'STD:', np.std(accs)) 693 | elif 1: 694 | r = 'cubes2keep/0016-03.04.2020..08.59__Cubes_NewPrms' 695 | logdir = os.path.expanduser('~') + '/mesh_walker/mesh_learning/' + r + '/' 696 | model_fn = logdir + 'learned_model2keep__00160080.keras' 697 | # Gals changes 698 | #model_fn = logdir + 'learned_model2keep__00160080.keras' 699 | model_fn = glob.glob(logdir + "learned_model2keep__*.keras")[-1] 700 | # End 701 | n_walks_to_check = [1, 2, 4, 8, 16, 32] 702 | acc_all = [] 703 | for n_walks_per_model in n_walks_to_check: 704 | acc = calc_accuracy_test(logdir=logdir, model_fn=model_fn, target_n_faces=[1000], 705 | from_cach_dataset='cubes/test*.npz', labels=dataset_prepare.cubes_labels, n_walks_per_model=n_walks_per_model) 706 | acc_all.append(acc[0][0]) 707 | print('--------------------------------') 708 | print(acc_all) 709 | #[0.7708649468892261, 0.8482549317147192, 0.921092564491654, 0.952959028831563, 0.9742033383915023, 0.9787556904400607] 710 | #calc_accuracy_per_seq_len() 711 | #calc_accuracy_per_n_faces() 712 | #features_analysis() 713 | -------------------------------------------------------------------------------- /dataset_prepare.py: -------------------------------------------------------------------------------- 1 | import glob, os, shutil, sys, json 2 | from pathlib import Path 3 | 4 | import pylab as plt 5 | import trimesh 6 | import open3d 7 | from easydict import EasyDict 8 | import numpy as np 9 | from tqdm import tqdm 10 | import re 11 | 12 | import utils 13 | 14 | 15 | FIX_BAD_ANNOTATION_HUMAN_15 = 0 16 | 17 | # Labels for all datasets_processed 18 | # ----------------------- 19 | sigg17_part_labels = ['---', 'head', 'hand', 'lower-arm', 'upper-arm', 'body', 'upper-lag', 'lower-leg', 'foot'] 20 | sigg17_shape2label = {v: k for k, v in enumerate(sigg17_part_labels)} 21 | 22 | model_net_labels = [ 23 | 'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet', 24 | 'wardrobe', 'bookshelf', 'laptop', 'door', 'lamp', 'person', 'curtain', 'piano', 'airplane', 'cup', 25 | 'cone', 'tent', 'radio', 'stool', 'range_hood', 'car', 'sink', 'guitar', 'tv_stand', 'stairs', 26 | 'mantel', 'bench', 'plant', 'bottle', 'bowl', 'flower_pot', 'keyboard', 'vase', 'xbox', 'glass_box' 27 | ] 28 | model_net_shape2label = {v: k for k, v in enumerate(model_net_labels)} 29 | 30 | cubes_labels = [ 31 | 'apple', 'bat', 'bell', 'brick', 'camel', 32 | 'car', 'carriage', 'chopper', 'elephant', 'fork', 33 | 'guitar', 'hammer', 'heart', 'horseshoe', 'key', 34 | 'lmfish', 'octopus', 'shoe', 'spoon', 'tree', 35 | 'turtle', 'watch' 36 | ] 37 | cubes_shape2label = {v: k for k, v in enumerate(cubes_labels)} 38 | 39 | shrec11_labels = [ 40 | 'armadillo', 'man', 'centaur', 'dinosaur', 'dog2', 41 | 'ants', 'rabbit', 'dog1', 'snake', 'bird2', 42 | 'shark', 'dino_ske', 'laptop', 'santa', 'flamingo', 43 | 'horse', 'hand', 'lamp', 'two_balls', 'gorilla', 44 | 'alien', 'octopus', 'cat', 'woman', 'spiders', 45 | 'camel', 'pliers', 'myScissor', 'glasses', 'bird1' 46 | ] 47 | shrec11_shape2label = {v: k for k, v in enumerate(shrec11_labels)} 48 | 49 | coseg_labels = [ 50 | '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 51 | ] 52 | coseg_shape2label = {v: k for k, v in enumerate(coseg_labels)} 53 | 54 | 55 | def calc_mesh_area(mesh): 56 | t_mesh = trimesh.Trimesh(vertices=mesh['vertices'], faces=mesh['faces'], process=False) 57 | mesh['area_faces'] = t_mesh.area_faces 58 | mesh['area_vertices'] = np.zeros((mesh['vertices'].shape[0])) 59 | for f_index, f in enumerate(mesh['faces']): 60 | for v in f: 61 | mesh['area_vertices'][v] += mesh['area_faces'][f_index] / f.size 62 | 63 | 64 | def calc_vertex_labels_from_face_labels(mesh, face_labels): 65 | vertices = mesh['vertices'] 66 | faces = mesh['faces'] 67 | all_vetrex_labels = [[] for _ in range(vertices.shape[0])] 68 | vertex_labels = -np.ones((vertices.shape[0],), dtype=np.int) 69 | n_classes = int(np.max(face_labels)) 70 | assert np.min(face_labels) == 1 # min label is 1, for compatibility to human_seg labels representation 71 | v_labels_fuzzy = -np.ones((vertices.shape[0], n_classes)) 72 | for i in range(faces.shape[0]): 73 | label = face_labels[i] 74 | for f in faces[i]: 75 | all_vetrex_labels[f].append(label) 76 | for i in range(vertices.shape[0]): 77 | counts = np.bincount(all_vetrex_labels[i]) 78 | vertex_labels[i] = np.argmax(counts) 79 | v_labels_fuzzy[i] = np.zeros((1, n_classes)) 80 | for j in all_vetrex_labels[i]: 81 | v_labels_fuzzy[i, int(j) - 1] += 1 / len(all_vetrex_labels[i]) 82 | return vertex_labels, v_labels_fuzzy 83 | 84 | 85 | def prepare_edges_and_kdtree(mesh): 86 | vertices = mesh['vertices'] 87 | faces = mesh['faces'] 88 | mesh['edges'] = [set() for _ in range(vertices.shape[0])] 89 | for i in range(faces.shape[0]): 90 | for v in faces[i]: 91 | mesh['edges'][v] |= set(faces[i]) 92 | for i in range(vertices.shape[0]): 93 | if i in mesh['edges'][i]: 94 | mesh['edges'][i].remove(i) 95 | mesh['edges'][i] = list(mesh['edges'][i]) 96 | max_vertex_degree = np.max([len(e) for e in mesh['edges']]) 97 | for i in range(vertices.shape[0]): 98 | if len(mesh['edges'][i]) < max_vertex_degree: 99 | mesh['edges'][i] += [-1] * (max_vertex_degree - len(mesh['edges'][i])) 100 | mesh['edges'] = np.array(mesh['edges'], dtype=np.int32) 101 | 102 | mesh['kdtree_query'] = [] 103 | t_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) 104 | n_nbrs = min(10, vertices.shape[0] - 2) 105 | for n in range(vertices.shape[0]): 106 | d, i_nbrs = t_mesh.kdtree.query(vertices[n], n_nbrs) 107 | i_nbrs_cleared = [inbr for inbr in i_nbrs if inbr != n and inbr < vertices.shape[0]] 108 | if len(i_nbrs_cleared) > n_nbrs - 1: 109 | i_nbrs_cleared = i_nbrs_cleared[:n_nbrs - 1] 110 | mesh['kdtree_query'].append(np.array(i_nbrs_cleared, dtype=np.int32)) 111 | mesh['kdtree_query'] = np.array(mesh['kdtree_query']) 112 | assert mesh['kdtree_query'].shape[1] == (n_nbrs - 1), 'Number of kdtree_query is wrong: ' + str(mesh['kdtree_query'].shape[1]) 113 | 114 | 115 | def add_fields_and_dump_model(mesh_data, fileds_needed, out_fn, dataset_name, dump_model=True): 116 | m = {} 117 | for k, v in mesh_data.items(): 118 | if k in fileds_needed: 119 | m[k] = v 120 | for field in fileds_needed: 121 | if field not in m.keys(): 122 | if field == 'labels': 123 | m[field] = np.zeros((0,)) 124 | if field == 'dataset_name': 125 | m[field] = dataset_name 126 | if field == 'walk_cache': 127 | m[field] = np.zeros((0,)) 128 | if field == 'kdtree_query' or field == 'edges': 129 | prepare_edges_and_kdtree(m) 130 | 131 | if dump_model: 132 | np.savez(out_fn, **m) 133 | 134 | return m 135 | 136 | 137 | def get_sig17_seg_bm_labels(mesh, file, seg_path): 138 | # Finding the best match file name .. : 139 | in_to_check = file.replace('obj', 'txt') 140 | in_to_check = in_to_check.replace('off', 'txt') 141 | in_to_check = in_to_check.replace('_fix_orientation', '') 142 | if in_to_check.find('MIT_animation') != -1 and in_to_check.split('/')[-1].startswith('mesh_'): 143 | in_to_check = '/'.join(in_to_check.split('/')[:-2]) 144 | in_to_check = in_to_check.replace('MIT_animation/meshes_', 'mit/mit_') 145 | in_to_check += '.txt' 146 | elif in_to_check.find('/scape/') != -1: 147 | in_to_check = '/'.join(in_to_check.split('/')[:-1]) 148 | in_to_check += '/scape.txt' 149 | elif in_to_check.find('/faust/') != -1: 150 | in_to_check = '/'.join(in_to_check.split('/')[:-1]) 151 | in_to_check += '/faust.txt' 152 | 153 | seg_full_fn = [] 154 | for fn in Path(seg_path).rglob('*.txt'): 155 | tmp = str(fn) 156 | tmp = tmp.replace('/segs/', '/meshes/') 157 | tmp = tmp.replace('_full', '') 158 | tmp = tmp.replace('shrec_', '') 159 | tmp = tmp.replace('_corrected', '') 160 | if tmp == in_to_check: 161 | seg_full_fn.append(str(fn)) 162 | if len(seg_full_fn) == 1: 163 | seg_full_fn = seg_full_fn[0] 164 | else: 165 | print('\nin_to_check', in_to_check) 166 | print('tmp', tmp) 167 | raise Exception('!!') 168 | face_labels = np.loadtxt(seg_full_fn) 169 | 170 | if FIX_BAD_ANNOTATION_HUMAN_15 and file.endswith('test/shrec/15.off'): 171 | face_center = [] 172 | for f in mesh.faces: 173 | face_center.append(np.mean(mesh.vertices[f, :], axis=0)) 174 | face_center = np.array(face_center) 175 | idxs = (face_labels == 6) * (face_center[:, 0] < 0) * (face_center[:, 1] < -0.4) 176 | face_labels[idxs] = 7 177 | np.savetxt(seg_full_fn + '.fixed.txt', face_labels.astype(np.int)) 178 | 179 | return face_labels 180 | 181 | 182 | def get_labels(dataset_name, mesh, file, fn2labels_map=None): 183 | v_labels_fuzzy = np.zeros((0,)) 184 | if dataset_name == 'faust': 185 | face_labels = np.load('faust_labels/faust_part_segmentation.npy').astype(np.int) 186 | vertex_labels, v_labels_fuzzy = calc_vertex_labels_from_face_labels(mesh, face_labels) 187 | model_label = np.zeros((0,)) 188 | return model_label, vertex_labels, v_labels_fuzzy 189 | elif dataset_name.startswith('coseg') or dataset_name == 'human_seg_from_meshcnn': 190 | labels_fn = '/'.join(file.split('/')[:-2]) + '/seg/' + file.split('/')[-1].split('.')[-2] + '.eseg' 191 | e_labels = np.loadtxt(labels_fn) 192 | v_labels = [[] for _ in range(mesh['vertices'].shape[0])] 193 | faces = mesh['faces'] 194 | 195 | fuzzy_labels_fn = '/'.join(file.split('/')[:-2]) + '/sseg/' + file.split('/')[-1].split('.')[-2] + '.seseg' 196 | seseg_labels = np.loadtxt(fuzzy_labels_fn) 197 | v_labels_fuzzy = np.zeros((mesh['vertices'].shape[0], seseg_labels.shape[1])) 198 | 199 | edge2key = dict() 200 | edges = [] 201 | edges_count = 0 202 | for face_id, face in enumerate(faces): 203 | faces_edges = [] 204 | for i in range(3): 205 | cur_edge = (face[i], face[(i + 1) % 3]) 206 | faces_edges.append(cur_edge) 207 | for idx, edge in enumerate(faces_edges): 208 | edge = tuple(sorted(list(edge))) 209 | faces_edges[idx] = edge 210 | if edge not in edge2key: 211 | v_labels_fuzzy[edge[0]] += seseg_labels[edges_count] 212 | v_labels_fuzzy[edge[1]] += seseg_labels[edges_count] 213 | 214 | edge2key[edge] = edges_count 215 | edges.append(list(edge)) 216 | v_labels[edge[0]].append(e_labels[edges_count]) 217 | v_labels[edge[1]].append(e_labels[edges_count]) 218 | edges_count += 1 219 | 220 | assert np.max(np.sum(v_labels_fuzzy != 0, axis=1)) <= 3, 'Number of non-zero labels must not acceeds 3!' 221 | 222 | vertex_labels = [] 223 | for l in v_labels: 224 | l2add = np.argmax(np.bincount(l)) 225 | vertex_labels.append(l2add) 226 | vertex_labels = np.array(vertex_labels) 227 | model_label = np.zeros((0,)) 228 | 229 | return model_label, vertex_labels, v_labels_fuzzy 230 | else: 231 | tmp = file.split('/')[-1] 232 | model_name = '_'.join(tmp.split('_')[:-1]) 233 | if dataset_name.lower().startswith('modelnet'): 234 | model_label = model_net_shape2label[model_name] 235 | elif dataset_name.lower().startswith('cubes'): 236 | model_label = cubes_shape2label[model_name] 237 | elif dataset_name.lower().startswith('shrec11'): 238 | model_name = file.split('/')[-3] 239 | if fn2labels_map is None: 240 | model_label = shrec11_shape2label[model_name] 241 | else: 242 | file_index = int(file.split('.')[-2].split('T')[-1]) 243 | model_label = fn2labels_map[file_index] 244 | else: 245 | raise Exception('Cannot find labels for the dataset') 246 | vertex_labels = np.zeros((0,)) 247 | return model_label, vertex_labels, v_labels_fuzzy 248 | 249 | def fix_labels_by_dist(vertices, orig_vertices, labels_orig): 250 | labels = -np.ones((vertices.shape[0], )) 251 | 252 | for i, vertex in enumerate(vertices): 253 | d = np.linalg.norm(vertex - orig_vertices, axis=1) 254 | orig_idx = np.argmin(d) 255 | labels[i] = labels_orig[orig_idx] 256 | 257 | return labels 258 | 259 | def get_faces_belong_to_vertices(vertices, faces): 260 | faces_belong = [] 261 | for face in faces: 262 | used = np.any([v in vertices for v in face]) 263 | if used: 264 | faces_belong.append(face) 265 | return np.array(faces_belong) 266 | 267 | 268 | def remesh(mesh_orig, target_n_faces, add_labels=False, labels_orig=None): 269 | labels = labels_orig 270 | if target_n_faces < np.asarray(mesh_orig.triangles).shape[0]: 271 | mesh = mesh_orig.simplify_quadric_decimation(target_n_faces) 272 | str_to_add = '_simplified_to_' + str(target_n_faces) 273 | mesh = mesh.remove_unreferenced_vertices() 274 | if add_labels and labels_orig.size: 275 | labels = fix_labels_by_dist(np.asarray(mesh.vertices), np.asarray(mesh_orig.vertices), labels_orig) 276 | else: 277 | mesh = mesh_orig 278 | str_to_add = '_not_changed_' + str(np.asarray(mesh_orig.triangles).shape[0]) 279 | 280 | return mesh, labels, str_to_add 281 | 282 | 283 | def load_meshes(model_fns): 284 | f_names = glob.glob(model_fns) 285 | joint_mesh_vertices = [] 286 | joint_mesh_faces = [] 287 | for fn in f_names: 288 | mesh_ = trimesh.load_mesh(fn) 289 | vertex_offset = len(joint_mesh_vertices) 290 | joint_mesh_vertices += mesh_.vertices.tolist() 291 | faces = mesh_.faces + vertex_offset 292 | joint_mesh_faces += faces.tolist() 293 | 294 | mesh = open3d.geometry.TriangleMesh() 295 | mesh.vertices = open3d.utility.Vector3dVector(joint_mesh_vertices) 296 | mesh.triangles = open3d.utility.Vector3iVector(joint_mesh_faces) 297 | 298 | return mesh 299 | 300 | 301 | def load_mesh(model_fn, classification=True): 302 | if 1: # To load and clean up mesh - "remove vertices that share position" 303 | if classification: 304 | mesh_ = trimesh.load_mesh(model_fn, process=True) 305 | mesh_.remove_duplicate_faces() 306 | else: 307 | mesh_ = trimesh.load_mesh(model_fn, process=False) 308 | mesh = open3d.geometry.TriangleMesh() 309 | mesh.vertices = open3d.utility.Vector3dVector(mesh_.vertices) 310 | mesh.triangles = open3d.utility.Vector3iVector(mesh_.faces) 311 | else: 312 | mesh = open3d.io.read_triangle_mesh(model_fn) 313 | 314 | return mesh 315 | 316 | def create_tmp_dataset(model_fn, p_out, n_target_faces): 317 | fileds_needed = ['vertices', 'faces', 'edge_features', 'edges_map', 'edges', 'kdtree_query', 318 | 'label', 'labels', 'dataset_name'] 319 | if not os.path.isdir(p_out): 320 | os.makedirs(p_out) 321 | mesh_orig = load_mesh(model_fn) 322 | mesh, labels, str_to_add = remesh(mesh_orig, n_target_faces) 323 | labels = np.zeros((np.asarray(mesh.vertices).shape[0],), dtype=np.int16) 324 | mesh_data = EasyDict({'vertices': np.asarray(mesh.vertices), 'faces': np.asarray(mesh.triangles), 'label': 0, 'labels': labels}) 325 | out_fn = p_out + '/tmp' 326 | add_fields_and_dump_model(mesh_data, fileds_needed, out_fn, 'tmp') 327 | 328 | 329 | #Gloabal variable that holds all the meshCNN and PD MESHNet vertices and faces npz's 330 | #all_mesh_cnn_files = os.listdir(os.path.expanduser('~') + '/mesh_cnn_faces_and_vertices_npz') 331 | def change_faces_and_vertices(mesh_data, file_name: str): 332 | name = (re.split(pattern=' |/', string=file_name))[-1] 333 | name = (re.split(pattern=' |\.', string=name))[0] 334 | path_to_meshCNN_file = [file for file in all_mesh_cnn_files if str(file).__contains__(name+'_')] 335 | mesh_cnn_raw_data = np.load(os.path.expanduser('~') + '/mesh_cnn_faces_and_vertices_npz/' + path_to_meshCNN_file[0]) 336 | mesh_cnn_data = {k: v for k, v in mesh_cnn_raw_data.items()} 337 | mesh_data['vertices'] = mesh_cnn_data['vertices'] 338 | mesh_data['faces'] = mesh_cnn_data['faces'] 339 | mesh_data['label'] = mesh_cnn_data['label'] 340 | return mesh_data 341 | 342 | 343 | #Gloabal variable that holds all the copycat's npzs 344 | #all_copycat_shrec11_files = os.listdir('datasets_processed/copycat_shrec11/') 345 | def change_to_copycat_walker(mesh_data, file_name: str): 346 | name = (re.split(pattern=' |/', string=file_name))[-1] 347 | name = (re.split(pattern=' |\.', string=name))[0] 348 | path_to_meshCNN_file = [file for file in all_copycat_shrec11_files if str(file).__contains__(name+'_')] 349 | mesh_cnn_raw_data = np.load('datasets_processed/copycat_shrec11/' + path_to_meshCNN_file[0], encoding='latin1', allow_pickle=True) 350 | mesh_cnn_data = {k: v for k, v in mesh_cnn_raw_data.items()} 351 | mesh_data['label'] = mesh_cnn_data['label'] 352 | return mesh_data 353 | 354 | 355 | def prepare_directory_from_scratch(dataset_name, pathname_expansion=None, p_out=None, n_target_faces=None, add_labels=True, 356 | size_limit=np.inf, fn_prefix='', verbose=True, classification=True, adversrial_data = None): 357 | fileds_needed = ['vertices', 'faces', 'edges', 'kdtree_query', 358 | 'label', 'labels', 'dataset_name'] 359 | fileds_needed += ['labels_fuzzy'] 360 | 361 | if not os.path.isdir(p_out): 362 | os.makedirs(p_out) 363 | 364 | filenames = glob.glob(pathname_expansion) 365 | filenames.sort() 366 | if len(filenames) > size_limit: 367 | filenames = filenames[:size_limit] 368 | for file in tqdm(filenames, disable=1 - verbose): 369 | out_fn = p_out + '/' + fn_prefix + os.path.split(file)[1].split('.')[0] 370 | mesh = load_mesh(file, classification=classification) 371 | mesh_orig = mesh 372 | mesh_data = EasyDict({'vertices': np.asarray(mesh.vertices), 'faces': np.asarray(mesh.triangles)}) 373 | if add_labels: 374 | if type(add_labels) is list: 375 | fn2labels_map = add_labels 376 | else: 377 | fn2labels_map = None 378 | label, labels_orig, v_labels_fuzzy = get_labels(dataset_name, mesh_data, file, fn2labels_map=fn2labels_map) 379 | else: 380 | label = np.zeros((0, )) 381 | for this_target_n_faces in n_target_faces: 382 | mesh, labels, str_to_add = remesh(mesh_orig, this_target_n_faces, add_labels=add_labels, labels_orig=labels_orig) 383 | mesh_data = EasyDict({'vertices': np.asarray(mesh.vertices), 'faces': np.asarray(mesh.triangles), 'label': label, 'labels': labels}) 384 | mesh_data['labels_fuzzy'] = v_labels_fuzzy 385 | out_fc_full = out_fn + str_to_add 386 | if adversrial_data == 'mesh_CNN' or adversrial_data == 'PD_MeshNet': 387 | change_faces_and_vertices(mesh_data, str(file)) 388 | if adversrial_data == 'walker_copycat': 389 | change_to_copycat_walker(mesh_data, str(file)) 390 | m = add_fields_and_dump_model(mesh_data, fileds_needed, out_fc_full, dataset_name) 391 | 392 | # ------------------------------------------------------- # 393 | 394 | def prepare_modelnet40_walker(): 395 | n_target_faces = [1000, 2000, 4000] 396 | labels2use = model_net_labels 397 | for i, name in tqdm(enumerate(labels2use)): 398 | for part in ['test', 'train']: 399 | pin = os.path.expanduser('~') + '/mesh_walker/datasets_raw/ModelNet40/' + name + '/' + part + '/' 400 | p_out = os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/modelnet40/' 401 | prepare_directory_from_scratch('modelnet40', pathname_expansion=pin + '*.off', 402 | p_out=p_out, add_labels='modelnet', n_target_faces=n_target_faces, 403 | fn_prefix=part + '_', verbose=False) 404 | 405 | 406 | def prepare_cubes(labels2use=cubes_labels, 407 | path_in=os.path.expanduser('~') + '/datasets_processed/cubes/', 408 | p_out=os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/cubes_tmp'): 409 | dataset_name = 'cubes' 410 | if not os.path.isdir(p_out): 411 | os.makedirs(p_out) 412 | 413 | for i, name in enumerate(labels2use): 414 | print('-->>>', name) 415 | for part in ['test', 'train']: 416 | pin = path_in + name + '/' + part + '/' 417 | prepare_directory_from_scratch(dataset_name, pathname_expansion=pin + '*.obj', 418 | p_out=p_out, add_labels=dataset_name, fn_prefix=part + '_', n_target_faces=[np.inf], 419 | classification=False) 420 | 421 | def prepare_shrec11_from_raw(): 422 | # Prepare labels per model name 423 | current_label = None 424 | model_number2label = [-1 for _ in range(600)] 425 | for line in open(os.path.expanduser('~') + '/Desktop/Shrec11/test.cla'): 426 | sp_line = line.split(' ') 427 | if len(sp_line) == 3: 428 | name = sp_line[0].replace('_test', '') 429 | if name in shrec11_labels: 430 | current_label = name 431 | else: 432 | raise Exception('?') 433 | if len(sp_line) == 1 and sp_line[0] != '\n': 434 | model_number2label[int(sp_line[0])] = shrec11_shape2label[current_label] 435 | 436 | 437 | # Prepare npz files 438 | p_in = os.path.expanduser('~') + '/Desktop/Shrec11/raw/' 439 | p_out = os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/shrec11_raw_500_meshCNN/' 440 | prepare_directory_from_scratch('shrec11', pathname_expansion=p_in + '*.off', 441 | p_out=p_out, add_labels=model_number2label, n_target_faces=[500]) 442 | 443 | def prepare_shrec11_meshCNN_from_raw(): 444 | # Prepare labels per model name 445 | current_label = None 446 | model_number2label = [-1 for _ in range(600)] 447 | for line in open(os.path.expanduser('~') + '/Desktop/Shrec11/test.cla'): 448 | sp_line = line.split(' ') 449 | if len(sp_line) == 3: 450 | name = sp_line[0].replace('_test', '') 451 | if name in shrec11_labels: 452 | current_label = name 453 | else: 454 | raise Exception('?') 455 | if len(sp_line) == 1 and sp_line[0] != '\n': 456 | model_number2label[int(sp_line[0])] = shrec11_shape2label[current_label] 457 | 458 | 459 | # Prepare npz files 460 | p_in = os.path.expanduser('~') + '/Desktop/Shrec11/raw/' 461 | p_out = os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/shrec11_raw_500_meshCNN/' 462 | prepare_directory_from_scratch('shrec11', pathname_expansion=p_in + '*.off', 463 | p_out=p_out, add_labels=model_number2label, n_target_faces=[500], adversrial_data='mesh_CNN') 464 | 465 | def prepare_shrec11_PD_MeshNet_from_raw(): 466 | # Prepare labels per model name 467 | current_label = None 468 | model_number2label = [-1 for _ in range(600)] 469 | for line in open(os.path.expanduser('~') + '/Desktop/Shrec11/test.cla'): 470 | sp_line = line.split(' ') 471 | if len(sp_line) == 3: 472 | name = sp_line[0].replace('_test', '') 473 | if name in shrec11_labels: 474 | current_label = name 475 | else: 476 | raise Exception('?') 477 | if len(sp_line) == 1 and sp_line[0] != '\n': 478 | model_number2label[int(sp_line[0])] = shrec11_shape2label[current_label] 479 | 480 | 481 | # Prepare npz files 482 | p_in = os.path.expanduser('~') + '/Desktop/Shrec11/raw/' 483 | p_out = os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/shrec11_raw_500_meshCNN/' 484 | prepare_directory_from_scratch('shrec11', pathname_expansion=p_in + '*.off', 485 | p_out=p_out, add_labels=model_number2label, n_target_faces=[500], adversrial_data='PD_MeshNet') 486 | 487 | 488 | 489 | def prepare_walker_shrec11_copycat_from_raw(): 490 | # Prepare labels per model name 491 | current_label = None 492 | model_number2label = [-1 for _ in range(600)] 493 | for line in open(os.path.expanduser('~') + '/Desktop/Shrec11/test.cla'): 494 | sp_line = line.split(' ') 495 | if len(sp_line) == 3: 496 | name = sp_line[0].replace('_test', '') 497 | if name in shrec11_labels: 498 | current_label = name 499 | else: 500 | raise Exception('?') 501 | if len(sp_line) == 1 and sp_line[0] != '\n': 502 | model_number2label[int(sp_line[0])] = shrec11_shape2label[current_label] 503 | 504 | 505 | # Prepare npz files 506 | p_in = os.path.expanduser('~') + '/Desktop/Shrec11/raw/' 507 | p_out = os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/shrec11_raw_500_copycat/' 508 | prepare_directory_from_scratch('shrec11', pathname_expansion=p_in + '*.off', 509 | p_out=p_out, add_labels=model_number2label, n_target_faces=[500], meshCNN_data='walker_copycat') 510 | 511 | # Prepare split train / test 512 | change_train_test_split(p_out, 16, 4, '16-04_a') 513 | 514 | def prepare_modelnet40_mesh_net(): 515 | n_target_faces = [1024] 516 | labels2use = mesh_net_labels 517 | for i, name in tqdm(enumerate(labels2use)): 518 | for part in ['test', 'train']: 519 | p_in = os.path.expanduser('~') + '/meshNet_adverserial/ModelNet40/' + name + '/' + part + '/' 520 | p_out = os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/modelnet40_mesh_net/' #+ name + '/' + part + '/' 521 | mesh_net_vertices_and_faces_path = os.path.expanduser('~') + '/meshNet_adverserial/ModelNet40_MeshNet_raw/' + name + '/' + part + '/' 522 | mesh_net_labels_path = os.path.expanduser('~') + '/meshNet_adverserial/mesh_net_predicted_labels/' + name + '/' + part + '/' 523 | 524 | prepare_mesh_net_directory_from_scratch('mesh_net_modelnet40', pathname_expansion=p_in + '*.off', 525 | p_out=p_out, add_labels='modelnet', n_target_faces=n_target_faces, 526 | fn_prefix=part + '_', verbose=False, mesh_net_vertices_and_faces_path= mesh_net_vertices_and_faces_path, 527 | mesh_net_labels_path= mesh_net_labels_path) 528 | 529 | 530 | 531 | def calc_face_labels_after_remesh(mesh_orig, mesh, face_labels): 532 | t_mesh = trimesh.Trimesh(vertices=np.array(mesh_orig.vertices), faces=np.array(mesh_orig.triangles), process=False) 533 | 534 | remeshed_face_labels = [] 535 | for face in mesh.triangles: 536 | vertices = np.array(mesh.vertices)[face] 537 | center = np.mean(vertices, axis=0) 538 | p, d, closest_face = trimesh.proximity.closest_point(t_mesh, [center]) 539 | remeshed_face_labels.append(face_labels[closest_face[0]]) 540 | return remeshed_face_labels 541 | 542 | 543 | def prepare_human_body_segmentation(): 544 | dataset_name = 'sig17_seg_benchmark' 545 | labels_fuzzy = True 546 | human_seg_path = os.path.expanduser('~') + '/mesh_walker/datasets_raw/sig17_seg_benchmark/' 547 | p_out = os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/sig17_seg_benchmark-no_simplification/' 548 | 549 | fileds_needed = ['vertices', 'faces', 'edge_features', 'edges_map', 'edges', 'kdtree_query', 550 | 'label', 'labels', 'dataset_name', 'face_labels'] 551 | if labels_fuzzy: 552 | fileds_needed += ['labels_fuzzy'] 553 | 554 | n_target_faces = [np.inf] 555 | if not os.path.isdir(p_out): 556 | os.makedirs(p_out) 557 | for part in ['test', 'train']: 558 | print('part: ', part) 559 | path_meshes = human_seg_path + '/meshes/' + part 560 | seg_path = human_seg_path + '/segs/' + part 561 | all_fns = [] 562 | for fn in Path(path_meshes).rglob('*.*'): 563 | all_fns.append(fn) 564 | for fn in tqdm(all_fns): 565 | model_name = str(fn) 566 | if model_name.endswith('.obj') or model_name.endswith('.off') or model_name.endswith('.ply'): 567 | new_fn = model_name[model_name.find(part) + len(part) + 1:] 568 | new_fn = new_fn.replace('/', '_') 569 | new_fn = new_fn.split('.')[-2] 570 | out_fn = p_out + '/' + part + '__' + new_fn 571 | mesh = mesh_orig = load_mesh(model_name, classification=False) 572 | mesh_data = EasyDict({'vertices': np.asarray(mesh.vertices), 'faces': np.asarray(mesh.triangles)}) 573 | face_labels = get_sig17_seg_bm_labels(mesh_data, model_name, seg_path) 574 | labels_orig, v_labels_fuzzy = calc_vertex_labels_from_face_labels(mesh_data, face_labels) 575 | if 0: # Show segment borders 576 | b_vertices = np.where(np.sum(v_labels_fuzzy != 0, axis=1) > 1)[0] 577 | vertex_colors = np.zeros((mesh_data['vertices'].shape[0],), dtype=np.int) 578 | vertex_colors[b_vertices] = 1 579 | utils.visualize_model(mesh_data['vertices'], mesh_data['faces'], vertex_colors_idx=vertex_colors, point_size=2) 580 | if 0: # Show face labels 581 | utils.visualize_model(mesh_data['vertices'], mesh_data['faces'], face_colors=face_labels, show_vertices=False, show_edges=False) 582 | if 0: 583 | print(model_name) 584 | print('min: ', np.min(mesh_data['vertices'], axis=0)) 585 | print('max: ', np.max(mesh_data['vertices'], axis=0)) 586 | cpos = [(-3.5, -0.12, 6.0), (0., 0., 0.1), (0., 1., 0.)] 587 | utils.visualize_model(mesh_data['vertices'], mesh_data['faces'], vertex_colors_idx=labels_orig, cpos=cpos) 588 | add_labels = 1 589 | label = -1 590 | for this_target_n_faces in n_target_faces: 591 | mesh, labels, str_to_add = remesh(mesh_orig, this_target_n_faces, add_labels=add_labels, labels_orig=labels_orig) 592 | if mesh == mesh_orig: 593 | remeshed_face_labels = face_labels 594 | else: 595 | remeshed_face_labels = calc_face_labels_after_remesh(mesh_orig, mesh, face_labels) 596 | mesh_data = EasyDict({'vertices': np.asarray(mesh.vertices), 597 | 'faces': np.asarray(mesh.triangles), 598 | 'label': label, 'labels': labels, 599 | 'face_labels': remeshed_face_labels}) 600 | if 1: 601 | v_labels, v_labels_fuzzy = calc_vertex_labels_from_face_labels(mesh_data, remeshed_face_labels) 602 | mesh_data['labels'] = v_labels 603 | mesh_data['labels_fuzzy'] = v_labels_fuzzy 604 | if 0: # Show segment borders 605 | b_vertices = np.where(np.sum(v_labels_fuzzy != 0, axis=1) > 1)[0] 606 | vertex_colors = np.zeros((mesh_data['vertices'].shape[0],), dtype=np.int) 607 | vertex_colors[b_vertices] = 1 608 | utils.visualize_model(mesh_data['vertices'], mesh_data['faces'], vertex_colors_idx=vertex_colors, point_size=10) 609 | if 0: # Show face labels 610 | utils.visualize_model(np.array(mesh.vertices), np.array(mesh.triangles), face_colors=remeshed_face_labels, show_vertices=False, show_edges=False) 611 | out_fc_full = out_fn + str_to_add 612 | if os.path.isfile(out_fc_full + '.npz'): 613 | continue 614 | add_fields_and_dump_model(mesh_data, fileds_needed, out_fc_full, dataset_name) 615 | if 0: 616 | utils.visualize_model(mesh_data['vertices'], mesh_data['faces'], vertex_colors_idx=mesh_data['labels'].astype(np.int), 617 | cpos=[(-2., -0.2, 3.3), (0., -0.3, 0.1), (0., 1., 0.)]) 618 | 619 | 620 | def prepare_seg_from_meshcnn(dataset, subfolder=None): 621 | if dataset == 'human_body': 622 | dataset_name = 'human_seg_from_meshcnn' 623 | p_in2add = 'human_seg' 624 | p_out_sub = p_in2add 625 | p_ext = '' 626 | elif dataset == 'coseg': 627 | p_out_sub = dataset_name = 'coseg' 628 | p_in2add = dataset_name + '/' + subfolder 629 | p_ext = subfolder 630 | 631 | path_in = os.path.expanduser('~') + '/mesh_walker/datasets_raw/from_meshcnn/' + p_in2add + '/' 632 | p_out = os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/' + p_out_sub + '_from_meshcnn/' + p_ext 633 | 634 | for part in ['test', 'train']: 635 | pin = path_in + '/' + part + '/' 636 | prepare_directory_from_scratch(dataset_name, pathname_expansion=pin + '*.obj', 637 | p_out=p_out, add_labels=dataset_name, fn_prefix=part + '_', n_target_faces=[np.inf], 638 | classification=False) 639 | 640 | 641 | def prepare_coseg(dataset_name='coseg', 642 | path_in=os.path.expanduser('~') + '/datasets_processed/coseg/', 643 | p_out_root=os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/coseg_tmp2'): 644 | for sub_folder in os.listdir(path_in): 645 | p_out = p_out_root + '/' + sub_folder 646 | if not os.path.isdir(p_out): 647 | os.makedirs(p_out + '/' + sub_folder) 648 | 649 | for part in ['test', 'train']: 650 | pin = path_in + '/' + sub_folder + '/' + part + '/' 651 | prepare_directory_from_scratch(sub_folder, pathname_expansion=pin + '*.obj', 652 | p_out=p_out, add_labels=dataset_name, fn_prefix=part + '_', n_target_faces=[np.inf]) 653 | 654 | # ------------------------------------------------------- # 655 | 656 | def map_fns_to_label(path=None, filenames=None): 657 | lmap = {} 658 | if path is not None: 659 | iterate = glob.glob(path + '/*.npz') 660 | elif filenames is not None: 661 | iterate = filenames 662 | 663 | for fn in iterate: 664 | mesh_data = np.load(fn, encoding='latin1', allow_pickle=True) 665 | label = int(mesh_data['label']) 666 | if label not in lmap.keys(): 667 | lmap[label] = [] 668 | if path is None: 669 | lmap[label].append(fn) 670 | else: 671 | lmap[label].append(fn.split('/')[-1]) 672 | return lmap 673 | 674 | 675 | def change_train_test_split(path, n_train_examples, n_test_examples, split_name): 676 | np.random.seed() 677 | fns_lbls_map = map_fns_to_label(path) 678 | for label, fns_ in fns_lbls_map.items(): 679 | fns = np.random.permutation(fns_) 680 | assert len(fns) == n_train_examples + n_test_examples 681 | train_path = path + '/' + split_name + '/train' 682 | if not os.path.isdir(train_path): 683 | os.makedirs(train_path) 684 | test_path = path + '/' + split_name + '/test' 685 | if not os.path.isdir(test_path): 686 | os.makedirs(test_path) 687 | for i, fn in enumerate(fns): 688 | out_fn = fn.replace('train_', '').replace('test_', '') 689 | if i < n_train_examples: 690 | shutil.copy(path + '/' + fn, train_path + '/' + out_fn) 691 | else: 692 | shutil.copy(path + '/' + fn, test_path + '/' + out_fn) 693 | 694 | 695 | # ------------------------------------------------------- # 696 | 697 | 698 | def prepare_one_dataset(dataset_name, mode): 699 | dataset_name = dataset_name.lower() 700 | if dataset_name == 'modelnet40' or dataset_name == 'modelnet': 701 | prepare_modelnet40() 702 | 703 | if dataset_name == 'shrec11': 704 | pass 705 | 706 | if dataset_name == 'cubes': 707 | pass 708 | 709 | # Semantic Segmentations 710 | if dataset_name == 'human_seg': 711 | if mode == 'from_meshcnn': 712 | prepare_seg_from_meshcnn('human_body') 713 | else: 714 | prepare_human_body_segmentation() 715 | 716 | if dataset_name == 'coseg': 717 | prepare_seg_from_meshcnn('coseg', 'coseg_aliens') 718 | prepare_seg_from_meshcnn('coseg', 'coseg_chairs') 719 | prepare_seg_from_meshcnn('coseg', 'coseg_vases') 720 | 721 | 722 | def vertex_pertubation(faces, vertices): 723 | n_vertices2change = int(vertices.shape[0] * 0.3) 724 | for _ in range(n_vertices2change): 725 | face = faces[np.random.randint(faces.shape[0])] 726 | vertices_mean = np.mean(vertices[face, :], axis=0) 727 | v = np.random.choice(face) 728 | vertices[v] = vertices_mean 729 | return vertices 730 | 731 | 732 | def visualize_dataset(pathname_expansion): 733 | cpos = None 734 | filenames = glob.glob(pathname_expansion) 735 | while 1: 736 | fn = np.random.choice(filenames) 737 | mesh_data = np.load(fn, encoding='latin1', allow_pickle=True) 738 | vertex_colors_idx = mesh_data['labels'].astype(np.int) if mesh_data['labels'].size else None 739 | vertices = mesh_data['vertices'] 740 | #vertices = vertex_pertubation(mesh_data['faces'], vertices) 741 | utils.visualize_model(vertices, mesh_data['faces'], vertex_colors_idx=vertex_colors_idx, cpos=cpos, point_size=5) 742 | 743 | 744 | if __name__ == '__main__': 745 | TEST_FAST = 0 746 | utils.config_gpu(False) 747 | np.random.seed(1) 748 | #prepare_shrec11_from_raw() 749 | #prepare_copycat_shrec11_from_raw() 750 | add_scale_to_dataset() 751 | 752 | #visualize_dataset('/home/alonlahav/mesh_walker/datasets_processed-tmp/sig17_seg_benchmark-no_simplification/*.npz') 753 | #visualize_dataset('/home/galye/mesh_walker/datasets_processed/shrec16/*.npz') 754 | ''' 755 | dataset_name = 'human_seg' 756 | mode = 'from_raw' # from_meshcnn / from_raw 757 | if len(sys.argv) > 1: 758 | dataset_name = sys.argv[1] 759 | if len(sys.argv) > 2: 760 | mode = sys.argv[2] 761 | 762 | if dataset_name == 'all': 763 | for dataset_name_ in ['modelnet40', 'shrec11', 'cubes', 'human_seg', 'coseg']: 764 | prepare_one_dataset(dataset_name_) 765 | else: 766 | prepare_one_dataset(dataset_name, mode) 767 | 768 | if 0: 769 | prepare_shrec11_from_raw() 770 | elif 0: 771 | prepare_cubes() 772 | elif 0: 773 | prepare_cubes(dataset_name='shrec11', path_in=os.path.expanduser('~') + '/datasets_processed/shrec_16/', 774 | p_out=os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/shrec11_tmp', 775 | labels2use=shrec11_labels) 776 | elif 0: 777 | prepare_coseg() 778 | elif 0: 779 | change_train_test_split(path=os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/shrec11/', 780 | n_train_examples=16, n_test_examples=4, split_name='16-04_C') 781 | elif 0: 782 | collect_n_models_per_class(in_path=os.path.expanduser('~') + '/mesh_walker/datasets_processed-tmp/coseg/coseg_vases/', 783 | n_models4train=[1, 2, 4, 8, 16, 32]) 784 | ''' 785 | -------------------------------------------------------------------------------- /rnn_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from easydict import EasyDict 6 | import copy 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | import tensorflow_addons as tfa 11 | 12 | import utils 13 | 14 | from tensorflow import keras 15 | layers = tf.keras.layers 16 | 17 | 18 | class RnnWalkBase(tf.keras.Model): 19 | def __init__(self, 20 | params, 21 | classes, 22 | net_input_dim, 23 | model_fn=None, 24 | model_must_be_load=False, 25 | dump_model_visualization=True, 26 | optimizer=None): 27 | super(RnnWalkBase, self).__init__(name='') 28 | 29 | self._classes = classes 30 | self._params = params 31 | self._model_must_be_load = model_must_be_load 32 | 33 | self._pooling_betwin_grus = 'pooling' in self._params.aditional_network_params 34 | self._bidirectional_rnn = 'bidirectional_rnn' in self._params.aditional_network_params 35 | 36 | self._init_layers() 37 | inputs = tf.keras.layers.Input(shape=(100, net_input_dim)) 38 | self.build(input_shape=(1, 100, net_input_dim)) 39 | outputs = self.call(inputs) 40 | if dump_model_visualization: 41 | tmp_model = keras.Model(inputs=inputs, outputs=outputs, name='WalkModel') 42 | tmp_model.summary(print_fn=self._print_fn) 43 | tf.keras.utils.plot_model(tmp_model, params.logdir + '/RnnWalkModel.png', show_shapes=True) 44 | 45 | self.manager = None 46 | if optimizer: 47 | if model_fn: 48 | #self.checkpoint = tf.train.Checkpoint(optimizer=copy.deepcopy(optimizer), model=self) 49 | self.checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=self) 50 | else: 51 | self.checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=self) 52 | self.manager = tf.train.CheckpointManager(self.checkpoint, directory=self._params.logdir, max_to_keep=5) 53 | if model_fn: # Transfer learning 54 | self.load_weights(model_fn) 55 | self.checkpoint.optimizer = optimizer 56 | else: 57 | self.load_weights() 58 | else: 59 | self.checkpoint = tf.train.Checkpoint(model=self) 60 | if model_fn: 61 | self.load_weights(model_fn) 62 | else: 63 | self.load_weights(tf.train.latest_checkpoint(self._params.logdir)) 64 | 65 | def _print_fn(self, st): 66 | with open(self._params.logdir + '/log.txt', 'at') as f: 67 | f.write(st + '\n') 68 | 69 | def load_weights(self, filepath=None): 70 | if filepath is not None and filepath.endswith('.keras'): 71 | super(RnnWalkBase, self).load_weights(filepath) 72 | elif filepath is None: 73 | status = self.checkpoint.restore(self.manager.latest_checkpoint) 74 | print(utils.color.BLUE, 'Starting from iteration: ', self.checkpoint.optimizer.iterations.numpy(), utils.color.END) 75 | else: 76 | filepath = filepath.replace('//', '/') 77 | status = self.checkpoint.restore(filepath) 78 | 79 | def save_weights(self, folder, step=None, keep=False): 80 | if self.manager is not None: 81 | self.manager.save() 82 | if keep: 83 | super(RnnWalkBase, self).save_weights(folder + '/learned_model2keep__' + str(step).zfill(8) + '.keras') 84 | #self.checkpoint.write(folder + '/learned_model2keep--' + str(step)) 85 | 86 | 87 | 88 | class RnnWalkNet(RnnWalkBase): 89 | def __init__(self, 90 | params, 91 | classes, 92 | net_input_dim, 93 | model_fn, 94 | model_must_be_load=False, 95 | dump_model_visualization=True, 96 | optimizer=None): 97 | if params.layer_sizes is None: 98 | self._layer_sizes = {'fc1': 128, 'fc2': 256, 'gru1': 1024, 'gru2': 1024, 'gru3': 512} 99 | else: 100 | self._layer_sizes = params.layer_sizes 101 | super(RnnWalkNet, self).__init__(params, classes, net_input_dim, model_fn, model_must_be_load=model_must_be_load, 102 | dump_model_visualization=dump_model_visualization, optimizer=optimizer) 103 | 104 | def _init_layers(self): 105 | kernel_regularizer = tf.keras.regularizers.l2(0.0001) 106 | initializer = tf.initializers.Orthogonal(3) 107 | self._use_norm_layer = self._params.use_norm_layer is not None 108 | if self._params.use_norm_layer == 'InstanceNorm': 109 | self._norm1 = tfa.layers.InstanceNormalization(axis=2) 110 | self._norm2 = tfa.layers.InstanceNormalization(axis=2) 111 | elif self._params.use_norm_layer == 'BatchNorm': 112 | self._norm1 = layers.BatchNormalization(axis=2) 113 | self._norm2 = layers.BatchNormalization(axis=2) 114 | self._fc1 = layers.Dense(self._layer_sizes['fc1'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 115 | kernel_initializer=initializer) 116 | self._fc2 = layers.Dense(self._layer_sizes['fc2'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 117 | kernel_initializer=initializer) 118 | #rnn_layer = layers.LSTM 119 | rnn_layer = layers.GRU 120 | self._gru1 = rnn_layer(self._layer_sizes['gru1'], time_major=False, return_sequences=True, return_state=False, 121 | #trainable=False, 122 | #activation='sigmoid', 123 | dropout=self._params.net_gru_dropout, 124 | #recurrent_dropout=self._params.net_gru_dropout, --->> very slow!! (tf2.1) 125 | recurrent_initializer=initializer, kernel_initializer=initializer, 126 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 127 | if self._bidirectional_rnn: 128 | self._gru1 = layers.Bidirectional(self._gru1) 129 | self._gru2 = rnn_layer(self._layer_sizes['gru2'], time_major=False, return_sequences=True, return_state=False, 130 | #trainable=False, 131 | #activation='sigmoid', 132 | dropout=self._params.net_gru_dropout, 133 | #recurrent_dropout=self._params.net_gru_dropout, 134 | recurrent_initializer=initializer, kernel_initializer=initializer, 135 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 136 | if self._bidirectional_rnn: 137 | self._gru2 = layers.Bidirectional(self._gru2) 138 | self._gru3 = rnn_layer(self._layer_sizes['gru3'], time_major=False, 139 | return_sequences=not self._params.one_label_per_model, 140 | return_state=False, 141 | #trainable=False, 142 | #activation='sigmoid', 143 | dropout=self._params.net_gru_dropout, 144 | #recurrent_dropout=self._params.net_gru_dropout, 145 | recurrent_initializer=initializer, kernel_initializer=initializer, 146 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, 147 | bias_regularizer=kernel_regularizer) 148 | if self._bidirectional_rnn: 149 | self._gru3 = layers.Bidirectional(self._gru3) 150 | print('Using Bidirectional GRUs.') 151 | self._fc_last = layers.Dense(self._classes, activation=self._params.last_layer_activation, kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 152 | kernel_initializer=initializer) 153 | self._pooling = layers.MaxPooling1D(pool_size=3, strides=2, padding='same') 154 | 155 | self._norm_input = False 156 | if self._norm_input: 157 | self._norm_features = layers.LayerNormalization(axis=-1, trainable=False) 158 | 159 | # @tf.function 160 | def call(self, model_ftrs, classify=True, skip_1st=True, training=True, mask=None): 161 | if self._norm_input: 162 | model_ftrs = self._norm_features(model_ftrs) 163 | if skip_1st: 164 | x = model_ftrs[:, 1:] 165 | else: 166 | x = model_ftrs 167 | x = self._fc1(x) 168 | if self._use_norm_layer: 169 | x = self._norm1(x, training=training) 170 | x = tf.nn.relu(x) 171 | x = self._fc2(x) 172 | if self._use_norm_layer: 173 | x = self._norm2(x, training=training) 174 | x = tf.nn.relu(x) 175 | x1 = self._gru1(x, training=training) 176 | if self._pooling_betwin_grus: 177 | x1 = self._pooling(x1) 178 | if mask is not None: 179 | mask = mask[:, ::2] 180 | x2 = self._gru2(x1, training=training) 181 | if self._pooling_betwin_grus: 182 | x2 = self._pooling(x2) 183 | if mask is not None: 184 | mask = mask[:, ::2] 185 | x3 = self._gru3(x2, training=training, mask=mask) 186 | x = x3 187 | 188 | #if self._params.one_label_per_model: 189 | # x = x[:, -1, :] 190 | 191 | if classify: 192 | x = self._fc_last(x) 193 | return x 194 | 195 | def call_dbg(self, model_ftrs, classify=True, skip_1st=True, training=True, get_layer=None): 196 | if skip_1st: 197 | x = model_ftrs[:, 1:] 198 | else: 199 | x = model_ftrs 200 | if get_layer == 'input': 201 | return x 202 | x = self._fc1(x) 203 | if self._use_norm_layer: 204 | x = self._norm1(x, training=training) 205 | x = tf.nn.relu(x) 206 | if get_layer == 'fc1': 207 | return x 208 | x = self._fc2(x) 209 | if self._use_norm_layer: 210 | x = self._norm2(x, training=training) 211 | x = tf.nn.relu(x) 212 | if get_layer == 'fc2': 213 | return x 214 | x = self._gru1(x, training=training) 215 | if get_layer == 'gru1': 216 | return x 217 | x = self._gru2(x, training=training) 218 | if get_layer == 'gru2': 219 | return x 220 | x = self._gru3(x, training=training) 221 | if get_layer == 'gru3': 222 | return x 223 | 224 | if self._params.one_label_per_model: 225 | x = x[:, -1, :] 226 | 227 | if classify: 228 | x = self._fc_last(x) 229 | return x 230 | 231 | 232 | class RnnManifoldWalkNet(RnnWalkBase): 233 | def __init__(self, 234 | params, 235 | classes, 236 | net_input_dim, 237 | model_fn, 238 | model_must_be_load=False, 239 | dump_model_visualization=True, 240 | optimizer=None): 241 | if params.layer_sizes is None: 242 | self._layer_sizes = {'fc1': 128, 'fc2': 256, 'gru1': 1024, 'gru2': 1024, 'gru3': 512} 243 | else: 244 | self._layer_sizes = params.layer_sizes 245 | super(RnnManifoldWalkNet, self).__init__(params, classes, net_input_dim, model_fn, model_must_be_load=model_must_be_load, 246 | dump_model_visualization=dump_model_visualization, optimizer=optimizer) 247 | 248 | def _init_layers(self): 249 | kernel_regularizer = tf.keras.regularizers.l2(0.0001) 250 | initializer = tf.initializers.Orthogonal(3) 251 | self._use_norm_layer = self._params.use_norm_layer is not None 252 | if self._params.use_norm_layer == 'InstanceNorm': 253 | self._norm1 = tfa.layers.InstanceNormalization(axis=2) 254 | self._norm2 = tfa.layers.InstanceNormalization(axis=2) 255 | elif self._params.use_norm_layer == 'BatchNorm': 256 | self._norm1 = layers.BatchNormalization(axis=2) 257 | self._norm2 = layers.BatchNormalization(axis=2) 258 | self._fc1 = layers.Dense(self._layer_sizes['fc1'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 259 | kernel_initializer=initializer) 260 | self._fc2 = layers.Dense(self._layer_sizes['fc2'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 261 | kernel_initializer=initializer) 262 | #rnn_layer = layers.LSTM 263 | rnn_layer = layers.GRU 264 | self._gru1 = rnn_layer(self._layer_sizes['gru1'], time_major=False, return_sequences=True, return_state=False, 265 | #trainable=False, 266 | #activation='sigmoid', 267 | dropout=self._params.net_gru_dropout, 268 | #recurrent_dropout=self._params.net_gru_dropout, --->> very slow!! (tf2.1) 269 | recurrent_initializer=initializer, kernel_initializer=initializer, 270 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 271 | if self._bidirectional_rnn: 272 | self._gru1 = layers.Bidirectional(self._gru1) 273 | self._gru2 = rnn_layer(self._layer_sizes['gru2'], time_major=False, return_sequences=True, return_state=False, 274 | #trainable=False, 275 | #activation='sigmoid', 276 | dropout=self._params.net_gru_dropout, 277 | #recurrent_dropout=self._params.net_gru_dropout, 278 | recurrent_initializer=initializer, kernel_initializer=initializer, 279 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 280 | if self._bidirectional_rnn: 281 | self._gru2 = layers.Bidirectional(self._gru2) 282 | self._gru3 = rnn_layer(self._layer_sizes['gru3'], time_major=False, 283 | return_sequences=not self._params.one_label_per_model, 284 | return_state=False, 285 | #trainable=False, 286 | #activation='sigmoid', 287 | dropout=self._params.net_gru_dropout, 288 | #recurrent_dropout=self._params.net_gru_dropout, 289 | recurrent_initializer=initializer, kernel_initializer=initializer, 290 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, 291 | bias_regularizer=kernel_regularizer) 292 | if self._bidirectional_rnn: 293 | self._gru3 = layers.Bidirectional(self._gru3) 294 | print('Using Bidirectional GRUs.') 295 | self._fc_last = layers.Dense(self._classes, activation=self._params.last_layer_activation, kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 296 | kernel_initializer=initializer) 297 | self._pooling = layers.MaxPooling1D(pool_size=3, strides=2, padding='same') 298 | 299 | self._norm_input = False 300 | if self._norm_input: 301 | self._norm_features = layers.LayerNormalization(axis=-1, trainable=False) 302 | 303 | # @tf.function 304 | def call(self, model_ftrs, classify=True, skip_1st=True, training=True, mask=None): 305 | if self._norm_input: 306 | model_ftrs = self._norm_features(model_ftrs) 307 | if skip_1st: 308 | x = model_ftrs[:, 1:] 309 | else: 310 | x = model_ftrs 311 | x = self._fc1(x) 312 | if self._use_norm_layer: 313 | x = self._norm1(x, training=training) 314 | x = tf.nn.relu(x) 315 | x = self._fc2(x) 316 | if self._use_norm_layer: 317 | x = self._norm2(x, training=training) 318 | x = tf.nn.relu(x) 319 | x1 = self._gru1(x, training=training) 320 | if self._pooling_betwin_grus: 321 | x1 = self._pooling(x1) 322 | if mask is not None: 323 | mask = mask[:, ::2] 324 | x2 = self._gru2(x1, training=training) 325 | if self._pooling_betwin_grus: 326 | x2 = self._pooling(x2) 327 | if mask is not None: 328 | mask = mask[:, ::2] 329 | x3 = self._gru3(x2, training=training, mask=mask) 330 | x = x3 331 | 332 | if classify: 333 | x = self._fc_last(x) 334 | return x 335 | 336 | def call_dbg(self, model_ftrs, classify=True, skip_1st=True, training=True, get_layer=None): 337 | if skip_1st: 338 | x = model_ftrs[:, 1:] 339 | else: 340 | x = model_ftrs 341 | if get_layer == 'input': 342 | return x 343 | x = self._fc1(x) 344 | if self._use_norm_layer: 345 | x = self._norm1(x, training=training) 346 | x = tf.nn.relu(x) 347 | if get_layer == 'fc1': 348 | return x 349 | x = self._fc2(x) 350 | if self._use_norm_layer: 351 | x = self._norm2(x, training=training) 352 | x = tf.nn.relu(x) 353 | if get_layer == 'fc2': 354 | return x 355 | x = self._gru1(x, training=training) 356 | if get_layer == 'gru1': 357 | return x 358 | x = self._gru2(x, training=training) 359 | if get_layer == 'gru2': 360 | return x 361 | x = self._gru3(x, training=training) 362 | if get_layer == 'gru3': 363 | return x 364 | 365 | if self._params.one_label_per_model: 366 | x = x[:, -1, :] 367 | 368 | if classify: 369 | x = self._fc_last(x) 370 | return x 371 | 372 | 373 | class Unsupervised_RnnWalkNet(RnnWalkBase): 374 | def __init__(self, 375 | params, 376 | classes, 377 | net_input_dim, 378 | model_fn, 379 | model_must_be_load=False, 380 | dump_model_visualization=True, 381 | optimizer=None): 382 | if params.layer_sizes is None: 383 | self._layer_sizes = {'fc1': 128, 'fc2': 256, 'gru1': 1024, 'gru2': 1024, 'gru3': 512} 384 | else: 385 | self._layer_sizes = params.layer_sizes 386 | if params.network_task == 'features_extraction': 387 | self.features_extraction = True 388 | else: 389 | self.features_extraction = False 390 | 391 | super(Unsupervised_RnnWalkNet, self).__init__(params, classes, net_input_dim, model_fn, model_must_be_load=model_must_be_load, 392 | dump_model_visualization=dump_model_visualization, optimizer=optimizer) 393 | 394 | def _init_layers(self): 395 | kernel_regularizer = tf.keras.regularizers.l2(0.0001) 396 | initializer = tf.initializers.Orthogonal(3) 397 | self._use_norm_layer = self._params.use_norm_layer is not None 398 | if self._params.use_norm_layer == 'InstanceNorm': 399 | self._norm1 = tfa.layers.InstanceNormalization(axis=2) 400 | self._norm2 = tfa.layers.InstanceNormalization(axis=2) 401 | elif self._params.use_norm_layer == 'BatchNorm': 402 | self._norm1 = layers.BatchNormalization(axis=2) 403 | self._norm2 = layers.BatchNormalization(axis=2) 404 | self._fc1 = layers.Dense(self._layer_sizes['fc1'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 405 | kernel_initializer=initializer) 406 | self._fc2 = layers.Dense(self._layer_sizes['fc2'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 407 | kernel_initializer=initializer) 408 | #rnn_layer = layers.LSTM 409 | rnn_layer = layers.GRU 410 | self._gru1 = rnn_layer(self._layer_sizes['gru1'], time_major=False, return_sequences=True, return_state=False, 411 | #trainable=False, 412 | #activation='sigmoid', 413 | dropout=self._params.net_gru_dropout, 414 | #recurrent_dropout=self._params.net_gru_dropout, --->> very slow!! (tf2.1) 415 | recurrent_initializer=initializer, kernel_initializer=initializer, 416 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 417 | if self._bidirectional_rnn: 418 | self._gru1 = layers.Bidirectional(self._gru1) 419 | self._gru2 = rnn_layer(self._layer_sizes['gru2'], time_major=False, return_sequences=True, return_state=False, 420 | #trainable=False, 421 | #activation='sigmoid', 422 | dropout=self._params.net_gru_dropout, 423 | #recurrent_dropout=self._params.net_gru_dropout, 424 | recurrent_initializer=initializer, kernel_initializer=initializer, 425 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 426 | if self._bidirectional_rnn: 427 | self._gru2 = layers.Bidirectional(self._gru2) 428 | self._gru3 = rnn_layer(self._layer_sizes['gru3'], time_major=False, 429 | return_sequences=not self._params.one_label_per_model, 430 | return_state=False, 431 | #trainable=False, 432 | #activation='sigmoid', 433 | dropout=self._params.net_gru_dropout, 434 | #recurrent_dropout=self._params.net_gru_dropout, 435 | recurrent_initializer=initializer, kernel_initializer=initializer, 436 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, 437 | bias_regularizer=kernel_regularizer) 438 | if self._bidirectional_rnn: 439 | self._gru3 = layers.Bidirectional(self._gru3) 440 | print('Using Bidirectional GRUs.') 441 | 442 | if not self.features_extraction: 443 | self._fc_last = layers.Dense(self._classes, activation=self._params.last_layer_activation, kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 444 | kernel_initializer=initializer) 445 | self._pooling = layers.MaxPooling1D(pool_size=3, strides=2, padding='same') 446 | 447 | self._l2_normalization = layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1)) 448 | 449 | self._norm_input = False 450 | if self._norm_input: 451 | self._norm_features = layers.LayerNormalization(axis=-1, trainable=False) 452 | 453 | # @tf.function 454 | def call(self, model_ftrs, classify=True, skip_1st=True, training=True, mask=None): 455 | if self._norm_input: 456 | model_ftrs = self._norm_features(model_ftrs) 457 | if skip_1st: 458 | x = model_ftrs[:, 1:] 459 | else: 460 | x = model_ftrs 461 | x = self._fc1(x) 462 | if self._use_norm_layer: 463 | x = self._norm1(x, training=training) 464 | x = tf.nn.relu(x) 465 | x = self._fc2(x) 466 | if self._use_norm_layer: 467 | x = self._norm2(x, training=training) 468 | x = tf.nn.relu(x) 469 | x1 = self._gru1(x, training=training) 470 | if self._pooling_betwin_grus: 471 | x1 = self._pooling(x1) 472 | if mask is not None: 473 | mask = mask[:, ::2] 474 | x2 = self._gru2(x1, training=training) 475 | if self._pooling_betwin_grus: 476 | x2 = self._pooling(x2) 477 | if mask is not None: 478 | mask = mask[:, ::2] 479 | x3 = self._gru3(x2, training=training, mask=mask) 480 | x = x3 481 | 482 | #if self._params.one_label_per_model: 483 | # x = x[:, -1, :] 484 | if self.features_extraction: 485 | x = self._l2_normalization(x) 486 | elif classify: 487 | x = self._fc_last(x) 488 | # L2 normalization is meeded when working with triplet loss 489 | x = self._l2_normalization(x) 490 | 491 | return x 492 | 493 | def call_dbg(self, model_ftrs, classify=True, skip_1st=True, training=True, get_layer=None): 494 | if skip_1st: 495 | x = model_ftrs[:, 1:] 496 | else: 497 | x = model_ftrs 498 | if get_layer == 'input': 499 | return x 500 | x = self._fc1(x) 501 | if self._use_norm_layer: 502 | x = self._norm1(x, training=training) 503 | x = tf.nn.relu(x) 504 | if get_layer == 'fc1': 505 | return x 506 | x = self._fc2(x) 507 | if self._use_norm_layer: 508 | x = self._norm2(x, training=training) 509 | x = tf.nn.relu(x) 510 | if get_layer == 'fc2': 511 | return x 512 | x = self._gru1(x, training=training) 513 | if get_layer == 'gru1': 514 | return x 515 | x = self._gru2(x, training=training) 516 | if get_layer == 'gru2': 517 | return x 518 | x = self._gru3(x, training=training) 519 | if get_layer == 'gru3': 520 | return x 521 | 522 | if self._params.one_label_per_model: 523 | x = x[:, -1, :] 524 | 525 | if classify and not self.features_extraction: 526 | x = self._fc_last(x) 527 | return x 528 | 529 | class AttentionWalkNet(RnnWalkBase): 530 | def __init__(self, 531 | params, 532 | classes, 533 | net_input_dim, 534 | model_fn, 535 | layer_sizes={'fc1': 128, 'fc2': 256, 'gru1': 1024, 'gru2': 1024, 'gru3': 512, 'gru_dec1': 1024, 'gru_dec2': 512}, 536 | model_must_be_load=False, 537 | optimizer=None): 538 | self._layer_sizes = layer_sizes 539 | super(AttentionWalkNet, self).__init__(params, classes, net_input_dim, model_fn, model_must_be_load=model_must_be_load, optimizer=optimizer) 540 | 541 | def _init_layers(self): 542 | kernel_regularizer = tf.keras.regularizers.l2(0.0001) 543 | initializer = tf.initializers.Orthogonal(3) 544 | self._use_norm_layer = 1 545 | if self._use_norm_layer: 546 | self._norm1 = tfa.layers.InstanceNormalization(axis=2) 547 | self._norm2 = tfa.layers.InstanceNormalization(axis=2) 548 | self._fc1 = layers.Dense(self._layer_sizes['fc1'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 549 | kernel_initializer=initializer) 550 | self._fc2 = layers.Dense(self._layer_sizes['fc2'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 551 | kernel_initializer=initializer) 552 | self._gru1 = layers.GRU(self._layer_sizes['gru1'], time_major=False, return_sequences=True, return_state=False, 553 | #activation='sigmoid', 554 | recurrent_initializer=initializer, kernel_initializer=initializer, 555 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 556 | self._gru2 = layers.GRU(self._layer_sizes['gru2'], time_major=False, return_sequences=True, return_state=False, 557 | #activation='sigmoid', 558 | recurrent_initializer=initializer, kernel_initializer=initializer, 559 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 560 | 561 | self._gru3 = layers.GRU(self._layer_sizes['gru3'], time_major=False, return_sequences=True, return_state=True, 562 | #activation='sigmoid', 563 | recurrent_initializer=initializer, kernel_initializer=initializer, 564 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, 565 | bias_regularizer=kernel_regularizer) 566 | self._attention_layer = BahdanauAttention(10) 567 | 568 | self._gru_decode_1 = layers.GRU(self._layer_sizes['gru_dec1'], time_major=False, return_sequences=True, return_state=False, 569 | recurrent_initializer=initializer, kernel_initializer=initializer, 570 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 571 | self._gru_decode_2 = layers.GRU(self._layer_sizes['gru_dec2'], time_major=False, return_sequences=True, return_state=False, 572 | recurrent_initializer=initializer, kernel_initializer=initializer, 573 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 574 | 575 | self._fc_last = layers.Dense(self._classes, activation='sigmoid', kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 576 | kernel_initializer=initializer) 577 | 578 | #@tf.function 579 | def call(self, model_ftrs, classify=True, skip_1st=True, training=True): 580 | if skip_1st: 581 | x = model_ftrs[:, 1:] 582 | else: 583 | x = model_ftrs 584 | model_ftrs_ = x 585 | 586 | # Encoder 587 | # ------- 588 | x = self._fc1(x) 589 | if self._use_norm_layer: 590 | x = self._norm1(x, training=training) 591 | x = tf.nn.relu(x) 592 | x = self._fc2(x) 593 | if self._use_norm_layer: 594 | x = self._norm2(x, training=training) 595 | x = tf.nn.relu(x) 596 | x = self._gru1(x) 597 | x = self._gru2(x) 598 | output, hidden = self._gru3(x) 599 | 600 | # Attention 601 | # --------- 602 | context_vector, attention_weights = self._attention_layer(hidden, output) 603 | 604 | # Decoder 605 | # ------- 606 | x = tf.concat([tf.expand_dims(context_vector, 1), model_ftrs_], axis=-1) 607 | x = self._gru_decode_1(x) 608 | x = self._gru_decode_2(x) 609 | x = self._fc_last(x) 610 | 611 | return x 612 | 613 | class RnnStrideWalkNet(RnnWalkBase): 614 | def __init__(self, 615 | params, 616 | classes, 617 | net_input_dim, 618 | model_fn, 619 | layer_sizes={'fc1': 128, 'fc2': 256, 'gru1': 1024, 'gru2': 1024, 'gru3': 1024}, 620 | model_must_be_load=False): 621 | self._layer_sizes = layer_sizes 622 | super(RnnStrideWalkNet, self).__init__(params, classes, net_input_dim, model_fn, model_must_be_load=model_must_be_load) 623 | 624 | def _init_layers(self): 625 | kernel_regularizer = tf.keras.regularizers.l2(0.0001) 626 | initializer = tf.initializers.Orthogonal(3) 627 | self._use_norm_layer = 1 628 | if self._use_norm_layer: 629 | self._norm1 = tfa.layers.InstanceNormalization(axis=2) 630 | self._norm2 = tfa.layers.InstanceNormalization(axis=2) 631 | self._fc1 = layers.Dense(self._layer_sizes['fc1'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 632 | kernel_initializer=initializer) 633 | self._fc2 = layers.Dense(self._layer_sizes['fc2'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 634 | kernel_initializer=initializer) 635 | self._gru1 = layers.GRU(self._layer_sizes['gru1'], time_major=False, return_sequences=True, return_state=False, 636 | #activation='sigmoid', 637 | recurrent_initializer=initializer, kernel_initializer=initializer, 638 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 639 | self._gru2 = layers.GRU(self._layer_sizes['gru2'], time_major=False, return_sequences=True, return_state=False, 640 | #activation='sigmoid', 641 | recurrent_initializer=initializer, kernel_initializer=initializer, 642 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 643 | self._gru3 = layers.GRU(self._layer_sizes['gru3'], time_major=False, return_sequences=True, return_state=False, 644 | #activation='sigmoid', 645 | recurrent_initializer=initializer, kernel_initializer=initializer, 646 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, 647 | bias_regularizer=kernel_regularizer) 648 | self._fc_last = layers.Dense(self._classes, activation='sigmoid', kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 649 | kernel_initializer=initializer) 650 | self._pooling = layers.MaxPooling1D(pool_size=3, strides=2, padding='same') 651 | self._up_sampling = layers.UpSampling1D(size=2) 652 | 653 | #@tf.function 654 | def call(self, model_ftrs, classify=True, skip_1st=True, training=True): 655 | if skip_1st: 656 | x = model_ftrs[:, 1:] 657 | else: 658 | x = model_ftrs 659 | x = self._fc1(x) 660 | if self._use_norm_layer: 661 | x = self._norm1(x, training=training) 662 | x = tf.nn.relu(x) 663 | x = self._fc2(x) 664 | if self._use_norm_layer: 665 | x = self._norm2(x, training=training) 666 | x = tf.nn.relu(x) 667 | x = self._gru1(x) 668 | before_pooling = x 669 | x = self._pooling(x) 670 | x = self._gru2(x) 671 | x = self._gru3(x) 672 | x = self._up_sampling(x) 673 | x = x[:, :before_pooling.shape[1], :] + before_pooling 674 | 675 | if self._params.one_label_per_model: 676 | x = x[:, -1, :] 677 | 678 | if classify: 679 | x = self._fc_last(x) 680 | return x 681 | 682 | def set_up_rnn_walk_model(): 683 | _layer_sizes = {'fc1': 128, 'fc2': 256, 'gru1': 1024, 'gru2': 1024, 'gru3': 512} 684 | last_layer_activation = 'softmax' 685 | _classes = 40 686 | training = True 687 | one_label_per_model = True 688 | classify = True 689 | 690 | input = keras.Input(shape=(28, 28, 1), name='original_img') 691 | 692 | kernel_regularizer = tf.keras.regularizers.l2(0.0001) 693 | initializer = tf.initializers.Orthogonal(3) 694 | _norm1 = tfa.layers.InstanceNormalization(axis=2) 695 | _norm2 = tfa.layers.InstanceNormalization(axis=2) 696 | _fc1 = layers.Dense(_layer_sizes['fc1'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 697 | kernel_initializer=initializer) 698 | _fc2 = layers.Dense(_layer_sizes['fc2'], kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 699 | kernel_initializer=initializer) 700 | _gru1 = layers.GRU(_layer_sizes['gru1'], time_major=False, return_sequences=True, return_state=False, 701 | recurrent_initializer=initializer, kernel_initializer=initializer, 702 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 703 | _gru2 = layers.GRU(_layer_sizes['gru2'], time_major=False, return_sequences=True, return_state=False, 704 | recurrent_initializer=initializer, kernel_initializer=initializer, 705 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer) 706 | _gru3 = layers.GRU(_layer_sizes['gru3'], time_major=False, return_sequences=True, return_state=False, 707 | recurrent_initializer=initializer, kernel_initializer=initializer, 708 | kernel_regularizer=kernel_regularizer, recurrent_regularizer=kernel_regularizer, 709 | bias_regularizer=kernel_regularizer) 710 | _fc_last = layers.Dense(_classes, activation=last_layer_activation, kernel_regularizer=kernel_regularizer, bias_regularizer=kernel_regularizer, 711 | kernel_initializer=initializer) 712 | 713 | inputs = keras.Input(shape=(100, 4,)) 714 | x = inputs 715 | x = _fc1(x) 716 | x = _norm1(x, training=training) 717 | x = tf.nn.relu(x) 718 | x = _fc2(x) 719 | x = _norm2(x, training=training) 720 | x = tf.nn.relu(x) 721 | x = _gru1(x) 722 | x = _gru2(x) 723 | x = _gru3(x) 724 | 725 | if one_label_per_model: 726 | x = x[:, -1, :] 727 | 728 | if classify: 729 | x = _fc_last(x) 730 | 731 | outputs = x 732 | 733 | model = keras.Model(inputs=inputs, outputs=outputs, name='mnist_model') 734 | 735 | return model 736 | 737 | 738 | class BahdanauAttention(tf.keras.layers.Layer): 739 | def __init__(self, units): 740 | super(BahdanauAttention, self).__init__() 741 | self.W1 = tf.keras.layers.Dense(units) 742 | self.W2 = tf.keras.layers.Dense(units) 743 | self.V = tf.keras.layers.Dense(1) 744 | 745 | def call(self, query, values): 746 | # hidden shape == (batch_size, hidden size) 747 | # hidden_with_time_axis shape == (batch_size, 1, hidden size) 748 | # we are doing this to perform addition to calculate the score 749 | hidden_with_time_axis = tf.expand_dims(query, 1) 750 | 751 | # score shape == (batch_size, max_length, 1) 752 | # we get 1 at the last axis because we are applying score to self.V 753 | # the shape of the tensor before applying self.V is (batch_size, max_length, units) 754 | score = self.V(tf.nn.tanh( 755 | self.W1(values) + self.W2(hidden_with_time_axis))) 756 | 757 | # attention_weights shape == (batch_size, max_length, 1) 758 | attention_weights = tf.nn.softmax(score, axis=1) 759 | 760 | # context_vector shape after sum == (batch_size, hidden_size) 761 | context_vector = attention_weights * values 762 | context_vector = tf.reduce_sum(context_vector, axis=1) 763 | 764 | return context_vector, attention_weights 765 | 766 | 767 | class GruAndVanillaLayer(tf.keras.Model): # TODO: to inherent from layer, not model 768 | def __init__(self, units_, initializer, regularizer): 769 | super(GruAndVanillaLayer, self).__init__(name='') 770 | units = int(units_ / 2) 771 | regularizer_ = tf.keras.regularizers.l2(100) 772 | self._gru = layers.GRU(units, time_major=False, return_sequences=True, return_state=False, 773 | kernel_regularizer=regularizer, recurrent_regularizer=regularizer, bias_regularizer=regularizer, 774 | kernel_initializer=initializer, recurrent_initializer=initializer) 775 | 776 | self._simple_rnn = layers.SimpleRNN(units, time_major=False, return_sequences=True, return_state=False, activation=None, 777 | kernel_regularizer=regularizer_, recurrent_regularizer=regularizer_, bias_regularizer=regularizer_, 778 | kernel_initializer=initializer, recurrent_initializer=initializer) 779 | 780 | def call(self, x): 781 | x1 = self._gru(x) 782 | x2 = self._simple_rnn(x) 783 | r = tf.concat((x1, x2), axis=2) 784 | return r 785 | 786 | 787 | def show_model(): 788 | def fn(to_print): 789 | print(to_print) 790 | if 1: 791 | params = EasyDict({'n_classes': 3, 'net_input_dim': 3, 'batch_size': 32, 'last_layer_activation': 'softmax', 792 | 'one_label_per_model': True, 'logdir': '.'}) 793 | params.net_input_dim = 3 + 5 794 | model = RnnWalkNet(params, classes=3, net_input_dim=3, model_fn=None) 795 | else: 796 | model = set_up_rnn_walk_model() 797 | tf.keras.utils.plot_model(model, "RnnWalkModel.png", show_shapes=True) 798 | model.summary(print_fn=fn) 799 | 800 | if __name__ == '__main__': 801 | np.random.seed(0) 802 | utils.config_gpu(0) 803 | show_model() 804 | --------------------------------------------------------------------------------