├── .gitignore
├── LICENSE
├── README.md
├── eipl
├── __init__.py
├── data
│ ├── __init__.py
│ ├── data_dict.py
│ ├── dataset.py
│ └── downloader.py
├── layer
│ ├── CoordConv2d.py
│ ├── GridMask.py
│ ├── MultipleTimescaleRNN.py
│ ├── SpatialSoftmax.py
│ └── __init__.py
├── model
│ ├── BasicRNN.py
│ ├── CAE.py
│ ├── CAEBN.py
│ ├── CNNRNN.py
│ ├── CNNRNNLN.py
│ ├── SARNN.py
│ └── __init__.py
├── test
│ ├── benchmark_dataloader.py
│ ├── test_CoordConv2d.py
│ ├── test_GridMask.py
│ ├── test_LossScheduler.py
│ ├── test_SampleDownloader.py
│ ├── test_SpatialSoftmax.py
│ ├── test_WeightDownloader.py
│ ├── test_bounds.py
│ ├── test_cos_interpolation.py
│ ├── test_dataloader.py
│ └── test_models.py
├── tutorials
│ ├── airec
│ │ ├── ros
│ │ │ ├── 1_rosbag2npz.py
│ │ │ ├── 2_make_dataset.py
│ │ │ └── 3_check_data.py
│ │ └── sarnn
│ │ │ ├── bin
│ │ │ ├── test.py
│ │ │ ├── test_pca_sarnn.py
│ │ │ └── train.py
│ │ │ ├── libs
│ │ │ └── fullBPTT.py
│ │ │ ├── log
│ │ │ └── .gitignore
│ │ │ └── output
│ │ │ └── .gitignore
│ ├── open_manipulator
│ │ ├── ros
│ │ │ ├── bag2npz
│ │ │ │ ├── 1_rosbag2npz.py
│ │ │ │ ├── 2_make_dataset.py
│ │ │ │ ├── 3_check_data.py
│ │ │ │ ├── bag
│ │ │ │ │ └── .gitignore
│ │ │ │ ├── data
│ │ │ │ │ └── .gitignore
│ │ │ │ └── utils.py
│ │ │ └── om_teleop
│ │ │ │ ├── CMakeLists.txt
│ │ │ │ ├── config
│ │ │ │ └── d435_param.yaml
│ │ │ │ ├── launch
│ │ │ │ ├── playback_bringup.launch
│ │ │ │ ├── rt_bringup.launch
│ │ │ │ └── teleop_bringup.launch
│ │ │ │ ├── package.xml
│ │ │ │ └── src
│ │ │ │ ├── dynamixel_driver.py
│ │ │ │ ├── dynamixel_utils.py
│ │ │ │ ├── follower_bringup.py
│ │ │ │ ├── interplation_node.py
│ │ │ │ ├── leader_bringup.py
│ │ │ │ └── playback.py
│ │ └── sarnn
│ │ │ ├── bin
│ │ │ ├── export_onnx.py
│ │ │ ├── test.py
│ │ │ ├── test_pca_sarnn.py
│ │ │ └── train.py
│ │ │ ├── libs
│ │ │ └── fullBPTT.py
│ │ │ ├── log
│ │ │ └── .gitignore
│ │ │ └── output
│ │ │ └── .gitignore
│ └── robosuite
│ │ ├── README.md
│ │ ├── sarnn
│ │ ├── bin
│ │ │ ├── test.py
│ │ │ ├── test_pca_sarnn.py
│ │ │ └── train.py
│ │ ├── libs
│ │ │ └── fullBPTT.py
│ │ ├── log
│ │ │ └── .gitignore
│ │ └── output
│ │ │ └── .gitignore
│ │ └── simulator
│ │ ├── 2_resave.sh
│ │ ├── 3_check_data.sh
│ │ ├── bin
│ │ ├── 1_teaching.py
│ │ ├── 2_resave.py
│ │ ├── 3_check_playback_data.py
│ │ ├── 4_generate_dataset.py
│ │ ├── 5_check_dataset.py
│ │ └── 6_rt_control.py
│ │ ├── data
│ │ └── .gitignore
│ │ ├── libs
│ │ ├── devices.py
│ │ ├── environment.py
│ │ ├── rt_control_wrapper.py
│ │ ├── samplers.py
│ │ └── utils.py
│ │ └── output
│ │ └── .gitignore
├── utils
│ ├── __init__.py
│ ├── arg_utils.py
│ ├── callback.py
│ ├── check_gpu.py
│ ├── convert_compiled_pth.py
│ ├── data.py
│ ├── nn_func.py
│ ├── path_utils.py
│ ├── print_func.py
│ ├── resave_pth.py
│ └── utils.py
└── zoo
│ ├── cae
│ ├── bin
│ │ ├── extract.py
│ │ ├── test.py
│ │ ├── test_pca_cae.py
│ │ └── train.py
│ ├── data
│ │ └── .gitignore
│ ├── libs
│ │ └── trainer.py
│ ├── log
│ │ └── .gitignore
│ └── output
│ │ └── .gitignore
│ ├── cnnrnn
│ ├── bin
│ │ ├── test.py
│ │ ├── test_pca_cnnrnn.py
│ │ └── train.py
│ ├── libs
│ │ └── fullBPTT.py
│ ├── log
│ │ └── .gitignore
│ └── output
│ │ └── .gitignore
│ └── rnn
│ ├── bin
│ ├── test.py
│ ├── test_pca_rnn.py
│ └── train.py
│ ├── libs
│ ├── dataloader.py
│ └── fullBPTT.py
│ ├── log
│ └── .gitignore
│ └── output
│ └── .gitignore
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # weight
2 | *.pth
3 | *.onnx
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
114 | .pdm.toml
115 | .pdm-python
116 | .pdm-build/
117 |
118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119 | __pypackages__/
120 |
121 | # Celery stuff
122 | celerybeat-schedule
123 | celerybeat.pid
124 |
125 | # SageMath parsed files
126 | *.sage.py
127 |
128 | # Environments
129 | .env
130 | .venv
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 |
137 | # Spyder project settings
138 | .spyderproject
139 | .spyproject
140 |
141 | # Rope project settings
142 | .ropeproject
143 |
144 | # mkdocs documentation
145 | /site
146 |
147 | # mypy
148 | .mypy_cache/
149 | .dmypy.json
150 | dmypy.json
151 |
152 | # Pyre type checker
153 | .pyre/
154 |
155 | # pytype static type analyzer
156 | .pytype/
157 |
158 | # Cython debug symbols
159 | cython_debug/
160 |
161 | # PyCharm
162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164 | # and can be added to the global gitignore or merged into this file. For a more nuclear
165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166 | #.idea/
167 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Ogata Laboratory (Waseda University)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 |
6 |
7 |
8 | ## What's EIPL?
9 |
10 | EIPL (Embodied Intelligence with Deep Predictive Learning) is a library for robot motion generation using deep predictive learning developed at [Ogata Laboratory](https://ogata-lab.jp/), [Waseda University](https://www.waseda.jp/top/en).
11 | Highlighted features include:
12 |
13 | - [**Full documentation**](https://ogata-lab.github.io/eipl-docs) for the systematic understanding of deep predictive learning
14 | - **Easy model training:** Includes sample datasets, source code, and pre-trained weights
15 | - **Applicable to real robots:** Generalized motion can be acquired with small data sets
16 |
17 | ## Install
18 |
19 | ```sh
20 | pip install -r requirements.txt
21 | pip install -e .
22 | ```
23 |
24 | ## Citation
25 |
26 | ```
27 | @article{suzuki2023deep,
28 | author = {Kanata Suzuki and Hiroshi Ito and Tatsuro Yamada and Kei Kase and Tetsuya Ogata},
29 | title = {Deep Predictive Learning : Motion Learning Concept inspired by Cognitive Robotics},
30 | booktitle = {arXiv preprint arXiv:2306.14714},
31 | year = {2023},
32 | }
33 | ```
34 |
35 | ## LICENSE
36 |
37 | MIT License
--------------------------------------------------------------------------------
/eipl/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import *
2 | from .layer import *
3 | from .utils import *
4 |
--------------------------------------------------------------------------------
/eipl/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .downloader import SampleDownloader, WeightDownloader
2 | from .dataset import *
3 |
--------------------------------------------------------------------------------
/eipl/data/data_dict.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | data_dict = {
7 | "airec": {
8 | "grasp_bottle": {
9 | "https://drive.google.com/uc?id=11NhuH35DR2Pg7gWqOIhw5yJ_Nmeto-wO"
10 | }
11 | },
12 | "om": {
13 | "grasp_cube": {
14 | "https://drive.google.com/uc?id=16M0wEPleiRAucZeGRcUel5yO8ou7R-SA"
15 | }
16 | },
17 | }
18 |
--------------------------------------------------------------------------------
/eipl/layer/CoordConv2d.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.modules.conv as conv
9 |
10 |
11 | class AddCoords(nn.Module):
12 | """AddCoords
13 |
14 | Arguments:
15 | min_range (flaot, optional): Minimum value for xx_channel, yy_channel. The default is 0. But original paper is -1.
16 | with_r (Boolean, optional): Wether add the radial channel. The default is True.
17 | """
18 |
19 | def __init__(self, min_range=0.0, with_r=False):
20 | super(AddCoords, self).__init__()
21 | self.min_range = min_range
22 | self.with_r = with_r
23 |
24 | def forward(self, x):
25 | batch_size, channels, width, height = x.shape
26 | device = x.device
27 |
28 | xx_channel, yy_channel = torch.meshgrid(
29 | torch.linspace(self.min_range, 1.0, height, dtype=torch.float32),
30 | torch.linspace(self.min_range, 1.0, width, dtype=torch.float32),
31 | indexing="ij",
32 | )
33 | xx_channel = xx_channel.expand(batch_size, 1, width, height).to(device)
34 | yy_channel = yy_channel.expand(batch_size, 1, width, height).to(device)
35 |
36 | y = torch.cat([x, xx_channel, yy_channel], dim=1)
37 |
38 | if self.with_r:
39 | rr = torch.sqrt(
40 | torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)
41 | )
42 | y = torch.cat([y, rr], dim=1)
43 |
44 | return y
45 |
46 |
47 | class CoordConv2d(conv.Conv2d):
48 | """CoordConv2d
49 | Rosanne Liu, Joel Lehman, Piero Molino, Felipe Petroski Such, Eric Frank, Alex Sergeev, Jason Yosinski,
50 | ``An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution.``,
51 | NeurIPS 2018.
52 | https://arxiv.org/abs/1807.03247v2
53 | """
54 |
55 | def __init__(
56 | self,
57 | input_size,
58 | output_size,
59 | kernel_size,
60 | stride=1,
61 | padding=0,
62 | dilation=1,
63 | groups=1,
64 | bias=True,
65 | min_range=0.0,
66 | with_r=False,
67 | ):
68 | super(CoordConv2d, self).__init__(
69 | input_size,
70 | output_size,
71 | kernel_size,
72 | stride,
73 | padding,
74 | dilation,
75 | groups,
76 | bias,
77 | )
78 | rank = 2
79 | self.addcoords = AddCoords(min_range=min_range, with_r=with_r)
80 | self.conv = nn.Conv2d(
81 | input_size + rank + int(with_r),
82 | output_size,
83 | kernel_size,
84 | stride,
85 | padding,
86 | dilation,
87 | groups,
88 | bias,
89 | )
90 |
91 | def forward(self, x):
92 | hid = self.addcoords(x)
93 | y = self.conv(hid)
94 |
95 | return y
96 |
--------------------------------------------------------------------------------
/eipl/layer/GridMask.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import numpy as np
8 |
9 |
10 | class GridMask:
11 | """GridMask
12 |
13 | Arguments:
14 | p (flaot, optional): Mask spacing
15 | d_range (Boolean, optional):
16 | r (Boolean, optional): This parameter determines how much of the original image is retained.
17 | If r=0, the entire image is masked; if r=1, the image is not masked at all.
18 |
19 | Chen, Pengguang, et al. "Gridmask data augmentation."
20 | https://arxiv.org/abs/2001.04086
21 | """
22 |
23 | def __init__(self, p=0.6, d_range=(10, 30), r=0.6, channel_first=True):
24 | self.p = p
25 | self.d_range = d_range
26 | self.r = r
27 | self.channel_first = channel_first
28 |
29 | def __call__(self, img, debug=False):
30 | if not debug and np.random.uniform() > self.p:
31 | return img
32 |
33 | side = img.shape[-2]
34 | d = np.random.randint(*self.d_range, dtype=np.uint8)
35 | r = int(self.r * d)
36 |
37 | mask = np.ones((side + d, side + d), dtype=np.uint8)
38 | for i in range(0, side + d, d):
39 | for j in range(0, side + d, d):
40 | mask[i : i + (d - r), j : j + (d - r)] = 0
41 |
42 | delta_x, delta_y = np.random.randint(0, d, size=2)
43 | mask = mask[delta_x : delta_x + side, delta_y : delta_y + side]
44 |
45 | if self.channel_first:
46 | img *= np.expand_dims(mask, 0)
47 | else:
48 | img *= np.expand_dims(mask, -1)
49 |
50 | return img
51 |
--------------------------------------------------------------------------------
/eipl/layer/MultipleTimescaleRNN.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 | from eipl.utils import get_activation_fn
9 |
10 |
11 | class MTRNNCell(nn.Module):
12 | #:: MTRNNCell
13 | """Multiple Timescale RNN.
14 |
15 | Implements a form of Recurrent Neural Network (RNN) that operates with multiple timescales.
16 | This is based on the idea of hierarchical organization in human cognitive functions.
17 |
18 | Arguments:
19 | input_dim (int): Number of input features.
20 | fast_dim (int): Number of fast context neurons.
21 | slow_dim (int): Number of slow context neurons.
22 | fast_tau (float): Time constant value of fast context.
23 | slow_tau (float): Time constant value of slow context.
24 | activation (string, optional): If you set `None`, no activation is applied (ie. "linear" activation: `a(x) = x`).
25 | use_bias (Boolean, optional): whether the layer uses a bias vector. The default is False.
26 | use_pb (Boolean, optional): whether the recurrent uses a pb vector. The default is False.
27 |
28 | Yuichi Yamashita, Jun Tani,
29 | "Emergence of functional hierarchy in a multiple timescale neural network model: a humanoid robot experiment." PLoS computational biology, 2008.
30 | https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1000220
31 | """
32 |
33 | def __init__(
34 | self,
35 | input_dim,
36 | fast_dim,
37 | slow_dim,
38 | fast_tau,
39 | slow_tau,
40 | activation="tanh",
41 | use_bias=False,
42 | use_pb=False,
43 | ):
44 | super(MTRNNCell, self).__init__()
45 |
46 | self.input_dim = input_dim
47 | self.fast_dim = fast_dim
48 | self.slow_dim = slow_dim
49 | self.fast_tau = fast_tau
50 | self.slow_tau = slow_tau
51 | self.use_bias = use_bias
52 | self.use_pb = use_pb
53 |
54 | # Legacy string support for activation function.
55 | if isinstance(activation, str):
56 | self.activation = get_activation_fn(activation)
57 | else:
58 | self.activation = activation
59 |
60 | # Input Layers
61 | self.i2f = nn.Linear(input_dim, fast_dim, bias=use_bias)
62 |
63 | # Fast context layer
64 | self.f2f = nn.Linear(fast_dim, fast_dim, bias=False)
65 | self.f2s = nn.Linear(fast_dim, slow_dim, bias=use_bias)
66 |
67 | # Slow context layer
68 | self.s2s = nn.Linear(slow_dim, slow_dim, bias=False)
69 | self.s2f = nn.Linear(slow_dim, fast_dim, bias=use_bias)
70 |
71 | def forward(self, x, state=None, pb=None):
72 | """Forward propagation of the MTRNN.
73 |
74 | Arguments:
75 | x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
76 | state (list): Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim).
77 | If None, initialize states to zeros.
78 | pb (bool): pb vector. Used if self.use_pb is set to True.
79 |
80 | Returns:
81 | new_h_fast (torch.Tensor): Updated fast context state.
82 | new_h_slow (torch.Tensor): Updated slow context state.
83 | new_u_fast (torch.Tensor): Updated fast internal state.
84 | new_u_slow (torch.Tensor): Updated slow internal state.
85 | """
86 | batch_size = x.shape[0]
87 | if state is not None:
88 | prev_h_fast, prev_h_slow, prev_u_fast, prev_u_slow = state
89 | else:
90 | device = x.device
91 | prev_h_fast = torch.zeros(batch_size, self.fast_dim).to(device)
92 | prev_h_slow = torch.zeros(batch_size, self.slow_dim).to(device)
93 | prev_u_fast = torch.zeros(batch_size, self.fast_dim).to(device)
94 | prev_u_slow = torch.zeros(batch_size, self.slow_dim).to(device)
95 |
96 | new_u_fast = (1.0 - 1.0 / self.fast_tau) * prev_u_fast + 1.0 / self.fast_tau * (
97 | self.i2f(x) + self.f2f(prev_h_fast) + self.s2f(prev_h_slow)
98 | )
99 |
100 | _input_slow = self.f2s(prev_h_fast) + self.s2s(prev_h_slow)
101 | if pb is not None:
102 | _input_slow += pb
103 |
104 | new_u_slow = (
105 | 1.0 - 1.0 / self.slow_tau
106 | ) * prev_u_slow + 1.0 / self.slow_tau * _input_slow
107 |
108 | new_h_fast = self.activation(new_u_fast)
109 | new_h_slow = self.activation(new_u_slow)
110 |
111 | return new_h_fast, new_h_slow, new_u_fast, new_u_slow
112 |
--------------------------------------------------------------------------------
/eipl/layer/SpatialSoftmax.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 | from eipl.utils import tensor2numpy, plt_img, get_feature_map
11 |
12 |
13 | def create_position_encoding(width: int, height: int, normalized=True, data_format="channels_first"):
14 | if normalized:
15 | pos_x, pos_y = np.meshgrid(np.linspace(0.0, 1.0, height), np.linspace(0.0, 1.0, width), indexing="xy")
16 | else:
17 | pos_x, pos_y = np.meshgrid(
18 | np.linspace(0, height - 1, height),
19 | np.linspace(0, width - 1, width),
20 | indexing="xy",
21 | )
22 |
23 | if data_format == "channels_first":
24 | pos_xy = torch.from_numpy(np.stack([pos_x, pos_y], axis=0)).float() # (2,W,H)
25 | else:
26 | pos_xy = torch.from_numpy(np.stack([pos_x, pos_y], axis=2)).float() # (W,H,2)
27 |
28 | pos_x = torch.from_numpy(pos_x.reshape(height * width)).float()
29 | pos_y = torch.from_numpy(pos_y.reshape(height * width)).float()
30 |
31 | return pos_xy, pos_x, pos_y
32 |
33 |
34 | class SpatialSoftmax(nn.Module):
35 | """Spatial Softmax
36 | Extract XY position from feature map of CNN
37 |
38 | Chelsea Finn, Xin Yu Tan, Yan Duan, Trevor Darrell, Sergey Levine, Pieter Abbeel
39 | ``Deep spatial autoencoders for visuomotor learning.``
40 | 2016 IEEE International Conference on Robotics and Automation (ICRA). IEEE, 2016.
41 | https://ieeexplore.ieee.org/abstract/document/7487173
42 | """
43 |
44 | def __init__(self, width: int, height: int, temperature=1e-4, normalized=True):
45 | super(SpatialSoftmax, self).__init__()
46 | self.width = width
47 | self.height = height
48 | self.temperature = temperature
49 |
50 | _, pos_x, pos_y = create_position_encoding(width, height, normalized=normalized)
51 | self.register_buffer("pos_x", pos_x)
52 | self.register_buffer("pos_y", pos_y)
53 |
54 | def forward(self, x):
55 | batch_size, channels, width, height = x.shape
56 | assert height == self.height
57 | assert width == self.width
58 |
59 | # flatten, apply softmax
60 | logit = x.reshape(batch_size, channels, -1)
61 | att_map = torch.softmax(logit / self.temperature, dim=-1)
62 |
63 | # compute expectation
64 | expected_x = torch.sum(self.pos_x * att_map, dim=-1, keepdim=True)
65 | expected_y = torch.sum(self.pos_y * att_map, dim=-1, keepdim=True)
66 | keys = torch.cat([expected_x, expected_y], -1)
67 |
68 | # keys [[x,y], [x,y], [x,y],...]
69 | keys = keys.reshape(batch_size, channels, 2)
70 | att_map = att_map.reshape(-1, channels, width, height)
71 | return keys, att_map
72 |
73 |
74 | class InverseSpatialSoftmax(nn.Module):
75 | """InverseSpatialSoftmax
76 | Generate heatmap from XY position
77 |
78 | Hideyuki Ichiwara, Hiroshi Ito, Kenjiro Yamamoto, Hiroki Mori, Tetsuya Ogata
79 | ``Spatial Attention Point Network for Deep-learning-based Robust Autonomous Robot Motion Generation.``
80 | https://arxiv.org/abs/2103.01598
81 | """
82 |
83 | def __init__(self, width: int, height: int, heatmap_size=0.1, normalized=True, convex=True):
84 | super(InverseSpatialSoftmax, self).__init__()
85 |
86 | self.width = width
87 | self.height = height
88 | self.normalized = normalized
89 | self.heatmap_size = heatmap_size
90 | self.convex = convex
91 |
92 | pos_xy, _, _ = create_position_encoding(width, height, normalized=normalized)
93 | self.register_buffer("pos_xy", pos_xy)
94 |
95 | def forward(self, keys):
96 | squared_distances = torch.sum(torch.pow(self.pos_xy[None, None] - keys[:, :, :, None, None], 2.0), axis=2)
97 | heatmap = torch.exp(-squared_distances / self.heatmap_size)
98 |
99 | if not self.convex:
100 | heatmap = torch.abs(1.0 - heatmap)
101 |
102 | return heatmap
103 |
--------------------------------------------------------------------------------
/eipl/layer/__init__.py:
--------------------------------------------------------------------------------
1 | from .CoordConv2d import *
2 | from .GridMask import *
3 | from .MultipleTimescaleRNN import *
4 | from .SpatialSoftmax import *
5 |
--------------------------------------------------------------------------------
/eipl/model/BasicRNN.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 | from eipl.utils import get_activation_fn
9 | from eipl.layer import MTRNNCell
10 |
11 |
12 | class BasicLSTM(nn.Module):
13 | #:: BasicLSTM
14 | """BasicLSTM
15 |
16 | Arguments:
17 | in_dim (int): Number of fast context neurons
18 | rec_dim (int): Number of fast context neurons
19 | out_dim (int): Number of fast context neurons
20 | activation (string, optional): If you set `None`, no activation is applied (ie. "linear" activation: `a(x) = x`).
21 | The default is hyperbolic tangent (`tanh`).
22 | """
23 |
24 | def __init__(self, in_dim, rec_dim, out_dim, activation="tanh"):
25 | super(BasicLSTM, self).__init__()
26 |
27 | # Legacy string support for activation function.
28 | if isinstance(activation, str):
29 | activation = get_activation_fn(activation)
30 |
31 | self.rnn = nn.LSTMCell(in_dim, rec_dim)
32 | self.rnn_out = nn.Sequential(nn.Linear(rec_dim, out_dim), activation)
33 |
34 | def forward(self, x, state=None):
35 | rnn_hid = self.rnn(x, state)
36 | y_hat = self.rnn_out(rnn_hid[0])
37 |
38 | return y_hat, rnn_hid
39 |
40 |
41 | class BasicMTRNN(nn.Module):
42 | #:: BasicMTRNN
43 | """BasicMTRNN
44 |
45 | Arguments:
46 | in_dim (int): Number of fast context neurons
47 | rec_dim (int): Number of fast context neurons
48 | out_dim (int): Number of fast context neurons
49 | activation (string, optional): If you set `None`, no activation is applied (ie. "linear" activation: `a(x) = x`).
50 | The default is hyperbolic tangent (`tanh`).
51 | """
52 |
53 | def __init__(
54 | self,
55 | in_dim,
56 | fast_dim,
57 | slow_dim,
58 | fast_tau,
59 | slow_tau,
60 | out_dim=None,
61 | activation="tanh",
62 | ):
63 | super(BasicMTRNN, self).__init__()
64 |
65 | if out_dim is None:
66 | out_dim = in_dim
67 |
68 | # Legacy string support for activation function.
69 | if isinstance(activation, str):
70 | activation = get_activation_fn(activation)
71 |
72 | self.mtrnn = MTRNNCell(
73 | in_dim, fast_dim, slow_dim, fast_tau, slow_tau, activation=activation
74 | )
75 | # Output of RNN
76 | self.rnn_out = nn.Sequential(nn.Linear(fast_dim, out_dim), activation)
77 |
78 | def forward(self, x, state=None):
79 | rnn_hid = self.mtrnn(x, state)
80 | y_hat = self.rnn_out(rnn_hid[0])
81 |
82 | return y_hat, rnn_hid
83 |
--------------------------------------------------------------------------------
/eipl/model/CAE.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 | from eipl.layer import CoordConv2d, AddCoords
9 |
10 |
11 | class BasicCAE(nn.Module):
12 | #:: BasicCAE
13 | """BasicCAE"""
14 |
15 | def __init__(self, feat_dim=10):
16 | super(BasicCAE, self).__init__()
17 |
18 | # encoder
19 | self.encoder = nn.Sequential(
20 | nn.Conv2d(3, 64, 3, 2, 1),
21 | nn.Tanh(),
22 | nn.Conv2d(64, 32, 3, 2, 1),
23 | nn.Tanh(),
24 | nn.Conv2d(32, 16, 3, 2, 1),
25 | nn.Tanh(),
26 | nn.Conv2d(16, 12, 3, 2, 1),
27 | nn.Tanh(),
28 | nn.Conv2d(12, 8, 3, 2, 1),
29 | nn.Tanh(),
30 | nn.Flatten(),
31 | nn.Linear(8 * 4 * 4, 50),
32 | nn.Tanh(),
33 | nn.Linear(50, feat_dim),
34 | nn.Tanh(),
35 | )
36 |
37 | # decoder
38 | self.decoder = nn.Sequential(
39 | nn.Linear(feat_dim, 50),
40 | nn.Tanh(),
41 | nn.Linear(50, 8 * 4 * 4),
42 | nn.Tanh(),
43 | nn.Unflatten(1, (8, 4, 4)),
44 | nn.ConvTranspose2d(8, 12, 3, 2, padding=1, output_padding=1),
45 | nn.Tanh(),
46 | nn.ConvTranspose2d(12, 16, 3, 2, padding=1, output_padding=1),
47 | nn.Tanh(),
48 | nn.ConvTranspose2d(16, 32, 3, 2, padding=1, output_padding=1),
49 | nn.Tanh(),
50 | nn.ConvTranspose2d(32, 64, 3, 2, padding=1, output_padding=1),
51 | nn.Tanh(),
52 | nn.ConvTranspose2d(64, 3, 3, 2, padding=1, output_padding=1),
53 | nn.Tanh(),
54 | )
55 |
56 | def forward(self, x):
57 | return self.decoder(self.encoder(x))
58 |
59 |
60 | class CAE(nn.Module):
61 | #:: CAE
62 | """CAE"""
63 |
64 | def __init__(self, feat_dim=10):
65 | super(CAE, self).__init__()
66 |
67 | # encoder
68 | self.encoder = nn.Sequential(
69 | nn.Conv2d(3, 32, 6, 2, 1),
70 | nn.ReLU(True),
71 | nn.Conv2d(32, 64, 6, 2, 1),
72 | nn.ReLU(True),
73 | nn.Conv2d(64, 128, 6, 2, 1),
74 | nn.ReLU(True),
75 | nn.Flatten(),
76 | nn.Linear(128 * 14 * 14, 1000),
77 | nn.ReLU(True),
78 | nn.Linear(1000, feat_dim),
79 | nn.ReLU(True),
80 | )
81 |
82 | # decoder
83 | self.decoder = nn.Sequential(
84 | nn.Linear(feat_dim, 1000),
85 | nn.ReLU(True),
86 | nn.Linear(1000, 128 * 14 * 14),
87 | nn.ReLU(True),
88 | nn.Unflatten(1, (128, 14, 14)),
89 | nn.ConvTranspose2d(128, 64, 6, 2, padding=1),
90 | nn.ReLU(True),
91 | nn.ConvTranspose2d(64, 32, 6, 2, padding=1),
92 | nn.ReLU(True),
93 | nn.ConvTranspose2d(32, 3, 6, 2, padding=0),
94 | nn.ReLU(True),
95 | )
96 |
97 | def forward(self, x):
98 | return self.decoder(self.encoder(x))
99 |
--------------------------------------------------------------------------------
/eipl/model/CAEBN.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 | from eipl.layer import CoordConv2d, AddCoords
9 |
10 |
11 | class BasicCAEBN(nn.Module):
12 | #:: BasicCAEBN
13 | """BasicCAEBN"""
14 |
15 | def __init__(self, feat_dim=10):
16 | super(BasicCAEBN, self).__init__()
17 |
18 | # Encoder
19 | self.encoder = nn.Sequential(
20 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
21 | nn.BatchNorm2d(64),
22 | nn.ReLU(True),
23 | nn.Conv2d(64, 32, 3, 2, 1),
24 | nn.BatchNorm2d(32),
25 | nn.ReLU(True),
26 | nn.Conv2d(32, 16, 3, 2, 1),
27 | nn.BatchNorm2d(16),
28 | nn.ReLU(True),
29 | nn.Conv2d(16, 12, 3, 2, 1),
30 | nn.BatchNorm2d(12),
31 | nn.ReLU(True),
32 | nn.Conv2d(12, 8, 3, 2, 1),
33 | nn.BatchNorm2d(8),
34 | nn.ReLU(True),
35 | nn.Flatten(),
36 | nn.Linear(8 * 4 * 4, 50),
37 | nn.BatchNorm1d(50),
38 | nn.ReLU(True),
39 | nn.Linear(50, feat_dim),
40 | nn.BatchNorm1d(feat_dim),
41 | nn.ReLU(True),
42 | )
43 |
44 | # Decoder
45 | self.decoder = nn.Sequential(
46 | nn.Linear(feat_dim, 50),
47 | nn.BatchNorm1d(50),
48 | nn.ReLU(True),
49 | nn.Linear(50, 8 * 4 * 4),
50 | nn.BatchNorm1d(8 * 4 * 4),
51 | nn.ReLU(True),
52 | nn.Unflatten(1, (8, 4, 4)),
53 | nn.ConvTranspose2d(8, 12, 3, 2, padding=1, output_padding=1),
54 | nn.BatchNorm2d(12),
55 | nn.ReLU(True),
56 | nn.ConvTranspose2d(12, 16, 3, 2, 1, 1),
57 | nn.BatchNorm2d(16),
58 | nn.ReLU(True),
59 | nn.ConvTranspose2d(16, 32, 3, 2, 1, 1),
60 | nn.BatchNorm2d(32),
61 | nn.ReLU(True),
62 | nn.ConvTranspose2d(32, 64, 3, 2, 1, 1),
63 | nn.BatchNorm2d(64),
64 | nn.ReLU(True),
65 | nn.ConvTranspose2d(64, 3, 3, 2, 1, 1),
66 | nn.ReLU(True),
67 | )
68 |
69 | def forward(self, x):
70 | return self.decoder(self.encoder(x))
71 |
72 |
73 | class CAEBN(nn.Module):
74 | #:: CAEBN
75 | """CAEBN"""
76 |
77 | def __init__(self, feat_dim=10):
78 | super(CAEBN, self).__init__()
79 |
80 | # Encoder
81 | self.encoder = nn.Sequential(
82 | nn.Conv2d(3, 32, 6, 2, 1),
83 | nn.BatchNorm2d(32),
84 | nn.ReLU(True),
85 | nn.Conv2d(32, 64, 6, 2, 1),
86 | nn.BatchNorm2d(64),
87 | nn.ReLU(True),
88 | nn.Conv2d(64, 128, 6, 2, 1),
89 | nn.BatchNorm2d(128),
90 | nn.ReLU(True),
91 | nn.Flatten(),
92 | nn.Linear(128 * 14 * 14, 1000),
93 | nn.BatchNorm1d(1000),
94 | nn.ReLU(True),
95 | nn.Linear(1000, feat_dim),
96 | nn.BatchNorm1d(feat_dim),
97 | nn.ReLU(True),
98 | )
99 |
100 | # Decoder
101 | self.decoder = nn.Sequential(
102 | nn.Linear(feat_dim, 1000),
103 | nn.BatchNorm1d(1000),
104 | nn.ReLU(True),
105 | nn.Linear(1000, 128 * 14 * 14),
106 | nn.BatchNorm1d(128 * 14 * 14),
107 | nn.ReLU(True),
108 | nn.Unflatten(1, (128, 14, 14)),
109 | nn.ConvTranspose2d(128, 64, 6, 2, padding=1),
110 | nn.BatchNorm2d(64),
111 | nn.ReLU(True),
112 | nn.ConvTranspose2d(64, 32, 6, 2, padding=1),
113 | nn.BatchNorm2d(32),
114 | nn.ReLU(True),
115 | nn.ConvTranspose2d(32, 3, 6, 2, padding=0),
116 | nn.ReLU(True),
117 | )
118 |
119 | def forward(self, x):
120 | return self.decoder(self.encoder(x))
121 |
--------------------------------------------------------------------------------
/eipl/model/CNNRNN.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class CNNRNN(nn.Module):
11 | #:: CNNRNN
12 | """CNNRNN"""
13 |
14 | def __init__(self, rec_dim=50, joint_dim=8, feat_dim=10):
15 | super(CNNRNN, self).__init__()
16 |
17 | # Encoder
18 | self.encoder_image = nn.Sequential(
19 | nn.Conv2d(3, 64, 3, 2, 1),
20 | nn.Tanh(),
21 | nn.Conv2d(64, 32, 3, 2, 1),
22 | nn.Tanh(),
23 | nn.Conv2d(32, 16, 3, 2, 1),
24 | nn.Tanh(),
25 | nn.Conv2d(16, 12, 3, 2, 1),
26 | nn.Tanh(),
27 | nn.Conv2d(12, 8, 3, 2, 1),
28 | nn.Tanh(),
29 | nn.Flatten(),
30 | nn.Linear(8 * 4 * 4, 50),
31 | nn.Tanh(),
32 | nn.Linear(50, feat_dim),
33 | nn.Tanh(),
34 | )
35 |
36 | # Recurrent
37 | rec_in = feat_dim + joint_dim
38 | self.rec = nn.LSTMCell(rec_in, rec_dim)
39 |
40 | # Decoder for joint angle
41 | self.decoder_joint = nn.Sequential(nn.Linear(rec_dim, joint_dim), nn.Tanh())
42 |
43 | # Decoder for image
44 | self.decoder_image = nn.Sequential(
45 | nn.Linear(rec_dim, 8 * 4 * 4),
46 | nn.Tanh(),
47 | nn.Unflatten(1, (8, 4, 4)),
48 | nn.ConvTranspose2d(8, 12, 3, 2, padding=1, output_padding=1),
49 | nn.Tanh(),
50 | nn.ConvTranspose2d(12, 16, 3, 2, padding=1, output_padding=1),
51 | nn.Tanh(),
52 | nn.ConvTranspose2d(16, 32, 3, 2, padding=1, output_padding=1),
53 | nn.Tanh(),
54 | nn.ConvTranspose2d(32, 64, 3, 2, padding=1, output_padding=1),
55 | nn.Tanh(),
56 | nn.ConvTranspose2d(64, 3, 3, 2, padding=1, output_padding=1),
57 | nn.Tanh(),
58 | )
59 |
60 | # image, joint
61 | def forward(self, xi, xv, state=None):
62 | # Encoder
63 | im_feat = self.encoder_image(xi)
64 | hid = torch.concat([im_feat, xv], -1)
65 |
66 | # Recurrent
67 | rnn_hid = self.rec(hid, state)
68 |
69 | # Decoder
70 | y_joint = self.decoder_joint(rnn_hid[0])
71 | y_image = self.decoder_image(rnn_hid[0])
72 |
73 | return y_image, y_joint, rnn_hid
74 |
--------------------------------------------------------------------------------
/eipl/model/CNNRNNLN.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class CNNRNNLN(nn.Module):
11 | #:: CNNRNNLN
12 | """CNNRNNLN"""
13 |
14 | def __init__(self, rec_dim=50, joint_dim=8, feat_dim=10):
15 | super(CNNRNNLN, self).__init__()
16 |
17 | # Encoder
18 | self.encoder_image = nn.Sequential(
19 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
20 | nn.LayerNorm([64, 64, 64]),
21 | nn.ReLU(True),
22 | nn.Conv2d(64, 32, 3, 2, 1),
23 | nn.LayerNorm([32, 32, 32]),
24 | nn.ReLU(True),
25 | nn.Conv2d(32, 16, 3, 2, 1),
26 | nn.LayerNorm([16, 16, 16]),
27 | nn.ReLU(True),
28 | nn.Conv2d(16, 12, 3, 2, 1),
29 | nn.LayerNorm([12, 8, 8]),
30 | nn.ReLU(True),
31 | nn.Conv2d(12, 8, 3, 2, 1),
32 | nn.LayerNorm([8, 4, 4]),
33 | nn.ReLU(True),
34 | nn.Flatten(),
35 | nn.Linear(8 * 4 * 4, 50),
36 | nn.LayerNorm([50]),
37 | nn.ReLU(True),
38 | nn.Linear(50, feat_dim),
39 | nn.LayerNorm([feat_dim]),
40 | nn.ReLU(True),
41 | )
42 |
43 | # Recurrent
44 | rec_in = feat_dim + joint_dim
45 | self.rec = nn.LSTMCell(rec_in, rec_dim)
46 |
47 | # Decoder for joint angle
48 | self.decoder_joint = nn.Sequential(nn.Linear(rec_dim, joint_dim), nn.ReLU(True))
49 |
50 | # Decoder for image
51 | self.decoder_image = nn.Sequential(
52 | nn.Linear(rec_dim, 8 * 4 * 4),
53 | nn.LayerNorm([8 * 4 * 4]),
54 | nn.ReLU(True),
55 | nn.Unflatten(1, (8, 4, 4)),
56 | nn.ConvTranspose2d(8, 12, 3, 2, padding=1, output_padding=1),
57 | nn.LayerNorm([12, 8, 8]),
58 | nn.ReLU(True),
59 | nn.ConvTranspose2d(12, 16, 3, 2, 1, 1),
60 | nn.LayerNorm([16, 16, 16]),
61 | nn.ReLU(True),
62 | nn.ConvTranspose2d(16, 32, 3, 2, 1, 1),
63 | nn.LayerNorm([32, 32, 32]),
64 | nn.ReLU(True),
65 | nn.ConvTranspose2d(32, 64, 3, 2, 1, 1),
66 | nn.LayerNorm([64, 64, 64]),
67 | nn.ReLU(True),
68 | nn.ConvTranspose2d(64, 3, 3, 2, 1, 1),
69 | nn.ReLU(True),
70 | )
71 |
72 | # image, joint
73 | def forward(self, xi, xv, state=None):
74 | # Encoder
75 | im_feat = self.encoder_image(xi)
76 | hid = torch.concat([im_feat, xv], -1)
77 |
78 | # Recurrent
79 | rnn_hid = self.rec(hid, state)
80 |
81 | # Decoder
82 | y_joint = self.decoder_joint(rnn_hid[0])
83 | y_image = self.decoder_image(rnn_hid[0])
84 |
85 | return y_image, y_joint, rnn_hid
86 |
--------------------------------------------------------------------------------
/eipl/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .BasicRNN import *
2 | from .CAE import *
3 | from .CAEBN import *
4 | from .CNNRNN import *
5 | from .CNNRNNLN import *
6 | from .SARNN import *
--------------------------------------------------------------------------------
/eipl/test/benchmark_dataloader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import time
7 | import torch
8 | import numpy as np
9 | from eipl.data import MultimodalDataset, MultiEpochsDataLoader
10 |
11 |
12 | def test_dataloader(dataloader):
13 | torch.cuda.synchronize()
14 | start = time.time()
15 | for _ in range(10):
16 | for _ in dataloader:
17 | pass
18 | torch.cuda.synchronize()
19 | elapsed_time = time.time() - start
20 | print(elapsed_time, "sec.")
21 |
22 |
23 | # parameter
24 | stdev = 0.1
25 | device = "cuda:0"
26 | batch_size = 8
27 | images = np.random.random((30, 200, 3, 64, 64))
28 | joints = np.random.random((30, 200, 7))
29 |
30 | # Original dataloader with CUDA transformation
31 | # Note that CUDA transformation does not support pin_memory or num_workers.
32 | dataset = MultimodalDataset(images, joints, device="cuda:0", stdev=stdev)
33 | cuda_loader = torch.utils.data.DataLoader(
34 | dataset, batch_size=batch_size, shuffle=True, drop_last=False
35 | )
36 | test_dataloader(cuda_loader)
37 | del cuda_loader
38 |
39 | # Original dataloader with CUDA
40 | dataset = MultimodalDataset(images, joints, device="cpu", stdev=stdev)
41 | cpu_loader = torch.utils.data.DataLoader(
42 | dataset,
43 | batch_size=batch_size,
44 | shuffle=True,
45 | drop_last=False,
46 | pin_memory=True,
47 | prefetch_factor=4,
48 | num_workers=8,
49 | )
50 | test_dataloader(cpu_loader)
51 | del cpu_loader
52 |
53 | # Multiprocess dataloader
54 | dataset = MultimodalDataset(images, joints, device="cpu", stdev=stdev)
55 | multiepoch_loader = MultiEpochsDataLoader(
56 | dataset,
57 | batch_size=batch_size,
58 | shuffle=True,
59 | drop_last=False,
60 | pin_memory=True,
61 | prefetch_factor=8,
62 | num_workers=20,
63 | )
64 | test_dataloader(multiepoch_loader)
65 | del multiepoch_loader
66 |
--------------------------------------------------------------------------------
/eipl/test/test_CoordConv2d.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import matplotlib.pylab as plt
8 | from eipl.layer import CoordConv2d, AddCoords
9 |
10 |
11 | add_coord = AddCoords(with_r=True)
12 | x_im = torch.zeros(7, 12, 64, 64)
13 | y_im = add_coord(x_im)
14 | print("x_shape:", x_im.shape)
15 | print("y_shape:", y_im.shape)
16 |
17 | plt.subplot(1, 3, 1)
18 | plt.imshow(y_im[0, -3])
19 | plt.subplot(1, 3, 2)
20 | plt.imshow(y_im[0, -2])
21 | plt.subplot(1, 3, 3)
22 | plt.imshow(y_im[0, -1])
23 | plt.show()
24 |
--------------------------------------------------------------------------------
/eipl/test/test_GridMask.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import numpy as np
7 | import matplotlib.pylab as plt
8 | from eipl.layer import GridMask
9 |
10 |
11 | # single image with channel last
12 | img = np.ones((64, 64, 3))
13 | mask = GridMask(channel_first=False)
14 | out = mask(img, debug=True)
15 | plt.figure(dpi=60)
16 | plt.imshow(out)
17 |
18 | # single image with channel first
19 | img = np.ones((3, 64, 64))
20 | mask = GridMask(channel_first=True)
21 | out = mask(img, debug=True)
22 | out = out.transpose(1, 2, 0)
23 | plt.figure(dpi=60)
24 | plt.imshow(out)
25 |
26 | # multi images with channel last
27 | img = np.ones((4, 64, 64, 3))
28 | mask = GridMask(channel_first=False)
29 | out = mask(img, debug=True)
30 |
31 | plt.figure(dpi=60)
32 | for i in range(4):
33 | plt.subplot(2, 2, i + 1)
34 | plt.imshow(out[i])
35 |
36 | # multi images with channel first
37 | img = np.ones((4, 3, 64, 64))
38 | mask = GridMask(channel_first=True)
39 | out = mask(img, debug=True)
40 | out = out.transpose(0, 2, 3, 1)
41 |
42 | plt.figure(dpi=60)
43 | for i in range(4):
44 | plt.subplot(2, 2, i + 1)
45 | plt.imshow(out[i])
46 | plt.show()
47 |
--------------------------------------------------------------------------------
/eipl/test/test_LossScheduler.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import numpy as np
7 | import matplotlib.pylab as plt
8 | from eipl.utils import LossScheduler
9 |
10 |
11 | ax = []
12 | ax.append(plt.subplot2grid(shape=(2, 6), loc=(0, 0), colspan=2))
13 | ax.append(plt.subplot2grid((2, 6), (0, 2), colspan=2))
14 | ax.append(plt.subplot2grid((2, 6), (0, 4), colspan=2))
15 | ax.append(plt.subplot2grid((2, 6), (1, 1), colspan=2))
16 | ax.append(plt.subplot2grid((2, 6), (1, 3), colspan=2))
17 |
18 | for i, curve_name in enumerate(
19 | ["linear", "s", "inverse_s", "deceleration", "acceleration"]
20 | ):
21 | scheduler = LossScheduler(decay_end=100, curve_name=curve_name)
22 | loss_weight_list = []
23 | for _ in range(150):
24 | loss_weight_list.append(scheduler(loss_weight=0.1))
25 |
26 | ax[i].plot(loss_weight_list)
27 | ax[i].set_title(curve_name)
28 | ax[i].grid()
29 |
30 | plt.tight_layout()
31 | # plt.savefig("./output/loss_scheduler.png", dpi=60)
32 | plt.show()
33 |
--------------------------------------------------------------------------------
/eipl/test/test_SampleDownloader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import numpy as np
7 | import matplotlib.pylab as plt
8 | import matplotlib.animation as anim
9 | from eipl.utils import print_info
10 | from eipl.data import SampleDownloader
11 |
12 | # get airec dataset
13 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="HWC")
14 | images, joints = grasp_data.load_raw_data("train")
15 | print_info("Load raw data")
16 | print(
17 | "images: shape={}, min={:.2f}, max={:.2f}".format(
18 | images.shape, images.min(), images.max()
19 | )
20 | )
21 | print(
22 | "joints: shape={}, min={:.2f}, max={:.2f}".format(
23 | joints.shape, joints.min(), joints.max()
24 | )
25 | )
26 |
27 | # plot animation
28 | idx = 0
29 | T = images.shape[1]
30 | fig, ax = plt.subplots(1, 2, figsize=(9, 5), dpi=60)
31 |
32 |
33 | def anim_update(i):
34 | for j in range(2):
35 | ax[j].cla()
36 |
37 | # plot camera image
38 | ax[0].imshow(images[idx, i, :, :, ::-1])
39 | ax[0].axis("off")
40 | ax[0].set_title("Input image")
41 |
42 | # plot joint angle
43 | ax[1].set_ylim(-1.0, 2.0)
44 | ax[1].set_xlim(0, T)
45 | ax[1].plot(joints[idx, 1:], linestyle="dashed", c="k")
46 | for joint_idx in range(8):
47 | ax[1].plot(np.arange(i + 1), joints[idx, : i + 1, joint_idx])
48 | ax[1].set_xlabel("Step")
49 | ax[1].set_title("Joint angles")
50 |
51 |
52 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T)
53 | ani.save("./output/viz_downloader.gif")
54 |
--------------------------------------------------------------------------------
/eipl/test/test_SpatialSoftmax.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import matplotlib.pylab as plt
8 | from eipl.layer import SpatialSoftmax, InverseSpatialSoftmax
9 | from eipl.utils import tensor2numpy, plt_img, get_feature_map
10 |
11 |
12 | channels = 4
13 | im_size = 64
14 | temperature = 1e-4
15 | heatmap_size = 0.05
16 |
17 | # Generate feature_map
18 | in_features = get_feature_map(im_size=im_size, channels=channels)
19 | print("in_features shape: ", in_features.shape)
20 |
21 | # Apply spetial softmax, get keypoints and attention map
22 | ssm = SpatialSoftmax(
23 | width=im_size, height=im_size, temperature=temperature, normalized=True
24 | )
25 | keypoints, att_map = ssm(torch.tensor(in_features))
26 | print("keypoints shape: ", keypoints.shape)
27 |
28 | # Generate heatmap from keypoints
29 | issm = InverseSpatialSoftmax(
30 | width=im_size, height=im_size, heatmap_size=heatmap_size, normalized=True
31 | )
32 | out_features = issm(keypoints)
33 | out_features = tensor2numpy(out_features)
34 | print("out_features shape: ", out_features.shape)
35 |
36 | plt.figure(dpi=60)
37 | # feature map
38 | for i in range(1, channels + 1):
39 | plt.subplot(2, channels, i)
40 | plt_img(
41 | in_features[0, i - 1], key=keypoints[0, i - 1], title="feature map {}".format(i)
42 | )
43 |
44 | # plot heatmap
45 | for i in range(1, channels + 1):
46 | plt.subplot(2, channels, channels + i)
47 | plt_img(out_features[0, i - 1], title="heatmap map {}".format(i))
48 |
49 | plt.tight_layout()
50 | # plt.savefig("spatial_softmax.png")
51 | plt.show()
52 |
--------------------------------------------------------------------------------
/eipl/test/test_WeightDownloader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import shutil
8 | from eipl.data import WeightDownloader
9 |
10 | # remove old data
11 | root_dir = os.path.join(os.path.expanduser("~"), ".eipl/")
12 | shutil.rmtree(root_dir)
13 |
14 | WeightDownloader("airec", "grasp_bottle")
15 | WeightDownloader("om", "grasp_cube")
16 |
--------------------------------------------------------------------------------
/eipl/test/test_bounds.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import numpy as np
7 | import matplotlib
8 | import matplotlib.pylab as plt
9 | from eipl.utils import get_mean_minmax, get_bounds, normalization
10 |
11 | matplotlib.use("tkagg")
12 |
13 | vmin = 0.1
14 | vmax = 0.9
15 |
16 | x = np.arange(0, 10, 0.1)
17 | y1 = np.sin(x)
18 | y2 = np.sin(x + 1.0)
19 | y3 = np.sin(x + 2.0)
20 |
21 | data = np.array([y1 * 0.1, y2 * 0.5, y3 * 0.8]).T
22 | # Change data shape from [seq_len, dim] to [N, seq_len, dim]
23 | data = np.expand_dims(data, axis=0)
24 | data_mean, data_min, data_max = get_mean_minmax(data)
25 | bounds = get_bounds(data_mean, data_min, data_max, clip=0.2, vmin=vmin, vmax=vmax)
26 | norm_data = normalization(data - bounds[0], bounds[1:], (vmin, vmax))
27 | denorm_data = normalization(norm_data, (vmin, vmax), bounds[1:]) + bounds[0]
28 |
29 | plt.figure()
30 | plt.plot(data[0])
31 | plt.title("original data")
32 |
33 | plt.figure()
34 | plt.plot(norm_data[0])
35 | plt.title("normalized data")
36 |
37 | plt.figure()
38 | plt.plot(denorm_data[0])
39 | plt.title("deormalized data")
40 |
41 | plt.show()
42 |
--------------------------------------------------------------------------------
/eipl/test/test_cos_interpolation.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import numpy as np
7 | import matplotlib
8 | import matplotlib.pylab as plt
9 | from eipl.utils import cos_interpolation
10 |
11 | matplotlib.use("tkagg")
12 |
13 | data = np.zeros(120)
14 | data[30:70] = 1
15 | smoothed_data = cos_interpolation(data, step=15)
16 |
17 | plt.plot(data)
18 | plt.plot(smoothed_data)
19 | plt.show()
20 |
--------------------------------------------------------------------------------
/eipl/test/test_dataloader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import cv2
7 | import time
8 | import numpy as np
9 | import matplotlib.pylab as plt
10 | import matplotlib.animation as anim
11 | from eipl.utils import deprocess_img
12 | from eipl.data import SampleDownloader, ImageDataset, MultimodalDataset
13 |
14 |
15 | # Download dataset
16 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="CHW")
17 | images, joints = grasp_data.load_norm_data("train", vmin=0.1, vmax=0.9)
18 |
19 | # test for ImageDataset
20 | dataset = ImageDataset(images, stdev=0.02)
21 | x_img, y_img = dataset[0]
22 | x_im = x_img.numpy()[::-1]
23 | y_im = y_img.numpy()[::-1]
24 |
25 | plt.figure(dpi=60)
26 | plt.subplot(1, 2, 1)
27 | plt.imshow(x_im.transpose(1, 2, 0))
28 | plt.axis("off")
29 | plt.title("Input image")
30 | plt.subplot(1, 2, 2)
31 | plt.imshow(y_im.transpose(1, 2, 0))
32 | plt.title("True image")
33 | plt.axis("off")
34 | plt.savefig("./output/viz_image_dataset.png")
35 | plt.close()
36 |
37 | # test for MultimodalDataset
38 | multi_dataset = MultimodalDataset(images, joints)
39 | x_data, y_data = multi_dataset[1]
40 | x_img = x_data[0]
41 | y_img = y_data[0]
42 |
43 | # tensor to numpy
44 | x_img = deprocess_img(x_img.numpy().transpose(0, 2, 3, 1), 0.1, 0.9)
45 | y_img = deprocess_img(y_img.numpy().transpose(0, 2, 3, 1), 0.1, 0.9)
46 |
47 |
48 | # plot images
49 | T = len(x_img)
50 | fig, ax = plt.subplots(1, 3, figsize=(14, 6), dpi=60)
51 |
52 |
53 | def anim_update(i):
54 | for j in range(3):
55 | ax[j].cla()
56 |
57 | # plot predicted image
58 | ax[0].imshow(y_img[i, :, :, ::-1])
59 | ax[0].axis("off")
60 | ax[0].set_title("Original image", fontsize=20)
61 |
62 | # plot camera image
63 | ax[1].imshow(x_img[i, :, :, ::-1])
64 | ax[1].axis("off")
65 | ax[1].set_title("Noisied image", fontsize=20)
66 |
67 | # plot joint angle
68 | ax[2].set_ylim(0.0, 1.0)
69 | ax[2].set_xlim(0, T)
70 | ax[2].plot(y_data[1], linestyle="dashed", c="k")
71 | for joint_idx in range(8):
72 | ax[2].plot(np.arange(i + 1), x_data[1][: i + 1, joint_idx])
73 | ax[2].set_xlabel("Step", fontsize=20)
74 | ax[2].set_title("Scaled joint angles", fontsize=20)
75 | ax[2].tick_params(axis="x", labelsize=16)
76 | ax[2].tick_params(axis="y", labelsize=16)
77 | plt.subplots_adjust(left=0.01, right=0.98, bottom=0.12, top=0.9)
78 |
79 |
80 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T)
81 | ani.save("./output/viz_multimodal_dataset.gif")
82 |
--------------------------------------------------------------------------------
/eipl/test/test_models.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | from eipl.layer import MTRNNCell
7 | from eipl.model import BasicLSTM, BasicMTRNN
8 | from eipl.model import BasicCAE, CAE
9 | from eipl.model import BasicCAEBN, CAEBN
10 | from eipl.model import CNNRNN, CNNRNNLN, SARNN
11 | from torchinfo import summary
12 |
13 |
14 | batch_size = 50
15 | input_dim = 12
16 |
17 | ### MTRNNCell
18 | print("# MTRNNCell")
19 | model = MTRNNCell(input_dim=12, fast_dim=50, slow_dim=5, fast_tau=2, slow_tau=12)
20 | summary(model, input_size=(batch_size, input_dim))
21 |
22 | print("# BasicMTRNN")
23 | model = BasicMTRNN(
24 | in_dim=12, fast_dim=30, slow_dim=5, fast_tau=2, slow_tau=12, activation="tanh"
25 | )
26 | summary(model, input_dim=(batch_size, input_dim))
27 |
28 | print("# BasicLSTM")
29 | model = BasicLSTM(in_dim=12, rec_dim=50, out_dim=10, activation="tanh")
30 | summary(model, input_dim=(batch_size, input_dim))
31 |
32 | print("BasicCAE")
33 | model = BasicCAE()
34 | summary(model, input_size=(batch_size, 3, 128, 128))
35 |
36 | print("CAE")
37 | model = CAE()
38 | summary(model, input_size=(batch_size, 3, 128, 128))
39 |
40 | print("BasicCAEBN")
41 | model = BasicCAEBN(feat_dim=10)
42 | summary(model, input_size=(batch_size, 3, 128, 128))
43 |
44 | print("CAEBN")
45 | model = CAEBN(feat_dim=30)
46 | summary(model, input_size=(batch_size, 3, 128, 128))
47 |
48 |
49 | print("CNNRNN")
50 | model = CNNRNN(rec_dim=50, joint_dim=8, feat_dim=10)
51 | summary(model, input_size=[(batch_size, 3, 128, 128), (batch_size, 8)])
52 |
53 | print("CNNRNNLN")
54 | model = CNNRNNLN(rec_dim=50, joint_dim=8, feat_dim=10)
55 | summary(model, input_size=[(batch_size, 3, 128, 128), (batch_size, 8)])
56 |
57 | print("SARNN")
58 | model = SARNN(rec_dim=50, k_dim=5, joint_dim=8)
59 | summary(model, input_size=[(batch_size, 3, 128, 128), (batch_size, 8)])
60 |
--------------------------------------------------------------------------------
/eipl/tutorials/airec/ros/1_rosbag2npz.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import cv2
8 | import glob
9 | import rospy
10 | import rosbag
11 | import argparse
12 | import numpy as np
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("bag_dir", type=str)
17 | parser.add_argument("--freq", type=float, default=10)
18 | args = parser.parse_args()
19 |
20 |
21 | files = glob.glob(os.path.join(args.bag_dir, "*.bag"))
22 | files.sort()
23 | for file in files:
24 | print(file)
25 | savename = file.split(".bag")[0] + ".npz"
26 |
27 | # Open the rosbag file
28 | bag = rosbag.Bag(file)
29 |
30 | # Get the start and end times of the rosbag file
31 | start_time = bag.get_start_time()
32 | end_time = bag.get_end_time()
33 |
34 | # Get the topics in the rosbag file
35 | # topics = bag.get_type_and_topic_info()[1].keys()
36 | topics = [
37 | "/torobo/joint_states",
38 | "/torobo/head/see3cam_left/camera/color/image_repub/compressed",
39 | "/torobo/left_hand_controller/state",
40 | ]
41 |
42 | # Create a rospy.Time object to represent the current time
43 | current_time = rospy.Time.from_sec(start_time)
44 |
45 | joint_list = []
46 | finger_list = []
47 | image_list = []
48 | finger_state_list = []
49 |
50 | prev_finger = None
51 | finger_state = 0
52 |
53 | # Loop through the rosbag file at regular intervals (args.freq)
54 | freq = 1.0 / float(args.freq)
55 | while current_time.to_sec() < end_time:
56 | print(current_time.to_sec())
57 |
58 | # Get the messages for each topic at the current time
59 | for topic in topics:
60 | for topic_msg, msg, time in bag.read_messages(
61 | topic, start_time=current_time
62 | ):
63 | if time >= current_time:
64 | if topic == "/torobo/joint_states":
65 | joint_list.append(msg.position[7:14])
66 |
67 | if (
68 | topic
69 | == "/torobo/head/see3cam_left/camera/color/image_repub/compressed"
70 | ):
71 | np_arr = np.frombuffer(msg.data, np.uint8)
72 | np_img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
73 | np_img = np_img[::2, ::2]
74 | image_list.append(np_img[150:470, 110:430].astype(np.uint8))
75 |
76 | if topic == "/torobo/left_hand_controller/state":
77 | finger = np.array(msg.desired.positions[3])
78 | if prev_finger is None:
79 | prev_finger = finger
80 |
81 | if finger - prev_finger > 0.005 and finger_state == 0:
82 | finger_state = 1
83 | elif prev_finger - finger > 0.005 and finger_state == 1:
84 | finger_state = 0
85 | prev_finger = finger
86 |
87 | finger_list.append(finger)
88 | finger_state_list.append(finger_state)
89 |
90 | break
91 |
92 | # Wait for the next interval
93 | current_time += rospy.Duration.from_sec(freq)
94 |
95 | # Close the rosbag file
96 | bag.close()
97 |
98 | # Convert list to array
99 | joints = np.array(joint_list, dtype=np.float32)
100 | finger = np.array(finger_list, dtype=np.float32)
101 | finger_state = np.array(finger_state_list, dtype=np.float32)
102 | images = np.array(image_list, dtype=np.uint8)
103 |
104 | # Get shorter lenght
105 | shorter_length = min(len(joints), len(images), len(finger), len(finger_state))
106 |
107 | # Trim
108 | joints = joints[:shorter_length]
109 | finger = finger[:shorter_length]
110 | images = images[:shorter_length]
111 | finger_state = finger_state[:shorter_length]
112 |
113 | # Save
114 | np.savez(
115 | savename, joints=joints, finger=finger, finger_state=finger_state, images=images
116 | )
117 |
--------------------------------------------------------------------------------
/eipl/tutorials/airec/ros/2_make_dataset.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import cv2
8 | import glob
9 | import argparse
10 | import numpy as np
11 | import matplotlib.pylab as plt
12 | from eipl.utils import resize_img, calc_minmax, list_to_numpy, cos_interpolation
13 |
14 |
15 | def load_data(dir):
16 | joints = []
17 | images = []
18 | seq_length = []
19 |
20 | files = glob.glob(os.path.join(dir, "*.npz"))
21 | files.sort()
22 | for filename in files:
23 | print(filename)
24 | npz_data = np.load(filename)
25 |
26 | images.append(resize_img(npz_data["images"], (128, 128)))
27 | finger_state = cos_interpolation(npz_data["finger_state"], expand_dims=True)
28 | _joints = np.concatenate((npz_data["joints"], finger_state), axis=-1)
29 | joints.append(_joints)
30 | seq_length.append(len(_joints))
31 |
32 | max_seq = max(seq_length)
33 | images = list_to_numpy(images, max_seq)
34 | joints = list_to_numpy(joints, max_seq)
35 |
36 | return images, joints
37 |
38 |
39 | if __name__ == "__main__":
40 | # dataset index
41 | train_list = [0, 1, 2, 3, 5, 6, 7, 8, 10, 11, 12, 13]
42 | test_list = [4, 9, 14, 15, 16]
43 |
44 | # load data
45 | images, joints = load_data("./bag/")
46 |
47 | # save images and joints
48 | np.save("./data/train/images.npy", images[train_list].astype(np.uint8))
49 | np.save("./data/train/joints.npy", joints[train_list].astype(np.float32))
50 | np.save("./data/test/images.npy", images[test_list].astype(np.uint8))
51 | np.save("./data/test/joints.npy", joints[test_list].astype(np.float32))
52 |
53 | # save joint bounds
54 | joint_bounds = calc_minmax(joints)
55 | np.save("./data/joint_bounds.npy", joint_bounds)
56 | np.save("./data/joint_bounds.npy", joint_bounds)
57 |
--------------------------------------------------------------------------------
/eipl/tutorials/airec/ros/3_check_data.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import argparse
7 | import numpy as np
8 | import matplotlib.pylab as plt
9 | import matplotlib.animation as anim
10 | from eipl.utils import normalization
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--idx", type=int, default=0)
15 | args = parser.parse_args()
16 |
17 | idx = int(args.idx)
18 | joints = np.load("./data/test/joints.npy")
19 | joint_bounds = np.load("./data/joint_bounds.npy")
20 | images = np.load("./data/test/images.npy")
21 | N = images.shape[1]
22 |
23 |
24 | # normalized joints
25 | minmax = [0.1, 0.9]
26 | norm_joints = normalization(joints, joint_bounds, minmax)
27 |
28 | # print data information
29 | print("load test data, index number is {}".format(idx))
30 | print(
31 | "Joint: shape={}, min={:.3g}, max={:.3g}".format(
32 | joints.shape, joints.min(), joints.max()
33 | )
34 | )
35 | print(
36 | "Norm joint: shape={}, min={:.3g}, max={:.3g}".format(
37 | norm_joints.shape, norm_joints.min(), norm_joints.max()
38 | )
39 | )
40 |
41 | # plot images and normalized joints
42 | fig, ax = plt.subplots(1, 3, figsize=(14, 5), dpi=60)
43 |
44 |
45 | def anim_update(i):
46 | for j in range(3):
47 | ax[j].cla()
48 |
49 | # plot image
50 | ax[0].imshow(images[idx, i, :, :, ::-1])
51 | ax[0].axis("off")
52 | ax[0].set_title("Image")
53 |
54 | # plot joint angle
55 | ax[1].set_ylim(-1.0, 2.0)
56 | ax[1].set_xlim(0, N)
57 | ax[1].plot(joints[idx], linestyle="dashed", c="k")
58 |
59 | for joint_idx in range(8):
60 | ax[1].plot(np.arange(i + 1), joints[idx, : i + 1, joint_idx])
61 | ax[1].set_xlabel("Step")
62 | ax[1].set_title("Joint angles")
63 |
64 | # plot normalized joint angle
65 | ax[2].set_ylim(0.0, 1.0)
66 | ax[2].set_xlim(0, N)
67 | ax[2].plot(norm_joints[idx], linestyle="dashed", c="k")
68 |
69 | for joint_idx in range(8):
70 | ax[2].plot(np.arange(i + 1), norm_joints[idx, : i + 1, joint_idx])
71 | ax[2].set_xlabel("Step")
72 | ax[2].set_title("Normalized joint angles")
73 |
74 |
75 | ani = anim.FuncAnimation(fig, anim_update, interval=int(N / 10), frames=N)
76 | ani.save("./output/check_data_{}.gif".format(idx))
77 |
--------------------------------------------------------------------------------
/eipl/tutorials/airec/sarnn/bin/test_pca_sarnn.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import torch
8 | import argparse
9 | import numpy as np
10 | import matplotlib.pylab as plt
11 | import matplotlib.animation as anim
12 | from sklearn.decomposition import PCA
13 | from eipl.data import SampleDownloader
14 | from eipl.model import SARNN
15 | from eipl.utils import restore_args, tensor2numpy, normalization
16 |
17 |
18 | # argument parser
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--filename", type=str, default=None)
21 | args = parser.parse_args()
22 |
23 | # check args
24 | assert args.filename, "Please set filename"
25 |
26 | # restore parameters
27 | dir_name = os.path.split(args.filename)[0]
28 | params = restore_args(os.path.join(dir_name, "args.json"))
29 |
30 | # load dataset
31 | minmax = [params["vmin"], params["vmax"]]
32 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="HWC")
33 | images, joints = grasp_data.load_raw_data("test")
34 | joint_bounds = grasp_data.joint_bounds
35 | print(
36 | "images shape:{}, min={}, max={}".format(images.shape, images.min(), images.max())
37 | )
38 | print(
39 | "joints shape:{}, min={}, max={}".format(joints.shape, joints.min(), joints.max())
40 | )
41 |
42 | # define model
43 | model = SARNN(
44 | rec_dim=params["rec_dim"],
45 | joint_dim=8,
46 | k_dim=params["k_dim"],
47 | heatmap_size=params["heatmap_size"],
48 | temperature=params["temperature"],
49 | )
50 |
51 | # If trained with torch.compile, comment out the following code.
52 | # model = torch.compile(model)
53 |
54 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
55 | model.load_state_dict(ckpt["model_state_dict"])
56 | model.eval()
57 |
58 | # Inference
59 | states = []
60 | state = None
61 | nloop = images.shape[1]
62 | for loop_ct in range(nloop):
63 | # load data and normalization
64 | img_t = images[:, loop_ct].transpose(0, 3, 1, 2)
65 | img_t = torch.Tensor(img_t)
66 | img_t = normalization(img_t, (0, 255), minmax)
67 | joint_t = torch.Tensor(joints[:, loop_ct])
68 | joint_t = normalization(joint_t, joint_bounds, minmax)
69 |
70 | # predict rnn
71 | _, _, _, _, state = model(img_t, joint_t, state)
72 | states.append(state[0])
73 |
74 | states = torch.permute(torch.stack(states), (1, 0, 2))
75 | states = tensor2numpy(states)
76 | # Reshape the state from [N,T,D] to [-1,D] for PCA of RNN.
77 | # N is the number of datasets
78 | # T is the sequence length
79 | # D is the dimension of the hidden state
80 | N, T, D = states.shape
81 | states = states.reshape(-1, D)
82 |
83 | # plot pca
84 | loop_ct = float(360) / T
85 | pca_dim = 3
86 | pca = PCA(n_components=pca_dim).fit(states)
87 | pca_val = pca.transform(states)
88 | # Reshape the states from [-1, pca_dim] to [N,T,pca_dim] to
89 | # visualize each state as a 3D scatter.
90 | pca_val = pca_val.reshape(N, T, pca_dim)
91 |
92 | fig = plt.figure(dpi=60)
93 | ax = fig.add_subplot(projection="3d")
94 |
95 |
96 | def anim_update(i):
97 | ax.cla()
98 | angle = int(loop_ct * i)
99 | ax.view_init(30, angle)
100 |
101 | c_list = ["C0", "C1", "C2", "C3", "C4"]
102 | for n, color in enumerate(c_list):
103 | ax.scatter(
104 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0
105 | )
106 |
107 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0)
108 | pca_ratio = pca.explained_variance_ratio_ * 100
109 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0]))
110 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1]))
111 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2]))
112 |
113 |
114 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T)
115 | ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"]))
116 |
117 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg).
118 | # ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"]), writer="imagemagick")
119 | # ani.save("./output/PCA_SARNN_{}.mp4".format(params["tag"]), writer="ffmpeg")
120 |
--------------------------------------------------------------------------------
/eipl/tutorials/airec/sarnn/bin/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import sys
8 | import torch
9 | import numpy as np
10 | import argparse
11 | from tqdm import tqdm
12 | import torch.optim as optim
13 | from collections import OrderedDict
14 | from torch.utils.tensorboard import SummaryWriter
15 | from eipl.model import SARNN
16 | from eipl.data import MultimodalDataset, SampleDownloader
17 | from eipl.utils import EarlyStopping, check_args, set_logdir, resize_img
18 |
19 | # load own library
20 | sys.path.append("./libs/")
21 | from fullBPTT import fullBPTTtrainer
22 |
23 | # argument parser
24 | parser = argparse.ArgumentParser(
25 | description="Learning spatial autoencoder with recurrent neural network"
26 | )
27 | parser.add_argument("--model", type=str, default="sarnn")
28 | parser.add_argument("--epoch", type=int, default=10000)
29 | parser.add_argument("--batch_size", type=int, default=5)
30 | parser.add_argument("--rec_dim", type=int, default=50)
31 | parser.add_argument("--k_dim", type=int, default=5)
32 | parser.add_argument("--img_loss", type=float, default=0.1)
33 | parser.add_argument("--joint_loss", type=float, default=1.0)
34 | parser.add_argument("--pt_loss", type=float, default=0.1)
35 | parser.add_argument("--heatmap_size", type=float, default=0.1)
36 | parser.add_argument("--temperature", type=float, default=1e-4)
37 | parser.add_argument("--stdev", type=float, default=0.1)
38 | parser.add_argument("--log_dir", default="log/")
39 | parser.add_argument("--vmin", type=float, default=0.0)
40 | parser.add_argument("--vmax", type=float, default=1.0)
41 | parser.add_argument("--device", type=int, default=0)
42 | parser.add_argument("--compile", action="store_true")
43 | parser.add_argument("--tag", help="Tag name for snap/log sub directory")
44 | args = parser.parse_args()
45 |
46 | # check args
47 | args = check_args(args)
48 |
49 | # calculate the noise level (variance) from the normalized range
50 | stdev = args.stdev * (args.vmax - args.vmin)
51 |
52 | # set device id
53 | if args.device >= 0:
54 | device = "cuda:{}".format(args.device)
55 | else:
56 | device = "cpu"
57 |
58 | # load dataset
59 | minmax = [args.vmin, args.vmax]
60 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="HWC")
61 | images, joints = grasp_data.load_norm_data("train", vmin=args.vmin, vmax=args.vmax)
62 | images = resize_img(images, (64, 64))
63 | images = images.transpose(0, 1, 4, 2, 3)
64 | train_dataset = MultimodalDataset(images, joints, device=device, stdev=stdev)
65 | train_loader = torch.utils.data.DataLoader(
66 | train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False
67 | )
68 |
69 | images, joints = grasp_data.load_norm_data("test", vmin=args.vmin, vmax=args.vmax)
70 | images = resize_img(images, (64, 64))
71 | images = images.transpose(0, 1, 4, 2, 3)
72 | test_dataset = MultimodalDataset(images, joints, device=device, stdev=None)
73 | test_loader = torch.utils.data.DataLoader(
74 | test_dataset,
75 | batch_size=args.batch_size,
76 | shuffle=True,
77 | drop_last=False,
78 | )
79 |
80 | # define model
81 | model = SARNN(
82 | rec_dim=args.rec_dim,
83 | joint_dim=8,
84 | k_dim=args.k_dim,
85 | heatmap_size=args.heatmap_size,
86 | temperature=args.temperature,
87 | )
88 |
89 | # torch.compile makes PyTorch code run faster
90 | if args.compile:
91 | torch.set_float32_matmul_precision("high")
92 | model = torch.compile(model)
93 |
94 | # set optimizer
95 | optimizer = optim.Adam(model.parameters(), eps=1e-07)
96 |
97 | # load trainer/tester class
98 | loss_weights = [args.img_loss, args.joint_loss, args.pt_loss]
99 | trainer = fullBPTTtrainer(model, optimizer, loss_weights=loss_weights, device=device)
100 |
101 | ### training main
102 | log_dir_path = set_logdir("./" + args.log_dir, args.tag)
103 | save_name = os.path.join(log_dir_path, "SARNN.pth")
104 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30)
105 | early_stop = EarlyStopping(patience=1000)
106 |
107 | with tqdm(range(args.epoch)) as pbar_epoch:
108 | for epoch in pbar_epoch:
109 | # train and test
110 | train_loss = trainer.process_epoch(train_loader)
111 | with torch.no_grad():
112 | test_loss = trainer.process_epoch(test_loader, training=False)
113 | writer.add_scalar("Loss/train_loss", train_loss, epoch)
114 | writer.add_scalar("Loss/test_loss", test_loss, epoch)
115 |
116 | # early stop
117 | save_ckpt, _ = early_stop(test_loss)
118 |
119 | if save_ckpt:
120 | trainer.save(epoch, [train_loss, test_loss], save_name)
121 |
122 | # print process bar
123 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss))
124 |
--------------------------------------------------------------------------------
/eipl/tutorials/airec/sarnn/libs/fullBPTT.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 | from eipl.utils import LossScheduler, tensor2numpy
9 |
10 |
11 | class fullBPTTtrainer:
12 | """
13 | Helper class to train recurrent neural networks with numpy sequences
14 |
15 | Args:
16 | traindata (np.array): list of np.array. First diemension should be time steps
17 | model (torch.nn.Module): rnn model
18 | optimizer (torch.optim): optimizer
19 | input_param (float): input parameter of sequential generation. 1.0 means open mode.
20 | """
21 |
22 | def __init__(self, model, optimizer, loss_weights=[1.0, 1.0], device="cpu"):
23 | self.device = device
24 | self.optimizer = optimizer
25 | self.loss_weights = loss_weights
26 | self.scheduler = LossScheduler(decay_end=1000, curve_name="s")
27 | self.model = model.to(self.device)
28 |
29 | def save(self, epoch, loss, savename):
30 | torch.save(
31 | {
32 | "epoch": epoch,
33 | "model_state_dict": self.model.state_dict(),
34 | #'optimizer_state_dict': self.optimizer.state_dict(),
35 | "train_loss": loss[0],
36 | "test_loss": loss[1],
37 | },
38 | savename,
39 | )
40 |
41 | def process_epoch(self, data, training=True):
42 | if not training:
43 | self.model.eval()
44 | else:
45 | self.model.train()
46 |
47 | total_loss = 0.0
48 | for n_batch, ((x_img, x_joint), (y_img, y_joint)) in enumerate(data):
49 | if "cpu" in self.device:
50 | x_img = x_img.to(self.device)
51 | y_img = y_img.to(self.device)
52 | x_joint = x_joint.to(self.device)
53 | y_joint = y_joint.to(self.device)
54 |
55 | state = None
56 | yi_list, yv_list = [], []
57 | dec_pts_list, enc_pts_list = [], []
58 | self.optimizer.zero_grad(set_to_none=True)
59 | for t in range(x_img.shape[1] - 1):
60 | _yi_hat, _yv_hat, enc_ij, dec_ij, state = self.model(
61 | x_img[:, t], x_joint[:, t], state
62 | )
63 | yi_list.append(_yi_hat)
64 | yv_list.append(_yv_hat)
65 | enc_pts_list.append(enc_ij)
66 | dec_pts_list.append(dec_ij)
67 |
68 | yi_hat = torch.permute(torch.stack(yi_list), (1, 0, 2, 3, 4))
69 | yv_hat = torch.permute(torch.stack(yv_list), (1, 0, 2))
70 |
71 | img_loss = nn.MSELoss()(yi_hat, y_img[:, 1:]) * self.loss_weights[0]
72 | joint_loss = nn.MSELoss()(yv_hat, y_joint[:, 1:]) * self.loss_weights[1]
73 | # Gradually change the loss value using the LossScheluder class.
74 | pt_loss = nn.MSELoss()(
75 | torch.stack(dec_pts_list[:-1]), torch.stack(enc_pts_list[1:])
76 | ) * self.scheduler(self.loss_weights[2])
77 | loss = img_loss + joint_loss + pt_loss
78 | total_loss += tensor2numpy(loss)
79 |
80 | if training:
81 | loss.backward()
82 | self.optimizer.step()
83 |
84 | return total_loss / (n_batch + 1)
85 |
--------------------------------------------------------------------------------
/eipl/tutorials/airec/sarnn/log/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/tutorials/airec/sarnn/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/bag2npz/1_rosbag2npz.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import cv2
8 | import glob
9 | import rospy
10 | import rosbag
11 | import argparse
12 | import numpy as np
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--bag_dir", type=str, default="./bag/")
17 | parser.add_argument("--freq", type=float, default=10)
18 | args = parser.parse_args()
19 |
20 |
21 | files = glob.glob(os.path.join(args.bag_dir, "*.bag"))
22 | files.sort()
23 | for file in files:
24 | print(file)
25 | savename = file.split(".bag")[0] + ".npz"
26 |
27 | # Open the rosbag file
28 | bag = rosbag.Bag(file)
29 |
30 | # Get the start and end times of the rosbag file
31 | start_time = bag.get_start_time()
32 | end_time = bag.get_end_time()
33 |
34 | # Get the topics in the rosbag file
35 | topics = ["/follower/joint_states", "/camera/color/image_raw/compressed"]
36 |
37 | # Create a rospy.Time object to represent the current time
38 | current_time = rospy.Time.from_sec(start_time)
39 |
40 | joint_list = []
41 | image_list = []
42 |
43 | prev_gripper = None
44 | gripper_state = 0
45 |
46 | # Loop through the rosbag file at regular intervals (args.freq)
47 | freq = 1.0 / float(args.freq)
48 | while current_time.to_sec() < end_time:
49 | # Get the messages for each topic at the current time
50 | for topic in topics:
51 | for topic_msg, msg, time in bag.read_messages(
52 | topic, start_time=current_time
53 | ):
54 | if time >= current_time:
55 | if topic == "/camera/color/image_raw/compressed":
56 | np_arr = np.frombuffer(msg.data, np.uint8)
57 | np_img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
58 | np_img = np_img[::2, ::2]
59 | image_list.append(np_img.astype(np.uint8))
60 |
61 | if topic == "/follower/joint_states":
62 | joint_list.append(msg.position)
63 | break
64 |
65 | # Wait for the next interval
66 | current_time += rospy.Duration.from_sec(freq)
67 |
68 | # Close the rosbag file
69 | bag.close()
70 |
71 | # Convert list to array
72 | joints = np.array(joint_list, dtype=np.float32)
73 | images = np.array(image_list, dtype=np.uint8)
74 |
75 | # # Get shorter lenght
76 | shorter_length = min(len(joints), len(images))
77 |
78 | # # Trim
79 | joints = joints[:shorter_length]
80 | images = images[:shorter_length]
81 |
82 | # # Save
83 | np.savez(savename, joints=joints, images=images)
84 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/bag2npz/2_make_dataset.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import cv2
8 | import glob
9 | import argparse
10 | import numpy as np
11 | import matplotlib.pylab as plt
12 | from utils import resize_img, calc_minmax, list_to_numpy
13 |
14 |
15 | def load_data(dir):
16 | joints = []
17 | images = []
18 | seq_length = []
19 |
20 | files = glob.glob(os.path.join(dir, "*.npz"))
21 | files.sort()
22 |
23 | for filename in files:
24 | print(filename)
25 | npz_data = np.load(filename)
26 | images.append(resize_img(npz_data["images"], (64, 64)))
27 | _joints = npz_data["joints"]
28 | joints.append(_joints)
29 | seq_length.append(len(_joints))
30 |
31 | max_seq = max(seq_length)
32 | images = list_to_numpy(images, max_seq)
33 | joints = list_to_numpy(joints, max_seq)
34 |
35 | return images, joints
36 |
37 |
38 | if __name__ == "__main__":
39 | train_list = [0, 1, 2, 3, 4, 5, 6, 7, 8]
40 | test_list = [9, 10, 11, 12, 13]
41 |
42 | # load data
43 | images, joints = load_data("./bag/")
44 |
45 | # save images and joints
46 | np.save("./data/train/images.npy", images[train_list].astype(np.uint8))
47 | np.save("./data/train/joints.npy", joints[train_list].astype(np.float32))
48 | np.save("./data/test/images.npy", images[test_list].astype(np.uint8))
49 | np.save("./data/test/joints.npy", joints[test_list].astype(np.float32))
50 |
51 | # save joint bounds
52 | joint_bounds = calc_minmax(joints)
53 | np.save("./data/joint_bounds.npy", joint_bounds)
54 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/bag2npz/3_check_data.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import argparse
7 | import numpy as np
8 | import matplotlib.pylab as plt
9 | import matplotlib.animation as anim
10 | from utils import normalization
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--idx", type=int, default=0)
14 | args = parser.parse_args()
15 |
16 | idx = int(args.idx)
17 | joints = np.load("./data/test/joints.npy")
18 | joint_bounds = np.load("./data/joint_bounds.npy")
19 | images = np.load("./data/test/images.npy")
20 | N = images.shape[1]
21 |
22 | # normalized joints
23 | minmax = [0.1, 0.9]
24 | norm_joints = normalization(joints, joint_bounds, minmax)
25 |
26 | # print data information
27 | print("load test data, index number is {}".format(idx))
28 | print(
29 | "Joint: shape={}, min={:.3g}, max={:.3g}".format(
30 | joints.shape, joints.min(), joints.max()
31 | )
32 | )
33 | print(
34 | "Norm joint: shape={}, min={:.3g}, max={:.3g}".format(
35 | norm_joints.shape, norm_joints.min(), norm_joints.max()
36 | )
37 | )
38 |
39 | # plot images and normalized joints
40 | fig, ax = plt.subplots(1, 3, figsize=(14, 5), dpi=60)
41 |
42 |
43 | def anim_update(i):
44 | for j in range(3):
45 | ax[j].cla()
46 |
47 | # plot image
48 | ax[0].imshow(images[idx, i, :, :, ::-1])
49 | ax[0].axis("off")
50 | ax[0].set_title("Image")
51 |
52 | # plot joint angle
53 | ax[1].set_ylim(-1.0, 2.0)
54 | ax[1].set_xlim(0, N)
55 | ax[1].plot(joints[idx], linestyle="dashed", c="k")
56 |
57 | for joint_idx in range(5):
58 | ax[1].plot(np.arange(i + 1), joints[idx, : i + 1, joint_idx])
59 | ax[1].set_xlabel("Step")
60 | ax[1].set_title("Joint angles")
61 |
62 | # plot normalized joint angle
63 | ax[2].set_ylim(0.0, 1.0)
64 | ax[2].set_xlim(0, N)
65 | ax[2].plot(norm_joints[idx], linestyle="dashed", c="k")
66 |
67 | for joint_idx in range(5):
68 | ax[2].plot(np.arange(i + 1), norm_joints[idx, : i + 1, joint_idx])
69 | ax[2].set_xlabel("Step")
70 | ax[2].set_title("Normalized joint angles")
71 |
72 |
73 | ani = anim.FuncAnimation(fig, anim_update, interval=int(N / 10), frames=N)
74 | ani.save("./output/check_data_{}.gif".format(idx))
75 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/bag2npz/bag/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/bag2npz/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/bag2npz/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import cv2
8 | import glob
9 | import datetime
10 | import numpy as np
11 | import matplotlib.pylab as plt
12 |
13 |
14 | def normalization(data, indataRange, outdataRange):
15 | """
16 | Function to normalize a numpy array within a specified range
17 | Args:
18 | data (np.array): Data array
19 | indataRange (float list): List of maximum and minimum values of original data, e.g. indataRange=[0.0, 255.0].
20 | outdataRange (float list): List of maximum and minimum values of output data, e.g. indataRange=[0.0, 1.0].
21 | Return:
22 | data (np.array): Normalized data array
23 | """
24 | data = (data - indataRange[0]) / (indataRange[1] - indataRange[0])
25 | data = data * (outdataRange[1] - outdataRange[0]) + outdataRange[0]
26 | return data
27 |
28 |
29 | def resize_img(img, size=(64, 64), reshape_flag=True):
30 | """
31 | Convert tensor to numpy array.
32 | """
33 | if len(img.shape) == 5:
34 | N, T, W, H, C = img.shape
35 | img = img.reshape((-1,) + img.shape[2:])
36 | else:
37 | reshape_flag = False
38 |
39 | imgs = []
40 | for i in range(len(img)):
41 | imgs.append(cv2.resize(img[i], size))
42 |
43 | imgs = np.array(imgs)
44 | if reshape_flag:
45 | imgs = imgs.reshape(N, T, size[1], size[0], 3)
46 | return imgs
47 |
48 |
49 | def calc_minmax(_data):
50 | data = _data.reshape(-1, _data.shape[-1])
51 | data_minmax = np.array([np.min(data, 0), np.max(data, 0)])
52 | return data_minmax
53 |
54 |
55 | def list_to_numpy(data_list, max_N):
56 | dtype = data_list[0].dtype
57 | array = np.ones(
58 | (
59 | len(data_list),
60 | max_N,
61 | )
62 | + data_list[0].shape[1:],
63 | dtype,
64 | )
65 |
66 | for i, data in enumerate(data_list):
67 | N = len(data)
68 | array[i, :N] = data[:N].astype(dtype)
69 | array[i, N:] = array[i, N:] * data[-1].astype(dtype)
70 |
71 | return array
72 |
73 |
74 | def cos_interpolation(data, step=20, expand_dims=False):
75 | """
76 | Args:
77 | data (seq_length): time-series sensor data
78 | """
79 |
80 | data = data.copy()
81 | points = np.diff(data)
82 |
83 | for i, p in enumerate(points):
84 | if p == 1:
85 | t = np.linspace(0.0, 1.0, step * 2)
86 | elif p == -1:
87 | t = np.linspace(1.0, 0.0, step * 2)
88 | else:
89 | continue
90 |
91 | x_latent = (1 - np.cos(t * np.pi)) / 2
92 | data[i - step + 1 : i + step + 1] = x_latent
93 |
94 | if expand_dims:
95 | data = np.expand_dims(data, axis=-1)
96 |
97 | return data
98 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/om_teleop/config/d435_param.yaml:
--------------------------------------------------------------------------------
1 | !!python/object/new:dynamic_reconfigure.encoding.Config
2 | dictitems:
3 | auto_exposure_priority: false
4 | backlight_compensation: false
5 | brightness: -20
6 | contrast: 60
7 | enable_auto_exposure: false
8 | enable_auto_white_balance: false
9 | exposure: 166
10 | frames_queue_size: 16
11 | gain: 100
12 | gamma: 300
13 | global_time_enabled: true
14 | groups: !!python/object/new:dynamic_reconfigure.encoding.Config
15 | dictitems:
16 | auto_exposure_priority: false
17 | backlight_compensation: false
18 | brightness: -20
19 | contrast: 60
20 | enable_auto_exposure: false
21 | enable_auto_white_balance: false
22 | exposure: 166
23 | frames_queue_size: 16
24 | gain: 100
25 | gamma: 300
26 | global_time_enabled: true
27 | groups: !!python/object/new:dynamic_reconfigure.encoding.Config
28 | state: []
29 | hue: 0
30 | id: 0
31 | name: Default
32 | parameters: !!python/object/new:dynamic_reconfigure.encoding.Config
33 | state: []
34 | parent: 0
35 | power_line_frequency: 3
36 | saturation: 64
37 | sharpness: 50
38 | state: true
39 | type: ''
40 | white_balance: 4600.0
41 | state: []
42 | hue: 0
43 | power_line_frequency: 3
44 | saturation: 64
45 | sharpness: 50
46 | white_balance: 4600.0
47 | state: []
48 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/om_teleop/launch/playback_bringup.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/om_teleop/launch/rt_bringup.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/om_teleop/launch/teleop_bringup.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/om_teleop/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | om_teleop
4 | 0.0.0
5 | The om_teleop package
6 |
7 |
8 |
9 |
10 | ito
11 |
12 |
13 |
14 |
15 |
16 | TODO
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 | catkin
52 | rospy
53 | std_msgs
54 | rospy
55 | std_msgs
56 | rospy
57 | std_msgs
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/om_teleop/src/dynamixel_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import sys, tty, termios
8 | import numpy as np
9 |
10 | # Control table address (need to change)
11 | ADDR_OPERATING_MODE = 1 # Data Byte Length
12 | CURRENT_POSITION_CONTROL_MODE = 5
13 | ADDR_GOAL_CURRENT = 102
14 | ADDR_TORQUE_ENABLE = 64
15 | ADDR_LED_RED = 65
16 | LEN_LED_RED = 1 # Data Byte Length
17 | ADDR_GOAL_POSITION = 116
18 | LEN_GOAL_POSITION = 4 # Data Byte Length
19 | ADDR_PRESENT_POSITION = 132
20 | LEN_PRESENT_POSITION = 4 # Data Byte Length
21 | DXL_MINIMUM_POSITION_VALUE = 0 # Refer to the Minimum Position Limit of product eManual
22 | DXL_MAXIMUM_POSITION_VALUE = (
23 | 4095 # Refer to the Maximum Position Limit of product eManual
24 | )
25 | TORQUE_ENABLE = 1 # Value for enabling the torque
26 | TORQUE_DISABLE = 0 # Value for disabling the torque
27 | DXL_MOVING_STATUS_THRESHOLD = 20 # Dynamixel moving status threshold
28 | PROTOCOL_VERSION = 2.0
29 |
30 |
31 | ## getch
32 | fd = sys.stdin.fileno()
33 | old_settings = termios.tcgetattr(fd)
34 |
35 |
36 | def getch():
37 | try:
38 | tty.setraw(sys.stdin.fileno())
39 | ch = sys.stdin.read(1)
40 | finally:
41 | termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
42 | return ch
43 |
44 |
45 | ## value2radian
46 | def value2degree(data):
47 | """convert dynamixel value to radian"""
48 | return np.array(data, dtype=np.float32) * 360.0 / 4095.0
49 |
50 |
51 | def value2radian(data):
52 | """convert dynamixel value to degree"""
53 | return (np.array(data, dtype=np.float32) - 2047.5) / 2047.5 * np.pi
54 |
55 |
56 | def degree2value(data):
57 | """convert dynamixel value to radian"""
58 | return np.array(data, dtype=np.float32) * 4095.0 / 360.0
59 |
60 |
61 | def radian2value(data):
62 | """convert dynamixel value to degree"""
63 | val = np.array(data, dtype=np.float32) / np.pi * 2047.5 + 2047.5
64 | return val.astype(np.int32)
65 |
66 |
67 | def normalization(dselfata, indataRange, outdataRange):
68 | data = (data - indataRange[0]) / (indataRange[1] - indataRange[0])
69 | data = data * (outdataRange[1] - outdataRange[0]) + outdataRange[0]
70 | return data
71 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/om_teleop/src/follower_bringup.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import sys
8 | import time
9 | import numpy as np
10 | from dynamixel_sdk import *
11 | from dynamixel_driver import *
12 |
13 | # ROS
14 | import rospy
15 | from sensor_msgs.msg import JointState
16 |
17 |
18 | class FollowerARM(DynamixelDriver):
19 | def __init__(self, freq, motor_list, device="/dev/ttyUSB0", baudrate=1000000):
20 | self.freq = freq
21 | self.motor_list = motor_list
22 | self.motor_names = ["motor_{}".format(m) for m in motor_list]
23 | self.baudrate = baudrate
24 | self.cmd_sub = rospy.Subscriber(
25 | "/leader/interpolated_command", JointState, self.cmdCallback
26 | )
27 | self.joint_pub = rospy.Publisher(
28 | "/follower/joint_states", JointState, queue_size=1
29 | )
30 |
31 | self.cmd = None
32 |
33 | # setup the message
34 | self.joint_msg = JointState()
35 | self.joint_msg.name = self.motor_names
36 |
37 | # Initialization
38 | self.portHandler = PortHandler(device)
39 | self.packetHandler = PacketHandler(PROTOCOL_VERSION)
40 | self.groupBulkWrite = GroupBulkWrite(self.portHandler, self.packetHandler)
41 | self.groupBulkRead = GroupBulkRead(self.portHandler, self.packetHandler)
42 | self.serial_open()
43 | self.set_gripper_mode()
44 | self.torque_on()
45 |
46 | def cmdCallback(self, msg):
47 | self.cmd = np.array(msg.position)
48 |
49 | def run(self):
50 | rate = rospy.Rate(self.freq)
51 |
52 | rospy.logwarn("FollowerARM.run(): Starting execution")
53 | while not rospy.is_shutdown():
54 | # start_time = time.time()
55 | # set joint
56 | if self.cmd is not None:
57 | self.set_joint_radian(self.cmd)
58 |
59 | # set message
60 | self.joint_msg.header.stamp = rospy.Time.now()
61 | self.joint_msg.position = self.get_joint_radian()
62 |
63 | # publish
64 | self.joint_pub.publish(self.joint_msg)
65 | rate.sleep()
66 | # print(time.time()-start_time)
67 |
68 | rospy.logwarn("FollowerARM.run(): Finished execution")
69 | self.torque_off()
70 |
71 |
72 | def main(freq, motor_list, device, baudrate):
73 | follower_arm = FollowerARM(freq, motor_list, device, baudrate)
74 |
75 | follower_arm.run()
76 |
77 |
78 | if __name__ == "__main__":
79 | try:
80 | rospy.init_node("dxl_follower_node", anonymous=True)
81 | freq = rospy.get_param("dxl_follower_node/freq")
82 | device = rospy.get_param("dxl_follower_node/device")
83 | baudrate = rospy.get_param("dxl_follower_node/baudrate")
84 | motor_list = rospy.get_param("dxl_follower_node/motor_list")
85 | motor_list = [eval(m) for m in motor_list.split(",")]
86 | main(freq, motor_list, device, baudrate)
87 | except rospy.ROSInterruptException:
88 | pass
89 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/om_teleop/src/interplation_node.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import time
7 | import rospy
8 | import numpy as np
9 | from enum import Enum
10 | from sensor_msgs.msg import JointState
11 |
12 |
13 | def linear_interpolation(q0, q1, T, t):
14 | """Linear interpolation between q0 and q1"""
15 | s = min(1.0, max(0.0, t / T))
16 | return (q1 - q0) * s + q0
17 |
18 |
19 | class Interpolator:
20 | def __init__(
21 | self, control_freq=50, target_freq=10, max_target_jump=20.0 / 180.0 * np.pi
22 | ):
23 | self.control_period = 1.0 / control_freq
24 | self.target_period = 1.0 / target_freq
25 | self.max_target_jump = max_target_jump
26 | self.max_target_delay = 2 * self.target_period
27 |
28 | # setup the message
29 | motor_list = [11, 12, 13, 14, 15]
30 | self.joint_msg = JointState()
31 | self.motor_names = ["motor_{}".format(m) for m in motor_list]
32 | self.joint_msg.name = self.motor_names
33 |
34 | # setup states
35 | self.has_valid_target = False
36 | self.target_update_time = None
37 | self.cmd = None
38 | self.prev_target = None
39 |
40 | # low freq target input from network
41 | self.target_sub = rospy.Subscriber(
42 | "/leader/joint_states", JointState, self.targetCallback
43 | )
44 |
45 | # setup the online effort controller client
46 | self.cmd_pub = rospy.Publisher(
47 | "/leader/interpolated_command", JointState, queue_size=1
48 | )
49 |
50 | def run(self):
51 | rate = rospy.Rate(1.0 / self.control_period)
52 |
53 | self.target_update_time = None
54 | self.target = None
55 | self.has_valid_target = False
56 |
57 | rospy.logwarn("OnlineExecutor.run(): Starting execution")
58 | while not rospy.is_shutdown():
59 | if not self.has_valid_target:
60 | # just keep the current state
61 | rospy.logwarn_throttle(
62 | 1.0, "OnlineExecutor.run(): Wait for first target"
63 | )
64 |
65 | if self.has_valid_target:
66 | # interpolation between prev_target and next target
67 | t = (rospy.Time.now() - self.target_update_time).to_sec()
68 |
69 | self.cmd = linear_interpolation(
70 | self.prev_target, self.target, self.target_period, t
71 | )
72 |
73 | # check if communicaton is broken
74 | if t > self.max_target_delay:
75 | self.has_valid_target = False
76 | rospy.logwarn(
77 | "OnlineExecutor.run(): Interpolation stopped, wait for valid command"
78 | )
79 |
80 | # send the target to robot
81 | self.command(self.cmd)
82 |
83 | rate.sleep()
84 |
85 | rospy.logwarn("OnlineExecutor.run(): Finished execution")
86 |
87 | def command(self, cmd):
88 | if not np.isnan(cmd).any():
89 | self.joint_msg.header.stamp = rospy.Time.now()
90 | self.joint_msg.position = cmd
91 | self.cmd_pub.publish(self.joint_msg)
92 |
93 | def targetCallback(self, msg):
94 | """
95 | next joint target callback from DL model
96 | stores next target, current time, previous target
97 | """
98 |
99 | # extract the state form message
100 | target = np.array(msg.position)
101 |
102 | if self.cmd is not None:
103 | # savety first, check last command against new target pose
104 | if np.max(np.abs(self.cmd - target)) > self.max_target_jump:
105 | self.has_valid_target = False
106 | idx = np.argmax(np.abs(self.cmd - target))
107 | rospy.logerr_throttle(
108 | 1.0,
109 | "OnlineExecutor.targetCallback(): Jump in cmd[{:d}]={:f} to target[{:d}]={:f} > {:f}".format(
110 | idx, self.cmd[idx], idx, target[idx], self.max_target_jump
111 | ),
112 | )
113 | return
114 | else:
115 | # initialization
116 | rospy.logwarn("OnlineExecutor.run(): Recieved first data")
117 | self.cmd = target
118 |
119 | # store target and last target
120 | self.target = target
121 | self.prev_target = np.copy(self.cmd)
122 | self.target_update_time = rospy.Time.now()
123 |
124 | # target was good
125 | self.has_valid_target = True
126 |
127 |
128 | def main(control_freq, target_freq):
129 | executor = Interpolator(
130 | control_freq=control_freq,
131 | target_freq=target_freq,
132 | max_target_jump=50.0 / 180.0 * np.pi,
133 | )
134 |
135 | executor.run()
136 |
137 |
138 | if __name__ == "__main__":
139 | try:
140 | rospy.init_node("interpolator_node", anonymous=True)
141 | control_freq = rospy.get_param("interpolator_node/control_freq")
142 | target_freq = rospy.get_param("interpolator_node/target_freq")
143 | main(control_freq, target_freq)
144 | except rospy.ROSInterruptException:
145 | pass
146 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/om_teleop/src/leader_bringup.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import sys
8 | import time
9 | import numpy as np
10 | from dynamixel_sdk import *
11 | from dynamixel_driver import *
12 |
13 | # ROS
14 | import rospy
15 | from sensor_msgs.msg import JointState
16 |
17 |
18 | class LeaderARM(DynamixelDriver):
19 | def __init__(self, freq, motor_list, device="/dev/ttyUSB1", baudrate=1000000):
20 | self.freq = freq
21 | self.motor_list = motor_list
22 | self.motor_names = ["motor_{}".format(m) for m in motor_list]
23 | self.baudrate = baudrate
24 | self.joint_pub = rospy.Publisher(
25 | "/leader/joint_states", JointState, queue_size=1
26 | )
27 |
28 | # setup the message
29 | self.joint_msg = JointState()
30 | self.joint_msg.name = self.motor_names
31 |
32 | # Initialization
33 | self.portHandler = PortHandler(device)
34 | self.packetHandler = PacketHandler(PROTOCOL_VERSION)
35 | self.groupBulkWrite = GroupBulkWrite(self.portHandler, self.packetHandler)
36 | self.groupBulkRead = GroupBulkRead(self.portHandler, self.packetHandler)
37 | self.serial_open()
38 |
39 | def run(self):
40 | rate = rospy.Rate(self.freq)
41 |
42 | rospy.logwarn("LeaderARM.run(): Starting execution")
43 | while not rospy.is_shutdown():
44 | # set message
45 | self.joint_msg.header.stamp = rospy.Time.now()
46 | self.joint_msg.position = self.get_joint_radian()
47 |
48 | # publish
49 | self.joint_pub.publish(self.joint_msg)
50 | rate.sleep()
51 |
52 | rospy.logwarn("LeaderARM.run(): Finished execution")
53 | self.torque_off()
54 |
55 |
56 | def main(freq, motor_list, device, baudrate):
57 | leader_arm = LeaderARM(freq, motor_list, device, baudrate)
58 |
59 | leader_arm.run()
60 |
61 |
62 | if __name__ == "__main__":
63 | try:
64 | rospy.init_node("dxl_leader_node", anonymous=True)
65 | freq = rospy.get_param("dxl_leader_node/freq")
66 | device = rospy.get_param("dxl_leader_node/device")
67 | baudrate = rospy.get_param("dxl_leader_node/baudrate")
68 | motor_list = rospy.get_param("dxl_leader_node/motor_list")
69 | motor_list = [eval(m) for m in motor_list.split(",")]
70 | main(freq, motor_list, device, baudrate)
71 | except rospy.ROSInterruptException:
72 | pass
73 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/ros/om_teleop/src/playback.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import sys
8 | import time
9 | import numpy as np
10 | from dynamixel_sdk import *
11 | from dynamixel_driver import *
12 |
13 | # ROS
14 | import rospy
15 | from sensor_msgs.msg import JointState
16 |
17 |
18 | def main(freq, motor_list):
19 | # initialization
20 | motor_list = motor_list
21 | motor_names = ["motor_{}".format(m) for m in motor_list]
22 |
23 | # publisher
24 | joint_pub = rospy.Publisher("/leader/joint_states", JointState, queue_size=1)
25 |
26 | # setup the message
27 | joint_msg = JointState()
28 | joint_msg.name = motor_names
29 |
30 | # load_data
31 | joint_angles = np.load("../bag2npy/data/train/joints.npy")[0]
32 | nloop = len(joint_angles)
33 | rate = rospy.Rate(freq)
34 |
35 | rospy.logwarn("Playback: Starting execution")
36 | for loop_ct in range(nloop):
37 | # set message
38 | joint_msg.header.stamp = rospy.Time.now()
39 | joint_msg.position = joint_angles[loop_ct]
40 |
41 | # publish
42 | joint_pub.publish(joint_msg)
43 | rate.sleep()
44 |
45 | rospy.logwarn("Playback: Finished execution")
46 |
47 |
48 | if __name__ == "__main__":
49 | try:
50 | rospy.init_node("dxl_playback_node", anonymous=True)
51 | freq = rospy.get_param("dxl_playback_node/freq")
52 | motor_list = rospy.get_param("dxl_playback_node/motor_list")
53 | motor_list = [eval(m) for m in motor_list.split(",")]
54 | main(freq, motor_list)
55 | except rospy.ROSInterruptException:
56 | pass
57 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/sarnn/bin/export_onnx.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from os import path
4 |
5 | import torch
6 |
7 | from eipl.model import SARNN
8 |
9 | def load_config(file):
10 | with open(file, "r") as f:
11 | config = json.load(f)
12 |
13 | #! temp fix
14 | # joint_dim is 5 for open_manipulator
15 | # im_size is 64x64 for object grasp task
16 | # hard coded for now
17 | config["joint_dim"] = 5
18 | config["im_size"] = [64, 64]
19 |
20 | return config
21 |
22 |
23 | def export_onnx(weight, output, config, verbose=False):
24 | model = SARNN(
25 | rec_dim=config["rec_dim"],
26 | joint_dim=config["joint_dim"],
27 | k_dim=config["k_dim"],
28 | heatmap_size=config["heatmap_size"],
29 | temperature=config["temperature"],
30 | im_size=config["im_size"],
31 | )
32 |
33 | ckpt = torch.load(weight, map_location=torch.device("cpu"))
34 | model.load_state_dict(ckpt["model_state_dict"])
35 | model.eval()
36 |
37 | input_names = ["i.image", "i.joint", "i.state_h", "i.state_c"]
38 | output_names = ["o.image", "o.joint", "o.enc_pts", "o.dec_pts", "o.state_h", "o.state_c"]
39 | dummy_input = (
40 | torch.randn(1, 3, config["im_size"][0], config["im_size"][1]),
41 | torch.randn(1, config["joint_dim"]),
42 | tuple(torch.randn(1, config["rec_dim"]) for _ in range(2)),
43 | )
44 | torch.onnx.export(
45 | model,
46 | dummy_input,
47 | output,
48 | input_names=input_names,
49 | output_names=output_names,
50 | verbose=verbose,
51 | )
52 |
53 | if __name__ == "__main__":
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument("weight", help="PyTorch model weight file")
56 | parser.add_argument("--config", help="Configuration file", default=None)
57 | parser.add_argument("--output", help="Output ONNX file", default="model.onnx")
58 | parser.add_argument("--verbose", "-v", action="store_true")
59 | args = parser.parse_args()
60 |
61 | if args.config is not None:
62 | config = load_config(args.config)
63 | else:
64 | config_path = path.join(path.dirname(args.weight), "args.json")
65 | config = load_config(config_path)
66 |
67 | export_onnx(args.weight, args.output, config, args.verbose)
68 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/sarnn/bin/test.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import torch
8 | import argparse
9 | import numpy as np
10 | import matplotlib.pylab as plt
11 | import matplotlib.animation as anim
12 | from eipl.data import SampleDownloader, WeightDownloader
13 | from eipl.utils import restore_args, tensor2numpy, deprocess_img, normalization
14 | from eipl.model import SARNN
15 |
16 | # argument parser
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--filename", type=str, default=None)
19 | parser.add_argument("--idx", type=int, default=0)
20 | parser.add_argument("--pretrained", action="store_true")
21 | args = parser.parse_args()
22 |
23 | # check args
24 | assert args.filename or args.pretrained, "Please set filename or pretrained"
25 |
26 | # load pretrained weight
27 | if args.pretrained:
28 | WeightDownloader("om", "grasp_cube")
29 | args.filename = os.path.join(
30 | os.path.expanduser("~"), ".eipl/om/grasp_cube/weights/SARNN/model.pth"
31 | )
32 |
33 | # restore parameters
34 | dir_name = os.path.split(args.filename)[0]
35 | params = restore_args(os.path.join(dir_name, "args.json"))
36 | idx = args.idx
37 |
38 | # load dataset
39 | minmax = [params["vmin"], params["vmax"]]
40 | grasp_data = SampleDownloader("om", "grasp_cube", img_format="HWC")
41 | _images, _joints = grasp_data.load_raw_data("test")
42 | images = _images[idx]
43 | joints = _joints[idx]
44 | joint_bounds = np.load(
45 | os.path.join(os.path.expanduser("~"), ".eipl/om/grasp_cube/joint_bounds.npy")
46 | )
47 | print(
48 | "images shape:{}, min={}, max={}".format(images.shape, images.min(), images.max())
49 | )
50 | print(
51 | "joints shape:{}, min={}, max={}".format(joints.shape, joints.min(), joints.max())
52 | )
53 |
54 | # define model
55 | model = SARNN(
56 | rec_dim=params["rec_dim"],
57 | joint_dim=5,
58 | k_dim=params["k_dim"],
59 | heatmap_size=params["heatmap_size"],
60 | temperature=params["temperature"],
61 | im_size=[64, 64],
62 | )
63 |
64 | if params["compile"]:
65 | model = torch.compile(model)
66 |
67 | # load weight
68 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
69 | model.load_state_dict(ckpt["model_state_dict"])
70 | model.eval()
71 |
72 | # Inference
73 | img_size = 64
74 | image_list, joint_list = [], []
75 | enc_pts_list, dec_pts_list = [], []
76 | state = None
77 | nloop = len(images)
78 | for loop_ct in range(nloop):
79 | # load data and normalization
80 | img_t = images[loop_ct].transpose(2, 0, 1)
81 | img_t = torch.Tensor(np.expand_dims(img_t, 0))
82 | img_t = normalization(img_t, (0, 255), minmax)
83 | joint_t = torch.Tensor(np.expand_dims(joints[loop_ct], 0))
84 | joint_t = normalization(joint_t, joint_bounds, minmax)
85 |
86 | # predict rnn
87 | y_image, y_joint, enc_pts, dec_pts, state = model(img_t, joint_t, state)
88 |
89 | # denormalization
90 | pred_image = tensor2numpy(y_image[0])
91 | pred_image = deprocess_img(pred_image, params["vmin"], params["vmax"])
92 | pred_image = pred_image.transpose(1, 2, 0)
93 | pred_joint = tensor2numpy(y_joint[0])
94 | pred_joint = normalization(pred_joint, minmax, joint_bounds)
95 |
96 | # append data
97 | image_list.append(pred_image)
98 | joint_list.append(pred_joint)
99 | enc_pts_list.append(tensor2numpy(enc_pts[0]))
100 | dec_pts_list.append(tensor2numpy(dec_pts[0]))
101 |
102 | print("loop_ct:{}, joint:{}".format(loop_ct, pred_joint))
103 |
104 | pred_image = np.array(image_list)
105 | pred_joint = np.array(joint_list)
106 |
107 | # split key points
108 | enc_pts = np.array(enc_pts_list)
109 | dec_pts = np.array(dec_pts_list)
110 | enc_pts = enc_pts.reshape(-1, params["k_dim"], 2) * img_size
111 | dec_pts = dec_pts.reshape(-1, params["k_dim"], 2) * img_size
112 | enc_pts = np.clip(enc_pts, 0, img_size)
113 | dec_pts = np.clip(dec_pts, 0, img_size)
114 |
115 |
116 | # plot images
117 | T = len(images)
118 | fig, ax = plt.subplots(1, 3, figsize=(12, 5), dpi=60)
119 |
120 |
121 | def anim_update(i):
122 | for j in range(3):
123 | ax[j].cla()
124 |
125 | # plot camera image
126 | ax[0].imshow(images[i, :, :, ::-1])
127 | for j in range(params["k_dim"]):
128 | ax[0].plot(enc_pts[i, j, 0], enc_pts[i, j, 1], "bo", markersize=6) # encoder
129 | ax[0].plot(
130 | dec_pts[i, j, 0], dec_pts[i, j, 1], "rx", markersize=6, markeredgewidth=2
131 | ) # decoder
132 | ax[0].axis("off")
133 | ax[0].set_title("Input image")
134 |
135 | # plot predicted image
136 | ax[1].imshow(pred_image[i, :, :, ::-1])
137 | ax[1].axis("off")
138 | ax[1].set_title("Predicted image")
139 |
140 | # plot joint angle
141 | ax[2].set_ylim(-1.0, 2.0)
142 | ax[2].set_xlim(0, T)
143 | ax[2].plot(joints[1:], linestyle="dashed", c="k")
144 | # om has 5 joints, not 8
145 | for joint_idx in range(5):
146 | ax[2].plot(np.arange(i + 1), pred_joint[: i + 1, joint_idx])
147 | ax[2].set_xlabel("Step")
148 | ax[2].set_title("Joint angles")
149 |
150 |
151 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T)
152 | ani.save("./output/SARNN_{}_{}.gif".format(params["tag"], idx))
153 |
154 | # If an error occurs in generating the gif animation, change the writer (imagemagick/ffmpeg).
155 | # ani.save("./output/SARNN_{}_{}_{}.gif".format(params["tag"], idx, args.input_param), writer="ffmpeg")
156 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/sarnn/bin/test_pca_sarnn.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import sys
8 | import torch
9 | import argparse
10 | import numpy as np
11 | import matplotlib.pylab as plt
12 | import matplotlib.animation as anim
13 | from sklearn.decomposition import PCA
14 | from eipl.data import SampleDownloader
15 | from eipl.model import SARNN
16 | from eipl.utils import restore_args, tensor2numpy, normalization
17 |
18 |
19 | # argument parser
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("filename", type=str, default=None)
22 | args = parser.parse_args()
23 |
24 | # restore parameters
25 | dir_name = os.path.split(args.filename)[0]
26 | params = restore_args(os.path.join(dir_name, "args.json"))
27 | # idx = args.idx
28 |
29 | # load dataset
30 | minmax = [params["vmin"], params["vmax"]]
31 | grasp_data = SampleDownloader("om", "grasp_cube", img_format="HWC")
32 | images, joints = grasp_data.load_raw_data("test")
33 | joint_bounds = grasp_data.joint_bounds
34 | print(
35 | "images shape:{}, min={}, max={}".format(images.shape, images.min(), images.max())
36 | )
37 | print(
38 | "joints shape:{}, min={}, max={}".format(joints.shape, joints.min(), joints.max())
39 | )
40 |
41 | # define model
42 | model = SARNN(
43 | rec_dim=params["rec_dim"],
44 | joint_dim=5,
45 | k_dim=params["k_dim"],
46 | heatmap_size=params["heatmap_size"],
47 | temperature=params["temperature"],
48 | im_size=[64, 64],
49 | )
50 |
51 | if params["compile"]:
52 | model = torch.compile(model)
53 |
54 | # load weight
55 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
56 | model.load_state_dict(ckpt["model_state_dict"])
57 | model.eval()
58 |
59 | # Inference
60 | states = []
61 | state = None
62 | nloop = images.shape[1]
63 | for loop_ct in range(nloop):
64 | # load data and normalization
65 | img_t = images[:, loop_ct].transpose(0, 3, 1, 2)
66 | img_t = torch.Tensor(img_t)
67 | img_t = normalization(img_t, (0, 255), minmax)
68 | joint_t = torch.Tensor(joints[:, loop_ct])
69 | joint_t = normalization(joint_t, joint_bounds, minmax)
70 |
71 | # predict rnn
72 | _, _, _, _, state = model(img_t, joint_t, state)
73 | states.append(state[0])
74 |
75 | states = torch.permute(torch.stack(states), (1, 0, 2))
76 | states = tensor2numpy(states)
77 | # Reshape the state from [N,T,D] to [-1,D] for PCA of RNN.
78 | # N is the number of datasets
79 | # T is the sequence length
80 | # D is the dimension of the hidden state
81 | N, T, D = states.shape
82 | states = states.reshape(-1, D)
83 |
84 | # PCA
85 | loop_ct = float(360) / T
86 | pca_dim = 3
87 | pca = PCA(n_components=pca_dim).fit(states)
88 | pca_val = pca.transform(states)
89 | # Reshape the states from [-1, pca_dim] to [N,T,pca_dim] to
90 | # visualize each state as a 3D scatter.
91 | pca_val = pca_val.reshape(N, T, pca_dim)
92 |
93 | # plot images
94 | fig = plt.figure(dpi=60)
95 | ax = fig.add_subplot(projection="3d")
96 |
97 |
98 | def anim_update(i):
99 | ax.cla()
100 | angle = int(loop_ct * i)
101 | ax.view_init(30, angle)
102 |
103 | c_list = ["C0", "C1", "C2", "C3", "C4"]
104 | for n, color in enumerate(c_list):
105 | ax.scatter(
106 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0
107 | )
108 |
109 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0)
110 | pca_ratio = pca.explained_variance_ratio_ * 100
111 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0]))
112 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1]))
113 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2]))
114 |
115 |
116 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T)
117 | ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"]))
118 |
119 | # If an error occurs in generating the gif or mp4 animation, change the writer (imagemagick/ffmpeg).
120 | # ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"]), writer="imagemagick")
121 | # ani.save("./output/PCA_SARNN_{}.mp4".format(params["tag"]), writer="ffmpeg")
122 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/sarnn/bin/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import sys
8 | import numpy as np
9 | import torch
10 | import argparse
11 | from tqdm import tqdm
12 | import torch.optim as optim
13 | from collections import OrderedDict
14 | from torch.utils.tensorboard import SummaryWriter
15 | from eipl.model import SARNN
16 | from eipl.data import MultimodalDataset, SampleDownloader
17 | from eipl.utils import EarlyStopping, check_args, set_logdir
18 |
19 | # load own library
20 | sys.path.append("./libs/")
21 | from fullBPTT import fullBPTTtrainer
22 |
23 | # argument parser
24 | parser = argparse.ArgumentParser(
25 | description="Learning spatial autoencoder with recurrent neural network"
26 | )
27 | parser.add_argument("--model", type=str, default="sarnn")
28 | parser.add_argument("--epoch", type=int, default=10000)
29 | parser.add_argument("--batch_size", type=int, default=5)
30 | parser.add_argument("--rec_dim", type=int, default=50)
31 | parser.add_argument("--k_dim", type=int, default=5)
32 | parser.add_argument("--img_loss", type=float, default=0.1)
33 | parser.add_argument("--joint_loss", type=float, default=1.0)
34 | parser.add_argument("--pt_loss", type=float, default=0.1)
35 | parser.add_argument("--heatmap_size", type=float, default=0.1)
36 | parser.add_argument("--temperature", type=float, default=1e-4)
37 | parser.add_argument("--stdev", type=float, default=0.1)
38 | parser.add_argument("--log_dir", default="log/")
39 | parser.add_argument("--vmin", type=float, default=0.0)
40 | parser.add_argument("--vmax", type=float, default=1.0)
41 | parser.add_argument("--device", type=int, default=0)
42 | parser.add_argument("--compile", action="store_true")
43 | parser.add_argument("--tag", help="Tag name for snap/log sub directory")
44 | args = parser.parse_args()
45 |
46 | # check args
47 | args = check_args(args)
48 |
49 | # calculate the noise level (variance) from the normalized range
50 | stdev = args.stdev * (args.vmax - args.vmin)
51 |
52 | # set device id
53 | if args.device >= 0:
54 | device = "cuda:{}".format(args.device)
55 | else:
56 | device = "cpu"
57 |
58 | # load dataset
59 | minmax = [args.vmin, args.vmax]
60 | grasp_data = SampleDownloader("om", "grasp_cube", img_format="CHW")
61 | images, joints = grasp_data.load_norm_data("train", vmin=args.vmin, vmax=args.vmax)
62 | train_dataset = MultimodalDataset(images, joints, device=device, stdev=stdev)
63 | train_loader = torch.utils.data.DataLoader(
64 | train_dataset,
65 | batch_size=args.batch_size,
66 | shuffle=True,
67 | drop_last=False,
68 | )
69 |
70 | images, joints = grasp_data.load_norm_data("test", vmin=args.vmin, vmax=args.vmax)
71 | test_dataset = MultimodalDataset(images, joints, device=device, stdev=None)
72 | test_loader = torch.utils.data.DataLoader(
73 | test_dataset,
74 | batch_size=args.batch_size,
75 | shuffle=True,
76 | drop_last=False,
77 | )
78 |
79 | # define model
80 | model = SARNN(
81 | rec_dim=args.rec_dim,
82 | joint_dim=5,
83 | k_dim=args.k_dim,
84 | heatmap_size=args.heatmap_size,
85 | temperature=args.temperature,
86 | im_size=[64, 64],
87 | )
88 |
89 | # torch.compile makes PyTorch code run faster
90 | if args.compile:
91 | torch.set_float32_matmul_precision("high")
92 | model = torch.compile(model)
93 |
94 | # set optimizer
95 | optimizer = optim.Adam(model.parameters(), eps=1e-07)
96 |
97 | # load trainer/tester class
98 | loss_weights = [args.img_loss, args.joint_loss, args.pt_loss]
99 | trainer = fullBPTTtrainer(model, optimizer, loss_weights=loss_weights, device=device)
100 |
101 | # training main
102 | log_dir_path = set_logdir("./" + args.log_dir, args.tag)
103 | save_name = os.path.join(log_dir_path, "SARNN.pth")
104 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30)
105 | early_stop = EarlyStopping(patience=1000)
106 |
107 | with tqdm(range(args.epoch)) as pbar_epoch:
108 | for epoch in pbar_epoch:
109 | # train and test
110 | train_loss = trainer.process_epoch(train_loader)
111 | with torch.no_grad():
112 | test_loss = trainer.process_epoch(test_loader, training=False)
113 | writer.add_scalar("Loss/train_loss", train_loss, epoch)
114 | writer.add_scalar("Loss/test_loss", test_loss, epoch)
115 |
116 | # early stop
117 | save_ckpt, _ = early_stop(test_loss)
118 |
119 | if save_ckpt:
120 | trainer.save(epoch, [train_loss, test_loss], save_name)
121 |
122 | # print process bar
123 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss))
124 | pbar_epoch.update()
125 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/sarnn/libs/fullBPTT.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 | from eipl.utils import LossScheduler
9 |
10 |
11 | class fullBPTTtrainer:
12 | """
13 | Helper class to train recurrent neural networks with numpy sequences
14 |
15 | Args:
16 | traindata (np.array): list of np.array. First diemension should be time steps
17 | model (torch.nn.Module): rnn model
18 | optimizer (torch.optim): optimizer
19 | input_param (float): input parameter of sequential generation. 1.0 means open mode.
20 | """
21 |
22 | def __init__(self, model, optimizer, loss_weights=[1.0, 1.0], device="cpu"):
23 | self.device = device
24 | self.optimizer = optimizer
25 | self.loss_weights = loss_weights
26 | self.scheduler = LossScheduler(decay_end=1000, curve_name="s")
27 | self.model = model.to(self.device)
28 |
29 | def save(self, epoch, loss, savename):
30 | torch.save(
31 | {
32 | "epoch": epoch,
33 | "model_state_dict": self.model.state_dict(),
34 | # 'optimizer_state_dict': self.optimizer.state_dict(),
35 | "train_loss": loss[0],
36 | "test_loss": loss[1],
37 | },
38 | savename,
39 | )
40 |
41 | def process_epoch(self, data, training=True):
42 | if not training:
43 | self.model.eval()
44 | else:
45 | self.model.train()
46 |
47 | total_loss = 0.0
48 | for n_batch, ((x_img, x_joint), (y_img, y_joint)) in enumerate(data):
49 | if "cpu" in self.device:
50 | x_img = x_img.to(self.device)
51 | y_img = y_img.to(self.device)
52 | x_joint = x_joint.to(self.device)
53 | y_joint = y_joint.to(self.device)
54 |
55 | state = None
56 | yi_list, yv_list = [], []
57 | dec_pts_list, enc_pts_list = [], []
58 | self.optimizer.zero_grad(set_to_none=True)
59 | for t in range(x_img.shape[1] - 1):
60 | _yi_hat, _yv_hat, enc_ij, dec_ij, state = self.model(
61 | x_img[:, t], x_joint[:, t], state
62 | )
63 | yi_list.append(_yi_hat)
64 | yv_list.append(_yv_hat)
65 | enc_pts_list.append(enc_ij)
66 | dec_pts_list.append(dec_ij)
67 |
68 | yi_hat = torch.permute(torch.stack(yi_list), (1, 0, 2, 3, 4))
69 | yv_hat = torch.permute(torch.stack(yv_list), (1, 0, 2))
70 |
71 | img_loss = nn.MSELoss()(yi_hat, y_img[:, 1:]) * self.loss_weights[0]
72 | joint_loss = nn.MSELoss()(yv_hat, y_joint[:, 1:]) * self.loss_weights[1]
73 | # Gradually change the loss value using the LossScheluder class.
74 | pt_loss = nn.MSELoss()(
75 | torch.stack(dec_pts_list[:-1]), torch.stack(enc_pts_list[1:])
76 | ) * self.scheduler(self.loss_weights[2])
77 | loss = img_loss + joint_loss + pt_loss
78 | total_loss += loss.item()
79 |
80 | if training:
81 | loss.backward()
82 | self.optimizer.step()
83 |
84 | return total_loss / (n_batch + 1)
85 |
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/sarnn/log/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/tutorials/open_manipulator/sarnn/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/sarnn/bin/test.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import argparse
7 | import os
8 |
9 | import matplotlib.animation as anim
10 | import matplotlib.pylab as plt
11 | import numpy as np
12 | import torch
13 |
14 | from eipl.model import SARNN
15 | from eipl.utils import deprocess_img, normalization, restore_args, tensor2numpy
16 |
17 | # argument parser
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--filename", type=str, default=None)
20 | parser.add_argument("--idx", type=int, default=0)
21 | args = parser.parse_args()
22 |
23 | # restore parameters
24 | dir_name = os.path.split(args.filename)[0]
25 | params = restore_args(os.path.join(dir_name, "args.json"))
26 | idx = args.idx
27 |
28 | # load dataset
29 | minmax = [params["vmin"], params["vmax"]]
30 | images_raw = np.load("../simulator/data/test/images.npy")
31 | joints_raw = np.load("../simulator/data/test/joints.npy")
32 | joint_bounds = np.load("../simulator/data/joint_bounds.npy")
33 | images = images_raw[idx]
34 | joints = joints_raw[idx]
35 |
36 | # define model
37 | model = SARNN(
38 | rec_dim=params["rec_dim"],
39 | joint_dim=8,
40 | k_dim=params["k_dim"],
41 | heatmap_size=params["heatmap_size"],
42 | temperature=params["temperature"],
43 | im_size=[64, 64],
44 | )
45 |
46 | if params["compile"]:
47 | model = torch.compile(model)
48 |
49 | # load weight
50 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
51 | model.load_state_dict(ckpt["model_state_dict"])
52 | model.eval()
53 |
54 | # Inference
55 | im_size = 64
56 | image_list, joint_list = [], []
57 | enc_pts_list, dec_pts_list = [], []
58 | state = None
59 | nloop = len(images)
60 | for loop_ct in range(nloop):
61 | # load data and normalization
62 | img_t = images[loop_ct].transpose(2, 0, 1)
63 | img_t = normalization(img_t, (0, 255), minmax)
64 | img_t = torch.Tensor(np.expand_dims(img_t, 0))
65 | joint_t = normalization(joints[loop_ct], joint_bounds, minmax)
66 | joint_t = torch.Tensor(np.expand_dims(joint_t, 0))
67 |
68 | # predict rnn
69 | y_image, y_joint, enc_pts, dec_pts, state = model(img_t, joint_t, state)
70 |
71 | # denormalization
72 | pred_image = tensor2numpy(y_image[0])
73 | pred_image = deprocess_img(pred_image, params["vmin"], params["vmax"])
74 | pred_image = pred_image.transpose(1, 2, 0)
75 | pred_joint = tensor2numpy(y_joint[0])
76 | pred_joint = normalization(pred_joint, minmax, joint_bounds)
77 |
78 | # append data
79 | image_list.append(pred_image)
80 | joint_list.append(pred_joint)
81 | enc_pts_list.append(tensor2numpy(enc_pts[0]))
82 | dec_pts_list.append(tensor2numpy(dec_pts[0]))
83 |
84 | print("loop_ct:{}, joint:{}".format(loop_ct, pred_joint))
85 |
86 | pred_image = np.array(image_list)
87 | pred_joint = np.array(joint_list)
88 |
89 | # split key points
90 | enc_pts = np.array(enc_pts_list)
91 | dec_pts = np.array(dec_pts_list)
92 | enc_pts = enc_pts.reshape(-1, params["k_dim"], 2) * im_size
93 | dec_pts = dec_pts.reshape(-1, params["k_dim"], 2) * im_size
94 | enc_pts = np.clip(enc_pts, 0, im_size)
95 | dec_pts = np.clip(dec_pts, 0, im_size)
96 |
97 |
98 | # plot images
99 | T = len(images)
100 | fig, ax = plt.subplots(1, 3, figsize=(14, 6), dpi=60)
101 |
102 |
103 | def anim_update(i):
104 | for j in range(3):
105 | ax[j].cla()
106 |
107 | # plot camera image
108 | ax[0].imshow(images[i])
109 | for j in range(params["k_dim"]):
110 | ax[0].plot(enc_pts[i, j, 0], enc_pts[i, j, 1], "co", markersize=12) # encoder
111 | ax[0].plot(
112 | dec_pts[i, j, 0], dec_pts[i, j, 1], "rx", markersize=12, markeredgewidth=2
113 | ) # decoder
114 | ax[0].axis("off")
115 | ax[0].set_title("Input image", fontsize=20)
116 |
117 | # plot predicted image
118 | ax[1].imshow(pred_image[i])
119 | ax[1].axis("off")
120 | ax[1].set_title("Predicted image", fontsize=20)
121 |
122 | # plot joint angle
123 | ax[2].set_ylim(-np.pi, 3.4)
124 | ax[2].set_xlim(0, T)
125 | ax[2].plot(joints[1:], linestyle="dashed", c="k")
126 | # om has 5 joints, not 8
127 | for joint_idx in range(8):
128 | ax[2].plot(np.arange(i + 1), pred_joint[: i + 1, joint_idx])
129 | ax[2].set_xlabel("Step", fontsize=20)
130 | ax[2].set_title("Joint angles", fontsize=20)
131 | ax[2].tick_params(axis="x", labelsize=16)
132 | ax[2].tick_params(axis="y", labelsize=16)
133 | plt.subplots_adjust(left=0.01, right=0.98, bottom=0.12, top=0.9)
134 |
135 |
136 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T)
137 | ani.save("./output/SARNN_{}_{}.gif".format(params["tag"], idx))
138 |
139 | # If an error occurs in generating the gif animation, change the writer (imagemagick/ffmpeg).
140 | # ani.save("./output/SARNN_{}_{}_{}.gif".format(params["tag"], idx, args.input_param), writer="ffmpeg")
141 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/sarnn/bin/test_pca_sarnn.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import sys
8 | import torch
9 | import argparse
10 | import numpy as np
11 | import matplotlib.pylab as plt
12 | import matplotlib.animation as anim
13 | from sklearn.decomposition import PCA
14 | from eipl.model import SARNN
15 | from eipl.utils import restore_args, tensor2numpy, normalization
16 |
17 |
18 | # argument parser
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("filename", type=str, default=None)
21 | args = parser.parse_args()
22 |
23 | # restore parameters
24 | dir_name = os.path.split(args.filename)[0]
25 | params = restore_args(os.path.join(dir_name, "args.json"))
26 | # idx = args.idx
27 |
28 | # load dataset
29 | minmax = [params["vmin"], params["vmax"]]
30 | images = np.load("../simulator/data/test/images.npy")
31 | joints = np.load("../simulator/data/test/joints.npy")
32 | joint_bounds = np.load("../simulator/data/joint_bounds.npy")
33 |
34 | # define model
35 | model = SARNN(
36 | rec_dim=params["rec_dim"],
37 | joint_dim=8,
38 | k_dim=params["k_dim"],
39 | heatmap_size=params["heatmap_size"],
40 | temperature=params["temperature"],
41 | im_size=[64, 64],
42 | )
43 |
44 | if params["compile"]:
45 | model = torch.compile(model)
46 |
47 | # load weight
48 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
49 | model.load_state_dict(ckpt["model_state_dict"])
50 | model.eval()
51 |
52 | # Inference
53 | states = []
54 | state = None
55 | nloop = images.shape[1]
56 | for loop_ct in range(nloop):
57 | # load data and normalization
58 | img_t = images[:, loop_ct].transpose(0, 3, 1, 2)
59 | img_t = normalization(img_t, (0, 255), minmax)
60 | img_t = torch.Tensor(img_t)
61 | joint_t = normalization(joints[:, loop_ct], joint_bounds, minmax)
62 | joint_t = torch.Tensor(joint_t)
63 |
64 | # predict rnn
65 | _, _, _, _, state = model(img_t, joint_t, state)
66 | states.append(state[0])
67 |
68 | states = torch.permute(torch.stack(states), (1, 0, 2))
69 | states = tensor2numpy(states)
70 | # Reshape the state from [N,T,D] to [-1,D] for PCA of RNN.
71 | # N is the number of datasets
72 | # T is the sequence length
73 | # D is the dimension of the hidden state
74 | N, T, D = states.shape
75 | states = states.reshape(-1, D)
76 |
77 | # PCA
78 | loop_ct = float(360) / T
79 | pca_dim = 3
80 | pca = PCA(n_components=pca_dim).fit(states)
81 | pca_val = pca.transform(states)
82 | # Reshape the states from [-1, pca_dim] to [N,T,pca_dim] to
83 | # visualize each state as a 3D scatter.
84 | pca_val = pca_val.reshape(N, T, pca_dim)
85 |
86 | # plot images
87 | fig = plt.figure(dpi=120)
88 | ax = fig.add_subplot(projection="3d")
89 |
90 |
91 | def anim_update(i):
92 | ax.cla()
93 | angle = int(loop_ct * i)
94 | ax.view_init(30, angle)
95 |
96 | c_list = ["C0", "C1", "C2", "C3", "C4", "C5", "C6", "C7", "C8"]
97 | for n, color in enumerate(c_list):
98 | ax.scatter(
99 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0
100 | )
101 |
102 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0)
103 | pca_ratio = pca.explained_variance_ratio_ * 100
104 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0]))
105 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1]))
106 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2]))
107 | ax.tick_params(axis="x", labelsize=8)
108 | ax.tick_params(axis="y", labelsize=8)
109 | ax.tick_params(axis="z", labelsize=8)
110 |
111 |
112 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T)
113 | ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"]))
114 |
115 | # If an error occurs in generating the gif or mp4 animation, change the writer (imagemagick/ffmpeg).
116 | # ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"]), writer="imagemagick")
117 | # ani.save("./output/PCA_SARNN_{}.mp4".format(params["tag"]), writer="ffmpeg")
118 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/sarnn/bin/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import sys
8 | import numpy as np
9 | import torch
10 | import argparse
11 | from tqdm import tqdm
12 | import torch.optim as optim
13 | from collections import OrderedDict
14 | from torch.utils.tensorboard import SummaryWriter
15 | from eipl.model import SARNN
16 | from eipl.data import MultimodalDataset
17 | from eipl.utils import EarlyStopping, check_args, set_logdir, normalization
18 |
19 | # load own library
20 | sys.path.append("./libs/")
21 | from fullBPTT import fullBPTTtrainer
22 |
23 | # argument parser
24 | parser = argparse.ArgumentParser(
25 | description="Learning spatial autoencoder with recurrent neural network"
26 | )
27 | parser.add_argument("--model", type=str, default="sarnn")
28 | parser.add_argument("--epoch", type=int, default=10000)
29 | parser.add_argument("--batch_size", type=int, default=5)
30 | parser.add_argument("--rec_dim", type=int, default=50)
31 | parser.add_argument("--k_dim", type=int, default=5)
32 | parser.add_argument("--img_loss", type=float, default=0.1)
33 | parser.add_argument("--joint_loss", type=float, default=1.0)
34 | parser.add_argument("--pt_loss", type=float, default=0.1)
35 | parser.add_argument("--heatmap_size", type=float, default=0.1)
36 | parser.add_argument("--temperature", type=float, default=1e-4)
37 | parser.add_argument("--stdev", type=float, default=0.1)
38 | parser.add_argument("--lr", type=float, default=1e-4)
39 | parser.add_argument("--optimizer", type=str, default="adam")
40 | parser.add_argument("--log_dir", default="log/")
41 | parser.add_argument("--vmin", type=float, default=0.0)
42 | parser.add_argument("--vmax", type=float, default=1.0)
43 | parser.add_argument("--device", type=int, default=0)
44 | parser.add_argument("--compile", action="store_true")
45 | parser.add_argument("--tag", help="Tag name for snap/log sub directory")
46 | args = parser.parse_args()
47 |
48 | # check args
49 | args = check_args(args)
50 |
51 | # calculate the noise level (variance) from the normalized range
52 | stdev = args.stdev * (args.vmax - args.vmin)
53 |
54 | # set device id
55 | if args.device >= 0:
56 | device = "cuda:{}".format(args.device)
57 | else:
58 | device = "cpu"
59 |
60 | # load dataset
61 | minmax = [args.vmin, args.vmax]
62 | images_raw = np.load("../simulator/data/train/images.npy")
63 | joints_raw = np.load("../simulator/data/train/joints.npy")
64 | joint_bounds = np.load("../simulator/data/joint_bounds.npy")
65 | images = normalization(images_raw.transpose(0, 1, 4, 2, 3), (0, 255), minmax)
66 | joints = normalization(joints_raw, joint_bounds, minmax)
67 | train_dataset = MultimodalDataset(images, joints, device=device, stdev=stdev)
68 | train_loader = torch.utils.data.DataLoader(
69 | train_dataset,
70 | batch_size=args.batch_size,
71 | shuffle=True,
72 | drop_last=False,
73 | )
74 |
75 | images_raw = np.load("../simulator/data/test/images.npy")
76 | joints_raw = np.load("../simulator/data/test/joints.npy")
77 | images = normalization(images_raw.transpose(0, 1, 4, 2, 3), (0, 255), minmax)
78 | joints = normalization(joints_raw, joint_bounds, minmax)
79 | test_dataset = MultimodalDataset(images, joints, device=device, stdev=None)
80 | test_loader = torch.utils.data.DataLoader(
81 | test_dataset,
82 | batch_size=args.batch_size,
83 | shuffle=True,
84 | drop_last=False,
85 | )
86 |
87 | # define model
88 | model = SARNN(
89 | rec_dim=args.rec_dim,
90 | joint_dim=8,
91 | k_dim=args.k_dim,
92 | heatmap_size=args.heatmap_size,
93 | temperature=args.temperature,
94 | im_size=[64, 64],
95 | )
96 |
97 | # torch.compile makes PyTorch code run faster
98 | if args.compile:
99 | torch.set_float32_matmul_precision("high")
100 | model = torch.compile(model)
101 |
102 | # set optimizer
103 | optimizer = optim.Adam(model.parameters(), eps=1e-07)
104 |
105 | # load trainer/tester class
106 | loss_weights = [args.img_loss, args.joint_loss, args.pt_loss]
107 | trainer = fullBPTTtrainer(model, optimizer, loss_weights=loss_weights, device=device)
108 |
109 | ### training main
110 | log_dir_path = set_logdir("./" + args.log_dir, args.tag)
111 | save_name = os.path.join(log_dir_path, "SARNN.pth")
112 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30)
113 | early_stop = EarlyStopping(patience=1000)
114 |
115 | with tqdm(range(args.epoch)) as pbar_epoch:
116 | for epoch in pbar_epoch:
117 | # train and test
118 | train_loss = trainer.process_epoch(train_loader)
119 | with torch.no_grad():
120 | test_loss = trainer.process_epoch(test_loader, training=False)
121 | writer.add_scalar("Loss/train_loss", train_loss, epoch)
122 | writer.add_scalar("Loss/test_loss", test_loss, epoch)
123 |
124 | # early stop
125 | save_ckpt, _ = early_stop(test_loss)
126 |
127 | if save_ckpt:
128 | trainer.save(epoch, [train_loss, test_loss], save_name)
129 |
130 | # print process bar
131 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss))
132 | pbar_epoch.update()
133 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/sarnn/libs/fullBPTT.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 | from eipl.utils import LossScheduler
9 |
10 |
11 | class fullBPTTtrainer:
12 | """
13 | Helper class to train recurrent neural networks with numpy sequences
14 |
15 | Args:
16 | traindata (np.array): list of np.array. First diemension should be time steps
17 | model (torch.nn.Module): rnn model
18 | optimizer (torch.optim): optimizer
19 | input_param (float): input parameter of sequential generation. 1.0 means open mode.
20 | """
21 |
22 | def __init__(self, model, optimizer, loss_weights=[1.0, 1.0], device="cpu"):
23 | self.device = device
24 | self.optimizer = optimizer
25 | self.loss_weights = loss_weights
26 | self.scheduler = LossScheduler(decay_end=1000, curve_name="s")
27 | self.model = model.to(self.device)
28 |
29 | def save(self, epoch, loss, savename):
30 | torch.save(
31 | {
32 | "epoch": epoch,
33 | "model_state_dict": self.model.state_dict(),
34 | #'optimizer_state_dict': self.optimizer.state_dict(),
35 | "train_loss": loss[0],
36 | "test_loss": loss[1],
37 | },
38 | savename,
39 | )
40 |
41 | def process_epoch(self, data, training=True):
42 | if not training:
43 | self.model.eval()
44 | else:
45 | self.model.train()
46 |
47 | total_loss = 0.0
48 | for n_batch, ((x_img, x_joint), (y_img, y_joint)) in enumerate(data):
49 | if "cpu" in self.device:
50 | x_img = x_img.to(self.device)
51 | y_img = y_img.to(self.device)
52 | x_joint = x_joint.to(self.device)
53 | y_joint = y_joint.to(self.device)
54 |
55 | state = None
56 | yi_list, yv_list = [], []
57 | dec_pts_list, enc_pts_list = [], []
58 | self.optimizer.zero_grad(set_to_none=True)
59 | for t in range(x_img.shape[1] - 1):
60 | _yi_hat, _yv_hat, enc_ij, dec_ij, state = self.model(
61 | x_img[:, t], x_joint[:, t], state
62 | )
63 | yi_list.append(_yi_hat)
64 | yv_list.append(_yv_hat)
65 | enc_pts_list.append(enc_ij)
66 | dec_pts_list.append(dec_ij)
67 |
68 | yi_hat = torch.permute(torch.stack(yi_list), (1, 0, 2, 3, 4))
69 | yv_hat = torch.permute(torch.stack(yv_list), (1, 0, 2))
70 |
71 | img_loss = nn.MSELoss()(yi_hat, y_img[:, 1:]) * self.loss_weights[0]
72 | joint_loss = nn.MSELoss()(yv_hat, y_joint[:, 1:]) * self.loss_weights[1]
73 | # Gradually change the loss value using the LossScheluder class.
74 | pt_loss = nn.MSELoss()(
75 | torch.stack(dec_pts_list[:-1]), torch.stack(enc_pts_list[1:])
76 | ) * self.scheduler(self.loss_weights[2])
77 | loss = img_loss + joint_loss + pt_loss
78 | total_loss += loss.item()
79 |
80 | if training:
81 | loss.backward()
82 | self.optimizer.step()
83 |
84 | return total_loss / (n_batch + 1)
85 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/sarnn/log/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/sarnn/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/simulator/2_resave.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for dir in `ls ./data/raw_data/`; do
4 | echo "python3 ./bin/2_resave.py --ep_dir ./data/raw_data/${dir}"
5 | python3 ./bin/2_resave.py --ep_dir ./data/raw_data/${dir}
6 | done
7 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/simulator/3_check_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for dir in `ls ./data/raw_data/`; do
4 | echo "python3 ./bin/3_check_playback_data.py ./data/raw_data/${dir}/state_resave.npz"
5 | python3 ./bin/3_check_playback_data.py ./data/raw_data/${dir}/state_resave.npz
6 | done
7 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/simulator/bin/2_resave.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import sys
8 | import argparse
9 | import numpy as np
10 | import robosuite as suite
11 | from robosuite import load_controller_config
12 |
13 | sys.path.append("./libs")
14 | from rt_control_wrapper import RTControlWrapper
15 |
16 |
17 | def playback(env, ep_dir, rate):
18 | data = np.load(os.path.join(ep_dir, "state.npz"), allow_pickle=True)
19 | nloop = data["nloop"]
20 | states = data["states"]
21 | joints = data["joint_angles"]
22 | grippers = np.array([el["actions"][-1] for el in data["action_infos"]])
23 |
24 | # Set user parameter
25 | loop_ct = 0
26 |
27 | # Reset the environment and Set environment
28 | env.set_environment(ep_dir, states[0])
29 |
30 | # save list
31 | pose_list = []
32 | state_list = []
33 | joint_list = []
34 | image_list = []
35 | gripper_list = []
36 | success_list = []
37 |
38 | for loop_ct in range(nloop):
39 | # append
40 | if loop_ct % rate == 0:
41 | action = env.get_joint_action(joints[loop_ct], kp=rate)
42 | action[-1] = grippers[loop_ct]
43 |
44 | obs, reward, done, info = env.step(action)
45 | # env.render()
46 |
47 | # save data
48 | if loop_ct % rate == 0:
49 | state, success = env.get_state()
50 | image = env.get_image()
51 | joint = env.get_joints()
52 | pose = env.get_pose()
53 |
54 | # save robot sensor data
55 | pose_list.append(pose)
56 | state_list.append(state)
57 | image_list.append(image)
58 | joint_list.append(joint)
59 | success_list.append(success)
60 | gripper_list.append(action[-1])
61 |
62 | # saveing
63 | save_name = os.path.join(ep_dir, "state_resave.npz")
64 | print("save fille: ", save_name)
65 | np.savez(
66 | save_name,
67 | poses=np.array(pose_list),
68 | states=np.array(state_list),
69 | joints=np.array(joint_list),
70 | images=np.array(image_list),
71 | success=np.array(success_list),
72 | gripper=np.array(gripper_list).reshape(-1, 1),
73 | )
74 |
75 | env.close()
76 |
77 |
78 | if __name__ == "__main__":
79 | parser = argparse.ArgumentParser()
80 | parser.add_argument("--ep_dir", type=str, default="./data/raw_data/test/")
81 | parser.add_argument("--rate", type=int, default=5) # Down sampling rate
82 | args = parser.parse_args()
83 |
84 | # Get controller config
85 | # Import controller config IK_POSE/OSC_POSE/OSC_POSITION/JOINT_POSITION
86 | controller_config = load_controller_config(default_controller="JOINT_POSITION")
87 |
88 | # Create argument configuration
89 | config = {
90 | "env_name": "Lift",
91 | "robots": "Panda",
92 | "controller_configs": controller_config,
93 | }
94 |
95 | # create environment
96 | env = suite.make(
97 | **config,
98 | has_renderer=True,
99 | has_offscreen_renderer=True,
100 | render_camera="agentview",
101 | ignore_done=True,
102 | use_camera_obs=True,
103 | reward_shaping=True,
104 | control_freq=20,
105 | hard_reset=False,
106 | )
107 |
108 | # wrap the environment with data collection wrapper
109 | env = RTControlWrapper(env)
110 |
111 | # collect some data
112 | playback(env, args.ep_dir, args.rate)
113 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/simulator/bin/3_check_playback_data.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import argparse
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | import matplotlib.animation as anim
10 |
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("filename", type=str, default="./data/test/state.npz")
14 | args = parser.parse_args()
15 |
16 | # load sub directory name from filename
17 | savename = "./output/{}_".format(args.filename.split("/")[-2])
18 |
19 | # action_infos: action
20 | # obs_infos: image, joint_cos, joint_sin, gripper
21 | data = np.load(args.filename, allow_pickle=True)
22 | if sum(data["success"]) == 0:
23 | print("[ERROR] This data set has failed task during playback.")
24 | exit()
25 |
26 | # plot joint and images
27 | images = data["images"]
28 | poses = data["poses"]
29 | joints = np.concatenate((data["joints"], data["gripper"]), axis=-1)
30 |
31 | N = len(joints)
32 | fig, ax = plt.subplots(1, 3, figsize=(14, 6), dpi=60)
33 |
34 |
35 | def anim_update(i):
36 | for j in range(3):
37 | ax[j].axis("off")
38 | ax[j].cla()
39 |
40 | ax[0].imshow(np.flipud(images[i]))
41 | ax[0].axis("off")
42 | ax[0].set_title("Image", fontsize=20)
43 |
44 | ax[1].set_ylim(-3.5, 3.5)
45 | ax[1].set_xlim(0, N)
46 | for idx in range(8):
47 | ax[1].plot(np.arange(i + 1), joints[: i + 1, idx])
48 | ax[1].set_xlabel("Step", fontsize=20)
49 | ax[1].set_title("Joint angles", fontsize=20)
50 | ax[1].tick_params(axis="x", labelsize=16)
51 | ax[1].tick_params(axis="y", labelsize=16)
52 |
53 | ax[2].set_ylim(-3.5, 3.5)
54 | ax[2].set_xlim(0, N)
55 | for idx in range(6):
56 | ax[2].plot(np.arange(i + 1), poses[: i + 1, idx])
57 | ax[2].set_xlabel("Step", fontsize=20)
58 | ax[2].set_title("End effector poses", fontsize=20)
59 | ax[2].tick_params(axis="x", labelsize=16)
60 | ax[2].tick_params(axis="y", labelsize=16)
61 | plt.subplots_adjust(left=0.01, right=0.98, bottom=0.12, top=0.9)
62 |
63 |
64 | ani = anim.FuncAnimation(fig, anim_update, interval=int(N / 10), frames=N)
65 | ani.save(savename + "image_joint_ani.gif")
66 |
67 | # If an error occurs in generating the gif animation, change the writer (imagemagick/ffmpeg).
68 | # ani.save(savename + "image_joint_ani.gif", writer="imagemagick")
69 | # ani.save(savename + "image_joint_ani.mp4", writer="ffmpeg")
70 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/simulator/bin/4_generate_dataset.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import numpy as np
7 | from eipl.utils import resize_img, normalization, cos_interpolation, calc_minmax
8 |
9 |
10 | def load_data(index, data_dir, train):
11 | image_list = []
12 | joint_list = []
13 | pose_list = []
14 |
15 | num_list = [1, 2, 3, 4, 5] if train else [6, 7]
16 | for h in index:
17 | for i in num_list:
18 | filename = data_dir.format(h, i)
19 | print(filename)
20 | data = np.load(filename, allow_pickle=True)
21 |
22 | images = data["images"]
23 | images = np.array(images)[:, ::-1]
24 | images = resize_img(images, (64, 64))
25 |
26 | _poses = data["poses"]
27 | _joints = data["joints"]
28 | _gripper = data["gripper"][:, 0]
29 | _gripper = normalization(_gripper, (-1, 1), (0, 1))
30 | _gripper = cos_interpolation(_gripper, 10, expand_dims=True)
31 | poses = np.concatenate((_poses, _gripper), axis=-1)
32 | joints = np.concatenate((_joints, _gripper), axis=-1)
33 |
34 | joint_list.append(joints)
35 | image_list.append(images)
36 | pose_list.append(poses)
37 |
38 | joints = np.array(joint_list)
39 | images = np.array(image_list)
40 | poses = np.array(pose_list)
41 |
42 | return images, joints, poses
43 |
44 |
45 | if __name__ == "__main__":
46 | # train position: collect 7 data/position (5 for train, 2 for test)
47 | # test position: collect 2 data/position (all for test)
48 | # Pos1 Pos2 Pos3 Pos4 Pos5 Pos6 Pos7 Pos8 Pos9
49 | # pos_train: -0.2 -0.1 0.0 0.1 0.2
50 | # pos_test: -0.2 -0.15 -0.1 -0.05 0.0 0.05 0.1 0.15 0.2
51 |
52 | data_dir = "./data/raw_data/Pos{}_{}/state_resave.npz"
53 |
54 | # load train data
55 | train_index = [1, 3, 5, 7, 9]
56 | train_images, train_joints, train_poses = load_data(
57 | train_index, data_dir, train=True
58 | )
59 | np.save("./data/train/images.npy", train_images.astype(np.uint8))
60 | np.save("./data/train/joints.npy", train_joints.astype(np.float32))
61 | np.save("./data/train/poses.npy", train_poses.astype(np.float32))
62 |
63 | test_index = [1, 2, 3, 4, 5, 6, 7, 8, 9]
64 | test_images, test_joints, test_poses = load_data(test_index, data_dir, train=False)
65 | np.save("./data/test/images.npy", test_images.astype(np.uint8))
66 | np.save("./data/test/joints.npy", test_joints.astype(np.float32))
67 | np.save("./data/test/poses.npy", test_poses.astype(np.float32))
68 |
69 | # save bounds
70 | poses = np.concatenate((train_poses, test_poses), axis=0)
71 | joints = np.concatenate((train_joints, test_joints), axis=0)
72 | pose_bounds = calc_minmax(poses)
73 | joint_bounds = calc_minmax(joints)
74 | np.save("./data/pose_bounds.npy", pose_bounds)
75 | np.save("./data/joint_bounds.npy", joint_bounds)
76 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/simulator/bin/5_check_dataset.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import argparse
7 | import numpy as np
8 | import matplotlib.pylab as plt
9 | import matplotlib.animation as anim
10 | from eipl.utils import normalization
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--idx", type=int, default=0)
15 | args = parser.parse_args()
16 |
17 | idx = int(args.idx)
18 | joints = np.load("./data/test/joints.npy")
19 | joint_bounds = np.load("./data/joint_bounds.npy")
20 | images = np.load("./data/test/images.npy")
21 | N = images.shape[1]
22 |
23 |
24 | # normalized joints
25 | minmax = [0.1, 0.9]
26 | norm_joints = normalization(joints, joint_bounds, minmax)
27 |
28 | # print data information
29 | print("load test data, index number is {}".format(idx))
30 | print(
31 | "Joint: shape={}, min={:.3g}, max={:.3g}".format(
32 | joints.shape, joints.min(), joints.max()
33 | )
34 | )
35 | print(
36 | "Norm joint: shape={}, min={:.3g}, max={:.3g}".format(
37 | norm_joints.shape, norm_joints.min(), norm_joints.max()
38 | )
39 | )
40 |
41 | # plot images and normalized joints
42 | fig, ax = plt.subplots(1, 3, figsize=(14, 6), dpi=60)
43 |
44 |
45 | def anim_update(i):
46 | for j in range(3):
47 | ax[j].cla()
48 |
49 | # plot image
50 | ax[0].imshow(images[idx, i])
51 | ax[0].axis("off")
52 | ax[0].set_title("Image", fontsize=20)
53 |
54 | # plot joint angle
55 | ax[1].set_ylim(-1.0, 2.0)
56 | ax[1].set_xlim(0, N)
57 | ax[1].plot(joints[idx], linestyle="dashed", c="k")
58 |
59 | for joint_idx in range(8):
60 | ax[1].plot(np.arange(i + 1), joints[idx, : i + 1, joint_idx])
61 | ax[1].set_xlabel("Step", fontsize=20)
62 | ax[1].set_title("Joint angles", fontsize=20)
63 | ax[1].tick_params(axis="x", labelsize=16)
64 | ax[1].tick_params(axis="y", labelsize=16)
65 |
66 | # plot normalized joint angle
67 | ax[2].set_ylim(0.0, 1.0)
68 | ax[2].set_xlim(0, N)
69 | ax[2].plot(norm_joints[idx], linestyle="dashed", c="k")
70 |
71 | for joint_idx in range(8):
72 | ax[2].plot(np.arange(i + 1), norm_joints[idx, : i + 1, joint_idx])
73 | ax[2].set_xlabel("Step", fontsize=20)
74 | ax[2].set_title("Normalized joint angles", fontsize=20)
75 | ax[2].tick_params(axis="x", labelsize=16)
76 | ax[2].tick_params(axis="y", labelsize=16)
77 | plt.subplots_adjust(left=0.01, right=0.98, bottom=0.12, top=0.9)
78 |
79 |
80 | ani = anim.FuncAnimation(fig, anim_update, interval=int(N / 10), frames=N)
81 | ani.save("./output/check_dataset_{}.gif".format(idx))
82 |
83 | # If an error occurs in generating the gif animation, change the writer (imagemagick/ffmpeg).
84 | # ani.save("./output/check_dataset_{}.gif".format(idx), writer="imagemagick")
85 | # ani.save("./output/check_dataset_{}.gif".format(idx), writer="ffmpeg")
86 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/simulator/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/simulator/libs/environment.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | from robosuite.models.objects import BoxObject
7 | from robosuite.utils.mjcf_utils import CustomMaterial
8 |
9 |
10 | tex_attrib = {
11 | "type": "cube",
12 | }
13 |
14 | mat_attrib = {
15 | "texrepeat": "1 1",
16 | "specular": "0.4",
17 | "shininess": "0.1",
18 | }
19 |
20 | redwood = CustomMaterial(
21 | texture="WoodRed",
22 | tex_name="redwood",
23 | mat_name="redwood_mat",
24 | tex_attrib=tex_attrib,
25 | mat_attrib=mat_attrib,
26 | )
27 |
28 | cube = BoxObject(
29 | name="cube",
30 | size_min=[0.020, 0.020, 0.020],
31 | size_max=[0.022, 0.022, 0.022],
32 | rgba=[1, 0, 0, 1],
33 | material=redwood,
34 | )
35 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/simulator/libs/rt_control_wrapper.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | # Released under the MIT License.
4 | #
5 |
6 | import os
7 | import numpy as np
8 | import transforms3d as T
9 | from robosuite.wrappers import Wrapper
10 |
11 |
12 | class RTControlWrapper(Wrapper):
13 | def __init__(self, env, save_dir=None, vis_settings=False):
14 | super().__init__(env)
15 | """
16 | RTControlWrapper specialized for imitation learning
17 | """
18 |
19 | # initialization
20 | self.robot = self.env.robots[0]
21 | self.robot.controller_config["kp"] = 150
22 |
23 | # save the task instance (will be saved on the first env interaction)
24 | self._current_task_instance_xml = self.env.sim.model.get_xml()
25 | self._current_task_instance_state = np.array(self.env.sim.get_state().flatten())
26 |
27 | # make save directory and save xml file
28 | if save_dir is not None:
29 | if not os.path.exists(save_dir):
30 | print(
31 | "DataCollectionWrapper: making new directory at {}".format(save_dir)
32 | )
33 | os.makedirs(save_dir)
34 | self.save_xml(save_dir)
35 |
36 | self._vis_settings = None
37 | if vis_settings:
38 | # Create internal dict to store visualization settings (set to True by default)
39 | self._vis_settings = {vis: True for vis in self.env._visualizations}
40 |
41 | def reset(self):
42 | """
43 | Extends vanilla reset() function call to accommodate data collection
44 | Returns:
45 | OrderedDict: Environment observation space after reset occurs
46 | """
47 | ret = super().reset()
48 | return ret
49 |
50 | def step(self, action):
51 | """
52 | Extends vanilla step() function call to accommodate data collection
53 | Args:
54 | action (np.array): Action to take in environment
55 | Returns:
56 | 4-tuple:
57 | - (OrderedDict) observations from the environment
58 | - (float) reward from the environment
59 | - (bool) whether the current episode is completed or not
60 | - (dict) misc information
61 | """
62 | ret = super().step(action)
63 | if self._vis_settings is not None:
64 | self.env.visualize(vis_settings=self._vis_settings)
65 |
66 | return ret
67 |
68 | def save_xml(self, save_dir="./data"):
69 | # save the model xml
70 | xml_path = os.path.join(save_dir, "model.xml")
71 | with open(xml_path, "w") as f:
72 | f.write(self._current_task_instance_xml)
73 |
74 | def set_environment(self, save_dir, state):
75 | xml_path = os.path.join(save_dir, "model.xml")
76 | with open(xml_path, "r") as f:
77 | self.env.reset_from_xml_string(f.read())
78 |
79 | self.env.sim.set_state_from_flattened(state)
80 | if self._vis_settings is not None:
81 | self.env.visualize(vis_settings=self._vis_settings)
82 |
83 | def set_gripper_qpos(self, qpos):
84 | for i, q in enumerate(qpos):
85 | self.env.sim.data.set_joint_qpos(
86 | name="gripper0_finger_joint{}".format(i + 1), value=q
87 | )
88 |
89 | def set_joint_qpos(self, qpos):
90 | for i, q in enumerate(qpos):
91 | self.env.sim.data.set_joint_qpos(
92 | name="robot0_joint{}".format(i + 1), value=q
93 | )
94 |
95 | def get_image(self, name="agentview_image"):
96 | return self.env.observation_spec()[name]
97 |
98 | def get_joint_action(self, goal_joint_pos, kp, kd=0.0):
99 | """relative2absolute_joint_pos_commands"""
100 | action = [0 for _ in range(self.robot.dof)]
101 | curr_joint_pos = self.robot._joint_positions
102 | curr_joint_vel = self.robot._joint_velocities
103 |
104 | for i in range(len(goal_joint_pos)):
105 | action[i] = (goal_joint_pos[i] - curr_joint_pos[i]) * kp - curr_joint_vel[
106 | i
107 | ] * kd
108 |
109 | return action
110 |
111 | def get_pose(self):
112 | position = self.env.robots[0]._hand_pos
113 | orientation_matrix = self.env.robots[0]._hand_orn
114 | orientation_euler = T.euler.mat2euler(orientation_matrix)
115 |
116 | pose = list(position) + list(orientation_euler)
117 | pose = np.array(pose)
118 | pose[pose < -np.pi / 2] += np.pi * 2
119 |
120 | return pose
121 |
122 | def get_gripper(self):
123 | return self.env.observation_spec()["robot0_gripper_qpos"]
124 |
125 | def get_joints(self):
126 | return self.robot._joint_positions
127 |
128 | def check_success(self):
129 | return self.env._check_success()
130 |
131 | def get_state(self):
132 | state = self.env.sim.get_state().flatten()
133 | # successful
134 | if self.env._check_success():
135 | return state, True
136 | else:
137 | return state, False
138 |
--------------------------------------------------------------------------------
/eipl/tutorials/robosuite/simulator/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .arg_utils import *
2 | from .callback import *
3 | from .data import *
4 | from .nn_func import *
5 | from .path_utils import *
6 | from .print_func import *
7 | from .utils import *
8 |
--------------------------------------------------------------------------------
/eipl/utils/arg_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import json
9 | import datetime
10 | from .print_func import *
11 | from .path_utils import *
12 |
13 |
14 | def print_args(args):
15 | """Print arguments"""
16 | if not isinstance(args, dict):
17 | args = vars(args)
18 |
19 | keys = args.keys()
20 | keys = sorted(keys)
21 |
22 | print("================================")
23 | for key in keys:
24 | print("{} : {}".format(key, args[key]))
25 | print("================================")
26 |
27 |
28 | def save_args(args, filename):
29 | """Dump arguments as json file"""
30 | with open(filename, "w") as f:
31 | json.dump(vars(args), f, indent=4, sort_keys=True)
32 |
33 |
34 | def restore_args(filename):
35 | """Load argument file from file"""
36 | with open(filename, "r") as f:
37 | args = json.load(f)
38 | return args
39 |
40 |
41 | def get_config(args, tag, default=None):
42 | """Get value from argument"""
43 | if not isinstance(args, dict):
44 | raise ValueError("args should be dict.")
45 |
46 | if tag in args:
47 | if args[tag] is None:
48 | print_info("set {} <-- {} (default)".format(tag, default))
49 | return default
50 | else:
51 | print_info("set {} <--- {}".format(tag, args[tag]))
52 | return args[tag]
53 | else:
54 | if default is None:
55 | raise ValueError("you need to specify config {}".format(tag))
56 |
57 | print_info("set {} <-- {} (default)".format(tag, default))
58 | return default
59 |
60 |
61 | def check_args(args):
62 | """Check arguments"""
63 |
64 | if args.tag is None:
65 | tag = datetime.datetime.today().strftime("%Y%m%d_%H%M_%S")
66 | args.tag = tag
67 | print_info("Set tag = %s" % tag)
68 |
69 | # make log directory
70 | check_path(os.path.join(args.log_dir, args.tag), mkdir=True)
71 |
72 | # saves arguments into json file
73 | save_args(args, os.path.join(args.log_dir, args.tag, "args.json"))
74 |
75 | print_args(args)
76 | return args
77 |
--------------------------------------------------------------------------------
/eipl/utils/callback.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import torch
9 | import numpy as np
10 |
11 |
12 | class EarlyStopping:
13 | def __init__(self, patience=5):
14 | """
15 | Early stops the training if validation loss doesn't improve after a given patience.
16 | Args:
17 | patience (int): Number of epochs with no improvement after which training will be stopped.
18 | """
19 | self.patience = patience
20 | self.counter = 0
21 | self.best_score = None
22 | self.save_ckpt = False
23 | self.stop_flag = False
24 | self.val_loss_min = np.inf
25 |
26 | def __call__(self, val_loss):
27 | if np.isnan(val_loss) or np.isinf(val_loss):
28 | raise RuntimeError("Invalid loss, terminating training")
29 |
30 | score = -val_loss
31 |
32 | if self.best_score is None:
33 | self.save_ckpt = True
34 | self.best_score = score
35 | elif score < self.best_score:
36 | self.save_ckpt = False
37 | self.counter += 1
38 | if self.counter >= self.patience:
39 | self.stop_flag = True
40 | else:
41 | self.save_ckpt = True
42 | self.best_score = score
43 | self.counter = 0
44 |
45 | return self.save_ckpt, self.stop_flag
46 |
47 |
48 | class LossScheduler:
49 | def __init__(self, decay_end=1000, curve_name="s"):
50 | decay_start = 0
51 | self.counter = -1
52 | self.decay_end = decay_end
53 | self.interpolated_values = self.curve_interpolation(
54 | decay_start, decay_end, decay_end, curve_name
55 | )
56 |
57 | def linear_interpolation(self, start, end, num_points):
58 | x = np.linspace(start, end, num_points)
59 | return x
60 |
61 | def s_curve_interpolation(self, start, end, num_points):
62 | t = np.linspace(0, 1, num_points)
63 | x = start + (end - start) * (t - np.sin(2 * np.pi * t) / (2 * np.pi))
64 | return x
65 |
66 | def inverse_s_curve_interpolation(self, start, end, num_points):
67 | t = np.linspace(0, 1, num_points)
68 | x = start + (end - start) * (t + np.sin(2 * np.pi * t) / (2 * np.pi))
69 | return x
70 |
71 | def deceleration_curve_interpolation(self, start, end, num_points):
72 | t = np.linspace(0, 1, num_points)
73 | x = start + (end - start) * (1 - np.cos(np.pi * t / 2))
74 | return x
75 |
76 | def acceleration_curve_interpolation(self, start, end, num_points):
77 | t = np.linspace(0, 1, num_points)
78 | x = start + (end - start) * (np.sin(np.pi * t / 2))
79 | return x
80 |
81 | def curve_interpolation(self, start, end, num_points, curve_name):
82 | if curve_name == "linear":
83 | interpolated_values = self.linear_interpolation(start, end, num_points)
84 | elif curve_name == "s":
85 | interpolated_values = self.s_curve_interpolation(start, end, num_points)
86 | elif curve_name == "inverse_s":
87 | interpolated_values = self.inverse_s_curve_interpolation(
88 | start, end, num_points
89 | )
90 | elif curve_name == "deceleration":
91 | interpolated_values = self.deceleration_curve_interpolation(
92 | start, end, num_points
93 | )
94 | elif curve_name == "acceleration":
95 | interpolated_values = self.acceleration_curve_interpolation(
96 | start, end, num_points
97 | )
98 | else:
99 | assert False, "Invalid curve name. {}".format(curve_name)
100 |
101 | return interpolated_values / num_points
102 |
103 | def __call__(self, loss_weight):
104 | self.counter += 1
105 | if self.counter >= self.decay_end:
106 | return loss_weight
107 | else:
108 | return self.interpolated_values[self.counter] * loss_weight
109 |
--------------------------------------------------------------------------------
/eipl/utils/check_gpu.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import torch
9 |
10 | print("torch version:", torch.__version__)
11 |
12 | if torch.cuda.is_available():
13 | print("cuda is available")
14 | print("device_count: ", torch.cuda.device_count())
15 | print("device name: ", torch.cuda.get_device_name())
16 | else:
17 | print("cuda is not avaiable")
18 |
--------------------------------------------------------------------------------
/eipl/utils/convert_compiled_pth.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import torch
10 | import argparse
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("filename", type=str, default=None)
14 | args = parser.parse_args()
15 |
16 | # resave original file
17 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
18 |
19 | restored_ckpt = {}
20 | for k, v in ckpt["model_state_dict"].items():
21 | restored_ckpt[k.replace("_orig_mod.", "")] = v
22 |
23 | ckpt["model_state_dict"] = restored_ckpt
24 |
25 | savename = "{}_v1.pth".format(os.path.splitext(args.filename)[0])
26 | torch.save(ckpt, savename)
27 |
--------------------------------------------------------------------------------
/eipl/utils/nn_func.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import torch.nn as nn
9 |
10 |
11 | def get_activation_fn(name, inplace=True):
12 | if name.casefold() == "relu":
13 | return nn.ReLU(inplace=inplace)
14 | elif name.casefold() == "lrelu":
15 | return nn.LeakyReLU(inplace=inplace)
16 | elif name.casefold() == "softmax":
17 | return nn.Softmax()
18 | elif name.casefold() == "tanh":
19 | return nn.Tanh()
20 | else:
21 | assert False, "Unknown activation function {}".format(name)
22 |
--------------------------------------------------------------------------------
/eipl/utils/path_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import numpy as np
10 |
11 |
12 | def check_filename(filename):
13 | if os.path.exists(filename):
14 | raise ValueError("{} exists.".format(filename))
15 | return filename
16 |
17 |
18 | def check_path(path, mkdir=False):
19 | """
20 | Checks that path is collect
21 | """
22 | if path[-1] == "/":
23 | path = path[:-1]
24 |
25 | if not os.path.exists(path):
26 | if mkdir:
27 | os.makedirs(path, exist_ok=True)
28 | else:
29 | raise ValueError("%s does not exist" % path)
30 |
31 | return path
32 |
--------------------------------------------------------------------------------
/eipl/utils/print_func.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | OK = "\033[92m"
9 | WARN = "\033[93m"
10 | NG = "\033[91m"
11 | END_CODE = "\033[0m"
12 |
13 |
14 | def print_info(msg):
15 | print(OK + "[INFO] " + END_CODE + msg)
16 |
17 |
18 | def print_warn(msg):
19 | print(WARN + "[WARNING] " + END_CODE + msg)
20 |
21 |
22 | def print_error(msg):
23 | print(NG + "[ERROR] " + END_CODE + msg)
24 |
--------------------------------------------------------------------------------
/eipl/utils/resave_pth.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import torch
10 | import argparse
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("filename", type=str, default=None)
14 | args = parser.parse_args()
15 |
16 | # resave original file
17 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
18 | savename = "{}_org.pth".format(os.path.splitext(args.filename)[0])
19 | torch.save(ckpt, savename)
20 |
21 | # save pth file without optimizer state
22 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
23 | ckpt.pop("optimizer_state_dict")
24 | torch.save(ckpt, args.filename)
25 |
--------------------------------------------------------------------------------
/eipl/utils/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import cv2
10 | import glob
11 | import datetime
12 | import numpy as np
13 | import matplotlib.pylab as plt
14 |
15 |
16 | def check_path(path, mkdir=False):
17 | """
18 | checks given path is existing or not
19 | """
20 | if path[-1] == "/":
21 | path = path[:-1]
22 |
23 | if not os.path.exists(path):
24 | if mkdir:
25 | os.mkdir(path)
26 | else:
27 | raise ValueError("%s does not exist" % path)
28 | return path
29 |
30 |
31 | def set_logdir(log_dir, tag):
32 | return check_path(os.path.join(log_dir, tag), mkdir=True)
33 |
--------------------------------------------------------------------------------
/eipl/zoo/cae/bin/extract.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import torch
10 | import argparse
11 | import numpy as np
12 | from eipl.model import BasicCAE, CAE, BasicCAEBN, CAEBN
13 | from eipl.data import SampleDownloader
14 | from eipl.utils import print_info, restore_args, tensor2numpy
15 |
16 |
17 | # argument parser
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--filename", type=str, default=None)
20 | args = parser.parse_args()
21 |
22 | # restore parameters
23 | dir_name = os.path.split(args.filename)[0]
24 | params = restore_args(os.path.join(dir_name, "args.json"))
25 |
26 | # define model
27 | if params["model"] == "BasicCAE":
28 | model = BasicCAE(feat_dim=params["feat_dim"])
29 | elif params["model"] == "CAE":
30 | model = CAE(feat_dim=params["feat_dim"])
31 | elif params["model"] == "BasicCAEBN":
32 | model = BasicCAEBN(feat_dim=params["feat_dim"])
33 | elif params["model"] == "CAEBN":
34 | model = CAEBN(feat_dim=params["feat_dim"])
35 | else:
36 | assert False, "Unknown model name {}".format(params["model"])
37 |
38 | # load weight
39 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
40 | model.load_state_dict(ckpt["model_state_dict"])
41 | model.eval()
42 |
43 | for data_type in ["train", "test"]:
44 | # generate rnn dataset
45 | os.makedirs("./data/{}/".format(data_type), exist_ok=True)
46 |
47 | # load dataset
48 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="CHW")
49 | images, joints = grasp_data.load_norm_data(
50 | data_type, params["vmin"], params["vmax"]
51 | )
52 | images = torch.tensor(images)
53 | joint_bounds = np.load(
54 | os.path.join(
55 | os.path.expanduser("~"), ".eipl/airec/grasp_bottle/joint_bounds.npy"
56 | )
57 | )
58 |
59 | # extract image feature
60 | N = images.shape[0]
61 | feature_list = []
62 | for i in range(N):
63 | _features = model.encoder(images[i])
64 | feature_list.append(tensor2numpy(_features))
65 |
66 | features = np.array(feature_list)
67 | np.save("./data/joint_bounds.npy", joint_bounds)
68 | np.save("./data/{}/features.npy".format(data_type), features)
69 | np.save("./data/{}/joints.npy".format(data_type), joints)
70 |
71 | print_info("{} data".format(data_type))
72 | print("==================================================")
73 | print("Shape of joints angle:", joints.shape)
74 | print("Shape of image feature:", features.shape)
75 | print("==================================================")
76 | print()
77 |
78 | # save features minmax bounds
79 | feat_list = []
80 | for data_type in ["train", "test"]:
81 | feat_list.append(np.load("./data/{}/features.npy".format(data_type)))
82 |
83 | feat = np.vstack(feat_list)
84 | feat_minmax = np.array([feat.min(), feat.max()])
85 | np.save("./data/feat_bounds.npy", feat_minmax)
86 |
--------------------------------------------------------------------------------
/eipl/zoo/cae/bin/test.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import torch
10 | import argparse
11 | import numpy as np
12 | import matplotlib.pylab as plt
13 | import matplotlib.animation as anim
14 | from eipl.model import BasicCAE, CAE, BasicCAEBN, CAEBN
15 | from eipl.data import SampleDownloader, WeightDownloader
16 | from eipl.utils import normalization, deprocess_img, restore_args
17 | from eipl.utils import tensor2numpy
18 |
19 |
20 | # argument parser
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--filename", type=str, default=None)
23 | parser.add_argument("--idx", type=str, default="0")
24 | parser.add_argument("--pretrained", action="store_true")
25 | args = parser.parse_args()
26 |
27 | # check args
28 | assert args.filename or args.pretrained, "Please set filename or pretrained"
29 |
30 | # load pretrained weight
31 | if args.pretrained:
32 | WeightDownloader("airec", "grasp_bottle")
33 | args.filename = os.path.join(
34 | os.path.expanduser("~"), ".eipl/airec/grasp_bottle/weights/CAEBN/model.pth"
35 | )
36 |
37 | # restore parameters
38 | dir_name = os.path.split(args.filename)[0]
39 | params = restore_args(os.path.join(dir_name, "args.json"))
40 | idx = int(args.idx)
41 |
42 | # load dataset
43 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="HWC")
44 | images_raw, _ = grasp_data.load_raw_data("test")
45 | images = normalization(
46 | images_raw.astype(np.float32), (0.0, 255.0), (params["vmin"], params["vmax"])
47 | )
48 | images = images.transpose(0, 1, 4, 2, 3)
49 | images = torch.tensor(images)
50 | T = images.shape[1]
51 |
52 | # define model
53 | if params["model"] == "BasicCAE":
54 | model = BasicCAE(feat_dim=params["feat_dim"])
55 | elif params["model"] == "CAE":
56 | model = CAE(feat_dim=params["feat_dim"])
57 | elif params["model"] == "BasicCAEBN":
58 | model = BasicCAEBN(feat_dim=params["feat_dim"])
59 | elif params["model"] == "CAEBN":
60 | model = CAEBN(feat_dim=params["feat_dim"])
61 | else:
62 | assert False, "Unknown model name {}".format(params["model"])
63 |
64 | if params["compile"]:
65 | model = torch.compile(model)
66 |
67 | # load weight
68 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
69 | model.load_state_dict(ckpt["model_state_dict"])
70 | model.eval()
71 |
72 | # prediction
73 | _yi = model(images[idx])
74 | yi = deprocess_img(tensor2numpy(_yi), params["vmin"], params["vmax"])
75 | yi = yi.transpose(0, 2, 3, 1)
76 |
77 | # plot images
78 | fig, ax = plt.subplots(1, 2, figsize=(8, 3), dpi=60)
79 |
80 |
81 | def anim_update(i):
82 | for j in range(2):
83 | ax[j].cla()
84 |
85 | ax[0].imshow(images_raw[idx, i, :, :, ::-1])
86 | ax[0].axis("off")
87 | ax[0].set_title("Input image")
88 |
89 | ax[1].imshow(yi[i, :, :, ::-1])
90 | ax[1].axis("off")
91 | ax[1].set_title("Reconstructed image")
92 |
93 |
94 | # defaults
95 | ani = anim.FuncAnimation(fig, anim_update, interval=int(T / 10), frames=T)
96 | ani.save("./output/{}_{}_{}.gif".format(params["model"], params["tag"], idx))
97 |
98 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg).
99 | # ani.save("./output/{}_{}_{}.gif".format(params["model"], params["tag"], idx), writer="imagemagick")
100 | # ani.save("./output/{}_{}_{}.mp4".format(params["model"], params["tag"], idx), writer="ffmpeg")
101 |
--------------------------------------------------------------------------------
/eipl/zoo/cae/bin/test_pca_cae.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import torch
10 | import argparse
11 | import matplotlib.pylab as plt
12 | import matplotlib.animation as anim
13 | from sklearn.decomposition import PCA
14 | from eipl.data import SampleDownloader
15 | from eipl.utils import tensor2numpy, restore_args
16 | from eipl.model import BasicCAE, CAE, BasicCAEBN, CAEBN
17 |
18 |
19 | # argument parser
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--filename", type=str, default=None)
22 | args = parser.parse_args()
23 |
24 | # restore parameters
25 | dir_name = os.path.split(args.filename)[0]
26 | params = restore_args(os.path.join(dir_name, "args.json"))
27 |
28 | # load dataset
29 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="CHW")
30 | images, _ = grasp_data.load_norm_data("test", params["vmin"], params["vmax"])
31 | N, T, C, W, H = images.shape
32 | images = images.reshape(N * T, C, W, H)
33 | images = torch.tensor(images)
34 |
35 | # define model
36 | if params["model"] == "BasicCAE":
37 | model = BasicCAE(feat_dim=params["feat_dim"])
38 | elif params["model"] == "CAE":
39 | model = CAE(feat_dim=params["feat_dim"])
40 | elif params["model"] == "BasicCAEBN":
41 | model = BasicCAEBN(feat_dim=params["feat_dim"])
42 | elif params["model"] == "CAEBN":
43 | model = CAEBN(feat_dim=params["feat_dim"])
44 | else:
45 | assert False, "Unknown model name {}".format(params["model"])
46 |
47 | if params["compile"]:
48 | model = torch.compile(model)
49 |
50 | # load weight
51 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
52 | model.load_state_dict(ckpt["model_state_dict"])
53 | model.eval()
54 |
55 | # prediction
56 | feat = model.encoder(images)
57 | feat = tensor2numpy(feat)
58 | loop_ct = float(360) / T
59 |
60 | pca_dim = 3
61 | pca = PCA(n_components=pca_dim).fit(feat)
62 | pca_val = pca.transform(feat)
63 | pca_val = pca_val.reshape(N, T, pca_dim)
64 |
65 | fig = plt.figure()
66 | ax = fig.add_subplot(projection="3d")
67 |
68 |
69 | def anim_update(i):
70 | ax.cla()
71 | angle = int(loop_ct * i)
72 | ax.view_init(30, angle)
73 |
74 | c_list = ["C0", "C1", "C2", "C3", "C4"]
75 | for n, color in enumerate(c_list):
76 | ax.scatter(
77 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0
78 | )
79 |
80 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0)
81 | pca_ratio = pca.explained_variance_ratio_ * 100
82 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0]))
83 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1]))
84 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2]))
85 |
86 |
87 | ani = anim.FuncAnimation(fig, anim_update, interval=T, frames=T)
88 | ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"]))
89 |
90 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg).
91 | # ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"]), writer="imagemagick")
92 | # ani.save("./output/PCA_{}_{}.mp4".format(params["model"], params["tag"]), writer="ffmpeg")
93 |
--------------------------------------------------------------------------------
/eipl/zoo/cae/bin/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import sys
10 | import torch
11 | import argparse
12 | from tqdm import tqdm
13 | from collections import OrderedDict
14 | import torch.optim as optim
15 | from torch.utils.tensorboard import SummaryWriter
16 | from eipl.model import BasicCAE, CAE, BasicCAEBN, CAEBN
17 | from eipl.data import ImageDataset, SampleDownloader
18 | from eipl.utils import EarlyStopping, check_args, set_logdir
19 |
20 | # load own library
21 | sys.path.append("./libs/")
22 | from trainer import Trainer
23 |
24 |
25 | # argument parser
26 | parser = argparse.ArgumentParser(description="Learning convolutional autoencoder")
27 | parser.add_argument("--model", type=str, default="CAEBN")
28 | parser.add_argument("--epoch", type=int, default=10000)
29 | parser.add_argument("--batch_size", type=int, default=32)
30 | parser.add_argument("--feat_dim", type=int, default=10)
31 | parser.add_argument("--stdev", type=float, default=0.1)
32 | parser.add_argument("--lr", type=float, default=1e-3)
33 | parser.add_argument("--optimizer", type=str, default="adam")
34 | parser.add_argument("--log_dir", default="log/")
35 | parser.add_argument("--vmin", type=float, default=0.1)
36 | parser.add_argument("--vmax", type=float, default=0.9)
37 | parser.add_argument("--device", type=int, default=0)
38 | parser.add_argument("--compile", action="store_true")
39 | parser.add_argument("--tag", help="Tag name for snap/log sub directory")
40 | args = parser.parse_args()
41 |
42 | # check args
43 | args = check_args(args)
44 |
45 | # calculate the noise level (variance) from the normalized range
46 | stdev = args.stdev * (args.vmax - args.vmin)
47 |
48 | # set device id
49 | if args.device >= 0:
50 | device = "cuda:{}".format(args.device)
51 | else:
52 | device = "cpu"
53 |
54 | # load dataset
55 | minmax = [args.vmin, args.vmax]
56 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="CHW")
57 | images, _ = grasp_data.load_norm_data("train", vmin=args.vmin, vmax=args.vmax)
58 | train_dataset = ImageDataset(images, device=device, stdev=stdev)
59 | train_loader = torch.utils.data.DataLoader(
60 | train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False
61 | )
62 |
63 | images, _ = grasp_data.load_norm_data("test", vmin=args.vmin, vmax=args.vmax)
64 | test_dataset = ImageDataset(images, device=device, stdev=None)
65 | test_loader = torch.utils.data.DataLoader(
66 | test_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False
67 | )
68 |
69 |
70 | # define model
71 | if args.model == "BasicCAE":
72 | model = BasicCAE(feat_dim=args.feat_dim)
73 | elif args.model == "CAE":
74 | model = CAE(feat_dim=args.feat_dim)
75 | elif args.model == "BasicCAEBN":
76 | model = BasicCAEBN(feat_dim=args.feat_dim)
77 | elif args.model == "CAEBN":
78 | model = CAEBN(feat_dim=args.feat_dim)
79 | else:
80 | assert False, "Unknown model name {}".format(args.model)
81 |
82 | # torch.compile makes PyTorch code run faster
83 | if args.compile:
84 | model = torch.compile(model)
85 |
86 | # set optimizer
87 | optimizer = optim.Adam(model.parameters(), eps=1e-07)
88 |
89 | # load trainer/tester class
90 | trainer = Trainer(model, optimizer, device=device)
91 |
92 | ### training main
93 | log_dir_path = set_logdir("./" + args.log_dir, args.tag)
94 | save_name = os.path.join(log_dir_path, "{}.pth".format(args.model))
95 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30)
96 | early_stop = EarlyStopping(patience=1000)
97 |
98 | with tqdm(range(args.epoch)) as pbar_epoch:
99 | for epoch in pbar_epoch:
100 | # train and test
101 | train_loss = trainer.process_epoch(train_loader)
102 | with torch.no_grad():
103 | test_loss = trainer.process_epoch(test_loader, training=False)
104 | writer.add_scalar("Loss/train_loss", train_loss, epoch)
105 | writer.add_scalar("Loss/test_loss", test_loss, epoch)
106 |
107 | # early stop
108 | save_ckpt, _ = early_stop(test_loss)
109 |
110 | if save_ckpt:
111 | trainer.save(epoch, [train_loss, test_loss], save_name)
112 |
113 | # print process bar
114 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss))
115 |
--------------------------------------------------------------------------------
/eipl/zoo/cae/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/zoo/cae/libs/trainer.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class Trainer:
13 | """
14 | Helper class to train convolutional neural network with datalodaer
15 |
16 | Args:
17 | traindata (np.array): list of np.array. First diemension should be time steps
18 | model (torch.nn.Module): rnn model
19 | optimizer (torch.optim): optimizer
20 | batch_size (int):
21 | stdev (float):
22 | device (str):
23 | """
24 |
25 | def __init__(self, model, optimizer, device="cpu"):
26 | self.device = device
27 | self.optimizer = optimizer
28 | self.model = model.to(self.device)
29 |
30 | def save(self, epoch, loss, savename):
31 | torch.save(
32 | {
33 | "epoch": epoch,
34 | "model_state_dict": self.model.state_dict(),
35 | #'optimizer_state_dict': self.optimizer.state_dict(),
36 | "train_loss": loss[0],
37 | "test_loss": loss[1],
38 | },
39 | savename,
40 | )
41 |
42 | def process_epoch(self, data, training=True):
43 | if not training:
44 | self.model.eval()
45 | else:
46 | self.model.train()
47 |
48 | total_loss = 0.0
49 | for n_batch, (xi, yi) in enumerate(data):
50 | xi = xi.to(self.device)
51 | yi = yi.to(self.device)
52 |
53 | yi_hat = self.model(xi)
54 | loss = nn.MSELoss()(yi_hat, yi)
55 | total_loss += loss.item()
56 |
57 | if training:
58 | self.optimizer.zero_grad(set_to_none=True)
59 | loss.backward()
60 | self.optimizer.step()
61 |
62 | return total_loss / n_batch
63 |
--------------------------------------------------------------------------------
/eipl/zoo/cae/log/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/zoo/cae/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/zoo/cnnrnn/bin/test_pca_cnnrnn.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import torch
10 | import argparse
11 | import numpy as np
12 | import matplotlib.pylab as plt
13 | import matplotlib.animation as anim
14 | from sklearn.decomposition import PCA
15 | from eipl.model import CNNRNN, CNNRNNLN
16 | from eipl.data import SampleDownloader
17 | from eipl.utils import normalization
18 | from eipl.utils import restore_args, tensor2numpy
19 |
20 |
21 | # argument parser
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--filename", type=str, default=None)
24 | args = parser.parse_args()
25 |
26 | # restore parameters
27 | dir_name = os.path.split(args.filename)[0]
28 | params = restore_args(os.path.join(dir_name, "args.json"))
29 |
30 | # load dataset
31 | minmax = [params["vmin"], params["vmax"]]
32 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="HWC")
33 | images, joints = grasp_data.load_raw_data("test")
34 | joint_bounds = grasp_data.joint_bounds
35 | print(
36 | "images shape:{}, min={}, max={}".format(images.shape, images.min(), images.max())
37 | )
38 | print(
39 | "joints shape:{}, min={}, max={}".format(joints.shape, joints.min(), joints.max())
40 | )
41 |
42 | # define model
43 | if params["model"] == "CNNRNN":
44 | model = CNNRNN(rec_dim=params["rec_dim"], joint_dim=8, feat_dim=params["feat_dim"])
45 | elif params["model"] == "CNNRNNLN":
46 | model = CNNRNNLN(
47 | rec_dim=params["rec_dim"], joint_dim=8, feat_dim=params["feat_dim"]
48 | )
49 | else:
50 | assert False, "Unknown model name {}".format(params["model"])
51 |
52 | if params["compile"]:
53 | model = torch.compile(model)
54 |
55 | # load weight
56 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
57 | model.load_state_dict(ckpt["model_state_dict"])
58 | model.eval()
59 |
60 | # Inference
61 | states = []
62 | state = None
63 | nloop = images.shape[1]
64 | for loop_ct in range(nloop):
65 | # load data and normalization
66 | img_t = images[:, loop_ct].transpose(0, 3, 1, 2)
67 | img_t = torch.Tensor(img_t)
68 | img_t = normalization(img_t, (0, 255), minmax)
69 | joint_t = torch.Tensor(joints[:, loop_ct])
70 | joint_t = normalization(joint_t, joint_bounds, minmax)
71 |
72 | # predict rnn
73 | _, _, state = model(img_t, joint_t, state)
74 | states.append(state[0])
75 |
76 | states = torch.permute(torch.stack(states), (1, 0, 2))
77 | states = tensor2numpy(states)
78 | # Reshape the state from [N,T,D] to [-1,D] for PCA of RNN.
79 | # N is the number of datasets
80 | # T is the sequence length
81 | # D is the dimension of the hidden state
82 | N, T, D = states.shape
83 | states = states.reshape(-1, D)
84 |
85 | # plot pca
86 | loop_ct = float(360) / T
87 | pca_dim = 3
88 | pca = PCA(n_components=pca_dim).fit(states)
89 | pca_val = pca.transform(states)
90 | # Reshape the states from [-1, pca_dim] to [N,T,pca_dim] to
91 | # visualize each state as a 3D scatter.
92 | pca_val = pca_val.reshape(N, T, pca_dim)
93 |
94 | fig = plt.figure(dpi=60)
95 | ax = fig.add_subplot(projection="3d")
96 |
97 |
98 | def anim_update(i):
99 | ax.cla()
100 | angle = int(loop_ct * i)
101 | ax.view_init(30, angle)
102 |
103 | c_list = ["C0", "C1", "C2", "C3", "C4"]
104 | for n, color in enumerate(c_list):
105 | ax.scatter(
106 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0
107 | )
108 |
109 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0)
110 | pca_ratio = pca.explained_variance_ratio_ * 100
111 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0]))
112 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1]))
113 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2]))
114 |
115 |
116 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T)
117 | ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"]))
118 |
119 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg).
120 | # ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"]), writer="imagemagick")
121 | # ani.save("./output/PCA_{}_{}.mp4".format(params["model"], params["tag"]), writer="ffmpeg")
122 |
--------------------------------------------------------------------------------
/eipl/zoo/cnnrnn/bin/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import sys
10 | import torch
11 | import argparse
12 | from tqdm import tqdm
13 | import torch.optim as optim
14 | from collections import OrderedDict
15 | from torch.utils.tensorboard import SummaryWriter
16 | from eipl.model import CNNRNN, CNNRNNLN
17 | from eipl.data import MultimodalDataset, SampleDownloader
18 | from eipl.utils import EarlyStopping, check_args, set_logdir
19 |
20 | # load own library
21 | sys.path.append("./libs/")
22 | from fullBPTT import fullBPTTtrainer
23 |
24 |
25 | # argument parser
26 | parser = argparse.ArgumentParser(
27 | description="Learning convolutional and recurrent neural network"
28 | )
29 | parser.add_argument("--model", type=str, default="CNNRNN")
30 | parser.add_argument("--epoch", type=int, default=10000)
31 | parser.add_argument("--batch_size", type=int, default=5)
32 | parser.add_argument("--rec_dim", type=int, default=50)
33 | parser.add_argument("--feat_dim", type=int, default=10)
34 | parser.add_argument("--img_loss", type=float, default=1.0)
35 | parser.add_argument("--joint_loss", type=float, default=1.0)
36 | parser.add_argument("--stdev", type=float, default=0.1)
37 | parser.add_argument("--lr", type=float, default=1e-3)
38 | parser.add_argument("--optimizer", type=str, default="adam")
39 | parser.add_argument("--log_dir", default="log/")
40 | parser.add_argument("--vmin", type=float, default=0.0)
41 | parser.add_argument("--vmax", type=float, default=1.0)
42 | parser.add_argument("--device", type=int, default=0)
43 | parser.add_argument("--compile", action="store_true")
44 | parser.add_argument("--tag", help="Tag name for snap/log sub directory")
45 | args = parser.parse_args()
46 |
47 | # check args
48 | args = check_args(args)
49 |
50 | # calculate the noise level (variance) from the normalized range
51 | stdev = args.stdev * (args.vmax - args.vmin)
52 |
53 | # set device id
54 | if args.device >= 0:
55 | device = "cuda:{}".format(args.device)
56 | else:
57 | device = "cpu"
58 |
59 | # load dataset
60 | minmax = [args.vmin, args.vmax]
61 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="CHW")
62 | images, joints = grasp_data.load_norm_data("train", vmin=args.vmin, vmax=args.vmax)
63 | train_dataset = MultimodalDataset(images, joints, device=device, stdev=stdev)
64 | train_loader = torch.utils.data.DataLoader(
65 | train_dataset,
66 | batch_size=args.batch_size,
67 | shuffle=True,
68 | drop_last=False,
69 | )
70 |
71 | images, joints = grasp_data.load_norm_data("test", vmin=args.vmin, vmax=args.vmax)
72 | test_dataset = MultimodalDataset(images, joints, device=device, stdev=None)
73 | test_loader = torch.utils.data.DataLoader(
74 | test_dataset,
75 | batch_size=args.batch_size,
76 | shuffle=True,
77 | drop_last=False,
78 | )
79 |
80 | # define model
81 | if args.model == "CNNRNN":
82 | model = CNNRNN(rec_dim=args.rec_dim, joint_dim=8, feat_dim=args.feat_dim)
83 | elif args.model == "CNNRNNLN":
84 | model = CNNRNNLN(rec_dim=args.rec_dim, joint_dim=8, feat_dim=args.feat_dim)
85 | else:
86 | assert False, "Unknown model name {}".format(args.model)
87 |
88 | # torch.compile makes PyTorch code run faster
89 | if args.compile:
90 | model = torch.compile(model)
91 |
92 | # set optimizer
93 | optimizer = optim.Adam(model.parameters(), eps=1e-07)
94 |
95 | # load trainer/tester class
96 | loss_weights = [args.img_loss, args.joint_loss]
97 | trainer = fullBPTTtrainer(model, optimizer, loss_weights=loss_weights, device=device)
98 |
99 | ### training main
100 | log_dir_path = set_logdir("./" + args.log_dir, args.tag)
101 | save_name = os.path.join(log_dir_path, "{}.pth".format(args.model))
102 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30)
103 | early_stop = EarlyStopping(patience=1000)
104 |
105 | with tqdm(range(args.epoch)) as pbar_epoch:
106 | for epoch in pbar_epoch:
107 | # train and test
108 | train_loss = trainer.process_epoch(train_loader)
109 | with torch.no_grad():
110 | test_loss = trainer.process_epoch(test_loader, training=False)
111 | writer.add_scalar("Loss/train_loss", train_loss, epoch)
112 | writer.add_scalar("Loss/test_loss", test_loss, epoch)
113 |
114 | # early stop
115 | save_ckpt, _ = early_stop(test_loss)
116 |
117 | if save_ckpt:
118 | trainer.save(epoch, [train_loss, test_loss], save_name)
119 |
120 | # print process bar
121 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss))
122 |
--------------------------------------------------------------------------------
/eipl/zoo/cnnrnn/libs/fullBPTT.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class fullBPTTtrainer:
13 | """
14 | Helper class to train recurrent neural networks with numpy sequences
15 |
16 | Args:
17 | traindata (np.array): list of np.array. First diemension should be time steps
18 | model (torch.nn.Module): rnn model
19 | optimizer (torch.optim): optimizer
20 | input_param (float): input parameter of sequential generation. 1.0 means open mode.
21 | """
22 |
23 | def __init__(self, model, optimizer, loss_weights=[1.0, 1.0], device="cpu"):
24 | self.device = device
25 | self.optimizer = optimizer
26 | self.loss_weights = loss_weights
27 | self.model = model.to(self.device)
28 |
29 | def save(self, epoch, loss, savename):
30 | torch.save(
31 | {
32 | "epoch": epoch,
33 | "model_state_dict": self.model.state_dict(),
34 | #'optimizer_state_dict': self.optimizer.state_dict(),
35 | "train_loss": loss[0],
36 | "test_loss": loss[1],
37 | },
38 | savename,
39 | )
40 |
41 | def process_epoch(self, data, training=True):
42 | if not training:
43 | self.model.eval()
44 | else:
45 | self.model.train()
46 |
47 | total_loss = 0.0
48 | for n_batch, ((x_img, x_joint), (y_img, y_joint)) in enumerate(data):
49 | x_img = x_img.to(self.device)
50 | y_img = y_img.to(self.device)
51 | x_joint = x_joint.to(self.device)
52 | y_joint = y_joint.to(self.device)
53 |
54 | state = None
55 | yi_list, yv_list = [], []
56 | T = x_img.shape[1]
57 | for t in range(T - 1):
58 | _yi_hat, _yv_hat, state = self.model(x_img[:, t], x_joint[:, t], state)
59 | yi_list.append(_yi_hat)
60 | yv_list.append(_yv_hat)
61 |
62 | yi_hat = torch.permute(torch.stack(yi_list), (1, 0, 2, 3, 4))
63 | yv_hat = torch.permute(torch.stack(yv_list), (1, 0, 2))
64 | loss = self.loss_weights[0] * nn.MSELoss()(
65 | yi_hat, y_img[:, 1:]
66 | ) + self.loss_weights[1] * nn.MSELoss()(yv_hat, y_joint[:, 1:])
67 | total_loss += loss.item()
68 |
69 | if training:
70 | self.optimizer.zero_grad(set_to_none=True)
71 | loss.backward()
72 | self.optimizer.step()
73 |
74 | return total_loss / (n_batch + 1)
75 |
--------------------------------------------------------------------------------
/eipl/zoo/cnnrnn/log/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/zoo/cnnrnn/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/zoo/rnn/bin/test.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import torch
10 | import argparse
11 | import numpy as np
12 | import matplotlib.pylab as plt
13 | import matplotlib.animation as anim
14 |
15 | # own libraries
16 | from eipl.model import BasicLSTM, BasicMTRNN
17 | from eipl.utils import normalization
18 | from eipl.utils import restore_args, tensor2numpy
19 | from eipl.data import WeightDownloader
20 |
21 |
22 | # argument parser
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--filename", type=str, default=None)
25 | parser.add_argument("--idx", type=str, default="0")
26 | parser.add_argument("--pretrained", action="store_true")
27 | args = parser.parse_args()
28 |
29 | # check args
30 | assert args.filename or args.pretrained, "Please set filename or pretrained"
31 |
32 | # load pretrained weight
33 | if args.pretrained:
34 | WeightDownloader("airec", "grasp_bottle")
35 | args.filename = os.path.join(
36 | os.path.expanduser("~"), ".eipl/airec/grasp_bottle/weights/RNN/model.pth"
37 | )
38 |
39 | # restore parameters
40 | dir_name = os.path.split(args.filename)[0]
41 | params = restore_args(os.path.join(dir_name, "args.json"))
42 | idx = int(args.idx)
43 |
44 | # load dataset
45 | minmax = [params["vmin"], params["vmax"]]
46 | feat_bounds = np.load("../cae/data/feat_bounds.npy")
47 | _feats = np.load("../cae/data/test/features.npy")
48 | test_feats = normalization(_feats, feat_bounds, minmax)
49 | test_joints = np.load("../cae/data/test/joints.npy")
50 | x_data = np.concatenate((test_feats, test_joints), axis=-1)
51 | x_data = torch.Tensor(x_data)
52 | in_dim = x_data.shape[-1]
53 |
54 | # define model
55 | if params["model"] == "LSTM":
56 | model = BasicLSTM(in_dim=in_dim, rec_dim=params["rec_dim"], out_dim=in_dim)
57 | elif params["model"] == "MTRNN":
58 | model = BasicMTRNN(in_dim, fast_dim=60, slow_dim=5, fast_tau=2, slow_tau=12)
59 | else:
60 | assert False, "Unknown model name {}".format(params["model"])
61 |
62 | if params["compile"]:
63 | model = torch.compile(model)
64 |
65 | # load weight
66 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
67 | model.load_state_dict(ckpt["model_state_dict"])
68 | model.eval()
69 |
70 | # Inference
71 | y_hat = []
72 | state = None
73 | T = x_data.shape[1]
74 | for i in range(T):
75 | _y, state = model(x_data[:, i], state)
76 | y_hat.append(_y)
77 |
78 | y_hat = torch.permute(torch.stack(y_hat), (1, 0, 2))
79 | y_hat = tensor2numpy(y_hat)
80 | y_joints = y_hat[:, :, 10:]
81 | y_feats = y_hat[:, :, :10]
82 |
83 | # plot animation
84 | fig, ax = plt.subplots(1, 2, figsize=(8, 4), dpi=60)
85 |
86 |
87 | def anim_update(i):
88 | for j in range(2):
89 | ax[j].cla()
90 |
91 | ax[0].set_ylim(-0.1, 1.1)
92 | ax[0].set_xlim(0, T)
93 | ax[0].plot(test_joints[idx, 1:], linestyle="dashed", c="k")
94 | for joint_idx in range(8):
95 | ax[0].plot(np.arange(i + 1), y_joints[idx, : i + 1, joint_idx])
96 | ax[0].set_xlabel("Step")
97 | ax[0].set_title("Joint angles")
98 |
99 | ax[1].set_ylim(-0.1, 1.1)
100 | ax[1].set_xlim(0, T)
101 | ax[1].plot(test_feats[idx, 1:], linestyle="dashed", c="k")
102 | for joint_idx in range(10):
103 | ax[1].plot(np.arange(i + 1), y_feats[idx, : i + 1, joint_idx])
104 | ax[1].set_xlabel("Step")
105 | ax[1].set_title("Image features")
106 |
107 |
108 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T)
109 | ani.save("./output/{}_{}_{}.gif".format(params["model"], params["tag"], idx))
110 |
111 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg).
112 | # ani.save("./output/{}_{}_{}.gif".format(params["model"], params["tag"], idx), writer="imagemagick")
113 | # ani.save("./output/{}_{}_{}.mp4".format(params["model"], params["tag"], idx), writer="ffmpeg")
114 |
--------------------------------------------------------------------------------
/eipl/zoo/rnn/bin/test_pca_rnn.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import torch
10 | import argparse
11 | import numpy as np
12 | import matplotlib.pylab as plt
13 | import matplotlib.animation as anim
14 | from sklearn.decomposition import PCA
15 |
16 | # own libraries
17 | from eipl.model import BasicLSTM, BasicMTRNN
18 | from eipl.utils import normalization
19 | from eipl.utils import restore_args, tensor2numpy
20 |
21 |
22 | # argument parser
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--filename", type=str, default=None)
25 | args = parser.parse_args()
26 |
27 | # restore parameters
28 | dir_name = os.path.split(args.filename)[0]
29 | params = restore_args(os.path.join(dir_name, "args.json"))
30 |
31 | # load dataset
32 | minmax = [params["vmin"], params["vmax"]]
33 | feat_bounds = np.load("../cae/data/feat_bounds.npy")
34 | _feats = np.load("../cae/data/test/features.npy")
35 | test_feats = normalization(_feats, feat_bounds, minmax)
36 | test_joints = np.load("../cae/data/test/joints.npy")
37 | x_data = np.concatenate((test_feats, test_joints), axis=-1)
38 | x_data = torch.Tensor(x_data)
39 | in_dim = x_data.shape[-1]
40 |
41 | # define model
42 | if params["model"] == "LSTM":
43 | model = BasicLSTM(in_dim=in_dim, rec_dim=params["rec_dim"], out_dim=in_dim)
44 | elif params["model"] == "MTRNN":
45 | model = BasicMTRNN(in_dim, fast_dim=60, slow_dim=5, fast_tau=2, slow_tau=12)
46 | else:
47 | assert False, "Unknown model name {}".format(params["model"])
48 |
49 | if params["compile"]:
50 | model = torch.compile(model)
51 |
52 | # load weight
53 | ckpt = torch.load(args.filename, map_location=torch.device("cpu"))
54 | model.load_state_dict(ckpt["model_state_dict"])
55 | model.eval()
56 |
57 | # Inference
58 | states = []
59 | state = None
60 | N = x_data.shape[1]
61 | for i in range(N):
62 | _, state = model(x_data[:, i], state)
63 | # lstm returns hidden state and cell state.
64 | # Here we store the hidden state to analyze the internal representation of the RNN.
65 | states.append(state[0])
66 |
67 | states = torch.permute(torch.stack(states), (1, 0, 2))
68 | states = tensor2numpy(states)
69 | # Reshape the state from [N,T,D] to [-1,D] for PCA of RNN.
70 | # N is the number of datasets
71 | # T is the sequence length
72 | # D is the dimension of the hidden state
73 | N, T, D = states.shape
74 | states = states.reshape(-1, D)
75 |
76 | # plot pca
77 | loop_ct = float(360) / T
78 | pca_dim = 3
79 | pca = PCA(n_components=pca_dim).fit(states)
80 | pca_val = pca.transform(states)
81 | # Reshape the states from [-1, pca_dim] to [N,T,pca_dim] to
82 | # visualize each state as a 3D scatter.
83 | pca_val = pca_val.reshape(N, T, pca_dim)
84 |
85 | fig = plt.figure(dpi=60)
86 | ax = fig.add_subplot(projection="3d")
87 |
88 |
89 | def anim_update(i):
90 | ax.cla()
91 | angle = int(loop_ct * i)
92 | ax.view_init(30, angle)
93 |
94 | c_list = ["C0", "C1", "C2", "C3", "C4"]
95 | for n, color in enumerate(c_list):
96 | ax.scatter(
97 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0
98 | )
99 |
100 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0)
101 | pca_ratio = pca.explained_variance_ratio_ * 100
102 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0]))
103 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1]))
104 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2]))
105 |
106 |
107 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T)
108 | ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"]))
109 |
110 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg).
111 | # ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"]), writer="imagemagick")
112 | # ani.save("./output/PCA_{}_{}.mp4".format(params["model"], params["tag"]), writer="ffmpeg")
113 |
--------------------------------------------------------------------------------
/eipl/zoo/rnn/bin/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import os
9 | import sys
10 | import torch
11 | import argparse
12 | import numpy as np
13 | from tqdm import tqdm
14 | from collections import OrderedDict
15 | import torch.optim as optim
16 | from torch.utils.tensorboard import SummaryWriter
17 | from eipl.model import BasicLSTM, BasicMTRNN
18 | from eipl.utils import normalization
19 | from eipl.utils import EarlyStopping, check_args, set_logdir
20 |
21 | # load own library
22 | sys.path.append("./libs/")
23 | from fullBPTT import fullBPTTtrainer
24 | from dataloader import TimeSeriesDataSet
25 |
26 |
27 | # argument parser
28 | parser = argparse.ArgumentParser(description="Learning convolutional autoencoder")
29 | parser.add_argument("--model", type=str, default="LSTM")
30 | parser.add_argument("--epoch", type=int, default=20000)
31 | parser.add_argument("--batch_size", type=int, default=5)
32 | parser.add_argument("--rec_dim", type=int, default=50)
33 | parser.add_argument("--stdev", type=float, default=0.1)
34 | parser.add_argument("--lr", type=float, default=1e-3)
35 | parser.add_argument("--optimizer", type=str, default="adam")
36 | parser.add_argument("--log_dir", default="log/")
37 | parser.add_argument("--vmin", type=float, default=0.0)
38 | parser.add_argument("--vmax", type=float, default=1.0)
39 | parser.add_argument("--device", type=int, default=-1)
40 | parser.add_argument("--compile", action="store_true")
41 | parser.add_argument("--tag", help="Tag name for snap/log sub directory")
42 | args = parser.parse_args()
43 |
44 | # check args
45 | args = check_args(args)
46 |
47 | # calculate the noise level (variance) from the normalized range
48 | stdev = args.stdev * (args.vmax - args.vmin)
49 |
50 | # set device id
51 | if args.device >= 0:
52 | device = "cuda:{}".format(args.device)
53 | else:
54 | device = "cpu"
55 |
56 | # load dataset
57 | minmax = [args.vmin, args.vmax]
58 | feat_bounds = np.load("../cae/data/feat_bounds.npy")
59 | _feats = np.load("../cae/data/train/features.npy")
60 | train_feats = normalization(_feats, feat_bounds, minmax)
61 | train_joints = np.load("../cae/data/train/joints.npy")
62 | in_dim = train_feats.shape[-1] + train_joints.shape[-1]
63 | train_dataset = TimeSeriesDataSet(
64 | train_feats, train_joints, minmax=[args.vmin, args.vmax], stdev=stdev
65 | )
66 | train_loader = torch.utils.data.DataLoader(
67 | train_dataset,
68 | batch_size=args.batch_size,
69 | shuffle=True,
70 | drop_last=False,
71 | pin_memory=True,
72 | )
73 |
74 | _feats = np.load("../cae/data/test/features.npy")
75 | test_feats = normalization(_feats, feat_bounds, minmax)
76 | test_joints = np.load("../cae/data/test/joints.npy")
77 | test_dataset = TimeSeriesDataSet(
78 | test_feats, test_joints, minmax=[args.vmin, args.vmax], stdev=None
79 | )
80 | test_loader = torch.utils.data.DataLoader(
81 | test_dataset,
82 | batch_size=args.batch_size,
83 | shuffle=True,
84 | drop_last=False,
85 | pin_memory=True,
86 | )
87 |
88 | # define model
89 | if args.model == "LSTM":
90 | model = BasicLSTM(in_dim=in_dim, rec_dim=args.rec_dim, out_dim=in_dim)
91 | elif args.model == "MTRNN":
92 | model = BasicMTRNN(in_dim, fast_dim=60, slow_dim=5, fast_tau=2, slow_tau=12)
93 | else:
94 | assert False, "Unknown model name {}".format(args.model)
95 |
96 | # torch.compile makes PyTorch code run faster
97 | if args.compile:
98 | torch.set_float32_matmul_precision("high")
99 | model = torch.compile(model)
100 |
101 | # set optimizer
102 | if args.optimizer.casefold() == "adam":
103 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
104 | elif args.optimizer.casefold() == "radam":
105 | optimizer = optim.RAdam(model.parameters(), lr=args.lr)
106 | else:
107 | assert False, "Unknown optimizer name {}. please set Adam or RAdam.".format(
108 | args.optimizer
109 | )
110 |
111 | # load trainer/tester class
112 | trainer = fullBPTTtrainer(model, optimizer, device=device)
113 |
114 | ### training main
115 | log_dir_path = set_logdir("./" + args.log_dir, args.tag)
116 | save_name = os.path.join(log_dir_path, "{}.pth".format(args.model))
117 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30)
118 | early_stop = EarlyStopping(patience=1000)
119 |
120 | with tqdm(range(args.epoch)) as pbar_epoch:
121 | for epoch in pbar_epoch:
122 | # train and test
123 | train_loss = trainer.process_epoch(train_loader)
124 | with torch.no_grad():
125 | test_loss = trainer.process_epoch(test_loader, training=False)
126 | writer.add_scalar("Loss/train_loss", train_loss, epoch)
127 | writer.add_scalar("Loss/test_loss", test_loss, epoch)
128 |
129 | # early stop
130 | save_ckpt, _ = early_stop(test_loss)
131 |
132 | if save_ckpt:
133 | trainer.save(epoch, [train_loss, test_loss], save_name)
134 |
135 | # print process bar
136 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss))
137 |
--------------------------------------------------------------------------------
/eipl/zoo/rnn/libs/dataloader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import torch
9 | import numpy as np
10 | from torch.utils.data import Dataset
11 |
12 |
13 | class TimeSeriesDataSet(Dataset):
14 | """AIREC_sample dataset.
15 |
16 | Args:
17 | feats (np.array): Set the image features.
18 | joints (np.array): Set the joint angles.
19 | minmax (float, optional): Set normalization range, default is [0.1,0.9].
20 | """
21 |
22 | def __init__(self, feats, joints, minmax=[0.1, 0.9], stdev=0.0):
23 | self.stdev = stdev
24 | self.feats = torch.from_numpy(feats).float()
25 | self.joints = torch.from_numpy(joints).float()
26 |
27 | def __len__(self):
28 | return len(self.feats)
29 |
30 | def __getitem__(self, idx):
31 | # normalization and convert numpy array to torch tensor
32 | y_feat = self.feats[idx]
33 | y_joint = self.joints[idx]
34 | y_data = torch.concat((y_feat, y_joint), axis=-1)
35 |
36 | # apply gaussian noise to joint angles and image features
37 | if self.stdev is not None:
38 | x_feat = self.feats[idx] + torch.normal(
39 | mean=0, std=self.stdev, size=y_feat.shape
40 | )
41 | x_joint = self.joints[idx] + torch.normal(
42 | mean=0, std=self.stdev, size=y_joint.shape
43 | )
44 | else:
45 | x_feat = self.feats[idx]
46 | x_joint = self.joints[idx]
47 |
48 | x_data = torch.concat((x_feat, x_joint), axis=-1)
49 |
50 | return [x_data, y_data]
51 |
52 |
53 | if __name__ == "__main__":
54 | import time
55 |
56 | # random dataset
57 | feats = np.random.randn(10, 120, 10)
58 | joints = np.random.randn(10, 120, 8)
59 |
60 | # load data
61 | data_loader = TimeSeriesDataSet(feats, joints, minmax=[0.1, 0.9])
62 | x_data, y_data = data_loader[1]
63 | print(x_data.shape, y_data.shape)
64 |
65 | train_loader = torch.utils.data.DataLoader(
66 | data_loader,
67 | batch_size=3,
68 | shuffle=True,
69 | drop_last=False,
70 | pin_memory=True,
71 | )
72 |
73 | print("[Start] load data using torch data loader")
74 | start_time = time.time()
75 | for batch in train_loader:
76 | print(batch[0].shape, batch[1].shape)
77 |
78 | print("[Finish] time: ", time.time() - start_time)
79 |
--------------------------------------------------------------------------------
/eipl/zoo/rnn/libs/fullBPTT.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University
3 | #
4 | # Released under the AGPL license.
5 | # see https://www.gnu.org/licenses/agpl-3.0.txt
6 | #
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class fullBPTTtrainer:
13 | """
14 | Helper class to train recurrent neural networks with numpy sequences
15 |
16 | Args:
17 | traindata (np.array): list of np.array. First diemension should be time steps
18 | model (torch.nn.Module): rnn model
19 | optimizer (torch.optim): optimizer
20 | input_param (float): input parameter of sequential generation. 1.0 means open mode.
21 | """
22 |
23 | def __init__(self, model, optimizer, device="cpu"):
24 | self.device = device
25 | self.optimizer = optimizer
26 | self.model = model.to(self.device)
27 |
28 | def save(self, epoch, loss, savename):
29 | torch.save(
30 | {
31 | "epoch": epoch,
32 | "model_state_dict": self.model.state_dict(),
33 | #'optimizer_state_dict': self.optimizer.state_dict(),
34 | "train_loss": loss[0],
35 | "test_loss": loss[1],
36 | },
37 | savename,
38 | )
39 |
40 | def process_epoch(self, data, training=True):
41 | if not training:
42 | self.model.eval()
43 | else:
44 | self.model.train()
45 |
46 | total_loss = 0.0
47 | for n_batch, (x, y) in enumerate(data):
48 | x = x.to(self.device)
49 | y = y.to(self.device)
50 |
51 | state = None
52 | y_list = []
53 | self.optimizer.zero_grad(set_to_none=True)
54 | for t in range(x.shape[1] - 1):
55 | y_hat, state = self.model(x[:, t], state)
56 | y_list.append(y_hat)
57 |
58 | y_hat = torch.permute(torch.stack(y_list), (1, 0, 2))
59 | loss = nn.MSELoss()(y_hat, y[:, 1:])
60 | total_loss += loss.item()
61 |
62 | if training:
63 | loss.backward()
64 | self.optimizer.step()
65 |
66 | return total_loss / (n_batch + 1)
67 |
--------------------------------------------------------------------------------
/eipl/zoo/rnn/log/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/eipl/zoo/rnn/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ###### Requirements without Version Specifiers ######
2 | scipy
3 | gdown
4 | matplotlib
5 | ipdb
6 | opencv-python
7 | tensorboard
8 | torchinfo
9 | tqdm
10 | scikit-learn
11 | torch
12 | torchvision
13 | onnx
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | DESCRIPTION = "EIPL: Embodied Intelligence with Deep Predictive Learning"
4 | NAME = "eipl"
5 | AUTHOR = "Hiroshi Ito"
6 | AUTHOR_EMAIL = "it.hiroshi.o@gmail.com"
7 | URL = "https://github.com/ogata-lab/eipl"
8 | LICENSE = "MIT License"
9 | DOWNLOAD_URL = "https://github.com/ogata-lab/eipl"
10 | VERSION = "1.1.1"
11 | PYTHON_REQUIRES = ">=3.8"
12 |
13 | """
14 | INSTALL_REQUIRES = [
15 | 'pytorch>=1.9.0',
16 | 'matplotlib>=3.3.4',
17 | 'numpy >=1.20.3',
18 | 'matplotlib>=3.3.4',
19 | 'scipy>=1.6.3',
20 | 'scikit-learn>=0.24.2',
21 | ]
22 | """
23 |
24 | """
25 | with open('README.rst', 'r') as fp:
26 | readme = fp.read()
27 | with open('CONTACT.txt', 'r') as fp:
28 | contacts = fp.read()
29 | long_description = readme + '\n\n' + contacts
30 | """
31 |
32 |
33 | setup(
34 | name=NAME,
35 | author=AUTHOR,
36 | author_email=AUTHOR_EMAIL,
37 | maintainer=AUTHOR,
38 | maintainer_email=AUTHOR_EMAIL,
39 | description=DESCRIPTION,
40 | # long_description=long_description,
41 | license=LICENSE,
42 | url=URL,
43 | version=VERSION,
44 | download_url=DOWNLOAD_URL,
45 | # python_requires=PYTHON_REQUIRES,
46 | # install_requires=INSTALL_REQUIRES,
47 | packages=find_packages(),
48 | )
49 |
--------------------------------------------------------------------------------