├── .gitignore ├── COPYING ├── README.md ├── fcn_approx_utils.py ├── function_optimization ├── Function_2D.ipynb ├── Gmm_nD.ipynb ├── Himmelblaue_4D.ipynb ├── Rosenbrock_4D.ipynb ├── Rosenbrock_nD.ipynb ├── fcn_plotting_utils.py ├── sine_nD.ipynb └── test_fcns.py ├── manipulator ├── data │ ├── sdf.npy │ ├── sdf_close_ground.npy │ ├── sphere_setting.npy │ └── urdf │ │ ├── env.urdf │ │ ├── frankaemika_new │ │ ├── meshes │ │ │ ├── collision │ │ │ │ ├── finger.stl │ │ │ │ ├── hand.stl │ │ │ │ ├── link0.stl │ │ │ │ ├── link1.stl │ │ │ │ ├── link2.stl │ │ │ │ ├── link3.stl │ │ │ │ ├── link4.stl │ │ │ │ ├── link5.stl │ │ │ │ ├── link6.stl │ │ │ │ └── link7.stl │ │ │ └── visual │ │ │ │ ├── finger.dae │ │ │ │ ├── hand.dae │ │ │ │ ├── link0.dae │ │ │ │ ├── link1.dae │ │ │ │ ├── link2.dae │ │ │ │ ├── link3.dae │ │ │ │ ├── link4.dae │ │ │ │ ├── link5.dae │ │ │ │ ├── link6.dae │ │ │ │ └── link7.dae │ │ ├── panda_arm.urdf │ │ └── panda_arm_franka.urdf │ │ ├── mesh.stl │ │ ├── meshes │ │ ├── collision │ │ │ ├── finger.stl │ │ │ ├── hand.stl │ │ │ ├── link0.stl │ │ │ ├── link1.stl │ │ │ ├── link2.stl │ │ │ ├── link3.stl │ │ │ ├── link4.stl │ │ │ ├── link5.stl │ │ │ ├── link6.stl │ │ │ └── link7.stl │ │ └── visual │ │ │ ├── finger.dae │ │ │ ├── hand.dae │ │ │ ├── link0.dae │ │ │ ├── link1.dae │ │ │ ├── link2.dae │ │ │ ├── link3.dae │ │ │ ├── link4.dae │ │ │ ├── link5.dae │ │ │ ├── link6.dae │ │ │ └── link7.dae │ │ ├── panda_arm.urdf │ │ ├── quadrotor.urdf │ │ ├── quadrotor_base.obj │ │ └── ur10 │ │ ├── meshes │ │ └── ur10 │ │ │ ├── collision │ │ │ ├── Base.stl │ │ │ ├── Forearm.stl │ │ │ ├── Shoulder.stl │ │ │ ├── UpperArm.stl │ │ │ ├── Wrist1.stl │ │ │ ├── Wrist2.stl │ │ │ └── Wrist3.stl │ │ │ └── visual │ │ │ ├── Base.dae │ │ │ ├── Forearm.dae │ │ │ ├── Shoulder.dae │ │ │ ├── UpperArm.dae │ │ │ ├── Wrist1.dae │ │ │ ├── Wrist2.dae │ │ │ └── Wrist3.dae │ │ └── ur10.urdf ├── manipulator_utils.py ├── panda_cost_utils.py ├── panda_ik.ipynb ├── panda_ik.py ├── panda_ik_visualize.ipynb ├── panda_kinematics.py ├── panda_reaching.py ├── panda_reaching_noTask.ipynb ├── panda_reaching_visualize.ipynb ├── panda_visualization_utils.py ├── ur10_ik.ipynb ├── ur10_ik.py ├── ur10_ik_visualize.ipynb └── ur10_kinematics.py ├── toy_robots ├── cost_utils.py ├── planar_manipulator.py ├── planar_manipulator_ik.ipynb ├── planar_manipulator_ik.py ├── planar_manipulator_reaching-Old.ipynb ├── planar_manipulator_reaching.py ├── planar_manipulator_reaching_draft.ipynb ├── planar_manipulator_reaching_noTask.ipynb └── plot_utils.py ├── tt_utils.py ├── tt_utils_ol.py ├── tt_vs_gmm_batch.ipynb ├── tt_vs_nn.ipynb ├── tt_vs_nn_batch.ipynb ├── tt_vs_nn_vs_bgmm.ipynb ├── ttgo.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *checkpoint.ipynb 2 | *.ipynb_checkpoints 3 | *.pyc 4 | *.sql3 5 | **/logs/ 6 | 7 | 8 | /manipulator/logs/** 9 | *.pyc 10 | *.pickle 11 | *.png 12 | *.jpeg 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TTGO: Tensor Train for Global Optimization Problems in Robotics 2 | 3 | A PyTorch implementation of TTGO algorithm and the applications presented in the paper "Tensor Train for Global Optimization Problems in Robotics " 4 | 5 | Website: https://sites.google.com/view/ttgo/home 6 | 7 | Paper: https://arxiv.org/pdf/2206.05077.pdf 8 | 9 | ### Pre-requistes 10 | - Install the tntorch library from: https://github.com/rballester/tntorch (pip install tntorch) 11 | - Pybullet (only required for visualization of robotics applications): https://pypi.org/project/pybullet/ 12 | - RoMa (only required robotic applications; for quarternion calculations): https://naver.github.io/roma/ 13 | 14 | ### Overview 15 | - *./ttgo.py*: the TTGO algorithm is defined in this class 16 | - *./function_optimization/*: includes the application of ttgo for optimization of several benchmark functions 17 | - Recommendation: try these notebooks first to understand the approach 18 | - *./toy_robots/*: application of ttgo for simple toy models of robotics problems (planar manipulator IK and reaching tasks) 19 | - *./manipulator/*: application of ttgo for IK and reaching tasks with some standard manipulators 20 | 21 | Note: All the implementations are fully compatible for use with GPU. For faster computation, it is highly recommended to use GPU 22 | 23 | For any questions, contact the author Suhan Shetty 24 | -------------------------------------------------------------------------------- /fcn_approx_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import time 6 | from sklearn.mixture import BayesianGaussianMixture 7 | import numpy as np 8 | 9 | class GMM: 10 | """ 11 | Mixture of spherical Gaussians (un-normalized) 12 | nmix: number of mixture coefficients 13 | n: dimension of the domain 14 | s: variance 15 | mu: the centers assumed to be in : [-L,L]^n 16 | """ 17 | def __init__(self, n=2, nmix=3, L=1, mx_coef=None, mu=None, s=0.2, device='cpu'): 18 | self.device = device 19 | self.n = n # dim 20 | self.nmix = nmix # number of components 21 | self.L = L # boundary 22 | 23 | self.s = s # assuming spherical Gaussians 24 | self.std = (s*torch.ones(self.n)).view(1,self.n).expand(self.nmix,-1).to(device) 25 | 26 | if mu is None: 27 | self.generate_gmm_params() 28 | else: 29 | self.mx_coef = mx_coef.to(self.device) 30 | self.mu = mu.to(self.device) 31 | self.mv_normals = [torch.distributions.MultivariateNormal(self.mu[k], (self.s**2)*torch.eye(self.n).to(device)) for k in range(self.nmix)] 32 | 33 | def generate_gmm_params(self): 34 | self.mx_coef = torch.rand(self.nmix).to(self.device) 35 | self.mx_coef = self.mx_coef/torch.sum(self.mx_coef) 36 | self.mu = (torch.rand(self.nmix,self.n).to(self.device)-0.5)*2*self.L 37 | 38 | def pdf(self, x): 39 | prob = torch.tensor([0]).to(self.device) 40 | for k in range(self.nmix): 41 | pdf_k = torch.exp(self.mv_normals[k].log_prob(x)) 42 | prob = prob + self.mx_coef[k]*pdf_k #*normalize_k*torch.exp(-0.5*l.view(-1)/self.s) 43 | return prob 44 | 45 | def log_pdf(self,x): 46 | return torch.log(1e-6+self.pdf(x)) 47 | 48 | def generate_sample(self, n_samples): 49 | X_noise = torch.randn(n_samples, self.n).to(self.device) 50 | idx_comp = torch.multinomial(self.mx_coef.view(-1), n_samples, replacement=True).view(-1) 51 | X = self.mu[idx_comp] + self.s*X_noise 52 | return X 53 | 54 | 55 | class BGMM: 56 | def __init__(self, nmix): 57 | self.nmix = nmix 58 | self.model = BayesianGaussianMixture(n_components=nmix, covariance_type='full') 59 | 60 | def load_data(self, data): 61 | # data should be a numpy array: n_samples x dim 62 | self.data = data 63 | 64 | def fit(self): 65 | self.model.fit(self.data) 66 | 67 | def pdf(self, x): 68 | x = torch.from_numpy(x) 69 | prob = torch.tensor([0]) 70 | for k in range(self.nmix): 71 | mu = torch.from_numpy(self.model.means_[k]).view(-1) 72 | cov = torch.from_numpy(self.model.covariances_[k]) 73 | mv_normal = torch.distributions.MultivariateNormal(mu, cov) 74 | pdf_k = torch.exp(mv_normal.log_prob(x)) 75 | prob = prob + self.model.weights_[k]*pdf_k 76 | return prob 77 | 78 | # def pdf(self, X): 79 | # return np.exp(self.model.score_samples(X)) 80 | 81 | def log_pdf(self, x): 82 | return np.log(1e-6+self.pdf(x)) 83 | 84 | class NNModel(nn.Module): 85 | def __init__(self, dim, width=64): 86 | super(NNModel, self).__init__() 87 | self.flatten = nn.Flatten() 88 | self.linear_relu_stack= nn.Sequential( 89 | nn.Linear(dim, width), 90 | nn.ReLU(), 91 | nn.Linear(width, width), 92 | nn.ReLU(), 93 | nn.Linear(width, 1), 94 | ) 95 | 96 | def forward(self, x): 97 | x = self.flatten(x) 98 | y = self.linear_relu_stack(x) 99 | return y 100 | 101 | 102 | class NeuralNetwork(nn.Module): 103 | def __init__(self, dim, width=64, lr=1e-3, device='cpu'): 104 | super(NeuralNetwork, self).__init__() 105 | self.device = device 106 | self.flatten = nn.Flatten() 107 | self.model = NNModel(dim, width).to(device) 108 | self.loss_fcn = nn.MSELoss() 109 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 110 | 111 | def load_data(self, train_data, test_data): 112 | self.train_data = train_data.to(self.device) 113 | self.test_data = test_data.to(self.device) 114 | 115 | def train(self, num_epochs=10, batch_size=128, verbose=False): 116 | size = self.train_data.shape[0] 117 | for k in range(num_epochs): 118 | counter = 0 119 | loss_train = 0. 120 | counter_batch = 0 121 | for i in range(int(size/batch_size)-1): 122 | # Compute prediction and loss 123 | next_counter = (counter+batch_size) 124 | x_data = self.train_data[counter:next_counter,:-1] 125 | y_data = self.train_data[counter:next_counter,-1].view(-1,1) 126 | y_pred = self.model(x_data) 127 | loss = self.loss_fcn(y_pred, y_data) 128 | # Backpropagation 129 | self.optimizer.zero_grad() 130 | loss.backward() 131 | self.optimizer.step() 132 | counter = 1*next_counter 133 | loss_train += loss.item() 134 | counter_batch += 1 135 | loss_train = loss_train/counter_batch 136 | loss_test = self.test() 137 | if verbose: 138 | print(f"epoch:{k}, loss-train:{loss_train}, loss-test:{loss_test}") 139 | 140 | def test(self): 141 | self.model.eval() 142 | x_data = self.test_data[:,:-1] 143 | y_data = self.test_data[:,-1].view(-1,1) 144 | with torch.no_grad(): 145 | pred = self.model(x_data) 146 | test_loss = self.loss_fcn(pred, y_data).item() 147 | return test_loss 148 | 149 | 150 | 151 | # def FitNN(x_train, y_train, x_test, y_test, 152 | # learning_rate = 1e-3,batch_size = 128, epochs=10, device="cpu"): 153 | # dim = x_train.shape[-1] 154 | # data_train = torch.cat((x_train.view(-1,dim),y_train.view(-1,1)),dim=-1) 155 | # data_test = torch.cat((x_test.view(-1,dim),y_test.view(-1,1)),dim=-1) 156 | # model = NeuralNetwork(dim=x_train.shape[-1],width=128).to(device)#width=dim*nmix*10 157 | # loss_fn = nn.MSELoss() 158 | # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 159 | # t1 = time.time() 160 | # for t in range(epochs): 161 | # # print(f"Epoch {t+1}\n-------------------------------") 162 | # train_loop(data_train, model, loss_fn, optimizer, batch_size) 163 | # model.eval() 164 | # test_loop(data_test, model, loss_fn) 165 | # t2 = time.time() 166 | # y_nn_0 = model(x_test) 167 | # mse_nn_0 = (((y_nn_0.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean() 168 | # return (mse_nn_0, (t2-t1)) -------------------------------------------------------------------------------- /function_optimization/Himmelblaue_4D.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Optimization of 2-D Himmelblaue function for varied coefficients\n", 8 | "##### Reference: https://en.wikipedia.org/wiki/Himmelblau%27s_function\n", 9 | "$$ cost(a,b,x,y) = (x^2+y-a)^2 + (x+y^2-b)^2 ,$$\n", 10 | "$$pdf(a,b,x,y) = e^{-cost(a,b,x,y)}$$ \n", 11 | "\n", 12 | "Here, $\\mathbf{x}_{task}=(a,b)$ and $\\mathbf{x}_{decision} = (x,y)$\n", 13 | "\n", 14 | "Depending on the choice of task-parameters $(a,b)$ there could be several global optima.\n", 15 | "\n", 16 | "We show that TTGO is able to find the multiple global optima consistently with a hand few of samples from the constructed tt-model of the above pdf (constructed offline) for various selection of $\\mathbf{x}_{task}=(a,b)$ in the online phase. We use scipy's SLSQP to fine tune the initialization. \n", 17 | "\n", 18 | "Condition on different values of $\\mathbf{x}_{task}=(a,b)$ to test the model. Watch out for the multimodality in the solutions of TTGO!\n", 19 | "\n", 20 | "Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/\n", 21 | " Written by Suhan Shetty ,\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import torch\n", 31 | "import numpy as np\n", 32 | "import sys\n", 33 | "sys.path.append('./fcn_opt')\n", 34 | "sys.path.append('../')\n", 35 | "\n", 36 | "from ttgo import TTGO\n", 37 | "import tt_utils\n", 38 | "from test_fcns import Himmelblaue_4D \n", 39 | "from fcn_plotting_utils import plot_surf, plot_contour\n", 40 | "\n", 41 | "%load_ext autoreload\n", 42 | "np.set_printoptions(precision=3)\n", 43 | "%autoreload 2" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 53 | "device" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "### Define the function" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "pdf, cost = Himmelblaue_4D(alpha=0.25)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "### Define the domain and the discretization" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# Define the domain of the function\n", 86 | "L = 5 # [-L,L]^2 is the domain of the function\n", 87 | "# domain of task params: domain of coefficients a and b in Himmelblaue \n", 88 | "domain_task = [torch.linspace(1,15,100).to(device)]+[torch.linspace(1,15,500).to(device)] \n", 89 | "# domain of decision variables\n", 90 | "domain_decision = [torch.linspace(-L,L,100).to(device)]*2 # domain of x-y coordinates \n", 91 | "domain = domain_task+domain_decision\n" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# Find the tt-model corresponding to the pdf\n", 101 | "tt_model = tt_utils.cross_approximate(fcn=pdf, domain=domain, \n", 102 | " rmax=200, nswp=20, eps=1e-3, verbose=True, \n", 103 | " kickrank=5, device=device)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "### Fit the TT-Model" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "# Refine the discretization and interpolate the model\n", 120 | "scale_factor = 20\n", 121 | "site_list = torch.arange(len(domain))#len(domain_task)+torch.arange(len(domain_decision))\n", 122 | "domain_new = tt_utils.refine_domain(domain=domain, \n", 123 | " site_list=site_list,\n", 124 | " scale_factor=scale_factor, device=device)\n", 125 | "tt_model_new = tt_utils.refine_model(tt_model=tt_model, \n", 126 | " site_list=site_list,\n", 127 | " scale_factor=scale_factor, device=device)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "ttgo = TTGO(tt_model=tt_model_new,domain=domain_new,cost=cost, device=device)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "# torch.save([ttgo.tt_model,domain],'himmel4D.pickle')" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "### Sample from TT-Model" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "a=14; b=2.\n", 162 | "x_task = torch.tensor([a,b]).view(1,-1).to(device) #given task-parameters\n", 163 | "n_samples_tt = 100\n", 164 | "samples = ttgo.sample_tt(n_samples=n_samples_tt, x_task=x_task.view(1,-1), alpha=0.9) " 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": {}, 170 | "source": [ 171 | "### Choose the best sample as an estimate for optima" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "best_estimate = ttgo.choose_best_sample(samples)[0]\n", 181 | "top_k_estimate = ttgo.choose_top_k_sample(samples,k=50)[0] # for multiple solutions" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": {}, 187 | "source": [ 188 | "##### Fine-tune the estimate using gradient-based optimization" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "ttgo_optimized, _ = ttgo.optimize(best_estimate)\n", 198 | "\n", 199 | "ttgo_optimized_k = 1*top_k_estimate\n", 200 | "for i, x in enumerate(ttgo_optimized_k):\n", 201 | " ttgo_optimized_k[i], _ = ttgo.optimize(x)\n" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "print(\"PDF at the estimated point: \", pdf(best_estimate))\n", 211 | "print(\"PDF at the optima: \", pdf(ttgo_optimized))" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "print(\"Estimated Optima: \", best_estimate)\n", 221 | "print(\"Optima: \", ttgo_optimized)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "### Visualization" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "# Redefinig the function given the coefficients\n", 238 | "def cost_fcn(X):\n", 239 | " X = torch.from_numpy(X)\n", 240 | " X_ext = torch.empty(X.shape[0],4)\n", 241 | " X_ext[:,:2] = x_task\n", 242 | " X_ext[:,2:] = X\n", 243 | " return cost(X_ext.to(device)).cpu().numpy()" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "x = np.linspace(-L,L,200)\n", 253 | "y = np.linspace(-L,L,200)\n", 254 | "data = samples[0,:,2:].cpu()\n", 255 | "\n", 256 | "plt=plot_contour(x,y,cost_fcn,data=data, contour_scale=1000, figsize=10, markersize=1)\n", 257 | "# plt.plot(ttgo_optimized[:,2],ttgo_optimized[:,3],'*r',markersize=10)\n", 258 | "plt.plot(ttgo_optimized_k[:,2].cpu(),ttgo_optimized_k[:,3].cpu(),'.r',markersize=10)\n", 259 | "# plt.legend([\"samples\",\"optima\"])\n", 260 | "# plt.title(r\"Himmelblau: $cost=(x^2+y-{})^2+(x+y^2-{})^2$\".format(a,b))\n", 261 | "# plt.savefig('Himmelblau4D_a13_b5_alpha0_ns1000_k10.png',pad_inches=0.01, dpi=300)\n", 262 | "# plt.plot(gott_top_k_estimate[:,2],gott_top_k_estimate[:,3],'*r',markersize=8)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [] 271 | } 272 | ], 273 | "metadata": { 274 | "kernelspec": { 275 | "display_name": "Python 3 (ipykernel)", 276 | "language": "python", 277 | "name": "python3" 278 | }, 279 | "language_info": { 280 | "codemirror_mode": { 281 | "name": "ipython", 282 | "version": 3 283 | }, 284 | "file_extension": ".py", 285 | "mimetype": "text/x-python", 286 | "name": "python", 287 | "nbconvert_exporter": "python", 288 | "pygments_lexer": "ipython3", 289 | "version": "3.9.7" 290 | }, 291 | "vscode": { 292 | "interpreter": { 293 | "hash": "cf96f6c213ba3f9333b362e3bb271376c1f8feeec3b85b92580d68346ee16de3" 294 | } 295 | } 296 | }, 297 | "nbformat": 4, 298 | "nbformat_minor": 4 299 | } 300 | -------------------------------------------------------------------------------- /function_optimization/Rosenbrock_nD.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Optimization of n-D Rosenbrock function\n", 8 | "\n", 9 | "Reference: https://en.wikipedia.org/wiki/Rosenbrock_function\n", 10 | "\n", 11 | "$$ cost(a,b,\\mathbf{x}) = \\sum_{i=0}^{N/2+1} b (x_{2i+1}-x_{2i})^2 + (x_{2i}-a)^2 ,$$\n", 12 | "$$pdf(a,b,\\mathbf{x}) = e^{-cost(a,b,\\mathbf{x})}$$ \n", 13 | "\n", 14 | "Here, $\\mathbf{x}_{task}=(a,b)$ and $\\mathbf{x}_{decision} = \\mathbf{x}$\n", 15 | "\n", 16 | "The global optima is uniquely given by $(a,a^2,a,a^2,\\ldots, a,a^2)$\n", 17 | "\n", 18 | "We show that TTGO is able to find the global optima consistently with a hand few of samples from the constructed tt-model of the above pdf (constructed offline) for various selection of $\\mathbf{x}_{task}$ in the online phase. However, a naive approach of sampling from uniform distribution to initialize a Netwon-type optimizer does not work for larger $n$ (try $n=10$ in this notebook). We use scipy's SLSQP to fine tune the initialization provided by TTGO and uniform distribution.\n", 19 | "\n", 20 | "Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/\n", 21 | " Written by Suhan Shetty ,\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import torch\n", 31 | "torch.set_default_dtype(torch.float64)\n", 32 | "\n", 33 | "import numpy as np\n", 34 | "import sys\n", 35 | "sys.path.append('./fcn_opt')\n", 36 | "sys.path.append('../')\n", 37 | "\n", 38 | "from ttgo import TTGO\n", 39 | "from test_fcns import Rosenbrock_nD \n", 40 | "from fcn_plotting_utils import plot_surf, plot_contour\n", 41 | "\n", 42 | "%load_ext autoreload\n", 43 | "np.set_printoptions(precision=2)\n", 44 | "torch.set_printoptions(precision=2)\n", 45 | "\n", 46 | "%autoreload 2" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "#### Define the cost and the correpsonding pdf function" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "n=20\n", 63 | "pdf, cost = Rosenbrock_nD(n=n,alpha=0.01) # n>=4" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "#### Define the domain and the discretization" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Define the domain of the function\n", 80 | "\n", 81 | "L = 2 # [-L,L]^n is the domain of the function\n", 82 | "\n", 83 | "# domain of task params: domain of coefficients a and b in Rosenbrock_4D\n", 84 | "# Note: a should in (-sqrt(L), sqrt(L)) \n", 85 | "domain_task = [torch.linspace(-np.sqrt(L),np.sqrt(L),500)]+[torch.linspace(50,150,500)] \n", 86 | "# domain of decison varibales\n", 87 | "domain_decision = [torch.linspace(-L,L,500)]*(n)\n", 88 | "domain = domain_task+domain_decision" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "### Fit the TT-Model" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "# Find the tt-model corresponding to the pdf\n", 105 | "tt_model = tt_utils.cross_approximate(fcn=pdf, domain=domain, \n", 106 | " rmax=200, nswp=20, eps=1e-3, verbose=True, \n", 107 | " kickrank=5, device=device)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "ttgo = TTGO(domain=domain,tt_model=tt_model, cost=cost)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "#### Specify task parameters" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 7, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "a = 0; b = 100;\n", 133 | "x_task = torch.tensor([a,b]).to(device).view(1,-1) #given task-parameters" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "### Sample from TT-Model" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 10, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "n_samples_tt = 5\n", 150 | "\n", 151 | "samples = ttgo.sample_tt(n_samples=n_samples_tt, x_task=x_task, alpha=0.5) \n" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "### Choose the best sample as an estimate for optima (initial guess)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 11, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "best_estimate = ttgo.choose_best_sample(samples)[0]\n", 168 | "top_k_estimate = ttgo.choose_top_k_sample(samples,k=1)[0] # for multiple solutions" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": {}, 174 | "source": [ 175 | "##### Fine-tune the estimate using gradient-based optimization" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 12, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "ttgo_optimized,_ = ttgo.optimize(best_estimate)\n", 185 | "\n", 186 | "ttgo_optimized_k = 1*top_k_estimate\n", 187 | "for i, x in enumerate(ttgo_optimized_k):\n", 188 | " ttgo_optimized_k[i],_ = ttgo.optimize(x)\n" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 13, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "name": "stdout", 198 | "output_type": "stream", 199 | "text": [ 200 | "PDF at the estimated point(initial guess): tensor([0.01])\n", 201 | "PDF at the TTGO Optima: tensor([1.00])\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "print(\"PDF at the estimated point(initial guess): \", pdf(best_estimate))\n", 207 | "print(\"PDF at the TTGO Optima: \", pdf(ttgo_optimized))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 14, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "Estimated Optima: tensor([[-2.83e-03, 9.99e+01, -7.66e-01, 9.10e-01, 1.11e+00, 9.74e-01,\n", 220 | " 6.37e-01, 2.20e-01, 2.53e-01, 2.53e-01, 4.01e-03, -2.36e-01,\n", 221 | " 5.21e-02, -5.17e-01, 3.61e-02, 7.90e-01, -6.61e-01, 8.14e-01,\n", 222 | " -2.69e-01, -2.00e-02, 8.42e-02, 3.09e-01]])\n", 223 | "Optima from TTGO: tensor([[-2.83e-03, 9.99e+01, -1.95e-02, 8.09e-04, 4.06e-04, 1.53e-04,\n", 224 | " 1.97e-04, -3.36e-04, -7.23e-05, -3.59e-05, -5.69e-07, -1.52e-05,\n", 225 | " -9.01e-08, -1.04e-04, -2.93e-05, -5.60e-05, -4.27e-05, 7.92e-06,\n", 226 | " 3.55e-05, -2.10e-05, 2.33e-05, -7.57e-05]])\n" 227 | ] 228 | } 229 | ], 230 | "source": [ 231 | "print(\"Estimated Optima: \", best_estimate)\n", 232 | "print(\"Optima from TTGO: \", ttgo_optimized)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "#### Global Optima from Analytical Evaluation" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 15, 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "Global Optima (analytical): tensor([[ 0, 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 252 | " 0, 0, 0, 0, 0, 0, 0, 0]])\n", 253 | "PDF at the gloabl optima: tensor([1.])\n" 254 | ] 255 | } 256 | ], 257 | "source": [ 258 | "global_optima = torch.tensor([ [a,b] +[a**(i%2+1) for i in range(n)]]).view(1,-1)\n", 259 | "print(\"Global Optima (analytical): \", global_optima )\n", 260 | "print(\"PDF at the gloabl optima: \", pdf(global_optima))" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [] 269 | } 270 | ], 271 | "metadata": { 272 | "kernelspec": { 273 | "display_name": "Python 3 (ipykernel)", 274 | "language": "python", 275 | "name": "python3" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.9.7" 288 | } 289 | }, 290 | "nbformat": 4, 291 | "nbformat_minor": 4 292 | } 293 | -------------------------------------------------------------------------------- /function_optimization/fcn_plotting_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty , 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | 21 | import seaborn as sns 22 | import matplotlib 23 | import matplotlib.pyplot as plt 24 | import numpy as np 25 | from matplotlib.ticker import LinearLocator 26 | import warnings 27 | from matplotlib.colors import LogNorm 28 | from matplotlib import ticker, cm 29 | warnings.filterwarnings("ignore") 30 | 31 | def plot_surf(x,y,cost,data=None,zlim=(0,1000),figsize=10, view_angle=(45,45),markersize=3): 32 | 33 | fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) 34 | fig.set_size_inches(figsize,figsize) 35 | 36 | Z = np.empty((len(x),len(y))) 37 | 38 | X, Y = np.meshgrid(x, y) 39 | XY = np.array([X.reshape(-1,),Y.reshape(-1,)]).T 40 | Z = cost(XY).reshape(X.shape[0],X.shape[1]) 41 | 42 | cmap = sns.cm.rocket_r 43 | surf = ax.plot_surface(X, Y, Z,cmap=cmap, 44 | linewidth=0, antialiased=False, zorder=0, alpha=1) 45 | # Customize the z axis. 46 | ax.set_zlim(zlim[0], zlim[1]) 47 | ax.zaxis.set_major_locator(LinearLocator(10)) 48 | ax.zaxis.set_major_formatter('{x:.1f}') 49 | 50 | if not (data is None): 51 | data_z = 0 52 | if len(data.shape)==3: 53 | data_z = data[:,2] 54 | ax.plot(data[:,0],data[:,1],data_z,'ob', markersize=markersize, zorder=10) 55 | 56 | ax.view_init(view_angle[0], view_angle[1]) 57 | # Add a color bar which maps values to colors. 58 | fig.colorbar(surf, shrink=0.4, aspect=5) 59 | 60 | return plt 61 | 62 | def plot_contour(x,y,cost, data=None, contour_scale=100, figsize=10, markersize=3,log_norm=True): 63 | plt.style.use('seaborn-white') 64 | Z = np.empty((len(x),len(y))) 65 | X, Y = np.meshgrid(x, y) 66 | XY = np.array([X.reshape(-1,),Y.reshape(-1,)]).T 67 | Z = cost(XY).reshape(X.shape[0],X.shape[1]) 68 | sns.set_style("white") 69 | cmap = 'binary_r' 70 | if log_norm == True: 71 | levels = 10**(0.25*np.arange(-6,14)) 72 | cs = plt.contour(X, Y, Z, contour_scale, cmap=cmap, shade=True,locator=ticker.LogLocator(), 73 | levels=levels, norm=LogNorm(), alpha=1); 74 | else: 75 | cs = plt.contour(X, Y, Z, contour_scale, cmap=cmap, shade=True,alpha=1); 76 | 77 | plt.colorbar(cs); 78 | if not (data is None): 79 | plt.plot(data[:,0],data[:,1],'ob', markersize=markersize) 80 | plt.rcParams["figure.figsize"] = (figsize, figsize) 81 | 82 | return plt 83 | -------------------------------------------------------------------------------- /function_optimization/sine_nD.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/idiap/temp/sshetty/miniconda/envs/pyml/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n", 14 | "Matplotlib created a temporary config/cache directory at /tmp/matplotlib-vn_v6jr1 because the default path (/idiap/home/sshetty/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import torch\n", 20 | "import numpy as np\n", 21 | "import sys\n", 22 | "sys.path.append('./fcn_opt')\n", 23 | "sys.path.append('../')\n", 24 | "\n", 25 | "from ttgo import TTGO\n", 26 | "import tt_utils\n", 27 | "from test_fcns import sine_nD \n", 28 | "from fcn_plotting_utils import plot_surf, plot_contour\n", 29 | "\n", 30 | "%load_ext autoreload\n", 31 | "np.set_printoptions(precision=2)\n", 32 | "torch.set_printoptions(precision=2)\n", 33 | "\n", 34 | "%autoreload 2" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "device(type='cuda')" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 55 | "device" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "#### Define the cost and the correpsonding pdf function" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "n=3\n", 72 | "pdf, cost = sine_nD(n=n,alpha=1,device=device) # n>=4" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "#### Define the domain and the discretization" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# Define the domain of the function\n", 89 | "\n", 90 | "L = 2 # [-L,L]^n is the domain of the function\n", 91 | "\n", 92 | "# domain of task params: domain of coefficients a and b in Rosenbrock_4D\n", 93 | "# Note: a should in (-sqrt(L), sqrt(L)) \n", 94 | "domain_task = [torch.linspace(-np.sqrt(L),np.sqrt(L),1000).to(device)]+[torch.linspace(50,150,1000).to(device)] \n", 95 | "# domain of decison varibales\n", 96 | "domain_decision = [torch.linspace(-L,L,1000).to(device)]*(n)\n", 97 | "domain = domain_task+domain_decision" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 5, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "cross device is cuda\n", 110 | "Cross-approximation over a 5D domain containing 1e+15 grid points:\n", 111 | "iter: 0 | tt-error: 1.159e+00, test-error:1.195e-01 | time: 1.2515 | largest rank: 1\n", 112 | "iter: 1 | tt-error: 1.199e-01, test-error:6.690e-09 | time: 1.2934 | largest rank: 6\n", 113 | "iter: 2 | tt-error: 0.000e+00, test-error:9.511e-14 | time: 1.3603 | largest rank: 11 <- converged: eps < 0.001\n", 114 | "Did 1002000 function evaluations, which took 0.01277s (7.846e+07 evals/s)\n", 115 | "\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "# Find the tt-model corresponding to the pdf\n", 121 | "tt_model = tt_utils.cross_approximate(fcn=pdf, domain=domain, \n", 122 | " rmax=200, nswp=20, eps=1e-3, verbose=True, \n", 123 | " kickrank=5, device=device)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "### Fit the TT-Model" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 6, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "ttgo = TTGO(tt_model=tt_model,domain=domain,cost=cost,device=device)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "#### Specify task parameters" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 7, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "x_task = torch.tensor([1]*3).view(1,-1).to(device) #given task-parameters" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "### Sample from TT-Model" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 8, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "n_samples_tt = 50\n", 172 | "\n", 173 | "samples = ttgo.sample_tt(n_samples=n_samples_tt, x_task=x_task, alpha=0.9) \n" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "### Choose the best sample as an estimate for optima (initial guess)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 9, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "best_estimate = ttgo.choose_best_sample(samples)[0]\n", 190 | "top_k_estimate = ttgo.choose_top_k_sample(samples,k=10)[0] # for multiple solutions" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "##### Fine-tune the estimate using gradient-based optimization" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 10, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "ttgo_optimized,_ = ttgo.optimize(best_estimate)\n", 207 | "\n", 208 | "ttgo_optimized_k = 1*top_k_estimate\n", 209 | "for i, x in enumerate(ttgo_optimized_k):\n", 210 | " ttgo_optimized_k[i],_ = ttgo.optimize(x)\n" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 11, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "name": "stdout", 220 | "output_type": "stream", 221 | "text": [ 222 | "PDF at the estimated point(initial guess): tensor([0.74], device='cuda:0')\n", 223 | "PDF at the TTGO Optima: tensor([0.98], device='cuda:0')\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "print(\"PDF at the estimated point(initial guess): \", pdf(best_estimate))\n", 229 | "print(\"PDF at the TTGO Optima: \", pdf(ttgo_optimized))" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 12, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "Estimated Optima: tensor([[1.00, 1.00, 1.00, 2.00, 1.33]], device='cuda:0')\n", 242 | "Optima from TTGO: tensor([[ 1.01, 50.37, 1.01, 2.00, 1.34]], device='cuda:0')\n" 243 | ] 244 | } 245 | ], 246 | "source": [ 247 | "print(\"Estimated Optima: \", best_estimate)\n", 248 | "print(\"Optima from TTGO: \", ttgo_optimized)" 249 | ] 250 | } 251 | ], 252 | "metadata": { 253 | "kernelspec": { 254 | "display_name": "pyml", 255 | "language": "python", 256 | "name": "python3" 257 | }, 258 | "language_info": { 259 | "codemirror_mode": { 260 | "name": "ipython", 261 | "version": 3 262 | }, 263 | "file_extension": ".py", 264 | "mimetype": "text/x-python", 265 | "name": "python", 266 | "nbconvert_exporter": "python", 267 | "pygments_lexer": "ipython3", 268 | "version": "3.9.7" 269 | }, 270 | "vscode": { 271 | "interpreter": { 272 | "hash": "cf96f6c213ba3f9333b362e3bb271376c1f8feeec3b85b92580d68346ee16de3" 273 | } 274 | } 275 | }, 276 | "nbformat": 4, 277 | "nbformat_minor": 4 278 | } 279 | -------------------------------------------------------------------------------- /function_optimization/test_fcns.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty , 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | 21 | import torch 22 | import numpy as np 23 | torch.set_default_dtype(torch.float64) 24 | 25 | def Rosenbrock_2D(a=1, b=100, alpha=1): 26 | ''' 27 | a 2D version of the Rosenbrock function with fixed coefficients (a,b) 28 | https://en.wikipedia.org/wiki/Rosenbrock_function 29 | ''' 30 | def cost(x): 31 | result = b*(x[:,1]-x[:,0]**2)**2 + (x[:,0]-a)**2 32 | return result 33 | 34 | def pdf(x): 35 | return torch.exp(-alpha*cost(x)) 36 | 37 | return pdf, cost 38 | 39 | 40 | def Rosenbrock_4D(alpha=1): 41 | ''' 42 | a 4D version of the Rosenbrock function with coefficients considered as variables of the function 43 | a=1, b=100 represents the standard 2D Rosenbrock function 44 | https://en.wikipedia.org/wiki/Rosenbrock_function 45 | ''' 46 | def cost(x): 47 | result = x[:,1]*(x[:,3]-x[:,2]**2)**2 + (x[:,2]-x[:,0])**2 48 | return result 49 | 50 | def pdf(x): 51 | return torch.exp(-alpha*cost(x)) 52 | 53 | return pdf, cost 54 | 55 | 56 | def Rosenbrock_nD(n=2,alpha=1): 57 | ''' 58 | nD version of Rosenbrock function: https://en.wikipedia.org/wiki/Rosenbrock_function 59 | actual minima is at (a, a^2, a, a^2,...,a, a^2). 60 | Domain: [-2,2] 61 | ''' 62 | 63 | def cost(x): 64 | a = x[:,0] 65 | b = x[:,1] 66 | y = x[:,2:] 67 | result = 0. 68 | for i in range(y.shape[1]-1): 69 | result = result+b*(y[:,i+1]-y[:,i]**2)**2 + (y[:,i]-a)**2 70 | return result 71 | 72 | def pdf(x): 73 | return torch.exp(-alpha*cost(x)) 74 | 75 | return pdf, cost 76 | 77 | def Rosenbrock_nD_2(n=2,alpha=1): 78 | ''' 79 | nD version of Rosenbrock function: https://en.wikipedia.org/wiki/Rosenbrock_function 80 | actual minima is at (a, a^2, a, a^2,...,a, a^2). 81 | Domain: [-2,2] 82 | ''' 83 | assert ((n%2)==0 and n>2), 'n has to be even number greater than 2' 84 | 85 | def cost(x): 86 | a = x[:,0] 87 | b = x[:,1] 88 | y = x[:,2:] 89 | result = 0. 90 | for i in range(int(y.shape[1]/2)): 91 | result = result+b*(y[:,2*i+1]-y[:,2*i]**2)**2 + (y[:,2*i]-a)**2 92 | return result 93 | 94 | def pdf(x): 95 | return torch.exp(-alpha*cost(x))+1e-9 96 | 97 | return pdf, cost 98 | 99 | def Himmelblaue_2D(alpha=1,a=11,b=7): 100 | ''' 101 | a 2D function: https://en.wikipedia.org/wiki/Himmelblau%27s_function 102 | cost(x,y)=(x^2+y-11)^2+(x+y^2-7)^2 103 | Domain: [-5,5] 104 | ''' 105 | def cost(x): 106 | result = (x[:,0]**2+x[:,1]-a)**2 + (x[:,0]+x[:,1]**2-b)**2 107 | return result 108 | 109 | def pdf(x): # Cost-to-PDF transformation 110 | return torch.exp(-alpha*cost(x)) # or use: 1/(eps+cost(x)) 111 | 112 | return pdf, cost 113 | 114 | 115 | 116 | def Himmelblaue_4D(alpha=1): 117 | ''' 118 | a 4D version of the Himmelblaue2D function with coefficients considered as variables of the function 119 | cost(a,b,x,y)=(x^2+y-a)^2+(x+y^2-b)^2 120 | a=11, b=7 represents the standard 2D Himmelblaue function 121 | ''' 122 | def cost(x): 123 | result = (x[:,2]**2+x[:,3]-x[:,0])**2 + (x[:,2]+x[:,3]**2-x[:,1])**2 #11, 7 124 | return result 125 | 126 | def pdf(x): 127 | return torch.exp(-alpha*cost(x)) + 1e-9 #or use: 1/(eps+cost(x))# 128 | 129 | return pdf, cost 130 | 131 | 132 | def gmm(n=2,nmix=3,L=1,mx_coef=None,mu=None,s=0.1,device='cpu'): 133 | """ 134 | Mixture of spherical Gaussians (un-normalized) 135 | nmix: number of mixture coefficients 136 | n: dimension of the domain 137 | s: variance 138 | mu: the centers assumed to be in : [-L,L]^n 139 | """ 140 | n_sqrt = torch.sqrt(torch.tensor([n])).to(device) 141 | if mx_coef is None: # if centers and mixture coef are not given, generate them randomly 142 | mx_coef = torch.rand(nmix) 143 | mx_coef = mx_coef/torch.sum(mx_coef) 144 | mu = ((torch.rand(nmix,n)-0.5)*2*L).to(device) 145 | 146 | def pdf(x): 147 | result = torch.tensor([0]).to(device) 148 | for k in range(nmix): 149 | l = torch.linalg.norm(mu[k]-x, dim=1)/n_sqrt 150 | result = result + mx_coef[k]*torch.exp(-(l/s)**2) 151 | return result 152 | 153 | def cost(x): 154 | return 1.-pdf(x) 155 | 156 | return pdf, cost 157 | 158 | 159 | def sine_nD(n=2, alpha=1, device='cpu'): 160 | "an nD sinusoidal surface" 161 | n_sqrt = torch.sqrt(torch.tensor([n])).to(device) 162 | def pdf(x): 163 | return 0.49*(1.001+torch.sin(4*torch.pi*torch.linalg.norm(x,dim=1)/n_sqrt)) 164 | 165 | def cost(x): 166 | return 1.-pdf(x) 167 | 168 | return pdf, cost 169 | 170 | 171 | -------------------------------------------------------------------------------- /manipulator/data/sdf.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/sdf.npy -------------------------------------------------------------------------------- /manipulator/data/sdf_close_ground.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/sdf_close_ground.npy -------------------------------------------------------------------------------- /manipulator/data/sphere_setting.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/sphere_setting.npy -------------------------------------------------------------------------------- /manipulator/data/urdf/env.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 10 | 17 | 18 | 19 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/meshes/collision/finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/frankaemika_new/meshes/collision/finger.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/meshes/collision/hand.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/frankaemika_new/meshes/collision/hand.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/meshes/collision/link0.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/frankaemika_new/meshes/collision/link0.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/meshes/collision/link1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/frankaemika_new/meshes/collision/link1.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/meshes/collision/link2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/frankaemika_new/meshes/collision/link2.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/meshes/collision/link3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/frankaemika_new/meshes/collision/link3.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/meshes/collision/link4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/frankaemika_new/meshes/collision/link4.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/meshes/collision/link5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/frankaemika_new/meshes/collision/link5.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/meshes/collision/link6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/frankaemika_new/meshes/collision/link6.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/meshes/collision/link7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/frankaemika_new/meshes/collision/link7.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/panda_arm.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | -------------------------------------------------------------------------------- /manipulator/data/urdf/frankaemika_new/panda_arm_franka.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | -------------------------------------------------------------------------------- /manipulator/data/urdf/mesh.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/mesh.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/meshes/collision/finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/meshes/collision/finger.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/meshes/collision/hand.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/meshes/collision/hand.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/meshes/collision/link0.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/meshes/collision/link0.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/meshes/collision/link1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/meshes/collision/link1.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/meshes/collision/link2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/meshes/collision/link2.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/meshes/collision/link3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/meshes/collision/link3.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/meshes/collision/link4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/meshes/collision/link4.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/meshes/collision/link5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/meshes/collision/link5.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/meshes/collision/link6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/meshes/collision/link6.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/meshes/collision/link7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/meshes/collision/link7.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/panda_arm.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | -------------------------------------------------------------------------------- /manipulator/data/urdf/quadrotor.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /manipulator/data/urdf/ur10/meshes/ur10/collision/Base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/ur10/meshes/ur10/collision/Base.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/ur10/meshes/ur10/collision/Forearm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/ur10/meshes/ur10/collision/Forearm.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/ur10/meshes/ur10/collision/Shoulder.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/ur10/meshes/ur10/collision/Shoulder.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/ur10/meshes/ur10/collision/UpperArm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/ur10/meshes/ur10/collision/UpperArm.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/ur10/meshes/ur10/collision/Wrist1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/ur10/meshes/ur10/collision/Wrist1.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/ur10/meshes/ur10/collision/Wrist2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/ur10/meshes/ur10/collision/Wrist2.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/ur10/meshes/ur10/collision/Wrist3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/ttgo/cadfac764fd00459f070b4ed3d71f7087efe22d4/manipulator/data/urdf/ur10/meshes/ur10/collision/Wrist3.stl -------------------------------------------------------------------------------- /manipulator/data/urdf/ur10/ur10.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | transmission_interface/SimpleTransmission 182 | 183 | 184 | EffortJointInterface 185 | 1 186 | 187 | 188 | 189 | transmission_interface/SimpleTransmission 190 | 191 | 192 | EffortJointInterface 193 | 1 194 | 195 | 196 | 197 | transmission_interface/SimpleTransmission 198 | 199 | 200 | EffortJointInterface 201 | 1 202 | 203 | 204 | 205 | transmission_interface/SimpleTransmission 206 | 207 | 208 | EffortJointInterface 209 | 1 210 | 211 | 212 | 213 | transmission_interface/SimpleTransmission 214 | 215 | 216 | EffortJointInterface 217 | 1 218 | 219 | 220 | 221 | transmission_interface/SimpleTransmission 222 | 223 | 224 | EffortJointInterface 225 | 1 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /manipulator/manipulator_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty , 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | 21 | import torch 22 | import numpy as np 23 | from roma import rotmat_to_unitquat as tfm 24 | from utils import test_ttgo 25 | 26 | def dist_orientation(Rd_0,v_d,Ra_0): 27 | ''' 28 | Cost on orientation (flexible orientation) 29 | Rd_0: a 3x3 rotation matrix corresponding to the desired orientation (w.r.t world frame) 30 | v_d: 1x3 vector w.r.t. Rd frame w.r.t. which rotation is allowed 31 | Ra_0: ..x3x3 batch of rotation matrices w.r.t world frame 32 | returns distance in range (0,1) 33 | ''' 34 | v_d = (v_d/torch.linalg.norm(v_d)).view(-1) # normalize the axis vector 35 | Rd_0 = Rd_0.view(3,3) 36 | Ra_d = torch.matmul(Ra_0,Rd_0.T) # Ra w.r.t. Rd frame 37 | 38 | qa_d = tfm(Ra_d) # corresponding quarternion (imaginary_vector,real) 39 | va_d = qa_d[:,:-1] 40 | va_d = va_d/(torch.linalg.vector_norm(va_d,dim=1).view(-1,1)+1e-9) # axis vector w.r.t Rd frame to get Ra_d 41 | 42 | d_orient = 1-torch.einsum('ij,j->i',va_d,v_d)**2 43 | return d_orient 44 | 45 | def dist_orientation_fixed(Rd_0, Ra_0,device='cpu'): 46 | ''' 47 | distance between two orientations: Rd_0 (fixed desired orientation), Ra_0 (actual orientation) 48 | ''' 49 | Rd_0 = Rd_0.view(3,3) 50 | qa_d = tfm(torch.matmul(Ra_0,Rd_0.T)) 51 | q0 = torch.tensor([0.,0.,0.,1.]).to(device) 52 | dist_orient = 1-torch.einsum('ij,j->i',qa_d,q0)**2 53 | 54 | return dist_orient 55 | 56 | 57 | 58 | def exp_space(xmin=-1,xmax=1.,d=100): 59 | ''' 60 | discretization of an interval with exponential sepration from the center 61 | ''' 62 | d1 = int(d/2) 63 | d2 = d - d1 64 | xmid = 0.5*(xmin+xmax) 65 | t1 = np.logspace(0., 1, d1); 66 | t2 = np.logspace(0., 1, d2); 67 | t1 = t1 - t1[0]; t1 = t1/t1[-1] 68 | t2 = t2 - t2[0] +t1[1];t2 = t2/t2[-1]; 69 | t1 = xmid + (xmax-xmid)*t1 70 | t2 = xmid + (xmin-xmid)*np.flip(t2) 71 | t = np.concatenate((t2,t1)) 72 | return torch.from_numpy(t) 73 | 74 | def get_latex_str(results, alphas): 75 | ''' 76 | Used for generating tables for latex 77 | ''' 78 | latex_str_tt = [] 79 | for i in range(results.shape[1]-1): # over aphas 80 | latex_str_tt.append("& "+ "$" +str(round(alphas[i],2)) + "$") 81 | for j in range(results.shape[0]): # over sample_set 82 | for item_ in results[j,i,:]: 83 | latex_str_tt.append( " & " + "$" + str(round(item_.item(),2)) + "$") 84 | latex_str_tt.append(" \\\ ") 85 | latex_str_rand = [] 86 | latex_str_rand.append("&-") 87 | for i in range(results.shape[0]): 88 | for item_ in results[i,-1,:]: 89 | latex_str_rand.append( " & " + "$" +str(round(item_.item(),2))+ "$") 90 | latex_str_rand.append(" \\\ ") 91 | 92 | latex_tt = "" 93 | latex_rand = "" 94 | for str_ in latex_str_tt: 95 | latex_tt+=str_ 96 | for str_ in latex_str_rand: 97 | latex_rand+=str_ 98 | return latex_tt, latex_rand 99 | 100 | 101 | def test_robotics_task(ttgo, cost_all, test_task, alphas, sample_set, cut_total=0.25,device='cpu'): 102 | # for latex 103 | norm = 1 104 | results_union = torch.empty((len(sample_set),len(alphas)+1,3)).to(device) #n_samples x (alpha,rand) x (raw,opt,sucess) 105 | results_intersection = torch.empty((len(sample_set),len(alphas)+1,3)).to(device) #n_samples x (alpha,rand) x (raw,opt,sucess) 106 | for i, n_samples in enumerate(sample_set): 107 | results_rand_union = torch.empty(len(alphas),3).to(device) 108 | results_rand_intersection = torch.empty(len(alphas),3).to(device) 109 | for j, alpha in enumerate(alphas): 110 | costs_tt,costs_tt_opt,costs_rand,costs_rand_opt,tt_nit,rand_nit = test_ttgo(ttgo=ttgo.clone(), cost=cost_all, 111 | test_task=test_task, n_samples_tt=n_samples, 112 | alpha=alpha, norm=norm, device=device, test_rand=True, cut_total=cut_total) 113 | n_test = costs_tt.shape[0] 114 | idx_tt = costs_tt_opt[:,0], 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | # Note: Prefer using jupyter-notebook version as it is more up to date 21 | 22 | import torch 23 | import numpy as np 24 | np.set_printoptions(2, suppress=True) 25 | torch.set_printoptions(2, sci_mode=False) 26 | 27 | import sys 28 | sys.path.append('../') 29 | from ttgo import TTGO 30 | import tt_utils 31 | from utils import test_ttgo 32 | from manipulator_utils import test_robotics_task 33 | from panda_cost_utils import SDF_Cost, PandaCost 34 | from panda_kinematics import PandaKinematics 35 | import argparse 36 | 37 | 38 | import warnings 39 | warnings.filterwarnings("ignore") 40 | 41 | 42 | if __name__ == '__main__': 43 | 44 | ############################################################ 45 | 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--dh_x', type=float, default=0.01) 48 | parser.add_argument('--d0_theta',type=int, default=50) 49 | parser.add_argument('--rmax', type=int, default=500) # max tt-rank for tt-cross 50 | parser.add_argument('--nswp', type=int, default=50) # number of sweeps in tt-cross 51 | parser.add_argument('--kr', type=float, default=5) # kickrank param for tt-cross 52 | parser.add_argument('--b_goal', type=float, default=0.05) #nominal distance of goal 53 | parser.add_argument('--b_obst', type=float, default=0.01) # nominal collision 54 | parser.add_argument('--b_orient', type=float, default=0.2) #nominal error in orientation 55 | parser.add_argument('--margin', type=float, default=0.0) #safety margin of collision for end-effector 56 | parser.add_argument('--d_type', type=str, default='uniform') # or {'log', 'uniform'} disctretization of joint angles 57 | parser.add_argument('--name', type=str) # file nazme for saving 58 | args = parser.parse_args() 59 | file_name = args.name if args.name else "panda-ik-margin-{}-dh-{}-d0_theta-{}-nswp-{}-rmax-{}-kr-{}-b-{}-{}-{}.pickle".format(args.margin, args.dh_x, args.d0_theta, args.nswp, args.rmax, args.kr, args.b_goal, args.b_obst, args.b_orient) 60 | print(file_name) 61 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 62 | ############################################################ 63 | 64 | 65 | with torch.no_grad(): 66 | # Setup the robot and the environment 67 | 68 | data_sdf = np.load('./data/sdf.npy', allow_pickle=True)[()] 69 | sdf_matr = data_sdf['sdf_matr'] 70 | bounds = torch.tensor(data_sdf['bounds']).float().to(device) # Bound of the environment 71 | sdf_tensor = torch.from_numpy(sdf_matr).float().to(device) 72 | sdf_cost = SDF_Cost(sdf_tensor=sdf_tensor, domain=bounds, device=device) 73 | env_bound = data_sdf['env_bound'] 74 | shelf_bound = data_sdf['shelf_bound'] 75 | box_bound = data_sdf['box_bound'] 76 | 77 | # key-points on the body of the robot for collision check 78 | data_keys = np.load('./data/sphere_setting.npy', allow_pickle=True)[()]# key_points 79 | status_array = data_keys['status_array'] 80 | body_radius = data_keys['body_radius'] 81 | relative_pos = data_keys['relative_pos'] 82 | key_points_weight = torch.from_numpy(status_array).float().to(device) >0 # 8xMx1 83 | key_points_weight[-1] = 1*args.margin 84 | key_points_margin = torch.from_numpy(body_radius).float().to(device)# 85 | key_points_pos = torch.from_numpy(relative_pos).float().to(device) 86 | key_points = [key_points_pos, key_points_weight, key_points_margin] 87 | 88 | # define the robot 89 | panda = PandaKinematics(device=device, key_points_data=key_points) 90 | 91 | ############################################################ 92 | 93 | # Define the cost function 94 | 95 | # Specify the doesired orientation 96 | Rd_0 = torch.tensor([[ 0.7071,0.7071,0.], [0.,0.,1],[0.7071, -0.7071, 0.]]).to(device) # desired orientation 97 | v_d = torch.tensor([0.,0.,1.]).to(device) 98 | # Rd = torch.tensor([[ 0,0.,0.], [0.,0.,1],[0., 0., 0.]]) 99 | 100 | pandaCost = PandaCost(robot=panda, sdf_cost=sdf_cost, 101 | Rd_0=Rd_0, v_d=v_d,b_obst=args.b_obst, 102 | b_goal=args.b_goal,b_orient=args.b_orient,device=device) 103 | 104 | 105 | def cost(x): # For inverse kinematics 106 | return pandaCost.cost_ik(x)[:,0] 107 | 108 | def cost_all(x): # For inverse kinematics 109 | return pandaCost.cost_ik(x) 110 | 111 | def pdf(x): 112 | x = x.to(device) 113 | pdf_ = torch.exp(-cost(x)**2) 114 | return pdf_ 115 | 116 | ############################################################################ 117 | 118 | # Define the domain for discretization 119 | 120 | n_joints=7 121 | d_theta_all = [args.d0_theta]*n_joints 122 | d_theta = [int(d_theta_all[joint]) for joint in range(n_joints)] 123 | 124 | # type of discretization of intervals of decision variables 125 | if args.d_type == 'uniform': 126 | domain_decision = [torch.linspace(panda.theta_min[i],panda.theta_max[i],d_theta[i]).to(device) for i in range(n_joints)] 127 | else: # logarithmic scaling 128 | domain_decision = [exp_space(panda.theta_min[i],panda.theta_max[i],d_theta[i]).to(device) for i in range(n_joints)] 129 | 130 | # task space of the manipulator (the shelf) 131 | env_bounds = torch.from_numpy(shelf_bound) 132 | x_min = env_bounds[:,0] 133 | x_max = env_bounds[:,1] 134 | x_max[0] = 0.75; x_min[0]=0.45 135 | x_max[1] = x_max[1]-0.1 136 | x_min[1] = x_min[1]+0.1 137 | x_max[-1] = 0.75; x_min[-1] = 0. 138 | 139 | 140 | domain_task = [torch.linspace(x_min[i], x_max[i], int((x_max[i]-x_min[i])/args.dh_x)) for i in range(3)] 141 | domain = domain_task + domain_decision 142 | print("Discretization: ",[len(x) for x in domain]) 143 | 144 | ####################################################################################### 145 | # Fit the TT-Model 146 | tt_model = tt_utils.cross_approximate(fcn=pdf, domain=[x.to(device) for x in domain], 147 | rmax=200, nswp=20, eps=1e-3, verbose=True, 148 | # Refine the discretization and interpolate the model 149 | scale_factor = 10 150 | site_list = torch.arange(len(domain))#len(domain_task)+torch.arange(len(domain_decision)) 151 | domain_new = tt_utils.refine_domain(domain=domain, 152 | site_list=site_list, 153 | scale_factor=scale_factor, device=device) 154 | tt_model_new = tt_utils.refine_model(tt_model=tt_model.to(device), 155 | site_list=site_list, 156 | scale_factor=scale_factor, device=device) kickrank=5, device=device) 157 | 158 | 159 | ttgo = TTGO(tt_model=tt_model_new, domain=domain_new, cost=cost,device=device) 160 | 161 | 162 | 163 | ############################################################ 164 | print("############################") 165 | print("Test the model") 166 | print("############################") 167 | 168 | # generate test set 169 | ns = 100 170 | test_task = torch.zeros(ns,len(domain_task)).to(device) 171 | for i in range(len(domain_task)): 172 | unif = torch.distributions.uniform.Uniform(low=domain_task[i][0],high=domain_task[i][-1]) 173 | test_task[:,i]= torch.tensor([unif.sample() for i in range(ns)]).to(device) 174 | 175 | 176 | torch.save({ 177 | 'tt_model':ttgo.tt_model, 178 | 'panda': panda, 179 | 'pandaCost':pandaCost, 180 | 'sdf_cost':sdf_cost, 181 | 'w': (pandaCost.w_goal,pandaCost.w_obst,pandaCost.w_orient), 182 | 'b': (args.b_goal,args.b_obst,args.b_orient), 183 | 'margin': args.margin, 184 | 'key_points_weight':key_points_weight, 185 | 'key_points_margin':key_points_margin, 186 | 'domains': domain, 187 | 'Rd_0': Rd_0, 188 | 'v_d':v_d, 189 | 'test_task': test_task, 190 | }, file_name) 191 | 192 | 193 | # Test the model 194 | sample_set = [1,10,100,1000] 195 | alphas = [0.9,0.75,0.5,0] 196 | cut_total=0.33 197 | test_robotics_task(ttgo.clone(), cost_all, test_task, alphas, sample_set, cut_total,device) -------------------------------------------------------------------------------- /manipulator/panda_kinematics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty , 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | import torch 21 | torch.set_default_dtype(torch.float64) 22 | 23 | class PandaKinematics: 24 | ''' 25 | Manipulator: Franka-Emika Panda 26 | Get the position of the key-points on the body of the manipulator and 27 | the pose (position and orientation) of the 28 | end-effector given a batch of joint angles 29 | ''' 30 | def __init__(self, device="cpu", key_points_data=None): 31 | self.device = device 32 | self.n_joints = 7 33 | self.theta_max_robot = torch.tensor([2.8973, 1.7628, 2.5, 34 | -0.0698, 2.8973, 3.7525, 35 | 2.8973]).to(device) 36 | self.theta_min_robot = torch.tensor([-2.8973, -1.7628, -2.5, 37 | -3.0718, -2.8973, -0.0175, 38 | -2.8973]).to(device) 39 | 40 | # a factor of safety to strictly avoid joint limits 41 | self.theta_max = self.theta_max_robot - 0.2 42 | self.theta_min = self.theta_min_robot + 0.2 43 | 44 | self.theta_min[4] = self.theta_min[4] + 0.2 45 | self.theta_max[4] = self.theta_max[4] - 0.2 46 | self.theta_min[5] = self.theta_min[5] + 0.2 47 | self.theta_max[5] = self.theta_max[5] - 0.2 48 | 49 | self.max_config = self.theta_max.reshape(1,-1).to(device) 50 | self.min_config = self.theta_min.reshape(1,-1).to(device) 51 | 52 | # DH Params 53 | self.dh_a = torch.tensor([0, 0, 0, 0.0825, 54 | -0.0825, 0, 0.088, 55 | 0]).to(device) 56 | 57 | self.dh_d = torch.tensor([0.333, 0, 0.316, 58 | 0, 0.384, 0, 59 | 0, 0.107+0.103]).to(self.device) # 0.103 is added for the tip of the flange/ee 60 | self.dh_alpha = (torch.pi/2)*torch.tensor([0, -1, 1, 1, -1, 1, 1, 0]).to(device) 61 | 62 | # Define key-points on the surface of the robot for collision detection 63 | if key_points_data is None: # choose the joint locations as the key-points 64 | self.key_points = torch.empty(8,1,3).fill_(0.).to(device).double() 65 | self.key_points_weight = torch.empty(8,1,1).fill_(1./8).to(device).double() 66 | self.key_points_margin = torch.empty(8,1,1).fill_(0.1).to(device).double() 67 | else: 68 | self.key_points = key_points_data[0].to(self.device).double() # 8xMx3 tensor 69 | self.key_points_weight = key_points_data[1].to(device).double() # 8xMx3 tensor 70 | self.key_points_margin = key_points_data[2].to(device).double() # 8xMX3 tensor 71 | 72 | self.n_kp = self.key_points.shape[1] # number of key points per joint 73 | 74 | 75 | # Initialize tranformation matrices 76 | ca = torch.cos(self.dh_alpha) 77 | sa = torch.sin(self.dh_alpha) 78 | Talpha = torch.eye(4,4).reshape(1,4,4).to(device) 79 | Talpha = Talpha.repeat(len(self.dh_alpha),1,1) 80 | Talpha[:,1,1] = ca 81 | Talpha[:,1,2] = -sa 82 | Talpha[:,2,1] = sa 83 | Talpha[:,2,2] = ca 84 | 85 | Ta = torch.eye(4,4).reshape(1,4,4).to(device) 86 | Ta = Ta.repeat(len(self.dh_a),1,1) 87 | Td = Ta.clone() 88 | Ta[:,0,-1] = self.dh_a 89 | Td[:,2,-1] = self.dh_d 90 | self.T_prod = torch.einsum('ijk,ikl,ilm->ijm',Talpha, Ta, Td).to(device) 91 | 92 | # Base frame (adapt it to the actual setup in the lab) 93 | T0 = torch.eye(4,4).reshape(1,1,4,4).to(device) 94 | T0[:,:,0,0] = 0.6157; T0[:,:,1,1]=0.6157; T0[:,:,0,1]=-0.788;T0[:,:,1,0]= 0.788 # offset orientation w.r.t the table 95 | self.T0 = T0.clone() 96 | 97 | def set_device(self,device): 98 | self.device=device 99 | 100 | 101 | 102 | 103 | ################################################################################################################################## 104 | def forward_kin(self,q): 105 | ''' 106 | given the joint angles get the position of key-points, position and orientation 107 | of the end-effector 108 | ''' 109 | self.T = self.computeTransformation(q) # 4D-tensor: batch x joint x (2D-Transformation-matrix) 110 | self.key_positions = self.getKeyPosition() # position of key-points 111 | self.ee_position = self.key_positions[:,-1,-1,:] # end-effector position 112 | self.ee_orientation = self.getEndPose() 113 | return self.key_positions, self.ee_position, self.ee_orientation 114 | 115 | 116 | def getKeyPosition(self): 117 | ''' returns position of all the key points given a batch of joint angles q''' 118 | key_position, _ = self.TransformationToKeyPosition(self.T) # 3D-tensor: batch x keys x position 119 | # Note: key_position[:,-1,:] gives end-effector position 120 | return key_position # 3D array: batch x joint x keys x position 121 | 122 | def getEndPoseEuler(self): 123 | ''' returns pose of end-effetor given a batch of joint angles q''' 124 | end_pose, end_R = self.TransformationToEndPoseEuler(self.T) # 3D-tensor: batch x pose 125 | return end_pose, end_R # 2D array: batch x pose and 3D array of rotation matrices 126 | 127 | def getEndPose(self): 128 | ''' returns pose of end-effetor given a batch of joint angles q''' 129 | _, end_R = self.TransformationToEndPose(self.T) # 3D-tensor: batch x pose 130 | return end_R # 2D array: batch x pose and 3D array of rotation matrices 131 | 132 | 133 | def TransformationToKeyPosition(self, T): 134 | ''' 135 | Given a batch of the Transformation matrices (4D array: batch x joint x Tranform-matrix ) get the 3D positions of the key-points 136 | output pose (3D array): batch x keys x position 137 | ''' 138 | x_joint = T[:, :, :3, -1].to(self.device) # 3D array of position of joints: batch x joint x position 139 | R_joint = T[:,:,:3,:3] # batch x joint x rot_matrix 140 | x_key = x_joint.view(x_joint.shape[0],x_joint.shape[1],1,x_joint.shape[2]) + torch.einsum('ijpr,jkr->ijkp',R_joint,self.key_points) # batch x joint x key x position 141 | 142 | return x_key, x_joint # batch x joint x key x x position 143 | 144 | 145 | 146 | def TransformationToEndPoseEuler(self,T): 147 | ''' 148 | Given a batch of the Transformation matrices (4D array: batch x joint x Tranform-matrix ) get the 6D pose (position and euler angles) 149 | of the end-effector 150 | output pose (2D array): batch x pose 151 | ''' 152 | x = T[:, -1, :3, -1].to(self.device) # 2D array of position of end-effector: batch x pose 153 | R = T[:, -1 , :3, :3].to(self.device) # 3D array containing rotation matrices: batch x rotation_matrix 154 | 155 | sy = torch.sqrt(R[:,0, 0] * R[:,0, 0] + R[:,1, 0] * R[:,1, 0]) 156 | 157 | t1 = torch.stack((torch.atan2(R[:,2, 1], R[:,2, 2]), torch.atan2(-R[:,2, 0], sy), 158 | torch.atan2(R[:,1, 0], R[:,0, 0])), dim=1) 159 | t2 = torch.stack((torch.atan2(-R[:,1, 2], R[:,1, 1]),torch.atan2(-R[:,2, 0], sy), 160 | torch.zeros(T.shape[0]).to(self.device)), dim=1) 161 | 162 | sy = (sy.reshape(sy.shape[0],1)).repeat(1,3) 163 | singular = sy < 1e-6 164 | ts = torch.where(singular, t2, t1) # 2D array of orientation: batch x orientation 165 | return torch.cat((x, ts), dim=1), R 166 | 167 | def TransformationToEndPose(self,T): 168 | ''' 169 | Given a batch of the Transformation matrices (4D array: batch x joint x Tranform-matrix ) get the 3D position 170 | of the end-effector and the orientation matrix 171 | ''' 172 | x = T[:, -1, :3, -1].to(self.device) # 2D array of position of end-effector: batch x pose 173 | R = T[:, -1 , :3, :3].to(self.device) # 3D array containing rotation matrices: batch x rotation_matrix 174 | return x,R 175 | 176 | 177 | def computeTransformation(self,q): 178 | ''' Returns transformation matrices: batch x joint x tranformation-matrix ''' 179 | q = q.to(self.device) 180 | q = torch.cat((q, torch.tensor([0]*q.shape[0]).to(self.device).view(-1,1)),dim=1) 181 | cq = torch.cos(q) 182 | sq = torch.sin(q) 183 | Tz = self.T0.repeat(q.shape[0],q.shape[1],1,1) 184 | T = self.T0.repeat(q.shape[0],q.shape[1]+1,1,1) 185 | Tz[:,:,0,0] = cq 186 | Tz[:,:,1,1] = cq 187 | Tz[:,:,0,1] = -1*sq 188 | Tz[:,:,1,0] = sq 189 | T_joints = torch.einsum('jkl,ijlm->ijkm',self.T_prod, Tz) 190 | for i in range(1,q.shape[1]+1): 191 | T[:,i,:,:] = torch.einsum('ijk,ikl->ijl', 1*T[:,i-1,:,:], T_joints[:,i-1,:,:]) 192 | return T[:,1:,:,:] 193 | 194 | 195 | -------------------------------------------------------------------------------- /manipulator/ur10_ik.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty , 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | 21 | import torch 22 | import numpy as np 23 | np.set_printoptions(4, suppress=True) 24 | torch.set_printoptions(4, sci_mode=False) 25 | import sys 26 | sys.path.append('../') 27 | from ttgo import TTGO 28 | from utils import test_ttgo 29 | from ur10_kinematics import Ur10Kinematics 30 | from manipulator_utils import dist_orientation_fixed 31 | import argparse 32 | import warnings 33 | warnings.filterwarnings("ignore") 34 | 35 | 36 | if __name__ == '__main__': 37 | 38 | ############################################################ 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--dh_x', type=float, default=0.01) 42 | parser.add_argument('--d0_theta',type=int, default=60) 43 | parser.add_argument('--rmax', type=int, default=500) # max tt-rank for tt-cross 44 | parser.add_argument('--nswp', type=int, default=30) # number of sweeps in tt-cross 45 | parser.add_argument('--b_goal', type=float, default=0.05) 46 | parser.add_argument('--b_orient', type=float, default=0.2) 47 | parser.add_argument('--kr', type=float, default=3) 48 | parser.add_argument('--d_type', type=str, default='uniform') # or {'log', 'uniform'} disctretization of joint angles 49 | parser.add_argument('--name', type=str) # file nazme for saving 50 | args = parser.parse_args() 51 | file_name = args.name if args.name else "ur10-ik-d0_theta-{}-kr-{}-nswp-{}-rmax-{}-dh-{}-b-{}-{}.pickle".format(args.d0_theta,args.kr, args.nswp, args.rmax, args.dh_x, args.b_orient, args.b_goal) 52 | print(file_name) 53 | 54 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 55 | 56 | ############################################################ 57 | 58 | # Setup the robot and the environment 59 | 60 | ur10 = Ur10Kinematics(device=device) 61 | n_joints= ur10.n_joints 62 | 63 | ############################################################ 64 | 65 | # Desired orientation (fixed orientation) 66 | Rd_0 = torch.eye(3).to(device) 67 | 68 | def cost_all(x): # For inverse kinematics 69 | x = x.to(device) 70 | batch_size = x.shape[0] 71 | goal_loc = x[:,:3] 72 | q = x[:,3:] # batch x joint angles 73 | _, end_loc, end_R = ur10.forward_kin(q) # batch x joint x keys x positions 74 | 75 | # cost on error in end-effector position 76 | d_goal = torch.linalg.norm(end_loc-goal_loc, dim=1) 77 | 78 | # cost on error in end-effector orientation 79 | d_orient = dist_orientation_fixed(Rd_0,end_R,device=device) 80 | 81 | c_total = 0.5*(d_goal/args.b_goal + d_orient/args.b_orient) 82 | 83 | c_return = torch.cat((c_total.view(-1,1), d_goal.view(-1,1), d_orient.view(-1,1)),dim=-1) 84 | return c_return 85 | 86 | def cost(x): 87 | return cost_all(x)[:,0] 88 | 89 | 90 | def pdf(x): 91 | x = x.to(device) 92 | pdf_ = torch.exp(-cost(x)**2) 93 | return pdf_ 94 | 95 | ##################################################################### 96 | 97 | # Define the domain 98 | d_theta_all = [args.d0_theta]*n_joints 99 | d_theta = [int(d_theta_all[joint]) for joint in range(n_joints)] 100 | if args.d_type == 'uniform': 101 | domain_decision = [0.5*torch.linspace(ur10.theta_min[i],0.5*ur10.theta_max[i],d_theta[i]).to(device) for i in range(n_joints)] 102 | else: # logarithmic scaling 103 | domain_decision = [exp_space(0.5*ur10.theta_min[i].to('cpu'),0.5*ur10.theta_max[i].to('cpu'),d_theta[i]).to(device) for i in range(n_joints)] 104 | 105 | # Find the work-space 106 | n_test = 1000 107 | test_theta = torch.zeros(n_test,n_joints).to(device) 108 | for i in range(n_joints): 109 | unif = torch.distributions.uniform.Uniform(low = domain_decision[i][0],high=domain_decision[i][-1]) 110 | test_theta[:,i]= torch.tensor([unif.sample() for i in range(n_test)]).to(device) 111 | _, test_xpos, _ = ur10.forward_kin(test_theta) 112 | x_min,_ = torch.min(test_xpos, dim=0) 113 | x_max,_ = torch.max(test_xpos,dim=0) 114 | x_min[-1] = 0.1 115 | idx_select = test_xpos[:,-1]>x_min[-1] 116 | test_task = test_xpos[idx_select,:] 117 | 118 | # discretize the domain 119 | domain_task = [torch.linspace(x_min[i], x_max[i], int((x_max[i]-x_min[i])/args.dh_x)).to(device) for i in range(3)] 120 | domain = domain_task + domain_decision 121 | print("Discretization: ",[len(x) for x in domain]) 122 | 123 | ############################################################################### 124 | # Fit TT-model 125 | tt_model = tt_utils.cross_approximate(fcn=pdf, domain=[x.to(device) for x in domain], 126 | rmax=200, nswp=20, eps=1e-3, verbose=True, 127 | # Refine the discretization and interpolate the model 128 | scale_factor = 10 129 | site_list = torch.arange(len(domain))#len(domain_task)+torch.arange(len(domain_decision)) 130 | domain_new = tt_utils.refine_domain(domain=domain, 131 | site_list=site_list, 132 | scale_factor=scale_factor, device=device) 133 | tt_model_new = tt_utils.refine_model(tt_model=tt_model.to(device), 134 | site_list=site_list, 135 | scale_factor=scale_factor, device=device) kickrank=5, device=device) 136 | 137 | 138 | ttgo = TTGO(tt_model=tt_model_new, domain=domain_new, cost=cost,device=device) 139 | 140 | 141 | # Save 142 | torch.save({ 143 | 'tt_model':ttgo.tt_model, 144 | 'b': (args.b_goal,args.b_orient), 145 | 'd0':args.d0_theta, 146 | 'dh_x': args.dh_x, 147 | 'domain': domain, 148 | 'Rd_0':Rd_0, 149 | 'test_task':test_task 150 | }, file_name) 151 | 152 | ############################################################## 153 | # Prepare for test 154 | 155 | sites_task = list(range(len(domain_task))) 156 | ttgo.set_sites(sites_task) 157 | 158 | # Test the model 159 | sample_set = [1,10,100,1000] 160 | alphas = [0.9,0.75] 161 | cut_total=0.33 162 | test_robotics_task(ttgo.clone(), cost_all, test_task, alphas, sample_set, cut_total,device) -------------------------------------------------------------------------------- /manipulator/ur10_ik_visualize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "WARNING: torch_batch_svd (https://github.com/KinglittleQ/torch-batch-svd) is not installed and is required for maximum efficiency of special_procrustes. Using torch.svd as a fallback.\n", 13 | "device: cpu\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "import torch\n", 21 | "import numpy as np\n", 22 | "np.set_printoptions(4, suppress=True)\n", 23 | "torch.set_printoptions(4, sci_mode=False)\n", 24 | "import sys\n", 25 | "sys.path.append('../')\n", 26 | "from ttgo import TTGO\n", 27 | "from utils import test_ttgo\n", 28 | "from ur10_kinematics import Ur10Kinematics\n", 29 | "from manipulator_utils import dist_orientation_fixed\n", 30 | "device = 'cpu'#torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 31 | "print(\"device: \", device)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "trained_model = torch.load('ur10_ik.pickle', map_location=torch.device('cpu'))\n", 48 | "b_goal,b_orient = trained_model['b']\n", 49 | "d0_theta = trained_model['d0_theta']\n", 50 | "dh_x = trained_model['dh_x']\n", 51 | "domain = trained_model['domain']\n", 52 | "Rd_0 = trained_model['Rd_0']\n", 53 | "test_task = trained_model['test_task']" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# Setup the robot and the environment\n", 63 | "\n", 64 | "ur10 = Ur10Kinematics(device=device)\n", 65 | "n_joints= ur10.n_joints\n", 66 | "# Desired orientation (fixed orientation)\n", 67 | "\n", 68 | "def cost_all(x): # For inverse kinematics\n", 69 | " x = x.to(device)\n", 70 | " batch_size = x.shape[0]\n", 71 | " goal_loc = x[:,:3]\n", 72 | " q = x[:,3:] # batch x joint angles\n", 73 | " _, end_loc, end_R = ur10.forward_kin(q) # batch x joint x keys x positions\n", 74 | " # cost on error in end-effector position\n", 75 | " d_goal = torch.linalg.norm(end_loc-goal_loc, dim=1)\n", 76 | " # cost on error in end-effector orientation\n", 77 | " d_orient = dist_orientation_fixed(Rd_0,end_R,device=device)\n", 78 | " c_total = 0.5*(d_goal/b_goal + d_orient/b_orient)\n", 79 | " c_return = torch.cat((c_total.view(-1,1), d_goal.view(-1,1), d_orient.view(-1,1)),dim=-1)\n", 80 | " return c_return\n", 81 | "\n", 82 | "def cost(x):\n", 83 | " return cost_all(x)[:,0]\n", 84 | "\n", 85 | "\n", 86 | "def pdf(x):\n", 87 | " x = x.to(device)\n", 88 | " pdf_ = torch.exp(-cost(x)**2) \n", 89 | " return pdf_\n", 90 | "\n", 91 | "#####################################################################\n" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "\n", 101 | "ttgo = TTGO(domain=domain,pdf=pdf,cost=cost)\n", 102 | "ttgo.tt_model = trained_model['tt_model']\n", 103 | "ttgo.to('cpu')\n", 104 | "\n", 105 | "# Prepare for the task\n", 106 | "sites_task = list(range(3))\n", 107 | "ttgo.set_sites(sites_task)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 5, 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "text/plain": [ 118 | "tensor([[1., 0., 0.],\n", 119 | " [0., 1., 0.],\n", 120 | " [0., 0., 1.]])" 121 | ] 122 | }, 123 | "execution_count": 5, 124 | "metadata": {}, 125 | "output_type": "execute_result" 126 | } 127 | ], 128 | "source": [ 129 | "Rd_0" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "## Visualization" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 6, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "import pybullet_data\n", 146 | "from panda_visualization_utils import *\n", 147 | "import pybullet as p\n", 148 | "from functools import partial\n", 149 | "# import the environment (SDF and for graphics visualization in pybullet)\n", 150 | "import sys\n", 151 | "sys.path.append('../../lib')\n", 152 | "\n", 153 | "import sys\n", 154 | "DATA_PATH = './data'\n", 155 | "robot_urdf = DATA_PATH + '/urdf/ur10/ur10.urdf'\n", 156 | "\n", 157 | "\n", 158 | "physics_client_id = p.connect(p.GUI)\n", 159 | "p.setPhysicsEngineParameter(enableFileCaching=0)\n", 160 | "p.setAdditionalSearchPath(pybullet_data.getDataPath())\n", 161 | "p.configureDebugVisualizer(p.COV_ENABLE_GUI,0)\n", 162 | "\n", 163 | "p.resetSimulation()\n", 164 | "\n", 165 | "# for i in range(8):\n", 166 | "# print(i, p.getJointInfo(robot_id, i)[1])\n", 167 | " \n", 168 | "robot_id = p.loadURDF(robot_urdf)\n", 169 | "\n", 170 | "dof = p.getNumJoints(robot_id)\n", 171 | "pb_joint_indices = np.arange(1,7)\n", 172 | "joint_limits = get_joint_limits(robot_id,pb_joint_indices)\n", 173 | "set_q_std = partial(set_q,robot_id, pb_joint_indices)\n", 174 | "\n", 175 | "plane_id = p.loadURDF('plane.urdf')\n", 176 | "p.resetBasePositionAndOrientation(plane_id, (0,0,0.), (0,0,0,1))\n", 177 | "\n", 178 | "#for visualizing the desired target\n", 179 | "_,_,ball_id = create_primitives(radius=0.05)\n", 180 | "\n", 181 | "ee_pb_id = 7\n" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 7, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "# # Prepare for the task\n", 191 | "# sites_task = list(range(3))\n", 192 | "# ttgo.set_sites(sites_task)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 8, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "tensor([-0.00, 0.30, 0.70])\n", 205 | "Time taken: 7.0895466804504395 0.00010991096496582031\n", 206 | "Cost-mean-tt: tensor([ 0.00, 0.00, 0.00])\n" 207 | ] 208 | } 209 | ], 210 | "source": [ 211 | "s = np.random.randint(0,test_task.shape[0]-1)\n", 212 | "sample_xe = torch.tensor([-0.0, 0.3, 0.7]) #\n", 213 | "print(sample_xe)\n", 214 | "\n", 215 | "\n", 216 | "n_solutions=50\n", 217 | "n_samples_tt = 200 #50*n_solutions\n", 218 | "n_samples_rand= 1*n_samples_tt\n", 219 | "\n", 220 | "alpha=0.75; norm=1 ; sample_replace = True; \n", 221 | "\n", 222 | "t1 = time.time()\n", 223 | "samples_tt, samples_idx = ttgo.sample(n_samples=n_samples_tt, x_task=sample_xe.reshape(1,-1),alpha=alpha, norm=norm, sample_replace=sample_replace)\n", 224 | "state_k_tt = ttgo.choose_top_k_sample(samples_tt,n_solutions)\n", 225 | "\n", 226 | "#Optimize\n", 227 | "state_k_tt_opt = 1*state_k_tt\n", 228 | "for i, state in enumerate(state_k_tt):\n", 229 | " state_k_tt_opt[i,:],results= ttgo.optimize(state,bound=True)\n", 230 | "t2 = time.time()\n", 231 | "\n", 232 | "# samples_rand, _ = ttgo.sample_random(n_samples=n_samples_rand, x_task=sample_xe.reshape(1,-1))\n", 233 | "# state_k_rand = ttgo.choose_top_k_sample(samples_rand,n_solutions)\n", 234 | "\n", 235 | "# #Optimize\n", 236 | "# state_k_rand_opt = state_k_rand*1\n", 237 | "# for i, state in enumerate(state_k_rand):\n", 238 | "# state_k_rand_opt[i,:],results= ttgo.optimize(state,bound=True)\n", 239 | "t3=time.time()\n", 240 | "\n", 241 | "print(\"Time taken:\", (t2-t1), t3-t2)\n", 242 | " \n", 243 | "c_tt = cost_all(state_k_tt_opt)\n", 244 | "# c_rand = cost_all(state_k_rand_opt)\n", 245 | "\n", 246 | "print(\"Cost-mean-tt:\",torch.mean(c_tt,dim=0))\n", 247 | "# print(\"Cost-mean-rand:\",torch.mean(c_rand,dim=0))\n", 248 | "\n" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 9, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "x_target = sample_xe[:3].numpy()\n", 258 | "joint_angles_k = state_k_tt[:,3:].numpy() \n", 259 | "joint_angles_k_opt = state_k_tt_opt[:,3:].numpy() " 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "# _,_,test_sphere = create_primitives(p.GEOM_SPHERE, radius = 0.02)\n", 269 | "# p.resetBasePositionAndOrientation(test_sphere, (0,0,1.), (0,0,0,1))\n", 270 | "_ , _,sphere_id = create_primitives(p.GEOM_SPHERE, radius = 0.02)\n", 271 | "pos = x_target[:]\n", 272 | "\n", 273 | "p.resetBasePositionAndOrientation(sphere_id, pos, (0,0,0,1))\n", 274 | "\n", 275 | "\n", 276 | "k = joint_angles_k.shape[0]\n", 277 | "dt = 0.5\n", 278 | "dT = 2\n", 279 | "for i in range(2*k):\n", 280 | " set_q_std(joint_angles_k[i%k])\n", 281 | " time.sleep(dt)\n", 282 | " set_q_std(joint_angles_k_opt[i%k])\n", 283 | " time.sleep(2*dt)\n", 284 | " " 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [] 300 | } 301 | ], 302 | "metadata": { 303 | "kernelspec": { 304 | "display_name": "Python 3", 305 | "language": "python", 306 | "name": "python3" 307 | }, 308 | "language_info": { 309 | "codemirror_mode": { 310 | "name": "ipython", 311 | "version": 3 312 | }, 313 | "file_extension": ".py", 314 | "mimetype": "text/x-python", 315 | "name": "python", 316 | "nbconvert_exporter": "python", 317 | "pygments_lexer": "ipython3", 318 | "version": "3.8.5" 319 | } 320 | }, 321 | "nbformat": 4, 322 | "nbformat_minor": 4 323 | } 324 | -------------------------------------------------------------------------------- /manipulator/ur10_kinematics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty , 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | 21 | import torch 22 | 23 | class Ur10Kinematics: 24 | 25 | def __init__(self, device="cpu", key_points_data=None): 26 | self.device = device 27 | self.n_joints = 6 28 | 29 | # joint limits 30 | self.theta_max = torch.tensor([ 2.*torch.pi]*self.n_joints).to(self.device) 31 | self.theta_min = torch.tensor([-2*torch.pi]*self.n_joints).to(self.device) 32 | self.max_config = self.theta_max.reshape(1,-1).to(device) 33 | self.min_config = self.theta_min.reshape(1,-1).to(device) 34 | 35 | # DH Params 36 | self.dh_a = torch.tensor([0,0.612,0.5723,0.,-0.,0, 0.]).to(self.device) 37 | self.dh_d = torch.tensor([0.1273,0,0.,0.163941,0.1157,0.0922,0.]).to(self.device) # 0.103 is added for the tip of the flange/ee 38 | self.dh_alpha = (torch.pi/2)*torch.tensor([-1,0,0,-1,1,0,0]).to(self.device) 39 | 40 | # key-points on the body of robot for collision check 41 | if key_points_data is None: # choose the joint locations as the key-points 42 | self.key_points = torch.empty(7,1,3).fill_(0.).to(self.device) 43 | self.key_points_weight = torch.empty(7,1,1).fill_(1./7.).to(self.device) 44 | self.key_points_margin = torch.empty(7,1,1).fill_(0.1).to(self.device) 45 | else: 46 | self.key_points = key_points_data[0].to(self.device) # 8xMx3 tensor 47 | self.key_points_weight = key_points_data[1].to(self.device) # 8xMx3 tensor 48 | self.key_points_margin = key_points_data[2].to(self.device) # 8xMX3 tensor 49 | 50 | # prepare trasfomation matrices 51 | ca = torch.cos(self.dh_alpha) 52 | sa = torch.sin(self.dh_alpha) 53 | Talpha = torch.eye(4,4).reshape(1,4,4).to(self.device) 54 | Talpha = Talpha.repeat(len(self.dh_alpha),1,1) 55 | Talpha[:,1,1] = ca 56 | Talpha[:,1,2] = -sa 57 | Talpha[:,2,1] = sa 58 | Talpha[:,2,2] = ca 59 | 60 | Ta = torch.eye(4,4).reshape(1,4,4).to(self.device) 61 | Ta = Ta.repeat(len(self.dh_a),1,1) 62 | Td = Ta.clone() 63 | Ta[:,0,-1] = self.dh_a 64 | Td[:,2,-1] = self.dh_d 65 | self.Td = Td.to(self.device) 66 | self.T_alpha_a = torch.einsum('ijk,ikl->ijl',Talpha, Ta).to(self.device) 67 | 68 | 69 | def forward_kin(self,q): 70 | self.T = self.computeTransformation(q) # 4D-tensor: batch x joint x (2D-Transformation-matrix) 71 | self.key_positions = self.getKeyPosition() # position of key-points 72 | self.ee_position = self.key_positions[:,-1,-1,:] # end-effector position 73 | self.ee_orientation = self.getEndPose() 74 | return self.key_positions, self.ee_position, self.ee_orientation 75 | 76 | def getKeyPosition(self): 77 | ''' returns position of all the key points given a batch of joint angles q''' 78 | key_position = self.TransformationToKeyPosition(self.T) # 3D-tensor: batch x keys x position 79 | # Note: key_position[:,-1,:] gives end-effector position 80 | return key_position # 3D array: batch x joint x keys x position 81 | 82 | def getEndPoseEuler(self): 83 | ''' returns pose of end-effetor given a batch of joint angles q''' 84 | end_pose, end_R = self.TransformationToEndPoseEuler(self.T) # 3D-tensor: batch x pose 85 | return end_pose, end_R # 2D array: batch x pose and 3D array of rotation matrices 86 | 87 | def getEndPose(self): 88 | ''' returns pose of end-effetor given a batch of joint angles q''' 89 | _, end_R = self.TransformationToEndPose(self.T) # 3D-tensor: batch x pose 90 | return end_R # 2D array: batch x pose and 3D array of rotation matrices 91 | 92 | def TransformationToKeyPosition(self, T): 93 | ''' 94 | Given a batch of the Transformation matrices (4D array: batch x joint x Tranform-matrix ) get the 3D positions of the key-points 95 | output key position (3D array): batch x keys x position 96 | ''' 97 | x_joint = T[:, :, :3, -1].to(self.device) # 3D array of position of joints: batch x joint x position 98 | R_joint = T[:,:,:3,:3] # batch x joint x rot_matrix 99 | x_key = x_joint.view(x_joint.shape[0],x_joint.shape[1],1,x_joint.shape[2]) + torch.einsum('ijpr,jkr->ijkp',R_joint,self.key_points) # batch x joint x key x position 100 | 101 | return x_key # batch x joint x keyx x position 102 | 103 | 104 | 105 | def TransformationToEndPoseEuler(self,T): 106 | ''' 107 | Given a batch of the Transformation matrices (4D array: batch x joint x Tranform-matrix ) get the 6D pose (position and euler angles) 108 | of the end-effector 109 | output pose (2D array): batch x pose 110 | ''' 111 | x = T[:, -1, :3, -1].to(self.device) # 2D array of position of end-effector: batch x pose 112 | R = T[:, -1 , :3, :3].to(self.device) # 3D array containing rotation matrices: batch x rotation_matrix 113 | 114 | sy = torch.sqrt(R[:,0, 0] * R[:,0, 0] + R[:,1, 0] * R[:,1, 0]) 115 | 116 | t1 = torch.stack((torch.atan2(R[:,2, 1], R[:,2, 2]), torch.atan2(-R[:,2, 0], sy), 117 | torch.atan2(R[:,1, 0], R[:,0, 0])), dim=1) 118 | t2 = torch.stack((torch.atan2(-R[:,1, 2], R[:,1, 1]),torch.atan2(-R[:,2, 0], sy), 119 | torch.zeros(T.shape[0]).to(self.device)), dim=1) 120 | 121 | sy = (sy.reshape(sy.shape[0],1)).repeat(1,3) 122 | singular = sy < 1e-6 123 | ts = torch.where(singular, t2, t1) # 2D array of orientation: batch x orientation 124 | return torch.cat((x, ts), dim=1), R 125 | 126 | def TransformationToEndPose(self,T): 127 | ''' 128 | Given a batch of the Transformation matrices (4D array: batch x joint x Tranform-matrix ) get the 3D position 129 | of the end-effector and the orientation matrix 130 | ''' 131 | x = T[:, -1, :3, -1].to(self.device) # 2D array of position of end-effector: batch x pose 132 | R = T[:, -1 , :3, :3].to(self.device) # 3D array containing rotation matrices: batch x rotation_matrix 133 | return x,R 134 | 135 | def computeTransformation(self,q): 136 | ''' Returns transformation matrices: batch x joint x tranformation-matrix ''' 137 | q = q.to(self.device) 138 | 139 | q = torch.cat((q, torch.tensor([0]*q.shape[0]).to(self.device).view(-1,1)),dim=1) 140 | cq = torch.cos(q) 141 | sq = torch.sin(q) 142 | 143 | T0 = torch.eye(4,4).reshape(1,1,4,4).to(self.device) 144 | Tz = T0.repeat(q.shape[0],q.shape[1],1,1) 145 | 146 | T = T0.repeat(q.shape[0],q.shape[1]+1,1,1) 147 | 148 | Tz[:,:,0,0] = cq 149 | Tz[:,:,1,1] = cq 150 | Tz[:,:,0,1] = -1*sq 151 | Tz[:,:,1,0] = sq 152 | 153 | T_joints = torch.einsum('ijk,nikl,ilm->nijm',self.Td, Tz, self.T_alpha_a) 154 | 155 | for i in range(1,q.shape[1]+1): 156 | T[:,i,:,:] = torch.einsum('ijk,ikl->ijl', T[:,i-1,:,:].clone(), T_joints[:,i-1,:,:].clone()) 157 | out_T = T[:,1:,:,:].clone() 158 | return out_T 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /toy_robots/planar_manipulator.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty , 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | 21 | import torch 22 | import numpy as np 23 | 24 | class PlanarManipulator: 25 | def __init__(self, n_joints=2, link_lengths=[], max_theta=torch.pi/1.1, n_kp=3, device="cpu"): 26 | ''' 27 | n_joints: number of joints in the planar manipulator 28 | max_theta: max joint angle (same for al joints) 29 | link_lengths: a list containing length of each link 30 | n_kp: number of key-points on each link (for collision check) 31 | ''' 32 | self.device = device 33 | self.n_joints = n_joints 34 | if link_lengths is None: 35 | self.link_lengths = torch.tensor([1./n_joints]*n_joints).to(self.device) 36 | else: 37 | self.link_lengths = link_lengths.to(device) 38 | assert n_joints== link_lengths.shape[0], 'The length of the list containing link_lengths should match n_joints' 39 | 40 | self.max_config = torch.tensor([max_theta]*n_joints).to(self.device) 41 | self.min_config = -1*self.max_config 42 | self.theta_max = self.max_config 43 | self.theta_min = self.min_config 44 | 45 | self.n_kp = n_kp 46 | assert self.n_kp>=2, 'number of key points should be at least two' 47 | self.key_points = torch.empty(self.n_joints,self.n_kp).to(device) 48 | for i in range(n_joints): 49 | self.key_points[i] = torch.arange(0,self.n_kp)/(self.n_kp-1) 50 | 51 | 52 | # forward kinematics 53 | def forward_kin(self, q): 54 | ''' Given a batch of joint angles find the position of all the key-points and the end-effector ''' 55 | batch_size = q.shape[0] 56 | q = torch.clip(q,self.min_config, self.max_config) 57 | q_cumsum = torch.zeros(batch_size,self.n_joints).to(self.device) 58 | for joint in range(self.n_joints): 59 | q_cumsum[:,joint] = torch.sum(q[:,:joint+1],dim=1) 60 | 61 | cq = torch.cos(q_cumsum).view(batch_size,-1,1) 62 | sq = torch.sin(q_cumsum).view(batch_size,-1,1) 63 | cq_sq = torch.cat((cq,sq),dim=2) 64 | 65 | joint_loc = torch.zeros((batch_size, self.n_joints+1, 2)).to(self.device) 66 | key_loc = torch.empty((batch_size, self.n_joints,self.n_kp,2)).to(self.device) 67 | for i in range(self.n_joints): 68 | joint_loc[:,i+1,:] = joint_loc[:,i,:]+self.link_lengths[i]*cq_sq[:,i,:] 69 | key_loc[:,i,:,:] = joint_loc[:,i,:][:,None,:] + (joint_loc[:,i+1,:]-joint_loc[:,i,:])[:,None,:]*self.key_points[i].reshape(1,-1,1) 70 | 71 | end_loc = joint_loc[:,-1,:] 72 | # find the orientation of end-effector in range (0,2*pi) 73 | theta_orient = torch.fmod(q_cumsum[:,-1],2*torch.pi) 74 | theta_orient[theta_orient<0] = 2*torch.pi+theta_orient[theta_orient<0] 75 | 76 | return key_loc, joint_loc, end_loc, theta_orient # output shape: batch_size x (n_joints) x n_kp x 2, batch_size x 2 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /toy_robots/planar_manipulator_ik.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty , 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | import torch 21 | import numpy as np 22 | import sys 23 | sys.path.append('../') 24 | 25 | from planar_manipulator import PlanarManipulator 26 | from cost_utils import PlanarManipulatorCost 27 | from utils import test_ttgo 28 | from ttgo import TTGO 29 | import tt_utils 30 | import time 31 | 32 | np.set_printoptions(precision=4, suppress=True) 33 | torch.set_printoptions(precision=4) 34 | 35 | import argparse 36 | 37 | if __name__ == '__main__': 38 | 39 | ############################################################ 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--d0_x', type=float, default=100) 43 | parser.add_argument('--d0_theta',type=int, default=50) 44 | parser.add_argument('--rmax', type=int, default=500) # max tt-rank for tt-cross 45 | parser.add_argument('--nswp', type=int, default=30) # number of sweeps in tt-cross 46 | parser.add_argument('--margin',type=float, default=0.025) 47 | parser.add_argument('--b_goal', type=float, default=0.1) 48 | parser.add_argument('--b_obst', type=float, default=0.1) 49 | parser.add_argument('--b_orient', type=float, default=1.) 50 | 51 | parser.add_argument('--kr', type=int, default=5) 52 | parser.add_argument('--w_goal',type=float, default=0.4) 53 | parser.add_argument('--w_obst',type=float, default=0.4) 54 | parser.add_argument('--w_orient', type=float, default=0.2) 55 | 56 | parser.add_argument('--n_joints',type=int, default=4) 57 | 58 | args = parser.parse_args() 59 | 60 | file_name = "planar-ik-n_joints-{}-margin-{}-d0_x-{}-d0_theta-{}-nswp-{}-rmax-{}-kr-{}-b-{}-{}-{}.pickle".format(args.n_joints, 61 | args.margin,args.d0_x, args.d0_theta, args.nswp, args.rmax, args.kr, args.b_goal, args.b_obst,args.b_orient) 62 | print(file_name) 63 | 64 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 65 | print("device is ", device) 66 | 67 | ############################################################ 68 | 69 | # Setup the robot and the environment 70 | n_joints = args.n_joints 71 | link_lengths = torch.tensor([1./n_joints]*n_joints).to(device) 72 | max_theta = np.pi/1.1 73 | min_theta = -1*max_theta 74 | robot = PlanarManipulator(n_joints=n_joints,link_lengths=link_lengths, 75 | max_theta=max_theta,n_kp=10, device=device) 76 | 77 | # Define the cost and pdf 78 | x_obst = [torch.tensor([0.5,0.5]),torch.tensor([-0.35,0.]), 79 | torch.tensor([-0.25,0.75]),torch.tensor([0,-0.75])] 80 | x_obst = [x_.to(device) for x_ in x_obst] 81 | r_obst = [0.25,0.15,0.25,0.3] 82 | 83 | costPlanarManipulator = PlanarManipulatorCost(robot,x_obst=x_obst,r_obst=r_obst, 84 | margin=args.margin,w_goal=args.w_goal,w_obst=args.w_obst,w_orient=args.w_orient, 85 | b_goal=args.b_goal,b_obst=args.b_obst,b_orient=args.b_orient, device=device) 86 | 87 | def cost(x): 88 | return costPlanarManipulator.cost_ik(x)[:,0] 89 | 90 | def cost_to_print(x): # for printing purposes 91 | return costPlanarManipulator.cost_ik(x) 92 | 93 | def pdf(x): 94 | return torch.exp(-cost(x)**2) 95 | 96 | ######################################################### 97 | # Discretize the domain 98 | 99 | # Define the range of target poses of the end-effector 100 | pose_max = torch.sum(link_lengths) 101 | pose_min = -1*pose_max 102 | # Discretize 103 | 104 | domain_task= [torch.linspace(pose_min,pose_max,args.d0_x).to(device)]*2 105 | domain_decision = [torch.linspace(min_theta,max_theta,args.d0_theta).to(device)]*args.n_joints 106 | domain = domain_task + domain_decision 107 | 108 | print("Discretization: ",[len(x) for x in domain]) 109 | 110 | ######################################################### 111 | # Fit TT-Model 112 | def pdf_goal(x): 113 | d_goal = costPlanarManipulator.cost_ik(x)[:,1] 114 | return torch.exp(-(d_goal/1)**2) 115 | 116 | def cost_obst(q): 117 | return costPlanarManipulator.cost_ik(q)[:,2] 118 | 119 | def pdf_obst_q(q): 120 | kp_loc = robot.forward_kin(q)[0] # get position of key-points and the end-effector 121 | d_obst = costPlanarManipulator.dist_obst(kp_loc) 122 | return torch.exp(-(d_obst/1)**2) 123 | 124 | print("Find tt_model of pdf_goal:") 125 | tt_goal = tt_utils_ol.cross_approximate(fcn=pdf_goal, domain=domain, 126 | rmax=200, nswp=10, eps=1e-3, verbose=True, 127 | kickrank=10, device=device) 128 | print("Find tt_model of pdf_obst:") 129 | tt_obst_q = tt_utils_ol.cross_approximate(fcn=pdf_obst_q, domain=domain_decision, 130 | rmax=200, nswp=10, eps=1e-3, verbose=True, 131 | kickrank=10, device=device) 132 | # make sure the dimensions of tt_obst matches with that of tt_model desired 133 | # i.e. pdf_obst(x_task,q) = pdf_obst_q(q) 134 | tt_obst = tt_utils_ol.extend_model(tt_model=tt_obst_q,site=0,n_cores=2,d=[d0_x]*2).to(device) 135 | 136 | print("Take product: pdf(x_task,x_decision) = pdf_goal(x_task,x_decision)*pdf_obst(x_decision)") 137 | tt_model = tt_goal.to('cpu')*tt_obst.to('cpu') 138 | 139 | tt_model.round_tt(1e-3) 140 | 141 | 142 | ttgo = TTGO(tt_model=tt_model.to(device),domain=domain, cost=cost, device=device) 143 | 144 | ######################################################### 145 | # Generate test set (feasible target points) 146 | ns = 1000 147 | test_theta = torch.zeros(ns,n_joints).to(device) 148 | for i in range(n_joints): 149 | unif = torch.distributions.uniform.Uniform(low=min_theta,high=max_theta) 150 | sample = torch.tensor([unif.sample() for i in range(ns)]) 151 | test_theta[:,i] = sample 152 | _, _, test_x, _ = robot.forward_kin(test_theta) 153 | test_set = torch.cat((test_x,test_theta),dim=-1) 154 | cost_values = costPlanarManipulator.cost_ik(test_set) 155 | test_set = test_set[cost_values[:,0]<0.1] 156 | ns = min(test_set.shape[0],50) 157 | test_task = test_set[:ns,:2] 158 | 159 | ########################################################## 160 | # Save the model. 161 | torch.save({ 162 | 'ttgo':ttgo, 163 | 'w': (args.w_goal,args.w_obst), 164 | 'b': (args.b_goal,args.b_obst), 165 | 'margin': args.margin, 166 | 'domain': domain, 167 | 'test_task': test_task, 168 | 'n_joints':n_joints 169 | }, file_name) 170 | 171 | ############################################################ 172 | print("############################") 173 | print("Test the model") 174 | print("############################") 175 | 176 | for alpha in [0.99,0.9,0.8,0.5]: 177 | for n_samples_tt in [10,50,100,1000]: 178 | test_ttgo(ttgo=ttgo.clone(), cost=cost_to_print, 179 | test_task=test_task, n_samples_tt=n_samples_tt, 180 | alpha=alpha, device=device, test_rand=True, cut_total=0.2) 181 | ############################################################ 182 | -------------------------------------------------------------------------------- /toy_robots/planar_manipulator_reaching.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty , 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | 21 | import torch 22 | import numpy as np 23 | from planar_manipulator import PlanarManipulator 24 | from plot_utils import plot_chain 25 | np.set_printoptions(3, suppress=True) 26 | torch.set_printoptions(3, sci_mode=False) 27 | import sys 28 | sys.path.append('../') 29 | from ttgo import TTGO 30 | from cost_utils import PlanarManipulatorCost 31 | from utils import Point2PointMotion 32 | from utils import test_ttgo 33 | import tt_utils 34 | 35 | import warnings 36 | 37 | warnings.filterwarnings('ignore') 38 | ##################################################################### 39 | import argparse 40 | 41 | if __name__ == '__main__': 42 | 43 | ############################################################ 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--d0_x', type=float, default=50) 47 | parser.add_argument('--d0_theta',type=int, default=50) 48 | parser.add_argument('--d0_w',type=int, default=50) 49 | parser.add_argument('--rmax', type=int, default=500) # max tt-rank for tt-cross 50 | parser.add_argument('--nswp', type=int, default=30) # number of sweeps in tt-cross 51 | parser.add_argument('--margin',type=float, default=0.02) 52 | parser.add_argument('--b_goal', type=float, default=0.1) 53 | parser.add_argument('--b_obst', type=float, default=1.) 54 | parser.add_argument('--b_ee', type=float, default=1.) 55 | parser.add_argument('--b_control', type=float, default=1.) 56 | parser.add_argument('--w_goal',type=float, default=1.) 57 | parser.add_argument('--w_obst',type=float, default=1.) 58 | parser.add_argument('--w_ee',type=float, default=1.) 59 | parser.add_argument('--w_control',type=float, default=0.) 60 | parser.add_argument('--K',type=int, default=2) 61 | parser.add_argument('--dt',type=float, default=0.02) 62 | parser.add_argument('--kr', type=int, default=5) 63 | parser.add_argument('--n_joints',type=int, default=2) 64 | parser.add_argument('--n_kp',type=int, default=5) 65 | parser.add_argument('--mp',type=int, default=3) #0: reach-target, 1: one via-point, 2: two via points, 3: two via and return 66 | args = parser.parse_args() 67 | 68 | file_name = "planar-ik-mp-{}-n_joints-{}-margin-{}-d0_x-{}-d0_theta-{}-d0_w-{}-nswp-{}-rmax-{}-kr-{}-b-{}-{}-{}-{}.pickle".format(args.mp, 69 | args.n_joints,args.margin,args.d0_x, args.d0_theta,args.d0_w, args.nswp, 70 | args.rmax, args.kr, args.b_goal, args.b_obst,args.b_ee,args.b_control) 71 | print(file_name) 72 | 73 | ################################################################## 74 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 75 | print("device is ", device) 76 | 77 | 78 | 79 | ######################################################## 80 | # Define the robot 81 | n_joints = args.n_joints 82 | link_lengths = torch.tensor([1./n_joints]*n_joints).to(device) 83 | max_theta = np.pi/1.1 84 | min_theta = -1*max_theta 85 | robot = PlanarManipulator(n_joints=n_joints,link_lengths=link_lengths, 86 | max_theta=max_theta,n_kp=args.n_kp, device=device) 87 | 88 | ######################################################## 89 | # Define the environment and the task (Cost function) 90 | x_obst = [torch.tensor([0.,-0.75]).to(device)] 91 | r_obst = [0.25] 92 | margin=0.02 93 | 94 | bounds = [robot.min_config, robot.max_config] 95 | p2p_motion = Point2PointMotion(n=n_joints,dt=args.dt,K=args.K,basis='rbf', 96 | bounds=bounds, device=device) 97 | 98 | costPlanarManipulator = PlanarManipulatorCost(robot,p2p_motion=p2p_motion,x_obst=x_obst, 99 | r_obst=r_obst, margin=args.margin, 100 | w_goal=args.w_goal,w_obst=args.w_obst, 101 | w_ee=args.w_ee, w_control=args.w_control, 102 | b_goal=args.b_goal, b_obst=args.b_obst, 103 | b_ee=args.b_ee, b_control=args.b_control, 104 | device=device) 105 | 106 | # Initial and final configuration 107 | theta_0 = torch.tensor([2.1*torch.pi/4,-1.5*torch.pi/4]).view(1,-1).to(device) 108 | theta_3 = 1*theta_0 109 | 110 | 111 | # Pick and place location (via-points: x_1 and x_2) 112 | x_min_place = -0.75; x_max_place = -0.5; 113 | y_min_place = -0.5; y_max_place = 1.; 114 | 115 | x_min_pick = 0.5; x_max_pick = 0.75; 116 | y_min_pick = -0.5; y_max_pick = 1.; 117 | 118 | d0_y = int(args.d0_x/5); 119 | domain_x1 = [torch.linspace(x_min_pick,x_max_pick,args.d0_x), 120 | torch.linspace(y_min_pick,y_max_pick,d0_y)] 121 | domain_x2= [torch.linspace(x_min_place,x_max_place,args.d0_x), 122 | torch.linspace(y_min_place,y_max_place,d0_y)] 123 | 124 | 125 | domain_theta = [torch.linspace(min_theta, max_theta,args.d0_theta)]*n_joints 126 | domain_w = [torch.linspace(min_theta,max_theta,args.d0_w)]*(args.K*n_joints) 127 | 128 | if args.mp==3: # 2-via points and initial and final config given 129 | def cost(x): 130 | return costPlanarManipulator.cost_j2p2p2j(x,theta_0,theta_3)[:,0] 131 | 132 | def cost_to_print(x): # for printing results 133 | return costPlanarManipulator.cost_j2p2p2j(x,theta_0,theta_3) 134 | 135 | domain_task = domain_x1 + domain_x2 136 | domain_decision = domain_theta*2+ domain_w*3 137 | 138 | elif args.mp==2: # 2-via-points only initial configuration is given 139 | def cost(x): 140 | return costPlanarManipulator.cost_j2p2p(x,theta_0)[:,0] 141 | 142 | def cost_to_print(x): # for printing results 143 | return costPlanarManipulator.cost_j2p2p(x,theta_0) 144 | 145 | 146 | domain_task = domain_x1 + domain_x2 147 | domain_decision = domain_theta*2 + domain_w*2 148 | 149 | elif args.mp==1: # one via point with initial and final config given 150 | def cost(x): 151 | return costPlanarManipulator.cost_j2p2j(x,theta_0,theta_3)[:,0] 152 | 153 | def cost_to_print(x): # for printing results 154 | return costPlanarManipulator.cost_j2p2j(x,theta_0,theta_3) 155 | 156 | domain_task = domain_x1 157 | domain_decision = domain_theta + domain_w*2 158 | 159 | elif args.mp==0: # only target point is given 160 | def cost(x): 161 | x = x.to(device) 162 | return costPlanarManipulator.cost_j2p(x,theta_0)[:,0] 163 | 164 | def cost_to_print(x): # for printing results 165 | return costPlanarManipulator.cost_j2p(x,theta_0) 166 | 167 | domain_task = domain_x1 168 | domain_decision = domain_theta + domain_w 169 | 170 | 171 | 172 | def pdf(x): 173 | return torch.exp(-cost(x)**2) 174 | 175 | 176 | domain = domain_task+domain_decision 177 | ######################################################### 178 | # Fit TT-Model 179 | tt_model = tt_utils.cross_approximate(fcn=pdf, domain=domain, 180 | rmax=100, nswp=20, eps=1e-3, verbose=True, 181 | kickrank=5, device=device) 182 | ttgo = TTGO(domain=domain,tt_mod=tt_model.to(device),cost=cost, device=device) 183 | ######################################################## 184 | # generate test set 185 | ns = 50 186 | test_task = torch.zeros(ns,len(domain_task)).to(device) 187 | for i in range(len(domain_task)): 188 | unif = torch.distributions.uniform.Uniform(low=domain_task[i][0],high=domain_task[i][-1]) 189 | test_task[:,i]= torch.tensor([unif.sample() for i in range(ns)]).to(device) 190 | 191 | ######################################################## 192 | 193 | # Save the model 194 | torch.save({ 195 | 'mp':args.mp, 196 | 'tt_model':ttgo.tt_model, 197 | 'w': (args.w_goal,args.w_obst,args.w_ee,args.w_control), 198 | 'b': (args.b_goal,args.b_obst,args.b_ee,args.b_control), 199 | 'margin': args.margin, 200 | 'domain': domain, 201 | 'test_task': test_task, 202 | 'x_obst':x_obst, 203 | 'r_obst':r_obst, 204 | 'n_joints':args.n_joints, 205 | 'n_kp':args.n_kp, 206 | 'dt':args.dt, 207 | 'theta_0':theta_0, 208 | 'theta_3':theta_3 209 | }, file_name) 210 | 211 | ######################################################## 212 | # Test the model 213 | 214 | norm=1 215 | print("total-cost | goal | collidion | end-effector | control ") 216 | for alpha in [0.99,0.9,0.8,0.5]: 217 | for n_samples_tt in [10,50,100,1000]: 218 | _ test_ttgo(ttgo=ttgo.clone(), cost=cost_to_print, 219 | test_task=test_task, n_samples_tt=n_samples_tt, 220 | alpha=alpha, device=device, test_rand=True) 221 | 222 | 223 | -------------------------------------------------------------------------------- /toy_robots/plot_utils.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 4 | Written by Suhan Shetty , 5 | 6 | This file is part of TTGO. 7 | 8 | TTGO is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU General Public License version 3 as 10 | published by the Free Software Foundation. 11 | 12 | TTGO is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with TTGO. If not, see . 19 | ''' 20 | 21 | 22 | import numpy as np 23 | from matplotlib import pyplot as plt 24 | import seaborn as sns 25 | 26 | 27 | def plot_chain(joint_loc, link_lengths, x_obst=[], r_obst=[], x_target=[], rect_patch=[], 28 | batch=False, skip_frame=10, title=None, save_as=None,figsize=3, 29 | color_intensity=0.9, motion=False, alpha=0.5, contrast=0.4, idx_highlight=[], lw=7, task='ik'): 30 | 31 | fig = plt.figure(edgecolor=[0.1,0.1,0.1]) 32 | 33 | fig.set_size_inches(figsize, figsize) 34 | sns.set_theme() 35 | 36 | sns.set_context("paper") 37 | 38 | 39 | # fig.patch.set_facecolor('white') 40 | # fig.patch.set_alpha(0.9) 41 | xmax_ = 1.1*np.sum(link_lengths) 42 | ax = fig.add_subplot(111, aspect='equal', autoscale_on=False, 43 | xlim=(-xmax_, xmax_), ylim=(-xmax_, xmax_)) 44 | 45 | for i,x in enumerate(x_obst): 46 | circ = plt.Circle(x,r_obst[i],color='grey',alpha=0.5) 47 | ax.add_patch(circ) 48 | for i, x_ in enumerate(rect_patch): 49 | rect = plt.Rectangle(rect_patch[i][0:2],rect_patch[i][2],rect_patch[i][3], color='c',alpha=0.5) 50 | ax.add_patch(rect) 51 | color_ = ['g','r'] 52 | for i, x_ in enumerate(x_target): 53 | ax.plot(x_[0],x_[1],'or', markersize=10) 54 | 55 | 56 | 57 | if batch is False: 58 | x = joint_loc[:,0] 59 | y = joint_loc[:,1] 60 | k_ = 1*color_intensity 61 | color_ = [k_,k_,k_] 62 | plt.plot(x, y, 'o-',zorder=0, marker='o',color=color_,lw=lw,mfc='w', 63 | solid_capstyle='round') 64 | else: 65 | 66 | T = joint_loc.shape[0] 67 | 68 | if task=='via': 69 | ax.legend(["target","obstacle"]) 70 | idx = np.arange(0,int(T/2), skip_frame) 71 | k_ = np.linspace(0.3,0.7,len(idx))[::-1] 72 | k_[0]=1 73 | for count,i in enumerate(idx): 74 | # color_ = np.where(motion, 1-k_[count], contrast) 75 | x = joint_loc[i,:,0] 76 | y = joint_loc[i,:,1] 77 | plt.plot(x, y, 'o-',zorder=0.9,marker='o',color='g',lw=lw,mfc='w', 78 | solid_capstyle='round', alpha= k_[count]) 79 | plt.plot(joint_loc[i,-1,0],joint_loc[i,-1,1],'oy', markersize=3) 80 | 81 | idx = idx = np.arange(int(T/2),T, skip_frame) 82 | k_ = np.linspace(0.2,0.5,len(idx)) 83 | k_[-1] = 1 84 | for count,i in enumerate(idx): 85 | # color_ = np.where(motion, 1-k_[count], contrast) 86 | x = joint_loc[i,:,0] 87 | y = joint_loc[i,:,1] 88 | plt.plot(x, y, 'o-',zorder=0.9,marker='o',color='b',lw=lw,mfc='w', 89 | solid_capstyle='round', alpha= k_[count]) 90 | plt.plot(joint_loc[i,-1,0],joint_loc[i,-1,1],'oy',markersize=3) 91 | 92 | 93 | 94 | elif task=='via2': 95 | ax.legend(["target-1","target-2","obstacle"]) 96 | idx = np.arange(0,int(T/3), skip_frame) 97 | k_ = np.linspace(0.3,0.7,len(idx))[::-1] 98 | k_[-1]=1 99 | for count,i in enumerate(idx): 100 | # color_ = np.where(motion, 1-k_[count], contrast) 101 | x = joint_loc[i,:,0] 102 | y = joint_loc[i,:,1] 103 | plt.plot(x, y, 'o-',zorder=0.9,marker='o',color='g',lw=lw,mfc='w', 104 | solid_capstyle='round', alpha= k_[count]) 105 | plt.plot(joint_loc[i,-1,0],joint_loc[i,-1,1],'og', markersize=3) 106 | 107 | idx = idx = np.arange(int(T/3),2*int(T/3), skip_frame) 108 | k_ = np.linspace(0.1,0.2,len(idx)) 109 | k_[-1] = 1 110 | for count,i in enumerate(idx): 111 | # color_ = np.where(motion, 1-k_[count], contrast) 112 | x = joint_loc[i,:,0] 113 | y = joint_loc[i,:,1] 114 | plt.plot(x, y, 'o-',zorder=0.9,marker='o',color='r',lw=lw,mfc='w', 115 | solid_capstyle='round', alpha= k_[count]) 116 | plt.plot(joint_loc[i,-1,0],joint_loc[i,-1,1],'or',markersize=3) 117 | 118 | idx = idx = np.arange(2*int(T/3),T, skip_frame) 119 | k_ = np.linspace(0.1,0.2,len(idx)) 120 | k_[-1] = 1 121 | for count,i in enumerate(idx): 122 | # color_ = np.where(motion, 1-k_[count], contrast) 123 | x = joint_loc[i,:,0] 124 | y = joint_loc[i,:,1] 125 | plt.plot(x, y, 'o-',zorder=0.9,marker='o',color='k',lw=lw,mfc='w', 126 | solid_capstyle='round', alpha= k_[count]) 127 | plt.plot(joint_loc[i,-1,0],joint_loc[i,-1,1],'ok',markersize=3) 128 | 129 | elif task=='reaching': 130 | ax.legend(["target","obstacle"]) 131 | idx = np.arange(0,int(T), skip_frame) 132 | k_ = np.linspace(0.3,0.7,len(idx))[::-1] 133 | k_[0]=1 134 | for count,i in enumerate(idx): 135 | # color_ = np.where(motion, 1-k_[count], contrast) 136 | x = joint_loc[i,:,0] 137 | y = joint_loc[i,:,1] 138 | plt.plot(x, y, 'o-',zorder=0.9,marker='o',color='g',lw=lw,mfc='w', 139 | solid_capstyle='round', alpha= k_[count]) 140 | plt.plot(joint_loc[i,-1,0],joint_loc[i,-1,1],'oy', markersize=3) 141 | elif task=='ik': 142 | ax.legend(["target","obstacle"]) 143 | idx = np.arange(0,int(T), skip_frame) 144 | for count,i in enumerate(idx): 145 | # color_ = np.where(motion, 1-k_[count], contrast) 146 | x = joint_loc[i,:,0] 147 | y = joint_loc[i,:,1] 148 | plt.plot(x, y, 'o-',zorder=0.9,marker='o',color='g',lw=lw,mfc='w', 149 | solid_capstyle='round', alpha=alpha ) 150 | plt.plot(joint_loc[i,-1,0],joint_loc[i,-1,1],'oy', markersize=3,alpha=alpha) 151 | 152 | 153 | for count,i in enumerate(idx_highlight): 154 | color_ = [0.1]*3 155 | x = joint_loc[i,:,0] 156 | y = joint_loc[i,:,1] 157 | plt.plot(x, y, 'o-',zorder=0.9,marker='o',color='k',lw=lw,mfc='w', 158 | solid_capstyle='round', alpha=0.5) 159 | 160 | plt.plot(0,0,color='y',marker='o', markersize=15) 161 | plt.grid(True) 162 | 163 | if not title is None: 164 | plt.title(title) 165 | if not save_as is None: 166 | fig.savefig('./images/'+save_as+".jpeg",bbox_inches='tight', pad_inches=0.01, dpi=300) 167 | 168 | return plt 169 | 170 | 171 | 172 | ########################################################################### 173 | ########################################################################### 174 | 175 | def plot_point_mass(x_t, xmax=1, x_obst=[],r_obst=[],batch=False,title=None, save_as=None,figsize=3): 176 | 177 | fig = plt.figure(edgecolor=[0.1,0.1,0.1]) 178 | fig.set_size_inches(figsize, figsize) 179 | 180 | # fig.patch.set_facecolor('white') 181 | # fig.patch.set_alpha(0.9) 182 | ax = fig.add_subplot(111, aspect='equal', autoscale_on=False, 183 | xlim=(-xmax, xmax), ylim=(-xmax, xmax)) 184 | 185 | if not x_obst is None: 186 | for i,x in enumerate(x_obst): 187 | circ = plt.Circle(x,r_obst[i],color='grey',alpha=0.5) 188 | ax.add_patch(circ) 189 | # if not rect_patch is None: 190 | # rect = plt.Rectangle(rect_patch[0:2],rect_patch[2],rect_patch[3], color='c',alpha=0.5) 191 | # ax.add_patch(rect) 192 | 193 | 194 | ax.plot(x_t[:,0,0],x_t[:,0,1],'og', markersize=10) 195 | ax.plot(x_t[:,-1,0],x_t[:,-1,1],'or', markersize=10) 196 | 197 | for i in range(x_t.shape[0]): 198 | plt.plot(x_t[i,:,0],x_t[i,:,1],'-b') 199 | 200 | # ax.legend(["target","init","obstacle"]) 201 | plt.grid("True") 202 | 203 | if not title is None: 204 | plt.title(title) 205 | if not save_as is None: 206 | fig.savefig(save_as+".jpeg",bbox_inches='tight', pad_inches=0.01, dpi=300) 207 | 208 | return plt -------------------------------------------------------------------------------- /tt_vs_nn_batch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "1ddacdf5-a59b-48b1-8431-57c9032fb439", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import time\n", 12 | "from tt_utils import *\n", 13 | "from fcn_approx_utils import GMM, NeuralNetwork" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "ff5e3456-9ee6-4705-a749-de3abd085edc", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "device = \"cpu\"#torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 24 | ] 25 | }, 26 | { 27 | "attachments": {}, 28 | "cell_type": "markdown", 29 | "id": "219ee14e-e65d-42a7-9339-05d7c48cef67", 30 | "metadata": {}, 31 | "source": [ 32 | "### Fit NN-model" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "69ccc462-4ef5-4321-9824-105fbc363d45", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "dim_list = [5,10,20]\n", 43 | "mix_list = [5,10,20]\n", 44 | "s_list =[0.15,0.3,0.45]\n", 45 | "\n", 46 | "Ndim = len(dim_list)\n", 47 | "Nmix = len(mix_list)\n", 48 | "Ns = len(s_list)\n", 49 | "Nt = 5\n", 50 | "nn_gmm_data_err = torch.empty((Ndim, Nmix, Ns,Nt))\n", 51 | "nn_gmm_data_time = torch.empty((Ndim, Nmix, Ns,Nt))\n", 52 | "\n", 53 | "\n", 54 | "dim_list = [5,10,20]\n", 55 | "mix_list = [10,20,40]\n", 56 | "s_list = [0.15,0.3,0.45]\n", 57 | "\n", 58 | "tt_gmm_data_err = torch.empty((Ndim, Nmix, Ns,Nt))\n", 59 | "tt_gmm_data_rank = torch.empty((Ndim, Nmix, Ns,Nt))\n", 60 | "tt_gmm_data_time = torch.empty((Ndim, Nmix, Ns,Nt))\n", 61 | "L=1.0\n", 62 | "for i, dim_ in enumerate(dim_list):\n", 63 | " for j, nmix_ in enumerate(mix_list):\n", 64 | " for k, s_ in enumerate(s_list):\n", 65 | " for p in range(Nt):\n", 66 | " print(\"###########\")\n", 67 | " print(i,j,k,p)\n", 68 | " print(dim_,nmix_,s_,p)\n", 69 | " print(\"###########\")\n", 70 | " gmm = GMM(n=dim_,nmix=nmix_,L=L,mx_coef=None,mu=None,s=s_, device=device)\n", 71 | " pdf = gmm.pdf\n", 72 | " ndata_train = 100000*dim_\n", 73 | " x_train = 2*L*(-0.5 + torch.rand((ndata_train,dim_)).to(device))\n", 74 | " y_train = pdf(x_train).view(-1,1)\n", 75 | " ndata_test = 10000\n", 76 | " x_test = 2*L*(-0.5 + torch.rand((ndata_test,dim_)).to(device))\n", 77 | " y_test = pdf(x_test)\n", 78 | " nn = NeuralNetwork(dim=dim_, width=64, lr=1e-3)\n", 79 | " data_train = torch.cat((x_train.view(-1,dim_),y_train.view(-1,1)),dim=-1)\n", 80 | " data_test = torch.cat((x_test.view(-1,dim_),y_test.view(-1,1)),dim=-1)\n", 81 | " nn.load_data(data_train, data_test)\n", 82 | " t1_nn = time.time()\n", 83 | " nn.train(num_epochs=10, batch_size=128, verbose=True)\n", 84 | " t2_nn = time.time()\n", 85 | " dt = t2_nn - t1_nn\n", 86 | " # Test the accuracy of NN over the test set\n", 87 | " y_nn = nn.model(x_test)\n", 88 | " mse_nn = ((y_nn.view(-1)-y_test.view(-1))**2).mean().detach().cpu() #(((y_nn.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean().detach()\n", 89 | " print(\"mse_nn: \", mse_nn)\n", 90 | " nn_gmm_data_err[i,j,k,p] = 1*mse_nn\n", 91 | " nn_gmm_data_time[i,j,k,p] = dt\n", 92 | " \n", 93 | " n_discretization = torch.tensor([200]*dim_).to(device)\n", 94 | " domain = [torch.linspace(-L,L,n_discretization[i_]).to(device) for i_ in range(dim_)] \n", 95 | " t1 = time.time()\n", 96 | " tt_gmm = cross_approximate(fcn=pdf, max_batch=10**6, domain=domain, \n", 97 | " rmax=200, nswp=20, eps=1e-3, verbose=False, \n", 98 | " kickrank=10, device=device)\n", 99 | " t2 = time.time()\n", 100 | "\n", 101 | " y_tt = get_value(tt_model=tt_gmm, x=x_test.to(device), domain=domain, \n", 102 | " n_discretization=n_discretization , max_batch=10**5, device=device)\n", 103 | "\n", 104 | " mse_tt = ((y_tt.view(-1)-y_test.view(-1))**2).mean()\n", 105 | " print(\"mse_tt: \", mse_tt)\n", 106 | "\n", 107 | " tt_gmm_data_err[i,j,k,p] = 1*mse_tt\n", 108 | " tt_gmm_data_rank[i,j,k,p] = max(tt_gmm.ranks_tt)\n", 109 | " tt_gmm_data_time[i,j,k,p] = (t2-t1)\n", 110 | "\n", 111 | "\n", 112 | " torch.save({'dim_list':dim_list,'s_list':s_list,'mix_list':mix_list,\n", 113 | " 'tt_gmm_data_err':tt_gmm_data_err,\n", 114 | " 'tt_gmm_data_time':tt_gmm_data_time,\n", 115 | " 'tt_gmm_data_rank':tt_gmm_data_rank},'tt_gmm_data_2.pt')\n", 116 | " torch.save({'dim_list':dim_list,'s_list':s_list,'mix_list':mix_list,\n", 117 | " 'nn_gmm_data_err':nn_gmm_data_err,\n", 118 | " 'nn_gmm_data_time':nn_gmm_data_time},'nn_gmm_data_2.pt')" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "f999a701", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "tt_data = torch.load('tt_gmm_data_1.pt')\n", 129 | "nn_data = torch.load('nn_gmm_data_1.pt')" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "3e4962b9", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "tt_mean = tt_data['tt_gmm_data_err'].mean(dim=(-1,-2))\n", 140 | "tt_std = tt_data['tt_gmm_data_err'].std(dim=(-1,-2))\n", 141 | "\n", 142 | "nn_mean = nn_data['nn_gmm_data_err'].mean(dim=(-1,-2))\n", 143 | "nn_std = nn_data['nn_gmm_data_err'].std(dim=(-1,-2))" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "8d6d9993-5b5e-4ffa-8f53-8974110f3323", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "tt_nn = (tt_data['tt_gmm_data_err']/nn_data['nn_gmm_data_err']).detach().cpu()" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "70dc9e1d-b3f4-4726-b4f4-86043d4017ae", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "y_mean = tt_nn.mean(dim=(-1,-2))\n", 164 | "y_std = tt_nn.std(dim=(-1,-2))" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "id": "82bf7fe3", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "dim_list = tt_data['dim_list']\n", 175 | "dim_list" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "id": "eb0fee0a", 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "mix_list = tt_data['mix_list']\n", 186 | "mix_list" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "70826373-40cf-4673-8ea3-09c2f23c4a7e", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "s_list = tt_data['s_list']\n", 197 | "s_list" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "f59cffee-56e8-4c4d-a356-83a640e5a68c", 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "y_mean.shape" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "id": "22bf0f8a-b582-4087-8db7-8aec1348622b", 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "import numpy as np\n", 218 | "import matplotlib\n", 219 | "import matplotlib.pyplot as plt\n", 220 | "matplotlib.rcParams.update({'font.size': 10})\n", 221 | "\n", 222 | "# Define your data\n", 223 | "conditions = ['k=5','k=10', 'k=20']\n", 224 | "m1_means = y_mean[0] # Mean values for Method 1\n", 225 | "m1_stdevs = y_std[0] # Standard deviations for Method 1\n", 226 | "m2_means = y_mean[1] # Mean values for Method 2\n", 227 | "m2_stdevs = y_std[1] # Standard deviations for Method 2\n", 228 | "m3_means = y_mean[2] # Mean values for Method 2\n", 229 | "m3_stdevs = y_std[2] # Standard deviations for Method 2\n", 230 | "m4_means = y_mean[3] # Mean values for Method 2\n", 231 | "m4_stdevs = y_std[3] # Standard deviations for Method 2\n", 232 | "\n", 233 | "\n", 234 | "\n", 235 | "# Set the width of the bars\n", 236 | "bar_width = 0.1\n", 237 | "\n", 238 | "# Set the positions of the bars on the x-axis\n", 239 | "r1 = np.arange(len(conditions))-0.1\n", 240 | "r2 = [x + bar_width for x in r1]\n", 241 | "r3 = [x + bar_width for x in r2]\n", 242 | "r4 = [x + bar_width for x in r3]\n", 243 | "\n", 244 | "# Plot the bars\n", 245 | "plt.bar(r1, m1_means, width=bar_width, label='d=10', capsize=5)\n", 246 | "plt.bar(r2, m2_means, width=bar_width, label='d=20', capsize=5)\n", 247 | "plt.bar(r3, m3_means, width=bar_width, label='d=30', capsize=5)\n", 248 | "plt.bar(r4, m4_means, width=bar_width, label='d=40', capsize=5)\n", 249 | "\n", 250 | "# Add labels, title, and legend\n", 251 | "# plt.xlabel('number of mixture components',fontsize='12')\n", 252 | "plt.ylabel('Error ratio TT/NN',fontsize='13')\n", 253 | "# plt.title('TT vs NN')\n", 254 | "plt.xticks([r + bar_width/2 for r in range(len(conditions))], conditions, fontsize=12)\n", 255 | "plt.legend(ncol=4)\n", 256 | "plt.yticks(fontsize=12)\n", 257 | "plt.yscale('log')\n", 258 | "plt.savefig('tt_vs_nn.jpeg', bbox_inches='tight',pad_inches=0.01, dpi=1000)\n", 259 | "# Show the plot\n", 260 | "plt.show()\n" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "id": "3da23969", 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "id": "a7edc12a-b8bb-461e-ba8d-0eb9d612061b", 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "id": "82295f86-373d-4382-baae-b70535824379", 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [] 286 | } 287 | ], 288 | "metadata": { 289 | "kernelspec": { 290 | "display_name": "Python 3 (ipykernel)", 291 | "language": "python", 292 | "name": "python3" 293 | }, 294 | "language_info": { 295 | "codemirror_mode": { 296 | "name": "ipython", 297 | "version": 3 298 | }, 299 | "file_extension": ".py", 300 | "mimetype": "text/x-python", 301 | "name": "python", 302 | "nbconvert_exporter": "python", 303 | "pygments_lexer": "ipython3", 304 | "version": "3.9.7" 305 | } 306 | }, 307 | "nbformat": 4, 308 | "nbformat_minor": 5 309 | } 310 | -------------------------------------------------------------------------------- /tt_vs_nn_vs_bgmm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "id": "53ff8cc6", 7 | "metadata": {}, 8 | "source": [ 9 | "'''\n", 10 | " \n", 11 | " Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/\n", 12 | " Written by Suhan Shetty ,\n", 13 | " \n", 14 | " This file is part of TTGO.\n", 15 | "\n", 16 | " TTGO is free software: you can redistribute it and/or modify\n", 17 | " it under the terms of the GNU General Public License version 3 as\n", 18 | " published by the Free Software Foundation.\n", 19 | "\n", 20 | " TTGO is distributed in the hope that it will be useful,\n", 21 | " but WITHOUT ANY WARRANTY; without even the implied warranty of\n", 22 | " MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n", 23 | " GNU General Public License for more details.\n", 24 | "\n", 25 | " You should have received a copy of the GNU General Public License\n", 26 | " along with TTGO. If not, see .\n", 27 | "'''\n" 28 | ] 29 | }, 30 | { 31 | "attachments": {}, 32 | "cell_type": "markdown", 33 | "id": "4552dce4", 34 | "metadata": {}, 35 | "source": [ 36 | "### Comparision of performance of TT vs NN vs BGMM\n", 37 | "In this notebook, we compare the approximation accuracy and speed of training between TT and NN. NN is a great tool for data-driven function approximation. However, it is not that great when the function to be approximated is given. On the other hand, TT is equipped with powerful technique called TT-Cross that can approximate a given function in TT format more efficiently. It directly takes the function to be approximated as input and outputs the function in TT format. Moreover, TT representation, unlike NN, offers other benefits like fast ways to sample, optimize, do algebra etc.\n", 38 | "\n", 39 | "We also compare it against Bayesian GMM. Note that unlike, NN and TT, BGMM requires exact samples from the reference pdf to be fit. " 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "id": "1ddacdf5-a59b-48b1-8431-57c9032fb439", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import torch\n", 50 | "from tt_utils import *\n", 51 | "from fcn_approx_utils import GMM, NeuralNetwork, BGMM\n", 52 | "import time \n" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "id": "ff5e3456-9ee6-4705-a749-de3abd085edc", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "device = \"cpu\"#torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "id": "29d66285-42f9-4bb4-bc0f-57d82e20bd5b", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "dim = 5\n", 73 | "L = 1\n", 74 | "nmix = 20\n", 75 | "s = 0.2\n", 76 | "\n", 77 | "# generate an arbitrary function (gmm with centers and covariances chosen randomly)\n", 78 | "gmm = GMM(n=dim,nmix=nmix,L=L,mx_coef=None,mu=None,s=s, device=device) \n", 79 | "pdf = gmm.pdf" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "id": "7239100c-5790-43f7-bbc3-bf6a3aeedb35", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "# For testing and training NN\n", 90 | "ndata_train = int(1e5)\n", 91 | "ndata_test = 10000\n", 92 | "\n", 93 | "x_train = 2*L*(-0.5 + torch.rand((ndata_train,dim)).to(device))\n", 94 | "y_train = pdf(x_train)\n", 95 | "\n", 96 | "x_test = 2*L*(-0.5 + torch.rand((ndata_test,dim)).to(device))\n", 97 | "y_test = pdf(x_test)\n", 98 | "\n", 99 | "data_train = torch.cat((x_train.view(-1,dim),y_train.view(-1,1)),dim=-1)\n", 100 | "data_test = torch.cat((x_test.view(-1,dim),y_test.view(-1,1)),dim=-1)" 101 | ] 102 | }, 103 | { 104 | "attachments": {}, 105 | "cell_type": "markdown", 106 | "id": "e2f2b2db", 107 | "metadata": {}, 108 | "source": [ 109 | "### Fit TT Model" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "id": "4b86ba96-9280-4706-9f55-820ea55ce398", 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "cross device is cpu\n", 123 | "Cross-approximation over a 5D domain containing 3.2e+11 grid points:\n", 124 | "iter: 0 | tt-error: 1.000e+00, test-error:9.807e-01 | time: 0.0699 | largest rank: 1\n", 125 | "iter: 1 | tt-error: 2.078e+00, test-error:8.744e-01 | time: 0.2167 | largest rank: 4\n", 126 | "iter: 2 | tt-error: 1.188e+00, test-error:7.371e-01 | time: 0.3070 | largest rank: 7\n", 127 | "iter: 3 | tt-error: 7.334e-01, test-error:5.693e-01 | time: 0.4282 | largest rank: 10\n", 128 | "iter: 4 | tt-error: 5.292e-01, test-error:3.752e-01 | time: 0.5595 | largest rank: 13\n", 129 | "iter: 5 | tt-error: 1.260e-01, test-error:3.545e-01 | time: 0.7153 | largest rank: 16\n", 130 | "iter: 6 | tt-error: 3.123e-01, test-error:7.426e-03 | time: 0.9527 | largest rank: 19\n", 131 | "iter: 7 | tt-error: 5.737e-03, test-error:1.889e-15 | time: 1.2702 | largest rank: 22\n", 132 | "iter: 8 | tt-error: 2.039e-08, test-error:1.663e-15 | time: 1.6888 | largest rank: 25 <- converged: eps < 0.001\n", 133 | "Did 2543400 function evaluations, which took 1.37s (1.856e+06 evals/s)\n", 134 | "\n", 135 | "time taken: 1.7976365089416504\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "# Represent the function in TT format (unsupervised learning and kind of non-parametric)\n", 141 | "n_discretization = torch.tensor([200]*dim).to(device)\n", 142 | "domain = [torch.linspace(-L,L,n_discretization[i]).to(device) for i in range(dim)] \n", 143 | "\n", 144 | "t1 = time.time()\n", 145 | "tt_gmm = cross_approximate(fcn=pdf, max_batch=10**6, domain=domain, \n", 146 | " rmax=200, nswp=20, eps=1e-3, verbose=True, \n", 147 | " kickrank=3, device=device)\n", 148 | "t2 = time.time()\n", 149 | "print(\"time taken: \", t2-t1)\n" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 6, 155 | "id": "bd26218e-27c2-4c03-9e52-25f2c0640edd", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "mse_tt: tensor(5.2491e-09)\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "# Test the accuracy of TT over the test set \n", 168 | "y_tt = get_value(tt_model=tt_gmm, x=x_test.to(device), domain=domain, \n", 169 | " n_discretization=n_discretization , max_batch=10**5, device=device)\n", 170 | "\n", 171 | "mse_tt = ((y_tt.view(-1)-y_test.view(-1))**2).mean()\n", 172 | "print(\"mse_tt: \", mse_tt)" 173 | ] 174 | }, 175 | { 176 | "attachments": {}, 177 | "cell_type": "markdown", 178 | "id": "79525b5f", 179 | "metadata": {}, 180 | "source": [ 181 | "### Fit NN Model" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 7, 187 | "id": "66181567-4394-4706-9c8f-9fb526bd7aab", 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "time taken: 3.4809112548828125e-05\n", 195 | "Done!\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "# Fit NN\n", 201 | "lr= 1e-3\n", 202 | "batch_size = 128\n", 203 | "epochs = 1\n", 204 | "nn = NeuralNetwork(dim, width=64, lr=1e-3, device=device)\n", 205 | "nn.load_data(data_train, data_test)\n", 206 | "t1 = time.time()\n", 207 | "# nn.train(num_epochs=epochs, batch_size=batch_size, verbose=True)\n", 208 | "t2 = time.time()\n", 209 | "print(\"time taken: \", t2-t1)\n", 210 | "print(\"Done!\")" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 8, 216 | "id": "7539baea-ecda-4204-ad3b-58b246262033", 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "mse_nn: tensor(0.0167)\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "# Test the accuracy of NN over the test set\n", 229 | "y_nn = nn.model(x_test)\n", 230 | "mse_nn = ((y_nn.view(-1)-y_test.view(-1))**2).mean().detach()\n", 231 | "print(\"mse_nn: \", mse_nn)" 232 | ] 233 | }, 234 | { 235 | "attachments": {}, 236 | "cell_type": "markdown", 237 | "id": "43eff033", 238 | "metadata": {}, 239 | "source": [ 240 | "### Fit BGMM Model" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 9, 246 | "id": "05c0a3bd", 247 | "metadata": {}, 248 | "outputs": [ 249 | { 250 | "name": "stdout", 251 | "output_type": "stream", 252 | "text": [ 253 | "mse_bgmm: tensor(0.0002)\n" 254 | ] 255 | }, 256 | { 257 | "name": "stderr", 258 | "output_type": "stream", 259 | "text": [ 260 | "/idiap/temp/sshetty/miniconda/envs/pyml/lib/python3.9/site-packages/sklearn/mixture/_base.py:268: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", 261 | " warnings.warn(\n" 262 | ] 263 | } 264 | ], 265 | "source": [ 266 | "# Sample data and Train BGMM\n", 267 | "X_sample = gmm.generate_sample(x_train.shape[0]) # sample from reference distribution\n", 268 | "bgmm = BGMM(nmix=nmix)\n", 269 | "X_numpy = X_sample.detach().cpu().numpy()\n", 270 | "bgmm.load_data(X_numpy)\n", 271 | "bgmm.fit()\n", 272 | "\n", 273 | "# Test bgmm\n", 274 | "y_bgmm = bgmm.pdf(x_test.detach().cpu().numpy())\n", 275 | "y_test_numpy = y_test.detach().cpu().numpy()\n", 276 | "mse_bgmm = ((y_bgmm.reshape(-1)-y_test_numpy.reshape(-1))**2).mean()\n", 277 | "print(\"mse_bgmm: \", mse_bgmm)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 10, 283 | "id": "65914903", 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | " mse_tt:5.249069685752172e-09,\n", 291 | " mse_bgmm:0.00016431352560326084,\n", 292 | " mse_nn:0.016652875400446403\n" 293 | ] 294 | } 295 | ], 296 | "source": [ 297 | "print(f\" mse_tt:{mse_tt},\\n mse_bgmm:{mse_bgmm},\\n mse_nn:{mse_nn}\")" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "id": "8ec3fb22", 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [] 307 | } 308 | ], 309 | "metadata": { 310 | "kernelspec": { 311 | "display_name": "Python 3 (ipykernel)", 312 | "language": "python", 313 | "name": "python3" 314 | }, 315 | "language_info": { 316 | "codemirror_mode": { 317 | "name": "ipython", 318 | "version": 3 319 | }, 320 | "file_extension": ".py", 321 | "mimetype": "text/x-python", 322 | "name": "python", 323 | "nbconvert_exporter": "python", 324 | "pygments_lexer": "ipython3", 325 | "version": "3.9.7" 326 | } 327 | }, 328 | "nbformat": 4, 329 | "nbformat_minor": 5 330 | } 331 | -------------------------------------------------------------------------------- /ttgo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | 21 | """ 22 | This class contains the pytorch implementation of the whole pipeline of TTGO: 23 | - Input: 24 | - cost: the cost function, 25 | - tt_model: corresponding to the pdf (e.g.: tt model of exp(-cost(x))) 26 | - domain: the discretization of the domain of the pdf, 27 | - max_batch: specifies the maximum batch size (decrease it if you encounter memory issues) 28 | - sites_task: a list containing the modes corresponding to the task parameters (optional). You can instead 29 | use set_sites() at test time 30 | 31 | - cross_approximate: Fit the TT-model to the given PDF using TT-Cross (Uses tntorch library) 32 | - Sample from the TT-Model 33 | - set the sites/modes for task parameters using set_sites() before calling sample (or use set_task at initialization) 34 | - two different samplers are provided: based on 1-norm or 2-norm as the density function 35 | - prioritized sampling can be done by setting alpha parameter in sampling() 36 | - Choose the best sample(s) 37 | - Fine-tune the best sample(s) using gradient-based optimization 38 | 39 | """ 40 | 41 | import numpy as np 42 | import torch 43 | import tntorch as tnt 44 | from scipy.optimize import minimize 45 | from scipy.optimize import Bounds 46 | import copy 47 | import warnings 48 | import tt_utils 49 | # torch.set_default_dtype(torch.float64) 50 | warnings.filterwarnings("ignore") 51 | 52 | 53 | class TTGO: 54 | def __init__(self, domain, cost, tt_model, sites_task=[],max_batch=10**5, device="cpu"): 55 | self.device = device 56 | self.domain = [x.to(self.device) for x in domain] # a list of 1-D torch-tensors containing the discretization points along each axis/mode 57 | self.min_domain = torch.tensor([x[0] for x in domain]).to(device) 58 | self.max_domain = torch.tensor([x[-1] for x in domain]).to(device) 59 | self.n = torch.tensor([len(x) for x in domain]).to(device) # number of discretization points along each axis/mode 60 | self.dim = len(domain) # dimension of the tensor 61 | self.tt_model = tt_model.to(device) 62 | self.canonicalize() 63 | self.cost = cost # the total cost function 64 | self.sites_task = sites_task 65 | 66 | # For optimization/fine-tuning 67 | lb = []; ub = [] 68 | for domain_i in self.domain: 69 | lb.append(domain_i[0].item()) 70 | ub.append(domain_i[-1].item()) 71 | self.scipy_bounds = Bounds(np.array(lb),np.array(ub)) 72 | 73 | 74 | def to(self,device='cpu'): 75 | self.device = device 76 | self.domain = [x.to(device) for x in self.domain] 77 | if self.tt_model: 78 | self.tt_model.to(device) 79 | 80 | def clone(self): 81 | return copy.deepcopy(self) 82 | 83 | def pdf(self,x): 84 | return -self.cost(x) 85 | 86 | 87 | def idx2domain(self,I): 88 | ''' Map the index of the tensor/discretization to the domain''' 89 | return tt_utils.idx2domain(I=I, domain=self.domain, device=self.device) 90 | 91 | 92 | def domain2idx(self, x_task): 93 | ''' Map the states from the domain (a tuple of the segment) to the index of the discretization ''' 94 | return tt_utils.domain2idx(x=x_task, domain=self.domain[:x_task.shape[-1]], device=self.device, uniform=False) 95 | 96 | 97 | def __getitem__(self,idxs): 98 | return self.tt_model[idxs].torch() 99 | 100 | def choose_best_sample(self,samples): 101 | ''' 102 | Given the samples (candidates for optima), find the best sample 103 | samples: batch_size x n_samples x dim (batch_size corresponds to the number of task-parameter) 104 | ''' 105 | cost_values = self.cost(samples.view(-1,samples.shape[-1])).view(samples.shape[0],samples.shape[1]) 106 | idx = torch.argmax(-cost_values, dim=-1) 107 | best_sample = samples[torch.arange(samples.shape[0]).unsqueeze(1),idx.view(-1,1),:] 108 | return best_sample.view(-1, 1, samples.shape[-1]) # batch_size x 1 x dim 109 | 110 | 111 | def choose_top_k_sample(self,samples,k=1): 112 | '''Given the samples choose the best k samples ''' 113 | cost_values = self.cost(samples.view(-1,samples.shape[-1])).view(samples.shape[0],samples.shape[1]) 114 | values, idx = torch.topk(-cost_values, k, dim=-1) 115 | return samples[torch.arange(samples.shape[0]).unsqueeze(1),idx,:] 116 | 117 | 118 | def optimize(self, x, bound=True, method='SLSQP', tol=1e-3): 119 | ''' 120 | Optimize from an initial guess x. 121 | To Do: Move it to pytorch based optimization instead of depending on scipy (slow) 122 | method: 'L-BFGS-B' or 'SLSQP' 123 | bound: if True the optimizaton (decision) variables will be constrained to the domain provided 124 | ''' 125 | # pytorch-to-numpy interface 126 | @torch.enable_grad() 127 | def cost_fcn(x): 128 | return self.cost(torch.from_numpy(x).reshape(1,-1).to(self.device)).to("cpu").numpy() 129 | @torch.enable_grad() 130 | def jacobian_cost(x): 131 | jac= torch.autograd.functional.jacobian(self.cost,torch.from_numpy(x).reshape(1,-1).to(self.device)).reshape(-1) 132 | jac[self.sites_task] = 0 133 | return jac.cpu().numpy().reshape(-1) 134 | 135 | if bound ==True: # constrained optimization 136 | results = minimize(cost_fcn, x.cpu().numpy().reshape(-1), method=method,jac=jacobian_cost, tol=tol, bounds=self.scipy_bounds) 137 | else: # unconstrained optimization 138 | results = minimize(cost_fcn, x.cpu().numpy().reshape(-1), method=method,jac=jacobian_cost, tol=tol) 139 | return torch.from_numpy(results.x).view(1,-1).to(self.device), results 140 | 141 | 142 | def sample_tt(self, x_task=None, n_samples=500, deterministic=False, alpha=0.75): 143 | 144 | if x_task is None: 145 | n_discretization_task = None 146 | else: 147 | self.sites_task=np.arange(x_task.shape[-1]) 148 | n_discretization_task = self.n[:x_task.shape[-1]] 149 | if not deterministic: 150 | samples = tt_utils.stochastic_top_k(tt_cores=self.tt_model.tt().cores[:], domain=self.domain, 151 | n_discretization_x=n_discretization_task , x=x_task, n_samples=n_samples, 152 | alpha=alpha, device=self.device) 153 | else: 154 | samples = tt_utils.deterministic_top_k(tt_cores=self.tt_model.tt().cores[:], domain=self.domain, 155 | n_discretization_x=n_discretization_task, x=x_task, n_samples=n_samples, 156 | device=self.device) 157 | return samples 158 | 159 | def sample_random(self, n_samples, x_task=None): 160 | ''' sample from the uniform distribution from the domain ''' 161 | samples = tt_utils.sample_random(batch_size=1, n_samples=n_samples, domain=self.domain, device=self.device) 162 | if x_task is not None: 163 | self.sites_task=np.arange(x_task.shape[-1]) 164 | samples[0,:,:x_task.shape[-1]] = x_task 165 | return samples 166 | 167 | 168 | 169 | def canonicalize(self): 170 | ''' Canonicalize the tt-cores ''' 171 | self.tt_model = tt_utils.tt_canonicalize(self.tt_model,site=0).to(self.device) 172 | 173 | 174 | def gradient_optimization(self,x, is_site_fixed, GN=True, lr=1e-2, n_step=10): 175 | ''' 176 | Given a batch of initializations x, fine tune the solution 177 | is_site_fixed: a list or tensor. is_site_fixed[i]=1 if x[:,i] is fixed/constant (e.g. task variables and discrete vasiables) 178 | GN=True => Gauss Newton else gradient-descent/asecent with learning rate lr 179 | n_step: number of steps of gd or GM 180 | ''' 181 | 182 | x_opt = tt_utils.gradient_optimization(x, fcn=self.pdf, is_site_fixed=is_site_fixed, 183 | x_min=self.min_domain, x_max=self.max_domain, 184 | lr=lr, n_step=n_step, GN=GN, max_batch=10**4, device=self.device) 185 | return x_opt 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ 3 | Written by Suhan Shetty , 4 | 5 | This file is part of TTGO. 6 | 7 | TTGO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License version 3 as 9 | published by the Free Software Foundation. 10 | 11 | TTGO is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with TTGO. If not, see . 18 | ''' 19 | 20 | 21 | import torch 22 | import numpy as np 23 | np.set_printoptions(2, suppress=True) 24 | torch.set_printoptions(2, sci_mode=False) 25 | 26 | def test_ttgo(ttgo, cost, test_task, n_samples_tt, 27 | deterministic=True, alpha=0, device='cpu', 28 | test_rand=False, robotics=True, cut_total=0.33): 29 | ''' 30 | Test TTGO for a given application 31 | test_task: a batch of test set of task paramters 32 | n_samples_tt: number of samplesfrom tt-model considered in ttgo from the tt-model 33 | n_samples_rand: number of samples from uniform distribution for random initialization 34 | alpha: choose a value between (0,1) for prioritized sampling 35 | norm: choose the type of sampling method 1 or 2 (chekc the paper) 36 | cost: the cost function 37 | ''' 38 | import time 39 | test_task = test_task.to(device) 40 | n=ttgo.dim 41 | n_samples_rand = 1*n_samples_tt 42 | n_test = test_task.shape[0] 43 | 44 | state_tt = torch.zeros(n_test,n).to(device); state_tt_opt = state_tt.clone() 45 | 46 | state_rand = state_tt.clone(); state_rand_opt = state_tt.clone() 47 | 48 | tt_t = torch.zeros(n_test).to(device); rand_t = tt_t.clone() 49 | tt_nit = tt_t.clone(); rand_nit = tt_t.clone() 50 | 51 | for i,sample_task in enumerate(test_task): 52 | t1 = time.time() 53 | # sample from tt 54 | samples = ttgo.sample_tt(n_samples=n_samples_tt, 55 | x_task=sample_task.reshape(1,-1),alpha=alpha,deterministic=deterministic) 56 | # choose the best solution 57 | state = ttgo.choose_best_sample(samples) 58 | t2= time.time() 59 | # optimize 60 | state_opt, results = ttgo.optimize(state) 61 | t3 = time.time() 62 | tt_nit_i = results.nit 63 | state_tt[i,:]= 1*state 64 | state_tt_opt[i,:]= 1*state_opt 65 | 66 | t4 = time.time() 67 | # sample from uniform distribution 68 | samples_rand = ttgo.sample_random(n_samples=n_samples_rand, 69 | x_task=sample_task.reshape(1,-1)) 70 | # choose the best sample 71 | state = ttgo.choose_best_sample(samples_rand) 72 | t5=time.time() 73 | # optimize 74 | state_opt, results = ttgo.optimize(state) 75 | t6=time.time() 76 | rand_nit_i = results.nit 77 | 78 | state_rand[i,:]= 1*state 79 | state_rand_opt[i,:]= 1*state_opt 80 | 81 | tt_t[i]=(t2-t1);rand_t[i]=(t5-t4); 82 | tt_nit[i] = tt_nit_i; rand_nit[i] = rand_nit_i 83 | 84 | costs_tt = cost(state_tt);costs_tt_opt = cost(state_tt_opt) 85 | costs_rand = cost(state_rand);costs_rand_opt = cost(state_rand_opt) 86 | 87 | print("################################################################") 88 | print("################################################################") 89 | print("deterministic:{} | alpha:{} | n_samples_tt:{} | n_samples_rand:{} | ".format(deterministic, 90 | alpha,n_samples_tt,n_samples_rand)) 91 | print('################################################################') 92 | print("################################################################") 93 | 94 | print("Cost TT (raw) : ", torch.mean(costs_tt,dim=0)) 95 | print("Cost TT (optimized) : ", torch.mean(costs_tt_opt,dim=0)) 96 | 97 | if test_rand==True: 98 | print("Cost rand (raw) : ", torch.mean(costs_rand,dim=0)) 99 | print("Cost rand (optimized) : ", torch.mean(costs_rand_opt,dim=0)) 100 | 101 | if robotics==True: 102 | 103 | n_test = costs_tt.shape[0] 104 | idx_tt = costs_tt_opt[:,0]ijl',self.Phi,w) # batch x time x n 207 | z_t = z_t - z_t[:,0,:][:,None,:] # so that z(0) = 0 208 | x_t = x_0 + z_t 209 | x_t_bounded = self.bound_traj(x_t) # clip the trajectory to maintain the upper and lower limits 210 | return x_t_bounded #.reshape(batch_size,self.T,self.n) 211 | 212 | def gen_traj_p2p(self,x_0, x_f, w): 213 | ''' 214 | generate trajectory with boundary conditions satisfied 215 | x_0: batch x n, initial state 216 | x_f: batc x n, final state 217 | w: batch x (K*n), weights of basis function, the 218 | ''' 219 | batch_size = w.shape[0] 220 | x_0 = x_0.reshape(batch_size,1,self.n).repeat(1,self.T,1) #batch x time x n 221 | x_f = x_f.reshape(batch_size,1,self.n).repeat(1,self.T,1) #batch x time x n 222 | w = w.reshape(batch_size,self.K,self.n) 223 | z_t = torch.einsum('jk,ikl->ijl',self.Phi,w) # batch x time x n 224 | z_0 = z_t[:,0,:][:,None,:] 225 | z_f = z_t[:,-1,:][:,None,:] 226 | x_t = x_0 + z_t - z_0 + torch.einsum('j,ijk->ijk',self.t,x_f-x_0+z_0-z_f) # x(t) = x(0)+ z(t)-z(0)+t*(x(1)-x(0)+z(0)-z(1)) 227 | x_t_bounded = self.bound_traj(x_t) # clip the trajectory to maintain the upper and lower limits 228 | return x_t_bounded # (batch_size,self.T,self.n) 229 | 230 | 231 | def bound_traj(self,x): 232 | ''' 233 | clip the given trajectories (batch x T x n) 234 | within the limits and smoothen it and maintain the boundary conditions 235 | ''' 236 | delta = self.upper_bound-self.lower_bound 237 | lower_x = self.lower_bound + delta*0.01 238 | upper_x = self.upper_bound - delta*0.01 239 | x = torch.clip(x, lower_x, upper_x ) # clip it 240 | 241 | # running average for filtering (also ensures zero velocity at the boundaries) 242 | k = 4 # set (k>0) 243 | x = torch.cat((x[:,0,:][:,None,:].repeat(1,2*k,1), x, 244 | x[:,-1,:][:,None,:].repeat(1,2*k,1)),dim=1) 245 | 246 | cum_l = x[:,k:-k,:].shape[1] 247 | cum_x = 8*x[:,k:k+cum_l,:]+3*(x[:,(k-1):(k-1+cum_l),:]+ 248 | x[:,(k+1):(k+1+cum_l),:])+2*(x[:,(k-2):(k-2+cum_l),:]+ 249 | x[:,(k+2):(k+2+cum_l),:])+1*(x[:,(k-3):(k-3+cum_l),:]+ 250 | x[:,(k+3):(k+3+cum_l),:]) 251 | cum_w = 2*(4+3+2+1) 252 | 253 | x_transformed = cum_x/cum_w 254 | 255 | return x_transformed 256 | --------------------------------------------------------------------------------