├── Model.py ├── README.md ├── SphereDataset.py ├── Test.py ├── Train.py ├── Utils.py ├── data_visualization.py ├── state_stat.npz └── state_stat_sample.npz /Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Encoder(nn.Module): 5 | def __init__(self, in_features, out_features, hidden_features, norm_type): 6 | super(Encoder, self).__init__() 7 | self.in_features = in_features 8 | self.out_features = out_features 9 | self.hidden_features = hidden_features 10 | self.fc_in = nn.Linear(self.in_features, self.hidden_features, bias = True) 11 | # self.ln_in = nn.LayerNorm(self.hidden_features) 12 | self.ac_in = nn.ReLU() 13 | self.fc_h0 = nn.Linear(self.hidden_features, self.hidden_features, bias = True) 14 | # self.ln_h0 = nn.LayerNorm(self.hidden_features) 15 | self.ac_h0 = nn.ReLU() 16 | self.fc_h1 = nn.Linear(self.hidden_features, self.hidden_features, bias = True) 17 | # self.ln_h1 = nn.LayerNorm(self.hidden_features) 18 | self.ac_h1 = nn.ReLU() 19 | self.fc_out = nn.Linear(self.hidden_features, self.out_features, bias = True) 20 | 21 | if norm_type == 'bn': 22 | self.norm_out = nn.BatchNorm1d(self.out_features) 23 | else: 24 | self.norm_out = nn.LayerNorm(self.out_features) 25 | 26 | def forward(self, x_in): 27 | x = self.fc_in(x_in) 28 | # x = self.ln_in(x) 29 | x = self.ac_in(x) 30 | 31 | x = self.fc_h0(x) 32 | # x = self.ln_h0(x) 33 | x = self.ac_h0(x) 34 | 35 | x = self.fc_h1(x) 36 | # x = self.ln_h1(x) 37 | x = self.ac_h1(x) 38 | 39 | x = self.fc_out(x) 40 | 41 | x = x.view(-1, x.size(-1)) 42 | x = self.norm_out(x) 43 | x = x.view(x_in.size(0), x_in.size(1), x.size(-1)) 44 | return x 45 | 46 | class Decoder(nn.Module): 47 | def __init__(self, in_features, out_features, hidden_features): 48 | super(Decoder, self).__init__() 49 | self.in_features = in_features 50 | self.out_features = out_features 51 | self.hidden_features = hidden_features 52 | self.fc_in = nn.Linear(self.in_features, self.hidden_features, bias = True) 53 | self.ac_in = nn.ReLU() 54 | self.fc_h0 = nn.Linear(self.hidden_features, self.hidden_features, bias = True) 55 | self.ac_h0 = nn.ReLU() 56 | self.fc_h1 = nn.Linear(self.hidden_features, self.hidden_features, bias = True) 57 | self.ac_h1 = nn.ReLU() 58 | self.fc_out = nn.Linear(self.hidden_features, self.out_features, bias = True) 59 | 60 | def forward(self, x): 61 | x = self.fc_in(x) 62 | x = self.ac_in(x) 63 | x = self.fc_h0(x) 64 | x = self.ac_h0(x) 65 | x = self.fc_h1(x) 66 | x = self.ac_h1(x) 67 | x = self.fc_out(x) 68 | return x 69 | 70 | 71 | class Processor(nn.Module): 72 | def __init__(self, in_features, out_features, hidden_features, norm_type): 73 | super(Processor, self).__init__() 74 | self.in_features = in_features 75 | self.out_features = out_features 76 | self.hidden_features = hidden_features 77 | self.fc_in = nn.Linear(self.in_features, self.hidden_features, bias = True) 78 | # self.ln_in = nn.LayerNorm(self.hidden_features) 79 | self.ac_in = nn.ReLU() 80 | self.fc_h0 = nn.Linear(self.hidden_features, self.hidden_features, bias = True) 81 | # self.ln_h0 = nn.LayerNorm(self.hidden_features) 82 | self.ac_h0 = nn.ReLU() 83 | self.fc_h1 = nn.Linear(self.hidden_features, self.hidden_features, bias = True) 84 | # self.ln_h1 = nn.LayerNorm(self.hidden_features) 85 | self.ac_h1 = nn.ReLU() 86 | self.fc_out = nn.Linear(self.hidden_features, self.out_features, bias = True) 87 | if norm_type == 'bn': 88 | self.norm_out = nn.BatchNorm1d(self.out_features) 89 | else: 90 | self.norm_out = nn.LayerNorm(self.out_features) 91 | 92 | def forward(self, x_in): 93 | x = self.fc_in(x_in) 94 | # x = self.ln_in(x) 95 | x = self.ac_in(x) 96 | 97 | x = self.fc_h0(x) 98 | # x = self.ln_h0(x) 99 | x = self.ac_h0(x) 100 | 101 | x = self.fc_h1(x) 102 | # x = self.ln_h1(x) 103 | x = self.ac_h1(x) 104 | 105 | x = self.fc_out(x) 106 | x = x.view(-1, x.size(-1)) 107 | x = self.norm_out(x) 108 | x = x.view(x_in.size(0), x_in.size(1), x.size(-1)) 109 | return x 110 | 111 | class Processor_Res(nn.Module): 112 | def __init__(self, in_features, out_features, hidden_features, norm_type): 113 | super(Processor_Res, self).__init__() 114 | self.in_features = in_features 115 | self.out_features = out_features 116 | self.hidden_features = hidden_features 117 | self.fc_in = nn.Linear(self.in_features, self.hidden_features, bias = True) 118 | # self.ln_in = nn.LayerNorm(self.hidden_features) 119 | self.ac_in = nn.ReLU() 120 | self.fc_h0 = nn.Linear(self.hidden_features, self.hidden_features, bias = True) 121 | # self.ln_h0 = nn.LayerNorm(self.hidden_features) 122 | self.ac_h0 = nn.ReLU() 123 | self.fc_h1 = nn.Linear(self.hidden_features, self.hidden_features, bias = True) 124 | # self.ln_h1 = nn.LayerNorm(self.hidden_features) 125 | self.ac_h1 = nn.ReLU() 126 | self.fc_out = nn.Linear(self.hidden_features, self.out_features, bias = True) 127 | if norm_type == 'bn': 128 | self.norm_out = nn.BatchNorm1d(self.out_features) 129 | else: 130 | self.norm_out = nn.LayerNorm(self.out_features) 131 | 132 | def forward(self, x_in): 133 | x = self.fc_in(x_in) 134 | # x = self.ln_in(x) 135 | x = self.ac_in(x) 136 | 137 | x = self.fc_h0(x) 138 | # x = self.ln_h0(x) 139 | x = self.ac_h0(x) 140 | 141 | x = self.fc_h1(x) 142 | # x = self.ln_h1(x) 143 | x = self.ac_h1(x) 144 | 145 | x = self.fc_out(x) 146 | 147 | x = x + x_in[:, :, :x.size(-1)] 148 | 149 | x = x.view(-1, x.size(-1)) 150 | x = self.norm_out(x) 151 | x = x.view(x_in.size(0), x_in.size(1), x.size(-1)) 152 | 153 | return x 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of Learning Mesh-based Simulation With Graph Networks 2 | 3 | This is the unofficial implementation of the approach described in the paper: 4 | > Tobias Pfaff, Meire Fortunato, Alvaro Sanchez-Gonzalez and Peter W. Battaglia [Learning Mesh-based Simulation With Graph Networks](https://openreview.net/pdf?id=roNqYL0_XP). In *ICLR*, 2021. 5 | 6 | Work in progress. 7 | 8 | [Current progress](https://drive.google.com/file/d/1znvXDCT-_EBQOUpeePhMU3_zGYeLg3Xo/view?usp=sharing): able for long-term rollout with consideration of (self-)collision. 9 | 10 | 11 | -------------------------------------------------------------------------------- /SphereDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import scipy.io as io 5 | from torch.utils.data import Dataset, DataLoader 6 | import matplotlib.pyplot as plt 7 | 8 | def collate_fn(batch): 9 | data0 = [item[0] for item in batch] 10 | data1 = [item[1] for item in batch] 11 | data2 = [item[2] for item in batch] 12 | data3 = [item[3] for item in batch] 13 | data4 = [item[4] for item in batch] 14 | data5 = [item[5] for item in batch] 15 | data6 = [item[6] for item in batch] 16 | data7 = [item[7] for item in batch] 17 | return [data0, data1, data2, data3, data4, data5, data6, data7] 18 | 19 | class SphereDataset(Dataset): 20 | def __init__(self, data_dir, length, train, noise): 21 | self.data = [] 22 | self.train = train 23 | self.length = length 24 | self.noise = noise 25 | self.vel_scale = np.array([0.005, 0.005, 0.005]) / 3.0 26 | self.kinematic_node = [1, 645] 27 | for data_id in range(100): 28 | state_dir = os.path.join(data_dir, ('data/%04d' % data_id)) 29 | for file_id in range(1, 498): 30 | pre_file = os.path.join(state_dir, '%03d_cloth.txt' % (file_id - 1)) 31 | cur_file = os.path.join(state_dir, '%03d_cloth.txt' % file_id) 32 | nxt_file = os.path.join(state_dir, '%03d_cloth.txt' % (file_id + 1)) 33 | self.data.append([pre_file, cur_file, nxt_file]) 34 | self.cloth_topo = np.load(os.path.join(data_dir, 'cloth_connection.npy'), allow_pickle = True).item() 35 | self.sphere_topo = np.load(os.path.join(data_dir, 'sphere_connection.npy'), allow_pickle = True) 36 | self.adj_map = np.load(os.path.join(data_dir, 'adj_map.npy'), allow_pickle = True) 37 | self.uvedge_node_i = np.load(os.path.join(data_dir, 'uvedge_node_i.npy'), allow_pickle = True) 38 | self.uvedge_node_j = np.load(os.path.join(data_dir, 'uvedge_node_j.npy'), allow_pickle = True) 39 | 40 | state_stat = np.load('state_stat_sample_ball.npz') 41 | self.cloth_mean = state_stat['arr_0'].item()['cloth_mean'] 42 | self.cloth_std = state_stat['arr_0'].item()['cloth_std'] 43 | self.ball_mean = state_stat['arr_0'].item()['ball_mean'] 44 | self.ball_std = state_stat['arr_0'].item()['ball_std'] 45 | self.uv_mean = state_stat['arr_0'].item()['uv_mean'] 46 | self.uv_std = state_stat['arr_0'].item()['uv_std'] 47 | self.worldcloth_mean = state_stat['arr_0'].item()['worldcloth_mean'] 48 | self.worldcloth_std = state_stat['arr_0'].item()['worldcloth_std'] 49 | self.worldball_mean = state_stat['arr_0'].item()['worldball_mean'] 50 | self.worldball_std = state_stat['arr_0'].item()['worldball_std'] 51 | self.cloth_nxt_mean = state_stat['arr_0'].item()['cloth_nxt_mean'] 52 | self.cloth_nxt_std = state_stat['arr_0'].item()['cloth_nxt_std'] 53 | self.set_seed = False 54 | self.cloth_file_name = None 55 | 56 | def __len__(self): 57 | return len(self.data) 58 | 59 | def GetState(self, index): 60 | #### get previous state to calculate the velocity #### 61 | # print(self.data[index][0]) 62 | p = 1.0 63 | cloth_pre_file = self.data[index][0] 64 | cloth_pre_data = [] 65 | with open(cloth_pre_file, 'r') as f: 66 | for line in f: 67 | line = line.split('\n')[0] 68 | cloth_pre_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 69 | cloth_pre_data = np.array(cloth_pre_data) 70 | f.close() 71 | 72 | ball_pre_file = self.data[index][0].replace('cloth', 'ball') 73 | ball_pre_data = [] 74 | with open(ball_pre_file, 'r'): 75 | for line in open(ball_pre_file, 'r'): 76 | line = line.split('\n')[0] 77 | ball_pre_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 78 | ball_pre_data = np.array(ball_pre_data) 79 | f.close() 80 | 81 | #### get current state, including vertex and edge information, as input #### 82 | cloth_file = self.data[index][1] 83 | ball_file = self.data[index][1].replace('cloth', 'ball') 84 | uv_file = self.data[index][1].replace('cloth', 'uv') 85 | world_file = self.data[index][1].replace('cloth', 'world') 86 | 87 | cloth_data = [] 88 | ball_data = [] 89 | uv_data = [] 90 | worldcloth_data = [] 91 | worldball_data = [] 92 | with open(cloth_file, 'r') as f: 93 | for line in f: 94 | line = line.split('\n')[0] 95 | cloth_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 96 | f.close() 97 | 98 | with open(ball_file, 'r') as f: 99 | for line in f: 100 | line = line.split('\n')[0] 101 | ball_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 102 | f.close() 103 | 104 | with open(uv_file, 'r') as f: 105 | for line in f: 106 | line = line.split('\n')[0] 107 | uv_data.append(np.array([float(data) for data in line.split(' ')])) 108 | f.close() 109 | 110 | if not self.noise: 111 | with open(world_file, 'r') as f: 112 | for line in f: 113 | line = line.split('\n')[0] 114 | if line.split(' ')[0] == '1': 115 | feat = np.array([float(data) for data in line.split(' ')[1:]]) 116 | worldball_data.append(feat) 117 | if line.split(' ')[0] == '0': 118 | feat = np.array([float(data) for data in line.split(' ')[1:]]) 119 | worldcloth_data.append(feat) 120 | f.close() 121 | 122 | cloth_data = np.array(cloth_data) 123 | ball_data = np.array(ball_data) 124 | uv_data = np.array(uv_data) 125 | 126 | cloth_data_noise = None 127 | if self.noise and (p > 0.5): 128 | delta_x = np.random.normal(0, self.vel_scale[0], cloth_data[:, 0:1].shape) 129 | delta_y = np.random.normal(0, self.vel_scale[1], cloth_data[:, 1:2].shape) 130 | delta_z = np.random.normal(0, self.vel_scale[2], cloth_data[:, 2:3].shape) 131 | delta = np.concatenate([delta_x, delta_y, delta_z], -1) 132 | #### zero-out kinematic node #### 133 | delta[self.kinematic_node] = 0.0 134 | cloth_data_noise = cloth_data[:, :3] + delta 135 | 136 | #### get next state, mainly the position, as output #### 137 | cloth_nxt_file = self.data[index][2] 138 | cloth_nxt_data = [] 139 | with open(cloth_nxt_file, 'r') as f: 140 | for line in f: 141 | line = line.split('\n')[0] 142 | cloth_nxt_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 143 | f.close() 144 | cloth_nxt_data = np.array(cloth_nxt_data) 145 | 146 | #### get next kinematics node information, mainly the position, as actuator #### 147 | ball_nxt_file = self.data[index][2].replace('cloth', 'ball') 148 | ball_nxt_data = [] 149 | with open(ball_nxt_file, 'r') as f: 150 | for line in f: 151 | line = line.split('\n')[0] 152 | ball_nxt_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 153 | f.close() 154 | ball_nxt_data = np.array(ball_nxt_data) 155 | 156 | if not (self.noise and p > 0.5): 157 | #### get the velocity information #### 158 | cloth_vel = cloth_data[:, :3] - cloth_pre_data[:, :3] 159 | cloth_label = np.zeros((cloth_vel.shape[0], 3)) 160 | cloth_label[:, 0] = 1.0 161 | #### kinematic node #### 162 | for node_idx in self.kinematic_node: 163 | cloth_label[node_idx] = np.array([0.0, 1.0, 0.0]) 164 | cloth_state = np.concatenate([cloth_label, cloth_vel], -1) 165 | #### get the final state information #### 166 | cloth_acc = cloth_nxt_data[:, :3] + cloth_pre_data[:, :3] - 2 * cloth_data[:, :3] 167 | else: 168 | cloth_vel_noise = cloth_data_noise[:, :3] - cloth_pre_data[:, :3] 169 | cloth_nxt_vel = cloth_nxt_data[:, :3] - cloth_data_noise[:, :3] 170 | cloth_acc_p = cloth_nxt_vel - cloth_vel_noise 171 | 172 | cloth_nxt_vel = cloth_nxt_data[:, :3] - cloth_data[:, :3] 173 | cloth_acc_v = cloth_nxt_vel - cloth_vel_noise 174 | cloth_acc = 0.1 * cloth_acc_p + 0.9 * cloth_acc_v 175 | 176 | #### get the velocity information #### 177 | cloth_label = np.zeros((cloth_vel_noise.shape[0], 3)) 178 | cloth_label[:, 0] = 1.0 179 | #### kinematic node #### 180 | for node_idx in self.kinematic_node: 181 | cloth_label[node_idx] = np.array([0.0, 1.0, 0.0]) 182 | cloth_state = np.concatenate([cloth_label, cloth_vel_noise], -1) 183 | 184 | ### recompute the uvedge feature #### 185 | uvedge_ij = cloth_data_noise[self.uvedge_node_i, :3] - cloth_data_noise[self.uvedge_node_j, :3] 186 | uvedge_ij_norm = np.linalg.norm(uvedge_ij, ord = 2, axis = -1, keepdims = True) 187 | uv_data = np.concatenate([uv_data[:, :3], uvedge_ij, uvedge_ij_norm], -1) 188 | 189 | ### recompute the world-edge frature #### 190 | cloth_world_dis = np.sum((cloth_data_noise[None, :, :3] - cloth_data_noise[:, None, :3])**2, -1)**0.5 191 | ball_world_dis = np.sum((cloth_data_noise[None, :, :3] - ball_nxt_data[:, None, :3])**2, -1)**0.5 192 | idxs_cloth = np.argwhere(cloth_world_dis < 0.02) 193 | idxs_ball = np.argwhere(ball_world_dis < 0.04) 194 | 195 | for idx in idxs_cloth: 196 | i_vertx = cloth_data[idx[0]] 197 | j_vertx = cloth_data[idx[1]] 198 | xij = i_vertx[:3] - j_vertx[:3] 199 | xij_norm = np.linalg.norm(xij, ord = 2) 200 | if (idx[0] != idx[1]) and (idx[0] not in self.cloth_topo[idx[1]]): 201 | worldcloth_data.append(np.array([idx[0], idx[1], xij[0], xij[1], xij[2], xij_norm])) 202 | 203 | for idx in idxs_ball: 204 | i_vertx = cloth_data[idx[1]] 205 | j_vertx = ball_data[idx[0]] 206 | xij = i_vertx[:3] - j_vertx[:3] 207 | xij_norm = np.linalg.norm(xij, ord = 2) 208 | worldball_data.append(np.array([idx[1], idx[0], xij[0], xij[1], xij[2], xij_norm])) 209 | 210 | worldcloth_data = np.array(worldcloth_data) 211 | worldball_data = np.array(worldball_data) 212 | 213 | worldcloth_adjmap = np.zeros((cloth_data.shape[0], worldcloth_data.shape[0])) 214 | for i in range(worldcloth_adjmap.shape[1]): 215 | worldcloth_adjmap[int(worldcloth_data[i, 0]), i] = 1.0 216 | if worldcloth_adjmap.shape[1] == 0: 217 | worldcloth_adjmap = [] 218 | worldcloth_adjmap = np.array(worldcloth_adjmap) 219 | 220 | worldball_adjmap = np.zeros((cloth_data.shape[0], worldball_data.shape[0])) 221 | for i in range(worldball_adjmap.shape[1]): 222 | worldball_adjmap[int(worldball_data[i, 0]), i] = 1.0 223 | if worldball_adjmap.shape[1] == 0: 224 | worldball_adjmap = [] 225 | worldball_adjmap = np.array(worldball_adjmap) 226 | 227 | ball_vel = ball_nxt_data[:, :3] - ball_data[:, :3] 228 | ball_label = np.zeros((ball_vel.shape[0], 3)) 229 | ball_label[:, 2] = 1.0 230 | ball_state = np.concatenate([ball_label, ball_vel], -1) 231 | 232 | # cloth_state = (cloth_state - self.cloth_mean) / self.cloth_std 233 | # ball_state = (ball_state - self.ball_mean) / self.ball_std 234 | # uv_data = (uv_data - self.uv_mean) / self.uv_std 235 | # if len(worldcloth_data) > 0: 236 | # worldcloth_data = (worldcloth_data - self.worldcloth_mean) / self.worldcloth_std 237 | # if len(worldball_data) > 0: 238 | # worldball_data = (worldball_data - self.worldball_mean) / self.worldball_std 239 | # cloth_acc = (cloth_acc - self.cloth_nxt_mean) / self.cloth_nxt_std 240 | 241 | if self.train: 242 | return cloth_state, ball_state, uv_data, worldcloth_data, worldball_data, cloth_acc, worldcloth_adjmap, worldball_adjmap 243 | else: 244 | return cloth_state, np.concatenate([ball_state, ball_nxt_data[:, :3]], -1), uv_data, worldcloth_data, worldball_data, np.concatenate([cloth_pre_data[:,:3], cloth_data[:, :3], cloth_nxt_data[:, :3]], -1), worldcloth_adjmap, worldball_adjmap 245 | 246 | def __getitem__(self, index): 247 | cloth_state, ball_state, uv_state, worldcloth_state, worldball_state, cloth_nxt_state, worldcloth_adjmap, worldball_adjmap = self.GetState(index) 248 | return cloth_state.astype(np.float32),\ 249 | ball_state.astype(np.float32), \ 250 | uv_state.astype(np.float32),\ 251 | worldcloth_state.astype(np.float32),\ 252 | worldball_state.astype(np.float32),\ 253 | cloth_nxt_state.astype(np.float32),\ 254 | worldcloth_adjmap.astype(np.float32),\ 255 | worldball_adjmap.astype(np.float32) 256 | # return torch.from_numpy(cloth_state.astype(np.float32)),\ 257 | # torch.from_numpy(ball_state.astype(np.float32)), \ 258 | # torch.from_numpy(uv_state.astype(np.float32)),\ 259 | # torch.from_numpy(worldcloth_state.astype(np.float32)),\ 260 | # torch.from_numpy(worldball_state.astype(np.float32)),\ 261 | # torch.from_numpy(cloth_nxt_state.astype(np.float32)),\ 262 | # torch.from_numpy(worldcloth_adjmap.astype(np.float32)),\ 263 | # torch.from_numpy(worldball_adjmap.astype(np.float32)) 264 | 265 | def GenDataStatics(): 266 | spdataset = SphereDataset('../Data', 500, True, False) 267 | sploader = DataLoader(spdataset, batch_size = 32, shuffle = False, num_workers = 48, collate_fn = collate_fn) 268 | cloth_state_list = [] 269 | ball_state_list = [] 270 | uv_state_list = [] 271 | worldcloth_state_list = [] 272 | worldball_state_list = [] 273 | cloth_nxt_state_list = [] 274 | torch.multiprocessing.set_sharing_strategy('file_system') 275 | for step, (cloth_state, ball_state, uv_state, worldcloth_state, worldball_state, cloth_nxt_state, _, _) in enumerate(sploader): 276 | for bs in range(len(cloth_state)): 277 | cloth_state_list.append(cloth_state[bs]) 278 | ball_state_list.append(ball_state[bs]) 279 | uv_state_list.append(uv_state[bs]) 280 | if worldcloth_state[bs].shape[0] > 0: 281 | worldcloth_state_list.append(worldcloth_state[bs]) 282 | if worldball_state[bs].shape[0] > 0: 283 | worldball_state_list.append(worldball_state[bs]) 284 | cloth_nxt_state_list.append(cloth_nxt_state[bs]) 285 | 286 | #### the conversion to np.float64 is very important to prevent bound error #### 287 | cloth_state = np.concatenate(cloth_state_list, 0) 288 | cloth_state = cloth_state.astype(np.float64) 289 | cloth_state_mean = np.mean(cloth_state, 0) 290 | cloth_state_std = np.std(cloth_state, 0) 291 | 292 | ball_state = np.concatenate(ball_state_list, 0) 293 | ball_state = ball_state.astype(np.float64) 294 | ball_state_mean = np.mean(ball_state, 0) 295 | ball_state_std = np.std(ball_state, 0) 296 | 297 | uv_state = np.concatenate(uv_state_list, 0) 298 | uv_state = uv_state.astype(np.float64) 299 | uv_state_mean = np.mean(uv_state, 0) 300 | uv_state_std = np.std(uv_state, 0) 301 | 302 | worldcloth_state = np.concatenate(worldcloth_state_list, 0) 303 | worldcloth_state = worldcloth_state.astype(np.float64) 304 | worldcloth_state_mean = np.mean(worldcloth_state, 0) 305 | worldcloth_state_std = np.std(worldcloth_state, 0) 306 | 307 | worldball_state = np.concatenate(worldball_state_list, 0) 308 | worldball_state = worldball_state.astype(np.float64) 309 | worldball_state_mean = np.mean(worldball_state, 0) 310 | worldball_state_std = np.std(worldball_state, 0) 311 | 312 | cloth_nxt_state = np.concatenate(cloth_nxt_state_list, 0) 313 | cloth_nxt_state = cloth_nxt_state.astype(np.float64) 314 | cloth_nxt_state_mean = np.mean(cloth_nxt_state, 0) 315 | cloth_nxt_state_std = np.std(cloth_nxt_state, 0) 316 | 317 | cloth_state_mean[:3] = 0.0 318 | cloth_state_std[:3] = 1.0 319 | ball_state_mean[:3] = 0.0 320 | ball_state_std[:3] = 1.0 321 | ball_state_std[4:] = 1.0 322 | worldcloth_state_mean[:2] = 0.0 323 | worldcloth_state_std[:2] = 1.0 324 | worldball_state_mean[:2] = 0.0 325 | worldball_state_std[:2] = 1.0 326 | 327 | print('cloth_state:', cloth_state_mean, cloth_state_std) 328 | print('ball_state:', ball_state_mean, ball_state_std) 329 | print('uv_state:', uv_state_mean, uv_state_std) 330 | print('worldcloth_state:', worldcloth_state_mean, worldcloth_state_std) 331 | print('worldball_state:', worldball_state_mean, worldball_state_std) 332 | print('cloth_nxt_state:', cloth_nxt_state_mean, cloth_nxt_state_std) 333 | np.savez('state_stat_sample.npz', {'cloth_mean':cloth_state_mean.astype(np.float32), 'cloth_std':cloth_state_std.astype(np.float32),\ 334 | 'ball_mean':ball_state_mean.astype(np.float32), 'ball_std':ball_state_std.astype(np.float32),\ 335 | 'uv_mean':uv_state_mean.astype(np.float32), 'uv_std':uv_state_std.astype(np.float32),\ 336 | 'worldcloth_mean':worldcloth_state_mean.astype(np.float32), 'worldcloth_std':worldcloth_state_std.astype(np.float32),\ 337 | 'worldball_mean':worldball_state_mean.astype(np.float32), 'worldball_std':worldball_state_std.astype(np.float32),\ 338 | 'cloth_nxt_mean':cloth_nxt_state_mean.astype(np.float32), 'cloth_nxt_std':cloth_nxt_state_std.astype(np.float32)}) 339 | 340 | if __name__ == '__main__': 341 | GenDataStatics() 342 | 343 | 344 | -------------------------------------------------------------------------------- /Test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | from Model import Encoder, Decoder, Processor 8 | from torch.utils.data import DataLoader 9 | from SphereDataset import SphereDataset, collate_fn 10 | from mpl_toolkits.mplot3d import Axes3D 11 | import matplotlib.pyplot as plt 12 | 13 | def main(): 14 | log_dir = '/home/xjwxjw/Documents/ClothSim/Models/2021-05-10-09:45:07' 15 | process_steps = 15 16 | train = False 17 | noise = False 18 | draw_ball = True 19 | 20 | input_cloth_feature = 6 21 | input_uvedge_feature = 7 22 | input_worldedge_feature = 6 23 | hidden_feature = 128 24 | output_feature = 3 25 | 26 | cloth_topo = np.load(os.path.join('../Data', 'cloth_connection.npy'), allow_pickle = True).item() 27 | 28 | cloth_mesh_tri_array = [] 29 | for line in open('../Data/cloth_connection.txt', 'r'): 30 | cloth_mesh_tri_array.append(int(line.split('\n')[0])) 31 | cloth_mesh_tri_array = np.array(cloth_mesh_tri_array).reshape(-1, 3) 32 | 33 | ball_mesh_tri_array = [] 34 | for line in open('../Data/sphere_connection.txt', 'r'): 35 | ball_mesh_tri_array.append(int(line.split('\n')[0])) 36 | ball_mesh_tri_array = np.array(ball_mesh_tri_array).reshape(-1, 3) 37 | 38 | spdataset = SphereDataset('../Data', 500, train, noise) 39 | sploader = DataLoader(spdataset, batch_size = 1, shuffle = False, num_workers = 1, collate_fn = collate_fn) 40 | 41 | adj_map = torch.from_numpy(spdataset.adj_map.astype(np.float32)).cuda().unsqueeze(0) 42 | uvedge_node_i = spdataset.uvedge_node_i.astype(np.float32) 43 | uvedge_node_j = spdataset.uvedge_node_j.astype(np.float32) 44 | 45 | node_encoder = Encoder(input_cloth_feature, hidden_feature, hidden_feature, 'ln').cuda() 46 | node_encoder.load_state_dict(torch.load(os.path.join(log_dir, 'node_encoder.pkl'))) 47 | node_encoder.eval() 48 | 49 | uvedge_encoder = Encoder(input_uvedge_feature, hidden_feature, hidden_feature, 'ln').cuda() 50 | uvedge_encoder.load_state_dict(torch.load(os.path.join(log_dir, 'uvedge_encoder.pkl'))) 51 | uvedge_encoder.eval() 52 | 53 | worldedge_encoder = Encoder(input_worldedge_feature - 2, hidden_feature, hidden_feature, 'ln').cuda() 54 | worldedge_encoder.load_state_dict(torch.load(os.path.join(log_dir, 'worldedge_encoder.pkl'))) 55 | worldedge_encoder.eval() 56 | 57 | decoder = Decoder(hidden_feature, output_feature, hidden_feature).cuda() 58 | decoder.load_state_dict(torch.load(os.path.join(log_dir, 'decoder.pkl'))) 59 | decoder.eval() 60 | 61 | node_processor_list = [] 62 | uvedge_processor_list = [] 63 | worldedge_processor_list = [] 64 | for l in range(process_steps): 65 | node_processor_list.append(Processor(hidden_feature * 3, hidden_feature, hidden_feature * 3, 'ln').cuda()) 66 | node_processor_list[-1].load_state_dict(torch.load(os.path.join(log_dir, 'node_processor_%02d.pkl' % l))) 67 | node_processor_list[-1].eval() 68 | 69 | uvedge_processor_list.append(Processor(hidden_feature * 3, hidden_feature, hidden_feature * 3, 'ln').cuda()) 70 | uvedge_processor_list[-1].load_state_dict(torch.load(os.path.join(log_dir, 'uvedge_processor_%02d.pkl' % l))) 71 | uvedge_processor_list[-1].eval() 72 | 73 | worldedge_processor_list.append(Processor(hidden_feature * 3, hidden_feature, hidden_feature * 3, 'ln').cuda()) 74 | worldedge_processor_list[-1].load_state_dict(torch.load(os.path.join(log_dir, 'worldedge_processor_%02d.pkl' % l))) 75 | worldedge_processor_list[-1].eval() 76 | print("All pretrained models successfully loaded") 77 | 78 | cloth_pre_pos = None 79 | cloth_cur_pos = None 80 | cloth_nxt_pos = None 81 | with torch.no_grad(): 82 | for step, (cloth_state, ball_state, uv_state, worldcloth_state, worldball_state, cloth_pos, worldcloth_adjmap, worldball_adjmap) in enumerate(sploader): 83 | ball_state = torch.stack([item for item in ball_state], 0).cuda() 84 | ball_nxt_pos = ball_state[:,:,6:] 85 | ball_state = ball_state[:,:,:6] 86 | if step == 0: 87 | cloth_pos = torch.stack([item for item in cloth_pos], 0).cuda() 88 | cloth_pre_pos = cloth_pos[:, :, 0:3] 89 | cloth_cur_pos = cloth_pos[:, :, 3:6] 90 | cloth_nxt_pos = cloth_pos[:, :, 6:9] 91 | tmp_std = torch.from_numpy(spdataset.cloth_nxt_std).unsqueeze(0).unsqueeze(0).cuda() 92 | tmp_mean = torch.from_numpy(spdataset.cloth_nxt_mean).unsqueeze(0).unsqueeze(0).cuda() 93 | 94 | cloth_state = torch.stack([item for item in cloth_state], 0).cuda() 95 | uv_state = torch.stack([item for item in uv_state], 0).cuda() 96 | worldedgecloth_state_list = [] 97 | for bs in range(len(worldcloth_state)): 98 | if worldcloth_state[bs].size(0) > 0: 99 | worldedgecloth_state_list.append(worldcloth_state[bs]) 100 | worldedgeball_state_list = [] 101 | for bs in range(len(worldball_state)): 102 | if worldball_state[bs].size(0) > 0: 103 | worldedgeball_state_list.append(worldball_state[bs]) 104 | 105 | fig = plt.figure() 106 | ax = fig.gca(projection='3d') 107 | 108 | cloth_cur_pos_np = cloth_cur_pos.detach().cpu().numpy() 109 | x = cloth_cur_pos_np[0,:,0] 110 | y = cloth_cur_pos_np[0,:,2] 111 | z = cloth_cur_pos_np[0,:,1] 112 | ax.plot_trisurf(x, y, z, triangles = cloth_mesh_tri_array, linewidth = 0.2, antialiased = True, color = 'b') 113 | if draw_ball: 114 | ball_nxt_pos_np = ball_nxt_pos.detach().cpu().numpy() 115 | x = ball_nxt_pos_np[0,:,0] 116 | y = ball_nxt_pos_np[0,:,2] 117 | z = ball_nxt_pos_np[0,:,1] 118 | ax.plot_trisurf(x, y, z, triangles = ball_mesh_tri_array, linewidth = 0.2, antialiased = True, color = 'r') 119 | 120 | ax.set_xlim([-1.0, 1.0]) 121 | ax.set_ylim([-1.0, 1.0]) 122 | ax.set_zlim([-1.0, 1.0]) 123 | plt.savefig('../Results/%03d.png' % (step)) 124 | plt.close('all') 125 | else: 126 | #### compute the cloth node feature #### 127 | cloth_pre_pos = cloth_cur_pos.clone() 128 | cloth_cur_pos = cloth_nxt_pos.clone() 129 | cloth_cur_vel = cloth_cur_pos - cloth_pre_pos 130 | tmp_mean = torch.from_numpy(spdataset.cloth_mean[3:]).unsqueeze(0).unsqueeze(0).cuda() 131 | tmp_std = torch.from_numpy(spdataset.cloth_std[3:]).unsqueeze(0).unsqueeze(0).cuda() 132 | cloth_cur_vel = (cloth_cur_vel - tmp_mean) / tmp_std 133 | cloth_state = torch.cat([cloth_state[0][:, :3].unsqueeze(0).cuda(), cloth_cur_vel], -1) 134 | #### compute the uv edge feature #### 135 | cloth_uvworld_ = cloth_cur_pos[0, uvedge_node_i] - cloth_cur_pos[0, uvedge_node_j] 136 | cloth_uvworld_norm = torch.norm(cloth_uvworld_, p = 2, dim = -1, keepdim = True) 137 | cloth_uvworld_feature = torch.cat([cloth_uvworld_, cloth_uvworld_norm], -1) 138 | tmp_mean = torch.from_numpy(spdataset.uv_mean[3:]).unsqueeze(0).unsqueeze(0).cuda() 139 | tmp_std = torch.from_numpy(spdataset.uv_std[3:]).unsqueeze(0).unsqueeze(0).cuda() 140 | cloth_uvworld_feature = (cloth_uvworld_feature - tmp_mean) / tmp_std 141 | uv_state = torch.stack([item for item in uv_state], 0).cuda() 142 | uv_state[0, :, 3:] = cloth_uvworld_feature 143 | 144 | #### compute the world edge feature #### 145 | cloth_data = cloth_cur_pos[0].detach().cpu().numpy() 146 | ball_data = ball_nxt_pos[0].detach().cpu().numpy() 147 | 148 | #### collision between cloth and ball #### 149 | ball_world_dis = np.sum((cloth_data[None, :, :3] - ball_data[:, None, :3]) ** 2, -1) ** 0.5 150 | idxs_ball = np.argwhere(ball_world_dis < 0.04) 151 | worldball_state = [] 152 | tmp_state_list = [] 153 | tmp_adj_idx = [] 154 | for idx in idxs_ball: 155 | i_vertx = cloth_data[idx[1]] 156 | j_vertx = ball_data[idx[0]] 157 | xij = i_vertx[:3] - j_vertx[:3] 158 | xij_norm = np.linalg.norm(xij, ord = 2) 159 | worldball_feat = np.array([idx[1], idx[0], xij[0], xij[1], xij[2], xij_norm]) 160 | worldball_feat = (worldball_feat - spdataset.worldball_mean) / spdataset.worldball_std 161 | tmp_state_list.append(torch.from_numpy(worldball_feat.astype(np.float32)).cuda()) 162 | tmp_adj_idx.append(idx[1]) 163 | if len(tmp_state_list) == 0: 164 | worldball_state.append(torch.from_numpy(np.array(tmp_state_list))) 165 | else: 166 | worldball_state.append(torch.stack(tmp_state_list)) 167 | 168 | worldball_adjmap = np.zeros((cloth_data.shape[0], len(tmp_state_list))) 169 | for i in range(worldball_adjmap.shape[1]): 170 | worldball_adjmap[tmp_adj_idx[i], i] = 1.0 171 | if len(tmp_state_list) == 0: 172 | worldball_adjmap = [] 173 | else: 174 | worldball_adjmap = [torch.from_numpy(worldball_adjmap.astype(np.float32))] 175 | 176 | worldedgeball_state_list = [] 177 | for bs in range(len(worldball_state)): 178 | if worldball_state[bs].size(0) > 0: 179 | worldedgeball_state_list.append(worldball_state[bs]) 180 | 181 | #### collision between cloth and cloth #### 182 | cloth_world_dis = np.sum((cloth_data[None, :, :3] - cloth_data[:, None, :3]) ** 2, -1) ** 0.5 183 | idxs_cloth = np.argwhere(cloth_world_dis < 0.02) 184 | worldcloth_state = [] 185 | tmp_state_list = [] 186 | tmp_adj_idx = [] 187 | for idx in idxs_cloth: 188 | i_vertx = cloth_data[idx[0]] 189 | j_vertx = cloth_data[idx[1]] 190 | if (idx[0] != idx[1]) and (idx[0] not in cloth_topo[idx[1]]): 191 | xij = i_vertx[:3] - j_vertx[:3] 192 | xij_norm = np.linalg.norm(xij, ord = 2) 193 | worldcloth_feat = np.array([idx[0], idx[1], xij[0], xij[1], xij[2], xij_norm]) 194 | worldcloth_feat = (worldcloth_feat - spdataset.worldcloth_mean) / spdataset.worldcloth_std 195 | tmp_state_list.append(torch.from_numpy(worldcloth_feat.astype(np.float32)).cuda()) 196 | tmp_adj_idx.append(idx[0]) 197 | if len(tmp_state_list) == 0: 198 | worldcloth_state.append(torch.from_numpy(np.array(tmp_state_list))) 199 | else: 200 | worldcloth_state.append(torch.stack(tmp_state_list)) 201 | 202 | worldcloth_adjmap = np.zeros((cloth_data.shape[0], len(tmp_state_list))) 203 | for i in range(worldcloth_adjmap.shape[1]): 204 | worldcloth_adjmap[tmp_adj_idx[i], i] = 1.0 205 | if len(tmp_state_list) == 0: 206 | worldcloth_adjmap = [] 207 | else: 208 | worldcloth_adjmap = [torch.from_numpy(worldcloth_adjmap.astype(np.float32))] 209 | 210 | worldedgecloth_state_list = [] 211 | for bs in range(len(worldcloth_state)): 212 | if worldcloth_state[bs].size(0) > 0: 213 | worldedgecloth_state_list.append(worldcloth_state[bs]) 214 | 215 | #### encoder part #### 216 | # print(step, 'input', cloth_state[0, 100, 3:].detach().cpu().numpy()) 217 | cloth_feature = node_encoder(cloth_state) 218 | ball_feature = node_encoder(ball_state) 219 | uvedge_feature = uvedge_encoder(uv_state) 220 | 221 | worldedgecloth_feature = None 222 | if len(worldedgecloth_state_list) > 0: 223 | worldedgecloth_state = torch.cat(worldedgecloth_state_list).unsqueeze(0) 224 | worldedgecloth_feature = worldedge_encoder(worldedgecloth_state[:, :, 2:].cuda()) 225 | 226 | worldedgeball_feature = None 227 | if len(worldedgeball_state_list) > 0: 228 | worldedgeball_state = torch.cat(worldedgeball_state_list).unsqueeze(0) 229 | worldedgeball_feature = worldedge_encoder(worldedgeball_state[:, :, 2:].cuda()) 230 | 231 | for l in range(process_steps): 232 | #### uv edge feature update #### 233 | uvedge_feature_cat = torch.cat([uvedge_feature, cloth_feature[:, uvedge_node_i], cloth_feature[:, uvedge_node_j]], -1) 234 | uvedge_nxt_feature = uvedge_processor_list[l](uvedge_feature_cat) 235 | 236 | ### cloth-ball world edge feature update #### 237 | if worldedgeball_feature is not None: 238 | worldedge_feature_node_i_list = [] 239 | worldedge_feature_node_j_list = [] 240 | for bs in range(len(worldball_state)): 241 | if worldball_state[bs].size(0) > 0: 242 | node_i_index = worldball_state[bs][:, 0].detach().cpu().numpy() 243 | node_j_index = worldball_state[bs][:, 1].detach().cpu().numpy() 244 | worldedge_feature_node_i_list.append(cloth_feature[bs, node_i_index]) 245 | worldedge_feature_node_j_list.append(ball_feature[bs, node_j_index]) 246 | worldedge_feature_node_i = torch.cat(worldedge_feature_node_i_list, 0).unsqueeze(0) 247 | worldedge_feature_node_j = torch.cat(worldedge_feature_node_j_list, 0).unsqueeze(0) 248 | worldedge_feature_cat = torch.cat([worldedgeball_feature, worldedge_feature_node_i, worldedge_feature_node_j], -1) 249 | worldedgeball_nxt_feature = worldedge_processor_list[l](worldedge_feature_cat) 250 | #### NOTE: here we assume batch size is 1 #### 251 | agr_worldball_feature = torch.matmul(worldball_adjmap[0].unsqueeze(0).cuda(), worldedgeball_nxt_feature) 252 | else: 253 | agr_worldball_feature = torch.zeros((len(cloth_state), cloth_state[0].size(0), hidden_feature)).cuda() 254 | 255 | ### cloth-cloth world edge feature update #### 256 | if worldedgecloth_feature is not None: 257 | worldedge_feature_node_i_list = [] 258 | worldedge_feature_node_j_list = [] 259 | for bs in range(len(worldcloth_state)): 260 | if worldcloth_state[bs].size(0) > 0: 261 | node_i_index = worldcloth_state[bs][:, 0].detach().cpu().numpy() 262 | node_j_index = worldcloth_state[bs][:, 1].detach().cpu().numpy() 263 | worldedge_feature_node_i_list.append(cloth_feature[bs, node_i_index]) 264 | worldedge_feature_node_j_list.append(cloth_feature[bs, node_j_index]) 265 | worldedge_feature_node_i = torch.cat(worldedge_feature_node_i_list, 0).unsqueeze(0) 266 | worldedge_feature_node_j = torch.cat(worldedge_feature_node_j_list, 0).unsqueeze(0) 267 | worldedge_feature_cat = torch.cat([worldedgecloth_feature, worldedge_feature_node_i, worldedge_feature_node_j], -1) 268 | worldedgecloth_nxt_feature = worldedge_processor_list[l](worldedge_feature_cat) 269 | #### NOTE: here we assume batch size is 1 #### 270 | agr_worldcloth_feature = torch.matmul(worldcloth_adjmap[0].unsqueeze(0).cuda(), worldedgecloth_nxt_feature) 271 | else: 272 | agr_worldcloth_feature = torch.zeros((len(cloth_state), cloth_state[0].size(0), hidden_feature)).cuda() 273 | 274 | ### node feature update #### 275 | agr_uv_feature = torch.matmul(adj_map[:cloth_feature.size(0)], uvedge_nxt_feature) 276 | cloth_feature_cat = torch.cat([cloth_feature, agr_uv_feature, agr_worldball_feature + agr_worldcloth_feature], -1) 277 | cloth_nxt_feature = node_processor_list[l](cloth_feature_cat) 278 | 279 | #### residual connection #### 280 | uvedge_feature = uvedge_feature + uvedge_nxt_feature 281 | if worldedgeball_feature is not None: 282 | worldedgeball_feature = worldedgeball_feature + worldedgeball_nxt_feature 283 | if worldedgecloth_feature is not None: 284 | worldedgecloth_feature = worldedgecloth_feature + worldedgecloth_nxt_feature 285 | cloth_feature = cloth_feature + cloth_nxt_feature 286 | 287 | output = decoder(cloth_feature) 288 | # print(step, output[0,100]) 289 | #### use predicted acc to calculate the position #### 290 | tmp_std = torch.from_numpy(spdataset.cloth_nxt_std).unsqueeze(0).unsqueeze(0).cuda() 291 | tmp_mean = torch.from_numpy(spdataset.cloth_nxt_mean).unsqueeze(0).unsqueeze(0).cuda() 292 | output = (output * tmp_std) + tmp_mean 293 | output[0, 1] = 0.0 294 | output[0, 645] = 0.0 295 | cloth_nxt_pos = 2 * cloth_cur_pos + output - cloth_pre_pos 296 | # print(step, 'after', cloth_pre_pos[0,100], cloth_cur_pos[0,100], cloth_nxt_pos[0,100]) 297 | 298 | fig = plt.figure() 299 | ax = fig.gca(projection='3d') 300 | 301 | cloth_nxt_pos_np = cloth_nxt_pos.detach().cpu().numpy() 302 | x = cloth_nxt_pos_np[0,:,0] 303 | y = cloth_nxt_pos_np[0,:,2] 304 | z = cloth_nxt_pos_np[0,:,1] 305 | ax.plot_trisurf(x, y, z, triangles = cloth_mesh_tri_array, linewidth = 0.2, antialiased = True, color = 'b') 306 | if draw_ball: 307 | ball_nxt_pos_np = ball_nxt_pos.detach().cpu().numpy() 308 | x = ball_nxt_pos_np[0,:,0] 309 | y = ball_nxt_pos_np[0,:,2] 310 | z = ball_nxt_pos_np[0,:,1] 311 | ax.plot_trisurf(x, y, z, triangles = ball_mesh_tri_array, linewidth = 0.2, antialiased = True, color = 'r') 312 | 313 | ax.set_xlim([-1.0, 1.0]) 314 | ax.set_ylim([-1.0, 1.0]) 315 | ax.set_zlim([-1.0, 1.0]) 316 | plt.savefig('../Results/%03d.png' % (step + 1)) 317 | plt.close('all') 318 | 319 | if __name__ == "__main__": 320 | main() 321 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import shutil 5 | from tensorboardX import SummaryWriter 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | from SphereDataset import SphereDataset, collate_fn 12 | from Model import Encoder, Decoder, Processor, Processor_Res 13 | 14 | def main(): 15 | pretrained_model = None#'/home/xjwxjw/Documents/ClothSim/Models/2021-05-05-15:15:14' 16 | learning_rate = 1e-4 17 | batch_size = 1 18 | num_workers = 8 19 | shuffle = True 20 | train = True 21 | noise = True 22 | num_epochs = 5001 23 | beta0 = 0.9 24 | beta1 = 0.999 25 | use_scheduler = True 26 | 27 | process_steps = 15 28 | 29 | input_cloth_feature = 6 30 | input_uvedge_feature = 7 31 | input_worldedge_feature = 6 32 | hidden_feature = 128 33 | output_feature = 3 34 | 35 | now = int(time.time()) 36 | timeArray = time.localtime(now) 37 | otherStyleTime = time.strftime("%Y-%m-%d-%H:%M:%S", timeArray) 38 | log_dir = '../Logs/%s' % otherStyleTime 39 | model_dir = '../Models/%s' % otherStyleTime 40 | 41 | if not os.path.exists(log_dir): 42 | os.makedirs(log_dir) 43 | if not os.path.exists(model_dir): 44 | os.makedirs(model_dir) 45 | 46 | def copydirs(from_file, to_file): 47 | if not os.path.exists(to_file): 48 | os.makedirs(to_file) 49 | files = os.listdir(from_file) 50 | for f in files: 51 | if os.path.isdir(from_file + '/' + f): 52 | copydirs(from_file + '/' + f, to_file + '/' + f) 53 | else: 54 | if '.git' not in from_file: 55 | shutil.copy(from_file + '/' + f, to_file + '/' + f) 56 | copydirs('./', log_dir + '/Src') 57 | 58 | writer = SummaryWriter(log_dir) 59 | 60 | spdataset = SphereDataset('../Data', 500, train, noise) 61 | adj_map = torch.from_numpy(spdataset.adj_map.astype(np.float32)).cuda().unsqueeze(0) 62 | adj_map = torch.cat([adj_map for i in range(batch_size)], 0) 63 | uvedge_node_i = spdataset.uvedge_node_i.astype(np.float32) 64 | uvedge_node_j = spdataset.uvedge_node_j.astype(np.float32) 65 | 66 | def truncated_normal_(tensor, mean = 0, std = 0.2): 67 | with torch.no_grad(): 68 | size = tensor.shape 69 | tmp = tensor.new_empty(size+(4,)).normal_() 70 | valid = (tmp < 2) & (tmp > -2) 71 | ind = valid.max(-1, keepdim=True)[1] 72 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 73 | tensor.data.mul_(std).add_(mean) 74 | return tensor 75 | 76 | def init_weights(m): 77 | if type(m) == nn.Linear: 78 | truncated_normal_(m.weight) 79 | m.bias.data.fill_(0.0) 80 | 81 | if pretrained_model is not None: 82 | node_encoder = Encoder(input_cloth_feature, hidden_feature, hidden_feature, 'ln').cuda() 83 | node_encoder.load_state_dict(torch.load(os.path.join(pretrained_model, 'node_encoder.pkl'))) 84 | 85 | uvedge_encoder = Encoder(input_uvedge_feature, hidden_feature, hidden_feature, 'ln').cuda() 86 | uvedge_encoder.load_state_dict(torch.load(os.path.join(pretrained_model, 'uvedge_encoder.pkl'))) 87 | 88 | worldedge_encoder = Encoder(input_worldedge_feature - 2, hidden_feature, hidden_feature, 'ln').cuda() 89 | worldedge_encoder.load_state_dict(torch.load(os.path.join(pretrained_model, 'worldedge_encoder.pkl'))) 90 | 91 | decoder = Decoder(hidden_feature, output_feature, hidden_feature).cuda() 92 | decoder.load_state_dict(torch.load(os.path.join(pretrained_model, 'decoder.pkl'))) 93 | 94 | node_processor_list = [] 95 | uvedge_processor_list = [] 96 | worldedge_processor_list = [] 97 | for l in range(process_steps): 98 | node_processor_list.append(Processor(hidden_feature * 3, hidden_feature, hidden_feature * 3, 'ln').cuda()) 99 | node_processor_list[-1].load_state_dict(torch.load(os.path.join(pretrained_model, 'node_processor_%02d.pkl' % l))) 100 | 101 | uvedge_processor_list.append(Processor(hidden_feature * 3, hidden_feature, hidden_feature * 3, 'ln').cuda()) 102 | uvedge_processor_list[-1].load_state_dict(torch.load(os.path.join(pretrained_model, 'uvedge_processor_%02d.pkl' % l))) 103 | 104 | worldedge_processor_list.append(Processor(hidden_feature * 3, hidden_feature, hidden_feature * 3, 'ln').cuda()) 105 | worldedge_processor_list[-1].load_state_dict(torch.load(os.path.join(pretrained_model, 'worldedge_processor_%02d.pkl' % l))) 106 | print("All pretrained models successfully loaded") 107 | else: 108 | node_encoder = Encoder(input_cloth_feature, hidden_feature, hidden_feature, 'ln').cuda() 109 | node_encoder.apply(init_weights) 110 | 111 | uvedge_encoder = Encoder(input_uvedge_feature, hidden_feature, hidden_feature, 'ln').cuda() 112 | uvedge_encoder.apply(init_weights) 113 | 114 | worldedge_encoder = Encoder(input_worldedge_feature - 2, hidden_feature, hidden_feature, 'ln').cuda() 115 | worldedge_encoder.apply(init_weights) 116 | 117 | decoder = Decoder(hidden_feature, output_feature, hidden_feature).cuda() 118 | decoder.apply(init_weights) 119 | 120 | node_processor_list = [] 121 | uvedge_processor_list = [] 122 | worldedge_processor_list = [] 123 | for l in range(process_steps): 124 | node_processor_list.append(Processor(hidden_feature * 3, hidden_feature, hidden_feature * 3, 'ln').cuda()) 125 | node_processor_list[-1].apply(init_weights) 126 | 127 | uvedge_processor_list.append(Processor(hidden_feature * 3, hidden_feature, hidden_feature * 3, 'ln').cuda()) 128 | uvedge_processor_list[-1].apply(init_weights) 129 | 130 | worldedge_processor_list.append(Processor(hidden_feature * 3, hidden_feature, hidden_feature * 3, 'ln').cuda()) 131 | worldedge_processor_list[-1].apply(init_weights) 132 | 133 | def worker_init_fn(worker_id): 134 | np.random.seed(np.random.get_state()[1][0] + worker_id) 135 | 136 | sploader = DataLoader(spdataset, batch_size = batch_size, shuffle = shuffle, num_workers = num_workers, collate_fn = collate_fn, worker_init_fn=worker_init_fn) 137 | 138 | node_encoder.train() 139 | uvedge_encoder.train() 140 | worldedge_encoder.train() 141 | for l in range(process_steps): 142 | node_processor_list[l].train() 143 | worldedge_processor_list[l].train() 144 | uvedge_processor_list[l].train() 145 | decoder.train() 146 | 147 | parm_list = [] 148 | parm_list += node_encoder.parameters() 149 | parm_list += uvedge_encoder.parameters() 150 | parm_list += worldedge_encoder.parameters() 151 | for l in range(process_steps): 152 | parm_list += node_processor_list[l].parameters() 153 | parm_list += worldedge_processor_list[l].parameters() 154 | parm_list += uvedge_processor_list[l].parameters() 155 | parm_list += decoder.parameters() 156 | 157 | optimizer = optim.Adam(parm_list, lr=learning_rate, betas=(beta0, beta1)) 158 | total_step = len(sploader) 159 | scheduler = None 160 | if use_scheduler: 161 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 100, gamma=0.1) 162 | 163 | # world_feature = None 164 | for num_epoch in range(num_epochs): 165 | np.random.seed() 166 | for step, (cloth_state, ball_state, uv_state, worldcloth_state, worldball_state, cloth_nxt_state, worldcloth_adjmap, worldball_adjmap) in enumerate(sploader): 167 | 168 | cloth_state = torch.stack([item for item in cloth_state], 0).cuda() 169 | ball_state = torch.stack([item for item in ball_state], 0).cuda() 170 | uv_state = torch.stack([item for item in uv_state], 0).cuda() 171 | cloth_nxt_state = torch.stack([item for item in cloth_nxt_state], 0).cuda() 172 | 173 | #### encoder part #### 174 | cloth_feature = node_encoder(cloth_state) 175 | ball_feature = node_encoder(ball_state) 176 | 177 | uvedge_feature = uvedge_encoder(uv_state) 178 | 179 | worldedge_state_list = [] 180 | worldedge_node_i_index_list = [] 181 | worldedge_node_j_index_list = [] 182 | for bs in range(len(worldcloth_state)): 183 | if worldcloth_state[bs].size(0) > 0: 184 | worldedge_state_list.append(worldcloth_state[bs]) 185 | node_i_index = worldcloth_state[bs][:, 0].detach().cpu().numpy() 186 | node_j_index = worldcloth_state[bs][:, 1].detach().cpu().numpy() 187 | worldedge_node_i_index_list.append(node_i_index) 188 | worldedge_node_j_index_list.append(node_j_index) 189 | else: 190 | worldedge_node_i_index_list.append([]) 191 | worldedge_node_j_index_list.append([]) 192 | worldedgecloth_feature = None 193 | if len(worldedge_state_list) > 0: 194 | worldedge_state = torch.cat(worldedge_state_list).unsqueeze(0) 195 | worldedgecloth_feature = worldedge_encoder(worldedge_state[:, :, 2:].cuda()) 196 | 197 | worldedge_state_list = [] 198 | worldedge_node_i_index_list = [] 199 | worldedge_node_j_index_list = [] 200 | for bs in range(len(worldball_state)): 201 | if worldball_state[bs].size(0) > 0: 202 | worldedge_state_list.append(worldball_state[bs]) 203 | node_i_index = worldball_state[bs][:, 0].detach().cpu().numpy() 204 | node_j_index = worldball_state[bs][:, 1].detach().cpu().numpy() 205 | worldedge_node_i_index_list.append(node_i_index) 206 | worldedge_node_j_index_list.append(node_j_index) 207 | else: 208 | worldedge_node_i_index_list.append([]) 209 | worldedge_node_j_index_list.append([]) 210 | worldedgeball_feature = None 211 | if len(worldedge_state_list) > 0: 212 | worldedge_state = torch.cat(worldedge_state_list).unsqueeze(0) 213 | worldedgeball_feature = worldedge_encoder(worldedge_state[:, :, 2:].cuda()) 214 | 215 | for l in range(process_steps): 216 | ### uv edge feature update #### 217 | uvedge_feature_cat = torch.cat([uvedge_feature, cloth_feature[:, uvedge_node_i], cloth_feature[:, uvedge_node_j]], -1) 218 | uvedge_nxt_feature = uvedge_processor_list[l](uvedge_feature_cat) 219 | 220 | ### cloth-ball world edge feature update #### 221 | if worldedgeball_feature is not None: 222 | worldedge_feature_node_i_list = [] 223 | worldedge_feature_node_j_list = [] 224 | for bs in range(len(worldball_state)): 225 | if worldball_state[bs].size(0) > 0: 226 | node_i_index = worldball_state[bs][:, 0].detach().cpu().numpy() 227 | node_j_index = worldball_state[bs][:, 1].detach().cpu().numpy() 228 | worldedge_feature_node_i_list.append(cloth_feature[bs, node_i_index]) 229 | worldedge_feature_node_j_list.append(ball_feature[bs, node_j_index]) 230 | worldedge_feature_node_i = torch.cat(worldedge_feature_node_i_list, 0).unsqueeze(0) 231 | worldedge_feature_node_j = torch.cat(worldedge_feature_node_j_list, 0).unsqueeze(0) 232 | worldedge_feature_cat = torch.cat([worldedgeball_feature, worldedge_feature_node_i, worldedge_feature_node_j], -1) 233 | worldedgeball_nxt_feature = worldedge_processor_list[l](worldedge_feature_cat) 234 | #### NOTE: here we assume batch size is 1 #### 235 | agr_worldball_feature = torch.matmul(worldball_adjmap[0].unsqueeze(0).cuda(), worldedgeball_nxt_feature) 236 | else: 237 | agr_worldball_feature = torch.zeros((len(cloth_state), cloth_state[0].size(0), hidden_feature)).cuda() 238 | 239 | ### cloth-cloth world edge feature update #### 240 | if worldedgecloth_feature is not None: 241 | worldedge_feature_node_i_list = [] 242 | worldedge_feature_node_j_list = [] 243 | for bs in range(len(worldcloth_state)): 244 | if worldcloth_state[bs].size(0) > 0: 245 | node_i_index = worldcloth_state[bs][:, 0].detach().cpu().numpy() 246 | node_j_index = worldcloth_state[bs][:, 1].detach().cpu().numpy() 247 | worldedge_feature_node_i_list.append(cloth_feature[bs, node_i_index]) 248 | worldedge_feature_node_j_list.append(cloth_feature[bs, node_j_index]) 249 | worldedge_feature_node_i = torch.cat(worldedge_feature_node_i_list, 0).unsqueeze(0) 250 | worldedge_feature_node_j = torch.cat(worldedge_feature_node_j_list, 0).unsqueeze(0) 251 | worldedge_feature_cat = torch.cat([worldedgecloth_feature, worldedge_feature_node_i, worldedge_feature_node_j], -1) 252 | worldedgecloth_nxt_feature = worldedge_processor_list[l](worldedge_feature_cat) 253 | #### NOTE: here we assume batch size is 1 #### 254 | agr_worldcloth_feature = torch.matmul(worldcloth_adjmap[0].unsqueeze(0).cuda(), worldedgecloth_nxt_feature) 255 | else: 256 | agr_worldcloth_feature = torch.zeros((len(cloth_state), cloth_state[0].size(0), hidden_feature)).cuda() 257 | 258 | ### node feature update #### 259 | agr_uv_feature = torch.matmul(adj_map[:cloth_feature.size(0)], uvedge_nxt_feature) 260 | cloth_feature_cat = torch.cat([cloth_feature, agr_uv_feature, agr_worldball_feature + agr_worldcloth_feature], -1) 261 | cloth_nxt_feature = node_processor_list[l](cloth_feature_cat) 262 | 263 | #### residual connection #### 264 | uvedge_feature = uvedge_feature + uvedge_nxt_feature 265 | if worldedgeball_feature is not None: 266 | worldedgeball_feature = worldedgeball_feature + worldedgeball_nxt_feature 267 | if worldedgecloth_feature is not None: 268 | worldedgecloth_feature = worldedgecloth_feature + worldedgecloth_nxt_feature 269 | cloth_feature = cloth_feature + cloth_nxt_feature 270 | 271 | output = decoder(cloth_feature) 272 | 273 | #### zero-out kinematic node #### 274 | # kinematic_node = [1, 645] 275 | # output[:, kinematic_node, :] = 0.0 276 | # cloth_nxt_state[:, kinematic_node, :] = 0.0 277 | 278 | loss = torch.sum((output - cloth_nxt_state) ** 2) / (output.size(0) * output.size(1)) 279 | print(num_epoch, step, output[0, 100, :].detach().cpu().numpy(), cloth_nxt_state[0, 100, :].detach().cpu().numpy(), loss.detach().cpu().numpy()) 280 | 281 | optimizer.zero_grad() 282 | loss.backward() 283 | optimizer.step() 284 | writer.add_scalar('train_loss', loss.detach().cpu().numpy(), global_step = num_epoch * total_step + step) 285 | 286 | torch.save(node_encoder.state_dict(), model_dir + '/node_encoder.pkl') 287 | torch.save(uvedge_encoder.state_dict(), model_dir + '/uvedge_encoder.pkl') 288 | torch.save(worldedge_encoder.state_dict(), model_dir + '/worldedge_encoder.pkl') 289 | for l in range(process_steps): 290 | torch.save(node_processor_list[l].state_dict(), model_dir + '/node_processor_%02d.pkl' % l) 291 | torch.save(uvedge_processor_list[l].state_dict(), model_dir + '/uvedge_processor_%02d.pkl' % l) 292 | torch.save(worldedge_processor_list[l].state_dict(), model_dir + '/worldedge_processor_%02d.pkl' % l) 293 | torch.save(decoder.state_dict(), model_dir + '/decoder.pkl') 294 | if use_scheduler: 295 | scheduler.step() 296 | 297 | if __name__ == '__main__': 298 | main() 299 | -------------------------------------------------------------------------------- /Utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | def GenDataStat(): 6 | target_list = ['cloth', 'ball'] 7 | for target in target_list: 8 | target_data = [] 9 | for i in range(500): 10 | target_path = os.path.join('../Data/0002/%03d_%s.txt' % (i, target)) 11 | for line in open(target_path, 'r'): 12 | line = line.split('\n')[0] 13 | target_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 14 | target_data = np.array(target_data) 15 | target_mean = np.mean(target_data, 0) 16 | target_std = np.std(target_data, 0) 17 | np.save('../Data/%s_mean.npy' % target, target_mean) 18 | np.save('../Data/%s_std.npy' % target, target_std + 1e-10) 19 | print('%s_mean:' % target, target_mean) 20 | print('%s_std:' % target, target_std + 1e-10) 21 | 22 | def draw_gt_data(): 23 | for i in range(100): 24 | for j in range(500): 25 | for line in open('../Data/data/%04d/%03d_world.txt' % (i, j), 'r'): 26 | if (line.split('\n')[0].split(' ')[0] == '0'): 27 | print(line.split('\n')[0].split(' ')) 28 | cloth_connection = np.array([int(line.split('\n')[0]) for line in open('../Data/cloth_connection.txt', 'r')]) 29 | cloth_connection = np.reshape(cloth_connection, (-1, 3)) 30 | ball_connection = np.array([int(line.split('\n')[0]) for line in open('../Data/sphere_connection.txt', 'r')]) 31 | ball_connection = np.reshape(ball_connection, (-1, 3)) 32 | 33 | from mpl_toolkits.mplot3d import Axes3D as axes3d 34 | import matplotlib.pyplot as plt 35 | 36 | for i in range(500): 37 | cloth_data = [] 38 | for line in open('../Data/data/0002/%03d_cloth.txt' % i, 'r'): 39 | line = line.split('\n')[0] 40 | cloth_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 41 | cloth_data = np.array(cloth_data) 42 | cloth_x, cloth_y, cloth_z = cloth_data[:,:3].T 43 | 44 | ball_data = [] 45 | for line in open('../Data/data/0002/%03d_ball.txt' % i, 'r'): 46 | line = line.split('\n')[0] 47 | ball_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 48 | ball_data = np.array(ball_data) 49 | ball_x, ball_y, ball_z = ball_data[:,:3].T 50 | 51 | fig = plt.figure() 52 | ax = fig.gca(projection = '3d') 53 | ax.plot_trisurf(cloth_x, cloth_z, cloth_y, triangles = cloth_connection) 54 | ax.plot_trisurf(ball_x, ball_z, ball_y, triangles = ball_connection) 55 | ax.set(xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1)) 56 | plt.savefig('../Results/%03d.png' % i) 57 | # plt.show() 58 | plt.close('all') 59 | 60 | def GenEdgeFeature(seq_id): 61 | cloth_connection = np.load('../Data/cloth_connection.npy', allow_pickle=True).item() 62 | # ball_connection = np.load('../Data/sphere_connection.npy', allow_pickle=True).item() 63 | for i in range(500): 64 | print(i) 65 | cloth_data = [] 66 | cloth_path = ('../Data/data/%04d/%03d_cloth.txt' % (seq_id, i)) 67 | for line in open(cloth_path, 'r'): 68 | line = line.split('\n')[0] 69 | cloth_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 70 | cloth_data = np.array(cloth_data) 71 | 72 | cloth_idx = i 73 | if i < 499: 74 | cloth_idx = i+1 75 | ball_data = [] 76 | ball_path = ('../Data/data/%04d/%03d_ball.txt' % (seq_id, cloth_idx)) 77 | for line in open(ball_path, 'r'): 78 | line = line.split('\n')[0] 79 | ball_data.append(np.array([float(data) for data in line.split(' ')[:-1]])) 80 | ball_data = np.array(ball_data) 81 | 82 | foutuv = open('../Data/data/%04d/%03d_uv.txt' % (seq_id, i), 'w') 83 | for key in cloth_connection.keys(): 84 | for val in cloth_connection[key]: 85 | i_vertx = cloth_data[key] 86 | j_vertx = cloth_data[val] 87 | uij = i_vertx[15:18] - j_vertx[15:18] 88 | uij_norm = np.linalg.norm(uij, ord = 2) 89 | xij = i_vertx[:3] - j_vertx[:3] 90 | xij_norm = np.linalg.norm(xij, ord = 2) 91 | foutuv.write("%.6f %.6f %.6f %.6f %.6f %.6f %.6f\n" % \ 92 | (uij[0], uij[1], uij_norm, xij[0], xij[1], xij[2], xij_norm)) 93 | foutuv.close() 94 | 95 | foutworld = open('../Data/data/%04d/%03d_world.txt' % (seq_id, i), 'w') 96 | cloth_world_dis = np.sum((cloth_data[None, :, :3] - cloth_data[:, None, :3])**2, -1)**0.5 97 | ball_world_dis = np.sum((cloth_data[None, :, :3] - ball_data[:, None, :3])**2, -1)**0.5 98 | idxs_cloth = np.argwhere(cloth_world_dis < 0.02) 99 | idxs_ball = np.argwhere(ball_world_dis < 0.04) 100 | 101 | for idx in idxs_cloth: 102 | i_vertx = cloth_data[idx[0]] 103 | j_vertx = cloth_data[idx[1]] 104 | xij = i_vertx[:3] - j_vertx[:3] 105 | xij_norm = np.linalg.norm(xij, ord = 2) 106 | if (idx[0] != idx[1]) and (idx[0] not in cloth_connection[idx[1]]): 107 | foutworld.write("0 %d %d %.6f %.6f %.6f %.6f\n" % (idx[0], idx[1], xij[0], xij[1], xij[2], xij_norm)) 108 | 109 | for idx in idxs_ball: 110 | i_vertx = cloth_data[idx[1]] 111 | j_vertx = ball_data[idx[0]] 112 | xij = i_vertx[:3] - j_vertx[:3] 113 | xij_norm = np.linalg.norm(xij, ord = 2) 114 | foutworld.write("1 %d %d %.6f %.6f %.6f %.6f\n" % (idx[1], idx[0], xij[0], xij[1], xij[2], xij_norm)) 115 | foutworld.close() 116 | 117 | import threading 118 | import multiprocessing 119 | 120 | if __name__ == "__main__": 121 | for i in range(100): 122 | p1 = multiprocessing.Process(target = GenEdgeFeature, args = (i,)) 123 | p1.start() 124 | 125 | -------------------------------------------------------------------------------- /data_visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | 5 | def hist_vis(): 6 | has_ball = True 7 | cloth_mesh_tri_array = [] 8 | ball_mesh_tri_array = [] 9 | for line in open('../Data/cloth_connection.txt', 'r'): 10 | cloth_mesh_tri_array.append(int(line.split('\n')[0])) 11 | cloth_mesh_tri_array = np.array(cloth_mesh_tri_array).reshape(-1, 3) 12 | if has_ball: 13 | for line in open('../Data/sphere_connection.txt', 'r'): 14 | ball_mesh_tri_array.append(int(line.split('\n')[0])) 15 | ball_mesh_tri_array = np.array(ball_mesh_tri_array).reshape(-1, 3) 16 | 17 | pos = [] 18 | for line in open('/media/xjwxjw/626898FF2DF873F3/Documents/ClothSimData/data_0424/0043/%03d_ball.txt' % 3, 'r'): 19 | line = line.split('\n')[0].split(' ')[:3] 20 | pos.append(np.array([float(x) for x in line])) 21 | pos = np.stack(pos, 0) 22 | plt.hist(np.linalg.norm(pos[ball_mesh_tri_array[:,0]] - pos[ball_mesh_tri_array[:,1]], ord = 2, axis = -1), 200) 23 | plt.show() 24 | 25 | pos = [] 26 | for line in open('/media/xjwxjw/626898FF2DF873F3/Documents/ClothSimData/data_0424/0043/%03d_cloth.txt' % 3, 'r'): 27 | line = line.split('\n')[0].split(' ')[:3] 28 | pos.append(np.array([float(x) for x in line])) 29 | pos = np.stack(pos, 0) 30 | plt.hist(np.linalg.norm(pos[cloth_mesh_tri_array[:,0]] - pos[cloth_mesh_tri_array[:,1]], ord = 2, axis = -1), 200) 31 | plt.show() 32 | 33 | def acc_vis(): 34 | acc_array = [] 35 | for t in range(3, 496): 36 | for line in open('../Results_Acc/%04d.txt' % t, 'r'): 37 | acc = line.split('\n')[0].split(' ') 38 | acc = np.array([ float(a) for a in acc]) 39 | acc_array.append(acc) 40 | acc_array = np.array(acc_array) 41 | plt.hist(acc_array[:,1], 200) 42 | plt.show() 43 | 44 | def vel_vis(): 45 | vel_list = [] 46 | for t in range(3, 496): 47 | vel_array = [] 48 | for line in open('../Data/0050/%03d_cloth.txt' % t, 'r'): 49 | vel = line.split('\n')[0].split(' ')[:3] 50 | vel = np.array([float(v) for v in vel]) 51 | vel_array.append(vel) 52 | vel_array = np.array(vel_array) 53 | 54 | vel_nxt_array = [] 55 | for line in open('../Data/0050/%03d_cloth.txt' % (t+1), 'r'): 56 | vel = line.split('\n')[0].split(' ')[:3] 57 | vel = np.array([float(v) for v in vel]) 58 | vel_nxt_array.append(vel) 59 | vel_nxt_array = np.array(vel_nxt_array) 60 | vel_list.append(vel_nxt_array - vel_array) 61 | vel_list = np.concatenate(vel_list, 0) 62 | # plt.hist(vel_list[:,2], 200) 63 | # plt.show() 64 | std = np.std(vel_list, 0) 65 | print(std) 66 | 67 | def cloth_vis(): 68 | has_ball = True 69 | cloth_mesh_tri_array = [] 70 | ball_mesh_tri_array = [] 71 | for line in open('../Data/cloth_connection.txt', 'r'): 72 | cloth_mesh_tri_array.append(int(line.split('\n')[0])) 73 | cloth_mesh_tri_array = np.array(cloth_mesh_tri_array).reshape(-1, 3) 74 | 75 | if has_ball: 76 | for line in open('../Data/sphere_connection.txt', 'r'): 77 | ball_mesh_tri_array.append(int(line.split('\n')[0])) 78 | ball_mesh_tri_array = np.array(ball_mesh_tri_array).reshape(-1, 3) 79 | 80 | for file_idx in range(500): 81 | fig = plt.figure() 82 | ax = fig.gca(projection='3d') 83 | 84 | pos = [] 85 | for line in open('/home/xjwxjw/Documents/ClothSim/Data/data/0050/%03d_cloth.txt' % file_idx, 'r'): 86 | line = line.split('\n')[0].split(' ')[:3] 87 | pos.append(np.array([float(x) for x in line])) 88 | pos = np.stack(pos, 0) 89 | x = pos[:,0] 90 | y = pos[:,2] 91 | z = pos[:,1] 92 | ax.plot_trisurf(x, y, z, triangles = cloth_mesh_tri_array, linewidth = 0.2, antialiased = True, color = 'b') 93 | 94 | pos = [] 95 | for line in open('/home/xjwxjw/Documents/ClothSim/Data/data/0050/%03d_ball.txt' % file_idx, 'r'): 96 | line = line.split('\n')[0].split(' ')[:3] 97 | pos.append(np.array([float(x) for x in line])) 98 | pos = np.stack(pos, 0) 99 | x = pos[:,0] 100 | y = pos[:,2] 101 | z = pos[:,1] 102 | ax.plot_trisurf(x, y, z, triangles = ball_mesh_tri_array, linewidth = 0.2, antialiased = True, color = 'r') 103 | 104 | ax.set_xlim([-1.0, 1.0]) 105 | ax.set_ylim([-1.0, 1.0]) 106 | ax.set_zlim([-1.0, 1.0]) 107 | plt.savefig('../Results/%03d.png' % (file_idx - 0)) 108 | plt.close('all') 109 | 110 | def gen_video(): 111 | import cv2 112 | import os 113 | 114 | def get_file_names(search_path): 115 | for (dirpath, _, filenames) in os.walk(search_path): 116 | for filename in filenames: 117 | yield filename # os.path.join(dirpath, filename) 118 | 119 | def save_to_video(output_path, output_video_file, frame_rate): 120 | list_files = sorted([int(i.split('_')[-1].split('.')[0]) for i in get_file_names(output_path)]) 121 | # 拿一张图片确认宽高 122 | img0 = cv2.imread(os.path.join(output_path, '%03d.png' % list_files[0])) 123 | img1 = cv2.imread(os.path.join(output_path, '%03d.png' % list_files[0]).replace('Results', 'Pred')) 124 | img = np.concatenate([img0, img1], 1) 125 | 126 | # print(img0) 127 | height, width, layers = img.shape 128 | # 视频保存初始化 VideoWriter 129 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 130 | videowriter = cv2.VideoWriter(output_video_file, fourcc, frame_rate, (width, height)) 131 | # 核心,保存的东西 132 | font = cv2.FONT_HERSHEY_SIMPLEX 133 | for f in list_files: 134 | try: 135 | f = '%03d.png' % f 136 | # print("saving..." + f) 137 | 138 | img0 = cv2.imread(os.path.join(output_path, f)) 139 | img1 = cv2.imread(os.path.join(output_path, f).replace('Results', 'Pred')) 140 | 141 | img = np.concatenate([img0, img1], 1) 142 | img = cv2.putText(img, f, (0, 100), font, 1.2, (255, 0, 0), 2) 143 | img = cv2.putText(img, 'houdini', (250, 100), font, 1.2, (0, 0, 255), 2) 144 | img = cv2.putText(img, 'model pred', (850, 100), font, 1.2, (0, 0, 255), 2) 145 | videowriter.write(img) 146 | except: 147 | print(os.path.join(output_path, f).replace('Results', 'Pred')) 148 | videowriter.release() 149 | cv2.destroyAllWindows() 150 | print('Success save %s!' % output_video_file) 151 | pass 152 | 153 | # 图片变视频 154 | output_dir = '../Results' 155 | output_path = os.path.join(output_dir, '') # 输入图片存放位置 156 | output_video_file = './gt.mp4' # 输入视频保存位置以及视频名称 157 | save_to_video(output_path, output_video_file, 20) 158 | 159 | if __name__ == "__main__": 160 | cloth_vis() 161 | -------------------------------------------------------------------------------- /state_stat.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjwxjw/Pytorch-Learned-Cloth-Simulation/111c02217caf09ff365a0a507a50c755faea7083/state_stat.npz -------------------------------------------------------------------------------- /state_stat_sample.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjwxjw/Pytorch-Learned-Cloth-Simulation/111c02217caf09ff365a0a507a50c755faea7083/state_stat_sample.npz --------------------------------------------------------------------------------