├── data ├── __init__.py ├── narrator_dataset.txt ├── scene_registration.pkl ├── body_model.py ├── shape_distribution.py ├── data_utils.py ├── scene_registration.py └── scene_graph.py ├── models ├── __init__.py ├── graph_layers.py ├── mesh.py ├── body_encoder.py ├── loss.py └── transform_trainer.py ├── evaluation ├── __init__.py ├── render_results.py ├── load_results.py └── eval_results.py ├── utils ├── narrator_utils.py ├── mesh_utils.py ├── eulerangles.py ├── chamfer_distance.py ├── pointnet2.py └── viz_util.py ├── configuration ├── __init__.py ├── recordings_temporal.txt ├── mpcat40.tsv ├── joints.py └── config.py ├── SceneGraphNet ├── data │ ├── __init__.py │ ├── preprocess │ │ ├── __init__.py │ │ ├── TRAIN_id2cat_bathroom.json │ │ ├── TRAIN_id2cat_office.json │ │ ├── TRAIN_id2cat_bedroom.json │ │ └── TRAIN_id2cat_living.json │ └── data_structure.md ├── sgmodel │ ├── __init__.py │ └── model.py ├── utils │ ├── __init__.py │ ├── default_settings.py │ └── utl.py └── main.py ├── imgs_Teaser.png ├── imgs_Pipeline.png ├── environment.yml └── README.md /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/narrator_utils.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configuration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SceneGraphNet/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SceneGraphNet/sgmodel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SceneGraphNet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/narrator_dataset.txt: -------------------------------------------------------------------------------- 1 | waiting 2 | -------------------------------------------------------------------------------- /SceneGraphNet/data/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs_Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaibiaoXuan/Narrator/HEAD/imgs_Teaser.png -------------------------------------------------------------------------------- /imgs_Pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaibiaoXuan/Narrator/HEAD/imgs_Pipeline.png -------------------------------------------------------------------------------- /data/scene_registration.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaibiaoXuan/Narrator/HEAD/data/scene_registration.pkl -------------------------------------------------------------------------------- /SceneGraphNet/data/preprocess/TRAIN_id2cat_bathroom.json: -------------------------------------------------------------------------------- 1 | {"0": "kitchenware", "1": "trash_can", "2": "ottoman", "3": "vase", "4": "plant", "5": "air_conditioner", "6": "heater", "7": "person", "8": "rug", "9": "door", "10": "household_appliance", "11": "hanger", "12": "stand", "13": "wardrobe_cabinet", "14": "curtain", "15": "toilet", "16": "wall", "17": "shower", "18": "shelving", "19": "picture_frame", "20": "chair", "21": "switch", "22": "window", "23": "column", "24": "bathroom_stuff", "25": "sink", "26": "partition", "27": "toy", "28": "mirror", "29": "bathtub", "30": "indoor_lamp"} -------------------------------------------------------------------------------- /SceneGraphNet/data/preprocess/TRAIN_id2cat_office.json: -------------------------------------------------------------------------------- 1 | {"0": "household_appliance", "1": "kitchen_appliance", "2": "toy", "3": "mirror", "4": "table", "5": "column", "6": "air_conditioner", "7": "wall", "8": "hanger", "9": "dresser", "10": "desk", "11": "table_and_chair", "12": "shelving", "13": "television", "14": "fan", "15": "window", "16": "clock", "17": "heater", "18": "books", "19": "rug", "20": "whiteboard", "21": "plant", "22": "indoor_lamp", "23": "sofa", "24": "gym_equipment", "25": "switch", "26": "computer", "27": "kitchenware", "28": "workplace", "29": "ottoman", "30": "arch", "31": "chair", "32": "door", "33": "stand", "34": "wardrobe_cabinet", "35": "person", "36": "picture_frame", "37": "vase", "38": "tv_stand", "39": "curtain", "40": "music", "41": "partition"} -------------------------------------------------------------------------------- /SceneGraphNet/utils/default_settings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, json 3 | from utils.utl import try_mkdir 4 | 5 | ''' CHANGE TO YOUR OWN DATASET DIRECTORY ''' 6 | root_dir = r'/YOUR/DATASET/DIRECTORY' 7 | 8 | pkl_dir = os.path.join(root_dir, 'data') 9 | log_dir = os.path.join(root_dir, 'nn', 'logs') 10 | ckpt_dir = os.path.join(root_dir, 'nn', 'ckpts') 11 | 12 | try_mkdir(os.path.join(root_dir, 'nn')) 13 | try_mkdir(log_dir) 14 | try_mkdir(ckpt_dir) 15 | 16 | 17 | id2type = np.loadtxt('data/preprocess/SUNCG_id2type.csv', delimiter=',', dtype=str) 18 | dic_id2type = {} 19 | dic_detail2causal = {} 20 | for line in id2type: 21 | dic_id2type[line[1]] = (line[2], line[3]) 22 | dic_detail2causal[line[2]] = line[3] 23 | 24 | k_size_dic = {'bedroom':58, 'living':58, 'bathroom':38, 'office':49} 25 | -------------------------------------------------------------------------------- /data/body_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | from configuration.config import * 4 | 5 | import smplx 6 | import torch 7 | import numpy as np 8 | 9 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | body_model_dict = { 11 | 'male': smplx.create(smplx_model_folder, model_type='smplx', 12 | gender='male', ext='npz', 13 | num_pca_comps=num_pca_comps).to(device), 14 | 'female': smplx.create(smplx_model_folder, model_type='smplx', 15 | gender='female', ext='npz', 16 | num_pca_comps=num_pca_comps).to(device), 17 | 'neutral': smplx.create(smplx_model_folder, model_type='smplx', 18 | gender='neutral', ext='npz', 19 | num_pca_comps=num_pca_comps).to(device) 20 | } -------------------------------------------------------------------------------- /SceneGraphNet/data/preprocess/TRAIN_id2cat_bedroom.json: -------------------------------------------------------------------------------- 1 | {"0": "pillow", "1": "arch", "2": "kitchen_cabinet", "3": "shoes_cabinet", "4": "column", "5": "partition", "6": "clock", "7": "kitchenware", "8": "household_appliance", "9": "fan", "10": "heater", "11": "hanger", "12": "vase", "13": "music", "14": "air_conditioner", "15": "sofa", "16": "mirror", "17": "dressing_table", "18": "tv_stand", "19": "person", "20": "ottoman", "21": "table", "22": "toy", "23": "switch", "24": "plant", "25": "dresser", "26": "desk", "27": "books", "28": "computer", "29": "shelving", "30": "rug", "31": "television", "32": "picture_frame", "33": "curtain", "34": "chair", "35": "stand", "36": "wardrobe_cabinet", "37": "bed", "38": "window", "39": "door", "40": "indoor_lamp", "41": "wall", "42": "magazines", "43": "table_and_chair", "44": "shoes", "45": "recreation", "46": "whiteboard", "47": "trinket", "48": "candle", "49": "gym_equipment", "50": "pet"} -------------------------------------------------------------------------------- /SceneGraphNet/data/preprocess/TRAIN_id2cat_living.json: -------------------------------------------------------------------------------- 1 | {"0": "clock", "1": "tv_stand", "2": "switch", "3": "desk", "4": "table", "5": "household_appliance", "6": "kitchenware", "7": "gym_equipment", "8": "fan", "9": "hanging_kitchen_cabinet", "10": "curtain", "11": "recreation", "12": "arch", "13": "television", "14": "pet", "15": "dresser", "16": "stand", "17": "table_and_chair", "18": "candle", "19": "kitchen_cabinet", "20": "wardrobe_cabinet", "21": "books", "22": "rug", "23": "sofa", "24": "wall", "25": "magazines", "26": "music", "27": "shoes_cabinet", "28": "fireplace", "29": "hanger", "30": "computer", "31": "vase", "32": "shelving", "33": "stairs", "34": "partition", "35": "heater", "36": "ottoman", "37": "window", "38": "picture_frame", "39": "chair", "40": "pillow", "41": "kitchen_appliance", "42": "plant", "43": "door", "44": "column", "45": "mirror", "46": "toy", "47": "air_conditioner", "48": "indoor_lamp", "49": "trinket", "50": "person"} -------------------------------------------------------------------------------- /data/shape_distribution.py: -------------------------------------------------------------------------------- 1 | """ 2 | Naive shape parameter distribution. Used by PiGraph. 3 | """ 4 | 5 | import sys 6 | sys.path.append('..') 7 | from configuration.config import * 8 | 9 | import numpy as np 10 | import pickle 11 | 12 | with open(Path.joinpath(project_folder, "data", 'train.pkl'), 'rb') as data_file: 13 | train_data = pickle.load(data_file) 14 | shape_params = np.asarray([record['smplx_param']['betas'] for record in train_data]).reshape((-1, 10)) 15 | # print(shape_params.shape) 16 | shape_params = np.unique(shape_params, axis=0) 17 | # print(shape_params.shape) 18 | shape_mean = np.mean(shape_params, axis=0) 19 | shape_cov = np.cov(shape_params, rowvar=0) 20 | shape_distribution = {'mean': shape_mean, 'cov': shape_cov} 21 | # print(np.random.multivariate_normal(**shape_distribution, size=4)) 22 | 23 | def sample_betas(size=1): 24 | return np.random.multivariate_normal(**shape_distribution, size=size).astype(np.float32) 25 | 26 | -------------------------------------------------------------------------------- /configuration/recordings_temporal.txt: -------------------------------------------------------------------------------- 1 | BasementSittingBooth_00142_01 2 | MPH11_00034_01 3 | MPH11_00150_01 4 | MPH11_00151_01 5 | MPH11_03515_01 6 | MPH112_00034_01 7 | MPH112_00150_01 8 | MPH112_00151_01 9 | MPH112_00157_01 10 | MPH112_00169_01 11 | MPH16_00157_01 12 | MPH1Library_00034_01 13 | MPH8_00168_01 14 | N0SittingBooth_00169_01 15 | N0SittingBooth_03301_01 16 | N0SittingBooth_03403_01 17 | N0Sofa_00034_01 18 | N0Sofa_00034_02 19 | N0Sofa_00141_01 20 | N0Sofa_00145_01 21 | N3Library_00157_01 22 | N3Library_00157_02 23 | N3Library_03301_01 24 | N3Library_03301_02 25 | N3Library_03375_01 26 | N3Library_03375_02 27 | N3Library_03403_01 28 | N3Library_03403_02 29 | N3Office_00034_01 30 | N3Office_00139_01 31 | N3Office_00150_01 32 | N3Office_00153_01 33 | N3Office_00159_01 34 | N3Office_03301_01 35 | N3OpenArea_00157_01 36 | N3OpenArea_00157_02 37 | N3OpenArea_00158_01 38 | N3OpenArea_03301_01 39 | N3OpenArea_03403_01 40 | Werkraum_03301_01 41 | Werkraum_03403_01 42 | Werkraum_03516_01 43 | Werkraum_03516_02 44 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - pytorch3d 4 | - conda-forge 5 | - defaults 6 | - nvidia 7 | dependencies: 8 | - python=3.7 9 | - cudatoolkit=11.3 10 | - pytorch::pytorch=1.11.0 11 | - pytorch3d::pytorch3d 12 | - pytorch-lightning 13 | - torchvision 14 | - trimesh 15 | - tqdm 16 | - opencv 17 | - scikit-learn 18 | - matplotlib 19 | - Pillow 20 | - PyYAML 21 | - numpy 22 | - scipy 23 | - pandas 24 | - tensorboardX 25 | - pip 26 | - pip: 27 | - "smplx" 28 | - "plyfile" 29 | - "configer" 30 | - "pyrender" 31 | - "openmesh" 32 | - "boto3" 33 | - "torchgeometry" 34 | - "tensorboard" 35 | - "rtree" 36 | - "open3d" 37 | - "setuptools==59.5.0" 38 | - "git+https://github.com/MPI-IS/mesh.git" 39 | - "git+https://github.com/nghorbani/human_body_prior.git" 40 | - "git+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib" 41 | - "pyopengl==3.1.5" 42 | - "scikit-image" 43 | - "dotmap" 44 | 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Narrator: Towards Natural Control of Human-Scene Interaction Generation via Relationship Reasoning 2 | 3 | 4 | 5 | 6 | ### [Project Page](https://haibiaoxuan.github.io/Narrator/) | [Paper](https://arxiv.org/pdf/2303.09410.pdf) 7 | 8 | 9 | > [Narrator: Towards Natural Control of Human-Scene Interaction Generation via Relationship Reasoning]() 10 | > Haibiao Xuan, Xiongzheng Li, Jinsong Zhang, Hongwen Zhang, Yebin Liu and Kun Li 11 | 12 | Any discussions or questions would be welcome! 13 | 14 | ## News 15 | 16 | ## Install 17 | 18 | ``` 19 | ``` 20 | 21 | ## Dependencies 22 | 23 | ## Pretrained model 24 | 25 | 26 | ## TODO 27 | 28 | - Model checkpoints 29 | - Narrator dataset 30 | - Some remaining codes 31 | 32 | 33 | ## Citation 34 | 35 | If you find this code useful for your research, please use the following BibTeX entry. 36 | 37 | ``` 38 | @inproceedings{Narrator, 39 | title={Narrator: Towards Natural Control of Human-Scene Interaction Generation via Relationship Reasoning}, 40 | author={Haibiao Xuan, Xiongzheng Li, Jinsong Zhang, Hongwen Zhang, Yebin Liu and Kun Li}, 41 | journal={arXiv preprint arXiv:2303.09410}, 42 | year={2023} 43 | } 44 | ``` 45 | 46 | ## Acknowledgement 47 | 48 | ## License 49 | -------------------------------------------------------------------------------- /utils/mesh_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import torch 5 | import numpy as np 6 | import trimesh 7 | 8 | from configuration.config import * 9 | from configuration.joints import * 10 | 11 | def padded_idx_to_code(verb_ids): 12 | codebook = torch.cat([torch.eye(4, dtype=torch.float32, device=verb_ids.device), 13 | torch.zeros((1, 4), dtype=torch.float32, device=verb_ids.device)], dim=0) 14 | return codebook[verb_ids.long()] 15 | 16 | def transform_back(vertices, centroid, rotation): 17 | B, N, C = vertices.shape 18 | vertices = vertices.matmul(torch.inverse(rotation).transpose(1, 2)) 19 | vertices = vertices + centroid.unsqueeze(1) 20 | return vertices 21 | 22 | def skeleton_to_mesh(skeleton, color): 23 | joint_num = skeleton.shape[0] 24 | body = trimesh.primitives.Sphere(radius=0.05, center=skeleton[0]) 25 | body.visual.vertex_colors = np.array(color[0] * 255, dtype=np.uint8) 26 | for idx in range(1, joint_num): 27 | joint = skeleton[idx] 28 | joint_mesh = trimesh.primitives.Sphere(radius=0.05, center=joint) 29 | joint_mesh.visual.vertex_colors = np.array(color[idx] * 255, dtype=np.uint8) 30 | body = body + joint_mesh 31 | parent_joint = skeleton[parent_joint_idx[idx]] 32 | bone = np.array([joint, parent_joint]) 33 | body = body + trimesh.creation.cylinder(0.02, segment=bone, vertex_colors=color[0]) 34 | return body 35 | 36 | def narrator(query_text): 37 | return -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from configuration.config import * 6 | from human_body_prior.tools import tgm_conversion as tgm 7 | 8 | def matrot2aa(pose_matrot): 9 | ''' 10 | :param pose_matrot: Nx3x3 11 | :return: Nx3 12 | ''' 13 | bs = pose_matrot.size(0) 14 | homogen_matrot = F.pad(pose_matrot, [0,1]) 15 | pose = tgm.rotation_matrix_to_angle_axis(homogen_matrot) 16 | return pose 17 | 18 | def aa2matrot(pose): 19 | ''' 20 | :param Nx3 21 | :return: pose_matrot: Nx3x3 22 | ''' 23 | bs = pose.size(0) 24 | num_joints = pose.size(1)//3 25 | pose_body_matrot = tgm.angle_axis_to_rotation_matrix(pose)[:, :3, :3].contiguous()#.view(bs, num_joints*9) 26 | return pose_body_matrot 27 | 28 | def smplx_dict_to_tensor(smplx_dict): 29 | # ['transl', 'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'betas'] 30 | smplx_params = [smplx_dict[param] for param in used_smplx_param_names] 31 | smplx_tensor = torch.cat(smplx_params, dim=1) 32 | return smplx_tensor 33 | 34 | def smplx_dict_to_rotmat(smplx_dict): 35 | rotvec = torch.cat([smplx_dict['global_orient'], smplx_dict['body_pose']], dim=1) 36 | rotmat = aa2matrot(rotvec.view(-1, 3)).view(-1, 3, 3) 37 | return rotmat # batch*22 x 3 x 3 38 | 39 | def smplx_dict_to_nonrot(smplx_dict, include_transl=True): 40 | nonrot = [smplx_dict['transl']] if include_transl else [] 41 | nonrot += [ 42 | smplx_dict['left_hand_pose'], 43 | smplx_dict['right_hand_pose'], 44 | smplx_dict['betas'] 45 | ] 46 | return torch.cat(nonrot, dim=1) # batch x (3 + num_pca *2 + 10) 47 | 48 | -------------------------------------------------------------------------------- /data/scene_registration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformation between PROX and POSA scene assets. 3 | """ 4 | import genericpath 5 | import sys 6 | sys.path.append('..') 7 | from configuration.config import * 8 | 9 | import pickle 10 | import numpy as np 11 | import trimesh 12 | 13 | def scene_registration(scene_name): 14 | PROX_scene = trimesh.load_mesh(Path.joinpath(scene_folder, scene_name + '.ply')) 15 | POSA_scene = trimesh.load_mesh(Path.joinpath(proxe_base_folder, 'POSA_dir/scenes', scene_name + '.ply')) 16 | transform, cost = trimesh.registration.mesh_other(POSA_scene, PROX_scene) 17 | # (PROX_scene + POSA_scene.apply_transform(transform)).show() 18 | return transform 19 | 20 | def prox_to_posa(scene_name, points): 21 | transform = np.linalg.inv(POSA_to_PROX_transform[scene_name]) 22 | return np.dot(transform[:3, :3], points.T).T + transform[:3, 3].reshape((1, 3)) 23 | 24 | if not Path.exists(scene_registration_file): 25 | POSA_to_PROX_transform = {} 26 | for scene_name in scene_names: 27 | POSA_to_PROX_transform[scene_name] = scene_registration(scene_name) 28 | with open(scene_registration_file, 'wb') as file: 29 | pickle.dump(POSA_to_PROX_transform, file) 30 | 31 | with open(scene_registration_file, 'rb') as file: 32 | POSA_to_PROX_transform = pickle.load(file) 33 | 34 | if __name__ == '__main__': 35 | for scene_name in scene_names: 36 | PROX_scene = trimesh.load_mesh(Path.joinpath(scene_folder, scene_name + '.ply')) 37 | num_vertex = len(PROX_scene.vertices) 38 | PROX_scene.visual.vertex_colors = np.array([[255, 0, 0, 255]]*num_vertex, dtype=np.uint8) 39 | POSA_scene = trimesh.load_mesh(Path.joinpath(proxe_base_folder, 'POSA_dir/scenes', scene_name + '.ply')) 40 | POSA_scene.visual.vertex_colors = np.array([[0, 255, 0, 255]]*num_vertex, dtype=np.uint8) 41 | (PROX_scene + POSA_scene.apply_transform(POSA_to_PROX_transform[scene_name])).show() 42 | -------------------------------------------------------------------------------- /SceneGraphNet/data/data_structure.md: -------------------------------------------------------------------------------- 1 | # Dataset Structure 2 | 3 | ## How to read 4 | 5 | Use JSON package to load file 6 | 7 | ``` 8 | import json 9 | f = open('[YOUR_DIR]/data/[ROOM_TYPE]_data.json') 10 | valid_rooms = json.load(f) 11 | ``` 12 | 13 | ## Data Structure: 14 | 15 | - `valid_rooms` is a list = `[room_1, room_2, .., room_i, ...]` 16 | 17 | - For each `room_i`, it's a dictionary organized as 18 | 19 | - `idx` _int_ : index of the room 20 | - `room_scene_id` _string_ : scene ID which this room belongs to in SUNCG dataset 21 | - `room_model_id` _string_ : room floor ID of the scene in SUNCG dataset 22 | - `room_type` _string_ : room type 23 | - `node_list` _dict_ : a dictionary consists of all nodes(objects) in this room. 24 | Each key represents a node - _key_ is the node's name and _val_ is its feature. 25 | 26 | - `KEYS`: node's ID, such as `18_shelving`, `wall_0`, organized as _index_category_. 27 | - `VALUES`: a dictionary consists of this node's features 28 | 29 | - `co-occurrence` _list_ : a list of node's ids which are in neighbor of this node, especially for `root` and `wall` nodes 30 | _(notice that this is not the `next-to` relation mentioned in the paper)_ 31 | - `support` _list_ : a list of node's ids which are supported by this node 32 | - `surround` _list_ : a list of dictionaries whose key and value are the two node's ids surrounding this node. 33 | For example: 34 | ```json 35 | { 36 | "15_bed":{ 37 | "surround":[ 38 | {"6_nightstand":"7_nightstand"}, 39 | {"9_lamp":"10_lamp"} 40 | ] 41 | } 42 | } 43 | ``` 44 | It means a pair of nightstands and a pair of lamps are surrounding a bed. 45 | - `self_info` : a dictionary consists of node's transform and id information for itself. 46 | 47 | - `type` : type of the node, choice of [`'root'`, `'wall'`, `'node'`] 48 | - `node_model_id` : model id of this node (you can look it up in this [file](preprocess/SUNCG_id2type.csv), 2nd column) 49 | - `dim` _vector3_ : dimension of the node 50 | - `translation` _vector3_ : absolute position of this node 51 | - `rotation` _vector9_ : unwrapped 3x3 rotation matrix of this node 52 | 53 | - An example visualization for a bedroom data structure . `co-occurence` in dense dash lines, `support` in solid lines, `surround` in dash arrow lines. 54 | 55 | ![img](../docs/data_structure_example.png) 56 | 57 | - To notice that, this is the original data structure we load for some preprocessing steps. 58 | The `Supporting`, `Supported-by`, `Surrounding`, `Surrounded-by`, `Next-to`, `Co-occuring` relationships are later defined in the training process. -------------------------------------------------------------------------------- /configuration/mpcat40.tsv: -------------------------------------------------------------------------------- 1 | mpcat40index mpcat40 hex wnsynsetkey nyu40 skip labels affordance 2 | 0 void #ffffff void "remove,void" 3 | 1 wall #aec7e8 "wall.n.01,baseboard.n.01,paneling.n.01" wall wallpaper "touch,lean against,put sth on" 4 | 2 floor #708090 "floor.n.01,rug.n.01,mat.n.01,bath_mat.n.01,landing.n.01" "floor,floor mat" "touch,stand on,walk on,put sth on" 5 | 3 chair #98df8a "chair.n.01,beanbag.n.01" chair "touch,sit on,put sth on" 6 | 4 door #c5b0d5 "door.n.01,doorframe.n.01,doorway.n.01,doorknob.n.01,archway.n.01" door garage door "touch,walk thorugh" 7 | 5 table #ff7f0e "table.n.02,dressing.n.04" "table,desk" counter.n.01 "touch,rest hand on,put sth on" 8 | 6 picture #d62728 "picture.n.01,photograph.n.01,picture_frame.n.01" picture "touch,watch" 9 | 7 cabinet #1f77b4 "cabinet.n.01,cupboard.n.01" cabinet "touch,put sth on" 10 | 8 cushion #bcbd22 cushion.n.03 pillow couch cushion "touch,lay on" 11 | 9 window #ff9896 "window.n.01,windowsill.n.01,window_frame.n.01,windowpane.n.02,window_screen.n.01" window "touch,put sth on" 12 | 10 sofa #2ca02c sofa.n.01 sofa "touch,sit on,lay on, stand on,lean against,put sth on" 13 | 11 bed #e377c2 "bed.n.01,bedpost.n.01,bedstead.n.01,headboard.n.01,footboard.n.01,bedspread.n.01,mattress.n.01,sheet.n.03" bed "touch,sit on,lay on, stand on,lean against,put sth on" 14 | 12 curtain #de9ed6 curtain.n.01 "curtain,shower curtain" "curtain rod,shower curtain rod,shower rod" touch 15 | 13 chest_of_drawers #9467bd "chest_of_drawers.n.01,drawer.n.01" "dresser,night stand" "touch,put sth on" 16 | 14 plant #8ca252 plant.n.02 "touch,water" 17 | 15 sink #843c39 sink.n.01 sink "touch,clean" 18 | 16 stairs #9edae5 "step.n.04,stairway.n.01,stairwell.n.01" "touch,walk on,stand on" 19 | 17 ceiling #9c9ede "ceiling.n.01,roof.n.01" ceiling 20 | 18 toilet #e7969c "toilet.n.01,bidet.n.01" toilet "touch,sit on" 21 | 19 stool #637939 stool.n.01 "touch,sit on,rest foot on" 22 | 20 towel #8c564b towel.n.01 towel "touch,move" 23 | 21 mirror #dbdb8d mirror.n.01 mirror "touch,watch" 24 | 22 tv_monitor #d6616b display.n.06 television "touch,watch" 25 | 23 shower #cedb9c "shower.n.01,showerhead.n.01" "touch,use" 26 | 24 column #e7ba52 "column.n.07,post.n.04" "touch,put sth on" 27 | 25 bathtub #393b79 bathtub.n.01 bathtub "touch,sit on,lay on, stand on,lean against,put sth on" 28 | 26 counter #a55194 "countertop.n.01,counter.n.01,kitchen_island.n.01" counter touch 29 | 27 fireplace #ad494a "fireplace.n.01,mantel.n.01" touch 30 | 28 lighting #b5cf6b "lamp.n.02,lampshade.n.01,light.n.02,chandelier.n.01,spotlight.n.02" lamp touch 31 | 29 beam #5254a3 beam.n.02 touch 32 | 30 railing #bd9e39 "railing.n.01,bannister.n.02" "touch,lean against" 33 | 31 shelving #c49c94 "bookshelf.n.01,shelf.n.01,rack.n.05" shelves "touch,put sth on" 34 | 32 blinds #f7b6d2 window_blind.n.01 blinds touch 35 | 33 gym_equipment #6b6ecf "sports_equipment.n.01,treadmill.n.01,exercise_bike.n.01" touch 36 | 34 seating #ffbb78 "bench.n.01,seat.n.03" "touch,sit on" 37 | 35 board_panel #c7c7c7 panel.n.01 whiteboard board touch 38 | 36 furniture #8c6d31 furniture.n.01 otherfurniture touch 39 | 37 appliances #e7cb94 "home_appliance.n.01,stove.n.02,dryer.n.01" refridgerator washing machine and dryer "touch,use" 40 | 38 clothes #ce6dbd clothing.n.01 clothes "touch,wear on" 41 | 39 objects #17becf "physical_object.n.01,material.n.01" "books,paper,box,bag,otherprop" "structure.n.01,way.n.06,vent.n.01,unknown.n.01,pool.n.01" "touch,move" 42 | 40 misc #7f7f7f "person,otherstructure" unknown.n.01 43 | 41 unlabeled #000000 unknown.n.01 unknown 44 | -------------------------------------------------------------------------------- /SceneGraphNet/main.py: -------------------------------------------------------------------------------- 1 | from utils.default_settings import * 2 | from utils.utl import try_mkdir 3 | import argparse 4 | from sgmodel.train import train_model 5 | 6 | ''' parser input ''' 7 | parser = argparse.ArgumentParser() 8 | 9 | # train process settings 10 | parser.add_argument('--nepoch', type=int, default=100, help='number of epochs to train for') 11 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 12 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 13 | parser.add_argument('--reg_lr', type=float, default=1e-5, help='weight decay') 14 | parser.add_argument('--d_vec_dim', type=int, default=100, help='feature dimension for encoded vector') 15 | parser.add_argument('--h_vec_dim', type=int, default=300, help='feature dimension for hidden layers') 16 | parser.add_argument('--train_cat', default=False, action='store_true', help='train for object categories') 17 | parser.add_argument('--train_dim', default=False, action='store_true', help='train for object dimensions') 18 | 19 | # model variants 20 | parser.add_argument('--K', type=int, default=3, help='times of iteration') 21 | parser.add_argument('--aggregate_in_order', default=True, action='store_false', help='if aggregating object features in distance order') 22 | parser.add_argument('--aggregation_func', default='GRU', help='aggregation function, choice=[GRU, CatRNN, MaxPool, Sum]') 23 | parser.add_argument('--decode_cat_d_vec', default=True, action='store_false', help='if decode concatenated object feature') 24 | parser.add_argument('--cat_msg', default=False, action='store_true', help='if true, use MLP to predict message passing, else, directly use node representation as message') 25 | parser.add_argument('--adapt_training_on_large_graph', default=True, action='store_false', help='if adapt acceleration on on large graphs') 26 | parser.add_argument('--max_scene_nodes', type=int, default=60, help='(if adpat accelecration) max number of nodes under the root node. if exceed, split into subgraphs') 27 | 28 | # room type settings 29 | parser.add_argument('--room_type', type=str, default='bedroom', help='room type, choice=[bedroom, living, bathroom, office]') 30 | parser.add_argument('--num_train_rooms', default=5000, type=int, help='number of rooms for training') 31 | parser.add_argument('--num_test_rooms', default=500, type=int, help='number of rooms for testing') 32 | 33 | # for load and test on pretrained model 34 | parser.add_argument('--test', default=False, action='store_true') 35 | parser.add_argument('--load_model_name', type=str, default='', help='dir of pretrained model') 36 | parser.add_argument('--load_model_along_with_optimizer', default=False, action='store_true', help='if load pretrained model along with optimizer') 37 | 38 | # others 39 | parser.add_argument('--verbose', default=0, type=int, help='') 40 | parser.add_argument('--name', default='my-train-model') 41 | 42 | opt_parser = parser.parse_args() 43 | opt_parser.write = not opt_parser.test 44 | 45 | id2cat_file = open('data/preprocess/TRAIN_id2cat_{}.json'.format(opt_parser.room_type)) 46 | opt_parser.id2cat = json.load(id2cat_file) 47 | opt_parser.cat2id = {opt_parser.id2cat[id]: id for id in opt_parser.id2cat.keys()} 48 | 49 | if(opt_parser.load_model_name != ''): 50 | opt_parser.ckpt = os.path.join(ckpt_dir, opt_parser.load_model_name, 'Entire_model_max_acc.pth') 51 | else: 52 | opt_parser.ckpt = '' 53 | 54 | opt_parser.outf = os.path.join(ckpt_dir, opt_parser.name) 55 | try_mkdir(opt_parser.outf) 56 | 57 | M = train_model(opt_parser=opt_parser) 58 | 59 | if(not opt_parser.test): 60 | for epoch in range(opt_parser.nepoch): 61 | M.train(epoch) 62 | M.test(epoch) 63 | else: 64 | M.test(0) 65 | -------------------------------------------------------------------------------- /data/scene_graph.py: -------------------------------------------------------------------------------- 1 | from utils.default_settings import * 2 | from utils.utl import try_mkdir 3 | import argparse 4 | from sgmodel.train import train_model 5 | 6 | ''' parser input ''' 7 | parser = argparse.ArgumentParser() 8 | 9 | # train process settings 10 | parser.add_argument('--nepoch', type=int, default=100, help='number of epochs to train for') 11 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 12 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 13 | parser.add_argument('--reg_lr', type=float, default=1e-5, help='weight decay') 14 | parser.add_argument('--d_vec_dim', type=int, default=100, help='feature dimension for encoded vector') 15 | parser.add_argument('--h_vec_dim', type=int, default=300, help='feature dimension for hidden layers') 16 | parser.add_argument('--train_cat', default=False, action='store_true', help='train for object categories') 17 | parser.add_argument('--train_dim', default=False, action='store_true', help='train for object dimensions') 18 | 19 | # model variants 20 | parser.add_argument('--K', type=int, default=3, help='times of iteration') 21 | parser.add_argument('--aggregate_in_order', default=True, action='store_false', help='if aggregating object features in distance order') 22 | parser.add_argument('--aggregation_func', default='GRU', help='aggregation function, choice=[GRU, CatRNN, MaxPool, Sum]') 23 | parser.add_argument('--decode_cat_d_vec', default=True, action='store_false', help='if decode concatenated object feature') 24 | parser.add_argument('--cat_msg', default=False, action='store_true', help='if true, use MLP to predict message passing, else, directly use node representation as message') 25 | parser.add_argument('--adapt_training_on_large_graph', default=True, action='store_false', help='if adapt acceleration on on large graphs') 26 | parser.add_argument('--max_scene_nodes', type=int, default=60, help='(if adpat accelecration) max number of nodes under the root node. if exceed, split into subgraphs') 27 | 28 | # room type settings 29 | parser.add_argument('--room_type', type=str, default='bedroom', help='room type, choice=[bedroom, living, bathroom, office]') 30 | parser.add_argument('--num_train_rooms', default=5000, type=int, help='number of rooms for training') 31 | parser.add_argument('--num_test_rooms', default=500, type=int, help='number of rooms for testing') 32 | 33 | # for load and test on pretrained model 34 | parser.add_argument('--test', default=False, action='store_true') 35 | parser.add_argument('--load_model_name', type=str, default='', help='dir of pretrained model') 36 | parser.add_argument('--load_model_along_with_optimizer', default=False, action='store_true', help='if load pretrained model along with optimizer') 37 | 38 | # others 39 | parser.add_argument('--verbose', default=0, type=int, help='') 40 | parser.add_argument('--name', default='my-train-model') 41 | 42 | opt_parser = parser.parse_args() 43 | opt_parser.write = not opt_parser.test 44 | 45 | id2cat_file = open('data/preprocess/TRAIN_id2cat_{}.json'.format(opt_parser.room_type)) 46 | opt_parser.id2cat = json.load(id2cat_file) 47 | opt_parser.cat2id = {opt_parser.id2cat[id]: id for id in opt_parser.id2cat.keys()} 48 | 49 | if(opt_parser.load_model_name != ''): 50 | opt_parser.ckpt = os.path.join(ckpt_dir, opt_parser.load_model_name, 'Entire_model_max_acc.pth') 51 | else: 52 | opt_parser.ckpt = '' 53 | 54 | opt_parser.outf = os.path.join(ckpt_dir, opt_parser.name) 55 | try_mkdir(opt_parser.outf) 56 | 57 | M = train_model(opt_parser=opt_parser) 58 | 59 | if(not opt_parser.test): 60 | for epoch in range(opt_parser.nepoch): 61 | M.train(epoch) 62 | M.test(epoch) 63 | else: 64 | M.test(0) 65 | -------------------------------------------------------------------------------- /evaluation/render_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('..') 4 | from configuration.config import * 5 | os.environ['PYOPENGL_PLATFORM'] = 'osmesa' 6 | 7 | import smplx 8 | import trimesh 9 | import pickle 10 | import numpy as np 11 | import torch 12 | from tqdm import tqdm 13 | 14 | from utils.viz_util import render_generation_multview 15 | from evaluation.load_results import synthesis_results_dict 16 | 17 | scene_meshes = {} 18 | for scene_name in scene_names: 19 | mesh_path = Path.joinpath(scene_folder, scene_name + '.ply').__str__() 20 | scene_meshes[scene_name] = trimesh.load_mesh(mesh_path) 21 | 22 | body_model_dict = { 23 | 'male': smplx.create(smplx_model_folder, model_type='smplx', 24 | gender='male', ext='npz', 25 | num_pca_comps=num_pca_comps), 26 | 'female': smplx.create(smplx_model_folder, model_type='smplx', 27 | gender='female', ext='npz', 28 | num_pca_comps=num_pca_comps), 29 | 'neutral': smplx.create(smplx_model_folder, model_type='smplx', 30 | gender='neutral', ext='npz', 31 | num_pca_comps=num_pca_comps) 32 | } 33 | 34 | default_colors = np.ones((10475, 3)) * np.array([0.80, 0.80, 0.80]) 35 | 36 | """ 37 | Input: 38 | results: frames 39 | render_dir: directory to save rendering 40 | max_render_num: maximum number of samples per semantic to be rendered 41 | num_view: number of rendering views for each sample 42 | """ 43 | def render_results(results, render_dir, max_render_num=10, num_view=2): 44 | for inter_idx, generation in enumerate(tqdm(results)): 45 | generation_dir = Path.joinpath(render_dir, generation) 46 | generation_dir.mkdir(parents=True, exist_ok=True) 47 | step_size = max(1, len(results[generation]) // max_render_num) 48 | for idx in range(0, len(results[generation]), step_size): 49 | generation_params = results[generation][idx] 50 | scene_mesh = scene_meshes[generation_params['scene']] 51 | if 'gender' in generation_params: 52 | body_model = body_model_dict[generation_params['gender']] 53 | else: 54 | body_model = body_model_dict['neutral'] 55 | for key in smplx_param_names: 56 | if key in generation_params: 57 | generation_params[key] = torch.tensor(generation_params[key], dtype=torch.float32).cpu() 58 | generation_params['left_hand_pose'] = generation_params['left_hand_pose'][:, :num_pca_comps] 59 | generation_params['right_hand_pose'] = generation_params['right_hand_pose'][:, :num_pca_comps] 60 | # print(generation_params) 61 | vertices = body_model(**generation_params).vertices.detach().cpu().numpy() 62 | body = trimesh.Trimesh(vertices[0], body_model.faces, vertex_colors=default_colors, 63 | process=False) 64 | 65 | img_collage = render_generation_multview(body, scene_mesh, clothed_body=None, 66 | body_center=True, 67 | num_view=num_view, 68 | collage_mode='grid' if num_view == 4 else 'vertical') 69 | export_path = Path.joinpath(generation_dir, generation + '_' +generation_params['scene'] + '_' + str(idx // step_size) + '.png') 70 | print(export_path) 71 | img_collage.save(export_path) 72 | 73 | if __name__ == '__main__': 74 | """render using multiple views or generation results from different sources""" 75 | for method in synthesis_results_dict: 76 | print('render for ', method) 77 | render_results(synthesis_results_dict[method], Path.joinpath(render_folder, method + '_2view'), 78 | max_render_num=16, num_view=2) 79 | -------------------------------------------------------------------------------- /configuration/joints.py: -------------------------------------------------------------------------------- 1 | # SMPL-X joints names: https://github.com/vchoutas/smplx/blob/f4206853a4746139f61bdcf58571f2cea0cbebad/smplx/joint_names.py#L17 2 | import numpy as np 3 | JOINT_NAMES = [ 4 | 'pelvis', 5 | 'left_hip', 6 | 'right_hip', 7 | 'spine1', 8 | 'left_knee', 9 | 'right_knee', 10 | 'spine2', 11 | 'left_ankle', 12 | 'right_ankle', 13 | 'spine3', 14 | 'left_foot', 15 | 'right_foot', 16 | 'neck', 17 | 'left_collar', 18 | 'right_collar', 19 | 'head', 20 | 'left_shoulder', 21 | 'right_shoulder', 22 | 'left_elbow', 23 | 'right_elbow', 24 | 'left_wrist', 25 | 'right_wrist', 26 | 'jaw', 27 | 'left_eye_smplhf', 28 | 'right_eye_smplhf', 29 | 'left_index1', 30 | 'left_index2', 31 | 'left_index3', 32 | 'left_middle1', 33 | 'left_middle2', 34 | 'left_middle3', 35 | 'left_pinky1', 36 | 'left_pinky2', 37 | 'left_pinky3', 38 | 'left_ring1', 39 | 'left_ring2', 40 | 'left_ring3', 41 | 'left_thumb1', 42 | 'left_thumb2', 43 | 'left_thumb3', 44 | 'right_index1', 45 | 'right_index2', 46 | 'right_index3', 47 | 'right_middle1', 48 | 'right_middle2', 49 | 'right_middle3', 50 | 'right_pinky1', 51 | 'right_pinky2', 52 | 'right_pinky3', 53 | 'right_ring1', 54 | 'right_ring2', 55 | 'right_ring3', 56 | 'right_thumb1', 57 | 'right_thumb2', 58 | 'right_thumb3', 59 | 'nose', 60 | 'right_eye', 61 | 'left_eye', 62 | 'right_ear', 63 | 'left_ear', 64 | 'left_big_toe', 65 | 'left_small_toe', 66 | 'left_heel', 67 | 'right_big_toe', 68 | 'right_small_toe', 69 | 'right_heel', 70 | 'left_thumb', 71 | 'left_index', 72 | 'left_middle', 73 | 'left_ring', 74 | 'left_pinky', 75 | 'right_thumb', 76 | 'right_index', 77 | 'right_middle', 78 | 'right_ring', 79 | 'right_pinky', 80 | 'right_eye_brow1', 81 | 'right_eye_brow2', 82 | 'right_eye_brow3', 83 | 'right_eye_brow4', 84 | 'right_eye_brow5', 85 | 'left_eye_brow5', 86 | 'left_eye_brow4', 87 | 'left_eye_brow3', 88 | 'left_eye_brow2', 89 | 'left_eye_brow1', 90 | 'nose1', 91 | 'nose2', 92 | 'nose3', 93 | 'nose4', 94 | 'right_nose_2', 95 | 'right_nose_1', 96 | 'nose_middle', 97 | 'left_nose_1', 98 | 'left_nose_2', 99 | 'right_eye1', 100 | 'right_eye2', 101 | 'right_eye3', 102 | 'right_eye4', 103 | 'right_eye5', 104 | 'right_eye6', 105 | 'left_eye4', 106 | 'left_eye3', 107 | 'left_eye2', 108 | 'left_eye1', 109 | 'left_eye6', 110 | 'left_eye5', 111 | 'right_mouth_1', 112 | 'right_mouth_2', 113 | 'right_mouth_3', 114 | 'mouth_top', 115 | 'left_mouth_3', 116 | 'left_mouth_2', 117 | 'left_mouth_1', 118 | 'left_mouth_5', # 59 in OpenPose output 119 | 'left_mouth_4', # 58 in OpenPose output 120 | 'mouth_bottom', 121 | 'right_mouth_4', 122 | 'right_mouth_5', 123 | 'right_lip_1', 124 | 'right_lip_2', 125 | 'lip_top', 126 | 'left_lip_2', 127 | 'left_lip_1', 128 | 'left_lip_3', 129 | 'lip_bottom', 130 | 'right_lip_3', 131 | # Face contour 132 | 'right_contour_1', 133 | 'right_contour_2', 134 | 'right_contour_3', 135 | 'right_contour_4', 136 | 'right_contour_5', 137 | 'right_contour_6', 138 | 'right_contour_7', 139 | 'right_contour_8', 140 | 'contour_middle', 141 | 'left_contour_8', 142 | 'left_contour_7', 143 | 'left_contour_6', 144 | 'left_contour_5', 145 | 'left_contour_4', 146 | 'left_contour_3', 147 | 'left_contour_2', 148 | 'left_contour_1', 149 | ] 150 | # print(len(JOINT_NAMES)) 151 | joint_name_to_idx = {} 152 | for idx, joint_name in enumerate(JOINT_NAMES): 153 | joint_name_to_idx[joint_name] = idx 154 | 155 | bones = [ 156 | ['pelvis', 'left_hip'], 157 | ['pelvis', 'right_hip'], 158 | ['pelvis', 'neck'], 159 | ['left_knee', 'left_hip'], 160 | ['right_knee', 'right_hip'], 161 | ['left_knee', 'left_ankle'], 162 | ['right_knee', 'right_ankle'], 163 | ['left_foot', 'left_ankle'], 164 | ['right_foot', 'right_ankle'], 165 | ['neck', 'head'], 166 | ['neck', 'left_shoulder'], 167 | ['neck', 'right_shoulder'], 168 | ['left_elbow', 'left_shoulder'], 169 | ['right_elbow', 'right_shoulder'], 170 | ['left_elbow', 'left_wrist'], 171 | ['right_elbow', 'right_wrist'] 172 | ] 173 | 174 | parent_joint_idx = np.array([-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 175 | 16, 17, 18, 19, 15, 15, 15, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 176 | 35, 20, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50, 21, 52, 177 | 53]) # smplx kinematic chain as index of parent joint -------------------------------------------------------------------------------- /models/graph_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains definitions of layers used to build the GraphCNN 3 | """ 4 | from __future__ import division 5 | 6 | import math 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | def batch_sparse_dense_matmul(sparse, dense): 14 | B, N, C = dense.shape 15 | dense = dense.permute(1, 0, 2).reshape(N, B*C) 16 | out = torch.matmul(sparse, dense).reshape(-1, B, C).permute(1, 0, 2) 17 | return out 18 | 19 | class GraphConvolution(nn.Module): 20 | """Simple interaction layer, similar to https://arxiv.org/abs/1609.02907.""" 21 | def __init__(self, in_features, out_features, adjmat, bias=True): 22 | super(GraphConvolution, self).__init__() 23 | self.in_features = in_features 24 | self.out_features = out_features 25 | self.adjmat = adjmat 26 | self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) 27 | if bias: 28 | self.bias = nn.Parameter(torch.FloatTensor(out_features)) 29 | else: 30 | self.register_parameter('bias', None) 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | # stdv = 1. / math.sqrt(self.weight.size(1)) 35 | stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1)) 36 | self.weight.data.uniform_(-stdv, stdv) 37 | if self.bias is not None: 38 | self.bias.data.uniform_(-stdv, stdv) 39 | 40 | def forward(self, x): 41 | if x.ndimension() == 2: 42 | support = torch.matmul(x, self.weight) 43 | output = torch.matmul(self.adjmat, support) 44 | if self.bias is not None: 45 | output = output + self.bias 46 | return output 47 | else: 48 | output = batch_sparse_dense_matmul(self.adjmat, torch.matmul(x, self.weight)) 49 | if self.bias is not None: 50 | output = output + self.bias 51 | return output 52 | 53 | def __repr__(self): 54 | return self.__class__.__name__ + ' (' \ 55 | + str(self.in_features) + ' -> ' \ 56 | + str(self.out_features) + ')' 57 | 58 | class GraphLinear(nn.Module): 59 | """ 60 | Generalization of 1x1 convolutions on Graphs 61 | """ 62 | def __init__(self, in_channels, out_channels): 63 | super(GraphLinear, self).__init__() 64 | self.in_channels = in_channels 65 | self.out_channels = out_channels 66 | self.W = nn.Parameter(torch.FloatTensor(out_channels, in_channels)) 67 | self.b = nn.Parameter(torch.FloatTensor(out_channels)) 68 | self.reset_parameters() 69 | 70 | def reset_parameters(self): 71 | w_stdv = 1 / (self.in_channels * self.out_channels) 72 | self.W.data.uniform_(-w_stdv, w_stdv) 73 | self.b.data.uniform_(-w_stdv, w_stdv) 74 | 75 | def forward(self, x): 76 | return torch.matmul(self.W[None, :], x) + self.b[None, :, None] 77 | 78 | class GraphResBlock(nn.Module): 79 | """ 80 | Graph Residual Block similar to the Bottleneck Residual Block in ResNet 81 | """ 82 | 83 | def __init__(self, in_channels, out_channels, A): 84 | super(GraphResBlock, self).__init__() 85 | self.in_channels = in_channels 86 | self.out_channels = out_channels 87 | self.lin1 = GraphLinear(in_channels, out_channels // 2) 88 | self.conv = GraphConvolution(out_channels // 2, out_channels // 2, A) 89 | self.lin2 = GraphLinear(out_channels // 2, out_channels) 90 | self.skip_conv = GraphLinear(in_channels, out_channels) 91 | self.pre_norm = nn.GroupNorm(in_channels // 8, in_channels) 92 | self.norm1 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2)) 93 | self.norm2 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2)) 94 | 95 | def forward(self, x): 96 | y = F.relu(self.pre_norm(x)) 97 | y = self.lin1(y) 98 | 99 | y = F.relu(self.norm1(y)) 100 | y = self.conv(y.transpose(1,2)).transpose(1,2) 101 | 102 | y = F.relu(self.norm2(y)) 103 | y = self.lin2(y) 104 | if self.in_channels != self.out_channels: 105 | x = self.skip_conv(x) 106 | return x+y 107 | 108 | class SparseMM(torch.autograd.Function): 109 | """Redefine sparse @ dense matrix multiplication to enable backpropagation. 110 | The builtin matrix multiplication operation does not support backpropagation in some cases. 111 | """ 112 | @staticmethod 113 | def forward(ctx, sparse, dense): 114 | ctx.req_grad = dense.requires_grad 115 | ctx.save_for_backward(sparse) 116 | return torch.matmul(sparse, dense) 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | grad_input = None 121 | sparse, = ctx.saved_tensors 122 | if ctx.req_grad: 123 | grad_input = torch.matmul(sparse.t(), grad_output) 124 | return None, grad_input 125 | 126 | def spmm(sparse, dense): 127 | return SparseMM.apply(sparse, dense) 128 | -------------------------------------------------------------------------------- /SceneGraphNet/utils/utl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from torch._six import string_classes, int_classes 5 | import re, os 6 | import collections 7 | import numpy as np 8 | 9 | 10 | def weight_init(m): 11 | ''' 12 | Usage: 13 | model = Model() 14 | model.apply(weight_init) 15 | ''' 16 | if isinstance(m, nn.Conv1d): 17 | init.normal_(m.weight.data) 18 | if m.bias is not None: 19 | init.normal_(m.bias.data) 20 | elif isinstance(m, nn.Conv2d): 21 | init.xavier_normal_(m.weight.data) 22 | if m.bias is not None: 23 | init.normal_(m.bias.data) 24 | elif isinstance(m, nn.Conv3d): 25 | init.xavier_normal_(m.weight.data) 26 | if m.bias is not None: 27 | init.normal_(m.bias.data) 28 | elif isinstance(m, nn.ConvTranspose1d): 29 | init.normal_(m.weight.data) 30 | if m.bias is not None: 31 | init.normal_(m.bias.data) 32 | elif isinstance(m, nn.ConvTranspose2d): 33 | init.xavier_normal_(m.weight.data) 34 | if m.bias is not None: 35 | init.normal_(m.bias.data) 36 | elif isinstance(m, nn.ConvTranspose3d): 37 | init.xavier_normal_(m.weight.data) 38 | if m.bias is not None: 39 | init.normal_(m.bias.data) 40 | elif isinstance(m, nn.BatchNorm1d): 41 | init.normal_(m.weight.data, mean=1, std=0.02) 42 | init.constant_(m.bias.data, 0) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | init.normal_(m.weight.data, mean=1, std=0.02) 45 | init.constant_(m.bias.data, 0) 46 | elif isinstance(m, nn.BatchNorm3d): 47 | init.normal_(m.weight.data, mean=1, std=0.02) 48 | init.constant_(m.bias.data, 0) 49 | elif isinstance(m, nn.Linear): 50 | init.xavier_normal_(m.weight.data) 51 | init.normal_(m.bias.data) 52 | elif isinstance(m, nn.LSTM): 53 | for param in m.parameters(): 54 | if len(param.shape) >= 2: 55 | init.orthogonal_(param.data) 56 | else: 57 | init.normal_(param.data) 58 | elif isinstance(m, nn.LSTMCell): 59 | for param in m.parameters(): 60 | if len(param.shape) >= 2: 61 | init.orthogonal_(param.data) 62 | else: 63 | init.normal_(param.data) 64 | elif isinstance(m, nn.GRU): 65 | for param in m.parameters(): 66 | if len(param.shape) >= 2: 67 | init.orthogonal_(param.data) 68 | else: 69 | init.normal_(param.data) 70 | elif isinstance(m, nn.GRUCell): 71 | for param in m.parameters(): 72 | if len(param.shape) >= 2: 73 | init.orthogonal_(param.data) 74 | else: 75 | init.normal_(param.data) 76 | 77 | def get_n_params(model): 78 | pp=0 79 | for p in list(model.parameters()): 80 | nn=1 81 | for s in list(p.size()): 82 | nn = nn*s 83 | pp += nn 84 | return pp 85 | 86 | _use_shared_memory = False 87 | numpy_type_map = { 88 | 'float64': torch.DoubleTensor, 89 | 'float32': torch.FloatTensor, 90 | 'float16': torch.HalfTensor, 91 | 'int64': torch.LongTensor, 92 | 'int32': torch.IntTensor, 93 | 'int16': torch.ShortTensor, 94 | 'int8': torch.CharTensor, 95 | 'uint8': torch.ByteTensor, 96 | } 97 | def default_collate(batch): 98 | r"""Puts each data field into a tensor with outer dimension batch size""" 99 | 100 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 101 | elem_type = type(batch[0]) 102 | if isinstance(batch[0], list): 103 | return batch 104 | elif isinstance(batch[0], torch.Tensor): 105 | out = None 106 | if _use_shared_memory: 107 | # If we're in a background process, concatenate directly into a 108 | # shared memory tensor to avoid an extra copy 109 | numel = sum([x.numel() for x in batch]) 110 | storage = batch[0].storage()._new_shared(numel) 111 | out = batch[0].new(storage) 112 | return torch.stack(batch, 0, out=out) 113 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 114 | and elem_type.__name__ != 'string_': 115 | elem = batch[0] 116 | if elem_type.__name__ == 'ndarray': 117 | # array of string classes and object 118 | if re.search('[SaUO]', elem.dtype.str) is not None: 119 | raise TypeError(error_msg.format(elem.dtype)) 120 | 121 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 122 | if elem.shape == (): # scalars 123 | py_type = float if elem.dtype.name.startswith('float') else int 124 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 125 | elif isinstance(batch[0], int_classes): 126 | return torch.LongTensor(batch) 127 | elif isinstance(batch[0], float): 128 | return torch.DoubleTensor(batch) 129 | elif isinstance(batch[0], string_classes): 130 | return batch 131 | elif isinstance(batch[0], collections.Mapping): 132 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 133 | elif isinstance(batch[0], collections.Sequence): 134 | transposed = zip(*batch) 135 | return [default_collate(samples) for samples in transposed] 136 | 137 | raise TypeError((error_msg.format(type(batch[0])))) 138 | 139 | def try_mkdir(dir): 140 | try: 141 | os.makedirs(dir) 142 | except OSError: 143 | pass 144 | 145 | def get_offset_vec(c1, c2): 146 | c1_trans_x = (c1['self_info']['translation'][0] - c1['self_info']['dim'][0] * 0.5, 147 | c1['self_info']['translation'][0] + c1['self_info']['dim'][0] * 0.5) 148 | c2_trans_x = (c2['self_info']['translation'][0] - c2['self_info']['dim'][0] * 0.5, 149 | c2['self_info']['translation'][0] + c2['self_info']['dim'][0] * 0.5) 150 | offset_x = np.array([c2_trans_x[0] - c1_trans_x[0], c2_trans_x[0] - c1_trans_x[1], 151 | c2_trans_x[1] - c1_trans_x[0], c2_trans_x[1] - c1_trans_x[1]]) 152 | offset_x_idx = int(np.argmin(np.abs(offset_x))) 153 | offset_x_val = offset_x[offset_x_idx] 154 | 155 | c1_trans_y = (c1['self_info']['translation'][2] - c1['self_info']['dim'][2] * 0.5, 156 | c1['self_info']['translation'][2] + c1['self_info']['dim'][2] * 0.5) 157 | c2_trans_y = (c2['self_info']['translation'][2] - c2['self_info']['dim'][2] * 0.5, 158 | c2['self_info']['translation'][2] + c2['self_info']['dim'][2] * 0.5) 159 | offset_y = np.array([c2_trans_y[0] - c1_trans_y[0], c2_trans_y[0] - c1_trans_y[1], 160 | c2_trans_y[1] - c1_trans_y[0], c2_trans_y[1] - c1_trans_y[1]]) 161 | offset_y_idx = int(np.argmin(np.abs(offset_y))) 162 | offset_y_val = offset_y[offset_y_idx] 163 | offset = [offset_x_val, offset_y_val] 164 | 165 | # dis = np.sqrt((offset[0]**2 + offset[1]**2)) 166 | delta_x = c1['self_info']['translation'][0] - c2['self_info']['translation'][0] 167 | delta_y = c1['self_info']['translation'][2] - c2['self_info']['translation'][2] 168 | dis = np.sqrt(delta_x ** 2 + delta_y ** 2) 169 | 170 | return offset_x_val, offset_y_val, dis 171 | 172 | -------------------------------------------------------------------------------- /models/mesh.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import sys 4 | sys.path.append('..') 5 | 6 | import numpy as np 7 | import scipy.sparse 8 | import torch 9 | import smplx 10 | import trimesh 11 | import os.path as osp 12 | 13 | from configuration.config import * 14 | from models.posa_utils import get_graph_params 15 | 16 | 17 | class SparseMM(torch.autograd.Function): 18 | """Redefine sparse @ dense matrix multiplication to enable backpropagation. 19 | The builtin matrix multiplication operation does not support backpropagation in some cases. 20 | """ 21 | @staticmethod 22 | def forward(ctx, sparse, dense): 23 | ctx.req_grad = dense.requires_grad 24 | ctx.save_for_backward(sparse) 25 | return torch.matmul(sparse, dense) 26 | 27 | 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | grad_input = None 31 | sparse, = ctx.saved_tensors 32 | if ctx.req_grad: 33 | grad_input = torch.matmul(sparse.t(), grad_output) 34 | return None, grad_input 35 | 36 | 37 | def spmm(sparse, dense): 38 | return SparseMM.apply(sparse, dense) 39 | 40 | 41 | def downsample_vertices(D, last_valid_vertices): 42 | new_dim, last_dim = D.shape 43 | last_valid_vertices = set(last_valid_vertices) 44 | # print(last_valid_vertices) 45 | valid_vertices = set() 46 | for vertex in range(new_dim): 47 | last_vertices = torch.nonzero(D[vertex] > 0, as_tuple=False).flatten() 48 | # print(vertex, last_vertices) 49 | if set(last_vertices.tolist()).issubset(last_valid_vertices): 50 | valid_vertices.add(vertex) 51 | # print(valid_vertices) 52 | return list(valid_vertices) 53 | 54 | 55 | # level of smplx meshes 56 | class Mesh(object): 57 | """Mesh object that is used for handling certain graph operations.""" 58 | def __init__(self, filename=mesh_operation_file, 59 | num_downsampling=1, nsize=1, device=torch.device('cuda')): 60 | self.num_downsampling = num_downsampling 61 | self._A, self._U, self._D, self.meshes = [], [], [], [] 62 | self._A.append(get_graph_params(mesh_ds_folder, 0, use_cuda=True)) 63 | for level in range(5): 64 | A, U, D = get_graph_params(mesh_ds_folder, level + 1, use_cuda=True) 65 | self._A.append(A) 66 | self._U.append(U) 67 | self._D.append(D) 68 | self.num_vertices = [] 69 | for level in range(6): 70 | m = trimesh.load(osp.join(str(mesh_ds_folder), 'mesh_{}.obj'.format(level)), process=False) 71 | self.meshes.append(m) 72 | self.num_vertices.append(m.vertices.shape[0]) 73 | 74 | ref_vertices = torch.tensor(self.meshes[0].vertices, dtype=torch.float32) 75 | center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None] 76 | ref_vertices -= center 77 | ref_vertices /= ref_vertices.abs().max().item() 78 | 79 | self._ref_vertices = ref_vertices.to(device) 80 | self.faces = self.meshes[self.num_downsampling].faces 81 | self.device = device 82 | self.body_part_vertices = self.downsample_body_part_vertices(body_part_vertices) 83 | self.body_part_vertices_full = self.downsample_body_part_vertices(body_part_vertices_full) 84 | 85 | 86 | def downsample_body_part_vertices(self, body_part_vertices): 87 | """ 88 | Get vertices list of each body part on downsampled meshes 89 | """ 90 | downsample_level = self.num_downsampling 91 | last_body_part_vertices = body_part_vertices 92 | for level in range(downsample_level): 93 | D = self._D[level].cpu().to_dense() 94 | new_body_part_vertices = {} 95 | for part in last_body_part_vertices: 96 | new_body_part_vertices[part] = downsample_vertices(D, last_body_part_vertices[part]) 97 | last_body_part_vertices = new_body_part_vertices 98 | 99 | return last_body_part_vertices 100 | 101 | 102 | @property 103 | def adjmat(self): 104 | """Return the graph adjacency matrix at the specified subsampling level.""" 105 | return self._A[self.num_downsampling].float() 106 | 107 | 108 | @property 109 | def ref_vertices(self): 110 | """Return the template vertices at the specified subsampling level.""" 111 | ref_vertices = torch.tensor(self.meshes[self.num_downsampling].vertices, dtype=torch.float32, device=self.device) 112 | center = 0.5 * (ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None] 113 | ref_vertices -= center 114 | ref_vertices /= ref_vertices.abs().max().item() 115 | return ref_vertices 116 | 117 | 118 | def ref_vertices_by_level(self, num_downsampling): 119 | """Return the template vertices at the specified subsampling level.""" 120 | ref_vertices = torch.tensor(self.meshes[num_downsampling].vertices, dtype=torch.float32, device=self.device) 121 | center = 0.5 * (ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None] 122 | ref_vertices -= center 123 | ref_vertices /= ref_vertices.abs().max().item() 124 | return ref_vertices 125 | 126 | 127 | def downsample(self, x, n1=0, n2=None): 128 | """Downsample mesh.""" 129 | if n2 is None: 130 | n2 = self.num_downsampling 131 | if x.ndimension() < 3: 132 | for i in range(n1, n2): 133 | # print(self._D[i].shape, x.shape) 134 | x = spmm(self._D[i], x) 135 | elif x.ndimension() == 3: 136 | out = [] 137 | for i in range(x.shape[0]): 138 | y = x[i] 139 | for j in range(n1, n2): 140 | y = spmm(self._D[j], y) 141 | out.append(y) 142 | x = torch.stack(out, dim=0) 143 | return x 144 | 145 | 146 | def upsample(self, x, n1=None, n2=0): 147 | """Upsample mesh.""" 148 | if n1 is None: 149 | n1 = self.num_downsampling 150 | if x.ndimension() < 3: 151 | for i in reversed(range(n2, n1)): 152 | x = spmm(self._U[i], x) 153 | elif x.ndimension() == 3: 154 | out = [] 155 | for i in range(x.shape[0]): 156 | y = x[i] 157 | for j in reversed(range(n2, n1)): 158 | y = spmm(self._U[j], y) 159 | out.append(y) 160 | x = torch.stack(out, dim=0) 161 | return x 162 | 163 | 164 | if __name__ == '__main__': 165 | import pylab 166 | color_map = pylab.get_cmap('hsv') 167 | for level in range(1, 6): 168 | mesh = Mesh(num_downsampling=level) 169 | print('faces:', mesh.faces.shape) 170 | for part in mesh.body_part_vertices: 171 | print(part, len(mesh.body_part_vertices[part])) 172 | print(mesh._A[-1]) 173 | vertices = mesh._ref_vertices 174 | down_sampled = mesh.downsample(vertices) 175 | up_sampled = mesh.upsample(down_sampled) 176 | print(down_sampled.shape, up_sampled.shape) 177 | import trimesh 178 | colors = np.ones((mesh.num_vertices[mesh.num_downsampling], 4), dtype=np.float32) * 0.8 179 | # for idx, part in enumerate(mesh.body_part_vertices_full): 180 | # colors[mesh.body_part_vertices_full[part], :] = color_map(idx / len(mesh.body_part_vertices_full)) 181 | # for idx, part in enumerate(mesh.body_part_vertices): 182 | # colors[mesh.body_part_vertices[part], :] = color_map(idx / len(mesh.body_part_vertices)) 183 | downsampled = trimesh.Trimesh( 184 | vertices=down_sampled.cpu().numpy(), 185 | faces=mesh.faces, 186 | vertex_colors=colors 187 | ) 188 | downsampled.show() 189 | colors = np.ones((mesh.num_vertices[0], 4), dtype=np.float32) * 0.8 190 | # for idx, part in enumerate(body_part_vertices_full): 191 | # colors[body_part_vertices_full[part], :] = color_map(idx / len(body_part_vertices_full)) 192 | # for idx, part in enumerate(body_part_vertices): 193 | # colors[body_part_vertices[part], :] = color_map(idx / len(body_part_vertices)) 194 | reconstructed = trimesh.Trimesh( 195 | vertices=up_sampled.cpu().numpy(), 196 | faces=mesh.meshes[0].faces, 197 | vertex_colors=colors 198 | ) 199 | reconstructed.show() 200 | 201 | -------------------------------------------------------------------------------- /evaluation/load_results.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import glob 5 | import pickle 6 | from collections import defaultdict 7 | 8 | from configuration.config import * 9 | from data.hand_pca_transform import pose_to_pca 10 | from data.load_generation import get_generation_segments 11 | 12 | """Load results generated by pigraph""" 13 | def load_pigraph(result_dir): 14 | results = defaultdict(list) 15 | 16 | for generation_dir in result_dir.iterdir(): 17 | if generation_dir.is_dir(): 18 | generation = generation_dir.name 19 | for scene_dir in generation_dir.iterdir(): 20 | scene_name = scene_dir.name 21 | if scene_dir.is_dir(): 22 | for result_file in scene_dir.iterdir(): 23 | if result_file.name[-4:] == '.pkl': 24 | combination_name = result_file.name[:-4] 25 | with result_file.open('rb') as pkl_file: 26 | smplx_params = pickle.load(pkl_file) 27 | T = len(smplx_params['transl']) 28 | for idx in range(T): 29 | generation_param = {'scene': scene_name, 'generation': generation, 30 | 'gender': 'neutral', 'object_combination': combination_name} 31 | for key in smplx_params: 32 | if key in smplx_param_names: 33 | generation_param[key] = smplx_params[key][[idx]].cpu() 34 | generation_param['left_hand_pose'], generation_param['right_hand_pose'] = \ 35 | pose_to_pca(generation_param['left_hand_pose'], generation_param['right_hand_pose'], gender=generation_param['gender']) 36 | 37 | results[scene_name + '_' + combination_name].append(generation_param) 38 | # print(results.keys()) 39 | return results 40 | 41 | """Load results generated using POSA""" 42 | def load_posa(result_dir): 43 | results = defaultdict(list) 44 | 45 | for scene_dir in result_dir.iterdir(): 46 | if scene_dir.is_dir(): 47 | scene_name = scene_dir.name 48 | for result_file in scene_dir.iterdir(): 49 | if result_file.name[-4:] == '.pkl': 50 | generation = result_file.name.split('.')[0] 51 | with result_file.open('rb') as pkl_file: 52 | smplx_params = pickle.load(pkl_file) 53 | T = len(smplx_params) 54 | for idx in range(T): 55 | generation_param = {'scene': scene_name} 56 | for key in smplx_params[idx]: 57 | if key in smplx_param_names: 58 | generation_param[key] = smplx_params[idx][key] 59 | results[generation].append(generation_param) 60 | print(results.keys()) 61 | return results 62 | 63 | """ Load pseudo ground truth PROX generation data from test set""" 64 | def load_prox(): 65 | with open(Path.joinpath(project_folder, "data", 'test.pkl'), 'rb') as data_file: 66 | test_data = pickle.load(data_file) 67 | results = defaultdict(list) 68 | for generation in generation_names: 69 | generation_data = get_generation_segments(generation.split('+'), test_data, mode='verb-noun') 70 | for record in generation_data: 71 | # scene_name, sequence, frame_idx, smplx_param, generation_labels, generation_obj_idx = record 72 | scene_name = record['scene_name'] 73 | atomics = generation.split('+') 74 | # verbs = [atomic.split('-')[0] for atomic in atomics] 75 | # nouns = [atomic.split('-')[1] for atomic in atomics] 76 | obj_ids = [record['generation_obj_idx'][record['generation_labels'].index(atomic)] for atomic in atomics] 77 | combination_name = '+'.join([atomics[atomic_idx] + '-' + str(obj_ids[atomic_idx]) for atomic_idx in range(len(atomics))]) 78 | wrong_combination = ['MPH1Library_sit down-chair-5', 'MPH1Library_step up-chair-6', 'MPH1Library_stand up-chair-6', 'MPH1Library_stand up-chair-5', 'MPH1Library_step down-chair-8'] 79 | if (scene_name + '_' + combination_name) in wrong_combination: # filter wrong records 80 | continue 81 | generation_param = {'scene': scene_name, 'generation': generation, 82 | 'object_combination': combination_name} 83 | generation_param.update(record['smplx_param']) 84 | if not 'gender' in generation_param: 85 | generation_param['gender'] = 'neutral' 86 | generation_param['left_hand_pose'] = generation_param['left_hand_pose'][:, :num_pca_comps] 87 | generation_param['right_hand_pose'] = generation_param['right_hand_pose'][:, :num_pca_comps] 88 | results[scene_name + '_' + combination_name].append(generation_param) 89 | 90 | print(results.keys()) 91 | return results 92 | 93 | 94 | def load_coins(): 95 | with open(Path.joinpath(project_folder, "data", 'test.pkl'), 'rb') as data_file: 96 | test_data = pickle.load(data_file) 97 | results = defaultdict(list) 98 | for generation in generation_names: 99 | generation_data = get_generation_segments(generation.split('+'), test_data, mode='verb-noun') 100 | for record in generation_data: 101 | scene_name = record['scene_name'] 102 | atomics = generation.split('+') 103 | obj_ids = [record['generation_obj_idx'][record['generation_labels'].index(atomic)] for atomic in atomics] 104 | combination_name = '+'.join([atomics[atomic_idx] + '-' + str(obj_ids[atomic_idx]) for atomic_idx in range(len(atomics))]) 105 | wrong_combination = ['MPH1Library_sit down-chair-5', 'MPH1Library_step up-chair-6', 'MPH1Library_stand up-chair-6', 'MPH1Library_stand up-chair-5', 'MPH1Library_step down-chair-8'] 106 | if (scene_name + '_' + combination_name) in wrong_combination: # filter wrong records 107 | continue 108 | generation_param = {'scene': scene_name, 'generation': generation, 109 | 'object_combination': combination_name} 110 | generation_param.update(record['smplx_param']) 111 | if not 'gender' in generation_param: 112 | generation_param['gender'] = 'neutral' 113 | generation_param['left_hand_pose'] = generation_param['left_hand_pose'][:, :num_pca_comps] 114 | generation_param['right_hand_pose'] = generation_param['right_hand_pose'][:, :num_pca_comps] 115 | results[scene_name + '_' + combination_name].append(generation_param) 116 | 117 | print(results.keys()) 118 | return results 119 | 120 | """Load results generated by our method.""" 121 | def load_results(result_dir): 122 | results = defaultdict(list) 123 | 124 | for generation_dir in result_dir.iterdir(): 125 | if generation_dir.is_dir(): 126 | generation = generation_dir.name 127 | for scene_dir in generation_dir.iterdir(): 128 | scene_name = scene_dir.name 129 | if scene_dir.is_dir(): 130 | for result_file in scene_dir.iterdir(): 131 | if result_file.name[-4:] == '.pkl': 132 | combination_name = result_file.name[:-4] 133 | with result_file.open('rb') as pkl_file: 134 | smplx_params = pickle.load(pkl_file) 135 | T = len(smplx_params) 136 | for idx in range(T): 137 | generation_param = {'scene': scene_name, 'generation': generation, 138 | 'gender': 'neutral', 'object_combination': combination_name} 139 | for key in smplx_params[idx]: 140 | if key in smplx_param_names: 141 | generation_param[key] = smplx_params[idx][key] 142 | results[scene_name + '_' + combination_name].append(generation_param) 143 | # print(results.keys()) 144 | return results 145 | 146 | # dict of generation results from different sources. Used in render_results.py and eval_results.py 147 | synthesis_results_dict = { 148 | 'prox': load_prox(), 149 | 'pigraph_no_penetration': load_pigraph(Path('./scene_graph/results') / 'pigraph_normal'), 150 | 'POSA_best1': load_posa(Path('./scene_graph/results') /'POSA_IPoser_best1'), 151 | 'COINS': load_coins(Path('/home/kaizhao/projects/scene_graph/results') / 'two_stage' / 'floor_eval_try1_pene20_noseed_lr0.01' / 'optimization_after_get_body'), 152 | 'Narrator': load_results(Path('./scene_graph/results') /'best_results') 153 | } 154 | -------------------------------------------------------------------------------- /configuration/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import pandas as pd 4 | import numpy as np 5 | from PIL import ImageColor 6 | import platform 7 | 8 | local_machine = (platform.node() == 'dalcowks') 9 | 10 | proxe_base_folder = Path("./proxe") 11 | scene_folder = Path.joinpath(proxe_base_folder, "scenes_semantics") 12 | sdf_folder = Path.joinpath(proxe_base_folder, "sdf") 13 | cam2world_folder = Path.joinpath(proxe_base_folder, "cam2world") 14 | # human_folder = Path.joinpath(proxe_base_folder, "PROX_temporal/PROXD_temp_v2") 15 | # graph_folder = Path.joinpath(proxe_base_folder, "scene_graph") 16 | scene_cache_folder = Path.joinpath(proxe_base_folder, 'scene_segmentation') 17 | # downloaded files from POSA: https://posa.is.tue.mpg.de/index.html 18 | posa_folder = Path.joinpath(proxe_base_folder, 'POSA_dir') 19 | mesh_ds_folder = Path.joinpath(proxe_base_folder, 'POSA_dir', 'mesh_ds') 20 | # smplx models 21 | smplx_model_folder = Path.joinpath(proxe_base_folder, "models_smplx_v1_1/models") 22 | # project directory 23 | project_folder = Path(__file__).resolve().parents[1] 24 | # mesh upsample and downsample weights 25 | mesh_operation_file = Path.joinpath(project_folder, "data", 'mesh_operation.npz') 26 | # tranformation matrix between PROX and POSA scenes 27 | scene_registration_file = Path.joinpath(project_folder, "data", 'scene_registration.pkl') 28 | # checkpoints 29 | checkpoint_folder = Path.joinpath(project_folder, 'checkpoints') 30 | checkpoint_folder.mkdir(parents=True, exist_ok=True) 31 | # rendering and results 32 | results_folder = Path.joinpath(Path(__file__).resolve().parents[1], "results") 33 | results_folder.mkdir(parents=True, exist_ok=True) 34 | render_folder = Path.joinpath(Path(__file__).resolve().parents[1], "render") 35 | render_folder.mkdir(parents=True, exist_ok=True) 36 | 37 | # scene names 38 | scene_names = ["BasementSittingBooth", "MPH11", "MPH112", "MPH16", "MPH1Library", "MPH8", 39 | "N0SittingBooth", "N0Sofa", "N3Library", "N3Office", "N3OpenArea", "Werkraum"] 40 | test_scenes = ['MPH1Library', 'MPH16', 'N0SittingBooth', 'N3OpenArea'] 41 | train_scenes = sorted(list(set(scene_names) - set(test_scenes))) 42 | 43 | # manually selected object instances for interactions. For each scene and each objects category combination, we list the selected instance combinations. 44 | candidate_combination_dict = { 45 | 'MPH1Library':{ 46 | 'wall': [[0]], 47 | 'floor': [[1]], 48 | 'chair': [[2], [3], [4], [7], [9]], 49 | 'table': [[11]], 50 | 'shelving': [[12]], 51 | 'floor+wall':[[1, 0]], 52 | 'floor+shelving':[[1, 12]], 53 | 'floor+table':[[1, 11]], 54 | 'chair+table':[[2, 11], [3, 11], [4, 11], [7, 11], [9, 11]] 55 | }, 56 | 'MPH16':{ 57 | 'wall': [[0]], 58 | 'floor':[[2]], 59 | 'chair':[[3]], 60 | 'cabinet':[[5], [6]], 61 | 'table':[[4]], 62 | 'bed':[[9]], 63 | 'tv_monitor':[[10]], 64 | 'shelving':[[11], [12]], 65 | 'floor+wall':[[2, 0]], 66 | 'floor+table':[[2, 4]], 67 | 'floor+tv_monitor':[[2, 10]], 68 | 'floor+shelving':[[2, 11], [2, 12]], 69 | 'chair+table':[[3, 4]] 70 | }, 71 | 'N0SittingBooth':{ 72 | 'wall':[], 73 | 'floor':[[1]], 74 | 'table':[[2], [3]], 75 | 'floor+table':[[1, 2], [1,3]], 76 | 'floor+wall':[], 77 | }, 78 | 'N3OpenArea':{ 79 | 'wall':[[0]], 80 | 'floor':[[1]], 81 | 'chair':[[2], [3], [4], [5]], 82 | 'table':[[6]], 83 | 'sofa':[[11]], 84 | 'floor+wall':[[1, 0]], 85 | 'floor+table':[[1, 6]], 86 | 'chair+table':[[3, 6], [4, 6], [5, 6], [2, 6]], 87 | # 'chair+table':[[2, 6]], 88 | 'sofa+table':[[11, 6]] 89 | } 90 | } 91 | 92 | # sequence names 93 | recordings_temporal = Path.joinpath(Path(__file__).resolve().parent, "recordings_temporal.txt") 94 | sequence_names = [sequence.split('\n')[0] for sequence in recordings_temporal.open().readlines()] 95 | 96 | # interaction names 97 | atomic_interaction_names = ['sit on-chair', 'sit on-sofa', 'sit on-bed', 'sit on-cabinet', 'sit on-table', 98 | # 'sit on-stool', 'stand on-furniture', 99 | 'stand on-floor', 'stand on-table', 'stand on-bed', 100 | 'stand on-chest_of_drawers', 101 | 'lie on-sofa', 'lie on-bed', 102 | 'touch-table', 'touch-board_panel', 'touch-tv_monitor', 'touch-shelving', 'touch-wall', 'touch-shelving' 103 | # 'touch-lighting', 'touch-objects', 104 | ] 105 | atomic_interaction_names_include_motion = ['jump on-sofa', 'step down-table', 'touch-shelving', 'sit down-sofa', 'step up-table', 'side walk-floor', 'turn-floor', 'sit down-chair', 'stand up-bed', 'step up-sofa', 'step down-sofa', 'step down-chair', 'touch-board_panel', 'sit on-seating', 'sit on-chair', 'walk on-floor', 'sit on-bed', 'stand on-table', 'stand up-sofa', 'turnover-floor', 'lie on-sofa', 'lie down-sofa', 'a pose-floor', 'touch-tv_monitor', 'stand up-chair', 'sit up-sofa', 'restfoot-chair', 'stand on-bed', 'step back-floor', 'touch-chair', 'step up-chair', 'move leg-sofa', 'move on-sofa', 'touch-chest_of_drawers', 'touch-sofa', 'stand up-cabinet', 'sit on-stool', 106 | 'lie on-bed', 'touch-table', 'lie on-seating', 'touch-wall', 'stand on-floor', 'sit on-sofa', 'move leg-bed', 'sit on-table', 'sit on-cabinet', 'restfoot-stool', 'sit down-cabinet', 'stand on-chest_of_drawers', 'sit down-bed'] 107 | atomic_interaction_names_include_motion_train = ['sit on-sofa', 'touch-shelving', 'touch-tv_monitor', 'sit down-sofa', 'jump on-sofa', 'touch-chair', 'step down-chair', 'walk on-floor', 'touch-chest_of_drawers', 'sit down-bed', 'sit on-table', 'move on-sofa', 'stand on-chest_of_drawers', 'turn-floor', 'lie on-sofa', 'stand up-bed', 'lie on-bed', 'step up-sofa', 'side walk-floor', 'sit down-cabinet', 'stand up-chair', 'stand up-cabinet', 'touch-sofa', 'sit on-cabinet', 'a pose-floor', 'move leg-sofa', 'sit on-bed', 'touch-wall', 'sit on-chair', 'step down-table', 'stand up-sofa', 'sit up-sofa', 'touch-table', 'step up-chair', 'stand on-table', 'step down-sofa', 'sit down-chair', 'stand on-floor', 'stand on-bed', 'touch-board_panel', 'lie down-sofa', 'step up-table'] 108 | composed_interaction_names = ['sit on-chair+touch-table', 'sit on-sofa+touch-table', 109 | # 'stand on-floor+touch-lighting', 'stand on-floor+touch-objects', 110 | 'stand on-floor+touch-board_panel', 'stand on-floor+touch-table', 111 | 'stand on-floor+touch-tv_monitor', 'stand on-floor+touch-shelving', 'stand on-floor+touch-wall', 112 | ] 113 | test_composed_interaction_names = [ 114 | 'sit on-chair+touch-table', 115 | 'stand on-floor+touch-board_panel', 'stand on-floor+touch-table', 116 | ] 117 | interaction_names = atomic_interaction_names_include_motion_train + composed_interaction_names 118 | 119 | # load category name and visualization color 120 | #mpcat40index mpcat40 hex wnsynsetkey nyu40 skip labels 121 | mptsv_path = Path.joinpath(Path(__file__).resolve().parent, "mpcat40.tsv") 122 | category_dict = pd.read_csv(mptsv_path, sep='\t') 123 | category_dict['color'] = category_dict.apply(lambda row: np.array(ImageColor.getrgb(row['hex'])), axis=1) 124 | obj_category_num = 42 125 | 126 | # human body param 127 | num_pca_comps = 6 128 | smplx_param_names = ['betas', 'global_orient', 'transl', 'body_pose', 'left_hand_pose', 'right_hand_pose'] 129 | smplx_param_names += ['jaw_pose', 'leye_pose', 'reye_pose', 'expression'] 130 | used_smplx_param_names = ['transl', 'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'betas'] # these are used in diversity evaluation 131 | 132 | # body part segmentation 133 | body_parts = ['back', 'gluteus', 'L_Hand', 'R_Hand', 'L_Leg', 'R_Leg', 'thighs'] 134 | body_part_vertices = {} 135 | for body_part in body_parts: 136 | with open(Path.joinpath(proxe_base_folder, 'body_segments', body_part + '.json'), 'r') as file: 137 | body_part_vertices[body_part] = json.load(file)['verts_ind'] 138 | #https://github.com/Meshcapade/wiki/blob/main/assets/SMPL_body_segmentation/smplx/smplx_vert_segmentation.json 139 | with open((project_folder / 'configuration' / 'smplx_vert_segmentation.json'), 'r') as file: 140 | body_part_vertices_full = json.load(file) 141 | upper_body_parts = ["rightHand", "leftArm", 142 | "rightArm", "leftHandIndex1", "rightHandIndex1", "leftForeArm", 143 | "rightForeArm", "leftHand", 144 | ] 145 | lower_body_parts = ["rightUpLeg", "leftLeg", "leftToeBase", "leftFoot", "rightFoot", 146 | "rightLeg", "rightToeBase", "leftUpLeg", 147 | ] 148 | 149 | # map action to corresponding body parts 150 | action_names = ['sit on', 'lie on', 'stand on', 'touch', 'step back', 'restfoot', 'step down', 'turn', 'jump on', 'sit up', 'stand up', 'turnover', 'sit down', 'move on', 'lie down', 'move leg', 'walk on', 'a pose', 'step up', 'side walk'] 151 | action_names_train = ['sit on', 'lie on', 'stand on', 'touch', 'jump on', 'turn', 'move leg', 'stand up', 'sit down', 'sit up', 'side walk', 'step down', 'walk on', 'a pose', 'lie down', 'step up', 'move on'] 152 | num_verb = len(action_names) 153 | num_noun = 42 154 | maximum_atomics = 2 155 | action_body_part_mapping = { 156 | 'sit on': ['gluteus', 'thighs'], 157 | 'lie on': ['back', 'gluteus', 'thighs'], 158 | 'stand on': ['L_Leg', 'R_Leg'], 159 | 'touch': ['L_Hand', 'R_Hand'], 160 | } 161 | -------------------------------------------------------------------------------- /models/body_encoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from configuration.config import * 9 | from models.graph_layers import GraphResBlock, GraphLinear, spmm, batch_sparse_dense_matmul 10 | 11 | class NormalDistDecoder(nn.Module): 12 | def __init__(self, num_feat_in, latentD): 13 | super(NormalDistDecoder, self).__init__() 14 | 15 | self.mu = nn.Linear(num_feat_in, latentD) 16 | nn.init.kaiming_uniform_(self.mu.weight) 17 | nn.init.uniform_(self.mu.bias) 18 | self.logvar = nn.Linear(num_feat_in, latentD) 19 | nn.init.kaiming_uniform_(self.logvar.weight) 20 | nn.init.uniform_(self.logvar.bias) 21 | 22 | def forward(self, Xout): 23 | Xout = Xout.squeeze() 24 | return torch.distributions.normal.Normal(self.mu(Xout), F.softplus(self.logvar(Xout))) 25 | 26 | class ds_us_fn(nn.Module): 27 | def __init__(self, M): 28 | super(ds_us_fn, self).__init__() 29 | self.M = M 30 | 31 | def forward(self, x): 32 | if x.ndimension() < 3: 33 | x = x.transpose() 34 | x = spmm(self.M, x) 35 | elif x.ndimension() == 3: 36 | B, C, V = x.shape 37 | x = x.permute(2, 0, 1).reshape(V, B*C) 38 | x = torch.matmul(self.M, x).reshape(-1, B, C).permute(1, 2, 0) 39 | return x 40 | 41 | 42 | class FCBlock(nn.Module): 43 | """Wrapper around nn.Linear that includes batch normalization and activation functions.""" 44 | 45 | def __init__(self, in_size, out_size, batchnorm=True, activation=nn.ReLU(inplace=True), dropout=False): 46 | super(FCBlock, self).__init__() 47 | module_list = [nn.Conv1d(in_size, out_size, 1)] 48 | if batchnorm: 49 | module_list.append(nn.BatchNorm1d(out_size)) 50 | if activation is not None: 51 | module_list.append(activation) 52 | if dropout: 53 | module_list.append(dropout) 54 | self.fc_block = nn.Sequential(*module_list) 55 | 56 | def forward(self, x): 57 | return self.fc_block(x) 58 | 59 | 60 | class FCResBlock(nn.Module): 61 | """Residual block using fully-connected layers.""" 62 | 63 | def __init__(self, in_size, out_size, batchnorm=True, activation=nn.ReLU(inplace=True), dropout=False): 64 | super(FCResBlock, self).__init__() 65 | self.fc_block = nn.Sequential(nn.Conv1d(in_size, out_size, 1), 66 | nn.BatchNorm1d(out_size), 67 | nn.ReLU(inplace=True), 68 | nn.Conv1d(out_size, out_size, 1), 69 | nn.BatchNorm1d(out_size)) 70 | 71 | def forward(self, x): 72 | return F.relu(x + self.fc_block(x)) 73 | 74 | class BodyEncoder(nn.Module): 75 | def __init__(self, mesh, args): 76 | super(BodyEncoder, self).__init__() 77 | self.mesh = mesh 78 | self.ref_vertices = nn.Parameter(mesh.ref_vertices_by_level(args.final_downsample_level).permute(1, 0), 79 | requires_grad=False) 80 | self.ref_vertices_init = mesh.ref_vertices_by_level(args.init_downsample_level).permute(1, 0) 81 | self.num_vertices = self.ref_vertices.shape[-1] 82 | self.args = args 83 | 84 | # graph CVAE 85 | init_level = self.args.init_downsample_level 86 | final_level = self.args.final_downsample_level 87 | encoder_channels = self.args.encoder_channels 88 | encoder_layers = nn.ModuleList() 89 | encoder_layers.append(GraphLinear(3 + num_noun + 1 + num_verb * num_noun, encoder_channels)) 90 | for mesh_level in range(init_level, final_level): 91 | for _ in range(args.conv_per_level): 92 | encoder_layers.append( 93 | GraphResBlock(encoder_channels, encoder_channels, self.mesh._A[mesh_level]) 94 | ) 95 | encoder_layers.append( 96 | ds_us_fn(self.mesh._D[mesh_level]) 97 | ) 98 | for _ in range(args.conv_per_level): 99 | encoder_layers.append( 100 | GraphResBlock(encoder_channels, encoder_channels, self.mesh._A[final_level]) 101 | ) 102 | self.encoder = nn.Sequential(*encoder_layers) 103 | self.latent_linear = nn.Linear(self.mesh.num_vertices[final_level] * encoder_channels, args.latent_dimension) 104 | 105 | self.dist_decoder = NormalDistDecoder(self.mesh.num_vertices[final_level] * encoder_channels, args.latent_dimension) 106 | 107 | decoder_channels = self.args.decoder_channels 108 | decoder_layers = nn.ModuleList( 109 | [GraphLinear(3 + args.latent_dimension + num_verb * num_noun, decoder_channels)] # concatenate 110 | ) 111 | for mesh_level in reversed(range(init_level + 1, final_level + 1)): 112 | for _ in range(args.conv_per_level): 113 | decoder_layers.append( 114 | GraphResBlock(decoder_channels, decoder_channels, self.mesh._A[mesh_level]) 115 | ) 116 | decoder_layers.append( 117 | ds_us_fn(self.mesh._U[mesh_level - 1]) 118 | ) 119 | for _ in range(args.conv_per_level): 120 | decoder_layers.append( 121 | GraphResBlock(decoder_channels, decoder_channels, self.mesh._A[init_level]) 122 | ) 123 | decoder_layers += [ 124 | GraphResBlock(decoder_channels, 64, self.mesh._A[init_level]), 125 | GraphResBlock(64, 64, self.mesh._A[init_level]), 126 | nn.GroupNorm(64 // 8, 64), 127 | nn.ReLU(inplace=True), 128 | GraphLinear(64, 3 + num_noun + 1), 129 | ] 130 | self.decoder = nn.Sequential(*decoder_layers) 131 | if args.residual: 132 | self.residual_net = nn.Sequential(FCBlock(3 + num_verb * num_noun, 512), 133 | FCResBlock(512, 512), 134 | FCResBlock(512, 512), 135 | nn.Conv1d(512, num_noun + 1, 1)) 136 | 137 | def forward(self, body_vertices, contact_features, generation_code): 138 | """Forward pass 139 | """ 140 | batch_size, num_vertices = body_vertices.shape[:2] 141 | body_vertices = body_vertices.transpose(1, 2) 142 | feature = self.encoder( 143 | torch.cat([body_vertices, 144 | contact_features.transpose(1, 2), 145 | generation_code.unsqueeze(2).expand(-1, -1, num_vertices)], dim=1) 146 | ) 147 | if self.args.model == 'AE': 148 | z = self.latent_linear(feature.reshape(batch_size, -1)) 149 | z_dist = None 150 | else: 151 | z_dist = self.dist_decoder(feature.reshape(batch_size, -1)) 152 | z = z_dist.rsample() 153 | decoder_input = torch.cat((z.unsqueeze(2).expand(-1, -1, self.num_vertices), 154 | generation_code.unsqueeze(2).expand(-1, -1, self.num_vertices), 155 | self.ref_vertices.unsqueeze(0).expand(batch_size, -1, -1)), dim=1) 156 | 157 | pred = self.decoder(decoder_input).transpose(1, 2) 158 | x_rec = pred[:, :, :3] 159 | pred_f = pred[:, :, 3:] 160 | if self.args.residual: 161 | residual_f = self.residual_net( 162 | torch.cat([generation_code.unsqueeze(2).expand(-1, -1, self.mesh.num_vertices[self.args.init_downsample_level]), 163 | self.ref_vertices_init.unsqueeze(0).expand(batch_size, -1, -1) 164 | ], dim=1) 165 | ).transpose(1, 2) 166 | pred_f = pred_f + residual_f 167 | f = torch.cat((torch.sigmoid(pred_f[:, :, 0]).unsqueeze(-1), 168 | pred_f[:, :, 1:]), dim=-1) 169 | return x_rec, f, z_dist 170 | 171 | def sample(self, batch_size, generation_code): 172 | assert self.args.model == 'VAE' 173 | set_training = False 174 | if self.training: 175 | set_training = True 176 | self.eval() 177 | z = torch.distributions.normal.Normal( 178 | loc=torch.zeros((batch_size, self.args.latent_dimension), requires_grad=False, device=self.args.device), 179 | scale=torch.ones((batch_size, self.args.latent_dimension), requires_grad=False, 180 | device=self.args.device)).rsample() 181 | decoder_input = torch.cat((z.unsqueeze(2).expand(-1, -1, self.num_vertices), 182 | generation_code.unsqueeze(2).expand(-1, -1, self.num_vertices), 183 | self.ref_vertices.unsqueeze(0).expand(batch_size, -1, -1)), dim=1) 184 | 185 | pred = self.decoder(decoder_input).transpose(1, 2) # (batch, channel, nodes) -> (batch, nodes, channel) 186 | x_rec = pred[:, :, :3] 187 | pred_f = pred[:, :, 3:] 188 | if self.args.residual: 189 | residual_f = self.residual_net( 190 | torch.cat([generation_code.unsqueeze(2).expand(-1, -1, self.mesh.num_vertices[self.args.init_downsample_level]), 191 | self.ref_vertices_init.unsqueeze(0).expand(batch_size, -1, -1) 192 | ], dim=1) # Bx(3+4*42)xV 193 | ).transpose(1, 2) 194 | pred_f = pred_f + residual_f 195 | f = torch.cat((torch.sigmoid(pred_f[:, :, 0]).unsqueeze(-1), 196 | pred_f[:, :, 1:]), dim=-1) 197 | if set_training: 198 | self.train() 199 | return x_rec, f 200 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import pytorch3d.loss 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy 6 | 7 | from pytorch3d.ops import cot_laplacian 8 | from pytorch3d.structures import Meshes 9 | 10 | def mesh_laplacian_smoothing(meshes, method: str = "uniform"): 11 | r""" 12 | Computes the laplacian smoothing objective for a batch of meshes. 13 | This function supports three variants of Laplacian smoothing, 14 | namely with uniform weights("uniform"), with cotangent weights ("cot"), 15 | and cotangent curvature ("cotcurv").For more details read [1, 2]. 16 | 17 | Args: 18 | meshes: Meshes object with a batch of meshes. 19 | method: str specifying the method for the laplacian. 20 | Returns: 21 | loss: Average laplacian smoothing loss across the batch. 22 | Returns 0 if meshes contains no meshes or all empty meshes. 23 | 24 | Consider a mesh M = (V, F), with verts of shape Nx3 and faces of shape Mx3. 25 | The Laplacian matrix L is a NxN tensor such that LV gives a tensor of vectors: 26 | for a uniform Laplacian, LuV[i] points to the centroid of its neighboring 27 | vertices, a cotangent Laplacian LcV[i] is known to be an approximation of 28 | the surface normal, while the curvature variant LckV[i] scales the normals 29 | by the discrete mean curvature. For vertex i, assume S[i] is the set of 30 | neighboring vertices to i, a_ij and b_ij are the "outside" angles in the 31 | two triangles connecting vertex v_i and its neighboring vertex v_j 32 | for j in S[i], as seen in the diagram below. 33 | 34 | .. code-block:: python 35 | 36 | a_ij 37 | /\ 38 | / \ 39 | / \ 40 | / \ 41 | v_i /________\ v_j 42 | \ / 43 | \ / 44 | \ / 45 | \ / 46 | \/ 47 | b_ij 48 | 49 | The definition of the Laplacian is LV[i] = sum_j w_ij (v_j - v_i) 50 | For the uniform variant, w_ij = 1 / |S[i]| 51 | For the cotangent variant, 52 | w_ij = (cot a_ij + cot b_ij) / (sum_k cot a_ik + cot b_ik) 53 | For the cotangent curvature, w_ij = (cot a_ij + cot b_ij) / (4 A[i]) 54 | where A[i] is the sum of the areas of all triangles containing vertex v_i. 55 | 56 | There is a nice trigonometry identity to compute cotangents. Consider a triangle 57 | with side lengths A, B, C and angles a, b, c. 58 | 59 | .. code-block:: python 60 | 61 | c 62 | /|\ 63 | / | \ 64 | / | \ 65 | B / H| \ A 66 | / | \ 67 | / | \ 68 | /a_____|_____b\ 69 | C 70 | 71 | Then cot a = (B^2 + C^2 - A^2) / 4 * area 72 | We know that area = CH/2, and by the law of cosines we have 73 | 74 | A^2 = B^2 + C^2 - 2BC cos a => B^2 + C^2 - A^2 = 2BC cos a 75 | 76 | Putting these together, we get: 77 | 78 | B^2 + C^2 - A^2 2BC cos a 79 | _______________ = _________ = (B/H) cos a = cos a / sin a = cot a 80 | 4 * area 2CH 81 | 82 | 83 | [1] Desbrun et al, "Implicit fairing of irregular meshes using diffusion 84 | and curvature flow", SIGGRAPH 1999. 85 | 86 | [2] Nealan et al, "Laplacian Mesh Optimization", Graphite 2006. 87 | """ 88 | 89 | if meshes.isempty(): 90 | return torch.tensor( 91 | [0.0], dtype=torch.float32, device=meshes.device, requires_grad=True 92 | ) 93 | 94 | N = len(meshes) 95 | verts_packed = meshes.verts_packed() # (sum(V_n), 3) 96 | faces_packed = meshes.faces_packed() # (sum(F_n), 3) 97 | num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,) 98 | verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),) 99 | weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),) 100 | weights = 1.0 / weights.float() 101 | 102 | with torch.no_grad(): 103 | if method == "uniform": 104 | L = meshes.laplacian_packed() 105 | elif method in ["cot", "cotcurv"]: 106 | L, inv_areas = cot_laplacian(verts_packed, faces_packed) 107 | if method == "cot": 108 | norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1) 109 | idx = norm_w > 0 110 | norm_w[idx] = 1.0 / norm_w[idx] 111 | else: 112 | L_sum = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1) 113 | norm_w = 0.25 * inv_areas 114 | else: 115 | raise ValueError("Method should be one of {uniform, cot, cotcurv}") 116 | 117 | if method == "uniform": 118 | laplacian = L.mm(verts_packed) 119 | elif method == "cot": 120 | laplacian = L.mm(verts_packed) * norm_w - verts_packed 121 | elif method == "cotcurv": 122 | laplacian = (L.mm(verts_packed) - L_sum * verts_packed) * norm_w 123 | curvature = laplacian.norm(p=2, dim=1) 124 | 125 | return curvature 126 | 127 | class LaplacianLoss(nn.Module): 128 | def __init__(self, faces): 129 | super(LaplacianLoss, self).__init__() 130 | self.faces = faces 131 | self.criterion = nn.L1Loss(reduction='mean') 132 | 133 | def forward(self, x, y): 134 | batch_size = x.shape[0] 135 | mesh_x = Meshes( 136 | verts=x, 137 | faces=self.faces.unsqueeze(0).expand(batch_size, -1, -1) 138 | ) 139 | 140 | mesh_y = Meshes( 141 | verts=y, 142 | faces=self.faces.unsqueeze(0).expand(batch_size, -1, -1) 143 | ) 144 | 145 | curvature_x = mesh_laplacian_smoothing(mesh_x, method='cotcurv') 146 | # curvature_y = mesh_laplacian_smoothing(mesh_y, method='cotcurv') 147 | # loss = self.criterion(curvature_x, curvature_y) 148 | loss = curvature_x.mean() 149 | return loss 150 | 151 | class NormalConsistencyLoss(nn.Module): 152 | # faces: BxFx3 153 | def __init__(self, faces): 154 | super(NormalConsistencyLoss, self).__init__() 155 | self.faces = faces 156 | 157 | # x,y: BxVx3 158 | def forward(self, x): 159 | batch_size = x.shape[0] 160 | mesh_x = Meshes( 161 | verts=x, 162 | faces=self.faces.unsqueeze(0).expand(batch_size, -1, -1) 163 | ) 164 | 165 | return pytorch3d.loss.mesh_normal_consistency(mesh_x) 166 | 167 | # https://github.com/hongsukchoi/Pose2Mesh_RELEASE/blob/master/lib/core/loss.py 168 | class NormalVectorLoss(nn.Module): 169 | # face: Fx3 170 | def __init__(self, face): 171 | super(NormalVectorLoss, self).__init__() 172 | self.face = face 173 | 174 | def forward(self, coord_out, coord_gt): 175 | face = torch.LongTensor(self.face).cuda() 176 | 177 | v1_out = coord_out[:, face[:, 1], :] - coord_out[:, face[:, 0], :] 178 | v1_out = F.normalize(v1_out, p=2, dim=2) # L2 normalize to make unit vector 179 | v2_out = coord_out[:, face[:, 2], :] - coord_out[:, face[:, 0], :] 180 | v2_out = F.normalize(v2_out, p=2, dim=2) # L2 normalize to make unit vector 181 | v3_out = coord_out[:, face[:, 2], :] - coord_out[:, face[:, 1], :] 182 | v3_out = F.normalize(v3_out, p=2, dim=2) # L2 nroamlize to make unit vector 183 | 184 | v1_gt = coord_gt[:, face[:, 1], :] - coord_gt[:, face[:, 0], :] 185 | v1_gt = F.normalize(v1_gt, p=2, dim=2) # L2 normalize to make unit vector 186 | v2_gt = coord_gt[:, face[:, 2], :] - coord_gt[:, face[:, 0], :] 187 | v2_gt = F.normalize(v2_gt, p=2, dim=2) # L2 normalize to make unit vector 188 | normal_gt = torch.cross(v1_gt, v2_gt, dim=2) 189 | normal_gt = F.normalize(normal_gt, p=2, dim=2) # L2 normalize to make unit vector 190 | 191 | cos1 = torch.abs(torch.sum(v1_out * normal_gt, 2, keepdim=True)) 192 | cos2 = torch.abs(torch.sum(v2_out * normal_gt, 2, keepdim=True)) 193 | cos3 = torch.abs(torch.sum(v3_out * normal_gt, 2, keepdim=True)) 194 | loss = torch.cat((cos1, cos2, cos3), 1) 195 | return loss.mean() 196 | 197 | epsilon = 1e-16 198 | class EdgeLengthLoss(nn.Module): 199 | def __init__(self, face, relative_length=False): 200 | super(EdgeLengthLoss, self).__init__() 201 | self.face = face 202 | self.relative_length = relative_length 203 | 204 | def forward(self, coord_out, coord_gt): 205 | face = torch.LongTensor(self.face).cuda() 206 | 207 | d1_out = torch.sqrt(epsilon + 208 | torch.sum((coord_out[:, face[:, 0], :] - coord_out[:, face[:, 1], :]) ** 2, 2, keepdim=True)) 209 | d2_out = torch.sqrt(epsilon + 210 | torch.sum((coord_out[:, face[:, 0], :] - coord_out[:, face[:, 2], :]) ** 2, 2, keepdim=True)) 211 | d3_out = torch.sqrt(epsilon + 212 | torch.sum((coord_out[:, face[:, 1], :] - coord_out[:, face[:, 2], :]) ** 2, 2, keepdim=True)) 213 | 214 | d1_gt = torch.sqrt(epsilon + torch.sum((coord_gt[:, face[:, 0], :] - coord_gt[:, face[:, 1], :]) ** 2, 2, keepdim=True)) 215 | d2_gt = torch.sqrt(epsilon + torch.sum((coord_gt[:, face[:, 0], :] - coord_gt[:, face[:, 2], :]) ** 2, 2, keepdim=True)) 216 | d3_gt = torch.sqrt(epsilon + torch.sum((coord_gt[:, face[:, 1], :] - coord_gt[:, face[:, 2], :]) ** 2, 2, keepdim=True)) 217 | 218 | diff1 = torch.abs(d1_out - d1_gt) 219 | diff2 = torch.abs(d2_out - d2_gt) 220 | diff3 = torch.abs(d3_out - d3_gt) 221 | if self.relative_length: 222 | diff1 = diff1 / d1_gt 223 | diff2 = diff2 / d2_gt 224 | diff3 = diff3 / d3_gt 225 | loss = torch.cat((diff1, diff2, diff3), 1) 226 | return loss.mean() 227 | 228 | 229 | class IBSLoss(nn.Module): 230 | def __init__(self, faces): 231 | super(IBSLoss, self).__init__() 232 | self.faces = faces 233 | self.criterion = nn.L1Loss(reduction='mean') 234 | 235 | def forward(self, x, y): 236 | batch_size = x.shape[0] 237 | mesh_x = Meshes( 238 | verts=x, 239 | faces=self.faces.unsqueeze(0).expand(batch_size, -1, -1) 240 | ) 241 | 242 | mesh_y = Meshes( 243 | verts=y, 244 | faces=self.faces.unsqueeze(0).expand(batch_size, -1, -1) 245 | ) 246 | 247 | curvature_x = mesh_laplacian_smoothing(mesh_x, method='cotcurv') 248 | loss = curvature_x.mean() 249 | return loss -------------------------------------------------------------------------------- /utils/eulerangles.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/matthew-brett/transforms3d/blob/master/transforms3d/euler.py 2 | """ Generic Euler rotations 3 | See: 4 | * http://en.wikipedia.org/wiki/Rotation_matrix 5 | * http://en.wikipedia.org/wiki/Euler_angles 6 | * http://mathworld.wolfram.com/EulerAngles.html 7 | See also: *Representing Attitude with Euler Angles and Quaternions: A 8 | Reference* (2006) by James Diebel. A cached PDF link last found here: 9 | http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.110.5134 10 | ****************** 11 | Defining rotations 12 | ****************** 13 | Euler's rotation theorem tells us that any rotation in 3D can be described by 3 14 | angles. Let's call the 3 angles the *Euler angle vector* and call the angles 15 | in the vector :math:`alpha`, :math:`beta` and :math:`gamma`. The vector is [ 16 | :math:`alpha`, :math:`beta`. :math:`gamma` ] and, in this description, the 17 | order of the parameters specifies the order in which the rotations occur (so 18 | the rotation corresponding to :math:`alpha` is applied first). 19 | In order to specify the meaning of an *Euler angle vector* we need to specify 20 | the axes around which each of the rotations corresponding to :math:`alpha`, 21 | :math:`beta` and :math:`gamma` will occur. 22 | There are therefore three axes for the rotations :math:`alpha`, :math:`beta` 23 | and :math:`gamma`; let's call them :math:`i` :math:`j`, :math:`k`. 24 | Let us express the rotation :math:`alpha` around axis `i` as a 3 by 3 rotation 25 | matrix `A`. Similarly :math:`beta` around `j` becomes 3 x 3 matrix `B` and 26 | :math:`gamma` around `k` becomes matrix `G`. Then the whole rotation expressed 27 | by the Euler angle vector [ :math:`alpha`, :math:`beta`. :math:`gamma` ], `R` 28 | is given by:: 29 | R = np.dot(G, np.dot(B, A)) 30 | See http://mathworld.wolfram.com/EulerAngles.html 31 | The order :math:`G B A` expresses the fact that the rotations are 32 | performed in the order of the vector (:math:`alpha` around axis `i` = 33 | `A` first). 34 | To convert a given Euler angle vector to a meaningful rotation, and a 35 | rotation matrix, we need to define: 36 | * the axes `i`, `j`, `k`; 37 | * whether the rotations move the axes as they are applied (intrinsic 38 | rotations) - compared the situation where the axes stay fixed and the 39 | vectors move within the axis frame (extrinsic); 40 | * whether a rotation matrix should be applied on the left of a vector to 41 | be transformed (vectors are column vectors) or on the right (vectors 42 | are row vectors); 43 | * the handedness of the coordinate system. 44 | See: http://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities 45 | This module implements intrinsic and extrinsic axes, with standard conventions 46 | for axes `i`, `j`, `k`. We assume that the matrix should be applied on the 47 | left of the vector, and right-handed coordinate systems. To get the matrix to 48 | apply on the right of the vector, you need the transpose of the matrix we 49 | supply here, by the matrix transpose rule: $(M . V)^T = V^T M^T$. 50 | ************* 51 | Rotation axes 52 | ************* 53 | Rotations given as a set of three angles can refer to any of 24 different ways 54 | of applying these rotations, or equivalently, 24 conventions for rotation 55 | angles. See http://en.wikipedia.org/wiki/Euler_angles. 56 | The different conventions break down into two groups of 12. In the first 57 | group, the rotation axes are fixed (also, global, static), and do not move with 58 | rotations. These are called *extrinsic* axes. The axes can also move with the 59 | rotations. These are called *intrinsic*, local or rotating axes. 60 | Each of the two groups (*intrinsic* and *extrinsic*) can further be divided 61 | into so-called Euler rotations (rotation about one axis, then a second and then 62 | the first again), and Tait-Bryan angles (rotations about all three axes). The 63 | two groups (Euler rotations and Tait-Bryan rotations) each have 6 possible 64 | choices. There are therefore 2 * 2 * 6 = 24 possible conventions that could 65 | apply to rotations about a sequence of three given angles. 66 | This module gives an implementation of conversion between angles and rotation 67 | matrices for which you can specify any of the 24 different conventions. 68 | **************************** 69 | Specifying angle conventions 70 | **************************** 71 | You specify conventions for interpreting the sequence of angles with a four 72 | character string. 73 | The first character is 'r' (rotating == intrinsic), or 's' (static == 74 | extrinsic). 75 | The next three characters give the axis ('x', 'y' or 'z') about which to 76 | perform the rotation, in the order in which the rotations will be performed. 77 | For example the string 'szyx' specifies that the angles should be interpreted 78 | relative to extrinsic (static) coordinate axes, and be performed in the order: 79 | rotation about z axis; rotation about y axis; rotation about x axis. This 80 | is a relatively common convention, with customized implementations in 81 | :mod:`taitbryan` in this package. 82 | The string 'rzxz' specifies that the angles should be interpreted 83 | relative to intrinsic (rotating) coordinate axes, and be performed in the 84 | order: rotation about z axis; rotation about the rotated x axis; rotation 85 | about the rotated z axis. Wolfram Mathworld claim this is the most common 86 | convention : http://mathworld.wolfram.com/EulerAngles.html. 87 | ********************* 88 | Direction of rotation 89 | ********************* 90 | The direction of rotation is given by the right-hand rule (orient the thumb of 91 | the right hand along the axis around which the rotation occurs, with the end of 92 | the thumb at the positive end of the axis; curl your fingers; the direction 93 | your fingers curl is the direction of rotation). Therefore, the rotations are 94 | counterclockwise if looking along the axis of rotation from positive to 95 | negative. 96 | **************************** 97 | Terms used in function names 98 | **************************** 99 | * *mat* : array shape (3, 3) (3D non-homogenous coordinates) 100 | * *euler* : (sequence of) rotation angles about the z, y, x axes (in that 101 | order) 102 | * *axangle* : rotations encoded by axis vector and angle scalar 103 | * *quat* : quaternion shape (4,) 104 | """ 105 | 106 | import math 107 | 108 | import numpy as np 109 | 110 | # from .quaternions import quat2mat, quat2axangle 111 | # from .axangles import axangle2mat 112 | # from . import taitbryan as tb 113 | 114 | # axis sequences for Euler angles 115 | _NEXT_AXIS = [1, 2, 0, 1] 116 | 117 | # map axes strings to/from tuples of inner axis, parity, repetition, frame 118 | _AXES2TUPLE = { 119 | 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0), 120 | 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0), 121 | 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0), 122 | 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0), 123 | 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1), 124 | 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1), 125 | 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1), 126 | 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)} 127 | 128 | _TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items()) 129 | 130 | # For testing whether a number is close to zero 131 | _EPS4 = np.finfo(float).eps * 4.0 132 | 133 | 134 | def euler2mat(ai, aj, ak, axes='sxyz'): 135 | """Return rotation matrix from Euler angles and axis sequence. 136 | Parameters 137 | ---------- 138 | ai : float 139 | First rotation angle (according to `axes`). 140 | aj : float 141 | Second rotation angle (according to `axes`). 142 | ak : float 143 | Third rotation angle (according to `axes`). 144 | axes : str, optional 145 | Axis specification; one of 24 axis sequences as string or encoded 146 | tuple - e.g. ``sxyz`` (the default). 147 | Returns 148 | ------- 149 | mat : array-like shape (3, 3) or (4, 4) 150 | Rotation matrix or affine. 151 | Examples 152 | -------- 153 | >>> R = euler2mat(1, 2, 3, 'syxz') 154 | >>> np.allclose(np.sum(R[0]), -1.34786452) 155 | True 156 | >>> R = euler2mat(1, 2, 3, (0, 1, 0, 1)) 157 | >>> np.allclose(np.sum(R[0]), -0.383436184) 158 | True 159 | """ 160 | try: 161 | firstaxis, parity, repetition, frame = _AXES2TUPLE[axes] 162 | except (AttributeError, KeyError): 163 | _TUPLE2AXES[axes] # validation 164 | firstaxis, parity, repetition, frame = axes 165 | 166 | i = firstaxis 167 | j = _NEXT_AXIS[i+parity] 168 | k = _NEXT_AXIS[i-parity+1] 169 | 170 | if frame: 171 | ai, ak = ak, ai 172 | if parity: 173 | ai, aj, ak = -ai, -aj, -ak 174 | 175 | si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak) 176 | ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak) 177 | cc, cs = ci*ck, ci*sk 178 | sc, ss = si*ck, si*sk 179 | 180 | M = np.eye(3) 181 | if repetition: 182 | M[i, i] = cj 183 | M[i, j] = sj*si 184 | M[i, k] = sj*ci 185 | M[j, i] = sj*sk 186 | M[j, j] = -cj*ss+cc 187 | M[j, k] = -cj*cs-sc 188 | M[k, i] = -sj*ck 189 | M[k, j] = cj*sc+cs 190 | M[k, k] = cj*cc-ss 191 | else: 192 | M[i, i] = cj*ck 193 | M[i, j] = sj*sc-cs 194 | M[i, k] = sj*cc+ss 195 | M[j, i] = cj*sk 196 | M[j, j] = sj*ss+cc 197 | M[j, k] = sj*cs-sc 198 | M[k, i] = -sj 199 | M[k, j] = cj*si 200 | M[k, k] = cj*ci 201 | return M 202 | 203 | 204 | def mat2euler(mat, axes='sxyz'): 205 | """Return Euler angles from rotation matrix for specified axis sequence. 206 | Note that many Euler angle triplets can describe one matrix. 207 | Parameters 208 | ---------- 209 | mat : array-like shape (3, 3) or (4, 4) 210 | Rotation matrix or affine. 211 | axes : str, optional 212 | Axis specification; one of 24 axis sequences as string or encoded 213 | tuple - e.g. ``sxyz`` (the default). 214 | Returns 215 | ------- 216 | ai : float 217 | First rotation angle (according to `axes`). 218 | aj : float 219 | Second rotation angle (according to `axes`). 220 | ak : float 221 | Third rotation angle (according to `axes`). 222 | Examples 223 | -------- 224 | >>> R0 = euler2mat(1, 2, 3, 'syxz') 225 | >>> al, be, ga = mat2euler(R0, 'syxz') 226 | >>> R1 = euler2mat(al, be, ga, 'syxz') 227 | >>> np.allclose(R0, R1) 228 | True 229 | """ 230 | try: 231 | firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()] 232 | except (AttributeError, KeyError): 233 | _TUPLE2AXES[axes] # validation 234 | firstaxis, parity, repetition, frame = axes 235 | 236 | i = firstaxis 237 | j = _NEXT_AXIS[i+parity] 238 | k = _NEXT_AXIS[i-parity+1] 239 | 240 | M = np.array(mat, dtype=np.float64, copy=False)[:3, :3] 241 | if repetition: 242 | sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k]) 243 | if sy > _EPS4: 244 | ax = math.atan2( M[i, j], M[i, k]) 245 | ay = math.atan2( sy, M[i, i]) 246 | az = math.atan2( M[j, i], -M[k, i]) 247 | else: 248 | ax = math.atan2(-M[j, k], M[j, j]) 249 | ay = math.atan2( sy, M[i, i]) 250 | az = 0.0 251 | else: 252 | cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i]) 253 | if cy > _EPS4: 254 | ax = math.atan2( M[k, j], M[k, k]) 255 | ay = math.atan2(-M[k, i], cy) 256 | az = math.atan2( M[j, i], M[i, i]) 257 | else: 258 | ax = math.atan2(-M[j, k], M[j, j]) 259 | ay = math.atan2(-M[k, i], cy) 260 | az = 0.0 261 | 262 | if parity: 263 | ax, ay, az = -ax, -ay, -az 264 | if frame: 265 | ax, az = az, ax 266 | return ax, ay, az 267 | -------------------------------------------------------------------------------- /utils/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from pytorch3d.ops.knn import knn_gather, knn_points 12 | from pytorch3d.structures.pointclouds import Pointclouds 13 | 14 | 15 | def _validate_chamfer_reduction_inputs( 16 | batch_reduction: Union[str, None], point_reduction: str 17 | ): 18 | """Check the requested reductions are valid. 19 | 20 | Args: 21 | batch_reduction: Reduction operation to apply for the loss across the 22 | batch, can be one of ["mean", "sum"] or None. 23 | point_reduction: Reduction operation to apply for the loss across the 24 | points, can be one of ["mean", "sum"]. 25 | """ 26 | if batch_reduction is not None and batch_reduction not in ["mean", "sum"]: 27 | raise ValueError('batch_reduction must be one of ["mean", "sum"] or None') 28 | if point_reduction not in ["mean", "sum"]: 29 | raise ValueError('point_reduction must be one of ["mean", "sum"]') 30 | 31 | 32 | def _handle_pointcloud_input( 33 | points: Union[torch.Tensor, Pointclouds], 34 | lengths: Union[torch.Tensor, None], 35 | normals: Union[torch.Tensor, None], 36 | ): 37 | """ 38 | If points is an instance of Pointclouds, retrieve the padded points tensor 39 | along with the number of points per batch and the padded normals. 40 | Otherwise, return the input points (and normals) with the number of points per cloud 41 | set to the size of the second dimension of `points`. 42 | """ 43 | if isinstance(points, Pointclouds): 44 | X = points.points_padded() 45 | lengths = points.num_points_per_cloud() 46 | normals = points.normals_padded() # either a tensor or None 47 | elif torch.is_tensor(points): 48 | if points.ndim != 3: 49 | raise ValueError("Expected points to be of shape (N, P, D)") 50 | X = points 51 | if lengths is not None and ( 52 | lengths.ndim != 1 or lengths.shape[0] != X.shape[0] 53 | ): 54 | raise ValueError("Expected lengths to be of shape (N,)") 55 | if lengths is None: 56 | lengths = torch.full( 57 | (X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device 58 | ) 59 | if normals is not None and normals.ndim != 3: 60 | raise ValueError("Expected normals to be of shape (N, P, 3") 61 | else: 62 | raise ValueError( 63 | "The input pointclouds should be either " 64 | + "Pointclouds objects or torch.Tensor of shape " 65 | + "(minibatch, num_points, 3)." 66 | ) 67 | return X, lengths, normals 68 | 69 | 70 | def chamfer_contact_loss( 71 | x, 72 | y, 73 | x_lengths=None, 74 | y_lengths=None, 75 | x_normals=None, 76 | y_normals=None, 77 | weights=None, 78 | batch_reduction: Union[str, None] = "mean", 79 | point_reduction: str = "mean", 80 | ): 81 | """ 82 | Chamfer distance between two pointclouds x and y. 83 | 84 | Args: 85 | x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing 86 | a batch of point clouds with at most P1 points in each batch element, 87 | batch size N and feature dimension D. 88 | y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing 89 | a batch of point clouds with at most P2 points in each batch element, 90 | batch size N and feature dimension D. 91 | x_lengths: Optional LongTensor of shape (N,) giving the number of points in each 92 | cloud in x. 93 | y_lengths: Optional LongTensor of shape (N,) giving the number of points in each 94 | cloud in x. 95 | x_normals: Optional FloatTensor of shape (N, P1, D). 96 | y_normals: Optional FloatTensor of shape (N, P2, D). 97 | weights: Optional FloatTensor of shape (N,) giving weights for 98 | batch elements for reduction operation. 99 | batch_reduction: Reduction operation to apply for the loss across the 100 | batch, can be one of ["mean", "sum"] or None. 101 | point_reduction: Reduction operation to apply for the loss across the 102 | points, can be one of ["mean", "sum"]. 103 | 104 | Returns: 105 | 2-element tuple containing 106 | 107 | - **loss**: Tensor giving the reduced distance between the pointclouds 108 | in x and the pointclouds in y. 109 | - **loss_normals**: Tensor giving the reduced cosine distance of normals 110 | between pointclouds in x and pointclouds in y. Returns None if 111 | x_normals and y_normals are None. 112 | """ 113 | _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) 114 | 115 | x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) 116 | y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) 117 | 118 | return_normals = x_normals is not None and y_normals is not None 119 | 120 | N, P1, D = x.shape 121 | P2 = y.shape[1] 122 | 123 | # Check if inputs are heterogeneous and create a lengths mask. 124 | is_x_heterogeneous = (x_lengths != P1).any() 125 | is_y_heterogeneous = (y_lengths != P2).any() 126 | x_mask = ( 127 | torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] 128 | ) # shape [N, P1] 129 | y_mask = ( 130 | torch.arange(P2, device=y.device)[None] >= y_lengths[:, None] 131 | ) # shape [N, P2] 132 | 133 | if y.shape[0] != N or y.shape[2] != D: 134 | raise ValueError("y does not have the correct shape.") 135 | if weights is not None: 136 | if weights.size(0) != N: 137 | raise ValueError("weights must be of shape (N,).") 138 | if not (weights >= 0).all(): 139 | raise ValueError("weights cannot be negative.") 140 | if weights.sum() == 0.0: 141 | weights = weights.view(N, 1) 142 | if batch_reduction in ["mean", "sum"]: 143 | return ( 144 | (x.sum((1, 2)) * weights).sum() * 0.0, 145 | (x.sum((1, 2)) * weights).sum() * 0.0, 146 | ) 147 | return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0) 148 | 149 | cham_norm_x = x.new_zeros(()) 150 | cham_norm_y = x.new_zeros(()) 151 | 152 | x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) 153 | y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) 154 | 155 | cham_x = x_nn.dists[..., 0] # (N, P1) 156 | cham_y = y_nn.dists[..., 0] # (N, P2) 157 | 158 | if is_x_heterogeneous: 159 | cham_x[x_mask] = 0.0 160 | if is_y_heterogeneous: 161 | cham_y[y_mask] = 0.0 162 | 163 | # from squared distance to L2 164 | # epsion = 1e-8 165 | # cham_x = cham_x.clamp(min=epsion).sqrt() # use sqrt lead to autograd error, why? 166 | # cham_y = cham_y.clamp(min=epsion).sqrt() 167 | # cham_x = cham_x ** 0.5 #use sqrt lead to autograd error, why? 168 | # cham_y = cham_y ** 0.5 169 | 170 | if weights is not None: 171 | cham_x *= weights.view(N, 1) 172 | cham_y *= weights.view(N, 1) 173 | 174 | if return_normals: 175 | # Gather the normals using the indices and keep only value for k=0 176 | x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :] 177 | y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :] 178 | 179 | cham_norm_x = 1 - torch.abs( 180 | F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6) 181 | ) 182 | cham_norm_y = 1 - torch.abs( 183 | F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6) 184 | ) 185 | 186 | if is_x_heterogeneous: 187 | cham_norm_x[x_mask] = 0.0 188 | if is_y_heterogeneous: 189 | cham_norm_y[y_mask] = 0.0 190 | 191 | if weights is not None: 192 | cham_norm_x *= weights.view(N, 1) 193 | cham_norm_y *= weights.view(N, 1) 194 | 195 | # Apply point reduction 196 | cham_x = cham_x.sum(1) # (N,) 197 | cham_y = cham_y.sum(1) # (N,) 198 | # cham_x, _ = torch.min(cham_x, 1) # (N,) 199 | # cham_y, _ = torch.min(cham_y, 1) # (N,) 200 | if return_normals: 201 | cham_norm_x = cham_norm_x.sum(1) # (N,) 202 | cham_norm_y = cham_norm_y.sum(1) # (N,) 203 | if point_reduction == "mean": 204 | cham_x /= x_lengths 205 | cham_y /= y_lengths 206 | if return_normals: 207 | cham_norm_x /= x_lengths 208 | cham_norm_y /= y_lengths 209 | 210 | if batch_reduction is not None: 211 | # batch_reduction == "sum" 212 | cham_x = cham_x.sum() 213 | cham_y = cham_y.sum() 214 | if return_normals: 215 | cham_norm_x = cham_norm_x.sum() 216 | cham_norm_y = cham_norm_y.sum() 217 | if batch_reduction == "mean": 218 | div = weights.sum() if weights is not None else N 219 | cham_x /= div 220 | cham_y /= div 221 | if return_normals: 222 | cham_norm_x /= div 223 | cham_norm_y /= div 224 | 225 | # cham_dist = cham_x + cham_y 226 | cham_dist = cham_x # only cares body to object distance 227 | cham_normals = cham_norm_x + cham_norm_y if return_normals else None 228 | 229 | return cham_dist, cham_normals 230 | 231 | 232 | def chamfer_dists( 233 | x, 234 | y, 235 | x_lengths=None, 236 | y_lengths=None, 237 | x_normals=None, 238 | y_normals=None, 239 | weights=None, 240 | batch_reduction: Union[str, None] = "mean", 241 | point_reduction: str = "mean", 242 | return_idx=False 243 | ): 244 | """ 245 | Chamfer distance between two pointclouds x and y. 246 | 247 | Args: 248 | x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing 249 | a batch of point clouds with at most P1 points in each batch element, 250 | batch size N and feature dimension D. 251 | y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing 252 | a batch of point clouds with at most P2 points in each batch element, 253 | batch size N and feature dimension D. 254 | x_lengths: Optional LongTensor of shape (N,) giving the number of points in each 255 | cloud in x. 256 | y_lengths: Optional LongTensor of shape (N,) giving the number of points in each 257 | cloud in x. 258 | x_normals: Optional FloatTensor of shape (N, P1, D). 259 | y_normals: Optional FloatTensor of shape (N, P2, D). 260 | weights: Optional FloatTensor of shape (N,) giving weights for 261 | batch elements for reduction operation. 262 | batch_reduction: Reduction operation to apply for the loss across the 263 | batch, can be one of ["mean", "sum"] or None. 264 | point_reduction: Reduction operation to apply for the loss across the 265 | points, can be one of ["mean", "sum"]. 266 | 267 | Returns: 268 | 2-element tuple containing 269 | 270 | - **loss**: Tensor giving the reduced distance between the pointclouds 271 | in x and the pointclouds in y. 272 | - **loss_normals**: Tensor giving the reduced cosine distance of normals 273 | between pointclouds in x and pointclouds in y. Returns None if 274 | x_normals and y_normals are None. 275 | """ 276 | _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) 277 | 278 | x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) 279 | y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) 280 | 281 | return_normals = x_normals is not None and y_normals is not None 282 | 283 | N, P1, D = x.shape 284 | P2 = y.shape[1] 285 | 286 | # Check if inputs are heterogeneous and create a lengths mask. 287 | is_x_heterogeneous = (x_lengths != P1).any() 288 | is_y_heterogeneous = (y_lengths != P2).any() 289 | x_mask = ( 290 | torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] 291 | ) # shape [N, P1] 292 | y_mask = ( 293 | torch.arange(P2, device=y.device)[None] >= y_lengths[:, None] 294 | ) # shape [N, P2] 295 | 296 | if y.shape[0] != N or y.shape[2] != D: 297 | raise ValueError("y does not have the correct shape.") 298 | if weights is not None: 299 | if weights.size(0) != N: 300 | raise ValueError("weights must be of shape (N,).") 301 | if not (weights >= 0).all(): 302 | raise ValueError("weights cannot be negative.") 303 | if weights.sum() == 0.0: 304 | weights = weights.view(N, 1) 305 | if batch_reduction in ["mean", "sum"]: 306 | return ( 307 | (x.sum((1, 2)) * weights).sum() * 0.0, 308 | (x.sum((1, 2)) * weights).sum() * 0.0, 309 | ) 310 | return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0) 311 | 312 | cham_norm_x = x.new_zeros(()) 313 | cham_norm_y = x.new_zeros(()) 314 | 315 | x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) 316 | # y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) 317 | 318 | cham_x = x_nn.dists[..., 0] # (N, P1) 319 | # cham_y = y_nn.dists[..., 0] # (N, P2) 320 | if return_idx: 321 | idx_x = x_nn.idx[..., 0] 322 | 323 | if is_x_heterogeneous: 324 | cham_x[x_mask] = 0.0 325 | # if is_y_heterogeneous: 326 | # cham_y[y_mask] = 0.0 327 | 328 | if return_idx: 329 | return cham_x, idx_x 330 | else: 331 | return cham_x -------------------------------------------------------------------------------- /evaluation/eval_results.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy 4 | 5 | sys.path.append('..') 6 | 7 | import torch 8 | import json 9 | import numpy as np 10 | import smplx 11 | import trimesh 12 | import scipy.cluster 13 | from scipy.stats import entropy 14 | from scipy.spatial import KDTree 15 | from sklearn.manifold import TSNE 16 | from sklearn.decomposition import PCA 17 | import matplotlib.pyplot as plt 18 | from collections import defaultdict 19 | from tqdm import tqdm 20 | from copy import copy 21 | 22 | from configuration.config import * 23 | from data.scene import scenes, to_trimesh 24 | from load_results import synthesis_results_dict 25 | from models.mesh import Mesh 26 | from utils.metric_utils import load_scene_data, read_sdf, eval_physical_metric 27 | from utils.chamfer_distance import chamfer_dists 28 | 29 | # DEBUG = True 30 | DEBUG = False 31 | 32 | """Get bodu vertices in scene coordinates. """ 33 | def get_vertices(generation_param, scene_coords='PROX'): 34 | if 'gender' in generation_param: 35 | body_model = body_model_dict[generation_param['gender']] 36 | else: 37 | body_model = body_model_dict['neutral'] 38 | scene_name = generation_param['scene'] 39 | for key in smplx_param_names: 40 | if key in generation_param: 41 | generation_param[key] = torch.tensor(generation_param[key], device=device) 42 | vertices = body_model(**generation_param).vertices.detach() 43 | 44 | return vertices 45 | 46 | def calc_diversity(body_params, cls_num=20): 47 | # print(body_params.shape) 48 | if body_params.shape[0] < cls_num: #deals with very few samples 49 | cls_num = max(1, body_params.shape[0] // 10) 50 | ## k-means 51 | codes, dist = scipy.cluster.vq.kmeans(body_params, cls_num) # codes: [20, 72], dist: scalar 52 | vecs, dist = scipy.cluster.vq.vq(body_params, codes) # assign codes, vecs/dist: [1200] 53 | counts, bins = np.histogram(vecs, np.arange(len(codes) + 1)) # count occurrences count: [20] 54 | ee = entropy(counts) 55 | return {'entropy': float(ee), 56 | 'mean_dist': float(np.mean(dist))} 57 | 58 | def evaluate_diversity(results, method='default'): 59 | diversity_metric = {'all': []} 60 | colors = [] 61 | # diversity_metric = {} 62 | for generation in results: 63 | smplx_params = [] 64 | combination_name = generation[generation.find('_') + 1:] 65 | atomics = combination_name.split('+') 66 | verbs = [atomic.split('-')[0] for atomic in atomics] 67 | verb_ids = [action_names_train.index(verb) for verb in verbs] 68 | color_id = np.prod(np.array(verb_ids)) % 233 69 | for generation_param in results[generation]: 70 | smplx_param = [] 71 | for param in used_smplx_param_names: 72 | if param in ['transl', 'global_orient', 'betas']: 73 | continue 74 | smplx_param.append(np.asarray(generation_param[param].detach().cpu()) if torch.is_tensor(generation_param[param]) else generation_param[param]) 75 | smplx_param = np.concatenate(smplx_param, axis=1) 76 | if np.isnan(smplx_param).any(): 77 | print(generation_param) 78 | continue 79 | smplx_param = numpy.zeros_like(smplx_param) 80 | smplx_params.append(smplx_param) 81 | 82 | diversity_metric[generation] = np.concatenate(smplx_params, axis=0) 83 | diversity_metric['all'].append(diversity_metric[generation]) 84 | colors += [color_id] * diversity_metric[generation].shape[0] 85 | diversity_metric['all'] = np.concatenate(diversity_metric['all'], axis=0) 86 | if len(diversity_metric['all']) == 0: 87 | return {} 88 | 89 | # """visualization""" 90 | # # Transform the data 91 | # data = diversity_metric['all'] 92 | # codes, _ = scipy.cluster.vq.kmeans(data, 50) # codes: [20, 72], dist: scalar 93 | # vecs, dist = scipy.cluster.vq.vq(data, codes) # assign codes, vecs/dist: [1200] 94 | # # t-sne 95 | # img_file = results_folder / 'tsne_new' /(method + '_.png') 96 | # img_file.parent.mkdir(exist_ok=True) 97 | # tsne = TSNE(n_components=2, verbose=0) 98 | # z = tsne.fit_transform(data) 99 | # plt.scatter(z[:, 0], z[:, 1], s=5, c=vecs, cmap='hsv') 100 | # plt.axis('off') 101 | # plt.savefig(str(img_file)) 102 | # plt.clf() 103 | # 104 | # # pca 105 | # img_file = results_folder / 'pca' / (method + '_.png') 106 | # img_file.parent.mkdir(exist_ok=True) 107 | # z = PCA(2).fit_transform(data) 108 | # plt.scatter(z[:, 0], z[:, 1], s=5, c=vecs, cmap='hsv') 109 | # plt.axis('off') 110 | # plt.savefig(str(img_file)) 111 | # plt.clf() 112 | 113 | # diversity using different number of clusters 114 | for key in diversity_metric: 115 | diversity_metric[key] = { 116 | 1: calc_diversity(diversity_metric[key], cls_num=1), 117 | 20: calc_diversity(diversity_metric[key], cls_num=20), 118 | 50: calc_diversity(diversity_metric[key], cls_num=50), 119 | 150: calc_diversity(diversity_metric[key], cls_num=150), 120 | } 121 | 122 | # for key in diversity_metric: 123 | # diversity_metric[key] = calc_diversity(diversity_metric[key], cls_num=3) 124 | # per_generation_metrics = [] 125 | # for key in diversity_metric: 126 | # per_generation_metrics.append(diversity_metric[key]) 127 | # diversity_metric['all'] = { 128 | # 'entropy': np.asarray([metric['entropy'] for metric in per_generation_metrics]).mean(), 129 | # 'mean_dist': np.asarray([metric['mean_dist'] for metric in per_generation_metrics]).mean(), 130 | # } 131 | 132 | return diversity_metric 133 | 134 | def calc_physical_metric(vertices, scene_data): 135 | nv = float(vertices.shape[1]) 136 | x = read_sdf(vertices, scene_data['sdf'], 137 | scene_data['grid_dim'], scene_data['grid_min'], scene_data['grid_max'], 138 | mode='bilinear').squeeze() 139 | 140 | contact_thresh = 0 141 | if x.le(contact_thresh).sum().item() < 1: # if the number of negative sdf entries is less than one 142 | contact_score = torch.tensor(0.0) 143 | else: 144 | contact_score = torch.tensor(1.0) 145 | non_collision_score = (x >= 0).sum().float() / nv 146 | 147 | return float(non_collision_score.detach().cpu().squeeze()), float(contact_score.detach().cpu().squeeze()) 148 | 149 | def evaluate_physical_plausibility(results): 150 | contact_metric = {'all': []} 151 | non_collision_metric = {'all': []} 152 | for generation in tqdm(results): 153 | non_collision_scores = [] 154 | contact_scores = [] 155 | for generation_param in results[generation]: 156 | vertices = get_vertices(generation_param, scene_coords='Narrator') 157 | non_collision_score, contact_score = calc_physical_metric(vertices, scene_data_dict[generation_param['scene']]) 158 | non_collision_scores.append(non_collision_score) 159 | contact_scores.append(contact_score) 160 | 161 | contact_metric[generation] = contact_scores 162 | contact_metric['all'] += contact_scores 163 | non_collision_metric[generation] = non_collision_scores 164 | non_collision_metric['all'] += non_collision_scores 165 | 166 | for key in contact_metric.keys(): 167 | contact_metric[key] = float(np.array(contact_metric[key]).mean()) 168 | non_collision_metric[key] = float(np.array(non_collision_metric[key]).mean()) 169 | 170 | return contact_metric, non_collision_metric 171 | 172 | # semantic sdf is very inaccurate 173 | def calc_semantic_accuracy(vertices, scene_data): 174 | x_semantics = read_sdf(vertices, scene_data['semantics'], 175 | scene_data['grid_dim'], scene_data['grid_min'], 176 | scene_data['grid_max'], mode="bilinear").squeeze() 177 | 178 | print(np.unique(scene_data['semantics'].cpu().numpy())) 179 | print(scene_data['scene_semantics'].shape) 180 | if DEBUG: 181 | x_semantics = x_semantics.type(torch.int).cpu().numpy() #(10475, 0) 182 | print(np.unique(x_semantics)) 183 | colors = category_dict['color'][x_semantics].to_numpy() 184 | print(colors) 185 | colors = np.asarray([np.asarray(color) for color in colors]) 186 | print(colors) 187 | body = trimesh.Trimesh(vertices=vertices[0].detach().cpu().numpy(), faces=body_model_dict['neutral'].faces, 188 | vertex_colors=colors) 189 | body.show() 190 | scene_name = generation_param['scene'] 191 | scene = trimesh.load_mesh(Path.joinpath(proxe_base_folder, 'Narrator_dir/scenes', scene_name + '.ply')) 192 | (body + scene).show() 193 | 194 | return 195 | 196 | def calc_semantic_accuracy(vertices, scene_name, generation): 197 | vertices = vertices.squeeze().detach().cpu().numpy() 198 | atomic_generations = generation.split('+') 199 | scores = [] 200 | for atomic in atomic_generations: 201 | verb, noun = atomic.split('-') 202 | if noun not in object_proximity_dict[scene_name]: 203 | objs = [obj for obj in scenes[scene_name].object_nodes if obj.category_name == noun] 204 | if len(objs) == 0: 205 | print(noun, scene_name) 206 | scores.append(0) 207 | continue 208 | points = np.concatenate([np.asarray(obj.mesh.vertices) for obj in objs], axis=0) 209 | # print(points.shape) 210 | object_proximity_dict[scene_name][noun] = KDTree(points) 211 | 212 | body_parts = action_body_part_mapping[verb] 213 | vertices_of_interest = [] 214 | for body_part in body_parts: 215 | vertices_of_interest += body_part_vertices[body_part] 216 | vertices_of_interest = np.asarray(vertices_of_interest) 217 | dists, idx = object_proximity_dict[scene_name][noun].query(vertices[vertices_of_interest]) 218 | contact_thresh_dict = { 219 | 'sit on': 0.1, 220 | 'stand on': 0.05, 221 | 'lie on': 0.1, 222 | 'touch': 0.1, 223 | } 224 | contact_thresh = contact_thresh_dict[verb] 225 | dist = dists.min() if verb == 'touch' else dists.mean() 226 | score = 1.0 if dist < contact_thresh else 0.0 227 | scores.append(score) 228 | 229 | if DEBUG: 230 | contact_vertices = vertices_of_interest[dists < contact_thresh] 231 | colors = np.ones((10475, 3)) * np.array([1.00, 0.75, 0.80]) 232 | colors[contact_vertices, :] = np.array([1.00, 0.0, 0.0]) 233 | body = trimesh.Trimesh(vertices=vertices, faces=body_model_dict['neutral'].faces, 234 | vertex_colors=colors) 235 | body.show() 236 | scene = trimesh.load_mesh(Path.joinpath(scene_folder, scene_name + '.ply')) 237 | (body + scene).show() 238 | 239 | 240 | return np.asarray(scores).mean() 241 | 242 | def calc_semantic_contact(vertices, scene_name, generation, object_combination): 243 | vertices = body_mesh.downsample(vertices) 244 | atomic_generations = object_combination.split('+') 245 | verbs = '+'.join([atomic.split('-')[0] for atomic in atomic_generations]) 246 | instance_idx = [int(atomic.split('-')[-1]) for atomic in atomic_generations] 247 | scene = scenes[scene_name] 248 | # print(instance_idx) 249 | object_points = [np.asarray(scene.get_mesh_with_accessory(node_idx).vertices) for node_idx in instance_idx] 250 | object_points = torch.from_numpy(np.concatenate(object_points)).unsqueeze(0).float().to(vertices.device) 251 | body_obj_dists = torch.sqrt(chamfer_dists(vertices, object_points)).squeeze().detach().cpu().numpy() 252 | 253 | contact_probability, contact_score_thresh = contact_statistics['probability'][verbs], contact_statistics['score'][verbs] 254 | contact_thresh = 0.05 255 | contact_score = np.sum((body_obj_dists < contact_thresh) * contact_probability) 256 | contact_score = (contact_score >= contact_score_thresh * 0.8) 257 | # contact_score = min(contact_score / contact_score_thresh, 1) if contact_score_thresh > 0 else 1 258 | return contact_score 259 | 260 | def evaluate_semantic(results): 261 | semantic_metric = {'all': []} 262 | 263 | for generation in tqdm(results): 264 | semantic_scores = [] 265 | for generation_param in results[generation]: 266 | vertices = get_vertices(generation_param, scene_coords='PROX') 267 | # semantic_score = calc_semantic_accuracy(vertices, generation_param['scene'], generation_param['generation']) 268 | semantic_score = calc_semantic_contact(vertices, generation_param['scene'], 269 | generation_param['generation'], generation_param['object_combination']) 270 | semantic_scores.append(semantic_score) 271 | 272 | semantic_metric[generation] = semantic_scores 273 | semantic_metric['all'] += semantic_scores 274 | 275 | for key in semantic_metric.keys(): 276 | semantic_metric[key] = float(np.asarray(semantic_metric[key]).mean()) 277 | 278 | return semantic_metric 279 | 280 | def evaluate_results(results, method='default'): 281 | metrics = {} 282 | metrics['diversity'] = evaluate_diversity(results, method=method) 283 | metrics['contact'], metrics['non_collision'] = evaluate_physical_plausibility(results) 284 | metrics['semantic'] = evaluate_semantic(results) 285 | return metrics 286 | 287 | if __name__ == '__main__': 288 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 289 | body_mesh = Mesh(num_downsampling=2) 290 | # load contact statistics 291 | with open(project_folder / 'data' / 'contact_statistics.json', 'r') as f: 292 | contact_statistics = json.load(f) 293 | body_model_dict = { 294 | 'male': smplx.create(smplx_model_folder, model_type='smplx', 295 | gender='male', ext='npz', 296 | num_pca_comps=num_pca_comps).to(device), 297 | 'female': smplx.create(smplx_model_folder, model_type='smplx', 298 | gender='female', ext='npz', 299 | num_pca_comps=num_pca_comps).to(device), 300 | 'neutral': smplx.create(smplx_model_folder, model_type='smplx', 301 | gender='neutral', ext='npz', 302 | num_pca_comps=num_pca_comps).to(device) 303 | } 304 | 305 | # load scenes using util function 306 | scene_data_dict = {} 307 | for scene_name in scene_names: 308 | scene_data_dict[scene_name] = load_scene_data(name=scene_name, 309 | sdf_dir=Path.joinpath(sdf_folder, 'sdf').__str__(), 310 | use_semantics=True, 311 | no_obj_classes=42, 312 | device=device 313 | ) 314 | # trimesh proximity of objects of specified category 315 | object_proximity_dict = {} 316 | for scene_name in test_scenes: 317 | object_proximity_dict[scene_name] = {} 318 | 319 | if DEBUG: 320 | generation_param = synthesis_results_dict['prox']['sit on-chair'][0] 321 | vertices = get_vertices(generation_param) 322 | calc_semantic_accuracy(vertices, generation_param['scene'], 'sit on-chair') 323 | 324 | # evaluate metrics for each interation semantics 325 | metrics = {} 326 | for method in tqdm(synthesis_results_dict): 327 | print('evaluate metrics for:', method) 328 | metrics[method] = evaluate_results(synthesis_results_dict[method], method) 329 | print(metrics) 330 | with open(Path.joinpath(results_folder, 'metrics.json'), 'w') as file: 331 | json.dump(metrics, file) 332 | 333 | # evaluate overall metrics for all interation frames 334 | metrics_overview = copy(metrics) 335 | for method in metrics_overview: 336 | for metric in metrics_overview[method]: 337 | metrics_overview[method][metric] = metrics_overview[method][metric]['all'] 338 | with open(Path.joinpath(results_folder, 'metrics_overview.json'), 'w') as file: 339 | json.dump(metrics_overview, file) 340 | -------------------------------------------------------------------------------- /utils/pointnet2.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet2_utils.py 2 | import sys 3 | sys.path.append('..') 4 | from configuration.config import checkpoint_folder 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from time import time 10 | import numpy as np 11 | 12 | def timeit(tag, t): 13 | print("{}: {}s".format(tag, time() - t)) 14 | return time() 15 | 16 | def pc_normalize(pc): 17 | l = pc.shape[0] 18 | centroid = np.mean(pc, axis=0) 19 | pc = pc - centroid 20 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 21 | pc = pc / m 22 | return pc 23 | 24 | def square_distance(src, dst): 25 | """ 26 | Calculate Euclid distance between each two points. 27 | 28 | src^T * dst = xn * xm + yn * ym + zn * zm; 29 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 30 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 31 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 32 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 33 | 34 | Input: 35 | src: source points, [B, N, C] 36 | dst: target points, [B, M, C] 37 | Output: 38 | dist: per-point square distance, [B, N, M] 39 | """ 40 | B, N, _ = src.shape 41 | _, M, _ = dst.shape 42 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 43 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 44 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 45 | return dist 46 | 47 | 48 | def index_points(points, idx): 49 | """ 50 | 51 | Input: 52 | points: input points data, [B, N, C] 53 | idx: sample index data, [B, S] 54 | Return: 55 | new_points:, indexed points data, [B, S, C] 56 | """ 57 | device = points.device 58 | B = points.shape[0] 59 | view_shape = list(idx.shape) 60 | view_shape[1:] = [1] * (len(view_shape) - 1) 61 | repeat_shape = list(idx.shape) 62 | repeat_shape[0] = 1 63 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 64 | new_points = points[batch_indices, idx, :] 65 | return new_points 66 | 67 | 68 | def farthest_point_sample(xyz, npoint): 69 | """ 70 | Input: 71 | xyz: pointcloud data, [B, N, 3] 72 | npoint: number of samples 73 | Return: 74 | centroids: sampled pointcloud index, [B, npoint] 75 | """ 76 | device = xyz.device 77 | B, N, C = xyz.shape 78 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 79 | distance = torch.ones(B, N).to(device) * 1e10 80 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 81 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 82 | for i in range(npoint): 83 | centroids[:, i] = farthest 84 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 85 | dist = torch.sum((xyz - centroid) ** 2, -1) 86 | mask = dist < distance 87 | distance[mask] = dist[mask] 88 | farthest = torch.max(distance, -1)[1] 89 | return centroids 90 | 91 | 92 | def query_ball_point(radius, nsample, xyz, new_xyz): 93 | """ 94 | Input: 95 | radius: local region radius 96 | nsample: max sample number in local region 97 | xyz: all points, [B, N, 3] 98 | new_xyz: query points, [B, S, 3] 99 | Return: 100 | group_idx: grouped points index, [B, S, nsample] 101 | """ 102 | device = xyz.device 103 | B, N, C = xyz.shape 104 | _, S, _ = new_xyz.shape 105 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 106 | sqrdists = square_distance(new_xyz, xyz) 107 | group_idx[sqrdists > radius ** 2] = N 108 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 109 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 110 | mask = group_idx == N 111 | group_idx[mask] = group_first[mask] 112 | return group_idx 113 | 114 | 115 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 116 | """ 117 | Input: 118 | npoint: 119 | radius: 120 | nsample: 121 | xyz: input points position data, [B, N, 3] 122 | points: input points data, [B, N, D] 123 | Return: 124 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 125 | new_points: sampled points data, [B, npoint, nsample, 3+D] 126 | """ 127 | B, N, C = xyz.shape 128 | S = npoint 129 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 130 | new_xyz = index_points(xyz, fps_idx) 131 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 132 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 133 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 134 | 135 | if points is not None: 136 | grouped_points = index_points(points, idx) 137 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 138 | else: 139 | new_points = grouped_xyz_norm 140 | if returnfps: 141 | return new_xyz, new_points, grouped_xyz, fps_idx 142 | else: 143 | return new_xyz, new_points 144 | 145 | 146 | def sample_and_group_all(xyz, points): 147 | """ 148 | Input: 149 | xyz: input points position data, [B, N, 3] 150 | points: input points data, [B, N, D] 151 | Return: 152 | new_xyz: sampled points position data, [B, 1, 3] 153 | new_points: sampled points data, [B, 1, N, 3+D] 154 | """ 155 | device = xyz.device 156 | B, N, C = xyz.shape 157 | new_xyz = torch.zeros(B, 1, C).to(device) 158 | grouped_xyz = xyz.view(B, 1, N, C) 159 | if points is not None: 160 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 161 | else: 162 | new_points = grouped_xyz 163 | return new_xyz, new_points 164 | 165 | 166 | class PointNetSetAbstraction(nn.Module): 167 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 168 | super(PointNetSetAbstraction, self).__init__() 169 | self.npoint = npoint 170 | self.radius = radius 171 | self.nsample = nsample 172 | self.mlp_convs = nn.ModuleList() 173 | self.mlp_bns = nn.ModuleList() 174 | last_channel = in_channel 175 | for out_channel in mlp: 176 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 177 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 178 | last_channel = out_channel 179 | self.group_all = group_all 180 | 181 | def forward(self, xyz, points): 182 | """ 183 | Input: 184 | xyz: input points position data, [B, C, N] 185 | points: input points data, [B, D, N] 186 | Return: 187 | new_xyz: sampled points position data, [B, C, S] 188 | new_points_concat: sample points feature data, [B, D', S] 189 | """ 190 | xyz = xyz.permute(0, 2, 1) 191 | if points is not None: 192 | points = points.permute(0, 2, 1) 193 | 194 | if self.group_all: 195 | new_xyz, new_points = sample_and_group_all(xyz, points) 196 | else: 197 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 198 | # new_xyz: sampled points position data, [B, npoint, C] 199 | # new_points: sampled points data, [B, npoint, nsample, C+D] 200 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 201 | for i, conv in enumerate(self.mlp_convs): 202 | bn = self.mlp_bns[i] 203 | new_points = F.relu(bn(conv(new_points))) 204 | 205 | new_points = torch.max(new_points, 2)[0] 206 | new_xyz = new_xyz.permute(0, 2, 1) 207 | return new_xyz, new_points 208 | 209 | 210 | class PointNetSetAbstractionMsg(nn.Module): 211 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 212 | super(PointNetSetAbstractionMsg, self).__init__() 213 | self.npoint = npoint 214 | self.radius_list = radius_list 215 | self.nsample_list = nsample_list 216 | self.conv_blocks = nn.ModuleList() 217 | self.bn_blocks = nn.ModuleList() 218 | for i in range(len(mlp_list)): 219 | convs = nn.ModuleList() 220 | bns = nn.ModuleList() 221 | last_channel = in_channel + 3 222 | for out_channel in mlp_list[i]: 223 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 224 | bns.append(nn.BatchNorm2d(out_channel)) 225 | last_channel = out_channel 226 | self.conv_blocks.append(convs) 227 | self.bn_blocks.append(bns) 228 | 229 | def forward(self, xyz, points): 230 | """ 231 | Input: 232 | xyz: input points position data, [B, C, N] 233 | points: input points data, [B, D, N] 234 | Return: 235 | new_xyz: sampled points position data, [B, C, S] 236 | new_points_concat: sample points feature data, [B, D', S] 237 | """ 238 | xyz = xyz.permute(0, 2, 1) 239 | if points is not None: 240 | points = points.permute(0, 2, 1) 241 | 242 | B, N, C = xyz.shape 243 | S = self.npoint 244 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 245 | new_points_list = [] 246 | for i, radius in enumerate(self.radius_list): 247 | K = self.nsample_list[i] 248 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 249 | grouped_xyz = index_points(xyz, group_idx) 250 | grouped_xyz -= new_xyz.view(B, S, 1, C) 251 | if points is not None: 252 | grouped_points = index_points(points, group_idx) 253 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 254 | else: 255 | grouped_points = grouped_xyz 256 | 257 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 258 | for j in range(len(self.conv_blocks[i])): 259 | conv = self.conv_blocks[i][j] 260 | bn = self.bn_blocks[i][j] 261 | grouped_points = F.relu(bn(conv(grouped_points))) 262 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 263 | new_points_list.append(new_points) 264 | 265 | new_xyz = new_xyz.permute(0, 2, 1) 266 | new_points_concat = torch.cat(new_points_list, dim=1) 267 | return new_xyz, new_points_concat 268 | 269 | 270 | class PointNetFeaturePropagation(nn.Module): 271 | def __init__(self, in_channel, mlp): 272 | super(PointNetFeaturePropagation, self).__init__() 273 | self.mlp_convs = nn.ModuleList() 274 | self.mlp_bns = nn.ModuleList() 275 | last_channel = in_channel 276 | for out_channel in mlp: 277 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 278 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 279 | last_channel = out_channel 280 | 281 | def forward(self, xyz1, xyz2, points1, points2): 282 | """ 283 | Input: 284 | xyz1: input points position data, [B, C, N] 285 | xyz2: sampled input points position data, [B, C, S] 286 | points1: input points data, [B, D, N] 287 | points2: input points data, [B, D, S] 288 | Return: 289 | new_points: upsampled points data, [B, D', N] 290 | """ 291 | xyz1 = xyz1.permute(0, 2, 1) 292 | xyz2 = xyz2.permute(0, 2, 1) 293 | 294 | points2 = points2.permute(0, 2, 1) 295 | B, N, C = xyz1.shape 296 | _, S, _ = xyz2.shape 297 | 298 | if S == 1: 299 | interpolated_points = points2.repeat(1, N, 1) 300 | else: 301 | dists = square_distance(xyz1, xyz2) 302 | dists, idx = dists.sort(dim=-1) 303 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 304 | 305 | dist_recip = 1.0 / (dists + 1e-8) 306 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 307 | weight = dist_recip / norm 308 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 309 | 310 | if points1 is not None: 311 | points1 = points1.permute(0, 2, 1) 312 | new_points = torch.cat([points1, interpolated_points], dim=-1) 313 | else: 314 | new_points = interpolated_points 315 | 316 | new_points = new_points.permute(0, 2, 1) 317 | for i, conv in enumerate(self.mlp_convs): 318 | bn = self.mlp_bns[i] 319 | new_points = F.relu(bn(conv(new_points))) 320 | return new_points 321 | 322 | class Pointnet2_sem(nn.Module): 323 | def __init__(self, num_classes, return_level): 324 | super(Pointnet2_sem, self).__init__() 325 | self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 9 + 3, [32, 32, 64], False) 326 | self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False) 327 | self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False) 328 | self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False) 329 | self.fp4 = PointNetFeaturePropagation(768, [256, 256]) 330 | self.fp3 = PointNetFeaturePropagation(384, [256, 256]) 331 | self.fp2 = PointNetFeaturePropagation(320, [256, 128]) 332 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128]) 333 | self.conv1 = nn.Conv1d(128, 128, 1) 334 | self.bn1 = nn.BatchNorm1d(128) 335 | self.drop1 = nn.Dropout(0.5) 336 | self.conv2 = nn.Conv1d(128, num_classes, 1) 337 | self.return_level = return_level 338 | 339 | def forward(self, xyz): 340 | l0_points = xyz 341 | l0_xyz = xyz[:,:3,:] 342 | 343 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) 344 | if self.return_level == 1: 345 | return l1_xyz, l1_points 346 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 347 | if self.return_level == 2: 348 | return l2_xyz, l2_points 349 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 350 | if self.return_level == 3: 351 | return l3_xyz, l3_points 352 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points) 353 | return l4_xyz, l4_points 354 | 355 | # l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points) 356 | # l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 357 | # l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 358 | # l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points) 359 | # 360 | # x = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) 361 | # x = self.conv2(x) 362 | # x = F.log_softmax(x, dim=1) 363 | # x = x.permute(0, 2, 1) 364 | # return x, l4_xyz, l4_points 365 | 366 | 367 | class Pointnet2_encoder(nn.Module): 368 | def __init__(self, num_classes=13, output_dim=128, pointnet2_checkpoint=checkpoint_folder.joinpath('pointnet2_seg_ssg.pth'), return_level=4): 369 | super(Pointnet2_encoder, self).__init__() 370 | self.pointenet2_sem = Pointnet2_sem(num_classes, return_level) 371 | self.pointenet2_sem.load_state_dict(torch.load(pointnet2_checkpoint)['model_state_dict']) 372 | for param in self.pointenet2_sem.parameters(): 373 | param.requires_grad = False 374 | self.sa = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 128, output_dim], group_all=True) 375 | 376 | 377 | def forward(self, xyz): 378 | l4_xyz, l4_points = self.pointenet2_sem(xyz) 379 | l5_xyz, l5_points = self.sa(l4_xyz, l4_points) 380 | # print(l5_points.squeeze().shape) 381 | return l5_points.squeeze() 382 | # return torch.zeros((xyz.shape[0], 128), device=xyz.device) 383 | 384 | class LocalPointEncoder(nn.Module): 385 | def __init__(self, num_classes=13, output_dim=128, 386 | pointnet2_checkpoint=checkpoint_folder.joinpath('pointnet2_seg_ssg.pth'), 387 | return_level=4, 388 | freeze=False 389 | ): 390 | super(LocalPointEncoder, self).__init__() 391 | self.pointenet2_sem = Pointnet2_sem(num_classes, return_level) 392 | self.pointenet2_sem.load_state_dict(torch.load(pointnet2_checkpoint)['model_state_dict']) 393 | if freeze: 394 | for param in self.pointenet2_sem.parameters(): 395 | param.requires_grad = False 396 | 397 | def forward(self, xyz): 398 | # return torch.zeros((xyz.shape[0], 3, 16), device=xyz.device, dtype=torch.float32), torch.zeros( 399 | # (xyz.shape[0], 512, 16), device=xyz.device, dtype=torch.float32) 400 | xyz, points = self.pointenet2_sem(xyz) 401 | return xyz, points 402 | 403 | 404 | class PointNet2_cls_ssg(nn.Module): 405 | def __init__(self, num_class, normal_channel=True): 406 | super(PointNet2_cls_ssg, self).__init__() 407 | in_channel = 6 if normal_channel else 3 408 | self.normal_channel = normal_channel 409 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False) 410 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False) 411 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True) 412 | self.fc1 = nn.Linear(1024, 512) 413 | self.bn1 = nn.BatchNorm1d(512) 414 | self.drop1 = nn.Dropout(0.4) 415 | self.fc2 = nn.Linear(512, 256) 416 | self.bn2 = nn.BatchNorm1d(256) 417 | self.drop2 = nn.Dropout(0.4) 418 | self.fc3 = nn.Linear(256, num_class) 419 | 420 | def forward(self, xyz): 421 | B, _, _ = xyz.shape 422 | if self.normal_channel: 423 | norm = xyz[:, 3:, :] 424 | xyz = xyz[:, :3, :] 425 | else: 426 | norm = None 427 | l1_xyz, l1_points = self.sa1(xyz, norm) 428 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 429 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 430 | x = l3_points.view(B, 1024) 431 | # x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 432 | # x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 433 | # x = self.fc3(x) 434 | # x = F.log_softmax(x, -1) 435 | 436 | 437 | return x, l3_points 438 | 439 | class Pointnet2_cls_encoder(nn.Module): 440 | def __init__(self, num_classes=40, output_dim=128, checkpoint=checkpoint_folder.joinpath('pointnet2_cls_ssg.pth'), freeze=True): 441 | super(Pointnet2_cls_encoder, self).__init__() 442 | self.pointenet2_cls = PointNet2_cls_ssg(num_classes, normal_channel=False) 443 | self.pointenet2_cls.load_state_dict(torch.load(checkpoint)['model_state_dict']) 444 | if freeze: 445 | for param in self.pointenet2_cls.parameters(): 446 | param.requires_grad = False 447 | 448 | def forward(self, xyz): 449 | xyz = xyz[:, :3, :] 450 | x, l3_points = self.pointenet2_cls(xyz) 451 | return l3_points.squeeze() 452 | # return torch.zeros((xyz.shape[0], 1024), device=xyz.device, dtype=torch.float32) -------------------------------------------------------------------------------- /utils/viz_util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import os.path as osp 5 | import open3d as o3d 6 | import torch 7 | import numpy as np 8 | import pandas as pd 9 | import trimesh 10 | import pyrender 11 | import PIL.Image as pil_img 12 | import eulerangles 13 | 14 | def create_renderer(H=1080, W=1920, intensity=50, fov=None, point_size=1.0): 15 | if fov is None: 16 | fov = np.pi / 3.0 17 | r = pyrender.OffscreenRenderer(viewport_width=W, 18 | viewport_height=H, 19 | point_size=point_size) 20 | camera = pyrender.PerspectiveCamera(yfov=fov, aspectRatio=1.333) 21 | light_directional = pyrender.DirectionalLight(color=np.ones(3), intensity=intensity) 22 | light_point = pyrender.PointLight(color=np.ones(3), intensity=intensity) 23 | material = pyrender.MetallicRoughnessMaterial( 24 | metallicFactor=0.0, 25 | alphaMode='OPAQUE', 26 | baseColorFactor=(1.0, 1.0, 0.9, 1.0)) 27 | return r, camera, light_directional, light_point, material 28 | 29 | 30 | def create_collage(images, mode='grid'): 31 | n = len(images) 32 | W, H = images[0].size 33 | if mode == 'grid': 34 | img_collage = pil_img.new('RGB', (2 * W, 2 * H)) 35 | for id, img in enumerate(images): 36 | img_collage.paste(img, (W * (id % 2), H * int(id / 2))) 37 | elif mode == 'vertical': 38 | img_collage = pil_img.new('RGB', (W, n * H)) 39 | for id, img in enumerate(images): 40 | img_collage.paste(img, (0, id * H)) 41 | elif mode == 'horizantal': 42 | img_collage = pil_img.new('RGB', (n * W, H)) 43 | for id, img in enumerate(images): 44 | img_collage.paste(img, (id * W, 0)) 45 | return img_collage 46 | 47 | 48 | def render_interaction_multview(body, static_scene, clothed_body=None, use_clothed_mesh=False, body_center=True, smooth_body=True, 49 | collage_mode='grid', body_contrast=None, obj_points_coord=None, num_view=4, **kwargs): 50 | H, W = int(480 * 1.5), int(640 * 1.5) 51 | renderer, camera, light_directional, light_point, material = create_renderer(H=H, W=W, intensity=2.0) 52 | light_point.intensity = 10.0 53 | 54 | # this will make the camera looks in the -x direction 55 | camera_pose = np.eye(4) 56 | camera_pose[0, 3] = 2 57 | camera_pose[2, 3] = 1 58 | camera_pose[:3, :3] = eulerangles.euler2mat(-np.pi / 6, np.pi / 2, np.pi / 2, 'sxzy') 59 | 60 | if body_center: 61 | center = (body.vertices.max(axis=0) + body.vertices.min(axis=0)) / 2.0 62 | # camera_pose[0, 3] = 0.5 63 | # camera_pose[2, 3] = 2 64 | else: 65 | center = (static_scene.vertices.max(axis=0) + static_scene.vertices.min(axis=0)) / 2.0 66 | camera_pose[0, 3] = 3 67 | 68 | static_scene.vertices -= center 69 | body.vertices -= center 70 | if use_clothed_mesh: 71 | clothed_body.vertices -= center 72 | if body_contrast is not None: 73 | body_contrast.vertices -= center 74 | 75 | 76 | images = [] 77 | views = list(range(0, 360, 90)) if num_view == 4 else [0, 90] 78 | # for ang_id, ang in enumerate(range(0, 360, 90)): 79 | for ang_id, ang in enumerate(views): 80 | ang = np.pi / 180 * ang 81 | rot_z = np.eye(4) 82 | rot_z[:3, :3] = eulerangles.euler2mat(ang, 0, 0, 'szxy') 83 | 84 | # print(1) 85 | static_scene_mesh = pyrender.Mesh.from_trimesh(static_scene) 86 | if use_clothed_mesh: 87 | body_mesh = pyrender.Mesh.from_trimesh(clothed_body, material=pyrender.MetallicRoughnessMaterial(alphaMode="BLEND", 88 | baseColorFactor=(1.0, 1.0, 1.0, 0.5), 89 | metallicFactor=0.0,)) 90 | else: 91 | body_mesh = pyrender.Mesh.from_trimesh(body, smooth=smooth_body) 92 | 93 | # print(2) 94 | scene = pyrender.Scene() 95 | scene.add(camera, pose=np.matmul(rot_z, camera_pose)) 96 | scene.add(light_point, pose=np.matmul(rot_z, camera_pose)) 97 | scene.add(light_directional, pose=np.eye(4)) 98 | scene.add(static_scene_mesh, 'mesh') 99 | # print(3) 100 | if obj_points_coord is not None: 101 | scene.add(pyrender.Mesh.from_points(points=obj_points_coord - center, colors=(1.0, 0.1, 0.1)), 'mesh') 102 | # print(4) 103 | scene.add(body_mesh, 'mesh') 104 | if body_contrast is not None: 105 | scene.add(pyrender.Mesh.from_trimesh(body_contrast, material=material), 'mesh') 106 | # print(5) 107 | color, _ = renderer.render(scene, pyrender.constants.RenderFlags.SHADOWS_DIRECTIONAL 108 | # | pyrender.constants.RenderFlags.SKIP_CULL_FACES 109 | # | pyrender.constants.RenderFlags.VERTEX_NORMALS 110 | ) 111 | # print(6) 112 | color = color.astype(np.float32) / 255.0 113 | img = pil_img.fromarray((color * 255).astype(np.uint8)) 114 | images.append(img) 115 | # print(7) 116 | # collage_mode = 'grid' if num_view == 4 else 'horizantal' 117 | images = create_collage(images, collage_mode) 118 | static_scene.vertices += center 119 | body.vertices += center 120 | return images 121 | 122 | 123 | def render_composite_interaction_multview(body, scene_meshes, body_center=True, use_material=True, smooth_body=True, 124 | collage_mode='grid', body_contrast=None, obj_points_coord=None, **kwargs): 125 | H, W = 720, 960 126 | renderer, camera, light_directional, light_point, material = create_renderer(H=H, W=W, intensity=2.0) 127 | light_point.intensity = 10.0 128 | 129 | # assert(len(scene_meshes) == obj_points_coord.shape[0]) 130 | 131 | # this will make the camera looks in the -x direction 132 | camera_pose = np.eye(4) 133 | camera_pose[0, 3] = 2 134 | camera_pose[2, 3] = 1 135 | camera_pose[:3, :3] = eulerangles.euler2mat(-np.pi / 6, np.pi / 2, np.pi / 2, 'sxzy') 136 | 137 | center = (body.vertices.max(axis=0) + body.vertices.min(axis=0)) / 2.0 138 | 139 | for scene_mesh in scene_meshes: 140 | scene_mesh.vertices -= center 141 | body.vertices -= center 142 | if body_contrast is not None: 143 | body_contrast.vertices -= center 144 | 145 | images = [] 146 | # for ang_id, ang in enumerate(range(0, 360, 90)): 147 | for ang_id, ang in enumerate(range(0, 360, 90)): 148 | ang = np.pi / 180 * ang 149 | rot_z = np.eye(4) 150 | rot_z[:3, :3] = eulerangles.euler2mat(ang, 0, 0, 'szxy') 151 | 152 | 153 | 154 | scene = pyrender.Scene() 155 | scene.add(camera, pose=np.matmul(rot_z, camera_pose)) 156 | scene.add(light_point, pose=np.matmul(rot_z, camera_pose)) 157 | scene.add(light_directional, pose=np.eye(4)) 158 | for scene_mesh in scene_meshes: 159 | static_scene_mesh = pyrender.Mesh.from_trimesh(scene_mesh) 160 | scene.add(static_scene_mesh, 'mesh') 161 | if obj_points_coord is not None: 162 | for obj_idx in range(obj_points_coord.shape[0]): 163 | scene.add(pyrender.Mesh.from_points(points=obj_points_coord[obj_idx, :, :] - center, colors=(1.0, 0.1, 0.1)), 'mesh') 164 | body_mesh = pyrender.Mesh.from_trimesh(body, material=material, smooth=smooth_body) if use_material else pyrender.Mesh.from_trimesh(body, smooth=smooth_body) 165 | scene.add(body_mesh, 'mesh') 166 | if body_contrast is not None: 167 | scene.add(pyrender.Mesh.from_trimesh(body_contrast, material=material, smooth=smooth_body) if use_material else pyrender.Mesh.from_trimesh(body_contrast, smooth=smooth_body), 'mesh') 168 | 169 | color, _ = renderer.render(scene, pyrender.constants.RenderFlags.SHADOWS_DIRECTIONAL) 170 | color = color.astype(np.float32) / 255.0 171 | img = pil_img.fromarray((color * 255).astype(np.uint8)) 172 | images.append(img) 173 | images = create_collage(images, collage_mode) 174 | for scene_mesh in scene_meshes: 175 | scene_mesh.vertices += center 176 | body.vertices += center 177 | if body_contrast is not None: 178 | body_contrast.vertices += center 179 | return images 180 | 181 | 182 | def render_body_multview(body, body_center=True, num_view=4, use_material=True, 183 | collage_mode='grid', body_contrast=None, **kwargs): 184 | H, W = 480 * 2, 640 * 2 185 | renderer, camera, light_directional, light_point, material = create_renderer(H=H, W=W, intensity=2.0) 186 | light_point.intensity = 10.0 187 | 188 | # this will make the camera looks in the -x direction 189 | camera_pose = np.eye(4) 190 | camera_pose[0, 3] = 2 191 | camera_pose[2, 3] = 1 192 | camera_pose[:3, :3] = eulerangles.euler2mat(-np.pi / 6, np.pi / 2, np.pi / 2, 'sxzy') 193 | 194 | center = (body.vertices.max(axis=0) + body.vertices.min(axis=0)) / 2.0 195 | 196 | body.vertices -= center 197 | if body_contrast is not None: 198 | body_contrast.vertices -= center - np.array((0.5, 0.0, 0.0)) 199 | 200 | images = [] 201 | views = list(range(0, 360, 90)) if num_view == 4 else [0, 90] 202 | # for ang_id, ang in enumerate(range(0, 360, 90)): 203 | for ang_id, ang in enumerate(views): 204 | ang = np.pi / 180 * ang 205 | rot_z = np.eye(4) 206 | rot_z[:3, :3] = eulerangles.euler2mat(0, 0, ang, 'szxy') 207 | 208 | 209 | 210 | scene = pyrender.Scene() 211 | scene.add(camera, pose=np.matmul(rot_z, camera_pose)) 212 | scene.add(light_point, pose=np.matmul(rot_z, camera_pose)) 213 | scene.add(light_directional, pose=np.eye(4)) 214 | body_mesh = pyrender.Mesh.from_trimesh(body, material=material, smooth=True) if use_material else pyrender.Mesh.from_trimesh(body, smooth=True) 215 | scene.add(body_mesh, 'mesh') 216 | if body_contrast is not None: 217 | scene.add(pyrender.Mesh.from_trimesh(body_contrast, material=material, smooth=True) if use_material else pyrender.Mesh.from_trimesh(body_contrast, smooth=True), 'mesh') 218 | 219 | color, _ = renderer.render(scene, pyrender.constants.RenderFlags.SHADOWS_DIRECTIONAL 220 | # | pyrender.constants.RenderFlags.SKIP_CULL_FACES 221 | ) 222 | color = color.astype(np.float32) / 255.0 223 | img = pil_img.fromarray((color * 255).astype(np.uint8)) 224 | images.append(img) 225 | images = create_collage(images, collage_mode) 226 | body.vertices += center 227 | if body_contrast is not None: 228 | body_contrast.vertices += center - np.array((0.5, 0.0, 0.0)) 229 | return images 230 | 231 | 232 | def render_obj_multview(obj_pointcloud, frame, 233 | collage_mode='grid', frame_contrast=None, body=None, **kwargs): 234 | H, W = 480, 640 235 | renderer, camera, light_directional, light_point, material = create_renderer(H=H, W=W, intensity=2.0, point_size=5.0) 236 | light_point.intensity = 10.0 237 | 238 | # this will make the camera looks in the -x direction 239 | camera_pose = np.eye(4) 240 | camera_pose[0, 3] = 2 241 | camera_pose[2, 3] = 1 242 | camera_pose[:3, :3] = eulerangles.euler2mat(-np.pi / 6, np.pi / 2, np.pi / 2, 'sxzy') 243 | 244 | center = (obj_pointcloud.vertices.max(axis=0) + obj_pointcloud.vertices.min(axis=0)) / 2.0 245 | 246 | obj_pointcloud.vertices -= center 247 | frame.vertices -= center 248 | if frame_contrast is not None: 249 | frame_contrast.vertices -= center 250 | if body is not None: 251 | body.vertices -= center 252 | 253 | images = [] 254 | # for ang_id, ang in enumerate(range(0, 360, 90)): 255 | for ang_id, ang in enumerate(range(0, 360, 90)): 256 | ang = np.pi / 180 * ang 257 | rot_z = np.eye(4) 258 | rot_z[:3, :3] = eulerangles.euler2mat(ang, 0, 0, 'szxy') 259 | 260 | scene = pyrender.Scene() 261 | scene.add(camera, pose=np.matmul(rot_z, camera_pose)) 262 | scene.add(light_point, pose=np.matmul(rot_z, camera_pose)) 263 | scene.add(light_directional, pose=np.eye(4)) 264 | scene.add(pyrender.Mesh.from_points(points=obj_pointcloud.vertices, colors=obj_pointcloud.colors / 255.0), 'mesh') 265 | scene.add(pyrender.Mesh.from_trimesh(frame, smooth=False), 'mesh') 266 | if body is not None: 267 | body_mesh = pyrender.Mesh.from_trimesh(body, material=pyrender.MetallicRoughnessMaterial(alphaMode="BLEND", 268 | baseColorFactor=(1.0, 1.0, 1.0, 0.5), 269 | metallicFactor=0.0,), 270 | smooth=False) 271 | # print('transparent', body_mesh.is_transparent) 272 | scene.add(body_mesh, 'mesh') 273 | if frame_contrast is not None: 274 | scene.add(pyrender.Mesh.from_trimesh(frame_contrast, smooth=False), 'mesh') 275 | 276 | color, _ = renderer.render(scene, pyrender.constants.RenderFlags.SHADOWS_DIRECTIONAL 277 | | pyrender.constants.RenderFlags.RGBA 278 | # | pyrender.constants.RenderFlags.SKIP_CULL_FACES 279 | ) 280 | color = color.astype(np.float32) / 255.0 281 | img = pil_img.fromarray((color * 255).astype(np.uint8)) 282 | images.append(img) 283 | images = create_collage(images, collage_mode) 284 | obj_pointcloud.vertices += center 285 | frame.vertices += center 286 | if frame_contrast is not None: 287 | frame_contrast.vertices += center 288 | if body is not None: 289 | body.vertices += center 290 | return images 291 | 292 | 293 | def makeLookAt(position, target, up): 294 | 295 | forward = np.subtract(target, position) 296 | forward = np.divide(forward, np.linalg.norm(forward)) 297 | 298 | right = np.cross(forward, up) 299 | 300 | # if forward and up vectors are parallel, right vector is zero; 301 | # fix by perturbing up vector a bit 302 | if np.linalg.norm(right) < 0.001: 303 | epsilon = np.array([0.001, 0, 0]) 304 | right = np.cross(forward, up + epsilon) 305 | 306 | right = np.divide(right, np.linalg.norm(right)) 307 | 308 | up = np.cross(right, forward) 309 | up = np.divide(up, np.linalg.norm(up)) 310 | 311 | return np.array([[right[0], up[0], -forward[0], position[0]], 312 | [right[1], up[1], -forward[1], position[1]], 313 | [right[2], up[2], -forward[2], position[2]], 314 | [0, 0, 0, 1]]) 315 | 316 | 317 | def render_scene_three_view(scene_mesh, body_mesh, collage_mode='grid', center='human', render_points=None, **kwargs): 318 | H, W = 480, 640 319 | renderer, camera, light_directional, light_point, material = create_renderer(H=H, W=W, intensity=2.0, point_size=10.0) 320 | light_point.intensity = 10.0 321 | 322 | center = (body_mesh.vertices.max(axis=0) + body_mesh.vertices.min(axis=0)) / 2.0 if center == 'human' else (scene_mesh.vertices.max(axis=0) + scene_mesh.vertices.min(axis=0)) / 2.0 323 | scene_mesh.vertices -= center 324 | body_mesh.vertices -= center 325 | dist = max(np.absolute(body_mesh.vertices).max() + 1, 1.5) if center == 'human' else max(np.absolute(scene_mesh.vertices).max() + 1, 1.5) 326 | 327 | camera_poses = [ 328 | makeLookAt(np.array([dist, 0, 0]), np.array([0, 0, 0]), np.array([0, 0, 1])), 329 | makeLookAt(np.array([0, dist, 0]), np.array([0, 0, 0]), np.array([0, 0, 1])), 330 | makeLookAt(np.array([0, 0, dist]), np.array([0, 0, 0]), np.array([-1, 0, 0])), 331 | ] 332 | 333 | images = [] 334 | # for ang_id, ang in enumerate(range(0, 360, 90)): 335 | for camera_pose in camera_poses: 336 | 337 | scene = pyrender.Scene() 338 | scene.add(camera, pose=camera_pose) 339 | scene.add(light_point, pose=camera_pose) 340 | scene.add(light_directional, pose=np.eye(4)) 341 | for mesh in [scene_mesh, body_mesh]: 342 | scene.add(pyrender.Mesh.from_trimesh(mesh, material=pyrender.MetallicRoughnessMaterial(alphaMode="BLEND", 343 | baseColorFactor=(1.0, 1.0, 1.0, 1.0), 344 | metallicFactor=0.0,), smooth=False), 'mesh') 345 | if render_points is not None: 346 | points_mesh = pyrender.Mesh.from_points(render_points.vertices - center, render_points.colors / 255.0) 347 | # points_mesh.material = pyrender.MetallicRoughnessMaterial(alphaMode="BLEND", 348 | # baseColorFactor=(1.0, 1.0, 1.0, 1.0), 349 | # metallicFactor=0.0,) 350 | # print(points_mesh.primitives[0].color_0) 351 | # print(points_mesh.is_transparent) 352 | scene.add(points_mesh, 'mesh') 353 | 354 | color, _ = renderer.render(scene, pyrender.constants.RenderFlags.SHADOWS_DIRECTIONAL 355 | | pyrender.constants.RenderFlags.RGBA 356 | | pyrender.constants.RenderFlags.SKIP_CULL_FACES 357 | ) 358 | color = color.astype(np.float32) / 255.0 359 | img = pil_img.fromarray((color * 255).astype(np.uint8)) 360 | images.append(img) 361 | images = create_collage(images, collage_mode) 362 | scene_mesh.vertices += center 363 | body_mesh.vertices += center 364 | return images 365 | 366 | def render_alignment_three_view(scene_mesh, shapenet_mesh, collage_mode='grid', center='human', render_points=None, **kwargs): 367 | H, W = 480, 640 368 | renderer, camera, light_directional, light_point, material = create_renderer(H=H, W=W, intensity=2.0, point_size=10.0) 369 | light_point.intensity = 2.0 370 | 371 | center = (scene_mesh.vertices.max(axis=0) + scene_mesh.vertices.min(axis=0)) / 2.0 372 | scene_mesh.vertices -= center 373 | shapenet_mesh.vertices -= center 374 | dist = max(np.absolute(scene_mesh.vertices).max() + 1, 1.5) 375 | 376 | camera_poses = [ 377 | makeLookAt(np.array([dist, 0, 0]), np.array([0, 0, 0]), np.array([0, 0, 1])), 378 | makeLookAt(np.array([0, dist, 0]), np.array([0, 0, 0]), np.array([0, 0, 1])), 379 | makeLookAt(np.array([0, 0, dist]), np.array([0, 0, 0]), np.array([-1, 0, 0])), 380 | ] 381 | 382 | images = [] 383 | # for ang_id, ang in enumerate(range(0, 360, 90)): 384 | for camera_pose in camera_poses: 385 | 386 | scene = pyrender.Scene() 387 | scene.add(camera, pose=camera_pose) 388 | scene.add(light_point, pose=camera_pose) 389 | scene.add(light_directional, pose=np.eye(4)) 390 | for mesh in [scene_mesh, shapenet_mesh]: 391 | scene.add(pyrender.Mesh.from_trimesh(mesh, smooth=False), 'mesh') 392 | 393 | color, _ = renderer.render(scene, pyrender.constants.RenderFlags.SHADOWS_DIRECTIONAL 394 | | pyrender.constants.RenderFlags.RGBA 395 | | pyrender.constants.RenderFlags.SKIP_CULL_FACES 396 | ) 397 | color = color.astype(np.float32) / 255.0 398 | img = pil_img.fromarray((color * 255).astype(np.uint8)) 399 | images.append(img) 400 | images = create_collage(images, collage_mode) 401 | scene_mesh.vertices += center 402 | shapenet_mesh.vertices += center 403 | return images -------------------------------------------------------------------------------- /models/transform_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYOPENGL_PLATFORM'] = 'osmesa' 3 | import sys 4 | sys.path.append('..') 5 | 6 | 7 | import smplx 8 | import trimesh 9 | from datetime import datetime 10 | import pickle 11 | import shutil 12 | import torch 13 | from torch.utils.data import DataLoader 14 | import torch.nn.functional as F 15 | from torch.optim.lr_scheduler import ReduceLROnPlateau 16 | import torchgeometry as tgm 17 | import pytorch3d 18 | from pytorch3d.structures import Pointclouds, Meshes 19 | import pytorch3d.loss 20 | import pytorch_lightning as pl 21 | from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler 22 | from pytorch_lightning import loggers as pl_loggers 23 | from pathlib import Path 24 | import open3d as o3d 25 | from datetime import datetime 26 | from copy import deepcopy 27 | from argparse import ArgumentParser, Namespace 28 | 29 | from configuration.joints import * 30 | from dataset import CompositeFrameDataset 31 | from data.scene import scenes, to_trimesh, to_open3d 32 | from models.smplx_regressor import SMPLX_Regressor 33 | from models.loss import * 34 | from models.module import InteractionVAE 35 | from utils.utils import * 36 | from utils.viz_util import render_obj_multview 37 | from utils.data_util import * 38 | 39 | 40 | def rot6d_to_mat(module_input): 41 | reshaped_input = module_input.view(-1, 3, 2) 42 | 43 | b1 = F.normalize(reshaped_input[:, :, 0], dim=1) 44 | 45 | dot_prod = torch.sum(b1 * reshaped_input[:, :, 1], dim=1, keepdim=True) 46 | b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=-1) 47 | b3 = torch.cross(b1, b2, dim=1) 48 | 49 | return torch.stack([b1, b2, b3], dim=-1) 50 | 51 | 52 | class geodesic_loss_R(nn.Module): 53 | def __init__(self, reduction='batchmean'): 54 | super(geodesic_loss_R, self).__init__() 55 | 56 | self.reduction = reduction 57 | self.eps = 1e-6 58 | 59 | 60 | # batch geodesic loss for rotation matrices 61 | def bgdR(self,m1,m2): 62 | batch = m1.shape[0] 63 | m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3 64 | 65 | cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 66 | cos = torch.min(cos, m1.new(np.ones(batch))) 67 | cos = torch.max(cos, m1.new(np.ones(batch)) * -1) 68 | 69 | return torch.acos(cos) 70 | 71 | 72 | def forward(self, ypred, ytrue): 73 | theta = self.bgdR(ypred,ytrue) 74 | if self.reduction == 'mean': 75 | return torch.mean(theta) 76 | if self.reduction == 'batchmean': 77 | breakpoint() 78 | return torch.mean(torch.sum(theta, dim=theta.shape[1:])) 79 | 80 | else: 81 | return theta 82 | 83 | 84 | def create_frame(x, origin_color=(0.8, 0.8, 0.8)): 85 | pelvis = x[6:].detach().cpu().numpy() 86 | rotmat = rot6d_to_mat(x[:6]).detach().cpu().numpy().reshape((3, 3)) 87 | transform = np.eye(4, dtype=np.float32) 88 | transform[:3, :3] = rotmat 89 | transform[:3, 3] = pelvis 90 | return trimesh.creation.axis(transform=transform, origin_color=origin_color) 91 | 92 | 93 | class LitTransformNet(pl.LightningModule): 94 | def __init__(self, args): 95 | super().__init__() 96 | if isinstance(args, dict): 97 | args = Namespace(**args) 98 | self.args = args 99 | self.save_hyperparameters(args) 100 | self.start_time = datetime.now().strftime("%m:%d:%Y_%H:%M:%S") 101 | # self.save_hyperparameters('args') 102 | args.device = device 103 | # self.body_model = smplx.create(smplx_model_folder, model_type='smplx', 104 | # gender='neutral', ext='npz', 105 | # num_pca_comps=num_pca_comps, batch_size=1) 106 | 107 | if args.model == 'InteractionVAE': 108 | self.model = InteractionVAE(args) 109 | else: 110 | print('not implemented') 111 | return 112 | 113 | 114 | # x: 6d global orientation and 3d location of pelvis 115 | def forward(self, x, batch): 116 | return self.model(x, batch) 117 | 118 | 119 | def configure_optimizers(self): 120 | optimizer = torch.optim.Adam(params=self.model.parameters(), 121 | lr=self.args.lr, 122 | weight_decay=self.args.l2_norm) 123 | 124 | lr_scheduler = ReduceLROnPlateau(optimizer, patience=5, factor=0.9, verbose=True) 125 | return ({'optimizer': optimizer, 126 | }) 127 | 128 | 129 | def calc_loss(self, x, x_hat, q_z, batch=None): 130 | obj_pointclouds, verb_ids = batch['object_pointclouds'], batch['verb_ids'] 131 | batch_size = x.shape[0] 132 | rotmat = rot6d_to_mat(x[:, :6]) 133 | pelvis = x[:, 6:] 134 | rotmat_hat = rot6d_to_mat(x_hat[:, :6]) 135 | pelvis_hat = x_hat[:, 6:] 136 | loss_orient = geodesic_loss_R(reduction='mean')( 137 | rotmat_hat, 138 | rotmat 139 | ) 140 | loss_pelvis = F.l1_loss(pelvis_hat, pelvis) 141 | 142 | location = obj_pointclouds[:, :, :, :3] 143 | vectors = location - pelvis[:, None, None, :] 144 | min_dist, _ = torch.min(torch.sum(vectors ** 2, dim=-1), dim=-1) # BxI 145 | vectors_hat = location - pelvis_hat[:, None, None, :] 146 | min_dist_hat, _ = torch.min(torch.sum(vectors_hat ** 2, dim=-1), dim=-1) 147 | loss_dist = F.l1_loss(min_dist_hat, min_dist) 148 | 149 | # loss of reconstruction of points coord in pelvis frame 150 | local_coords = torch.matmul(vectors, rotmat.transpose(1, 2)[:, None, :, :]) 151 | local_coords_hat = torch.matmul(vectors_hat, rotmat_hat.transpose(1, 2)[:, None, :, :]) 152 | loss_coord = F.l1_loss(local_coords, local_coords_hat) 153 | 154 | # pelvis-object penetration loss 155 | dist_hat = torch.sqrt(torch.sum(vectors_hat ** 2, dim=-1)) # BxIxP 156 | thresh = self.args.thresh_penetration 157 | # positive value means very close to pelvis, possible penetration 158 | penetration = thresh - dist_hat 159 | penetration_mask = (verb_ids == 3).unsqueeze(2) # whether atomic is touch, BxIx1 160 | penetration = penetration * penetration_mask.float() 161 | penetration = penetration[penetration > 0] 162 | loss_penetration = penetration.mean() if len(penetration) > 0 else torch.tensor(0.0, device=penetration.device) 163 | 164 | p_z = torch.distributions.normal.Normal( 165 | loc=torch.zeros((x.shape[0], self.args.latent_dim), requires_grad=False, device=device), 166 | scale=torch.ones((x.shape[0], self.args.latent_dim), requires_grad=False, device=device)) 167 | loss_kl = torch.mean(torch.mean(torch.distributions.kl.kl_divergence(q_z, p_z), dim=[1])) 168 | if self.args.robust_kl: 169 | loss_kl = torch.sqrt(loss_kl * loss_kl + 1) - 1.0 170 | 171 | loss_dict = dict(orient=loss_orient, 172 | pelvis=loss_pelvis, 173 | dist=loss_dist, 174 | kl=loss_kl, 175 | penetration=loss_penetration, 176 | coord=loss_coord, 177 | ) 178 | 179 | annealing_factor = min(1.0, max(float(self.current_epoch) / (self.args.second_stage), 0)) if self.args.use_annealing else 1 180 | weighted_loss_dict = { 181 | 'orient': loss_dict['orient'] * self.args.weight_orient, 182 | 'pelvis': loss_dict['pelvis'] * self.args.weight_pelvis, 183 | 'dist': loss_dict['dist'] * self.args.weight_dist, 184 | 'coord': loss_dict['coord'] * self.args.weight_coord, 185 | 'penetration': loss_dict['penetration'] * self.args.weight_penetration, 186 | 'kl': 187 | max(annealing_factor ** 2, 0) * 188 | self.args.weight_kl * loss_dict['kl'], 189 | } 190 | 191 | loss = torch.stack(list(weighted_loss_dict.values())).sum() 192 | 193 | return loss, loss_dict, weighted_loss_dict 194 | 195 | 196 | def _common_step(self, batch, batch_idx, mode): 197 | pelvis, rotmat = batch['pelvis'], batch['pelvis_orient'] 198 | rot6d = rotmat[:, :3, :2].reshape(-1, 6) 199 | x = torch.cat([rot6d, pelvis], dim=1) 200 | x_hat, q_z = self(x, batch) 201 | x_hat = x_hat.squeeze(1) 202 | loss, loss_dict, weighted_loss_dict = self.calc_loss(x, x_hat, q_z, batch=batch) 203 | 204 | # render reconstructed and sampled interactions 205 | render_interval = 64 if mode == 'valid' else 256 206 | if (batch_idx % render_interval == 0) and (self.current_epoch > self.args.render_epoch or self.args.debug): 207 | x_sample, _ = self.model.sample(batch) 208 | x_sample = x_sample.squeeze(1) 209 | obj_points = batch['object_pointclouds'] 210 | B, I, P, C = obj_points.shape 211 | obj_points = obj_points.reshape(B, I*P, C) 212 | batch_size = x.shape[0] 213 | render_num = 4 214 | for idx in range(min(batch_size, render_num)): 215 | base_name = mode + '_E{:03d}_It{:04d}_orient_{:.4f}_pelvis_{:.5f}_id{:d}_{}.png'.format( 216 | self.current_epoch, batch_idx, loss_dict['orient'].item(), 217 | loss_dict['pelvis'].item(), idx, batch['interaction'][idx]) 218 | 219 | colors = np.ones((obj_points.shape[1], 4), dtype=np.uint8) * 255 220 | colors[:, :3] = (obj_points[idx, :, 3:6].cpu().numpy() * 255).astype(np.uint8) 221 | body=None 222 | obj_pointcloud = trimesh.PointCloud( 223 | vertices=obj_points[idx, :, :3].cpu().numpy(), 224 | colors=colors, 225 | ) 226 | frame_ori = create_frame(x[idx], origin_color=(1.0, 0.0, 0.0)) 227 | frame_rec = create_frame(x_hat[idx], origin_color=(0.0, 1.0, 0.0)) 228 | frame_sample = create_frame(x_sample[idx], origin_color=(0.0, 0.0, 1.0)) 229 | 230 | 231 | export_file = Path.joinpath(save_dir, 'render', 'contrast_' + base_name) 232 | export_file.parent.mkdir(exist_ok=True, parents=True) 233 | img_collage = render_obj_multview(obj_pointcloud, frame_rec, frame_contrast=frame_ori, body=body) 234 | img_collage.save(str(export_file)) 235 | 236 | export_file = Path.joinpath(save_dir, 'render', 'sample_' + base_name) 237 | img_collage = render_obj_multview(obj_pointcloud, frame_sample) 238 | img_collage.save(str(export_file)) 239 | 240 | 241 | return loss, loss_dict, weighted_loss_dict 242 | 243 | 244 | def training_step(self, batch, batch_idx): 245 | loss, loss_dict, weighted_loss_dict = self._common_step(batch, batch_idx, 'train') 246 | 247 | self.log('train_loss', loss, prog_bar=False) 248 | for key in loss_dict: 249 | self.log(key, loss_dict[key], prog_bar=True) 250 | 251 | return loss 252 | 253 | 254 | def validation_step(self, batch, batch_idx): 255 | loss, loss_dict, weighted_loss_dict = self._common_step(batch, batch_idx, 'valid') 256 | 257 | for key in loss_dict: 258 | self.log('val_' + key, loss_dict[key], prog_bar=False) 259 | self.log('val_loss', loss) 260 | 261 | 262 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 263 | if __name__ == '__main__': 264 | if torch.cuda.is_available(): 265 | print(torch.cuda.get_device_name(0)) 266 | # args 267 | parser = ArgumentParser() 268 | parser.add_argument("--model", type=str, default='InteractionVAE') 269 | parser.add_argument("--num_verb", type=int, default=4) 270 | parser.add_argument("--use_pointnet2", type=int, default=0) 271 | parser.add_argument("--num_obj_points", type=int, default=512) 272 | parser.add_argument("--num_obj_keypoints", type=int, default=512) 273 | parser.add_argument("--num_body_points", type=int, default=1) 274 | parser.add_argument("--dim_body_points", type=int, default=9) 275 | parser.add_argument("--point_level", type=int, default=3) 276 | parser.add_argument("--latent_dimension", type=int, default=128) 277 | parser.add_argument("--use_contact_feature", type=int, default=0) 278 | 279 | # transformer 280 | parser.add_argument("--latent_dim", type=int, default=6) 281 | parser.add_argument("--embedding_dim", type=int, default=64) 282 | parser.add_argument("--num_heads", type=int, default=4) 283 | parser.add_argument("--ff_size", type=int, default=512) 284 | parser.add_argument("--activation", type=str, default='gelu') 285 | parser.add_argument("--dropout", type=float, default=0) 286 | parser.add_argument("--num_layers", type=int, default=4) 287 | parser.add_argument("--interaction_bias", type=int, default=0) 288 | parser.add_argument("--latent_usage", type=str, default='memory') 289 | parser.add_argument("--template_type", type=str, default='zero') 290 | parser.add_argument("--return_attention", type=int, default=1) 291 | parser.add_argument("--mask_body", type=int, default=0) 292 | parser.add_argument("--mask_prob", type=float, default=0.05) 293 | 294 | parser.add_argument("--lr", type=float, default=3e-4) 295 | parser.add_argument("--l2_norm", type=float, default=0) 296 | parser.add_argument("--robust_kl", type=int, default=1) 297 | parser.add_argument("--weight_pelvis", type=float, default=1) 298 | parser.add_argument("--weight_orient", type=float, default=1) 299 | parser.add_argument("--weight_kl", type=float, default=1e-2) 300 | parser.add_argument("--weight_dist", type=float, default=1) 301 | parser.add_argument("--weight_coord", type=float, default=1) 302 | parser.add_argument("--weight_penetration", type=float, default=0) 303 | parser.add_argument("--thresh_penetration", type=float, default=0.1) 304 | parser.add_argument("--use_annealing", type=int, default=0) 305 | 306 | parser.add_argument("--use_regressor", type=int, default=0) 307 | parser.add_argument("--raw_points", type=int, default=0) 308 | parser.add_argument("--dummy_obj", type=int, default=0) 309 | parser.add_argument("--use_contact", type=int, default=0) 310 | 311 | parser.add_argument("--learned_prior", type=int, default=0) 312 | parser.add_argument("--use_kronecker", type=int, default=0) 313 | parser.add_argument("--freeze", type=int, default=0) 314 | 315 | parser.add_argument("--used_interaction", type=str, default='all') 316 | parser.add_argument("--skip_composite", type=str, default=None) 317 | parser.add_argument("--used_instance", type=str, default=None) 318 | parser.add_argument("--scale_obj", type=int, default=0) 319 | parser.add_argument("--center_type", type=str, default='object') 320 | parser.add_argument("--rotation", type=str, default='object') 321 | parser.add_argument("--point_sample", type=str, default='random') 322 | parser.add_argument("--use_augment", type=int, default=1) 323 | parser.add_argument("--data_overwrite", type=int, default=0) 324 | parser.add_argument("--use_prox_single", type=int, default=0) 325 | parser.add_argument("--use_annotate", type=int, default=1) 326 | parser.add_argument("--include_motion", type=int, default=0) 327 | parser.add_argument("--use_floor_height", type=int, default=1) 328 | 329 | parser.add_argument("--batch_size", type=int, default=16) 330 | parser.add_argument("--num_workers", type=int, default=8) 331 | parser.add_argument("--profiler", type=str, default='simple', help='simple or advanced') 332 | parser.add_argument("--gpus", type=int, default=1) 333 | parser.add_argument("--max_epochs", type=int, default=500) 334 | parser.add_argument("--second_stage", type=int, default=20, help="annealing some loss weights in early epochs before this num") 335 | parser.add_argument("--expr_name", type=str, default=datetime.now().strftime("%H:%M:%S.%f")) 336 | parser.add_argument("--render_thresh", type=float, default=5e-2) 337 | parser.add_argument("--render_epoch", type=int, default=2333) 338 | parser.add_argument("--resume_checkpoint", type=str, default=None) 339 | parser.add_argument("--debug", type=int, default=0) 340 | args = parser.parse_args() 341 | 342 | # make demterministic 343 | pl.seed_everything(233, workers=True) 344 | # torch.autograd.set_detect_anomaly(True) 345 | # args.deterministic = True 346 | 347 | # data 348 | train_dataset = CompositeFrameDataset(split='train', augment=args.use_augment, 349 | data_overwrite=args.data_overwrite, 350 | use_prox_single=args.use_prox_single, 351 | skip_prox_composite=args.skip_composite, 352 | used_interaction=args.used_interaction, 353 | use_annotate=args.use_annotate, 354 | use_floor_height=args.use_floor_height, 355 | include_motion=args.include_motion, 356 | num_points=args.num_obj_points) 357 | test_dataset = CompositeFrameDataset(split='test', augment=False, 358 | data_overwrite=args.data_overwrite, 359 | used_interaction=args.used_interaction, 360 | include_motion=args.include_motion, 361 | use_floor_height=args.use_floor_height, 362 | use_annotate=args.use_annotate, 363 | use_prox_single=args.use_prox_single, 364 | skip_prox_composite=args.skip_composite, 365 | num_points=args.num_obj_points) 366 | 367 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, 368 | drop_last=True, pin_memory=False) #pin_memory cause warning in pytorch 1.9.0 369 | val_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, 370 | drop_last=True, pin_memory=False) 371 | print('dataset loaded') 372 | 373 | if args.resume_checkpoint is not None: 374 | print('resume training') 375 | model = LitTransformNet.load_from_checkpoint(args.resume_checkpoint, args=args) 376 | else: 377 | print('start training from scratch') 378 | model = LitTransformNet(args) 379 | 380 | # callback 381 | tb_logger = pl_loggers.TensorBoardLogger(str(results_folder / 'transform'), name=args.expr_name) 382 | save_dir = Path(tb_logger.log_dir) # for this version 383 | print(save_dir) 384 | checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=str(save_dir / 'checkpoints'), 385 | monitor="val_loss", 386 | save_weights_only=True, save_last=True) 387 | print(checkpoint_callback.dirpath) 388 | early_stop_callback = pl.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.00, patience=10000, verbose=False, 389 | mode="min") 390 | profiler = SimpleProfiler() if args.profiler == 'simple' else AdvancedProfiler(output_filename='profiling.txt') 391 | 392 | # trainer 393 | trainer = pl.Trainer.from_argparse_args(args, 394 | logger=tb_logger, 395 | profiler=profiler, 396 | progress_bar_refresh_rate=1, 397 | callbacks=[checkpoint_callback, early_stop_callback]) 398 | trainer.fit(model, train_loader, val_loader) 399 | 400 | 401 | 402 | -------------------------------------------------------------------------------- /SceneGraphNet/sgmodel/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | from utils.default_settings import dic_id2type 7 | import copy 8 | import numpy as np 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | ''' to_torch Variable ''' 13 | def to_torch(n, torch_type=torch.FloatTensor, requires_grad=False, dim_0=1): 14 | n = torch.tensor(n, requires_grad=requires_grad).type(torch_type).to(device) 15 | n = n.view(dim_0, -1) 16 | return n 17 | 18 | def get_gt_k_vec(node_list, cur_node, opt_parser): 19 | """ 20 | Get cur_node's k-vec = category + dimension + position 21 | :param node_list: 22 | :param cur_node: 23 | :param opt_parser: 24 | :return: 25 | """ 26 | 27 | if (node_list[cur_node]['type'] == 'root'): 28 | cat = 'wall' 29 | dim_vec = [0.0] * 3 30 | pos_vec = [0.0] * 3 31 | elif (node_list[cur_node]['type'] == 'wall'): 32 | cat = 'wall' 33 | dim_vec = node_list[cur_node]['self_info']['dim'] 34 | pos_vec = node_list[cur_node]['self_info']['translation'] 35 | else: 36 | if(len(node_list[cur_node]['self_info']['node_model_id']) > 2 and 37 | node_list[cur_node]['self_info']['node_model_id'][0:2] == 'EX'): 38 | cat = node_list[cur_node]['self_info']['node_model_id'][3:] 39 | else: 40 | cat = dic_id2type[node_list[cur_node]['self_info']['node_model_id']][1] 41 | dim_vec = node_list[cur_node]['self_info']['dim'] 42 | pos_vec = node_list[cur_node]['self_info']['translation'] 43 | 44 | cat_vec = [0.0] * (len(opt_parser.cat2id.keys()) + 1) 45 | cat_vec[int(opt_parser.cat2id[cat])] = 1.0 46 | 47 | return cat_vec + dim_vec + pos_vec 48 | 49 | 50 | class AggregateMaxPoolEnc(nn.Module): 51 | def __init__(self, k=54, d=100, h=300): 52 | super(AggregateMaxPoolEnc, self).__init__() 53 | 54 | self.enc = nn.Sequential( 55 | nn.Linear(d, h), 56 | nn.ReLU(), 57 | nn.Linear(h, d), 58 | # nn.Tanh() 59 | ) 60 | self.msg = nn.Sequential( 61 | nn.Linear(2 * d, h), 62 | nn.ReLU(), 63 | nn.Linear(h, d), 64 | ) 65 | 66 | def forward(self, d_vec, pre_vec, cur_d_vec, w=1.0, cat_msg=[False]): 67 | if (cat_msg[0]): 68 | msg = self.msg(torch.cat((cur_d_vec, d_vec), dim=1)) 69 | else: 70 | msg = d_vec 71 | 72 | d_vec = self.enc(msg) * w 73 | compare = torch.stack((pre_vec, d_vec), dim=2) 74 | d_vec, _ = torch.max(compare, 2) 75 | d_vec.view(d_vec.shape) 76 | return d_vec 77 | 78 | class AggregateSumEnc(nn.Module): 79 | def __init__(self, k=54, d=100, h=300): 80 | super(AggregateSumEnc, self).__init__() 81 | 82 | self.enc = nn.Sequential( 83 | nn.Linear(d, h), 84 | nn.ReLU(), 85 | nn.Linear(h, d), 86 | # nn.Tanh() 87 | ) 88 | self.msg = nn.Sequential( 89 | nn.Linear(2 * d, h), 90 | nn.ReLU(), 91 | nn.Linear(h, d), 92 | ) 93 | 94 | def forward(self, d_vec, pre_vec, cur_d_vec, w=1.0, cat_msg=[False]): 95 | if (cat_msg[0]): 96 | msg = self.msg(torch.cat((cur_d_vec, d_vec), dim=1)) 97 | else: 98 | msg = d_vec 99 | 100 | d_vec = pre_vec + self.enc(msg) * w 101 | return d_vec 102 | 103 | class AggregateCatEnc(nn.Module): 104 | def __init__(self, k=54, d=100, h=300): 105 | super(AggregateCatEnc, self).__init__() 106 | 107 | self.enc = nn.Sequential( 108 | nn.Linear(d * 2, h), 109 | nn.ReLU(), 110 | nn.Linear(h, d), 111 | # nn.Tanh() 112 | ) 113 | self.msg = nn.Sequential( 114 | nn.Linear(2 * d, h), 115 | nn.ReLU(), 116 | nn.Linear(h, d), 117 | ) 118 | 119 | def forward(self, d_vec, pre_vec, cur_d_vec, w=1.0, cat_msg=[False]): 120 | if (cat_msg[0]): 121 | msg = self.msg(torch.cat((cur_d_vec, d_vec), dim=1)) 122 | else: 123 | msg = d_vec 124 | 125 | feat = torch.cat((pre_vec, msg), dim=1) 126 | d_vec = self.enc(feat) * w 127 | return d_vec 128 | 129 | class AggregateGRUEnc(nn.Module): 130 | def __init__(self, k=54, d=100, h=300): 131 | super(AggregateGRUEnc, self).__init__() 132 | 133 | self.w_x = nn.Sequential( 134 | nn.Linear(d, h), 135 | nn.ReLU(), 136 | nn.Linear(h, d), 137 | ) 138 | self.w_h = nn.Sequential( 139 | nn.Linear(d, h), 140 | nn.ReLU(), 141 | nn.Linear(h, d) 142 | ) 143 | self.msg = nn.Sequential( 144 | nn.Linear(2 * d, h), 145 | nn.ReLU(), 146 | nn.Linear(h, d), 147 | ) 148 | 149 | def forward(self, d_vec, pre_vec, cur_d_vec, w=1.0, cat_msg=[False]): 150 | if(cat_msg[0]): 151 | msg = self.msg(torch.cat((cur_d_vec, d_vec), dim=1)) 152 | else: 153 | msg = d_vec 154 | 155 | ht = self.w_h(pre_vec) + self.w_x(msg) * w 156 | # ht = self.act(ht) 157 | return ht 158 | 159 | 160 | class UpdateEnc(nn.Module): 161 | def __init__(self, k=54, d=100, h=300): 162 | super(UpdateEnc, self).__init__() 163 | 164 | self.enc = nn.Sequential( 165 | nn.Linear(d * 7, h), 166 | nn.ReLU(), 167 | nn.Linear(h, d), 168 | # nn.Tanh() 169 | ) 170 | 171 | def forward(self, self_vec, p_sup_vec, c_sup_vec, p_sur_vec, c_sur_vec, n_vec, co_vec): 172 | feat = torch.cat((self_vec, p_sup_vec, c_sup_vec, p_sur_vec, c_sur_vec, n_vec, co_vec), dim=1) 173 | d_vec = self.enc(feat) 174 | return d_vec 175 | 176 | 177 | class BoxEnc(nn.Module): 178 | def __init__(self, k=54, d=100, h=300): 179 | super(BoxEnc, self).__init__() 180 | 181 | self.enc = nn.Sequential( 182 | nn.Linear(k, h), 183 | nn.ReLU(), 184 | nn.Linear(h, d), 185 | ) 186 | 187 | def forward(self, k_vec): 188 | d_vec = self.enc(k_vec) 189 | return d_vec 190 | 191 | 192 | class LearnedWeight(nn.Module): 193 | def __init__(self, k=54, h=300, dis_vec_dim=3): 194 | super(LearnedWeight, self).__init__() 195 | 196 | self.offset_enc = nn.Sequential( 197 | nn.Linear(dis_vec_dim, h), 198 | nn.ReLU(), 199 | nn.Linear(h, k) 200 | ) 201 | 202 | self.k_offset_enc = nn.Sequential( 203 | nn.Linear(k * 3, h), 204 | nn.ReLU(), 205 | nn.Linear(h, 1), 206 | nn.Sigmoid() 207 | ) 208 | 209 | def forward(self, k_vec1, k_vec2, offset_vec): 210 | offset_vec = self.offset_enc(offset_vec) 211 | k_offset_vec = torch.cat((k_vec1, k_vec2, offset_vec), dim=1) 212 | w = self.k_offset_enc(k_offset_vec) 213 | return w 214 | 215 | 216 | class FullEnc(nn.Module): 217 | def __init__(self, k=55, d=100, h=300, aggregate_func='GRU'): 218 | super(FullEnc, self).__init__() 219 | 220 | if(aggregate_func == 'Sum'): 221 | AggregateEnc = AggregateSumEnc 222 | elif(aggregate_func == 'GRU'): 223 | AggregateEnc = AggregateGRUEnc 224 | elif(aggregate_func == 'MaxPool'): 225 | AggregateEnc = AggregateMaxPoolEnc 226 | elif (aggregate_func == 'CatRNN'): 227 | print('CatAggregate') 228 | AggregateEnc = AggregateCatEnc 229 | else: 230 | AggregateEnc = None 231 | print('Aggregation function selection error') 232 | exit(-1) 233 | 234 | self.aggregate_neighbor_enc = AggregateEnc(k, d, h) 235 | self.aggregate_child_supp_enc = AggregateEnc(k, d, h) 236 | self.aggregate_child_surr_enc = AggregateEnc(k, d, h) 237 | self.aggregate_parent_supp_enc = AggregateEnc(k, d, h) 238 | self.aggregate_parent_surr_enc = AggregateEnc(k, d, h) 239 | self.aggregate_cooc_enc = AggregateEnc(k, d, h) 240 | 241 | dis_vec_dim = 3 242 | 243 | self.learned_weight = LearnedWeight(k, h, dis_vec_dim=dis_vec_dim) 244 | 245 | self.aggregate_self_enc = UpdateEnc(k, d, h) 246 | 247 | self.box_enc = BoxEnc(k, d, h) 248 | 249 | def aggregate_neighbor_func(self, d_vec, pre_vec, cur_d_vec, w=1.0, cat_msg=False): 250 | return self.aggregate_neighbor_enc(d_vec, pre_vec, cur_d_vec, w, cat_msg) 251 | 252 | def aggregate_child_supp_func(self, d_vec, pre_vec, cur_d_vec, w=1.0, cat_msg=False): 253 | return self.aggregate_child_supp_enc(d_vec, pre_vec, cur_d_vec, w, cat_msg) 254 | 255 | def aggregate_child_surr_func(self, d_vec, pre_vec, cur_d_vec, w=1.0, cat_msg=False): 256 | return self.aggregate_child_surr_enc(d_vec, pre_vec, cur_d_vec, w, cat_msg) 257 | 258 | def aggregate_parent_supp_func(self, d_vec, pre_vec, cur_d_vec, w=1.0, cat_msg=False): 259 | return self.aggregate_parent_supp_enc(d_vec, pre_vec, cur_d_vec, w, cat_msg) 260 | 261 | def aggregate_parent_surr_func(self, d_vec, pre_vec, cur_d_vec, w=1.0, cat_msg=False): 262 | return self.aggregate_parent_surr_enc(d_vec, pre_vec, cur_d_vec, w, cat_msg) 263 | 264 | def aggregate_cooc_func(self, d_vec, pre_vec, cur_d_vec, w=1.0, cat_msg=False): 265 | return self.aggregate_cooc_enc(d_vec, pre_vec, cur_d_vec, w, cat_msg) 266 | 267 | def learned_weight_func(self, k_vec1, k_vec2, offset_vec): 268 | return self.learned_weight(k_vec1, k_vec2, offset_vec) 269 | 270 | 271 | def aggregate_self_func(self, self_vec, p_sup_vec, c_sup_vec, p_sur_vec, c_sur_vec, n_vec, co_vec): 272 | return self.aggregate_self_enc(self_vec, p_sup_vec, c_sup_vec, p_sur_vec, c_sur_vec, n_vec, co_vec) 273 | 274 | def cat_self_func(self, self_vec, p_sup_vec, c_sup_vec, p_sur_vec, c_sur_vec, n_vec, co_vec): 275 | return torch.cat((self_vec, p_sup_vec, c_sup_vec, p_sur_vec, c_sur_vec, n_vec, co_vec), dim=1) 276 | 277 | def box_enc_func(self, k_vec): 278 | return self.box_enc(k_vec) 279 | 280 | 281 | def encode_tree_fold(fold, raw_node_list, rand_path, opt_parser): 282 | node_list = copy.deepcopy(raw_node_list) 283 | d_vec_dim = opt_parser.d_vec_dim 284 | 285 | encode_fold_list = [] 286 | rand_path_node_name_order = [] 287 | tree_leaf_node = rand_path[-1] 288 | 289 | def encode_node(node_list, leaf_node=tree_leaf_node, step=0): 290 | """ 291 | Graph message passing by torchfold encoding 292 | :param node_list: 293 | :param leaf_node: 294 | :param step: 295 | :return: 296 | """ 297 | 298 | # init d-vec for all nodes 299 | if(step == 0): 300 | 301 | # loop to get each node's k-vec and d-vec 302 | for cur_node in node_list.keys(): 303 | 304 | # if leaf node, reset its k-vec to all-zeros to represent it is missing 305 | if(cur_node == leaf_node): 306 | missing_cat = [0.0] * len(opt_parser.cat2id.keys()) + [1.0] 307 | missing_dim_pos = [0.0] * 3 + node_list[cur_node]['self_info']['translation'] 308 | node_list[cur_node]['k-vec'] = missing_cat + missing_dim_pos 309 | else: 310 | node_list[cur_node]['k-vec'] = get_gt_k_vec(node_list, cur_node, opt_parser) 311 | 312 | node_list[cur_node]['k-vec'] = to_torch(node_list[cur_node]['k-vec']) 313 | node_list[cur_node]['d-vec'] = fold.add('box_enc_func', node_list[cur_node]['k-vec']) 314 | node_list[cur_node]['w'] = {} 315 | node_list[cur_node]['dis'] = {} 316 | 317 | # loop to get each pair of neighbor nodes' attention weight 318 | for cur_node in node_list.keys(): 319 | for neighbor_node in node_list.keys(): 320 | 321 | dis_feat_vec = [node_list[cur_node]['self_info']['translation'][0] - 322 | node_list[neighbor_node]['self_info']['translation'][0], 323 | node_list[cur_node]['self_info']['translation'][1] - 324 | node_list[neighbor_node]['self_info']['translation'][1], 325 | node_list[cur_node]['self_info']['translation'][2] - 326 | node_list[neighbor_node]['self_info']['translation'][2]] 327 | dis = np.sqrt(dis_feat_vec[0] ** 2 + dis_feat_vec[1] ** 2 + dis_feat_vec[2] ** 2) 328 | dis_feat = dis_feat_vec 329 | 330 | node_list[cur_node]['dis'][neighbor_node] = dis 331 | node_list[cur_node]['w'][neighbor_node] = \ 332 | fold.add('learned_weight_func', node_list[cur_node]['k-vec'], 333 | node_list[neighbor_node]['k-vec'], 334 | to_torch(dis_feat)) 335 | 336 | # graph message passing 337 | else: 338 | for cur_node in node_list.keys(): 339 | cur_node_d_vec = node_list[cur_node]['pre-d-vec'] 340 | 341 | # message from parents (supported-by, surrounded-by relation) 342 | aggregate_parent_d_vec ={'supp' : to_torch([0.0] * d_vec_dim), 'surr' : to_torch([0.0] * d_vec_dim)} 343 | for parent_node, parent_node_type in node_list[cur_node]['parents']: 344 | parent_d_vec = node_list[parent_node]['pre-d-vec'] 345 | aggregate_parent_d_vec[parent_node_type] = \ 346 | fold.add('aggregate_parent_{}_func'.format(parent_node_type), 347 | parent_d_vec, 348 | aggregate_parent_d_vec[parent_node_type], 349 | cur_node_d_vec, 350 | to_torch([1.]), 351 | to_torch(opt_parser.cat_msg, torch.bool)) 352 | 353 | # message from siblings (next-to relation) 354 | aggregate_neighbor_d_vec = to_torch([0.0] * d_vec_dim) 355 | for sibling_node_i, _ in node_list[cur_node]['siblings']: 356 | sibling_node_d_vec = node_list[sibling_node_i]['pre-d-vec'] 357 | aggregate_neighbor_d_vec = fold.add('aggregate_neighbor_func', 358 | sibling_node_d_vec, 359 | aggregate_neighbor_d_vec, 360 | cur_node_d_vec, 361 | to_torch([1.]), 362 | to_torch([opt_parser.cat_msg], torch.bool)) 363 | 364 | # message from childs (supporting, surrounding relation) 365 | aggregate_child_d_vec = {'supp' : to_torch([0.0] * d_vec_dim), 'surr' : to_torch([0.0] * d_vec_dim)} 366 | for child_node_i, child_node_type_i in node_list[cur_node]['childs']: 367 | child_node_d_vec = node_list[child_node_i]['pre-d-vec'] 368 | aggregate_child_d_vec[child_node_type_i] = \ 369 | fold.add('aggregate_child_{}_func'.format(child_node_type_i), 370 | child_node_d_vec, 371 | aggregate_child_d_vec[child_node_type_i], 372 | cur_node_d_vec, 373 | to_torch([1.]), 374 | to_torch([opt_parser.cat_msg], torch.bool)) 375 | 376 | # message from loose neighbors (co-occurring relation) 377 | aggregate_cooc_d_vec = to_torch([0.0] * d_vec_dim) 378 | if(opt_parser.aggregate_in_order): 379 | all_neighbor_nodes = list(node_list.keys()) 380 | all_neighbor_nodes.sort(key=lambda x: node_list[cur_node]['dis'][x], reverse=True) 381 | else: 382 | all_neighbor_nodes = node_list.keys() 383 | for neighbor_node in all_neighbor_nodes: 384 | if (neighbor_node != cur_node): 385 | w = node_list[cur_node]['w'][neighbor_node] 386 | neighbor_node_d_vec = node_list[neighbor_node]['pre-d-vec'] 387 | aggregate_cooc_d_vec = fold.add('aggregate_cooc_func', 388 | neighbor_node_d_vec, 389 | aggregate_cooc_d_vec, 390 | cur_node_d_vec, 391 | w, 392 | to_torch([opt_parser.cat_msg], torch.bool)) 393 | 394 | node_list[cur_node]['d-vec'] = fold.add('aggregate_self_func', 395 | node_list[cur_node]['pre-d-vec'], 396 | aggregate_parent_d_vec['supp'], 397 | aggregate_child_d_vec['supp'], 398 | aggregate_parent_d_vec['surr'], 399 | aggregate_child_d_vec['surr'], 400 | aggregate_neighbor_d_vec, 401 | aggregate_cooc_d_vec) 402 | node_list[cur_node]['cat-d-vec'] = fold.add('cat_self_func', 403 | node_list[cur_node]['pre-d-vec'], 404 | aggregate_parent_d_vec['supp'], 405 | aggregate_child_d_vec['supp'], 406 | aggregate_parent_d_vec['surr'], 407 | aggregate_child_d_vec['surr'], 408 | aggregate_neighbor_d_vec, 409 | aggregate_cooc_d_vec) 410 | 411 | # end of func 412 | 413 | 414 | for i in range(opt_parser.K): 415 | encode_node(node_list, leaf_node=tree_leaf_node, step=i) 416 | for cur_node in node_list.keys(): 417 | node_list[cur_node]['pre-d-vec'] = node_list[cur_node]['d-vec'] 418 | 419 | 420 | # encode all d-vec along rand path 421 | if(opt_parser.decode_cat_d_vec == True): 422 | # use cat version d-vec for decoding 423 | encode_fold_list.append(node_list[tree_leaf_node]['cat-d-vec']) 424 | else: 425 | # use normal version d-vec for decoding 426 | encode_fold_list.append(node_list[tree_leaf_node]['d-vec']) 427 | rand_path_node_name_order.append(tree_leaf_node) 428 | 429 | return encode_fold_list, rand_path_node_name_order 430 | 431 | 432 | 433 | class Root_to_leaf_Dec(nn.Module): 434 | def __init__(self, k=54, d=100, r=28, h=300): 435 | super(Root_to_leaf_Dec, self).__init__() 436 | 437 | self.dec = nn.Sequential( 438 | nn.Linear(d, h), 439 | nn.ReLU(), 440 | nn.Linear(h, d) 441 | ) 442 | 443 | def forward(self, parent_d_vec): 444 | dec_vec = self.dec(parent_d_vec) 445 | return dec_vec 446 | 447 | class Cat_Root_to_leaf_Dec(nn.Module): 448 | def __init__(self, k=54, d=100, r=28, h=300): 449 | super(Cat_Root_to_leaf_Dec, self).__init__() 450 | 451 | self.dec = nn.Sequential( 452 | nn.Linear(d * 7, h), 453 | nn.ReLU(), 454 | nn.Linear(h, d) 455 | ) 456 | 457 | def forward(self, parent_d_vec): 458 | dec_vec = self.dec(parent_d_vec) 459 | return dec_vec 460 | 461 | class BoxDec(nn.Module): 462 | def __init__(self, k=54, d=100, r=28, h=300): 463 | super(BoxDec, self).__init__() 464 | 465 | self.dec = nn.Sequential( 466 | nn.Linear(d, h), 467 | nn.ReLU(), 468 | nn.Linear(h, k) 469 | ) 470 | 471 | def forward(self, d_vec): 472 | k_vec = self.dec(d_vec) 473 | return k_vec 474 | 475 | class FullDec(nn.Module): 476 | def __init__(self, k=55, d=100, r=28, h=300, root_d=350, root_h=1050): 477 | super(FullDec, self).__init__() 478 | 479 | self.root_to_leaf_dec = Root_to_leaf_Dec(k, d, r, h) 480 | self.cat_root_to_leaf_dec = Cat_Root_to_leaf_Dec(k, d, r, h) 481 | self.box_dec = BoxDec(k, d, r, h) 482 | 483 | def root_dec_func(self, parent_d_vec): 484 | return self.root_to_leaf_dec(parent_d_vec) 485 | 486 | def cat_root_dec_func(self, parent_d_vec): 487 | return self.cat_root_to_leaf_dec(parent_d_vec) 488 | 489 | def box_dec_func(self, d_vec): 490 | return self.box_dec(d_vec) 491 | 492 | def forward(self, d_vec): 493 | box_d_vec = self.root_to_leaf_dec(d_vec) 494 | k_vec = self.box_dec(box_d_vec) 495 | 496 | return k_vec 497 | 498 | def decode_tree_fold(fold, d_vec, opt_parser=None): 499 | if(opt_parser != None and opt_parser.decode_cat_d_vec == True): 500 | leaf_node_d_vec = fold.add('cat_root_dec_func', d_vec) 501 | else: 502 | leaf_node_d_vec = fold.add('root_dec_func', d_vec) 503 | return fold.add('box_dec_func', leaf_node_d_vec) 504 | --------------------------------------------------------------------------------