├── README.md
├── __init__.py
├── image
├── CEM_result.png
└── multi_step_pred_result.png
├── model
├── README.md
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── cem.cpython-36.pyc
│ ├── cem.cpython-37.pyc
│ ├── cem.cpython-38.pyc
│ ├── cem_e2c.cpython-37.pyc
│ ├── cem_gpu.cpython-38.pyc
│ ├── cem_plot_curve.cpython-37.pyc
│ ├── configs.cpython-36.pyc
│ ├── configs.cpython-37.pyc
│ ├── create_dataset.cpython-36.pyc
│ ├── create_dataset.cpython-37.pyc
│ ├── create_dataset.cpython-38.pyc
│ ├── dl_model.cpython-36.pyc
│ ├── dl_model.cpython-37.pyc
│ ├── dl_model.cpython-38.pyc
│ ├── dl_model_gpu.cpython-38.pyc
│ ├── e2c.cpython-37.pyc
│ ├── hidden_dynamics.cpython-36.pyc
│ ├── hidden_dynamics.cpython-37.pyc
│ ├── hidden_dynamics.cpython-38.pyc
│ ├── model.cpython-36.pyc
│ ├── model.cpython-37.pyc
│ ├── nn.cpython-36.pyc
│ ├── nn.cpython-37.pyc
│ ├── nn_linear.cpython-36.pyc
│ ├── nn_linear.cpython-37.pyc
│ ├── predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-36.pyc
│ ├── predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-37.pyc
│ └── predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred4.cpython-36.pyc
├── cem.py
├── cem_e2c.py
├── cem_plot_action.py
├── cem_plot_curve.py
├── create_dataset.py
├── e2c.py
├── hidden_dynamics.py
├── image
│ ├── action_model.pdf
│ ├── action_model.png
│ ├── dynamics_model.pdf
│ ├── dynamics_model.png
│ ├── state_model.pdf
│ └── state_model.png
├── nn_large_linear_seprt_no_loop_Kp_Lpa.py
├── predict_e2c.py
├── predict_e2c_our_method_compare_gpu.py
└── predict_large_linear_seprt_no_loop_Kp_Lpa.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── __init__.cpython-37.pyc
├── __init__.cpython-38.pyc
├── params.cpython-36.pyc
├── utils.cpython-36.pyc
├── utils.cpython-37.pyc
└── utils.cpython-38.pyc
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Deformable Linear Object Prediction Using Locally Linear Latent Dynamics
2 |
3 | ## Introduction
4 | This repository contains the code for the paper - Deformable Linear Object Prediction Using Locally Linear Latent Dynamics.
5 |
6 | Access the [Paper](https://arxiv.org/pdf/2103.14184.pdf) on Arxiv
7 | Author: [Wenbo Zhang](https://www.linkedin.com/in/wenbo-zhang6/), [Karl Schmeckpaper](https://sites.google.com/view/karlschmeckpeper), [Pratik Chaudhari](https://pratikac.github.io/), [Kostas Daniilidis](https://www.cis.upenn.edu/~kostas/)
8 | GRASP Laboratory, University of Pennsylvania
9 | The 2021 International Conference on Robotics and Automation (ICRA 2021), Xi'an, China
10 |
11 | ## Citation
12 | If you use this code for your research, please cite our paper:
13 | ```
14 | @article{zhang2021deformable,
15 | title={Deformable Linear Object Prediction Using Locally Linear Latent Dynamics},
16 | author={Zhang, Wenbo and Schmeckpeper, Karl and Chaudhari, Pratik and Daniilidis, Kostas},
17 | journal={IEEE International Conference on Robotics and Automation (ICRA)},
18 | year={2021}
19 | }
20 | ```
21 |
22 | ## Running the code
23 | ### Preparation
24 | Create a folder for the repo
25 | ```
26 | mkdir deform
27 | cd deform
28 | ```
29 | Build the virtual environment
30 | ```
31 | python3 -m pip install --user virtualenv
32 | python3 -m venv deform_env
33 | ```
34 |
35 | Activate virtual environment
36 | ```
37 | source deform_env/bin/activate
38 | ```
39 |
40 | Install libraries and dependencies
41 | ```
42 | pip install torch torchvision
43 | pip install matplotlib
44 | ```
45 |
46 | Clone the program
47 | ```
48 | git clone https://github.com/zwbgood6/deform.git
49 | ```
50 |
51 | ### Dataset
52 | ```
53 | mkdir rope_dataset
54 | ```
55 | Download the dataset from the [Google Drive](https://drive.google.com/file/d/1jy1EUDSeH3d3cZUSK1xChOBvn-qqx9WA/view?usp=sharing) and place it in the folder `rope_dataset`.
56 |
57 | ```
58 | cd rope_dataset
59 | unzip paper_dataset.zip
60 | ```
61 |
62 | ### Training
63 | Go to the main deform folder directory
64 | ```
65 | python -m deform.model.nn_large_linear_seprt_no_loop_Kp_Lpa
66 | ```
67 |
68 | ### Prediction
69 | After training, run
70 | ```
71 | python -m deform.model.predict_large_linear_seprt_no_loop_Kp_Lpa
72 | ```
73 |
74 | ### Issues
75 | Please post a GitHub issue for any code related questions.
76 |
77 | ## Model Architecture and Hyperparameters
78 |
79 | 
80 |
81 | ## Experimental Results
82 |
83 | Ten-step Prediction
84 |
85 |
86 |
87 | Sampling-based Model Predictive Control
88 |
89 |
90 |
91 | ## License
92 | MIT
93 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/__init__.py
--------------------------------------------------------------------------------
/image/CEM_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/image/CEM_result.png
--------------------------------------------------------------------------------
/image/multi_step_pred_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/image/multi_step_pred_result.png
--------------------------------------------------------------------------------
/model/README.md:
--------------------------------------------------------------------------------
1 | ## Model Architecture
2 |
3 | State Model
4 |
5 |
6 |
7 | Action Model
8 |
9 |
10 |
11 | Dynamics model
12 |
13 |
14 |
15 | ## Hyperparameters
16 |
17 | | Hyperparameters | Values |
18 | | :------------- | :----------: |
19 | | epochs (overall) | 1000 |
20 | | epochs (state and action encoder-decoder) | 500 |
21 | | epochs (dynamics model) | 500 |
22 | | learning rate | 1e-3 |
23 | | batch size | 32 |
24 | | latent state size | 80 |
25 | | latent action size | 80 |
26 | | λ1 (action coefficient in the loss function) | 450 |
27 | | λ2 (dynamics coefficient in the loss function) | 900 |
28 | | λ3 (prediction coefficient in the loss function) | 10 |
29 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__init__.py
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/cem.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/cem.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/cem.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/cem_e2c.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem_e2c.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/cem_gpu.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem_gpu.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/cem_plot_curve.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/cem_plot_curve.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/configs.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/configs.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/configs.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/configs.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/create_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/create_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/create_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/create_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/create_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/create_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/dl_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/dl_model.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/dl_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/dl_model.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/dl_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/dl_model.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/dl_model_gpu.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/dl_model_gpu.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/e2c.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/e2c.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/hidden_dynamics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/hidden_dynamics.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/hidden_dynamics.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/hidden_dynamics.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/hidden_dynamics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/hidden_dynamics.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/nn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/nn.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/nn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/nn.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/nn_linear.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/nn_linear.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/nn_linear.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/nn_linear.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred10.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred4.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/__pycache__/predict_large_linear_seprt_no_loop_Kp_Lpa_multiStepPred4.cpython-36.pyc
--------------------------------------------------------------------------------
/model/cem.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.linalg import det
3 | from numpy import sqrt
4 | from deform.model.dl_model import *
5 | from deform.model.create_dataset import *
6 | from deform.model.hidden_dynamics import get_next_state_linear
7 | import math
8 | import torch
9 | from torch.utils.data import Dataset, DataLoader
10 | from torchvision.utils import save_image
11 | from torch.distributions import Uniform
12 | from torch.distributions.multivariate_normal import MultivariateNormal
13 | from torch.distributions.kl import kl_divergence
14 | import torchvision.transforms.functional as TF
15 | from PIL import Image
16 | from deform.utils.utils import plot_cem_sample
17 | import os
18 |
19 | def sample_action(I, mean=None, cov=None):
20 | '''TODO: unit test
21 | each action sequence length: H
22 | number of action sequences: N
23 | '''
24 | action = torch.tensor([0]*4, dtype=torch.float)
25 | multiplier = torch.tensor([50, 50, 2*math.pi, 0.14])
26 | addition = torch.tensor([0, 0, 0, 0.01])
27 | thres = 0.9
28 | if I[0][0][0][0] == 1.:
29 | if ((mean is None) and (cov is None)):
30 | action_base = Uniform(low=0.0, high=1.0).sample((4,))
31 | action = torch.mul(action_base, multiplier) + addition
32 | else:
33 | cov = add_eye(cov)
34 | action = MultivariateNormal(mean, cov).sample()
35 | action[0], action[1] = 0, 0
36 | return action
37 |
38 | while I[0][0][torch.floor(action[0]).type(torch.LongTensor)][torch.floor(action[1]).type(torch.LongTensor)] != 1.:
39 | if ((mean is None) and (cov is None)):
40 | action_base = Uniform(low=0.0, high=1.0).sample((4,))
41 | action = torch.mul(action_base, multiplier) + addition
42 | else:
43 | cov = add_eye(cov)# dont need to use multiplication and addition
44 | while torch.floor(action[0]).type(torch.LongTensor) >= 50 or torch.floor(action[1]).type(torch.LongTensor) >= 50:
45 | cov = add_eye(cov)
46 | action = MultivariateNormal(mean, cov).sample()
47 | return action
48 |
49 | def generate_next_pred_state(recon_model, dyn_model, img_pre, act_pre):
50 | '''generate next predicted state
51 | reconstruction model: recon_model
52 | dynamics model: dyn_model
53 | initial image: img_pre
54 | each action sequence length: H
55 | number of action sequences: N
56 | '''
57 | latent_img_pre, latent_act_pre, _, _, _ = recon_model(img_pre.reshape((-1, 1, 50, 50)), act_pre.reshape((-1, 4)).type(torch.float), None)
58 | K_T_pre, L_T_pre = dyn_model(img_pre.reshape((-1, 1, 50, 50)), act_pre.reshape((-1, 4)).type(torch.float))
59 | recon_latent_img_cur = get_next_state_linear(latent_img_pre, latent_act_pre, K_T_pre, L_T_pre)
60 | return recon_model.decoder(recon_latent_img_cur)
61 |
62 | def generate_next_pred_state_in_n_step(recon_model, dyn_model, img_initial, N, H, mean=None, cov=None):
63 | imgs = [None]*N
64 | actions = torch.tensor([[0.]*4]*N)
65 | for n in range(N):
66 | img_tmp = img_initial
67 | for h in range(H):
68 | action = sample_action(img_tmp, mean, cov)
69 | if h==0:
70 | actions[n] = action
71 | img_tmp = generate_next_pred_state(recon_model, dyn_model, img_tmp, action)
72 | imgs[n] = img_tmp
73 | return imgs, actions
74 |
75 | def loss_function_img(img_recon, img_goal, N):
76 | loss = torch.as_tensor([0.]*N)
77 | for n in range(N):
78 | loss[n] = F.binary_cross_entropy(img_recon[n].view(-1, 2500), img_goal.view(-1, 2500), reduction='sum')
79 | return loss
80 |
81 | def add_eye(cov):
82 | if det(cov)==0:
83 | return cov + torch.eye(4) * 0.000001
84 | else:
85 | return cov
86 |
87 | def mahalanobis(dist, cov):
88 | '''dist = mu1 - mu2, mu1 & mu2 are means of two multivariate gaussian distribution
89 | matrix multiplication: dist^T * cov^(-1) * dist
90 | '''
91 | return (dist.transpose(0,1).mm(cov.inverse())).mm(dist)
92 |
93 | def bhattacharyya(dist, cov1, cov2):
94 | '''source: https://en.wikipedia.org/wiki/Bhattacharyya_distance
95 | '''
96 | cov = (cov1 + cov2) / 2
97 | d1 = mahalanobis(dist.reshape((4,-1)), cov) / 8
98 | if det(cov)==0 or det(cov1)==0 or det(cov2)==0:
99 | return inf
100 | d2 = np.log(det(cov) / sqrt(det(cov1) * det(cov2))) / 2
101 | return d1 + d2
102 |
103 | def main(recon_model, dyn_model, T, K, N, H, img_initial, img_goal, resz_act, step_i):
104 | for t in range(T):
105 | print("***** Start Step {}".format(t))
106 | if t==0:
107 | img_cur = img_initial
108 | #Initialize Q with uniform distribution
109 | mean = None
110 | cov = None
111 | mean_tmp = None
112 | cov_tmp = None
113 | converge = False
114 | iter_count = 0
115 | while not converge:
116 | #Use model M to predict the next state using M action sequences
117 | imgs_recon, sample_actions = generate_next_pred_state_in_n_step(recon_model, dyn_model, img_cur, N, H, mean, cov)
118 | #Calculate binary cross entropy loss for predicted image and goal image
119 | loss = loss_function_img(imgs_recon, img_goal, N)
120 | #Select K action sequences with lowest loss
121 | loss_index = torch.argsort(loss)
122 | sorted_sample_actions = sample_actions[loss_index]
123 | #Fit multivariate gaussian distribution to K samples
124 | #(see how to fit algorithm:
125 | #https://stackoverflow.com/questions/27230824/fit-multivariate-gaussian-distribution-to-a-given-dataset)
126 | mean = torch.mean(sorted_sample_actions[:K], dim=0).type(torch.DoubleTensor)
127 | cov = torch.from_numpy(np.cov(sorted_sample_actions[:K], rowvar=0)).type(torch.DoubleTensor)
128 | # iteration is based on convergence of Q
129 | if det(cov) == 0 or cov_tmp == None:
130 | mean_tmp = mean
131 | cov_tmp = cov
132 | continue
133 | else:
134 | if det(cov_tmp)==0:
135 | mean_tmp = mean
136 | cov_tmp = cov
137 | continue
138 | else:
139 | p = MultivariateNormal(mean, cov)
140 | q = MultivariateNormal(mean_tmp, cov_tmp)
141 | if kl_divergence(p, q) < 1000: # 0.3 is okay
142 | converge = True
143 |
144 | mean_tmp = mean
145 | cov_tmp = cov
146 |
147 | print("***** At action time step {}, iteration {} *****".format(t, iter_count))
148 | iter_count += 1
149 |
150 |
151 | #Execute action a{t}* with lowest loss
152 | action_best = sorted_sample_actions[0]
153 | torch.save(action_best, "./plan_result/{}/action_best_step{}_N{}_K{}.pt".format(plan_folder_name, step_i, N, K))
154 | #Observe new image I{t+1}
155 | img_cur = generate_next_pred_state(recon_model, dyn_model, img_cur, action_best)
156 | comparison = torch.cat([img_initial, img_goal, img_cur])
157 | save_image(comparison, "./plan_result/{}/image_best_step{}_N{}_K{}.png".format(plan_folder_name, step_i, N, K))
158 | plot_cem_sample(img_initial.detach().reshape((50,50)).cpu().numpy(),
159 | img_goal.detach().reshape((50,50)).cpu().numpy(),
160 | img_cur.detach().reshape((50,50)).cpu().numpy(),
161 | resz_act[:4],
162 | action_best.detach().cpu().numpy(),
163 | './plan_result/{}/compare_align_{}.png'.format(plan_folder_name, i))
164 | print("***** Generate Next Predicted Image {}*****".format(t+1))
165 |
166 | print("***** End Planning *****")
167 |
168 | # plan result folder name
169 | plan_folder_name = 'test_KL1000'
170 | if not os.path.exists('./plan_result/{}'.format(plan_folder_name)):
171 | os.makedirs('./plan_result/{}'.format(plan_folder_name))
172 | # time step to execute the action
173 | T = 1
174 | # total number of samples for action sequences
175 | N = 100
176 | # K samples to fit multivariate gaussian distribution (N>K, K>1)
177 | K = 50
178 | # length of action sequence
179 | H = 1
180 | # model
181 | torch.manual_seed(1)
182 | device = torch.device("cpu")
183 | print("Device is:", device)
184 | recon_model = CAE().to(device)
185 | dyn_model = SysDynamics().to(device)
186 |
187 | # action
188 | # load GT action
189 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy'
190 | resz_act = np.load(resz_act_path)
191 |
192 | # checkpoint
193 | print('***** Load Checkpoint *****')
194 | folder_name = "test_act80_pred30"
195 | PATH = './result/{}/checkpoint'.format(folder_name)
196 | checkpoint = torch.load(PATH, map_location=device)
197 | recon_model.load_state_dict(checkpoint['recon_model_state_dict'])
198 | dyn_model.load_state_dict(checkpoint['dyn_model_state_dict'])
199 |
200 | total_img_num = 22515
201 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num)
202 |
203 |
204 | def get_image(i):
205 | img = TF.to_tensor(Image.open(image_paths_bi[i])) > 0.3
206 | return img.reshape((-1, 1, 50, 50)).type(torch.float)
207 |
208 | for i in range(0,1):
209 | img_initial = get_image(i)
210 | img_goal = get_image(i+1)
211 | main(recon_model, dyn_model, T, K, N, H, img_initial, img_goal, resz_act[i], i)
--------------------------------------------------------------------------------
/model/cem_e2c.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.linalg import det
3 | from numpy import sqrt
4 | from deform.model.dl_model import *
5 | from deform.model.create_dataset import *
6 | from deform.model.hidden_dynamics import get_next_state_linear
7 | import math
8 | import torch
9 | from torch.utils.data import Dataset, DataLoader
10 | from torchvision.utils import save_image
11 | from torch.distributions import Uniform
12 | from torch.distributions.multivariate_normal import MultivariateNormal
13 | from torch.distributions.kl import kl_divergence
14 | import torchvision.transforms.functional as TF
15 | from PIL import Image
16 | from deform.utils.utils import plot_cem_sample
17 | import os
18 |
19 | def sample_action(I, mean=None, cov=None):
20 | '''TODO: unit test
21 | each action sequence length: H
22 | number of action sequences: N
23 | '''
24 | action = torch.tensor([0]*4, dtype=torch.float)
25 | multiplier = torch.tensor([50, 50, 2*math.pi, 0.14])
26 | addition = torch.tensor([0, 0, 0, 0.01])
27 | thres = 0.9
28 |
29 | if I[0][0][0][0] == 1.:
30 | if ((mean is None) and (cov is None)):
31 | action_base = Uniform(low=0.0, high=1.0).sample((4,))
32 | action = torch.mul(action_base, multiplier) + addition
33 | else:
34 | cov = add_eye(cov)
35 | action = MultivariateNormal(mean, cov).sample()
36 | action[0], action[1] = 0, 0
37 | return action
38 |
39 | while I[0][0][torch.floor(action[0]).type(torch.LongTensor)][torch.floor(action[1]).type(torch.LongTensor)] != 1.:
40 | if ((mean is None) and (cov is None)):
41 | action_base = Uniform(low=0.0, high=1.0).sample((4,))
42 | action = torch.mul(action_base, multiplier) + addition
43 | else:
44 | cov = add_eye(cov)
45 | action = MultivariateNormal(mean, cov).sample()
46 | while torch.floor(action[0]).type(torch.LongTensor) >= 50 or torch.floor(action[1]).type(torch.LongTensor) >= 50:
47 | cov = add_eye(cov)
48 | action = MultivariateNormal(mean, cov).sample()
49 |
50 | return action
51 |
52 | def generate_next_pred_state(e2c_model, img_pre, act_pre):
53 | '''generate next predicted state
54 | reconstruction model: recon_model
55 | dynamics model: dyn_model
56 | initial image: img_pre
57 | each action sequence length: H
58 | number of action sequences: N
59 | '''
60 | recon_img_cur = e2c_model.predict(img_pre, act_pre)
61 | return recon_img_cur
62 |
63 |
64 | def generate_next_pred_state_in_n_step(e2c_model, img_initial, N, H, mean=None, cov=None):
65 | imgs = [None]*N
66 | actions = torch.tensor([[0.]*4]*N)
67 | for n in range(N):
68 | img_tmp = img_initial
69 | for h in range(H):
70 | action = sample_action(img_tmp, mean, cov)
71 | if h==0:
72 | actions[n] = action
73 | img_tmp = generate_next_pred_state(e2c_model, img_tmp, action)
74 | imgs[n] = img_tmp
75 | return imgs, actions
76 |
77 | def loss_function_img(img_recon, img_goal, N):
78 | loss = torch.as_tensor([0.]*N)
79 | for n in range(N):
80 | loss[n] = F.binary_cross_entropy(img_recon[n].view(-1, 2500), img_goal.view(-1, 2500), reduction='sum')
81 | return loss
82 |
83 | def add_eye(cov):
84 | if det(cov)==0:
85 | return cov + torch.eye(4) * 0.000001
86 | else:
87 | return cov
88 |
89 | def mahalanobis(dist, cov):
90 | '''dist = mu1 - mu2, mu1 & mu2 are means of two multivariate gaussian distribution
91 | matrix multiplication: dist^T * cov^(-1) * dist
92 | '''
93 | return (dist.transpose(0,1).mm(cov.inverse())).mm(dist)
94 |
95 | def bhattacharyya(dist, cov1, cov2):
96 | '''source: https://en.wikipedia.org/wiki/Bhattacharyya_distance
97 | '''
98 | cov = (cov1 + cov2) / 2
99 | d1 = mahalanobis(dist.reshape((4,-1)), cov) / 8
100 | if det(cov)==0 or det(cov1)==0 or det(cov2)==0:
101 | return inf
102 | d2 = np.log(det(cov) / sqrt(det(cov1) * det(cov2))) / 2
103 | return d1 + d2
104 |
105 | def main(e2c_model, T, K, N, H, img_initial, img_goal, resz_act, step_i, KL):
106 | for t in range(T):
107 | print("***** Start Step {}".format(t))
108 | if t==0:
109 | img_cur = img_initial
110 | #Initialize Q with uniform distribution
111 | mean = None
112 | cov = None
113 | mean_tmp = None
114 | cov_tmp = None
115 | converge = False
116 | iter_count = 0
117 | while not converge:
118 | imgs_recon, sample_actions = generate_next_pred_state_in_n_step(e2c_model, img_cur, N, H, mean, cov)
119 | #Calculate binary cross entropy loss for predicted image and goal image
120 | loss = loss_function_img(imgs_recon, img_goal, N)
121 | #Select K action sequences with lowest loss
122 | loss_index = torch.argsort(loss)
123 | sorted_sample_actions = sample_actions[loss_index]
124 | #Fit multivariate gaussian distribution to K samples
125 | #(see how to fit algorithm:
126 | #https://stackoverflow.com/questions/27230824/fit-multivariate-gaussian-distribution-to-a-given-dataset)
127 | mean = torch.mean(sorted_sample_actions[:K], dim=0).type(torch.DoubleTensor)
128 | cov = torch.from_numpy(np.cov(sorted_sample_actions[:K], rowvar=0)).type(torch.DoubleTensor)
129 | # iteration is based on convergence of Q
130 | if det(cov) == 0 or cov_tmp == None:
131 | mean_tmp = mean
132 | cov_tmp = cov
133 | continue
134 | else:
135 | if det(cov_tmp)==0:
136 | mean_tmp = mean
137 | cov_tmp = cov
138 | continue
139 | else:
140 | p = MultivariateNormal(mean, cov)
141 | q = MultivariateNormal(mean_tmp, cov_tmp)
142 | if kl_divergence(p, q) < KL: # 0.3 is okay
143 | converge = True
144 | mean_tmp = mean
145 | cov_tmp = cov
146 |
147 | print("***** At action time step {}, iteration {} *****".format(t, iter_count))
148 | iter_count += 1
149 |
150 | #Execute action a{t}* with lowest loss
151 | action_best = sorted_sample_actions[0]
152 | action_loss = ((action_best.detach().cpu().numpy()-resz_act[:4])**2).mean(axis=None)
153 | #Observe new image I{t+1}
154 | img_cur = generate_next_pred_state(e2c_model, img_cur, action_best)
155 | img_loss = F.binary_cross_entropy(img_cur.view(-1, 2500), img_goal.view(-1, 2500), reduction='mean')
156 |
157 | print("***** Generate Next Predicted Image {}*****".format(t+1))
158 | print("***** End Planning *****")
159 | return action_loss, img_loss.detach().cpu().numpy()
160 |
161 | # plan result folder name
162 | plan_folder_name = 'test_e2c'
163 | if not os.path.exists('./plan_result/{}'.format(plan_folder_name)):
164 | os.makedirs('./plan_result/{}'.format(plan_folder_name))
165 | # time step to execute the action
166 | T = 1
167 | # total number of samples for action sequences
168 | N = 100
169 | # K samples to fit multivariate gaussian distribution (N>K, K>1)
170 | K = 50
171 | # length of action sequence
172 | H = 1
173 | # model
174 | torch.manual_seed(1)
175 | device = torch.device("cpu")
176 | print("Device is:", device)
177 | e2c_model = E2C().to(device)
178 |
179 | # action
180 | # load GT action
181 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy'
182 | resz_act = np.load(resz_act_path)
183 |
184 | # checkpoint
185 | print('***** Load Checkpoint *****')
186 | folder_name = "test_E2C_gpu_update_loss"
187 | PATH = './result/{}/checkpoint'.format(folder_name)
188 | checkpoint_e2c = torch.load(PATH, map_location=device)
189 | e2c_model.load_state_dict(checkpoint_e2c['e2c_model_state_dict'])
190 |
191 | total_img_num = 22515
192 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num)
193 |
194 |
195 | def get_image(i):
196 | img = TF.to_tensor(Image.open(image_paths_bi[i])) > 0.3
197 | return img.reshape((-1, 1, 50, 50)).type(torch.float)
198 |
199 | for KL in [1000]:
200 | action_loss_all = []
201 | img_loss_all = []
202 | for i in range(20000, 20010):
203 | img_initial = get_image(i)
204 | img_goal = get_image(i+1)
205 | action_loss, img_loss = main(e2c_model, T, K, N, H, img_initial, img_goal, resz_act[i], i, KL)
206 | action_loss_all.append(action_loss)
207 | img_loss_all.append(img_loss)
208 | np.save('./plan_result/{}/KL_action_{}.npy'.format(plan_folder_name, KL), action_loss_all)
209 | np.save('./plan_result/{}/KL_image_{}.npy'.format(plan_folder_name, KL), img_loss_all)
--------------------------------------------------------------------------------
/model/cem_plot_action.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import math
5 |
6 | def rect(poke, c):
7 | x, y, t, l = poke
8 | dx = -2000 * l * math.cos(t)
9 | dy = -2000 * l * math.sin(t)
10 | plt.arrow(x, y, dx, dy, width=0.01, head_width=1, head_length=3, color=c)
11 |
12 | def plot_sample(img_before, img_after, img_after_pred, resz_action, recon_action, directory):
13 | plt.figure()
14 | N = int(img_before.shape[0])
15 | # upper row original
16 | plt.subplot(2, 2, 1)
17 | rect(resz_action[i], "blue")
18 | plt.imshow(img_before.reshape((50,50)))
19 | plt.axis('off')
20 | # middle row reconstruction
21 | plt.subplot(2, 2, 2)
22 | plt.imshow(img_after.reshape((50,50)))
23 | plt.axis('off')
24 | # lower row: next image after action
25 | plt.subplot(2, 2, 3)
26 | rect(recon_action[i], "blue")
27 | plt.imshow(img_before.reshape((50,50)))
28 | plt.axis('off')
29 | # lower row: next image after action
30 | plt.subplot(2, 2, 4)
31 | plt.imshow(img_after_pred.reshape((50,50)))
32 | plt.axis('off')
33 | plt.savefig(directory)
34 | plt.close()
35 |
36 | def plot_action(resz_action, recon_action, directory):
37 | plt.figure()
38 | # upper row original
39 | plt.subplot(1, 2, 1)
40 | rect(resz_action, "blue")
41 | plt.axis('off')
42 | # middle row reconstruction
43 | plt.subplot(1, 2, 2)
44 | rect(recon_action, "red")
45 | plt.axis('off')
46 | plt.savefig(directory)
47 | plt.close()
48 |
49 |
50 | def get_image(i):
51 | img = TF.to_tensor(Image.open(image_paths_bi[i])) > 0.3
52 | return img.reshape((-1, 1, 50, 50)).type(torch.float)
53 |
54 | # total number of samples for action sequences
55 | N = 500
56 | # K samples to fit multivariate gaussian distribution (N>K, K>1)
57 | K = 50
58 |
59 | # load GT action
60 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy'
61 | resz_act = np.load(resz_act_path)
62 |
63 | # load image
64 | total_img_num = 22515
65 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num)
66 |
67 | for i in range(45):
68 | img_before = get_image(i)
69 | img_after = get_image(i+1)
70 | action_best = torch.load("./plan_result/03/action_best_step{}_N{}_K{}.pt".format(i, N, K))
71 | plot_sample(img_before,
72 | img_after,
73 | img_after_pred,
74 | resz_action,
75 | recon_action,
76 | './plan_result/03/compare_align_{}'.format(i))
--------------------------------------------------------------------------------
/model/cem_plot_curve.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.linalg import det
3 | from numpy import sqrt
4 | from deform.model.dl_model import *
5 | from deform.model.create_dataset import *
6 | from deform.model.hidden_dynamics import get_next_state_linear
7 | import math
8 | import torch
9 | from torch.utils.data import Dataset, DataLoader
10 | from torchvision.utils import save_image
11 | from torch.distributions import Uniform
12 | from torch.distributions.multivariate_normal import MultivariateNormal
13 | from torch.distributions.kl import kl_divergence
14 | import torchvision.transforms.functional as TF
15 | from PIL import Image
16 | from deform.utils.utils import plot_cem_sample
17 | import os
18 |
19 | def sample_action(I, mean=None, cov=None):
20 | '''TODO: unit test
21 | each action sequence length: H
22 | number of action sequences: N
23 | '''
24 | action = torch.tensor([0]*4, dtype=torch.float)
25 | multiplier = torch.tensor([50, 50, 2*math.pi, 0.14])
26 | addition = torch.tensor([0, 0, 0, 0.01])
27 | thres = 0.9
28 | if I[0][0][0][0] == 1.:
29 | if ((mean is None) and (cov is None)):
30 | action_base = Uniform(low=0.0, high=1.0).sample((4,))
31 | action = torch.mul(action_base, multiplier) + addition
32 | else:
33 | cov = add_eye(cov)
34 | action = MultivariateNormal(mean, cov).sample()
35 | action[0], action[1] = 0, 0
36 | return action
37 |
38 | while I[0][0][torch.floor(action[0]).type(torch.LongTensor)][torch.floor(action[1]).type(torch.LongTensor)] != 1.:
39 | if ((mean is None) and (cov is None)):
40 | action_base = Uniform(low=0.0, high=1.0).sample((4,))
41 | action = torch.mul(action_base, multiplier) + addition
42 | else:
43 | cov = add_eye(cov)
44 | action = MultivariateNormal(mean, cov).sample()
45 | while torch.floor(action[0]).type(torch.LongTensor) >= 50 or torch.floor(action[1]).type(torch.LongTensor) >= 50:
46 | cov = add_eye(cov)
47 | action = MultivariateNormal(mean, cov).sample()
48 |
49 | return action
50 |
51 | def generate_next_pred_state(recon_model, dyn_model, img_pre, act_pre):
52 | '''generate next predicted state
53 | reconstruction model: recon_model
54 | dynamics model: dyn_model
55 | initial image: img_pre
56 | each action sequence length: H
57 | number of action sequences: N
58 | '''
59 | latent_img_pre, latent_act_pre, _, _, _ = recon_model(img_pre.reshape((-1, 1, 50, 50)), act_pre.reshape((-1, 4)).type(torch.float), None)
60 | K_T_pre, L_T_pre = dyn_model(img_pre.reshape((-1, 1, 50, 50)), act_pre.reshape((-1, 4)).type(torch.float))
61 | recon_latent_img_cur = get_next_state_linear(latent_img_pre, latent_act_pre, K_T_pre, L_T_pre)
62 | return recon_model.decoder(recon_latent_img_cur)
63 |
64 | def generate_next_pred_state_in_n_step(recon_model, dyn_model, img_initial, N, H, mean=None, cov=None):
65 | imgs = [None]*N
66 | actions = torch.tensor([[0.]*4]*N)
67 | for n in range(N):
68 | img_tmp = img_initial
69 | for h in range(H):
70 | action = sample_action(img_tmp, mean, cov)
71 | if h==0:
72 | actions[n] = action
73 | img_tmp = generate_next_pred_state(recon_model, dyn_model, img_tmp, action)
74 | imgs[n] = img_tmp
75 | return imgs, actions
76 |
77 | def loss_function_img(img_recon, img_goal, N):
78 | loss = torch.as_tensor([0.]*N)
79 | for n in range(N):
80 | loss[n] = F.binary_cross_entropy(img_recon[n].view(-1, 2500), img_goal.view(-1, 2500), reduction='sum')
81 | return loss
82 |
83 | def add_eye(cov):
84 | if det(cov)==0:
85 | return cov + torch.eye(4) * 0.000001
86 | else:
87 | return cov
88 |
89 | def mahalanobis(dist, cov):
90 | '''dist = mu1 - mu2, mu1 & mu2 are means of two multivariate gaussian distribution
91 | matrix multiplication: dist^T * cov^(-1) * dist
92 | '''
93 | return (dist.transpose(0,1).mm(cov.inverse())).mm(dist)
94 |
95 | def bhattacharyya(dist, cov1, cov2):
96 | '''source: https://en.wikipedia.org/wiki/Bhattacharyya_distance
97 | '''
98 | cov = (cov1 + cov2) / 2
99 | d1 = mahalanobis(dist.reshape((4,-1)), cov) / 8
100 | if det(cov)==0 or det(cov1)==0 or det(cov2)==0:
101 | return inf
102 | d2 = np.log(det(cov) / sqrt(det(cov1) * det(cov2))) / 2
103 | return d1 + d2
104 |
105 | def main(recon_model, dyn_model, T, K, N, H, img_initial, img_goal, resz_act, step_i, KL):
106 | for t in range(T):
107 | print("***** Start Step {}".format(t))
108 | if t==0:
109 | img_cur = img_initial
110 | #Initialize Q with uniform distribution
111 | mean = None
112 | cov = None
113 | mean_tmp = None
114 | cov_tmp = None
115 | converge = False
116 | iter_count = 0
117 | while not converge:
118 | imgs_recon, sample_actions = generate_next_pred_state_in_n_step(recon_model, dyn_model, img_cur, N, H, mean, cov)
119 | #Calculate binary cross entropy loss for predicted image and goal image
120 | loss = loss_function_img(imgs_recon, img_goal, N)
121 | #Select K action sequences with lowest loss
122 | loss_index = torch.argsort(loss)
123 | sorted_sample_actions = sample_actions[loss_index]
124 | #Fit multivariate gaussian distribution to K samples
125 | #(see how to fit algorithm:
126 | #https://stackoverflow.com/questions/27230824/fit-multivariate-gaussian-distribution-to-a-given-dataset)
127 | mean = torch.mean(sorted_sample_actions[:K], dim=0).type(torch.DoubleTensor)
128 | cov = torch.from_numpy(np.cov(sorted_sample_actions[:K], rowvar=0)).type(torch.DoubleTensor)
129 | # iteration is based on convergence of Q
130 | if det(cov) == 0 or cov_tmp == None:
131 | mean_tmp = mean
132 | cov_tmp = cov
133 | continue
134 | else:
135 | if det(cov_tmp)==0:
136 | mean_tmp = mean
137 | cov_tmp = cov
138 | continue
139 | else:
140 | p = MultivariateNormal(mean, cov)
141 | q = MultivariateNormal(mean_tmp, cov_tmp)
142 | if kl_divergence(p, q) < KL:
143 | converge = True
144 | mean_tmp = mean
145 | cov_tmp = cov
146 |
147 | print("***** At action time step {}, iteration {} *****".format(t, iter_count))
148 | iter_count += 1
149 |
150 | #Execute action a{t}* with lowest loss
151 | action_best = sorted_sample_actions[0]
152 | action_loss = ((action_best.detach().cpu().numpy()-resz_act[:4])**2).mean(axis=None)
153 | #Observe new image I{t+1}
154 | img_cur = generate_next_pred_state(recon_model, dyn_model, img_cur, action_best)
155 | img_loss = F.binary_cross_entropy(img_cur.view(-1, 2500), img_goal.view(-1, 2500), reduction='mean')
156 | print("***** Generate Next Predicted Image {}*****".format(t+1))
157 |
158 | print("***** End Planning *****")
159 | return action_loss, img_loss.detach().cpu().numpy()
160 |
161 | # plan result folder name
162 | plan_folder_name = 'curve_KL'
163 | if not os.path.exists('./plan_result/{}'.format(plan_folder_name)):
164 | os.makedirs('./plan_result/{}'.format(plan_folder_name))
165 | # time step to execute the action
166 | T = 1
167 | # total number of samples for action sequences
168 | N = 100
169 | # K samples to fit multivariate gaussian distribution (N>K, K>1)
170 | K = 50
171 | # length of action sequence
172 | H = 1
173 | # model
174 | torch.manual_seed(1)
175 | device = torch.device("cpu")
176 | print("Device is:", device)
177 | recon_model = CAE().to(device)
178 | dyn_model = SysDynamics().to(device)
179 |
180 | # action
181 | # load GT action
182 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy'
183 | resz_act = np.load(resz_act_path)
184 |
185 | # checkpoint
186 | print('***** Load Checkpoint *****')
187 | folder_name = "test_act80_pred30"
188 | PATH = './result/{}/checkpoint'.format(folder_name)
189 | checkpoint = torch.load(PATH, map_location=device)
190 | recon_model.load_state_dict(checkpoint['recon_model_state_dict'])
191 | dyn_model.load_state_dict(checkpoint['dyn_model_state_dict'])
192 |
193 | total_img_num = 22515
194 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num)
195 |
196 |
197 | def get_image(i):
198 | img = TF.to_tensor(Image.open(image_paths_bi[i])) > 0.3
199 | return img.reshape((-1, 1, 50, 50)).type(torch.float)
200 |
201 |
202 | for KL in [1000]:
203 | action_loss_all = []
204 | img_loss_all = []
205 | for i in range(20000, 20010):
206 | img_initial = get_image(i)
207 | img_goal = get_image(i+1)
208 | action_loss, img_loss = main(recon_model, dyn_model, T, K, N, H, img_initial, img_goal, resz_act[i], i, KL)
209 | action_loss_all.append(action_loss)
210 | img_loss_all.append(img_loss)
211 | np.save('./plan_result/{}/KL_action_{}.npy'.format(plan_folder_name, KL), action_loss_all)
212 | np.save('./plan_result/{}/KL_image_{}.npy'.format(plan_folder_name, KL), img_loss_all)
--------------------------------------------------------------------------------
/model/create_dataset.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms.functional as TF
2 | from torchvision import transforms
3 | from torch.utils.data import Dataset
4 | import numpy as np
5 | import random
6 | import math
7 | from PIL import Image
8 | import matplotlib.pyplot as plt
9 | import torch
10 | from torch.utils.data.dataloader import default_collate
11 |
12 | class MyDataset(Dataset):
13 | def __init__(self, image_paths_bi, resize_actions, transform=None):
14 | self.image_paths_bi = image_paths_bi
15 | self.resz_actions = resize_actions
16 | self.transform = transform
17 |
18 | def __getitem__(self, index):
19 | n = self.__len__()
20 | none_sample = {'image_bi_pre': None, 'image_bi_cur': None, 'image_bi_post': None, 'resz_action_pre': None, 'resz_action_cur': None}
21 | if index == 0:
22 | if n == 2:
23 | index = 1
24 | else:
25 | index = np.random.randint(1, n-1)
26 | if index == n-1:
27 | if n == 2:
28 | index = 1
29 | else:
30 | index = np.random.randint(1, n-1)
31 |
32 | # load action
33 | resz_action_pre = self.resz_actions[index-1]
34 | resz_action_cur = self.resz_actions[index]
35 | # decide if action is valid
36 | if int(resz_action_pre[4]) == 0 or int(resz_action_cur[4]) == 0:
37 | return none_sample
38 |
39 | # load images (pre-transform images)
40 | image_bi_pre = Image.open(self.image_paths_bi[index-1])
41 | image_bi_cur = Image.open(self.image_paths_bi[index])
42 | image_bi_post = Image.open(self.image_paths_bi[index+1])
43 |
44 | sample = {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'resz_action_pre': resz_action_pre[:4], 'resz_action_cur': resz_action_cur[:4]}
45 |
46 | if self.transform:
47 | sample = self.transform(sample)
48 |
49 | return sample
50 |
51 | def __len__(self):
52 | return len(self.image_paths_bi)
53 |
54 | class MyDatasetMultiPred4(Dataset):
55 | def __init__(self, image_paths_bi, resize_actions, transform=None):
56 | self.image_paths_bi = image_paths_bi
57 | self.resz_actions = resize_actions
58 | self.transform = transform
59 |
60 | def __getitem__(self, index):
61 | n = self.__len__()
62 | none_sample = {'image_bi_pre': None, 'image_bi_cur': None, 'image_bi_post': None, 'image_bi_post2': None, 'image_bi_post3': None,\
63 | 'resz_action_pre': None, 'resz_action_cur': None, 'resz_action_post': None, 'resz_action_post2': None}
64 | # edge cases, use index in [1,n-1)
65 | if index == 0:
66 | index = np.random.randint(1, n-3)
67 | if index == n-1:
68 | index = np.random.randint(1, n-3)
69 | if index == n-2:
70 | index = np.random.randint(1, n-3)
71 | if index == n-3:
72 | index = np.random.randint(1, n-3)
73 |
74 | # load action
75 | resz_action_pre = self.resz_actions[index-1]
76 | resz_action_cur = self.resz_actions[index]
77 | resz_action_post = self.resz_actions[index+1]
78 | resz_action_post2 = self.resz_actions[index+2]
79 | # decide if action is valid
80 | if int(resz_action_pre[4]) == 0 or int(resz_action_cur[4]) == 0 or int(resz_action_post[4]) == 0 or int(resz_action_post2[4]) == 0:
81 | return none_sample
82 |
83 | # load images (pre-transform images)
84 | image_bi_pre = Image.open(self.image_paths_bi[index-1])
85 | image_bi_cur = Image.open(self.image_paths_bi[index])
86 | image_bi_post = Image.open(self.image_paths_bi[index+1])
87 | image_bi_post2 = Image.open(self.image_paths_bi[index+2])
88 | image_bi_post3 = Image.open(self.image_paths_bi[index+3])
89 |
90 | sample = {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'image_bi_post2': image_bi_post2, 'image_bi_post3': image_bi_post3,\
91 | 'resz_action_pre': resz_action_pre[:4], 'resz_action_cur': resz_action_cur[:4], 'resz_action_post': resz_action_post[:4], 'resz_action_post2': resz_action_post2[:4]}
92 |
93 | if self.transform:
94 | sample = self.transform(sample)
95 |
96 | return sample
97 |
98 | def __len__(self):
99 | return len(self.image_paths_bi)
100 |
101 | class MyDatasetMultiPred10(Dataset):
102 | def __init__(self, image_paths_bi, resize_actions, transform=None):
103 | self.image_paths_bi = image_paths_bi
104 | self.resz_actions = resize_actions
105 | self.transform = transform
106 |
107 | def __getitem__(self, index):
108 | n = self.__len__()
109 | none_sample = {'image_bi_pre': None, 'image_bi_cur': None, 'image_bi_post': None, 'image_bi_post2': None, 'image_bi_post3': None,\
110 | 'image_bi_post4': None, 'image_bi_post5': None, 'image_bi_post6': None, 'image_bi_post7': None, 'image_bi_post8': None, 'image_bi_post9': None,\
111 | 'resz_action_pre': None, 'resz_action_cur': None, 'resz_action_post': None, 'resz_action_post2': None, 'resz_action_post3': None, 'resz_action_post4': None,\
112 | 'resz_action_post5': None, 'resz_action_post6': None, 'resz_action_post7': None, 'resz_action_post8': None}
113 | # edge cases, use index in [1,n-1)
114 | if index == 0:
115 | index = np.random.randint(1, n-9)
116 | if index == n-1:
117 | index = np.random.randint(1, n-9)
118 | if index == n-2:
119 | index = np.random.randint(1, n-9)
120 | if index == n-3:
121 | index = np.random.randint(1, n-9)
122 | if index == n-4:
123 | index = np.random.randint(1, n-9)
124 | if index == n-5:
125 | index = np.random.randint(1, n-9)
126 | if index == n-6:
127 | index = np.random.randint(1, n-9)
128 | if index == n-7:
129 | index = np.random.randint(1, n-9)
130 | if index == n-8:
131 | index = np.random.randint(1, n-9)
132 | if index == n-9:
133 | index = np.random.randint(1, n-9)
134 | # load action
135 | resz_action_pre = self.resz_actions[index-1]
136 | resz_action_cur = self.resz_actions[index]
137 | resz_action_post = self.resz_actions[index+1]
138 | resz_action_post2 = self.resz_actions[index+2]
139 | resz_action_post3 = self.resz_actions[index+3]
140 | resz_action_post4 = self.resz_actions[index+4]
141 | resz_action_post5 = self.resz_actions[index+5]
142 | resz_action_post6 = self.resz_actions[index+6]
143 | resz_action_post7 = self.resz_actions[index+7]
144 | resz_action_post8 = self.resz_actions[index+8]
145 | resz_action_post9 = self.resz_actions[index+9]
146 | # decide if action is valid
147 | if int(resz_action_pre[4]) == 0 or int(resz_action_cur[4]) == 0 \
148 | or int(resz_action_post[4]) == 0 or int(resz_action_post2[4]) == 0 \
149 | or int(resz_action_post3[4]) == 0 or int(resz_action_post4[4]) == 0 \
150 | or int(resz_action_post5[4]) == 0 or int(resz_action_post6[4]) == 0 \
151 | or int(resz_action_post7[4]) == 0 or int(resz_action_post8[4]) == 0 \
152 | or int(resz_action_post9[4]) == 0:
153 | return none_sample
154 |
155 | # load images (pre-transform images)
156 | image_bi_pre = Image.open(self.image_paths_bi[index-1])
157 | image_bi_cur = Image.open(self.image_paths_bi[index])
158 | image_bi_post = Image.open(self.image_paths_bi[index+1])
159 | image_bi_post2 = Image.open(self.image_paths_bi[index+2])
160 | image_bi_post3 = Image.open(self.image_paths_bi[index+3])
161 | image_bi_post4 = Image.open(self.image_paths_bi[index+4])
162 | image_bi_post5 = Image.open(self.image_paths_bi[index+5])
163 | image_bi_post6 = Image.open(self.image_paths_bi[index+6])
164 | image_bi_post7 = Image.open(self.image_paths_bi[index+7])
165 | image_bi_post8 = Image.open(self.image_paths_bi[index+8])
166 | image_bi_post9 = Image.open(self.image_paths_bi[index+9])
167 |
168 |
169 | sample = {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'image_bi_post2': image_bi_post2, 'image_bi_post3': image_bi_post3,\
170 | 'image_bi_post4': image_bi_post4, 'image_bi_post5': image_bi_post5, 'image_bi_post6': image_bi_post6, 'image_bi_post7': image_bi_post7, 'image_bi_post8': image_bi_post8, \
171 | 'image_bi_post9': image_bi_post9, \
172 | 'resz_action_pre': resz_action_pre[:4], 'resz_action_cur': resz_action_cur[:4], 'resz_action_post': resz_action_post[:4], 'resz_action_post2': resz_action_post2[:4], \
173 | 'resz_action_post3': resz_action_post3[:4], 'resz_action_post4': resz_action_post4[:4], 'resz_action_post5': resz_action_post5[:4], 'resz_action_post6': resz_action_post6[:4], \
174 | 'resz_action_post7': resz_action_post7[:4], 'resz_action_post8': resz_action_post8[:4]}
175 |
176 | if self.transform:
177 | sample = self.transform(sample)
178 |
179 | return sample
180 |
181 | def __len__(self):
182 | return len(self.image_paths_bi)
183 |
184 | class Translation(object):
185 | '''Translate the image and action [-max_translation, max_translation]. e.g., [-10, 10]
186 | '''
187 | def __init__(self, max_translation=10):
188 | self.m_trans = max_translation
189 |
190 |
191 | def __call__(self, sample):
192 | image_bi_pre, image_bi_cur, image_bi_post, action_pre, action_cur = sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['resz_action_pre'], sample['resz_action_cur']
193 | trans = list(2 * self.m_trans * np.random.random_sample((2,)) - self.m_trans)
194 | trans_action_pre, trans_action_cur = action_pre.copy(), action_cur.copy()
195 | trans_action_pre[:2], trans_action_cur[:2] = trans_action_pre[:2] + np.array(trans), trans_action_cur[:2] + np.array(trans)
196 | image_bi_pre = TF.affine(image_bi_pre, angle=0, translate=trans, scale=1.0, shear=0.0)
197 | image_bi_cur = TF.affine(image_bi_cur, angle=0, translate=trans, scale=1.0, shear=0.0)
198 | image_bi_post = TF.affine(image_bi_post, angle=0, translate=trans, scale=1.0, shear=0.0)
199 |
200 | return {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'resz_action_pre': trans_action_pre, 'resz_action_cur': trans_action_cur}
201 |
202 |
203 | class HFlip(object):
204 | '''Ramdom Horizontal flip the image and action with probabilty p
205 | '''
206 | def __init__(self, probability=0.5):
207 | self.p = probability
208 |
209 | def flip_angle(self, action):
210 | if action[2] > math.pi:
211 | action[2] = 3 * math.pi - action[2]
212 | else:
213 | action[2] = math.pi - action[2]
214 | return action
215 |
216 | def __call__(self, sample):
217 | if random.random() >= self.p:
218 | return sample
219 | else:
220 | image_bi_pre, image_bi_cur, image_bi_post, action_pre, action_cur = sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['resz_action_pre'], sample['resz_action_cur']
221 | image_bi_pre, image_bi_cur, image_bi_post = TF.hflip(image_bi_pre), TF.hflip(image_bi_cur), TF.hflip(image_bi_post)
222 |
223 | tsfrm_action_pre, tsfrm_action_cur = action_pre.copy(), action_cur.copy()
224 | tsfrm_action_pre[0], tsfrm_action_cur[0] = image_bi_pre.size[0] - tsfrm_action_pre[0], image_bi_cur.size[0] - tsfrm_action_cur[0]
225 |
226 | tsfrm_action_pre = self.flip_angle(tsfrm_action_pre)
227 | tsfrm_action_cur = self.flip_angle(tsfrm_action_cur)
228 |
229 | return {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'resz_action_pre': tsfrm_action_pre, 'resz_action_cur': tsfrm_action_cur}
230 |
231 |
232 | class VFlip(object):
233 | '''Ramdom vertical flip the image and action with probabilty p
234 | '''
235 | def __init__(self, probability=0.5):
236 | self.p = probability
237 |
238 | def __call__(self, sample):
239 | if random.random() >= self.p:
240 | return sample
241 | else:
242 | image_bi_pre, image_bi_cur, image_bi_post, action_pre, action_cur = sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['resz_action_pre'], sample['resz_action_cur']
243 | image_bi_pre, image_bi_cur, image_bi_post = TF.vflip(image_bi_pre), TF.vflip(image_bi_cur), TF.vflip(image_bi_post)
244 |
245 | tsfrm_action_pre, tsfrm_action_cur = action_pre.copy(), action_cur.copy()
246 | tsfrm_action_pre[1], tsfrm_action_cur[1] = image_bi_pre.size[1] - tsfrm_action_pre[1], image_bi_cur.size[1] - tsfrm_action_cur[1]
247 |
248 | tsfrm_action_pre[2] = 2 * math.pi - tsfrm_action_pre[2]
249 | tsfrm_action_cur[2] = 2 * math.pi - tsfrm_action_cur[2]
250 | return {'image_bi_pre': image_bi_pre, 'image_bi_cur': image_bi_cur, 'image_bi_post': image_bi_post, 'resz_action_pre': tsfrm_action_pre, 'resz_action_cur': tsfrm_action_cur}
251 |
252 | class ToTensor(object):
253 | '''convert ndarrays in sample to tensors
254 | '''
255 | def __call__(self, sample):
256 | image_bi_pre, image_bi_cur, image_bi_post, resz_action_pre, resz_action_cur = sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['resz_action_pre'], sample['resz_action_cur']
257 | image_bi_pre = TF.to_tensor(image_bi_pre) > 0.3
258 | image_bi_cur = TF.to_tensor(image_bi_cur) > 0.3
259 | image_bi_post = TF.to_tensor(image_bi_post) > 0.3
260 | return {'image_bi_pre': image_bi_pre.float(), 'image_bi_cur': image_bi_cur.float(), 'image_bi_post': image_bi_post.float(), 'resz_action_pre': torch.tensor(resz_action_pre), 'resz_action_cur': torch.tensor(resz_action_cur)}
261 |
262 | class ToTensorRGB(object):
263 | '''convert ndarrays in sample to tensors
264 | '''
265 | def __call__(self, sample):
266 | image_bi_pre, image_bi_cur, image_bi_post, resz_action_pre, resz_action_cur = sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['resz_action_pre'], sample['resz_action_cur']
267 | image_bi_pre = TF.to_tensor(image_bi_pre)
268 | image_bi_cur = TF.to_tensor(image_bi_cur)
269 | image_bi_post = TF.to_tensor(image_bi_post)
270 | return {'image_bi_pre': image_bi_pre.float(), 'image_bi_cur': image_bi_cur.float(), 'image_bi_post': image_bi_post.float(), 'resz_action_pre': torch.tensor(resz_action_pre), 'resz_action_cur': torch.tensor(resz_action_cur)}
271 |
272 |
273 | class ToTensorMultiPred4(object):
274 | '''convert ndarrays in sample to tensors
275 | '''
276 | def __call__(self, sample):
277 | image_bi_pre, image_bi_cur, image_bi_post, image_bi_post2, image_bi_post3, resz_action_pre, resz_action_cur, resz_action_post, resz_action_post2 = \
278 | sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['image_bi_post2'], sample['image_bi_post3'], sample['resz_action_pre'], sample['resz_action_cur'], sample['resz_action_post'], sample['resz_action_post2']
279 | image_bi_pre = TF.to_tensor(image_bi_pre) > 0.3
280 | image_bi_cur = TF.to_tensor(image_bi_cur) > 0.3
281 | image_bi_post = TF.to_tensor(image_bi_post) > 0.3
282 | image_bi_post2 = TF.to_tensor(image_bi_post2) > 0.3
283 | image_bi_post3 = TF.to_tensor(image_bi_post3) > 0.3
284 | return {'image_bi_pre': image_bi_pre.float(), 'image_bi_cur': image_bi_cur.float(), 'image_bi_post': image_bi_post.float(), 'image_bi_post2': image_bi_post2.float(), 'image_bi_post3': image_bi_post3.float(), \
285 | 'resz_action_pre': torch.tensor(resz_action_pre), 'resz_action_cur': torch.tensor(resz_action_cur), 'resz_action_post': torch.tensor(resz_action_post), 'resz_action_post2': torch.tensor(resz_action_post2)}
286 |
287 | class ToTensorMultiPred10(object):
288 | '''convert ndarrays in sample to tensors
289 | '''
290 | def __call__(self, sample):
291 | image_bi_pre, image_bi_cur, image_bi_post, image_bi_post2, image_bi_post3, image_bi_post4, image_bi_post5, image_bi_post6, image_bi_post7, image_bi_post8, image_bi_post9, \
292 | resz_action_pre, resz_action_cur, resz_action_post, resz_action_post2, resz_action_post3, resz_action_post4, resz_action_post5, resz_action_post6, resz_action_post7, resz_action_post8 = \
293 | sample['image_bi_pre'], sample['image_bi_cur'], sample['image_bi_post'], sample['image_bi_post2'], sample['image_bi_post3'], sample['image_bi_post4'], sample['image_bi_post5'],\
294 | sample['image_bi_post6'], sample['image_bi_post7'], sample['image_bi_post8'], sample['image_bi_post9'], \
295 | sample['resz_action_pre'], sample['resz_action_cur'], sample['resz_action_post'], sample['resz_action_post2'], sample['resz_action_post3'], sample['resz_action_post4'], sample['resz_action_post5'],\
296 | sample['resz_action_post6'], sample['resz_action_post7'], sample['resz_action_post8']
297 | value = 0.1
298 | image_bi_pre = TF.to_tensor(image_bi_pre) > value
299 | image_bi_cur = TF.to_tensor(image_bi_cur) > value
300 | image_bi_post = TF.to_tensor(image_bi_post) > value
301 | image_bi_post2 = TF.to_tensor(image_bi_post2) > value
302 | image_bi_post3 = TF.to_tensor(image_bi_post3) > value
303 | image_bi_post4 = TF.to_tensor(image_bi_post4) > value
304 | image_bi_post5 = TF.to_tensor(image_bi_post5) > value
305 | image_bi_post6 = TF.to_tensor(image_bi_post6) > value
306 | image_bi_post7 = TF.to_tensor(image_bi_post7) > value
307 | image_bi_post8 = TF.to_tensor(image_bi_post8) > value
308 | image_bi_post9 = TF.to_tensor(image_bi_post9) > value
309 | return {'image_bi_pre': image_bi_pre.float(), 'image_bi_cur': image_bi_cur.float(), 'image_bi_post': image_bi_post.float(), 'image_bi_post2': image_bi_post2.float(), 'image_bi_post3': image_bi_post3.float(), \
310 | 'image_bi_post4': image_bi_post4.float(), 'image_bi_post5': image_bi_post5.float(), 'image_bi_post6': image_bi_post6.float(), 'image_bi_post7': image_bi_post7.float(), 'image_bi_post8': image_bi_post8.float(), \
311 | 'image_bi_post9': image_bi_post9.float(), 'resz_action_pre': torch.tensor(resz_action_pre), 'resz_action_cur': torch.tensor(resz_action_cur), 'resz_action_post': torch.tensor(resz_action_post), \
312 | 'resz_action_post2': torch.tensor(resz_action_post2), 'resz_action_post3': torch.tensor(resz_action_post3), 'resz_action_post4': torch.tensor(resz_action_post4), 'resz_action_post5': torch.tensor(resz_action_post5), \
313 | 'resz_action_post6': torch.tensor(resz_action_post6), 'resz_action_post7': torch.tensor(resz_action_post7), 'resz_action_post8': torch.tensor(resz_action_post8)}
314 |
315 | def my_collate(batch):
316 | batch = list(filter(lambda x: x['image_bi_post'] is not None, batch))
317 | return default_collate(batch)
318 |
319 | def create_image_path(folder, total_img_num):
320 | '''create image_path list as input of MyDataset
321 | total_img_num: number of images
322 | '''
323 | add1 = './rope_dataset/{}'.format(folder)
324 | image_paths = []
325 | for i in range(total_img_num):
326 | if len(str(i)) == 1:
327 | add2 = '/img_0000{}.jpg'.format(i)
328 | elif len(str(i)) == 2:
329 | add2 = '/img_000{}.jpg'.format(i)
330 | elif len(str(i)) == 3:
331 | add2 = '/img_00{}.jpg'.format(i)
332 | elif len(str(i)) == 4:
333 | add2 = '/img_0{}.jpg'.format(i)
334 | elif len(str(i)) == 5:
335 | add2 = '/img_{}.jpg'.format(i)
336 | image_paths.append(add1+add2)
337 | return image_paths
338 |
339 |
--------------------------------------------------------------------------------
/model/e2c.py:
--------------------------------------------------------------------------------
1 | # separate two models, train g^t first, then train K and L
2 | from __future__ import print_function
3 | import argparse
4 |
5 | import torch
6 | from torch import nn, optim, sigmoid, tanh, relu
7 | from torch.autograd import Variable
8 | from torch.nn import functional as F
9 | from torchvision.utils import save_image
10 | from torch.utils.data import Dataset, DataLoader
11 | from deform.model.create_dataset import *
12 | from deform.model.hidden_dynamics import *
13 | import matplotlib.pyplot as plt
14 | from deform.utils.utils import plot_train_loss, plot_train_bound_loss, plot_train_kl_loss, plot_train_pred_loss, \
15 | plot_test_loss, plot_test_bound_loss, plot_test_kl_loss, plot_test_pred_loss, \
16 | plot_sample, rect, save_data, save_e2c_data, create_loss_list, create_folder, plot_grad_flow
17 | from deform.model.configs import load_config
18 | import os
19 | import math
20 |
21 | class NormalDistribution(object):
22 | """
23 | Wrapper class representing a multivariate normal distribution parameterized by
24 | N(mu,Cov). If cov. matrix is diagonal, Cov=(sigma).^2. Otherwise,
25 | Cov=A*(sigma).^2*A', where A = (I+v*r^T).
26 | """
27 |
28 | def __init__(self, mu, sigma, logsigma, *, v=None, r=None):
29 | self.mu = mu
30 | self.sigma = sigma
31 | self.logsigma = logsigma
32 | self.v = v
33 | self.r = r
34 |
35 | @property
36 | def cov(self):
37 | """This should only be called when NormalDistribution represents one sample"""
38 | if self.v is not None and self.r is not None:
39 | assert self.v.dim() == 1
40 | dim = self.v.dim()
41 | v = self.v.unsqueeze(1) # D * 1 vector
42 | rt = self.r.unsqueeze(0) # 1 * D vector
43 | A = torch.eye(dim) + v.mm(rt)
44 | return A.mm(torch.diag(self.sigma.pow(2)).mm(A.t()))
45 | else:
46 | return torch.diag(self.sigma.pow(2))
47 |
48 |
49 | def KLDGaussian(Q, N, eps=1e-8):
50 | """KL Divergence between two Gaussians
51 | Assuming Q ~ N(mu0, A\sigma_0A') where A = I + vr^{T}
52 | and N ~ N(mu1, \sigma_1)
53 | """
54 | sum = lambda x: torch.sum(x, dim=1)
55 | k = float(Q.mu.size()[1]) # dimension of distribution
56 | mu0, v, r, mu1 = Q.mu, Q.v, Q.r, N.mu
57 | s02, s12 = (Q.sigma).pow(2) + eps, (N.sigma).pow(2) + eps
58 | a = sum(s02 * (1. + 2. * v * r) / s12) + sum(v.pow(2) / s12) * sum(r.pow(2) * s02) # trace term
59 | b = sum((mu1 - mu0).pow(2) / s12) # difference-of-means term
60 | c = 2. * (sum(N.logsigma - Q.logsigma) - torch.log(1. + sum(v * r) + eps)) # ratio-of-determinants term.
61 |
62 | return 0.5 * (a + b - k + c)
63 |
64 | class E2C(nn.Module):
65 | def __init__(self, dim_z=80, dim_u=4):
66 | super(E2C, self).__init__()
67 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1),
68 | nn.ReLU(),
69 | nn.MaxPool2d(3, stride=2),
70 | nn.Conv2d(32, 64, 3, padding=1),
71 | nn.ReLU(),
72 | nn.MaxPool2d(3, stride=2),
73 | nn.Conv2d(64, 128, 3, padding=1),
74 | nn.ReLU(),
75 | nn.MaxPool2d(3, stride=2),
76 | nn.Conv2d(128, 128, 3, padding=1),
77 | nn.ReLU(),
78 | nn.Conv2d(128, 128, 3, padding=1),
79 | nn.ReLU(),
80 | nn.MaxPool2d(3, stride=2, padding=1))
81 | self.fc1 = nn.Linear(128*3*3, dim_z*2)
82 | self.fc2 = nn.Linear(dim_z, 128*3*3)
83 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
84 | nn.ReLU(),
85 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
86 | nn.ReLU(),
87 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2),
88 | nn.ReLU(),
89 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2),
90 | nn.ReLU(),
91 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2),
92 | nn.Sigmoid())
93 | self.trans = nn.Sequential(
94 | nn.Linear(dim_z, 100),
95 | nn.BatchNorm1d(100),
96 | nn.ReLU(),
97 | nn.Linear(100, 100),
98 | nn.BatchNorm1d(100),
99 | nn.ReLU(),
100 | nn.Linear(100, dim_z*2)
101 | )
102 | self.fc_B = nn.Linear(dim_z, dim_z * dim_u)
103 | self.fc_o = nn.Linear(dim_z, dim_z)
104 | self.dim_z = dim_z
105 | self.dim_u = dim_u
106 |
107 |
108 | def encode(self, x):
109 | x = self.conv_layers(x)
110 | x = x.view(x.shape[0], -1)
111 | return relu(self.fc1(x)).chunk(2, dim=1)
112 |
113 | def decode(self, x):
114 | x = relu(self.fc2(x))
115 | x = x.view(-1, 128, 3, 3)
116 | return self.dconv_layers(x)
117 |
118 | def transition(self, h, Q, u):
119 | batch_size = h.size()[0]
120 | v, r = self.trans(h).chunk(2, dim=1)
121 | v1 = v.unsqueeze(2)
122 | rT = r.unsqueeze(1)
123 | I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1))
124 | if rT.data.is_cuda:
125 | I.dada.cuda()
126 | A = I.add(v1.bmm(rT))
127 |
128 | B = self.fc_B(h).view(-1, self.dim_z, self.dim_u)
129 | o = self.fc_o(h).reshape((-1, self.dim_z, 1))
130 |
131 | u = u.unsqueeze(2)
132 |
133 | d = A.bmm(Q.mu.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
134 | sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
135 | return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r)
136 |
137 | def reparam(self, mean, logvar):
138 | std = logvar.mul(0.5).exp_()
139 | self.z_mean = mean
140 | self.z_sigma = std
141 | eps = torch.FloatTensor(std.size()).normal_()
142 | if std.data.is_cuda:
143 | eps.cuda()
144 | eps = Variable(eps)
145 | return eps.mul(std).add_(mean), NormalDistribution(mean, std, torch.log(std))
146 |
147 | def forward(self, x, action, x_next):
148 | mean, logvar = self.encode(x)
149 | mean_next, logvar_next = self.encode(x_next)
150 |
151 | z, self.Qz = self.reparam(mean, logvar)
152 | z_next, self.Qz_next = self.reparam(mean_next, logvar_next)
153 |
154 | self.x_dec = self.decode(z)
155 | self.x_next_dec = self.decode(z_next)
156 |
157 | self.z_next_pred, self.Qz_next_pred = self.transition(z, self.Qz, action)
158 | self.x_next_pred_dec = self.decode(self.z_next_pred)
159 |
160 | return self.x_dec, self.x_next_pred_dec, self.Qz, self.Qz_next, self.Qz_next_pred
161 |
162 | def latent_embeddings(self, x):
163 | return self.encode(x)[0]
164 |
165 | def predict(self, X, U):
166 | mean, logvar = self.encode(X)
167 | z, Qz = self.reparam(mean, logvar)
168 | z_next_pred, Qz_next_pred = self.transition(z, Qz, U)
169 | return self.decode(z_next_pred)
170 |
171 | def binary_crossentropy(t, o, eps=1e-8):
172 | return t * torch.log(o + eps) + (1.0 - t) * torch.log(1.0 - o + eps)
173 |
174 | def compute_loss(x_dec, x_next_pred_dec, x, x_next,
175 | Qz, Qz_next_pred,
176 | Qz_next):
177 | # Reconstruction losses
178 | if False:
179 | x_reconst_loss = (x_dec - x_next).pow(2).sum(dim=1)
180 | x_next_reconst_loss = (x_next_pred_dec - x_next).pow(2).sum(dim=1)
181 | else:
182 | x_reconst_loss = -binary_crossentropy(x, x_dec).sum(dim=1)
183 | x_next_reconst_loss = -binary_crossentropy(x_next, x_next_pred_dec).sum(dim=1)
184 |
185 | logvar = Qz.logsigma.mul(2)
186 | KLD_element = Qz.mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
187 | KLD = torch.sum(KLD_element, dim=1).mul(-0.5)
188 |
189 | # ELBO
190 | bound_loss = x_reconst_loss.add(x_next_reconst_loss).add(KLD.reshape(-1,1,1))
191 | kl = KLDGaussian(Qz_next_pred, Qz_next)
192 | kl = kl[~torch.isnan(kl)]
193 | return bound_loss.mean(), kl.mean()
194 |
195 | def train(e2c_model):
196 | e2c_model.train()
197 | bound_loss = 0
198 | kl_loss = 0
199 | train_loss = 0
200 | pred_loss = 0
201 | for batch_idx, batch_data in enumerate(trainloader):
202 | # current image before action
203 | x = batch_data['image_bi_cur']
204 | x = x.float().to(device).view(-1, 1, 50, 50)
205 | # action
206 | action = batch_data['resz_action_cur']
207 | action = action.float().to(device).view(-1, 4)
208 | # image after action
209 | x_next = batch_data['image_bi_post']
210 | x_next = x_next.float().to(device).view(-1, 1, 50, 50)
211 | # optimization
212 | e2c_optimizer.zero_grad()
213 | # model
214 | x_dec, x_next_pred_dec, Qz, Qz_next, Qz_next_pred = e2c_model(x, action, x_next)
215 | # prediction
216 | x_next_pred = e2c_model.predict(x, action)
217 | # loss
218 | loss_pred = F.binary_cross_entropy(x_next_pred.view(-1, 2500), x_next.view(-1, 2500), reduction='sum')
219 | loss_bound, loss_kl = compute_loss(x_dec, x_next_pred_dec, x, x_next, Qz, Qz_next_pred, Qz_next)
220 | loss = GAMMA_bound * loss_bound + GAMMA_kl * loss_kl + GAMMA_pred * loss_pred
221 | loss.backward()
222 | train_loss += loss.item()
223 | bound_loss += GAMMA_bound * loss_bound.item()
224 | kl_loss += GAMMA_kl * loss_kl.item()
225 | pred_loss += GAMMA_pred * loss_pred.item()
226 | e2c_optimizer.step()
227 | if batch_idx % 5 == 0:
228 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
229 | epoch, batch_idx * len(batch_data['image_bi_cur']), len(trainloader.dataset),
230 | 100. * batch_idx / len(trainloader),
231 | loss.item() / len(batch_data['image_bi_cur'])))
232 | # reconstruction
233 | if batch_idx == 0:
234 | n = min(batch_data['image_bi_cur'].size(0), 8)
235 | comparison = torch.cat([batch_data['image_bi_cur'][:n], # current image
236 | x_dec.view(-1, 1, 50, 50).cpu()[:n]]) # reconstruction of current image
237 | save_image(comparison.cpu(),
238 | './result/{}/reconstruction_train/reconstruct_epoch_{}.png'.format(folder_name, epoch), nrow=n)
239 | print('====> Epoch: {} Average loss: {:.4f}'.format(
240 | epoch, train_loss / len(trainloader.dataset)))
241 | n = len(trainloader.dataset)
242 | return train_loss/n, bound_loss/n, kl_loss/n, pred_loss/n
243 |
244 | def test(e2c_model):
245 | e2c_model.eval()
246 | test_loss = 0
247 | bound_loss = 0
248 | kl_loss = 0
249 | pred_loss = 0
250 | with torch.no_grad():
251 | for batch_idx, batch_data in enumerate(testloader):
252 | # current image before current action
253 | x = batch_data['image_bi_cur']
254 | x = x.float().to(device).view(-1, 1, 50, 50)
255 | # current action
256 | action = batch_data['resz_action_cur']
257 | action = action.float().to(device).view(-1, 4)
258 | # post image after current action
259 | x_next = batch_data['image_bi_post']
260 | x_next = x_next.float().to(device).view(-1, 1, 50, 50)
261 | # model
262 | x_dec, x_next_pred_dec, Qz, Qz_next, Qz_next_pred = e2c_model(x, action, x_next)
263 | # prediction
264 | x_next_pred = e2c_model.predict(x, action)
265 | # loss
266 | loss_bound, loss_kl = compute_loss(x_dec, x_next_pred_dec, x, x_next, Qz, Qz_next_pred, Qz_next)
267 | loss_pred = F.binary_cross_entropy(x_next_pred.view(-1, 2500), x_next.view(-1, 2500), reduction='sum')
268 | loss = GAMMA_bound * loss_bound + GAMMA_kl * loss_kl + GAMMA_pred * loss_pred
269 | test_loss += loss.item()
270 | bound_loss += GAMMA_bound * loss_bound.item()
271 | kl_loss += GAMMA_kl * loss_kl.item()
272 | pred_loss += GAMMA_pred * loss_pred.item()
273 | if batch_idx == 0:
274 | n = min(batch_data['image_bi_cur'].size(0), 8)
275 | comparison = torch.cat([batch_data['image_bi_cur'][:n], # current image
276 | x_dec.view(-1, 1, 50, 50).cpu()[:n], # reconstruction of current image
277 | batch_data['image_bi_post'][:n], # post image
278 | x_next_pred.view(-1, 1, 50, 50).cpu()[:n]]) # prediction of post image
279 | save_image(comparison.cpu(),
280 | './result/{}/reconstruction_test/reconstruct_epoch_{}.png'.format(folder_name, epoch), nrow=n)
281 | n = len(testloader.dataset)
282 | return test_loss/n, bound_loss/n, kl_loss/n, pred_loss/n
283 |
284 | # args
285 | parser = argparse.ArgumentParser(description='E2C Rope Deform Example')
286 | parser.add_argument('--folder-name', default='test_E2C',
287 | help='set folder name to save image files')
288 | parser.add_argument('--batch-size', type=int, default=32, metavar='N',
289 | help='input batch size for training (default: 64)')
290 | parser.add_argument('--epochs', type=int, default=2, metavar='N',
291 | help='number of epochs to train (default: 500)')
292 | parser.add_argument('--gamma-bound', type=int, default=10000, metavar='N',
293 | help='scale coefficient for loss of kl divergence for z (default: 10)')
294 | parser.add_argument('--gamma-kl', type=int, default=1, metavar='N',
295 | help='scale coefficient for loss of kl divergence for z (default: 10)')
296 | parser.add_argument('--gamma-pred', type=int, default=1, metavar='N',
297 | help='scale coefficient for loss of prediction (default: 100)')
298 | parser.add_argument('--no-cuda', action='store_true', default=False,
299 | help='enables CUDA training')
300 | parser.add_argument('--seed', type=int, default=1, metavar='S',
301 | help='random seed (default: 1)')
302 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
303 | help='how many batches to wait before logging training status')
304 | parser.add_argument('--restore', action='store_true', default=False)
305 | args = parser.parse_args()
306 | args.cuda = not args.no_cuda and torch.cuda.is_available()
307 | torch.manual_seed(args.seed)
308 |
309 |
310 | # dataset
311 | print('***** Preparing Data *****')
312 | total_img_num = 22515
313 | train_num = int(total_img_num * 0.8)
314 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num)
315 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy'
316 | resz_act = np.load(resz_act_path)
317 | # transform = transforms.Compose([Translation(10),
318 | # HFlip(0.5),
319 | # VFlip(0.5),
320 | # ToTensor()])
321 | transform = transforms.Compose([Translation(10),
322 | ToTensor()])
323 | trainset = MyDataset(image_paths_bi[0:train_num], resz_act[0:train_num], transform=transform)
324 | testset = MyDataset(image_paths_bi[train_num:], resz_act[train_num:], transform=ToTensor())
325 | trainloader = DataLoader(trainset, batch_size=args.batch_size,
326 | shuffle=True, num_workers=4, collate_fn=my_collate)
327 | testloader = DataLoader(testset, batch_size=args.batch_size,
328 | shuffle=True, num_workers=4, collate_fn=my_collate)
329 | print('***** Finish Preparing Data *****')
330 |
331 |
332 | # train var
333 | GAMMA_bound = args.gamma_bound
334 | GAMMA_kl = args.gamma_kl
335 | GAMMA_pred = args.gamma_pred
336 |
337 | # create folders
338 | folder_name = args.folder_name
339 | create_folder(folder_name)
340 |
341 | print('***** Start Training & Testing *****')
342 | device = torch.device("cuda" if args.cuda else "cpu")
343 | epochs = args.epochs
344 | e2c_model = E2C().to(device)
345 | e2c_optimizer = optim.Adam(e2c_model.parameters(), lr=1e-3)
346 |
347 | # initial train
348 | if not args.restore:
349 | init_epoch = 1
350 | loss_logger = None
351 | # restore previous train
352 | else:
353 | print('***** Load Checkpoint *****')
354 | PATH = './result/{}/checkpoint'.format(folder_name)
355 | checkpoint = torch.load(PATH, map_location=device)
356 | e2c_model.load_state_dict(checkpoint['e2c_model_state_dict'])
357 | e2c_optimizer.load_state_dict(checkpoint['e2c_optimizer_state_dict'])
358 | init_epoch = checkpoint['epoch'] + 1
359 | loss_logger = checkpoint['loss_logger']
360 |
361 | train_loss_all = []
362 | test_loss_all = []
363 | train_bound_loss_all = []
364 | test_bound_loss_all = []
365 | train_kl_loss_all = []
366 | test_kl_loss_all = []
367 | train_pred_loss_all = []
368 | test_pred_loss_all = []
369 | for epoch in range(init_epoch, epochs+1):
370 | train_loss, train_bound_loss, train_kl_loss, train_pred_loss = train(e2c_model)
371 | test_loss, test_bound_loss, test_kl_loss, test_pred_loss = test(e2c_model)
372 | train_loss_all.append(train_loss)
373 | test_loss_all.append(test_loss)
374 | train_bound_loss_all.append(train_bound_loss)
375 | test_bound_loss_all.append(test_bound_loss)
376 | train_kl_loss_all.append(train_kl_loss)
377 | test_kl_loss_all.append(test_kl_loss)
378 | train_pred_loss_all.append(train_pred_loss)
379 | test_pred_loss_all.append(test_pred_loss)
380 | if epoch % args.log_interval == 0:
381 | save_e2c_data(folder_name, epochs, train_loss_all, train_bound_loss_all, train_kl_loss_all, train_pred_loss_all, \
382 | test_loss_all, test_bound_loss_all, test_kl_loss_all, test_pred_loss_all)
383 | # save checkpoint
384 | PATH = './result/{}/checkpoint'.format(folder_name)
385 | loss_logger = {'train_loss_all': train_loss_all, 'train_bound_loss_all': train_bound_loss_all,
386 | 'train_kl_loss_all': train_kl_loss_all, 'train_pred_loss_all': train_pred_loss_all,
387 | 'test_loss_all': test_loss_all, 'test_bound_loss_all': test_bound_loss_all,
388 | 'test_kl_loss_all': test_kl_loss_all, 'test_pred_loss_all': test_pred_loss_all}
389 | torch.save({
390 | 'epoch': epoch,
391 | 'e2c_model_state_dict': e2c_model.state_dict(),
392 | 'e2c_optimizer_state_dict': e2c_optimizer.state_dict(),
393 | 'loss_logger': loss_logger
394 | },
395 | PATH)
396 |
397 |
398 | save_e2c_data(folder_name, epochs, train_loss_all, train_bound_loss_all, train_kl_loss_all, train_pred_loss_all, \
399 | test_loss_all, test_bound_loss_all, test_kl_loss_all, test_pred_loss_all)
400 |
401 | # plot
402 | plot_train_loss('./result/{}/train_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
403 | plot_train_bound_loss('./result/{}/train_bound_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
404 | plot_train_kl_loss('./result/{}/train_kl_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
405 | plot_train_pred_loss('./result/{}/train_pred_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
406 | plot_test_loss('./result/{}/test_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
407 | plot_test_bound_loss('./result/{}/test_bound_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
408 | plot_test_kl_loss('./result/{}/test_kl_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
409 | plot_test_pred_loss('./result/{}/test_pred_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
410 |
411 |
412 | # save checkpoint
413 | PATH = './result/{}/checkpoint'.format(folder_name)
414 | loss_logger = {'train_loss_all': train_loss_all, 'train_bound_loss_all': train_bound_loss_all,
415 | 'train_kl_loss_all': train_kl_loss_all, 'train_pred_loss_all': train_pred_loss_all,
416 | 'test_loss_all': test_loss_all, 'test_bound_loss_all': test_bound_loss_all,
417 | 'test_kl_loss_all': test_kl_loss_all, 'test_pred_loss_all': test_pred_loss_all}
418 | torch.save({
419 | 'epoch': epoch,
420 | 'e2c_model_state_dict': e2c_model.state_dict(),
421 | 'e2c_optimizer_state_dict': e2c_optimizer.state_dict(),
422 | 'loss_logger': loss_logger
423 | },
424 | PATH)
425 | print('***** End Program *****')
--------------------------------------------------------------------------------
/model/hidden_dynamics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy
3 | from numpy.linalg import inv, pinv, norm, det
4 | import torch
5 |
6 |
7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8 |
9 | def get_control_matrix(G, U):
10 | '''Calculate control matrix
11 | G: n * len(x_i)
12 | U: n * len(u_i)
13 | return control matrix L: len(x_i) * len(u_i)
14 | '''
15 | n, d = U.size()
16 | # U is a thin matrix
17 | if n > d:
18 | eps = 1e-5
19 | return (torch.pinverse(U.t().mm(U) + eps*torch.eye(d).to(device).mm(U.t()).mm(G)).to(device)).t()
20 | # U is a fat matrix
21 | elif n < d:
22 | eps = 1e-5
23 | return (U.t().mm(torch.pinverse(U.mm(U.t()) + eps*torch.eye(n).to(device))).to(device).mm(G)).t()
24 | # U is a squared matrix
25 | else:
26 | return (torch.inverse(U).to(device).mm(G)).t()
27 |
28 |
29 | def get_error(G, U, L):
30 | '''||G-UL^T||^2
31 | '''
32 | return torch.norm(G-U.mm(L.t()))
33 |
34 | def get_error_linear(G, U, L_T):
35 | '''||G-UL_T||^2
36 | '''
37 | err = G - torch.matmul(U.view(U.shape[0], 1, -1), L_T)
38 | return torch.norm(err.view(err.shape[0], -1))
39 |
40 | def get_next_state(latent_image_pre, latent_action, K, L):
41 | '''get next embedded state after certain steps
42 | g_{t+1} = K * g_{t} + L * u_{t}
43 |
44 | embedded_state: 1 * len(x_i)
45 | action: m * len(u_i), m is the number of predicted steps
46 | L: len(x_i) * len(u_i), control matrix
47 | '''
48 |
49 | return latent_image_pre.mm(K.t().to(device)) + latent_action.mm(L.t().to(device))
50 |
51 | def get_next_state_linear(latent_image_pre, latent_action, K_T, L_T, z=None):
52 | '''
53 | latent_image_pre: (batch_size, latent_state_dim)
54 | latent_action: (batch_size, latent_act_dim)
55 | K_T: (batch_size, latent_state_dim, latent_state_dim)
56 | L_T: (batch_size, latent_act_dim, latent_state_dim)
57 | '''
58 | if z is not None:
59 | return (torch.matmul(latent_image_pre.view(latent_image_pre.shape[0], 1, -1), K_T)).view(latent_image_pre.shape[0], -1) + \
60 | (torch.matmul(latent_action.view(latent_action.shape[0], 1, -1), L_T)).view(latent_action.shape[0], -1) + z.to(device)
61 | else:
62 | return (torch.matmul(latent_image_pre.view(latent_image_pre.shape[0], 1, -1), K_T)).view(latent_image_pre.shape[0], -1) + \
63 | (torch.matmul(latent_action.view(latent_action.shape[0], 1, -1), L_T)).view(latent_action.shape[0], -1)
64 |
65 | def get_next_state_linear_without_L(latent_image_pre, K_T, z=None):
66 | '''
67 | latent_image_pre: (batch_size, latent_state_dim)
68 | latent_action: (batch_size, latent_act_dim)
69 | K_T: (batch_size, latent_state_dim, latent_state_dim)
70 | '''
71 | if z is not None:
72 | return (torch.matmul(latent_image_pre.view(latent_image_pre.shape[0], 1, -1), K_T)).view(latent_image_pre.shape[0], -1) + z.to(device)
73 | else:
74 | return (torch.matmul(latent_image_pre.view(latent_image_pre.shape[0], 1, -1), K_T)).view(latent_image_pre.shape[0], -1)
75 |
76 |
--------------------------------------------------------------------------------
/model/image/action_model.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/action_model.pdf
--------------------------------------------------------------------------------
/model/image/action_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/action_model.png
--------------------------------------------------------------------------------
/model/image/dynamics_model.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/dynamics_model.pdf
--------------------------------------------------------------------------------
/model/image/dynamics_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/dynamics_model.png
--------------------------------------------------------------------------------
/model/image/state_model.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/state_model.pdf
--------------------------------------------------------------------------------
/model/image/state_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/model/image/state_model.png
--------------------------------------------------------------------------------
/model/nn_large_linear_seprt_no_loop_Kp_Lpa.py:
--------------------------------------------------------------------------------
1 | # separate two models, train g^t first, then train K and L
2 | from __future__ import print_function
3 | import argparse
4 |
5 | import torch
6 | from torch import nn, optim, sigmoid, tanh, relu
7 | from torch.nn import functional as F
8 | from torchvision.utils import save_image
9 |
10 | from torch.utils.data import Dataset, DataLoader
11 | from deform.model.create_dataset import *
12 | from deform.model.hidden_dynamics import *
13 | import matplotlib.pyplot as plt
14 | from deform.utils.utils import plot_train_loss, plot_train_latent_loss, plot_train_img_loss, plot_train_act_loss, plot_train_pred_loss, \
15 | plot_test_loss, plot_test_latent_loss, plot_test_img_loss, plot_test_act_loss, plot_test_pred_loss, \
16 | plot_sample, rect, save_data, create_loss_list, create_folder, plot_grad_flow
17 | import os
18 | import math
19 |
20 | class CAE(nn.Module):
21 | def __init__(self, latent_state_dim=80, latent_act_dim=80):
22 | super(CAE, self).__init__()
23 | # state
24 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1),
25 | nn.ReLU(),
26 | nn.MaxPool2d(3, stride=2),
27 | nn.Conv2d(32, 64, 3, padding=1),
28 | nn.ReLU(),
29 | nn.MaxPool2d(3, stride=2),
30 | nn.Conv2d(64, 128, 3, padding=1),
31 | nn.ReLU(),
32 | nn.MaxPool2d(3, stride=2),
33 | nn.Conv2d(128, 128, 3, padding=1),
34 | nn.ReLU(),
35 | nn.Conv2d(128, 128, 3, padding=1),
36 | nn.ReLU(),
37 | nn.MaxPool2d(3, stride=2, padding=1))
38 | self.fc1 = nn.Linear(128*3*3, latent_state_dim)
39 | self.fc2 = nn.Linear(latent_state_dim, 128*3*3)
40 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
41 | nn.ReLU(),
42 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
43 | nn.ReLU(),
44 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2),
45 | nn.ReLU(),
46 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2),
47 | nn.ReLU(),
48 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2),
49 | nn.Sigmoid())
50 | # action
51 | self.fc5 = nn.Linear(4, latent_act_dim)
52 | self.fc6 = nn.Linear(latent_act_dim, latent_act_dim)
53 | self.fc7 = nn.Linear(latent_act_dim, latent_act_dim)
54 | self.fc8 = nn.Linear(latent_act_dim, 4)
55 | # add these in order to use GPU for parameters
56 | self.mul_tensor = torch.tensor([50, 50, 2*math.pi, 0.14])
57 | self.add_tensor = torch.tensor([0, 0, 0, 0.01])
58 |
59 |
60 | def encoder(self, x):
61 | x = self.conv_layers(x)
62 | x = x.view(x.shape[0], -1)
63 | return relu(self.fc1(x))
64 |
65 | def decoder(self, x):
66 | x = relu(self.fc2(x))
67 | x = x.view(-1, 128, 3, 3)
68 | return self.dconv_layers(x)
69 |
70 | def encoder_act(self, u):
71 | h1 = relu(self.fc5(u))
72 | return relu(self.fc6(h1))
73 |
74 | def decoder_act(self, u):
75 | h2 = relu(self.fc7(u))
76 | return torch.mul(sigmoid(self.fc8(h2)), self.mul_tensor.cuda()) + self.add_tensor.cuda()
77 |
78 | def forward(self, x_cur, u, x_post):
79 | g_cur = self.encoder(x_cur)
80 | a = self.encoder_act(u)
81 | g_post = self.encoder(x_post)
82 |
83 | return g_cur, a, g_post, self.decoder(g_cur), self.decoder_act(a)
84 |
85 | class SysDynamics(nn.Module):
86 | def __init__(self, latent_state_dim=80, latent_act_dim=80):
87 | super(SysDynamics, self).__init__()
88 | self.conv_layers_matrix = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1),
89 | nn.ReLU(),
90 | nn.MaxPool2d(3, stride=1),
91 | nn.Conv2d(32, 64, 3, padding=1),
92 | nn.ReLU(),
93 | nn.MaxPool2d(3, stride=1),
94 | nn.Conv2d(64, 128, 3, padding=1),
95 | nn.ReLU(),
96 | nn.MaxPool2d(3, stride=2),
97 | nn.Conv2d(128, 256, 3, padding=1),
98 | nn.ReLU(),
99 | nn.MaxPool2d(3, stride=2),
100 | nn.Conv2d(256, 512, 3, padding=1),
101 | nn.ReLU(),
102 | nn.MaxPool2d(3, stride=2),
103 | nn.Conv2d(512, 512, 3, padding=1),
104 | nn.ReLU(),
105 | nn.Conv2d(512, 512, 3, padding=1),
106 | nn.ReLU(),
107 | nn.MaxPool2d(3, stride=2, padding=1))
108 | self.fc31 = nn.Linear(512*2*2, latent_state_dim*latent_state_dim)
109 | self.fc32 = nn.Linear(latent_state_dim*latent_state_dim, latent_state_dim*latent_state_dim)
110 | self.fc41 = nn.Linear(512*2*2 + latent_act_dim, latent_state_dim*latent_act_dim)
111 | self.fc42 = nn.Linear(latent_state_dim*latent_act_dim, latent_state_dim*latent_act_dim)
112 | self.fc9 = nn.Linear(4, latent_act_dim)
113 | self.fc10 = nn.Linear(latent_act_dim, latent_act_dim)
114 | # latent dim
115 | self.latent_act_dim = latent_act_dim
116 | self.latent_state_dim = latent_state_dim
117 |
118 | def encoder_matrix(self, x, a):
119 | x = self.conv_layers_matrix(x)
120 | x = x.view(x.shape[0], -1)
121 | xa = torch.cat((x,a), 1)
122 |
123 | return relu(self.fc32(relu(self.fc31(x)))).view(-1, self.latent_state_dim, self.latent_state_dim), \
124 | relu(self.fc42(relu(self.fc41(xa)))).view(-1, self.latent_act_dim, self.latent_state_dim)
125 |
126 | def forward(self, x_cur, u):
127 | a = relu(self.fc10(relu(self.fc9(u))))
128 | K_T, L_T = self.encoder_matrix(x_cur, a)
129 |
130 | return K_T, L_T
131 |
132 | def loss_function(recon_x, x):
133 | '''
134 | recon_x: tensor
135 | x: tensor
136 | '''
137 | return F.binary_cross_entropy(recon_x.view(-1, 2500), x.view(-1, 2500), reduction='sum')
138 |
139 |
140 | def mse(recon_x, x):
141 | '''mean square error
142 | recon_x: numpy array
143 | x: numpy array
144 | '''
145 | return F.mse_loss(recon_x, x)
146 |
147 | def loss_function_img(recon_img, img):
148 | return F.binary_cross_entropy(recon_img.view(-1, 2500), img.view(-1, 2500), reduction='sum')
149 |
150 | def loss_function_act(recon_act, act):
151 | return F.mse_loss(recon_act.view(-1, 4), act.view(-1, 4), reduction='sum')
152 |
153 | def loss_function_latent_linear(latent_img_pre, latent_img_post, latent_action, K_T, L_T):
154 | G = latent_img_post.view(latent_img_post.shape[0], 1, -1) - torch.matmul(latent_img_pre.view(latent_img_pre.shape[0], 1, -1), K_T)
155 | return get_error_linear(G, latent_action, L_T)
156 |
157 | def loss_function_pred_linear(img_post, latent_img_pre, latent_act, K_T, L_T):
158 | recon_latent_img_post = get_next_state_linear(latent_img_pre, latent_act, K_T, L_T)
159 | recon_img_post = recon_model.decoder(recon_latent_img_post)
160 | return F.binary_cross_entropy(recon_img_post.view(-1, 2500), img_post.view(-1, 2500), reduction='sum')
161 |
162 | def constraint_loss(steps, idx, trainset, U_latent, L):
163 | loss = 0
164 | data = trainset.__getitem__(idx).float().to(device).view(-1, 1, 50, 50)
165 | embed_state = model.encoder(data).detach().cpu().numpy()
166 | for i in range(steps):
167 | step = i + 1
168 | data_next = trainset.__getitem__(idx+step).float().to(device).view(-1, 1, 50, 50)
169 | action = U_latent[idx:idx+step][:]
170 | embed_state_next = torch.from_numpy(get_next_state(embed_state, action, L)).to(device).float()
171 | recon_state_next = model.decoder(embed_state_next)#.detach().cpu()#.numpy()
172 | loss += mse(recon_state_next, data_next)
173 | return loss
174 |
175 | def train_new(epoch, recon_model, dyn_model, epoch_thres=500):
176 | if epoch < epoch_thres:
177 | recon_model.train()
178 | dyn_model.eval()
179 | train_loss = 0
180 | img_loss = 0
181 | act_loss = 0
182 | latent_loss = 0
183 | pred_loss = 0
184 | for batch_idx, batch_data in enumerate(trainloader):
185 | # current image before action
186 | img_cur = batch_data['image_bi_cur']
187 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50)
188 | # action
189 | act = batch_data['resz_action_cur']
190 | act = act.float().to(device).view(-1, 4)
191 | # image after action
192 | img_post = batch_data['image_bi_post']
193 | img_post = img_post.float().to(device).view(-1, 1, 50, 50)
194 | # optimization
195 | recon_optimizer.zero_grad()
196 | # model
197 | latent_img_cur, latent_act, latent_img_post, recon_img_cur, recon_act = recon_model(img_cur, act, img_post)
198 | # loss
199 | loss_img = loss_function_img(recon_img_cur, img_cur)
200 | loss_act = loss_function_act(recon_act, act)
201 | loss = loss_img + GAMMA_act * loss_act
202 | loss.backward()
203 | train_loss += loss.item()
204 | img_loss += loss_img.item()
205 | act_loss += GAMMA_act * loss_act.item()
206 |
207 | recon_optimizer.step()
208 | if batch_idx % 5 == 0:
209 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
210 | epoch, batch_idx * len(batch_data['image_bi_cur']), len(trainloader.dataset),
211 | 100. * batch_idx / len(trainloader),
212 | loss.item() / len(batch_data['image_bi_cur'])))
213 | # reconstruction
214 | if batch_idx == 0:
215 | n = min(batch_data['image_bi_cur'].size(0), 8)
216 | comparison = torch.cat([batch_data['image_bi_cur'][:n], # current image
217 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n]]) # reconstruction of current image
218 | save_image(comparison.cpu(),
219 | './result/{}/reconstruction_train/reconstruct_epoch_{}.png'.format(folder_name, epoch), nrow=n)
220 | plot_sample(batch_data['image_bi_cur'][:n].detach().cpu().numpy(),
221 | batch_data['image_bi_post'][:n].detach().cpu().numpy(),
222 | batch_data['resz_action_cur'][:n].detach().cpu().numpy(),
223 | recon_act.view(-1, 4)[:n].detach().cpu().numpy(),
224 | './result/{}/reconstruction_act_train/recon_epoch_{}.png'.format(folder_name, epoch))
225 | print('====> Epoch: {} Average loss: {:.4f}'.format(
226 | epoch, train_loss / len(trainloader.dataset)))
227 | n = len(trainloader.dataset)
228 | return train_loss/n, img_loss/n, act_loss/n, latent_loss/n, pred_loss/n
229 | else:
230 | recon_model.eval()
231 | dyn_model.train()
232 | train_loss = 0
233 | img_loss = 0
234 | act_loss = 0
235 | latent_loss = 0
236 | pred_loss = 0
237 | for batch_idx, batch_data in enumerate(trainloader):
238 | # current image before action
239 | img_cur = batch_data['image_bi_cur']
240 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50)
241 | # action
242 | act = batch_data['resz_action_cur']
243 | act = act.float().to(device).view(-1, 4)
244 | # image after action
245 | img_post = batch_data['image_bi_post']
246 | img_post = img_post.float().to(device).view(-1, 1, 50, 50)
247 | # optimization
248 | dyn_optimizer.zero_grad()
249 | # model
250 | latent_img_cur, latent_act, latent_img_post, recon_img_cur, recon_act = recon_model(img_cur, act, img_post)
251 | K_T, L_T = dyn_model(img_cur, act)
252 | # prediction
253 | pred_latent_img_post = get_next_state_linear(latent_img_cur, latent_act, K_T, L_T)
254 | pred_img_post = recon_model.decoder(pred_latent_img_post)
255 | # loss
256 | loss_latent = loss_function_latent_linear(latent_img_cur, latent_img_post, latent_act, K_T, L_T)
257 | loss_predict = loss_function_pred_linear(img_post, latent_img_cur, latent_act, K_T, L_T)
258 | loss = GAMMA_latent * loss_latent + GAMMA_pred * loss_predict
259 | loss.backward()
260 | train_loss += loss.item()
261 | latent_loss += GAMMA_latent * loss_latent.item()
262 | pred_loss += GAMMA_pred * loss_predict.item()
263 |
264 | dyn_optimizer.step()
265 | if batch_idx % 5 == 0:
266 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
267 | epoch, batch_idx * len(batch_data['image_bi_cur']), len(trainloader.dataset),
268 | 100. * batch_idx / len(trainloader),
269 | loss.item() / len(batch_data['image_bi_cur'])))
270 | # reconstruction
271 | if batch_idx == 0:
272 | n = min(batch_data['image_bi_cur'].size(0), 8)
273 | comparison = torch.cat([batch_data['image_bi_cur'][:n], # current image
274 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n], # reconstruction of current image
275 | batch_data['image_bi_post'][:n], # post image
276 | pred_img_post.view(-1, 1, 50, 50).cpu()[:n]]) # prediction of post image
277 | save_image(comparison.cpu(),
278 | './result/{}/reconstruction_train/reconstruct_epoch_{}.png'.format(folder_name, epoch), nrow=n)
279 | plot_sample(batch_data['image_bi_cur'][:n].detach().cpu().numpy(),
280 | batch_data['image_bi_post'][:n].detach().cpu().numpy(),
281 | batch_data['resz_action_cur'][:n].detach().cpu().numpy(),
282 | recon_act.view(-1, 4)[:n].detach().cpu().numpy(),
283 | './result/{}/reconstruction_act_train/recon_epoch_{}.png'.format(folder_name, epoch))
284 | print('====> Epoch: {} Average loss: {:.4f}'.format(
285 | epoch, train_loss / len(trainloader.dataset)))
286 | n = len(trainloader.dataset)
287 | return train_loss/n, img_loss/n, act_loss/n, latent_loss/n, pred_loss/n
288 |
289 | def test_new(epoch, recon_model, dyn_model):
290 | recon_model.eval()
291 | dyn_model.eval()
292 | test_loss = 0
293 | img_loss = 0
294 | act_loss = 0
295 | latent_loss = 0
296 | pred_loss = 0
297 | with torch.no_grad():
298 | for batch_idx, batch_data in enumerate(testloader):
299 | # current image before current action
300 | img_cur = batch_data['image_bi_cur']
301 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50)
302 | # current action
303 | act = batch_data['resz_action_cur']
304 | act = act.float().to(device).view(-1, 4)
305 | # post image after current action
306 | img_post = batch_data['image_bi_post']
307 | img_post = img_post.float().to(device).view(-1, 1, 50, 50)
308 | # model
309 | latent_img_cur, latent_act, latent_img_post, recon_img_cur, recon_act = recon_model(img_cur, act, img_post)
310 | K_T, L_T = dyn_model(img_cur, act)
311 | # prediction
312 | pred_latent_img_post = get_next_state_linear(latent_img_cur, latent_act, K_T, L_T)
313 | pred_img_post = recon_model.decoder(pred_latent_img_post)
314 | # loss
315 | loss_img = loss_function_img(recon_img_cur, img_cur)
316 | loss_act = loss_function_act(recon_act, act)
317 | loss_latent = loss_function_latent_linear(latent_img_cur, latent_img_post, latent_act, K_T, L_T)
318 | loss_predict = loss_function_pred_linear(img_post, latent_img_cur, latent_act, K_T, L_T)
319 | loss = loss_img + GAMMA_act * loss_act + GAMMA_latent * loss_latent + GAMMA_pred * loss_predict
320 | test_loss += loss.item()
321 | img_loss += loss_img.item()
322 | act_loss += GAMMA_act * loss_act.item()
323 | latent_loss += GAMMA_latent * loss_latent.item()
324 | pred_loss += GAMMA_pred * loss_predict.item()
325 | if batch_idx == 0:
326 | n = min(batch_data['image_bi_cur'].size(0), 8)
327 | comparison = torch.cat([batch_data['image_bi_cur'][:n], # current image
328 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n], # reconstruction of current image
329 | batch_data['image_bi_post'][:n], # post image
330 | pred_img_post.view(-1, 1, 50, 50).cpu()[:n]]) # prediction of post image
331 | save_image(comparison.cpu(),
332 | './result/{}/reconstruction_test/reconstruct_epoch_{}.png'.format(folder_name, epoch), nrow=n)
333 | plot_sample(batch_data['image_bi_cur'][:n].detach().cpu().numpy(),
334 | batch_data['image_bi_post'][:n].detach().cpu().numpy(),
335 | batch_data['resz_action_cur'][:n].detach().cpu().numpy(),
336 | recon_act.view(-1, 4)[:n].detach().cpu().numpy(),
337 | './result/{}/reconstruction_act_test/recon_epoch_{}.png'.format(folder_name, epoch))
338 | n = len(testloader.dataset)
339 | return test_loss/n, img_loss/n, act_loss/n, latent_loss/n, pred_loss/n
340 |
341 | # args
342 | parser = argparse.ArgumentParser(description='CAE Rope Deform Example')
343 | parser.add_argument('--folder-name', default='test',
344 | help='set folder name to save image files')
345 | parser.add_argument('--batch-size', type=int, default=32, metavar='N',
346 | help='input batch size for training (default: 64)')
347 | parser.add_argument('--epochs', type=int, default=1000, metavar='N',
348 | help='number of epochs to train (default: 500)')
349 | parser.add_argument('--gamma-act', type=int, default=450, metavar='N',
350 | help='scale coefficient for loss of action (default: 150*3)')
351 | parser.add_argument('--gamma-lat', type=int, default=900, metavar='N',
352 | help='scale coefficient for loss of latent dynamics (default: 150*6)')
353 | parser.add_argument('--gamma-pred', type=int, default=10, metavar='N',
354 | help='scale coefficient for loss of prediction (default: 3)')
355 | parser.add_argument('--no-cuda', action='store_true', default=False,
356 | help='enables CUDA training')
357 | parser.add_argument('--math', default=False,
358 | help='get control matrix L: True - use regression, False - use backpropagation')
359 | parser.add_argument('--seed', type=int, default=1, metavar='S',
360 | help='random seed (default: 1)')
361 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
362 | help='how many batches to wait before logging training status')
363 | parser.add_argument('--restore', action='store_true', default=False)
364 | args = parser.parse_args()
365 | args.cuda = not args.no_cuda and torch.cuda.is_available()
366 | torch.manual_seed(args.seed)
367 |
368 |
369 | # dataset
370 | print('***** Preparing Data *****')
371 | total_img_num = 22515
372 | train_num = int(total_img_num * 0.8)
373 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num)
374 | resz_act_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy'
375 | resz_act = np.load(resz_act_path)
376 | # transform = transforms.Compose([Translation(10),
377 | # HFlip(0.5),
378 | # VFlip(0.5),
379 | # ToTensor()])
380 | transform = transforms.Compose([Translation(10),
381 | ToTensor()])
382 | trainset = MyDataset(image_paths_bi[0:train_num], resz_act[0:train_num], transform=transform)
383 | testset = MyDataset(image_paths_bi[train_num:], resz_act[train_num:], transform=ToTensor())
384 | trainloader = DataLoader(trainset, batch_size=args.batch_size,
385 | shuffle=True, num_workers=4, collate_fn=my_collate)
386 | testloader = DataLoader(testset, batch_size=args.batch_size,
387 | shuffle=True, num_workers=4, collate_fn=my_collate)
388 | print('***** Finish Preparing Data *****')
389 |
390 | # train var
391 | MATH = args.math
392 | GAMMA_act = args.gamma_act
393 | GAMMA_latent = args.gamma_lat
394 | GAMMA_pred = args.gamma_pred
395 |
396 | # create folders
397 | folder_name = args.folder_name
398 | create_folder(folder_name)
399 |
400 | print('***** Start Training & Testing *****')
401 | device = torch.device("cuda" if args.cuda else "cpu")
402 | epochs = args.epochs
403 | recon_model = CAE().to(device)
404 | dyn_model = SysDynamics().to(device)
405 | recon_optimizer = optim.Adam(recon_model.parameters(), lr=1e-3)
406 | dyn_optimizer = optim.Adam(dyn_model.parameters(), lr=1e-3)
407 |
408 | # initial train
409 | if not args.restore:
410 | init_epoch = 1
411 | loss_logger = None
412 | # restore previous train
413 | else:
414 | print('***** Load Checkpoint *****')
415 | PATH = './result/{}/checkpoint'.format(folder_name)
416 | checkpoint = torch.load(PATH, map_location=device)
417 | recon_model.load_state_dict(checkpoint['recon_model_state_dict'])
418 | dyn_model.load_state_dict(checkpoint['dyn_model_state_dict'])
419 | recon_optimizer.load_state_dict(checkpoint['recon_optimizer_state_dict'])
420 | dyn_optimizer.load_state_dict(checkpoint['dyn_optimizer_state_dict'])
421 | init_epoch = checkpoint['epoch'] + 1
422 | loss_logger = checkpoint['loss_logger']
423 |
424 | train_loss_all, train_img_loss_all, train_act_loss_all, train_latent_loss_all, train_pred_loss_all, _, \
425 | test_loss_all, test_img_loss_all, test_act_loss_all, test_latent_loss_all, test_pred_loss_all, _ = create_loss_list(loss_logger, kld=False)
426 |
427 |
428 | for epoch in range(init_epoch, epochs+1):
429 | train_loss, train_img_loss, train_act_loss, train_latent_loss, train_pred_loss = train_new(epoch, recon_model, dyn_model, epoch_thres=int(epochs/2))
430 | test_loss, test_img_loss, test_act_loss, test_latent_loss, test_pred_loss = test_new(epoch, recon_model, dyn_model)
431 | train_loss_all.append(train_loss)
432 | train_img_loss_all.append(train_img_loss)
433 | train_act_loss_all.append(train_act_loss)
434 | train_latent_loss_all.append(train_latent_loss)
435 | train_pred_loss_all.append(train_pred_loss)
436 | test_loss_all.append(test_loss)
437 | test_img_loss_all.append(test_img_loss)
438 | test_act_loss_all.append(test_act_loss)
439 | test_latent_loss_all.append(test_latent_loss)
440 | test_pred_loss_all.append(test_pred_loss)
441 | if epoch % args.log_interval == 0:
442 | save_data(folder_name, epochs, train_loss_all, train_img_loss_all, train_act_loss_all,
443 | train_latent_loss_all, train_pred_loss_all, test_loss_all, test_img_loss_all,
444 | test_act_loss_all, test_latent_loss_all, test_pred_loss_all, None, None, None, None)
445 | # save checkpoint
446 | PATH = './result/{}/checkpoint'.format(folder_name)
447 | loss_logger = {'train_loss_all': train_loss_all, 'train_img_loss_all': train_img_loss_all,
448 | 'train_act_loss_all': train_act_loss_all, 'train_latent_loss_all': train_latent_loss_all,
449 | 'train_pred_loss_all': train_pred_loss_all, 'test_loss_all': test_loss_all,
450 | 'test_img_loss_all': test_img_loss_all, 'test_act_loss_all': test_act_loss_all,
451 | 'test_latent_loss_all': test_latent_loss_all, 'test_pred_loss_all': test_pred_loss_all}
452 | torch.save({
453 | 'epoch': epoch,
454 | 'recon_model_state_dict': recon_model.state_dict(),
455 | 'dyn_model_state_dict': dyn_model.state_dict(),
456 | 'recon_optimizer_state_dict': recon_optimizer.state_dict(),
457 | 'dyn_optimizer_state_dict': dyn_optimizer.state_dict(),
458 | 'loss_logger': loss_logger
459 | },
460 | PATH)
461 |
462 |
463 | save_data(folder_name, epochs, train_loss_all, train_img_loss_all, train_act_loss_all,
464 | train_latent_loss_all, train_pred_loss_all, test_loss_all, test_img_loss_all,
465 | test_act_loss_all, test_latent_loss_all, test_pred_loss_all, None, None, None, None)
466 |
467 | # plot
468 | plot_train_loss('./result/{}/train_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
469 | plot_train_img_loss('./result/{}/train_img_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
470 | plot_train_act_loss('./result/{}/train_act_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
471 | plot_train_latent_loss('./result/{}/train_latent_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
472 | plot_train_pred_loss('./result/{}/train_pred_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
473 | plot_test_loss('./result/{}/test_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
474 | plot_test_img_loss('./result/{}/test_img_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
475 | plot_test_act_loss('./result/{}/test_act_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
476 | plot_test_latent_loss('./result/{}/test_latent_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
477 | plot_test_pred_loss('./result/{}/test_pred_loss_epoch{}.npy'.format(folder_name, epochs), folder_name)
478 |
479 | # save checkpoint
480 | PATH = './result/{}/checkpoint'.format(folder_name)
481 | loss_logger = {'train_loss_all': train_loss_all, 'train_img_loss_all': train_img_loss_all,
482 | 'train_act_loss_all': train_act_loss_all, 'train_latent_loss_all': train_latent_loss_all,
483 | 'train_pred_loss_all': train_pred_loss_all, 'test_loss_all': test_loss_all,
484 | 'test_img_loss_all': test_img_loss_all, 'test_act_loss_all': test_act_loss_all,
485 | 'test_latent_loss_all': test_latent_loss_all, 'test_pred_loss_all': test_pred_loss_all}
486 | torch.save({
487 | 'epoch': epoch,
488 | 'recon_model_state_dict': recon_model.state_dict(),
489 | 'dyn_model_state_dict': dyn_model.state_dict(),
490 | 'recon_optimizer_state_dict': recon_optimizer.state_dict(),
491 | 'dyn_optimizer_state_dict': dyn_optimizer.state_dict(),
492 | 'loss_logger': loss_logger
493 | },
494 | PATH)
495 | print('***** End Program *****')
496 |
497 |
--------------------------------------------------------------------------------
/model/predict_e2c.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | from torch.autograd import Variable
5 | from torch import nn, optim, sigmoid, tanh, relu
6 | from torch.utils.data import Dataset, DataLoader
7 | from torch.nn import functional as F
8 | from deform.model.create_dataset import *
9 | from deform.model.hidden_dynamics import *
10 | from deform.utils.utils import plot_sample_multi_step
11 | from torchvision.utils import save_image
12 | import os
13 | import math
14 |
15 |
16 | class NormalDistribution(object):
17 | """
18 | Wrapper class representing a multivariate normal distribution parameterized by
19 | N(mu,Cov). If cov. matrix is diagonal, Cov=(sigma).^2. Otherwise,
20 | Cov=A*(sigma).^2*A', where A = (I+v*r^T).
21 | """
22 |
23 | def __init__(self, mu, sigma, logsigma, *, v=None, r=None):
24 | self.mu = mu
25 | self.sigma = sigma
26 | self.logsigma = logsigma
27 | self.v = v
28 | self.r = r
29 |
30 | @property
31 | def cov(self):
32 | """This should only be called when NormalDistribution represents one sample"""
33 | if self.v is not None and self.r is not None:
34 | assert self.v.dim() == 1
35 | dim = self.v.dim()
36 | v = self.v.unsqueeze(1) # D * 1 vector
37 | rt = self.r.unsqueeze(0) # 1 * D vector
38 | A = torch.eye(dim) + v.mm(rt)
39 | return A.mm(torch.diag(self.sigma.pow(2)).mm(A.t()))
40 | else:
41 | return torch.diag(self.sigma.pow(2))
42 |
43 |
44 | def KLDGaussian(Q, N, eps=1e-8):
45 | """KL Divergence between two Gaussians
46 | Assuming Q ~ N(mu0, A\sigma_0A') where A = I + vr^{T}
47 | and N ~ N(mu1, \sigma_1)
48 | """
49 | sum = lambda x: torch.sum(x, dim=1)
50 | k = float(Q.mu.size()[1]) # dimension of distribution
51 | mu0, v, r, mu1 = Q.mu, Q.v, Q.r, N.mu
52 | s02, s12 = (Q.sigma).pow(2) + eps, (N.sigma).pow(2) + eps
53 | a = sum(s02 * (1. + 2. * v * r) / s12) + sum(v.pow(2) / s12) * sum(r.pow(2) * s02) # trace term
54 | b = sum((mu1 - mu0).pow(2) / s12) # difference-of-means term
55 | c = 2. * (sum(N.logsigma - Q.logsigma) - torch.log(1. + sum(v * r) + eps)) # ratio-of-determinants term.
56 |
57 | return 0.5 * (a + b - k + c)
58 |
59 | class E2C(nn.Module):
60 | def __init__(self, dim_z=80, dim_u=4):
61 | super(E2C, self).__init__()
62 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1),
63 | nn.ReLU(),
64 | nn.MaxPool2d(3, stride=2),
65 | nn.Conv2d(32, 64, 3, padding=1),
66 | nn.ReLU(),
67 | nn.MaxPool2d(3, stride=2),
68 | nn.Conv2d(64, 128, 3, padding=1),
69 | nn.ReLU(),
70 | nn.MaxPool2d(3, stride=2),
71 | nn.Conv2d(128, 128, 3, padding=1),
72 | nn.ReLU(),
73 | nn.Conv2d(128, 128, 3, padding=1),
74 | nn.ReLU(),
75 | nn.MaxPool2d(3, stride=2, padding=1))
76 | self.fc1 = nn.Linear(128*3*3, dim_z*2)
77 | self.fc2 = nn.Linear(dim_z, 128*3*3)
78 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
79 | nn.ReLU(),
80 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
81 | nn.ReLU(),
82 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2),
83 | nn.ReLU(),
84 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2),
85 | nn.ReLU(),
86 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2),
87 | nn.Sigmoid())
88 | self.trans = nn.Sequential(
89 | nn.Linear(dim_z, 100),
90 | nn.BatchNorm1d(100),
91 | nn.ReLU(),
92 | nn.Linear(100, 100),
93 | nn.BatchNorm1d(100),
94 | nn.ReLU(),
95 | nn.Linear(100, dim_z*2)
96 | )
97 | self.fc_B = nn.Linear(dim_z, dim_z * dim_u)
98 | self.fc_o = nn.Linear(dim_z, dim_z)
99 | self.dim_z = dim_z
100 | self.dim_u = dim_u
101 |
102 |
103 | def encode(self, x):
104 | x = self.conv_layers(x)
105 | x = x.view(x.shape[0], -1)
106 | return relu(self.fc1(x)).chunk(2, dim=1)
107 |
108 | def decode(self, x):
109 | x = relu(self.fc2(x))
110 | x = x.view(-1, 128, 3, 3)
111 | return self.dconv_layers(x)
112 |
113 | def transition(self, h, Q, u):
114 | batch_size = h.size()[0]
115 | v, r = self.trans(h).chunk(2, dim=1)
116 | v1 = v.unsqueeze(2)
117 | rT = r.unsqueeze(1)
118 | I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1))
119 | if rT.data.is_cuda:
120 | I.dada.cuda()
121 | A = I.add(v1.bmm(rT))
122 |
123 | B = self.fc_B(h).view(-1, self.dim_z, self.dim_u)
124 | o = self.fc_o(h).reshape((-1, self.dim_z, 1))
125 |
126 | u = u.unsqueeze(2)
127 |
128 | d = A.bmm(Q.mu.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
129 | sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
130 | return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r)
131 |
132 | def reparam(self, mean, logvar):
133 | std = logvar.mul(0.5).exp_()
134 | self.z_mean = mean
135 | self.z_sigma = std
136 | eps = torch.FloatTensor(std.size()).normal_()
137 | if std.data.is_cuda:
138 | eps.cuda()
139 | eps = Variable(eps)
140 | return eps.mul(std).add_(mean), NormalDistribution(mean, std, torch.log(std))
141 |
142 | def forward(self, x, action, x_next):
143 | mean, logvar = self.encode(x)
144 | mean_next, logvar_next = self.encode(x_next)
145 |
146 | z, self.Qz = self.reparam(mean, logvar)
147 | z_next, self.Qz_next = self.reparam(mean_next, logvar_next)
148 |
149 | self.x_dec = self.decode(z)
150 | self.x_next_dec = self.decode(z_next)
151 |
152 | self.z_next_pred, self.Qz_next_pred = self.transition(z, self.Qz, action)
153 | self.x_next_pred_dec = self.decode(self.z_next_pred)
154 |
155 | return self.x_dec, self.x_next_pred_dec, self.Qz, self.Qz_next, self.Qz_next_pred
156 |
157 | def latent_embeddings(self, x):
158 | return self.encode(x)[0]
159 |
160 | def predict(self, X, U):
161 | mean, logvar = self.encode(X)
162 | z, Qz = self.reparam(mean, logvar)
163 | z_next_pred, Qz_next_pred = self.transition(z, Qz, U)
164 | return self.decode(z_next_pred)
165 |
166 | def binary_crossentropy(t, o, eps=1e-8):
167 | return t * torch.log(o + eps) + (1.0 - t) * torch.log(1.0 - o + eps)
168 |
169 | def compute_loss(x_dec, x_next_pred_dec, x, x_next,
170 | Qz, Qz_next_pred,
171 | Qz_next):
172 | # Reconstruction losses
173 | if False:
174 | x_reconst_loss = (x_dec - x_next).pow(2).sum(dim=1)
175 | x_next_reconst_loss = (x_next_pred_dec - x_next).pow(2).sum(dim=1)
176 | else:
177 | x_reconst_loss = -binary_crossentropy(x, x_dec).sum(dim=1)
178 | x_next_reconst_loss = -binary_crossentropy(x_next, x_next_pred_dec).sum(dim=1)
179 |
180 | logvar = Qz.logsigma.mul(2)
181 | KLD_element = Qz.mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
182 | KLD = torch.sum(KLD_element, dim=1).mul(-0.5)
183 |
184 | # ELBO
185 | bound_loss = x_reconst_loss.add(x_next_reconst_loss).add(KLD.reshape(-1,1,1))
186 | kl = KLDGaussian(Qz_next_pred, Qz_next)
187 | kl = kl[~torch.isnan(kl)]
188 | return bound_loss.mean(), kl.mean()
189 |
190 | def predict():
191 | e2c_model.eval()
192 | with torch.no_grad():
193 | for batch_idx, batch_data in enumerate(dataloader):
194 | # order: img_pre -> act_pre -> img_cur -> act_cur -> img_post
195 | # previous image
196 | img_pre = batch_data['image_bi_pre']
197 | img_pre = img_pre.float().to(device).view(-1, 1, 50, 50)
198 | # previous action
199 | act_pre = batch_data['resz_action_pre']
200 | act_pre = act_pre.float().to(device).view(-1, 4)
201 | # current image
202 | img_cur = batch_data['image_bi_cur']
203 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50)
204 | # current action
205 | act_cur = batch_data['resz_action_cur']
206 | act_cur = act_cur.float().to(device).view(-1, 4)
207 | # post image
208 | img_post = batch_data['image_bi_post']
209 | img_post = img_post.float().to(device).view(-1, 1, 50, 50)
210 | # post action
211 | act_post = batch_data['resz_action_post']
212 | act_post = act_post.float().to(device).view(-1, 4)
213 | # post2 image
214 | img_post2 = batch_data['image_bi_post2']
215 | img_post2 = img_post2.float().to(device).view(-1, 1, 50, 50)
216 | # post2 action
217 | act_post2 = batch_data['resz_action_post2']
218 | act_post2 = act_post2.float().to(device).view(-1, 4)
219 | # post3 image
220 | img_post3 = batch_data['image_bi_post3']
221 | img_post3 = img_post3.float().to(device).view(-1, 1, 50, 50)
222 | # post3 action
223 | act_post3 = batch_data['resz_action_post3']
224 | act_post3 = act_post3.float().to(device).view(-1, 4)
225 | # post4 image
226 | img_post4 = batch_data['image_bi_post4']
227 | img_post4 = img_post4.float().to(device).view(-1, 1, 50, 50)
228 | # post4 action
229 | act_post4 = batch_data['resz_action_post4']
230 | act_post4 = act_post4.float().to(device).view(-1, 4)
231 | # post5 image
232 | img_post5 = batch_data['image_bi_post5']
233 | img_post5 = img_post5.float().to(device).view(-1, 1, 50, 50)
234 | # post5 action
235 | act_post5 = batch_data['resz_action_post5']
236 | act_post5 = act_post5.float().to(device).view(-1, 4)
237 | # post6 image
238 | img_post6 = batch_data['image_bi_post6']
239 | img_post6 = img_post6.float().to(device).view(-1, 1, 50, 50)
240 | # post6 action
241 | act_post6 = batch_data['resz_action_post6']
242 | act_post6 = act_post6.float().to(device).view(-1, 4)
243 | # post7 image
244 | img_post7 = batch_data['image_bi_post7']
245 | img_post7 = img_post7.float().to(device).view(-1, 1, 50, 50)
246 | # post7 action
247 | act_post7 = batch_data['resz_action_post7']
248 | act_post7 = act_post7.float().to(device).view(-1, 4)
249 | # post8 image
250 | img_post8 = batch_data['image_bi_post8']
251 | img_post8 = img_post8.float().to(device).view(-1, 1, 50, 50)
252 | # post8 action
253 | act_post8 = batch_data['resz_action_post8']
254 | act_post8 = act_post8.float().to(device).view(-1, 4)
255 | # post9 image
256 | img_post9 = batch_data['image_bi_post9']
257 | img_post9 = img_post9.float().to(device).view(-1, 1, 50, 50)
258 | # ten step prediction
259 | # prediction for current image from pre image
260 | recon_img_cur = e2c_model.predict(img_pre, act_pre)
261 | # prediction for post image from pre image
262 | recon_img_post = e2c_model.predict(recon_img_cur, act_cur)
263 | # prediction for post2 image from pre image
264 | recon_img_post2 = e2c_model.predict(recon_img_post, act_post)
265 | # prediction for post3 image from pre image
266 | recon_img_post3 = e2c_model.predict(recon_img_post2, act_post2)
267 | # prediction for post4 image from pre image
268 | recon_img_post4 = e2c_model.predict(recon_img_post3, act_post3)
269 | # prediction for post5 image from pre image
270 | recon_img_post5 = e2c_model.predict(recon_img_post4, act_post4)
271 | # prediction for post6 image from pre image
272 | recon_img_post6 = e2c_model.predict(recon_img_post5, act_post5)
273 | # prediction for post7 image from pre image
274 | recon_img_post7 = e2c_model.predict(recon_img_post6, act_post6)
275 | # prediction for post8 image from pre image
276 | recon_img_post8 = e2c_model.predict(recon_img_post7, act_post7)
277 | # prediction for post9 image from pre image
278 | recon_img_post9 = e2c_model.predict(recon_img_post8, act_post8)
279 | if batch_idx % 10 == 0:
280 | n = min(batch_data['image_bi_pre'].size(0), 1)
281 | comparison_GT = torch.cat([batch_data['image_bi_pre'][:n],
282 | batch_data['image_bi_cur'][:n],
283 | batch_data['image_bi_post'][:n],
284 | batch_data['image_bi_post2'][:n],
285 | batch_data['image_bi_post3'][:n],
286 | batch_data['image_bi_post4'][:n],
287 | batch_data['image_bi_post5'][:n],
288 | batch_data['image_bi_post6'][:n],
289 | batch_data['image_bi_post7'][:n],
290 | batch_data['image_bi_post8'][:n],
291 | batch_data['image_bi_post9'][:n]])
292 | save_image(comparison_GT.cpu(),
293 | './result/{}/prediction_full_step{}/prediction_GT_batch{}.png'.format(folder_name, step, batch_idx), nrow=n)
294 | comparison_Pred = torch.cat([batch_data['image_bi_pre'][:n],
295 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n],
296 | recon_img_post.view(-1, 1, 50, 50).cpu()[:n],
297 | recon_img_post2.view(-1, 1, 50, 50).cpu()[:n],
298 | recon_img_post3.view(-1, 1, 50, 50).cpu()[:n],
299 | recon_img_post4.view(-1, 1, 50, 50).cpu()[:n],
300 | recon_img_post5.view(-1, 1, 50, 50).cpu()[:n],
301 | recon_img_post6.view(-1, 1, 50, 50).cpu()[:n],
302 | recon_img_post7.view(-1, 1, 50, 50).cpu()[:n],
303 | recon_img_post8.view(-1, 1, 50, 50).cpu()[:n],
304 | recon_img_post9.view(-1, 1, 50, 50).cpu()[:n]])
305 | save_image(comparison_Pred.cpu(),
306 | './result/{}/prediction_full_step{}/prediction_Pred_batch{}.png'.format(folder_name, step, batch_idx), nrow=n)
307 | #GT
308 | plot_sample_multi_step(batch_data['image_bi_pre'][:n].detach().cpu().numpy(),
309 | batch_data['image_bi_cur'][:n].detach().cpu().numpy(),
310 | batch_data['image_bi_post'][:n].detach().cpu().numpy(),
311 | batch_data['image_bi_post2'][:n].detach().cpu().numpy(),
312 | batch_data['image_bi_post3'][:n].detach().cpu().numpy(),
313 | batch_data['image_bi_post4'][:n].detach().cpu().numpy(),
314 | batch_data['image_bi_post5'][:n].detach().cpu().numpy(),
315 | batch_data['image_bi_post6'][:n].detach().cpu().numpy(),
316 | batch_data['image_bi_post7'][:n].detach().cpu().numpy(),
317 | batch_data['image_bi_post8'][:n].detach().cpu().numpy(),
318 | batch_data['image_bi_post9'][:n].detach().cpu().numpy(),
319 | batch_data['resz_action_pre'][:n].detach().cpu().numpy(),
320 | batch_data['resz_action_cur'][:n].detach().cpu().numpy(),
321 | batch_data['resz_action_post'][:n].detach().cpu().numpy(),
322 | batch_data['resz_action_post2'][:n].detach().cpu().numpy(),
323 | batch_data['resz_action_post3'][:n].detach().cpu().numpy(),
324 | batch_data['resz_action_post4'][:n].detach().cpu().numpy(),
325 | batch_data['resz_action_post5'][:n].detach().cpu().numpy(),
326 | batch_data['resz_action_post6'][:n].detach().cpu().numpy(),
327 | batch_data['resz_action_post7'][:n].detach().cpu().numpy(),
328 | batch_data['resz_action_post8'][:n].detach().cpu().numpy(),
329 | './result/{}/prediction_with_action_step{}/recon_GT_epoch_{}.png'.format(folder_name, step, batch_idx))
330 | # Predicted
331 | plot_sample_multi_step(batch_data['image_bi_pre'][:n].detach().cpu().numpy(),
332 | recon_img_cur.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(),
333 | recon_img_post.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(),
334 | recon_img_post2.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(),
335 | recon_img_post3.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(),
336 | recon_img_post4.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(),
337 | recon_img_post5.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(),
338 | recon_img_post6.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(),
339 | recon_img_post7.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(),
340 | recon_img_post8.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(),
341 | recon_img_post9.view(-1, 1, 50, 50)[:n].detach().cpu().numpy(),
342 | batch_data['resz_action_pre'][:n].detach().cpu().numpy(),
343 | batch_data['resz_action_cur'][:n].detach().cpu().numpy(),
344 | batch_data['resz_action_post'][:n].detach().cpu().numpy(),
345 | batch_data['resz_action_post2'][:n].detach().cpu().numpy(),
346 | batch_data['resz_action_post3'][:n].detach().cpu().numpy(),
347 | batch_data['resz_action_post4'][:n].detach().cpu().numpy(),
348 | batch_data['resz_action_post5'][:n].detach().cpu().numpy(),
349 | batch_data['resz_action_post6'][:n].detach().cpu().numpy(),
350 | batch_data['resz_action_post7'][:n].detach().cpu().numpy(),
351 | batch_data['resz_action_post8'][:n].detach().cpu().numpy(),
352 | './result/{}/prediction_with_action_step{}/recon_Pred_epoch_{}.png'.format(folder_name, step, batch_idx))
353 |
354 | print('***** Preparing Data *****')
355 | total_img_num = 22515
356 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num)
357 | action_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy'
358 | actions = np.load(action_path)
359 | dataset = MyDatasetMultiPred10(image_paths_bi, actions, transform=ToTensorMultiPred10())
360 | dataloader = DataLoader(dataset, batch_size=32,
361 | shuffle=True, num_workers=4, collate_fn=my_collate)
362 | print('***** Finish Preparing Data *****')
363 |
364 | folder_name = 'test_E2C'
365 | PATH = './result/{}/checkpoint'.format(folder_name)
366 |
367 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
368 | e2c_model = E2C().to(device)
369 |
370 |
371 | # load check point
372 | print('***** Load Checkpoint *****')
373 | checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
374 | e2c_model.load_state_dict(checkpoint['e2c_model_state_dict'])
375 |
376 | # prediction
377 | print('***** Start Prediction *****')
378 | step=10 # Change this based on different prediction steps
379 | if not os.path.exists('./result/{}/prediction_full_step{}'.format(folder_name, step)):
380 | os.makedirs('./result/{}/prediction_full_step{}'.format(folder_name, step))
381 | if not os.path.exists('./result/{}/prediction_with_action_step{}'.format(folder_name, step)):
382 | os.makedirs('./result/{}/prediction_with_action_step{}'.format(folder_name, step))
383 | predict()
384 | print('***** Finish Prediction *****')
--------------------------------------------------------------------------------
/model/predict_e2c_our_method_compare_gpu.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | from torch.autograd import Variable
5 | from torch import nn, optim, sigmoid, tanh, relu
6 | from torch.utils.data import Dataset, DataLoader
7 | from torch.nn import functional as F
8 | from deform.model.create_dataset import *
9 | from deform.model.hidden_dynamics import *
10 | from deform.utils.utils import plot_sample_multi_step
11 | from torchvision.utils import save_image
12 | import os
13 | import math
14 |
15 | class CAE(nn.Module):
16 | def __init__(self, latent_state_dim=80, latent_act_dim=80):
17 | super(CAE, self).__init__()
18 | # state
19 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1),
20 | nn.ReLU(),
21 | nn.MaxPool2d(3, stride=2),
22 | nn.Conv2d(32, 64, 3, padding=1),
23 | nn.ReLU(),
24 | nn.MaxPool2d(3, stride=2),
25 | nn.Conv2d(64, 128, 3, padding=1),
26 | nn.ReLU(),
27 | nn.MaxPool2d(3, stride=2),
28 | nn.Conv2d(128, 128, 3, padding=1),
29 | nn.ReLU(),
30 | nn.Conv2d(128, 128, 3, padding=1),
31 | nn.ReLU(),
32 | nn.MaxPool2d(3, stride=2, padding=1))
33 | self.fc1 = nn.Linear(128*3*3, latent_state_dim)
34 | self.fc2 = nn.Linear(latent_state_dim, 128*3*3)
35 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
36 | nn.ReLU(),
37 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
38 | nn.ReLU(),
39 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2),
40 | nn.ReLU(),
41 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2),
42 | nn.ReLU(),
43 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2),
44 | nn.Sigmoid())
45 | # action
46 | self.fc5 = nn.Linear(4, latent_act_dim)
47 | self.fc6 = nn.Linear(latent_act_dim, latent_act_dim)
48 | self.fc7 = nn.Linear(latent_act_dim, latent_act_dim)
49 | self.fc8 = nn.Linear(latent_act_dim, 4)
50 | # add these in order to use GPU for parameters
51 | self.mul_tensor = torch.tensor([50, 50, 2*math.pi, 0.14])
52 | self.add_tensor = torch.tensor([0, 0, 0, 0.01])
53 |
54 |
55 | def encoder(self, x):
56 | x = self.conv_layers(x)
57 | x = x.view(x.shape[0], -1)
58 | return relu(self.fc1(x))
59 |
60 | def decoder(self, x):
61 | x = relu(self.fc2(x))
62 | x = x.view(-1, 128, 3, 3)
63 | return self.dconv_layers(x)
64 |
65 | def encoder_act(self, u):
66 | h1 = relu(self.fc5(u))
67 | return relu(self.fc6(h1))
68 |
69 | def decoder_act(self, u):
70 | h2 = relu(self.fc7(u))
71 | if torch.cuda.is_available():
72 | return torch.mul(sigmoid(self.fc8(h2)), self.mul_tensor.cuda()) + self.add_tensor.cuda()
73 | else:
74 | return torch.mul(sigmoid(self.fc8(h2)), self.mul_tensor) + self.add_tensor
75 |
76 |
77 | def forward(self, x_cur, u, x_post):
78 | g_cur = self.encoder(x_cur)
79 | a = self.encoder_act(u)
80 | g_post = self.encoder(x_post)
81 |
82 | return g_cur, a, g_post, self.decoder(g_cur), self.decoder_act(a)
83 |
84 | class SysDynamics(nn.Module):
85 | def __init__(self, latent_state_dim=80, latent_act_dim=80):
86 | super(SysDynamics, self).__init__()
87 | self.conv_layers_matrix = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1),
88 | nn.ReLU(),
89 | nn.MaxPool2d(3, stride=1),
90 | nn.Conv2d(32, 64, 3, padding=1),
91 | nn.ReLU(),
92 | nn.MaxPool2d(3, stride=1),
93 | nn.Conv2d(64, 128, 3, padding=1),
94 | nn.ReLU(),
95 | nn.MaxPool2d(3, stride=2),
96 | nn.Conv2d(128, 256, 3, padding=1),
97 | nn.ReLU(),
98 | nn.MaxPool2d(3, stride=2),
99 | nn.Conv2d(256, 512, 3, padding=1),
100 | nn.ReLU(),
101 | nn.MaxPool2d(3, stride=2),
102 | nn.Conv2d(512, 512, 3, padding=1),
103 | nn.ReLU(),
104 | nn.Conv2d(512, 512, 3, padding=1),
105 | nn.ReLU(),
106 | nn.MaxPool2d(3, stride=2, padding=1))
107 | self.fc31 = nn.Linear(512*2*2, latent_state_dim*latent_state_dim)
108 | self.fc32 = nn.Linear(latent_state_dim*latent_state_dim, latent_state_dim*latent_state_dim)
109 | self.fc41 = nn.Linear(512*2*2 + latent_act_dim, latent_state_dim*latent_act_dim)
110 | self.fc42 = nn.Linear(latent_state_dim*latent_act_dim, latent_state_dim*latent_act_dim)
111 | self.fc9 = nn.Linear(4, latent_act_dim)
112 | self.fc10 = nn.Linear(latent_act_dim, latent_act_dim)
113 |
114 | self.latent_act_dim = latent_act_dim
115 | self.latent_state_dim = latent_state_dim
116 |
117 | def encoder_matrix(self, x, a):
118 | x = self.conv_layers_matrix(x)
119 | x = x.view(x.shape[0], -1)
120 | xa = torch.cat((x,a), 1)
121 |
122 | return relu(self.fc32(relu(self.fc31(x)))).view(-1, self.latent_state_dim, self.latent_state_dim), \
123 | relu(self.fc42(relu(self.fc41(xa)))).view(-1, self.latent_act_dim, self.latent_state_dim)
124 |
125 | def forward(self, x_cur, u):
126 | a = relu(self.fc10(relu(self.fc9(u))))
127 | K_T, L_T = self.encoder_matrix(x_cur, a)
128 |
129 | return K_T, L_T
130 |
131 | class NormalDistribution(object):
132 | """
133 | Wrapper class representing a multivariate normal distribution parameterized by
134 | N(mu,Cov). If cov. matrix is diagonal, Cov=(sigma).^2. Otherwise,
135 | Cov=A*(sigma).^2*A', where A = (I+v*r^T).
136 | """
137 |
138 | def __init__(self, mu, sigma, logsigma, *, v=None, r=None):
139 | self.mu = mu
140 | self.sigma = sigma
141 | self.logsigma = logsigma
142 | self.v = v
143 | self.r = r
144 |
145 | @property
146 | def cov(self):
147 | """This should only be called when NormalDistribution represents one sample"""
148 | if self.v is not None and self.r is not None:
149 | assert self.v.dim() == 1
150 | dim = self.v.dim()
151 | v = self.v.unsqueeze(1) # D * 1 vector
152 | rt = self.r.unsqueeze(0) # 1 * D vector
153 | A = torch.eye(dim) + v.mm(rt)
154 | return A.mm(torch.diag(self.sigma.pow(2)).mm(A.t()))
155 | else:
156 | return torch.diag(self.sigma.pow(2))
157 |
158 |
159 | def KLDGaussian(Q, N, eps=1e-8):
160 | """KL Divergence between two Gaussians
161 | Assuming Q ~ N(mu0, A\sigma_0A') where A = I + vr^{T}
162 | and N ~ N(mu1, \sigma_1)
163 | """
164 | sum = lambda x: torch.sum(x, dim=1)
165 | k = float(Q.mu.size()[1]) # dimension of distribution
166 | mu0, v, r, mu1 = Q.mu, Q.v, Q.r, N.mu
167 | s02, s12 = (Q.sigma).pow(2) + eps, (N.sigma).pow(2) + eps
168 | a = sum(s02 * (1. + 2. * v * r) / s12) + sum(v.pow(2) / s12) * sum(r.pow(2) * s02) # trace term
169 | b = sum((mu1 - mu0).pow(2) / s12) # difference-of-means term
170 | c = 2. * (sum(N.logsigma - Q.logsigma) - torch.log(1. + sum(v * r) + eps)) # ratio-of-determinants term.
171 |
172 | return 0.5 * (a + b - k + c)
173 |
174 | class E2C(nn.Module):
175 | def __init__(self, dim_z=80, dim_u=4):
176 | super(E2C, self).__init__()
177 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1),
178 | nn.ReLU(),
179 | nn.MaxPool2d(3, stride=2),
180 | nn.Conv2d(32, 64, 3, padding=1),
181 | nn.ReLU(),
182 | nn.MaxPool2d(3, stride=2),
183 | nn.Conv2d(64, 128, 3, padding=1),
184 | nn.ReLU(),
185 | nn.MaxPool2d(3, stride=2),
186 | nn.Conv2d(128, 128, 3, padding=1),
187 | nn.ReLU(),
188 | nn.Conv2d(128, 128, 3, padding=1),
189 | nn.ReLU(),
190 | nn.MaxPool2d(3, stride=2, padding=1))
191 | self.fc1 = nn.Linear(128*3*3, dim_z*2)
192 | self.fc2 = nn.Linear(dim_z, 128*3*3)
193 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
194 | nn.ReLU(),
195 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
196 | nn.ReLU(),
197 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2),
198 | nn.ReLU(),
199 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2),
200 | nn.ReLU(),
201 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2),
202 | nn.Sigmoid())
203 | self.trans = nn.Sequential(
204 | nn.Linear(dim_z, 100),
205 | nn.BatchNorm1d(100),
206 | nn.ReLU(),
207 | nn.Linear(100, 100),
208 | nn.BatchNorm1d(100),
209 | nn.ReLU(),
210 | nn.Linear(100, dim_z*2)
211 | )
212 | self.fc_B = nn.Linear(dim_z, dim_z * dim_u)
213 | self.fc_o = nn.Linear(dim_z, dim_z)
214 | self.dim_z = dim_z
215 | self.dim_u = dim_u
216 | # action
217 | self.fc5 = nn.Linear(4, dim_u*20)
218 | self.fc6 = nn.Linear(dim_u*20, dim_u)
219 | self.fc7 = nn.Linear(dim_u, dim_u*20)
220 | self.fc8 = nn.Linear(dim_u*20, 4)
221 | self.mul_tensor = torch.tensor([50, 50, 2*math.pi, 0.14])
222 | self.add_tensor = torch.tensor([0, 0, 0, 0.01])
223 |
224 | def encode(self, x):
225 | x = self.conv_layers(x)
226 | x = x.view(x.shape[0], -1)
227 | return relu(self.fc1(x)).chunk(2, dim=1)
228 |
229 | def decode(self, x):
230 | x = relu(self.fc2(x))
231 | x = x.view(-1, 128, 3, 3)
232 | return self.dconv_layers(x)
233 |
234 | def encode_act(self, u):
235 | h1 = relu(self.fc5(u))
236 | return relu(self.fc6(h1))
237 |
238 | def decode_act(self, u):
239 | h2 = relu(self.fc7(u))
240 | return torch.mul(sigmoid(self.fc8(h2)), self.mul_tensor.cuda()) + self.add_tensor.cuda()
241 |
242 | def transition(self, h, Q, u):
243 | batch_size = h.size()[0]
244 | v, r = self.trans(h).chunk(2, dim=1)
245 | v1 = v.unsqueeze(2).cpu()
246 | rT = r.unsqueeze(1).cpu()
247 | I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1))
248 | if rT.data.is_cuda:
249 | I.data.cuda()
250 | A = I.add(v1.bmm(rT)).cuda()
251 |
252 | B = self.fc_B(h).view(-1, self.dim_z, self.dim_u)
253 | o = self.fc_o(h).reshape((-1, self.dim_z, 1))
254 |
255 | u = u.unsqueeze(2)
256 |
257 | d = A.bmm(Q.mu.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
258 | sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
259 | return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r)
260 |
261 | def reparam(self, mean, logvar):
262 | std = logvar.mul(0.5).exp_()
263 | self.z_mean = mean
264 | self.z_sigma = std
265 | eps = torch.FloatTensor(std.size()).normal_()
266 | eps = Variable(eps)
267 | return eps.mul(std.cpu()).add_(mean.cpu()).cuda(), NormalDistribution(mean, std, torch.log(std))
268 |
269 | def forward(self, x, action, x_next):
270 | mean, logvar = self.encode(x)
271 | mean_next, logvar_next = self.encode(x_next)
272 |
273 | z, self.Qz = self.reparam(mean, logvar)
274 | z_next, self.Qz_next = self.reparam(mean_next, logvar_next)
275 |
276 | self.x_dec = self.decode(z)
277 | self.x_next_dec = self.decode(z_next)
278 |
279 | latent_a = self.encode_act(action)
280 | action_dec = self.decode_act(latent_a)
281 |
282 | self.z_next_pred, self.Qz_next_pred = self.transition(z, self.Qz, latent_a)
283 | self.x_next_pred_dec = self.decode(self.z_next_pred)
284 |
285 |
286 | return self.x_dec, self.x_next_dec, self.x_next_pred_dec, self.Qz, self.Qz_next, self.Qz_next_pred, action_dec
287 |
288 | def latent_embeddings(self, x):
289 | return self.encode(x)[0]
290 |
291 | def predict(self, X, U):
292 | mean, logvar = self.encode(X)
293 | z, Qz = self.reparam(mean, logvar)
294 | z_next_pred, Qz_next_pred = self.transition(z, Qz, U)
295 | return self.decode(z_next_pred)
296 |
297 | def binary_crossentropy(t, o, eps=1e-8):
298 | return t * torch.log(o + eps) + (1.0 - t) * torch.log(1.0 - o + eps)
299 |
300 | def compute_loss(x_dec, x_next_pred_dec, x, x_next,
301 | Qz, Qz_next_pred,
302 | Qz_next):
303 | # Reconstruction losses
304 | if False:
305 | x_reconst_loss = (x_dec - x_next).pow(2).sum(dim=1)
306 | x_next_reconst_loss = (x_next_pred_dec - x_next).pow(2).sum(dim=1)
307 | else:
308 | x_reconst_loss = -binary_crossentropy(x, x_dec).sum(dim=1)
309 | x_next_reconst_loss = -binary_crossentropy(x_next, x_next_pred_dec).sum(dim=1)
310 |
311 | logvar = Qz.logsigma.mul(2)
312 | KLD_element = Qz.mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
313 | KLD = torch.sum(KLD_element, dim=1).mul(-0.5)
314 |
315 | # ELBO
316 | bound_loss = x_reconst_loss.add(x_next_reconst_loss).add(KLD.reshape(-1,1,1))
317 | kl = KLDGaussian(Qz_next_pred, Qz_next)
318 | kl = kl[~torch.isnan(kl)]
319 | return bound_loss.mean(), kl.mean()
320 |
321 | def predict():
322 | e2c_model.eval()
323 | recon_model.eval()
324 | dyn_model.eval()
325 | with torch.no_grad():
326 | for batch_idx, batch_data in enumerate(dataloader):
327 | # order: img_pre -> act_pre -> img_cur -> act_cur -> img_post
328 | # previous image
329 | img_pre = batch_data['image_bi_pre']
330 | img_pre = img_pre.float().to(device).view(-1, 1, 50, 50)
331 | # previous action
332 | act_pre = batch_data['resz_action_pre']
333 | act_pre = act_pre.float().to(device).view(-1, 4)
334 | # current image
335 | img_cur = batch_data['image_bi_cur']
336 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50)
337 | # current action
338 | act_cur = batch_data['resz_action_cur']
339 | act_cur = act_cur.float().to(device).view(-1, 4)
340 | # post image
341 | img_post = batch_data['image_bi_post']
342 | img_post = img_post.float().to(device).view(-1, 1, 50, 50)
343 | # post action
344 | act_post = batch_data['resz_action_post']
345 | act_post = act_post.float().to(device).view(-1, 4)
346 | # post2 image
347 | img_post2 = batch_data['image_bi_post2']
348 | img_post2 = img_post2.float().to(device).view(-1, 1, 50, 50)
349 | # post2 action
350 | act_post2 = batch_data['resz_action_post2']
351 | act_post2 = act_post2.float().to(device).view(-1, 4)
352 | # post3 image
353 | img_post3 = batch_data['image_bi_post3']
354 | img_post3 = img_post3.float().to(device).view(-1, 1, 50, 50)
355 | # post3 action
356 | act_post3 = batch_data['resz_action_post3']
357 | act_post3 = act_post3.float().to(device).view(-1, 4)
358 | # post4 image
359 | img_post4 = batch_data['image_bi_post4']
360 | img_post4 = img_post4.float().to(device).view(-1, 1, 50, 50)
361 | # post4 action
362 | act_post4 = batch_data['resz_action_post4']
363 | act_post4 = act_post4.float().to(device).view(-1, 4)
364 | # post5 image
365 | img_post5 = batch_data['image_bi_post5']
366 | img_post5 = img_post5.float().to(device).view(-1, 1, 50, 50)
367 | # post5 action
368 | act_post5 = batch_data['resz_action_post5']
369 | act_post5 = act_post5.float().to(device).view(-1, 4)
370 | # post6 image
371 | img_post6 = batch_data['image_bi_post6']
372 | img_post6 = img_post6.float().to(device).view(-1, 1, 50, 50)
373 | # post6 action
374 | act_post6 = batch_data['resz_action_post6']
375 | act_post6 = act_post6.float().to(device).view(-1, 4)
376 | # post7 image
377 | img_post7 = batch_data['image_bi_post7']
378 | img_post7 = img_post7.float().to(device).view(-1, 1, 50, 50)
379 | # post7 action
380 | act_post7 = batch_data['resz_action_post7']
381 | act_post7 = act_post7.float().to(device).view(-1, 4)
382 | # post8 image
383 | img_post8 = batch_data['image_bi_post8']
384 | img_post8 = img_post8.float().to(device).view(-1, 1, 50, 50)
385 | # post8 action
386 | act_post8 = batch_data['resz_action_post8']
387 | act_post8 = act_post8.float().to(device).view(-1, 4)
388 | # post9 image
389 | img_post9 = batch_data['image_bi_post9']
390 | img_post9 = img_post9.float().to(device).view(-1, 1, 50, 50)
391 | # ten step prediction
392 | # prediction for current image from pre image
393 | recon_img_cur = e2c_model.predict(img_pre, act_pre)
394 | # prediction for post image from pre image
395 | recon_img_post = e2c_model.predict(recon_img_cur, act_cur)
396 | # prediction for post2 image from pre image
397 | recon_img_post2 = e2c_model.predict(recon_img_post, act_post)
398 | # prediction for post3 image from pre image
399 | recon_img_post3 = e2c_model.predict(recon_img_post2, act_post2)
400 | # prediction for post4 image from pre image
401 | recon_img_post4 = e2c_model.predict(recon_img_post3, act_post3)
402 | # prediction for post5 image from pre image
403 | recon_img_post5 = e2c_model.predict(recon_img_post4, act_post4)
404 | # prediction for post6 image from pre image
405 | recon_img_post6 = e2c_model.predict(recon_img_post5, act_post5)
406 | # prediction for post7 image from pre image
407 | recon_img_post7 = e2c_model.predict(recon_img_post6, act_post6)
408 | # prediction for post8 image from pre image
409 | recon_img_post8 = e2c_model.predict(recon_img_post7, act_post7)
410 | # prediction for post9 image from pre image
411 | recon_img_post9 = e2c_model.predict(recon_img_post8, act_post8)
412 | if batch_idx % 5 == 0:
413 | n = min(batch_data['image_bi_pre'].size(0), 1)
414 | comparison_GT = torch.cat([batch_data['image_bi_pre'][:n],
415 | batch_data['image_bi_cur'][:n],
416 | batch_data['image_bi_post'][:n],
417 | batch_data['image_bi_post2'][:n],
418 | batch_data['image_bi_post3'][:n],
419 | batch_data['image_bi_post4'][:n],
420 | batch_data['image_bi_post5'][:n],
421 | batch_data['image_bi_post6'][:n],
422 | batch_data['image_bi_post7'][:n],
423 | batch_data['image_bi_post8'][:n],
424 | batch_data['image_bi_post9'][:n]])
425 | save_image(comparison_GT.cpu(),
426 | './result/{}/prediction_full_step{}/prediction_GT_batch{}.png'.format(folder_name_e2c, step, batch_idx), nrow=n)
427 | comparison_Pred = torch.cat([batch_data['image_bi_pre'][:n],
428 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n],
429 | recon_img_post.view(-1, 1, 50, 50).cpu()[:n],
430 | recon_img_post2.view(-1, 1, 50, 50).cpu()[:n],
431 | recon_img_post3.view(-1, 1, 50, 50).cpu()[:n],
432 | recon_img_post4.view(-1, 1, 50, 50).cpu()[:n],
433 | recon_img_post5.view(-1, 1, 50, 50).cpu()[:n],
434 | recon_img_post6.view(-1, 1, 50, 50).cpu()[:n],
435 | recon_img_post7.view(-1, 1, 50, 50).cpu()[:n],
436 | recon_img_post8.view(-1, 1, 50, 50).cpu()[:n],
437 | recon_img_post9.view(-1, 1, 50, 50).cpu()[:n]])
438 | save_image(comparison_Pred.cpu(),
439 | './result/{}/prediction_full_step{}/prediction_Pred_batch{}.png'.format(folder_name_e2c, step, batch_idx), nrow=n)
440 | # ten step prediction
441 | # prediction for current image from pre image
442 | latent_img_pre, latent_act_pre, _, _, _ = recon_model(img_pre, act_pre, img_cur)
443 | K_T_pre, L_T_pre = dyn_model(img_pre, act_pre)
444 | recon_latent_img_cur = get_next_state_linear(latent_img_pre, latent_act_pre, K_T_pre, L_T_pre)
445 | recon_img_cur = recon_model.decoder(recon_latent_img_cur)
446 | # prediction for post image from pre image
447 | _, latent_act_cur, _, _, _ = recon_model(img_cur, act_cur, img_post)
448 | K_T_cur, L_T_cur = dyn_model(recon_img_cur, act_cur)
449 | recon_latent_img_post = get_next_state_linear(recon_latent_img_cur, latent_act_cur, K_T_cur, L_T_cur)
450 | recon_img_post = recon_model.decoder(recon_latent_img_post)
451 | # prediction for post2 image from pre image
452 | _, latent_act_post, _, _, _ = recon_model(img_post, act_post, img_post2)
453 | K_T_post, L_T_post = dyn_model(recon_img_post, act_post)
454 | recon_latent_img_post2 = get_next_state_linear(recon_latent_img_post, latent_act_post, K_T_post, L_T_post)
455 | recon_img_post2 = recon_model.decoder(recon_latent_img_post2)
456 | # prediction for post3 image from pre image
457 | _, latent_act_post2, _, _, _ = recon_model(img_post2, act_post2, img_post3)
458 | K_T_post2, L_T_post2 = dyn_model(recon_img_post2, act_post2)
459 | recon_latent_img_post3 = get_next_state_linear(recon_latent_img_post2, latent_act_post2, K_T_post2, L_T_post2)
460 | recon_img_post3 = recon_model.decoder(recon_latent_img_post3)
461 | # prediction for post4 image from pre image
462 | _, latent_act_post3, _, _, _ = recon_model(img_post3, act_post3, img_post4)
463 | K_T_post3, L_T_post3 = dyn_model(recon_img_post3, act_post3)
464 | recon_latent_img_post4 = get_next_state_linear(recon_latent_img_post3, latent_act_post3, K_T_post3, L_T_post3)
465 | recon_img_post4 = recon_model.decoder(recon_latent_img_post4)
466 | # prediction for post5 image from pre image
467 | _, latent_act_post4, _, _, _ = recon_model(img_post4, act_post4, img_post5)
468 | K_T_post4, L_T_post4 = dyn_model(recon_img_post4, act_post4)
469 | recon_latent_img_post5 = get_next_state_linear(recon_latent_img_post4, latent_act_post4, K_T_post4, L_T_post4)
470 | recon_img_post5 = recon_model.decoder(recon_latent_img_post5)
471 | # prediction for post6 image from pre image
472 | _, latent_act_post5, _, _, _ = recon_model(img_post5, act_post5, img_post6)
473 | K_T_post5, L_T_post5 = dyn_model(recon_img_post5, act_post5)
474 | recon_latent_img_post6 = get_next_state_linear(recon_latent_img_post5, latent_act_post5, K_T_post5, L_T_post5)
475 | recon_img_post6 = recon_model.decoder(recon_latent_img_post6)
476 | # prediction for post7 image from pre image
477 | _, latent_act_post6, _, _, _ = recon_model(img_post6, act_post6, img_post7)
478 | K_T_post6, L_T_post6 = dyn_model(recon_img_post6, act_post6)
479 | recon_latent_img_post7 = get_next_state_linear(recon_latent_img_post6, latent_act_post6, K_T_post6, L_T_post6)
480 | recon_img_post7 = recon_model.decoder(recon_latent_img_post7)
481 | # prediction for post8 image from pre image
482 | _, latent_act_post7, _, _, _ = recon_model(img_post7, act_post7, img_post8)
483 | K_T_post7, L_T_post7 = dyn_model(recon_img_post7, act_post7)
484 | recon_latent_img_post8 = get_next_state_linear(recon_latent_img_post7, latent_act_post7, K_T_post7, L_T_post7)
485 | recon_img_post8 = recon_model.decoder(recon_latent_img_post8)
486 | # prediction for post9 image from pre image
487 | _, latent_act_post8, _, _, _ = recon_model(img_post8, act_post8, img_post9)
488 | K_T_post8, L_T_post8 = dyn_model(recon_img_post8, act_post8)
489 | recon_latent_img_post9 = get_next_state_linear(recon_latent_img_post8, latent_act_post8, K_T_post8, L_T_post8)
490 | recon_img_post9 = recon_model.decoder(recon_latent_img_post9)
491 | if batch_idx % 5 == 0:
492 | n = min(batch_data['image_bi_pre'].size(0), 1)
493 | comparison_GT = torch.cat([batch_data['image_bi_pre'][:n],
494 | batch_data['image_bi_cur'][:n],
495 | batch_data['image_bi_post'][:n],
496 | batch_data['image_bi_post2'][:n],
497 | batch_data['image_bi_post3'][:n],
498 | batch_data['image_bi_post4'][:n],
499 | batch_data['image_bi_post5'][:n],
500 | batch_data['image_bi_post6'][:n],
501 | batch_data['image_bi_post7'][:n],
502 | batch_data['image_bi_post8'][:n],
503 | batch_data['image_bi_post9'][:n]])
504 | save_image(comparison_GT.cpu(),
505 | './result/{}/prediction_full_step{}/prediction_GT_batch{}.png'.format(folder_name_our, step, batch_idx), nrow=n)
506 | comparison_Pred = torch.cat([batch_data['image_bi_pre'][:n],
507 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n],
508 | recon_img_post.view(-1, 1, 50, 50).cpu()[:n],
509 | recon_img_post2.view(-1, 1, 50, 50).cpu()[:n],
510 | recon_img_post3.view(-1, 1, 50, 50).cpu()[:n],
511 | recon_img_post4.view(-1, 1, 50, 50).cpu()[:n],
512 | recon_img_post5.view(-1, 1, 50, 50).cpu()[:n],
513 | recon_img_post6.view(-1, 1, 50, 50).cpu()[:n],
514 | recon_img_post7.view(-1, 1, 50, 50).cpu()[:n],
515 | recon_img_post8.view(-1, 1, 50, 50).cpu()[:n],
516 | recon_img_post9.view(-1, 1, 50, 50).cpu()[:n]])
517 | save_image(comparison_Pred.cpu(),
518 | './result/{}/prediction_full_step{}/prediction_Pred_batch{}.png'.format(folder_name_our, step, batch_idx), nrow=n)
519 |
520 | print('***** Preparing Data *****')
521 | total_img_num = 22515
522 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num)
523 | action_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy'
524 | actions = np.load(action_path)
525 | dataset = MyDatasetMultiPred10(image_paths_bi, actions, transform=ToTensorMultiPred10())
526 | dataloader = DataLoader(dataset, batch_size=32,
527 | shuffle=True, num_workers=4, collate_fn=my_collate)
528 | print('***** Finish Preparing Data *****')
529 |
530 | folder_name_e2c = 'test_E2C_gpu_update_loss'
531 | PATH_e2c = './result/{}/checkpoint'.format(folder_name_e2c)
532 | folder_name_our = 'test_act80_pred30'
533 | PATH_our = './result/{}/checkpoint'.format(folder_name_our)
534 |
535 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
536 | e2c_model = E2C().to(device)
537 | recon_model = CAE().to(device)
538 | dyn_model = SysDynamics().to(device)
539 |
540 | # load check point
541 | print('***** Load Checkpoint *****')
542 | checkpoint_e2c = torch.load(PATH_e2c, map_location=torch.device('cpu'))
543 | e2c_model.load_state_dict(checkpoint_e2c['e2c_model_state_dict'])
544 | checkpoint_our = torch.load(PATH_our, map_location=torch.device('cpu'))
545 | recon_model.load_state_dict(checkpoint_our['recon_model_state_dict'])
546 | dyn_model.load_state_dict(checkpoint_our['dyn_model_state_dict'])
547 |
548 | # prediction
549 | print('***** Start Prediction *****')
550 | step=10 # Change this based on different prediction steps
551 | if not os.path.exists('./result/{}/prediction_full_step{}'.format(folder_name_e2c, step)):
552 | os.makedirs('./result/{}/prediction_full_step{}'.format(folder_name_e2c, step))
553 | if not os.path.exists('./result/{}/prediction_with_action_step{}'.format(folder_name_e2c, step)):
554 | os.makedirs('./result/{}/prediction_with_action_step{}'.format(folder_name_e2c, step))
555 | if not os.path.exists('./result/{}/prediction_full_step{}'.format(folder_name_our, step)):
556 | os.makedirs('./result/{}/prediction_full_step{}'.format(folder_name_our, step))
557 | if not os.path.exists('./result/{}/prediction_with_action_step{}'.format(folder_name_our, step)):
558 | os.makedirs('./result/{}/prediction_with_action_step{}'.format(folder_name_our, step))
559 | predict()
560 | print('***** Finish Prediction *****')
561 |
--------------------------------------------------------------------------------
/model/predict_large_linear_seprt_no_loop_Kp_Lpa.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | from torch import nn, optim, sigmoid, tanh, relu
5 | from torch.utils.data import Dataset, DataLoader
6 | from torch.nn import functional as F
7 | from deform.model.create_dataset import *
8 | from deform.model.hidden_dynamics import *
9 | from torchvision.utils import save_image
10 | import os
11 | import math
12 |
13 | class CAE(nn.Module):
14 | def __init__(self, latent_state_dim=80, latent_act_dim=80):
15 | super(CAE, self).__init__()
16 | # state
17 | self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1),
18 | nn.ReLU(),
19 | nn.MaxPool2d(3, stride=2),
20 | nn.Conv2d(32, 64, 3, padding=1),
21 | nn.ReLU(),
22 | nn.MaxPool2d(3, stride=2),
23 | nn.Conv2d(64, 128, 3, padding=1),
24 | nn.ReLU(),
25 | nn.MaxPool2d(3, stride=2),
26 | nn.Conv2d(128, 128, 3, padding=1),
27 | nn.ReLU(),
28 | nn.Conv2d(128, 128, 3, padding=1),
29 | nn.ReLU(),
30 | nn.MaxPool2d(3, stride=2, padding=1))
31 | self.fc1 = nn.Linear(128*3*3, latent_state_dim)
32 | self.fc2 = nn.Linear(latent_state_dim, 128*3*3)
33 | self.dconv_layers = nn.Sequential(nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
34 | nn.ReLU(),
35 | nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1),
36 | nn.ReLU(),
37 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=2),
38 | nn.ReLU(),
39 | nn.ConvTranspose2d(64, 32, 3, stride=2, padding=2),
40 | nn.ReLU(),
41 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=2),
42 | nn.Sigmoid())
43 | # action
44 | self.fc5 = nn.Linear(4, latent_act_dim)
45 | self.fc6 = nn.Linear(latent_act_dim, latent_act_dim)
46 | self.fc7 = nn.Linear(latent_act_dim, latent_act_dim)
47 | self.fc8 = nn.Linear(latent_act_dim, 4)
48 |
49 | self.mul_tensor = torch.tensor([50, 50, 2*math.pi, 0.14])
50 | self.add_tensor = torch.tensor([0, 0, 0, 0.01])
51 |
52 |
53 | def encoder(self, x):
54 | x = self.conv_layers(x)
55 | x = x.view(x.shape[0], -1)
56 | return relu(self.fc1(x))
57 |
58 | def decoder(self, x):
59 | x = relu(self.fc2(x))
60 | x = x.view(-1, 128, 3, 3)
61 | return self.dconv_layers(x)
62 |
63 | def encoder_act(self, u):
64 | h1 = relu(self.fc5(u))
65 | return relu(self.fc6(h1))
66 |
67 | def decoder_act(self, u):
68 | h2 = relu(self.fc7(u))
69 | return torch.mul(sigmoid(self.fc8(h2)), self.mul_tensor.cuda()) + self.add_tensor.cuda()
70 |
71 | def forward(self, x_cur, u, x_post):
72 | g_cur = self.encoder(x_cur)
73 | a = self.encoder_act(u)
74 | g_post = self.encoder(x_post)
75 |
76 | return g_cur, a, g_post, self.decoder(g_cur), self.decoder_act(a)
77 |
78 | class SysDynamics(nn.Module):
79 | def __init__(self, latent_state_dim=80, latent_act_dim=80):
80 | super(SysDynamics, self).__init__()
81 | self.conv_layers_matrix = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1),
82 | nn.ReLU(),
83 | nn.MaxPool2d(3, stride=1),
84 | nn.Conv2d(32, 64, 3, padding=1),
85 | nn.ReLU(),
86 | nn.MaxPool2d(3, stride=1),
87 | nn.Conv2d(64, 128, 3, padding=1),
88 | nn.ReLU(),
89 | nn.MaxPool2d(3, stride=2),
90 | nn.Conv2d(128, 256, 3, padding=1),
91 | nn.ReLU(),
92 | nn.MaxPool2d(3, stride=2),
93 | nn.Conv2d(256, 512, 3, padding=1),
94 | nn.ReLU(),
95 | nn.MaxPool2d(3, stride=2),
96 | nn.Conv2d(512, 512, 3, padding=1),
97 | nn.ReLU(),
98 | nn.Conv2d(512, 512, 3, padding=1),
99 | nn.ReLU(),
100 | nn.MaxPool2d(3, stride=2, padding=1))
101 | self.fc31 = nn.Linear(512*2*2, latent_state_dim*latent_state_dim)
102 | self.fc32 = nn.Linear(latent_state_dim*latent_state_dim, latent_state_dim*latent_state_dim)
103 | self.fc41 = nn.Linear(512*2*2 + latent_act_dim, latent_state_dim*latent_act_dim)
104 | self.fc42 = nn.Linear(latent_state_dim*latent_act_dim, latent_state_dim*latent_act_dim)
105 | self.fc9 = nn.Linear(4, latent_act_dim)
106 | self.fc10 = nn.Linear(latent_act_dim, latent_act_dim)
107 | # latent dim
108 | self.latent_act_dim = latent_act_dim
109 | self.latent_state_dim = latent_state_dim
110 |
111 | def encoder_matrix(self, x, a):
112 | x = self.conv_layers_matrix(x)
113 | x = x.view(x.shape[0], -1)
114 | xa = torch.cat((x,a), 1)
115 |
116 | return relu(self.fc32(relu(self.fc31(x)))).view(-1, self.latent_state_dim, self.latent_state_dim), \
117 | relu(self.fc42(relu(self.fc41(xa)))).view(-1, self.latent_act_dim, self.latent_state_dim)
118 |
119 | def forward(self, x_cur, u):
120 | a = relu(self.fc10(relu(self.fc9(u))))
121 | K_T, L_T = self.encoder_matrix(x_cur, a)
122 |
123 | return K_T, L_T
124 |
125 | def predict():
126 | recon_model.eval()
127 | dyn_model.eval()
128 | with torch.no_grad():
129 | for batch_idx, batch_data in enumerate(dataloader):
130 | # order: img_pre -> act_pre -> img_cur -> act_cur -> img_post
131 | # previous image
132 | img_pre = batch_data['image_bi_pre']
133 | img_pre = img_pre.float().to(device).view(-1, 1, 50, 50)
134 | # previous action
135 | act_pre = batch_data['resz_action_pre']
136 | act_pre = act_pre.float().to(device).view(-1, 4)
137 | # current image
138 | img_cur = batch_data['image_bi_cur']
139 | img_cur = img_cur.float().to(device).view(-1, 1, 50, 50)
140 | # current action
141 | act_cur = batch_data['resz_action_cur']
142 | act_cur = act_cur.float().to(device).view(-1, 4)
143 | # post image
144 | img_post = batch_data['image_bi_post']
145 | img_post = img_post.float().to(device).view(-1, 1, 50, 50)
146 | # prediction for current image
147 | latent_img_pre, latent_act_pre, _, _, _ = recon_model(img_pre, act_pre, img_cur)
148 | K_T_pre, L_T_pre = dyn_model(img_pre, act_pre)
149 | recon_latent_img_cur = get_next_state_linear(latent_img_pre, latent_act_pre, K_T_pre, L_T_pre)
150 | recon_img_cur = recon_model.decoder(recon_latent_img_cur)
151 | # prediction for post image
152 | latent_img_cur, latent_act_cur, _, _, _ = recon_model(img_cur, act_cur, img_post)
153 | K_T_cur, L_T_cur = dyn_model(img_cur, act_cur)
154 | recon_latent_img_post = get_next_state_linear(latent_img_cur, latent_act_cur, K_T_cur, L_T_cur)
155 | recon_img_post = recon_model.decoder(recon_latent_img_post)
156 | if batch_idx % 10 == 0:
157 | n = min(batch_data['image_bi_pre'].size(0), 8)
158 | comparison = torch.cat([batch_data['image_bi_pre'][:n],
159 | batch_data['image_bi_cur'][:n],
160 | recon_img_cur.view(-1, 1, 50, 50).cpu()[:n],
161 | batch_data['image_bi_post'][:n],
162 | recon_img_post.view(-1, 1, 50, 50).cpu()[:n]])
163 | save_image(comparison.cpu(),
164 | './result/{}/prediction_full_step{}/prediction_batch{}.png'.format(folder_name, step, batch_idx), nrow=n)
165 |
166 |
167 | print('***** Preparing Data *****')
168 | total_img_num = 22515
169 | image_paths_bi = create_image_path('rope_no_loop_all_resize_gray_clean', total_img_num)
170 | action_path = './rope_dataset/rope_no_loop_all_resize_gray_clean/simplified_clean_actions_all_size50.npy'
171 | actions = np.load(action_path)
172 | dataset = MyDataset(image_paths_bi, actions, transform=ToTensor())
173 | dataloader = DataLoader(dataset, batch_size=64,
174 | shuffle=True, num_workers=4, collate_fn=my_collate)
175 | print('***** Finish Preparing Data *****')
176 |
177 | folder_name = 'test_act80_pred50'
178 | PATH = './result/{}/checkpoint'.format(folder_name)
179 |
180 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
181 | recon_model = CAE().to(device)
182 | dyn_model = SysDynamics().to(device)
183 |
184 | # load check point
185 | print('***** Load Checkpoint *****')
186 | checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
187 | recon_model.load_state_dict(checkpoint['recon_model_state_dict'])
188 | dyn_model.load_state_dict(checkpoint['dyn_model_state_dict'])
189 |
190 | # prediction
191 | print('***** Start Prediction *****')
192 | step=1
193 | if not os.path.exists('./result/{}/prediction_full_step{}'.format(folder_name, step)):
194 | os.makedirs('./result/{}/prediction_full_step{}'.format(folder_name, step))
195 | predict()
196 | print('***** Finish Prediction *****')
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/params.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/params.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbgood6/deform/eb7fc4f55d20812efb5575890bec652c4a420f68/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import matplotlib.pyplot as plt
4 | import os
5 | from matplotlib.lines import Line2D
6 |
7 | def plot_grad_flow(named_parameters, folder_name):
8 | '''Plots the gradients flowing through different layers in the net during training.
9 | Can be used for checking for possible gradient vanishing / exploding problems.
10 |
11 | Usage: Plug this function in Trainer class after loss.backwards() as
12 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow
13 |
14 | Source: https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10
15 | '''
16 | ave_grads = []
17 | layers = []
18 | for n, p in named_parameters:
19 | if(p.requires_grad) and ("bias" not in n):
20 | layers.append(n)
21 | ave_grads.append(p.grad.abs().mean())
22 | plt.plot(ave_grads, alpha=0.3, color="b")
23 | plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
24 | plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
25 | plt.xlim(xmin=0, xmax=len(ave_grads))
26 | plt.xlabel("Layers")
27 | plt.ylabel("average gradient")
28 | plt.title("Gradient flow")
29 | plt.grid(True)
30 | plt.savefig('./result/{}/plot/gradient_flow.png'.format(folder_name))
31 | plt.close()
32 |
33 | def create_loss_list(loss_logger, kld=True):
34 | if loss_logger is None:
35 | train_loss_all = []
36 | train_img_loss_all = []
37 | train_act_loss_all = []
38 | train_latent_loss_all = []
39 | train_pred_loss_all = []
40 | train_kld_loss_all = []
41 | test_loss_all = []
42 | test_img_loss_all = []
43 | test_act_loss_all = []
44 | test_latent_loss_all = []
45 | test_pred_loss_all = []
46 | test_kld_loss_all = []
47 | else:
48 | train_loss_all = loss_logger['train_loss_all']
49 | train_img_loss_all = loss_logger['train_img_loss_all']
50 | train_act_loss_all = loss_logger['train_act_loss_all']
51 | train_latent_loss_all = loss_logger['train_latent_loss_all']
52 | train_pred_loss_all = loss_logger['train_pred_loss_all']
53 | test_loss_all = loss_logger['test_loss_all']
54 | test_img_loss_all = loss_logger['test_img_loss_all']
55 | test_act_loss_all = loss_logger['test_act_loss_all']
56 | test_latent_loss_all = loss_logger['test_latent_loss_all']
57 | test_pred_loss_all = loss_logger['test_pred_loss_all']
58 | if kld is True:
59 | train_kld_loss_all = loss_logger['train_kld_loss_all']
60 | test_kld_loss_all = loss_logger['test_kld_loss_all']
61 | else:
62 | train_kld_loss_all = []
63 | test_kld_loss_all = []
64 | return train_loss_all, train_img_loss_all, train_act_loss_all, train_latent_loss_all, train_pred_loss_all, train_kld_loss_all, \
65 | test_loss_all, test_img_loss_all, test_act_loss_all, test_latent_loss_all, test_pred_loss_all, test_kld_loss_all
66 |
67 | def create_folder(folder_name):
68 | if not os.path.exists('./result/' + folder_name):
69 | os.makedirs('./result/' + folder_name)
70 | if not os.path.exists('./result/' + folder_name + '/plot'):
71 | os.makedirs('./result/' + folder_name + '/plot')
72 | if not os.path.exists('./result/' + folder_name + '/reconstruction_test'):
73 | os.makedirs('./result/' + folder_name + '/reconstruction_test')
74 | if not os.path.exists('./result/' + folder_name + '/reconstruction_train'):
75 | os.makedirs('./result/' + folder_name + '/reconstruction_train')
76 | if not os.path.exists('./result/' + folder_name + '/reconstruction_act_train'):
77 | os.makedirs('./result/' + folder_name + '/reconstruction_act_train')
78 | if not os.path.exists('./result/' + folder_name + '/reconstruction_act_test'):
79 | os.makedirs('./result/' + folder_name + '/reconstruction_act_test')
80 |
81 | def rect(poke, c, label=None):
82 | x, y, t, l = poke
83 | dx = -100 * l * math.cos(t)
84 | dy = -100 * l * math.sin(t)
85 | arrow = plt.arrow(x, y, dx, dy, width=0.001, head_width=6, head_length=6, color=c, label=label)
86 |
87 |
88 | def plot_action(resz_action, recon_action, directory):
89 | plt.figure()
90 | # upper row original
91 | plt.subplot(1, 2, 1)
92 | rect(resz_action[i], "blue")
93 | plt.axis('off')
94 | # middle row reconstruction
95 | plt.subplot(1, 2, 2)
96 | rect(recon_action[i], "red")
97 | plt.axis('off')
98 | plt.savefig(directory)
99 | plt.close()
100 |
101 | def change_img_dim(img):
102 | img_tmp = np.ones((img.shape[1], img.shape[2], img.shape[0]))
103 | img_tmp[:,:,0] = img[0]
104 | img_tmp[:,:,1] = img[1]
105 | img_tmp[:,:,2] = img[2]
106 | return img_tmp
107 |
108 | def plot_sample(img_before, img_after, resz_action, recon_action, directory):
109 | plt.figure()
110 | N = int(img_before.shape[0])
111 | for i in range(N):
112 | # upper row original
113 | plt.subplot(3, N, i+1)
114 | rect(resz_action[i], "blue")
115 | plt.imshow(change_img_dim(img_before[i]))
116 | plt.axis('off')
117 | # middle row reconstruction
118 | plt.subplot(3, N, i+1+N)
119 | rect(recon_action[i], "red")
120 | plt.imshow(change_img_dim(img_before[i]))
121 | plt.axis('off')
122 | # lower row: next image after action
123 | plt.subplot(3, N, i+1+2*N)
124 | plt.imshow(change_img_dim(img_after[i]))
125 | plt.axis('off')
126 | plt.savefig(directory)
127 | plt.close()
128 |
129 | def plot_cem_sample(img_before, img_after, img_after_pred, resz_action, recon_action, directory):
130 | # source from rope.ipynb in Berkeley's rope dataset file
131 | plt.figure()
132 | # upper row original
133 | plt.subplot(2, 2, 1)
134 | rect(resz_action, "blue", "Ground Truth Action")
135 | plt.imshow(img_before.reshape((50,50,-1)), cmap='gray')
136 | plt.axis('off')
137 | # middle row reconstruction
138 | plt.subplot(2, 2, 2)
139 | plt.imshow(img_after.reshape((50,50,-1)), cmap='gray')
140 | plt.axis('off')
141 | # lower row: next image after action
142 | plt.subplot(2, 2, 3)
143 | rect(recon_action, "red", "Sampled Action")
144 | plt.imshow(img_before.reshape((50,50,-1)), cmap='gray')
145 | plt.axis('off')
146 | # lower row: next image after action
147 | plt.subplot(2, 2, 4)
148 | plt.imshow(img_after_pred.reshape((50,50,-1)), cmap='gray')
149 | plt.axis('off')
150 | plt.savefig(directory)
151 | plt.close()
152 |
153 | def plot_sample_multi_step(img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, act1, act2,\
154 | act3, act4, act5, act6, act7, act8, act9, act10, directory):
155 | # multi-step prediction with action on the image
156 | plt.figure()
157 | N = int(img1.shape[0])
158 | plt.subplots_adjust(wspace=0, hspace=0)
159 | for i in range(N):
160 | # img1, act1
161 | plt.subplot(11, N, i+1)
162 | rect(act1[i], "red")
163 | plt.imshow(img1[i].reshape((50,50,-1)), cmap='binary')
164 | plt.axis('off')
165 | # img2, act2
166 | plt.subplot(11, N, i+1+N)
167 | rect(act2[i], "red")
168 | plt.imshow(img2[i].reshape((50,50,-1)), cmap='binary')
169 | plt.axis('off')
170 | # img3, act3
171 | plt.subplot(11, N, i+1+2*N)
172 | rect(act3[i], "red")
173 | plt.imshow(img3[i].reshape((50,50,-1)), cmap='binary')
174 | plt.axis('off')
175 | # img4, act4
176 | plt.subplot(11, N, i+1+3*N)
177 | rect(act4[i], "red")
178 | plt.imshow(img4[i].reshape((50,50,-1)), cmap='binary')
179 | plt.axis('off')
180 | # img5, act5
181 | plt.subplot(11, N, i+1+4*N)
182 | rect(act5[i], "red")
183 | plt.imshow(img5[i].reshape((50,50,-1)), cmap='binary')
184 | plt.axis('off')
185 | # img6, act6
186 | plt.subplot(11, N, i+1+5*N)
187 | rect(act6[i], "red")
188 | plt.imshow(img6[i].reshape((50,50,-1)), cmap='binary')
189 | plt.axis('off')
190 | # img7, act7
191 | plt.subplot(11, N, i+1+6*N)
192 | rect(act7[i], "red")
193 | plt.imshow(img7[i].reshape((50,50,-1)), cmap='binary')
194 | plt.axis('off')
195 | # img8, act8
196 | plt.subplot(11, N, i+1+7*N)
197 | rect(act8[i], "red")
198 | plt.imshow(img8[i].reshape((50,50,-1)), cmap='binary')
199 | plt.axis('off')
200 | # img9, act9
201 | plt.subplot(11, N, i+1+8*N)
202 | rect(act9[i], "red")
203 | plt.imshow(img9[i].reshape((50,50,-1)), cmap='binary')
204 | plt.axis('off')
205 | # img10, act10
206 | plt.subplot(11, N, i+1+9*N)
207 | rect(act10[i], "red")
208 | plt.imshow(img10[i].reshape((50,50,-1)), cmap='binary')
209 | plt.axis('off')
210 | # img11, act11
211 | plt.subplot(11, N, i+1+10*N)
212 | plt.imshow(img11[i].reshape((50,50,-1)), cmap='binary')
213 | plt.axis('off')
214 | plt.savefig(directory)
215 | plt.close()
216 |
217 | def generate_initial_points(x, y, num_points, link_length):
218 | """generate initial points for a line
219 | x: the first (start from left) point's x position
220 | y: the first (start from left) point's y position
221 | num_points: number of points on a line
222 | link_length: each segment's length
223 | """
224 | x_all = [x]
225 | y_all = [y]
226 | for _ in range(num_points-1):
227 | phi = np.random.uniform(-np.pi/10, np.pi/10)
228 | x1, y1 = x + link_length * np.cos(phi), y + link_length * np.sin(phi)
229 | x_all.append(x1)
230 | y_all.append(y1)
231 | x, y = x1, y1
232 |
233 | return x_all, y_all
234 |
235 | def generate_new_line(line_x_all, line_y_all, index, move_angle, move_length, link_length):
236 | ###
237 | # line_x_all: x position for all points on a line
238 | # i.e. np.array([x1, x2, x3,..., xn])
239 | # line_y_all: y position for all points on a line
240 | # i.e. np.array([y1, y2, y3,..., yn])
241 | # index: touching point's index in the line. Scalar.
242 | # grip_pos_before: gripper's position before moving in 2D
243 | # assume it is the same as touching posiiton on a line
244 | # i.e. np.array([x, y])
245 | # grip_pos_after: gripper's position after moving in 2D
246 | # move_angle: gripper's moving angle. [0, 2*pi]
247 | # move_length: gripper's moving length. Scalar.
248 | # link_length: constant distance between two nearby points. Scalar.
249 | # num_points: number of points on the line. Scalar.
250 | # action: relative x and y position change.
251 | # i.e. np.array([delta_x, delta_y])
252 | ###
253 |
254 | num_points = np.size(line_x_all)
255 | # initialize new_line_x_all and new_line_y_all
256 | new_line_x_all = [0] * num_points
257 | new_line_y_all = [0] * num_points
258 |
259 | # action
260 | action = get_action(move_angle, move_length)
261 |
262 | # gripper position (touching position on the line) before moving
263 | grip_pos_before = np.array([line_x_all[index], line_y_all[index]])
264 |
265 | # gripper position after moving
266 | grip_pos_after = get_pos_after(grip_pos_before, action)
267 | new_line_x_all[index] = grip_pos_after[0]
268 | new_line_y_all[index] = grip_pos_after[1]
269 |
270 | # move points in the left side in order
271 | if index != 0:
272 | grip_pos_after_temp = grip_pos_after
273 | for i in range(index):
274 | new_index_left = index - (i+1)
275 | moved_pos_before = np.array([line_x_all[new_index_left], line_y_all[new_index_left]])
276 | moved_pos_after = generate_new_point_pos_on_the_line(grip_pos_after_temp, moved_pos_before, link_length)
277 | grip_pos_after_temp = moved_pos_after
278 | new_line_x_all[new_index_left] = grip_pos_after_temp[0]
279 | new_line_y_all[new_index_left] = grip_pos_after_temp[1]
280 |
281 | # move points in the right side in oder
282 | if index != (num_points-1):
283 | grip_pos_after_temp = grip_pos_after
284 | for j in range(num_points-index-1):
285 | new_index_right = index + (j+1)
286 | moved_pos_before = np.array([line_x_all[new_index_right], line_y_all[new_index_right]])
287 | moved_pos_after = generate_new_point_pos_on_the_line(grip_pos_after_temp, moved_pos_before, link_length)
288 | grip_pos_after_temp = moved_pos_after
289 | new_line_x_all[new_index_right] = grip_pos_after_temp[0]
290 | new_line_y_all[new_index_right] = grip_pos_after_temp[1]
291 |
292 | # # move points in the right side in order
293 | # # touch the line
294 | # line_index = get_line_index(gripper_x_pos, x_all)
295 | # touch_line_pos = [x_all[line_index], y_all[line_index], x_all[line_index+1], y_all[line_index+1]] # [x1, y1, x2, y2]
296 |
297 | # # move the touching line
298 | # touch_line_pos += [action[0], action[1], action[0], action[1]]
299 |
300 |
301 | return new_line_x_all, new_line_y_all
302 |
303 | def generate_new_point_pos_on_the_line(grip_pos_after, moved_pos_before, link_length):
304 | # get relative positions between moved_pos_before and grip_pos_after
305 | # cannot change the order of these two points
306 | delta_x = moved_pos_before[0] - grip_pos_after[0]
307 | delta_y = moved_pos_before[1] - grip_pos_after[1]
308 | angle = math.atan2(delta_y, delta_x)
309 |
310 | # get nearby point position after moving
311 | x_after = grip_pos_after[0] + link_length * np.cos(angle)
312 | y_after = grip_pos_after[1] + link_length * np.sin(angle)
313 | moved_pos_after = np.array([x_after, y_after])
314 |
315 | return moved_pos_after
316 |
317 | def get_action(angle, length):
318 | action = np.array([length*np.cos(angle), length*np.sin(angle)])
319 |
320 | return action
321 |
322 | def get_pos_after(grip_pos_before, action):
323 | x, y = grip_pos_before[0], grip_pos_before[1]
324 | pos_after = np.array([x+action[0], y+action[1]])
325 |
326 | return pos_after
327 |
328 | def get_line_index(gripper_x_pos, x_all):
329 | line_index = sum(gripper_x_pos >= np.array(x_all)) - 1
330 |
331 | return line_index
332 |
333 |
334 | def plot_train_loss(file_name, folder_name):
335 | train_loss = np.load(file_name)
336 | plt.figure()
337 | plt.plot(train_loss)
338 | plt.title('Train Loss')
339 | plt.xlabel('Epochs')
340 | plt.ylabel('Loss')
341 | plt.savefig('./result/{}/plot/train_loss.png'.format(folder_name))
342 | plt.close()
343 |
344 | def plot_test_loss(file_name, folder_name):
345 | test_loss = np.load(file_name)
346 | plt.figure()
347 | plt.plot(test_loss)
348 | plt.title('Test Loss')
349 | plt.xlabel('Epochs')
350 | plt.ylabel('Loss')
351 | plt.savefig('./result/{}/plot/test_loss.png'.format(folder_name))
352 | plt.close()
353 |
354 | def plot_train_img_loss(file_name, folder_name):
355 | img_loss = np.load(file_name)
356 | plt.figure()
357 | plt.plot(img_loss)
358 | plt.title('Train Image Loss')
359 | plt.xlabel('Epochs')
360 | plt.ylabel('Loss')
361 | plt.savefig('./result/{}/plot/train_image_loss.png'.format(folder_name))
362 | plt.close()
363 |
364 | def plot_train_act_loss(file_name, folder_name):
365 | act_loss = np.load(file_name)
366 | plt.figure()
367 | plt.plot(act_loss)
368 | plt.title('Train Action Loss')
369 | plt.xlabel('Epochs')
370 | plt.ylabel('Loss')
371 | plt.savefig('./result/{}/plot/train_action_loss.png'.format(folder_name))
372 | plt.close()
373 |
374 | def plot_train_latent_loss(file_name, folder_name):
375 | latent_loss = np.load(file_name)
376 | plt.figure()
377 | plt.plot(latent_loss)
378 | plt.title('Train Latent Loss')
379 | plt.xlabel('Epochs')
380 | plt.ylabel('Loss')
381 | plt.savefig('./result/{}/plot/train_latent_loss.png'.format(folder_name))
382 | plt.close()
383 |
384 | def plot_train_pred_loss(file_name, folder_name):
385 | pred_loss = np.load(file_name)
386 | plt.figure()
387 | plt.plot(pred_loss)
388 | plt.title('Train Prediction Loss')
389 | plt.xlabel('Epochs')
390 | plt.ylabel('Loss')
391 | plt.savefig('./result/{}/plot/train_prediction_loss.png'.format(folder_name))
392 | plt.close()
393 |
394 | def plot_train_kld_loss(file_name, folder_name):
395 | kld_loss = np.load(file_name)
396 | plt.figure()
397 | plt.plot(kld_loss)
398 | plt.title('Train KL Divergence Loss')
399 | plt.xlabel('Epochs')
400 | plt.ylabel('Loss')
401 | plt.savefig('./result/{}/plot/train_kld_loss.png'.format(folder_name))
402 | plt.close()
403 |
404 | def plot_test_img_loss(file_name, folder_name):
405 | img_loss = np.load(file_name)
406 | plt.figure()
407 | plt.plot(img_loss)
408 | plt.title('Test Image Loss')
409 | plt.xlabel('Epochs')
410 | plt.ylabel('Loss')
411 | plt.savefig('./result/{}/plot/test_image_loss.png'.format(folder_name))
412 | plt.close()
413 |
414 | def plot_test_act_loss(file_name, folder_name):
415 | act_loss = np.load(file_name)
416 | plt.figure()
417 | plt.plot(act_loss)
418 | plt.title('Test Action Loss')
419 | plt.xlabel('Epochs')
420 | plt.ylabel('Loss')
421 | plt.savefig('./result/{}/plot/test_action_loss.png'.format(folder_name))
422 | plt.close()
423 |
424 | def plot_test_latent_loss(file_name, folder_name):
425 | latent_loss = np.load(file_name)
426 | plt.figure()
427 | plt.plot(latent_loss)
428 | plt.title('Test Latent Loss')
429 | plt.xlabel('Epochs')
430 | plt.ylabel('Loss')
431 | plt.savefig('./result/{}/plot/test_latent_loss.png'.format(folder_name))
432 | plt.close()
433 |
434 | def plot_test_pred_loss(file_name, folder_name):
435 | pred_loss = np.load(file_name)
436 | plt.figure()
437 | plt.plot(pred_loss)
438 | plt.title('Test Prediction Loss')
439 | plt.xlabel('Epochs')
440 | plt.ylabel('Loss')
441 | plt.savefig('./result/{}/plot/test_prediction_loss.png'.format(folder_name))
442 | plt.close()
443 |
444 | def plot_test_kld_loss(file_name, folder_name):
445 | kld_loss = np.load(file_name)
446 | plt.figure()
447 | plt.plot(kld_loss)
448 | plt.title('Test KL Divergence Loss')
449 | plt.xlabel('Epochs')
450 | plt.ylabel('Loss')
451 | plt.savefig('./result/{}/plot/test_kld_loss.png'.format(folder_name))
452 | plt.close()
453 |
454 | def plot_train_bound_loss(file_name, folder_name):
455 | bound_loss = np.load(file_name)
456 | plt.figure()
457 | plt.plot(bound_loss)
458 | plt.title('Train Bound Loss')
459 | plt.xlabel('Epochs')
460 | plt.ylabel('Loss')
461 | plt.savefig('./result/{}/plot/train_bound_loss.png'.format(folder_name))
462 | plt.close()
463 |
464 | def plot_test_bound_loss(file_name, folder_name):
465 | bound_loss = np.load(file_name)
466 | plt.figure()
467 | plt.plot(bound_loss)
468 | plt.title('Test Bound Loss')
469 | plt.xlabel('Epochs')
470 | plt.ylabel('Loss')
471 | plt.savefig('./result/{}/plot/test_bound_loss.png'.format(folder_name))
472 | plt.close()
473 |
474 | def plot_train_kl_loss(file_name, folder_name):
475 | kl_loss = np.load(file_name)
476 | plt.figure()
477 | plt.plot(kl_loss)
478 | plt.title('Train KL Loss')
479 | plt.xlabel('Epochs')
480 | plt.ylabel('Loss')
481 | plt.savefig('./result/{}/plot/train_kl_loss.png'.format(folder_name))
482 | plt.close()
483 |
484 | def plot_test_kl_loss(file_name, folder_name):
485 | kl_loss = np.load(file_name)
486 | plt.figure()
487 | plt.plot(kl_loss)
488 | plt.title('Test KL Loss')
489 | plt.xlabel('Epochs')
490 | plt.ylabel('Loss')
491 | plt.savefig('./result/{}/plot/test_kl_loss.png'.format(folder_name))
492 | plt.close()
493 |
494 | def plot_all_train_loss_with_noise(train, test, img, act, latent, pred, kld, folder_name):
495 | train_loss = np.load(train)
496 | test_loss = np.load(test)
497 | img_loss = np.load(img)
498 | act_loss = np.load(act)
499 | latent_loss = np.load(latent)
500 | pred_loss = np.load(pred)
501 | kld_loss = np.load(kld)
502 | plt.figure()
503 | train_curve, = plt.plot(train_loss, label='Train')
504 | test_curve, = plt.plot(test_loss, label='Test')
505 | img_curve, = plt.plot(img_loss, label='Image')
506 | act_curve, = plt.plot(act_loss, label='Action')
507 | latent_curve, = plt.plot(latent_loss, label='Latent')
508 | pred_curve, = plt.plot(pred_loss, label='Prediction')
509 | kld_curve, = plt.plot(kld_loss, label='KL Divergence')
510 | plt.title('Train loss and its subcomponents')
511 | plt.xlabel('Epochs')
512 | plt.ylabel('Loss')
513 | plt.legend([train_curve, test_curve, img_curve, act_curve, latent_curve, pred_curve, kld_curve], ['Train', 'Test', 'Image', 'Action', 'Latent', 'Prediction', 'KL Divergence'])
514 | plt.savefig('./result/{}/plot/all_train_loss.png'.format(folder_name))
515 | plt.close()
516 |
517 | def plot_all_test_loss_with_noise(test, img, act, latent, pred, kld, folder_name):
518 | test_loss = np.load(test)
519 | img_loss = np.load(img)
520 | act_loss = np.load(act)
521 | latent_loss = np.load(latent)
522 | pred_loss = np.load(pred)
523 | kld_loss = np.load(kld)
524 | plt.figure()
525 | test_curve, = plt.plot(test_loss, label='Test')
526 | img_curve, = plt.plot(img_loss, label='Image')
527 | act_curve, = plt.plot(act_loss, label='Action')
528 | latent_curve, = plt.plot(latent_loss, label='Latent')
529 | pred_curve, = plt.plot(pred_loss, label='Prediction')
530 | kld_curve, = plt.plot(kld_loss, label='KL Divergence')
531 | plt.title('Test loss and its subcomponents')
532 | plt.xlabel('Epochs')
533 | plt.ylabel('Loss')
534 | plt.legend([test_curve, img_curve, act_curve, latent_curve, pred_curve, kld_curve], ['Test', 'Image', 'Action', 'Latent', 'Prediction', 'KL Divergence'])
535 | plt.savefig('./result/{}/plot/all_test_loss.png'.format(folder_name))
536 | plt.close()
537 |
538 | def plot_all_train_loss_without_noise(train, test, img, act, latent, pred, folder_name):
539 | train_loss = np.load(train)
540 | test_loss = np.load(test)
541 | img_loss = np.load(img)
542 | act_loss = np.load(act)
543 | latent_loss = np.load(latent)
544 | pred_loss = np.load(pred)
545 | plt.figure()
546 | train_curve, = plt.plot(train_loss, label='Train')
547 | test_curve, = plt.plot(test_loss, label='Test')
548 | img_curve, = plt.plot(img_loss, label='Image')
549 | act_curve, = plt.plot(act_loss, label='Action')
550 | latent_curve, = plt.plot(latent_loss, label='Latent')
551 | pred_curve, = plt.plot(pred_loss, label='Prediction')
552 | plt.title('Train loss and its subcomponents')
553 | plt.xlabel('Epochs')
554 | plt.ylabel('Loss')
555 | plt.legend([train_curve, test_curve, img_curve, act_curve, latent_curve, pred_curve], ['Train', 'Test', 'Image', 'Action', 'Latent', 'Prediction'])
556 | plt.savefig('./result/{}/plot/all_train_loss.png'.format(folder_name))
557 | plt.close()
558 |
559 | def plot_all_test_loss_without_noise(test, img, act, latent, pred, folder_name):
560 | test_loss = np.load(test)
561 | img_loss = np.load(img)
562 | act_loss = np.load(act)
563 | latent_loss = np.load(latent)
564 | pred_loss = np.load(pred)
565 | plt.figure()
566 | test_curve, = plt.plot(test_loss, label='Test')
567 | img_curve, = plt.plot(img_loss, label='Image')
568 | act_curve, = plt.plot(act_loss, label='Action')
569 | latent_curve, = plt.plot(latent_loss, label='Latent')
570 | pred_curve, = plt.plot(pred_loss, label='Prediction')
571 | plt.title('Test loss and its subcomponents')
572 | plt.xlabel('Epochs')
573 | plt.ylabel('Loss')
574 | plt.legend([test_curve, img_curve, act_curve, latent_curve, pred_curve], ['Test', 'Image', 'Action', 'Latent', 'Prediction'])
575 | plt.savefig('./result/{}/plot/all_test_loss.png'.format(folder_name))
576 | plt.close()
577 |
578 | def save_data(folder_name, epochs, train_loss_all, train_img_loss_all, train_act_loss_all,
579 | train_latent_loss_all, train_pred_loss_all,
580 | test_loss_all, test_img_loss_all, test_act_loss_all, test_latent_loss_all,
581 | test_pred_loss_all, train_kld_loss_all=None, test_kld_loss_all=None, K=None, L=None):
582 | np.save('./result/{}/train_loss_epoch{}.npy'.format(folder_name, epochs), train_loss_all)
583 | np.save('./result/{}/train_img_loss_epoch{}.npy'.format(folder_name, epochs), train_img_loss_all)
584 | np.save('./result/{}/train_act_loss_epoch{}.npy'.format(folder_name, epochs), train_act_loss_all)
585 | np.save('./result/{}/train_latent_loss_epoch{}.npy'.format(folder_name, epochs), train_latent_loss_all)
586 | np.save('./result/{}/train_pred_loss_epoch{}.npy'.format(folder_name, epochs), train_pred_loss_all)
587 | np.save('./result/{}/test_loss_epoch{}.npy'.format(folder_name, epochs), test_loss_all)
588 | np.save('./result/{}/test_img_loss_epoch{}.npy'.format(folder_name, epochs), test_img_loss_all)
589 | np.save('./result/{}/test_act_loss_epoch{}.npy'.format(folder_name, epochs), test_act_loss_all)
590 | np.save('./result/{}/test_latent_loss_epoch{}.npy'.format(folder_name, epochs), test_latent_loss_all)
591 | np.save('./result/{}/test_pred_loss_epoch{}.npy'.format(folder_name, epochs), test_pred_loss_all)
592 | if train_kld_loss_all is not None:
593 | np.save('./result/{}/train_kld_loss_epoch{}.npy'.format(folder_name, epochs), train_kld_loss_all)
594 | if test_kld_loss_all is not None:
595 | np.save('./result/{}/test_kld_loss_epoch{}.npy'.format(folder_name, epochs), test_kld_loss_all)
596 | if K is not None:
597 | np.save('./result/{}/koopman_matrix.npy'.format(folder_name), K)
598 | if L is not None:
599 | np.save('./result/{}/control_matrix.npy'.format(folder_name), L)
600 |
601 | def save_e2c_data(folder_name, epochs, train_loss_all, train_bound_loss_all, train_kl_loss_all, train_pred_loss_all, \
602 | test_loss_all, test_bound_loss_all, test_kl_loss_all, test_pred_loss_all):
603 | np.save('./result/{}/train_loss_epoch{}.npy'.format(folder_name, epochs), train_loss_all)
604 | np.save('./result/{}/train_bound_loss_epoch{}.npy'.format(folder_name, epochs), train_bound_loss_all)
605 | np.save('./result/{}/train_kl_loss_epoch{}.npy'.format(folder_name, epochs), train_kl_loss_all)
606 | np.save('./result/{}/train_pred_loss_epoch{}.npy'.format(folder_name, epochs), train_pred_loss_all)
607 | np.save('./result/{}/test_loss_epoch{}.npy'.format(folder_name, epochs), test_loss_all)
608 | np.save('./result/{}/test_bound_loss_epoch{}.npy'.format(folder_name, epochs), test_bound_loss_all)
609 | np.save('./result/{}/test_kl_loss_epoch{}.npy'.format(folder_name, epochs), test_kl_loss_all)
610 | np.save('./result/{}/test_pred_loss_epoch{}.npy'.format(folder_name, epochs), test_pred_loss_all)
611 |
--------------------------------------------------------------------------------