├── README.md ├── __init__.py ├── image ├── CEM_result.png └── multi_step_pred_result.png ├── model ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── cem.cpython-36.pyc │ ├── cem.cpython-37.pyc │ ├── cem.cpython-38.pyc │ ├── cem_e2c.cpython-37.pyc │ ├── cem_gpu.cpython-38.pyc │ ├── cem_plot_curve.cpython-37.pyc │ ├── configs.cpython-36.pyc │ ├── configs.cpython-37.pyc │ ├── create_dataset.cpython-36.pyc │ ├── create_dataset.cpython-37.pyc │ ├── create_dataset.cpython-38.pyc │ ├── dl_model.cpython-36.pyc │ ├── dl_model.cpython-37.pyc │ ├── dl_model.cpython-38.pyc │ ├── dl_model_gpu.cpython-38.pyc │ ├── e2c.cpython-37.pyc │ ├── hidden_dynamics.cpython-36.pyc │ ├── hidden_dynamics.cpython-37.pyc │ ├── hidden_dynamics.cpython-38.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ ├── nn.cpython-36.pyc │ ├── nn.cpython-37.pyc │ ├── nn_linear.cpython-36.pyc │ ├── nn_linear.cpython-37.pyc │ ├── predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-36.pyc │ ├── predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-37.pyc │ └── predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred4.cpython-36.pyc ├── cem.py ├── cem_e2c.py ├── cem_plot_action.py ├── cem_plot_curve.py ├── create_dataset.py ├── e2c.py ├── hidden_dynamics.py ├── image │ ├── action_model.pdf │ ├── action_model.png │ ├── dynamics_model.pdf │ ├── dynamics_model.png │ ├── state_model.pdf │ └── state_model.png ├── nn_large_linear_seprt_no_loop_Kp_Lpa.py ├── predict_e2c.py ├── predict_e2c_our_method_compare_gpu.py └── predict_large_linear_seprt_no_loop_Kp_Lpa.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── params.cpython-36.pyc ├── utils.cpython-36.pyc ├── utils.cpython-37.pyc └── utils.cpython-38.pyc └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Deformable Linear Object Prediction Using Locally Linear Latent Dynamics 2 | 3 | ## Introduction 4 | This repository contains the code for the paper - Deformable Linear Object Prediction Using Locally Linear Latent Dynamics. 5 | 6 | Access the [Paper](https://arxiv.org/pdf/2103.14184.pdf) on Arxiv
7 | Author: [Wenbo Zhang](https://www.linkedin.com/in/wenbo-zhang6/), [Karl Schmeckpaper](https://sites.google.com/view/karlschmeckpeper), [Pratik Chaudhari](https://pratikac.github.io/), [Kostas Daniilidis](https://www.cis.upenn.edu/~kostas/)
8 | GRASP Laboratory, University of Pennsylvania
9 | The 2021 International Conference on Robotics and Automation (ICRA 2021), Xi'an, China 10 | 11 | ## Citation 12 | If you use this code for your research, please cite our paper: 13 | ``` 14 | @article{zhang2021deformable, 15 | title={Deformable Linear Object Prediction Using Locally Linear Latent Dynamics}, 16 | author={Zhang, Wenbo and Schmeckpeper, Karl and Chaudhari, Pratik and Daniilidis, Kostas}, 17 | journal={IEEE International Conference on Robotics and Automation (ICRA)}, 18 | year={2021} 19 | } 20 | ``` 21 | 22 | ## Running the code 23 | ### Preparation 24 | Create a folder for the repo 25 | ``` 26 | mkdir deform 27 | cd deform 28 | ``` 29 | Build the virtual environment 30 | ``` 31 | python3 -m pip install --user virtualenv 32 | python3 -m venv deform_env 33 | ``` 34 | 35 | Activate virtual environment 36 | ``` 37 | source deform_env/bin/activate 38 | ``` 39 | 40 | Install libraries and dependencies 41 | ``` 42 | pip install torch torchvision 43 | pip install matplotlib 44 | ``` 45 | 46 | Clone the program 47 | ``` 48 | git clone https://github.com/zwbgood6/deform.git 49 | ``` 50 | 51 | ### Dataset 52 | ``` 53 | mkdir rope_dataset 54 | ``` 55 | Download the dataset from the [Google Drive](https://drive.google.com/file/d/1jy1EUDSeH3d3cZUSK1xChOBvn-qqx9WA/view?usp=sharing) and place it in the folder `rope_dataset`. 56 | 57 | ``` 58 | cd rope_dataset 59 | unzip paper_dataset.zip 60 | ``` 61 | 62 | ### Training 63 | Go to the main deform folder directory 64 | ``` 65 | python -m deform.model.nn_large_linear_seprt_no_loop_Kp_Lpa 66 | ``` 67 | 68 | ### Prediction 69 | After training, run 70 | ``` 71 | python -m deform.model.predict_large_linear_seprt_no_loop_Kp_Lpa 72 | ``` 73 | 74 | ### Issues 75 | Please post a GitHub issue for any code related questions. 76 | 77 | ## Model Architecture and Hyperparameters 78 | 79 | ![Click here](./model/README.md) 80 | 81 | ## Experimental Results 82 | 83 | Ten-step Prediction 84 | 85 | 86 | 87 | Sampling-based Model Predictive Control 88 | 89 | 90 | 91 | ## License 92 | MIT 93 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/__init__.py -------------------------------------------------------------------------------- /image/CEM_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/image/CEM_result.png -------------------------------------------------------------------------------- /image/multi_step_pred_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/image/multi_step_pred_result.png -------------------------------------------------------------------------------- /model/README.md: -------------------------------------------------------------------------------- 1 | ## Model Architecture 2 | 3 | State Model 4 | 5 | 6 | 7 | Action Model 8 | 9 | 10 | 11 | Dynamics model 12 | 13 | 14 | 15 | ## Hyperparameters 16 | 17 | | Hyperparameters | Values | 18 | | :------------- | :----------: | 19 | | epochs (overall) | 1000 | 20 | | epochs (state and action encoder-decoder) | 500 | 21 | | epochs (dynamics model) | 500 | 22 | | learning rate | 1e-3 | 23 | | batch size | 32 | 24 | | latent state size | 80 | 25 | | latent action size | 80 | 26 | | λ1 (action coefficient in the loss function) | 450 | 27 | | λ2 (dynamics coefficient in the loss function) | 900 | 28 | | λ3 (prediction coefficient in the loss function) | 10 | 29 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__init__.py -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/cem.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/cem.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/cem.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/cem_e2c.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem_e2c.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/cem_gpu.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem_gpu.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/cem_plot_curve.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem_plot_curve.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/configs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/configs.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/configs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/configs.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/create_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/create_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/create_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/create_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/create_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/create_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/dl_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/dl_model.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/dl_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/dl_model.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/dl_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/dl_model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/dl_model_gpu.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/dl_model_gpu.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/e2c.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/e2c.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/hidden_dynamics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/hidden_dynamics.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hidden_dynamics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/hidden_dynamics.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/hidden_dynamics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/hidden_dynamics.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/nn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/nn.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/nn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/nn.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/nn_linear.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/nn_linear.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/nn_linear.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/nn_linear.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred4.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred4.cpython-36.pyc -------------------------------------------------------------------------------- /model/cem.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import det 3 | from numpy import sqrt 4 | from deform.model.dl_model import * 5 | from deform.model.create_dataset import * 6 | from deform.model.hidden_dynamics import get_next_state_linear 7 | import math 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision.utils import save_image 11 | from torch.distributions import Uniform 12 | from torch.distributions.multivariate_normal import MultivariateNormal 13 | from torch.distributions.kl import kl_divergence 14 | import torchvision.transforms.functional as TF 15 | from PIL import Image 16 | from deform.utils.utils import plot_cem_sample 17 | import os 18 | 19 | def sample_action(I, mean=None, cov=None): 20 | '''TODO: unit test 21 | each action sequence length: H 22 | number of action sequences: N 23 | ''' 24 | action = torch.tensor([0]*4, dtype=torch.float) 25 | multiplier = torch.tensor([50, 50, 2*math.pi, 0.14]) 26 | addition = torch.tensor([0, 0, 0, 0.01]) 27 | thres = 0.9 28 | if I[0][0][0][0] == 1.: 29 | if ((mean is None) and (cov is None)): 30 | action_base = Uniform(low=0.0, high=1.0).sample((4,)) 31 | action = torch.mul(action_base, multiplier) + addition 32 | else: 33 | cov = add_eye(cov) 34 | action = MultivariateNormal(mean, cov).sample() 35 | action[0], action[1] = 0, 0 36 | return action 37 | 38 | while I[0][0][torch.floor(action[0]).type(torch.LongTensor)][torch.floor(action[1]).type(torch.LongTensor)] != 1.: 39 | if ((mean is None) and (cov is None)): 40 | action_base = Uniform(low=0.0, high=1.0).sample((4,)) 41 | action = torch.mul(action_base, multiplier) + addition 42 | else: 43 | cov = add_eye(cov)# dont need to use multiplication and addition 44 | while torch.floor(action[0]).type(torch.LongTensor) >= 50 or torch.floor(action[1]).type(torch.LongTensor) >= 50: 45 | cov = add_eye(cov) 46 | action = MultivariateNormal(mean, cov).sample() 47 | return action 48 | 49 | def generate_next_pred_state(recon_model, dyn_model, img_pre, act_pre): 50 | '''generate next predicted state 51 | reconstruction model: recon_model 52 | dynamics model: dyn_model 53 | initial image: img_pre 54 | each action sequence length: H 55 | number of action sequences: N 56 | ''' 57 | latent_img_pre, latent_act_pre, _, _, _ = recon_model(img_pre.reshape((-1, 1, 50, 50)), act_pre.reshape((-1, 4)).type(torch.float), None) 58 | K_T_pre, L_T_pre = dyn_model(img_pre.reshape((-1, 1, 50, 50)), act_pre.reshape((-1, 4)).type(torch.float)) 59 | recon_latent_img_cur = get_next_state_linear(latent_img_pre, latent_act_pre, K_T_pre, L_T_pre) 60 | return recon_model.decoder(recon_latent_img_cur) 61 | 62 | def generate_next_pred_state_in_n_step(recon_model, dyn_model, img_initial, N, H, mean=None, cov=None): 63 | imgs = [None]*N 64 | actions = torch.tensor([[0.]*4]*N) 65 | for n in range(N): 66 | img_tmp = img_initial 67 | for h in range(H): 68 | action = sample_action(img_tmp, mean, cov) 69 | if h==0: 70 | actions[n] = action 71 | img_tmp = generate_next_pred_state(recon_model, dyn_model, img_tmp, action) 72 | imgs[n] = img_tmp 73 | return imgs, actions 74 | 75 | def loss_function_img(img_recon, img_goal, N): 76 | loss = torch.as_tensor([0.]*N) 77 | for n in range(N): 78 | loss[n] = F.binary_cross_entropy(img_recon[n].view(-1, 2500), img_goal.view(-1, 2500), reduction='sum') 79 | return loss 80 | 81 | def add_eye(cov): 82 | if det(cov)==0: 83 | return cov + torch.eye(4) * 0.000001 84 | else: 85 | return cov 86 | 87 | def mahalanobis(dist, cov): 88 | '''dist = mu1 - mu2, mu1 & mu2 are means of two multivariate gaussian distribution 89 | matrix multiplication: dist^T * cov^(-1) * dist 90 | ''' 91 | return (dist.transpose(0,1).mm(cov.inverse())).mm(dist) 92 | 93 | def bhattacharyya(dist, cov1, cov2): 94 | '''source: https://en.wikipedia.org/wiki/Bhattacharyya_distance 95 | ''' 96 | cov = (cov1 + cov2) / 2 97 | d1 = mahalanobis(dist.reshape((4,-1)), cov) / 8 98 | if det(cov)==0 or det(cov1)==0 or det(cov2)==0: 99 | return inf 100 | d2 = np.log(det(cov) / sqrt(det(cov1) * det(cov2))) / 2 101 | return d1 + d2 102 | 103 | def main(recon_model, dyn_model, T, K, N, H, img_initial, img_goal, resz_act, step_i): 104 | for t in range(T): 105 | print("***** Start Step {}".format(t)) 106 | if t==0: 107 | img_cur = img_initial 108 | #Initialize Q with uniform distribution 109 | mean = None 110 | cov = None 111 | mean_tmp = None 112 | cov_tmp = None 113 | converge = False 114 | iter_count = 0 115 | while not converge: 116 | #Use model M to predict the next state using M action sequences 117 | imgs_recon, sample_actions = generate_next_pred_state_in_n_step(recon_model, dyn_model, img_cur, N, H, mean, cov) 118 | #Calculate binary cross entropy loss for predicted image and goal image 119 | loss = loss_function_img(imgs_recon, img_goal, N) 120 | #Select K action sequences with lowest loss 121 | loss_index = torch.argsort(loss) 122 | sorted_sample_actions = sample_actions[loss_index] 123 | #Fit multivariate gaussian distribution to K samples 124 | #(see how to fit algorithm: 125 | #https://stackoverflow.com/questions/27230824/fit-multivariate-gaussian-distribution-to-a-given-dataset) 126 | mean = torch.mean(sorted_sample_actions[:K], dim=0).type(torch.DoubleTensor) 127 | cov = torch.from_numpy(np.cov(sorted_sample_actions[:K], rowvar=0)).type(torch.DoubleTensor) 128 | # iteration is based on convergence of Q 129 | if det(cov) == 0 or cov_tmp == None: 130 | mean_tmp = mean 131 | cov_tmp = cov 132 | continue 133 | else: 134 | if det(cov_tmp)==0: 135 | mean_tmp = mean 136 | cov_tmp = cov 137 | continue 138 | else: 139 | p = MultivariateNormal(mean, cov) 140 | q = MultivariateNormal(mean_tmp, cov_tmp) 141 | if kl_divergence(p, q) < 1000: # 0.3 is okay 142 | converge = True 143 | 144 | mean_tmp = mean 145 | cov_tmp = cov 146 | 147 | print("***** At action time step {}, iteration {} *****".format(t, iter_count)) 148 | iter_count += 1 149 | 150 | 151 | #Execute action a{t}* with lowest loss 152 | action_best = sorted_sample_actions[0] 153 | torch.save(action_best, "./plan_result/{}/action_best_step{}_N{}_K{}.pt".format(plan_folder_name, step_i, N, K)) 154 | #Observe new image I{t+1} 155 | img_cur = generate_next_pred_state(recon_model, dyn_model, img_cur, action_best) 156 | comparison = torch.cat([img_initial, img_goal, img_cur]) 157 | save_image(comparison, "./plan_result/{}/image_best_step{}_N{}_K{}.png".format(plan_folder_name, step_i, N, K)) 158 | plot_cem_sample(img_initial.detach().reshape((50,50)).cpu().numpy(), 159 | img_goal.detach().reshape((50,50)).cpu().numpy(), 160 | img_cur.detach().reshape((50,50)).cpu().numpy(), 161 | resz_act[:4], 162 | action_best.detach().cpu().numpy(), 163 | './plan_result/{}/compare_align_{}.png'.format(plan_folder_name, i)) 164 | print("***** Generate Next Predicted Image {}*****".format(t+1)) 165 | 166 | print("***** End Planning *****") 167 | 168 | # plan result folder name 169 | plan_folder_name = 'test_KL1000' 170 | if not os.path.exists('./plan_result/{}'.format(plan_folder_name)): 171 | os.makedirs('./plan_result/{}'.format(plan_folder_name)) 172 | # time step to execute the action 173 | T = 1 174 | # total number of samples for action sequences 175 | N = 100 176 | # K samples to fit multivariate gaussian distribution (N>K, K>1) 177 | K = 50 178 | # length of action sequence 179 | H = 1 180 | # model 181 | torch.manual_seed(1) 182 | device = torch.device("cpu") 183 | print("Device is:", device) 184 | recon_model = CAE().to(device) 185 | dyn_model = SysDynamics().to(device) 186 | 187 | # action 188 | # load GT action 189 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy' 190 | resz_act = np.load(resz_act_path) 191 | 192 | # checkpoint 193 | print('***** Load Checkpoint *****') 194 | folder_name = "test_act80_pred30" 195 | PATH = './result/{}/checkpoint'.format(folder_name) 196 | checkpoint = torch.load(PATH, map_location=device) 197 | recon_model.load_state_dict(checkpoint['recon_model_state_dict']) 198 | dyn_model.load_state_dict(checkpoint['dyn_model_state_dict']) 199 | 200 | total_img_num = 22515 201 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num) 202 | 203 | 204 | def get_image(i): 205 | img = TF.to_tensor(Image.open(image_paths_bi[i])) > 0.3 206 | return img.reshape((-1, 1, 50, 50)).type(torch.float) 207 | 208 | for i in range(0,1): 209 | img_initial = get_image(i) 210 | img_goal = get_image(i+1) 211 | main(recon_model, dyn_model, T, K, N, H, img_initial, img_goal, resz_act[i], i) -------------------------------------------------------------------------------- /model/cem_e2c.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import det 3 | from numpy import sqrt 4 | from deform.model.dl_model import * 5 | from deform.model.create_dataset import * 6 | from deform.model.hidden_dynamics import get_next_state_linear 7 | import math 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision.utils import save_image 11 | from torch.distributions import Uniform 12 | from torch.distributions.multivariate_normal import MultivariateNormal 13 | from torch.distributions.kl import kl_divergence 14 | import torchvision.transforms.functional as TF 15 | from PIL import Image 16 | from deform.utils.utils import plot_cem_sample 17 | import os 18 | 19 | def sample_action(I, mean=None, cov=None): 20 | '''TODO: unit test 21 | each action sequence length: H 22 | number of action sequences: N 23 | ''' 24 | action = torch.tensor([0]*4, dtype=torch.float) 25 | multiplier = torch.tensor([50, 50, 2*math.pi, 0.14]) 26 | addition = torch.tensor([0, 0, 0, 0.01]) 27 | thres = 0.9 28 | 29 | if I[0][0][0][0] == 1.: 30 | if ((mean is None) and (cov is None)): 31 | action_base = Uniform(low=0.0, high=1.0).sample((4,)) 32 | action = torch.mul(action_base, multiplier) + addition 33 | else: 34 | cov = add_eye(cov) 35 | action = MultivariateNormal(mean, cov).sample() 36 | action[0], action[1] = 0, 0 37 | return action 38 | 39 | while I[0][0][torch.floor(action[0]).type(torch.LongTensor)][torch.floor(action[1]).type(torch.LongTensor)] != 1.: 40 | if ((mean is None) and (cov is None)): 41 | action_base = Uniform(low=0.0, high=1.0).sample((4,)) 42 | action = torch.mul(action_base, multiplier) + addition 43 | else: 44 | cov = add_eye(cov) 45 | action = MultivariateNormal(mean, cov).sample() 46 | while torch.floor(action[0]).type(torch.LongTensor) >= 50 or torch.floor(action[1]).type(torch.LongTensor) >= 50: 47 | cov = add_eye(cov) 48 | action = MultivariateNormal(mean, cov).sample() 49 | 50 | return action 51 | 52 | def generate_next_pred_state(e2c_model, img_pre, act_pre): 53 | '''generate next predicted state 54 | reconstruction model: recon_model 55 | dynamics model: dyn_model 56 | initial image: img_pre 57 | each action sequence length: H 58 | number of action sequences: N 59 | ''' 60 | recon_img_cur = e2c_model.predict(img_pre, act_pre) 61 | return recon_img_cur 62 | 63 | 64 | def generate_next_pred_state_in_n_step(e2c_model, img_initial, N, H, mean=None, cov=None): 65 | imgs = [None]*N 66 | actions = torch.tensor([[0.]*4]*N) 67 | for n in range(N): 68 | img_tmp = img_initial 69 | for h in range(H): 70 | action = sample_action(img_tmp, mean, cov) 71 | if h==0: 72 | actions[n] = action 73 | img_tmp = generate_next_pred_state(e2c_model, img_tmp, action) 74 | imgs[n] = img_tmp 75 | return imgs, actions 76 | 77 | def loss_function_img(img_recon, img_goal, N): 78 | loss = torch.as_tensor([0.]*N) 79 | for n in range(N): 80 | loss[n] = F.binary_cross_entropy(img_recon[n].view(-1, 2500), img_goal.view(-1, 2500), reduction='sum') 81 | return loss 82 | 83 | def add_eye(cov): 84 | if det(cov)==0: 85 | return cov + torch.eye(4) * 0.000001 86 | else: 87 | return cov 88 | 89 | def mahalanobis(dist, cov): 90 | '''dist = mu1 - mu2, mu1 & mu2 are means of two multivariate gaussian distribution 91 | matrix multiplication: dist^T * cov^(-1) * dist 92 | ''' 93 | return (dist.transpose(0,1).mm(cov.inverse())).mm(dist) 94 | 95 | def bhattacharyya(dist, cov1, cov2): 96 | '''source: https://en.wikipedia.org/wiki/Bhattacharyya_distance 97 | ''' 98 | cov = (cov1 + cov2) / 2 99 | d1 = mahalanobis(dist.reshape((4,-1)), cov) / 8 100 | if det(cov)==0 or det(cov1)==0 or det(cov2)==0: 101 | return inf 102 | d2 = np.log(det(cov) / sqrt(det(cov1) * det(cov2))) / 2 103 | return d1 + d2 104 | 105 | def main(e2c_model, T, K, N, H, img_initial, img_goal, resz_act, step_i, KL): 106 | for t in range(T): 107 | print("***** Start Step {}".format(t)) 108 | if t==0: 109 | img_cur = img_initial 110 | #Initialize Q with uniform distribution 111 | mean = None 112 | cov = None 113 | mean_tmp = None 114 | cov_tmp = None 115 | converge = False 116 | iter_count = 0 117 | while not converge: 118 | imgs_recon, sample_actions = generate_next_pred_state_in_n_step(e2c_model, img_cur, N, H, mean, cov) 119 | #Calculate binary cross entropy loss for predicted image and goal image 120 | loss = loss_function_img(imgs_recon, img_goal, N) 121 | #Select K action sequences with lowest loss 122 | loss_index = torch.argsort(loss) 123 | sorted_sample_actions = sample_actions[loss_index] 124 | #Fit multivariate gaussian distribution to K samples 125 | #(see how to fit algorithm: 126 | #https://stackoverflow.com/questions/27230824/fit-multivariate-gaussian-distribution-to-a-given-dataset) 127 | mean = torch.mean(sorted_sample_actions[:K], dim=0).type(torch.DoubleTensor) 128 | cov = torch.from_numpy(np.cov(sorted_sample_actions[:K], rowvar=0)).type(torch.DoubleTensor) 129 | # iteration is based on convergence of Q 130 | if det(cov) == 0 or cov_tmp == None: 131 | mean_tmp = mean 132 | cov_tmp = cov 133 | continue 134 | else: 135 | if det(cov_tmp)==0: 136 | mean_tmp = mean 137 | cov_tmp = cov 138 | continue 139 | else: 140 | p = MultivariateNormal(mean, cov) 141 | q = MultivariateNormal(mean_tmp, cov_tmp) 142 | if kl_divergence(p, q) < KL: # 0.3 is okay 143 | converge = True 144 | mean_tmp = mean 145 | cov_tmp = cov 146 | 147 | print("***** At action time step {}, iteration {} *****".format(t, iter_count)) 148 | iter_count += 1 149 | 150 | #Execute action a{t}* with lowest loss 151 | action_best = sorted_sample_actions[0] 152 | action_loss = ((action_best.detach().cpu().numpy()-resz_act[:4])**2).mean(axis=None) 153 | #Observe new image I{t+1} 154 | img_cur = generate_next_pred_state(e2c_model, img_cur, action_best) 155 | img_loss = F.binary_cross_entropy(img_cur.view(-1, 2500), img_goal.view(-1, 2500), reduction='mean') 156 | 157 | print("***** Generate Next Predicted Image {}*****".format(t+1)) 158 | print("***** End Planning *****") 159 | return action_loss, img_loss.detach().cpu().numpy() 160 | 161 | # plan result folder name 162 | plan_folder_name = 'test_e2c' 163 | if not os.path.exists('./plan_result/{}'.format(plan_folder_name)): 164 | os.makedirs('./plan_result/{}'.format(plan_folder_name)) 165 | # time step to execute the action 166 | T = 1 167 | # total number of samples for action sequences 168 | N = 100 169 | # K samples to fit multivariate gaussian distribution (N>K, K>1) 170 | K = 50 171 | # length of action sequence 172 | H = 1 173 | # model 174 | torch.manual_seed(1) 175 | device = torch.device("cpu") 176 | print("Device is:", device) 177 | e2c_model = E2C().to(device) 178 | 179 | # action 180 | # load GT action 181 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy' 182 | resz_act = np.load(resz_act_path) 183 | 184 | # checkpoint 185 | print('***** Load Checkpoint *****') 186 | folder_name = "test_E2C_gpu_update_loss" 187 | PATH = './result/{}/checkpoint'.format(folder_name) 188 | checkpoint_e2c = torch.load(PATH, map_location=device) 189 | e2c_model.load_state_dict(checkpoint_e2c['e2c_model_state_dict']) 190 | 191 | total_img_num = 22515 192 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num) 193 | 194 | 195 | def get_image(i): 196 | img = TF.to_tensor(Image.open(image_paths_bi[i])) > 0.3 197 | return img.reshape((-1, 1, 50, 50)).type(torch.float) 198 | 199 | for KL in [1000]: 200 | action_loss_all = [] 201 | img_loss_all = [] 202 | for i in range(20000, 20010): 203 | img_initial = get_image(i) 204 | img_goal = get_image(i+1) 205 | action_loss, img_loss = main(e2c_model, T, K, N, H, img_initial, img_goal, resz_act[i], i, KL) 206 | action_loss_all.append(action_loss) 207 | img_loss_all.append(img_loss) 208 | np.save('./plan_result/{}/KL_action_{}.npy'.format(plan_folder_name, KL), action_loss_all) 209 | np.save('./plan_result/{}/KL_image_{}.npy'.format(plan_folder_name, KL), img_loss_all) -------------------------------------------------------------------------------- /model/cem_plot_action.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import math 5 | 6 | def rect(poke, c): 7 | x, y, t, l = poke 8 | dx = -2000 * l * math.cos(t) 9 | dy = -2000 * l * math.sin(t) 10 | plt.arrow(x, y, dx, dy, width=0.01, head_width=1, head_length=3, color=c) 11 | 12 | def plot_sample(img_before, img_after, img_after_pred, resz_action, recon_action, directory): 13 | plt.figure() 14 | N = int(img_before.shape[0]) 15 | # upper row original 16 | plt.subplot(2, 2, 1) 17 | rect(resz_action[i], "blue") 18 | plt.imshow(img_before.reshape((50,50))) 19 | plt.axis('off') 20 | # middle row reconstruction 21 | plt.subplot(2, 2, 2) 22 | plt.imshow(img_after.reshape((50,50))) 23 | plt.axis('off') 24 | # lower row: next image after action 25 | plt.subplot(2, 2, 3) 26 | rect(recon_action[i], "blue") 27 | plt.imshow(img_before.reshape((50,50))) 28 | plt.axis('off') 29 | # lower row: next image after action 30 | plt.subplot(2, 2, 4) 31 | plt.imshow(img_after_pred.reshape((50,50))) 32 | plt.axis('off') 33 | plt.savefig(directory) 34 | plt.close() 35 | 36 | def plot_action(resz_action, recon_action, directory): 37 | plt.figure() 38 | # upper row original 39 | plt.subplot(1, 2, 1) 40 | rect(resz_action, "blue") 41 | plt.axis('off') 42 | # middle row reconstruction 43 | plt.subplot(1, 2, 2) 44 | rect(recon_action, "red") 45 | plt.axis('off') 46 | plt.savefig(directory) 47 | plt.close() 48 | 49 | 50 | def get_image(i): 51 | img = TF.to_tensor(Image.open(image_paths_bi[i])) > 0.3 52 | return img.reshape((-1, 1, 50, 50)).type(torch.float) 53 | 54 | # total number of samples for action sequences 55 | N = 500 56 | # K samples to fit multivariate gaussian distribution (N>K, K>1) 57 | K = 50 58 | 59 | # load GT action 60 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy' 61 | resz_act = np.load(resz_act_path) 62 | 63 | # load image 64 | total_img_num = 22515 65 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num) 66 | 67 | for i in range(45): 68 | img_before = get_image(i) 69 | img_after = get_image(i+1) 70 | action_best = torch.load("./plan_result/03/action_best_step{}_N{}_K{}.pt".format(i, N, K)) 71 | plot_sample(img_before, 72 | img_after, 73 | img_after_pred, 74 | resz_action, 75 | recon_action, 76 | './plan_result/03/compare_align_{}'.format(i)) -------------------------------------------------------------------------------- /model/cem_plot_curve.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import det 3 | from numpy import sqrt 4 | from deform.model.dl_model import * 5 | from deform.model.create_dataset import * 6 | from deform.model.hidden_dynamics import get_next_state_linear 7 | import math 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision.utils import save_image 11 | from torch.distributions import Uniform 12 | from torch.distributions.multivariate_normal import MultivariateNormal 13 | from torch.distributions.kl import kl_divergence 14 | import torchvision.transforms.functional as TF 15 | from PIL import Image 16 | from deform.utils.utils import plot_cem_sample 17 | import os 18 | 19 | def sample_action(I, mean=None, cov=None): 20 | '''TODO: unit test 21 | each action sequence length: H 22 | number of action sequences: N 23 | ''' 24 | action = torch.tensor([0]*4, dtype=torch.float) 25 | multiplier = torch.tensor([50, 50, 2*math.pi, 0.14]) 26 | addition = torch.tensor([0, 0, 0, 0.01]) 27 | thres = 0.9 28 | if I[0][0][0][0] == 1.: 29 | if ((mean is None) and (cov is None)): 30 | action_base = Uniform(low=0.0, high=1.0).sample((4,)) 31 | action = torch.mul(action_base, multiplier) + addition 32 | else: 33 | cov = add_eye(cov) 34 | action = MultivariateNormal(mean, cov).sample() 35 | action[0], action[1] = 0, 0 36 | return action 37 | 38 | while I[0][0][torch.floor(action[0]).type(torch.LongTensor)][torch.floor(action[1]).type(torch.LongTensor)] != 1.: 39 | if ((mean is None) and (cov is None)): 40 | action_base = Uniform(low=0.0, high=1.0).sample((4,)) 41 | action = torch.mul(action_base, multiplier) + addition 42 | else: 43 | cov = add_eye(cov) 44 | action = MultivariateNormal(mean, cov).sample() 45 | while torch.floor(action[0]).type(torch.LongTensor) >= 50 or torch.floor(action[1]).type(torch.LongTensor) >= 50: 46 | cov = add_eye(cov) 47 | action = MultivariateNormal(mean, cov).sample() 48 | 49 | return action 50 | 51 | def generate_next_pred_state(recon_model, dyn_model, img_pre, act_pre): 52 | '''generate next predicted state 53 | reconstruction model: recon_model 54 | dynamics model: dyn_model 55 | initial image: img_pre 56 | each action sequence length: H 57 | number of action sequences: N 58 | ''' 59 | latent_img_pre, latent_act_pre, _, _, _ = recon_model(img_pre.reshape((-1, 1, 50, 50)), act_pre.reshape((-1, 4)).type(torch.float), None) 60 | K_T_pre, L_T_pre = dyn_model(img_pre.reshape((-1, 1, 50, 50)), act_pre.reshape((-1, 4)).type(torch.float)) 61 | recon_latent_img_cur = get_next_state_linear(latent_img_pre, latent_act_pre, K_T_pre, L_T_pre) 62 | return recon_model.decoder(recon_latent_img_cur) 63 | 64 | def generate_next_pred_state_in_n_step(recon_model, dyn_model, img_initial, N, H, mean=None, cov=None): 65 | imgs = [None]*N 66 | actions = torch.tensor([[0.]*4]*N) 67 | for n in range(N): 68 | img_tmp = img_initial 69 | for h in range(H): 70 | action = sample_action(img_tmp, mean, cov) 71 | if h==0: 72 | actions[n] = action 73 | img_tmp = generate_next_pred_state(recon_model, dyn_model, img_tmp, action) 74 | imgs[n] = img_tmp 75 | return imgs, actions 76 | 77 | def loss_function_img(img_recon, img_goal, N): 78 | loss = torch.as_tensor([0.]*N) 79 | for n in range(N): 80 | loss[n] = F.binary_cross_entropy(img_recon[n].view(-1, 2500), img_goal.view(-1, 2500), reduction='sum') 81 | return loss 82 | 83 | def add_eye(cov): 84 | if det(cov)==0: 85 | return cov + torch.eye(4) * 0.000001 86 | else: 87 | return cov 88 | 89 | def mahalanobis(dist, cov): 90 | '''dist = mu1 - mu2, mu1 & mu2 are means of two multivariate gaussian distribution 91 | matrix multiplication: dist^T * cov^(-1) * dist 92 | ''' 93 | return (dist.transpose(0,1).mm(cov.inverse())).mm(dist) 94 | 95 | def bhattacharyya(dist, cov1, cov2): 96 | '''source: https://en.wikipedia.org/wiki/Bhattacharyya_distance 97 | ''' 98 | cov = (cov1 + cov2) / 2 99 | d1 = mahalanobis(dist.reshape((4,-1)), cov) / 8 100 | if det(cov)==0 or det(cov1)==0 or det(cov2)==0: 101 | return inf 102 | d2 = np.log(det(cov) / sqrt(det(cov1) * det(cov2))) / 2 103 | return d1 + d2 104 | 105 | def main(recon_model, dyn_model, T, K, N, H, img_initial, img_goal, resz_act, step_i, KL): 106 | for t in range(T): 107 | print("***** Start Step {}".format(t)) 108 | if t==0: 109 | img_cur = img_initial 110 | #Initialize Q with uniform distribution 111 | mean = None 112 | cov = None 113 | mean_tmp = None 114 | cov_tmp = None 115 | converge = False 116 | iter_count = 0 117 | while not converge: 118 | imgs_recon, sample_actions = generate_next_pred_state_in_n_step(recon_model, dyn_model, img_cur, N, H, mean, cov) 119 | #Calculate binary cross entropy loss for predicted image and goal image 120 | loss = loss_function_img(imgs_recon, img_goal, N) 121 | #Select K action sequences with lowest loss 122 | loss_index = torch.argsort(loss) 123 | sorted_sample_actions = sample_actions[loss_index] 124 | #Fit multivariate gaussian distribution to K samples 125 | #(see how to fit algorithm: 126 | #https://stackoverflow.com/questions/27230824/fit-multivariate-gaussian-distribution-to-a-given-dataset) 127 | mean = torch.mean(sorted_sample_actions[:K], dim=0).type(torch.DoubleTensor) 128 | cov = torch.from_numpy(np.cov(sorted_sample_actions[:K], rowvar=0)).type(torch.DoubleTensor) 129 | # iteration is based on convergence of Q 130 | if det(cov) == 0 or cov_tmp == None: 131 | mean_tmp = mean 132 | cov_tmp = cov 133 | continue 134 | else: 135 | if det(cov_tmp)==0: 136 | mean_tmp = mean 137 | cov_tmp = cov 138 | continue 139 | else: 140 | p = MultivariateNormal(mean, cov) 141 | q = MultivariateNormal(mean_tmp, cov_tmp) 142 | if kl_divergence(p, q) < KL: 143 | converge = True 144 | mean_tmp = mean 145 | cov_tmp = cov 146 | 147 | print("***** At action time step {}, iteration {} *****".format(t, iter_count)) 148 | iter_count += 1 149 | 150 | #Execute action a{t}* with lowest loss 151 | action_best = sorted_sample_actions[0] 152 | action_loss = ((action_best.detach().cpu().numpy()-resz_act[:4])**2).mean(axis=None) 153 | #Observe new image I{t+1} 154 | img_cur = generate_next_pred_state(recon_model, dyn_model, img_cur, action_best) 155 | img_loss = F.binary_cross_entropy(img_cur.view(-1, 2500), img_goal.view(-1, 2500), reduction='mean') 156 | print("***** Generate Next Predicted Image {}*****".format(t+1)) 157 | 158 | print("***** End Planning *****") 159 | return action_loss, img_loss.detach().cpu().numpy() 160 | 161 | # plan result folder name 162 | plan_folder_name = 'curve_KL' 163 | if not os.path.exists('./plan_result/{}'.format(plan_folder_name)): 164 | os.makedirs('./plan_result/{}'.format(plan_folder_name)) 165 | # time step to execute the action 166 | T = 1 167 | # total number of samples for action sequences 168 | N = 100 169 | # K samples to fit multivariate gaussian distribution (N>K, K>1) 170 | K = 50 171 | # length of action sequence 172 | H = 1 173 | # model 174 | torch.manual_seed(1) 175 | device = torch.device("cpu") 176 | print("Device is:", device) 177 | recon_model = CAE().to(device) 178 | dyn_model = SysDynamics().to(device) 179 | 180 | # action 181 | # load GT action 182 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy' 183 | resz_act = np.load(resz_act_path) 184 | 185 | # checkpoint 186 | print('***** Load Checkpoint *****') 187 | folder_name = "test_act80_pred30" 188 | PATH = './result/{}/checkpoint'.format(folder_name) 189 | checkpoint = torch.load(PATH, map_location=device) 190 | recon_model.load_state_dict(checkpoint['recon_model_state_dict']) 191 | dyn_model.load_state_dict(checkpoint['dyn_model_state_dict']) 192 | 193 | total_img_num = 22515 194 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num) 195 | 196 | 197 | def get_image(i): 198 | img = TF.to_tensor(Image.open(image_paths_bi[i])) > 0.3 199 | return img.reshape((-1, 1, 50, 50)).type(torch.float) 200 | 201 | 202 | for KL in [1000]: 203 | action_loss_all = [] 204 | img_loss_all = [] 205 | for i in range(20000, 20010): 206 | img_initial = get_image(i) 207 | img_goal = get_image(i+1) 208 | action_loss, img_loss = main(recon_model, dyn_model, T, K, N, H, img_initial, img_goal, resz_act[i], i, KL) 209 | action_loss_all.append(action_loss) 210 | img_loss_all.append(img_loss) 211 | np.save('./plan_result/{}/KL_action_{}.npy'.format(plan_folder_name, KL), action_loss_all) 212 | np.save('./plan_result/{}/KL_image_{}.npy'.format(plan_folder_name, KL), img_loss_all) -------------------------------------------------------------------------------- /model/create_dataset.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms.functional as TF 2 | from torchvision import transforms 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | import random 6 | import math 7 | from PIL import Image 8 | import matplotlib.pyplot as plt 9 | import torch 10 | from torch.utils.data.dataloader import default_collate 11 | 12 | class MyDataset(Dataset): 13 | def __init__(self, image_paths_bi, resize_actions, transform=None): 14 | self.image_paths_bi = image_paths_bi 15 | self.resz_actions = resize_actions 16 | self.transform = transform 17 | 18 | def __getitem__(self, index): 19 | n = self.__len__() 20 | none_sample = {'image_bi_pre': None, 'image_bi_cur': None, 'image_bi_post': None, 'resz_action_pre': None, 'resz_action_cur': None} 21 | if index == 0: 22 | if n == 2: 23 | index = 1 24 | else: 25 | index = np.random.randint(1, n-1) 26 | if index == n-1: 27 | if n == 2: 28 | index = 1 29 | else: 30 | index = np.random.randint(1, n-1) 31 | 32 | # load action 33 | resz_action_pre = self.resz_actions[index-1] 34 | resz_action_cur = self.resz_actions[index] 35 | # decide if action is valid 36 | if int(resz_action_pre[4]) == 0 or int(resz_action_cur[4]) == 0: 37 | return none_sample 38 | 39 | # load images (pre-transform images) 40 | image_bi_pre = Image.open(self.image_paths_bi[index-1]) 41 | image_bi_cur = Image.open(self.image_paths_bi[index]) 42 | image_bi_post = Image.open(self.image_paths_bi[index+1]) 43 | 44 | sample = {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'resz_action_pre': resz_action_pre[:4], 'resz_action_cur': resz_action_cur[:4]} 45 | 46 | if self.transform: 47 | sample = self.transform(sample) 48 | 49 | return sample 50 | 51 | def __len__(self): 52 | return len(self.image_paths_bi) 53 | 54 | class MyDatasetMultiPred4(Dataset): 55 | def __init__(self, image_paths_bi, resize_actions, transform=None): 56 | self.image_paths_bi = image_paths_bi 57 | self.resz_actions = resize_actions 58 | self.transform = transform 59 | 60 | def __getitem__(self, index): 61 | n = self.__len__() 62 | none_sample = {'image_bi_pre': None, 'image_bi_cur': None, 'image_bi_post': None, 'image_bi_post2': None, 'image_bi_post3': None,\ 63 | 'resz_action_pre': None, 'resz_action_cur': None, 'resz_action_post': None, 'resz_action_post2': None} 64 | # edge cases, use index in [1,n-1) 65 | if index == 0: 66 | index = np.random.randint(1, n-3) 67 | if index == n-1: 68 | index = np.random.randint(1, n-3) 69 | if index == n-2: 70 | index = np.random.randint(1, n-3) 71 | if index == n-3: 72 | index = np.random.randint(1, n-3) 73 | 74 | # load action 75 | resz_action_pre = self.resz_actions[index-1] 76 | resz_action_cur = self.resz_actions[index] 77 | resz_action_post = self.resz_actions[index+1] 78 | resz_action_post2 = self.resz_actions[index+2] 79 | # decide if action is valid 80 | if int(resz_action_pre[4]) == 0 or int(resz_action_cur[4]) == 0 or int(resz_action_post[4]) == 0 or int(resz_action_post2[4]) == 0: 81 | return none_sample 82 | 83 | # load images (pre-transform images) 84 | image_bi_pre = Image.open(self.image_paths_bi[index-1]) 85 | image_bi_cur = Image.open(self.image_paths_bi[index]) 86 | image_bi_post = Image.open(self.image_paths_bi[index+1]) 87 | image_bi_post2 = Image.open(self.image_paths_bi[index+2]) 88 | image_bi_post3 = Image.open(self.image_paths_bi[index+3]) 89 | 90 | sample = {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'image_bi_post2': image_bi_post2, 'image_bi_post3': image_bi_post3,\ 91 | 'resz_action_pre': resz_action_pre[:4], 'resz_action_cur': resz_action_cur[:4], 'resz_action_post': resz_action_post[:4], 'resz_action_post2': resz_action_post2[:4]} 92 | 93 | if self.transform: 94 | sample = self.transform(sample) 95 | 96 | return sample 97 | 98 | def __len__(self): 99 | return len(self.image_paths_bi) 100 | 101 | class MyDatasetMultiPred10(Dataset): 102 | def __init__(self, image_paths_bi, resize_actions, transform=None): 103 | self.image_paths_bi = image_paths_bi 104 | self.resz_actions = resize_actions 105 | self.transform = transform 106 | 107 | def __getitem__(self, index): 108 | n = self.__len__() 109 | none_sample = {'image_bi_pre': None, 'image_bi_cur': None, 'image_bi_post': None, 'image_bi_post2': None, 'image_bi_post3': None,\ 110 | 'image_bi_post4': None, 'image_bi_post5': None, 'image_bi_post6': None, 'image_bi_post7': None, 'image_bi_post8': None, 'image_bi_post9': None,\ 111 | 'resz_action_pre': None, 'resz_action_cur': None, 'resz_action_post': None, 'resz_action_post2': None, 'resz_action_post3': None, 'resz_action_post4': None,\ 112 | 'resz_action_post5': None, 'resz_action_post6': None, 'resz_action_post7': None, 'resz_action_post8': None} 113 | # edge cases, use index in [1,n-1) 114 | if index == 0: 115 | index = np.random.randint(1, n-9) 116 | if index == n-1: 117 | index = np.random.randint(1, n-9) 118 | if index == n-2: 119 | index = np.random.randint(1, n-9) 120 | if index == n-3: 121 | index = np.random.randint(1, n-9) 122 | if index == n-4: 123 | index = np.random.randint(1, n-9) 124 | if index == n-5: 125 | index = np.random.randint(1, n-9) 126 | if index == n-6: 127 | index = np.random.randint(1, n-9) 128 | if index == n-7: 129 | index = np.random.randint(1, n-9) 130 | if index == n-8: 131 | index = np.random.randint(1, n-9) 132 | if index == n-9: 133 | index = np.random.randint(1, n-9) 134 | # load action 135 | resz_action_pre = self.resz_actions[index-1] 136 | resz_action_cur = self.resz_actions[index] 137 | resz_action_post = self.resz_actions[index+1] 138 | resz_action_post2 = self.resz_actions[index+2] 139 | resz_action_post3 = self.resz_actions[index+3] 140 | resz_action_post4 = self.resz_actions[index+4] 141 | resz_action_post5 = self.resz_actions[index+5] 142 | resz_action_post6 = self.resz_actions[index+6] 143 | resz_action_post7 = self.resz_actions[index+7] 144 | resz_action_post8 = self.resz_actions[index+8] 145 | resz_action_post9 = self.resz_actions[index+9] 146 | # decide if action is valid 147 | if int(resz_action_pre[4]) == 0 or int(resz_action_cur[4]) == 0 \ 148 | or int(resz_action_post[4]) == 0 or int(resz_action_post2[4]) == 0 \ 149 | or int(resz_action_post3[4]) == 0 or int(resz_action_post4[4]) == 0 \ 150 | or int(resz_action_post5[4]) == 0 or int(resz_action_post6[4]) == 0 \ 151 | or int(resz_action_post7[4]) == 0 or int(resz_action_post8[4]) == 0 \ 152 | or int(resz_action_post9[4]) == 0: 153 | return none_sample 154 | 155 | # load images (pre-transform images) 156 | image_bi_pre = Image.open(self.image_paths_bi[index-1]) 157 | image_bi_cur = Image.open(self.image_paths_bi[index]) 158 | image_bi_post = Image.open(self.image_paths_bi[index+1]) 159 | image_bi_post2 = Image.open(self.image_paths_bi[index+2]) 160 | image_bi_post3 = Image.open(self.image_paths_bi[index+3]) 161 | image_bi_post4 = Image.open(self.image_paths_bi[index+4]) 162 | image_bi_post5 = Image.open(self.image_paths_bi[index+5]) 163 | image_bi_post6 = Image.open(self.image_paths_bi[index+6]) 164 | image_bi_post7 = Image.open(self.image_paths_bi[index+7]) 165 | image_bi_post8 = Image.open(self.image_paths_bi[index+8]) 166 | image_bi_post9 = Image.open(self.image_paths_bi[index+9]) 167 | 168 | 169 | sample = {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'image_bi_post2': image_bi_post2, 'image_bi_post3': image_bi_post3,\ 170 | 'image_bi_post4': image_bi_post4, 'image_bi_post5': image_bi_post5, 'image_bi_post6': image_bi_post6, 'image_bi_post7': image_bi_post7, 'image_bi_post8': image_bi_post8, \ 171 | 'image_bi_post9': image_bi_post9, \ 172 | 'resz_action_pre': resz_action_pre[:4], 'resz_action_cur': resz_action_cur[:4], 'resz_action_post': resz_action_post[:4], 'resz_action_post2': resz_action_post2[:4], \ 173 | 'resz_action_post3': resz_action_post3[:4], 'resz_action_post4': resz_action_post4[:4], 'resz_action_post5': resz_action_post5[:4], 'resz_action_post6': resz_action_post6[:4], \ 174 | 'resz_action_post7': resz_action_post7[:4], 'resz_action_post8': resz_action_post8[:4]} 175 | 176 | if self.transform: 177 | sample = self.transform(sample) 178 | 179 | return sample 180 | 181 | def __len__(self): 182 | return len(self.image_paths_bi) 183 | 184 | class Translation(object): 185 | '''Translate the image and action [-max_translation, max_translation]. e.g., [-10, 10] 186 | ''' 187 | def __init__(self, max_translation=10): 188 | self.m_trans = max_translation 189 | 190 | 191 | def __call__(self, sample): 192 | image_bi_pre, image_bi_cur, image_bi_post, action_pre, action_cur = sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['resz_action_pre'], sample['resz_action_cur'] 193 | trans = list(2 * self.m_trans * np.random.random_sample((2,)) - self.m_trans) 194 | trans_action_pre, trans_action_cur = action_pre.copy(), action_cur.copy() 195 | trans_action_pre[:2], trans_action_cur[:2] = trans_action_pre[:2] + np.array(trans), trans_action_cur[:2] + np.array(trans) 196 | image_bi_pre = TF.affine(image_bi_pre, angle=0, translate=trans, scale=1.0, shear=0.0) 197 | image_bi_cur = TF.affine(image_bi_cur, angle=0, translate=trans, scale=1.0, shear=0.0) 198 | image_bi_post = TF.affine(image_bi_post, angle=0, translate=trans, scale=1.0, shear=0.0) 199 | 200 | return {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'resz_action_pre': trans_action_pre, 'resz_action_cur': trans_action_cur} 201 | 202 | 203 | class HFlip(object): 204 | '''Ramdom Horizontal flip the image and action with probabilty p 205 | ''' 206 | def __init__(self, probability=0.5): 207 | self.p = probability 208 | 209 | def flip_angle(self, action): 210 | if action[2] > math.pi: 211 | action[2] = 3 * math.pi - action[2] 212 | else: 213 | action[2] = math.pi - action[2] 214 | return action 215 | 216 | def __call__(self, sample): 217 | if random.random() >= self.p: 218 | return sample 219 | else: 220 | image_bi_pre, image_bi_cur, image_bi_post, action_pre, action_cur = sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['resz_action_pre'], sample['resz_action_cur'] 221 | image_bi_pre, image_bi_cur, image_bi_post = TF.hflip(image_bi_pre), TF.hflip(image_bi_cur), TF.hflip(image_bi_post) 222 | 223 | tsfrm_action_pre, tsfrm_action_cur = action_pre.copy(), action_cur.copy() 224 | tsfrm_action_pre[0], tsfrm_action_cur[0] = image_bi_pre.size[0] - tsfrm_action_pre[0], image_bi_cur.size[0] - tsfrm_action_cur[0] 225 | 226 | tsfrm_action_pre = self.flip_angle(tsfrm_action_pre) 227 | tsfrm_action_cur = self.flip_angle(tsfrm_action_cur) 228 | 229 | return {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'resz_action_pre': tsfrm_action_pre, 'resz_action_cur': tsfrm_action_cur} 230 | 231 | 232 | class VFlip(object): 233 | '''Ramdom vertical flip the image and action with probabilty p 234 | ''' 235 | def __init__(self, probability=0.5): 236 | self.p = probability 237 | 238 | def __call__(self, sample): 239 | if random.random() >= self.p: 240 | return sample 241 | else: 242 | image_bi_pre, image_bi_cur, image_bi_post, action_pre, action_cur = sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['resz_action_pre'], sample['resz_action_cur'] 243 | image_bi_pre, image_bi_cur, image_bi_post = TF.vflip(image_bi_pre), TF.vflip(image_bi_cur), TF.vflip(image_bi_post) 244 | 245 | tsfrm_action_pre, tsfrm_action_cur = action_pre.copy(), action_cur.copy() 246 | tsfrm_action_pre[1], tsfrm_action_cur[1] = image_bi_pre.size[1] - tsfrm_action_pre[1], image_bi_cur.size[1] - tsfrm_action_cur[1] 247 | 248 | tsfrm_action_pre[2] = 2 * math.pi - tsfrm_action_pre[2] 249 | tsfrm_action_cur[2] = 2 * math.pi - tsfrm_action_cur[2] 250 | return {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'resz_action_pre': tsfrm_action_pre, 'resz_action_cur': tsfrm_action_cur} 251 | 252 | class ToTensor(object): 253 | '''convert ndarrays in sample to tensors 254 | ''' 255 | def __call__(self, sample): 256 | image_bi_pre, image_bi_cur, image_bi_post, resz_action_pre, resz_action_cur = sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['resz_action_pre'], sample['resz_action_cur'] 257 | image_bi_pre = TF.to_tensor(image_bi_pre) > 0.3 258 | image_bi_cur = TF.to_tensor(image_bi_cur) > 0.3 259 | image_bi_post = TF.to_tensor(image_bi_post) > 0.3 260 | return {'image_bi_pre': image_bi_pre.float(), 'image_bi_cur': image_bi_cur.float(), 'image_bi_post': image_bi_post.float(), 'resz_action_pre': torch.tensor(resz_action_pre), 'resz_action_cur': torch.tensor(resz_action_cur)} 261 | 262 | class ToTensorRGB(object): 263 | '''convert ndarrays in sample to tensors 264 | ''' 265 | def __call__(self, sample): 266 | image_bi_pre, image_bi_cur, image_bi_post, resz_action_pre, resz_action_cur = sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['resz_action_pre'], sample['resz_action_cur'] 267 | image_bi_pre = TF.to_tensor(image_bi_pre) 268 | image_bi_cur = TF.to_tensor(image_bi_cur) 269 | image_bi_post = TF.to_tensor(image_bi_post) 270 | return {'image_bi_pre': image_bi_pre.float(), 'image_bi_cur': image_bi_cur.float(), 'image_bi_post': image_bi_post.float(), 'resz_action_pre': torch.tensor(resz_action_pre), 'resz_action_cur': torch.tensor(resz_action_cur)} 271 | 272 | 273 | class ToTensorMultiPred4(object): 274 | '''convert ndarrays in sample to tensors 275 | ''' 276 | def __call__(self, sample): 277 | image_bi_pre, image_bi_cur, image_bi_post, image_bi_post2, image_bi_post3, resz_action_pre, resz_action_cur, resz_action_post, resz_action_post2 = \ 278 | sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['image_bi_post2'], sample['image_bi_post3'], sample['resz_action_pre'], sample['resz_action_cur'], sample['resz_action_post'], sample['resz_action_post2'] 279 | image_bi_pre = TF.to_tensor(image_bi_pre) > 0.3 280 | image_bi_cur = TF.to_tensor(image_bi_cur) > 0.3 281 | image_bi_post = TF.to_tensor(image_bi_post) > 0.3 282 | image_bi_post2 = TF.to_tensor(image_bi_post2) > 0.3 283 | image_bi_post3 = TF.to_tensor(image_bi_post3) > 0.3 284 | return {'image_bi_pre': image_bi_pre.float(), 'image_bi_cur': image_bi_cur.float(), 'image_bi_post': image_bi_post.float(), 'image_bi_post2': image_bi_post2.float(), 'image_bi_post3': image_bi_post3.float(), \ 285 | 'resz_action_pre': torch.tensor(resz_action_pre), 'resz_action_cur': torch.tensor(resz_action_cur), 'resz_action_post': torch.tensor(resz_action_post), 'resz_action_post2': torch.tensor(resz_action_post2)} 286 | 287 | class ToTensorMultiPred10(object): 288 | '''convert ndarrays in sample to tensors 289 | ''' 290 | def __call__(self, sample): 291 | image_bi_pre, image_bi_cur, image_bi_post, image_bi_post2, image_bi_post3, image_bi_post4, image_bi_post5, image_bi_post6, image_bi_post7, image_bi_post8, image_bi_post9, \ 292 | resz_action_pre, resz_action_cur, resz_action_post, resz_action_post2, resz_action_post3, resz_action_post4, resz_action_post5, resz_action_post6, resz_action_post7, resz_action_post8 = \ 293 | sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['image_bi_post2'], sample['image_bi_post3'], sample['image_bi_post4'], sample['image_bi_post5'],\ 294 | sample['image_bi_post6'], sample['image_bi_post7'], sample['image_bi_post8'], sample['image_bi_post9'], \ 295 | sample['resz_action_pre'], sample['resz_action_cur'], sample['resz_action_post'], sample['resz_action_post2'], sample['resz_action_post3'], sample['resz_action_post4'], sample['resz_action_post5'],\ 296 | sample['resz_action_post6'], sample['resz_action_post7'], sample['resz_action_post8'] 297 | value = 0.1 298 | image_bi_pre = TF.to_tensor(image_bi_pre) > value 299 | image_bi_cur = TF.to_tensor(image_bi_cur) > value 300 | image_bi_post = TF.to_tensor(image_bi_post) > value 301 | image_bi_post2 = TF.to_tensor(image_bi_post2) > value 302 | image_bi_post3 = TF.to_tensor(image_bi_post3) > value 303 | image_bi_post4 = TF.to_tensor(image_bi_post4) > value 304 | image_bi_post5 = TF.to_tensor(image_bi_post5) > value 305 | image_bi_post6 = TF.to_tensor(image_bi_post6) > value 306 | image_bi_post7 = TF.to_tensor(image_bi_post7) > value 307 | image_bi_post8 = TF.to_tensor(image_bi_post8) > value 308 | image_bi_post9 = TF.to_tensor(image_bi_post9) > value 309 | return {'image_bi_pre': image_bi_pre.float(), 'image_bi_cur': image_bi_cur.float(), 'image_bi_post': image_bi_post.float(), 'image_bi_post2': image_bi_post2.float(), 'image_bi_post3': image_bi_post3.float(), \ 310 | 'image_bi_post4': image_bi_post4.float(), 'image_bi_post5': image_bi_post5.float(), 'image_bi_post6': image_bi_post6.float(), 'image_bi_post7': image_bi_post7.float(), 'image_bi_post8': image_bi_post8.float(), \ 311 | 'image_bi_post9': image_bi_post9.float(), 'resz_action_pre': torch.tensor(resz_action_pre), 'resz_action_cur': torch.tensor(resz_action_cur), 'resz_action_post': torch.tensor(resz_action_post), \ 312 | 'resz_action_post2': torch.tensor(resz_action_post2), 'resz_action_post3': torch.tensor(resz_action_post3), 'resz_action_post4': torch.tensor(resz_action_post4), 'resz_action_post5': torch.tensor(resz_action_post5), \ 313 | 'resz_action_post6': torch.tensor(resz_action_post6), 'resz_action_post7': torch.tensor(resz_action_post7), 'resz_action_post8': torch.tensor(resz_action_post8)} 314 | 315 | def my_collate(batch): 316 | batch = list(filter(lambda x: x['image_bi_post'] is not None, batch)) 317 | return default_collate(batch) 318 | 319 | def create_image_path(folder, total_img_num): 320 | '''create image_path list as input of MyDataset 321 | total_img_num: number of images 322 | ''' 323 | add1 = './rope_dataset/{}'.format(folder) 324 | image_paths = [] 325 | for i in range(total_img_num): 326 | if len(str(i)) == 1: 327 | add2 = '/img_0000{}.jpg'.format(i) 328 | elif len(str(i)) == 2: 329 | add2 = '/img_000{}.jpg'.format(i) 330 | elif len(str(i)) == 3: 331 | add2 = '/img_00{}.jpg'.format(i) 332 | elif len(str(i)) == 4: 333 | add2 = '/img_0{}.jpg'.format(i) 334 | elif len(str(i)) == 5: 335 | add2 = '/img_{}.jpg'.format(i) 336 | image_paths.append(add1+add2) 337 | return image_paths 338 | 339 | -------------------------------------------------------------------------------- /model/e2c.py: -------------------------------------------------------------------------------- 1 | # separate two models, train g^t first, then train K and L 2 | from __future__ import print_function 3 | import argparse 4 | 5 | import torch 6 | from torch import nn, optim, sigmoid, tanh, relu 7 | from torch.autograd import Variable 8 | from torch.nn import functional as F 9 | from torchvision.utils import save_image 10 | from torch.utils.data import Dataset, DataLoader 11 | from deform.model.create_dataset import * 12 | from deform.model.hidden_dynamics import * 13 | import matplotlib.pyplot as plt 14 | from deform.utils.utils import plot_train_loss, plot_train_bound_loss, plot_train_kl_loss, plot_train_pred_loss, \ 15 | plot_test_loss, plot_test_bound_loss, plot_test_kl_loss, plot_test_pred_loss, \ 16 | plot_sample, rect, save_data, save_e2c_data, create_loss_list, create_folder, plot_grad_flow 17 | from deform.model.configs import load_config 18 | import os 19 | import math 20 | 21 | class NormalDistribution(object): 22 | """ 23 | Wrapper class representing a multivariate normal distribution parameterized by 24 | N(mu,Cov). If cov. matrix is diagonal, Cov=(sigma).^2. Otherwise, 25 | Cov=A*(sigma).^2*A', where A = (I+v*r^T). 26 | """ 27 | 28 | def __init__(self, mu, sigma, logsigma, *, v=None, r=None): 29 | self.mu = mu 30 | self.sigma = sigma 31 | self.logsigma = logsigma 32 | self.v = v 33 | self.r = r 34 | 35 | @property 36 | def cov(self): 37 | """This should only be called when NormalDistribution represents one sample""" 38 | if self.v is not None and self.r is not None: 39 | assert self.v.dim() == 1 40 | dim = self.v.dim() 41 | v = self.v.unsqueeze(1) # D * 1 vector 42 | rt = self.r.unsqueeze(0) # 1 * D vector 43 | A = torch.eye(dim) + v.mm(rt) 44 | return A.mm(torch.diag(self.sigma.pow(2)).mm(A.t())) 45 | else: 46 | return torch.diag(self.sigma.pow(2)) 47 | 48 | 49 | def KLDGaussian(Q, N, eps=1e-8): 50 | """KL Divergence between two Gaussians 51 | Assuming Q ~ N(mu0, A\sigma_0A') where A = I + vr^{T} 52 | and N ~ N(mu1, \sigma_1) 53 | """ 54 | sum = lambda x: torch.sum(x, dim=1) 55 | k = float(Q.mu.size()[1]) # dimension of distribution 56 | mu0, v, r, mu1 = Q.mu, Q.v, Q.r, N.mu 57 | s02, s12 = (Q.sigma).pow(2) + eps, (N.sigma).pow(2) + eps 58 | a = sum(s02 * (1. + 2. * v * r) / s12) + sum(v.pow(2) / s12) * sum(r.pow(2) * s02) # trace term 59 | b = sum((mu1 - mu0).pow(2) / s12) # difference-of-means term 60 | c = 2. * (sum(N.logsigma - Q.logsigma) - torch.log(1. + sum(v * r) + eps)) # ratio-of-determinants term. 61 | 62 | return 0.5 * (a + b - k + c) 63 | 64 | class E2C(nn.Module): 65 | def __init__(self, dim_z=80, dim_u=4): 66 | super(E2C, self).__init__() 67 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), 68 | nn.ReLU(), 69 | nn.MaxPool2d(3, stride=2), 70 | nn.Conv2d(32, 64, 3, padding=1), 71 | nn.ReLU(), 72 | nn.MaxPool2d(3, stride=2), 73 | nn.Conv2d(64, 128, 3, padding=1), 74 | nn.ReLU(), 75 | nn.MaxPool2d(3, stride=2), 76 | nn.Conv2d(128, 128, 3, padding=1), 77 | nn.ReLU(), 78 | nn.Conv2d(128, 128, 3, padding=1), 79 | nn.ReLU(), 80 | nn.MaxPool2d(3, stride=2, padding=1)) 81 | self.fc1 = nn.Linear(128*3*3, dim_z*2) 82 | self.fc2 = nn.Linear(dim_z, 128*3*3) 83 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 84 | nn.ReLU(), 85 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 86 | nn.ReLU(), 87 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2), 88 | nn.ReLU(), 89 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2), 90 | nn.ReLU(), 91 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2), 92 | nn.Sigmoid()) 93 | self.trans = nn.Sequential( 94 | nn.Linear(dim_z, 100), 95 | nn.BatchNorm1d(100), 96 | nn.ReLU(), 97 | nn.Linear(100, 100), 98 | nn.BatchNorm1d(100), 99 | nn.ReLU(), 100 | nn.Linear(100, dim_z*2) 101 | ) 102 | self.fc_B = nn.Linear(dim_z, dim_z * dim_u) 103 | self.fc_o = nn.Linear(dim_z, dim_z) 104 | self.dim_z = dim_z 105 | self.dim_u = dim_u 106 | 107 | 108 | def encode(self, x): 109 | x = self.conv_layers(x) 110 | x = x.view(x.shape[0], -1) 111 | return relu(self.fc1(x)).chunk(2, dim=1) 112 | 113 | def decode(self, x): 114 | x = relu(self.fc2(x)) 115 | x = x.view(-1, 128, 3, 3) 116 | return self.dconv_layers(x) 117 | 118 | def transition(self, h, Q, u): 119 | batch_size = h.size()[0] 120 | v, r = self.trans(h).chunk(2, dim=1) 121 | v1 = v.unsqueeze(2) 122 | rT = r.unsqueeze(1) 123 | I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1)) 124 | if rT.data.is_cuda: 125 | I.dada.cuda() 126 | A = I.add(v1.bmm(rT)) 127 | 128 | B = self.fc_B(h).view(-1, self.dim_z, self.dim_u) 129 | o = self.fc_o(h).reshape((-1, self.dim_z, 1)) 130 | 131 | u = u.unsqueeze(2) 132 | 133 | d = A.bmm(Q.mu.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2) 134 | sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2) 135 | return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r) 136 | 137 | def reparam(self, mean, logvar): 138 | std = logvar.mul(0.5).exp_() 139 | self.z_mean = mean 140 | self.z_sigma = std 141 | eps = torch.FloatTensor(std.size()).normal_() 142 | if std.data.is_cuda: 143 | eps.cuda() 144 | eps = Variable(eps) 145 | return eps.mul(std).add_(mean), NormalDistribution(mean, std, torch.log(std)) 146 | 147 | def forward(self, x, action, x_next): 148 | mean, logvar = self.encode(x) 149 | mean_next, logvar_next = self.encode(x_next) 150 | 151 | z, self.Qz = self.reparam(mean, logvar) 152 | z_next, self.Qz_next = self.reparam(mean_next, logvar_next) 153 | 154 | self.x_dec = self.decode(z) 155 | self.x_next_dec = self.decode(z_next) 156 | 157 | self.z_next_pred, self.Qz_next_pred = self.transition(z, self.Qz, action) 158 | self.x_next_pred_dec = self.decode(self.z_next_pred) 159 | 160 | return self.x_dec, self.x_next_pred_dec, self.Qz, self.Qz_next, self.Qz_next_pred 161 | 162 | def latent_embeddings(self, x): 163 | return self.encode(x)[0] 164 | 165 | def predict(self, X, U): 166 | mean, logvar = self.encode(X) 167 | z, Qz = self.reparam(mean, logvar) 168 | z_next_pred, Qz_next_pred = self.transition(z, Qz, U) 169 | return self.decode(z_next_pred) 170 | 171 | def binary_crossentropy(t, o, eps=1e-8): 172 | return t * torch.log(o + eps) + (1.0 - t) * torch.log(1.0 - o + eps) 173 | 174 | def compute_loss(x_dec, x_next_pred_dec, x, x_next, 175 | Qz, Qz_next_pred, 176 | Qz_next): 177 | # Reconstruction losses 178 | if False: 179 | x_reconst_loss = (x_dec - x_next).pow(2).sum(dim=1) 180 | x_next_reconst_loss = (x_next_pred_dec - x_next).pow(2).sum(dim=1) 181 | else: 182 | x_reconst_loss = -binary_crossentropy(x, x_dec).sum(dim=1) 183 | x_next_reconst_loss = -binary_crossentropy(x_next, x_next_pred_dec).sum(dim=1) 184 | 185 | logvar = Qz.logsigma.mul(2) 186 | KLD_element = Qz.mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 187 | KLD = torch.sum(KLD_element, dim=1).mul(-0.5) 188 | 189 | # ELBO 190 | bound_loss = x_reconst_loss.add(x_next_reconst_loss).add(KLD.reshape(-1,1,1)) 191 | kl = KLDGaussian(Qz_next_pred, Qz_next) 192 | kl = kl[~torch.isnan(kl)] 193 | return bound_loss.mean(), kl.mean() 194 | 195 | def train(e2c_model): 196 | e2c_model.train() 197 | bound_loss = 0 198 | kl_loss = 0 199 | train_loss = 0 200 | pred_loss = 0 201 | for batch_idx, batch_data in enumerate(trainloader): 202 | # current image before action 203 | x = batch_data['image_bi_cur'] 204 | x = x.float().to(device).view(-1, 1, 50, 50) 205 | # action 206 | action = batch_data['resz_action_cur'] 207 | action = action.float().to(device).view(-1, 4) 208 | # image after action 209 | x_next = batch_data['image_bi_post'] 210 | x_next = x_next.float().to(device).view(-1, 1, 50, 50) 211 | # optimization 212 | e2c_optimizer.zero_grad() 213 | # model 214 | x_dec, x_next_pred_dec, Qz, Qz_next, Qz_next_pred = e2c_model(x, action, x_next) 215 | # prediction 216 | x_next_pred = e2c_model.predict(x, action) 217 | # loss 218 | loss_pred = F.binary_cross_entropy(x_next_pred.view(-1, 2500), x_next.view(-1, 2500), reduction='sum') 219 | loss_bound, loss_kl = compute_loss(x_dec, x_next_pred_dec, x, x_next, Qz, Qz_next_pred, Qz_next) 220 | loss = GAMMA_bound * loss_bound + GAMMA_kl * loss_kl + GAMMA_pred * loss_pred 221 | loss.backward() 222 | train_loss += loss.item() 223 | bound_loss += GAMMA_bound * loss_bound.item() 224 | kl_loss += GAMMA_kl * loss_kl.item() 225 | pred_loss += GAMMA_pred * loss_pred.item() 226 | e2c_optimizer.step() 227 | if batch_idx % 5 == 0: 228 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 229 | epoch, batch_idx * len(batch_data['image_bi_cur']), len(trainloader.dataset), 230 | 100. * batch_idx / len(trainloader), 231 | loss.item() / len(batch_data['image_bi_cur']))) 232 | # reconstruction 233 | if batch_idx == 0: 234 | n = min(batch_data['image_bi_cur'].size(0), 8) 235 | comparison = torch.cat([batch_data['image_bi_cur'][:n], # current image 236 | x_dec.view(-1, 1, 50, 50).cpu()[:n]]) # reconstruction of current image 237 | save_image(comparison.cpu(), 238 | './result/{}/reconstruction_train/reconstruct_epoch_{}.png'.format(folder_name, epoch), nrow=n) 239 | print('====> Epoch: {} Average loss: {:.4f}'.format( 240 | epoch, train_loss / len(trainloader.dataset))) 241 | n = len(trainloader.dataset) 242 | return train_loss/n, bound_loss/n, kl_loss/n, pred_loss/n 243 | 244 | def test(e2c_model): 245 | e2c_model.eval() 246 | test_loss = 0 247 | bound_loss = 0 248 | kl_loss = 0 249 | pred_loss = 0 250 | with torch.no_grad(): 251 | for batch_idx, batch_data in enumerate(testloader): 252 | # current image before current action 253 | x = batch_data['image_bi_cur'] 254 | x = x.float().to(device).view(-1, 1, 50, 50) 255 | # current action 256 | action = batch_data['resz_action_cur'] 257 | action = action.float().to(device).view(-1, 4) 258 | # post image after current action 259 | x_next = batch_data['image_bi_post'] 260 | x_next = x_next.float().to(device).view(-1, 1, 50, 50) 261 | # model 262 | x_dec, x_next_pred_dec, Qz, Qz_next, Qz_next_pred = e2c_model(x, action, x_next) 263 | # prediction 264 | x_next_pred = e2c_model.predict(x, action) 265 | # loss 266 | loss_bound, loss_kl = compute_loss(x_dec, x_next_pred_dec, x, x_next, Qz, Qz_next_pred, Qz_next) 267 | loss_pred = F.binary_cross_entropy(x_next_pred.view(-1, 2500), x_next.view(-1, 2500), reduction='sum') 268 | loss = GAMMA_bound * loss_bound + GAMMA_kl * loss_kl + GAMMA_pred * loss_pred 269 | test_loss += loss.item() 270 | bound_loss += GAMMA_bound * loss_bound.item() 271 | kl_loss += GAMMA_kl * loss_kl.item() 272 | pred_loss += GAMMA_pred * loss_pred.item() 273 | if batch_idx == 0: 274 | n = min(batch_data['image_bi_cur'].size(0), 8) 275 | comparison = torch.cat([batch_data['image_bi_cur'][:n], # current image 276 | x_dec.view(-1, 1, 50, 50).cpu()[:n], # reconstruction of current image 277 | batch_data['image_bi_post'][:n], # post image 278 | x_next_pred.view(-1, 1, 50, 50).cpu()[:n]]) # prediction of post image 279 | save_image(comparison.cpu(), 280 | './result/{}/reconstruction_test/reconstruct_epoch_{}.png'.format(folder_name, epoch), nrow=n) 281 | n = len(testloader.dataset) 282 | return test_loss/n, bound_loss/n, kl_loss/n, pred_loss/n 283 | 284 | # args 285 | parser = argparse.ArgumentParser(description='E2C Rope Deform Example') 286 | parser.add_argument('--folder-name', default='test_E2C', 287 | help='set folder name to save image files') 288 | parser.add_argument('--batch-size', type=int, default=32, metavar='N', 289 | help='input batch size for training (default: 64)') 290 | parser.add_argument('--epochs', type=int, default=2, metavar='N', 291 | help='number of epochs to train (default: 500)') 292 | parser.add_argument('--gamma-bound', type=int, default=10000, metavar='N', 293 | help='scale coefficient for loss of kl divergence for z (default: 10)') 294 | parser.add_argument('--gamma-kl', type=int, default=1, metavar='N', 295 | help='scale coefficient for loss of kl divergence for z (default: 10)') 296 | parser.add_argument('--gamma-pred', type=int, default=1, metavar='N', 297 | help='scale coefficient for loss of prediction (default: 100)') 298 | parser.add_argument('--no-cuda', action='store_true', default=False, 299 | help='enables CUDA training') 300 | parser.add_argument('--seed', type=int, default=1, metavar='S', 301 | help='random seed (default: 1)') 302 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 303 | help='how many batches to wait before logging training status') 304 | parser.add_argument('--restore', action='store_true', default=False) 305 | args = parser.parse_args() 306 | args.cuda = not args.no_cuda and torch.cuda.is_available() 307 | torch.manual_seed(args.seed) 308 | 309 | 310 | # dataset 311 | print('***** Preparing Data *****') 312 | total_img_num = 22515 313 | train_num = int(total_img_num * 0.8) 314 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num) 315 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy' 316 | resz_act = np.load(resz_act_path) 317 | # transform = transforms.Compose([Translation(10), 318 | # HFlip(0.5), 319 | # VFlip(0.5), 320 | # ToTensor()]) 321 | transform = transforms.Compose([Translation(10), 322 | ToTensor()]) 323 | trainset = MyDataset(image_paths_bi[0:train_num], resz_act[0:train_num], transform=transform) 324 | testset = MyDataset(image_paths_bi[train_num:], resz_act[train_num:], transform=ToTensor()) 325 | trainloader = DataLoader(trainset, batch_size=args.batch_size, 326 | shuffle=True, num_workers=4, collate_fn=my_collate) 327 | testloader = DataLoader(testset, batch_size=args.batch_size, 328 | shuffle=True, num_workers=4, collate_fn=my_collate) 329 | print('***** Finish Preparing Data *****') 330 | 331 | 332 | # train var 333 | GAMMA_bound = args.gamma_bound 334 | GAMMA_kl = args.gamma_kl 335 | GAMMA_pred = args.gamma_pred 336 | 337 | # create folders 338 | folder_name = args.folder_name 339 | create_folder(folder_name) 340 | 341 | print('***** Start Training & Testing *****') 342 | device = torch.device("cuda" if args.cuda else "cpu") 343 | epochs = args.epochs 344 | e2c_model = E2C().to(device) 345 | e2c_optimizer = optim.Adam(e2c_model.parameters(), lr=1e-3) 346 | 347 | # initial train 348 | if not args.restore: 349 | init_epoch = 1 350 | loss_logger = None 351 | # restore previous train 352 | else: 353 | print('***** Load Checkpoint *****') 354 | PATH = './result/{}/checkpoint'.format(folder_name) 355 | checkpoint = torch.load(PATH, map_location=device) 356 | e2c_model.load_state_dict(checkpoint['e2c_model_state_dict']) 357 | e2c_optimizer.load_state_dict(checkpoint['e2c_optimizer_state_dict']) 358 | init_epoch = checkpoint['epoch'] + 1 359 | loss_logger = checkpoint['loss_logger'] 360 | 361 | train_loss_all = [] 362 | test_loss_all = [] 363 | train_bound_loss_all = [] 364 | test_bound_loss_all = [] 365 | train_kl_loss_all = [] 366 | test_kl_loss_all = [] 367 | train_pred_loss_all = [] 368 | test_pred_loss_all = [] 369 | for epoch in range(init_epoch, epochs+1): 370 | train_loss, train_bound_loss, train_kl_loss, train_pred_loss = train(e2c_model) 371 | test_loss, test_bound_loss, test_kl_loss, test_pred_loss = test(e2c_model) 372 | train_loss_all.append(train_loss) 373 | test_loss_all.append(test_loss) 374 | train_bound_loss_all.append(train_bound_loss) 375 | test_bound_loss_all.append(test_bound_loss) 376 | train_kl_loss_all.append(train_kl_loss) 377 | test_kl_loss_all.append(test_kl_loss) 378 | train_pred_loss_all.append(train_pred_loss) 379 | test_pred_loss_all.append(test_pred_loss) 380 | if epoch % args.log_interval == 0: 381 | save_e2c_data(folder_name, epochs, train_loss_all, train_bound_loss_all, train_kl_loss_all, train_pred_loss_all, \ 382 | test_loss_all, test_bound_loss_all, test_kl_loss_all, test_pred_loss_all) 383 | # save checkpoint 384 | PATH = './result/{}/checkpoint'.format(folder_name) 385 | loss_logger = {'train_loss_all': train_loss_all, 'train_bound_loss_all': train_bound_loss_all, 386 | 'train_kl_loss_all': train_kl_loss_all, 'train_pred_loss_all': train_pred_loss_all, 387 | 'test_loss_all': test_loss_all, 'test_bound_loss_all': test_bound_loss_all, 388 | 'test_kl_loss_all': test_kl_loss_all, 'test_pred_loss_all': test_pred_loss_all} 389 | torch.save({ 390 | 'epoch': epoch, 391 | 'e2c_model_state_dict': e2c_model.state_dict(), 392 | 'e2c_optimizer_state_dict': e2c_optimizer.state_dict(), 393 | 'loss_logger': loss_logger 394 | }, 395 | PATH) 396 | 397 | 398 | save_e2c_data(folder_name, epochs, train_loss_all, train_bound_loss_all, train_kl_loss_all, train_pred_loss_all, \ 399 | test_loss_all, test_bound_loss_all, test_kl_loss_all, test_pred_loss_all) 400 | 401 | # plot 402 | plot_train_loss('./result/{}/train_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 403 | plot_train_bound_loss('./result/{}/train_bound_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 404 | plot_train_kl_loss('./result/{}/train_kl_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 405 | plot_train_pred_loss('./result/{}/train_pred_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 406 | plot_test_loss('./result/{}/test_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 407 | plot_test_bound_loss('./result/{}/test_bound_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 408 | plot_test_kl_loss('./result/{}/test_kl_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 409 | plot_test_pred_loss('./result/{}/test_pred_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 410 | 411 | 412 | # save checkpoint 413 | PATH = './result/{}/checkpoint'.format(folder_name) 414 | loss_logger = {'train_loss_all': train_loss_all, 'train_bound_loss_all': train_bound_loss_all, 415 | 'train_kl_loss_all': train_kl_loss_all, 'train_pred_loss_all': train_pred_loss_all, 416 | 'test_loss_all': test_loss_all, 'test_bound_loss_all': test_bound_loss_all, 417 | 'test_kl_loss_all': test_kl_loss_all, 'test_pred_loss_all': test_pred_loss_all} 418 | torch.save({ 419 | 'epoch': epoch, 420 | 'e2c_model_state_dict': e2c_model.state_dict(), 421 | 'e2c_optimizer_state_dict': e2c_optimizer.state_dict(), 422 | 'loss_logger': loss_logger 423 | }, 424 | PATH) 425 | print('***** End Program *****') -------------------------------------------------------------------------------- /model/hidden_dynamics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy 3 | from numpy.linalg import inv, pinv, norm, det 4 | import torch 5 | 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | def get_control_matrix(G, U): 10 | '''Calculate control matrix 11 | G: n * len(x_i) 12 | U: n * len(u_i) 13 | return control matrix L: len(x_i) * len(u_i) 14 | ''' 15 | n, d = U.size() 16 | # U is a thin matrix 17 | if n > d: 18 | eps = 1e-5 19 | return (torch.pinverse(U.t().mm(U) + eps*torch.eye(d).to(device).mm(U.t()).mm(G)).to(device)).t() 20 | # U is a fat matrix 21 | elif n < d: 22 | eps = 1e-5 23 | return (U.t().mm(torch.pinverse(U.mm(U.t()) + eps*torch.eye(n).to(device))).to(device).mm(G)).t() 24 | # U is a squared matrix 25 | else: 26 | return (torch.inverse(U).to(device).mm(G)).t() 27 | 28 | 29 | def get_error(G, U, L): 30 | '''||G-UL^T||^2 31 | ''' 32 | return torch.norm(G-U.mm(L.t())) 33 | 34 | def get_error_linear(G, U, L_T): 35 | '''||G-UL_T||^2 36 | ''' 37 | err = G - torch.matmul(U.view(U.shape[0], 1, -1), L_T) 38 | return torch.norm(err.view(err.shape[0], -1)) 39 | 40 | def get_next_state(latent_image_pre, latent_action, K, L): 41 | '''get next embedded state after certain steps 42 | g_{t+1} = K * g_{t} + L * u_{t} 43 | 44 | embedded_state: 1 * len(x_i) 45 | action: m * len(u_i), m is the number of predicted steps 46 | L: len(x_i) * len(u_i), control matrix 47 | ''' 48 | 49 | return latent_image_pre.mm(K.t().to(device)) + latent_action.mm(L.t().to(device)) 50 | 51 | def get_next_state_linear(latent_image_pre, latent_action, K_T, L_T, z=None): 52 | ''' 53 | latent_image_pre: (batch_size, latent_state_dim) 54 | latent_action: (batch_size, latent_act_dim) 55 | K_T: (batch_size, latent_state_dim, latent_state_dim) 56 | L_T: (batch_size, latent_act_dim, latent_state_dim) 57 | ''' 58 | if z is not None: 59 | return (torch.matmul(latent_image_pre.view(latent_image_pre.shape[0], 1, -1), K_T)).view(latent_image_pre.shape[0], -1) + \ 60 | (torch.matmul(latent_action.view(latent_action.shape[0], 1, -1), L_T)).view(latent_action.shape[0], -1) + z.to(device) 61 | else: 62 | return (torch.matmul(latent_image_pre.view(latent_image_pre.shape[0], 1, -1), K_T)).view(latent_image_pre.shape[0], -1) + \ 63 | (torch.matmul(latent_action.view(latent_action.shape[0], 1, -1), L_T)).view(latent_action.shape[0], -1) 64 | 65 | def get_next_state_linear_without_L(latent_image_pre, K_T, z=None): 66 | ''' 67 | latent_image_pre: (batch_size, latent_state_dim) 68 | latent_action: (batch_size, latent_act_dim) 69 | K_T: (batch_size, latent_state_dim, latent_state_dim) 70 | ''' 71 | if z is not None: 72 | return (torch.matmul(latent_image_pre.view(latent_image_pre.shape[0], 1, -1), K_T)).view(latent_image_pre.shape[0], -1) + z.to(device) 73 | else: 74 | return (torch.matmul(latent_image_pre.view(latent_image_pre.shape[0], 1, -1), K_T)).view(latent_image_pre.shape[0], -1) 75 | 76 | -------------------------------------------------------------------------------- /model/image/action_model.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/action_model.pdf -------------------------------------------------------------------------------- /model/image/action_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/action_model.png -------------------------------------------------------------------------------- /model/image/dynamics_model.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/dynamics_model.pdf -------------------------------------------------------------------------------- /model/image/dynamics_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/dynamics_model.png -------------------------------------------------------------------------------- /model/image/state_model.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/state_model.pdf -------------------------------------------------------------------------------- /model/image/state_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/state_model.png -------------------------------------------------------------------------------- /model/nn_large_linear_seprt_no_loop_Kp_Lpa.py: -------------------------------------------------------------------------------- 1 | # separate two models, train g^t first, then train K and L 2 | from __future__ import print_function 3 | import argparse 4 | 5 | import torch 6 | from torch import nn, optim, sigmoid, tanh, relu 7 | from torch.nn import functional as F 8 | from torchvision.utils import save_image 9 | 10 | from torch.utils.data import Dataset, DataLoader 11 | from deform.model.create_dataset import * 12 | from deform.model.hidden_dynamics import * 13 | import matplotlib.pyplot as plt 14 | from deform.utils.utils import plot_train_loss, plot_train_latent_loss, plot_train_img_loss, plot_train_act_loss, plot_train_pred_loss, \ 15 | plot_test_loss, plot_test_latent_loss, plot_test_img_loss, plot_test_act_loss, plot_test_pred_loss, \ 16 | plot_sample, rect, save_data, create_loss_list, create_folder, plot_grad_flow 17 | import os 18 | import math 19 | 20 | class CAE(nn.Module): 21 | def __init__(self, latent_state_dim=80, latent_act_dim=80): 22 | super(CAE, self).__init__() 23 | # state 24 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), 25 | nn.ReLU(), 26 | nn.MaxPool2d(3, stride=2), 27 | nn.Conv2d(32, 64, 3, padding=1), 28 | nn.ReLU(), 29 | nn.MaxPool2d(3, stride=2), 30 | nn.Conv2d(64, 128, 3, padding=1), 31 | nn.ReLU(), 32 | nn.MaxPool2d(3, stride=2), 33 | nn.Conv2d(128, 128, 3, padding=1), 34 | nn.ReLU(), 35 | nn.Conv2d(128, 128, 3, padding=1), 36 | nn.ReLU(), 37 | nn.MaxPool2d(3, stride=2, padding=1)) 38 | self.fc1 = nn.Linear(128*3*3, latent_state_dim) 39 | self.fc2 = nn.Linear(latent_state_dim, 128*3*3) 40 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 41 | nn.ReLU(), 42 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 43 | nn.ReLU(), 44 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2), 45 | nn.ReLU(), 46 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2), 47 | nn.ReLU(), 48 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2), 49 | nn.Sigmoid()) 50 | # action 51 | self.fc5 = nn.Linear(4, latent_act_dim) 52 | self.fc6 = nn.Linear(latent_act_dim, latent_act_dim) 53 | self.fc7 = nn.Linear(latent_act_dim, latent_act_dim) 54 | self.fc8 = nn.Linear(latent_act_dim, 4) 55 | # add these in order to use GPU for parameters 56 | self.mul_tensor = torch.tensor([50, 50, 2*math.pi, 0.14]) 57 | self.add_tensor = torch.tensor([0, 0, 0, 0.01]) 58 | 59 | 60 | def encoder(self, x): 61 | x = self.conv_layers(x) 62 | x = x.view(x.shape[0], -1) 63 | return relu(self.fc1(x)) 64 | 65 | def decoder(self, x): 66 | x = relu(self.fc2(x)) 67 | x = x.view(-1, 128, 3, 3) 68 | return self.dconv_layers(x) 69 | 70 | def encoder_act(self, u): 71 | h1 = relu(self.fc5(u)) 72 | return relu(self.fc6(h1)) 73 | 74 | def decoder_act(self, u): 75 | h2 = relu(self.fc7(u)) 76 | return torch.mul(sigmoid(self.fc8(h2)), self.mul_tensor.cuda()) + self.add_tensor.cuda() 77 | 78 | def forward(self, x_cur, u, x_post): 79 | g_cur = self.encoder(x_cur) 80 | a = self.encoder_act(u) 81 | g_post = self.encoder(x_post) 82 | 83 | return g_cur, a, g_post, self.decoder(g_cur), self.decoder_act(a) 84 | 85 | class SysDynamics(nn.Module): 86 | def __init__(self, latent_state_dim=80, latent_act_dim=80): 87 | super(SysDynamics, self).__init__() 88 | self.conv_layers_matrix = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), 89 | nn.ReLU(), 90 | nn.MaxPool2d(3, stride=1), 91 | nn.Conv2d(32, 64, 3, padding=1), 92 | nn.ReLU(), 93 | nn.MaxPool2d(3, stride=1), 94 | nn.Conv2d(64, 128, 3, padding=1), 95 | nn.ReLU(), 96 | nn.MaxPool2d(3, stride=2), 97 | nn.Conv2d(128, 256, 3, padding=1), 98 | nn.ReLU(), 99 | nn.MaxPool2d(3, stride=2), 100 | nn.Conv2d(256, 512, 3, padding=1), 101 | nn.ReLU(), 102 | nn.MaxPool2d(3, stride=2), 103 | nn.Conv2d(512, 512, 3, padding=1), 104 | nn.ReLU(), 105 | nn.Conv2d(512, 512, 3, padding=1), 106 | nn.ReLU(), 107 | nn.MaxPool2d(3, stride=2, padding=1)) 108 | self.fc31 = nn.Linear(512*2*2, latent_state_dim*latent_state_dim) 109 | self.fc32 = nn.Linear(latent_state_dim*latent_state_dim, latent_state_dim*latent_state_dim) 110 | self.fc41 = nn.Linear(512*2*2 + latent_act_dim, latent_state_dim*latent_act_dim) 111 | self.fc42 = nn.Linear(latent_state_dim*latent_act_dim, latent_state_dim*latent_act_dim) 112 | self.fc9 = nn.Linear(4, latent_act_dim) 113 | self.fc10 = nn.Linear(latent_act_dim, latent_act_dim) 114 | # latent dim 115 | self.latent_act_dim = latent_act_dim 116 | self.latent_state_dim = latent_state_dim 117 | 118 | def encoder_matrix(self, x, a): 119 | x = self.conv_layers_matrix(x) 120 | x = x.view(x.shape[0], -1) 121 | xa = torch.cat((x,a), 1) 122 | 123 | return relu(self.fc32(relu(self.fc31(x)))).view(-1, self.latent_state_dim, self.latent_state_dim), \ 124 | relu(self.fc42(relu(self.fc41(xa)))).view(-1, self.latent_act_dim, self.latent_state_dim) 125 | 126 | def forward(self, x_cur, u): 127 | a = relu(self.fc10(relu(self.fc9(u)))) 128 | K_T, L_T = self.encoder_matrix(x_cur, a) 129 | 130 | return K_T, L_T 131 | 132 | def loss_function(recon_x, x): 133 | ''' 134 | recon_x: tensor 135 | x: tensor 136 | ''' 137 | return F.binary_cross_entropy(recon_x.view(-1, 2500), x.view(-1, 2500), reduction='sum') 138 | 139 | 140 | def mse(recon_x, x): 141 | '''mean square error 142 | recon_x: numpy array 143 | x: numpy array 144 | ''' 145 | return F.mse_loss(recon_x, x) 146 | 147 | def loss_function_img(recon_img, img): 148 | return F.binary_cross_entropy(recon_img.view(-1, 2500), img.view(-1, 2500), reduction='sum') 149 | 150 | def loss_function_act(recon_act, act): 151 | return F.mse_loss(recon_act.view(-1, 4), act.view(-1, 4), reduction='sum') 152 | 153 | def loss_function_latent_linear(latent_img_pre, latent_img_post, latent_action, K_T, L_T): 154 | G = latent_img_post.view(latent_img_post.shape[0], 1, -1) - torch.matmul(latent_img_pre.view(latent_img_pre.shape[0], 1, -1), K_T) 155 | return get_error_linear(G, latent_action, L_T) 156 | 157 | def loss_function_pred_linear(img_post, latent_img_pre, latent_act, K_T, L_T): 158 | recon_latent_img_post = get_next_state_linear(latent_img_pre, latent_act, K_T, L_T) 159 | recon_img_post = recon_model.decoder(recon_latent_img_post) 160 | return F.binary_cross_entropy(recon_img_post.view(-1, 2500), img_post.view(-1, 2500), reduction='sum') 161 | 162 | def constraint_loss(steps, idx, trainset, U_latent, L): 163 | loss = 0 164 | data = trainset.__getitem__(idx).float().to(device).view(-1, 1, 50, 50) 165 | embed_state = model.encoder(data).detach().cpu().numpy() 166 | for i in range(steps): 167 | step = i + 1 168 | data_next = trainset.__getitem__(idx+step).float().to(device).view(-1, 1, 50, 50) 169 | action = U_latent[idx:idx+step][:] 170 | embed_state_next = torch.from_numpy(get_next_state(embed_state, action, L)).to(device).float() 171 | recon_state_next = model.decoder(embed_state_next)#.detach().cpu()#.numpy() 172 | loss += mse(recon_state_next, data_next) 173 | return loss 174 | 175 | def train_new(epoch, recon_model, dyn_model, epoch_thres=500): 176 | if epoch < epoch_thres: 177 | recon_model.train() 178 | dyn_model.eval() 179 | train_loss = 0 180 | img_loss = 0 181 | act_loss = 0 182 | latent_loss = 0 183 | pred_loss = 0 184 | for batch_idx, batch_data in enumerate(trainloader): 185 | # current image before action 186 | img_cur = batch_data['image_bi_cur'] 187 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50) 188 | # action 189 | act = batch_data['resz_action_cur'] 190 | act = act.float().to(device).view(-1, 4) 191 | # image after action 192 | img_post = batch_data['image_bi_post'] 193 | img_post = img_post.float().to(device).view(-1, 1, 50, 50) 194 | # optimization 195 | recon_optimizer.zero_grad() 196 | # model 197 | latent_img_cur, latent_act, latent_img_post, recon_img_cur, recon_act = recon_model(img_cur, act, img_post) 198 | # loss 199 | loss_img = loss_function_img(recon_img_cur, img_cur) 200 | loss_act = loss_function_act(recon_act, act) 201 | loss = loss_img + GAMMA_act * loss_act 202 | loss.backward() 203 | train_loss += loss.item() 204 | img_loss += loss_img.item() 205 | act_loss += GAMMA_act * loss_act.item() 206 | 207 | recon_optimizer.step() 208 | if batch_idx % 5 == 0: 209 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 210 | epoch, batch_idx * len(batch_data['image_bi_cur']), len(trainloader.dataset), 211 | 100. * batch_idx / len(trainloader), 212 | loss.item() / len(batch_data['image_bi_cur']))) 213 | # reconstruction 214 | if batch_idx == 0: 215 | n = min(batch_data['image_bi_cur'].size(0), 8) 216 | comparison = torch.cat([batch_data['image_bi_cur'][:n], # current image 217 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n]]) # reconstruction of current image 218 | save_image(comparison.cpu(), 219 | './result/{}/reconstruction_train/reconstruct_epoch_{}.png'.format(folder_name, epoch), nrow=n) 220 | plot_sample(batch_data['image_bi_cur'][:n].detach().cpu().numpy(), 221 | batch_data['image_bi_post'][:n].detach().cpu().numpy(), 222 | batch_data['resz_action_cur'][:n].detach().cpu().numpy(), 223 | recon_act.view(-1, 4)[:n].detach().cpu().numpy(), 224 | './result/{}/reconstruction_act_train/recon_epoch_{}.png'.format(folder_name, epoch)) 225 | print('====> Epoch: {} Average loss: {:.4f}'.format( 226 | epoch, train_loss / len(trainloader.dataset))) 227 | n = len(trainloader.dataset) 228 | return train_loss/n, img_loss/n, act_loss/n, latent_loss/n, pred_loss/n 229 | else: 230 | recon_model.eval() 231 | dyn_model.train() 232 | train_loss = 0 233 | img_loss = 0 234 | act_loss = 0 235 | latent_loss = 0 236 | pred_loss = 0 237 | for batch_idx, batch_data in enumerate(trainloader): 238 | # current image before action 239 | img_cur = batch_data['image_bi_cur'] 240 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50) 241 | # action 242 | act = batch_data['resz_action_cur'] 243 | act = act.float().to(device).view(-1, 4) 244 | # image after action 245 | img_post = batch_data['image_bi_post'] 246 | img_post = img_post.float().to(device).view(-1, 1, 50, 50) 247 | # optimization 248 | dyn_optimizer.zero_grad() 249 | # model 250 | latent_img_cur, latent_act, latent_img_post, recon_img_cur, recon_act = recon_model(img_cur, act, img_post) 251 | K_T, L_T = dyn_model(img_cur, act) 252 | # prediction 253 | pred_latent_img_post = get_next_state_linear(latent_img_cur, latent_act, K_T, L_T) 254 | pred_img_post = recon_model.decoder(pred_latent_img_post) 255 | # loss 256 | loss_latent = loss_function_latent_linear(latent_img_cur, latent_img_post, latent_act, K_T, L_T) 257 | loss_predict = loss_function_pred_linear(img_post, latent_img_cur, latent_act, K_T, L_T) 258 | loss = GAMMA_latent * loss_latent + GAMMA_pred * loss_predict 259 | loss.backward() 260 | train_loss += loss.item() 261 | latent_loss += GAMMA_latent * loss_latent.item() 262 | pred_loss += GAMMA_pred * loss_predict.item() 263 | 264 | dyn_optimizer.step() 265 | if batch_idx % 5 == 0: 266 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 267 | epoch, batch_idx * len(batch_data['image_bi_cur']), len(trainloader.dataset), 268 | 100. * batch_idx / len(trainloader), 269 | loss.item() / len(batch_data['image_bi_cur']))) 270 | # reconstruction 271 | if batch_idx == 0: 272 | n = min(batch_data['image_bi_cur'].size(0), 8) 273 | comparison = torch.cat([batch_data['image_bi_cur'][:n], # current image 274 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n], # reconstruction of current image 275 | batch_data['image_bi_post'][:n], # post image 276 | pred_img_post.view(-1, 1, 50, 50).cpu()[:n]]) # prediction of post image 277 | save_image(comparison.cpu(), 278 | './result/{}/reconstruction_train/reconstruct_epoch_{}.png'.format(folder_name, epoch), nrow=n) 279 | plot_sample(batch_data['image_bi_cur'][:n].detach().cpu().numpy(), 280 | batch_data['image_bi_post'][:n].detach().cpu().numpy(), 281 | batch_data['resz_action_cur'][:n].detach().cpu().numpy(), 282 | recon_act.view(-1, 4)[:n].detach().cpu().numpy(), 283 | './result/{}/reconstruction_act_train/recon_epoch_{}.png'.format(folder_name, epoch)) 284 | print('====> Epoch: {} Average loss: {:.4f}'.format( 285 | epoch, train_loss / len(trainloader.dataset))) 286 | n = len(trainloader.dataset) 287 | return train_loss/n, img_loss/n, act_loss/n, latent_loss/n, pred_loss/n 288 | 289 | def test_new(epoch, recon_model, dyn_model): 290 | recon_model.eval() 291 | dyn_model.eval() 292 | test_loss = 0 293 | img_loss = 0 294 | act_loss = 0 295 | latent_loss = 0 296 | pred_loss = 0 297 | with torch.no_grad(): 298 | for batch_idx, batch_data in enumerate(testloader): 299 | # current image before current action 300 | img_cur = batch_data['image_bi_cur'] 301 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50) 302 | # current action 303 | act = batch_data['resz_action_cur'] 304 | act = act.float().to(device).view(-1, 4) 305 | # post image after current action 306 | img_post = batch_data['image_bi_post'] 307 | img_post = img_post.float().to(device).view(-1, 1, 50, 50) 308 | # model 309 | latent_img_cur, latent_act, latent_img_post, recon_img_cur, recon_act = recon_model(img_cur, act, img_post) 310 | K_T, L_T = dyn_model(img_cur, act) 311 | # prediction 312 | pred_latent_img_post = get_next_state_linear(latent_img_cur, latent_act, K_T, L_T) 313 | pred_img_post = recon_model.decoder(pred_latent_img_post) 314 | # loss 315 | loss_img = loss_function_img(recon_img_cur, img_cur) 316 | loss_act = loss_function_act(recon_act, act) 317 | loss_latent = loss_function_latent_linear(latent_img_cur, latent_img_post, latent_act, K_T, L_T) 318 | loss_predict = loss_function_pred_linear(img_post, latent_img_cur, latent_act, K_T, L_T) 319 | loss = loss_img + GAMMA_act * loss_act + GAMMA_latent * loss_latent + GAMMA_pred * loss_predict 320 | test_loss += loss.item() 321 | img_loss += loss_img.item() 322 | act_loss += GAMMA_act * loss_act.item() 323 | latent_loss += GAMMA_latent * loss_latent.item() 324 | pred_loss += GAMMA_pred * loss_predict.item() 325 | if batch_idx == 0: 326 | n = min(batch_data['image_bi_cur'].size(0), 8) 327 | comparison = torch.cat([batch_data['image_bi_cur'][:n], # current image 328 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n], # reconstruction of current image 329 | batch_data['image_bi_post'][:n], # post image 330 | pred_img_post.view(-1, 1, 50, 50).cpu()[:n]]) # prediction of post image 331 | save_image(comparison.cpu(), 332 | './result/{}/reconstruction_test/reconstruct_epoch_{}.png'.format(folder_name, epoch), nrow=n) 333 | plot_sample(batch_data['image_bi_cur'][:n].detach().cpu().numpy(), 334 | batch_data['image_bi_post'][:n].detach().cpu().numpy(), 335 | batch_data['resz_action_cur'][:n].detach().cpu().numpy(), 336 | recon_act.view(-1, 4)[:n].detach().cpu().numpy(), 337 | './result/{}/reconstruction_act_test/recon_epoch_{}.png'.format(folder_name, epoch)) 338 | n = len(testloader.dataset) 339 | return test_loss/n, img_loss/n, act_loss/n, latent_loss/n, pred_loss/n 340 | 341 | # args 342 | parser = argparse.ArgumentParser(description='CAE Rope Deform Example') 343 | parser.add_argument('--folder-name', default='test', 344 | help='set folder name to save image files') 345 | parser.add_argument('--batch-size', type=int, default=32, metavar='N', 346 | help='input batch size for training (default: 64)') 347 | parser.add_argument('--epochs', type=int, default=1000, metavar='N', 348 | help='number of epochs to train (default: 500)') 349 | parser.add_argument('--gamma-act', type=int, default=450, metavar='N', 350 | help='scale coefficient for loss of action (default: 150*3)') 351 | parser.add_argument('--gamma-lat', type=int, default=900, metavar='N', 352 | help='scale coefficient for loss of latent dynamics (default: 150*6)') 353 | parser.add_argument('--gamma-pred', type=int, default=10, metavar='N', 354 | help='scale coefficient for loss of prediction (default: 3)') 355 | parser.add_argument('--no-cuda', action='store_true', default=False, 356 | help='enables CUDA training') 357 | parser.add_argument('--math', default=False, 358 | help='get control matrix L: True - use regression, False - use backpropagation') 359 | parser.add_argument('--seed', type=int, default=1, metavar='S', 360 | help='random seed (default: 1)') 361 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 362 | help='how many batches to wait before logging training status') 363 | parser.add_argument('--restore', action='store_true', default=False) 364 | args = parser.parse_args() 365 | args.cuda = not args.no_cuda and torch.cuda.is_available() 366 | torch.manual_seed(args.seed) 367 | 368 | 369 | # dataset 370 | print('***** Preparing Data *****') 371 | total_img_num = 22515 372 | train_num = int(total_img_num * 0.8) 373 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num) 374 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy' 375 | resz_act = np.load(resz_act_path) 376 | # transform = transforms.Compose([Translation(10), 377 | # HFlip(0.5), 378 | # VFlip(0.5), 379 | # ToTensor()]) 380 | transform = transforms.Compose([Translation(10), 381 | ToTensor()]) 382 | trainset = MyDataset(image_paths_bi[0:train_num], resz_act[0:train_num], transform=transform) 383 | testset = MyDataset(image_paths_bi[train_num:], resz_act[train_num:], transform=ToTensor()) 384 | trainloader = DataLoader(trainset, batch_size=args.batch_size, 385 | shuffle=True, num_workers=4, collate_fn=my_collate) 386 | testloader = DataLoader(testset, batch_size=args.batch_size, 387 | shuffle=True, num_workers=4, collate_fn=my_collate) 388 | print('***** Finish Preparing Data *****') 389 | 390 | # train var 391 | MATH = args.math 392 | GAMMA_act = args.gamma_act 393 | GAMMA_latent = args.gamma_lat 394 | GAMMA_pred = args.gamma_pred 395 | 396 | # create folders 397 | folder_name = args.folder_name 398 | create_folder(folder_name) 399 | 400 | print('***** Start Training & Testing *****') 401 | device = torch.device("cuda" if args.cuda else "cpu") 402 | epochs = args.epochs 403 | recon_model = CAE().to(device) 404 | dyn_model = SysDynamics().to(device) 405 | recon_optimizer = optim.Adam(recon_model.parameters(), lr=1e-3) 406 | dyn_optimizer = optim.Adam(dyn_model.parameters(), lr=1e-3) 407 | 408 | # initial train 409 | if not args.restore: 410 | init_epoch = 1 411 | loss_logger = None 412 | # restore previous train 413 | else: 414 | print('***** Load Checkpoint *****') 415 | PATH = './result/{}/checkpoint'.format(folder_name) 416 | checkpoint = torch.load(PATH, map_location=device) 417 | recon_model.load_state_dict(checkpoint['recon_model_state_dict']) 418 | dyn_model.load_state_dict(checkpoint['dyn_model_state_dict']) 419 | recon_optimizer.load_state_dict(checkpoint['recon_optimizer_state_dict']) 420 | dyn_optimizer.load_state_dict(checkpoint['dyn_optimizer_state_dict']) 421 | init_epoch = checkpoint['epoch'] + 1 422 | loss_logger = checkpoint['loss_logger'] 423 | 424 | train_loss_all, train_img_loss_all, train_act_loss_all, train_latent_loss_all, train_pred_loss_all, _, \ 425 | test_loss_all, test_img_loss_all, test_act_loss_all, test_latent_loss_all, test_pred_loss_all, _ = create_loss_list(loss_logger, kld=False) 426 | 427 | 428 | for epoch in range(init_epoch, epochs+1): 429 | train_loss, train_img_loss, train_act_loss, train_latent_loss, train_pred_loss = train_new(epoch, recon_model, dyn_model, epoch_thres=int(epochs/2)) 430 | test_loss, test_img_loss, test_act_loss, test_latent_loss, test_pred_loss = test_new(epoch, recon_model, dyn_model) 431 | train_loss_all.append(train_loss) 432 | train_img_loss_all.append(train_img_loss) 433 | train_act_loss_all.append(train_act_loss) 434 | train_latent_loss_all.append(train_latent_loss) 435 | train_pred_loss_all.append(train_pred_loss) 436 | test_loss_all.append(test_loss) 437 | test_img_loss_all.append(test_img_loss) 438 | test_act_loss_all.append(test_act_loss) 439 | test_latent_loss_all.append(test_latent_loss) 440 | test_pred_loss_all.append(test_pred_loss) 441 | if epoch % args.log_interval == 0: 442 | save_data(folder_name, epochs, train_loss_all, train_img_loss_all, train_act_loss_all, 443 | train_latent_loss_all, train_pred_loss_all, test_loss_all, test_img_loss_all, 444 | test_act_loss_all, test_latent_loss_all, test_pred_loss_all, None, None, None, None) 445 | # save checkpoint 446 | PATH = './result/{}/checkpoint'.format(folder_name) 447 | loss_logger = {'train_loss_all': train_loss_all, 'train_img_loss_all': train_img_loss_all, 448 | 'train_act_loss_all': train_act_loss_all, 'train_latent_loss_all': train_latent_loss_all, 449 | 'train_pred_loss_all': train_pred_loss_all, 'test_loss_all': test_loss_all, 450 | 'test_img_loss_all': test_img_loss_all, 'test_act_loss_all': test_act_loss_all, 451 | 'test_latent_loss_all': test_latent_loss_all, 'test_pred_loss_all': test_pred_loss_all} 452 | torch.save({ 453 | 'epoch': epoch, 454 | 'recon_model_state_dict': recon_model.state_dict(), 455 | 'dyn_model_state_dict': dyn_model.state_dict(), 456 | 'recon_optimizer_state_dict': recon_optimizer.state_dict(), 457 | 'dyn_optimizer_state_dict': dyn_optimizer.state_dict(), 458 | 'loss_logger': loss_logger 459 | }, 460 | PATH) 461 | 462 | 463 | save_data(folder_name, epochs, train_loss_all, train_img_loss_all, train_act_loss_all, 464 | train_latent_loss_all, train_pred_loss_all, test_loss_all, test_img_loss_all, 465 | test_act_loss_all, test_latent_loss_all, test_pred_loss_all, None, None, None, None) 466 | 467 | # plot 468 | plot_train_loss('./result/{}/train_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 469 | plot_train_img_loss('./result/{}/train_img_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 470 | plot_train_act_loss('./result/{}/train_act_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 471 | plot_train_latent_loss('./result/{}/train_latent_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 472 | plot_train_pred_loss('./result/{}/train_pred_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 473 | plot_test_loss('./result/{}/test_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 474 | plot_test_img_loss('./result/{}/test_img_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 475 | plot_test_act_loss('./result/{}/test_act_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 476 | plot_test_latent_loss('./result/{}/test_latent_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 477 | plot_test_pred_loss('./result/{}/test_pred_loss_epoch{}.npy'.format(folder_name, epochs), folder_name) 478 | 479 | # save checkpoint 480 | PATH = './result/{}/checkpoint'.format(folder_name) 481 | loss_logger = {'train_loss_all': train_loss_all, 'train_img_loss_all': train_img_loss_all, 482 | 'train_act_loss_all': train_act_loss_all, 'train_latent_loss_all': train_latent_loss_all, 483 | 'train_pred_loss_all': train_pred_loss_all, 'test_loss_all': test_loss_all, 484 | 'test_img_loss_all': test_img_loss_all, 'test_act_loss_all': test_act_loss_all, 485 | 'test_latent_loss_all': test_latent_loss_all, 'test_pred_loss_all': test_pred_loss_all} 486 | torch.save({ 487 | 'epoch': epoch, 488 | 'recon_model_state_dict': recon_model.state_dict(), 489 | 'dyn_model_state_dict': dyn_model.state_dict(), 490 | 'recon_optimizer_state_dict': recon_optimizer.state_dict(), 491 | 'dyn_optimizer_state_dict': dyn_optimizer.state_dict(), 492 | 'loss_logger': loss_logger 493 | }, 494 | PATH) 495 | print('***** End Program *****') 496 | 497 | -------------------------------------------------------------------------------- /model/predict_e2c.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | from torch.autograd import Variable 5 | from torch import nn, optim, sigmoid, tanh, relu 6 | from torch.utils.data import Dataset, DataLoader 7 | from torch.nn import functional as F 8 | from deform.model.create_dataset import * 9 | from deform.model.hidden_dynamics import * 10 | from deform.utils.utils import plot_sample_multi_step 11 | from torchvision.utils import save_image 12 | import os 13 | import math 14 | 15 | 16 | class NormalDistribution(object): 17 | """ 18 | Wrapper class representing a multivariate normal distribution parameterized by 19 | N(mu,Cov). If cov. matrix is diagonal, Cov=(sigma).^2. Otherwise, 20 | Cov=A*(sigma).^2*A', where A = (I+v*r^T). 21 | """ 22 | 23 | def __init__(self, mu, sigma, logsigma, *, v=None, r=None): 24 | self.mu = mu 25 | self.sigma = sigma 26 | self.logsigma = logsigma 27 | self.v = v 28 | self.r = r 29 | 30 | @property 31 | def cov(self): 32 | """This should only be called when NormalDistribution represents one sample""" 33 | if self.v is not None and self.r is not None: 34 | assert self.v.dim() == 1 35 | dim = self.v.dim() 36 | v = self.v.unsqueeze(1) # D * 1 vector 37 | rt = self.r.unsqueeze(0) # 1 * D vector 38 | A = torch.eye(dim) + v.mm(rt) 39 | return A.mm(torch.diag(self.sigma.pow(2)).mm(A.t())) 40 | else: 41 | return torch.diag(self.sigma.pow(2)) 42 | 43 | 44 | def KLDGaussian(Q, N, eps=1e-8): 45 | """KL Divergence between two Gaussians 46 | Assuming Q ~ N(mu0, A\sigma_0A') where A = I + vr^{T} 47 | and N ~ N(mu1, \sigma_1) 48 | """ 49 | sum = lambda x: torch.sum(x, dim=1) 50 | k = float(Q.mu.size()[1]) # dimension of distribution 51 | mu0, v, r, mu1 = Q.mu, Q.v, Q.r, N.mu 52 | s02, s12 = (Q.sigma).pow(2) + eps, (N.sigma).pow(2) + eps 53 | a = sum(s02 * (1. + 2. * v * r) / s12) + sum(v.pow(2) / s12) * sum(r.pow(2) * s02) # trace term 54 | b = sum((mu1 - mu0).pow(2) / s12) # difference-of-means term 55 | c = 2. * (sum(N.logsigma - Q.logsigma) - torch.log(1. + sum(v * r) + eps)) # ratio-of-determinants term. 56 | 57 | return 0.5 * (a + b - k + c) 58 | 59 | class E2C(nn.Module): 60 | def __init__(self, dim_z=80, dim_u=4): 61 | super(E2C, self).__init__() 62 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), 63 | nn.ReLU(), 64 | nn.MaxPool2d(3, stride=2), 65 | nn.Conv2d(32, 64, 3, padding=1), 66 | nn.ReLU(), 67 | nn.MaxPool2d(3, stride=2), 68 | nn.Conv2d(64, 128, 3, padding=1), 69 | nn.ReLU(), 70 | nn.MaxPool2d(3, stride=2), 71 | nn.Conv2d(128, 128, 3, padding=1), 72 | nn.ReLU(), 73 | nn.Conv2d(128, 128, 3, padding=1), 74 | nn.ReLU(), 75 | nn.MaxPool2d(3, stride=2, padding=1)) 76 | self.fc1 = nn.Linear(128*3*3, dim_z*2) 77 | self.fc2 = nn.Linear(dim_z, 128*3*3) 78 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 79 | nn.ReLU(), 80 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 81 | nn.ReLU(), 82 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2), 83 | nn.ReLU(), 84 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2), 85 | nn.ReLU(), 86 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2), 87 | nn.Sigmoid()) 88 | self.trans = nn.Sequential( 89 | nn.Linear(dim_z, 100), 90 | nn.BatchNorm1d(100), 91 | nn.ReLU(), 92 | nn.Linear(100, 100), 93 | nn.BatchNorm1d(100), 94 | nn.ReLU(), 95 | nn.Linear(100, dim_z*2) 96 | ) 97 | self.fc_B = nn.Linear(dim_z, dim_z * dim_u) 98 | self.fc_o = nn.Linear(dim_z, dim_z) 99 | self.dim_z = dim_z 100 | self.dim_u = dim_u 101 | 102 | 103 | def encode(self, x): 104 | x = self.conv_layers(x) 105 | x = x.view(x.shape[0], -1) 106 | return relu(self.fc1(x)).chunk(2, dim=1) 107 | 108 | def decode(self, x): 109 | x = relu(self.fc2(x)) 110 | x = x.view(-1, 128, 3, 3) 111 | return self.dconv_layers(x) 112 | 113 | def transition(self, h, Q, u): 114 | batch_size = h.size()[0] 115 | v, r = self.trans(h).chunk(2, dim=1) 116 | v1 = v.unsqueeze(2) 117 | rT = r.unsqueeze(1) 118 | I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1)) 119 | if rT.data.is_cuda: 120 | I.dada.cuda() 121 | A = I.add(v1.bmm(rT)) 122 | 123 | B = self.fc_B(h).view(-1, self.dim_z, self.dim_u) 124 | o = self.fc_o(h).reshape((-1, self.dim_z, 1)) 125 | 126 | u = u.unsqueeze(2) 127 | 128 | d = A.bmm(Q.mu.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2) 129 | sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2) 130 | return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r) 131 | 132 | def reparam(self, mean, logvar): 133 | std = logvar.mul(0.5).exp_() 134 | self.z_mean = mean 135 | self.z_sigma = std 136 | eps = torch.FloatTensor(std.size()).normal_() 137 | if std.data.is_cuda: 138 | eps.cuda() 139 | eps = Variable(eps) 140 | return eps.mul(std).add_(mean), NormalDistribution(mean, std, torch.log(std)) 141 | 142 | def forward(self, x, action, x_next): 143 | mean, logvar = self.encode(x) 144 | mean_next, logvar_next = self.encode(x_next) 145 | 146 | z, self.Qz = self.reparam(mean, logvar) 147 | z_next, self.Qz_next = self.reparam(mean_next, logvar_next) 148 | 149 | self.x_dec = self.decode(z) 150 | self.x_next_dec = self.decode(z_next) 151 | 152 | self.z_next_pred, self.Qz_next_pred = self.transition(z, self.Qz, action) 153 | self.x_next_pred_dec = self.decode(self.z_next_pred) 154 | 155 | return self.x_dec, self.x_next_pred_dec, self.Qz, self.Qz_next, self.Qz_next_pred 156 | 157 | def latent_embeddings(self, x): 158 | return self.encode(x)[0] 159 | 160 | def predict(self, X, U): 161 | mean, logvar = self.encode(X) 162 | z, Qz = self.reparam(mean, logvar) 163 | z_next_pred, Qz_next_pred = self.transition(z, Qz, U) 164 | return self.decode(z_next_pred) 165 | 166 | def binary_crossentropy(t, o, eps=1e-8): 167 | return t * torch.log(o + eps) + (1.0 - t) * torch.log(1.0 - o + eps) 168 | 169 | def compute_loss(x_dec, x_next_pred_dec, x, x_next, 170 | Qz, Qz_next_pred, 171 | Qz_next): 172 | # Reconstruction losses 173 | if False: 174 | x_reconst_loss = (x_dec - x_next).pow(2).sum(dim=1) 175 | x_next_reconst_loss = (x_next_pred_dec - x_next).pow(2).sum(dim=1) 176 | else: 177 | x_reconst_loss = -binary_crossentropy(x, x_dec).sum(dim=1) 178 | x_next_reconst_loss = -binary_crossentropy(x_next, x_next_pred_dec).sum(dim=1) 179 | 180 | logvar = Qz.logsigma.mul(2) 181 | KLD_element = Qz.mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 182 | KLD = torch.sum(KLD_element, dim=1).mul(-0.5) 183 | 184 | # ELBO 185 | bound_loss = x_reconst_loss.add(x_next_reconst_loss).add(KLD.reshape(-1,1,1)) 186 | kl = KLDGaussian(Qz_next_pred, Qz_next) 187 | kl = kl[~torch.isnan(kl)] 188 | return bound_loss.mean(), kl.mean() 189 | 190 | def predict(): 191 | e2c_model.eval() 192 | with torch.no_grad(): 193 | for batch_idx, batch_data in enumerate(dataloader): 194 | # order: img_pre -> act_pre -> img_cur -> act_cur -> img_post 195 | # previous image 196 | img_pre = batch_data['image_bi_pre'] 197 | img_pre = img_pre.float().to(device).view(-1, 1, 50, 50) 198 | # previous action 199 | act_pre = batch_data['resz_action_pre'] 200 | act_pre = act_pre.float().to(device).view(-1, 4) 201 | # current image 202 | img_cur = batch_data['image_bi_cur'] 203 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50) 204 | # current action 205 | act_cur = batch_data['resz_action_cur'] 206 | act_cur = act_cur.float().to(device).view(-1, 4) 207 | # post image 208 | img_post = batch_data['image_bi_post'] 209 | img_post = img_post.float().to(device).view(-1, 1, 50, 50) 210 | # post action 211 | act_post = batch_data['resz_action_post'] 212 | act_post = act_post.float().to(device).view(-1, 4) 213 | # post2 image 214 | img_post2 = batch_data['image_bi_post2'] 215 | img_post2 = img_post2.float().to(device).view(-1, 1, 50, 50) 216 | # post2 action 217 | act_post2 = batch_data['resz_action_post2'] 218 | act_post2 = act_post2.float().to(device).view(-1, 4) 219 | # post3 image 220 | img_post3 = batch_data['image_bi_post3'] 221 | img_post3 = img_post3.float().to(device).view(-1, 1, 50, 50) 222 | # post3 action 223 | act_post3 = batch_data['resz_action_post3'] 224 | act_post3 = act_post3.float().to(device).view(-1, 4) 225 | # post4 image 226 | img_post4 = batch_data['image_bi_post4'] 227 | img_post4 = img_post4.float().to(device).view(-1, 1, 50, 50) 228 | # post4 action 229 | act_post4 = batch_data['resz_action_post4'] 230 | act_post4 = act_post4.float().to(device).view(-1, 4) 231 | # post5 image 232 | img_post5 = batch_data['image_bi_post5'] 233 | img_post5 = img_post5.float().to(device).view(-1, 1, 50, 50) 234 | # post5 action 235 | act_post5 = batch_data['resz_action_post5'] 236 | act_post5 = act_post5.float().to(device).view(-1, 4) 237 | # post6 image 238 | img_post6 = batch_data['image_bi_post6'] 239 | img_post6 = img_post6.float().to(device).view(-1, 1, 50, 50) 240 | # post6 action 241 | act_post6 = batch_data['resz_action_post6'] 242 | act_post6 = act_post6.float().to(device).view(-1, 4) 243 | # post7 image 244 | img_post7 = batch_data['image_bi_post7'] 245 | img_post7 = img_post7.float().to(device).view(-1, 1, 50, 50) 246 | # post7 action 247 | act_post7 = batch_data['resz_action_post7'] 248 | act_post7 = act_post7.float().to(device).view(-1, 4) 249 | # post8 image 250 | img_post8 = batch_data['image_bi_post8'] 251 | img_post8 = img_post8.float().to(device).view(-1, 1, 50, 50) 252 | # post8 action 253 | act_post8 = batch_data['resz_action_post8'] 254 | act_post8 = act_post8.float().to(device).view(-1, 4) 255 | # post9 image 256 | img_post9 = batch_data['image_bi_post9'] 257 | img_post9 = img_post9.float().to(device).view(-1, 1, 50, 50) 258 | # ten step prediction 259 | # prediction for current image from pre image 260 | recon_img_cur = e2c_model.predict(img_pre, act_pre) 261 | # prediction for post image from pre image 262 | recon_img_post = e2c_model.predict(recon_img_cur, act_cur) 263 | # prediction for post2 image from pre image 264 | recon_img_post2 = e2c_model.predict(recon_img_post, act_post) 265 | # prediction for post3 image from pre image 266 | recon_img_post3 = e2c_model.predict(recon_img_post2, act_post2) 267 | # prediction for post4 image from pre image 268 | recon_img_post4 = e2c_model.predict(recon_img_post3, act_post3) 269 | # prediction for post5 image from pre image 270 | recon_img_post5 = e2c_model.predict(recon_img_post4, act_post4) 271 | # prediction for post6 image from pre image 272 | recon_img_post6 = e2c_model.predict(recon_img_post5, act_post5) 273 | # prediction for post7 image from pre image 274 | recon_img_post7 = e2c_model.predict(recon_img_post6, act_post6) 275 | # prediction for post8 image from pre image 276 | recon_img_post8 = e2c_model.predict(recon_img_post7, act_post7) 277 | # prediction for post9 image from pre image 278 | recon_img_post9 = e2c_model.predict(recon_img_post8, act_post8) 279 | if batch_idx % 10 == 0: 280 | n = min(batch_data['image_bi_pre'].size(0), 1) 281 | comparison_GT = torch.cat([batch_data['image_bi_pre'][:n], 282 | batch_data['image_bi_cur'][:n], 283 | batch_data['image_bi_post'][:n], 284 | batch_data['image_bi_post2'][:n], 285 | batch_data['image_bi_post3'][:n], 286 | batch_data['image_bi_post4'][:n], 287 | batch_data['image_bi_post5'][:n], 288 | batch_data['image_bi_post6'][:n], 289 | batch_data['image_bi_post7'][:n], 290 | batch_data['image_bi_post8'][:n], 291 | batch_data['image_bi_post9'][:n]]) 292 | save_image(comparison_GT.cpu(), 293 | './result/{}/prediction_full_step{}/prediction_GT_batch{}.png'.format(folder_name, step, batch_idx), nrow=n) 294 | comparison_Pred = torch.cat([batch_data['image_bi_pre'][:n], 295 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n], 296 | recon_img_post.view(-1, 1, 50, 50).cpu()[:n], 297 | recon_img_post2.view(-1, 1, 50, 50).cpu()[:n], 298 | recon_img_post3.view(-1, 1, 50, 50).cpu()[:n], 299 | recon_img_post4.view(-1, 1, 50, 50).cpu()[:n], 300 | recon_img_post5.view(-1, 1, 50, 50).cpu()[:n], 301 | recon_img_post6.view(-1, 1, 50, 50).cpu()[:n], 302 | recon_img_post7.view(-1, 1, 50, 50).cpu()[:n], 303 | recon_img_post8.view(-1, 1, 50, 50).cpu()[:n], 304 | recon_img_post9.view(-1, 1, 50, 50).cpu()[:n]]) 305 | save_image(comparison_Pred.cpu(), 306 | './result/{}/prediction_full_step{}/prediction_Pred_batch{}.png'.format(folder_name, step, batch_idx), nrow=n) 307 | #GT 308 | plot_sample_multi_step(batch_data['image_bi_pre'][:n].detach().cpu().numpy(), 309 | batch_data['image_bi_cur'][:n].detach().cpu().numpy(), 310 | batch_data['image_bi_post'][:n].detach().cpu().numpy(), 311 | batch_data['image_bi_post2'][:n].detach().cpu().numpy(), 312 | batch_data['image_bi_post3'][:n].detach().cpu().numpy(), 313 | batch_data['image_bi_post4'][:n].detach().cpu().numpy(), 314 | batch_data['image_bi_post5'][:n].detach().cpu().numpy(), 315 | batch_data['image_bi_post6'][:n].detach().cpu().numpy(), 316 | batch_data['image_bi_post7'][:n].detach().cpu().numpy(), 317 | batch_data['image_bi_post8'][:n].detach().cpu().numpy(), 318 | batch_data['image_bi_post9'][:n].detach().cpu().numpy(), 319 | batch_data['resz_action_pre'][:n].detach().cpu().numpy(), 320 | batch_data['resz_action_cur'][:n].detach().cpu().numpy(), 321 | batch_data['resz_action_post'][:n].detach().cpu().numpy(), 322 | batch_data['resz_action_post2'][:n].detach().cpu().numpy(), 323 | batch_data['resz_action_post3'][:n].detach().cpu().numpy(), 324 | batch_data['resz_action_post4'][:n].detach().cpu().numpy(), 325 | batch_data['resz_action_post5'][:n].detach().cpu().numpy(), 326 | batch_data['resz_action_post6'][:n].detach().cpu().numpy(), 327 | batch_data['resz_action_post7'][:n].detach().cpu().numpy(), 328 | batch_data['resz_action_post8'][:n].detach().cpu().numpy(), 329 | './result/{}/prediction_with_action_step{}/recon_GT_epoch_{}.png'.format(folder_name, step, batch_idx)) 330 | # Predicted 331 | plot_sample_multi_step(batch_data['image_bi_pre'][:n].detach().cpu().numpy(), 332 | recon_img_cur.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(), 333 | recon_img_post.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(), 334 | recon_img_post2.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(), 335 | recon_img_post3.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(), 336 | recon_img_post4.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(), 337 | recon_img_post5.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(), 338 | recon_img_post6.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(), 339 | recon_img_post7.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(), 340 | recon_img_post8.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(), 341 | recon_img_post9.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(), 342 | batch_data['resz_action_pre'][:n].detach().cpu().numpy(), 343 | batch_data['resz_action_cur'][:n].detach().cpu().numpy(), 344 | batch_data['resz_action_post'][:n].detach().cpu().numpy(), 345 | batch_data['resz_action_post2'][:n].detach().cpu().numpy(), 346 | batch_data['resz_action_post3'][:n].detach().cpu().numpy(), 347 | batch_data['resz_action_post4'][:n].detach().cpu().numpy(), 348 | batch_data['resz_action_post5'][:n].detach().cpu().numpy(), 349 | batch_data['resz_action_post6'][:n].detach().cpu().numpy(), 350 | batch_data['resz_action_post7'][:n].detach().cpu().numpy(), 351 | batch_data['resz_action_post8'][:n].detach().cpu().numpy(), 352 | './result/{}/prediction_with_action_step{}/recon_Pred_epoch_{}.png'.format(folder_name, step, batch_idx)) 353 | 354 | print('***** Preparing Data *****') 355 | total_img_num = 22515 356 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num) 357 | action_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy' 358 | actions = np.load(action_path) 359 | dataset = MyDatasetMultiPred10(image_paths_bi, actions, transform=ToTensorMultiPred10()) 360 | dataloader = DataLoader(dataset, batch_size=32, 361 | shuffle=True, num_workers=4, collate_fn=my_collate) 362 | print('***** Finish Preparing Data *****') 363 | 364 | folder_name = 'test_E2C' 365 | PATH = './result/{}/checkpoint'.format(folder_name) 366 | 367 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 368 | e2c_model = E2C().to(device) 369 | 370 | 371 | # load check point 372 | print('***** Load Checkpoint *****') 373 | checkpoint = torch.load(PATH, map_location=torch.device('cpu')) 374 | e2c_model.load_state_dict(checkpoint['e2c_model_state_dict']) 375 | 376 | # prediction 377 | print('***** Start Prediction *****') 378 | step=10 # Change this based on different prediction steps 379 | if not os.path.exists('./result/{}/prediction_full_step{}'.format(folder_name, step)): 380 | os.makedirs('./result/{}/prediction_full_step{}'.format(folder_name, step)) 381 | if not os.path.exists('./result/{}/prediction_with_action_step{}'.format(folder_name, step)): 382 | os.makedirs('./result/{}/prediction_with_action_step{}'.format(folder_name, step)) 383 | predict() 384 | print('***** Finish Prediction *****') -------------------------------------------------------------------------------- /model/predict_e2c_our_method_compare_gpu.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | from torch.autograd import Variable 5 | from torch import nn, optim, sigmoid, tanh, relu 6 | from torch.utils.data import Dataset, DataLoader 7 | from torch.nn import functional as F 8 | from deform.model.create_dataset import * 9 | from deform.model.hidden_dynamics import * 10 | from deform.utils.utils import plot_sample_multi_step 11 | from torchvision.utils import save_image 12 | import os 13 | import math 14 | 15 | class CAE(nn.Module): 16 | def __init__(self, latent_state_dim=80, latent_act_dim=80): 17 | super(CAE, self).__init__() 18 | # state 19 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), 20 | nn.ReLU(), 21 | nn.MaxPool2d(3, stride=2), 22 | nn.Conv2d(32, 64, 3, padding=1), 23 | nn.ReLU(), 24 | nn.MaxPool2d(3, stride=2), 25 | nn.Conv2d(64, 128, 3, padding=1), 26 | nn.ReLU(), 27 | nn.MaxPool2d(3, stride=2), 28 | nn.Conv2d(128, 128, 3, padding=1), 29 | nn.ReLU(), 30 | nn.Conv2d(128, 128, 3, padding=1), 31 | nn.ReLU(), 32 | nn.MaxPool2d(3, stride=2, padding=1)) 33 | self.fc1 = nn.Linear(128*3*3, latent_state_dim) 34 | self.fc2 = nn.Linear(latent_state_dim, 128*3*3) 35 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 36 | nn.ReLU(), 37 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 38 | nn.ReLU(), 39 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2), 40 | nn.ReLU(), 41 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2), 42 | nn.ReLU(), 43 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2), 44 | nn.Sigmoid()) 45 | # action 46 | self.fc5 = nn.Linear(4, latent_act_dim) 47 | self.fc6 = nn.Linear(latent_act_dim, latent_act_dim) 48 | self.fc7 = nn.Linear(latent_act_dim, latent_act_dim) 49 | self.fc8 = nn.Linear(latent_act_dim, 4) 50 | # add these in order to use GPU for parameters 51 | self.mul_tensor = torch.tensor([50, 50, 2*math.pi, 0.14]) 52 | self.add_tensor = torch.tensor([0, 0, 0, 0.01]) 53 | 54 | 55 | def encoder(self, x): 56 | x = self.conv_layers(x) 57 | x = x.view(x.shape[0], -1) 58 | return relu(self.fc1(x)) 59 | 60 | def decoder(self, x): 61 | x = relu(self.fc2(x)) 62 | x = x.view(-1, 128, 3, 3) 63 | return self.dconv_layers(x) 64 | 65 | def encoder_act(self, u): 66 | h1 = relu(self.fc5(u)) 67 | return relu(self.fc6(h1)) 68 | 69 | def decoder_act(self, u): 70 | h2 = relu(self.fc7(u)) 71 | if torch.cuda.is_available(): 72 | return torch.mul(sigmoid(self.fc8(h2)), self.mul_tensor.cuda()) + self.add_tensor.cuda() 73 | else: 74 | return torch.mul(sigmoid(self.fc8(h2)), self.mul_tensor) + self.add_tensor 75 | 76 | 77 | def forward(self, x_cur, u, x_post): 78 | g_cur = self.encoder(x_cur) 79 | a = self.encoder_act(u) 80 | g_post = self.encoder(x_post) 81 | 82 | return g_cur, a, g_post, self.decoder(g_cur), self.decoder_act(a) 83 | 84 | class SysDynamics(nn.Module): 85 | def __init__(self, latent_state_dim=80, latent_act_dim=80): 86 | super(SysDynamics, self).__init__() 87 | self.conv_layers_matrix = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), 88 | nn.ReLU(), 89 | nn.MaxPool2d(3, stride=1), 90 | nn.Conv2d(32, 64, 3, padding=1), 91 | nn.ReLU(), 92 | nn.MaxPool2d(3, stride=1), 93 | nn.Conv2d(64, 128, 3, padding=1), 94 | nn.ReLU(), 95 | nn.MaxPool2d(3, stride=2), 96 | nn.Conv2d(128, 256, 3, padding=1), 97 | nn.ReLU(), 98 | nn.MaxPool2d(3, stride=2), 99 | nn.Conv2d(256, 512, 3, padding=1), 100 | nn.ReLU(), 101 | nn.MaxPool2d(3, stride=2), 102 | nn.Conv2d(512, 512, 3, padding=1), 103 | nn.ReLU(), 104 | nn.Conv2d(512, 512, 3, padding=1), 105 | nn.ReLU(), 106 | nn.MaxPool2d(3, stride=2, padding=1)) 107 | self.fc31 = nn.Linear(512*2*2, latent_state_dim*latent_state_dim) 108 | self.fc32 = nn.Linear(latent_state_dim*latent_state_dim, latent_state_dim*latent_state_dim) 109 | self.fc41 = nn.Linear(512*2*2 + latent_act_dim, latent_state_dim*latent_act_dim) 110 | self.fc42 = nn.Linear(latent_state_dim*latent_act_dim, latent_state_dim*latent_act_dim) 111 | self.fc9 = nn.Linear(4, latent_act_dim) 112 | self.fc10 = nn.Linear(latent_act_dim, latent_act_dim) 113 | 114 | self.latent_act_dim = latent_act_dim 115 | self.latent_state_dim = latent_state_dim 116 | 117 | def encoder_matrix(self, x, a): 118 | x = self.conv_layers_matrix(x) 119 | x = x.view(x.shape[0], -1) 120 | xa = torch.cat((x,a), 1) 121 | 122 | return relu(self.fc32(relu(self.fc31(x)))).view(-1, self.latent_state_dim, self.latent_state_dim), \ 123 | relu(self.fc42(relu(self.fc41(xa)))).view(-1, self.latent_act_dim, self.latent_state_dim) 124 | 125 | def forward(self, x_cur, u): 126 | a = relu(self.fc10(relu(self.fc9(u)))) 127 | K_T, L_T = self.encoder_matrix(x_cur, a) 128 | 129 | return K_T, L_T 130 | 131 | class NormalDistribution(object): 132 | """ 133 | Wrapper class representing a multivariate normal distribution parameterized by 134 | N(mu,Cov). If cov. matrix is diagonal, Cov=(sigma).^2. Otherwise, 135 | Cov=A*(sigma).^2*A', where A = (I+v*r^T). 136 | """ 137 | 138 | def __init__(self, mu, sigma, logsigma, *, v=None, r=None): 139 | self.mu = mu 140 | self.sigma = sigma 141 | self.logsigma = logsigma 142 | self.v = v 143 | self.r = r 144 | 145 | @property 146 | def cov(self): 147 | """This should only be called when NormalDistribution represents one sample""" 148 | if self.v is not None and self.r is not None: 149 | assert self.v.dim() == 1 150 | dim = self.v.dim() 151 | v = self.v.unsqueeze(1) # D * 1 vector 152 | rt = self.r.unsqueeze(0) # 1 * D vector 153 | A = torch.eye(dim) + v.mm(rt) 154 | return A.mm(torch.diag(self.sigma.pow(2)).mm(A.t())) 155 | else: 156 | return torch.diag(self.sigma.pow(2)) 157 | 158 | 159 | def KLDGaussian(Q, N, eps=1e-8): 160 | """KL Divergence between two Gaussians 161 | Assuming Q ~ N(mu0, A\sigma_0A') where A = I + vr^{T} 162 | and N ~ N(mu1, \sigma_1) 163 | """ 164 | sum = lambda x: torch.sum(x, dim=1) 165 | k = float(Q.mu.size()[1]) # dimension of distribution 166 | mu0, v, r, mu1 = Q.mu, Q.v, Q.r, N.mu 167 | s02, s12 = (Q.sigma).pow(2) + eps, (N.sigma).pow(2) + eps 168 | a = sum(s02 * (1. + 2. * v * r) / s12) + sum(v.pow(2) / s12) * sum(r.pow(2) * s02) # trace term 169 | b = sum((mu1 - mu0).pow(2) / s12) # difference-of-means term 170 | c = 2. * (sum(N.logsigma - Q.logsigma) - torch.log(1. + sum(v * r) + eps)) # ratio-of-determinants term. 171 | 172 | return 0.5 * (a + b - k + c) 173 | 174 | class E2C(nn.Module): 175 | def __init__(self, dim_z=80, dim_u=4): 176 | super(E2C, self).__init__() 177 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), 178 | nn.ReLU(), 179 | nn.MaxPool2d(3, stride=2), 180 | nn.Conv2d(32, 64, 3, padding=1), 181 | nn.ReLU(), 182 | nn.MaxPool2d(3, stride=2), 183 | nn.Conv2d(64, 128, 3, padding=1), 184 | nn.ReLU(), 185 | nn.MaxPool2d(3, stride=2), 186 | nn.Conv2d(128, 128, 3, padding=1), 187 | nn.ReLU(), 188 | nn.Conv2d(128, 128, 3, padding=1), 189 | nn.ReLU(), 190 | nn.MaxPool2d(3, stride=2, padding=1)) 191 | self.fc1 = nn.Linear(128*3*3, dim_z*2) 192 | self.fc2 = nn.Linear(dim_z, 128*3*3) 193 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 194 | nn.ReLU(), 195 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 196 | nn.ReLU(), 197 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2), 198 | nn.ReLU(), 199 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2), 200 | nn.ReLU(), 201 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2), 202 | nn.Sigmoid()) 203 | self.trans = nn.Sequential( 204 | nn.Linear(dim_z, 100), 205 | nn.BatchNorm1d(100), 206 | nn.ReLU(), 207 | nn.Linear(100, 100), 208 | nn.BatchNorm1d(100), 209 | nn.ReLU(), 210 | nn.Linear(100, dim_z*2) 211 | ) 212 | self.fc_B = nn.Linear(dim_z, dim_z * dim_u) 213 | self.fc_o = nn.Linear(dim_z, dim_z) 214 | self.dim_z = dim_z 215 | self.dim_u = dim_u 216 | # action 217 | self.fc5 = nn.Linear(4, dim_u*20) 218 | self.fc6 = nn.Linear(dim_u*20, dim_u) 219 | self.fc7 = nn.Linear(dim_u, dim_u*20) 220 | self.fc8 = nn.Linear(dim_u*20, 4) 221 | self.mul_tensor = torch.tensor([50, 50, 2*math.pi, 0.14]) 222 | self.add_tensor = torch.tensor([0, 0, 0, 0.01]) 223 | 224 | def encode(self, x): 225 | x = self.conv_layers(x) 226 | x = x.view(x.shape[0], -1) 227 | return relu(self.fc1(x)).chunk(2, dim=1) 228 | 229 | def decode(self, x): 230 | x = relu(self.fc2(x)) 231 | x = x.view(-1, 128, 3, 3) 232 | return self.dconv_layers(x) 233 | 234 | def encode_act(self, u): 235 | h1 = relu(self.fc5(u)) 236 | return relu(self.fc6(h1)) 237 | 238 | def decode_act(self, u): 239 | h2 = relu(self.fc7(u)) 240 | return torch.mul(sigmoid(self.fc8(h2)), self.mul_tensor.cuda()) + self.add_tensor.cuda() 241 | 242 | def transition(self, h, Q, u): 243 | batch_size = h.size()[0] 244 | v, r = self.trans(h).chunk(2, dim=1) 245 | v1 = v.unsqueeze(2).cpu() 246 | rT = r.unsqueeze(1).cpu() 247 | I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1)) 248 | if rT.data.is_cuda: 249 | I.data.cuda() 250 | A = I.add(v1.bmm(rT)).cuda() 251 | 252 | B = self.fc_B(h).view(-1, self.dim_z, self.dim_u) 253 | o = self.fc_o(h).reshape((-1, self.dim_z, 1)) 254 | 255 | u = u.unsqueeze(2) 256 | 257 | d = A.bmm(Q.mu.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2) 258 | sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2) 259 | return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r) 260 | 261 | def reparam(self, mean, logvar): 262 | std = logvar.mul(0.5).exp_() 263 | self.z_mean = mean 264 | self.z_sigma = std 265 | eps = torch.FloatTensor(std.size()).normal_() 266 | eps = Variable(eps) 267 | return eps.mul(std.cpu()).add_(mean.cpu()).cuda(), NormalDistribution(mean, std, torch.log(std)) 268 | 269 | def forward(self, x, action, x_next): 270 | mean, logvar = self.encode(x) 271 | mean_next, logvar_next = self.encode(x_next) 272 | 273 | z, self.Qz = self.reparam(mean, logvar) 274 | z_next, self.Qz_next = self.reparam(mean_next, logvar_next) 275 | 276 | self.x_dec = self.decode(z) 277 | self.x_next_dec = self.decode(z_next) 278 | 279 | latent_a = self.encode_act(action) 280 | action_dec = self.decode_act(latent_a) 281 | 282 | self.z_next_pred, self.Qz_next_pred = self.transition(z, self.Qz, latent_a) 283 | self.x_next_pred_dec = self.decode(self.z_next_pred) 284 | 285 | 286 | return self.x_dec, self.x_next_dec, self.x_next_pred_dec, self.Qz, self.Qz_next, self.Qz_next_pred, action_dec 287 | 288 | def latent_embeddings(self, x): 289 | return self.encode(x)[0] 290 | 291 | def predict(self, X, U): 292 | mean, logvar = self.encode(X) 293 | z, Qz = self.reparam(mean, logvar) 294 | z_next_pred, Qz_next_pred = self.transition(z, Qz, U) 295 | return self.decode(z_next_pred) 296 | 297 | def binary_crossentropy(t, o, eps=1e-8): 298 | return t * torch.log(o + eps) + (1.0 - t) * torch.log(1.0 - o + eps) 299 | 300 | def compute_loss(x_dec, x_next_pred_dec, x, x_next, 301 | Qz, Qz_next_pred, 302 | Qz_next): 303 | # Reconstruction losses 304 | if False: 305 | x_reconst_loss = (x_dec - x_next).pow(2).sum(dim=1) 306 | x_next_reconst_loss = (x_next_pred_dec - x_next).pow(2).sum(dim=1) 307 | else: 308 | x_reconst_loss = -binary_crossentropy(x, x_dec).sum(dim=1) 309 | x_next_reconst_loss = -binary_crossentropy(x_next, x_next_pred_dec).sum(dim=1) 310 | 311 | logvar = Qz.logsigma.mul(2) 312 | KLD_element = Qz.mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 313 | KLD = torch.sum(KLD_element, dim=1).mul(-0.5) 314 | 315 | # ELBO 316 | bound_loss = x_reconst_loss.add(x_next_reconst_loss).add(KLD.reshape(-1,1,1)) 317 | kl = KLDGaussian(Qz_next_pred, Qz_next) 318 | kl = kl[~torch.isnan(kl)] 319 | return bound_loss.mean(), kl.mean() 320 | 321 | def predict(): 322 | e2c_model.eval() 323 | recon_model.eval() 324 | dyn_model.eval() 325 | with torch.no_grad(): 326 | for batch_idx, batch_data in enumerate(dataloader): 327 | # order: img_pre -> act_pre -> img_cur -> act_cur -> img_post 328 | # previous image 329 | img_pre = batch_data['image_bi_pre'] 330 | img_pre = img_pre.float().to(device).view(-1, 1, 50, 50) 331 | # previous action 332 | act_pre = batch_data['resz_action_pre'] 333 | act_pre = act_pre.float().to(device).view(-1, 4) 334 | # current image 335 | img_cur = batch_data['image_bi_cur'] 336 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50) 337 | # current action 338 | act_cur = batch_data['resz_action_cur'] 339 | act_cur = act_cur.float().to(device).view(-1, 4) 340 | # post image 341 | img_post = batch_data['image_bi_post'] 342 | img_post = img_post.float().to(device).view(-1, 1, 50, 50) 343 | # post action 344 | act_post = batch_data['resz_action_post'] 345 | act_post = act_post.float().to(device).view(-1, 4) 346 | # post2 image 347 | img_post2 = batch_data['image_bi_post2'] 348 | img_post2 = img_post2.float().to(device).view(-1, 1, 50, 50) 349 | # post2 action 350 | act_post2 = batch_data['resz_action_post2'] 351 | act_post2 = act_post2.float().to(device).view(-1, 4) 352 | # post3 image 353 | img_post3 = batch_data['image_bi_post3'] 354 | img_post3 = img_post3.float().to(device).view(-1, 1, 50, 50) 355 | # post3 action 356 | act_post3 = batch_data['resz_action_post3'] 357 | act_post3 = act_post3.float().to(device).view(-1, 4) 358 | # post4 image 359 | img_post4 = batch_data['image_bi_post4'] 360 | img_post4 = img_post4.float().to(device).view(-1, 1, 50, 50) 361 | # post4 action 362 | act_post4 = batch_data['resz_action_post4'] 363 | act_post4 = act_post4.float().to(device).view(-1, 4) 364 | # post5 image 365 | img_post5 = batch_data['image_bi_post5'] 366 | img_post5 = img_post5.float().to(device).view(-1, 1, 50, 50) 367 | # post5 action 368 | act_post5 = batch_data['resz_action_post5'] 369 | act_post5 = act_post5.float().to(device).view(-1, 4) 370 | # post6 image 371 | img_post6 = batch_data['image_bi_post6'] 372 | img_post6 = img_post6.float().to(device).view(-1, 1, 50, 50) 373 | # post6 action 374 | act_post6 = batch_data['resz_action_post6'] 375 | act_post6 = act_post6.float().to(device).view(-1, 4) 376 | # post7 image 377 | img_post7 = batch_data['image_bi_post7'] 378 | img_post7 = img_post7.float().to(device).view(-1, 1, 50, 50) 379 | # post7 action 380 | act_post7 = batch_data['resz_action_post7'] 381 | act_post7 = act_post7.float().to(device).view(-1, 4) 382 | # post8 image 383 | img_post8 = batch_data['image_bi_post8'] 384 | img_post8 = img_post8.float().to(device).view(-1, 1, 50, 50) 385 | # post8 action 386 | act_post8 = batch_data['resz_action_post8'] 387 | act_post8 = act_post8.float().to(device).view(-1, 4) 388 | # post9 image 389 | img_post9 = batch_data['image_bi_post9'] 390 | img_post9 = img_post9.float().to(device).view(-1, 1, 50, 50) 391 | # ten step prediction 392 | # prediction for current image from pre image 393 | recon_img_cur = e2c_model.predict(img_pre, act_pre) 394 | # prediction for post image from pre image 395 | recon_img_post = e2c_model.predict(recon_img_cur, act_cur) 396 | # prediction for post2 image from pre image 397 | recon_img_post2 = e2c_model.predict(recon_img_post, act_post) 398 | # prediction for post3 image from pre image 399 | recon_img_post3 = e2c_model.predict(recon_img_post2, act_post2) 400 | # prediction for post4 image from pre image 401 | recon_img_post4 = e2c_model.predict(recon_img_post3, act_post3) 402 | # prediction for post5 image from pre image 403 | recon_img_post5 = e2c_model.predict(recon_img_post4, act_post4) 404 | # prediction for post6 image from pre image 405 | recon_img_post6 = e2c_model.predict(recon_img_post5, act_post5) 406 | # prediction for post7 image from pre image 407 | recon_img_post7 = e2c_model.predict(recon_img_post6, act_post6) 408 | # prediction for post8 image from pre image 409 | recon_img_post8 = e2c_model.predict(recon_img_post7, act_post7) 410 | # prediction for post9 image from pre image 411 | recon_img_post9 = e2c_model.predict(recon_img_post8, act_post8) 412 | if batch_idx % 5 == 0: 413 | n = min(batch_data['image_bi_pre'].size(0), 1) 414 | comparison_GT = torch.cat([batch_data['image_bi_pre'][:n], 415 | batch_data['image_bi_cur'][:n], 416 | batch_data['image_bi_post'][:n], 417 | batch_data['image_bi_post2'][:n], 418 | batch_data['image_bi_post3'][:n], 419 | batch_data['image_bi_post4'][:n], 420 | batch_data['image_bi_post5'][:n], 421 | batch_data['image_bi_post6'][:n], 422 | batch_data['image_bi_post7'][:n], 423 | batch_data['image_bi_post8'][:n], 424 | batch_data['image_bi_post9'][:n]]) 425 | save_image(comparison_GT.cpu(), 426 | './result/{}/prediction_full_step{}/prediction_GT_batch{}.png'.format(folder_name_e2c, step, batch_idx), nrow=n) 427 | comparison_Pred = torch.cat([batch_data['image_bi_pre'][:n], 428 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n], 429 | recon_img_post.view(-1, 1, 50, 50).cpu()[:n], 430 | recon_img_post2.view(-1, 1, 50, 50).cpu()[:n], 431 | recon_img_post3.view(-1, 1, 50, 50).cpu()[:n], 432 | recon_img_post4.view(-1, 1, 50, 50).cpu()[:n], 433 | recon_img_post5.view(-1, 1, 50, 50).cpu()[:n], 434 | recon_img_post6.view(-1, 1, 50, 50).cpu()[:n], 435 | recon_img_post7.view(-1, 1, 50, 50).cpu()[:n], 436 | recon_img_post8.view(-1, 1, 50, 50).cpu()[:n], 437 | recon_img_post9.view(-1, 1, 50, 50).cpu()[:n]]) 438 | save_image(comparison_Pred.cpu(), 439 | './result/{}/prediction_full_step{}/prediction_Pred_batch{}.png'.format(folder_name_e2c, step, batch_idx), nrow=n) 440 | # ten step prediction 441 | # prediction for current image from pre image 442 | latent_img_pre, latent_act_pre, _, _, _ = recon_model(img_pre, act_pre, img_cur) 443 | K_T_pre, L_T_pre = dyn_model(img_pre, act_pre) 444 | recon_latent_img_cur = get_next_state_linear(latent_img_pre, latent_act_pre, K_T_pre, L_T_pre) 445 | recon_img_cur = recon_model.decoder(recon_latent_img_cur) 446 | # prediction for post image from pre image 447 | _, latent_act_cur, _, _, _ = recon_model(img_cur, act_cur, img_post) 448 | K_T_cur, L_T_cur = dyn_model(recon_img_cur, act_cur) 449 | recon_latent_img_post = get_next_state_linear(recon_latent_img_cur, latent_act_cur, K_T_cur, L_T_cur) 450 | recon_img_post = recon_model.decoder(recon_latent_img_post) 451 | # prediction for post2 image from pre image 452 | _, latent_act_post, _, _, _ = recon_model(img_post, act_post, img_post2) 453 | K_T_post, L_T_post = dyn_model(recon_img_post, act_post) 454 | recon_latent_img_post2 = get_next_state_linear(recon_latent_img_post, latent_act_post, K_T_post, L_T_post) 455 | recon_img_post2 = recon_model.decoder(recon_latent_img_post2) 456 | # prediction for post3 image from pre image 457 | _, latent_act_post2, _, _, _ = recon_model(img_post2, act_post2, img_post3) 458 | K_T_post2, L_T_post2 = dyn_model(recon_img_post2, act_post2) 459 | recon_latent_img_post3 = get_next_state_linear(recon_latent_img_post2, latent_act_post2, K_T_post2, L_T_post2) 460 | recon_img_post3 = recon_model.decoder(recon_latent_img_post3) 461 | # prediction for post4 image from pre image 462 | _, latent_act_post3, _, _, _ = recon_model(img_post3, act_post3, img_post4) 463 | K_T_post3, L_T_post3 = dyn_model(recon_img_post3, act_post3) 464 | recon_latent_img_post4 = get_next_state_linear(recon_latent_img_post3, latent_act_post3, K_T_post3, L_T_post3) 465 | recon_img_post4 = recon_model.decoder(recon_latent_img_post4) 466 | # prediction for post5 image from pre image 467 | _, latent_act_post4, _, _, _ = recon_model(img_post4, act_post4, img_post5) 468 | K_T_post4, L_T_post4 = dyn_model(recon_img_post4, act_post4) 469 | recon_latent_img_post5 = get_next_state_linear(recon_latent_img_post4, latent_act_post4, K_T_post4, L_T_post4) 470 | recon_img_post5 = recon_model.decoder(recon_latent_img_post5) 471 | # prediction for post6 image from pre image 472 | _, latent_act_post5, _, _, _ = recon_model(img_post5, act_post5, img_post6) 473 | K_T_post5, L_T_post5 = dyn_model(recon_img_post5, act_post5) 474 | recon_latent_img_post6 = get_next_state_linear(recon_latent_img_post5, latent_act_post5, K_T_post5, L_T_post5) 475 | recon_img_post6 = recon_model.decoder(recon_latent_img_post6) 476 | # prediction for post7 image from pre image 477 | _, latent_act_post6, _, _, _ = recon_model(img_post6, act_post6, img_post7) 478 | K_T_post6, L_T_post6 = dyn_model(recon_img_post6, act_post6) 479 | recon_latent_img_post7 = get_next_state_linear(recon_latent_img_post6, latent_act_post6, K_T_post6, L_T_post6) 480 | recon_img_post7 = recon_model.decoder(recon_latent_img_post7) 481 | # prediction for post8 image from pre image 482 | _, latent_act_post7, _, _, _ = recon_model(img_post7, act_post7, img_post8) 483 | K_T_post7, L_T_post7 = dyn_model(recon_img_post7, act_post7) 484 | recon_latent_img_post8 = get_next_state_linear(recon_latent_img_post7, latent_act_post7, K_T_post7, L_T_post7) 485 | recon_img_post8 = recon_model.decoder(recon_latent_img_post8) 486 | # prediction for post9 image from pre image 487 | _, latent_act_post8, _, _, _ = recon_model(img_post8, act_post8, img_post9) 488 | K_T_post8, L_T_post8 = dyn_model(recon_img_post8, act_post8) 489 | recon_latent_img_post9 = get_next_state_linear(recon_latent_img_post8, latent_act_post8, K_T_post8, L_T_post8) 490 | recon_img_post9 = recon_model.decoder(recon_latent_img_post9) 491 | if batch_idx % 5 == 0: 492 | n = min(batch_data['image_bi_pre'].size(0), 1) 493 | comparison_GT = torch.cat([batch_data['image_bi_pre'][:n], 494 | batch_data['image_bi_cur'][:n], 495 | batch_data['image_bi_post'][:n], 496 | batch_data['image_bi_post2'][:n], 497 | batch_data['image_bi_post3'][:n], 498 | batch_data['image_bi_post4'][:n], 499 | batch_data['image_bi_post5'][:n], 500 | batch_data['image_bi_post6'][:n], 501 | batch_data['image_bi_post7'][:n], 502 | batch_data['image_bi_post8'][:n], 503 | batch_data['image_bi_post9'][:n]]) 504 | save_image(comparison_GT.cpu(), 505 | './result/{}/prediction_full_step{}/prediction_GT_batch{}.png'.format(folder_name_our, step, batch_idx), nrow=n) 506 | comparison_Pred = torch.cat([batch_data['image_bi_pre'][:n], 507 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n], 508 | recon_img_post.view(-1, 1, 50, 50).cpu()[:n], 509 | recon_img_post2.view(-1, 1, 50, 50).cpu()[:n], 510 | recon_img_post3.view(-1, 1, 50, 50).cpu()[:n], 511 | recon_img_post4.view(-1, 1, 50, 50).cpu()[:n], 512 | recon_img_post5.view(-1, 1, 50, 50).cpu()[:n], 513 | recon_img_post6.view(-1, 1, 50, 50).cpu()[:n], 514 | recon_img_post7.view(-1, 1, 50, 50).cpu()[:n], 515 | recon_img_post8.view(-1, 1, 50, 50).cpu()[:n], 516 | recon_img_post9.view(-1, 1, 50, 50).cpu()[:n]]) 517 | save_image(comparison_Pred.cpu(), 518 | './result/{}/prediction_full_step{}/prediction_Pred_batch{}.png'.format(folder_name_our, step, batch_idx), nrow=n) 519 | 520 | print('***** Preparing Data *****') 521 | total_img_num = 22515 522 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num) 523 | action_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy' 524 | actions = np.load(action_path) 525 | dataset = MyDatasetMultiPred10(image_paths_bi, actions, transform=ToTensorMultiPred10()) 526 | dataloader = DataLoader(dataset, batch_size=32, 527 | shuffle=True, num_workers=4, collate_fn=my_collate) 528 | print('***** Finish Preparing Data *****') 529 | 530 | folder_name_e2c = 'test_E2C_gpu_update_loss' 531 | PATH_e2c = './result/{}/checkpoint'.format(folder_name_e2c) 532 | folder_name_our = 'test_act80_pred30' 533 | PATH_our = './result/{}/checkpoint'.format(folder_name_our) 534 | 535 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 536 | e2c_model = E2C().to(device) 537 | recon_model = CAE().to(device) 538 | dyn_model = SysDynamics().to(device) 539 | 540 | # load check point 541 | print('***** Load Checkpoint *****') 542 | checkpoint_e2c = torch.load(PATH_e2c, map_location=torch.device('cpu')) 543 | e2c_model.load_state_dict(checkpoint_e2c['e2c_model_state_dict']) 544 | checkpoint_our = torch.load(PATH_our, map_location=torch.device('cpu')) 545 | recon_model.load_state_dict(checkpoint_our['recon_model_state_dict']) 546 | dyn_model.load_state_dict(checkpoint_our['dyn_model_state_dict']) 547 | 548 | # prediction 549 | print('***** Start Prediction *****') 550 | step=10 # Change this based on different prediction steps 551 | if not os.path.exists('./result/{}/prediction_full_step{}'.format(folder_name_e2c, step)): 552 | os.makedirs('./result/{}/prediction_full_step{}'.format(folder_name_e2c, step)) 553 | if not os.path.exists('./result/{}/prediction_with_action_step{}'.format(folder_name_e2c, step)): 554 | os.makedirs('./result/{}/prediction_with_action_step{}'.format(folder_name_e2c, step)) 555 | if not os.path.exists('./result/{}/prediction_full_step{}'.format(folder_name_our, step)): 556 | os.makedirs('./result/{}/prediction_full_step{}'.format(folder_name_our, step)) 557 | if not os.path.exists('./result/{}/prediction_with_action_step{}'.format(folder_name_our, step)): 558 | os.makedirs('./result/{}/prediction_with_action_step{}'.format(folder_name_our, step)) 559 | predict() 560 | print('***** Finish Prediction *****') 561 | -------------------------------------------------------------------------------- /model/predict_large_linear_seprt_no_loop_Kp_Lpa.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | from torch import nn, optim, sigmoid, tanh, relu 5 | from torch.utils.data import Dataset, DataLoader 6 | from torch.nn import functional as F 7 | from deform.model.create_dataset import * 8 | from deform.model.hidden_dynamics import * 9 | from torchvision.utils import save_image 10 | import os 11 | import math 12 | 13 | class CAE(nn.Module): 14 | def __init__(self, latent_state_dim=80, latent_act_dim=80): 15 | super(CAE, self).__init__() 16 | # state 17 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), 18 | nn.ReLU(), 19 | nn.MaxPool2d(3, stride=2), 20 | nn.Conv2d(32, 64, 3, padding=1), 21 | nn.ReLU(), 22 | nn.MaxPool2d(3, stride=2), 23 | nn.Conv2d(64, 128, 3, padding=1), 24 | nn.ReLU(), 25 | nn.MaxPool2d(3, stride=2), 26 | nn.Conv2d(128, 128, 3, padding=1), 27 | nn.ReLU(), 28 | nn.Conv2d(128, 128, 3, padding=1), 29 | nn.ReLU(), 30 | nn.MaxPool2d(3, stride=2, padding=1)) 31 | self.fc1 = nn.Linear(128*3*3, latent_state_dim) 32 | self.fc2 = nn.Linear(latent_state_dim, 128*3*3) 33 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 34 | nn.ReLU(), 35 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1), 36 | nn.ReLU(), 37 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2), 38 | nn.ReLU(), 39 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2), 40 | nn.ReLU(), 41 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2), 42 | nn.Sigmoid()) 43 | # action 44 | self.fc5 = nn.Linear(4, latent_act_dim) 45 | self.fc6 = nn.Linear(latent_act_dim, latent_act_dim) 46 | self.fc7 = nn.Linear(latent_act_dim, latent_act_dim) 47 | self.fc8 = nn.Linear(latent_act_dim, 4) 48 | 49 | self.mul_tensor = torch.tensor([50, 50, 2*math.pi, 0.14]) 50 | self.add_tensor = torch.tensor([0, 0, 0, 0.01]) 51 | 52 | 53 | def encoder(self, x): 54 | x = self.conv_layers(x) 55 | x = x.view(x.shape[0], -1) 56 | return relu(self.fc1(x)) 57 | 58 | def decoder(self, x): 59 | x = relu(self.fc2(x)) 60 | x = x.view(-1, 128, 3, 3) 61 | return self.dconv_layers(x) 62 | 63 | def encoder_act(self, u): 64 | h1 = relu(self.fc5(u)) 65 | return relu(self.fc6(h1)) 66 | 67 | def decoder_act(self, u): 68 | h2 = relu(self.fc7(u)) 69 | return torch.mul(sigmoid(self.fc8(h2)), self.mul_tensor.cuda()) + self.add_tensor.cuda() 70 | 71 | def forward(self, x_cur, u, x_post): 72 | g_cur = self.encoder(x_cur) 73 | a = self.encoder_act(u) 74 | g_post = self.encoder(x_post) 75 | 76 | return g_cur, a, g_post, self.decoder(g_cur), self.decoder_act(a) 77 | 78 | class SysDynamics(nn.Module): 79 | def __init__(self, latent_state_dim=80, latent_act_dim=80): 80 | super(SysDynamics, self).__init__() 81 | self.conv_layers_matrix = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), 82 | nn.ReLU(), 83 | nn.MaxPool2d(3, stride=1), 84 | nn.Conv2d(32, 64, 3, padding=1), 85 | nn.ReLU(), 86 | nn.MaxPool2d(3, stride=1), 87 | nn.Conv2d(64, 128, 3, padding=1), 88 | nn.ReLU(), 89 | nn.MaxPool2d(3, stride=2), 90 | nn.Conv2d(128, 256, 3, padding=1), 91 | nn.ReLU(), 92 | nn.MaxPool2d(3, stride=2), 93 | nn.Conv2d(256, 512, 3, padding=1), 94 | nn.ReLU(), 95 | nn.MaxPool2d(3, stride=2), 96 | nn.Conv2d(512, 512, 3, padding=1), 97 | nn.ReLU(), 98 | nn.Conv2d(512, 512, 3, padding=1), 99 | nn.ReLU(), 100 | nn.MaxPool2d(3, stride=2, padding=1)) 101 | self.fc31 = nn.Linear(512*2*2, latent_state_dim*latent_state_dim) 102 | self.fc32 = nn.Linear(latent_state_dim*latent_state_dim, latent_state_dim*latent_state_dim) 103 | self.fc41 = nn.Linear(512*2*2 + latent_act_dim, latent_state_dim*latent_act_dim) 104 | self.fc42 = nn.Linear(latent_state_dim*latent_act_dim, latent_state_dim*latent_act_dim) 105 | self.fc9 = nn.Linear(4, latent_act_dim) 106 | self.fc10 = nn.Linear(latent_act_dim, latent_act_dim) 107 | # latent dim 108 | self.latent_act_dim = latent_act_dim 109 | self.latent_state_dim = latent_state_dim 110 | 111 | def encoder_matrix(self, x, a): 112 | x = self.conv_layers_matrix(x) 113 | x = x.view(x.shape[0], -1) 114 | xa = torch.cat((x,a), 1) 115 | 116 | return relu(self.fc32(relu(self.fc31(x)))).view(-1, self.latent_state_dim, self.latent_state_dim), \ 117 | relu(self.fc42(relu(self.fc41(xa)))).view(-1, self.latent_act_dim, self.latent_state_dim) 118 | 119 | def forward(self, x_cur, u): 120 | a = relu(self.fc10(relu(self.fc9(u)))) 121 | K_T, L_T = self.encoder_matrix(x_cur, a) 122 | 123 | return K_T, L_T 124 | 125 | def predict(): 126 | recon_model.eval() 127 | dyn_model.eval() 128 | with torch.no_grad(): 129 | for batch_idx, batch_data in enumerate(dataloader): 130 | # order: img_pre -> act_pre -> img_cur -> act_cur -> img_post 131 | # previous image 132 | img_pre = batch_data['image_bi_pre'] 133 | img_pre = img_pre.float().to(device).view(-1, 1, 50, 50) 134 | # previous action 135 | act_pre = batch_data['resz_action_pre'] 136 | act_pre = act_pre.float().to(device).view(-1, 4) 137 | # current image 138 | img_cur = batch_data['image_bi_cur'] 139 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50) 140 | # current action 141 | act_cur = batch_data['resz_action_cur'] 142 | act_cur = act_cur.float().to(device).view(-1, 4) 143 | # post image 144 | img_post = batch_data['image_bi_post'] 145 | img_post = img_post.float().to(device).view(-1, 1, 50, 50) 146 | # prediction for current image 147 | latent_img_pre, latent_act_pre, _, _, _ = recon_model(img_pre, act_pre, img_cur) 148 | K_T_pre, L_T_pre = dyn_model(img_pre, act_pre) 149 | recon_latent_img_cur = get_next_state_linear(latent_img_pre, latent_act_pre, K_T_pre, L_T_pre) 150 | recon_img_cur = recon_model.decoder(recon_latent_img_cur) 151 | # prediction for post image 152 | latent_img_cur, latent_act_cur, _, _, _ = recon_model(img_cur, act_cur, img_post) 153 | K_T_cur, L_T_cur = dyn_model(img_cur, act_cur) 154 | recon_latent_img_post = get_next_state_linear(latent_img_cur, latent_act_cur, K_T_cur, L_T_cur) 155 | recon_img_post = recon_model.decoder(recon_latent_img_post) 156 | if batch_idx % 10 == 0: 157 | n = min(batch_data['image_bi_pre'].size(0), 8) 158 | comparison = torch.cat([batch_data['image_bi_pre'][:n], 159 | batch_data['image_bi_cur'][:n], 160 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n], 161 | batch_data['image_bi_post'][:n], 162 | recon_img_post.view(-1, 1, 50, 50).cpu()[:n]]) 163 | save_image(comparison.cpu(), 164 | './result/{}/prediction_full_step{}/prediction_batch{}.png'.format(folder_name, step, batch_idx), nrow=n) 165 | 166 | 167 | print('***** Preparing Data *****') 168 | total_img_num = 22515 169 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num) 170 | action_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy' 171 | actions = np.load(action_path) 172 | dataset = MyDataset(image_paths_bi, actions, transform=ToTensor()) 173 | dataloader = DataLoader(dataset, batch_size=64, 174 | shuffle=True, num_workers=4, collate_fn=my_collate) 175 | print('***** Finish Preparing Data *****') 176 | 177 | folder_name = 'test_act80_pred50' 178 | PATH = './result/{}/checkpoint'.format(folder_name) 179 | 180 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 181 | recon_model = CAE().to(device) 182 | dyn_model = SysDynamics().to(device) 183 | 184 | # load check point 185 | print('***** Load Checkpoint *****') 186 | checkpoint = torch.load(PATH, map_location=torch.device('cpu')) 187 | recon_model.load_state_dict(checkpoint['recon_model_state_dict']) 188 | dyn_model.load_state_dict(checkpoint['dyn_model_state_dict']) 189 | 190 | # prediction 191 | print('***** Start Prediction *****') 192 | step=1 193 | if not os.path.exists('./result/{}/prediction_full_step{}'.format(folder_name, step)): 194 | os.makedirs('./result/{}/prediction_full_step{}'.format(folder_name, step)) 195 | predict() 196 | print('***** Finish Prediction *****') -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/params.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/params.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import matplotlib.pyplot as plt 4 | import os 5 | from matplotlib.lines import Line2D 6 | 7 | def plot_grad_flow(named_parameters, folder_name): 8 | '''Plots the gradients flowing through different layers in the net during training. 9 | Can be used for checking for possible gradient vanishing / exploding problems. 10 | 11 | Usage: Plug this function in Trainer class after loss.backwards() as 12 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow 13 | 14 | Source: https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10 15 | ''' 16 | ave_grads = [] 17 | layers = [] 18 | for n, p in named_parameters: 19 | if(p.requires_grad) and ("bias" not in n): 20 | layers.append(n) 21 | ave_grads.append(p.grad.abs().mean()) 22 | plt.plot(ave_grads, alpha=0.3, color="b") 23 | plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" ) 24 | plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical") 25 | plt.xlim(xmin=0, xmax=len(ave_grads)) 26 | plt.xlabel("Layers") 27 | plt.ylabel("average gradient") 28 | plt.title("Gradient flow") 29 | plt.grid(True) 30 | plt.savefig('./result/{}/plot/gradient_flow.png'.format(folder_name)) 31 | plt.close() 32 | 33 | def create_loss_list(loss_logger, kld=True): 34 | if loss_logger is None: 35 | train_loss_all = [] 36 | train_img_loss_all = [] 37 | train_act_loss_all = [] 38 | train_latent_loss_all = [] 39 | train_pred_loss_all = [] 40 | train_kld_loss_all = [] 41 | test_loss_all = [] 42 | test_img_loss_all = [] 43 | test_act_loss_all = [] 44 | test_latent_loss_all = [] 45 | test_pred_loss_all = [] 46 | test_kld_loss_all = [] 47 | else: 48 | train_loss_all = loss_logger['train_loss_all'] 49 | train_img_loss_all = loss_logger['train_img_loss_all'] 50 | train_act_loss_all = loss_logger['train_act_loss_all'] 51 | train_latent_loss_all = loss_logger['train_latent_loss_all'] 52 | train_pred_loss_all = loss_logger['train_pred_loss_all'] 53 | test_loss_all = loss_logger['test_loss_all'] 54 | test_img_loss_all = loss_logger['test_img_loss_all'] 55 | test_act_loss_all = loss_logger['test_act_loss_all'] 56 | test_latent_loss_all = loss_logger['test_latent_loss_all'] 57 | test_pred_loss_all = loss_logger['test_pred_loss_all'] 58 | if kld is True: 59 | train_kld_loss_all = loss_logger['train_kld_loss_all'] 60 | test_kld_loss_all = loss_logger['test_kld_loss_all'] 61 | else: 62 | train_kld_loss_all = [] 63 | test_kld_loss_all = [] 64 | return train_loss_all, train_img_loss_all, train_act_loss_all, train_latent_loss_all, train_pred_loss_all, train_kld_loss_all, \ 65 | test_loss_all, test_img_loss_all, test_act_loss_all, test_latent_loss_all, test_pred_loss_all, test_kld_loss_all 66 | 67 | def create_folder(folder_name): 68 | if not os.path.exists('./result/' + folder_name): 69 | os.makedirs('./result/' + folder_name) 70 | if not os.path.exists('./result/' + folder_name + '/plot'): 71 | os.makedirs('./result/' + folder_name + '/plot') 72 | if not os.path.exists('./result/' + folder_name + '/reconstruction_test'): 73 | os.makedirs('./result/' + folder_name + '/reconstruction_test') 74 | if not os.path.exists('./result/' + folder_name + '/reconstruction_train'): 75 | os.makedirs('./result/' + folder_name + '/reconstruction_train') 76 | if not os.path.exists('./result/' + folder_name + '/reconstruction_act_train'): 77 | os.makedirs('./result/' + folder_name + '/reconstruction_act_train') 78 | if not os.path.exists('./result/' + folder_name + '/reconstruction_act_test'): 79 | os.makedirs('./result/' + folder_name + '/reconstruction_act_test') 80 | 81 | def rect(poke, c, label=None): 82 | x, y, t, l = poke 83 | dx = -100 * l * math.cos(t) 84 | dy = -100 * l * math.sin(t) 85 | arrow = plt.arrow(x, y, dx, dy, width=0.001, head_width=6, head_length=6, color=c, label=label) 86 | 87 | 88 | def plot_action(resz_action, recon_action, directory): 89 | plt.figure() 90 | # upper row original 91 | plt.subplot(1, 2, 1) 92 | rect(resz_action[i], "blue") 93 | plt.axis('off') 94 | # middle row reconstruction 95 | plt.subplot(1, 2, 2) 96 | rect(recon_action[i], "red") 97 | plt.axis('off') 98 | plt.savefig(directory) 99 | plt.close() 100 | 101 | def change_img_dim(img): 102 | img_tmp = np.ones((img.shape[1], img.shape[2], img.shape[0])) 103 | img_tmp[:,:,0] = img[0] 104 | img_tmp[:,:,1] = img[1] 105 | img_tmp[:,:,2] = img[2] 106 | return img_tmp 107 | 108 | def plot_sample(img_before, img_after, resz_action, recon_action, directory): 109 | plt.figure() 110 | N = int(img_before.shape[0]) 111 | for i in range(N): 112 | # upper row original 113 | plt.subplot(3, N, i+1) 114 | rect(resz_action[i], "blue") 115 | plt.imshow(change_img_dim(img_before[i])) 116 | plt.axis('off') 117 | # middle row reconstruction 118 | plt.subplot(3, N, i+1+N) 119 | rect(recon_action[i], "red") 120 | plt.imshow(change_img_dim(img_before[i])) 121 | plt.axis('off') 122 | # lower row: next image after action 123 | plt.subplot(3, N, i+1+2*N) 124 | plt.imshow(change_img_dim(img_after[i])) 125 | plt.axis('off') 126 | plt.savefig(directory) 127 | plt.close() 128 | 129 | def plot_cem_sample(img_before, img_after, img_after_pred, resz_action, recon_action, directory): 130 | # source from rope.ipynb in Berkeley's rope dataset file 131 | plt.figure() 132 | # upper row original 133 | plt.subplot(2, 2, 1) 134 | rect(resz_action, "blue", "Ground Truth Action") 135 | plt.imshow(img_before.reshape((50,50,-1)), cmap='gray') 136 | plt.axis('off') 137 | # middle row reconstruction 138 | plt.subplot(2, 2, 2) 139 | plt.imshow(img_after.reshape((50,50,-1)), cmap='gray') 140 | plt.axis('off') 141 | # lower row: next image after action 142 | plt.subplot(2, 2, 3) 143 | rect(recon_action, "red", "Sampled Action") 144 | plt.imshow(img_before.reshape((50,50,-1)), cmap='gray') 145 | plt.axis('off') 146 | # lower row: next image after action 147 | plt.subplot(2, 2, 4) 148 | plt.imshow(img_after_pred.reshape((50,50,-1)), cmap='gray') 149 | plt.axis('off') 150 | plt.savefig(directory) 151 | plt.close() 152 | 153 | def plot_sample_multi_step(img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, act1, act2,\ 154 | act3, act4, act5, act6, act7, act8, act9, act10, directory): 155 | # multi-step prediction with action on the image 156 | plt.figure() 157 | N = int(img1.shape[0]) 158 | plt.subplots_adjust(wspace=0, hspace=0) 159 | for i in range(N): 160 | # img1, act1 161 | plt.subplot(11, N, i+1) 162 | rect(act1[i], "red") 163 | plt.imshow(img1[i].reshape((50,50,-1)), cmap='binary') 164 | plt.axis('off') 165 | # img2, act2 166 | plt.subplot(11, N, i+1+N) 167 | rect(act2[i], "red") 168 | plt.imshow(img2[i].reshape((50,50,-1)), cmap='binary') 169 | plt.axis('off') 170 | # img3, act3 171 | plt.subplot(11, N, i+1+2*N) 172 | rect(act3[i], "red") 173 | plt.imshow(img3[i].reshape((50,50,-1)), cmap='binary') 174 | plt.axis('off') 175 | # img4, act4 176 | plt.subplot(11, N, i+1+3*N) 177 | rect(act4[i], "red") 178 | plt.imshow(img4[i].reshape((50,50,-1)), cmap='binary') 179 | plt.axis('off') 180 | # img5, act5 181 | plt.subplot(11, N, i+1+4*N) 182 | rect(act5[i], "red") 183 | plt.imshow(img5[i].reshape((50,50,-1)), cmap='binary') 184 | plt.axis('off') 185 | # img6, act6 186 | plt.subplot(11, N, i+1+5*N) 187 | rect(act6[i], "red") 188 | plt.imshow(img6[i].reshape((50,50,-1)), cmap='binary') 189 | plt.axis('off') 190 | # img7, act7 191 | plt.subplot(11, N, i+1+6*N) 192 | rect(act7[i], "red") 193 | plt.imshow(img7[i].reshape((50,50,-1)), cmap='binary') 194 | plt.axis('off') 195 | # img8, act8 196 | plt.subplot(11, N, i+1+7*N) 197 | rect(act8[i], "red") 198 | plt.imshow(img8[i].reshape((50,50,-1)), cmap='binary') 199 | plt.axis('off') 200 | # img9, act9 201 | plt.subplot(11, N, i+1+8*N) 202 | rect(act9[i], "red") 203 | plt.imshow(img9[i].reshape((50,50,-1)), cmap='binary') 204 | plt.axis('off') 205 | # img10, act10 206 | plt.subplot(11, N, i+1+9*N) 207 | rect(act10[i], "red") 208 | plt.imshow(img10[i].reshape((50,50,-1)), cmap='binary') 209 | plt.axis('off') 210 | # img11, act11 211 | plt.subplot(11, N, i+1+10*N) 212 | plt.imshow(img11[i].reshape((50,50,-1)), cmap='binary') 213 | plt.axis('off') 214 | plt.savefig(directory) 215 | plt.close() 216 | 217 | def generate_initial_points(x, y, num_points, link_length): 218 | """generate initial points for a line 219 | x: the first (start from left) point's x position 220 | y: the first (start from left) point's y position 221 | num_points: number of points on a line 222 | link_length: each segment's length 223 | """ 224 | x_all = [x] 225 | y_all = [y] 226 | for _ in range(num_points-1): 227 | phi = np.random.uniform(-np.pi/10, np.pi/10) 228 | x1, y1 = x + link_length * np.cos(phi), y + link_length * np.sin(phi) 229 | x_all.append(x1) 230 | y_all.append(y1) 231 | x, y = x1, y1 232 | 233 | return x_all, y_all 234 | 235 | def generate_new_line(line_x_all, line_y_all, index, move_angle, move_length, link_length): 236 | ### 237 | # line_x_all: x position for all points on a line 238 | # i.e. np.array([x1, x2, x3,..., xn]) 239 | # line_y_all: y position for all points on a line 240 | # i.e. np.array([y1, y2, y3,..., yn]) 241 | # index: touching point's index in the line. Scalar. 242 | # grip_pos_before: gripper's position before moving in 2D 243 | # assume it is the same as touching posiiton on a line 244 | # i.e. np.array([x, y]) 245 | # grip_pos_after: gripper's position after moving in 2D 246 | # move_angle: gripper's moving angle. [0, 2*pi] 247 | # move_length: gripper's moving length. Scalar. 248 | # link_length: constant distance between two nearby points. Scalar. 249 | # num_points: number of points on the line. Scalar. 250 | # action: relative x and y position change. 251 | # i.e. np.array([delta_x, delta_y]) 252 | ### 253 | 254 | num_points = np.size(line_x_all) 255 | # initialize new_line_x_all and new_line_y_all 256 | new_line_x_all = [0] * num_points 257 | new_line_y_all = [0] * num_points 258 | 259 | # action 260 | action = get_action(move_angle, move_length) 261 | 262 | # gripper position (touching position on the line) before moving 263 | grip_pos_before = np.array([line_x_all[index], line_y_all[index]]) 264 | 265 | # gripper position after moving 266 | grip_pos_after = get_pos_after(grip_pos_before, action) 267 | new_line_x_all[index] = grip_pos_after[0] 268 | new_line_y_all[index] = grip_pos_after[1] 269 | 270 | # move points in the left side in order 271 | if index != 0: 272 | grip_pos_after_temp = grip_pos_after 273 | for i in range(index): 274 | new_index_left = index - (i+1) 275 | moved_pos_before = np.array([line_x_all[new_index_left], line_y_all[new_index_left]]) 276 | moved_pos_after = generate_new_point_pos_on_the_line(grip_pos_after_temp, moved_pos_before, link_length) 277 | grip_pos_after_temp = moved_pos_after 278 | new_line_x_all[new_index_left] = grip_pos_after_temp[0] 279 | new_line_y_all[new_index_left] = grip_pos_after_temp[1] 280 | 281 | # move points in the right side in oder 282 | if index != (num_points-1): 283 | grip_pos_after_temp = grip_pos_after 284 | for j in range(num_points-index-1): 285 | new_index_right = index + (j+1) 286 | moved_pos_before = np.array([line_x_all[new_index_right], line_y_all[new_index_right]]) 287 | moved_pos_after = generate_new_point_pos_on_the_line(grip_pos_after_temp, moved_pos_before, link_length) 288 | grip_pos_after_temp = moved_pos_after 289 | new_line_x_all[new_index_right] = grip_pos_after_temp[0] 290 | new_line_y_all[new_index_right] = grip_pos_after_temp[1] 291 | 292 | # # move points in the right side in order 293 | # # touch the line 294 | # line_index = get_line_index(gripper_x_pos, x_all) 295 | # touch_line_pos = [x_all[line_index], y_all[line_index], x_all[line_index+1], y_all[line_index+1]] # [x1, y1, x2, y2] 296 | 297 | # # move the touching line 298 | # touch_line_pos += [action[0], action[1], action[0], action[1]] 299 | 300 | 301 | return new_line_x_all, new_line_y_all 302 | 303 | def generate_new_point_pos_on_the_line(grip_pos_after, moved_pos_before, link_length): 304 | # get relative positions between moved_pos_before and grip_pos_after 305 | # cannot change the order of these two points 306 | delta_x = moved_pos_before[0] - grip_pos_after[0] 307 | delta_y = moved_pos_before[1] - grip_pos_after[1] 308 | angle = math.atan2(delta_y, delta_x) 309 | 310 | # get nearby point position after moving 311 | x_after = grip_pos_after[0] + link_length * np.cos(angle) 312 | y_after = grip_pos_after[1] + link_length * np.sin(angle) 313 | moved_pos_after = np.array([x_after, y_after]) 314 | 315 | return moved_pos_after 316 | 317 | def get_action(angle, length): 318 | action = np.array([length*np.cos(angle), length*np.sin(angle)]) 319 | 320 | return action 321 | 322 | def get_pos_after(grip_pos_before, action): 323 | x, y = grip_pos_before[0], grip_pos_before[1] 324 | pos_after = np.array([x+action[0], y+action[1]]) 325 | 326 | return pos_after 327 | 328 | def get_line_index(gripper_x_pos, x_all): 329 | line_index = sum(gripper_x_pos >= np.array(x_all)) - 1 330 | 331 | return line_index 332 | 333 | 334 | def plot_train_loss(file_name, folder_name): 335 | train_loss = np.load(file_name) 336 | plt.figure() 337 | plt.plot(train_loss) 338 | plt.title('Train Loss') 339 | plt.xlabel('Epochs') 340 | plt.ylabel('Loss') 341 | plt.savefig('./result/{}/plot/train_loss.png'.format(folder_name)) 342 | plt.close() 343 | 344 | def plot_test_loss(file_name, folder_name): 345 | test_loss = np.load(file_name) 346 | plt.figure() 347 | plt.plot(test_loss) 348 | plt.title('Test Loss') 349 | plt.xlabel('Epochs') 350 | plt.ylabel('Loss') 351 | plt.savefig('./result/{}/plot/test_loss.png'.format(folder_name)) 352 | plt.close() 353 | 354 | def plot_train_img_loss(file_name, folder_name): 355 | img_loss = np.load(file_name) 356 | plt.figure() 357 | plt.plot(img_loss) 358 | plt.title('Train Image Loss') 359 | plt.xlabel('Epochs') 360 | plt.ylabel('Loss') 361 | plt.savefig('./result/{}/plot/train_image_loss.png'.format(folder_name)) 362 | plt.close() 363 | 364 | def plot_train_act_loss(file_name, folder_name): 365 | act_loss = np.load(file_name) 366 | plt.figure() 367 | plt.plot(act_loss) 368 | plt.title('Train Action Loss') 369 | plt.xlabel('Epochs') 370 | plt.ylabel('Loss') 371 | plt.savefig('./result/{}/plot/train_action_loss.png'.format(folder_name)) 372 | plt.close() 373 | 374 | def plot_train_latent_loss(file_name, folder_name): 375 | latent_loss = np.load(file_name) 376 | plt.figure() 377 | plt.plot(latent_loss) 378 | plt.title('Train Latent Loss') 379 | plt.xlabel('Epochs') 380 | plt.ylabel('Loss') 381 | plt.savefig('./result/{}/plot/train_latent_loss.png'.format(folder_name)) 382 | plt.close() 383 | 384 | def plot_train_pred_loss(file_name, folder_name): 385 | pred_loss = np.load(file_name) 386 | plt.figure() 387 | plt.plot(pred_loss) 388 | plt.title('Train Prediction Loss') 389 | plt.xlabel('Epochs') 390 | plt.ylabel('Loss') 391 | plt.savefig('./result/{}/plot/train_prediction_loss.png'.format(folder_name)) 392 | plt.close() 393 | 394 | def plot_train_kld_loss(file_name, folder_name): 395 | kld_loss = np.load(file_name) 396 | plt.figure() 397 | plt.plot(kld_loss) 398 | plt.title('Train KL Divergence Loss') 399 | plt.xlabel('Epochs') 400 | plt.ylabel('Loss') 401 | plt.savefig('./result/{}/plot/train_kld_loss.png'.format(folder_name)) 402 | plt.close() 403 | 404 | def plot_test_img_loss(file_name, folder_name): 405 | img_loss = np.load(file_name) 406 | plt.figure() 407 | plt.plot(img_loss) 408 | plt.title('Test Image Loss') 409 | plt.xlabel('Epochs') 410 | plt.ylabel('Loss') 411 | plt.savefig('./result/{}/plot/test_image_loss.png'.format(folder_name)) 412 | plt.close() 413 | 414 | def plot_test_act_loss(file_name, folder_name): 415 | act_loss = np.load(file_name) 416 | plt.figure() 417 | plt.plot(act_loss) 418 | plt.title('Test Action Loss') 419 | plt.xlabel('Epochs') 420 | plt.ylabel('Loss') 421 | plt.savefig('./result/{}/plot/test_action_loss.png'.format(folder_name)) 422 | plt.close() 423 | 424 | def plot_test_latent_loss(file_name, folder_name): 425 | latent_loss = np.load(file_name) 426 | plt.figure() 427 | plt.plot(latent_loss) 428 | plt.title('Test Latent Loss') 429 | plt.xlabel('Epochs') 430 | plt.ylabel('Loss') 431 | plt.savefig('./result/{}/plot/test_latent_loss.png'.format(folder_name)) 432 | plt.close() 433 | 434 | def plot_test_pred_loss(file_name, folder_name): 435 | pred_loss = np.load(file_name) 436 | plt.figure() 437 | plt.plot(pred_loss) 438 | plt.title('Test Prediction Loss') 439 | plt.xlabel('Epochs') 440 | plt.ylabel('Loss') 441 | plt.savefig('./result/{}/plot/test_prediction_loss.png'.format(folder_name)) 442 | plt.close() 443 | 444 | def plot_test_kld_loss(file_name, folder_name): 445 | kld_loss = np.load(file_name) 446 | plt.figure() 447 | plt.plot(kld_loss) 448 | plt.title('Test KL Divergence Loss') 449 | plt.xlabel('Epochs') 450 | plt.ylabel('Loss') 451 | plt.savefig('./result/{}/plot/test_kld_loss.png'.format(folder_name)) 452 | plt.close() 453 | 454 | def plot_train_bound_loss(file_name, folder_name): 455 | bound_loss = np.load(file_name) 456 | plt.figure() 457 | plt.plot(bound_loss) 458 | plt.title('Train Bound Loss') 459 | plt.xlabel('Epochs') 460 | plt.ylabel('Loss') 461 | plt.savefig('./result/{}/plot/train_bound_loss.png'.format(folder_name)) 462 | plt.close() 463 | 464 | def plot_test_bound_loss(file_name, folder_name): 465 | bound_loss = np.load(file_name) 466 | plt.figure() 467 | plt.plot(bound_loss) 468 | plt.title('Test Bound Loss') 469 | plt.xlabel('Epochs') 470 | plt.ylabel('Loss') 471 | plt.savefig('./result/{}/plot/test_bound_loss.png'.format(folder_name)) 472 | plt.close() 473 | 474 | def plot_train_kl_loss(file_name, folder_name): 475 | kl_loss = np.load(file_name) 476 | plt.figure() 477 | plt.plot(kl_loss) 478 | plt.title('Train KL Loss') 479 | plt.xlabel('Epochs') 480 | plt.ylabel('Loss') 481 | plt.savefig('./result/{}/plot/train_kl_loss.png'.format(folder_name)) 482 | plt.close() 483 | 484 | def plot_test_kl_loss(file_name, folder_name): 485 | kl_loss = np.load(file_name) 486 | plt.figure() 487 | plt.plot(kl_loss) 488 | plt.title('Test KL Loss') 489 | plt.xlabel('Epochs') 490 | plt.ylabel('Loss') 491 | plt.savefig('./result/{}/plot/test_kl_loss.png'.format(folder_name)) 492 | plt.close() 493 | 494 | def plot_all_train_loss_with_noise(train, test, img, act, latent, pred, kld, folder_name): 495 | train_loss = np.load(train) 496 | test_loss = np.load(test) 497 | img_loss = np.load(img) 498 | act_loss = np.load(act) 499 | latent_loss = np.load(latent) 500 | pred_loss = np.load(pred) 501 | kld_loss = np.load(kld) 502 | plt.figure() 503 | train_curve, = plt.plot(train_loss, label='Train') 504 | test_curve, = plt.plot(test_loss, label='Test') 505 | img_curve, = plt.plot(img_loss, label='Image') 506 | act_curve, = plt.plot(act_loss, label='Action') 507 | latent_curve, = plt.plot(latent_loss, label='Latent') 508 | pred_curve, = plt.plot(pred_loss, label='Prediction') 509 | kld_curve, = plt.plot(kld_loss, label='KL Divergence') 510 | plt.title('Train loss and its subcomponents') 511 | plt.xlabel('Epochs') 512 | plt.ylabel('Loss') 513 | plt.legend([train_curve, test_curve, img_curve, act_curve, latent_curve, pred_curve, kld_curve], ['Train', 'Test', 'Image', 'Action', 'Latent', 'Prediction', 'KL Divergence']) 514 | plt.savefig('./result/{}/plot/all_train_loss.png'.format(folder_name)) 515 | plt.close() 516 | 517 | def plot_all_test_loss_with_noise(test, img, act, latent, pred, kld, folder_name): 518 | test_loss = np.load(test) 519 | img_loss = np.load(img) 520 | act_loss = np.load(act) 521 | latent_loss = np.load(latent) 522 | pred_loss = np.load(pred) 523 | kld_loss = np.load(kld) 524 | plt.figure() 525 | test_curve, = plt.plot(test_loss, label='Test') 526 | img_curve, = plt.plot(img_loss, label='Image') 527 | act_curve, = plt.plot(act_loss, label='Action') 528 | latent_curve, = plt.plot(latent_loss, label='Latent') 529 | pred_curve, = plt.plot(pred_loss, label='Prediction') 530 | kld_curve, = plt.plot(kld_loss, label='KL Divergence') 531 | plt.title('Test loss and its subcomponents') 532 | plt.xlabel('Epochs') 533 | plt.ylabel('Loss') 534 | plt.legend([test_curve, img_curve, act_curve, latent_curve, pred_curve, kld_curve], ['Test', 'Image', 'Action', 'Latent', 'Prediction', 'KL Divergence']) 535 | plt.savefig('./result/{}/plot/all_test_loss.png'.format(folder_name)) 536 | plt.close() 537 | 538 | def plot_all_train_loss_without_noise(train, test, img, act, latent, pred, folder_name): 539 | train_loss = np.load(train) 540 | test_loss = np.load(test) 541 | img_loss = np.load(img) 542 | act_loss = np.load(act) 543 | latent_loss = np.load(latent) 544 | pred_loss = np.load(pred) 545 | plt.figure() 546 | train_curve, = plt.plot(train_loss, label='Train') 547 | test_curve, = plt.plot(test_loss, label='Test') 548 | img_curve, = plt.plot(img_loss, label='Image') 549 | act_curve, = plt.plot(act_loss, label='Action') 550 | latent_curve, = plt.plot(latent_loss, label='Latent') 551 | pred_curve, = plt.plot(pred_loss, label='Prediction') 552 | plt.title('Train loss and its subcomponents') 553 | plt.xlabel('Epochs') 554 | plt.ylabel('Loss') 555 | plt.legend([train_curve, test_curve, img_curve, act_curve, latent_curve, pred_curve], ['Train', 'Test', 'Image', 'Action', 'Latent', 'Prediction']) 556 | plt.savefig('./result/{}/plot/all_train_loss.png'.format(folder_name)) 557 | plt.close() 558 | 559 | def plot_all_test_loss_without_noise(test, img, act, latent, pred, folder_name): 560 | test_loss = np.load(test) 561 | img_loss = np.load(img) 562 | act_loss = np.load(act) 563 | latent_loss = np.load(latent) 564 | pred_loss = np.load(pred) 565 | plt.figure() 566 | test_curve, = plt.plot(test_loss, label='Test') 567 | img_curve, = plt.plot(img_loss, label='Image') 568 | act_curve, = plt.plot(act_loss, label='Action') 569 | latent_curve, = plt.plot(latent_loss, label='Latent') 570 | pred_curve, = plt.plot(pred_loss, label='Prediction') 571 | plt.title('Test loss and its subcomponents') 572 | plt.xlabel('Epochs') 573 | plt.ylabel('Loss') 574 | plt.legend([test_curve, img_curve, act_curve, latent_curve, pred_curve], ['Test', 'Image', 'Action', 'Latent', 'Prediction']) 575 | plt.savefig('./result/{}/plot/all_test_loss.png'.format(folder_name)) 576 | plt.close() 577 | 578 | def save_data(folder_name, epochs, train_loss_all, train_img_loss_all, train_act_loss_all, 579 | train_latent_loss_all, train_pred_loss_all, 580 | test_loss_all, test_img_loss_all, test_act_loss_all, test_latent_loss_all, 581 | test_pred_loss_all, train_kld_loss_all=None, test_kld_loss_all=None, K=None, L=None): 582 | np.save('./result/{}/train_loss_epoch{}.npy'.format(folder_name, epochs), train_loss_all) 583 | np.save('./result/{}/train_img_loss_epoch{}.npy'.format(folder_name, epochs), train_img_loss_all) 584 | np.save('./result/{}/train_act_loss_epoch{}.npy'.format(folder_name, epochs), train_act_loss_all) 585 | np.save('./result/{}/train_latent_loss_epoch{}.npy'.format(folder_name, epochs), train_latent_loss_all) 586 | np.save('./result/{}/train_pred_loss_epoch{}.npy'.format(folder_name, epochs), train_pred_loss_all) 587 | np.save('./result/{}/test_loss_epoch{}.npy'.format(folder_name, epochs), test_loss_all) 588 | np.save('./result/{}/test_img_loss_epoch{}.npy'.format(folder_name, epochs), test_img_loss_all) 589 | np.save('./result/{}/test_act_loss_epoch{}.npy'.format(folder_name, epochs), test_act_loss_all) 590 | np.save('./result/{}/test_latent_loss_epoch{}.npy'.format(folder_name, epochs), test_latent_loss_all) 591 | np.save('./result/{}/test_pred_loss_epoch{}.npy'.format(folder_name, epochs), test_pred_loss_all) 592 | if train_kld_loss_all is not None: 593 | np.save('./result/{}/train_kld_loss_epoch{}.npy'.format(folder_name, epochs), train_kld_loss_all) 594 | if test_kld_loss_all is not None: 595 | np.save('./result/{}/test_kld_loss_epoch{}.npy'.format(folder_name, epochs), test_kld_loss_all) 596 | if K is not None: 597 | np.save('./result/{}/koopman_matrix.npy'.format(folder_name), K) 598 | if L is not None: 599 | np.save('./result/{}/control_matrix.npy'.format(folder_name), L) 600 | 601 | def save_e2c_data(folder_name, epochs, train_loss_all, train_bound_loss_all, train_kl_loss_all, train_pred_loss_all, \ 602 | test_loss_all, test_bound_loss_all, test_kl_loss_all, test_pred_loss_all): 603 | np.save('./result/{}/train_loss_epoch{}.npy'.format(folder_name, epochs), train_loss_all) 604 | np.save('./result/{}/train_bound_loss_epoch{}.npy'.format(folder_name, epochs), train_bound_loss_all) 605 | np.save('./result/{}/train_kl_loss_epoch{}.npy'.format(folder_name, epochs), train_kl_loss_all) 606 | np.save('./result/{}/train_pred_loss_epoch{}.npy'.format(folder_name, epochs), train_pred_loss_all) 607 | np.save('./result/{}/test_loss_epoch{}.npy'.format(folder_name, epochs), test_loss_all) 608 | np.save('./result/{}/test_bound_loss_epoch{}.npy'.format(folder_name, epochs), test_bound_loss_all) 609 | np.save('./result/{}/test_kl_loss_epoch{}.npy'.format(folder_name, epochs), test_kl_loss_all) 610 | np.save('./result/{}/test_pred_loss_epoch{}.npy'.format(folder_name, epochs), test_pred_loss_all) 611 | --------------------------------------------------------------------------------