├── README.md
├── __pycache__
├── p3_helper.cpython-310.pyc
├── p3_helper.cpython-38.pyc
├── pose_cnn.cpython-310.pyc
└── pose_cnn.cpython-38.pyc
├── demo.py
├── image
├── 6d1.png
├── 6d2.png
├── 6d3.png
├── 6d4.png
├── 6d5.png
└── 6d6.png
├── inference.py
├── p3_helper.py
├── pose_cnn.py
├── requirements.txt
├── rob599
├── PROPSPoseDataset.py
├── __init__.py
├── __pycache__
│ ├── PROPSPoseDataset.cpython-310.pyc
│ ├── PROPSPoseDataset.cpython-38.pyc
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-38.pyc
│ ├── grad.cpython-310.pyc
│ ├── grad.cpython-38.pyc
│ ├── submit.cpython-38.pyc
│ ├── utils.cpython-310.pyc
│ └── utils.cpython-38.pyc
├── grad.py
├── submit.py
└── utils.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | # PoseCNN_pytorch
2 | This is an implementation of PoseCNN for 6D pose estimation on PROPSP dataset
3 |
4 |

5 |

6 |

7 |

8 |

9 |

10 |
11 |
12 | # System Requirements
13 | We tested the codes on
14 | ```bash
15 | PyTorch version: 2.3.1
16 | CUDA version: 12.1
17 | Ubuntu 22.04
18 | GeForce RTX 4070 and 4090
19 |
20 | ```
21 |
22 | # Dependencies
23 |
24 | The project requires the following Python libraries and versions:
25 |
26 | | Package | Version | Description |
27 | |---------------|------------|-----------------------------------------------------|
28 | | `matplotlib` | `3.7.2` | For plotting and visualization. |
29 | | `numpy` | `1.24.3` | Fundamental package for numerical computations. |
30 | | `Pillow` | `11.0.0` | Library for working with image processing tasks. |
31 | | `pyrender` | `0.1.45` | Rendering 3D scenes for visualization. |
32 | | `torch` | `2.3.1` | PyTorch library for deep learning. |
33 | | `torchvision` | `0.18.1` | PyTorch's library for vision-related tasks. |
34 | | `tqdm` | `4.66.4` | For creating progress bars in scripts. |
35 | | `trimesh` | `4.4.3` | For loading and working with 3D triangular meshes. |
36 |
37 | ### Installing Dependencies
38 |
39 | You can install the required dependencies using the `requirements.txt` file:
40 |
41 | ```bash
42 | pip install -r requirements.txt
43 |
44 | ```
45 | ## Dataset Preparation
46 |
47 | To use this project, you need to download the required dataset and extract it to the root path of the project.
48 |
49 | ### Steps to Prepare the Dataset
50 |
51 | 1. **Download the Dataset:**
52 | - Download the dataset from the following link:
53 | [PROPS-Pose-Dataset](https://drive.google.com/file/d/15rhwXhzHGKtBcxJAYMWJG7gN7BLLhyAq/view)
54 |
55 | 2. **Place the Dataset:**
56 | - Move the downloaded file `PROPS-Pose-Dataset.tar.gz` to the root directory of the project.
57 |
58 | 3. **Extract the Dataset:**
59 | - Use the following command to extract the dataset:
60 | ```bash
61 | tar -xvzf PROPS-Pose-Dataset.tar.gz
62 | ```
63 | - This will create a folder named `PROPS-Pose-Dataset` in the root directory.
64 |
65 | 4. **Verify the Dataset Structure:**
66 | - Ensure the folder structure matches the following:
67 | ```
68 | PROPS-Pose-Dataset/
69 | ├── train/
70 | │ ├── rgb/
71 | │ ├── depth/
72 | │ ├── mask_visib/
73 | │ ├── train_gt.json
74 | │ ├── train_gt_info.json
75 | ├── val/
76 | │ ├── rgb/
77 | │ ├── depth/
78 | │ ├── mask_visib/
79 | │ ├── val_gt.json
80 | │ ├── val_gt_info.json
81 | ├── model/
82 | ├── 1_master_chef_can/
83 | ├── ...
84 | ```
85 |
86 | 5. **Set Dataset Path in Code:**
87 | - The project will automatically locate the dataset in `PROPS-Pose-Dataset` under the root path during execution. Ensure this directory exists before running the code.
88 |
89 | ---
90 |
91 | ## Training
92 |
93 | ### Training
94 | To train the model, run the `train.py` script:
95 |
96 | ```bash
97 | python train.py
98 | ```
99 |
100 | ## Inference
101 |
102 | To visualize the results, follow these steps to set up and run the `inference.py` script:
103 |
104 | ### Steps for Inference
105 |
106 | 1. **Download Pretrained Weights:**
107 | - Download the pretrained model weights from the following link:
108 | [PoseCNN Pretrained Weights](https://drive.google.com/file/d/1-9iheQf-TL5MjHTYZITulqbdFn5UK1Sd/view?usp=sharing)
109 |
110 | 2. **Place the Weights:**
111 | - Save the downloaded weights file (e.g., `posecnn_weights.pth`) to your desired directory.
112 |
113 | 3. **Set the Weights Path in Code:**
114 | - Open the `inference.py` script and locate the following line:
115 | ```python
116 | posecnn_model.load_state_dict(torch.load(os.path.join("your weight here")))
117 | ```
118 | - Replace `"your weight here"` with the path to your weights file. For example:
119 | ```python
120 | posecnn_model.load_state_dict(torch.load(os.path.join("models/posecnn_weights.pth")))
121 | ```
122 |
123 | 4. **Run the Inference Script:**
124 | - Execute the script to visualize predictions:
125 | ```bash
126 | python inference.py
127 | ```
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/__pycache__/p3_helper.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/__pycache__/p3_helper.cpython-310.pyc
--------------------------------------------------------------------------------
/__pycache__/p3_helper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/__pycache__/p3_helper.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/pose_cnn.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/__pycache__/pose_cnn.cpython-310.pyc
--------------------------------------------------------------------------------
/__pycache__/pose_cnn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/__pycache__/pose_cnn.cpython-38.pyc
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | os.environ["TZ"] = "US/Eastern"
5 | time.tzset()
6 |
7 | import matplotlib.pyplot as plt
8 | from pose_cnn import hello_pose_cnn
9 | from p3_helper import hello_helper
10 | from rob599 import reset_seed
11 | from rob599.grad import rel_error
12 | import torch
13 | from rob599 import PROPSPoseDataset
14 | from rob599 import reset_seed, visualize_dataset
15 | import torchvision.models as models
16 | from pose_cnn import FeatureExtraction
17 | from rob599 import reset_seed
18 | from pose_cnn import FeatureExtraction, SegmentationBranch
19 |
20 |
21 | # Ensure helper functions run correctly
22 | hello_pose_cnn()
23 | hello_helper()
24 |
25 | # Check last modification time of pose_cnn.py
26 | pose_cnn_path = os.path.join("/home/yifeng/PycharmProjects/TestEnv/PoseCNN/pose_cnn.py")
27 | pose_cnn_edit_time = time.ctime(os.path.getmtime(pose_cnn_path))
28 | print("pose_cnn.py last edited on %s" % pose_cnn_edit_time)
29 |
30 | # Set up matplotlib plotting parameters
31 | plt.rcParams["figure.figsize"] = (10.0, 8.0) # set default size of plots
32 | plt.rcParams["font.size"] = 16
33 | plt.rcParams["image.interpolation"] = "nearest"
34 | plt.rcParams["image.cmap"] = "gray"
35 |
36 | # Check for CUDA availability
37 | if torch.cuda.is_available():
38 | print("Good to go!")
39 | DEVICE = torch.device("cuda")
40 | else:
41 | DEVICE = torch.device("cpu")
42 |
43 | # -------------------------- Dataset Preparation --------------------------------
44 | # NOTE: Set `download=True` for the first time to download the dataset.
45 | # After downloading, set `download=False` for faster execution.
46 |
47 | data_root = "/home/yifeng/PycharmProjects/TestEnv/PoseCNN"
48 | # Prepare train dataset
49 | train_dataset = PROPSPoseDataset(
50 | root=data_root,
51 | split="train",
52 | download=False # Change to True for the first-time download
53 | )
54 |
55 | # Prepare validation dataset
56 | val_dataset = PROPSPoseDataset(
57 | root=data_root,
58 | split="val",
59 | download=False
60 | )
61 |
62 | # Print dataset sizes
63 | print(f"Dataset sizes: train ({len(train_dataset)}), val ({len(val_dataset)})")
64 |
65 |
66 | reset_seed(0)
67 |
68 | grid_vis = visualize_dataset(val_dataset,alpha = 0.25)
69 | plt.axis('off')
70 | plt.imshow(grid_vis)
71 | plt.show()
72 |
73 | vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
74 |
75 | # Based on PoseCNN section III.B, the output features should
76 | # be 1/8 and 1/16 the input's spatial resolution with 512 channels
77 | print('feature1 expected shape: (N, {}, {}, {})'.format(512, 480//8, 640//8))
78 | print('feature2 expected shape: (N, {}, {}, {})'.format(512, 480//16, 640//16))
79 | print()
80 |
81 | vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
82 | feature_extractor = FeatureExtraction(pretrained_model=vgg16)
83 |
84 | dummy_input = {'rgb': torch.zeros((2,3,480,640))}
85 | feature1, feature2 = feature_extractor(dummy_input)
86 |
87 | print('feature1 shape:', feature1.shape)
88 | print('feature2 shape:', feature2.shape)
89 |
90 |
--------------------------------------------------------------------------------
/image/6d1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d1.png
--------------------------------------------------------------------------------
/image/6d2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d2.png
--------------------------------------------------------------------------------
/image/6d3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d3.png
--------------------------------------------------------------------------------
/image/6d4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d4.png
--------------------------------------------------------------------------------
/image/6d5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d5.png
--------------------------------------------------------------------------------
/image/6d6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/image/6d6.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | import torchvision.models as models
4 | import multiprocessing
5 | import rob599
6 | from pose_cnn import PoseCNN, eval
7 | from rob599 import PROPSPoseDataset
8 | import os
9 | import matplotlib.pyplot as plt
10 | # Check for CUDA availability
11 | if torch.cuda.is_available():
12 | print("Good to go!")
13 | DEVICE = torch.device("cuda")
14 | else:
15 | DEVICE = torch.device("cpu")
16 | rob599.reset_seed(0)
17 | NUM_CLASSES = 10
18 | BATCH_SIZE = 4
19 | NUM_WORKERS = multiprocessing.cpu_count()
20 |
21 | data_root = "/home/yifeng/PycharmProjects/TestEnv/PoseCNN"
22 | # Prepare train dataset
23 | train_dataset = PROPSPoseDataset(
24 | root=data_root,
25 | split="train",
26 | download=False # Change to True for the first-time download
27 | )
28 |
29 | # Prepare validation dataset
30 | val_dataset = PROPSPoseDataset(
31 | root=data_root,
32 | split="val",
33 | download=False
34 | )
35 |
36 |
37 | dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE)
38 |
39 | vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
40 | posecnn_model = PoseCNN(pretrained_backbone = vgg16,
41 | models_pcd = torch.tensor(val_dataset.models_pcd).to(DEVICE, dtype=torch.float32),
42 | cam_intrinsic = val_dataset.cam_intrinsic).to(DEVICE)
43 | posecnn_model.load_state_dict(torch.load(os.path.join("/home/yifeng/PycharmProjects/TestEnv/posecnn_model.pth")))
44 |
45 | num_samples =5
46 | for i in range(num_samples):
47 | out = eval(posecnn_model, dataloader, DEVICE)
48 |
49 | plt.axis('off')
50 | plt.imshow(out)
51 | plt.show()
--------------------------------------------------------------------------------
/p3_helper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | from torchvision.ops import box_iou
5 | import sys, os
6 | import json
7 | import random
8 | import cv2
9 | from PIL import Image
10 | import trimesh
11 | import pyrender
12 | import tqdm
13 |
14 | _HOUGHVOTING_NUM_INLIER = 100
15 | _HOUGHVOTING_DIRECTION_INLIER = 0.9
16 | _LABEL2MASK_THRESHOL = 100
17 |
18 | def hello_helper():
19 | print("Hello from P3_helper.py!")
20 |
21 |
22 | def loss_cross_entropy(scores, labels):
23 | """
24 | scores: a tensor [batch_size, num_classes, height, width]
25 | labels: a tensor [batch_size, num_classes, height, width]
26 | """
27 |
28 | cross_entropy = -torch.sum(labels * torch.log(scores + 1e-10), dim=1)
29 | loss = torch.div(torch.sum(cross_entropy), torch.sum(labels)+1e-10)
30 |
31 | return loss
32 |
33 | def loss_Rotation(pred_R, gt_R, label, model):
34 | """
35 | pred_R: a tensor [N, 3, 3]
36 | gt_R: a tensor [N, 3, 3]
37 | label: a tensor [N, ]
38 | model: a tensor [N_cls, 1024, 3]
39 | """
40 | device = pred_R.device
41 | models_pcd = model[label - 1].to(device)
42 | gt_points = models_pcd @ gt_R
43 | pred_points = models_pcd @ pred_R
44 | loss = ((pred_points - gt_points) ** 2).sum(dim=2).sqrt().mean()
45 | return loss
46 |
47 |
48 | def IOUselection(pred_bbxes, gt_bbxes, threshold):
49 | """
50 | pred_bbx is N_pred_bbx * 6 (batch_ids, x1, y1, x2, y2, cls)
51 | gt_bbx is gt_bbx * 6 (batch_ids, x1, y1, x2, y2, cls)
52 | threshold : threshold of IOU for selection of predicted bbx
53 | """
54 | device = pred_bbxes.device
55 | output_bbxes = torch.empty((0, 6)).to(device = device, dtype =torch.float)
56 | for pred_bbx in pred_bbxes:
57 | for gt_bbx in gt_bbxes:
58 | if pred_bbx[0] == gt_bbx[0] and pred_bbx[5] == gt_bbx[5]:
59 | iou = box_iou(pred_bbx[1:5].unsqueeze(dim=0), gt_bbx[1:5].unsqueeze(dim=0)).item()
60 | if iou > threshold:
61 | output_bbxes = torch.cat((output_bbxes, pred_bbx.unsqueeze(dim=0)), dim=0)
62 | return output_bbxes
63 |
64 |
65 | def HoughVoting(label, centermap, num_classes=10):
66 | """
67 | label [bs, 3, H, W]
68 | centermap [bs, 3*maxinstance, H, W]
69 | """
70 | batches, H, W = label.shape
71 | x = np.linspace(0, W - 1, W)
72 | y = np.linspace(0, H - 1, H)
73 | xv, yv = np.meshgrid(x, y)
74 | xy = torch.from_numpy(np.array((xv, yv))).to(device = label.device, dtype=torch.float32)
75 | x_index = torch.from_numpy(x).to(device = label.device, dtype=torch.int32)
76 | centers = torch.zeros(batches, num_classes, 2)
77 | depths = torch.zeros(batches, num_classes)
78 | for bs in range(batches):
79 | for cls in range(1, num_classes + 1):
80 | if (label[bs] == cls).sum() >= _LABEL2MASK_THRESHOL:
81 | pixel_location = xy[:2, label[bs] == cls]
82 | pixel_direction = centermap[bs, (cls-1)*3:cls*3][:2, label[bs] == cls]
83 | y_index = x_index.unsqueeze(dim=0) - pixel_location[0].unsqueeze(dim=1)
84 | y_index = torch.round(pixel_location[1].unsqueeze(dim=1) + (pixel_direction[1]/pixel_direction[0]).unsqueeze(dim=1) * y_index).to(torch.int32)
85 | mask = (y_index >= 0) * (y_index < H)
86 | count = y_index * W + x_index.unsqueeze(dim=0)
87 | center, inlier_num = torch.bincount(count[mask]).argmax(), torch.bincount(count[mask]).max()
88 | center_x, center_y = center % W, torch.div(center, W, rounding_mode='trunc')
89 | if inlier_num > _HOUGHVOTING_NUM_INLIER:
90 | centers[bs, cls - 1, 0], centers[bs, cls - 1, 1] = center_x, center_y
91 | xyplane_dis = xy - torch.tensor([center_x, center_y])[:, None, None].to(device = label.device)
92 | xyplane_direction = xyplane_dis/(xyplane_dis**2).sum(dim=0).sqrt()[None, :, :]
93 | predict_direction = centermap[bs, (cls-1)*3:cls*3][:2]
94 | inlier_mask = ((xyplane_direction * predict_direction).sum(dim=0).abs() >= _HOUGHVOTING_DIRECTION_INLIER) * label[bs] == cls
95 | depths[bs, cls - 1] = centermap[bs, (cls-1)*3:cls*3][2, inlier_mask].mean()
96 | return centers, depths
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/pose_cnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements the PoseCNN network architecture in PyTorch.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.init import kaiming_normal_
7 | import torchvision.models as models
8 | from torchvision.ops import RoIPool
9 |
10 | import numpy as np
11 | import random
12 | import statistics
13 | import time
14 | from typing import Dict, List, Callable, Optional
15 |
16 | from rob599 import quaternion_to_matrix
17 | from p3_helper import HoughVoting, _LABEL2MASK_THRESHOL, loss_cross_entropy, loss_Rotation, IOUselection
18 |
19 |
20 | def hello_pose_cnn():
21 | """
22 | This is a sample function that we will try to import and run to ensure that
23 | our environment is correctly set up on Google Colab.
24 | """
25 | print("Hello from pose_cnn.py!")
26 |
27 |
28 | class FeatureExtraction(nn.Module):
29 | """
30 | Feature Embedding Module for PoseCNN. Using pretrained VGG16 network as backbone.
31 | """
32 | def __init__(self, pretrained_model):
33 | super(FeatureExtraction, self).__init__()
34 | embedding_layers = list(pretrained_model.features)[:30]
35 | ## Embedding Module from begining till the first output feature map
36 | self.embedding1 = nn.Sequential(*embedding_layers[:23])
37 | ## Embedding Module from the first output feature map till the second output feature map
38 | self.embedding2 = nn.Sequential(*embedding_layers[23:])
39 |
40 | for i in [0, 2, 5, 7, 10, 12, 14]:
41 | self.embedding1[i].weight.requires_grad = False
42 | self.embedding1[i].bias.requires_grad = False
43 |
44 | def forward(self, datadict):
45 | """
46 | feature1: [bs, 512, H/8, W/8]
47 | feature2: [bs, 512, H/16, W/16]
48 | """
49 | feature1 = self.embedding1(datadict['rgb'])
50 | feature2 = self.embedding2(feature1)
51 | return feature1, feature2
52 |
53 | class SegmentationBranch(nn.Module):
54 | """
55 | Instance Segmentation Module for PoseCNN.
56 | """
57 | def __init__(self, num_classes = 10, hidden_layer_dim = 64):
58 | super(SegmentationBranch, self).__init__()
59 |
60 | self.num_classes = num_classes
61 |
62 | self.conv1_feat1 = nn.Conv2d(512, hidden_layer_dim, kernel_size=1)
63 | self.relu1_feat1 = nn.ReLU()
64 | self.conv1_feat2 = nn.Conv2d(512, hidden_layer_dim, kernel_size=1)
65 | self.relu1_feat2 = nn.ReLU()
66 |
67 |
68 | self.upsample_f1tof2 = nn.Upsample(scale_factor=2, mode='bilinear')
69 |
70 | # Using nearest neighbor interpolation to upsample from feature level 2 to full size
71 | # Note: For 'nearest', the align_corners option is not applicable and thus not used
72 | self.upsample_f2tofullsize = nn.Upsample(scale_factor=8, mode='bilinear')
73 | # print("*********", )
74 | self.conv2 = nn.Conv2d(hidden_layer_dim, num_classes + 1, kernel_size=1)
75 | # self.relu2 = nn.ReLU()
76 | self.softmax_Seg = nn.Softmax(dim=1)
77 |
78 | # Initialize layers
79 | nn.init.kaiming_normal_(self.conv1_feat1.weight, nonlinearity='relu')
80 | nn.init.kaiming_normal_(self.conv1_feat2.weight, nonlinearity='relu')
81 | nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='relu')
82 |
83 | nn.init.zeros_(self.conv1_feat1.bias)
84 | nn.init.zeros_(self.conv1_feat2.bias)
85 | nn.init.zeros_(self.conv2.bias)
86 |
87 |
88 |
89 | def forward(self, feature1, feature2):
90 | """
91 | Args:
92 | feature1: Features from feature extraction backbone (B, 512, h, w)
93 | feature2: Features from feature extraction backbone (B, 512, h//2, w//2)
94 | Returns:
95 | probability: Segmentation map of probability for each class at each pixel.
96 | probability size: (B,num_classes+1,H,W)
97 | segmentation: Segmentation map of class id's with highest prob at each pixel.
98 | segmentation size: (B,H,W)
99 | bbx: Bounding boxs detected from the segmentation. Can be extracted
100 | from the predicted segmentation map using self.label2bbx(segmentation).
101 | bbx size: (N,6) with (batch_ids, x1, y1, x2, y2, cls)
102 | """
103 | probability = None
104 | segmentation = None
105 | bbx = None
106 |
107 | # Replace "pass" statement with your code
108 | feature1 = self.relu1_feat1(self.conv1_feat1(feature1))
109 | feature2 = self.relu1_feat2(self.conv1_feat2(feature2))
110 |
111 | feature2 = self.upsample_f1tof2(feature2)
112 |
113 | feature = torch.add(feature1, feature2)
114 | feature = self.upsample_f2tofullsize(feature)
115 |
116 | feature = self.conv2(feature)
117 | # feature = self.relu2(feature)
118 | probability = self.softmax_Seg(feature)
119 |
120 | segmentation = torch.argmax(probability, dim=1)
121 | bbx = self.label2bbx(segmentation)
122 |
123 | return probability, segmentation, bbx
124 |
125 | def label2bbx(self, label):
126 | bbx = []
127 | bs, H, W = label.shape
128 | device = label.device
129 | label_repeat = label.view(bs, 1, H, W).repeat(1, self.num_classes, 1, 1).to(device)
130 | label_target = torch.linspace(0, self.num_classes - 1, steps = self.num_classes).view(1, -1, 1, 1).repeat(bs, 1, H, W).to(device)
131 | mask = (label_repeat == label_target)
132 | for batch_id in range(mask.shape[0]):
133 | for cls_id in range(mask.shape[1]):
134 | if cls_id != 0:
135 | # cls_id == 0 is the background
136 | y, x = torch.where(mask[batch_id, cls_id] != 0)
137 | if y.numel() >= _LABEL2MASK_THRESHOL:
138 | bbx.append([batch_id, torch.min(x).item(), torch.min(y).item(),
139 | torch.max(x).item(), torch.max(y).item(), cls_id])
140 | bbx = torch.tensor(bbx).to(device)
141 | return bbx
142 |
143 |
144 | class TranslationBranch(nn.Module):
145 | """
146 | 3D Translation Estimation Module for PoseCNN.
147 | """
148 | def __init__(self, num_classes = 10, hidden_layer_dim = 128):
149 | super(TranslationBranch, self).__init__()
150 |
151 | # Replace "pass" statement with your code
152 | self.num_classes = num_classes
153 |
154 | self.conv1_feat1 = nn.Conv2d(512, hidden_layer_dim, kernel_size=1)
155 | self.relu1_feat1 = nn.ReLU()
156 | self.conv1_feat2 = nn.Conv2d(512, hidden_layer_dim, kernel_size=1)
157 | self.relu1_feat2 = nn.ReLU()
158 |
159 | self.upsample_f1tof2 = nn.Upsample(scale_factor=2)
160 | self.upsample_f2tofullsize = nn.Upsample(scale_factor=8)
161 | self.conv2 = nn.Conv2d(hidden_layer_dim, num_classes*3, kernel_size=1)
162 |
163 | # Initialize layers
164 | nn.init.kaiming_normal_(self.conv1_feat1.weight, nonlinearity='relu')
165 | nn.init.kaiming_normal_(self.conv1_feat2.weight, nonlinearity='relu')
166 | nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='relu')
167 |
168 | nn.init.zeros_(self.conv1_feat1.bias)
169 | nn.init.zeros_(self.conv1_feat2.bias)
170 | nn.init.zeros_(self.conv2.bias)
171 |
172 |
173 | def forward(self, feature1, feature2):
174 | """
175 | Args:
176 | feature1: Features from feature extraction backbone (B, 512, h, w)
177 | feature2: Features from feature extraction backbone (B, 512, h//2, w//2)
178 | Returns:
179 | translation: Map of object centroid predictions.
180 | translation size: (N,3*num_classes,H,W)
181 | """
182 | translation = None
183 |
184 | # Replace "pass" statement with your code
185 | feature1 = self.relu1_feat1(self.conv1_feat1(feature1))
186 | feature2 = self.relu1_feat2(self.conv1_feat2(feature2))
187 | feature2 = self.upsample_f1tof2(feature2)
188 | feature = torch.add(feature1, feature2)
189 | feature = self.upsample_f2tofullsize(feature)
190 |
191 | translation = self.conv2(feature)
192 |
193 | return translation
194 |
195 | class RotationBranch(nn.Module):
196 | """
197 | 3D Rotation Regression Module for PoseCNN.
198 | """
199 | def __init__(self, feature_dim = 512, roi_shape = 7, hidden_dim = 4096, num_classes = 10):
200 | super(RotationBranch, self).__init__()
201 |
202 |
203 | # Replace "pass" statement with your code
204 | self.roi_feat1 = RoIPool(output_size=roi_shape, spatial_scale=1/8)
205 | self.roi_feat2 = RoIPool(output_size=roi_shape, spatial_scale=1/16)
206 | self.fc1 = nn.Linear(feature_dim*roi_shape*roi_shape, hidden_dim)
207 | self.fc2 = nn.Linear(hidden_dim, hidden_dim)
208 | self.fc3 = nn.Linear(hidden_dim, 4 * num_classes)
209 |
210 |
211 |
212 | def forward(self, feature1, feature2, bbx):
213 | """
214 | Args:
215 | feature1: Features from feature extraction backbone (B, 512, h, w)
216 | feature2: Features from feature extraction backbone (B, 512, h//2, w//2)
217 | bbx: Bounding boxes of regions of interst (N, 5) with (batch_ids, x1, y1, x2, y2)
218 | Returns:
219 | quaternion: Regressed components of a quaternion for each class at each ROI.
220 | quaternion size: (N,4*num_classes)
221 | """
222 | quaternion = None
223 |
224 | # Replace "pass" statement with your code
225 | feature1 = self.roi_feat1.forward(feature1, bbx)
226 | feature2 = self.roi_feat1.forward(feature2, bbx)
227 |
228 | feature = torch.add(feature1, feature2)
229 | feature = feature.view(feature.shape[0], -1)
230 |
231 | feature = self.fc1.forward(feature)
232 | feature = self.fc2.forward(feature)
233 | quaternion = self.fc3.forward(feature)
234 |
235 |
236 | return quaternion
237 |
238 | class PoseCNN(nn.Module):
239 | """
240 | PoseCNN
241 | """
242 | def __init__(self, pretrained_backbone, models_pcd, cam_intrinsic):
243 | super(PoseCNN, self).__init__()
244 |
245 | self.iou_threshold = 0.7
246 | self.models_pcd = models_pcd
247 | self.cam_intrinsic = cam_intrinsic
248 |
249 |
250 | # Replace "pass" statement with your code
251 | self.feature_extraction = FeatureExtraction(pretrained_backbone)
252 | self.segmentation_branch = SegmentationBranch()
253 | self.translation_branch = TranslationBranch()
254 | self.rotation_branch = RotationBranch()
255 |
256 |
257 |
258 | def forward(self, input_dict):
259 | """
260 | input_dict = {
261 | 'rgb',
262 | 'depth',
263 | 'objs_id',
264 | 'mask',
265 | 'bbx',
266 | 'RTs'
267 | }
268 | """
269 |
270 |
271 | if self.training:
272 | loss_dict = {
273 | "loss_segmentation": 0,
274 | "loss_centermap": 0,
275 | "loss_R": 0
276 | }
277 |
278 | gt_bbx = self.getGTbbx(input_dict)
279 |
280 | # Important: the rotation loss should be calculated only for regions
281 | # of interest that match with a ground truth object instance.
282 | # Note that the helper function, IOUselection, may be used for
283 | # identifying the predicted regions of interest with acceptable IOU
284 | # with the ground truth bounding boxes.
285 | # If no ROIs result from the selection, don't compute the loss_R
286 |
287 | # Replace "pass" statement with your code
288 | # Feature extraction
289 | feature1, feature2 = self.feature_extraction.forward(input_dict)
290 |
291 | # loss_dict["loss_segmentation"] and calculated using the loss_cross_entropy(.) function.
292 | probability, segmentation, bbx = self.segmentation_branch.forward(feature1, feature2)
293 | loss_dict["loss_segmentation"] = loss_cross_entropy(probability, input_dict["label"])
294 |
295 | # The training loss for translation should be stored in loss_dict["loss_centermap"] using the L1loss function.
296 | translation = self.translation_branch.forward(feature1, feature2)
297 | loss_Translation = nn.L1Loss()
298 | # loss_dict["loss_centermap"] using the L1loss function.
299 | loss_dict["loss_centermap"] = loss_Translation(translation, input_dict["centermaps"])
300 |
301 | # The training loss for rotation should be stored in loss_dict["loss_R"] using the given loss_Rotation function.
302 | if bbx.numel() > 0:
303 | sel_bbx = IOUselection(bbx.to(feature1), gt_bbx, threshold=self.iou_threshold)
304 |
305 | if sel_bbx.numel() > 0:
306 | quaternion = self.rotation_branch.forward(feature1, feature2, sel_bbx[:, :-1])
307 | pred_Rs, label = self.estimateRotation(quaternion, sel_bbx)
308 | gt_Rs = self.gtRotation(sel_bbx, input_dict)
309 | loss_dict["loss_R"] = loss_Rotation(pred_Rs, gt_Rs, label, self.models_pcd)
310 |
311 |
312 |
313 | return loss_dict
314 | else:
315 | output_dict = None
316 | segmentation = None
317 |
318 | with torch.no_grad():
319 |
320 | # Replace "pass" statement with your code
321 | feature1, feature2 = self.feature_extraction.forward(input_dict)
322 | probability, segmentation, bbx = self.segmentation_branch.forward(feature1, feature2)
323 | translation = self.translation_branch.forward(feature1, feature2)
324 | pred_centers, depths = HoughVoting(segmentation, translation)
325 |
326 | quaternion = self.rotation_branch.forward(feature1, feature2, bbx[:, :-1].to(dtype=feature1.dtype))
327 | pred_Rs, _ = self.estimateRotation(quaternion, bbx)
328 |
329 |
330 | output_dict = self.generate_pose(pred_Rs, pred_centers, depths, bbx)
331 |
332 |
333 | return output_dict, segmentation
334 |
335 | def estimateTrans(self, translation_map, filter_bbx, pred_label):
336 | """
337 | translation_map: a tensor [batch_size, num_classes * 3, height, width]
338 | filter_bbx: N_filter_bbx * 6 (batch_ids, x1, y1, x2, y2, cls)
339 | label: a tensor [batch_size, num_classes, height, width]
340 | """
341 | N_filter_bbx = filter_bbx.shape[0]
342 | pred_Ts = torch.zeros(N_filter_bbx, 3)
343 | for idx, bbx in enumerate(filter_bbx):
344 | batch_id = int(bbx[0].item())
345 | cls = int(bbx[5].item())
346 | trans_map = translation_map[batch_id, (cls-1) * 3 : cls * 3, :]
347 | label = (pred_label[batch_id] == cls).detach()
348 | pred_T = trans_map[:, label].mean(dim=1)
349 | pred_Ts[idx] = pred_T
350 | return pred_Ts
351 |
352 | def gtTrans(self, filter_bbx, input_dict):
353 | N_filter_bbx = filter_bbx.shape[0]
354 | gt_Ts = torch.zeros(N_filter_bbx, 3)
355 | for idx, bbx in enumerate(filter_bbx):
356 | batch_id = int(bbx[0].item())
357 | cls = int(bbx[5].item())
358 | gt_Ts[idx] = input_dict['RTs'][batch_id][cls - 1][:3, [3]].T
359 | return gt_Ts
360 |
361 | def getGTbbx(self, input_dict):
362 | """
363 | bbx is N*6 (batch_ids, x1, y1, x2, y2, cls)
364 | """
365 | gt_bbx = []
366 | objs_id = input_dict['objs_id']
367 | device = objs_id.device
368 | ## [x_min, y_min, width, height]
369 | bbxes = input_dict['bbx']
370 | for batch_id in range(bbxes.shape[0]):
371 | for idx, obj_id in enumerate(objs_id[batch_id]):
372 | if obj_id.item() != 0:
373 | # the obj appears in this image
374 | bbx = bbxes[batch_id][idx]
375 | gt_bbx.append([batch_id, bbx[0].item(), bbx[1].item(),
376 | bbx[0].item() + bbx[2].item(), bbx[1].item() + bbx[3].item(), obj_id.item()])
377 | return torch.tensor(gt_bbx).to(device=device, dtype=torch.int16)
378 |
379 | def estimateRotation(self, quaternion_map, filter_bbx):
380 | """
381 | quaternion_map: a tensor [batch_size, num_classes * 3, height, width]
382 | filter_bbx: N_filter_bbx * 6 (batch_ids, x1, y1, x2, y2, cls)
383 | """
384 | N_filter_bbx = filter_bbx.shape[0]
385 | pred_Rs = torch.zeros(N_filter_bbx, 3, 3)
386 | label = []
387 | for idx, bbx in enumerate(filter_bbx):
388 | batch_id = int(bbx[0].item())
389 | cls = int(bbx[5].item())
390 | quaternion = quaternion_map[idx, (cls-1) * 4 : cls * 4]
391 | quaternion = nn.functional.normalize(quaternion, dim=0)
392 | pred_Rs[idx] = quaternion_to_matrix(quaternion)
393 | label.append(cls)
394 | label = torch.tensor(label)
395 | return pred_Rs, label
396 |
397 | def gtRotation(self, filter_bbx, input_dict):
398 | N_filter_bbx = filter_bbx.shape[0]
399 | gt_Rs = torch.zeros(N_filter_bbx, 3, 3)
400 | for idx, bbx in enumerate(filter_bbx):
401 | batch_id = int(bbx[0].item())
402 | cls = int(bbx[5].item())
403 | gt_Rs[idx] = input_dict['RTs'][batch_id][cls - 1][:3, :3]
404 | return gt_Rs
405 |
406 | def generate_pose(self, pred_Rs, pred_centers, pred_depths, bbxs):
407 | """
408 | pred_Rs: a tensor [pred_bbx_size, 3, 3]
409 | pred_centers: [batch_size, num_classes, 2]
410 | pred_depths: a tensor [batch_size, num_classes]
411 | bbx: a tensor [pred_bbx_size, 6]
412 | """
413 | output_dict = {}
414 | for idx, bbx in enumerate(bbxs):
415 | bs, _, _, _, _, obj_id = bbx
416 | R = pred_Rs[idx].numpy()
417 | center = pred_centers[bs, obj_id - 1].numpy()
418 | depth = pred_depths[bs, obj_id - 1].numpy()
419 | if (center**2).sum().item() != 0:
420 | T = np.linalg.inv(self.cam_intrinsic) @ np.array([center[0], center[1], 1]) * depth
421 | T = T[:, np.newaxis]
422 | if bs.item() not in output_dict:
423 | output_dict[bs.item()] = {}
424 | output_dict[bs.item()][obj_id.item()] = np.vstack((np.hstack((R, T)), np.array([[0, 0, 0, 1]])))
425 | return output_dict
426 |
427 |
428 | def eval(model, dataloader, device, alpha = 0.35):
429 | import cv2
430 | model.eval()
431 |
432 | sample_idx = random.randint(0,len(dataloader.dataset)-1)
433 | ## image version vis
434 | rgb = torch.tensor(dataloader.dataset[sample_idx]['rgb'][None, :]).to(device)
435 | inputdict = {'rgb': rgb}
436 | pose_dict, label = model(inputdict)
437 | poselist = []
438 | rgb = (rgb[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
439 | return dataloader.dataset.visualizer.vis_oneview(
440 | ipt_im = rgb,
441 | obj_pose_dict = pose_dict[0],
442 | alpha = alpha
443 | )
444 |
445 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.7.2
2 | numpy==1.24.3
3 | Pillow==9.4.0
4 | Pillow==11.0.0
5 | pyrender==0.1.45
6 | torch==2.3.1
7 | torchvision==0.18.1
8 | tqdm==4.66.4
9 | trimesh==4.4.3
10 |
--------------------------------------------------------------------------------
/rob599/PROPSPoseDataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from typing import Any, Callable, Optional, Tuple
4 | import random
5 |
6 | import cv2
7 | from PIL import Image
8 |
9 | import torch
10 | import numpy as np
11 | from torch.utils.data import Dataset
12 | from torchvision.datasets.utils import download_and_extract_archive
13 |
14 | from rob599 import Visualize, chromatic_transform, add_noise
15 |
16 |
17 |
18 | class PROPSPoseDataset(Dataset):
19 |
20 | base_folder = "PROPS-Pose-Dataset"
21 | url = "https://drive.google.com/file/d/15rhwXhzHGKtBcxJAYMWJG7gN7BLLhyAq/view?usp=share_link"
22 | filename = "PROPS-Pose-Dataset.tar.gz"
23 | tgz_md5 = "a0c39fe326377dacd1d652f9fe11a7f4"
24 |
25 | def __init__(
26 | self,
27 | root: str,
28 | split: str = 'train',
29 | download: bool = False,
30 | ) -> None:
31 | assert split in ['train', 'val']
32 |
33 | self.root = root
34 | self.split = split
35 | self.dataset_dir = os.path.join(self.root, self.base_folder)
36 |
37 | if download:
38 | self.download()
39 |
40 |
41 |
42 | ## parameter
43 | self.max_instance_num = 10
44 | self.H = 480
45 | self.W = 640
46 | self.rgb_aug_prob = 0.4
47 | self.cam_intrinsic = np.array([
48 | [902.19, 0.0, 342.35],
49 | [0.0, 902.39, 252.23],
50 | [0.0, 0.0, 1.0]])
51 | self.resolution = [640, 480]
52 |
53 | self.all_lst = self.parse_dir()
54 | self.shuffle()
55 | self.models_pcd = self.parse_model()
56 |
57 |
58 | self.obj_id_list = [
59 | 1, # master chef
60 | 2, # cracker box
61 | 3, # sugar box
62 | 4, # soup can
63 | 5, # mustard bottle
64 | 6, # tuna can
65 | 8, # jello box
66 | 9, # meat can
67 | 14,# mug
68 | 18 # marker
69 | ]
70 | self.id2label = {}
71 | for idx, id in enumerate(self.obj_id_list):
72 | self.id2label[id] = idx + 1
73 |
74 | def parse_dir(self):
75 | data_dir = os.path.join(self.dataset_dir, self.split)
76 | rgb_path = os.path.join(data_dir, "rgb")
77 | depth_path = os.path.join(data_dir, "depth")
78 | mask_path = os.path.join(data_dir, "mask_visib")
79 | scene_gt_json = os.path.join(data_dir, self.split+"_gt.json")
80 | scene_gt_info_json = os.path.join(data_dir, self.split+"_gt_info.json")
81 | rgb_list = os.listdir(rgb_path)
82 | rgb_list.sort()
83 | depth_list = os.listdir(depth_path)
84 | depth_list.sort()
85 | mask_list = os.listdir(mask_path)
86 | mask_list.sort()
87 | scene_gt = json.load(open(scene_gt_json))
88 | scene_gt_info = json.load(open(scene_gt_info_json))
89 | assert len(rgb_list) == len(depth_list) == len(scene_gt) == len(scene_gt_info), "data files number mismatching"
90 | all_lst = []
91 | for rgb_file in rgb_list:
92 | idx = int(rgb_file.split(".png")[0])
93 | depth_file = f"{idx:06d}.png"
94 | scene_objs_gt = scene_gt[str(idx)]
95 | scene_objs_info_gt = scene_gt_info[str(idx)]
96 | objs_dict = {}
97 | for obj_idx in range(len(scene_objs_gt)):
98 | objs_dict[obj_idx] = {}
99 | objs_dict[obj_idx]['R'] = np.array(scene_objs_gt[obj_idx]['cam_R_m2c']).reshape(3, 3)
100 | objs_dict[obj_idx]['T'] = np.array(scene_objs_gt[obj_idx]['cam_t_m2c']).reshape(3, 1)
101 | objs_dict[obj_idx]['obj_id'] = scene_objs_gt[obj_idx]['obj_id']
102 | objs_dict[obj_idx]['bbox_visib'] = scene_objs_info_gt[obj_idx]['bbox_visib']
103 | assert f"{idx:006d}_{obj_idx:06d}.png" in mask_list
104 | objs_dict[obj_idx]['visible_mask_path'] = os.path.join(mask_path, f"{idx:006d}_{obj_idx:06d}.png")
105 | """
106 | obj_sample = (rgb_path, depth_path, objs_dict)
107 | objs_dict = {
108 | 0: {
109 | cam_R_m2c:
110 | cam_t_m2c:
111 | obj_id:
112 | bbox_visib:
113 | visiable_mask_path:
114 | }
115 | ...
116 | }
117 | """
118 | obj_sample = (
119 | os.path.join(rgb_path, rgb_file),
120 | os.path.join(depth_path, depth_file),
121 | objs_dict
122 | )
123 | all_lst.append(obj_sample)
124 | return all_lst
125 |
126 | def parse_model(self):
127 | model_path = os.path.join(self.dataset_dir, "model")
128 | objpathdict = {
129 | 1: ["master_chef_can", os.path.join(model_path, "1_master_chef_can", "textured_simple.obj")],
130 | 2: ["cracker_box", os.path.join(model_path, "2_cracker_box", "textured_simple.obj")],
131 | 3: ["sugar_box", os.path.join(model_path, "3_sugar_box", "textured_simple.obj")],
132 | 4: ["tomato_soup_can", os.path.join(model_path, "4_tomato_soup_can", "textured_simple.obj")],
133 | 5: ["mustard_bottle", os.path.join(model_path, "5_mustard_bottle", "textured_simple.obj")],
134 | 6: ["tuna_fish_can", os.path.join(model_path, "6_tuna_fish_can", "textured_simple.obj")],
135 | 7: ["gelatin_box", os.path.join(model_path, "8_gelatin_box", "textured_simple.obj")],
136 | 8: ["potted_meat_can", os.path.join(model_path, "9_potted_meat_can", "textured_simple.obj")],
137 | 9: ["mug", os.path.join(model_path, "14_mug", "textured_simple.obj")],
138 | 10: ["large_marker", os.path.join(model_path, "18_large_marker", "textured_simple.obj")],
139 | }
140 | self.visualizer = Visualize(objpathdict, self.cam_intrinsic, self.resolution)
141 | models_pcd_dict = {index:np.array(self.visualizer.objnode[index]['mesh'].vertices) for index in self.visualizer.objnode}
142 | models_pcd = np.zeros((len(models_pcd_dict), 1024, 3))
143 | for m in models_pcd_dict:
144 | model = models_pcd_dict[m]
145 | models_pcd[m - 1] = model[np.random.randint(0, model.shape[0], 1024)]
146 | return models_pcd
147 |
148 | def __len__(self):
149 | return len(self.all_lst)
150 |
151 | def __getitem__(self, idx):
152 | """
153 | obj_sample = (rgb_path, depth_path, objs_dict)
154 | objs_dict = {
155 | 0: {
156 | cam_R_m2c:
157 | cam_t_m2c:
158 | obj_id:
159 | bbox_visib:
160 | visiable_mask_path:
161 | }
162 | ...
163 | }
164 |
165 | data_dict = {
166 | 'rgb',
167 | 'depth',
168 | 'objs_id',
169 | 'mask',
170 | 'bbx',
171 | 'RTs',
172 | 'centermaps', []
173 | }
174 | """
175 | rgb_path, depth_path, objs_dict = self.all_lst[idx]
176 | data_dict = {}
177 | with Image.open(rgb_path) as im:
178 | rgb = np.array(im)
179 |
180 | if self.split == 'train' and np.random.rand(1) > 1 - self.rgb_aug_prob:
181 | rgb = chromatic_transform(rgb)
182 | rgb = add_noise(rgb)
183 | rgb = rgb.astype(np.float32)/255
184 | data_dict['rgb'] = rgb.transpose((2,0,1))
185 |
186 | with Image.open(depth_path) as im:
187 | data_dict['depth'] = np.array(im)[np.newaxis, :]
188 | ## TODO data-augmentation of depth
189 | assert(len(objs_dict) <= self.max_instance_num)
190 | objs_id = np.zeros(self.max_instance_num, dtype=np.uint8)
191 | label = np.zeros((self.max_instance_num + 1, self.H, self.W), dtype=bool)
192 | bbx = np.zeros((self.max_instance_num, 4))
193 | RTs = np.zeros((self.max_instance_num, 3, 4))
194 | centers = np.zeros((self.max_instance_num, 2))
195 | centermaps = np.zeros((self.max_instance_num, 3, self.resolution[1], self.resolution[0]))
196 | ## test
197 | img = cv2.imread(rgb_path)
198 |
199 | for idx in objs_dict.keys():
200 | if len(objs_dict[idx]['bbox_visib']) > 0:
201 | ## have visible mask
202 | objs_id[idx] = self.id2label[objs_dict[idx]['obj_id']]
203 | assert(objs_id[idx] > 0)
204 | with Image.open(objs_dict[idx]['visible_mask_path']) as im:
205 | label[objs_id[idx]] = np.array(im, dtype=bool)
206 | ## [x_min, y_min, width, height]
207 | bbx[idx] = objs_dict[idx]['bbox_visib']
208 | RT = np.zeros((4, 4))
209 | RT[3, 3] = 1
210 | RT[:3, :3] = objs_dict[idx]['R']
211 | RT[:3, [3]] = objs_dict[idx]['T']
212 | RT = np.linalg.inv(RT)
213 | RTs[idx] = RT[:3]
214 | center_homo = self.cam_intrinsic @ RT[:3, [3]]
215 | center = center_homo[:2]/center_homo[2]
216 | x = np.linspace(0, self.resolution[0] - 1, self.resolution[0])
217 | y = np.linspace(0, self.resolution[1] - 1, self.resolution[1])
218 | xv, yv = np.meshgrid(x, y)
219 | dx, dy = center[0] - xv, center[1] - yv
220 | distance = np.sqrt(dx ** 2 + dy ** 2)
221 | nx, ny = dx / distance, dy / distance
222 | Tz = np.ones((self.resolution[1], self.resolution[0])) * RT[2, 3]
223 | centermaps[idx] = np.array([nx, ny, Tz])
224 | ## test
225 | img = cv2.circle(img, (int(center[0]), int(center[1])), radius=2, color=(0, 0, 255), thickness = -1)
226 | centers[idx] = np.array([int(center[0]), int(center[1])])
227 | label[0] = 1 - label[1:].sum(axis=0)
228 | # Image.fromarray(label[0].astype(np.uint8) * 255).save("testlabel.png")
229 | # Image.open(rgb_path).save("testrgb.png")
230 | # cv2.imwrite("testcenter.png", img)
231 | data_dict['objs_id'] = objs_id
232 | data_dict['label'] = label
233 | data_dict['bbx'] = bbx
234 | data_dict['RTs'] = RTs
235 | data_dict['centermaps'] = centermaps.reshape(-1, self.resolution[1], self.resolution[0])
236 | data_dict['centers'] = centers
237 | return data_dict
238 |
239 | def shuffle(self):
240 | random.shuffle(self.all_lst)
241 |
242 | def download(self) -> None:
243 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
244 |
--------------------------------------------------------------------------------
/rob599/__init__.py:
--------------------------------------------------------------------------------
1 | from . import grad, submit
2 | from .utils import reset_seed, tensor_to_image, visualize_dataset, chromatic_transform, add_noise, Visualize, quaternion_to_matrix, format_gt_RTs
3 | from .PROPSPoseDataset import PROPSPoseDataset
4 |
--------------------------------------------------------------------------------
/rob599/__pycache__/PROPSPoseDataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/PROPSPoseDataset.cpython-310.pyc
--------------------------------------------------------------------------------
/rob599/__pycache__/PROPSPoseDataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/PROPSPoseDataset.cpython-38.pyc
--------------------------------------------------------------------------------
/rob599/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/rob599/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/rob599/__pycache__/grad.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/grad.cpython-310.pyc
--------------------------------------------------------------------------------
/rob599/__pycache__/grad.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/grad.cpython-38.pyc
--------------------------------------------------------------------------------
/rob599/__pycache__/submit.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/submit.cpython-38.pyc
--------------------------------------------------------------------------------
/rob599/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/rob599/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IlikeSukiyaki/PoseCNN_pytorch/e60a4dd6e32550893ed0fe660903efedf5b56c24/rob599/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/rob599/grad.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 |
5 | import rob599
6 |
7 | """ Utilities for computing and checking gradients. """
8 |
9 |
10 | def grad_check_sparse(f, x, analytic_grad, num_checks=10, h=1e-7):
11 | """
12 | Utility function to perform numeric gradient checking. We use the centered
13 | difference formula to compute a numeric derivative:
14 |
15 | f'(x) =~ (f(x + h) - f(x - h)) / (2h)
16 |
17 | Rather than computing a full numeric gradient, we sparsely sample a few
18 | dimensions along which to compute numeric derivatives.
19 |
20 | Inputs:
21 | - f: A function that inputs a torch tensor and returns a torch scalar
22 | - x: A torch tensor of the point at which to evaluate the numeric gradient
23 | - analytic_grad: A torch tensor giving the analytic gradient of f at x
24 | - num_checks: The number of dimensions along which to check
25 | - h: Step size for computing numeric derivatives
26 | """
27 | # fix random seed to 0
28 | rob599.reset_seed(0)
29 | for i in range(num_checks):
30 |
31 | ix = tuple([random.randrange(m) for m in x.shape])
32 |
33 | oldval = x[ix].item()
34 | x[ix] = oldval + h # increment by h
35 | fxph = f(x).item() # evaluate f(x + h)
36 | x[ix] = oldval - h # increment by h
37 | fxmh = f(x).item() # evaluate f(x - h)
38 | x[ix] = oldval # reset
39 |
40 | grad_numerical = (fxph - fxmh) / (2 * h)
41 | grad_analytic = analytic_grad[ix]
42 | rel_error_top = abs(grad_numerical - grad_analytic)
43 | rel_error_bot = abs(grad_numerical) + abs(grad_analytic) + 1e-12
44 | rel_error = rel_error_top / rel_error_bot
45 | msg = "numerical: %f analytic: %f, relative error: %e"
46 | print(msg % (grad_numerical, grad_analytic, rel_error))
47 |
48 |
49 | def compute_numeric_gradient(f, x, dLdf=None, h=1e-7):
50 | """
51 | Compute the numeric gradient of f at x using a finite differences
52 | approximation. We use the centered difference:
53 |
54 | df f(x + h) - f(x - h)
55 | -- ~= -------------------
56 | dx 2 * h
57 |
58 | Function can also expand this easily to intermediate layers using the
59 | chain rule:
60 |
61 | dL df dL
62 | -- = -- * --
63 | dx dx df
64 |
65 | Inputs:
66 | - f: A function that inputs a torch tensor and returns a torch scalar
67 | - x: A torch tensor giving the point at which to compute the gradient
68 | - dLdf: optional upstream gradient for intermediate layers
69 | - h: epsilon used in the finite difference calculation
70 | Returns:
71 | - grad: A tensor of the same shape as x giving the gradient of f at x
72 | """
73 | flat_x = x.contiguous().flatten()
74 | grad = torch.zeros_like(x)
75 | flat_grad = grad.flatten()
76 |
77 | # Initialize upstream gradient to be ones if not provide
78 | if dLdf is None:
79 | y = f(x)
80 | dLdf = torch.ones_like(y)
81 | dLdf = dLdf.flatten()
82 |
83 | # iterate over all indexes in x
84 | for i in range(flat_x.shape[0]):
85 | oldval = flat_x[i].item() # Store the original value
86 | flat_x[i] = oldval + h # Increment by h
87 | fxph = f(x).flatten() # Evaluate f(x + h)
88 | flat_x[i] = oldval - h # Decrement by h
89 | fxmh = f(x).flatten() # Evaluate f(x - h)
90 | flat_x[i] = oldval # Restore original value
91 |
92 | # compute the partial derivative with centered formula
93 | dfdxi = (fxph - fxmh) / (2 * h)
94 |
95 | # use chain rule to compute dLdx
96 | flat_grad[i] = dLdf.dot(dfdxi).item()
97 |
98 | # Note that since flat_grad was only a reference to grad,
99 | # we can just return the object in the shape of x by returning grad
100 | return grad
101 |
102 |
103 | def rel_error(x, y, eps=1e-10):
104 | """
105 | Compute the relative error between a pair of tensors x and y,
106 | which is defined as:
107 |
108 | max_i |x_i - y_i]|
109 | rel_error(x, y) = -------------------------------
110 | max_i |x_i| + max_i |y_i| + eps
111 |
112 | Inputs:
113 | - x, y: Tensors of the same shape
114 | - eps: Small positive constant for numeric stability
115 |
116 | Returns:
117 | - rel_error: Scalar giving the relative error between x and y
118 | """
119 | """ returns relative error between x and y """
120 | top = (x - y).abs().max().item()
121 | bot = (x.abs() + y.abs()).clamp(min=eps).max().item()
122 | return top / bot
123 |
--------------------------------------------------------------------------------
/rob599/submit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 |
4 | _P3_FILES = [
5 | "pose_cnn.py",
6 | "pose_estimation.ipynb"
7 | ]
8 |
9 |
10 | def make_p3_submission(assignment_path, uniquename=None, umid=None):
11 | _make_submission(assignment_path, _P3_FILES, "P3", uniquename, umid)
12 |
13 |
14 | def _make_submission(
15 | assignment_path, file_list, assignment_no, uniquename=None, umid=None
16 | ):
17 | if uniquename is None or umid is None:
18 | uniquename, umid = _get_user_info()
19 | zip_path = "{}_{}_{}.zip".format(uniquename, umid, assignment_no)
20 | zip_path = os.path.join(assignment_path, zip_path)
21 | print("Writing zip file to: ", zip_path)
22 | with zipfile.ZipFile(zip_path, "w") as zf:
23 | for filename in file_list:
24 | if filename.startswith('rob599/'):
25 | filename_out = filename.split('/')[-1]
26 | else:
27 | filename_out = filename
28 | in_path = os.path.join(assignment_path, filename)
29 | if not os.path.isfile(in_path):
30 | raise ValueError('Could not find file "%s"' % filename)
31 | zf.write(in_path, filename_out)
32 |
33 |
34 | def _get_user_info():
35 | uniquename = None
36 | umid = None
37 | if uniquename is None:
38 | uniquename = input("Enter your uniquename (e.g. topipari): ")
39 | if umid is None:
40 | umid = input("Enter your umid (e.g. 12345678): ")
41 | return uniquename, umid
42 |
--------------------------------------------------------------------------------
/rob599/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import cv2
4 | import matplotlib as mpl
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | from torchvision.utils import make_grid
9 |
10 | from torchvision.ops import box_iou
11 | import sys, os
12 | import trimesh
13 | import pyrender
14 | import tqdm
15 |
16 | """
17 | General utilities to help with implementation
18 | """
19 |
20 |
21 | def reset_seed(number):
22 | """
23 | Reset random seed to the specific number
24 |
25 | Inputs:
26 | - number: A seed number to use
27 | """
28 | random.seed(number)
29 | np.random.seed(number)
30 | torch.manual_seed(number)
31 | return
32 |
33 |
34 | def tensor_to_image(tensor):
35 | """
36 | Convert a torch tensor into a numpy ndarray for visualization.
37 |
38 | Inputs:
39 | - tensor: A torch tensor of shape (3, H, W) with
40 | elements in the range [0, 1]
41 |
42 | Returns:
43 | - ndarr: A uint8 numpy array of shape (H, W, 3)
44 | """
45 | tensor = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
46 | ndarr = tensor.to("cpu", torch.uint8).numpy()
47 | return ndarr
48 |
49 | def format_gt_RTs(RTs):
50 | return {idx+1: np.concatenate((RTs[idx],[[0,0,0,1]])) for idx in range(len(RTs))}
51 |
52 | def visualize_dataset(pose_dataset, num_samples = 4, alpha = 0.5):
53 | """
54 | Make a grid-shape image to plot
55 |
56 | Inputs:
57 | - pose_dataset: instance of PROPSPoseDataset
58 |
59 | Outputs:
60 | - A grid-image that visualize num_samples
61 | number of image and pose label samples
62 | """
63 | plt.text(300, -40, 'RGB', ha="center")
64 | plt.text(950, -40, 'Pose', ha="center")
65 | plt.text(1600, -40, 'Depth', ha="center")
66 | plt.text(2250, -40, 'Segmentation', ha="center")
67 | plt.text(2900, -40, 'Centermaps[0]', ha="center")
68 |
69 | samples = []
70 | for sample_i in range(num_samples):
71 | sample_idx = random.randint(0,len(pose_dataset)-1)
72 | sample = pose_dataset[sample_idx]
73 | rgb = (sample['rgb'].transpose(1, 2, 0) * 255).astype(np.uint8)
74 | depth = ((np.tile(sample['depth'], (3, 1, 1)) / sample['depth'].max()) * 255).astype(np.uint8)
75 | segmentation = (sample['label']*np.arange(11).reshape((11,1,1))).sum(0,keepdims=True).astype(np.float64)
76 | segmentation /= segmentation.max()
77 | segmentation = (np.tile(segmentation, (3, 1, 1)) * 255).astype(np.uint8)
78 | ctrs = sample['centermaps'].reshape(10,3,480,640)[0]
79 | ctrs -= ctrs.min()
80 | ctrs /= ctrs.max()
81 | ctrs = (ctrs * 255).astype(np.uint8)
82 | pose_dict = format_gt_RTs(sample['RTs'])
83 | render = pose_dataset.visualizer.vis_oneview(
84 | ipt_im = rgb,
85 | obj_pose_dict = pose_dict,
86 | alpha = alpha
87 | )
88 | samples.append(torch.tensor(rgb.transpose(2, 0, 1)))
89 | samples.append(torch.tensor(render.transpose(2, 0, 1)))
90 | samples.append(torch.tensor(depth))
91 | samples.append(torch.tensor(segmentation))
92 | samples.append(torch.tensor(ctrs))
93 | img = make_grid(samples, nrow=5).permute(1, 2, 0)
94 | return img
95 |
96 |
97 |
98 | def chromatic_transform(image):
99 | """
100 | Add the hue, saturation and luminosity to the image.
101 |
102 | This is adapted from implicit-depth repository, ref: https://github.com/NVlabs/implicit_depth/blob/main/src/utils/data_augmentation.py
103 |
104 | Parameters
105 | ----------
106 |
107 | image: array, required, the given image.
108 |
109 | Returns
110 | -------
111 |
112 | The new image after augmentation in HLS space.
113 | """
114 | # Set random hue, luminosity and saturation which ranges from -0.1 to 0.1
115 | d_h = (np.random.rand(1) - 0.5) * 0.1 * 180
116 | d_l = (np.random.rand(1) - 0.5) * 0.2 * 256
117 | d_s = (np.random.rand(1) - 0.5) * 0.2 * 256
118 | # Convert the BGR to HLS
119 | hls = cv2.cvtColor(image, cv2.COLOR_BGR2HLS)
120 | h, l, s = cv2.split(hls)
121 | # Add the values to the image H, L, S
122 | new_h = (h + d_h) % 180
123 | new_l = np.clip(l + d_l, 0, 255)
124 | new_s = np.clip(s + d_s, 0, 255)
125 | # Convert the HLS to BGR
126 | new_hls = cv2.merge((new_h, new_l, new_s)).astype('uint8')
127 | new_image = cv2.cvtColor(new_hls, cv2.COLOR_HLS2BGR)
128 | return new_image
129 |
130 |
131 |
132 | def add_noise(image, level = 0.1):
133 | """
134 | Add noise to the image.
135 |
136 | This is adapted from implicit-depth repository, ref: https://github.com/NVlabs/implicit_depth/blob/main/src/utils/data_augmentation.py
137 |
138 | Parameters
139 | ----------
140 |
141 | image: array, required, the given image;
142 |
143 | level: float, optional, default: 0.1, the maximum noise level.
144 |
145 | Returns
146 | -------
147 |
148 | The new image after augmentation of adding noises.
149 | """
150 | # random number
151 | r = np.random.rand(1)
152 |
153 | # gaussian noise
154 | if r < 0.9:
155 | row,col,ch= image.shape
156 | mean = 0
157 | noise_level = random.uniform(0, level)
158 | sigma = np.random.rand(1) * noise_level * 256
159 | gauss = sigma * np.random.randn(row,col) + mean
160 | gauss = np.repeat(gauss[:, :, np.newaxis], ch, axis=2)
161 | noisy = image + gauss
162 | noisy = np.clip(noisy, 0, 255)
163 | else:
164 | # motion blur
165 | sizes = [3, 5, 7, 9, 11, 15]
166 | size = sizes[int(np.random.randint(len(sizes), size=1))]
167 | kernel_motion_blur = np.zeros((size, size))
168 | if np.random.rand(1) < 0.5:
169 | kernel_motion_blur[int((size-1)/2), :] = np.ones(size)
170 | else:
171 | kernel_motion_blur[:, int((size-1)/2)] = np.ones(size)
172 | kernel_motion_blur = kernel_motion_blur / size
173 | noisy = cv2.filter2D(image, -1, kernel_motion_blur)
174 |
175 | return noisy.astype('uint8')
176 |
177 |
178 | class Visualize:
179 | def __init__(self, object_dict, cam_intrinsic, resolution):
180 | '''
181 | object_dict is a dict store object labels, object names and object model path,
182 | example:
183 | object_dict = {
184 | 1: ["beaker_1", path]
185 | 2: ["dropper_1", path]
186 | 3: ["dropper_2", path]
187 | }
188 | '''
189 | self.objnode = {}
190 | self.render = pyrender.OffscreenRenderer(resolution[0], resolution[1])
191 | self.scene = pyrender.Scene()
192 | cam = pyrender.camera.IntrinsicsCamera(cam_intrinsic[0, 0],
193 | cam_intrinsic[1, 1],
194 | cam_intrinsic[0, 2],
195 | cam_intrinsic[1, 2],
196 | znear=0.05, zfar=100.0, name=None)
197 | self.intrinsic = cam_intrinsic
198 | Axis_align = np.array([[1, 0, 0, 0],
199 | [0, -1, 0, 0],
200 | [0, 0, -1, 0],
201 | [0, 0, 0, 1]])
202 | self.nc = pyrender.Node(camera=cam, matrix=Axis_align)
203 | self.scene.add_node(self.nc)
204 |
205 | for obj_label in object_dict:
206 | objname = object_dict[obj_label][0]
207 | objpath = object_dict[obj_label][1]
208 | tm = trimesh.load(objpath)
209 | mesh = pyrender.Mesh.from_trimesh(tm, smooth = False)
210 | node = pyrender.Node(mesh=mesh, matrix=np.eye(4))
211 | node.mesh.is_visible = False
212 | self.objnode[obj_label] = {"name":objname, "node":node, "mesh":tm}
213 | self.scene.add_node(node)
214 | self.cmp = self.color_map(N=len(object_dict))
215 | self.object_dict = object_dict
216 |
217 | def vis_oneview(self, ipt_im, obj_pose_dict, alpha = 0.5, axis_len=30):
218 | '''
219 | Input:
220 | ipt_im: numpy [H, W, 3]
221 | input image
222 | obj_pose_dict:
223 | is a dict store object poses within input image
224 | example:
225 | poselist = {
226 | 15: numpy_pose 4X4,
227 | 37: numpy_pose 4X4,
228 | 39: numpy_pose 4X4,
229 | }
230 | alpha: float [0,1]
231 | alpha for labels' colormap on image
232 | axis_len: int
233 | pixel lengths for draw axis
234 | '''
235 | img = ipt_im.copy()
236 | for obj_label in obj_pose_dict:
237 | if obj_label in self.object_dict:
238 | pose = obj_pose_dict[obj_label]
239 | node = self.objnode[obj_label]['node']
240 | node.mesh.is_visible = True
241 | self.scene.set_pose(node, pose=pose)
242 | full_depth = self.render.render(self.scene, flags = pyrender.constants.RenderFlags.DEPTH_ONLY)
243 | for obj_label in obj_pose_dict:
244 | if obj_label in self.object_dict:
245 | node = self.objnode[obj_label]['node']
246 | node.mesh.is_visible = False
247 | for obj_label in self.object_dict:
248 | node = self.objnode[obj_label]['node']
249 | node.mesh.is_visible = False
250 | for obj_label in obj_pose_dict:
251 | if obj_label in self.object_dict:
252 | node = self.objnode[obj_label]['node']
253 | node.mesh.is_visible = True
254 | depth = self.render.render(self.scene, flags = pyrender.constants.RenderFlags.DEPTH_ONLY)
255 | node.mesh.is_visible = False
256 | mask = np.logical_and(
257 | (np.abs(depth - full_depth) < 1e-6), np.abs(full_depth) > 0.2
258 | )
259 | if np.sum(mask) > 0:
260 | color = self.cmp[obj_label - 1]
261 | img[mask, :] = alpha * img[mask, :] + (1 - alpha) * color[:]
262 | obj_pose = obj_pose_dict[obj_label]
263 | obj_center = self.project2d(self.intrinsic, obj_pose[:3, -1])
264 | rgb_colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0)]
265 | for j in range(3):
266 | obj_xyz_offset_2d = self.project2d(self.intrinsic, obj_pose[:3, -1] + obj_pose[:3, j] * 0.001)
267 | obj_axis_endpoint = obj_center + (obj_xyz_offset_2d - obj_center) / np.linalg.norm(obj_xyz_offset_2d - obj_center) * axis_len
268 | cv2.arrowedLine(img, (int(obj_center[0]), int(obj_center[1])), (int(obj_axis_endpoint[0]), int(obj_axis_endpoint[1])), rgb_colors[j], thickness=2, tipLength=0.15)
269 | return img
270 |
271 | def color_map(self, N=256, normalized=False):
272 | def bitget(byteval, idx):
273 | return ((byteval & (1 << idx)) != 0)
274 | dtype = 'float32' if normalized else 'uint8'
275 | cmap = np.zeros((N, 3), dtype=dtype)
276 | for i in range(N):
277 | r = g = b = 0
278 | c = i
279 | for j in range(8):
280 | r = r | (bitget(c, 0) << 7-j)
281 | g = g | (bitget(c, 1) << 7-j)
282 | b = b | (bitget(c, 2) << 7-j)
283 | c = c >> 3
284 | cmap[i] = np.array([r, g, b])
285 | cmap = cmap/255 if normalized else cmap
286 | return cmap
287 |
288 | def project2d(self, intrinsic, point3d):
289 | return (intrinsic @ (point3d / point3d[2]))[:2]
290 |
291 | def quaternion_to_matrix(quaternions):
292 | """
293 | Convert rotations given as quaternions to rotation matrices.
294 |
295 | Args:
296 | quaternions: quaternions with real part first,
297 | as tensor of shape (..., 4).
298 |
299 | Returns:
300 | Rotation matrices as tensor of shape (..., 3, 3).
301 | """
302 | r, i, j, k = torch.unbind(quaternions, -1)
303 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
304 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
305 |
306 | o = torch.stack(
307 | (
308 | 1 - two_s * (j * j + k * k),
309 | two_s * (i * j - k * r),
310 | two_s * (i * k + j * r),
311 | two_s * (i * j + k * r),
312 | 1 - two_s * (i * i + k * k),
313 | two_s * (j * k - i * r),
314 | two_s * (i * k - j * r),
315 | two_s * (j * k + i * r),
316 | 1 - two_s * (i * i + j * j),
317 | ),
318 | -1,
319 | )
320 | return o.reshape(quaternions.shape[:-1] + (3, 3))
321 |
322 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | os.environ["TZ"] = "US/Eastern"
5 | time.tzset()
6 |
7 | import matplotlib.pyplot as plt
8 | from pose_cnn import hello_pose_cnn
9 | from p3_helper import hello_helper
10 | from rob599 import reset_seed
11 | from rob599.grad import rel_error
12 | import torch
13 | from rob599 import PROPSPoseDataset
14 | from rob599 import reset_seed, visualize_dataset
15 | import torchvision.models as models
16 | from pose_cnn import FeatureExtraction
17 | from rob599 import reset_seed
18 | from pose_cnn import FeatureExtraction, SegmentationBranch
19 | import time
20 | from torch.utils.data import DataLoader
21 | import torchvision.models as models
22 | import multiprocessing
23 |
24 | from rob599 import reset_seed
25 | from pose_cnn import PoseCNN
26 | from tqdm import tqdm
27 |
28 | # Set a few constants related to data loading.
29 | NUM_CLASSES = 10
30 | BATCH_SIZE = 4
31 | NUM_WORKERS = multiprocessing.cpu_count()
32 |
33 | # Ensure helper functions run correctly
34 | hello_pose_cnn()
35 | hello_helper()
36 |
37 | # Check last modification time of pose_cnn.py
38 | pose_cnn_path = os.path.join("/home/yifeng/PycharmProjects/TestEnv/PoseCNN/pose_cnn.py")
39 | pose_cnn_edit_time = time.ctime(os.path.getmtime(pose_cnn_path))
40 | print("pose_cnn.py last edited on %s" % pose_cnn_edit_time)
41 |
42 | # Set up matplotlib plotting parameters
43 | plt.rcParams["figure.figsize"] = (10.0, 8.0) # set default size of plots
44 | plt.rcParams["font.size"] = 16
45 | plt.rcParams["image.interpolation"] = "nearest"
46 | plt.rcParams["image.cmap"] = "gray"
47 |
48 | # Check for CUDA availability
49 | if torch.cuda.is_available():
50 | print("Good to go!")
51 | DEVICE = torch.device("cuda")
52 | else:
53 | DEVICE = torch.device("cpu")
54 |
55 | # -------------------------- Dataset Preparation --------------------------------
56 | # NOTE: Set `download=True` for the first time to download the dataset.
57 | # After downloading, set `download=False` for faster execution.
58 |
59 | data_root = "/home/yifeng/PycharmProjects/TestEnv/PoseCNN"
60 | # Prepare train dataset
61 | train_dataset = PROPSPoseDataset(
62 | root=data_root,
63 | split="train",
64 | download=False # Change to True for the first-time download
65 | )
66 |
67 | # Prepare validation dataset
68 | val_dataset = PROPSPoseDataset(
69 | root=data_root,
70 | split="val",
71 | download=False
72 | )
73 |
74 | # Print dataset sizes
75 | print(f"Dataset sizes: train ({len(train_dataset)}), val ({len(val_dataset)})")
76 |
77 |
78 | reset_seed(0)
79 |
80 | grid_vis = visualize_dataset(val_dataset,alpha = 0.25)
81 | plt.axis('off')
82 | plt.imshow(grid_vis)
83 | plt.show()
84 |
85 | reset_seed(0)
86 |
87 | dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE)
88 |
89 | vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
90 | posecnn_model = PoseCNN(pretrained_backbone = vgg16,
91 | models_pcd = torch.tensor(train_dataset.models_pcd).to(DEVICE, dtype=torch.float32),
92 | cam_intrinsic = train_dataset.cam_intrinsic).to(DEVICE)
93 | posecnn_model.train()
94 |
95 | optimizer = torch.optim.Adam(posecnn_model.parameters(), lr=0.001,
96 | betas=(0.9, 0.999))
97 |
98 |
99 | loss_history = []
100 | log_period = 5
101 | _iter = 0
102 |
103 |
104 | st_time = time.time()
105 | for epoch in range(10):
106 | train_loss = []
107 | dataloader.dataset.dataset_type = 'train'
108 | for batch in dataloader:
109 | for item in batch:
110 | batch[item] = batch[item].to(DEVICE)
111 | loss_dict = posecnn_model(batch)
112 | optimizer.zero_grad()
113 | total_loss = 0
114 | for loss in loss_dict:
115 | total_loss += loss_dict[loss]
116 | total_loss.backward()
117 | optimizer.step()
118 | train_loss.append(total_loss.item())
119 |
120 | if _iter % log_period == 0:
121 | loss_str = f"[Iter {_iter}][loss: {total_loss:.3f}]"
122 | for key, value in loss_dict.items():
123 | loss_str += f"[{key}: {value:.3f}]"
124 |
125 | print(loss_str)
126 | loss_history.append(total_loss.item())
127 | _iter += 1
128 |
129 | print('Time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + \
130 | ', ' + 'Epoch %02d' % epoch + ', ' + 'Training finished' + f' , with mean training loss {np.array(train_loss).mean()}'))
131 |
132 | torch.save(posecnn_model.state_dict(), os.path.join("your path here", "posecnn_model.pth"))
133 |
134 | plt.title("Training loss history")
135 | plt.xlabel(f"Iteration (x {log_period})")
136 | plt.ylabel("Loss")
137 | plt.plot(loss_history)
138 | plt.show()
--------------------------------------------------------------------------------