├── .gitignore ├── LICENSE ├── README.md ├── eipl ├── __init__.py ├── data │ ├── __init__.py │ ├── data_dict.py │ ├── dataset.py │ └── downloader.py ├── layer │ ├── CoordConv2d.py │ ├── GridMask.py │ ├── MultipleTimescaleRNN.py │ ├── SpatialSoftmax.py │ └── __init__.py ├── model │ ├── BasicRNN.py │ ├── CAE.py │ ├── CAEBN.py │ ├── CNNRNN.py │ ├── CNNRNNLN.py │ ├── SARNN.py │ └── __init__.py ├── test │ ├── benchmark_dataloader.py │ ├── test_CoordConv2d.py │ ├── test_GridMask.py │ ├── test_LossScheduler.py │ ├── test_SampleDownloader.py │ ├── test_SpatialSoftmax.py │ ├── test_WeightDownloader.py │ ├── test_bounds.py │ ├── test_cos_interpolation.py │ ├── test_dataloader.py │ └── test_models.py ├── tutorials │ ├── airec │ │ ├── ros │ │ │ ├── 1_rosbag2npz.py │ │ │ ├── 2_make_dataset.py │ │ │ └── 3_check_data.py │ │ └── sarnn │ │ │ ├── bin │ │ │ ├── test.py │ │ │ ├── test_pca_sarnn.py │ │ │ └── train.py │ │ │ ├── libs │ │ │ └── fullBPTT.py │ │ │ ├── log │ │ │ └── .gitignore │ │ │ └── output │ │ │ └── .gitignore │ ├── open_manipulator │ │ ├── ros │ │ │ ├── bag2npz │ │ │ │ ├── 1_rosbag2npz.py │ │ │ │ ├── 2_make_dataset.py │ │ │ │ ├── 3_check_data.py │ │ │ │ ├── bag │ │ │ │ │ └── .gitignore │ │ │ │ ├── data │ │ │ │ │ └── .gitignore │ │ │ │ └── utils.py │ │ │ └── om_teleop │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── config │ │ │ │ └── d435_param.yaml │ │ │ │ ├── launch │ │ │ │ ├── playback_bringup.launch │ │ │ │ ├── rt_bringup.launch │ │ │ │ └── teleop_bringup.launch │ │ │ │ ├── package.xml │ │ │ │ └── src │ │ │ │ ├── dynamixel_driver.py │ │ │ │ ├── dynamixel_utils.py │ │ │ │ ├── follower_bringup.py │ │ │ │ ├── interplation_node.py │ │ │ │ ├── leader_bringup.py │ │ │ │ └── playback.py │ │ └── sarnn │ │ │ ├── bin │ │ │ ├── export_onnx.py │ │ │ ├── test.py │ │ │ ├── test_pca_sarnn.py │ │ │ └── train.py │ │ │ ├── libs │ │ │ └── fullBPTT.py │ │ │ ├── log │ │ │ └── .gitignore │ │ │ └── output │ │ │ └── .gitignore │ └── robosuite │ │ ├── README.md │ │ ├── sarnn │ │ ├── bin │ │ │ ├── test.py │ │ │ ├── test_pca_sarnn.py │ │ │ └── train.py │ │ ├── libs │ │ │ └── fullBPTT.py │ │ ├── log │ │ │ └── .gitignore │ │ └── output │ │ │ └── .gitignore │ │ └── simulator │ │ ├── 2_resave.sh │ │ ├── 3_check_data.sh │ │ ├── bin │ │ ├── 1_teaching.py │ │ ├── 2_resave.py │ │ ├── 3_check_playback_data.py │ │ ├── 4_generate_dataset.py │ │ ├── 5_check_dataset.py │ │ └── 6_rt_control.py │ │ ├── data │ │ └── .gitignore │ │ ├── libs │ │ ├── devices.py │ │ ├── environment.py │ │ ├── rt_control_wrapper.py │ │ ├── samplers.py │ │ └── utils.py │ │ └── output │ │ └── .gitignore ├── utils │ ├── __init__.py │ ├── arg_utils.py │ ├── callback.py │ ├── check_gpu.py │ ├── convert_compiled_pth.py │ ├── data.py │ ├── nn_func.py │ ├── path_utils.py │ ├── print_func.py │ ├── resave_pth.py │ └── utils.py └── zoo │ ├── cae │ ├── bin │ │ ├── extract.py │ │ ├── test.py │ │ ├── test_pca_cae.py │ │ └── train.py │ ├── data │ │ └── .gitignore │ ├── libs │ │ └── trainer.py │ ├── log │ │ └── .gitignore │ └── output │ │ └── .gitignore │ ├── cnnrnn │ ├── bin │ │ ├── test.py │ │ ├── test_pca_cnnrnn.py │ │ └── train.py │ ├── libs │ │ └── fullBPTT.py │ ├── log │ │ └── .gitignore │ └── output │ │ └── .gitignore │ └── rnn │ ├── bin │ ├── test.py │ ├── test_pca_rnn.py │ └── train.py │ ├── libs │ ├── dataloader.py │ └── fullBPTT.py │ ├── log │ └── .gitignore │ └── output │ └── .gitignore ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # weight 2 | *.pth 3 | *.onnx 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 114 | .pdm.toml 115 | .pdm-python 116 | .pdm-build/ 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ogata Laboratory (Waseda University) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 |
5 | 6 | 7 | 8 | ## What's EIPL? 9 | 10 | EIPL (Embodied Intelligence with Deep Predictive Learning) is a library for robot motion generation using deep predictive learning developed at [Ogata Laboratory](https://ogata-lab.jp/), [Waseda University](https://www.waseda.jp/top/en). 11 | Highlighted features include: 12 | 13 | - [**Full documentation**](https://ogata-lab.github.io/eipl-docs) for the systematic understanding of deep predictive learning 14 | - **Easy model training:** Includes sample datasets, source code, and pre-trained weights 15 | - **Applicable to real robots:** Generalized motion can be acquired with small data sets 16 | 17 | ## Install 18 | 19 | ```sh 20 | pip install -r requirements.txt 21 | pip install -e . 22 | ``` 23 | 24 | ## Citation 25 | 26 | ``` 27 | @article{suzuki2023deep, 28 | author = {Kanata Suzuki and Hiroshi Ito and Tatsuro Yamada and Kei Kase and Tetsuya Ogata}, 29 | title = {Deep Predictive Learning : Motion Learning Concept inspired by Cognitive Robotics}, 30 | booktitle = {arXiv preprint arXiv:2306.14714}, 31 | year = {2023}, 32 | } 33 | ``` 34 | 35 | ## LICENSE 36 | 37 | MIT License -------------------------------------------------------------------------------- /eipl/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * 2 | from .layer import * 3 | from .utils import * 4 | -------------------------------------------------------------------------------- /eipl/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .downloader import SampleDownloader, WeightDownloader 2 | from .dataset import * 3 | -------------------------------------------------------------------------------- /eipl/data/data_dict.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | data_dict = { 7 | "airec": { 8 | "grasp_bottle": { 9 | "https://drive.google.com/uc?id=11NhuH35DR2Pg7gWqOIhw5yJ_Nmeto-wO" 10 | } 11 | }, 12 | "om": { 13 | "grasp_cube": { 14 | "https://drive.google.com/uc?id=16M0wEPleiRAucZeGRcUel5yO8ou7R-SA" 15 | } 16 | }, 17 | } 18 | -------------------------------------------------------------------------------- /eipl/layer/CoordConv2d.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.modules.conv as conv 9 | 10 | 11 | class AddCoords(nn.Module): 12 | """AddCoords 13 | 14 | Arguments: 15 | min_range (flaot, optional): Minimum value for xx_channel, yy_channel. The default is 0. But original paper is -1. 16 | with_r (Boolean, optional): Wether add the radial channel. The default is True. 17 | """ 18 | 19 | def __init__(self, min_range=0.0, with_r=False): 20 | super(AddCoords, self).__init__() 21 | self.min_range = min_range 22 | self.with_r = with_r 23 | 24 | def forward(self, x): 25 | batch_size, channels, width, height = x.shape 26 | device = x.device 27 | 28 | xx_channel, yy_channel = torch.meshgrid( 29 | torch.linspace(self.min_range, 1.0, height, dtype=torch.float32), 30 | torch.linspace(self.min_range, 1.0, width, dtype=torch.float32), 31 | indexing="ij", 32 | ) 33 | xx_channel = xx_channel.expand(batch_size, 1, width, height).to(device) 34 | yy_channel = yy_channel.expand(batch_size, 1, width, height).to(device) 35 | 36 | y = torch.cat([x, xx_channel, yy_channel], dim=1) 37 | 38 | if self.with_r: 39 | rr = torch.sqrt( 40 | torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2) 41 | ) 42 | y = torch.cat([y, rr], dim=1) 43 | 44 | return y 45 | 46 | 47 | class CoordConv2d(conv.Conv2d): 48 | """CoordConv2d 49 | Rosanne Liu, Joel Lehman, Piero Molino, Felipe Petroski Such, Eric Frank, Alex Sergeev, Jason Yosinski, 50 | ``An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution.``, 51 | NeurIPS 2018. 52 | https://arxiv.org/abs/1807.03247v2 53 | """ 54 | 55 | def __init__( 56 | self, 57 | input_size, 58 | output_size, 59 | kernel_size, 60 | stride=1, 61 | padding=0, 62 | dilation=1, 63 | groups=1, 64 | bias=True, 65 | min_range=0.0, 66 | with_r=False, 67 | ): 68 | super(CoordConv2d, self).__init__( 69 | input_size, 70 | output_size, 71 | kernel_size, 72 | stride, 73 | padding, 74 | dilation, 75 | groups, 76 | bias, 77 | ) 78 | rank = 2 79 | self.addcoords = AddCoords(min_range=min_range, with_r=with_r) 80 | self.conv = nn.Conv2d( 81 | input_size + rank + int(with_r), 82 | output_size, 83 | kernel_size, 84 | stride, 85 | padding, 86 | dilation, 87 | groups, 88 | bias, 89 | ) 90 | 91 | def forward(self, x): 92 | hid = self.addcoords(x) 93 | y = self.conv(hid) 94 | 95 | return y 96 | -------------------------------------------------------------------------------- /eipl/layer/GridMask.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class GridMask: 11 | """GridMask 12 | 13 | Arguments: 14 | p (flaot, optional): Mask spacing 15 | d_range (Boolean, optional): 16 | r (Boolean, optional): This parameter determines how much of the original image is retained. 17 | If r=0, the entire image is masked; if r=1, the image is not masked at all. 18 | 19 | Chen, Pengguang, et al. "Gridmask data augmentation." 20 | https://arxiv.org/abs/2001.04086 21 | """ 22 | 23 | def __init__(self, p=0.6, d_range=(10, 30), r=0.6, channel_first=True): 24 | self.p = p 25 | self.d_range = d_range 26 | self.r = r 27 | self.channel_first = channel_first 28 | 29 | def __call__(self, img, debug=False): 30 | if not debug and np.random.uniform() > self.p: 31 | return img 32 | 33 | side = img.shape[-2] 34 | d = np.random.randint(*self.d_range, dtype=np.uint8) 35 | r = int(self.r * d) 36 | 37 | mask = np.ones((side + d, side + d), dtype=np.uint8) 38 | for i in range(0, side + d, d): 39 | for j in range(0, side + d, d): 40 | mask[i : i + (d - r), j : j + (d - r)] = 0 41 | 42 | delta_x, delta_y = np.random.randint(0, d, size=2) 43 | mask = mask[delta_x : delta_x + side, delta_y : delta_y + side] 44 | 45 | if self.channel_first: 46 | img *= np.expand_dims(mask, 0) 47 | else: 48 | img *= np.expand_dims(mask, -1) 49 | 50 | return img 51 | -------------------------------------------------------------------------------- /eipl/layer/MultipleTimescaleRNN.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | from eipl.utils import get_activation_fn 9 | 10 | 11 | class MTRNNCell(nn.Module): 12 | #:: MTRNNCell 13 | """Multiple Timescale RNN. 14 | 15 | Implements a form of Recurrent Neural Network (RNN) that operates with multiple timescales. 16 | This is based on the idea of hierarchical organization in human cognitive functions. 17 | 18 | Arguments: 19 | input_dim (int): Number of input features. 20 | fast_dim (int): Number of fast context neurons. 21 | slow_dim (int): Number of slow context neurons. 22 | fast_tau (float): Time constant value of fast context. 23 | slow_tau (float): Time constant value of slow context. 24 | activation (string, optional): If you set `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). 25 | use_bias (Boolean, optional): whether the layer uses a bias vector. The default is False. 26 | use_pb (Boolean, optional): whether the recurrent uses a pb vector. The default is False. 27 | 28 | Yuichi Yamashita, Jun Tani, 29 | "Emergence of functional hierarchy in a multiple timescale neural network model: a humanoid robot experiment." PLoS computational biology, 2008. 30 | https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1000220 31 | """ 32 | 33 | def __init__( 34 | self, 35 | input_dim, 36 | fast_dim, 37 | slow_dim, 38 | fast_tau, 39 | slow_tau, 40 | activation="tanh", 41 | use_bias=False, 42 | use_pb=False, 43 | ): 44 | super(MTRNNCell, self).__init__() 45 | 46 | self.input_dim = input_dim 47 | self.fast_dim = fast_dim 48 | self.slow_dim = slow_dim 49 | self.fast_tau = fast_tau 50 | self.slow_tau = slow_tau 51 | self.use_bias = use_bias 52 | self.use_pb = use_pb 53 | 54 | # Legacy string support for activation function. 55 | if isinstance(activation, str): 56 | self.activation = get_activation_fn(activation) 57 | else: 58 | self.activation = activation 59 | 60 | # Input Layers 61 | self.i2f = nn.Linear(input_dim, fast_dim, bias=use_bias) 62 | 63 | # Fast context layer 64 | self.f2f = nn.Linear(fast_dim, fast_dim, bias=False) 65 | self.f2s = nn.Linear(fast_dim, slow_dim, bias=use_bias) 66 | 67 | # Slow context layer 68 | self.s2s = nn.Linear(slow_dim, slow_dim, bias=False) 69 | self.s2f = nn.Linear(slow_dim, fast_dim, bias=use_bias) 70 | 71 | def forward(self, x, state=None, pb=None): 72 | """Forward propagation of the MTRNN. 73 | 74 | Arguments: 75 | x (torch.Tensor): Input tensor of shape (batch_size, input_dim). 76 | state (list): Previous states (h_fast, h_slow, u_fast, u_slow), each of shape (batch_size, context_dim). 77 | If None, initialize states to zeros. 78 | pb (bool): pb vector. Used if self.use_pb is set to True. 79 | 80 | Returns: 81 | new_h_fast (torch.Tensor): Updated fast context state. 82 | new_h_slow (torch.Tensor): Updated slow context state. 83 | new_u_fast (torch.Tensor): Updated fast internal state. 84 | new_u_slow (torch.Tensor): Updated slow internal state. 85 | """ 86 | batch_size = x.shape[0] 87 | if state is not None: 88 | prev_h_fast, prev_h_slow, prev_u_fast, prev_u_slow = state 89 | else: 90 | device = x.device 91 | prev_h_fast = torch.zeros(batch_size, self.fast_dim).to(device) 92 | prev_h_slow = torch.zeros(batch_size, self.slow_dim).to(device) 93 | prev_u_fast = torch.zeros(batch_size, self.fast_dim).to(device) 94 | prev_u_slow = torch.zeros(batch_size, self.slow_dim).to(device) 95 | 96 | new_u_fast = (1.0 - 1.0 / self.fast_tau) * prev_u_fast + 1.0 / self.fast_tau * ( 97 | self.i2f(x) + self.f2f(prev_h_fast) + self.s2f(prev_h_slow) 98 | ) 99 | 100 | _input_slow = self.f2s(prev_h_fast) + self.s2s(prev_h_slow) 101 | if pb is not None: 102 | _input_slow += pb 103 | 104 | new_u_slow = ( 105 | 1.0 - 1.0 / self.slow_tau 106 | ) * prev_u_slow + 1.0 / self.slow_tau * _input_slow 107 | 108 | new_h_fast = self.activation(new_u_fast) 109 | new_h_slow = self.activation(new_u_slow) 110 | 111 | return new_h_fast, new_h_slow, new_u_fast, new_u_slow 112 | -------------------------------------------------------------------------------- /eipl/layer/SpatialSoftmax.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from eipl.utils import tensor2numpy, plt_img, get_feature_map 11 | 12 | 13 | def create_position_encoding(width: int, height: int, normalized=True, data_format="channels_first"): 14 | if normalized: 15 | pos_x, pos_y = np.meshgrid(np.linspace(0.0, 1.0, height), np.linspace(0.0, 1.0, width), indexing="xy") 16 | else: 17 | pos_x, pos_y = np.meshgrid( 18 | np.linspace(0, height - 1, height), 19 | np.linspace(0, width - 1, width), 20 | indexing="xy", 21 | ) 22 | 23 | if data_format == "channels_first": 24 | pos_xy = torch.from_numpy(np.stack([pos_x, pos_y], axis=0)).float() # (2,W,H) 25 | else: 26 | pos_xy = torch.from_numpy(np.stack([pos_x, pos_y], axis=2)).float() # (W,H,2) 27 | 28 | pos_x = torch.from_numpy(pos_x.reshape(height * width)).float() 29 | pos_y = torch.from_numpy(pos_y.reshape(height * width)).float() 30 | 31 | return pos_xy, pos_x, pos_y 32 | 33 | 34 | class SpatialSoftmax(nn.Module): 35 | """Spatial Softmax 36 | Extract XY position from feature map of CNN 37 | 38 | Chelsea Finn, Xin Yu Tan, Yan Duan, Trevor Darrell, Sergey Levine, Pieter Abbeel 39 | ``Deep spatial autoencoders for visuomotor learning.`` 40 | 2016 IEEE International Conference on Robotics and Automation (ICRA). IEEE, 2016. 41 | https://ieeexplore.ieee.org/abstract/document/7487173 42 | """ 43 | 44 | def __init__(self, width: int, height: int, temperature=1e-4, normalized=True): 45 | super(SpatialSoftmax, self).__init__() 46 | self.width = width 47 | self.height = height 48 | self.temperature = temperature 49 | 50 | _, pos_x, pos_y = create_position_encoding(width, height, normalized=normalized) 51 | self.register_buffer("pos_x", pos_x) 52 | self.register_buffer("pos_y", pos_y) 53 | 54 | def forward(self, x): 55 | batch_size, channels, width, height = x.shape 56 | assert height == self.height 57 | assert width == self.width 58 | 59 | # flatten, apply softmax 60 | logit = x.reshape(batch_size, channels, -1) 61 | att_map = torch.softmax(logit / self.temperature, dim=-1) 62 | 63 | # compute expectation 64 | expected_x = torch.sum(self.pos_x * att_map, dim=-1, keepdim=True) 65 | expected_y = torch.sum(self.pos_y * att_map, dim=-1, keepdim=True) 66 | keys = torch.cat([expected_x, expected_y], -1) 67 | 68 | # keys [[x,y], [x,y], [x,y],...] 69 | keys = keys.reshape(batch_size, channels, 2) 70 | att_map = att_map.reshape(-1, channels, width, height) 71 | return keys, att_map 72 | 73 | 74 | class InverseSpatialSoftmax(nn.Module): 75 | """InverseSpatialSoftmax 76 | Generate heatmap from XY position 77 | 78 | Hideyuki Ichiwara, Hiroshi Ito, Kenjiro Yamamoto, Hiroki Mori, Tetsuya Ogata 79 | ``Spatial Attention Point Network for Deep-learning-based Robust Autonomous Robot Motion Generation.`` 80 | https://arxiv.org/abs/2103.01598 81 | """ 82 | 83 | def __init__(self, width: int, height: int, heatmap_size=0.1, normalized=True, convex=True): 84 | super(InverseSpatialSoftmax, self).__init__() 85 | 86 | self.width = width 87 | self.height = height 88 | self.normalized = normalized 89 | self.heatmap_size = heatmap_size 90 | self.convex = convex 91 | 92 | pos_xy, _, _ = create_position_encoding(width, height, normalized=normalized) 93 | self.register_buffer("pos_xy", pos_xy) 94 | 95 | def forward(self, keys): 96 | squared_distances = torch.sum(torch.pow(self.pos_xy[None, None] - keys[:, :, :, None, None], 2.0), axis=2) 97 | heatmap = torch.exp(-squared_distances / self.heatmap_size) 98 | 99 | if not self.convex: 100 | heatmap = torch.abs(1.0 - heatmap) 101 | 102 | return heatmap 103 | -------------------------------------------------------------------------------- /eipl/layer/__init__.py: -------------------------------------------------------------------------------- 1 | from .CoordConv2d import * 2 | from .GridMask import * 3 | from .MultipleTimescaleRNN import * 4 | from .SpatialSoftmax import * 5 | -------------------------------------------------------------------------------- /eipl/model/BasicRNN.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | from eipl.utils import get_activation_fn 9 | from eipl.layer import MTRNNCell 10 | 11 | 12 | class BasicLSTM(nn.Module): 13 | #:: BasicLSTM 14 | """BasicLSTM 15 | 16 | Arguments: 17 | in_dim (int): Number of fast context neurons 18 | rec_dim (int): Number of fast context neurons 19 | out_dim (int): Number of fast context neurons 20 | activation (string, optional): If you set `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). 21 | The default is hyperbolic tangent (`tanh`). 22 | """ 23 | 24 | def __init__(self, in_dim, rec_dim, out_dim, activation="tanh"): 25 | super(BasicLSTM, self).__init__() 26 | 27 | # Legacy string support for activation function. 28 | if isinstance(activation, str): 29 | activation = get_activation_fn(activation) 30 | 31 | self.rnn = nn.LSTMCell(in_dim, rec_dim) 32 | self.rnn_out = nn.Sequential(nn.Linear(rec_dim, out_dim), activation) 33 | 34 | def forward(self, x, state=None): 35 | rnn_hid = self.rnn(x, state) 36 | y_hat = self.rnn_out(rnn_hid[0]) 37 | 38 | return y_hat, rnn_hid 39 | 40 | 41 | class BasicMTRNN(nn.Module): 42 | #:: BasicMTRNN 43 | """BasicMTRNN 44 | 45 | Arguments: 46 | in_dim (int): Number of fast context neurons 47 | rec_dim (int): Number of fast context neurons 48 | out_dim (int): Number of fast context neurons 49 | activation (string, optional): If you set `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). 50 | The default is hyperbolic tangent (`tanh`). 51 | """ 52 | 53 | def __init__( 54 | self, 55 | in_dim, 56 | fast_dim, 57 | slow_dim, 58 | fast_tau, 59 | slow_tau, 60 | out_dim=None, 61 | activation="tanh", 62 | ): 63 | super(BasicMTRNN, self).__init__() 64 | 65 | if out_dim is None: 66 | out_dim = in_dim 67 | 68 | # Legacy string support for activation function. 69 | if isinstance(activation, str): 70 | activation = get_activation_fn(activation) 71 | 72 | self.mtrnn = MTRNNCell( 73 | in_dim, fast_dim, slow_dim, fast_tau, slow_tau, activation=activation 74 | ) 75 | # Output of RNN 76 | self.rnn_out = nn.Sequential(nn.Linear(fast_dim, out_dim), activation) 77 | 78 | def forward(self, x, state=None): 79 | rnn_hid = self.mtrnn(x, state) 80 | y_hat = self.rnn_out(rnn_hid[0]) 81 | 82 | return y_hat, rnn_hid 83 | -------------------------------------------------------------------------------- /eipl/model/CAE.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | from eipl.layer import CoordConv2d, AddCoords 9 | 10 | 11 | class BasicCAE(nn.Module): 12 | #:: BasicCAE 13 | """BasicCAE""" 14 | 15 | def __init__(self, feat_dim=10): 16 | super(BasicCAE, self).__init__() 17 | 18 | # encoder 19 | self.encoder = nn.Sequential( 20 | nn.Conv2d(3, 64, 3, 2, 1), 21 | nn.Tanh(), 22 | nn.Conv2d(64, 32, 3, 2, 1), 23 | nn.Tanh(), 24 | nn.Conv2d(32, 16, 3, 2, 1), 25 | nn.Tanh(), 26 | nn.Conv2d(16, 12, 3, 2, 1), 27 | nn.Tanh(), 28 | nn.Conv2d(12, 8, 3, 2, 1), 29 | nn.Tanh(), 30 | nn.Flatten(), 31 | nn.Linear(8 * 4 * 4, 50), 32 | nn.Tanh(), 33 | nn.Linear(50, feat_dim), 34 | nn.Tanh(), 35 | ) 36 | 37 | # decoder 38 | self.decoder = nn.Sequential( 39 | nn.Linear(feat_dim, 50), 40 | nn.Tanh(), 41 | nn.Linear(50, 8 * 4 * 4), 42 | nn.Tanh(), 43 | nn.Unflatten(1, (8, 4, 4)), 44 | nn.ConvTranspose2d(8, 12, 3, 2, padding=1, output_padding=1), 45 | nn.Tanh(), 46 | nn.ConvTranspose2d(12, 16, 3, 2, padding=1, output_padding=1), 47 | nn.Tanh(), 48 | nn.ConvTranspose2d(16, 32, 3, 2, padding=1, output_padding=1), 49 | nn.Tanh(), 50 | nn.ConvTranspose2d(32, 64, 3, 2, padding=1, output_padding=1), 51 | nn.Tanh(), 52 | nn.ConvTranspose2d(64, 3, 3, 2, padding=1, output_padding=1), 53 | nn.Tanh(), 54 | ) 55 | 56 | def forward(self, x): 57 | return self.decoder(self.encoder(x)) 58 | 59 | 60 | class CAE(nn.Module): 61 | #:: CAE 62 | """CAE""" 63 | 64 | def __init__(self, feat_dim=10): 65 | super(CAE, self).__init__() 66 | 67 | # encoder 68 | self.encoder = nn.Sequential( 69 | nn.Conv2d(3, 32, 6, 2, 1), 70 | nn.ReLU(True), 71 | nn.Conv2d(32, 64, 6, 2, 1), 72 | nn.ReLU(True), 73 | nn.Conv2d(64, 128, 6, 2, 1), 74 | nn.ReLU(True), 75 | nn.Flatten(), 76 | nn.Linear(128 * 14 * 14, 1000), 77 | nn.ReLU(True), 78 | nn.Linear(1000, feat_dim), 79 | nn.ReLU(True), 80 | ) 81 | 82 | # decoder 83 | self.decoder = nn.Sequential( 84 | nn.Linear(feat_dim, 1000), 85 | nn.ReLU(True), 86 | nn.Linear(1000, 128 * 14 * 14), 87 | nn.ReLU(True), 88 | nn.Unflatten(1, (128, 14, 14)), 89 | nn.ConvTranspose2d(128, 64, 6, 2, padding=1), 90 | nn.ReLU(True), 91 | nn.ConvTranspose2d(64, 32, 6, 2, padding=1), 92 | nn.ReLU(True), 93 | nn.ConvTranspose2d(32, 3, 6, 2, padding=0), 94 | nn.ReLU(True), 95 | ) 96 | 97 | def forward(self, x): 98 | return self.decoder(self.encoder(x)) 99 | -------------------------------------------------------------------------------- /eipl/model/CAEBN.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | from eipl.layer import CoordConv2d, AddCoords 9 | 10 | 11 | class BasicCAEBN(nn.Module): 12 | #:: BasicCAEBN 13 | """BasicCAEBN""" 14 | 15 | def __init__(self, feat_dim=10): 16 | super(BasicCAEBN, self).__init__() 17 | 18 | # Encoder 19 | self.encoder = nn.Sequential( 20 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 21 | nn.BatchNorm2d(64), 22 | nn.ReLU(True), 23 | nn.Conv2d(64, 32, 3, 2, 1), 24 | nn.BatchNorm2d(32), 25 | nn.ReLU(True), 26 | nn.Conv2d(32, 16, 3, 2, 1), 27 | nn.BatchNorm2d(16), 28 | nn.ReLU(True), 29 | nn.Conv2d(16, 12, 3, 2, 1), 30 | nn.BatchNorm2d(12), 31 | nn.ReLU(True), 32 | nn.Conv2d(12, 8, 3, 2, 1), 33 | nn.BatchNorm2d(8), 34 | nn.ReLU(True), 35 | nn.Flatten(), 36 | nn.Linear(8 * 4 * 4, 50), 37 | nn.BatchNorm1d(50), 38 | nn.ReLU(True), 39 | nn.Linear(50, feat_dim), 40 | nn.BatchNorm1d(feat_dim), 41 | nn.ReLU(True), 42 | ) 43 | 44 | # Decoder 45 | self.decoder = nn.Sequential( 46 | nn.Linear(feat_dim, 50), 47 | nn.BatchNorm1d(50), 48 | nn.ReLU(True), 49 | nn.Linear(50, 8 * 4 * 4), 50 | nn.BatchNorm1d(8 * 4 * 4), 51 | nn.ReLU(True), 52 | nn.Unflatten(1, (8, 4, 4)), 53 | nn.ConvTranspose2d(8, 12, 3, 2, padding=1, output_padding=1), 54 | nn.BatchNorm2d(12), 55 | nn.ReLU(True), 56 | nn.ConvTranspose2d(12, 16, 3, 2, 1, 1), 57 | nn.BatchNorm2d(16), 58 | nn.ReLU(True), 59 | nn.ConvTranspose2d(16, 32, 3, 2, 1, 1), 60 | nn.BatchNorm2d(32), 61 | nn.ReLU(True), 62 | nn.ConvTranspose2d(32, 64, 3, 2, 1, 1), 63 | nn.BatchNorm2d(64), 64 | nn.ReLU(True), 65 | nn.ConvTranspose2d(64, 3, 3, 2, 1, 1), 66 | nn.ReLU(True), 67 | ) 68 | 69 | def forward(self, x): 70 | return self.decoder(self.encoder(x)) 71 | 72 | 73 | class CAEBN(nn.Module): 74 | #:: CAEBN 75 | """CAEBN""" 76 | 77 | def __init__(self, feat_dim=10): 78 | super(CAEBN, self).__init__() 79 | 80 | # Encoder 81 | self.encoder = nn.Sequential( 82 | nn.Conv2d(3, 32, 6, 2, 1), 83 | nn.BatchNorm2d(32), 84 | nn.ReLU(True), 85 | nn.Conv2d(32, 64, 6, 2, 1), 86 | nn.BatchNorm2d(64), 87 | nn.ReLU(True), 88 | nn.Conv2d(64, 128, 6, 2, 1), 89 | nn.BatchNorm2d(128), 90 | nn.ReLU(True), 91 | nn.Flatten(), 92 | nn.Linear(128 * 14 * 14, 1000), 93 | nn.BatchNorm1d(1000), 94 | nn.ReLU(True), 95 | nn.Linear(1000, feat_dim), 96 | nn.BatchNorm1d(feat_dim), 97 | nn.ReLU(True), 98 | ) 99 | 100 | # Decoder 101 | self.decoder = nn.Sequential( 102 | nn.Linear(feat_dim, 1000), 103 | nn.BatchNorm1d(1000), 104 | nn.ReLU(True), 105 | nn.Linear(1000, 128 * 14 * 14), 106 | nn.BatchNorm1d(128 * 14 * 14), 107 | nn.ReLU(True), 108 | nn.Unflatten(1, (128, 14, 14)), 109 | nn.ConvTranspose2d(128, 64, 6, 2, padding=1), 110 | nn.BatchNorm2d(64), 111 | nn.ReLU(True), 112 | nn.ConvTranspose2d(64, 32, 6, 2, padding=1), 113 | nn.BatchNorm2d(32), 114 | nn.ReLU(True), 115 | nn.ConvTranspose2d(32, 3, 6, 2, padding=0), 116 | nn.ReLU(True), 117 | ) 118 | 119 | def forward(self, x): 120 | return self.decoder(self.encoder(x)) 121 | -------------------------------------------------------------------------------- /eipl/model/CNNRNN.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class CNNRNN(nn.Module): 11 | #:: CNNRNN 12 | """CNNRNN""" 13 | 14 | def __init__(self, rec_dim=50, joint_dim=8, feat_dim=10): 15 | super(CNNRNN, self).__init__() 16 | 17 | # Encoder 18 | self.encoder_image = nn.Sequential( 19 | nn.Conv2d(3, 64, 3, 2, 1), 20 | nn.Tanh(), 21 | nn.Conv2d(64, 32, 3, 2, 1), 22 | nn.Tanh(), 23 | nn.Conv2d(32, 16, 3, 2, 1), 24 | nn.Tanh(), 25 | nn.Conv2d(16, 12, 3, 2, 1), 26 | nn.Tanh(), 27 | nn.Conv2d(12, 8, 3, 2, 1), 28 | nn.Tanh(), 29 | nn.Flatten(), 30 | nn.Linear(8 * 4 * 4, 50), 31 | nn.Tanh(), 32 | nn.Linear(50, feat_dim), 33 | nn.Tanh(), 34 | ) 35 | 36 | # Recurrent 37 | rec_in = feat_dim + joint_dim 38 | self.rec = nn.LSTMCell(rec_in, rec_dim) 39 | 40 | # Decoder for joint angle 41 | self.decoder_joint = nn.Sequential(nn.Linear(rec_dim, joint_dim), nn.Tanh()) 42 | 43 | # Decoder for image 44 | self.decoder_image = nn.Sequential( 45 | nn.Linear(rec_dim, 8 * 4 * 4), 46 | nn.Tanh(), 47 | nn.Unflatten(1, (8, 4, 4)), 48 | nn.ConvTranspose2d(8, 12, 3, 2, padding=1, output_padding=1), 49 | nn.Tanh(), 50 | nn.ConvTranspose2d(12, 16, 3, 2, padding=1, output_padding=1), 51 | nn.Tanh(), 52 | nn.ConvTranspose2d(16, 32, 3, 2, padding=1, output_padding=1), 53 | nn.Tanh(), 54 | nn.ConvTranspose2d(32, 64, 3, 2, padding=1, output_padding=1), 55 | nn.Tanh(), 56 | nn.ConvTranspose2d(64, 3, 3, 2, padding=1, output_padding=1), 57 | nn.Tanh(), 58 | ) 59 | 60 | # image, joint 61 | def forward(self, xi, xv, state=None): 62 | # Encoder 63 | im_feat = self.encoder_image(xi) 64 | hid = torch.concat([im_feat, xv], -1) 65 | 66 | # Recurrent 67 | rnn_hid = self.rec(hid, state) 68 | 69 | # Decoder 70 | y_joint = self.decoder_joint(rnn_hid[0]) 71 | y_image = self.decoder_image(rnn_hid[0]) 72 | 73 | return y_image, y_joint, rnn_hid 74 | -------------------------------------------------------------------------------- /eipl/model/CNNRNNLN.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class CNNRNNLN(nn.Module): 11 | #:: CNNRNNLN 12 | """CNNRNNLN""" 13 | 14 | def __init__(self, rec_dim=50, joint_dim=8, feat_dim=10): 15 | super(CNNRNNLN, self).__init__() 16 | 17 | # Encoder 18 | self.encoder_image = nn.Sequential( 19 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 20 | nn.LayerNorm([64, 64, 64]), 21 | nn.ReLU(True), 22 | nn.Conv2d(64, 32, 3, 2, 1), 23 | nn.LayerNorm([32, 32, 32]), 24 | nn.ReLU(True), 25 | nn.Conv2d(32, 16, 3, 2, 1), 26 | nn.LayerNorm([16, 16, 16]), 27 | nn.ReLU(True), 28 | nn.Conv2d(16, 12, 3, 2, 1), 29 | nn.LayerNorm([12, 8, 8]), 30 | nn.ReLU(True), 31 | nn.Conv2d(12, 8, 3, 2, 1), 32 | nn.LayerNorm([8, 4, 4]), 33 | nn.ReLU(True), 34 | nn.Flatten(), 35 | nn.Linear(8 * 4 * 4, 50), 36 | nn.LayerNorm([50]), 37 | nn.ReLU(True), 38 | nn.Linear(50, feat_dim), 39 | nn.LayerNorm([feat_dim]), 40 | nn.ReLU(True), 41 | ) 42 | 43 | # Recurrent 44 | rec_in = feat_dim + joint_dim 45 | self.rec = nn.LSTMCell(rec_in, rec_dim) 46 | 47 | # Decoder for joint angle 48 | self.decoder_joint = nn.Sequential(nn.Linear(rec_dim, joint_dim), nn.ReLU(True)) 49 | 50 | # Decoder for image 51 | self.decoder_image = nn.Sequential( 52 | nn.Linear(rec_dim, 8 * 4 * 4), 53 | nn.LayerNorm([8 * 4 * 4]), 54 | nn.ReLU(True), 55 | nn.Unflatten(1, (8, 4, 4)), 56 | nn.ConvTranspose2d(8, 12, 3, 2, padding=1, output_padding=1), 57 | nn.LayerNorm([12, 8, 8]), 58 | nn.ReLU(True), 59 | nn.ConvTranspose2d(12, 16, 3, 2, 1, 1), 60 | nn.LayerNorm([16, 16, 16]), 61 | nn.ReLU(True), 62 | nn.ConvTranspose2d(16, 32, 3, 2, 1, 1), 63 | nn.LayerNorm([32, 32, 32]), 64 | nn.ReLU(True), 65 | nn.ConvTranspose2d(32, 64, 3, 2, 1, 1), 66 | nn.LayerNorm([64, 64, 64]), 67 | nn.ReLU(True), 68 | nn.ConvTranspose2d(64, 3, 3, 2, 1, 1), 69 | nn.ReLU(True), 70 | ) 71 | 72 | # image, joint 73 | def forward(self, xi, xv, state=None): 74 | # Encoder 75 | im_feat = self.encoder_image(xi) 76 | hid = torch.concat([im_feat, xv], -1) 77 | 78 | # Recurrent 79 | rnn_hid = self.rec(hid, state) 80 | 81 | # Decoder 82 | y_joint = self.decoder_joint(rnn_hid[0]) 83 | y_image = self.decoder_image(rnn_hid[0]) 84 | 85 | return y_image, y_joint, rnn_hid 86 | -------------------------------------------------------------------------------- /eipl/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .BasicRNN import * 2 | from .CAE import * 3 | from .CAEBN import * 4 | from .CNNRNN import * 5 | from .CNNRNNLN import * 6 | from .SARNN import * -------------------------------------------------------------------------------- /eipl/test/benchmark_dataloader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import time 7 | import torch 8 | import numpy as np 9 | from eipl.data import MultimodalDataset, MultiEpochsDataLoader 10 | 11 | 12 | def test_dataloader(dataloader): 13 | torch.cuda.synchronize() 14 | start = time.time() 15 | for _ in range(10): 16 | for _ in dataloader: 17 | pass 18 | torch.cuda.synchronize() 19 | elapsed_time = time.time() - start 20 | print(elapsed_time, "sec.") 21 | 22 | 23 | # parameter 24 | stdev = 0.1 25 | device = "cuda:0" 26 | batch_size = 8 27 | images = np.random.random((30, 200, 3, 64, 64)) 28 | joints = np.random.random((30, 200, 7)) 29 | 30 | # Original dataloader with CUDA transformation 31 | # Note that CUDA transformation does not support pin_memory or num_workers. 32 | dataset = MultimodalDataset(images, joints, device="cuda:0", stdev=stdev) 33 | cuda_loader = torch.utils.data.DataLoader( 34 | dataset, batch_size=batch_size, shuffle=True, drop_last=False 35 | ) 36 | test_dataloader(cuda_loader) 37 | del cuda_loader 38 | 39 | # Original dataloader with CUDA 40 | dataset = MultimodalDataset(images, joints, device="cpu", stdev=stdev) 41 | cpu_loader = torch.utils.data.DataLoader( 42 | dataset, 43 | batch_size=batch_size, 44 | shuffle=True, 45 | drop_last=False, 46 | pin_memory=True, 47 | prefetch_factor=4, 48 | num_workers=8, 49 | ) 50 | test_dataloader(cpu_loader) 51 | del cpu_loader 52 | 53 | # Multiprocess dataloader 54 | dataset = MultimodalDataset(images, joints, device="cpu", stdev=stdev) 55 | multiepoch_loader = MultiEpochsDataLoader( 56 | dataset, 57 | batch_size=batch_size, 58 | shuffle=True, 59 | drop_last=False, 60 | pin_memory=True, 61 | prefetch_factor=8, 62 | num_workers=20, 63 | ) 64 | test_dataloader(multiepoch_loader) 65 | del multiepoch_loader 66 | -------------------------------------------------------------------------------- /eipl/test/test_CoordConv2d.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import matplotlib.pylab as plt 8 | from eipl.layer import CoordConv2d, AddCoords 9 | 10 | 11 | add_coord = AddCoords(with_r=True) 12 | x_im = torch.zeros(7, 12, 64, 64) 13 | y_im = add_coord(x_im) 14 | print("x_shape:", x_im.shape) 15 | print("y_shape:", y_im.shape) 16 | 17 | plt.subplot(1, 3, 1) 18 | plt.imshow(y_im[0, -3]) 19 | plt.subplot(1, 3, 2) 20 | plt.imshow(y_im[0, -2]) 21 | plt.subplot(1, 3, 3) 22 | plt.imshow(y_im[0, -1]) 23 | plt.show() 24 | -------------------------------------------------------------------------------- /eipl/test/test_GridMask.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import numpy as np 7 | import matplotlib.pylab as plt 8 | from eipl.layer import GridMask 9 | 10 | 11 | # single image with channel last 12 | img = np.ones((64, 64, 3)) 13 | mask = GridMask(channel_first=False) 14 | out = mask(img, debug=True) 15 | plt.figure(dpi=60) 16 | plt.imshow(out) 17 | 18 | # single image with channel first 19 | img = np.ones((3, 64, 64)) 20 | mask = GridMask(channel_first=True) 21 | out = mask(img, debug=True) 22 | out = out.transpose(1, 2, 0) 23 | plt.figure(dpi=60) 24 | plt.imshow(out) 25 | 26 | # multi images with channel last 27 | img = np.ones((4, 64, 64, 3)) 28 | mask = GridMask(channel_first=False) 29 | out = mask(img, debug=True) 30 | 31 | plt.figure(dpi=60) 32 | for i in range(4): 33 | plt.subplot(2, 2, i + 1) 34 | plt.imshow(out[i]) 35 | 36 | # multi images with channel first 37 | img = np.ones((4, 3, 64, 64)) 38 | mask = GridMask(channel_first=True) 39 | out = mask(img, debug=True) 40 | out = out.transpose(0, 2, 3, 1) 41 | 42 | plt.figure(dpi=60) 43 | for i in range(4): 44 | plt.subplot(2, 2, i + 1) 45 | plt.imshow(out[i]) 46 | plt.show() 47 | -------------------------------------------------------------------------------- /eipl/test/test_LossScheduler.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import numpy as np 7 | import matplotlib.pylab as plt 8 | from eipl.utils import LossScheduler 9 | 10 | 11 | ax = [] 12 | ax.append(plt.subplot2grid(shape=(2, 6), loc=(0, 0), colspan=2)) 13 | ax.append(plt.subplot2grid((2, 6), (0, 2), colspan=2)) 14 | ax.append(plt.subplot2grid((2, 6), (0, 4), colspan=2)) 15 | ax.append(plt.subplot2grid((2, 6), (1, 1), colspan=2)) 16 | ax.append(plt.subplot2grid((2, 6), (1, 3), colspan=2)) 17 | 18 | for i, curve_name in enumerate( 19 | ["linear", "s", "inverse_s", "deceleration", "acceleration"] 20 | ): 21 | scheduler = LossScheduler(decay_end=100, curve_name=curve_name) 22 | loss_weight_list = [] 23 | for _ in range(150): 24 | loss_weight_list.append(scheduler(loss_weight=0.1)) 25 | 26 | ax[i].plot(loss_weight_list) 27 | ax[i].set_title(curve_name) 28 | ax[i].grid() 29 | 30 | plt.tight_layout() 31 | # plt.savefig("./output/loss_scheduler.png", dpi=60) 32 | plt.show() 33 | -------------------------------------------------------------------------------- /eipl/test/test_SampleDownloader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import numpy as np 7 | import matplotlib.pylab as plt 8 | import matplotlib.animation as anim 9 | from eipl.utils import print_info 10 | from eipl.data import SampleDownloader 11 | 12 | # get airec dataset 13 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="HWC") 14 | images, joints = grasp_data.load_raw_data("train") 15 | print_info("Load raw data") 16 | print( 17 | "images: shape={}, min={:.2f}, max={:.2f}".format( 18 | images.shape, images.min(), images.max() 19 | ) 20 | ) 21 | print( 22 | "joints: shape={}, min={:.2f}, max={:.2f}".format( 23 | joints.shape, joints.min(), joints.max() 24 | ) 25 | ) 26 | 27 | # plot animation 28 | idx = 0 29 | T = images.shape[1] 30 | fig, ax = plt.subplots(1, 2, figsize=(9, 5), dpi=60) 31 | 32 | 33 | def anim_update(i): 34 | for j in range(2): 35 | ax[j].cla() 36 | 37 | # plot camera image 38 | ax[0].imshow(images[idx, i, :, :, ::-1]) 39 | ax[0].axis("off") 40 | ax[0].set_title("Input image") 41 | 42 | # plot joint angle 43 | ax[1].set_ylim(-1.0, 2.0) 44 | ax[1].set_xlim(0, T) 45 | ax[1].plot(joints[idx, 1:], linestyle="dashed", c="k") 46 | for joint_idx in range(8): 47 | ax[1].plot(np.arange(i + 1), joints[idx, : i + 1, joint_idx]) 48 | ax[1].set_xlabel("Step") 49 | ax[1].set_title("Joint angles") 50 | 51 | 52 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T) 53 | ani.save("./output/viz_downloader.gif") 54 | -------------------------------------------------------------------------------- /eipl/test/test_SpatialSoftmax.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import matplotlib.pylab as plt 8 | from eipl.layer import SpatialSoftmax, InverseSpatialSoftmax 9 | from eipl.utils import tensor2numpy, plt_img, get_feature_map 10 | 11 | 12 | channels = 4 13 | im_size = 64 14 | temperature = 1e-4 15 | heatmap_size = 0.05 16 | 17 | # Generate feature_map 18 | in_features = get_feature_map(im_size=im_size, channels=channels) 19 | print("in_features shape: ", in_features.shape) 20 | 21 | # Apply spetial softmax, get keypoints and attention map 22 | ssm = SpatialSoftmax( 23 | width=im_size, height=im_size, temperature=temperature, normalized=True 24 | ) 25 | keypoints, att_map = ssm(torch.tensor(in_features)) 26 | print("keypoints shape: ", keypoints.shape) 27 | 28 | # Generate heatmap from keypoints 29 | issm = InverseSpatialSoftmax( 30 | width=im_size, height=im_size, heatmap_size=heatmap_size, normalized=True 31 | ) 32 | out_features = issm(keypoints) 33 | out_features = tensor2numpy(out_features) 34 | print("out_features shape: ", out_features.shape) 35 | 36 | plt.figure(dpi=60) 37 | # feature map 38 | for i in range(1, channels + 1): 39 | plt.subplot(2, channels, i) 40 | plt_img( 41 | in_features[0, i - 1], key=keypoints[0, i - 1], title="feature map {}".format(i) 42 | ) 43 | 44 | # plot heatmap 45 | for i in range(1, channels + 1): 46 | plt.subplot(2, channels, channels + i) 47 | plt_img(out_features[0, i - 1], title="heatmap map {}".format(i)) 48 | 49 | plt.tight_layout() 50 | # plt.savefig("spatial_softmax.png") 51 | plt.show() 52 | -------------------------------------------------------------------------------- /eipl/test/test_WeightDownloader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import shutil 8 | from eipl.data import WeightDownloader 9 | 10 | # remove old data 11 | root_dir = os.path.join(os.path.expanduser("~"), ".eipl/") 12 | shutil.rmtree(root_dir) 13 | 14 | WeightDownloader("airec", "grasp_bottle") 15 | WeightDownloader("om", "grasp_cube") 16 | -------------------------------------------------------------------------------- /eipl/test/test_bounds.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import numpy as np 7 | import matplotlib 8 | import matplotlib.pylab as plt 9 | from eipl.utils import get_mean_minmax, get_bounds, normalization 10 | 11 | matplotlib.use("tkagg") 12 | 13 | vmin = 0.1 14 | vmax = 0.9 15 | 16 | x = np.arange(0, 10, 0.1) 17 | y1 = np.sin(x) 18 | y2 = np.sin(x + 1.0) 19 | y3 = np.sin(x + 2.0) 20 | 21 | data = np.array([y1 * 0.1, y2 * 0.5, y3 * 0.8]).T 22 | # Change data shape from [seq_len, dim] to [N, seq_len, dim] 23 | data = np.expand_dims(data, axis=0) 24 | data_mean, data_min, data_max = get_mean_minmax(data) 25 | bounds = get_bounds(data_mean, data_min, data_max, clip=0.2, vmin=vmin, vmax=vmax) 26 | norm_data = normalization(data - bounds[0], bounds[1:], (vmin, vmax)) 27 | denorm_data = normalization(norm_data, (vmin, vmax), bounds[1:]) + bounds[0] 28 | 29 | plt.figure() 30 | plt.plot(data[0]) 31 | plt.title("original data") 32 | 33 | plt.figure() 34 | plt.plot(norm_data[0]) 35 | plt.title("normalized data") 36 | 37 | plt.figure() 38 | plt.plot(denorm_data[0]) 39 | plt.title("deormalized data") 40 | 41 | plt.show() 42 | -------------------------------------------------------------------------------- /eipl/test/test_cos_interpolation.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import numpy as np 7 | import matplotlib 8 | import matplotlib.pylab as plt 9 | from eipl.utils import cos_interpolation 10 | 11 | matplotlib.use("tkagg") 12 | 13 | data = np.zeros(120) 14 | data[30:70] = 1 15 | smoothed_data = cos_interpolation(data, step=15) 16 | 17 | plt.plot(data) 18 | plt.plot(smoothed_data) 19 | plt.show() 20 | -------------------------------------------------------------------------------- /eipl/test/test_dataloader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import cv2 7 | import time 8 | import numpy as np 9 | import matplotlib.pylab as plt 10 | import matplotlib.animation as anim 11 | from eipl.utils import deprocess_img 12 | from eipl.data import SampleDownloader, ImageDataset, MultimodalDataset 13 | 14 | 15 | # Download dataset 16 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="CHW") 17 | images, joints = grasp_data.load_norm_data("train", vmin=0.1, vmax=0.9) 18 | 19 | # test for ImageDataset 20 | dataset = ImageDataset(images, stdev=0.02) 21 | x_img, y_img = dataset[0] 22 | x_im = x_img.numpy()[::-1] 23 | y_im = y_img.numpy()[::-1] 24 | 25 | plt.figure(dpi=60) 26 | plt.subplot(1, 2, 1) 27 | plt.imshow(x_im.transpose(1, 2, 0)) 28 | plt.axis("off") 29 | plt.title("Input image") 30 | plt.subplot(1, 2, 2) 31 | plt.imshow(y_im.transpose(1, 2, 0)) 32 | plt.title("True image") 33 | plt.axis("off") 34 | plt.savefig("./output/viz_image_dataset.png") 35 | plt.close() 36 | 37 | # test for MultimodalDataset 38 | multi_dataset = MultimodalDataset(images, joints) 39 | x_data, y_data = multi_dataset[1] 40 | x_img = x_data[0] 41 | y_img = y_data[0] 42 | 43 | # tensor to numpy 44 | x_img = deprocess_img(x_img.numpy().transpose(0, 2, 3, 1), 0.1, 0.9) 45 | y_img = deprocess_img(y_img.numpy().transpose(0, 2, 3, 1), 0.1, 0.9) 46 | 47 | 48 | # plot images 49 | T = len(x_img) 50 | fig, ax = plt.subplots(1, 3, figsize=(14, 6), dpi=60) 51 | 52 | 53 | def anim_update(i): 54 | for j in range(3): 55 | ax[j].cla() 56 | 57 | # plot predicted image 58 | ax[0].imshow(y_img[i, :, :, ::-1]) 59 | ax[0].axis("off") 60 | ax[0].set_title("Original image", fontsize=20) 61 | 62 | # plot camera image 63 | ax[1].imshow(x_img[i, :, :, ::-1]) 64 | ax[1].axis("off") 65 | ax[1].set_title("Noisied image", fontsize=20) 66 | 67 | # plot joint angle 68 | ax[2].set_ylim(0.0, 1.0) 69 | ax[2].set_xlim(0, T) 70 | ax[2].plot(y_data[1], linestyle="dashed", c="k") 71 | for joint_idx in range(8): 72 | ax[2].plot(np.arange(i + 1), x_data[1][: i + 1, joint_idx]) 73 | ax[2].set_xlabel("Step", fontsize=20) 74 | ax[2].set_title("Scaled joint angles", fontsize=20) 75 | ax[2].tick_params(axis="x", labelsize=16) 76 | ax[2].tick_params(axis="y", labelsize=16) 77 | plt.subplots_adjust(left=0.01, right=0.98, bottom=0.12, top=0.9) 78 | 79 | 80 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T) 81 | ani.save("./output/viz_multimodal_dataset.gif") 82 | -------------------------------------------------------------------------------- /eipl/test/test_models.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | from eipl.layer import MTRNNCell 7 | from eipl.model import BasicLSTM, BasicMTRNN 8 | from eipl.model import BasicCAE, CAE 9 | from eipl.model import BasicCAEBN, CAEBN 10 | from eipl.model import CNNRNN, CNNRNNLN, SARNN 11 | from torchinfo import summary 12 | 13 | 14 | batch_size = 50 15 | input_dim = 12 16 | 17 | ### MTRNNCell 18 | print("# MTRNNCell") 19 | model = MTRNNCell(input_dim=12, fast_dim=50, slow_dim=5, fast_tau=2, slow_tau=12) 20 | summary(model, input_size=(batch_size, input_dim)) 21 | 22 | print("# BasicMTRNN") 23 | model = BasicMTRNN( 24 | in_dim=12, fast_dim=30, slow_dim=5, fast_tau=2, slow_tau=12, activation="tanh" 25 | ) 26 | summary(model, input_dim=(batch_size, input_dim)) 27 | 28 | print("# BasicLSTM") 29 | model = BasicLSTM(in_dim=12, rec_dim=50, out_dim=10, activation="tanh") 30 | summary(model, input_dim=(batch_size, input_dim)) 31 | 32 | print("BasicCAE") 33 | model = BasicCAE() 34 | summary(model, input_size=(batch_size, 3, 128, 128)) 35 | 36 | print("CAE") 37 | model = CAE() 38 | summary(model, input_size=(batch_size, 3, 128, 128)) 39 | 40 | print("BasicCAEBN") 41 | model = BasicCAEBN(feat_dim=10) 42 | summary(model, input_size=(batch_size, 3, 128, 128)) 43 | 44 | print("CAEBN") 45 | model = CAEBN(feat_dim=30) 46 | summary(model, input_size=(batch_size, 3, 128, 128)) 47 | 48 | 49 | print("CNNRNN") 50 | model = CNNRNN(rec_dim=50, joint_dim=8, feat_dim=10) 51 | summary(model, input_size=[(batch_size, 3, 128, 128), (batch_size, 8)]) 52 | 53 | print("CNNRNNLN") 54 | model = CNNRNNLN(rec_dim=50, joint_dim=8, feat_dim=10) 55 | summary(model, input_size=[(batch_size, 3, 128, 128), (batch_size, 8)]) 56 | 57 | print("SARNN") 58 | model = SARNN(rec_dim=50, k_dim=5, joint_dim=8) 59 | summary(model, input_size=[(batch_size, 3, 128, 128), (batch_size, 8)]) 60 | -------------------------------------------------------------------------------- /eipl/tutorials/airec/ros/1_rosbag2npz.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import cv2 8 | import glob 9 | import rospy 10 | import rosbag 11 | import argparse 12 | import numpy as np 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("bag_dir", type=str) 17 | parser.add_argument("--freq", type=float, default=10) 18 | args = parser.parse_args() 19 | 20 | 21 | files = glob.glob(os.path.join(args.bag_dir, "*.bag")) 22 | files.sort() 23 | for file in files: 24 | print(file) 25 | savename = file.split(".bag")[0] + ".npz" 26 | 27 | # Open the rosbag file 28 | bag = rosbag.Bag(file) 29 | 30 | # Get the start and end times of the rosbag file 31 | start_time = bag.get_start_time() 32 | end_time = bag.get_end_time() 33 | 34 | # Get the topics in the rosbag file 35 | # topics = bag.get_type_and_topic_info()[1].keys() 36 | topics = [ 37 | "/torobo/joint_states", 38 | "/torobo/head/see3cam_left/camera/color/image_repub/compressed", 39 | "/torobo/left_hand_controller/state", 40 | ] 41 | 42 | # Create a rospy.Time object to represent the current time 43 | current_time = rospy.Time.from_sec(start_time) 44 | 45 | joint_list = [] 46 | finger_list = [] 47 | image_list = [] 48 | finger_state_list = [] 49 | 50 | prev_finger = None 51 | finger_state = 0 52 | 53 | # Loop through the rosbag file at regular intervals (args.freq) 54 | freq = 1.0 / float(args.freq) 55 | while current_time.to_sec() < end_time: 56 | print(current_time.to_sec()) 57 | 58 | # Get the messages for each topic at the current time 59 | for topic in topics: 60 | for topic_msg, msg, time in bag.read_messages( 61 | topic, start_time=current_time 62 | ): 63 | if time >= current_time: 64 | if topic == "/torobo/joint_states": 65 | joint_list.append(msg.position[7:14]) 66 | 67 | if ( 68 | topic 69 | == "/torobo/head/see3cam_left/camera/color/image_repub/compressed" 70 | ): 71 | np_arr = np.frombuffer(msg.data, np.uint8) 72 | np_img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) 73 | np_img = np_img[::2, ::2] 74 | image_list.append(np_img[150:470, 110:430].astype(np.uint8)) 75 | 76 | if topic == "/torobo/left_hand_controller/state": 77 | finger = np.array(msg.desired.positions[3]) 78 | if prev_finger is None: 79 | prev_finger = finger 80 | 81 | if finger - prev_finger > 0.005 and finger_state == 0: 82 | finger_state = 1 83 | elif prev_finger - finger > 0.005 and finger_state == 1: 84 | finger_state = 0 85 | prev_finger = finger 86 | 87 | finger_list.append(finger) 88 | finger_state_list.append(finger_state) 89 | 90 | break 91 | 92 | # Wait for the next interval 93 | current_time += rospy.Duration.from_sec(freq) 94 | 95 | # Close the rosbag file 96 | bag.close() 97 | 98 | # Convert list to array 99 | joints = np.array(joint_list, dtype=np.float32) 100 | finger = np.array(finger_list, dtype=np.float32) 101 | finger_state = np.array(finger_state_list, dtype=np.float32) 102 | images = np.array(image_list, dtype=np.uint8) 103 | 104 | # Get shorter lenght 105 | shorter_length = min(len(joints), len(images), len(finger), len(finger_state)) 106 | 107 | # Trim 108 | joints = joints[:shorter_length] 109 | finger = finger[:shorter_length] 110 | images = images[:shorter_length] 111 | finger_state = finger_state[:shorter_length] 112 | 113 | # Save 114 | np.savez( 115 | savename, joints=joints, finger=finger, finger_state=finger_state, images=images 116 | ) 117 | -------------------------------------------------------------------------------- /eipl/tutorials/airec/ros/2_make_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import cv2 8 | import glob 9 | import argparse 10 | import numpy as np 11 | import matplotlib.pylab as plt 12 | from eipl.utils import resize_img, calc_minmax, list_to_numpy, cos_interpolation 13 | 14 | 15 | def load_data(dir): 16 | joints = [] 17 | images = [] 18 | seq_length = [] 19 | 20 | files = glob.glob(os.path.join(dir, "*.npz")) 21 | files.sort() 22 | for filename in files: 23 | print(filename) 24 | npz_data = np.load(filename) 25 | 26 | images.append(resize_img(npz_data["images"], (128, 128))) 27 | finger_state = cos_interpolation(npz_data["finger_state"], expand_dims=True) 28 | _joints = np.concatenate((npz_data["joints"], finger_state), axis=-1) 29 | joints.append(_joints) 30 | seq_length.append(len(_joints)) 31 | 32 | max_seq = max(seq_length) 33 | images = list_to_numpy(images, max_seq) 34 | joints = list_to_numpy(joints, max_seq) 35 | 36 | return images, joints 37 | 38 | 39 | if __name__ == "__main__": 40 | # dataset index 41 | train_list = [0, 1, 2, 3, 5, 6, 7, 8, 10, 11, 12, 13] 42 | test_list = [4, 9, 14, 15, 16] 43 | 44 | # load data 45 | images, joints = load_data("./bag/") 46 | 47 | # save images and joints 48 | np.save("./data/train/images.npy", images[train_list].astype(np.uint8)) 49 | np.save("./data/train/joints.npy", joints[train_list].astype(np.float32)) 50 | np.save("./data/test/images.npy", images[test_list].astype(np.uint8)) 51 | np.save("./data/test/joints.npy", joints[test_list].astype(np.float32)) 52 | 53 | # save joint bounds 54 | joint_bounds = calc_minmax(joints) 55 | np.save("./data/joint_bounds.npy", joint_bounds) 56 | np.save("./data/joint_bounds.npy", joint_bounds) 57 | -------------------------------------------------------------------------------- /eipl/tutorials/airec/ros/3_check_data.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import argparse 7 | import numpy as np 8 | import matplotlib.pylab as plt 9 | import matplotlib.animation as anim 10 | from eipl.utils import normalization 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--idx", type=int, default=0) 15 | args = parser.parse_args() 16 | 17 | idx = int(args.idx) 18 | joints = np.load("./data/test/joints.npy") 19 | joint_bounds = np.load("./data/joint_bounds.npy") 20 | images = np.load("./data/test/images.npy") 21 | N = images.shape[1] 22 | 23 | 24 | # normalized joints 25 | minmax = [0.1, 0.9] 26 | norm_joints = normalization(joints, joint_bounds, minmax) 27 | 28 | # print data information 29 | print("load test data, index number is {}".format(idx)) 30 | print( 31 | "Joint: shape={}, min={:.3g}, max={:.3g}".format( 32 | joints.shape, joints.min(), joints.max() 33 | ) 34 | ) 35 | print( 36 | "Norm joint: shape={}, min={:.3g}, max={:.3g}".format( 37 | norm_joints.shape, norm_joints.min(), norm_joints.max() 38 | ) 39 | ) 40 | 41 | # plot images and normalized joints 42 | fig, ax = plt.subplots(1, 3, figsize=(14, 5), dpi=60) 43 | 44 | 45 | def anim_update(i): 46 | for j in range(3): 47 | ax[j].cla() 48 | 49 | # plot image 50 | ax[0].imshow(images[idx, i, :, :, ::-1]) 51 | ax[0].axis("off") 52 | ax[0].set_title("Image") 53 | 54 | # plot joint angle 55 | ax[1].set_ylim(-1.0, 2.0) 56 | ax[1].set_xlim(0, N) 57 | ax[1].plot(joints[idx], linestyle="dashed", c="k") 58 | 59 | for joint_idx in range(8): 60 | ax[1].plot(np.arange(i + 1), joints[idx, : i + 1, joint_idx]) 61 | ax[1].set_xlabel("Step") 62 | ax[1].set_title("Joint angles") 63 | 64 | # plot normalized joint angle 65 | ax[2].set_ylim(0.0, 1.0) 66 | ax[2].set_xlim(0, N) 67 | ax[2].plot(norm_joints[idx], linestyle="dashed", c="k") 68 | 69 | for joint_idx in range(8): 70 | ax[2].plot(np.arange(i + 1), norm_joints[idx, : i + 1, joint_idx]) 71 | ax[2].set_xlabel("Step") 72 | ax[2].set_title("Normalized joint angles") 73 | 74 | 75 | ani = anim.FuncAnimation(fig, anim_update, interval=int(N / 10), frames=N) 76 | ani.save("./output/check_data_{}.gif".format(idx)) 77 | -------------------------------------------------------------------------------- /eipl/tutorials/airec/sarnn/bin/test_pca_sarnn.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import torch 8 | import argparse 9 | import numpy as np 10 | import matplotlib.pylab as plt 11 | import matplotlib.animation as anim 12 | from sklearn.decomposition import PCA 13 | from eipl.data import SampleDownloader 14 | from eipl.model import SARNN 15 | from eipl.utils import restore_args, tensor2numpy, normalization 16 | 17 | 18 | # argument parser 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--filename", type=str, default=None) 21 | args = parser.parse_args() 22 | 23 | # check args 24 | assert args.filename, "Please set filename" 25 | 26 | # restore parameters 27 | dir_name = os.path.split(args.filename)[0] 28 | params = restore_args(os.path.join(dir_name, "args.json")) 29 | 30 | # load dataset 31 | minmax = [params["vmin"], params["vmax"]] 32 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="HWC") 33 | images, joints = grasp_data.load_raw_data("test") 34 | joint_bounds = grasp_data.joint_bounds 35 | print( 36 | "images shape:{}, min={}, max={}".format(images.shape, images.min(), images.max()) 37 | ) 38 | print( 39 | "joints shape:{}, min={}, max={}".format(joints.shape, joints.min(), joints.max()) 40 | ) 41 | 42 | # define model 43 | model = SARNN( 44 | rec_dim=params["rec_dim"], 45 | joint_dim=8, 46 | k_dim=params["k_dim"], 47 | heatmap_size=params["heatmap_size"], 48 | temperature=params["temperature"], 49 | ) 50 | 51 | # If trained with torch.compile, comment out the following code. 52 | # model = torch.compile(model) 53 | 54 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 55 | model.load_state_dict(ckpt["model_state_dict"]) 56 | model.eval() 57 | 58 | # Inference 59 | states = [] 60 | state = None 61 | nloop = images.shape[1] 62 | for loop_ct in range(nloop): 63 | # load data and normalization 64 | img_t = images[:, loop_ct].transpose(0, 3, 1, 2) 65 | img_t = torch.Tensor(img_t) 66 | img_t = normalization(img_t, (0, 255), minmax) 67 | joint_t = torch.Tensor(joints[:, loop_ct]) 68 | joint_t = normalization(joint_t, joint_bounds, minmax) 69 | 70 | # predict rnn 71 | _, _, _, _, state = model(img_t, joint_t, state) 72 | states.append(state[0]) 73 | 74 | states = torch.permute(torch.stack(states), (1, 0, 2)) 75 | states = tensor2numpy(states) 76 | # Reshape the state from [N,T,D] to [-1,D] for PCA of RNN. 77 | # N is the number of datasets 78 | # T is the sequence length 79 | # D is the dimension of the hidden state 80 | N, T, D = states.shape 81 | states = states.reshape(-1, D) 82 | 83 | # plot pca 84 | loop_ct = float(360) / T 85 | pca_dim = 3 86 | pca = PCA(n_components=pca_dim).fit(states) 87 | pca_val = pca.transform(states) 88 | # Reshape the states from [-1, pca_dim] to [N,T,pca_dim] to 89 | # visualize each state as a 3D scatter. 90 | pca_val = pca_val.reshape(N, T, pca_dim) 91 | 92 | fig = plt.figure(dpi=60) 93 | ax = fig.add_subplot(projection="3d") 94 | 95 | 96 | def anim_update(i): 97 | ax.cla() 98 | angle = int(loop_ct * i) 99 | ax.view_init(30, angle) 100 | 101 | c_list = ["C0", "C1", "C2", "C3", "C4"] 102 | for n, color in enumerate(c_list): 103 | ax.scatter( 104 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0 105 | ) 106 | 107 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0) 108 | pca_ratio = pca.explained_variance_ratio_ * 100 109 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0])) 110 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1])) 111 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2])) 112 | 113 | 114 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T) 115 | ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"])) 116 | 117 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg). 118 | # ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"]), writer="imagemagick") 119 | # ani.save("./output/PCA_SARNN_{}.mp4".format(params["tag"]), writer="ffmpeg") 120 | -------------------------------------------------------------------------------- /eipl/tutorials/airec/sarnn/bin/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import sys 8 | import torch 9 | import numpy as np 10 | import argparse 11 | from tqdm import tqdm 12 | import torch.optim as optim 13 | from collections import OrderedDict 14 | from torch.utils.tensorboard import SummaryWriter 15 | from eipl.model import SARNN 16 | from eipl.data import MultimodalDataset, SampleDownloader 17 | from eipl.utils import EarlyStopping, check_args, set_logdir, resize_img 18 | 19 | # load own library 20 | sys.path.append("./libs/") 21 | from fullBPTT import fullBPTTtrainer 22 | 23 | # argument parser 24 | parser = argparse.ArgumentParser( 25 | description="Learning spatial autoencoder with recurrent neural network" 26 | ) 27 | parser.add_argument("--model", type=str, default="sarnn") 28 | parser.add_argument("--epoch", type=int, default=10000) 29 | parser.add_argument("--batch_size", type=int, default=5) 30 | parser.add_argument("--rec_dim", type=int, default=50) 31 | parser.add_argument("--k_dim", type=int, default=5) 32 | parser.add_argument("--img_loss", type=float, default=0.1) 33 | parser.add_argument("--joint_loss", type=float, default=1.0) 34 | parser.add_argument("--pt_loss", type=float, default=0.1) 35 | parser.add_argument("--heatmap_size", type=float, default=0.1) 36 | parser.add_argument("--temperature", type=float, default=1e-4) 37 | parser.add_argument("--stdev", type=float, default=0.1) 38 | parser.add_argument("--log_dir", default="log/") 39 | parser.add_argument("--vmin", type=float, default=0.0) 40 | parser.add_argument("--vmax", type=float, default=1.0) 41 | parser.add_argument("--device", type=int, default=0) 42 | parser.add_argument("--compile", action="store_true") 43 | parser.add_argument("--tag", help="Tag name for snap/log sub directory") 44 | args = parser.parse_args() 45 | 46 | # check args 47 | args = check_args(args) 48 | 49 | # calculate the noise level (variance) from the normalized range 50 | stdev = args.stdev * (args.vmax - args.vmin) 51 | 52 | # set device id 53 | if args.device >= 0: 54 | device = "cuda:{}".format(args.device) 55 | else: 56 | device = "cpu" 57 | 58 | # load dataset 59 | minmax = [args.vmin, args.vmax] 60 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="HWC") 61 | images, joints = grasp_data.load_norm_data("train", vmin=args.vmin, vmax=args.vmax) 62 | images = resize_img(images, (64, 64)) 63 | images = images.transpose(0, 1, 4, 2, 3) 64 | train_dataset = MultimodalDataset(images, joints, device=device, stdev=stdev) 65 | train_loader = torch.utils.data.DataLoader( 66 | train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False 67 | ) 68 | 69 | images, joints = grasp_data.load_norm_data("test", vmin=args.vmin, vmax=args.vmax) 70 | images = resize_img(images, (64, 64)) 71 | images = images.transpose(0, 1, 4, 2, 3) 72 | test_dataset = MultimodalDataset(images, joints, device=device, stdev=None) 73 | test_loader = torch.utils.data.DataLoader( 74 | test_dataset, 75 | batch_size=args.batch_size, 76 | shuffle=True, 77 | drop_last=False, 78 | ) 79 | 80 | # define model 81 | model = SARNN( 82 | rec_dim=args.rec_dim, 83 | joint_dim=8, 84 | k_dim=args.k_dim, 85 | heatmap_size=args.heatmap_size, 86 | temperature=args.temperature, 87 | ) 88 | 89 | # torch.compile makes PyTorch code run faster 90 | if args.compile: 91 | torch.set_float32_matmul_precision("high") 92 | model = torch.compile(model) 93 | 94 | # set optimizer 95 | optimizer = optim.Adam(model.parameters(), eps=1e-07) 96 | 97 | # load trainer/tester class 98 | loss_weights = [args.img_loss, args.joint_loss, args.pt_loss] 99 | trainer = fullBPTTtrainer(model, optimizer, loss_weights=loss_weights, device=device) 100 | 101 | ### training main 102 | log_dir_path = set_logdir("./" + args.log_dir, args.tag) 103 | save_name = os.path.join(log_dir_path, "SARNN.pth") 104 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30) 105 | early_stop = EarlyStopping(patience=1000) 106 | 107 | with tqdm(range(args.epoch)) as pbar_epoch: 108 | for epoch in pbar_epoch: 109 | # train and test 110 | train_loss = trainer.process_epoch(train_loader) 111 | with torch.no_grad(): 112 | test_loss = trainer.process_epoch(test_loader, training=False) 113 | writer.add_scalar("Loss/train_loss", train_loss, epoch) 114 | writer.add_scalar("Loss/test_loss", test_loss, epoch) 115 | 116 | # early stop 117 | save_ckpt, _ = early_stop(test_loss) 118 | 119 | if save_ckpt: 120 | trainer.save(epoch, [train_loss, test_loss], save_name) 121 | 122 | # print process bar 123 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss)) 124 | -------------------------------------------------------------------------------- /eipl/tutorials/airec/sarnn/libs/fullBPTT.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | from eipl.utils import LossScheduler, tensor2numpy 9 | 10 | 11 | class fullBPTTtrainer: 12 | """ 13 | Helper class to train recurrent neural networks with numpy sequences 14 | 15 | Args: 16 | traindata (np.array): list of np.array. First diemension should be time steps 17 | model (torch.nn.Module): rnn model 18 | optimizer (torch.optim): optimizer 19 | input_param (float): input parameter of sequential generation. 1.0 means open mode. 20 | """ 21 | 22 | def __init__(self, model, optimizer, loss_weights=[1.0, 1.0], device="cpu"): 23 | self.device = device 24 | self.optimizer = optimizer 25 | self.loss_weights = loss_weights 26 | self.scheduler = LossScheduler(decay_end=1000, curve_name="s") 27 | self.model = model.to(self.device) 28 | 29 | def save(self, epoch, loss, savename): 30 | torch.save( 31 | { 32 | "epoch": epoch, 33 | "model_state_dict": self.model.state_dict(), 34 | #'optimizer_state_dict': self.optimizer.state_dict(), 35 | "train_loss": loss[0], 36 | "test_loss": loss[1], 37 | }, 38 | savename, 39 | ) 40 | 41 | def process_epoch(self, data, training=True): 42 | if not training: 43 | self.model.eval() 44 | else: 45 | self.model.train() 46 | 47 | total_loss = 0.0 48 | for n_batch, ((x_img, x_joint), (y_img, y_joint)) in enumerate(data): 49 | if "cpu" in self.device: 50 | x_img = x_img.to(self.device) 51 | y_img = y_img.to(self.device) 52 | x_joint = x_joint.to(self.device) 53 | y_joint = y_joint.to(self.device) 54 | 55 | state = None 56 | yi_list, yv_list = [], [] 57 | dec_pts_list, enc_pts_list = [], [] 58 | self.optimizer.zero_grad(set_to_none=True) 59 | for t in range(x_img.shape[1] - 1): 60 | _yi_hat, _yv_hat, enc_ij, dec_ij, state = self.model( 61 | x_img[:, t], x_joint[:, t], state 62 | ) 63 | yi_list.append(_yi_hat) 64 | yv_list.append(_yv_hat) 65 | enc_pts_list.append(enc_ij) 66 | dec_pts_list.append(dec_ij) 67 | 68 | yi_hat = torch.permute(torch.stack(yi_list), (1, 0, 2, 3, 4)) 69 | yv_hat = torch.permute(torch.stack(yv_list), (1, 0, 2)) 70 | 71 | img_loss = nn.MSELoss()(yi_hat, y_img[:, 1:]) * self.loss_weights[0] 72 | joint_loss = nn.MSELoss()(yv_hat, y_joint[:, 1:]) * self.loss_weights[1] 73 | # Gradually change the loss value using the LossScheluder class. 74 | pt_loss = nn.MSELoss()( 75 | torch.stack(dec_pts_list[:-1]), torch.stack(enc_pts_list[1:]) 76 | ) * self.scheduler(self.loss_weights[2]) 77 | loss = img_loss + joint_loss + pt_loss 78 | total_loss += tensor2numpy(loss) 79 | 80 | if training: 81 | loss.backward() 82 | self.optimizer.step() 83 | 84 | return total_loss / (n_batch + 1) 85 | -------------------------------------------------------------------------------- /eipl/tutorials/airec/sarnn/log/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/tutorials/airec/sarnn/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/bag2npz/1_rosbag2npz.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import cv2 8 | import glob 9 | import rospy 10 | import rosbag 11 | import argparse 12 | import numpy as np 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--bag_dir", type=str, default="./bag/") 17 | parser.add_argument("--freq", type=float, default=10) 18 | args = parser.parse_args() 19 | 20 | 21 | files = glob.glob(os.path.join(args.bag_dir, "*.bag")) 22 | files.sort() 23 | for file in files: 24 | print(file) 25 | savename = file.split(".bag")[0] + ".npz" 26 | 27 | # Open the rosbag file 28 | bag = rosbag.Bag(file) 29 | 30 | # Get the start and end times of the rosbag file 31 | start_time = bag.get_start_time() 32 | end_time = bag.get_end_time() 33 | 34 | # Get the topics in the rosbag file 35 | topics = ["/follower/joint_states", "/camera/color/image_raw/compressed"] 36 | 37 | # Create a rospy.Time object to represent the current time 38 | current_time = rospy.Time.from_sec(start_time) 39 | 40 | joint_list = [] 41 | image_list = [] 42 | 43 | prev_gripper = None 44 | gripper_state = 0 45 | 46 | # Loop through the rosbag file at regular intervals (args.freq) 47 | freq = 1.0 / float(args.freq) 48 | while current_time.to_sec() < end_time: 49 | # Get the messages for each topic at the current time 50 | for topic in topics: 51 | for topic_msg, msg, time in bag.read_messages( 52 | topic, start_time=current_time 53 | ): 54 | if time >= current_time: 55 | if topic == "/camera/color/image_raw/compressed": 56 | np_arr = np.frombuffer(msg.data, np.uint8) 57 | np_img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) 58 | np_img = np_img[::2, ::2] 59 | image_list.append(np_img.astype(np.uint8)) 60 | 61 | if topic == "/follower/joint_states": 62 | joint_list.append(msg.position) 63 | break 64 | 65 | # Wait for the next interval 66 | current_time += rospy.Duration.from_sec(freq) 67 | 68 | # Close the rosbag file 69 | bag.close() 70 | 71 | # Convert list to array 72 | joints = np.array(joint_list, dtype=np.float32) 73 | images = np.array(image_list, dtype=np.uint8) 74 | 75 | # # Get shorter lenght 76 | shorter_length = min(len(joints), len(images)) 77 | 78 | # # Trim 79 | joints = joints[:shorter_length] 80 | images = images[:shorter_length] 81 | 82 | # # Save 83 | np.savez(savename, joints=joints, images=images) 84 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/bag2npz/2_make_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import cv2 8 | import glob 9 | import argparse 10 | import numpy as np 11 | import matplotlib.pylab as plt 12 | from utils import resize_img, calc_minmax, list_to_numpy 13 | 14 | 15 | def load_data(dir): 16 | joints = [] 17 | images = [] 18 | seq_length = [] 19 | 20 | files = glob.glob(os.path.join(dir, "*.npz")) 21 | files.sort() 22 | 23 | for filename in files: 24 | print(filename) 25 | npz_data = np.load(filename) 26 | images.append(resize_img(npz_data["images"], (64, 64))) 27 | _joints = npz_data["joints"] 28 | joints.append(_joints) 29 | seq_length.append(len(_joints)) 30 | 31 | max_seq = max(seq_length) 32 | images = list_to_numpy(images, max_seq) 33 | joints = list_to_numpy(joints, max_seq) 34 | 35 | return images, joints 36 | 37 | 38 | if __name__ == "__main__": 39 | train_list = [0, 1, 2, 3, 4, 5, 6, 7, 8] 40 | test_list = [9, 10, 11, 12, 13] 41 | 42 | # load data 43 | images, joints = load_data("./bag/") 44 | 45 | # save images and joints 46 | np.save("./data/train/images.npy", images[train_list].astype(np.uint8)) 47 | np.save("./data/train/joints.npy", joints[train_list].astype(np.float32)) 48 | np.save("./data/test/images.npy", images[test_list].astype(np.uint8)) 49 | np.save("./data/test/joints.npy", joints[test_list].astype(np.float32)) 50 | 51 | # save joint bounds 52 | joint_bounds = calc_minmax(joints) 53 | np.save("./data/joint_bounds.npy", joint_bounds) 54 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/bag2npz/3_check_data.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import argparse 7 | import numpy as np 8 | import matplotlib.pylab as plt 9 | import matplotlib.animation as anim 10 | from utils import normalization 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--idx", type=int, default=0) 14 | args = parser.parse_args() 15 | 16 | idx = int(args.idx) 17 | joints = np.load("./data/test/joints.npy") 18 | joint_bounds = np.load("./data/joint_bounds.npy") 19 | images = np.load("./data/test/images.npy") 20 | N = images.shape[1] 21 | 22 | # normalized joints 23 | minmax = [0.1, 0.9] 24 | norm_joints = normalization(joints, joint_bounds, minmax) 25 | 26 | # print data information 27 | print("load test data, index number is {}".format(idx)) 28 | print( 29 | "Joint: shape={}, min={:.3g}, max={:.3g}".format( 30 | joints.shape, joints.min(), joints.max() 31 | ) 32 | ) 33 | print( 34 | "Norm joint: shape={}, min={:.3g}, max={:.3g}".format( 35 | norm_joints.shape, norm_joints.min(), norm_joints.max() 36 | ) 37 | ) 38 | 39 | # plot images and normalized joints 40 | fig, ax = plt.subplots(1, 3, figsize=(14, 5), dpi=60) 41 | 42 | 43 | def anim_update(i): 44 | for j in range(3): 45 | ax[j].cla() 46 | 47 | # plot image 48 | ax[0].imshow(images[idx, i, :, :, ::-1]) 49 | ax[0].axis("off") 50 | ax[0].set_title("Image") 51 | 52 | # plot joint angle 53 | ax[1].set_ylim(-1.0, 2.0) 54 | ax[1].set_xlim(0, N) 55 | ax[1].plot(joints[idx], linestyle="dashed", c="k") 56 | 57 | for joint_idx in range(5): 58 | ax[1].plot(np.arange(i + 1), joints[idx, : i + 1, joint_idx]) 59 | ax[1].set_xlabel("Step") 60 | ax[1].set_title("Joint angles") 61 | 62 | # plot normalized joint angle 63 | ax[2].set_ylim(0.0, 1.0) 64 | ax[2].set_xlim(0, N) 65 | ax[2].plot(norm_joints[idx], linestyle="dashed", c="k") 66 | 67 | for joint_idx in range(5): 68 | ax[2].plot(np.arange(i + 1), norm_joints[idx, : i + 1, joint_idx]) 69 | ax[2].set_xlabel("Step") 70 | ax[2].set_title("Normalized joint angles") 71 | 72 | 73 | ani = anim.FuncAnimation(fig, anim_update, interval=int(N / 10), frames=N) 74 | ani.save("./output/check_data_{}.gif".format(idx)) 75 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/bag2npz/bag/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/bag2npz/data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/bag2npz/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import cv2 8 | import glob 9 | import datetime 10 | import numpy as np 11 | import matplotlib.pylab as plt 12 | 13 | 14 | def normalization(data, indataRange, outdataRange): 15 | """ 16 | Function to normalize a numpy array within a specified range 17 | Args: 18 | data (np.array): Data array 19 | indataRange (float list): List of maximum and minimum values of original data, e.g. indataRange=[0.0, 255.0]. 20 | outdataRange (float list): List of maximum and minimum values of output data, e.g. indataRange=[0.0, 1.0]. 21 | Return: 22 | data (np.array): Normalized data array 23 | """ 24 | data = (data - indataRange[0]) / (indataRange[1] - indataRange[0]) 25 | data = data * (outdataRange[1] - outdataRange[0]) + outdataRange[0] 26 | return data 27 | 28 | 29 | def resize_img(img, size=(64, 64), reshape_flag=True): 30 | """ 31 | Convert tensor to numpy array. 32 | """ 33 | if len(img.shape) == 5: 34 | N, T, W, H, C = img.shape 35 | img = img.reshape((-1,) + img.shape[2:]) 36 | else: 37 | reshape_flag = False 38 | 39 | imgs = [] 40 | for i in range(len(img)): 41 | imgs.append(cv2.resize(img[i], size)) 42 | 43 | imgs = np.array(imgs) 44 | if reshape_flag: 45 | imgs = imgs.reshape(N, T, size[1], size[0], 3) 46 | return imgs 47 | 48 | 49 | def calc_minmax(_data): 50 | data = _data.reshape(-1, _data.shape[-1]) 51 | data_minmax = np.array([np.min(data, 0), np.max(data, 0)]) 52 | return data_minmax 53 | 54 | 55 | def list_to_numpy(data_list, max_N): 56 | dtype = data_list[0].dtype 57 | array = np.ones( 58 | ( 59 | len(data_list), 60 | max_N, 61 | ) 62 | + data_list[0].shape[1:], 63 | dtype, 64 | ) 65 | 66 | for i, data in enumerate(data_list): 67 | N = len(data) 68 | array[i, :N] = data[:N].astype(dtype) 69 | array[i, N:] = array[i, N:] * data[-1].astype(dtype) 70 | 71 | return array 72 | 73 | 74 | def cos_interpolation(data, step=20, expand_dims=False): 75 | """ 76 | Args: 77 | data (seq_length): time-series sensor data 78 | """ 79 | 80 | data = data.copy() 81 | points = np.diff(data) 82 | 83 | for i, p in enumerate(points): 84 | if p == 1: 85 | t = np.linspace(0.0, 1.0, step * 2) 86 | elif p == -1: 87 | t = np.linspace(1.0, 0.0, step * 2) 88 | else: 89 | continue 90 | 91 | x_latent = (1 - np.cos(t * np.pi)) / 2 92 | data[i - step + 1 : i + step + 1] = x_latent 93 | 94 | if expand_dims: 95 | data = np.expand_dims(data, axis=-1) 96 | 97 | return data 98 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/om_teleop/config/d435_param.yaml: -------------------------------------------------------------------------------- 1 | !!python/object/new:dynamic_reconfigure.encoding.Config 2 | dictitems: 3 | auto_exposure_priority: false 4 | backlight_compensation: false 5 | brightness: -20 6 | contrast: 60 7 | enable_auto_exposure: false 8 | enable_auto_white_balance: false 9 | exposure: 166 10 | frames_queue_size: 16 11 | gain: 100 12 | gamma: 300 13 | global_time_enabled: true 14 | groups: !!python/object/new:dynamic_reconfigure.encoding.Config 15 | dictitems: 16 | auto_exposure_priority: false 17 | backlight_compensation: false 18 | brightness: -20 19 | contrast: 60 20 | enable_auto_exposure: false 21 | enable_auto_white_balance: false 22 | exposure: 166 23 | frames_queue_size: 16 24 | gain: 100 25 | gamma: 300 26 | global_time_enabled: true 27 | groups: !!python/object/new:dynamic_reconfigure.encoding.Config 28 | state: [] 29 | hue: 0 30 | id: 0 31 | name: Default 32 | parameters: !!python/object/new:dynamic_reconfigure.encoding.Config 33 | state: [] 34 | parent: 0 35 | power_line_frequency: 3 36 | saturation: 64 37 | sharpness: 50 38 | state: true 39 | type: '' 40 | white_balance: 4600.0 41 | state: [] 42 | hue: 0 43 | power_line_frequency: 3 44 | saturation: 64 45 | sharpness: 50 46 | white_balance: 4600.0 47 | state: [] 48 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/om_teleop/launch/playback_bringup.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/om_teleop/launch/rt_bringup.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/om_teleop/launch/teleop_bringup.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/om_teleop/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | om_teleop 4 | 0.0.0 5 | The om_teleop package 6 | 7 | 8 | 9 | 10 | ito 11 | 12 | 13 | 14 | 15 | 16 | TODO 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | catkin 52 | rospy 53 | std_msgs 54 | rospy 55 | std_msgs 56 | rospy 57 | std_msgs 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/om_teleop/src/dynamixel_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import sys, tty, termios 8 | import numpy as np 9 | 10 | # Control table address (need to change) 11 | ADDR_OPERATING_MODE = 1 # Data Byte Length 12 | CURRENT_POSITION_CONTROL_MODE = 5 13 | ADDR_GOAL_CURRENT = 102 14 | ADDR_TORQUE_ENABLE = 64 15 | ADDR_LED_RED = 65 16 | LEN_LED_RED = 1 # Data Byte Length 17 | ADDR_GOAL_POSITION = 116 18 | LEN_GOAL_POSITION = 4 # Data Byte Length 19 | ADDR_PRESENT_POSITION = 132 20 | LEN_PRESENT_POSITION = 4 # Data Byte Length 21 | DXL_MINIMUM_POSITION_VALUE = 0 # Refer to the Minimum Position Limit of product eManual 22 | DXL_MAXIMUM_POSITION_VALUE = ( 23 | 4095 # Refer to the Maximum Position Limit of product eManual 24 | ) 25 | TORQUE_ENABLE = 1 # Value for enabling the torque 26 | TORQUE_DISABLE = 0 # Value for disabling the torque 27 | DXL_MOVING_STATUS_THRESHOLD = 20 # Dynamixel moving status threshold 28 | PROTOCOL_VERSION = 2.0 29 | 30 | 31 | ## getch 32 | fd = sys.stdin.fileno() 33 | old_settings = termios.tcgetattr(fd) 34 | 35 | 36 | def getch(): 37 | try: 38 | tty.setraw(sys.stdin.fileno()) 39 | ch = sys.stdin.read(1) 40 | finally: 41 | termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) 42 | return ch 43 | 44 | 45 | ## value2radian 46 | def value2degree(data): 47 | """convert dynamixel value to radian""" 48 | return np.array(data, dtype=np.float32) * 360.0 / 4095.0 49 | 50 | 51 | def value2radian(data): 52 | """convert dynamixel value to degree""" 53 | return (np.array(data, dtype=np.float32) - 2047.5) / 2047.5 * np.pi 54 | 55 | 56 | def degree2value(data): 57 | """convert dynamixel value to radian""" 58 | return np.array(data, dtype=np.float32) * 4095.0 / 360.0 59 | 60 | 61 | def radian2value(data): 62 | """convert dynamixel value to degree""" 63 | val = np.array(data, dtype=np.float32) / np.pi * 2047.5 + 2047.5 64 | return val.astype(np.int32) 65 | 66 | 67 | def normalization(dselfata, indataRange, outdataRange): 68 | data = (data - indataRange[0]) / (indataRange[1] - indataRange[0]) 69 | data = data * (outdataRange[1] - outdataRange[0]) + outdataRange[0] 70 | return data 71 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/om_teleop/src/follower_bringup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import sys 8 | import time 9 | import numpy as np 10 | from dynamixel_sdk import * 11 | from dynamixel_driver import * 12 | 13 | # ROS 14 | import rospy 15 | from sensor_msgs.msg import JointState 16 | 17 | 18 | class FollowerARM(DynamixelDriver): 19 | def __init__(self, freq, motor_list, device="/dev/ttyUSB0", baudrate=1000000): 20 | self.freq = freq 21 | self.motor_list = motor_list 22 | self.motor_names = ["motor_{}".format(m) for m in motor_list] 23 | self.baudrate = baudrate 24 | self.cmd_sub = rospy.Subscriber( 25 | "/leader/interpolated_command", JointState, self.cmdCallback 26 | ) 27 | self.joint_pub = rospy.Publisher( 28 | "/follower/joint_states", JointState, queue_size=1 29 | ) 30 | 31 | self.cmd = None 32 | 33 | # setup the message 34 | self.joint_msg = JointState() 35 | self.joint_msg.name = self.motor_names 36 | 37 | # Initialization 38 | self.portHandler = PortHandler(device) 39 | self.packetHandler = PacketHandler(PROTOCOL_VERSION) 40 | self.groupBulkWrite = GroupBulkWrite(self.portHandler, self.packetHandler) 41 | self.groupBulkRead = GroupBulkRead(self.portHandler, self.packetHandler) 42 | self.serial_open() 43 | self.set_gripper_mode() 44 | self.torque_on() 45 | 46 | def cmdCallback(self, msg): 47 | self.cmd = np.array(msg.position) 48 | 49 | def run(self): 50 | rate = rospy.Rate(self.freq) 51 | 52 | rospy.logwarn("FollowerARM.run(): Starting execution") 53 | while not rospy.is_shutdown(): 54 | # start_time = time.time() 55 | # set joint 56 | if self.cmd is not None: 57 | self.set_joint_radian(self.cmd) 58 | 59 | # set message 60 | self.joint_msg.header.stamp = rospy.Time.now() 61 | self.joint_msg.position = self.get_joint_radian() 62 | 63 | # publish 64 | self.joint_pub.publish(self.joint_msg) 65 | rate.sleep() 66 | # print(time.time()-start_time) 67 | 68 | rospy.logwarn("FollowerARM.run(): Finished execution") 69 | self.torque_off() 70 | 71 | 72 | def main(freq, motor_list, device, baudrate): 73 | follower_arm = FollowerARM(freq, motor_list, device, baudrate) 74 | 75 | follower_arm.run() 76 | 77 | 78 | if __name__ == "__main__": 79 | try: 80 | rospy.init_node("dxl_follower_node", anonymous=True) 81 | freq = rospy.get_param("dxl_follower_node/freq") 82 | device = rospy.get_param("dxl_follower_node/device") 83 | baudrate = rospy.get_param("dxl_follower_node/baudrate") 84 | motor_list = rospy.get_param("dxl_follower_node/motor_list") 85 | motor_list = [eval(m) for m in motor_list.split(",")] 86 | main(freq, motor_list, device, baudrate) 87 | except rospy.ROSInterruptException: 88 | pass 89 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/om_teleop/src/interplation_node.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import time 7 | import rospy 8 | import numpy as np 9 | from enum import Enum 10 | from sensor_msgs.msg import JointState 11 | 12 | 13 | def linear_interpolation(q0, q1, T, t): 14 | """Linear interpolation between q0 and q1""" 15 | s = min(1.0, max(0.0, t / T)) 16 | return (q1 - q0) * s + q0 17 | 18 | 19 | class Interpolator: 20 | def __init__( 21 | self, control_freq=50, target_freq=10, max_target_jump=20.0 / 180.0 * np.pi 22 | ): 23 | self.control_period = 1.0 / control_freq 24 | self.target_period = 1.0 / target_freq 25 | self.max_target_jump = max_target_jump 26 | self.max_target_delay = 2 * self.target_period 27 | 28 | # setup the message 29 | motor_list = [11, 12, 13, 14, 15] 30 | self.joint_msg = JointState() 31 | self.motor_names = ["motor_{}".format(m) for m in motor_list] 32 | self.joint_msg.name = self.motor_names 33 | 34 | # setup states 35 | self.has_valid_target = False 36 | self.target_update_time = None 37 | self.cmd = None 38 | self.prev_target = None 39 | 40 | # low freq target input from network 41 | self.target_sub = rospy.Subscriber( 42 | "/leader/joint_states", JointState, self.targetCallback 43 | ) 44 | 45 | # setup the online effort controller client 46 | self.cmd_pub = rospy.Publisher( 47 | "/leader/interpolated_command", JointState, queue_size=1 48 | ) 49 | 50 | def run(self): 51 | rate = rospy.Rate(1.0 / self.control_period) 52 | 53 | self.target_update_time = None 54 | self.target = None 55 | self.has_valid_target = False 56 | 57 | rospy.logwarn("OnlineExecutor.run(): Starting execution") 58 | while not rospy.is_shutdown(): 59 | if not self.has_valid_target: 60 | # just keep the current state 61 | rospy.logwarn_throttle( 62 | 1.0, "OnlineExecutor.run(): Wait for first target" 63 | ) 64 | 65 | if self.has_valid_target: 66 | # interpolation between prev_target and next target 67 | t = (rospy.Time.now() - self.target_update_time).to_sec() 68 | 69 | self.cmd = linear_interpolation( 70 | self.prev_target, self.target, self.target_period, t 71 | ) 72 | 73 | # check if communicaton is broken 74 | if t > self.max_target_delay: 75 | self.has_valid_target = False 76 | rospy.logwarn( 77 | "OnlineExecutor.run(): Interpolation stopped, wait for valid command" 78 | ) 79 | 80 | # send the target to robot 81 | self.command(self.cmd) 82 | 83 | rate.sleep() 84 | 85 | rospy.logwarn("OnlineExecutor.run(): Finished execution") 86 | 87 | def command(self, cmd): 88 | if not np.isnan(cmd).any(): 89 | self.joint_msg.header.stamp = rospy.Time.now() 90 | self.joint_msg.position = cmd 91 | self.cmd_pub.publish(self.joint_msg) 92 | 93 | def targetCallback(self, msg): 94 | """ 95 | next joint target callback from DL model 96 | stores next target, current time, previous target 97 | """ 98 | 99 | # extract the state form message 100 | target = np.array(msg.position) 101 | 102 | if self.cmd is not None: 103 | # savety first, check last command against new target pose 104 | if np.max(np.abs(self.cmd - target)) > self.max_target_jump: 105 | self.has_valid_target = False 106 | idx = np.argmax(np.abs(self.cmd - target)) 107 | rospy.logerr_throttle( 108 | 1.0, 109 | "OnlineExecutor.targetCallback(): Jump in cmd[{:d}]={:f} to target[{:d}]={:f} > {:f}".format( 110 | idx, self.cmd[idx], idx, target[idx], self.max_target_jump 111 | ), 112 | ) 113 | return 114 | else: 115 | # initialization 116 | rospy.logwarn("OnlineExecutor.run(): Recieved first data") 117 | self.cmd = target 118 | 119 | # store target and last target 120 | self.target = target 121 | self.prev_target = np.copy(self.cmd) 122 | self.target_update_time = rospy.Time.now() 123 | 124 | # target was good 125 | self.has_valid_target = True 126 | 127 | 128 | def main(control_freq, target_freq): 129 | executor = Interpolator( 130 | control_freq=control_freq, 131 | target_freq=target_freq, 132 | max_target_jump=50.0 / 180.0 * np.pi, 133 | ) 134 | 135 | executor.run() 136 | 137 | 138 | if __name__ == "__main__": 139 | try: 140 | rospy.init_node("interpolator_node", anonymous=True) 141 | control_freq = rospy.get_param("interpolator_node/control_freq") 142 | target_freq = rospy.get_param("interpolator_node/target_freq") 143 | main(control_freq, target_freq) 144 | except rospy.ROSInterruptException: 145 | pass 146 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/om_teleop/src/leader_bringup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import sys 8 | import time 9 | import numpy as np 10 | from dynamixel_sdk import * 11 | from dynamixel_driver import * 12 | 13 | # ROS 14 | import rospy 15 | from sensor_msgs.msg import JointState 16 | 17 | 18 | class LeaderARM(DynamixelDriver): 19 | def __init__(self, freq, motor_list, device="/dev/ttyUSB1", baudrate=1000000): 20 | self.freq = freq 21 | self.motor_list = motor_list 22 | self.motor_names = ["motor_{}".format(m) for m in motor_list] 23 | self.baudrate = baudrate 24 | self.joint_pub = rospy.Publisher( 25 | "/leader/joint_states", JointState, queue_size=1 26 | ) 27 | 28 | # setup the message 29 | self.joint_msg = JointState() 30 | self.joint_msg.name = self.motor_names 31 | 32 | # Initialization 33 | self.portHandler = PortHandler(device) 34 | self.packetHandler = PacketHandler(PROTOCOL_VERSION) 35 | self.groupBulkWrite = GroupBulkWrite(self.portHandler, self.packetHandler) 36 | self.groupBulkRead = GroupBulkRead(self.portHandler, self.packetHandler) 37 | self.serial_open() 38 | 39 | def run(self): 40 | rate = rospy.Rate(self.freq) 41 | 42 | rospy.logwarn("LeaderARM.run(): Starting execution") 43 | while not rospy.is_shutdown(): 44 | # set message 45 | self.joint_msg.header.stamp = rospy.Time.now() 46 | self.joint_msg.position = self.get_joint_radian() 47 | 48 | # publish 49 | self.joint_pub.publish(self.joint_msg) 50 | rate.sleep() 51 | 52 | rospy.logwarn("LeaderARM.run(): Finished execution") 53 | self.torque_off() 54 | 55 | 56 | def main(freq, motor_list, device, baudrate): 57 | leader_arm = LeaderARM(freq, motor_list, device, baudrate) 58 | 59 | leader_arm.run() 60 | 61 | 62 | if __name__ == "__main__": 63 | try: 64 | rospy.init_node("dxl_leader_node", anonymous=True) 65 | freq = rospy.get_param("dxl_leader_node/freq") 66 | device = rospy.get_param("dxl_leader_node/device") 67 | baudrate = rospy.get_param("dxl_leader_node/baudrate") 68 | motor_list = rospy.get_param("dxl_leader_node/motor_list") 69 | motor_list = [eval(m) for m in motor_list.split(",")] 70 | main(freq, motor_list, device, baudrate) 71 | except rospy.ROSInterruptException: 72 | pass 73 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/ros/om_teleop/src/playback.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import sys 8 | import time 9 | import numpy as np 10 | from dynamixel_sdk import * 11 | from dynamixel_driver import * 12 | 13 | # ROS 14 | import rospy 15 | from sensor_msgs.msg import JointState 16 | 17 | 18 | def main(freq, motor_list): 19 | # initialization 20 | motor_list = motor_list 21 | motor_names = ["motor_{}".format(m) for m in motor_list] 22 | 23 | # publisher 24 | joint_pub = rospy.Publisher("/leader/joint_states", JointState, queue_size=1) 25 | 26 | # setup the message 27 | joint_msg = JointState() 28 | joint_msg.name = motor_names 29 | 30 | # load_data 31 | joint_angles = np.load("../bag2npy/data/train/joints.npy")[0] 32 | nloop = len(joint_angles) 33 | rate = rospy.Rate(freq) 34 | 35 | rospy.logwarn("Playback: Starting execution") 36 | for loop_ct in range(nloop): 37 | # set message 38 | joint_msg.header.stamp = rospy.Time.now() 39 | joint_msg.position = joint_angles[loop_ct] 40 | 41 | # publish 42 | joint_pub.publish(joint_msg) 43 | rate.sleep() 44 | 45 | rospy.logwarn("Playback: Finished execution") 46 | 47 | 48 | if __name__ == "__main__": 49 | try: 50 | rospy.init_node("dxl_playback_node", anonymous=True) 51 | freq = rospy.get_param("dxl_playback_node/freq") 52 | motor_list = rospy.get_param("dxl_playback_node/motor_list") 53 | motor_list = [eval(m) for m in motor_list.split(",")] 54 | main(freq, motor_list) 55 | except rospy.ROSInterruptException: 56 | pass 57 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/sarnn/bin/export_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from os import path 4 | 5 | import torch 6 | 7 | from eipl.model import SARNN 8 | 9 | def load_config(file): 10 | with open(file, "r") as f: 11 | config = json.load(f) 12 | 13 | #! temp fix 14 | # joint_dim is 5 for open_manipulator 15 | # im_size is 64x64 for object grasp task 16 | # hard coded for now 17 | config["joint_dim"] = 5 18 | config["im_size"] = [64, 64] 19 | 20 | return config 21 | 22 | 23 | def export_onnx(weight, output, config, verbose=False): 24 | model = SARNN( 25 | rec_dim=config["rec_dim"], 26 | joint_dim=config["joint_dim"], 27 | k_dim=config["k_dim"], 28 | heatmap_size=config["heatmap_size"], 29 | temperature=config["temperature"], 30 | im_size=config["im_size"], 31 | ) 32 | 33 | ckpt = torch.load(weight, map_location=torch.device("cpu")) 34 | model.load_state_dict(ckpt["model_state_dict"]) 35 | model.eval() 36 | 37 | input_names = ["i.image", "i.joint", "i.state_h", "i.state_c"] 38 | output_names = ["o.image", "o.joint", "o.enc_pts", "o.dec_pts", "o.state_h", "o.state_c"] 39 | dummy_input = ( 40 | torch.randn(1, 3, config["im_size"][0], config["im_size"][1]), 41 | torch.randn(1, config["joint_dim"]), 42 | tuple(torch.randn(1, config["rec_dim"]) for _ in range(2)), 43 | ) 44 | torch.onnx.export( 45 | model, 46 | dummy_input, 47 | output, 48 | input_names=input_names, 49 | output_names=output_names, 50 | verbose=verbose, 51 | ) 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("weight", help="PyTorch model weight file") 56 | parser.add_argument("--config", help="Configuration file", default=None) 57 | parser.add_argument("--output", help="Output ONNX file", default="model.onnx") 58 | parser.add_argument("--verbose", "-v", action="store_true") 59 | args = parser.parse_args() 60 | 61 | if args.config is not None: 62 | config = load_config(args.config) 63 | else: 64 | config_path = path.join(path.dirname(args.weight), "args.json") 65 | config = load_config(config_path) 66 | 67 | export_onnx(args.weight, args.output, config, args.verbose) 68 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/sarnn/bin/test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import torch 8 | import argparse 9 | import numpy as np 10 | import matplotlib.pylab as plt 11 | import matplotlib.animation as anim 12 | from eipl.data import SampleDownloader, WeightDownloader 13 | from eipl.utils import restore_args, tensor2numpy, deprocess_img, normalization 14 | from eipl.model import SARNN 15 | 16 | # argument parser 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--filename", type=str, default=None) 19 | parser.add_argument("--idx", type=int, default=0) 20 | parser.add_argument("--pretrained", action="store_true") 21 | args = parser.parse_args() 22 | 23 | # check args 24 | assert args.filename or args.pretrained, "Please set filename or pretrained" 25 | 26 | # load pretrained weight 27 | if args.pretrained: 28 | WeightDownloader("om", "grasp_cube") 29 | args.filename = os.path.join( 30 | os.path.expanduser("~"), ".eipl/om/grasp_cube/weights/SARNN/model.pth" 31 | ) 32 | 33 | # restore parameters 34 | dir_name = os.path.split(args.filename)[0] 35 | params = restore_args(os.path.join(dir_name, "args.json")) 36 | idx = args.idx 37 | 38 | # load dataset 39 | minmax = [params["vmin"], params["vmax"]] 40 | grasp_data = SampleDownloader("om", "grasp_cube", img_format="HWC") 41 | _images, _joints = grasp_data.load_raw_data("test") 42 | images = _images[idx] 43 | joints = _joints[idx] 44 | joint_bounds = np.load( 45 | os.path.join(os.path.expanduser("~"), ".eipl/om/grasp_cube/joint_bounds.npy") 46 | ) 47 | print( 48 | "images shape:{}, min={}, max={}".format(images.shape, images.min(), images.max()) 49 | ) 50 | print( 51 | "joints shape:{}, min={}, max={}".format(joints.shape, joints.min(), joints.max()) 52 | ) 53 | 54 | # define model 55 | model = SARNN( 56 | rec_dim=params["rec_dim"], 57 | joint_dim=5, 58 | k_dim=params["k_dim"], 59 | heatmap_size=params["heatmap_size"], 60 | temperature=params["temperature"], 61 | im_size=[64, 64], 62 | ) 63 | 64 | if params["compile"]: 65 | model = torch.compile(model) 66 | 67 | # load weight 68 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 69 | model.load_state_dict(ckpt["model_state_dict"]) 70 | model.eval() 71 | 72 | # Inference 73 | img_size = 64 74 | image_list, joint_list = [], [] 75 | enc_pts_list, dec_pts_list = [], [] 76 | state = None 77 | nloop = len(images) 78 | for loop_ct in range(nloop): 79 | # load data and normalization 80 | img_t = images[loop_ct].transpose(2, 0, 1) 81 | img_t = torch.Tensor(np.expand_dims(img_t, 0)) 82 | img_t = normalization(img_t, (0, 255), minmax) 83 | joint_t = torch.Tensor(np.expand_dims(joints[loop_ct], 0)) 84 | joint_t = normalization(joint_t, joint_bounds, minmax) 85 | 86 | # predict rnn 87 | y_image, y_joint, enc_pts, dec_pts, state = model(img_t, joint_t, state) 88 | 89 | # denormalization 90 | pred_image = tensor2numpy(y_image[0]) 91 | pred_image = deprocess_img(pred_image, params["vmin"], params["vmax"]) 92 | pred_image = pred_image.transpose(1, 2, 0) 93 | pred_joint = tensor2numpy(y_joint[0]) 94 | pred_joint = normalization(pred_joint, minmax, joint_bounds) 95 | 96 | # append data 97 | image_list.append(pred_image) 98 | joint_list.append(pred_joint) 99 | enc_pts_list.append(tensor2numpy(enc_pts[0])) 100 | dec_pts_list.append(tensor2numpy(dec_pts[0])) 101 | 102 | print("loop_ct:{}, joint:{}".format(loop_ct, pred_joint)) 103 | 104 | pred_image = np.array(image_list) 105 | pred_joint = np.array(joint_list) 106 | 107 | # split key points 108 | enc_pts = np.array(enc_pts_list) 109 | dec_pts = np.array(dec_pts_list) 110 | enc_pts = enc_pts.reshape(-1, params["k_dim"], 2) * img_size 111 | dec_pts = dec_pts.reshape(-1, params["k_dim"], 2) * img_size 112 | enc_pts = np.clip(enc_pts, 0, img_size) 113 | dec_pts = np.clip(dec_pts, 0, img_size) 114 | 115 | 116 | # plot images 117 | T = len(images) 118 | fig, ax = plt.subplots(1, 3, figsize=(12, 5), dpi=60) 119 | 120 | 121 | def anim_update(i): 122 | for j in range(3): 123 | ax[j].cla() 124 | 125 | # plot camera image 126 | ax[0].imshow(images[i, :, :, ::-1]) 127 | for j in range(params["k_dim"]): 128 | ax[0].plot(enc_pts[i, j, 0], enc_pts[i, j, 1], "bo", markersize=6) # encoder 129 | ax[0].plot( 130 | dec_pts[i, j, 0], dec_pts[i, j, 1], "rx", markersize=6, markeredgewidth=2 131 | ) # decoder 132 | ax[0].axis("off") 133 | ax[0].set_title("Input image") 134 | 135 | # plot predicted image 136 | ax[1].imshow(pred_image[i, :, :, ::-1]) 137 | ax[1].axis("off") 138 | ax[1].set_title("Predicted image") 139 | 140 | # plot joint angle 141 | ax[2].set_ylim(-1.0, 2.0) 142 | ax[2].set_xlim(0, T) 143 | ax[2].plot(joints[1:], linestyle="dashed", c="k") 144 | # om has 5 joints, not 8 145 | for joint_idx in range(5): 146 | ax[2].plot(np.arange(i + 1), pred_joint[: i + 1, joint_idx]) 147 | ax[2].set_xlabel("Step") 148 | ax[2].set_title("Joint angles") 149 | 150 | 151 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T) 152 | ani.save("./output/SARNN_{}_{}.gif".format(params["tag"], idx)) 153 | 154 | # If an error occurs in generating the gif animation, change the writer (imagemagick/ffmpeg). 155 | # ani.save("./output/SARNN_{}_{}_{}.gif".format(params["tag"], idx, args.input_param), writer="ffmpeg") 156 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/sarnn/bin/test_pca_sarnn.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import sys 8 | import torch 9 | import argparse 10 | import numpy as np 11 | import matplotlib.pylab as plt 12 | import matplotlib.animation as anim 13 | from sklearn.decomposition import PCA 14 | from eipl.data import SampleDownloader 15 | from eipl.model import SARNN 16 | from eipl.utils import restore_args, tensor2numpy, normalization 17 | 18 | 19 | # argument parser 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("filename", type=str, default=None) 22 | args = parser.parse_args() 23 | 24 | # restore parameters 25 | dir_name = os.path.split(args.filename)[0] 26 | params = restore_args(os.path.join(dir_name, "args.json")) 27 | # idx = args.idx 28 | 29 | # load dataset 30 | minmax = [params["vmin"], params["vmax"]] 31 | grasp_data = SampleDownloader("om", "grasp_cube", img_format="HWC") 32 | images, joints = grasp_data.load_raw_data("test") 33 | joint_bounds = grasp_data.joint_bounds 34 | print( 35 | "images shape:{}, min={}, max={}".format(images.shape, images.min(), images.max()) 36 | ) 37 | print( 38 | "joints shape:{}, min={}, max={}".format(joints.shape, joints.min(), joints.max()) 39 | ) 40 | 41 | # define model 42 | model = SARNN( 43 | rec_dim=params["rec_dim"], 44 | joint_dim=5, 45 | k_dim=params["k_dim"], 46 | heatmap_size=params["heatmap_size"], 47 | temperature=params["temperature"], 48 | im_size=[64, 64], 49 | ) 50 | 51 | if params["compile"]: 52 | model = torch.compile(model) 53 | 54 | # load weight 55 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 56 | model.load_state_dict(ckpt["model_state_dict"]) 57 | model.eval() 58 | 59 | # Inference 60 | states = [] 61 | state = None 62 | nloop = images.shape[1] 63 | for loop_ct in range(nloop): 64 | # load data and normalization 65 | img_t = images[:, loop_ct].transpose(0, 3, 1, 2) 66 | img_t = torch.Tensor(img_t) 67 | img_t = normalization(img_t, (0, 255), minmax) 68 | joint_t = torch.Tensor(joints[:, loop_ct]) 69 | joint_t = normalization(joint_t, joint_bounds, minmax) 70 | 71 | # predict rnn 72 | _, _, _, _, state = model(img_t, joint_t, state) 73 | states.append(state[0]) 74 | 75 | states = torch.permute(torch.stack(states), (1, 0, 2)) 76 | states = tensor2numpy(states) 77 | # Reshape the state from [N,T,D] to [-1,D] for PCA of RNN. 78 | # N is the number of datasets 79 | # T is the sequence length 80 | # D is the dimension of the hidden state 81 | N, T, D = states.shape 82 | states = states.reshape(-1, D) 83 | 84 | # PCA 85 | loop_ct = float(360) / T 86 | pca_dim = 3 87 | pca = PCA(n_components=pca_dim).fit(states) 88 | pca_val = pca.transform(states) 89 | # Reshape the states from [-1, pca_dim] to [N,T,pca_dim] to 90 | # visualize each state as a 3D scatter. 91 | pca_val = pca_val.reshape(N, T, pca_dim) 92 | 93 | # plot images 94 | fig = plt.figure(dpi=60) 95 | ax = fig.add_subplot(projection="3d") 96 | 97 | 98 | def anim_update(i): 99 | ax.cla() 100 | angle = int(loop_ct * i) 101 | ax.view_init(30, angle) 102 | 103 | c_list = ["C0", "C1", "C2", "C3", "C4"] 104 | for n, color in enumerate(c_list): 105 | ax.scatter( 106 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0 107 | ) 108 | 109 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0) 110 | pca_ratio = pca.explained_variance_ratio_ * 100 111 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0])) 112 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1])) 113 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2])) 114 | 115 | 116 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T) 117 | ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"])) 118 | 119 | # If an error occurs in generating the gif or mp4 animation, change the writer (imagemagick/ffmpeg). 120 | # ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"]), writer="imagemagick") 121 | # ani.save("./output/PCA_SARNN_{}.mp4".format(params["tag"]), writer="ffmpeg") 122 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/sarnn/bin/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import sys 8 | import numpy as np 9 | import torch 10 | import argparse 11 | from tqdm import tqdm 12 | import torch.optim as optim 13 | from collections import OrderedDict 14 | from torch.utils.tensorboard import SummaryWriter 15 | from eipl.model import SARNN 16 | from eipl.data import MultimodalDataset, SampleDownloader 17 | from eipl.utils import EarlyStopping, check_args, set_logdir 18 | 19 | # load own library 20 | sys.path.append("./libs/") 21 | from fullBPTT import fullBPTTtrainer 22 | 23 | # argument parser 24 | parser = argparse.ArgumentParser( 25 | description="Learning spatial autoencoder with recurrent neural network" 26 | ) 27 | parser.add_argument("--model", type=str, default="sarnn") 28 | parser.add_argument("--epoch", type=int, default=10000) 29 | parser.add_argument("--batch_size", type=int, default=5) 30 | parser.add_argument("--rec_dim", type=int, default=50) 31 | parser.add_argument("--k_dim", type=int, default=5) 32 | parser.add_argument("--img_loss", type=float, default=0.1) 33 | parser.add_argument("--joint_loss", type=float, default=1.0) 34 | parser.add_argument("--pt_loss", type=float, default=0.1) 35 | parser.add_argument("--heatmap_size", type=float, default=0.1) 36 | parser.add_argument("--temperature", type=float, default=1e-4) 37 | parser.add_argument("--stdev", type=float, default=0.1) 38 | parser.add_argument("--log_dir", default="log/") 39 | parser.add_argument("--vmin", type=float, default=0.0) 40 | parser.add_argument("--vmax", type=float, default=1.0) 41 | parser.add_argument("--device", type=int, default=0) 42 | parser.add_argument("--compile", action="store_true") 43 | parser.add_argument("--tag", help="Tag name for snap/log sub directory") 44 | args = parser.parse_args() 45 | 46 | # check args 47 | args = check_args(args) 48 | 49 | # calculate the noise level (variance) from the normalized range 50 | stdev = args.stdev * (args.vmax - args.vmin) 51 | 52 | # set device id 53 | if args.device >= 0: 54 | device = "cuda:{}".format(args.device) 55 | else: 56 | device = "cpu" 57 | 58 | # load dataset 59 | minmax = [args.vmin, args.vmax] 60 | grasp_data = SampleDownloader("om", "grasp_cube", img_format="CHW") 61 | images, joints = grasp_data.load_norm_data("train", vmin=args.vmin, vmax=args.vmax) 62 | train_dataset = MultimodalDataset(images, joints, device=device, stdev=stdev) 63 | train_loader = torch.utils.data.DataLoader( 64 | train_dataset, 65 | batch_size=args.batch_size, 66 | shuffle=True, 67 | drop_last=False, 68 | ) 69 | 70 | images, joints = grasp_data.load_norm_data("test", vmin=args.vmin, vmax=args.vmax) 71 | test_dataset = MultimodalDataset(images, joints, device=device, stdev=None) 72 | test_loader = torch.utils.data.DataLoader( 73 | test_dataset, 74 | batch_size=args.batch_size, 75 | shuffle=True, 76 | drop_last=False, 77 | ) 78 | 79 | # define model 80 | model = SARNN( 81 | rec_dim=args.rec_dim, 82 | joint_dim=5, 83 | k_dim=args.k_dim, 84 | heatmap_size=args.heatmap_size, 85 | temperature=args.temperature, 86 | im_size=[64, 64], 87 | ) 88 | 89 | # torch.compile makes PyTorch code run faster 90 | if args.compile: 91 | torch.set_float32_matmul_precision("high") 92 | model = torch.compile(model) 93 | 94 | # set optimizer 95 | optimizer = optim.Adam(model.parameters(), eps=1e-07) 96 | 97 | # load trainer/tester class 98 | loss_weights = [args.img_loss, args.joint_loss, args.pt_loss] 99 | trainer = fullBPTTtrainer(model, optimizer, loss_weights=loss_weights, device=device) 100 | 101 | # training main 102 | log_dir_path = set_logdir("./" + args.log_dir, args.tag) 103 | save_name = os.path.join(log_dir_path, "SARNN.pth") 104 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30) 105 | early_stop = EarlyStopping(patience=1000) 106 | 107 | with tqdm(range(args.epoch)) as pbar_epoch: 108 | for epoch in pbar_epoch: 109 | # train and test 110 | train_loss = trainer.process_epoch(train_loader) 111 | with torch.no_grad(): 112 | test_loss = trainer.process_epoch(test_loader, training=False) 113 | writer.add_scalar("Loss/train_loss", train_loss, epoch) 114 | writer.add_scalar("Loss/test_loss", test_loss, epoch) 115 | 116 | # early stop 117 | save_ckpt, _ = early_stop(test_loss) 118 | 119 | if save_ckpt: 120 | trainer.save(epoch, [train_loss, test_loss], save_name) 121 | 122 | # print process bar 123 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss)) 124 | pbar_epoch.update() 125 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/sarnn/libs/fullBPTT.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | from eipl.utils import LossScheduler 9 | 10 | 11 | class fullBPTTtrainer: 12 | """ 13 | Helper class to train recurrent neural networks with numpy sequences 14 | 15 | Args: 16 | traindata (np.array): list of np.array. First diemension should be time steps 17 | model (torch.nn.Module): rnn model 18 | optimizer (torch.optim): optimizer 19 | input_param (float): input parameter of sequential generation. 1.0 means open mode. 20 | """ 21 | 22 | def __init__(self, model, optimizer, loss_weights=[1.0, 1.0], device="cpu"): 23 | self.device = device 24 | self.optimizer = optimizer 25 | self.loss_weights = loss_weights 26 | self.scheduler = LossScheduler(decay_end=1000, curve_name="s") 27 | self.model = model.to(self.device) 28 | 29 | def save(self, epoch, loss, savename): 30 | torch.save( 31 | { 32 | "epoch": epoch, 33 | "model_state_dict": self.model.state_dict(), 34 | # 'optimizer_state_dict': self.optimizer.state_dict(), 35 | "train_loss": loss[0], 36 | "test_loss": loss[1], 37 | }, 38 | savename, 39 | ) 40 | 41 | def process_epoch(self, data, training=True): 42 | if not training: 43 | self.model.eval() 44 | else: 45 | self.model.train() 46 | 47 | total_loss = 0.0 48 | for n_batch, ((x_img, x_joint), (y_img, y_joint)) in enumerate(data): 49 | if "cpu" in self.device: 50 | x_img = x_img.to(self.device) 51 | y_img = y_img.to(self.device) 52 | x_joint = x_joint.to(self.device) 53 | y_joint = y_joint.to(self.device) 54 | 55 | state = None 56 | yi_list, yv_list = [], [] 57 | dec_pts_list, enc_pts_list = [], [] 58 | self.optimizer.zero_grad(set_to_none=True) 59 | for t in range(x_img.shape[1] - 1): 60 | _yi_hat, _yv_hat, enc_ij, dec_ij, state = self.model( 61 | x_img[:, t], x_joint[:, t], state 62 | ) 63 | yi_list.append(_yi_hat) 64 | yv_list.append(_yv_hat) 65 | enc_pts_list.append(enc_ij) 66 | dec_pts_list.append(dec_ij) 67 | 68 | yi_hat = torch.permute(torch.stack(yi_list), (1, 0, 2, 3, 4)) 69 | yv_hat = torch.permute(torch.stack(yv_list), (1, 0, 2)) 70 | 71 | img_loss = nn.MSELoss()(yi_hat, y_img[:, 1:]) * self.loss_weights[0] 72 | joint_loss = nn.MSELoss()(yv_hat, y_joint[:, 1:]) * self.loss_weights[1] 73 | # Gradually change the loss value using the LossScheluder class. 74 | pt_loss = nn.MSELoss()( 75 | torch.stack(dec_pts_list[:-1]), torch.stack(enc_pts_list[1:]) 76 | ) * self.scheduler(self.loss_weights[2]) 77 | loss = img_loss + joint_loss + pt_loss 78 | total_loss += loss.item() 79 | 80 | if training: 81 | loss.backward() 82 | self.optimizer.step() 83 | 84 | return total_loss / (n_batch + 1) 85 | -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/sarnn/log/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/tutorials/open_manipulator/sarnn/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/sarnn/bin/test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import argparse 7 | import os 8 | 9 | import matplotlib.animation as anim 10 | import matplotlib.pylab as plt 11 | import numpy as np 12 | import torch 13 | 14 | from eipl.model import SARNN 15 | from eipl.utils import deprocess_img, normalization, restore_args, tensor2numpy 16 | 17 | # argument parser 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--filename", type=str, default=None) 20 | parser.add_argument("--idx", type=int, default=0) 21 | args = parser.parse_args() 22 | 23 | # restore parameters 24 | dir_name = os.path.split(args.filename)[0] 25 | params = restore_args(os.path.join(dir_name, "args.json")) 26 | idx = args.idx 27 | 28 | # load dataset 29 | minmax = [params["vmin"], params["vmax"]] 30 | images_raw = np.load("../simulator/data/test/images.npy") 31 | joints_raw = np.load("../simulator/data/test/joints.npy") 32 | joint_bounds = np.load("../simulator/data/joint_bounds.npy") 33 | images = images_raw[idx] 34 | joints = joints_raw[idx] 35 | 36 | # define model 37 | model = SARNN( 38 | rec_dim=params["rec_dim"], 39 | joint_dim=8, 40 | k_dim=params["k_dim"], 41 | heatmap_size=params["heatmap_size"], 42 | temperature=params["temperature"], 43 | im_size=[64, 64], 44 | ) 45 | 46 | if params["compile"]: 47 | model = torch.compile(model) 48 | 49 | # load weight 50 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 51 | model.load_state_dict(ckpt["model_state_dict"]) 52 | model.eval() 53 | 54 | # Inference 55 | im_size = 64 56 | image_list, joint_list = [], [] 57 | enc_pts_list, dec_pts_list = [], [] 58 | state = None 59 | nloop = len(images) 60 | for loop_ct in range(nloop): 61 | # load data and normalization 62 | img_t = images[loop_ct].transpose(2, 0, 1) 63 | img_t = normalization(img_t, (0, 255), minmax) 64 | img_t = torch.Tensor(np.expand_dims(img_t, 0)) 65 | joint_t = normalization(joints[loop_ct], joint_bounds, minmax) 66 | joint_t = torch.Tensor(np.expand_dims(joint_t, 0)) 67 | 68 | # predict rnn 69 | y_image, y_joint, enc_pts, dec_pts, state = model(img_t, joint_t, state) 70 | 71 | # denormalization 72 | pred_image = tensor2numpy(y_image[0]) 73 | pred_image = deprocess_img(pred_image, params["vmin"], params["vmax"]) 74 | pred_image = pred_image.transpose(1, 2, 0) 75 | pred_joint = tensor2numpy(y_joint[0]) 76 | pred_joint = normalization(pred_joint, minmax, joint_bounds) 77 | 78 | # append data 79 | image_list.append(pred_image) 80 | joint_list.append(pred_joint) 81 | enc_pts_list.append(tensor2numpy(enc_pts[0])) 82 | dec_pts_list.append(tensor2numpy(dec_pts[0])) 83 | 84 | print("loop_ct:{}, joint:{}".format(loop_ct, pred_joint)) 85 | 86 | pred_image = np.array(image_list) 87 | pred_joint = np.array(joint_list) 88 | 89 | # split key points 90 | enc_pts = np.array(enc_pts_list) 91 | dec_pts = np.array(dec_pts_list) 92 | enc_pts = enc_pts.reshape(-1, params["k_dim"], 2) * im_size 93 | dec_pts = dec_pts.reshape(-1, params["k_dim"], 2) * im_size 94 | enc_pts = np.clip(enc_pts, 0, im_size) 95 | dec_pts = np.clip(dec_pts, 0, im_size) 96 | 97 | 98 | # plot images 99 | T = len(images) 100 | fig, ax = plt.subplots(1, 3, figsize=(14, 6), dpi=60) 101 | 102 | 103 | def anim_update(i): 104 | for j in range(3): 105 | ax[j].cla() 106 | 107 | # plot camera image 108 | ax[0].imshow(images[i]) 109 | for j in range(params["k_dim"]): 110 | ax[0].plot(enc_pts[i, j, 0], enc_pts[i, j, 1], "co", markersize=12) # encoder 111 | ax[0].plot( 112 | dec_pts[i, j, 0], dec_pts[i, j, 1], "rx", markersize=12, markeredgewidth=2 113 | ) # decoder 114 | ax[0].axis("off") 115 | ax[0].set_title("Input image", fontsize=20) 116 | 117 | # plot predicted image 118 | ax[1].imshow(pred_image[i]) 119 | ax[1].axis("off") 120 | ax[1].set_title("Predicted image", fontsize=20) 121 | 122 | # plot joint angle 123 | ax[2].set_ylim(-np.pi, 3.4) 124 | ax[2].set_xlim(0, T) 125 | ax[2].plot(joints[1:], linestyle="dashed", c="k") 126 | # om has 5 joints, not 8 127 | for joint_idx in range(8): 128 | ax[2].plot(np.arange(i + 1), pred_joint[: i + 1, joint_idx]) 129 | ax[2].set_xlabel("Step", fontsize=20) 130 | ax[2].set_title("Joint angles", fontsize=20) 131 | ax[2].tick_params(axis="x", labelsize=16) 132 | ax[2].tick_params(axis="y", labelsize=16) 133 | plt.subplots_adjust(left=0.01, right=0.98, bottom=0.12, top=0.9) 134 | 135 | 136 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T) 137 | ani.save("./output/SARNN_{}_{}.gif".format(params["tag"], idx)) 138 | 139 | # If an error occurs in generating the gif animation, change the writer (imagemagick/ffmpeg). 140 | # ani.save("./output/SARNN_{}_{}_{}.gif".format(params["tag"], idx, args.input_param), writer="ffmpeg") 141 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/sarnn/bin/test_pca_sarnn.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import sys 8 | import torch 9 | import argparse 10 | import numpy as np 11 | import matplotlib.pylab as plt 12 | import matplotlib.animation as anim 13 | from sklearn.decomposition import PCA 14 | from eipl.model import SARNN 15 | from eipl.utils import restore_args, tensor2numpy, normalization 16 | 17 | 18 | # argument parser 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("filename", type=str, default=None) 21 | args = parser.parse_args() 22 | 23 | # restore parameters 24 | dir_name = os.path.split(args.filename)[0] 25 | params = restore_args(os.path.join(dir_name, "args.json")) 26 | # idx = args.idx 27 | 28 | # load dataset 29 | minmax = [params["vmin"], params["vmax"]] 30 | images = np.load("../simulator/data/test/images.npy") 31 | joints = np.load("../simulator/data/test/joints.npy") 32 | joint_bounds = np.load("../simulator/data/joint_bounds.npy") 33 | 34 | # define model 35 | model = SARNN( 36 | rec_dim=params["rec_dim"], 37 | joint_dim=8, 38 | k_dim=params["k_dim"], 39 | heatmap_size=params["heatmap_size"], 40 | temperature=params["temperature"], 41 | im_size=[64, 64], 42 | ) 43 | 44 | if params["compile"]: 45 | model = torch.compile(model) 46 | 47 | # load weight 48 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 49 | model.load_state_dict(ckpt["model_state_dict"]) 50 | model.eval() 51 | 52 | # Inference 53 | states = [] 54 | state = None 55 | nloop = images.shape[1] 56 | for loop_ct in range(nloop): 57 | # load data and normalization 58 | img_t = images[:, loop_ct].transpose(0, 3, 1, 2) 59 | img_t = normalization(img_t, (0, 255), minmax) 60 | img_t = torch.Tensor(img_t) 61 | joint_t = normalization(joints[:, loop_ct], joint_bounds, minmax) 62 | joint_t = torch.Tensor(joint_t) 63 | 64 | # predict rnn 65 | _, _, _, _, state = model(img_t, joint_t, state) 66 | states.append(state[0]) 67 | 68 | states = torch.permute(torch.stack(states), (1, 0, 2)) 69 | states = tensor2numpy(states) 70 | # Reshape the state from [N,T,D] to [-1,D] for PCA of RNN. 71 | # N is the number of datasets 72 | # T is the sequence length 73 | # D is the dimension of the hidden state 74 | N, T, D = states.shape 75 | states = states.reshape(-1, D) 76 | 77 | # PCA 78 | loop_ct = float(360) / T 79 | pca_dim = 3 80 | pca = PCA(n_components=pca_dim).fit(states) 81 | pca_val = pca.transform(states) 82 | # Reshape the states from [-1, pca_dim] to [N,T,pca_dim] to 83 | # visualize each state as a 3D scatter. 84 | pca_val = pca_val.reshape(N, T, pca_dim) 85 | 86 | # plot images 87 | fig = plt.figure(dpi=120) 88 | ax = fig.add_subplot(projection="3d") 89 | 90 | 91 | def anim_update(i): 92 | ax.cla() 93 | angle = int(loop_ct * i) 94 | ax.view_init(30, angle) 95 | 96 | c_list = ["C0", "C1", "C2", "C3", "C4", "C5", "C6", "C7", "C8"] 97 | for n, color in enumerate(c_list): 98 | ax.scatter( 99 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0 100 | ) 101 | 102 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0) 103 | pca_ratio = pca.explained_variance_ratio_ * 100 104 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0])) 105 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1])) 106 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2])) 107 | ax.tick_params(axis="x", labelsize=8) 108 | ax.tick_params(axis="y", labelsize=8) 109 | ax.tick_params(axis="z", labelsize=8) 110 | 111 | 112 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T) 113 | ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"])) 114 | 115 | # If an error occurs in generating the gif or mp4 animation, change the writer (imagemagick/ffmpeg). 116 | # ani.save("./output/PCA_SARNN_{}.gif".format(params["tag"]), writer="imagemagick") 117 | # ani.save("./output/PCA_SARNN_{}.mp4".format(params["tag"]), writer="ffmpeg") 118 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/sarnn/bin/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import sys 8 | import numpy as np 9 | import torch 10 | import argparse 11 | from tqdm import tqdm 12 | import torch.optim as optim 13 | from collections import OrderedDict 14 | from torch.utils.tensorboard import SummaryWriter 15 | from eipl.model import SARNN 16 | from eipl.data import MultimodalDataset 17 | from eipl.utils import EarlyStopping, check_args, set_logdir, normalization 18 | 19 | # load own library 20 | sys.path.append("./libs/") 21 | from fullBPTT import fullBPTTtrainer 22 | 23 | # argument parser 24 | parser = argparse.ArgumentParser( 25 | description="Learning spatial autoencoder with recurrent neural network" 26 | ) 27 | parser.add_argument("--model", type=str, default="sarnn") 28 | parser.add_argument("--epoch", type=int, default=10000) 29 | parser.add_argument("--batch_size", type=int, default=5) 30 | parser.add_argument("--rec_dim", type=int, default=50) 31 | parser.add_argument("--k_dim", type=int, default=5) 32 | parser.add_argument("--img_loss", type=float, default=0.1) 33 | parser.add_argument("--joint_loss", type=float, default=1.0) 34 | parser.add_argument("--pt_loss", type=float, default=0.1) 35 | parser.add_argument("--heatmap_size", type=float, default=0.1) 36 | parser.add_argument("--temperature", type=float, default=1e-4) 37 | parser.add_argument("--stdev", type=float, default=0.1) 38 | parser.add_argument("--lr", type=float, default=1e-4) 39 | parser.add_argument("--optimizer", type=str, default="adam") 40 | parser.add_argument("--log_dir", default="log/") 41 | parser.add_argument("--vmin", type=float, default=0.0) 42 | parser.add_argument("--vmax", type=float, default=1.0) 43 | parser.add_argument("--device", type=int, default=0) 44 | parser.add_argument("--compile", action="store_true") 45 | parser.add_argument("--tag", help="Tag name for snap/log sub directory") 46 | args = parser.parse_args() 47 | 48 | # check args 49 | args = check_args(args) 50 | 51 | # calculate the noise level (variance) from the normalized range 52 | stdev = args.stdev * (args.vmax - args.vmin) 53 | 54 | # set device id 55 | if args.device >= 0: 56 | device = "cuda:{}".format(args.device) 57 | else: 58 | device = "cpu" 59 | 60 | # load dataset 61 | minmax = [args.vmin, args.vmax] 62 | images_raw = np.load("../simulator/data/train/images.npy") 63 | joints_raw = np.load("../simulator/data/train/joints.npy") 64 | joint_bounds = np.load("../simulator/data/joint_bounds.npy") 65 | images = normalization(images_raw.transpose(0, 1, 4, 2, 3), (0, 255), minmax) 66 | joints = normalization(joints_raw, joint_bounds, minmax) 67 | train_dataset = MultimodalDataset(images, joints, device=device, stdev=stdev) 68 | train_loader = torch.utils.data.DataLoader( 69 | train_dataset, 70 | batch_size=args.batch_size, 71 | shuffle=True, 72 | drop_last=False, 73 | ) 74 | 75 | images_raw = np.load("../simulator/data/test/images.npy") 76 | joints_raw = np.load("../simulator/data/test/joints.npy") 77 | images = normalization(images_raw.transpose(0, 1, 4, 2, 3), (0, 255), minmax) 78 | joints = normalization(joints_raw, joint_bounds, minmax) 79 | test_dataset = MultimodalDataset(images, joints, device=device, stdev=None) 80 | test_loader = torch.utils.data.DataLoader( 81 | test_dataset, 82 | batch_size=args.batch_size, 83 | shuffle=True, 84 | drop_last=False, 85 | ) 86 | 87 | # define model 88 | model = SARNN( 89 | rec_dim=args.rec_dim, 90 | joint_dim=8, 91 | k_dim=args.k_dim, 92 | heatmap_size=args.heatmap_size, 93 | temperature=args.temperature, 94 | im_size=[64, 64], 95 | ) 96 | 97 | # torch.compile makes PyTorch code run faster 98 | if args.compile: 99 | torch.set_float32_matmul_precision("high") 100 | model = torch.compile(model) 101 | 102 | # set optimizer 103 | optimizer = optim.Adam(model.parameters(), eps=1e-07) 104 | 105 | # load trainer/tester class 106 | loss_weights = [args.img_loss, args.joint_loss, args.pt_loss] 107 | trainer = fullBPTTtrainer(model, optimizer, loss_weights=loss_weights, device=device) 108 | 109 | ### training main 110 | log_dir_path = set_logdir("./" + args.log_dir, args.tag) 111 | save_name = os.path.join(log_dir_path, "SARNN.pth") 112 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30) 113 | early_stop = EarlyStopping(patience=1000) 114 | 115 | with tqdm(range(args.epoch)) as pbar_epoch: 116 | for epoch in pbar_epoch: 117 | # train and test 118 | train_loss = trainer.process_epoch(train_loader) 119 | with torch.no_grad(): 120 | test_loss = trainer.process_epoch(test_loader, training=False) 121 | writer.add_scalar("Loss/train_loss", train_loss, epoch) 122 | writer.add_scalar("Loss/test_loss", test_loss, epoch) 123 | 124 | # early stop 125 | save_ckpt, _ = early_stop(test_loss) 126 | 127 | if save_ckpt: 128 | trainer.save(epoch, [train_loss, test_loss], save_name) 129 | 130 | # print process bar 131 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss)) 132 | pbar_epoch.update() 133 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/sarnn/libs/fullBPTT.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | from eipl.utils import LossScheduler 9 | 10 | 11 | class fullBPTTtrainer: 12 | """ 13 | Helper class to train recurrent neural networks with numpy sequences 14 | 15 | Args: 16 | traindata (np.array): list of np.array. First diemension should be time steps 17 | model (torch.nn.Module): rnn model 18 | optimizer (torch.optim): optimizer 19 | input_param (float): input parameter of sequential generation. 1.0 means open mode. 20 | """ 21 | 22 | def __init__(self, model, optimizer, loss_weights=[1.0, 1.0], device="cpu"): 23 | self.device = device 24 | self.optimizer = optimizer 25 | self.loss_weights = loss_weights 26 | self.scheduler = LossScheduler(decay_end=1000, curve_name="s") 27 | self.model = model.to(self.device) 28 | 29 | def save(self, epoch, loss, savename): 30 | torch.save( 31 | { 32 | "epoch": epoch, 33 | "model_state_dict": self.model.state_dict(), 34 | #'optimizer_state_dict': self.optimizer.state_dict(), 35 | "train_loss": loss[0], 36 | "test_loss": loss[1], 37 | }, 38 | savename, 39 | ) 40 | 41 | def process_epoch(self, data, training=True): 42 | if not training: 43 | self.model.eval() 44 | else: 45 | self.model.train() 46 | 47 | total_loss = 0.0 48 | for n_batch, ((x_img, x_joint), (y_img, y_joint)) in enumerate(data): 49 | if "cpu" in self.device: 50 | x_img = x_img.to(self.device) 51 | y_img = y_img.to(self.device) 52 | x_joint = x_joint.to(self.device) 53 | y_joint = y_joint.to(self.device) 54 | 55 | state = None 56 | yi_list, yv_list = [], [] 57 | dec_pts_list, enc_pts_list = [], [] 58 | self.optimizer.zero_grad(set_to_none=True) 59 | for t in range(x_img.shape[1] - 1): 60 | _yi_hat, _yv_hat, enc_ij, dec_ij, state = self.model( 61 | x_img[:, t], x_joint[:, t], state 62 | ) 63 | yi_list.append(_yi_hat) 64 | yv_list.append(_yv_hat) 65 | enc_pts_list.append(enc_ij) 66 | dec_pts_list.append(dec_ij) 67 | 68 | yi_hat = torch.permute(torch.stack(yi_list), (1, 0, 2, 3, 4)) 69 | yv_hat = torch.permute(torch.stack(yv_list), (1, 0, 2)) 70 | 71 | img_loss = nn.MSELoss()(yi_hat, y_img[:, 1:]) * self.loss_weights[0] 72 | joint_loss = nn.MSELoss()(yv_hat, y_joint[:, 1:]) * self.loss_weights[1] 73 | # Gradually change the loss value using the LossScheluder class. 74 | pt_loss = nn.MSELoss()( 75 | torch.stack(dec_pts_list[:-1]), torch.stack(enc_pts_list[1:]) 76 | ) * self.scheduler(self.loss_weights[2]) 77 | loss = img_loss + joint_loss + pt_loss 78 | total_loss += loss.item() 79 | 80 | if training: 81 | loss.backward() 82 | self.optimizer.step() 83 | 84 | return total_loss / (n_batch + 1) 85 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/sarnn/log/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/sarnn/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/simulator/2_resave.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for dir in `ls ./data/raw_data/`; do 4 | echo "python3 ./bin/2_resave.py --ep_dir ./data/raw_data/${dir}" 5 | python3 ./bin/2_resave.py --ep_dir ./data/raw_data/${dir} 6 | done 7 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/simulator/3_check_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for dir in `ls ./data/raw_data/`; do 4 | echo "python3 ./bin/3_check_playback_data.py ./data/raw_data/${dir}/state_resave.npz" 5 | python3 ./bin/3_check_playback_data.py ./data/raw_data/${dir}/state_resave.npz 6 | done 7 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/simulator/bin/2_resave.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import sys 8 | import argparse 9 | import numpy as np 10 | import robosuite as suite 11 | from robosuite import load_controller_config 12 | 13 | sys.path.append("./libs") 14 | from rt_control_wrapper import RTControlWrapper 15 | 16 | 17 | def playback(env, ep_dir, rate): 18 | data = np.load(os.path.join(ep_dir, "state.npz"), allow_pickle=True) 19 | nloop = data["nloop"] 20 | states = data["states"] 21 | joints = data["joint_angles"] 22 | grippers = np.array([el["actions"][-1] for el in data["action_infos"]]) 23 | 24 | # Set user parameter 25 | loop_ct = 0 26 | 27 | # Reset the environment and Set environment 28 | env.set_environment(ep_dir, states[0]) 29 | 30 | # save list 31 | pose_list = [] 32 | state_list = [] 33 | joint_list = [] 34 | image_list = [] 35 | gripper_list = [] 36 | success_list = [] 37 | 38 | for loop_ct in range(nloop): 39 | # append 40 | if loop_ct % rate == 0: 41 | action = env.get_joint_action(joints[loop_ct], kp=rate) 42 | action[-1] = grippers[loop_ct] 43 | 44 | obs, reward, done, info = env.step(action) 45 | # env.render() 46 | 47 | # save data 48 | if loop_ct % rate == 0: 49 | state, success = env.get_state() 50 | image = env.get_image() 51 | joint = env.get_joints() 52 | pose = env.get_pose() 53 | 54 | # save robot sensor data 55 | pose_list.append(pose) 56 | state_list.append(state) 57 | image_list.append(image) 58 | joint_list.append(joint) 59 | success_list.append(success) 60 | gripper_list.append(action[-1]) 61 | 62 | # saveing 63 | save_name = os.path.join(ep_dir, "state_resave.npz") 64 | print("save fille: ", save_name) 65 | np.savez( 66 | save_name, 67 | poses=np.array(pose_list), 68 | states=np.array(state_list), 69 | joints=np.array(joint_list), 70 | images=np.array(image_list), 71 | success=np.array(success_list), 72 | gripper=np.array(gripper_list).reshape(-1, 1), 73 | ) 74 | 75 | env.close() 76 | 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--ep_dir", type=str, default="./data/raw_data/test/") 81 | parser.add_argument("--rate", type=int, default=5) # Down sampling rate 82 | args = parser.parse_args() 83 | 84 | # Get controller config 85 | # Import controller config IK_POSE/OSC_POSE/OSC_POSITION/JOINT_POSITION 86 | controller_config = load_controller_config(default_controller="JOINT_POSITION") 87 | 88 | # Create argument configuration 89 | config = { 90 | "env_name": "Lift", 91 | "robots": "Panda", 92 | "controller_configs": controller_config, 93 | } 94 | 95 | # create environment 96 | env = suite.make( 97 | **config, 98 | has_renderer=True, 99 | has_offscreen_renderer=True, 100 | render_camera="agentview", 101 | ignore_done=True, 102 | use_camera_obs=True, 103 | reward_shaping=True, 104 | control_freq=20, 105 | hard_reset=False, 106 | ) 107 | 108 | # wrap the environment with data collection wrapper 109 | env = RTControlWrapper(env) 110 | 111 | # collect some data 112 | playback(env, args.ep_dir, args.rate) 113 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/simulator/bin/3_check_playback_data.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import argparse 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import matplotlib.animation as anim 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("filename", type=str, default="./data/test/state.npz") 14 | args = parser.parse_args() 15 | 16 | # load sub directory name from filename 17 | savename = "./output/{}_".format(args.filename.split("/")[-2]) 18 | 19 | # action_infos: action 20 | # obs_infos: image, joint_cos, joint_sin, gripper 21 | data = np.load(args.filename, allow_pickle=True) 22 | if sum(data["success"]) == 0: 23 | print("[ERROR] This data set has failed task during playback.") 24 | exit() 25 | 26 | # plot joint and images 27 | images = data["images"] 28 | poses = data["poses"] 29 | joints = np.concatenate((data["joints"], data["gripper"]), axis=-1) 30 | 31 | N = len(joints) 32 | fig, ax = plt.subplots(1, 3, figsize=(14, 6), dpi=60) 33 | 34 | 35 | def anim_update(i): 36 | for j in range(3): 37 | ax[j].axis("off") 38 | ax[j].cla() 39 | 40 | ax[0].imshow(np.flipud(images[i])) 41 | ax[0].axis("off") 42 | ax[0].set_title("Image", fontsize=20) 43 | 44 | ax[1].set_ylim(-3.5, 3.5) 45 | ax[1].set_xlim(0, N) 46 | for idx in range(8): 47 | ax[1].plot(np.arange(i + 1), joints[: i + 1, idx]) 48 | ax[1].set_xlabel("Step", fontsize=20) 49 | ax[1].set_title("Joint angles", fontsize=20) 50 | ax[1].tick_params(axis="x", labelsize=16) 51 | ax[1].tick_params(axis="y", labelsize=16) 52 | 53 | ax[2].set_ylim(-3.5, 3.5) 54 | ax[2].set_xlim(0, N) 55 | for idx in range(6): 56 | ax[2].plot(np.arange(i + 1), poses[: i + 1, idx]) 57 | ax[2].set_xlabel("Step", fontsize=20) 58 | ax[2].set_title("End effector poses", fontsize=20) 59 | ax[2].tick_params(axis="x", labelsize=16) 60 | ax[2].tick_params(axis="y", labelsize=16) 61 | plt.subplots_adjust(left=0.01, right=0.98, bottom=0.12, top=0.9) 62 | 63 | 64 | ani = anim.FuncAnimation(fig, anim_update, interval=int(N / 10), frames=N) 65 | ani.save(savename + "image_joint_ani.gif") 66 | 67 | # If an error occurs in generating the gif animation, change the writer (imagemagick/ffmpeg). 68 | # ani.save(savename + "image_joint_ani.gif", writer="imagemagick") 69 | # ani.save(savename + "image_joint_ani.mp4", writer="ffmpeg") 70 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/simulator/bin/4_generate_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import numpy as np 7 | from eipl.utils import resize_img, normalization, cos_interpolation, calc_minmax 8 | 9 | 10 | def load_data(index, data_dir, train): 11 | image_list = [] 12 | joint_list = [] 13 | pose_list = [] 14 | 15 | num_list = [1, 2, 3, 4, 5] if train else [6, 7] 16 | for h in index: 17 | for i in num_list: 18 | filename = data_dir.format(h, i) 19 | print(filename) 20 | data = np.load(filename, allow_pickle=True) 21 | 22 | images = data["images"] 23 | images = np.array(images)[:, ::-1] 24 | images = resize_img(images, (64, 64)) 25 | 26 | _poses = data["poses"] 27 | _joints = data["joints"] 28 | _gripper = data["gripper"][:, 0] 29 | _gripper = normalization(_gripper, (-1, 1), (0, 1)) 30 | _gripper = cos_interpolation(_gripper, 10, expand_dims=True) 31 | poses = np.concatenate((_poses, _gripper), axis=-1) 32 | joints = np.concatenate((_joints, _gripper), axis=-1) 33 | 34 | joint_list.append(joints) 35 | image_list.append(images) 36 | pose_list.append(poses) 37 | 38 | joints = np.array(joint_list) 39 | images = np.array(image_list) 40 | poses = np.array(pose_list) 41 | 42 | return images, joints, poses 43 | 44 | 45 | if __name__ == "__main__": 46 | # train position: collect 7 data/position (5 for train, 2 for test) 47 | # test position: collect 2 data/position (all for test) 48 | # Pos1 Pos2 Pos3 Pos4 Pos5 Pos6 Pos7 Pos8 Pos9 49 | # pos_train: -0.2 -0.1 0.0 0.1 0.2 50 | # pos_test: -0.2 -0.15 -0.1 -0.05 0.0 0.05 0.1 0.15 0.2 51 | 52 | data_dir = "./data/raw_data/Pos{}_{}/state_resave.npz" 53 | 54 | # load train data 55 | train_index = [1, 3, 5, 7, 9] 56 | train_images, train_joints, train_poses = load_data( 57 | train_index, data_dir, train=True 58 | ) 59 | np.save("./data/train/images.npy", train_images.astype(np.uint8)) 60 | np.save("./data/train/joints.npy", train_joints.astype(np.float32)) 61 | np.save("./data/train/poses.npy", train_poses.astype(np.float32)) 62 | 63 | test_index = [1, 2, 3, 4, 5, 6, 7, 8, 9] 64 | test_images, test_joints, test_poses = load_data(test_index, data_dir, train=False) 65 | np.save("./data/test/images.npy", test_images.astype(np.uint8)) 66 | np.save("./data/test/joints.npy", test_joints.astype(np.float32)) 67 | np.save("./data/test/poses.npy", test_poses.astype(np.float32)) 68 | 69 | # save bounds 70 | poses = np.concatenate((train_poses, test_poses), axis=0) 71 | joints = np.concatenate((train_joints, test_joints), axis=0) 72 | pose_bounds = calc_minmax(poses) 73 | joint_bounds = calc_minmax(joints) 74 | np.save("./data/pose_bounds.npy", pose_bounds) 75 | np.save("./data/joint_bounds.npy", joint_bounds) 76 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/simulator/bin/5_check_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import argparse 7 | import numpy as np 8 | import matplotlib.pylab as plt 9 | import matplotlib.animation as anim 10 | from eipl.utils import normalization 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--idx", type=int, default=0) 15 | args = parser.parse_args() 16 | 17 | idx = int(args.idx) 18 | joints = np.load("./data/test/joints.npy") 19 | joint_bounds = np.load("./data/joint_bounds.npy") 20 | images = np.load("./data/test/images.npy") 21 | N = images.shape[1] 22 | 23 | 24 | # normalized joints 25 | minmax = [0.1, 0.9] 26 | norm_joints = normalization(joints, joint_bounds, minmax) 27 | 28 | # print data information 29 | print("load test data, index number is {}".format(idx)) 30 | print( 31 | "Joint: shape={}, min={:.3g}, max={:.3g}".format( 32 | joints.shape, joints.min(), joints.max() 33 | ) 34 | ) 35 | print( 36 | "Norm joint: shape={}, min={:.3g}, max={:.3g}".format( 37 | norm_joints.shape, norm_joints.min(), norm_joints.max() 38 | ) 39 | ) 40 | 41 | # plot images and normalized joints 42 | fig, ax = plt.subplots(1, 3, figsize=(14, 6), dpi=60) 43 | 44 | 45 | def anim_update(i): 46 | for j in range(3): 47 | ax[j].cla() 48 | 49 | # plot image 50 | ax[0].imshow(images[idx, i]) 51 | ax[0].axis("off") 52 | ax[0].set_title("Image", fontsize=20) 53 | 54 | # plot joint angle 55 | ax[1].set_ylim(-1.0, 2.0) 56 | ax[1].set_xlim(0, N) 57 | ax[1].plot(joints[idx], linestyle="dashed", c="k") 58 | 59 | for joint_idx in range(8): 60 | ax[1].plot(np.arange(i + 1), joints[idx, : i + 1, joint_idx]) 61 | ax[1].set_xlabel("Step", fontsize=20) 62 | ax[1].set_title("Joint angles", fontsize=20) 63 | ax[1].tick_params(axis="x", labelsize=16) 64 | ax[1].tick_params(axis="y", labelsize=16) 65 | 66 | # plot normalized joint angle 67 | ax[2].set_ylim(0.0, 1.0) 68 | ax[2].set_xlim(0, N) 69 | ax[2].plot(norm_joints[idx], linestyle="dashed", c="k") 70 | 71 | for joint_idx in range(8): 72 | ax[2].plot(np.arange(i + 1), norm_joints[idx, : i + 1, joint_idx]) 73 | ax[2].set_xlabel("Step", fontsize=20) 74 | ax[2].set_title("Normalized joint angles", fontsize=20) 75 | ax[2].tick_params(axis="x", labelsize=16) 76 | ax[2].tick_params(axis="y", labelsize=16) 77 | plt.subplots_adjust(left=0.01, right=0.98, bottom=0.12, top=0.9) 78 | 79 | 80 | ani = anim.FuncAnimation(fig, anim_update, interval=int(N / 10), frames=N) 81 | ani.save("./output/check_dataset_{}.gif".format(idx)) 82 | 83 | # If an error occurs in generating the gif animation, change the writer (imagemagick/ffmpeg). 84 | # ani.save("./output/check_dataset_{}.gif".format(idx), writer="imagemagick") 85 | # ani.save("./output/check_dataset_{}.gif".format(idx), writer="ffmpeg") 86 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/simulator/data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/simulator/libs/environment.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | from robosuite.models.objects import BoxObject 7 | from robosuite.utils.mjcf_utils import CustomMaterial 8 | 9 | 10 | tex_attrib = { 11 | "type": "cube", 12 | } 13 | 14 | mat_attrib = { 15 | "texrepeat": "1 1", 16 | "specular": "0.4", 17 | "shininess": "0.1", 18 | } 19 | 20 | redwood = CustomMaterial( 21 | texture="WoodRed", 22 | tex_name="redwood", 23 | mat_name="redwood_mat", 24 | tex_attrib=tex_attrib, 25 | mat_attrib=mat_attrib, 26 | ) 27 | 28 | cube = BoxObject( 29 | name="cube", 30 | size_min=[0.020, 0.020, 0.020], 31 | size_max=[0.022, 0.022, 0.022], 32 | rgba=[1, 0, 0, 1], 33 | material=redwood, 34 | ) 35 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/simulator/libs/rt_control_wrapper.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # Released under the MIT License. 4 | # 5 | 6 | import os 7 | import numpy as np 8 | import transforms3d as T 9 | from robosuite.wrappers import Wrapper 10 | 11 | 12 | class RTControlWrapper(Wrapper): 13 | def __init__(self, env, save_dir=None, vis_settings=False): 14 | super().__init__(env) 15 | """ 16 | RTControlWrapper specialized for imitation learning 17 | """ 18 | 19 | # initialization 20 | self.robot = self.env.robots[0] 21 | self.robot.controller_config["kp"] = 150 22 | 23 | # save the task instance (will be saved on the first env interaction) 24 | self._current_task_instance_xml = self.env.sim.model.get_xml() 25 | self._current_task_instance_state = np.array(self.env.sim.get_state().flatten()) 26 | 27 | # make save directory and save xml file 28 | if save_dir is not None: 29 | if not os.path.exists(save_dir): 30 | print( 31 | "DataCollectionWrapper: making new directory at {}".format(save_dir) 32 | ) 33 | os.makedirs(save_dir) 34 | self.save_xml(save_dir) 35 | 36 | self._vis_settings = None 37 | if vis_settings: 38 | # Create internal dict to store visualization settings (set to True by default) 39 | self._vis_settings = {vis: True for vis in self.env._visualizations} 40 | 41 | def reset(self): 42 | """ 43 | Extends vanilla reset() function call to accommodate data collection 44 | Returns: 45 | OrderedDict: Environment observation space after reset occurs 46 | """ 47 | ret = super().reset() 48 | return ret 49 | 50 | def step(self, action): 51 | """ 52 | Extends vanilla step() function call to accommodate data collection 53 | Args: 54 | action (np.array): Action to take in environment 55 | Returns: 56 | 4-tuple: 57 | - (OrderedDict) observations from the environment 58 | - (float) reward from the environment 59 | - (bool) whether the current episode is completed or not 60 | - (dict) misc information 61 | """ 62 | ret = super().step(action) 63 | if self._vis_settings is not None: 64 | self.env.visualize(vis_settings=self._vis_settings) 65 | 66 | return ret 67 | 68 | def save_xml(self, save_dir="./data"): 69 | # save the model xml 70 | xml_path = os.path.join(save_dir, "model.xml") 71 | with open(xml_path, "w") as f: 72 | f.write(self._current_task_instance_xml) 73 | 74 | def set_environment(self, save_dir, state): 75 | xml_path = os.path.join(save_dir, "model.xml") 76 | with open(xml_path, "r") as f: 77 | self.env.reset_from_xml_string(f.read()) 78 | 79 | self.env.sim.set_state_from_flattened(state) 80 | if self._vis_settings is not None: 81 | self.env.visualize(vis_settings=self._vis_settings) 82 | 83 | def set_gripper_qpos(self, qpos): 84 | for i, q in enumerate(qpos): 85 | self.env.sim.data.set_joint_qpos( 86 | name="gripper0_finger_joint{}".format(i + 1), value=q 87 | ) 88 | 89 | def set_joint_qpos(self, qpos): 90 | for i, q in enumerate(qpos): 91 | self.env.sim.data.set_joint_qpos( 92 | name="robot0_joint{}".format(i + 1), value=q 93 | ) 94 | 95 | def get_image(self, name="agentview_image"): 96 | return self.env.observation_spec()[name] 97 | 98 | def get_joint_action(self, goal_joint_pos, kp, kd=0.0): 99 | """relative2absolute_joint_pos_commands""" 100 | action = [0 for _ in range(self.robot.dof)] 101 | curr_joint_pos = self.robot._joint_positions 102 | curr_joint_vel = self.robot._joint_velocities 103 | 104 | for i in range(len(goal_joint_pos)): 105 | action[i] = (goal_joint_pos[i] - curr_joint_pos[i]) * kp - curr_joint_vel[ 106 | i 107 | ] * kd 108 | 109 | return action 110 | 111 | def get_pose(self): 112 | position = self.env.robots[0]._hand_pos 113 | orientation_matrix = self.env.robots[0]._hand_orn 114 | orientation_euler = T.euler.mat2euler(orientation_matrix) 115 | 116 | pose = list(position) + list(orientation_euler) 117 | pose = np.array(pose) 118 | pose[pose < -np.pi / 2] += np.pi * 2 119 | 120 | return pose 121 | 122 | def get_gripper(self): 123 | return self.env.observation_spec()["robot0_gripper_qpos"] 124 | 125 | def get_joints(self): 126 | return self.robot._joint_positions 127 | 128 | def check_success(self): 129 | return self.env._check_success() 130 | 131 | def get_state(self): 132 | state = self.env.sim.get_state().flatten() 133 | # successful 134 | if self.env._check_success(): 135 | return state, True 136 | else: 137 | return state, False 138 | -------------------------------------------------------------------------------- /eipl/tutorials/robosuite/simulator/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .arg_utils import * 2 | from .callback import * 3 | from .data import * 4 | from .nn_func import * 5 | from .path_utils import * 6 | from .print_func import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /eipl/utils/arg_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import json 9 | import datetime 10 | from .print_func import * 11 | from .path_utils import * 12 | 13 | 14 | def print_args(args): 15 | """Print arguments""" 16 | if not isinstance(args, dict): 17 | args = vars(args) 18 | 19 | keys = args.keys() 20 | keys = sorted(keys) 21 | 22 | print("================================") 23 | for key in keys: 24 | print("{} : {}".format(key, args[key])) 25 | print("================================") 26 | 27 | 28 | def save_args(args, filename): 29 | """Dump arguments as json file""" 30 | with open(filename, "w") as f: 31 | json.dump(vars(args), f, indent=4, sort_keys=True) 32 | 33 | 34 | def restore_args(filename): 35 | """Load argument file from file""" 36 | with open(filename, "r") as f: 37 | args = json.load(f) 38 | return args 39 | 40 | 41 | def get_config(args, tag, default=None): 42 | """Get value from argument""" 43 | if not isinstance(args, dict): 44 | raise ValueError("args should be dict.") 45 | 46 | if tag in args: 47 | if args[tag] is None: 48 | print_info("set {} <-- {} (default)".format(tag, default)) 49 | return default 50 | else: 51 | print_info("set {} <--- {}".format(tag, args[tag])) 52 | return args[tag] 53 | else: 54 | if default is None: 55 | raise ValueError("you need to specify config {}".format(tag)) 56 | 57 | print_info("set {} <-- {} (default)".format(tag, default)) 58 | return default 59 | 60 | 61 | def check_args(args): 62 | """Check arguments""" 63 | 64 | if args.tag is None: 65 | tag = datetime.datetime.today().strftime("%Y%m%d_%H%M_%S") 66 | args.tag = tag 67 | print_info("Set tag = %s" % tag) 68 | 69 | # make log directory 70 | check_path(os.path.join(args.log_dir, args.tag), mkdir=True) 71 | 72 | # saves arguments into json file 73 | save_args(args, os.path.join(args.log_dir, args.tag, "args.json")) 74 | 75 | print_args(args) 76 | return args 77 | -------------------------------------------------------------------------------- /eipl/utils/callback.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class EarlyStopping: 13 | def __init__(self, patience=5): 14 | """ 15 | Early stops the training if validation loss doesn't improve after a given patience. 16 | Args: 17 | patience (int): Number of epochs with no improvement after which training will be stopped. 18 | """ 19 | self.patience = patience 20 | self.counter = 0 21 | self.best_score = None 22 | self.save_ckpt = False 23 | self.stop_flag = False 24 | self.val_loss_min = np.inf 25 | 26 | def __call__(self, val_loss): 27 | if np.isnan(val_loss) or np.isinf(val_loss): 28 | raise RuntimeError("Invalid loss, terminating training") 29 | 30 | score = -val_loss 31 | 32 | if self.best_score is None: 33 | self.save_ckpt = True 34 | self.best_score = score 35 | elif score < self.best_score: 36 | self.save_ckpt = False 37 | self.counter += 1 38 | if self.counter >= self.patience: 39 | self.stop_flag = True 40 | else: 41 | self.save_ckpt = True 42 | self.best_score = score 43 | self.counter = 0 44 | 45 | return self.save_ckpt, self.stop_flag 46 | 47 | 48 | class LossScheduler: 49 | def __init__(self, decay_end=1000, curve_name="s"): 50 | decay_start = 0 51 | self.counter = -1 52 | self.decay_end = decay_end 53 | self.interpolated_values = self.curve_interpolation( 54 | decay_start, decay_end, decay_end, curve_name 55 | ) 56 | 57 | def linear_interpolation(self, start, end, num_points): 58 | x = np.linspace(start, end, num_points) 59 | return x 60 | 61 | def s_curve_interpolation(self, start, end, num_points): 62 | t = np.linspace(0, 1, num_points) 63 | x = start + (end - start) * (t - np.sin(2 * np.pi * t) / (2 * np.pi)) 64 | return x 65 | 66 | def inverse_s_curve_interpolation(self, start, end, num_points): 67 | t = np.linspace(0, 1, num_points) 68 | x = start + (end - start) * (t + np.sin(2 * np.pi * t) / (2 * np.pi)) 69 | return x 70 | 71 | def deceleration_curve_interpolation(self, start, end, num_points): 72 | t = np.linspace(0, 1, num_points) 73 | x = start + (end - start) * (1 - np.cos(np.pi * t / 2)) 74 | return x 75 | 76 | def acceleration_curve_interpolation(self, start, end, num_points): 77 | t = np.linspace(0, 1, num_points) 78 | x = start + (end - start) * (np.sin(np.pi * t / 2)) 79 | return x 80 | 81 | def curve_interpolation(self, start, end, num_points, curve_name): 82 | if curve_name == "linear": 83 | interpolated_values = self.linear_interpolation(start, end, num_points) 84 | elif curve_name == "s": 85 | interpolated_values = self.s_curve_interpolation(start, end, num_points) 86 | elif curve_name == "inverse_s": 87 | interpolated_values = self.inverse_s_curve_interpolation( 88 | start, end, num_points 89 | ) 90 | elif curve_name == "deceleration": 91 | interpolated_values = self.deceleration_curve_interpolation( 92 | start, end, num_points 93 | ) 94 | elif curve_name == "acceleration": 95 | interpolated_values = self.acceleration_curve_interpolation( 96 | start, end, num_points 97 | ) 98 | else: 99 | assert False, "Invalid curve name. {}".format(curve_name) 100 | 101 | return interpolated_values / num_points 102 | 103 | def __call__(self, loss_weight): 104 | self.counter += 1 105 | if self.counter >= self.decay_end: 106 | return loss_weight 107 | else: 108 | return self.interpolated_values[self.counter] * loss_weight 109 | -------------------------------------------------------------------------------- /eipl/utils/check_gpu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import torch 9 | 10 | print("torch version:", torch.__version__) 11 | 12 | if torch.cuda.is_available(): 13 | print("cuda is available") 14 | print("device_count: ", torch.cuda.device_count()) 15 | print("device name: ", torch.cuda.get_device_name()) 16 | else: 17 | print("cuda is not avaiable") 18 | -------------------------------------------------------------------------------- /eipl/utils/convert_compiled_pth.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import torch 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("filename", type=str, default=None) 14 | args = parser.parse_args() 15 | 16 | # resave original file 17 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 18 | 19 | restored_ckpt = {} 20 | for k, v in ckpt["model_state_dict"].items(): 21 | restored_ckpt[k.replace("_orig_mod.", "")] = v 22 | 23 | ckpt["model_state_dict"] = restored_ckpt 24 | 25 | savename = "{}_v1.pth".format(os.path.splitext(args.filename)[0]) 26 | torch.save(ckpt, savename) 27 | -------------------------------------------------------------------------------- /eipl/utils/nn_func.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import torch.nn as nn 9 | 10 | 11 | def get_activation_fn(name, inplace=True): 12 | if name.casefold() == "relu": 13 | return nn.ReLU(inplace=inplace) 14 | elif name.casefold() == "lrelu": 15 | return nn.LeakyReLU(inplace=inplace) 16 | elif name.casefold() == "softmax": 17 | return nn.Softmax() 18 | elif name.casefold() == "tanh": 19 | return nn.Tanh() 20 | else: 21 | assert False, "Unknown activation function {}".format(name) 22 | -------------------------------------------------------------------------------- /eipl/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import numpy as np 10 | 11 | 12 | def check_filename(filename): 13 | if os.path.exists(filename): 14 | raise ValueError("{} exists.".format(filename)) 15 | return filename 16 | 17 | 18 | def check_path(path, mkdir=False): 19 | """ 20 | Checks that path is collect 21 | """ 22 | if path[-1] == "/": 23 | path = path[:-1] 24 | 25 | if not os.path.exists(path): 26 | if mkdir: 27 | os.makedirs(path, exist_ok=True) 28 | else: 29 | raise ValueError("%s does not exist" % path) 30 | 31 | return path 32 | -------------------------------------------------------------------------------- /eipl/utils/print_func.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | OK = "\033[92m" 9 | WARN = "\033[93m" 10 | NG = "\033[91m" 11 | END_CODE = "\033[0m" 12 | 13 | 14 | def print_info(msg): 15 | print(OK + "[INFO] " + END_CODE + msg) 16 | 17 | 18 | def print_warn(msg): 19 | print(WARN + "[WARNING] " + END_CODE + msg) 20 | 21 | 22 | def print_error(msg): 23 | print(NG + "[ERROR] " + END_CODE + msg) 24 | -------------------------------------------------------------------------------- /eipl/utils/resave_pth.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import torch 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("filename", type=str, default=None) 14 | args = parser.parse_args() 15 | 16 | # resave original file 17 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 18 | savename = "{}_org.pth".format(os.path.splitext(args.filename)[0]) 19 | torch.save(ckpt, savename) 20 | 21 | # save pth file without optimizer state 22 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 23 | ckpt.pop("optimizer_state_dict") 24 | torch.save(ckpt, args.filename) 25 | -------------------------------------------------------------------------------- /eipl/utils/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import cv2 10 | import glob 11 | import datetime 12 | import numpy as np 13 | import matplotlib.pylab as plt 14 | 15 | 16 | def check_path(path, mkdir=False): 17 | """ 18 | checks given path is existing or not 19 | """ 20 | if path[-1] == "/": 21 | path = path[:-1] 22 | 23 | if not os.path.exists(path): 24 | if mkdir: 25 | os.mkdir(path) 26 | else: 27 | raise ValueError("%s does not exist" % path) 28 | return path 29 | 30 | 31 | def set_logdir(log_dir, tag): 32 | return check_path(os.path.join(log_dir, tag), mkdir=True) 33 | -------------------------------------------------------------------------------- /eipl/zoo/cae/bin/extract.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import torch 10 | import argparse 11 | import numpy as np 12 | from eipl.model import BasicCAE, CAE, BasicCAEBN, CAEBN 13 | from eipl.data import SampleDownloader 14 | from eipl.utils import print_info, restore_args, tensor2numpy 15 | 16 | 17 | # argument parser 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--filename", type=str, default=None) 20 | args = parser.parse_args() 21 | 22 | # restore parameters 23 | dir_name = os.path.split(args.filename)[0] 24 | params = restore_args(os.path.join(dir_name, "args.json")) 25 | 26 | # define model 27 | if params["model"] == "BasicCAE": 28 | model = BasicCAE(feat_dim=params["feat_dim"]) 29 | elif params["model"] == "CAE": 30 | model = CAE(feat_dim=params["feat_dim"]) 31 | elif params["model"] == "BasicCAEBN": 32 | model = BasicCAEBN(feat_dim=params["feat_dim"]) 33 | elif params["model"] == "CAEBN": 34 | model = CAEBN(feat_dim=params["feat_dim"]) 35 | else: 36 | assert False, "Unknown model name {}".format(params["model"]) 37 | 38 | # load weight 39 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 40 | model.load_state_dict(ckpt["model_state_dict"]) 41 | model.eval() 42 | 43 | for data_type in ["train", "test"]: 44 | # generate rnn dataset 45 | os.makedirs("./data/{}/".format(data_type), exist_ok=True) 46 | 47 | # load dataset 48 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="CHW") 49 | images, joints = grasp_data.load_norm_data( 50 | data_type, params["vmin"], params["vmax"] 51 | ) 52 | images = torch.tensor(images) 53 | joint_bounds = np.load( 54 | os.path.join( 55 | os.path.expanduser("~"), ".eipl/airec/grasp_bottle/joint_bounds.npy" 56 | ) 57 | ) 58 | 59 | # extract image feature 60 | N = images.shape[0] 61 | feature_list = [] 62 | for i in range(N): 63 | _features = model.encoder(images[i]) 64 | feature_list.append(tensor2numpy(_features)) 65 | 66 | features = np.array(feature_list) 67 | np.save("./data/joint_bounds.npy", joint_bounds) 68 | np.save("./data/{}/features.npy".format(data_type), features) 69 | np.save("./data/{}/joints.npy".format(data_type), joints) 70 | 71 | print_info("{} data".format(data_type)) 72 | print("==================================================") 73 | print("Shape of joints angle:", joints.shape) 74 | print("Shape of image feature:", features.shape) 75 | print("==================================================") 76 | print() 77 | 78 | # save features minmax bounds 79 | feat_list = [] 80 | for data_type in ["train", "test"]: 81 | feat_list.append(np.load("./data/{}/features.npy".format(data_type))) 82 | 83 | feat = np.vstack(feat_list) 84 | feat_minmax = np.array([feat.min(), feat.max()]) 85 | np.save("./data/feat_bounds.npy", feat_minmax) 86 | -------------------------------------------------------------------------------- /eipl/zoo/cae/bin/test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import torch 10 | import argparse 11 | import numpy as np 12 | import matplotlib.pylab as plt 13 | import matplotlib.animation as anim 14 | from eipl.model import BasicCAE, CAE, BasicCAEBN, CAEBN 15 | from eipl.data import SampleDownloader, WeightDownloader 16 | from eipl.utils import normalization, deprocess_img, restore_args 17 | from eipl.utils import tensor2numpy 18 | 19 | 20 | # argument parser 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--filename", type=str, default=None) 23 | parser.add_argument("--idx", type=str, default="0") 24 | parser.add_argument("--pretrained", action="store_true") 25 | args = parser.parse_args() 26 | 27 | # check args 28 | assert args.filename or args.pretrained, "Please set filename or pretrained" 29 | 30 | # load pretrained weight 31 | if args.pretrained: 32 | WeightDownloader("airec", "grasp_bottle") 33 | args.filename = os.path.join( 34 | os.path.expanduser("~"), ".eipl/airec/grasp_bottle/weights/CAEBN/model.pth" 35 | ) 36 | 37 | # restore parameters 38 | dir_name = os.path.split(args.filename)[0] 39 | params = restore_args(os.path.join(dir_name, "args.json")) 40 | idx = int(args.idx) 41 | 42 | # load dataset 43 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="HWC") 44 | images_raw, _ = grasp_data.load_raw_data("test") 45 | images = normalization( 46 | images_raw.astype(np.float32), (0.0, 255.0), (params["vmin"], params["vmax"]) 47 | ) 48 | images = images.transpose(0, 1, 4, 2, 3) 49 | images = torch.tensor(images) 50 | T = images.shape[1] 51 | 52 | # define model 53 | if params["model"] == "BasicCAE": 54 | model = BasicCAE(feat_dim=params["feat_dim"]) 55 | elif params["model"] == "CAE": 56 | model = CAE(feat_dim=params["feat_dim"]) 57 | elif params["model"] == "BasicCAEBN": 58 | model = BasicCAEBN(feat_dim=params["feat_dim"]) 59 | elif params["model"] == "CAEBN": 60 | model = CAEBN(feat_dim=params["feat_dim"]) 61 | else: 62 | assert False, "Unknown model name {}".format(params["model"]) 63 | 64 | if params["compile"]: 65 | model = torch.compile(model) 66 | 67 | # load weight 68 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 69 | model.load_state_dict(ckpt["model_state_dict"]) 70 | model.eval() 71 | 72 | # prediction 73 | _yi = model(images[idx]) 74 | yi = deprocess_img(tensor2numpy(_yi), params["vmin"], params["vmax"]) 75 | yi = yi.transpose(0, 2, 3, 1) 76 | 77 | # plot images 78 | fig, ax = plt.subplots(1, 2, figsize=(8, 3), dpi=60) 79 | 80 | 81 | def anim_update(i): 82 | for j in range(2): 83 | ax[j].cla() 84 | 85 | ax[0].imshow(images_raw[idx, i, :, :, ::-1]) 86 | ax[0].axis("off") 87 | ax[0].set_title("Input image") 88 | 89 | ax[1].imshow(yi[i, :, :, ::-1]) 90 | ax[1].axis("off") 91 | ax[1].set_title("Reconstructed image") 92 | 93 | 94 | # defaults 95 | ani = anim.FuncAnimation(fig, anim_update, interval=int(T / 10), frames=T) 96 | ani.save("./output/{}_{}_{}.gif".format(params["model"], params["tag"], idx)) 97 | 98 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg). 99 | # ani.save("./output/{}_{}_{}.gif".format(params["model"], params["tag"], idx), writer="imagemagick") 100 | # ani.save("./output/{}_{}_{}.mp4".format(params["model"], params["tag"], idx), writer="ffmpeg") 101 | -------------------------------------------------------------------------------- /eipl/zoo/cae/bin/test_pca_cae.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import torch 10 | import argparse 11 | import matplotlib.pylab as plt 12 | import matplotlib.animation as anim 13 | from sklearn.decomposition import PCA 14 | from eipl.data import SampleDownloader 15 | from eipl.utils import tensor2numpy, restore_args 16 | from eipl.model import BasicCAE, CAE, BasicCAEBN, CAEBN 17 | 18 | 19 | # argument parser 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--filename", type=str, default=None) 22 | args = parser.parse_args() 23 | 24 | # restore parameters 25 | dir_name = os.path.split(args.filename)[0] 26 | params = restore_args(os.path.join(dir_name, "args.json")) 27 | 28 | # load dataset 29 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="CHW") 30 | images, _ = grasp_data.load_norm_data("test", params["vmin"], params["vmax"]) 31 | N, T, C, W, H = images.shape 32 | images = images.reshape(N * T, C, W, H) 33 | images = torch.tensor(images) 34 | 35 | # define model 36 | if params["model"] == "BasicCAE": 37 | model = BasicCAE(feat_dim=params["feat_dim"]) 38 | elif params["model"] == "CAE": 39 | model = CAE(feat_dim=params["feat_dim"]) 40 | elif params["model"] == "BasicCAEBN": 41 | model = BasicCAEBN(feat_dim=params["feat_dim"]) 42 | elif params["model"] == "CAEBN": 43 | model = CAEBN(feat_dim=params["feat_dim"]) 44 | else: 45 | assert False, "Unknown model name {}".format(params["model"]) 46 | 47 | if params["compile"]: 48 | model = torch.compile(model) 49 | 50 | # load weight 51 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 52 | model.load_state_dict(ckpt["model_state_dict"]) 53 | model.eval() 54 | 55 | # prediction 56 | feat = model.encoder(images) 57 | feat = tensor2numpy(feat) 58 | loop_ct = float(360) / T 59 | 60 | pca_dim = 3 61 | pca = PCA(n_components=pca_dim).fit(feat) 62 | pca_val = pca.transform(feat) 63 | pca_val = pca_val.reshape(N, T, pca_dim) 64 | 65 | fig = plt.figure() 66 | ax = fig.add_subplot(projection="3d") 67 | 68 | 69 | def anim_update(i): 70 | ax.cla() 71 | angle = int(loop_ct * i) 72 | ax.view_init(30, angle) 73 | 74 | c_list = ["C0", "C1", "C2", "C3", "C4"] 75 | for n, color in enumerate(c_list): 76 | ax.scatter( 77 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0 78 | ) 79 | 80 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0) 81 | pca_ratio = pca.explained_variance_ratio_ * 100 82 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0])) 83 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1])) 84 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2])) 85 | 86 | 87 | ani = anim.FuncAnimation(fig, anim_update, interval=T, frames=T) 88 | ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"])) 89 | 90 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg). 91 | # ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"]), writer="imagemagick") 92 | # ani.save("./output/PCA_{}_{}.mp4".format(params["model"], params["tag"]), writer="ffmpeg") 93 | -------------------------------------------------------------------------------- /eipl/zoo/cae/bin/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import sys 10 | import torch 11 | import argparse 12 | from tqdm import tqdm 13 | from collections import OrderedDict 14 | import torch.optim as optim 15 | from torch.utils.tensorboard import SummaryWriter 16 | from eipl.model import BasicCAE, CAE, BasicCAEBN, CAEBN 17 | from eipl.data import ImageDataset, SampleDownloader 18 | from eipl.utils import EarlyStopping, check_args, set_logdir 19 | 20 | # load own library 21 | sys.path.append("./libs/") 22 | from trainer import Trainer 23 | 24 | 25 | # argument parser 26 | parser = argparse.ArgumentParser(description="Learning convolutional autoencoder") 27 | parser.add_argument("--model", type=str, default="CAEBN") 28 | parser.add_argument("--epoch", type=int, default=10000) 29 | parser.add_argument("--batch_size", type=int, default=32) 30 | parser.add_argument("--feat_dim", type=int, default=10) 31 | parser.add_argument("--stdev", type=float, default=0.1) 32 | parser.add_argument("--lr", type=float, default=1e-3) 33 | parser.add_argument("--optimizer", type=str, default="adam") 34 | parser.add_argument("--log_dir", default="log/") 35 | parser.add_argument("--vmin", type=float, default=0.1) 36 | parser.add_argument("--vmax", type=float, default=0.9) 37 | parser.add_argument("--device", type=int, default=0) 38 | parser.add_argument("--compile", action="store_true") 39 | parser.add_argument("--tag", help="Tag name for snap/log sub directory") 40 | args = parser.parse_args() 41 | 42 | # check args 43 | args = check_args(args) 44 | 45 | # calculate the noise level (variance) from the normalized range 46 | stdev = args.stdev * (args.vmax - args.vmin) 47 | 48 | # set device id 49 | if args.device >= 0: 50 | device = "cuda:{}".format(args.device) 51 | else: 52 | device = "cpu" 53 | 54 | # load dataset 55 | minmax = [args.vmin, args.vmax] 56 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="CHW") 57 | images, _ = grasp_data.load_norm_data("train", vmin=args.vmin, vmax=args.vmax) 58 | train_dataset = ImageDataset(images, device=device, stdev=stdev) 59 | train_loader = torch.utils.data.DataLoader( 60 | train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False 61 | ) 62 | 63 | images, _ = grasp_data.load_norm_data("test", vmin=args.vmin, vmax=args.vmax) 64 | test_dataset = ImageDataset(images, device=device, stdev=None) 65 | test_loader = torch.utils.data.DataLoader( 66 | test_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False 67 | ) 68 | 69 | 70 | # define model 71 | if args.model == "BasicCAE": 72 | model = BasicCAE(feat_dim=args.feat_dim) 73 | elif args.model == "CAE": 74 | model = CAE(feat_dim=args.feat_dim) 75 | elif args.model == "BasicCAEBN": 76 | model = BasicCAEBN(feat_dim=args.feat_dim) 77 | elif args.model == "CAEBN": 78 | model = CAEBN(feat_dim=args.feat_dim) 79 | else: 80 | assert False, "Unknown model name {}".format(args.model) 81 | 82 | # torch.compile makes PyTorch code run faster 83 | if args.compile: 84 | model = torch.compile(model) 85 | 86 | # set optimizer 87 | optimizer = optim.Adam(model.parameters(), eps=1e-07) 88 | 89 | # load trainer/tester class 90 | trainer = Trainer(model, optimizer, device=device) 91 | 92 | ### training main 93 | log_dir_path = set_logdir("./" + args.log_dir, args.tag) 94 | save_name = os.path.join(log_dir_path, "{}.pth".format(args.model)) 95 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30) 96 | early_stop = EarlyStopping(patience=1000) 97 | 98 | with tqdm(range(args.epoch)) as pbar_epoch: 99 | for epoch in pbar_epoch: 100 | # train and test 101 | train_loss = trainer.process_epoch(train_loader) 102 | with torch.no_grad(): 103 | test_loss = trainer.process_epoch(test_loader, training=False) 104 | writer.add_scalar("Loss/train_loss", train_loss, epoch) 105 | writer.add_scalar("Loss/test_loss", test_loss, epoch) 106 | 107 | # early stop 108 | save_ckpt, _ = early_stop(test_loss) 109 | 110 | if save_ckpt: 111 | trainer.save(epoch, [train_loss, test_loss], save_name) 112 | 113 | # print process bar 114 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss)) 115 | -------------------------------------------------------------------------------- /eipl/zoo/cae/data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/zoo/cae/libs/trainer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class Trainer: 13 | """ 14 | Helper class to train convolutional neural network with datalodaer 15 | 16 | Args: 17 | traindata (np.array): list of np.array. First diemension should be time steps 18 | model (torch.nn.Module): rnn model 19 | optimizer (torch.optim): optimizer 20 | batch_size (int): 21 | stdev (float): 22 | device (str): 23 | """ 24 | 25 | def __init__(self, model, optimizer, device="cpu"): 26 | self.device = device 27 | self.optimizer = optimizer 28 | self.model = model.to(self.device) 29 | 30 | def save(self, epoch, loss, savename): 31 | torch.save( 32 | { 33 | "epoch": epoch, 34 | "model_state_dict": self.model.state_dict(), 35 | #'optimizer_state_dict': self.optimizer.state_dict(), 36 | "train_loss": loss[0], 37 | "test_loss": loss[1], 38 | }, 39 | savename, 40 | ) 41 | 42 | def process_epoch(self, data, training=True): 43 | if not training: 44 | self.model.eval() 45 | else: 46 | self.model.train() 47 | 48 | total_loss = 0.0 49 | for n_batch, (xi, yi) in enumerate(data): 50 | xi = xi.to(self.device) 51 | yi = yi.to(self.device) 52 | 53 | yi_hat = self.model(xi) 54 | loss = nn.MSELoss()(yi_hat, yi) 55 | total_loss += loss.item() 56 | 57 | if training: 58 | self.optimizer.zero_grad(set_to_none=True) 59 | loss.backward() 60 | self.optimizer.step() 61 | 62 | return total_loss / n_batch 63 | -------------------------------------------------------------------------------- /eipl/zoo/cae/log/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/zoo/cae/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/zoo/cnnrnn/bin/test_pca_cnnrnn.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import torch 10 | import argparse 11 | import numpy as np 12 | import matplotlib.pylab as plt 13 | import matplotlib.animation as anim 14 | from sklearn.decomposition import PCA 15 | from eipl.model import CNNRNN, CNNRNNLN 16 | from eipl.data import SampleDownloader 17 | from eipl.utils import normalization 18 | from eipl.utils import restore_args, tensor2numpy 19 | 20 | 21 | # argument parser 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--filename", type=str, default=None) 24 | args = parser.parse_args() 25 | 26 | # restore parameters 27 | dir_name = os.path.split(args.filename)[0] 28 | params = restore_args(os.path.join(dir_name, "args.json")) 29 | 30 | # load dataset 31 | minmax = [params["vmin"], params["vmax"]] 32 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="HWC") 33 | images, joints = grasp_data.load_raw_data("test") 34 | joint_bounds = grasp_data.joint_bounds 35 | print( 36 | "images shape:{}, min={}, max={}".format(images.shape, images.min(), images.max()) 37 | ) 38 | print( 39 | "joints shape:{}, min={}, max={}".format(joints.shape, joints.min(), joints.max()) 40 | ) 41 | 42 | # define model 43 | if params["model"] == "CNNRNN": 44 | model = CNNRNN(rec_dim=params["rec_dim"], joint_dim=8, feat_dim=params["feat_dim"]) 45 | elif params["model"] == "CNNRNNLN": 46 | model = CNNRNNLN( 47 | rec_dim=params["rec_dim"], joint_dim=8, feat_dim=params["feat_dim"] 48 | ) 49 | else: 50 | assert False, "Unknown model name {}".format(params["model"]) 51 | 52 | if params["compile"]: 53 | model = torch.compile(model) 54 | 55 | # load weight 56 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 57 | model.load_state_dict(ckpt["model_state_dict"]) 58 | model.eval() 59 | 60 | # Inference 61 | states = [] 62 | state = None 63 | nloop = images.shape[1] 64 | for loop_ct in range(nloop): 65 | # load data and normalization 66 | img_t = images[:, loop_ct].transpose(0, 3, 1, 2) 67 | img_t = torch.Tensor(img_t) 68 | img_t = normalization(img_t, (0, 255), minmax) 69 | joint_t = torch.Tensor(joints[:, loop_ct]) 70 | joint_t = normalization(joint_t, joint_bounds, minmax) 71 | 72 | # predict rnn 73 | _, _, state = model(img_t, joint_t, state) 74 | states.append(state[0]) 75 | 76 | states = torch.permute(torch.stack(states), (1, 0, 2)) 77 | states = tensor2numpy(states) 78 | # Reshape the state from [N,T,D] to [-1,D] for PCA of RNN. 79 | # N is the number of datasets 80 | # T is the sequence length 81 | # D is the dimension of the hidden state 82 | N, T, D = states.shape 83 | states = states.reshape(-1, D) 84 | 85 | # plot pca 86 | loop_ct = float(360) / T 87 | pca_dim = 3 88 | pca = PCA(n_components=pca_dim).fit(states) 89 | pca_val = pca.transform(states) 90 | # Reshape the states from [-1, pca_dim] to [N,T,pca_dim] to 91 | # visualize each state as a 3D scatter. 92 | pca_val = pca_val.reshape(N, T, pca_dim) 93 | 94 | fig = plt.figure(dpi=60) 95 | ax = fig.add_subplot(projection="3d") 96 | 97 | 98 | def anim_update(i): 99 | ax.cla() 100 | angle = int(loop_ct * i) 101 | ax.view_init(30, angle) 102 | 103 | c_list = ["C0", "C1", "C2", "C3", "C4"] 104 | for n, color in enumerate(c_list): 105 | ax.scatter( 106 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0 107 | ) 108 | 109 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0) 110 | pca_ratio = pca.explained_variance_ratio_ * 100 111 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0])) 112 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1])) 113 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2])) 114 | 115 | 116 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T) 117 | ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"])) 118 | 119 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg). 120 | # ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"]), writer="imagemagick") 121 | # ani.save("./output/PCA_{}_{}.mp4".format(params["model"], params["tag"]), writer="ffmpeg") 122 | -------------------------------------------------------------------------------- /eipl/zoo/cnnrnn/bin/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import sys 10 | import torch 11 | import argparse 12 | from tqdm import tqdm 13 | import torch.optim as optim 14 | from collections import OrderedDict 15 | from torch.utils.tensorboard import SummaryWriter 16 | from eipl.model import CNNRNN, CNNRNNLN 17 | from eipl.data import MultimodalDataset, SampleDownloader 18 | from eipl.utils import EarlyStopping, check_args, set_logdir 19 | 20 | # load own library 21 | sys.path.append("./libs/") 22 | from fullBPTT import fullBPTTtrainer 23 | 24 | 25 | # argument parser 26 | parser = argparse.ArgumentParser( 27 | description="Learning convolutional and recurrent neural network" 28 | ) 29 | parser.add_argument("--model", type=str, default="CNNRNN") 30 | parser.add_argument("--epoch", type=int, default=10000) 31 | parser.add_argument("--batch_size", type=int, default=5) 32 | parser.add_argument("--rec_dim", type=int, default=50) 33 | parser.add_argument("--feat_dim", type=int, default=10) 34 | parser.add_argument("--img_loss", type=float, default=1.0) 35 | parser.add_argument("--joint_loss", type=float, default=1.0) 36 | parser.add_argument("--stdev", type=float, default=0.1) 37 | parser.add_argument("--lr", type=float, default=1e-3) 38 | parser.add_argument("--optimizer", type=str, default="adam") 39 | parser.add_argument("--log_dir", default="log/") 40 | parser.add_argument("--vmin", type=float, default=0.0) 41 | parser.add_argument("--vmax", type=float, default=1.0) 42 | parser.add_argument("--device", type=int, default=0) 43 | parser.add_argument("--compile", action="store_true") 44 | parser.add_argument("--tag", help="Tag name for snap/log sub directory") 45 | args = parser.parse_args() 46 | 47 | # check args 48 | args = check_args(args) 49 | 50 | # calculate the noise level (variance) from the normalized range 51 | stdev = args.stdev * (args.vmax - args.vmin) 52 | 53 | # set device id 54 | if args.device >= 0: 55 | device = "cuda:{}".format(args.device) 56 | else: 57 | device = "cpu" 58 | 59 | # load dataset 60 | minmax = [args.vmin, args.vmax] 61 | grasp_data = SampleDownloader("airec", "grasp_bottle", img_format="CHW") 62 | images, joints = grasp_data.load_norm_data("train", vmin=args.vmin, vmax=args.vmax) 63 | train_dataset = MultimodalDataset(images, joints, device=device, stdev=stdev) 64 | train_loader = torch.utils.data.DataLoader( 65 | train_dataset, 66 | batch_size=args.batch_size, 67 | shuffle=True, 68 | drop_last=False, 69 | ) 70 | 71 | images, joints = grasp_data.load_norm_data("test", vmin=args.vmin, vmax=args.vmax) 72 | test_dataset = MultimodalDataset(images, joints, device=device, stdev=None) 73 | test_loader = torch.utils.data.DataLoader( 74 | test_dataset, 75 | batch_size=args.batch_size, 76 | shuffle=True, 77 | drop_last=False, 78 | ) 79 | 80 | # define model 81 | if args.model == "CNNRNN": 82 | model = CNNRNN(rec_dim=args.rec_dim, joint_dim=8, feat_dim=args.feat_dim) 83 | elif args.model == "CNNRNNLN": 84 | model = CNNRNNLN(rec_dim=args.rec_dim, joint_dim=8, feat_dim=args.feat_dim) 85 | else: 86 | assert False, "Unknown model name {}".format(args.model) 87 | 88 | # torch.compile makes PyTorch code run faster 89 | if args.compile: 90 | model = torch.compile(model) 91 | 92 | # set optimizer 93 | optimizer = optim.Adam(model.parameters(), eps=1e-07) 94 | 95 | # load trainer/tester class 96 | loss_weights = [args.img_loss, args.joint_loss] 97 | trainer = fullBPTTtrainer(model, optimizer, loss_weights=loss_weights, device=device) 98 | 99 | ### training main 100 | log_dir_path = set_logdir("./" + args.log_dir, args.tag) 101 | save_name = os.path.join(log_dir_path, "{}.pth".format(args.model)) 102 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30) 103 | early_stop = EarlyStopping(patience=1000) 104 | 105 | with tqdm(range(args.epoch)) as pbar_epoch: 106 | for epoch in pbar_epoch: 107 | # train and test 108 | train_loss = trainer.process_epoch(train_loader) 109 | with torch.no_grad(): 110 | test_loss = trainer.process_epoch(test_loader, training=False) 111 | writer.add_scalar("Loss/train_loss", train_loss, epoch) 112 | writer.add_scalar("Loss/test_loss", test_loss, epoch) 113 | 114 | # early stop 115 | save_ckpt, _ = early_stop(test_loss) 116 | 117 | if save_ckpt: 118 | trainer.save(epoch, [train_loss, test_loss], save_name) 119 | 120 | # print process bar 121 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss)) 122 | -------------------------------------------------------------------------------- /eipl/zoo/cnnrnn/libs/fullBPTT.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class fullBPTTtrainer: 13 | """ 14 | Helper class to train recurrent neural networks with numpy sequences 15 | 16 | Args: 17 | traindata (np.array): list of np.array. First diemension should be time steps 18 | model (torch.nn.Module): rnn model 19 | optimizer (torch.optim): optimizer 20 | input_param (float): input parameter of sequential generation. 1.0 means open mode. 21 | """ 22 | 23 | def __init__(self, model, optimizer, loss_weights=[1.0, 1.0], device="cpu"): 24 | self.device = device 25 | self.optimizer = optimizer 26 | self.loss_weights = loss_weights 27 | self.model = model.to(self.device) 28 | 29 | def save(self, epoch, loss, savename): 30 | torch.save( 31 | { 32 | "epoch": epoch, 33 | "model_state_dict": self.model.state_dict(), 34 | #'optimizer_state_dict': self.optimizer.state_dict(), 35 | "train_loss": loss[0], 36 | "test_loss": loss[1], 37 | }, 38 | savename, 39 | ) 40 | 41 | def process_epoch(self, data, training=True): 42 | if not training: 43 | self.model.eval() 44 | else: 45 | self.model.train() 46 | 47 | total_loss = 0.0 48 | for n_batch, ((x_img, x_joint), (y_img, y_joint)) in enumerate(data): 49 | x_img = x_img.to(self.device) 50 | y_img = y_img.to(self.device) 51 | x_joint = x_joint.to(self.device) 52 | y_joint = y_joint.to(self.device) 53 | 54 | state = None 55 | yi_list, yv_list = [], [] 56 | T = x_img.shape[1] 57 | for t in range(T - 1): 58 | _yi_hat, _yv_hat, state = self.model(x_img[:, t], x_joint[:, t], state) 59 | yi_list.append(_yi_hat) 60 | yv_list.append(_yv_hat) 61 | 62 | yi_hat = torch.permute(torch.stack(yi_list), (1, 0, 2, 3, 4)) 63 | yv_hat = torch.permute(torch.stack(yv_list), (1, 0, 2)) 64 | loss = self.loss_weights[0] * nn.MSELoss()( 65 | yi_hat, y_img[:, 1:] 66 | ) + self.loss_weights[1] * nn.MSELoss()(yv_hat, y_joint[:, 1:]) 67 | total_loss += loss.item() 68 | 69 | if training: 70 | self.optimizer.zero_grad(set_to_none=True) 71 | loss.backward() 72 | self.optimizer.step() 73 | 74 | return total_loss / (n_batch + 1) 75 | -------------------------------------------------------------------------------- /eipl/zoo/cnnrnn/log/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/zoo/cnnrnn/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/zoo/rnn/bin/test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import torch 10 | import argparse 11 | import numpy as np 12 | import matplotlib.pylab as plt 13 | import matplotlib.animation as anim 14 | 15 | # own libraries 16 | from eipl.model import BasicLSTM, BasicMTRNN 17 | from eipl.utils import normalization 18 | from eipl.utils import restore_args, tensor2numpy 19 | from eipl.data import WeightDownloader 20 | 21 | 22 | # argument parser 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--filename", type=str, default=None) 25 | parser.add_argument("--idx", type=str, default="0") 26 | parser.add_argument("--pretrained", action="store_true") 27 | args = parser.parse_args() 28 | 29 | # check args 30 | assert args.filename or args.pretrained, "Please set filename or pretrained" 31 | 32 | # load pretrained weight 33 | if args.pretrained: 34 | WeightDownloader("airec", "grasp_bottle") 35 | args.filename = os.path.join( 36 | os.path.expanduser("~"), ".eipl/airec/grasp_bottle/weights/RNN/model.pth" 37 | ) 38 | 39 | # restore parameters 40 | dir_name = os.path.split(args.filename)[0] 41 | params = restore_args(os.path.join(dir_name, "args.json")) 42 | idx = int(args.idx) 43 | 44 | # load dataset 45 | minmax = [params["vmin"], params["vmax"]] 46 | feat_bounds = np.load("../cae/data/feat_bounds.npy") 47 | _feats = np.load("../cae/data/test/features.npy") 48 | test_feats = normalization(_feats, feat_bounds, minmax) 49 | test_joints = np.load("../cae/data/test/joints.npy") 50 | x_data = np.concatenate((test_feats, test_joints), axis=-1) 51 | x_data = torch.Tensor(x_data) 52 | in_dim = x_data.shape[-1] 53 | 54 | # define model 55 | if params["model"] == "LSTM": 56 | model = BasicLSTM(in_dim=in_dim, rec_dim=params["rec_dim"], out_dim=in_dim) 57 | elif params["model"] == "MTRNN": 58 | model = BasicMTRNN(in_dim, fast_dim=60, slow_dim=5, fast_tau=2, slow_tau=12) 59 | else: 60 | assert False, "Unknown model name {}".format(params["model"]) 61 | 62 | if params["compile"]: 63 | model = torch.compile(model) 64 | 65 | # load weight 66 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 67 | model.load_state_dict(ckpt["model_state_dict"]) 68 | model.eval() 69 | 70 | # Inference 71 | y_hat = [] 72 | state = None 73 | T = x_data.shape[1] 74 | for i in range(T): 75 | _y, state = model(x_data[:, i], state) 76 | y_hat.append(_y) 77 | 78 | y_hat = torch.permute(torch.stack(y_hat), (1, 0, 2)) 79 | y_hat = tensor2numpy(y_hat) 80 | y_joints = y_hat[:, :, 10:] 81 | y_feats = y_hat[:, :, :10] 82 | 83 | # plot animation 84 | fig, ax = plt.subplots(1, 2, figsize=(8, 4), dpi=60) 85 | 86 | 87 | def anim_update(i): 88 | for j in range(2): 89 | ax[j].cla() 90 | 91 | ax[0].set_ylim(-0.1, 1.1) 92 | ax[0].set_xlim(0, T) 93 | ax[0].plot(test_joints[idx, 1:], linestyle="dashed", c="k") 94 | for joint_idx in range(8): 95 | ax[0].plot(np.arange(i + 1), y_joints[idx, : i + 1, joint_idx]) 96 | ax[0].set_xlabel("Step") 97 | ax[0].set_title("Joint angles") 98 | 99 | ax[1].set_ylim(-0.1, 1.1) 100 | ax[1].set_xlim(0, T) 101 | ax[1].plot(test_feats[idx, 1:], linestyle="dashed", c="k") 102 | for joint_idx in range(10): 103 | ax[1].plot(np.arange(i + 1), y_feats[idx, : i + 1, joint_idx]) 104 | ax[1].set_xlabel("Step") 105 | ax[1].set_title("Image features") 106 | 107 | 108 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T) 109 | ani.save("./output/{}_{}_{}.gif".format(params["model"], params["tag"], idx)) 110 | 111 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg). 112 | # ani.save("./output/{}_{}_{}.gif".format(params["model"], params["tag"], idx), writer="imagemagick") 113 | # ani.save("./output/{}_{}_{}.mp4".format(params["model"], params["tag"], idx), writer="ffmpeg") 114 | -------------------------------------------------------------------------------- /eipl/zoo/rnn/bin/test_pca_rnn.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import torch 10 | import argparse 11 | import numpy as np 12 | import matplotlib.pylab as plt 13 | import matplotlib.animation as anim 14 | from sklearn.decomposition import PCA 15 | 16 | # own libraries 17 | from eipl.model import BasicLSTM, BasicMTRNN 18 | from eipl.utils import normalization 19 | from eipl.utils import restore_args, tensor2numpy 20 | 21 | 22 | # argument parser 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--filename", type=str, default=None) 25 | args = parser.parse_args() 26 | 27 | # restore parameters 28 | dir_name = os.path.split(args.filename)[0] 29 | params = restore_args(os.path.join(dir_name, "args.json")) 30 | 31 | # load dataset 32 | minmax = [params["vmin"], params["vmax"]] 33 | feat_bounds = np.load("../cae/data/feat_bounds.npy") 34 | _feats = np.load("../cae/data/test/features.npy") 35 | test_feats = normalization(_feats, feat_bounds, minmax) 36 | test_joints = np.load("../cae/data/test/joints.npy") 37 | x_data = np.concatenate((test_feats, test_joints), axis=-1) 38 | x_data = torch.Tensor(x_data) 39 | in_dim = x_data.shape[-1] 40 | 41 | # define model 42 | if params["model"] == "LSTM": 43 | model = BasicLSTM(in_dim=in_dim, rec_dim=params["rec_dim"], out_dim=in_dim) 44 | elif params["model"] == "MTRNN": 45 | model = BasicMTRNN(in_dim, fast_dim=60, slow_dim=5, fast_tau=2, slow_tau=12) 46 | else: 47 | assert False, "Unknown model name {}".format(params["model"]) 48 | 49 | if params["compile"]: 50 | model = torch.compile(model) 51 | 52 | # load weight 53 | ckpt = torch.load(args.filename, map_location=torch.device("cpu")) 54 | model.load_state_dict(ckpt["model_state_dict"]) 55 | model.eval() 56 | 57 | # Inference 58 | states = [] 59 | state = None 60 | N = x_data.shape[1] 61 | for i in range(N): 62 | _, state = model(x_data[:, i], state) 63 | # lstm returns hidden state and cell state. 64 | # Here we store the hidden state to analyze the internal representation of the RNN. 65 | states.append(state[0]) 66 | 67 | states = torch.permute(torch.stack(states), (1, 0, 2)) 68 | states = tensor2numpy(states) 69 | # Reshape the state from [N,T,D] to [-1,D] for PCA of RNN. 70 | # N is the number of datasets 71 | # T is the sequence length 72 | # D is the dimension of the hidden state 73 | N, T, D = states.shape 74 | states = states.reshape(-1, D) 75 | 76 | # plot pca 77 | loop_ct = float(360) / T 78 | pca_dim = 3 79 | pca = PCA(n_components=pca_dim).fit(states) 80 | pca_val = pca.transform(states) 81 | # Reshape the states from [-1, pca_dim] to [N,T,pca_dim] to 82 | # visualize each state as a 3D scatter. 83 | pca_val = pca_val.reshape(N, T, pca_dim) 84 | 85 | fig = plt.figure(dpi=60) 86 | ax = fig.add_subplot(projection="3d") 87 | 88 | 89 | def anim_update(i): 90 | ax.cla() 91 | angle = int(loop_ct * i) 92 | ax.view_init(30, angle) 93 | 94 | c_list = ["C0", "C1", "C2", "C3", "C4"] 95 | for n, color in enumerate(c_list): 96 | ax.scatter( 97 | pca_val[n, 1:, 0], pca_val[n, 1:, 1], pca_val[n, 1:, 2], color=color, s=3.0 98 | ) 99 | 100 | ax.scatter(pca_val[n, 0, 0], pca_val[n, 0, 1], pca_val[n, 0, 2], color="k", s=30.0) 101 | pca_ratio = pca.explained_variance_ratio_ * 100 102 | ax.set_xlabel("PC1 ({:.1f}%)".format(pca_ratio[0])) 103 | ax.set_ylabel("PC2 ({:.1f}%)".format(pca_ratio[1])) 104 | ax.set_zlabel("PC3 ({:.1f}%)".format(pca_ratio[2])) 105 | 106 | 107 | ani = anim.FuncAnimation(fig, anim_update, interval=int(np.ceil(T / 10)), frames=T) 108 | ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"])) 109 | 110 | # If an error occurs in generating the gif animation or mp4, change the writer (imagemagick/ffmpeg). 111 | # ani.save("./output/PCA_{}_{}.gif".format(params["model"], params["tag"]), writer="imagemagick") 112 | # ani.save("./output/PCA_{}_{}.mp4".format(params["model"], params["tag"]), writer="ffmpeg") 113 | -------------------------------------------------------------------------------- /eipl/zoo/rnn/bin/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import os 9 | import sys 10 | import torch 11 | import argparse 12 | import numpy as np 13 | from tqdm import tqdm 14 | from collections import OrderedDict 15 | import torch.optim as optim 16 | from torch.utils.tensorboard import SummaryWriter 17 | from eipl.model import BasicLSTM, BasicMTRNN 18 | from eipl.utils import normalization 19 | from eipl.utils import EarlyStopping, check_args, set_logdir 20 | 21 | # load own library 22 | sys.path.append("./libs/") 23 | from fullBPTT import fullBPTTtrainer 24 | from dataloader import TimeSeriesDataSet 25 | 26 | 27 | # argument parser 28 | parser = argparse.ArgumentParser(description="Learning convolutional autoencoder") 29 | parser.add_argument("--model", type=str, default="LSTM") 30 | parser.add_argument("--epoch", type=int, default=20000) 31 | parser.add_argument("--batch_size", type=int, default=5) 32 | parser.add_argument("--rec_dim", type=int, default=50) 33 | parser.add_argument("--stdev", type=float, default=0.1) 34 | parser.add_argument("--lr", type=float, default=1e-3) 35 | parser.add_argument("--optimizer", type=str, default="adam") 36 | parser.add_argument("--log_dir", default="log/") 37 | parser.add_argument("--vmin", type=float, default=0.0) 38 | parser.add_argument("--vmax", type=float, default=1.0) 39 | parser.add_argument("--device", type=int, default=-1) 40 | parser.add_argument("--compile", action="store_true") 41 | parser.add_argument("--tag", help="Tag name for snap/log sub directory") 42 | args = parser.parse_args() 43 | 44 | # check args 45 | args = check_args(args) 46 | 47 | # calculate the noise level (variance) from the normalized range 48 | stdev = args.stdev * (args.vmax - args.vmin) 49 | 50 | # set device id 51 | if args.device >= 0: 52 | device = "cuda:{}".format(args.device) 53 | else: 54 | device = "cpu" 55 | 56 | # load dataset 57 | minmax = [args.vmin, args.vmax] 58 | feat_bounds = np.load("../cae/data/feat_bounds.npy") 59 | _feats = np.load("../cae/data/train/features.npy") 60 | train_feats = normalization(_feats, feat_bounds, minmax) 61 | train_joints = np.load("../cae/data/train/joints.npy") 62 | in_dim = train_feats.shape[-1] + train_joints.shape[-1] 63 | train_dataset = TimeSeriesDataSet( 64 | train_feats, train_joints, minmax=[args.vmin, args.vmax], stdev=stdev 65 | ) 66 | train_loader = torch.utils.data.DataLoader( 67 | train_dataset, 68 | batch_size=args.batch_size, 69 | shuffle=True, 70 | drop_last=False, 71 | pin_memory=True, 72 | ) 73 | 74 | _feats = np.load("../cae/data/test/features.npy") 75 | test_feats = normalization(_feats, feat_bounds, minmax) 76 | test_joints = np.load("../cae/data/test/joints.npy") 77 | test_dataset = TimeSeriesDataSet( 78 | test_feats, test_joints, minmax=[args.vmin, args.vmax], stdev=None 79 | ) 80 | test_loader = torch.utils.data.DataLoader( 81 | test_dataset, 82 | batch_size=args.batch_size, 83 | shuffle=True, 84 | drop_last=False, 85 | pin_memory=True, 86 | ) 87 | 88 | # define model 89 | if args.model == "LSTM": 90 | model = BasicLSTM(in_dim=in_dim, rec_dim=args.rec_dim, out_dim=in_dim) 91 | elif args.model == "MTRNN": 92 | model = BasicMTRNN(in_dim, fast_dim=60, slow_dim=5, fast_tau=2, slow_tau=12) 93 | else: 94 | assert False, "Unknown model name {}".format(args.model) 95 | 96 | # torch.compile makes PyTorch code run faster 97 | if args.compile: 98 | torch.set_float32_matmul_precision("high") 99 | model = torch.compile(model) 100 | 101 | # set optimizer 102 | if args.optimizer.casefold() == "adam": 103 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 104 | elif args.optimizer.casefold() == "radam": 105 | optimizer = optim.RAdam(model.parameters(), lr=args.lr) 106 | else: 107 | assert False, "Unknown optimizer name {}. please set Adam or RAdam.".format( 108 | args.optimizer 109 | ) 110 | 111 | # load trainer/tester class 112 | trainer = fullBPTTtrainer(model, optimizer, device=device) 113 | 114 | ### training main 115 | log_dir_path = set_logdir("./" + args.log_dir, args.tag) 116 | save_name = os.path.join(log_dir_path, "{}.pth".format(args.model)) 117 | writer = SummaryWriter(log_dir=log_dir_path, flush_secs=30) 118 | early_stop = EarlyStopping(patience=1000) 119 | 120 | with tqdm(range(args.epoch)) as pbar_epoch: 121 | for epoch in pbar_epoch: 122 | # train and test 123 | train_loss = trainer.process_epoch(train_loader) 124 | with torch.no_grad(): 125 | test_loss = trainer.process_epoch(test_loader, training=False) 126 | writer.add_scalar("Loss/train_loss", train_loss, epoch) 127 | writer.add_scalar("Loss/test_loss", test_loss, epoch) 128 | 129 | # early stop 130 | save_ckpt, _ = early_stop(test_loss) 131 | 132 | if save_ckpt: 133 | trainer.save(epoch, [train_loss, test_loss], save_name) 134 | 135 | # print process bar 136 | pbar_epoch.set_postfix(OrderedDict(train_loss=train_loss, test_loss=test_loss)) 137 | -------------------------------------------------------------------------------- /eipl/zoo/rnn/libs/dataloader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class TimeSeriesDataSet(Dataset): 14 | """AIREC_sample dataset. 15 | 16 | Args: 17 | feats (np.array): Set the image features. 18 | joints (np.array): Set the joint angles. 19 | minmax (float, optional): Set normalization range, default is [0.1,0.9]. 20 | """ 21 | 22 | def __init__(self, feats, joints, minmax=[0.1, 0.9], stdev=0.0): 23 | self.stdev = stdev 24 | self.feats = torch.from_numpy(feats).float() 25 | self.joints = torch.from_numpy(joints).float() 26 | 27 | def __len__(self): 28 | return len(self.feats) 29 | 30 | def __getitem__(self, idx): 31 | # normalization and convert numpy array to torch tensor 32 | y_feat = self.feats[idx] 33 | y_joint = self.joints[idx] 34 | y_data = torch.concat((y_feat, y_joint), axis=-1) 35 | 36 | # apply gaussian noise to joint angles and image features 37 | if self.stdev is not None: 38 | x_feat = self.feats[idx] + torch.normal( 39 | mean=0, std=self.stdev, size=y_feat.shape 40 | ) 41 | x_joint = self.joints[idx] + torch.normal( 42 | mean=0, std=self.stdev, size=y_joint.shape 43 | ) 44 | else: 45 | x_feat = self.feats[idx] 46 | x_joint = self.joints[idx] 47 | 48 | x_data = torch.concat((x_feat, x_joint), axis=-1) 49 | 50 | return [x_data, y_data] 51 | 52 | 53 | if __name__ == "__main__": 54 | import time 55 | 56 | # random dataset 57 | feats = np.random.randn(10, 120, 10) 58 | joints = np.random.randn(10, 120, 8) 59 | 60 | # load data 61 | data_loader = TimeSeriesDataSet(feats, joints, minmax=[0.1, 0.9]) 62 | x_data, y_data = data_loader[1] 63 | print(x_data.shape, y_data.shape) 64 | 65 | train_loader = torch.utils.data.DataLoader( 66 | data_loader, 67 | batch_size=3, 68 | shuffle=True, 69 | drop_last=False, 70 | pin_memory=True, 71 | ) 72 | 73 | print("[Start] load data using torch data loader") 74 | start_time = time.time() 75 | for batch in train_loader: 76 | print(batch[0].shape, batch[1].shape) 77 | 78 | print("[Finish] time: ", time.time() - start_time) 79 | -------------------------------------------------------------------------------- /eipl/zoo/rnn/libs/fullBPTT.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Since 2023 Ogata Laboratory, Waseda University 3 | # 4 | # Released under the AGPL license. 5 | # see https://www.gnu.org/licenses/agpl-3.0.txt 6 | # 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class fullBPTTtrainer: 13 | """ 14 | Helper class to train recurrent neural networks with numpy sequences 15 | 16 | Args: 17 | traindata (np.array): list of np.array. First diemension should be time steps 18 | model (torch.nn.Module): rnn model 19 | optimizer (torch.optim): optimizer 20 | input_param (float): input parameter of sequential generation. 1.0 means open mode. 21 | """ 22 | 23 | def __init__(self, model, optimizer, device="cpu"): 24 | self.device = device 25 | self.optimizer = optimizer 26 | self.model = model.to(self.device) 27 | 28 | def save(self, epoch, loss, savename): 29 | torch.save( 30 | { 31 | "epoch": epoch, 32 | "model_state_dict": self.model.state_dict(), 33 | #'optimizer_state_dict': self.optimizer.state_dict(), 34 | "train_loss": loss[0], 35 | "test_loss": loss[1], 36 | }, 37 | savename, 38 | ) 39 | 40 | def process_epoch(self, data, training=True): 41 | if not training: 42 | self.model.eval() 43 | else: 44 | self.model.train() 45 | 46 | total_loss = 0.0 47 | for n_batch, (x, y) in enumerate(data): 48 | x = x.to(self.device) 49 | y = y.to(self.device) 50 | 51 | state = None 52 | y_list = [] 53 | self.optimizer.zero_grad(set_to_none=True) 54 | for t in range(x.shape[1] - 1): 55 | y_hat, state = self.model(x[:, t], state) 56 | y_list.append(y_hat) 57 | 58 | y_hat = torch.permute(torch.stack(y_list), (1, 0, 2)) 59 | loss = nn.MSELoss()(y_hat, y[:, 1:]) 60 | total_loss += loss.item() 61 | 62 | if training: 63 | loss.backward() 64 | self.optimizer.step() 65 | 66 | return total_loss / (n_batch + 1) 67 | -------------------------------------------------------------------------------- /eipl/zoo/rnn/log/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /eipl/zoo/rnn/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ###### Requirements without Version Specifiers ###### 2 | scipy 3 | gdown 4 | matplotlib 5 | ipdb 6 | opencv-python 7 | tensorboard 8 | torchinfo 9 | tqdm 10 | scikit-learn 11 | torch 12 | torchvision 13 | onnx 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | DESCRIPTION = "EIPL: Embodied Intelligence with Deep Predictive Learning" 4 | NAME = "eipl" 5 | AUTHOR = "Hiroshi Ito" 6 | AUTHOR_EMAIL = "it.hiroshi.o@gmail.com" 7 | URL = "https://github.com/ogata-lab/eipl" 8 | LICENSE = "MIT License" 9 | DOWNLOAD_URL = "https://github.com/ogata-lab/eipl" 10 | VERSION = "1.1.1" 11 | PYTHON_REQUIRES = ">=3.8" 12 | 13 | """ 14 | INSTALL_REQUIRES = [ 15 | 'pytorch>=1.9.0', 16 | 'matplotlib>=3.3.4', 17 | 'numpy >=1.20.3', 18 | 'matplotlib>=3.3.4', 19 | 'scipy>=1.6.3', 20 | 'scikit-learn>=0.24.2', 21 | ] 22 | """ 23 | 24 | """ 25 | with open('README.rst', 'r') as fp: 26 | readme = fp.read() 27 | with open('CONTACT.txt', 'r') as fp: 28 | contacts = fp.read() 29 | long_description = readme + '\n\n' + contacts 30 | """ 31 | 32 | 33 | setup( 34 | name=NAME, 35 | author=AUTHOR, 36 | author_email=AUTHOR_EMAIL, 37 | maintainer=AUTHOR, 38 | maintainer_email=AUTHOR_EMAIL, 39 | description=DESCRIPTION, 40 | # long_description=long_description, 41 | license=LICENSE, 42 | url=URL, 43 | version=VERSION, 44 | download_url=DOWNLOAD_URL, 45 | # python_requires=PYTHON_REQUIRES, 46 | # install_requires=INSTALL_REQUIRES, 47 | packages=find_packages(), 48 | ) 49 | --------------------------------------------------------------------------------