├── README.md ├── __pycache__ ├── p3_helper.cpython-310.pyc ├── p3_helper.cpython-38.pyc ├── pose_cnn.cpython-310.pyc └── pose_cnn.cpython-38.pyc ├── demo.py ├── image ├── 6d1.png ├── 6d2.png ├── 6d3.png ├── 6d4.png ├── 6d5.png └── 6d6.png ├── inference.py ├── p3_helper.py ├── pose_cnn.py ├── requirements.txt ├── rob599 ├── PROPSPoseDataset.py ├── __init__.py ├── __pycache__ │ ├── PROPSPoseDataset.cpython-310.pyc │ ├── PROPSPoseDataset.cpython-38.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── grad.cpython-310.pyc │ ├── grad.cpython-38.pyc │ ├── submit.cpython-38.pyc │ ├── utils.cpython-310.pyc │ └── utils.cpython-38.pyc ├── grad.py ├── submit.py └── utils.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # PoseCNN_pytorch 2 | This is an implementation of PoseCNN for 6D pose estimation on PROPSP dataset 3 |
4 | 5 | 6 | 7 | 8 | 9 | 10 |
11 | 12 | # System Requirements 13 | We tested the codes on 14 | ```bash 15 | PyTorch version: 2.3.1 16 | CUDA version: 12.1 17 | Ubuntu 22.04 18 | GeForce RTX 4070 and 4090 19 | 20 | ``` 21 | 22 | # Dependencies 23 | 24 | The project requires the following Python libraries and versions: 25 | 26 | | Package | Version | Description | 27 | |---------------|------------|-----------------------------------------------------| 28 | | `matplotlib` | `3.7.2` | For plotting and visualization. | 29 | | `numpy` | `1.24.3` | Fundamental package for numerical computations. | 30 | | `Pillow` | `11.0.0` | Library for working with image processing tasks. | 31 | | `pyrender` | `0.1.45` | Rendering 3D scenes for visualization. | 32 | | `torch` | `2.3.1` | PyTorch library for deep learning. | 33 | | `torchvision` | `0.18.1` | PyTorch's library for vision-related tasks. | 34 | | `tqdm` | `4.66.4` | For creating progress bars in scripts. | 35 | | `trimesh` | `4.4.3` | For loading and working with 3D triangular meshes. | 36 | 37 | ### Installing Dependencies 38 | 39 | You can install the required dependencies using the `requirements.txt` file: 40 | 41 | ```bash 42 | pip install -r requirements.txt 43 | 44 | ``` 45 | ## Dataset Preparation 46 | 47 | To use this project, you need to download the required dataset and extract it to the root path of the project. 48 | 49 | ### Steps to Prepare the Dataset 50 | 51 | 1. **Download the Dataset:** 52 | - Download the dataset from the following link: 53 | [PROPS-Pose-Dataset](https://drive.google.com/file/d/15rhwXhzHGKtBcxJAYMWJG7gN7BLLhyAq/view) 54 | 55 | 2. **Place the Dataset:** 56 | - Move the downloaded file `PROPS-Pose-Dataset.tar.gz` to the root directory of the project. 57 | 58 | 3. **Extract the Dataset:** 59 | - Use the following command to extract the dataset: 60 | ```bash 61 | tar -xvzf PROPS-Pose-Dataset.tar.gz 62 | ``` 63 | - This will create a folder named `PROPS-Pose-Dataset` in the root directory. 64 | 65 | 4. **Verify the Dataset Structure:** 66 | - Ensure the folder structure matches the following: 67 | ``` 68 | PROPS-Pose-Dataset/ 69 | ├── train/ 70 | │ ├── rgb/ 71 | │ ├── depth/ 72 | │ ├── mask_visib/ 73 | │ ├── train_gt.json 74 | │ ├── train_gt_info.json 75 | ├── val/ 76 | │ ├── rgb/ 77 | │ ├── depth/ 78 | │ ├── mask_visib/ 79 | │ ├── val_gt.json 80 | │ ├── val_gt_info.json 81 | ├── model/ 82 | ├── 1_master_chef_can/ 83 | ├── ... 84 | ``` 85 | 86 | 5. **Set Dataset Path in Code:** 87 | - The project will automatically locate the dataset in `PROPS-Pose-Dataset` under the root path during execution. Ensure this directory exists before running the code. 88 | 89 | --- 90 | 91 | ## Training 92 | 93 | ### Training 94 | To train the model, run the `train.py` script: 95 | 96 | ```bash 97 | python train.py 98 | ``` 99 | 100 | ## Inference 101 | 102 | To visualize the results, follow these steps to set up and run the `inference.py` script: 103 | 104 | ### Steps for Inference 105 | 106 | 1. **Download Pretrained Weights:** 107 | - Download the pretrained model weights from the following link: 108 | [PoseCNN Pretrained Weights](https://drive.google.com/file/d/1-9iheQf-TL5MjHTYZITulqbdFn5UK1Sd/view?usp=sharing) 109 | 110 | 2. **Place the Weights:** 111 | - Save the downloaded weights file (e.g., `posecnn_weights.pth`) to your desired directory. 112 | 113 | 3. **Set the Weights Path in Code:** 114 | - Open the `inference.py` script and locate the following line: 115 | ```python 116 | posecnn_model.load_state_dict(torch.load(os.path.join("your weight here"))) 117 | ``` 118 | - Replace `"your weight here"` with the path to your weights file. For example: 119 | ```python 120 | posecnn_model.load_state_dict(torch.load(os.path.join("models/posecnn_weights.pth"))) 121 | ``` 122 | 123 | 4. **Run the Inference Script:** 124 | - Execute the script to visualize predictions: 125 | ```bash 126 | python inference.py 127 | ``` 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /__pycache__/p3_helper.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/__pycache__/p3_helper.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/p3_helper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/__pycache__/p3_helper.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/pose_cnn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/__pycache__/pose_cnn.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/pose_cnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/__pycache__/pose_cnn.cpython-38.pyc -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | os.environ["TZ"] = "US/Eastern" 5 | time.tzset() 6 | 7 | import matplotlib.pyplot as plt 8 | from pose_cnn import hello_pose_cnn 9 | from p3_helper import hello_helper 10 | from rob599 import reset_seed 11 | from rob599.grad import rel_error 12 | import torch 13 | from rob599 import PROPSPoseDataset 14 | from rob599 import reset_seed, visualize_dataset 15 | import torchvision.models as models 16 | from pose_cnn import FeatureExtraction 17 | from rob599 import reset_seed 18 | from pose_cnn import FeatureExtraction, SegmentationBranch 19 | 20 | 21 | # Ensure helper functions run correctly 22 | hello_pose_cnn() 23 | hello_helper() 24 | 25 | # Check last modification time of pose_cnn.py 26 | pose_cnn_path = os.path.join("/home/yifeng/PycharmProjects/TestEnv/PoseCNN/pose_cnn.py") 27 | pose_cnn_edit_time = time.ctime(os.path.getmtime(pose_cnn_path)) 28 | print("pose_cnn.py last edited on %s" % pose_cnn_edit_time) 29 | 30 | # Set up matplotlib plotting parameters 31 | plt.rcParams["figure.figsize"] = (10.0, 8.0) # set default size of plots 32 | plt.rcParams["font.size"] = 16 33 | plt.rcParams["image.interpolation"] = "nearest" 34 | plt.rcParams["image.cmap"] = "gray" 35 | 36 | # Check for CUDA availability 37 | if torch.cuda.is_available(): 38 | print("Good to go!") 39 | DEVICE = torch.device("cuda") 40 | else: 41 | DEVICE = torch.device("cpu") 42 | 43 | # -------------------------- Dataset Preparation -------------------------------- 44 | # NOTE: Set `download=True` for the first time to download the dataset. 45 | # After downloading, set `download=False` for faster execution. 46 | 47 | data_root = "/home/yifeng/PycharmProjects/TestEnv/PoseCNN" 48 | # Prepare train dataset 49 | train_dataset = PROPSPoseDataset( 50 | root=data_root, 51 | split="train", 52 | download=False # Change to True for the first-time download 53 | ) 54 | 55 | # Prepare validation dataset 56 | val_dataset = PROPSPoseDataset( 57 | root=data_root, 58 | split="val", 59 | download=False 60 | ) 61 | 62 | # Print dataset sizes 63 | print(f"Dataset sizes: train ({len(train_dataset)}), val ({len(val_dataset)})") 64 | 65 | 66 | reset_seed(0) 67 | 68 | grid_vis = visualize_dataset(val_dataset,alpha = 0.25) 69 | plt.axis('off') 70 | plt.imshow(grid_vis) 71 | plt.show() 72 | 73 | vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) 74 | 75 | # Based on PoseCNN section III.B, the output features should 76 | # be 1/8 and 1/16 the input's spatial resolution with 512 channels 77 | print('feature1 expected shape: (N, {}, {}, {})'.format(512, 480//8, 640//8)) 78 | print('feature2 expected shape: (N, {}, {}, {})'.format(512, 480//16, 640//16)) 79 | print() 80 | 81 | vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) 82 | feature_extractor = FeatureExtraction(pretrained_model=vgg16) 83 | 84 | dummy_input = {'rgb': torch.zeros((2,3,480,640))} 85 | feature1, feature2 = feature_extractor(dummy_input) 86 | 87 | print('feature1 shape:', feature1.shape) 88 | print('feature2 shape:', feature2.shape) 89 | 90 | -------------------------------------------------------------------------------- /image/6d1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d1.png -------------------------------------------------------------------------------- /image/6d2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d2.png -------------------------------------------------------------------------------- /image/6d3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d3.png -------------------------------------------------------------------------------- /image/6d4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d4.png -------------------------------------------------------------------------------- /image/6d5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d5.png -------------------------------------------------------------------------------- /image/6d6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d6.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torchvision.models as models 4 | import multiprocessing 5 | import rob599 6 | from pose_cnn import PoseCNN, eval 7 | from rob599 import PROPSPoseDataset 8 | import os 9 | import matplotlib.pyplot as plt 10 | # Check for CUDA availability 11 | if torch.cuda.is_available(): 12 | print("Good to go!") 13 | DEVICE = torch.device("cuda") 14 | else: 15 | DEVICE = torch.device("cpu") 16 | rob599.reset_seed(0) 17 | NUM_CLASSES = 10 18 | BATCH_SIZE = 4 19 | NUM_WORKERS = multiprocessing.cpu_count() 20 | 21 | data_root = "/home/yifeng/PycharmProjects/TestEnv/PoseCNN" 22 | # Prepare train dataset 23 | train_dataset = PROPSPoseDataset( 24 | root=data_root, 25 | split="train", 26 | download=False # Change to True for the first-time download 27 | ) 28 | 29 | # Prepare validation dataset 30 | val_dataset = PROPSPoseDataset( 31 | root=data_root, 32 | split="val", 33 | download=False 34 | ) 35 | 36 | 37 | dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE) 38 | 39 | vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) 40 | posecnn_model = PoseCNN(pretrained_backbone = vgg16, 41 | models_pcd = torch.tensor(val_dataset.models_pcd).to(DEVICE, dtype=torch.float32), 42 | cam_intrinsic = val_dataset.cam_intrinsic).to(DEVICE) 43 | posecnn_model.load_state_dict(torch.load(os.path.join("/home/yifeng/PycharmProjects/TestEnv/posecnn_model.pth"))) 44 | 45 | num_samples =5 46 | for i in range(num_samples): 47 | out = eval(posecnn_model, dataloader, DEVICE) 48 | 49 | plt.axis('off') 50 | plt.imshow(out) 51 | plt.show() -------------------------------------------------------------------------------- /p3_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from torchvision.ops import box_iou 5 | import sys, os 6 | import json 7 | import random 8 | import cv2 9 | from PIL import Image 10 | import trimesh 11 | import pyrender 12 | import tqdm 13 | 14 | _HOUGHVOTING_NUM_INLIER = 100 15 | _HOUGHVOTING_DIRECTION_INLIER = 0.9 16 | _LABEL2MASK_THRESHOL = 100 17 | 18 | def hello_helper(): 19 | print("Hello from P3_helper.py!") 20 | 21 | 22 | def loss_cross_entropy(scores, labels): 23 | """ 24 | scores: a tensor [batch_size, num_classes, height, width] 25 | labels: a tensor [batch_size, num_classes, height, width] 26 | """ 27 | 28 | cross_entropy = -torch.sum(labels * torch.log(scores + 1e-10), dim=1) 29 | loss = torch.div(torch.sum(cross_entropy), torch.sum(labels)+1e-10) 30 | 31 | return loss 32 | 33 | def loss_Rotation(pred_R, gt_R, label, model): 34 | """ 35 | pred_R: a tensor [N, 3, 3] 36 | gt_R: a tensor [N, 3, 3] 37 | label: a tensor [N, ] 38 | model: a tensor [N_cls, 1024, 3] 39 | """ 40 | device = pred_R.device 41 | models_pcd = model[label - 1].to(device) 42 | gt_points = models_pcd @ gt_R 43 | pred_points = models_pcd @ pred_R 44 | loss = ((pred_points - gt_points) ** 2).sum(dim=2).sqrt().mean() 45 | return loss 46 | 47 | 48 | def IOUselection(pred_bbxes, gt_bbxes, threshold): 49 | """ 50 | pred_bbx is N_pred_bbx * 6 (batch_ids, x1, y1, x2, y2, cls) 51 | gt_bbx is gt_bbx * 6 (batch_ids, x1, y1, x2, y2, cls) 52 | threshold : threshold of IOU for selection of predicted bbx 53 | """ 54 | device = pred_bbxes.device 55 | output_bbxes = torch.empty((0, 6)).to(device = device, dtype =torch.float) 56 | for pred_bbx in pred_bbxes: 57 | for gt_bbx in gt_bbxes: 58 | if pred_bbx[0] == gt_bbx[0] and pred_bbx[5] == gt_bbx[5]: 59 | iou = box_iou(pred_bbx[1:5].unsqueeze(dim=0), gt_bbx[1:5].unsqueeze(dim=0)).item() 60 | if iou > threshold: 61 | output_bbxes = torch.cat((output_bbxes, pred_bbx.unsqueeze(dim=0)), dim=0) 62 | return output_bbxes 63 | 64 | 65 | def HoughVoting(label, centermap, num_classes=10): 66 | """ 67 | label [bs, 3, H, W] 68 | centermap [bs, 3*maxinstance, H, W] 69 | """ 70 | batches, H, W = label.shape 71 | x = np.linspace(0, W - 1, W) 72 | y = np.linspace(0, H - 1, H) 73 | xv, yv = np.meshgrid(x, y) 74 | xy = torch.from_numpy(np.array((xv, yv))).to(device = label.device, dtype=torch.float32) 75 | x_index = torch.from_numpy(x).to(device = label.device, dtype=torch.int32) 76 | centers = torch.zeros(batches, num_classes, 2) 77 | depths = torch.zeros(batches, num_classes) 78 | for bs in range(batches): 79 | for cls in range(1, num_classes + 1): 80 | if (label[bs] == cls).sum() >= _LABEL2MASK_THRESHOL: 81 | pixel_location = xy[:2, label[bs] == cls] 82 | pixel_direction = centermap[bs, (cls-1)*3:cls*3][:2, label[bs] == cls] 83 | y_index = x_index.unsqueeze(dim=0) - pixel_location[0].unsqueeze(dim=1) 84 | y_index = torch.round(pixel_location[1].unsqueeze(dim=1) + (pixel_direction[1]/pixel_direction[0]).unsqueeze(dim=1) * y_index).to(torch.int32) 85 | mask = (y_index >= 0) * (y_index < H) 86 | count = y_index * W + x_index.unsqueeze(dim=0) 87 | center, inlier_num = torch.bincount(count[mask]).argmax(), torch.bincount(count[mask]).max() 88 | center_x, center_y = center % W, torch.div(center, W, rounding_mode='trunc') 89 | if inlier_num > _HOUGHVOTING_NUM_INLIER: 90 | centers[bs, cls - 1, 0], centers[bs, cls - 1, 1] = center_x, center_y 91 | xyplane_dis = xy - torch.tensor([center_x, center_y])[:, None, None].to(device = label.device) 92 | xyplane_direction = xyplane_dis/(xyplane_dis**2).sum(dim=0).sqrt()[None, :, :] 93 | predict_direction = centermap[bs, (cls-1)*3:cls*3][:2] 94 | inlier_mask = ((xyplane_direction * predict_direction).sum(dim=0).abs() >= _HOUGHVOTING_DIRECTION_INLIER) * label[bs] == cls 95 | depths[bs, cls - 1] = centermap[bs, (cls-1)*3:cls*3][2, inlier_mask].mean() 96 | return centers, depths 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /pose_cnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the PoseCNN network architecture in PyTorch. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.init import kaiming_normal_ 7 | import torchvision.models as models 8 | from torchvision.ops import RoIPool 9 | 10 | import numpy as np 11 | import random 12 | import statistics 13 | import time 14 | from typing import Dict, List, Callable, Optional 15 | 16 | from rob599 import quaternion_to_matrix 17 | from p3_helper import HoughVoting, _LABEL2MASK_THRESHOL, loss_cross_entropy, loss_Rotation, IOUselection 18 | 19 | 20 | def hello_pose_cnn(): 21 | """ 22 | This is a sample function that we will try to import and run to ensure that 23 | our environment is correctly set up on Google Colab. 24 | """ 25 | print("Hello from pose_cnn.py!") 26 | 27 | 28 | class FeatureExtraction(nn.Module): 29 | """ 30 | Feature Embedding Module for PoseCNN. Using pretrained VGG16 network as backbone. 31 | """ 32 | def __init__(self, pretrained_model): 33 | super(FeatureExtraction, self).__init__() 34 | embedding_layers = list(pretrained_model.features)[:30] 35 | ## Embedding Module from begining till the first output feature map 36 | self.embedding1 = nn.Sequential(*embedding_layers[:23]) 37 | ## Embedding Module from the first output feature map till the second output feature map 38 | self.embedding2 = nn.Sequential(*embedding_layers[23:]) 39 | 40 | for i in [0, 2, 5, 7, 10, 12, 14]: 41 | self.embedding1[i].weight.requires_grad = False 42 | self.embedding1[i].bias.requires_grad = False 43 | 44 | def forward(self, datadict): 45 | """ 46 | feature1: [bs, 512, H/8, W/8] 47 | feature2: [bs, 512, H/16, W/16] 48 | """ 49 | feature1 = self.embedding1(datadict['rgb']) 50 | feature2 = self.embedding2(feature1) 51 | return feature1, feature2 52 | 53 | class SegmentationBranch(nn.Module): 54 | """ 55 | Instance Segmentation Module for PoseCNN. 56 | """ 57 | def __init__(self, num_classes = 10, hidden_layer_dim = 64): 58 | super(SegmentationBranch, self).__init__() 59 | 60 | self.num_classes = num_classes 61 | 62 | self.conv1_feat1 = nn.Conv2d(512, hidden_layer_dim, kernel_size=1) 63 | self.relu1_feat1 = nn.ReLU() 64 | self.conv1_feat2 = nn.Conv2d(512, hidden_layer_dim, kernel_size=1) 65 | self.relu1_feat2 = nn.ReLU() 66 | 67 | 68 | self.upsample_f1tof2 = nn.Upsample(scale_factor=2, mode='bilinear') 69 | 70 | # Using nearest neighbor interpolation to upsample from feature level 2 to full size 71 | # Note: For 'nearest', the align_corners option is not applicable and thus not used 72 | self.upsample_f2tofullsize = nn.Upsample(scale_factor=8, mode='bilinear') 73 | # print("*********", ) 74 | self.conv2 = nn.Conv2d(hidden_layer_dim, num_classes + 1, kernel_size=1) 75 | # self.relu2 = nn.ReLU() 76 | self.softmax_Seg = nn.Softmax(dim=1) 77 | 78 | # Initialize layers 79 | nn.init.kaiming_normal_(self.conv1_feat1.weight, nonlinearity='relu') 80 | nn.init.kaiming_normal_(self.conv1_feat2.weight, nonlinearity='relu') 81 | nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='relu') 82 | 83 | nn.init.zeros_(self.conv1_feat1.bias) 84 | nn.init.zeros_(self.conv1_feat2.bias) 85 | nn.init.zeros_(self.conv2.bias) 86 | 87 | 88 | 89 | def forward(self, feature1, feature2): 90 | """ 91 | Args: 92 | feature1: Features from feature extraction backbone (B, 512, h, w) 93 | feature2: Features from feature extraction backbone (B, 512, h//2, w//2) 94 | Returns: 95 | probability: Segmentation map of probability for each class at each pixel. 96 | probability size: (B,num_classes+1,H,W) 97 | segmentation: Segmentation map of class id's with highest prob at each pixel. 98 | segmentation size: (B,H,W) 99 | bbx: Bounding boxs detected from the segmentation. Can be extracted 100 | from the predicted segmentation map using self.label2bbx(segmentation). 101 | bbx size: (N,6) with (batch_ids, x1, y1, x2, y2, cls) 102 | """ 103 | probability = None 104 | segmentation = None 105 | bbx = None 106 | 107 | # Replace "pass" statement with your code 108 | feature1 = self.relu1_feat1(self.conv1_feat1(feature1)) 109 | feature2 = self.relu1_feat2(self.conv1_feat2(feature2)) 110 | 111 | feature2 = self.upsample_f1tof2(feature2) 112 | 113 | feature = torch.add(feature1, feature2) 114 | feature = self.upsample_f2tofullsize(feature) 115 | 116 | feature = self.conv2(feature) 117 | # feature = self.relu2(feature) 118 | probability = self.softmax_Seg(feature) 119 | 120 | segmentation = torch.argmax(probability, dim=1) 121 | bbx = self.label2bbx(segmentation) 122 | 123 | return probability, segmentation, bbx 124 | 125 | def label2bbx(self, label): 126 | bbx = [] 127 | bs, H, W = label.shape 128 | device = label.device 129 | label_repeat = label.view(bs, 1, H, W).repeat(1, self.num_classes, 1, 1).to(device) 130 | label_target = torch.linspace(0, self.num_classes - 1, steps = self.num_classes).view(1, -1, 1, 1).repeat(bs, 1, H, W).to(device) 131 | mask = (label_repeat == label_target) 132 | for batch_id in range(mask.shape[0]): 133 | for cls_id in range(mask.shape[1]): 134 | if cls_id != 0: 135 | # cls_id == 0 is the background 136 | y, x = torch.where(mask[batch_id, cls_id] != 0) 137 | if y.numel() >= _LABEL2MASK_THRESHOL: 138 | bbx.append([batch_id, torch.min(x).item(), torch.min(y).item(), 139 | torch.max(x).item(), torch.max(y).item(), cls_id]) 140 | bbx = torch.tensor(bbx).to(device) 141 | return bbx 142 | 143 | 144 | class TranslationBranch(nn.Module): 145 | """ 146 | 3D Translation Estimation Module for PoseCNN. 147 | """ 148 | def __init__(self, num_classes = 10, hidden_layer_dim = 128): 149 | super(TranslationBranch, self).__init__() 150 | 151 | # Replace "pass" statement with your code 152 | self.num_classes = num_classes 153 | 154 | self.conv1_feat1 = nn.Conv2d(512, hidden_layer_dim, kernel_size=1) 155 | self.relu1_feat1 = nn.ReLU() 156 | self.conv1_feat2 = nn.Conv2d(512, hidden_layer_dim, kernel_size=1) 157 | self.relu1_feat2 = nn.ReLU() 158 | 159 | self.upsample_f1tof2 = nn.Upsample(scale_factor=2) 160 | self.upsample_f2tofullsize = nn.Upsample(scale_factor=8) 161 | self.conv2 = nn.Conv2d(hidden_layer_dim, num_classes*3, kernel_size=1) 162 | 163 | # Initialize layers 164 | nn.init.kaiming_normal_(self.conv1_feat1.weight, nonlinearity='relu') 165 | nn.init.kaiming_normal_(self.conv1_feat2.weight, nonlinearity='relu') 166 | nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='relu') 167 | 168 | nn.init.zeros_(self.conv1_feat1.bias) 169 | nn.init.zeros_(self.conv1_feat2.bias) 170 | nn.init.zeros_(self.conv2.bias) 171 | 172 | 173 | def forward(self, feature1, feature2): 174 | """ 175 | Args: 176 | feature1: Features from feature extraction backbone (B, 512, h, w) 177 | feature2: Features from feature extraction backbone (B, 512, h//2, w//2) 178 | Returns: 179 | translation: Map of object centroid predictions. 180 | translation size: (N,3*num_classes,H,W) 181 | """ 182 | translation = None 183 | 184 | # Replace "pass" statement with your code 185 | feature1 = self.relu1_feat1(self.conv1_feat1(feature1)) 186 | feature2 = self.relu1_feat2(self.conv1_feat2(feature2)) 187 | feature2 = self.upsample_f1tof2(feature2) 188 | feature = torch.add(feature1, feature2) 189 | feature = self.upsample_f2tofullsize(feature) 190 | 191 | translation = self.conv2(feature) 192 | 193 | return translation 194 | 195 | class RotationBranch(nn.Module): 196 | """ 197 | 3D Rotation Regression Module for PoseCNN. 198 | """ 199 | def __init__(self, feature_dim = 512, roi_shape = 7, hidden_dim = 4096, num_classes = 10): 200 | super(RotationBranch, self).__init__() 201 | 202 | 203 | # Replace "pass" statement with your code 204 | self.roi_feat1 = RoIPool(output_size=roi_shape, spatial_scale=1/8) 205 | self.roi_feat2 = RoIPool(output_size=roi_shape, spatial_scale=1/16) 206 | self.fc1 = nn.Linear(feature_dim*roi_shape*roi_shape, hidden_dim) 207 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 208 | self.fc3 = nn.Linear(hidden_dim, 4 * num_classes) 209 | 210 | 211 | 212 | def forward(self, feature1, feature2, bbx): 213 | """ 214 | Args: 215 | feature1: Features from feature extraction backbone (B, 512, h, w) 216 | feature2: Features from feature extraction backbone (B, 512, h//2, w//2) 217 | bbx: Bounding boxes of regions of interst (N, 5) with (batch_ids, x1, y1, x2, y2) 218 | Returns: 219 | quaternion: Regressed components of a quaternion for each class at each ROI. 220 | quaternion size: (N,4*num_classes) 221 | """ 222 | quaternion = None 223 | 224 | # Replace "pass" statement with your code 225 | feature1 = self.roi_feat1.forward(feature1, bbx) 226 | feature2 = self.roi_feat1.forward(feature2, bbx) 227 | 228 | feature = torch.add(feature1, feature2) 229 | feature = feature.view(feature.shape[0], -1) 230 | 231 | feature = self.fc1.forward(feature) 232 | feature = self.fc2.forward(feature) 233 | quaternion = self.fc3.forward(feature) 234 | 235 | 236 | return quaternion 237 | 238 | class PoseCNN(nn.Module): 239 | """ 240 | PoseCNN 241 | """ 242 | def __init__(self, pretrained_backbone, models_pcd, cam_intrinsic): 243 | super(PoseCNN, self).__init__() 244 | 245 | self.iou_threshold = 0.7 246 | self.models_pcd = models_pcd 247 | self.cam_intrinsic = cam_intrinsic 248 | 249 | 250 | # Replace "pass" statement with your code 251 | self.feature_extraction = FeatureExtraction(pretrained_backbone) 252 | self.segmentation_branch = SegmentationBranch() 253 | self.translation_branch = TranslationBranch() 254 | self.rotation_branch = RotationBranch() 255 | 256 | 257 | 258 | def forward(self, input_dict): 259 | """ 260 | input_dict = { 261 | 'rgb', 262 | 'depth', 263 | 'objs_id', 264 | 'mask', 265 | 'bbx', 266 | 'RTs' 267 | } 268 | """ 269 | 270 | 271 | if self.training: 272 | loss_dict = { 273 | "loss_segmentation": 0, 274 | "loss_centermap": 0, 275 | "loss_R": 0 276 | } 277 | 278 | gt_bbx = self.getGTbbx(input_dict) 279 | 280 | # Important: the rotation loss should be calculated only for regions 281 | # of interest that match with a ground truth object instance. 282 | # Note that the helper function, IOUselection, may be used for 283 | # identifying the predicted regions of interest with acceptable IOU 284 | # with the ground truth bounding boxes. 285 | # If no ROIs result from the selection, don't compute the loss_R 286 | 287 | # Replace "pass" statement with your code 288 | # Feature extraction 289 | feature1, feature2 = self.feature_extraction.forward(input_dict) 290 | 291 | # loss_dict["loss_segmentation"] and calculated using the loss_cross_entropy(.) function. 292 | probability, segmentation, bbx = self.segmentation_branch.forward(feature1, feature2) 293 | loss_dict["loss_segmentation"] = loss_cross_entropy(probability, input_dict["label"]) 294 | 295 | # The training loss for translation should be stored in loss_dict["loss_centermap"] using the L1loss function. 296 | translation = self.translation_branch.forward(feature1, feature2) 297 | loss_Translation = nn.L1Loss() 298 | # loss_dict["loss_centermap"] using the L1loss function. 299 | loss_dict["loss_centermap"] = loss_Translation(translation, input_dict["centermaps"]) 300 | 301 | # The training loss for rotation should be stored in loss_dict["loss_R"] using the given loss_Rotation function. 302 | if bbx.numel() > 0: 303 | sel_bbx = IOUselection(bbx.to(feature1), gt_bbx, threshold=self.iou_threshold) 304 | 305 | if sel_bbx.numel() > 0: 306 | quaternion = self.rotation_branch.forward(feature1, feature2, sel_bbx[:, :-1]) 307 | pred_Rs, label = self.estimateRotation(quaternion, sel_bbx) 308 | gt_Rs = self.gtRotation(sel_bbx, input_dict) 309 | loss_dict["loss_R"] = loss_Rotation(pred_Rs, gt_Rs, label, self.models_pcd) 310 | 311 | 312 | 313 | return loss_dict 314 | else: 315 | output_dict = None 316 | segmentation = None 317 | 318 | with torch.no_grad(): 319 | 320 | # Replace "pass" statement with your code 321 | feature1, feature2 = self.feature_extraction.forward(input_dict) 322 | probability, segmentation, bbx = self.segmentation_branch.forward(feature1, feature2) 323 | translation = self.translation_branch.forward(feature1, feature2) 324 | pred_centers, depths = HoughVoting(segmentation, translation) 325 | 326 | quaternion = self.rotation_branch.forward(feature1, feature2, bbx[:, :-1].to(dtype=feature1.dtype)) 327 | pred_Rs, _ = self.estimateRotation(quaternion, bbx) 328 | 329 | 330 | output_dict = self.generate_pose(pred_Rs, pred_centers, depths, bbx) 331 | 332 | 333 | return output_dict, segmentation 334 | 335 | def estimateTrans(self, translation_map, filter_bbx, pred_label): 336 | """ 337 | translation_map: a tensor [batch_size, num_classes * 3, height, width] 338 | filter_bbx: N_filter_bbx * 6 (batch_ids, x1, y1, x2, y2, cls) 339 | label: a tensor [batch_size, num_classes, height, width] 340 | """ 341 | N_filter_bbx = filter_bbx.shape[0] 342 | pred_Ts = torch.zeros(N_filter_bbx, 3) 343 | for idx, bbx in enumerate(filter_bbx): 344 | batch_id = int(bbx[0].item()) 345 | cls = int(bbx[5].item()) 346 | trans_map = translation_map[batch_id, (cls-1) * 3 : cls * 3, :] 347 | label = (pred_label[batch_id] == cls).detach() 348 | pred_T = trans_map[:, label].mean(dim=1) 349 | pred_Ts[idx] = pred_T 350 | return pred_Ts 351 | 352 | def gtTrans(self, filter_bbx, input_dict): 353 | N_filter_bbx = filter_bbx.shape[0] 354 | gt_Ts = torch.zeros(N_filter_bbx, 3) 355 | for idx, bbx in enumerate(filter_bbx): 356 | batch_id = int(bbx[0].item()) 357 | cls = int(bbx[5].item()) 358 | gt_Ts[idx] = input_dict['RTs'][batch_id][cls - 1][:3, [3]].T 359 | return gt_Ts 360 | 361 | def getGTbbx(self, input_dict): 362 | """ 363 | bbx is N*6 (batch_ids, x1, y1, x2, y2, cls) 364 | """ 365 | gt_bbx = [] 366 | objs_id = input_dict['objs_id'] 367 | device = objs_id.device 368 | ## [x_min, y_min, width, height] 369 | bbxes = input_dict['bbx'] 370 | for batch_id in range(bbxes.shape[0]): 371 | for idx, obj_id in enumerate(objs_id[batch_id]): 372 | if obj_id.item() != 0: 373 | # the obj appears in this image 374 | bbx = bbxes[batch_id][idx] 375 | gt_bbx.append([batch_id, bbx[0].item(), bbx[1].item(), 376 | bbx[0].item() + bbx[2].item(), bbx[1].item() + bbx[3].item(), obj_id.item()]) 377 | return torch.tensor(gt_bbx).to(device=device, dtype=torch.int16) 378 | 379 | def estimateRotation(self, quaternion_map, filter_bbx): 380 | """ 381 | quaternion_map: a tensor [batch_size, num_classes * 3, height, width] 382 | filter_bbx: N_filter_bbx * 6 (batch_ids, x1, y1, x2, y2, cls) 383 | """ 384 | N_filter_bbx = filter_bbx.shape[0] 385 | pred_Rs = torch.zeros(N_filter_bbx, 3, 3) 386 | label = [] 387 | for idx, bbx in enumerate(filter_bbx): 388 | batch_id = int(bbx[0].item()) 389 | cls = int(bbx[5].item()) 390 | quaternion = quaternion_map[idx, (cls-1) * 4 : cls * 4] 391 | quaternion = nn.functional.normalize(quaternion, dim=0) 392 | pred_Rs[idx] = quaternion_to_matrix(quaternion) 393 | label.append(cls) 394 | label = torch.tensor(label) 395 | return pred_Rs, label 396 | 397 | def gtRotation(self, filter_bbx, input_dict): 398 | N_filter_bbx = filter_bbx.shape[0] 399 | gt_Rs = torch.zeros(N_filter_bbx, 3, 3) 400 | for idx, bbx in enumerate(filter_bbx): 401 | batch_id = int(bbx[0].item()) 402 | cls = int(bbx[5].item()) 403 | gt_Rs[idx] = input_dict['RTs'][batch_id][cls - 1][:3, :3] 404 | return gt_Rs 405 | 406 | def generate_pose(self, pred_Rs, pred_centers, pred_depths, bbxs): 407 | """ 408 | pred_Rs: a tensor [pred_bbx_size, 3, 3] 409 | pred_centers: [batch_size, num_classes, 2] 410 | pred_depths: a tensor [batch_size, num_classes] 411 | bbx: a tensor [pred_bbx_size, 6] 412 | """ 413 | output_dict = {} 414 | for idx, bbx in enumerate(bbxs): 415 | bs, _, _, _, _, obj_id = bbx 416 | R = pred_Rs[idx].numpy() 417 | center = pred_centers[bs, obj_id - 1].numpy() 418 | depth = pred_depths[bs, obj_id - 1].numpy() 419 | if (center**2).sum().item() != 0: 420 | T = np.linalg.inv(self.cam_intrinsic) @ np.array([center[0], center[1], 1]) * depth 421 | T = T[:, np.newaxis] 422 | if bs.item() not in output_dict: 423 | output_dict[bs.item()] = {} 424 | output_dict[bs.item()][obj_id.item()] = np.vstack((np.hstack((R, T)), np.array([[0, 0, 0, 1]]))) 425 | return output_dict 426 | 427 | 428 | def eval(model, dataloader, device, alpha = 0.35): 429 | import cv2 430 | model.eval() 431 | 432 | sample_idx = random.randint(0,len(dataloader.dataset)-1) 433 | ## image version vis 434 | rgb = torch.tensor(dataloader.dataset[sample_idx]['rgb'][None, :]).to(device) 435 | inputdict = {'rgb': rgb} 436 | pose_dict, label = model(inputdict) 437 | poselist = [] 438 | rgb = (rgb[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) 439 | return dataloader.dataset.visualizer.vis_oneview( 440 | ipt_im = rgb, 441 | obj_pose_dict = pose_dict[0], 442 | alpha = alpha 443 | ) 444 | 445 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.7.2 2 | numpy==1.24.3 3 | Pillow==9.4.0 4 | Pillow==11.0.0 5 | pyrender==0.1.45 6 | torch==2.3.1 7 | torchvision==0.18.1 8 | tqdm==4.66.4 9 | trimesh==4.4.3 10 | -------------------------------------------------------------------------------- /rob599/PROPSPoseDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Any, Callable, Optional, Tuple 4 | import random 5 | 6 | import cv2 7 | from PIL import Image 8 | 9 | import torch 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | from torchvision.datasets.utils import download_and_extract_archive 13 | 14 | from rob599 import Visualize, chromatic_transform, add_noise 15 | 16 | 17 | 18 | class PROPSPoseDataset(Dataset): 19 | 20 | base_folder = "PROPS-Pose-Dataset" 21 | url = "https://drive.google.com/file/d/15rhwXhzHGKtBcxJAYMWJG7gN7BLLhyAq/view?usp=share_link" 22 | filename = "PROPS-Pose-Dataset.tar.gz" 23 | tgz_md5 = "a0c39fe326377dacd1d652f9fe11a7f4" 24 | 25 | def __init__( 26 | self, 27 | root: str, 28 | split: str = 'train', 29 | download: bool = False, 30 | ) -> None: 31 | assert split in ['train', 'val'] 32 | 33 | self.root = root 34 | self.split = split 35 | self.dataset_dir = os.path.join(self.root, self.base_folder) 36 | 37 | if download: 38 | self.download() 39 | 40 | 41 | 42 | ## parameter 43 | self.max_instance_num = 10 44 | self.H = 480 45 | self.W = 640 46 | self.rgb_aug_prob = 0.4 47 | self.cam_intrinsic = np.array([ 48 | [902.19, 0.0, 342.35], 49 | [0.0, 902.39, 252.23], 50 | [0.0, 0.0, 1.0]]) 51 | self.resolution = [640, 480] 52 | 53 | self.all_lst = self.parse_dir() 54 | self.shuffle() 55 | self.models_pcd = self.parse_model() 56 | 57 | 58 | self.obj_id_list = [ 59 | 1, # master chef 60 | 2, # cracker box 61 | 3, # sugar box 62 | 4, # soup can 63 | 5, # mustard bottle 64 | 6, # tuna can 65 | 8, # jello box 66 | 9, # meat can 67 | 14,# mug 68 | 18 # marker 69 | ] 70 | self.id2label = {} 71 | for idx, id in enumerate(self.obj_id_list): 72 | self.id2label[id] = idx + 1 73 | 74 | def parse_dir(self): 75 | data_dir = os.path.join(self.dataset_dir, self.split) 76 | rgb_path = os.path.join(data_dir, "rgb") 77 | depth_path = os.path.join(data_dir, "depth") 78 | mask_path = os.path.join(data_dir, "mask_visib") 79 | scene_gt_json = os.path.join(data_dir, self.split+"_gt.json") 80 | scene_gt_info_json = os.path.join(data_dir, self.split+"_gt_info.json") 81 | rgb_list = os.listdir(rgb_path) 82 | rgb_list.sort() 83 | depth_list = os.listdir(depth_path) 84 | depth_list.sort() 85 | mask_list = os.listdir(mask_path) 86 | mask_list.sort() 87 | scene_gt = json.load(open(scene_gt_json)) 88 | scene_gt_info = json.load(open(scene_gt_info_json)) 89 | assert len(rgb_list) == len(depth_list) == len(scene_gt) == len(scene_gt_info), "data files number mismatching" 90 | all_lst = [] 91 | for rgb_file in rgb_list: 92 | idx = int(rgb_file.split(".png")[0]) 93 | depth_file = f"{idx:06d}.png" 94 | scene_objs_gt = scene_gt[str(idx)] 95 | scene_objs_info_gt = scene_gt_info[str(idx)] 96 | objs_dict = {} 97 | for obj_idx in range(len(scene_objs_gt)): 98 | objs_dict[obj_idx] = {} 99 | objs_dict[obj_idx]['R'] = np.array(scene_objs_gt[obj_idx]['cam_R_m2c']).reshape(3, 3) 100 | objs_dict[obj_idx]['T'] = np.array(scene_objs_gt[obj_idx]['cam_t_m2c']).reshape(3, 1) 101 | objs_dict[obj_idx]['obj_id'] = scene_objs_gt[obj_idx]['obj_id'] 102 | objs_dict[obj_idx]['bbox_visib'] = scene_objs_info_gt[obj_idx]['bbox_visib'] 103 | assert f"{idx:006d}_{obj_idx:06d}.png" in mask_list 104 | objs_dict[obj_idx]['visible_mask_path'] = os.path.join(mask_path, f"{idx:006d}_{obj_idx:06d}.png") 105 | """ 106 | obj_sample = (rgb_path, depth_path, objs_dict) 107 | objs_dict = { 108 | 0: { 109 | cam_R_m2c: 110 | cam_t_m2c: 111 | obj_id: 112 | bbox_visib: 113 | visiable_mask_path: 114 | } 115 | ... 116 | } 117 | """ 118 | obj_sample = ( 119 | os.path.join(rgb_path, rgb_file), 120 | os.path.join(depth_path, depth_file), 121 | objs_dict 122 | ) 123 | all_lst.append(obj_sample) 124 | return all_lst 125 | 126 | def parse_model(self): 127 | model_path = os.path.join(self.dataset_dir, "model") 128 | objpathdict = { 129 | 1: ["master_chef_can", os.path.join(model_path, "1_master_chef_can", "textured_simple.obj")], 130 | 2: ["cracker_box", os.path.join(model_path, "2_cracker_box", "textured_simple.obj")], 131 | 3: ["sugar_box", os.path.join(model_path, "3_sugar_box", "textured_simple.obj")], 132 | 4: ["tomato_soup_can", os.path.join(model_path, "4_tomato_soup_can", "textured_simple.obj")], 133 | 5: ["mustard_bottle", os.path.join(model_path, "5_mustard_bottle", "textured_simple.obj")], 134 | 6: ["tuna_fish_can", os.path.join(model_path, "6_tuna_fish_can", "textured_simple.obj")], 135 | 7: ["gelatin_box", os.path.join(model_path, "8_gelatin_box", "textured_simple.obj")], 136 | 8: ["potted_meat_can", os.path.join(model_path, "9_potted_meat_can", "textured_simple.obj")], 137 | 9: ["mug", os.path.join(model_path, "14_mug", "textured_simple.obj")], 138 | 10: ["large_marker", os.path.join(model_path, "18_large_marker", "textured_simple.obj")], 139 | } 140 | self.visualizer = Visualize(objpathdict, self.cam_intrinsic, self.resolution) 141 | models_pcd_dict = {index:np.array(self.visualizer.objnode[index]['mesh'].vertices) for index in self.visualizer.objnode} 142 | models_pcd = np.zeros((len(models_pcd_dict), 1024, 3)) 143 | for m in models_pcd_dict: 144 | model = models_pcd_dict[m] 145 | models_pcd[m - 1] = model[np.random.randint(0, model.shape[0], 1024)] 146 | return models_pcd 147 | 148 | def __len__(self): 149 | return len(self.all_lst) 150 | 151 | def __getitem__(self, idx): 152 | """ 153 | obj_sample = (rgb_path, depth_path, objs_dict) 154 | objs_dict = { 155 | 0: { 156 | cam_R_m2c: 157 | cam_t_m2c: 158 | obj_id: 159 | bbox_visib: 160 | visiable_mask_path: 161 | } 162 | ... 163 | } 164 | 165 | data_dict = { 166 | 'rgb', 167 | 'depth', 168 | 'objs_id', 169 | 'mask', 170 | 'bbx', 171 | 'RTs', 172 | 'centermaps', [] 173 | } 174 | """ 175 | rgb_path, depth_path, objs_dict = self.all_lst[idx] 176 | data_dict = {} 177 | with Image.open(rgb_path) as im: 178 | rgb = np.array(im) 179 | 180 | if self.split == 'train' and np.random.rand(1) > 1 - self.rgb_aug_prob: 181 | rgb = chromatic_transform(rgb) 182 | rgb = add_noise(rgb) 183 | rgb = rgb.astype(np.float32)/255 184 | data_dict['rgb'] = rgb.transpose((2,0,1)) 185 | 186 | with Image.open(depth_path) as im: 187 | data_dict['depth'] = np.array(im)[np.newaxis, :] 188 | ## TODO data-augmentation of depth 189 | assert(len(objs_dict) <= self.max_instance_num) 190 | objs_id = np.zeros(self.max_instance_num, dtype=np.uint8) 191 | label = np.zeros((self.max_instance_num + 1, self.H, self.W), dtype=bool) 192 | bbx = np.zeros((self.max_instance_num, 4)) 193 | RTs = np.zeros((self.max_instance_num, 3, 4)) 194 | centers = np.zeros((self.max_instance_num, 2)) 195 | centermaps = np.zeros((self.max_instance_num, 3, self.resolution[1], self.resolution[0])) 196 | ## test 197 | img = cv2.imread(rgb_path) 198 | 199 | for idx in objs_dict.keys(): 200 | if len(objs_dict[idx]['bbox_visib']) > 0: 201 | ## have visible mask 202 | objs_id[idx] = self.id2label[objs_dict[idx]['obj_id']] 203 | assert(objs_id[idx] > 0) 204 | with Image.open(objs_dict[idx]['visible_mask_path']) as im: 205 | label[objs_id[idx]] = np.array(im, dtype=bool) 206 | ## [x_min, y_min, width, height] 207 | bbx[idx] = objs_dict[idx]['bbox_visib'] 208 | RT = np.zeros((4, 4)) 209 | RT[3, 3] = 1 210 | RT[:3, :3] = objs_dict[idx]['R'] 211 | RT[:3, [3]] = objs_dict[idx]['T'] 212 | RT = np.linalg.inv(RT) 213 | RTs[idx] = RT[:3] 214 | center_homo = self.cam_intrinsic @ RT[:3, [3]] 215 | center = center_homo[:2]/center_homo[2] 216 | x = np.linspace(0, self.resolution[0] - 1, self.resolution[0]) 217 | y = np.linspace(0, self.resolution[1] - 1, self.resolution[1]) 218 | xv, yv = np.meshgrid(x, y) 219 | dx, dy = center[0] - xv, center[1] - yv 220 | distance = np.sqrt(dx ** 2 + dy ** 2) 221 | nx, ny = dx / distance, dy / distance 222 | Tz = np.ones((self.resolution[1], self.resolution[0])) * RT[2, 3] 223 | centermaps[idx] = np.array([nx, ny, Tz]) 224 | ## test 225 | img = cv2.circle(img, (int(center[0]), int(center[1])), radius=2, color=(0, 0, 255), thickness = -1) 226 | centers[idx] = np.array([int(center[0]), int(center[1])]) 227 | label[0] = 1 - label[1:].sum(axis=0) 228 | # Image.fromarray(label[0].astype(np.uint8) * 255).save("testlabel.png") 229 | # Image.open(rgb_path).save("testrgb.png") 230 | # cv2.imwrite("testcenter.png", img) 231 | data_dict['objs_id'] = objs_id 232 | data_dict['label'] = label 233 | data_dict['bbx'] = bbx 234 | data_dict['RTs'] = RTs 235 | data_dict['centermaps'] = centermaps.reshape(-1, self.resolution[1], self.resolution[0]) 236 | data_dict['centers'] = centers 237 | return data_dict 238 | 239 | def shuffle(self): 240 | random.shuffle(self.all_lst) 241 | 242 | def download(self) -> None: 243 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) 244 | -------------------------------------------------------------------------------- /rob599/__init__.py: -------------------------------------------------------------------------------- 1 | from . import grad, submit 2 | from .utils import reset_seed, tensor_to_image, visualize_dataset, chromatic_transform, add_noise, Visualize, quaternion_to_matrix, format_gt_RTs 3 | from .PROPSPoseDataset import PROPSPoseDataset 4 | -------------------------------------------------------------------------------- /rob599/__pycache__/PROPSPoseDataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/PROPSPoseDataset.cpython-310.pyc -------------------------------------------------------------------------------- /rob599/__pycache__/PROPSPoseDataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/PROPSPoseDataset.cpython-38.pyc -------------------------------------------------------------------------------- /rob599/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /rob599/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /rob599/__pycache__/grad.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/grad.cpython-310.pyc -------------------------------------------------------------------------------- /rob599/__pycache__/grad.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/grad.cpython-38.pyc -------------------------------------------------------------------------------- /rob599/__pycache__/submit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/submit.cpython-38.pyc -------------------------------------------------------------------------------- /rob599/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /rob599/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /rob599/grad.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | import rob599 6 | 7 | """ Utilities for computing and checking gradients. """ 8 | 9 | 10 | def grad_check_sparse(f, x, analytic_grad, num_checks=10, h=1e-7): 11 | """ 12 | Utility function to perform numeric gradient checking. We use the centered 13 | difference formula to compute a numeric derivative: 14 | 15 | f'(x) =~ (f(x + h) - f(x - h)) / (2h) 16 | 17 | Rather than computing a full numeric gradient, we sparsely sample a few 18 | dimensions along which to compute numeric derivatives. 19 | 20 | Inputs: 21 | - f: A function that inputs a torch tensor and returns a torch scalar 22 | - x: A torch tensor of the point at which to evaluate the numeric gradient 23 | - analytic_grad: A torch tensor giving the analytic gradient of f at x 24 | - num_checks: The number of dimensions along which to check 25 | - h: Step size for computing numeric derivatives 26 | """ 27 | # fix random seed to 0 28 | rob599.reset_seed(0) 29 | for i in range(num_checks): 30 | 31 | ix = tuple([random.randrange(m) for m in x.shape]) 32 | 33 | oldval = x[ix].item() 34 | x[ix] = oldval + h # increment by h 35 | fxph = f(x).item() # evaluate f(x + h) 36 | x[ix] = oldval - h # increment by h 37 | fxmh = f(x).item() # evaluate f(x - h) 38 | x[ix] = oldval # reset 39 | 40 | grad_numerical = (fxph - fxmh) / (2 * h) 41 | grad_analytic = analytic_grad[ix] 42 | rel_error_top = abs(grad_numerical - grad_analytic) 43 | rel_error_bot = abs(grad_numerical) + abs(grad_analytic) + 1e-12 44 | rel_error = rel_error_top / rel_error_bot 45 | msg = "numerical: %f analytic: %f, relative error: %e" 46 | print(msg % (grad_numerical, grad_analytic, rel_error)) 47 | 48 | 49 | def compute_numeric_gradient(f, x, dLdf=None, h=1e-7): 50 | """ 51 | Compute the numeric gradient of f at x using a finite differences 52 | approximation. We use the centered difference: 53 | 54 | df f(x + h) - f(x - h) 55 | -- ~= ------------------- 56 | dx 2 * h 57 | 58 | Function can also expand this easily to intermediate layers using the 59 | chain rule: 60 | 61 | dL df dL 62 | -- = -- * -- 63 | dx dx df 64 | 65 | Inputs: 66 | - f: A function that inputs a torch tensor and returns a torch scalar 67 | - x: A torch tensor giving the point at which to compute the gradient 68 | - dLdf: optional upstream gradient for intermediate layers 69 | - h: epsilon used in the finite difference calculation 70 | Returns: 71 | - grad: A tensor of the same shape as x giving the gradient of f at x 72 | """ 73 | flat_x = x.contiguous().flatten() 74 | grad = torch.zeros_like(x) 75 | flat_grad = grad.flatten() 76 | 77 | # Initialize upstream gradient to be ones if not provide 78 | if dLdf is None: 79 | y = f(x) 80 | dLdf = torch.ones_like(y) 81 | dLdf = dLdf.flatten() 82 | 83 | # iterate over all indexes in x 84 | for i in range(flat_x.shape[0]): 85 | oldval = flat_x[i].item() # Store the original value 86 | flat_x[i] = oldval + h # Increment by h 87 | fxph = f(x).flatten() # Evaluate f(x + h) 88 | flat_x[i] = oldval - h # Decrement by h 89 | fxmh = f(x).flatten() # Evaluate f(x - h) 90 | flat_x[i] = oldval # Restore original value 91 | 92 | # compute the partial derivative with centered formula 93 | dfdxi = (fxph - fxmh) / (2 * h) 94 | 95 | # use chain rule to compute dLdx 96 | flat_grad[i] = dLdf.dot(dfdxi).item() 97 | 98 | # Note that since flat_grad was only a reference to grad, 99 | # we can just return the object in the shape of x by returning grad 100 | return grad 101 | 102 | 103 | def rel_error(x, y, eps=1e-10): 104 | """ 105 | Compute the relative error between a pair of tensors x and y, 106 | which is defined as: 107 | 108 | max_i |x_i - y_i]| 109 | rel_error(x, y) = ------------------------------- 110 | max_i |x_i| + max_i |y_i| + eps 111 | 112 | Inputs: 113 | - x, y: Tensors of the same shape 114 | - eps: Small positive constant for numeric stability 115 | 116 | Returns: 117 | - rel_error: Scalar giving the relative error between x and y 118 | """ 119 | """ returns relative error between x and y """ 120 | top = (x - y).abs().max().item() 121 | bot = (x.abs() + y.abs()).clamp(min=eps).max().item() 122 | return top / bot 123 | -------------------------------------------------------------------------------- /rob599/submit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | 4 | _P3_FILES = [ 5 | "pose_cnn.py", 6 | "pose_estimation.ipynb" 7 | ] 8 | 9 | 10 | def make_p3_submission(assignment_path, uniquename=None, umid=None): 11 | _make_submission(assignment_path, _P3_FILES, "P3", uniquename, umid) 12 | 13 | 14 | def _make_submission( 15 | assignment_path, file_list, assignment_no, uniquename=None, umid=None 16 | ): 17 | if uniquename is None or umid is None: 18 | uniquename, umid = _get_user_info() 19 | zip_path = "{}_{}_{}.zip".format(uniquename, umid, assignment_no) 20 | zip_path = os.path.join(assignment_path, zip_path) 21 | print("Writing zip file to: ", zip_path) 22 | with zipfile.ZipFile(zip_path, "w") as zf: 23 | for filename in file_list: 24 | if filename.startswith('rob599/'): 25 | filename_out = filename.split('/')[-1] 26 | else: 27 | filename_out = filename 28 | in_path = os.path.join(assignment_path, filename) 29 | if not os.path.isfile(in_path): 30 | raise ValueError('Could not find file "%s"' % filename) 31 | zf.write(in_path, filename_out) 32 | 33 | 34 | def _get_user_info(): 35 | uniquename = None 36 | umid = None 37 | if uniquename is None: 38 | uniquename = input("Enter your uniquename (e.g. topipari): ") 39 | if umid is None: 40 | umid = input("Enter your umid (e.g. 12345678): ") 41 | return uniquename, umid 42 | -------------------------------------------------------------------------------- /rob599/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | from torchvision.utils import make_grid 9 | 10 | from torchvision.ops import box_iou 11 | import sys, os 12 | import trimesh 13 | import pyrender 14 | import tqdm 15 | 16 | """ 17 | General utilities to help with implementation 18 | """ 19 | 20 | 21 | def reset_seed(number): 22 | """ 23 | Reset random seed to the specific number 24 | 25 | Inputs: 26 | - number: A seed number to use 27 | """ 28 | random.seed(number) 29 | np.random.seed(number) 30 | torch.manual_seed(number) 31 | return 32 | 33 | 34 | def tensor_to_image(tensor): 35 | """ 36 | Convert a torch tensor into a numpy ndarray for visualization. 37 | 38 | Inputs: 39 | - tensor: A torch tensor of shape (3, H, W) with 40 | elements in the range [0, 1] 41 | 42 | Returns: 43 | - ndarr: A uint8 numpy array of shape (H, W, 3) 44 | """ 45 | tensor = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0) 46 | ndarr = tensor.to("cpu", torch.uint8).numpy() 47 | return ndarr 48 | 49 | def format_gt_RTs(RTs): 50 | return {idx+1: np.concatenate((RTs[idx],[[0,0,0,1]])) for idx in range(len(RTs))} 51 | 52 | def visualize_dataset(pose_dataset, num_samples = 4, alpha = 0.5): 53 | """ 54 | Make a grid-shape image to plot 55 | 56 | Inputs: 57 | - pose_dataset: instance of PROPSPoseDataset 58 | 59 | Outputs: 60 | - A grid-image that visualize num_samples 61 | number of image and pose label samples 62 | """ 63 | plt.text(300, -40, 'RGB', ha="center") 64 | plt.text(950, -40, 'Pose', ha="center") 65 | plt.text(1600, -40, 'Depth', ha="center") 66 | plt.text(2250, -40, 'Segmentation', ha="center") 67 | plt.text(2900, -40, 'Centermaps[0]', ha="center") 68 | 69 | samples = [] 70 | for sample_i in range(num_samples): 71 | sample_idx = random.randint(0,len(pose_dataset)-1) 72 | sample = pose_dataset[sample_idx] 73 | rgb = (sample['rgb'].transpose(1, 2, 0) * 255).astype(np.uint8) 74 | depth = ((np.tile(sample['depth'], (3, 1, 1)) / sample['depth'].max()) * 255).astype(np.uint8) 75 | segmentation = (sample['label']*np.arange(11).reshape((11,1,1))).sum(0,keepdims=True).astype(np.float64) 76 | segmentation /= segmentation.max() 77 | segmentation = (np.tile(segmentation, (3, 1, 1)) * 255).astype(np.uint8) 78 | ctrs = sample['centermaps'].reshape(10,3,480,640)[0] 79 | ctrs -= ctrs.min() 80 | ctrs /= ctrs.max() 81 | ctrs = (ctrs * 255).astype(np.uint8) 82 | pose_dict = format_gt_RTs(sample['RTs']) 83 | render = pose_dataset.visualizer.vis_oneview( 84 | ipt_im = rgb, 85 | obj_pose_dict = pose_dict, 86 | alpha = alpha 87 | ) 88 | samples.append(torch.tensor(rgb.transpose(2, 0, 1))) 89 | samples.append(torch.tensor(render.transpose(2, 0, 1))) 90 | samples.append(torch.tensor(depth)) 91 | samples.append(torch.tensor(segmentation)) 92 | samples.append(torch.tensor(ctrs)) 93 | img = make_grid(samples, nrow=5).permute(1, 2, 0) 94 | return img 95 | 96 | 97 | 98 | def chromatic_transform(image): 99 | """ 100 | Add the hue, saturation and luminosity to the image. 101 | 102 | This is adapted from implicit-depth repository, ref: https://github.com/NVlabs/implicit_depth/blob/main/src/utils/data_augmentation.py 103 | 104 | Parameters 105 | ---------- 106 | 107 | image: array, required, the given image. 108 | 109 | Returns 110 | ------- 111 | 112 | The new image after augmentation in HLS space. 113 | """ 114 | # Set random hue, luminosity and saturation which ranges from -0.1 to 0.1 115 | d_h = (np.random.rand(1) - 0.5) * 0.1 * 180 116 | d_l = (np.random.rand(1) - 0.5) * 0.2 * 256 117 | d_s = (np.random.rand(1) - 0.5) * 0.2 * 256 118 | # Convert the BGR to HLS 119 | hls = cv2.cvtColor(image, cv2.COLOR_BGR2HLS) 120 | h, l, s = cv2.split(hls) 121 | # Add the values to the image H, L, S 122 | new_h = (h + d_h) % 180 123 | new_l = np.clip(l + d_l, 0, 255) 124 | new_s = np.clip(s + d_s, 0, 255) 125 | # Convert the HLS to BGR 126 | new_hls = cv2.merge((new_h, new_l, new_s)).astype('uint8') 127 | new_image = cv2.cvtColor(new_hls, cv2.COLOR_HLS2BGR) 128 | return new_image 129 | 130 | 131 | 132 | def add_noise(image, level = 0.1): 133 | """ 134 | Add noise to the image. 135 | 136 | This is adapted from implicit-depth repository, ref: https://github.com/NVlabs/implicit_depth/blob/main/src/utils/data_augmentation.py 137 | 138 | Parameters 139 | ---------- 140 | 141 | image: array, required, the given image; 142 | 143 | level: float, optional, default: 0.1, the maximum noise level. 144 | 145 | Returns 146 | ------- 147 | 148 | The new image after augmentation of adding noises. 149 | """ 150 | # random number 151 | r = np.random.rand(1) 152 | 153 | # gaussian noise 154 | if r < 0.9: 155 | row,col,ch= image.shape 156 | mean = 0 157 | noise_level = random.uniform(0, level) 158 | sigma = np.random.rand(1) * noise_level * 256 159 | gauss = sigma * np.random.randn(row,col) + mean 160 | gauss = np.repeat(gauss[:, :, np.newaxis], ch, axis=2) 161 | noisy = image + gauss 162 | noisy = np.clip(noisy, 0, 255) 163 | else: 164 | # motion blur 165 | sizes = [3, 5, 7, 9, 11, 15] 166 | size = sizes[int(np.random.randint(len(sizes), size=1))] 167 | kernel_motion_blur = np.zeros((size, size)) 168 | if np.random.rand(1) < 0.5: 169 | kernel_motion_blur[int((size-1)/2), :] = np.ones(size) 170 | else: 171 | kernel_motion_blur[:, int((size-1)/2)] = np.ones(size) 172 | kernel_motion_blur = kernel_motion_blur / size 173 | noisy = cv2.filter2D(image, -1, kernel_motion_blur) 174 | 175 | return noisy.astype('uint8') 176 | 177 | 178 | class Visualize: 179 | def __init__(self, object_dict, cam_intrinsic, resolution): 180 | ''' 181 | object_dict is a dict store object labels, object names and object model path, 182 | example: 183 | object_dict = { 184 | 1: ["beaker_1", path] 185 | 2: ["dropper_1", path] 186 | 3: ["dropper_2", path] 187 | } 188 | ''' 189 | self.objnode = {} 190 | self.render = pyrender.OffscreenRenderer(resolution[0], resolution[1]) 191 | self.scene = pyrender.Scene() 192 | cam = pyrender.camera.IntrinsicsCamera(cam_intrinsic[0, 0], 193 | cam_intrinsic[1, 1], 194 | cam_intrinsic[0, 2], 195 | cam_intrinsic[1, 2], 196 | znear=0.05, zfar=100.0, name=None) 197 | self.intrinsic = cam_intrinsic 198 | Axis_align = np.array([[1, 0, 0, 0], 199 | [0, -1, 0, 0], 200 | [0, 0, -1, 0], 201 | [0, 0, 0, 1]]) 202 | self.nc = pyrender.Node(camera=cam, matrix=Axis_align) 203 | self.scene.add_node(self.nc) 204 | 205 | for obj_label in object_dict: 206 | objname = object_dict[obj_label][0] 207 | objpath = object_dict[obj_label][1] 208 | tm = trimesh.load(objpath) 209 | mesh = pyrender.Mesh.from_trimesh(tm, smooth = False) 210 | node = pyrender.Node(mesh=mesh, matrix=np.eye(4)) 211 | node.mesh.is_visible = False 212 | self.objnode[obj_label] = {"name":objname, "node":node, "mesh":tm} 213 | self.scene.add_node(node) 214 | self.cmp = self.color_map(N=len(object_dict)) 215 | self.object_dict = object_dict 216 | 217 | def vis_oneview(self, ipt_im, obj_pose_dict, alpha = 0.5, axis_len=30): 218 | ''' 219 | Input: 220 | ipt_im: numpy [H, W, 3] 221 | input image 222 | obj_pose_dict: 223 | is a dict store object poses within input image 224 | example: 225 | poselist = { 226 | 15: numpy_pose 4X4, 227 | 37: numpy_pose 4X4, 228 | 39: numpy_pose 4X4, 229 | } 230 | alpha: float [0,1] 231 | alpha for labels' colormap on image 232 | axis_len: int 233 | pixel lengths for draw axis 234 | ''' 235 | img = ipt_im.copy() 236 | for obj_label in obj_pose_dict: 237 | if obj_label in self.object_dict: 238 | pose = obj_pose_dict[obj_label] 239 | node = self.objnode[obj_label]['node'] 240 | node.mesh.is_visible = True 241 | self.scene.set_pose(node, pose=pose) 242 | full_depth = self.render.render(self.scene, flags = pyrender.constants.RenderFlags.DEPTH_ONLY) 243 | for obj_label in obj_pose_dict: 244 | if obj_label in self.object_dict: 245 | node = self.objnode[obj_label]['node'] 246 | node.mesh.is_visible = False 247 | for obj_label in self.object_dict: 248 | node = self.objnode[obj_label]['node'] 249 | node.mesh.is_visible = False 250 | for obj_label in obj_pose_dict: 251 | if obj_label in self.object_dict: 252 | node = self.objnode[obj_label]['node'] 253 | node.mesh.is_visible = True 254 | depth = self.render.render(self.scene, flags = pyrender.constants.RenderFlags.DEPTH_ONLY) 255 | node.mesh.is_visible = False 256 | mask = np.logical_and( 257 | (np.abs(depth - full_depth) < 1e-6), np.abs(full_depth) > 0.2 258 | ) 259 | if np.sum(mask) > 0: 260 | color = self.cmp[obj_label - 1] 261 | img[mask, :] = alpha * img[mask, :] + (1 - alpha) * color[:] 262 | obj_pose = obj_pose_dict[obj_label] 263 | obj_center = self.project2d(self.intrinsic, obj_pose[:3, -1]) 264 | rgb_colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0)] 265 | for j in range(3): 266 | obj_xyz_offset_2d = self.project2d(self.intrinsic, obj_pose[:3, -1] + obj_pose[:3, j] * 0.001) 267 | obj_axis_endpoint = obj_center + (obj_xyz_offset_2d - obj_center) / np.linalg.norm(obj_xyz_offset_2d - obj_center) * axis_len 268 | cv2.arrowedLine(img, (int(obj_center[0]), int(obj_center[1])), (int(obj_axis_endpoint[0]), int(obj_axis_endpoint[1])), rgb_colors[j], thickness=2, tipLength=0.15) 269 | return img 270 | 271 | def color_map(self, N=256, normalized=False): 272 | def bitget(byteval, idx): 273 | return ((byteval & (1 << idx)) != 0) 274 | dtype = 'float32' if normalized else 'uint8' 275 | cmap = np.zeros((N, 3), dtype=dtype) 276 | for i in range(N): 277 | r = g = b = 0 278 | c = i 279 | for j in range(8): 280 | r = r | (bitget(c, 0) << 7-j) 281 | g = g | (bitget(c, 1) << 7-j) 282 | b = b | (bitget(c, 2) << 7-j) 283 | c = c >> 3 284 | cmap[i] = np.array([r, g, b]) 285 | cmap = cmap/255 if normalized else cmap 286 | return cmap 287 | 288 | def project2d(self, intrinsic, point3d): 289 | return (intrinsic @ (point3d / point3d[2]))[:2] 290 | 291 | def quaternion_to_matrix(quaternions): 292 | """ 293 | Convert rotations given as quaternions to rotation matrices. 294 | 295 | Args: 296 | quaternions: quaternions with real part first, 297 | as tensor of shape (..., 4). 298 | 299 | Returns: 300 | Rotation matrices as tensor of shape (..., 3, 3). 301 | """ 302 | r, i, j, k = torch.unbind(quaternions, -1) 303 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. 304 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 305 | 306 | o = torch.stack( 307 | ( 308 | 1 - two_s * (j * j + k * k), 309 | two_s * (i * j - k * r), 310 | two_s * (i * k + j * r), 311 | two_s * (i * j + k * r), 312 | 1 - two_s * (i * i + k * k), 313 | two_s * (j * k - i * r), 314 | two_s * (i * k - j * r), 315 | two_s * (j * k + i * r), 316 | 1 - two_s * (i * i + j * j), 317 | ), 318 | -1, 319 | ) 320 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 321 | 322 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | os.environ["TZ"] = "US/Eastern" 5 | time.tzset() 6 | 7 | import matplotlib.pyplot as plt 8 | from pose_cnn import hello_pose_cnn 9 | from p3_helper import hello_helper 10 | from rob599 import reset_seed 11 | from rob599.grad import rel_error 12 | import torch 13 | from rob599 import PROPSPoseDataset 14 | from rob599 import reset_seed, visualize_dataset 15 | import torchvision.models as models 16 | from pose_cnn import FeatureExtraction 17 | from rob599 import reset_seed 18 | from pose_cnn import FeatureExtraction, SegmentationBranch 19 | import time 20 | from torch.utils.data import DataLoader 21 | import torchvision.models as models 22 | import multiprocessing 23 | 24 | from rob599 import reset_seed 25 | from pose_cnn import PoseCNN 26 | from tqdm import tqdm 27 | 28 | # Set a few constants related to data loading. 29 | NUM_CLASSES = 10 30 | BATCH_SIZE = 4 31 | NUM_WORKERS = multiprocessing.cpu_count() 32 | 33 | # Ensure helper functions run correctly 34 | hello_pose_cnn() 35 | hello_helper() 36 | 37 | # Check last modification time of pose_cnn.py 38 | pose_cnn_path = os.path.join("/home/yifeng/PycharmProjects/TestEnv/PoseCNN/pose_cnn.py") 39 | pose_cnn_edit_time = time.ctime(os.path.getmtime(pose_cnn_path)) 40 | print("pose_cnn.py last edited on %s" % pose_cnn_edit_time) 41 | 42 | # Set up matplotlib plotting parameters 43 | plt.rcParams["figure.figsize"] = (10.0, 8.0) # set default size of plots 44 | plt.rcParams["font.size"] = 16 45 | plt.rcParams["image.interpolation"] = "nearest" 46 | plt.rcParams["image.cmap"] = "gray" 47 | 48 | # Check for CUDA availability 49 | if torch.cuda.is_available(): 50 | print("Good to go!") 51 | DEVICE = torch.device("cuda") 52 | else: 53 | DEVICE = torch.device("cpu") 54 | 55 | # -------------------------- Dataset Preparation -------------------------------- 56 | # NOTE: Set `download=True` for the first time to download the dataset. 57 | # After downloading, set `download=False` for faster execution. 58 | 59 | data_root = "/home/yifeng/PycharmProjects/TestEnv/PoseCNN" 60 | # Prepare train dataset 61 | train_dataset = PROPSPoseDataset( 62 | root=data_root, 63 | split="train", 64 | download=False # Change to True for the first-time download 65 | ) 66 | 67 | # Prepare validation dataset 68 | val_dataset = PROPSPoseDataset( 69 | root=data_root, 70 | split="val", 71 | download=False 72 | ) 73 | 74 | # Print dataset sizes 75 | print(f"Dataset sizes: train ({len(train_dataset)}), val ({len(val_dataset)})") 76 | 77 | 78 | reset_seed(0) 79 | 80 | grid_vis = visualize_dataset(val_dataset,alpha = 0.25) 81 | plt.axis('off') 82 | plt.imshow(grid_vis) 83 | plt.show() 84 | 85 | reset_seed(0) 86 | 87 | dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE) 88 | 89 | vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) 90 | posecnn_model = PoseCNN(pretrained_backbone = vgg16, 91 | models_pcd = torch.tensor(train_dataset.models_pcd).to(DEVICE, dtype=torch.float32), 92 | cam_intrinsic = train_dataset.cam_intrinsic).to(DEVICE) 93 | posecnn_model.train() 94 | 95 | optimizer = torch.optim.Adam(posecnn_model.parameters(), lr=0.001, 96 | betas=(0.9, 0.999)) 97 | 98 | 99 | loss_history = [] 100 | log_period = 5 101 | _iter = 0 102 | 103 | 104 | st_time = time.time() 105 | for epoch in range(10): 106 | train_loss = [] 107 | dataloader.dataset.dataset_type = 'train' 108 | for batch in dataloader: 109 | for item in batch: 110 | batch[item] = batch[item].to(DEVICE) 111 | loss_dict = posecnn_model(batch) 112 | optimizer.zero_grad() 113 | total_loss = 0 114 | for loss in loss_dict: 115 | total_loss += loss_dict[loss] 116 | total_loss.backward() 117 | optimizer.step() 118 | train_loss.append(total_loss.item()) 119 | 120 | if _iter % log_period == 0: 121 | loss_str = f"[Iter {_iter}][loss: {total_loss:.3f}]" 122 | for key, value in loss_dict.items(): 123 | loss_str += f"[{key}: {value:.3f}]" 124 | 125 | print(loss_str) 126 | loss_history.append(total_loss.item()) 127 | _iter += 1 128 | 129 | print('Time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + \ 130 | ', ' + 'Epoch %02d' % epoch + ', ' + 'Training finished' + f' , with mean training loss {np.array(train_loss).mean()}')) 131 | 132 | torch.save(posecnn_model.state_dict(), os.path.join("your path here", "posecnn_model.pth")) 133 | 134 | plt.title("Training loss history") 135 | plt.xlabel(f"Iteration (x {log_period})") 136 | plt.ylabel("Loss") 137 | plt.plot(loss_history) 138 | plt.show() --------------------------------------------------------------------------------