├── src ├── sfm │ ├── __init__.py │ ├── loss_functions.py │ ├── model.py │ ├── utils.py │ ├── datasets.py │ ├── custom_transforms.py │ ├── train_sfm.py │ └── inverse_warp.py ├── segmentation │ ├── __init__.py │ ├── utils.py │ ├── model.py │ └── train_segmentation_model.py ├── reconstruction_utils.py ├── video_utils.py └── reconstruct.py ├── .gitattributes ├── .DS_Store ├── sfm_net.pth ├── segmentation_net.pth ├── example_inputs ├── intrinsics_eucm.json ├── transect_single_video_gopro │ └── single_video.sh ├── transect_two_videos_gopro │ └── two_video.sh ├── transect_single_video_other_camera │ └── single_video.sh ├── class_to_label.json └── class_to_color.json ├── pyproject.toml ├── Dockerfile ├── .github └── workflows │ └── publish-container.yml ├── .gitignore └── README.md /src/sfm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josauder/mee-deepreefmap/HEAD/.DS_Store -------------------------------------------------------------------------------- /sfm_net.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fc9ca83ea551c7ea007f7126ee0b70faa8fbc5c4f694d8d34806b8888d912887 3 | size 128762663 4 | -------------------------------------------------------------------------------- /segmentation_net.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:75c8f9b8bbbdd3af287448dff28ae4ef512fb8282b3ba3b786faa634e97a1f80 3 | size 105049648 4 | -------------------------------------------------------------------------------- /example_inputs/intrinsics_eucm.json: -------------------------------------------------------------------------------- 1 | { 2 | "fx": 365.383616615, 3 | "fy": 383.145828595, 4 | "cx": 319.7500, 5 | "cy": 191.7500, 6 | "alpha": -0.6, 7 | "beta":0.915 8 | } -------------------------------------------------------------------------------- /example_inputs/transect_single_video_gopro/single_video.sh: -------------------------------------------------------------------------------- 1 | python3 reconstruct.py \ 2 | --input_video=../../example_data/input_videos/GX_SINGLE_VIDEO.MP4 \ 3 | --timestamp=0-367 \ 4 | --out_dir=../../out_test \ 5 | --fps=10 -------------------------------------------------------------------------------- /example_inputs/transect_two_videos_gopro/two_video.sh: -------------------------------------------------------------------------------- 1 | python3 reconstruct.py \ 2 | --input_video=../../example_data/input_videos/GX_VIDEO_1_OF_2.MP4,../../example_data/input_videos/GX_VIDEO_2_OF_2.MP4 \ 3 | --timestamp=310-end,begin-100 \ 4 | --out_dir=../../out_test \ 5 | --fps=10 -------------------------------------------------------------------------------- /src/sfm/loss_functions.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | def l2_pose_regularization(poses): 5 | l2loss = [] 6 | for pose in poses: 7 | for p in pose: 8 | if len(p)>0: 9 | l2loss.append((p[0]**2).mean()) 10 | return sum(l2loss) / len(l2loss) 11 | -------------------------------------------------------------------------------- /example_inputs/transect_single_video_other_camera/single_video.sh: -------------------------------------------------------------------------------- 1 | python3 reconstruct.py \ 2 | --input_video=../../example_data/input_videos/OTHER_CAMERA_VID.MP4 \ 3 | --timestamp=10-120 \ 4 | --out_dir=../../out_test \ 5 | --intrinsics_file=../../example_data/transect_single_video_other_camera/intrinsics.json \ 6 | --fps=10 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mee-deepreefmap" 3 | version = "1.1.0" 4 | description = "Scalable 3D Semantic Mapping of Coral Reefs using Deep Learning" 5 | readme = "README.md" 6 | requires-python = ">=3.10, <3.12" 7 | authors = ["Jonathan Sauder "] 8 | dependencies = [ 9 | "addict==2.4.0", 10 | "h5py==3.7.0", 11 | "matplotlib==3.5.3", 12 | "numpy==1.26.4", 13 | "open3d==0.16.0", 14 | "pandas==1.4.4", 15 | "pillow==9.1", 16 | "scikit-image==0.24.0", 17 | "scikit-learn==1.5.2", 18 | "scipy==1.14.1", 19 | "segmentation-models-pytorch==0.3.3", 20 | "torch==2.0.1", 21 | "torchaudio==2.0.2", 22 | "torchvision==0.15.2", 23 | "tqdm==4.64.1", 24 | "wandb==0.13.7", 25 | ] 26 | -------------------------------------------------------------------------------- /example_inputs/class_to_label.json: -------------------------------------------------------------------------------- 1 | {"human": 7, "background": 13, "fish": 9, "sand": 5, "rubble": 18, "unknown hard substrate": 12, "algae covered substrate": 10, "dark": 14, "branching bleached": 19, "branching dead": 20, "branching alive": 22, "stylophora alive": 34, "pocillopora alive": 31, "acropora alive": 25, "table acropora alive": 28, "table acropora dead": 32, "millepora": 21, "turbinaria": 27, "other coral bleached": 4, "other coral dead": 3, "other coral alive": 6, "massive/meandering alive": 17, "massive/meandering dead": 23, "massive/meandering bleached": 16, "meandering alive": 36, "meandering dead": 37, "meandering bleached": 33, "transect line": 15, "transect tools": 8, "sea urchin": 35, "sea cucumber": 26, "anemone": 30, "sponge": 29, "clam": 24, "other animal": 11, "trash": 2, "seagrass": 1} -------------------------------------------------------------------------------- /example_inputs/class_to_color.json: -------------------------------------------------------------------------------- 1 | {"human": [255, 0, 0], "background": [29, 162, 216], "fish": [255, 255, 0], "sand": [194, 178, 128], "rubble": [161, 153, 128], "unknown hard substrate": [125, 125, 125], "algae covered substrate": [125, 163, 125], "dark": [31, 31, 31], "branching bleached": [252, 231, 240], "branching dead": [123, 50, 86], "branching alive": [226, 91, 157], "stylophora alive": [255, 111, 194], "pocillopora alive": [255, 146, 150], "acropora alive": [236, 128, 255], "table acropora alive": [189, 119, 255], "table acropora dead": [85, 53, 116], "millepora": [244, 150, 115], "turbinaria": [228, 255, 119], "other coral bleached": [250, 224, 225], "other coral dead": [114, 60, 61], "other coral alive": [224, 118, 119], "massive/meandering alive": [236, 150, 21], "massive/meandering dead": [134, 86, 18], "massive/meandering bleached": [255, 248, 228], "meandering alive": [230, 193, 0], "meandering dead": [119, 100, 14], "meandering bleached": [251, 243, 216], "transect line": [0, 255, 0], "transect tools": [8, 205, 12], "sea urchin": [0, 142, 255], "sea cucumber": [0, 231, 255], "anemone": [0, 255, 189], "sponge": [240, 80, 80], "clam": [189, 255, 234], "other animal": [0, 255, 255], "trash": [255, 0, 134], "seagrass": [125, 222, 125]} -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-base-ubuntu22.04 2 | 3 | # Set commit hash for gpmfstream repo as there are no versions 4 | ARG GPMFSTREAM_GIT_HASH=f1a9742 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | 7 | # # Get the gpmfstream library 8 | RUN apt-get update && \ 9 | apt-get install -y git ffmpeg pipx python3-pip \ 10 | && pip3 install uv 11 | 12 | # RUN pip3 install poetry 13 | WORKDIR /app 14 | 15 | # Change to non-root user 16 | RUN groupadd mygroup --gid 1000 && \ 17 | useradd -m -U -s /bin/bash -G mygroup -u 1000 myuser && \ 18 | chown -R 1000:1000 /app && \ 19 | chmod -R 755 /app && \ 20 | mkdir /output /input && \ 21 | chown -R 1000:1000 /output /input /tmp && \ 22 | chmod -R o+w /input /output 23 | 24 | # Copy the model checkpoints and environment 25 | COPY --chown=1000:1000 segmentation_net.pth sfm_net.pth uv.lock pyproject.toml /app/ 26 | 27 | # Build gpmfstream 28 | RUN git clone https://github.com/hovren/gpmfstream.git 29 | WORKDIR /app/gpmfstream 30 | RUN pip3 install pybind11 setuptools 31 | RUN git checkout ${GPMFSTREAM_GIT_HASH} \ 32 | && git submodule update --init 33 | RUN python3 setup.py bdist_wheel 34 | 35 | # Install dependencies 36 | WORKDIR /app 37 | RUN uv sync --no-dev --no-cache 38 | RUN uv pip install "./gpmfstream/dist/gpmfstream-0.5-cp310-cp310-linux_x86_64.whl" 39 | 40 | COPY --chown=1000:1000 src /app/src 41 | COPY --chown=1000:1000 example_inputs /app/example_inputs 42 | 43 | WORKDIR /app/src 44 | USER 1000 45 | 46 | ENTRYPOINT ["uv", "run", "python3", "reconstruct.py"] 47 | -------------------------------------------------------------------------------- /src/sfm/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import segmentation_models_pytorch as smp 4 | from .utils import get_depths_and_poses 5 | 6 | 7 | class SfMModel(nn.Module): 8 | def __init__(self,in_channels=3): 9 | super().__init__() 10 | 11 | self.depth_net = smp.DeepLabV3Plus(in_channels=in_channels, encoder_name="resnext50_32x4d", encoder_weights='swsl', activation=None) 12 | self.pose_reduction = nn.Sequential( 13 | nn.Conv2d(2048, 512, (1, 1)), nn.ReLU(), nn.BatchNorm2d(512), 14 | ) 15 | self.squeeze_unsqueeze = nn.Sequential( 16 | nn.Conv2d(4096, 512, (1, 1)), nn.ReLU(), nn.BatchNorm2d(512), 17 | nn.Conv2d(512, 512, (1, 1)), nn.ReLU(), nn.BatchNorm2d(512), 18 | nn.Conv2d(512, 2048, (1, 1)), nn.ReLU(), nn.BatchNorm2d(2048), 19 | ) 20 | self.pose_decoder = nn.Sequential( 21 | nn.Conv2d(1024, 256, (1, 1)), nn.ReLU(), nn.BatchNorm2d(256), 22 | nn.Conv2d(256, 256, (3, 3)), nn.ReLU(), nn.BatchNorm2d(256), 23 | nn.Conv2d(256, 256, (3, 3)), nn.ReLU(), nn.BatchNorm2d(256), 24 | nn.Conv2d(256, 6, (3, 3), bias=False), 25 | ) 26 | 27 | def extract_features(self, x): 28 | return self.depth_net.encoder(x) 29 | 30 | def get_depth_and_poses_from_features(self, images, features, intrinsics): 31 | depth, pose = get_depths_and_poses( 32 | self.depth_net.encoder, 33 | self.depth_net.segmentation_head, 34 | self.depth_net.decoder, 35 | self.pose_decoder, 36 | torch.stack(images).transpose(1,0), 37 | [torch.stack([features[0][flen]] + [feature[flen] for feature in features]).squeeze() for flen in range(len(features[0]))], 38 | self.pose_reduction, 39 | self.squeeze_unsqueeze, 40 | ) 41 | depth = (1 / (25 * torch.sigmoid(depth) + 0.1)) 42 | return depth.squeeze(), pose.squeeze(), intrinsics.repeat(len(images),1) 43 | -------------------------------------------------------------------------------- /.github/workflows/publish-container.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | # This workflow builds a docker image and pushes it 4 | # to the github package registry 5 | 6 | on: 7 | push: 8 | 9 | env: 10 | # Use docker.io for Docker Hub if empty 11 | REGISTRY: ghcr.io 12 | # github.repository as / 13 | IMAGE_NAME: ${{ github.repository }} 14 | 15 | 16 | jobs: 17 | build: 18 | runs-on: ubuntu-latest 19 | permissions: 20 | contents: read 21 | packages: write 22 | # This is used to complete the identity challenge 23 | # with sigstore/fulcio when running outside of PRs. 24 | id-token: write 25 | outputs: 26 | # The digest of the built image. 27 | digest: ${{ steps.build-and-push.outputs.digest }} 28 | 29 | steps: 30 | - name: Checkout repository 31 | uses: actions/checkout@v3 32 | with: 33 | lfs: 'true' 34 | 35 | # Install the cosign tool 36 | # https://github.com/sigstore/cosign-installer 37 | - name: Install cosign 38 | uses: sigstore/cosign-installer@v3.3.0 39 | with: 40 | cosign-release: 'v2.2.2' 41 | 42 | # Workaround: https://github.com/docker/build-push-action/issues/461 43 | - name: Setup Docker buildx 44 | uses: docker/setup-buildx-action@79abd3f86f79a9d68a23c75a09a9a85889262adf 45 | 46 | # Login against a Docker registry 47 | # https://github.com/docker/login-action 48 | - name: Log into registry ${{ env.REGISTRY }} 49 | uses: docker/login-action@28218f9b04b4f3f62068d7b6ce6ca5b26e35336c 50 | with: 51 | registry: ${{ env.REGISTRY }} 52 | username: ${{ github.actor }} 53 | password: ${{ secrets.GITHUB_TOKEN }} 54 | 55 | # Extract metadata (tags, labels) for Docker 56 | # https://github.com/docker/metadata-action 57 | - name: Extract Docker metadata 58 | id: meta 59 | uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38 60 | with: 61 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 62 | 63 | # Build and push Docker image with Buildx 64 | # https://github.com/docker/build-push-action 65 | - name: Build and push Docker image 66 | id: build-and-push 67 | uses: docker/build-push-action@ac9327eae2b366085ac7f6a2d02df8aa8ead720a 68 | with: 69 | context: . 70 | push: true 71 | tags: ${{ steps.meta.outputs.tags }} 72 | labels: ${{ steps.meta.outputs.labels }} 73 | cache-from: type=gha 74 | cache-to: type=gha,mode=max 75 | 76 | 77 | # Sign the resulting Docker image digest. 78 | # This will only write to the public Rekor transparency log when the Docker 79 | # repository is public to avoid leaking data. If you would like to publish 80 | # transparency data even for private images, pass --force to cosign below. 81 | # https://github.com/sigstore/cosign 82 | - name: Sign the published Docker image 83 | env: 84 | COSIGN_EXPERIMENTAL: "true" 85 | # This step uses the identity token to provision an ephemeral certificate 86 | # against the sigstore community Fulcio instance. 87 | run: echo "${{ steps.meta.outputs.tags }}" | xargs -I {} cosign sign --yes {}@${{ steps.build-and-push.outputs.digest }} 88 | -------------------------------------------------------------------------------- /src/sfm/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import cm 3 | from matplotlib.colors import ListedColormap, LinearSegmentedColormap 4 | from torch import nn 5 | import torch 6 | 7 | 8 | def high_res_colormap(low_res_cmap, resolution=1000, max_value=1): 9 | # Construct the list colormap, with interpolated values for higer resolution 10 | # For a linear segmented colormap, you can just specify the number of point in 11 | # cm.get_cmap(name, lutsize) with the parameter lutsize 12 | x = np.linspace(0, 1, low_res_cmap.N) 13 | low_res = low_res_cmap(x) 14 | new_x = np.linspace(0, max_value, resolution) 15 | high_res = np.stack([np.interp(new_x, x, low_res[:, i]) 16 | for i in range(low_res.shape[1])], axis=1) 17 | return ListedColormap(high_res) 18 | 19 | 20 | def opencv_rainbow(resolution=1000): 21 | # Construct the opencv equivalent of Rainbow 22 | opencv_rainbow_data = ( 23 | (0.000, (1.00, 0.00, 0.00)), 24 | (0.400, (1.00, 1.00, 0.00)), 25 | (0.600, (0.00, 1.00, 0.00)), 26 | (0.800, (0.00, 0.00, 1.00)), 27 | (1.000, (0.60, 0.00, 1.00)) 28 | ) 29 | 30 | return LinearSegmentedColormap.from_list('opencv_rainbow', opencv_rainbow_data, resolution) 31 | 32 | 33 | COLORMAPS = {'rainbow': opencv_rainbow(), 34 | 'magma': high_res_colormap(cm.get_cmap('magma')), 35 | 'bone': cm.get_cmap('bone', 10000)} 36 | 37 | 38 | def tensor2array(tensor, max_value=None, colormap='rainbow'): 39 | tensor = tensor.detach().cpu() 40 | if max_value is None: 41 | max_value = tensor.max().item() 42 | if tensor.ndimension() == 2 or tensor.size(0) == 1: 43 | norm_array = tensor.squeeze().numpy()/max_value 44 | array = COLORMAPS[colormap](norm_array).astype(np.float32) 45 | array = array.transpose(2, 0, 1) 46 | 47 | elif tensor.ndimension() == 3: 48 | assert(tensor.size(0) == 3) 49 | array = 0.45 + tensor.numpy()*0.225 50 | return array 51 | 52 | 53 | def change_bn_momentum(model, new_value): 54 | for name, module in model.named_modules(): 55 | if isinstance(module, nn.BatchNorm2d): 56 | module.momentum = new_value 57 | 58 | 59 | def get_depths_and_poses(encoder, segmentation_head, decoder, pose_decoder, images, features_, reduction, squeeze_unsqueeze): 60 | b, l, c, h, w = images.shape 61 | 62 | ref_features = [x[:b] for x in features_] 63 | 64 | features = [] 65 | lf = len(features_) 66 | _, c_feat, h_feat, w_feat = features_[-1].shape 67 | 68 | for i in range(lf): 69 | if i == lf - 1: 70 | features.append( 71 | features_[-1][b:] + squeeze_unsqueeze(torch.cat([ 72 | features_[-1][b:], 73 | ref_features[-1].reshape(b,1,c_feat,h_feat,w_feat).expand(b, l, c_feat, h_feat, w_feat).reshape(b*l,c_feat, h_feat, w_feat) 74 | ], dim=1))) 75 | else: 76 | features.append(features_[i][b:]) 77 | 78 | 79 | depths = segmentation_head(decoder(*features)).reshape(b, l, 1, h, w) 80 | 81 | _, c_feat, h_feat, w_feat = features[-1].shape 82 | last_feat = features[-1] 83 | 84 | last_feat = reduction(last_feat.reshape(b* l, c_feat, h_feat, w_feat)) 85 | c_feat = last_feat.shape[1] 86 | last_feat = last_feat.reshape(b, l, c_feat, h_feat, w_feat) 87 | 88 | 89 | last_feat = last_feat.unsqueeze(2).expand(b, l, l, c_feat, h_feat, w_feat).contiguous() 90 | features_sq = torch.cat([last_feat, last_feat.transpose(1, 2)], dim=3).reshape(b*l*l, c_feat*2, h_feat, w_feat) 91 | 92 | 93 | poses = pose_decoder(features_sq) 94 | poses = poses.mean(dim=(2,3)).reshape(b,l,l,6) 95 | 96 | return depths, poses * 0.01 97 | -------------------------------------------------------------------------------- /src/segmentation/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import json 4 | import torch.nn as nn 5 | from collections import defaultdict 6 | 7 | def load_files(base_dir, test_splits, ignore_splits): 8 | train_images = [] 9 | test_images = [] 10 | train_labels = [] 11 | test_labels = [] 12 | test_polygons = [] 13 | counts = defaultdict(int) 14 | 15 | for root, dirs, files in os.walk(base_dir): 16 | for file in files: 17 | file = root + "/" + file 18 | fileending = "." + file.split('.')[-1] 19 | if fileending in ['.jpg', '.png', '.jpeg']: 20 | split = file.split("/")[-2] 21 | if split not in ignore_splits and split not in test_splits: 22 | train_images.append(file) 23 | train_labels.append(file.replace(fileending, '_seg.npy')) 24 | elif split not in ignore_splits and split in test_splits: 25 | test_images.append(file) 26 | test_labels.append(file.replace(fileending, '_seg.npy')) 27 | test_polygons.append(file.replace(fileending, '_poly.npy')) 28 | all_counts = json.loads(open(base_dir+"/counts.json").read()) 29 | for split in all_counts.keys(): 30 | if split not in ignore_splits and split not in test_splits: 31 | for class_name, count in all_counts[split].items(): 32 | counts[int(class_name)] += int(count) 33 | return train_images, test_images, train_labels, test_labels, test_polygons, counts 34 | 35 | 36 | def color_rgb_image(x, classes, colors): 37 | classes_inverse = {v: k for k, v in classes.items()} 38 | """Takes a semantic segmentation image and returns a colored RGB image.""" 39 | semseg = np.ones((x.shape[0], x.shape[1], 3)) 40 | for val in np.unique(x): 41 | if val > 0: 42 | name = classes_inverse[val] 43 | semseg[x==val] = np.array(colors[name])/255. 44 | return semseg 45 | 46 | def color_by_correctness(prediction, label): 47 | semseg = np.zeros((prediction.shape[0], prediction.shape[1], 3)) 48 | semseg[prediction==label] = np.array([0, 0, 1]) 49 | semseg[prediction!=label] = np.array([1, 0, 0]) 50 | semseg[label==0] = 1 51 | return semseg 52 | 53 | def rotatedRectWithMaxArea(h, w, angle): 54 | """ 55 | From https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders 56 | Given a rectangle of size wxh that has been rotated by 'angle' (in 57 | radians), computes the width and height of the largest possible 58 | axis-aligned rectangle (maximal area) within the rotated rectangle. 59 | """ 60 | angle = np.deg2rad(angle) 61 | if w <= 0 or h <= 0: 62 | return 0,0 63 | 64 | width_is_longer = w >= h 65 | side_long, side_short = (w,h) if width_is_longer else (h,w) 66 | 67 | # since the solutions for angle, -angle and 180-angle are all the same, 68 | # if suffices to look at the first quadrant and the absolute values of sin,cos: 69 | sin_a, cos_a = abs(np.sin(angle)), abs(np.cos(angle)) 70 | if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10: 71 | # half constrained case: two crop corners touch the longer side, 72 | # the other two corners are on the mid-line parallel to the longer line 73 | x = 0.5*side_short 74 | wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a) 75 | else: 76 | # fully constrained case: crop touches all 4 sides 77 | cos_2a = cos_a*cos_a - sin_a*sin_a 78 | wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a 79 | 80 | return int(hr),int(wr) 81 | 82 | 83 | def change_bn_momentum(model, new_value): 84 | for name, module in model.named_modules(): 85 | if isinstance(module, nn.BatchNorm2d): 86 | module.momentum = new_value 87 | 88 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Pytorch 2 | *.pt 3 | checkpoints/ 4 | 5 | # Video processing 6 | src/tmp 7 | tmp 8 | *.MP4 9 | *.mp4 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/#use-with-ide 120 | .pdm.toml 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ 171 | .history 172 | .vscode 173 | gpmfstream/ 174 | -------------------------------------------------------------------------------- /src/sfm/datasets.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import random 5 | from path import Path 6 | from PIL import Image 7 | import custom_transforms 8 | import torchvision 9 | 10 | class SequenceDataset(Dataset): 11 | def __init__(self, data_path, train=True, transform=None, individual_transform=None, seed=0, long_sequence_length=7, subsampled_sequence_length=5, with_replacement=True): 12 | """The data_path argument points to a directory structured in the KITTI format: 13 | which means that it contains a subdirectory for each sequence, and each sequence 14 | //0000001.jpg 15 | //0000002.jpg 16 | .. 17 | //0000001.jpg 18 | //0000002.jpg 19 | 20 | One sample from the dataset is chosen as follows: 21 | 1. A random sequence is chosen 22 | 2. A sequence of subsequent frames is randomly chosen from the sequence 23 | 3. A sequence of random frames is subsampled from the sequence, with or without replacement 24 | """ 25 | np.random.seed(seed) 26 | random.seed(seed) 27 | 28 | self.root = Path(data_path) 29 | scene_list_path = self.root/'train.txt' if train else self.root/'val.txt' 30 | scene_list_path = self.root/'train.txt' if train else self.root/'val.txt' 31 | self.scenes = [self.root/folder[:-1] for folder in open(scene_list_path)] 32 | self.transform = transform 33 | self.individual_transform = individual_transform 34 | self.long_sequence_length = long_sequence_length 35 | self.subsampled_sequence_length = subsampled_sequence_length 36 | 37 | self.backward = {} 38 | with open( self.root/"forward.txt", "r") as f: 39 | for line in f.readlines(): 40 | self.backward[line.strip()] = False 41 | with open( self.root/"backward.txt", "r") as f: 42 | for line in f.readlines(): 43 | self.backward[line.strip()] = True 44 | 45 | 46 | self.imgs = {(scene): sorted(list(scene.files('*.jpg'))+list(scene.files('*.jpeg'))) for scene in self.scenes} 47 | 48 | self.num_samples = 0 49 | self.index_to_sequence = [] 50 | self.index_to_index_in_sequence = [] 51 | for scene in self.scenes: 52 | self.num_samples += len(self.imgs[scene]) - long_sequence_length + 1 53 | self.index_to_sequence += [scene] * (len(self.imgs[scene]) - long_sequence_length + 1) 54 | self.index_to_index_in_sequence += list(range(len(self.imgs[scene]) - long_sequence_length + 1)) 55 | self.index_to_sequence = np.array(self.index_to_sequence) 56 | self.with_replacement = with_replacement 57 | 58 | normalize = custom_transforms.Normalize(mean=[0.45, 0.45, 0.45], 59 | std=[0.225, 0.225, 0.225]) 60 | self.totensor = torchvision.transforms.ToTensor() 61 | self.out_transform = normalize 62 | 63 | 64 | def __len__(self): 65 | return self.num_samples 66 | 67 | def __getitem__(self, idx): 68 | # First, get right scene via self.index_to_sequence 69 | scene = self.index_to_sequence[idx] 70 | # Then, get corresponding sequence of images within scene 71 | index = self.index_to_index_in_sequence[idx] 72 | backward = False 73 | long_sequence = self.imgs[scene][index:index+self.long_sequence_length] 74 | 75 | # Then, subsample from sequence, but keep the order of the sequence intact 76 | if self.with_replacement: 77 | indices = sorted(np.random.choice(len(long_sequence), self.subsampled_sequence_length, replace=True), reverse=backward) 78 | else: 79 | indices = sorted(np.random.choice(len(long_sequence), self.subsampled_sequence_length, replace=False), reverse=backward) 80 | subsampled_sequence = [long_sequence[i] for i in indices] 81 | 82 | # TODO Fix 83 | intrinsics = np.eye(3) 84 | 85 | imgs = [Image.open(img) for img in subsampled_sequence] 86 | if self.transform: 87 | # TODO remove intrinsics 88 | imgs, intrinsics = self.transform(imgs, intrinsics) 89 | imgs_out, intrinsics = self.out_transform([self.totensor(img) for img in imgs], intrinsics) 90 | else: 91 | imgs_out, intrinsics = self.out_transform([self.totensor(img) for img in imgs], intrinsics) 92 | 93 | if self.individual_transform: 94 | # TODO remove intrinsics 95 | imgs_individually_transformed = [self.individual_transform(img) for img in imgs] 96 | imgs_individually_transformed, intrinsics = self.out_transform([self.totensor(img) for img in imgs_individually_transformed], intrinsics) 97 | else: 98 | imgs_individually_transformed = imgs_out 99 | 100 | return imgs_out, imgs_individually_transformed, 0 101 | -------------------------------------------------------------------------------- /src/sfm/custom_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import random 4 | import numpy as np 5 | from PIL import Image 6 | import torchvision.transforms.functional as F 7 | '''Set of tranform random routines that takes list of inputs as arguments, 8 | in order to have random but coherent transformations.''' 9 | 10 | class ColorJitter(object): 11 | def __init__(self, brightness, contrast, saturation, hue): 12 | self.brightness = brightness 13 | self.contrast = contrast 14 | self.saturation = saturation 15 | self.hue = hue 16 | 17 | def __call__(self, images, intrinsics=None): 18 | fn_idx = torch.randperm(4) 19 | 20 | brightness_factor = float(torch.empty(1).uniform_(-self.brightness, self.brightness)) 21 | contrast_factor = float(torch.empty(1).uniform_(-self.contrast, self.contrast)) 22 | saturation_factor = float(torch.empty(1).uniform_(-self.saturation, self.saturation)) 23 | hue_factor = float(torch.empty(1).uniform_(-self.hue, self.hue)) 24 | 25 | output_images = [] 26 | for img in images: 27 | 28 | for fn_id in fn_idx: 29 | if fn_id == 0 and brightness_factor is not None: 30 | img = F.adjust_brightness(img, 1+brightness_factor) 31 | elif fn_id == 1 and contrast_factor is not None: 32 | img = F.adjust_contrast(img, 1+contrast_factor) 33 | elif fn_id == 2 and saturation_factor is not None: 34 | img = F.adjust_saturation(img, 1+saturation_factor) 35 | elif fn_id == 3 and hue_factor is not None: 36 | img = F.adjust_hue(img, hue_factor) 37 | output_images.append(img) 38 | 39 | return output_images, intrinsics 40 | 41 | 42 | class Compose(object): 43 | def __init__(self, transforms): 44 | self.transforms = transforms 45 | 46 | def __call__(self, images, intrinsics=None): 47 | for t in self.transforms: 48 | images, intrinsics = t(images, intrinsics) 49 | 50 | return images, intrinsics 51 | 52 | 53 | class Normalize(object): 54 | def __init__(self, mean, std): 55 | self.mean = mean 56 | self.std = std 57 | 58 | def __call__(self, images, intrinsics): 59 | for tensor in images: 60 | for t, m, s in zip(tensor, self.mean, self.std): 61 | t.sub_(m).div_(s) 62 | return images, intrinsics 63 | 64 | def invert(self, img): 65 | img = torch.clone(img.detach()) 66 | for t, m, s in zip(img, self.mean, self.std): 67 | t.mul_(s).add_(m) 68 | return img 69 | 70 | class ArrayToTensor(object): 71 | """Converts a list of numpy.ndarray (H x W x C) along with a intrinsics matrix to a list of torch.FloatTensor of shape (C x H x W) with a intrinsics tensor.""" 72 | 73 | def __call__(self, images, intrinsics): 74 | tensors = [] 75 | for im in images: 76 | # put it from HWC to CHW format 77 | im = np.transpose(im, (2, 0, 1)) 78 | # handle numpy array 79 | tensors.append(torch.from_numpy(im).float()/255) 80 | return tensors, intrinsics 81 | 82 | 83 | class RandomHorizontalFlip(object): 84 | """Randomly horizontally flips the given numpy array with a probability of 0.5""" 85 | 86 | def __call__(self, images, intrinsics): 87 | if intrinsics is not None: 88 | if random.random() < 0.5: 89 | output_images = [F.hflip(im) for im in images] 90 | output_intrinsics = intrinsics 91 | 92 | w = output_images[0].shape[1] 93 | output_intrinsics[0, 2] = w - output_intrinsics[0, 2] 94 | else: 95 | output_images = images 96 | output_intrinsics = intrinsics 97 | else: 98 | if random.random() < 0.5: 99 | # Only works if intrinsics are centered, otherwise intrinsics will be wrong! 100 | output_images = [F.hflip(im) for im in images] 101 | output_intrinsics = intrinsics 102 | 103 | return output_images, output_intrinsics 104 | 105 | class RandomScaleCrop(object): 106 | """Randomly zooms images up to 15% and crop them to keep same size as before.""" 107 | 108 | def __call__(self, images, intrinsics): 109 | assert intrinsics is not None 110 | output_intrinsics = np.copy(intrinsics) 111 | 112 | in_h, in_w, _ = images[0].shape 113 | x_scaling, y_scaling = np.random.uniform(1, 1.15, 2) 114 | scaled_h, scaled_w = int(in_h * y_scaling), int(in_w * x_scaling) 115 | 116 | output_intrinsics[0] *= x_scaling 117 | output_intrinsics[1] *= y_scaling 118 | scaled_images = [np.array(Image.fromarray(im.astype(np.uint8)).resize((scaled_w, scaled_h))).astype(np.float32) for im in images] 119 | 120 | offset_y = np.random.randint(scaled_h - in_h + 1) 121 | offset_x = np.random.randint(scaled_w - in_w + 1) 122 | cropped_images = [im[offset_y:offset_y + in_h, offset_x:offset_x + in_w] for im in scaled_images] 123 | 124 | output_intrinsics[0, 2] -= offset_x 125 | output_intrinsics[1, 2] -= offset_y 126 | 127 | return cropped_images, output_intrinsics -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scalable 3D Semantic Mapping of Coral Reefs using Deep Learning 2 | 3 | This repository contais the source code for the Paper [Scalable 3D Semantic Mapping of Coral Reefs using Deep Learning](https://arxiv.org/abs/2309.12804). 4 | 5 | [The project page](https://josauder.github.io/deepreefmap/) contains updated information on the state of the project. 6 | 7 | 8 | ## Quick start (docker) 9 | 10 | A docker image is available in this repository to run the code without 11 | installing any dependencies. 12 | 13 | ```bash 14 | docker run \ 15 | -v ./example_data/input_videos:/input \ 16 | -v ./output:/output \ 17 | ghcr.io/josauder/mee-deepreefmap \ 18 | --input_video=/input/GX_SINGLE_VIDEO.MP4 \ 19 | --timestamp=0-100 \ 20 | --out_dir=/output \ 21 | --fps=10 22 | ``` 23 | 24 | If using a GPU, make sure your docker runtime is configured to use the GPU 25 | (installing for example, `nvidia-container-toolkit` for NVIDIA GPUs), and run 26 | the docker image with the `--gpus all` flag. 27 | 28 | The Dockerfile used for building the above image is also included should 29 | you wish you build the image yourself. 30 | 31 | To build the docker image, use the following command: 32 | 33 | ```bash 34 | docker build -t deepreefmap . 35 | ``` 36 | 37 | **Note**: The above command assumes that the input video is located in 38 | `./example_data/input_videos`, the output will be saved to `./output`. The 39 | input video can be obtained from the example data located in the Zenodo archive 40 | described below. 41 | 42 | 43 | ## Installation 44 | 45 | This repository depends on [gpmfstream](https://github.com/hovren/gpmfstream), which in turn depends on [gpmf-parser](https://github.com/gopro/gpmf-parser). 46 | To install gpmfstream, run the following: 47 | 48 | ``` 49 | git clone https://github.com/hovren/gpmfstream.git 50 | cd gpmfstream 51 | git submodule update --init 52 | python3 setup.py install 53 | ``` 54 | 55 | ### Poetry (recommended) 56 | 57 | This repository uses [uv](https://docs.astral.sh/uv/) to manage the rest 58 | of the dependencies. To install the dependencies, run the following in the main 59 | repository directory: 60 | 61 | ```bash 62 | uv sync 63 | ``` 64 | 65 | For any doubt, refer to the `Dockefile` for the complete method of installing 66 | the dependencies. 67 | 68 | ## Download Example Data and Pre-Trained Models: 69 | 70 | Pre-trained model checkpints and example input videos can be downloaded from the [Zenodo archive](https://zenodo.org/record/10624794). 71 | 72 | Checkpoints are also included in this repository with Git-LFS. Ensure you have Git-LFS installed before cloning the repository. 73 | 74 | ## Running 3D Reconstructions of GoPro Hero 10 Videos 75 | 76 | Simple usage: the input is one MP4 video taken with a GoPro Hero 10 camera, as well as the timestamps on when the transect begins and ends in the video (TODO: discuss format). 77 | 78 | ``` 79 | python3 reconstruct.py \ 80 | --input_video= \ 81 | --timestamp=- \ 82 | --out_dir= 83 | ``` 84 | 85 | Advanced usage: the GoPro Hero 10 camera cuts videos into 4GB chuns. If the transect is spread over two or more videos, the following command can be used to reconstruct the transect. 86 | 87 | ``` 88 | python3 reconstruct.py \ 89 | --input_video=, \ 90 | --timestamp=-end,begin- \ 91 | --out_dir= 92 | ``` 93 | 94 | ## Running 3D Reconstructions of Videos from other Cameras 95 | 96 | This repository, for now, supports the GoPro Hero 10 Camera. If you want to use a different camera, be sure to provide the correct camera intrinsics as `intrinsics.json`, which are passed as a command line argument to any other scripts. For now, the intrinsics follow a simplified UCM format, with the focal lengths `fx, fy` in pixels, the focal point `cx, cy` in pixels, and the `alpha` value to account for distortion, which is set to zero in the default case to assume linear camera intrinsics. 97 | 98 | ## Training the 3D Reconstruction Network on Your Own Data 99 | 100 | To train the 3D reconstruction data on your own data, use 101 | 102 | ``` 103 | sfm/train_sfm.py 104 | --data \ 105 | --checkpoint \ 106 | --name 107 | ```` 108 | 109 | Where your data should be a directory of the following structure (same as KITTI VO Dataset): 110 | 111 | ``` 112 | train.txt 113 | val.txt 114 | sequence1/ 115 | 000001.jpg 116 | 000002.jpg 117 | ... 118 | sequence1/ 119 | 000001.jpg 120 | 000002.jpg 121 | ... 122 | sequence3/ 123 | 000001.jpg 124 | 000002.jpg 125 | ... 126 | ``` 127 | 128 | With `train.txt` containing, for example 129 | 130 | ``` 131 | sequence1 132 | sequence2 133 | ``` 134 | And `val.txt` containing, for example 135 | 136 | ``` 137 | sequence3 138 | ``` 139 | 140 | ## Training the Semantic Segmentation Network on Your Own Data 141 | 142 | For training the segmentation model, use 143 | 144 | ``` 145 | python3 train_segmentation.py \ 146 | --data \ 147 | --checkpoint \ 148 | --test_splits , \ 149 | --name 150 | ``` 151 | 152 | Where your data should be a directory with the following structure: 153 | ``` 154 | data/ 155 | classes.json 156 | colors.json 157 | counts.json 158 | scene_1/ 159 | image_0.png 160 | image_0_seg.npy 161 | image_0_poly.npy 162 | image_1.png 163 | image_1_seg.npy 164 | image_1_poly.npy 165 | ... 166 | scene_2/ 167 | image_0.png 168 | image_0_seg.npy 169 | image_0_poly.npy 170 | ... 171 | ... 172 | ``` 173 | 174 | Following the example dataset from the [Zenodo archive](https://zenodo.org/record/10624794). -------------------------------------------------------------------------------- /src/reconstruction_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.spatial import cKDTree 4 | from scipy.stats import mode 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | 8 | def rotation_matrix_from_vectors(vec1, vec2): 9 | """ Find the rotation matrix that aligns vec1 to vec2 10 | :param vec1: A 3d "source" vector 11 | :param vec2: A 3d "destination" vector 12 | :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2. 13 | """ 14 | a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3) 15 | v = np.cross(a, b) 16 | c = np.dot(a, b) 17 | s = np.linalg.norm(v) 18 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 19 | rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2)) 20 | return rotation_matrix 21 | 22 | def get_closest_to_centroid(g): 23 | if len(g) == 1: 24 | return g[0] 25 | return g[np.argmin(np.linalg.norm(g[:,:3] - g[:,:3].mean(axis=0), axis=1))] 26 | 27 | 28 | def get_closest_to_centroid_with_attributes_of_closest_to_cam(g): 29 | if len(g) == 1: 30 | return g 31 | xyz = g[np.argmin(np.linalg.norm(g[:,:3] - g[:,:3].mean(axis=0), axis=1)),:3] 32 | attributes = g[np.argmin(g[:,3], axis=0)][3:] 33 | return np.concatenate([xyz, attributes]).reshape(1, -1) 34 | 35 | def remove_outliers(g): 36 | if len(g) == 1: 37 | return [] 38 | return g 39 | 40 | def map_3d(x, fn, size=0.03): 41 | # Makes Floats into Int-Bins 42 | to_bin = np.floor(x[:, :3] / size).astype(np.int32) 43 | # Lexsort Bins 44 | inds = np.lexsort(np.transpose(to_bin)[::-1]) 45 | to_bin = to_bin[inds] 46 | x = x[inds] 47 | del inds 48 | splits = np.split(x, np.cumsum(np.unique(to_bin, return_counts=True, axis=0)[1])[:-1]) 49 | del to_bin 50 | del x 51 | results = np.concatenate([x for x in np.vectorize(fn, otypes=[np.ndarray])(splits) if len(x)>0], axis=0) 52 | return results 53 | 54 | 55 | def get_matching_indices(arr_a, arr_b): 56 | tree = cKDTree(arr_b) 57 | dist, index = tree.query(arr_a, workers=64) 58 | return index 59 | 60 | def get_rotation_matrix_to_align_pose_with_gravity(pose, g): 61 | """Used to find rotation that rotates pose matrix to align with gravity vector g""" 62 | xx = np.array([0,0,1]) # Vector to which gravity is aligned 63 | return rotation_matrix_from_vectors(pose[:3,:3] @ (g / np.linalg.norm(g)), xx) 64 | 65 | 66 | def get_edgeness(x): 67 | edgeness_x = torch.abs(x[:-1] - x[1:]) # has shape (height, width-1) 68 | edgeness_y = torch.abs(x[:,:-1] - x[:,1:]) # has shape (height-1, width) 69 | edgeness = torch.zeros_like(x) 70 | edgeness[:,:-1] += edgeness_y 71 | edgeness[:,1:] += edgeness_y 72 | edgeness[:-1,:] += edgeness_x 73 | edgeness[1:,:] += edgeness_x 74 | return edgeness 75 | 76 | 77 | def aggregate_2d_grid(inp, size): 78 | """ Builds a 2D Grid along the two principal components of the point cloud 79 | in each grid element, the points in the point cloud are aggregated to give a final 80 | semantic class, height, and color""" 81 | to_bin = np.floor(inp[:, 0:2] / size).astype(np.int32) 82 | inds = np.lexsort(np.transpose(to_bin)[::-1]) 83 | to_bin = to_bin[inds] 84 | inp = inp[inds] 85 | 86 | def aggregate_2d_grid_cell(group): 87 | # If only one point in grid element, return this one point, set the counter of points in grid element to 1 88 | if len(group) == 1: 89 | return np.concatenate([group, np.array([[1]])], axis=1) 90 | # If two points in grid element, return the one with higher z value, set counter of points in grid element to 2 91 | if len(group) == 2: 92 | return np.concatenate([group[np.argmax(group[:,2])], np.array([2])]).reshape(1, -1) 93 | # If more than two points in grid element, discard points below the mean height 94 | z = group[:,2] 95 | mean_height = z.mean() 96 | 97 | #TODO doublecheck orientation of z axis 98 | group_ = group[z >= mean_height] 99 | if len(group_) == 0: 100 | return np.concatenate([group[:1], np.array([[1]])], axis=1) 101 | x, y, z, r, g, b, distance_to_cam, class_, class_r, class_g, class_b, frame_index, depth_unc = group_.T 102 | 103 | most_common_class = mode(class_, keepdims=False)[0] 104 | 105 | class_r = class_r[class_==most_common_class][0] 106 | class_g = class_g[class_==most_common_class][0] 107 | class_b = class_b[class_==most_common_class][0] 108 | 109 | return np.array([[ 110 | x[0], # Is the same for all points in group 111 | y[0], # Is the same for all points in group 112 | np.mean(z), # Height is calculated as mean 113 | np.mean(r), # Color is calculated as mean 114 | np.mean(g), # Color is calculated as mean 115 | np.mean(b), # Color is calculated as mean 116 | np.mean(distance_to_cam), # Distance to camera is calculated as mean 117 | most_common_class, # Class is the most common class 118 | class_r, # Class color is the color of most common class 119 | class_g, # Class color is the color of most common class 120 | class_b, # Class color is the color of most common class 121 | np.mean(frame_index), # Frame index is calculated as mean 122 | np.mean(depth_unc), # Depth uncertainty is calculated as mean 123 | len(group) # Number of points in 2D Grid element 124 | ]]) 125 | 126 | 127 | del inds 128 | splits = np.split(inp, np.cumsum(np.unique(to_bin, return_counts=True, axis=0)[1])[:-1]) 129 | del to_bin 130 | del inp 131 | results = np.concatenate([inp for inp in np.vectorize(aggregate_2d_grid_cell, otypes=[np.ndarray])(splits) if len(inp)>0], axis=0) 132 | results[:,0] /= size 133 | results[:,1] /= size 134 | return results 135 | 136 | def get_legend(class_to_colors, tmp_dir): 137 | 138 | class_to_colors = dict(sorted(class_to_colors.items(), key=lambda item: len(item[0]))) 139 | 140 | labels = list(class_to_colors.keys()) 141 | fig,ax = plt.subplots() 142 | 143 | def f(m, c,l): 144 | return plt.plot([],[],marker=m, color=c, ls="none", label=l)[0] 145 | 146 | [f("s", class_color/255., class_name) for class_name, class_color in class_to_colors.items()] 147 | 148 | ax.axis('off') 149 | legend = plt.legend(ncol=5) 150 | fig = legend.figure 151 | fig.canvas.draw() 152 | bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) 153 | fig.savefig(tmp_dir + "/legend.png", dpi="figure", bbox_inches=bbox) 154 | return np.array(Image.open(tmp_dir + "/legend.png"))[:,:,:3].transpose(1, 0, 2)/255. -------------------------------------------------------------------------------- /src/video_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import numpy as np 3 | import os 4 | from gpmfstream import Stream 5 | from tqdm import tqdm 6 | from PIL import Image 7 | from skimage.transform import resize 8 | import matplotlib.pyplot as plt 9 | from reconstruction_utils import get_legend 10 | 11 | def get_video_length(filename): 12 | result = subprocess.run(["ffprobe", "-v", "error", "-show_entries", 13 | "format=duration", "-of", 14 | "default=noprint_wrappers=1:nokey=1", filename], 15 | stdout=subprocess.PIPE, 16 | stderr=subprocess.STDOUT) 17 | return float(result.stdout) 18 | 19 | def extract_frames_and_gopro_gravity_vector(video_names, timestamps, width, height, fps, tmp_dir, reverse=False): 20 | """As input, takes a list of video names of the form: [, ], 21 | and a list of (minute:second)-timestamps of the example form ["-end","begin-"] 22 | using FFMPEG, extracts frames at frames per second, with the height and width set accordingly, 23 | and in reverse for videos where the camera moves backwards. 24 | """ 25 | 26 | os.makedirs(tmp_dir + "/rgb", exist_ok=True) 27 | 28 | gravity_vectors = [] 29 | total_frames = 0 30 | 31 | for video_id, (video_name, timestamp) in enumerate(zip(video_names, timestamps)): 32 | 33 | targetpath = tmp_dir + "/" + video_name.split("/")[-1].split(".")[0].replace(" ", "_").replace("/", "_") 34 | os.makedirs(targetpath, exist_ok=True) 35 | 36 | begin, end = timestamp.split("-") 37 | ss = "" 38 | to = "" 39 | if begin != "begin": 40 | ss += " -ss " + begin 41 | if end != "end": 42 | to += " -to " + end 43 | 44 | reverse_flag = "" 45 | if reverse: 46 | reverse_flag = "reverse," 47 | 48 | # First: cut video 49 | #os.system("ffmpeg-7.0.2-amd64-static/ffmpeg -hide_banner -loglevel error"+ss+" " +to+" -y -i '"+video_name+"' -c copy "+tmp_dir+"/"+str(video_id)+".mp4") 50 | os.system("ffmpeg -hide_banner -loglevel error"+ss+" " +to+" -y -i '"+video_name+"' -c copy "+tmp_dir+"/"+str(video_id)+".mp4") 51 | # Second: scale video to right dimensions 52 | os.system("ffmpeg -hide_banner -loglevel error -y -i "+tmp_dir+"/"+str(video_id)+".mp4 -vf scale="+str(width)+":"+str(height)+" "+tmp_dir+"/"+str(video_id)+"_.mp4") 53 | # Third: extract frames 54 | os.system("ffmpeg -hide_banner -loglevel error -y -i "+tmp_dir+"/"+str(video_id)+"_.mp4 -vf "+reverse_flag+"fps="+str(fps)+" -qscale:v 2 "+targetpath +"/%07d.jpg") 55 | 56 | num_frames = len(os.listdir(targetpath)) 57 | 58 | for frame in os.listdir(targetpath): 59 | frameid = int(frame.split(".")[0]) 60 | os.system("mv "+targetpath + "/" + frame + " "+tmp_dir+"/rgb/" + str(frameid + total_frames).zfill(7) + ".jpg") 61 | 62 | total_frames += num_frames 63 | gravity_vectors.append(get_gravity_vectors(video_name, timestamp, num_frames)) 64 | if gravity_vectors[0] is None: 65 | return None 66 | return np.concatenate(gravity_vectors) 67 | 68 | 69 | def get_gravity_vectors(video, timestamp, number_of_frames): 70 | """Uses gpmfstream to extract gravity vectors from an MP4 video file.""" 71 | try: 72 | grav = Stream.extract_streams(video)["GRAV"].data 73 | except Exception as e: 74 | print("WARNING: Could not extract gravity vectors from video file:", video, " is your video an unedited GoPro video?") 75 | return None 76 | length = get_video_length(video) 77 | 78 | begin, end = timestamp.split("-") 79 | #TODO: timestamp parsing! 80 | if begin == "begin": 81 | begin = 0 82 | if end == "end": 83 | end = length 84 | begin = float(begin) 85 | end = float(end) 86 | 87 | grav = grav[int(begin/length*len(grav)):int(end/length*len(grav))] 88 | inds = np.linspace(0, len(grav)-1, number_of_frames).astype(np.int32) 89 | grav = grav[inds] 90 | grav /= np.linalg.norm(grav, axis=1).reshape(-1, 1) 91 | return grav 92 | 93 | def render_video(img_list, depths, semantic_segmentation, results_npy, fps, class_to_label, label_to_color, tmp_dir, reverse): 94 | """Renders a video from the given images, depths, semantic_segmentation and 2d maps.""" 95 | os.makedirs(tmp_dir + "/render", exist_ok=True) 96 | 97 | # For visualization, its nicer when depths are scaled between 0 and 1, and sqrt_scaled 98 | q2, q98 = np.nanquantile(depths, [0.02, 0.98]) 99 | depths = np.nan_to_num(depths, nan=q2) 100 | depths = np.clip(depths, q2, q98) 101 | depths = np.clip(depths, q2, 0.35) 102 | depths_ = np.sqrt(depths) 103 | depths_ = (depths_ - np.min(depths_)/2) / (np.max(depths_)-np.min(depths_)/2) 104 | #depths_ = depths_ / np.max(depths_) 105 | 106 | 107 | class_to_color = {class_name: label_to_color[class_label] for class_name, class_label in class_to_label.items()} 108 | legend = get_legend(class_to_color, tmp_dir) 109 | 110 | final_rgb = results_npy[:,:,1:4] 111 | final_class_rgb = results_npy[:,:,6:9] 112 | frame_index = results_npy[:,:,9:10].astype(np.int16) 113 | 114 | 115 | for i in tqdm(range(len(depths))): 116 | 117 | 118 | color_semseg = np.zeros((semantic_segmentation.shape[1], semantic_segmentation.shape[2], 3), dtype=np.uint8) 119 | for class_name, class_label in class_to_label.items(): 120 | color_semseg[semantic_segmentation[i]==class_label] = label_to_color[class_label] 121 | 122 | 123 | depths_[i][semantic_segmentation[i]==class_to_label['fish']] = 0 124 | depths_[i][semantic_segmentation[i]==class_to_label['human']] = 0 125 | 126 | 127 | rgb = np.array(Image.open(img_list[i]).resize((640, 384)))/255. 128 | if reverse: 129 | ind = (frame_index >= i).astype(np.uint8) 130 | else: 131 | ind = (frame_index <= i).astype(np.uint8) 132 | 133 | results_npy_rgb = final_rgb * ind 134 | results_npy_class_rgb = final_class_rgb * ind 135 | if results_npy_rgb.shape[0]0: 70 | mask = polygon==polygons 71 | polygon_mask = label[mask] 72 | polygon_value = polygon_mask[0] 73 | if polygon_value > 0: 74 | class_name = classes_inverse[polygon_value] 75 | 76 | polygon_prediction = prediction[mask] 77 | polygons_confusion_matrix[polygon_value-1, polygon_prediction-1] += 1 78 | per_class_correct_polygons[class_name].append((np.bincount(polygon_prediction.flatten()).argmax()==polygon_value)) 79 | per_class_correct50_polygons[class_name].append((polygon_prediction==polygon_value).mean() > 0.5) 80 | per_class_correct90_polygons[class_name].append((polygon_prediction==polygon_value).mean() > 0.9) 81 | total_accuracy = np.array(correctly_classified_pixels).sum() / np.array(annotated_pixels).sum() 82 | per_class_accuracy = {class_label: np.array(correct).sum() / np.array(annotated).sum() for class_label, correct, annotated in zip(class_labels, per_class_correct.values(), per_class_annotated_pixels.values())} 83 | per_class_iou = {class_label: np.array(intersection).sum() / np.array(union).sum() for class_label, intersection, union in zip(class_labels, per_class_intersection.values(), per_class_union.values())} 84 | miou = np.array([iou for iou in per_class_iou.values() if not np.isnan(iou)]).mean() 85 | 86 | per_class_correct_polygons = {class_label: np.mean(per_class_correct_polygons[class_name]) for class_name, class_label in classes.items()} 87 | per_class_correct50_polygons = {class_label: np.mean(per_class_correct50_polygons[class_name]) for class_name, class_label in classes.items()} 88 | per_class_correct90_polygons = {class_label: np.mean(per_class_correct90_polygons[class_name]) for class_name, class_label in classes.items()} 89 | 90 | # Confusion Matrix 91 | # Sorted class names by class label 92 | display_labels = [class_name for class_name,class_label in sorted(classes.items(), key=lambda x: x[1])] 93 | disp = ConfusionMatrixDisplay(confusion_matrix=pixels_confusion_matrix/1000, display_labels=display_labels) 94 | fig, ax = plt.subplots(figsize=(15, 15)) 95 | disp.plot(cmap='viridis', ax=ax, values_format='.0f', xticks_rotation=90) 96 | plt.savefig('pixel_confusion.png') 97 | plt.close() 98 | log_dict["eval/pixel_confusion_matrix"] = wandb.Image('pixel_confusion.png') 99 | 100 | disp = ConfusionMatrixDisplay(confusion_matrix=polygons_confusion_matrix, display_labels=display_labels) 101 | fig, ax = plt.subplots(figsize=(15, 15)) 102 | disp.plot(cmap='viridis', ax=ax, values_format='.0f',xticks_rotation=90) 103 | plt.savefig('polygon_confusion.png') 104 | plt.close() 105 | log_dict["eval/polygon_confusion_matrix"] = wandb.Image('polygon_confusion.png') 106 | 107 | log_dict["eval/general/total_accuracy"] = total_accuracy 108 | log_dict["eval/general/miou"] = miou 109 | log_dict["eval/general/total_time"]= total_time 110 | log_dict['eval/general/per_polygon_accuracy'] = np.mean([acc for acc in per_class_correct_polygons.values() if not np.isnan(acc)]) 111 | 112 | for class_name, class_label in classes.items(): 113 | log_dict[f"eval/per_class/per_class_accuracy/{class_name.replace('/','_')}"] = per_class_accuracy[class_label] 114 | log_dict[f"eval/per_class/per_class_iou/{class_name.replace('/','_')}"] = per_class_iou[class_label] 115 | log_dict[f"eval/per_class/per_class_correct_polygons/{class_name.replace('/','_')}"] = per_class_correct_polygons[class_label] 116 | log_dict[f"eval/per_class/per_class_polygon_accuracy50/{class_name.replace('/','_')}"] = per_class_correct50_polygons[class_label] 117 | log_dict[f"eval/per_class/per_class_polygon_accuracy90/{class_name.replace('/','_')}"] = per_class_correct90_polygons[class_label] 118 | 119 | logger.log(log_dict, step=epoch) 120 | 121 | def launch_experiment(experiment, name): 122 | 123 | logger = wandb.init( 124 | project="coral-segmentation", 125 | name=name, 126 | ) 127 | 128 | for epoch in range(experiment.epochs): 129 | experiment.train_epoch(epoch, logger) 130 | 131 | if epoch % 10 == 0 or epoch == experiment.epochs-1: 132 | evaluate(experiment, logger, epoch) 133 | torch.save(experiment.model.state_dict(), f"{name}.pth") 134 | 135 | if __name__ == '__main__': 136 | args = parser.parse_args() 137 | classes = json.loads(open(args.base_dir+"/classes.json").read()) 138 | print(classes) 139 | classes_inverse = {v: k for k, v in classes.items()} 140 | class_labels = list(classes.values()) 141 | class_names = list(classes.keys()) 142 | colors = json.loads(open(args.base_dir+"/colors.json").read()) 143 | 144 | train_images, test_images, train_labels, test_labels, test_polygon_files, counts = utils.load_files(args.base_dir, test_splits=["gabi_split_test"], ignore_splits=[]) 145 | print("Number of training images:", len(train_images) ," number of test images ", len(test_images)) 146 | 147 | 148 | launch_experiment( 149 | experiment=Experiment(classes, train_images, train_labels, counts, classes, colors, output_size=(352*2, 608*2)), 150 | name=args.name 151 | ) 152 | -------------------------------------------------------------------------------- /src/sfm/train_sfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import SequenceDataset 3 | import custom_transforms 4 | from tqdm import tqdm 5 | from model import SfMModel 6 | from loss_functions import get_all_loss_fn, l2_pose_regularization 7 | import argparse 8 | import wandb 9 | import numpy as np 10 | import torch.nn as nn 11 | from utils import tensor2array, change_bn_momentum 12 | import torchvision 13 | 14 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 15 | parser = argparse.ArgumentParser(description='Structure from Motion Learner training on video sequences', 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | 18 | parser.add_argument('--data', type=str, default='data/sequences', help='path to dataset') 19 | parser.add_argument('--batch_size', type=int, default=4, help='input batch size') 20 | parser.add_argument('--num_workers', type=int, default=4, help='number of workers to load data') 21 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train') 22 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 23 | parser.add_argument('--long_sequence_length', type=int, default=2, help='length of long sequences') 24 | parser.add_argument('--subsampled_sequence_length', type=int, default=2, help='length of subsampled sequences') 25 | parser.add_argument('--with_replacement', action='store_true', default=False, help='use replacement when subsampling') 26 | parser.add_argument('--with_ssim', action='store_true', help='use SSIM loss') 27 | parser.add_argument('--with_mask', action='store_true', help='use mask loss') 28 | parser.add_argument('--with_auto_mask', action='store_true', help='use auto mask loss') 29 | parser.add_argument('--padding_mode', type=str, default='zeros', help='padding mode for inverse warp') 30 | parser.add_argument('--neighbor_range', type=int, default=1, help='neighbor range for pairwise loss') 31 | parser.add_argument('--photometric_loss_weight', type=float, default=1.0, help='weight for photometric loss') 32 | parser.add_argument('--geometric_consistency_loss_weight', type=float, default=0.5, help='weight for geometric consistency loss') 33 | parser.add_argument('--smoothness_loss_weight', type=float, default=0.1, help='weight for smoothness loss') 34 | parser.add_argument('--seed', type=int, default=1, help='random seed') 35 | parser.add_argument('--name', type=str, default='default', help='name of the experiment') 36 | parser.add_argument('--checkpoint', type=str, help='path to checkpoint to continue training from ') 37 | parser.add_argument('--l2_pose_reg_weight', type=float, default=0.0, help='weight for l2 pose regularization loss') 38 | parser.add_argument('--accumulate_steps', type=int, default=1, help='number of steps to accumulate gradients over') 39 | parser.add_argument('--intrinsics_file', default='example_inputs/intrinsics.json', help='path to intrinsics file') 40 | 41 | def main(): 42 | args = parser.parse_args() 43 | torch.manual_seed(args.seed) 44 | np.random.seed(args.seed) 45 | 46 | run_wandb = wandb.init( 47 | project="deepreefmap", 48 | name=args.name, 49 | config=vars(args) 50 | ) 51 | 52 | compute_loss = get_all_loss_fn( 53 | args.neighbor_range, 54 | args.subsampled_sequence_length, 55 | args.photometric_loss_weight, 56 | args.geometric_consistency_loss_weight, 57 | args.smoothness_loss_weight, 58 | args.with_ssim, 59 | args.with_mask, 60 | args.with_auto_mask, 61 | args.padding_mode, 62 | return_reprojections=False 63 | ) 64 | train_transform = custom_transforms.Compose([ 65 | custom_transforms.ColorJitter(0.10, 0.10, 0.10, 0.05), 66 | custom_transforms.RandomHorizontalFlip(), 67 | ]) 68 | individual_transform = torchvision.transforms.ColorJitter(0.07, 0.07, 0.07, 0.03) 69 | train_dataset = SequenceDataset(args.data, train=True, transform=train_transform, individual_transform=individual_transform, 70 | seed=args.seed, long_sequence_length=args.long_sequence_length, subsampled_sequence_length=args.subsampled_sequence_length, with_replacement=args.with_replacement) 71 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) 72 | 73 | val_dataset = SequenceDataset(args.data, train=False, transform=None, seed=args.seed, long_sequence_length=args.subsampled_sequence_length, subsampled_sequence_length=args.subsampled_sequence_length, with_replacement=False) 74 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) 75 | 76 | intrinsics = intrinsics.to(device) 77 | 78 | model = SfMModel(args).to(device) 79 | if args.checkpoint is not None: 80 | model.load_state_dict(torch.load(args.checkpoint)) 81 | 82 | change_bn_momentum(model, 0.01) 83 | 84 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 85 | best_loss = 100000 86 | 87 | for epoch in range(args.epochs): 88 | 89 | model.train() 90 | 91 | for batch_idx, (images, images_individually_jittered, _) in tqdm(enumerate(train_loader)): 92 | 93 | images_individually_jittered = [img.to(device) for img in images_individually_jittered] 94 | 95 | intrinsics = (model.intrinsics.unsqueeze(0)*model.const_mul + model.const_add).repeat(args.batch_size, 1).detach() 96 | updated_intrinsics = intrinsics 97 | 98 | depths, poses = model(images_individually_jittered, intrinsics) 99 | del images_individually_jittered 100 | images = [img.to(device) for img in images] 101 | 102 | photometric_loss, geometric_consistency_loss, smoothness_loss = compute_loss(images, depths, poses, updated_intrinsics) 103 | 104 | loss = photometric_loss + geometric_consistency_loss + smoothness_loss 105 | 106 | l2loss = args.l2_pose_reg_weight * l2_pose_regularization(poses) 107 | loss += l2loss 108 | (loss / args.accumulate_steps).backward() 109 | 110 | 111 | if batch_idx % args.accumulate_steps == 0: 112 | optimizer.step() 113 | optimizer.zero_grad() 114 | 115 | run_wandb.log({ 116 | "train/photometric_loss": photometric_loss.item(), 117 | "train/geometric_consistency_loss": geometric_consistency_loss.item(), 118 | "train/smoothness_loss": smoothness_loss.item(), 119 | "train/total_loss": loss.item(), 120 | }) 121 | 122 | if batch_idx % 1000 == 0: 123 | log_images = [] 124 | log_images.append(wandb.Image( 125 | np.swapaxes(tensor2array(images[0][0]).T, 0, 1), 126 | caption="val Input [0][0]")) 127 | log_images.append(wandb.Image( 128 | np.swapaxes(tensor2array(depths[0][0], max_value=None, colormap='magma').T, 0, 1), 129 | caption="val Dispnet Output Normalized [0][0]")) 130 | log_images.append(wandb.Image( 131 | np.swapaxes(tensor2array(1/ (depths[0][0]), max_value=10, colormap='magma').T, 0, 1), 132 | caption="val Depth Output [0][0]")) 133 | 134 | log_images.append(wandb.Image( 135 | np.swapaxes(tensor2array(images[1][0]).T, 0, 1), 136 | caption="val Input [1][0]")) 137 | log_images.append(wandb.Image( 138 | np.swapaxes(tensor2array(depths[1][0], max_value=None, colormap='magma').T, 0, 1), 139 | caption="val Dispnet Output Normalized [1][0]")) 140 | log_images.append(wandb.Image( 141 | np.swapaxes(tensor2array(1/ (depths[1][0]), max_value=10, colormap='magma').T, 0, 1), 142 | caption="val Depth Output [1][0]")) 143 | 144 | log_images.append(wandb.Image( 145 | np.swapaxes(tensor2array(images[0][1]).T, 0, 1), 146 | caption="val Input [0][1]")) 147 | log_images.append(wandb.Image( 148 | np.swapaxes(tensor2array(depths[0][1], max_value=None, colormap='magma').T, 0, 1), 149 | caption="val Dispnet Output Normalized [0][1]")) 150 | log_images.append(wandb.Image( 151 | np.swapaxes(tensor2array(1/ (depths[0][1]), max_value=10, colormap='magma').T, 0, 1), 152 | caption="val Depth Output [0][1]")) 153 | 154 | log_images.append(wandb.Image( 155 | np.swapaxes(tensor2array(images[1][1]).T, 0, 1), 156 | caption="val Input [1][1]")) 157 | log_images.append(wandb.Image( 158 | np.swapaxes(tensor2array(depths[1][1], max_value=None, colormap='magma').T, 0, 1), 159 | caption="val Dispnet Output Normalized [1][1]")) 160 | log_images.append(wandb.Image( 161 | np.swapaxes(tensor2array(1/ (depths[1][1]), max_value=10, colormap='magma').T, 0, 1), 162 | caption="val Depth Output [1][1]")) 163 | 164 | run_wandb.log({f"Train {batch_idx}": log_images, "epoch": epoch}) 165 | import matplotlib.pyplot as plt 166 | plt.imsave(f"train_{batch_idx}.png", np.swapaxes(tensor2array(1/depths[0][0], max_value=10, colormap='magma').T, 0, 1)) 167 | if batch_idx == 50000: 168 | break 169 | 170 | model.eval() 171 | with torch.no_grad(): 172 | val_photometric_loss = [] 173 | val_geometric_consistency_loss = [] 174 | val_smoothness_loss = [] 175 | for batch_idx, (images, images_individually_jittered, _) in tqdm(enumerate(val_loader)): 176 | intrinsics = (model.intrinsics.unsqueeze(0)*model.const_mul + model.const_add).repeat(images[0].shape[0], 1) 177 | 178 | images = [img.to(device) for img in images] 179 | intrinsics = intrinsics.to(device) 180 | updated_intrinsics = intrinsics 181 | depths, poses, _ = model(images, intrinsics) 182 | 183 | photometric_loss, geometric_consistency_loss, smoothness_loss = compute_loss(images, depths, poses, updated_intrinsics) 184 | val_photometric_loss.append(photometric_loss.item()) 185 | val_geometric_consistency_loss.append(geometric_consistency_loss.item()) 186 | val_smoothness_loss.append(smoothness_loss.item()) 187 | 188 | 189 | val_loss = np.mean(np.array(val_photometric_loss)+np.array(val_geometric_consistency_loss) + np.array(val_smoothness_loss)) 190 | run_wandb.log({ 191 | "val/photometric_loss": np.mean(val_photometric_loss), 192 | "val/geometric_consistency_loss": np.mean(val_geometric_consistency_loss), 193 | "val/smoothness_loss": np.mean(val_smoothness_loss), 194 | "val/total_loss": val_loss, 195 | }) 196 | torch.save(model.state_dict(), args.name + "_last.pth") 197 | if val_loss < best_loss: 198 | print("BEST!") 199 | best_loss = val_loss 200 | torch.save(model.state_dict(), args.name + "_best.pth") 201 | if epoch % 10 == 0 and epoch > 0: 202 | torch.save(model.state_dict(), args.name + "_" + str(epoch) + ".pth") 203 | if __name__ == '__main__': 204 | main() -------------------------------------------------------------------------------- /src/reconstruct.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | import torchvision.transforms.functional as F 7 | from PIL import Image 8 | import pandas as pd 9 | import torch.nn as nn 10 | from time import time 11 | import segmentation_models_pytorch as smp 12 | import segmentation 13 | from sfm.model import SfMModel 14 | from segmentation.model import SegmentationModel 15 | from video_utils import extract_frames_and_gopro_gravity_vector, render_video 16 | from tqdm import tqdm 17 | import h5py 18 | import open3d as o3d 19 | from sklearn.decomposition import PCA 20 | from reconstruction_utils import get_closest_to_centroid_with_attributes_of_closest_to_cam, map_3d, get_matching_indices, get_rotation_matrix_to_align_pose_with_gravity, get_edgeness, aggregate_2d_grid 21 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 22 | from sfm.inverse_warp import EUCMCamera, Pose, pose_vec2mat, rectify_eucm 23 | import scipy 24 | import json 25 | 26 | import sys 27 | sys.path.append('/home/jonathan/mee-deepreefmap/src/ffmpeg-7.0.2-amd64-static/ffmpeg') 28 | 29 | # Start by parsing args 30 | parser = argparse.ArgumentParser(description='Reconstruct a 3D model from a video') 31 | parser.add_argument('--input_video', type=str, help='Path to video file - can be multiple files, in which case the paths should be comma separated') 32 | parser.add_argument('--out_dir', type=str, default='out', help='Path to output directory - will be created if does not exist') 33 | parser.add_argument('--tmp_dir', type=str, default='tmp', help='Path to temporary directory - will be created if does not exist') 34 | parser.add_argument('--timestamp', type=str, help='Begin and End timestamp of the transect. In case multiple videos are supplied, the format should be comma separated e.g. of the form "0:23-end,begin-1:44"') 35 | parser.add_argument('--sfm_checkpoint', type=str, default='../sfm_net.pth', help='Path to the sfm_net checkpoint') 36 | parser.add_argument('--segmentation_checkpoint', type=str, default='../segmentation_net.pth', help='Path to the segmentation_net checkpoint') 37 | parser.add_argument('--height', type=int, default=384, help='Height in pixels to which input video is scaled') 38 | parser.add_argument('--width', type=int, default=640, help='Width in pixels to which input video is scaled') 39 | parser.add_argument('--seg_height', type=int, default=384*2, help='Height in pixels to which input video is scaled') 40 | parser.add_argument('--seg_width', type=int, default=640*2, help='Width in pixels to which input video is scaled') 41 | parser.add_argument('--fps', type=int, default=8, help='FPS of the input video') 42 | parser.add_argument('--reverse', action='store_true', help='Whether the transect video is filmed backwards (with a back-facing camera)') 43 | parser.add_argument('--number_of_points_per_image', type=int, default=2000, help='Number of points to sample from each image') 44 | parser.add_argument('--frames_per_volume', type=int, default=500, help='Number of frames per TSDF Volume') 45 | parser.add_argument('--tsdf_overlap', type=int, default=100, help='Overlap in frames over TSDF Volumes') 46 | parser.add_argument('--distance_thresh', type=float, default=0.2, help='Distance threshold for points added to cloud') 47 | parser.add_argument('--ignore_classes_in_point_cloud', type=str, default="background,fish,human", help='Classes to ignore when adding points to cloud') 48 | parser.add_argument('--ignore_classes_in_benthic_cover', type=str, default="background,fish,human,transect tools,transect line,dark", help='Classes to ignore when calculating benthic cover percentages') 49 | parser.add_argument('--intrinsics_file', type=str, default="../example_inputs/intrinsics_eucm.json", help='Path to intrinsics file') 50 | parser.add_argument('--class_to_label_file', type=str, default="../example_inputs/class_to_label.json", help='Path to label_to_class_file') 51 | parser.add_argument('--class_to_color_file', type=str, default="../example_inputs/class_to_color.json", help='Path to class_to_color_file') 52 | parser.add_argument('--output_2d_grid_size', type=int, default=2000, help='Size of the 2D grid used for benthic cover analysis - a higher grid size will produce higher resolution outputs but takes longer to compute and may have empty grid cells') 53 | parser.add_argument('--buffer_size', type=int, default=2, help='Number of frames to use for temporal smoothing') 54 | parser.add_argument('--render_video', action='store_true', help='Whether to render output 4-panel video') 55 | args = parser.parse_args() 56 | 57 | def main(args): 58 | 59 | t = time() 60 | 61 | with open(args.class_to_color_file) as f: 62 | class_to_color = {k: (np.array(v)).astype(np.uint8) for k,v in json.load(f).items()} 63 | with open(args.class_to_label_file) as f: 64 | class_to_label = json.load(f) 65 | label_to_class = {v:k for k,v in class_to_label.items()} 66 | label_to_color = {k: class_to_color[v] for k,v in label_to_class.items()} 67 | 68 | 69 | grav = extract_frames_and_gopro_gravity_vector( 70 | args.input_video.split(","), 71 | args.timestamp.split(","), 72 | args.seg_width, 73 | args.seg_height, 74 | args.fps, 75 | args.tmp_dir, 76 | args.reverse, 77 | ) 78 | print("Extracted Frames And Gravity Vector in", time() - t, "seconds") 79 | 80 | 81 | h5f = h5py.File(args.tmp_dir + '/tmp.hdf5', 'w') 82 | 83 | img_list = [args.tmp_dir + "/rgb/" +file for file in sorted(os.listdir(args.tmp_dir + "/rgb")) if "jpg" in file] 84 | print("Running Neural Networks ...") 85 | 86 | depths, depth_uncertainties, poses, semantic_segmentation, intrinsics = get_nn_predictions( 87 | img_list, 88 | grav, 89 | len(class_to_label) + 1, 90 | h5f, 91 | args, 92 | ) 93 | 94 | print("Ran NN Predictions in ", time() - t, "seconds") 95 | print("Building Point Cloud ...") 96 | os.makedirs(args.out_dir + "/videos", exist_ok=True) 97 | xyz_index_arr, distance2cam_arr, seg_arr, frame_index_arr, depth_unc_arr, keep_masks, dist_cutoffs = get_point_cloud( 98 | img_list, 99 | depths, 100 | poses, 101 | depth_uncertainties, 102 | semantic_segmentation, 103 | intrinsics, 104 | label_to_color, 105 | class_to_label, 106 | h5f, 107 | args 108 | ) 109 | 110 | print("Integrating TSDF!") 111 | tsdf_xyz, tsdf_rgb = tsdf_point_cloud(img_list, depths, keep_masks, poses, intrinsics, np.mean(depths), args.frames_per_volume, args.tsdf_overlap, dist_cutoffs) 112 | print("Integrated TSDF Point Cloud in ", time() - t, "seconds") 113 | 114 | idx = get_matching_indices(tsdf_xyz, xyz_index_arr) 115 | print("Matched TSDF to Point Cloud in ", time() - t, "seconds") 116 | rgb_seg_arr = np.vectorize(lambda k: label_to_color[k], signature='()->(n)')(seg_arr[idx]) 117 | tsdf_pc = pd.DataFrame({ 118 | 'x':tsdf_xyz[:,0], 119 | 'y':tsdf_xyz[:,1], 120 | 'z':tsdf_xyz[:,2], 121 | 'r':tsdf_rgb[:,0], 122 | 'g':tsdf_rgb[:,1], 123 | 'b':tsdf_rgb[:,2], 124 | 'distance_to_cam': distance2cam_arr[idx], 125 | 'class': seg_arr[idx], 126 | 'class_r': rgb_seg_arr[:,0], 127 | 'class_g': rgb_seg_arr[:,1], 128 | 'class_b': rgb_seg_arr[:,2], 129 | 'frame_index': frame_index_arr[idx], 130 | 'depth_uncertainty': depth_unc_arr[idx], 131 | }) 132 | tsdf_pc.to_csv(args.out_dir + "/point_cloud_tsdf.csv", index=False) 133 | print("Saved TSDF Point Cloud in ", time() - t, "seconds") 134 | 135 | 136 | print("Starting Benthic Cover Analsysis after ", time() - t, "seconds") 137 | results, percentage_covers = benthic_cover_analysis(tsdf_pc, label_to_class, args.ignore_classes_in_benthic_cover.split(","), bins=args.output_2d_grid_size) 138 | np.save(args.out_dir + "/results.npy", results) 139 | json.dump(percentage_covers, open(args.out_dir + "/percentage_covers.json", "w")) 140 | print("Finished Benthic Cover Analysis in ", time() - t, "seconds") 141 | 142 | os.system("cp "+args.class_to_color_file+" "+ args.out_dir) 143 | if args.render_video: 144 | os.system("cp "+args.tmp_dir+"/*_.mp4 "+ args.out_dir + "/videos") 145 | render_video(img_list, depths, semantic_segmentation, results, args.fps, class_to_label, label_to_color, args.tmp_dir, args.reverse) 146 | os.system("mv " + args.tmp_dir + "/out.mp4 " + args.out_dir + "/videos") 147 | print("Rendered Video in ", time() - t, "seconds") 148 | return 149 | 150 | 151 | def reset_batchnorm_layers(model): 152 | for module in model.modules(): 153 | if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 154 | module.eps = 1e-4 155 | 156 | 157 | def change_bn_momentum(model, new_value): 158 | for name, module in model.named_modules(): 159 | if isinstance(module, nn.BatchNorm2d): 160 | module.momentum = new_value 161 | 162 | 163 | def expand_zeros(mask): 164 | # Add an extra batch dimension and channel dimension to the mask for convolution 165 | mask = mask.unsqueeze(0).unsqueeze(0).float() # Shape: 1x1xHxW 166 | 167 | # Define a 3x3 kernel filled with ones 168 | kernel = torch.ones((1, 1, 3, 3), dtype=torch.float32, device=mask.device) 169 | 170 | # Perform 2D convolution with padding=1 to keep the same output size 171 | conv_result = torch.nn.functional.conv2d(mask, kernel, padding=1) 172 | 173 | # Any place where the convolution result is less than 9 means it had a zero in the neighborhood 174 | result_mask = (conv_result == 9).squeeze().bool() 175 | 176 | return result_mask 177 | 178 | def get_nn_predictions(img_list, grav, num_classes, h5f, args): 179 | totensor = torchvision.transforms.ToTensor() 180 | normalize = torchvision.transforms.Normalize(mean=[0.45, 0.45, 0.45], 181 | std=[0.225, 0.225, 0.225]) 182 | sfm_model = SfMModel().to(device) 183 | sfm_model.load_state_dict(torch.load(args.sfm_checkpoint, map_location=device)) 184 | change_bn_momentum(sfm_model, 0.01) 185 | reset_batchnorm_layers(sfm_model) 186 | sfm_model.eval() 187 | 188 | segmentation_model = SegmentationModel(num_classes).to(device) 189 | segmentation_model.load_state_dict(torch.load(args.segmentation_checkpoint, map_location=device)) 190 | segmentation_model.eval() 191 | 192 | intrinsics = torch.tensor(list(json.load(open(args.intrinsics_file)).values())).float().to(device).unsqueeze(0) 193 | 194 | buffer_size = args.buffer_size 195 | # Start by initializing, load the images in a buffer of buffer_size subsequent frames 196 | images = [normalize(totensor(Image.open(img_list[i]))).to(device).unsqueeze(0) for i in range(buffer_size-1)] 197 | 198 | counts = np.zeros(len(img_list), dtype=np.uint8) 199 | depths_buffered = h5f.create_dataset("depths_buffered", (args.buffer_size, len(img_list), args.height, args.width), dtype='f4') 200 | depths = h5f.create_dataset("depths", (len(img_list), args.height, args.width), dtype='f4') 201 | depth_uncertainties = h5f.create_dataset("depth_uncertainties", (len(img_list), args.height, args.width), dtype='f4') 202 | semantic_segmentation = h5f.create_dataset("semantic_segmentation", (len(img_list), args.height, args.width), dtype='u1') 203 | intrinsics_predicted_buffered = h5f.create_dataset("intrinsics_buffered", (args.buffer_size, len(img_list), 6), dtype='f4') 204 | intrinsics_predicted = h5f.create_dataset("intrinsics", (len(img_list), 6), dtype='f4') 205 | poses = {} 206 | 207 | semseg_buffer = torch.zeros((3, num_classes, args.height, args.width), requires_grad=False).to(device) 208 | wtens = torch.tensor([1.0, 2.0, 1.0], requires_grad=False).to(device).unsqueeze(1).unsqueeze(1).unsqueeze(1) 209 | with torch.no_grad(): 210 | semseg_logits = [] 211 | for i in range(buffer_size-1): 212 | semseg_logits.append(segmentation.model.predict(segmentation_model, images[i], num_classes, args.height, args.width)) 213 | 214 | 215 | for i in range(buffer_size-2): 216 | semantic_segmentation[i] = torch.stack(semseg_logits[max(0,i-1):i+1]).mean(dim=0).argmax(dim=0).cpu().numpy() 217 | 218 | if len(semseg_logits)==1: 219 | semseg_logits.append(semseg_logits[-1]) 220 | semseg_buffer[0] = semseg_logits[-2] 221 | semseg_buffer[1] = semseg_logits[-1] 222 | del semseg_logits 223 | images = [F.resize(x, (args.height, args.width)) for x in images] 224 | depth_features = [[f.detach() for f in sfm_model.extract_features(x)] for x in images] 225 | 226 | 227 | for end_index in tqdm(range(buffer_size-1, len(img_list))): 228 | new_im = normalize(totensor(Image.open(img_list[end_index]))).to(device).unsqueeze(0) 229 | 230 | 231 | with torch.no_grad(): 232 | semseg_buffer[2] = segmentation.model.predict(segmentation_model, new_im, num_classes, args.height, args.width) 233 | semantic_segmentation[end_index-1] = (semseg_buffer*wtens).mean(dim=0).argmax(dim=0).cpu().numpy() 234 | semseg_buffer[:2] = semseg_buffer[1:].clone() 235 | 236 | images.append(F.resize(new_im, (args.height, args.width))) 237 | 238 | with torch.no_grad(): 239 | depth_features.append([f.detach() for f in sfm_model.extract_features(images[-1])]) 240 | 241 | depth, pose, intrinsics_updated = sfm_model.get_depth_and_poses_from_features(images, depth_features, intrinsics) 242 | for i in range(buffer_size): 243 | idx = end_index - buffer_size + i + 1 244 | count = counts[idx] 245 | depths_buffered[count, idx] = depth[i].squeeze().detach().cpu().numpy() 246 | intrinsics_predicted_buffered[count, idx] = intrinsics_updated[i].detach().cpu().numpy() 247 | counts[idx]+=1 248 | for j in range(buffer_size): 249 | jdx = end_index - buffer_size + j + 1 250 | if pose[i][j] != []: 251 | if idx not in poses: 252 | poses[idx] = {} 253 | if jdx not in poses[idx]: 254 | poses[idx][jdx] = [] 255 | poses[idx][jdx].append(pose[i][j].detach().unsqueeze(0).cpu().numpy()) 256 | 257 | 258 | images.pop(0) 259 | depth_features.pop(0) 260 | 261 | for i in tqdm(range(len(img_list))): 262 | depths[i] = np.mean(depths_buffered[:counts[i],i],axis=0) 263 | for i in tqdm(range(len(img_list))): 264 | depth_uncertainties[i] = np.std(depths_buffered[:counts[i],i],axis=0) 265 | for i in tqdm(range(len(img_list))): 266 | intrinsics_predicted[i] = np.median(intrinsics_predicted_buffered[:counts[i],i],axis=0) 267 | 268 | l = len(img_list) 269 | 270 | semantic_segmentation[-1] = (semseg_buffer * wtens).mean(dim=0).argmax(dim=0).cpu().numpy() 271 | 272 | poses = [(torch.tensor(np.median(poses[i+1][i],axis=0) - np.median(poses[i][i+1],axis=0))/2) for i in range(len(poses)-1)] 273 | poses = [pose_vec2mat(p).squeeze().cpu().numpy() for p in poses] 274 | poses = [np.vstack([p, np.array([0, 0, 0, 1]).reshape(1, 4)]) for p in poses] 275 | 276 | poses = np.array(poses) 277 | med_rot = np.median(poses[:,:3,:3]-np.eye(3), axis=0) 278 | poses[:,:3,:3] -= med_rot 279 | 280 | grav_buffer = 100 281 | if grav is not None: 282 | pose0 = np.eye(4) 283 | new_cum_poses = np.zeros((len(poses)+1,4,4)) 284 | 285 | 286 | grav0 = np.mean(grav[:grav_buffer],axis=0) 287 | 288 | correction = get_rotation_matrix_to_align_pose_with_gravity(pose0, grav0) 289 | pose0[:3,:3] = correction @ pose0[:3,:3] 290 | new_cum_poses[0] = pose0.copy() 291 | 292 | for i, (pose, g_) in enumerate(zip(poses,grav[1:])): 293 | 294 | g = np.mean(grav[max(0, 1 + i - grav_buffer):min(i + grav_buffer, len(grav-1))], axis=0) 295 | pose0 = pose0 @ pose 296 | correction = get_rotation_matrix_to_align_pose_with_gravity(pose0, g) 297 | pose0[:3,:3] = correction @ pose0[:3,:3] 298 | new_cum_poses[i+1] = pose0.copy() 299 | poses = np.array(new_cum_poses) 300 | else: 301 | new_cum_poses = np.zeros((len(poses)+1,4,4)) 302 | new_cum_poses[0] = np.eye(4) 303 | for i in range(len(poses)): 304 | new_cum_poses[i+1] = new_cum_poses[i] @ poses[i] 305 | poses = np.array(new_cum_poses) 306 | 307 | return depths, depth_uncertainties, poses, semantic_segmentation, intrinsics_predicted 308 | 309 | 310 | def get_point_cloud(image_list, depths, poses, depth_uncertainties, semantic_segmentation, intrinsics, label_to_color, class_to_label, h5f, args): 311 | ignore_classes = args.ignore_classes_in_point_cloud.split(",") 312 | 313 | def class_to_color(class_arr): 314 | color_arr = np.zeros((class_arr.shape[0], 3), dtype=np.uint8) 315 | for val in np.unique(class_arr): 316 | color_arr[class_arr==val] = label_to_color[val] 317 | return color_arr 318 | dist_cutoffs = [] 319 | with torch.no_grad(): 320 | 321 | 322 | xyz_arr = h5f.create_dataset("xyz_arr", (len(image_list)*args.number_of_points_per_image, 3), dtype='f4') 323 | distance2cam_arr = h5f.create_dataset("distance2cam_arr", (len(image_list)*args.number_of_points_per_image), dtype='f4') 324 | seg_arr = h5f.create_dataset("seg_arr", (len(image_list)*args.number_of_points_per_image), dtype='u1') 325 | keep_masks = h5f.create_dataset("keep_masks", (len(image_list), args.height, args.width), dtype='u1') 326 | depth_unc_arr = h5f.create_dataset("depth_unc_arr", (len(image_list)*args.number_of_points_per_image), dtype='f4') 327 | frame_index_arr = h5f.create_dataset("frame_index_arr", (len(image_list)*args.number_of_points_per_image), dtype='u2') 328 | 329 | cursor = 0 330 | 331 | for i in tqdm(range(len(poses))): 332 | 333 | pose = torch.tensor((poses[i])[:3]).float().to(device) 334 | 335 | 336 | cam = EUCMCamera(torch.tensor(intrinsics[i]).unsqueeze(0).to(device), Tcw=Pose(T=1)) 337 | depth_i_tensor = torch.tensor(depths[i]).to(device) 338 | coords = cam.reconstruct_depth_map(depth_i_tensor.unsqueeze(0).unsqueeze(0).to(device)).squeeze() 339 | coords = coords.reshape(3, -1) 340 | coords = (pose @ torch.cat([coords, torch.ones_like(coords[:1])], dim=0).reshape(4,-1)).T.cpu() 341 | 342 | 343 | dist_cutoffs.append(args.distance_thresh) 344 | keep_mask = depth_i_tensor.squeeze() < args.distance_thresh 345 | seg = torch.tensor(semantic_segmentation[i]).to(device) 346 | 347 | #Exclude points on the 'edge' of objects, i.e. where the countours depth map varies a lot 348 | # TODO Magic Numbers 349 | keep_mask = torch.logical_and(get_edgeness(depth_i_tensor) < 0.04, keep_mask) 350 | keep_mask[30:170,30:-30] = 0 351 | for class_name in ignore_classes: 352 | keep_mask = torch.logical_and(seg != class_to_label[class_name], keep_mask) 353 | keep_mask = expand_zeros(keep_mask) 354 | keep_mask = keep_mask.cpu().numpy() 355 | keep_masks[i] = keep_mask.astype(np.uint8) 356 | keep_mask = keep_mask.reshape(-1) 357 | valid_points = keep_mask.sum().item() 358 | random_selection = np.random.permutation(valid_points)[:args.number_of_points_per_image] 359 | offset = min(valid_points, args.number_of_points_per_image) 360 | 361 | xyz_arr[cursor:cursor+offset]=coords[keep_mask][random_selection] 362 | distance2cam_arr[cursor:cursor+offset]= depths[i].reshape(-1)[keep_mask][random_selection] 363 | 364 | 365 | seg_arr[cursor:cursor+offset] = semantic_segmentation[i].reshape(-1).astype(np.uint8)[keep_mask][random_selection] 366 | dunc = depth_uncertainties[i].reshape(-1)[keep_mask][random_selection] 367 | depth_unc_arr[cursor:cursor+offset]=dunc 368 | frame_index_arr[cursor:cursor+offset]= np.zeros_like(dunc, dtype=np.uint16)+i 369 | 370 | cursor += offset 371 | 372 | 373 | 374 | print("Filtering redundant points") 375 | xyz_index_arr = map_3d(np.concatenate([ 376 | xyz_arr[:cursor], 377 | distance2cam_arr[:cursor].reshape(-1,1), 378 | np.arange(len(xyz_arr)).reshape(-1, 1)[:cursor]], axis=1), get_closest_to_centroid_with_attributes_of_closest_to_cam, 0.003) 379 | filtered_indices = xyz_index_arr[:, -1].astype(np.uint32) 380 | 381 | return xyz_index_arr[:,:3], distance2cam_arr[:cursor][filtered_indices], seg_arr[:cursor][filtered_indices], frame_index_arr[:cursor][filtered_indices], depth_unc_arr[:cursor][filtered_indices], keep_masks, dist_cutoffs 382 | 383 | 384 | def tsdf_point_cloud(img_list, depths, masks, poses, intrinsics, cutoff, frames_per_volume, tsdf_overlap,dist_cutoffs): 385 | # TODO Magic Numbers 386 | xyz = [] 387 | rgb = [] 388 | 389 | 390 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 391 | voxel_length=0.3 / 512.0, 392 | sdf_trunc=0.035, 393 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8) 394 | totensor = torchvision.transforms.ToTensor() 395 | 396 | mask_out_background = np.ones_like(masks[0].astype(np.float32)) 397 | intrinsics = torch.tensor(intrinsics).float() 398 | 399 | #TODO: Magic numbers 400 | mask_out_background[:170,80:-80] *= 0 401 | for i in tqdm(range(len(poses))): 402 | 403 | if i > len(poses)-10: 404 | mask_out_background = np.ones_like(masks[0]) 405 | 406 | # Rectify to linear intrinsics 407 | projected_img, projected_mask, projected_depth = rectify_eucm( 408 | totensor(Image.open(img_list[i])).unsqueeze(0), 409 | torch.tensor(masks[i].astype(np.float32)*mask_out_background).unsqueeze(0).unsqueeze(0).float(), 410 | torch.tensor(depths[i]).unsqueeze(0).unsqueeze(0), 411 | intrinsics[i] 412 | ) 413 | 414 | depth = o3d.geometry.Image(projected_depth*projected_mask) 415 | color = o3d.geometry.Image(np.ascontiguousarray(projected_img.transpose(1, 2, 0)*255.).astype(np.uint8)) 416 | 417 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 418 | color, depth, depth_trunc=dist_cutoffs[i], convert_rgb_to_intensity=False, depth_scale=1) 419 | 420 | volume.integrate( 421 | rgbd, 422 | o3d.camera.PinholeCameraIntrinsic( 423 | width= args.width, 424 | height= args.height, 425 | fx=intrinsics[i][0], 426 | fy=intrinsics[i][1], 427 | cx=intrinsics[i][2], 428 | cy=intrinsics[i][3], 429 | ), 430 | np.linalg.inv(poses[i])) 431 | if (i % frames_per_volume) == (frames_per_volume - tsdf_overlap): 432 | volume2 = o3d.pipelines.integration.ScalableTSDFVolume( 433 | voxel_length=0.3 / 512.0, 434 | sdf_trunc=0.015, 435 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 436 | ) 437 | 438 | if i % frames_per_volume >= (frames_per_volume - tsdf_overlap): 439 | volume2.integrate( 440 | rgbd, 441 | o3d.camera.PinholeCameraIntrinsic( 442 | width= args.width, 443 | height= args.height, 444 | fx=intrinsics[i][0], 445 | fy=intrinsics[i][1], 446 | cx=intrinsics[i][2], 447 | cy=intrinsics[i][3], 448 | ), 449 | np.linalg.inv(poses[i])) 450 | if (i % frames_per_volume) == frames_per_volume - 1: 451 | pc = volume.extract_point_cloud() 452 | pc = volume.extract_point_cloud() 453 | xyz.append(np.array(pc.points)) 454 | rgb.append((np.array(pc.colors)*255).astype(np.uint8)) 455 | volume = volume2 456 | pc = volume.extract_point_cloud() 457 | pc = volume.extract_point_cloud() 458 | xyz.append(np.array(pc.points)) 459 | rgb.append((np.array(pc.colors)*255).astype(np.uint8)) 460 | return np.concatenate(xyz), np.concatenate(rgb) 461 | 462 | 463 | def benthic_cover_analysis(pc, label_to_class, ignore_classes_in_benthic_cover, bins=1000): 464 | #step 1: fit PCA 465 | pca = PCA(n_components=2) 466 | pca.fit(pc[['x','y', 'z']].values) 467 | x_axis = pca.components_[0] # Estimated x-axis 468 | y_axis = pca.components_[1] # Estimated y-axis 469 | 470 | # Step 2: Calculate the normal vector to the x-y plane as the z-axis 471 | z_axis = np.cross(x_axis, y_axis) 472 | z_axis /= np.linalg.norm(z_axis) 473 | 474 | # Step 3: Create the transformation matrix 475 | transformation_matrix = np.vstack((x_axis, y_axis, z_axis)).T 476 | 477 | # Now, you can apply this transformation matrix to your point cloud 478 | transformed = np.dot(pc[['x','y', 'z']].values, transformation_matrix) 479 | transformed -= np.min(transformed, axis=0) 480 | xmax, ymax, zmax = np.max(transformed, axis=0) 481 | 482 | discretization = xmax / bins 483 | pcarr = np.concatenate([transformed, pc.drop(columns=["x", "y", "z"]).values], axis=1) 484 | out = aggregate_2d_grid(pcarr, size=discretization) 485 | 486 | xcoords = out[:,0].astype(np.int32) 487 | ycoords = out[:,1].astype(np.int32) 488 | 489 | img = np.zeros((xcoords.max()+1, ycoords.max()+1, 12)) 490 | 491 | img[xcoords, ycoords] = out[:,2:] 492 | 493 | percentage_covers = {} 494 | benthic_class = out[:,7].astype(np.uint8) 495 | all_classes = (benthic_class!=0) 496 | for class_label, class_name in label_to_class.items(): 497 | if class_name not in ignore_classes_in_benthic_cover: 498 | percentage_covers[class_name] = (benthic_class==class_label).sum() 499 | else: 500 | all_classes = np.logical_and(all_classes, benthic_class!=class_label) 501 | all_classes = all_classes.sum() 502 | percentage_covers = {k: v / all_classes for k,v in percentage_covers.items()} 503 | return img, percentage_covers 504 | 505 | if __name__ == "__main__": 506 | main(args) -------------------------------------------------------------------------------- /src/sfm/inverse_warp.py: -------------------------------------------------------------------------------- 1 | # TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. 2 | 3 | import torch 4 | import numpy as np 5 | from abc import ABC 6 | from functools import partial 7 | # TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. 8 | 9 | ## TODO: link to repository! 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | def is_tuple(data): 15 | """Checks if data is a tuple.""" 16 | return isinstance(data, tuple) 17 | def is_seq(data): 18 | """Checks if data is a list or tuple.""" 19 | return is_tuple(data) or is_list(data) 20 | def is_dict(data): 21 | """Checks if data is a dictionary.""" 22 | return isinstance(data, dict) 23 | 24 | def is_list(data): 25 | """Checks if data is a list.""" 26 | return isinstance(data, list) 27 | 28 | def iterate1(func): 29 | """Decorator to iterate over a list (first argument)""" 30 | def inner(var, *args, **kwargs): 31 | if is_seq(var): 32 | return [func(v, *args, **kwargs) for v in var] 33 | elif is_dict(var): 34 | return {key: func(val, *args, **kwargs) for key, val in var.items()} 35 | else: 36 | return func(var, *args, **kwargs) 37 | return inner 38 | def is_int(data): 39 | """Checks if data is an integer.""" 40 | return isinstance(data, int) 41 | 42 | def to_global_pose(pose, zero_origin=False): 43 | """Get global pose coordinates from current and context poses""" 44 | if zero_origin: 45 | pose[0].T[[0]] = torch.eye(4, device=pose[0].device, dtype=pose[0].dtype) 46 | for b in range(1, len(pose[0])): 47 | pose[0].T[[b]] = (pose[0][b] * pose[0][0]).T.float() 48 | for key in pose.keys(): 49 | if key != 0: 50 | pose[key] = pose[key] * pose[0] 51 | return pose 52 | 53 | 54 | # def to_global_pose(pose, zero_origin=False): 55 | # """Get global pose coordinates from current and context poses""" 56 | # if zero_origin: 57 | # pose[(0, 0)].T = torch.eye(4, device=pose[(0, 0)].device, dtype=pose[(0, 0)].dtype). \ 58 | # repeat(pose[(0, 0)].shape[0], 1, 1) 59 | # for key in pose.keys(): 60 | # if key[0] == 0 and key[1] != 0: 61 | # pose[key].T = (pose[key] * pose[(0, 0)]).T 62 | # for key in pose.keys(): 63 | # if key[0] != 0: 64 | # pose[key] = pose[key] * pose[(0, 0)] 65 | # return pose 66 | 67 | 68 | def euler2mat(angle): 69 | """Convert euler angles to rotation matrix""" 70 | B = angle.size(0) 71 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2] 72 | 73 | cosz = torch.cos(z) 74 | sinz = torch.sin(z) 75 | 76 | zeros = z.detach() * 0 77 | ones = zeros.detach() + 1 78 | zmat = torch.stack([ cosz, -sinz, zeros, 79 | sinz, cosz, zeros, 80 | zeros, zeros, ones], dim=1).view(B, 3, 3) 81 | 82 | cosy = torch.cos(y) 83 | siny = torch.sin(y) 84 | 85 | ymat = torch.stack([ cosy, zeros, siny, 86 | zeros, ones, zeros, 87 | -siny, zeros, cosy], dim=1).view(B, 3, 3) 88 | 89 | cosx = torch.cos(x) 90 | sinx = torch.sin(x) 91 | 92 | xmat = torch.stack([ ones, zeros, zeros, 93 | zeros, cosx, -sinx, 94 | zeros, sinx, cosx], dim=1).view(B, 3, 3) 95 | 96 | rot_mat = xmat.bmm(ymat).bmm(zmat) 97 | return rot_mat 98 | 99 | def pose_vec2mat(vec, mode='euler'): 100 | """Convert translation and Euler rotation to a [B,4,4] torch.Tensor transformation matrix""" 101 | if mode is None: 102 | return vec 103 | trans, rot = vec[:, :3].unsqueeze(-1), vec[:, 3:] 104 | if mode == 'euler': 105 | rot_mat = euler2mat(rot) 106 | else: 107 | raise ValueError('Rotation mode not supported {}'.format(mode)) 108 | mat = torch.cat([rot_mat, trans], dim=2) # [B,3,4] 109 | return mat 110 | 111 | def is_tensor(data): 112 | """Checks if data is a torch tensor.""" 113 | return type(data) == torch.Tensor 114 | 115 | 116 | @iterate1 117 | def invert_pose(T): 118 | """Invert a [B,4,4] torch.Tensor pose""" 119 | Tinv = torch.eye(4, device=T.device, dtype=T.dtype).repeat([len(T), 1, 1]) 120 | Tinv[:, :3, :3] = torch.transpose(T[:, :3, :3], -2, -1) 121 | Tinv[:, :3, -1] = torch.bmm(-1. * Tinv[:, :3, :3], T[:, :3, -1].unsqueeze(-1)).squeeze(-1) 122 | return Tinv 123 | # return torch.linalg.inv(T) 124 | 125 | 126 | def tvec_to_translation(tvec): 127 | """Convert translation vector to translation matrix (no rotation)""" 128 | batch_size = tvec.shape[0] 129 | T = torch.eye(4).to(device=tvec.device).repeat(batch_size, 1, 1) 130 | t = tvec.contiguous().view(-1, 3, 1) 131 | T[:, :3, 3, None] = t 132 | return T 133 | 134 | 135 | def euler2rot(euler): 136 | """Convert Euler parameters to a [B,3,3] torch.Tensor rotation matrix""" 137 | euler_norm = torch.norm(euler, 2, 2, True) 138 | axis = euler / (euler_norm + 1e-7) 139 | 140 | cos_a = torch.cos(euler_norm) 141 | sin_a = torch.sin(euler_norm) 142 | cos1_a = 1 - cos_a 143 | 144 | x = axis[..., 0].unsqueeze(1) 145 | y = axis[..., 1].unsqueeze(1) 146 | z = axis[..., 2].unsqueeze(1) 147 | 148 | x_sin = x * sin_a 149 | y_sin = y * sin_a 150 | z_sin = z * sin_a 151 | x_cos1 = x * cos1_a 152 | y_cos1 = y * cos1_a 153 | z_cos1 = z * cos1_a 154 | 155 | xx_cos1 = x * x_cos1 156 | yy_cos1 = y * y_cos1 157 | zz_cos1 = z * z_cos1 158 | xy_cos1 = x * y_cos1 159 | yz_cos1 = y * z_cos1 160 | zx_cos1 = z * x_cos1 161 | 162 | batch_size = euler.shape[0] 163 | rot = torch.zeros((batch_size, 4, 4)).to(device=euler.device) 164 | 165 | rot[:, 0, 0] = torch.squeeze(xx_cos1 + cos_a) 166 | rot[:, 0, 1] = torch.squeeze(xy_cos1 - z_sin) 167 | rot[:, 0, 2] = torch.squeeze(zx_cos1 + y_sin) 168 | rot[:, 1, 0] = torch.squeeze(xy_cos1 + z_sin) 169 | rot[:, 1, 1] = torch.squeeze(yy_cos1 + cos_a) 170 | rot[:, 1, 2] = torch.squeeze(yz_cos1 - x_sin) 171 | rot[:, 2, 0] = torch.squeeze(zx_cos1 - y_sin) 172 | rot[:, 2, 1] = torch.squeeze(yz_cos1 + x_sin) 173 | rot[:, 2, 2] = torch.squeeze(zz_cos1 + cos_a) 174 | rot[:, 3, 3] = 1 175 | 176 | return rot 177 | 178 | 179 | def vec2mat(euler, translation, invert=False): 180 | """Convert Euler rotation and translation to a [B,4,4] torch.Tensor transformation matrix""" 181 | R = euler2rot(euler) 182 | t = translation.clone() 183 | 184 | if invert: 185 | R = R.transpose(1, 2) 186 | t *= -1 187 | 188 | T = tvec_to_translation(t) 189 | 190 | if invert: 191 | M = torch.matmul(R, T) 192 | else: 193 | M = torch.matmul(T, R) 194 | 195 | return M 196 | 197 | 198 | def rot2quat(R): 199 | """Convert a [B,3,3] rotation matrix to [B,4] quaternions""" 200 | b, _, _ = R.shape 201 | q = torch.ones((b, 4), device=R.device) 202 | 203 | R00 = R[:, 0, 0] 204 | R01 = R[:, 0, 1] 205 | R02 = R[:, 0, 2] 206 | R10 = R[:, 1, 0] 207 | R11 = R[:, 1, 1] 208 | R12 = R[:, 1, 2] 209 | R20 = R[:, 2, 0] 210 | R21 = R[:, 2, 1] 211 | R22 = R[:, 2, 2] 212 | 213 | q[:, 3] = torch.sqrt(1.0 + R00 + R11 + R22) / 2 214 | q[:, 0] = (R21 - R12) / (4 * q[:, 3]) 215 | q[:, 1] = (R02 - R20) / (4 * q[:, 3]) 216 | q[:, 2] = (R10 - R01) / (4 * q[:, 3]) 217 | 218 | return q 219 | 220 | 221 | def quat2rot(q): 222 | """Convert [B,4] quaternions to [B,3,3] rotation matrix""" 223 | b, _ = q.shape 224 | q = F.normalize(q, dim=1) 225 | R = torch.ones((b, 3, 3), device=q.device) 226 | 227 | qr = q[:, 0] 228 | qi = q[:, 1] 229 | qj = q[:, 2] 230 | qk = q[:, 3] 231 | 232 | R[:, 0, 0] = 1 - 2 * (qj ** 2 + qk ** 2) 233 | R[:, 0, 1] = 2 * (qj * qi - qk * qr) 234 | R[:, 0, 2] = 2 * (qi * qk + qr * qj) 235 | R[:, 1, 0] = 2 * (qj * qi + qk * qr) 236 | R[:, 1, 1] = 1 - 2 * (qi ** 2 + qk ** 2) 237 | R[:, 1, 2] = 2 * (qj * qk - qi * qr) 238 | R[:, 2, 0] = 2 * (qk * qi - qj * qr) 239 | R[:, 2, 1] = 2 * (qj * qk + qi * qr) 240 | R[:, 2, 2] = 1 - 2 * (qi ** 2 + qj ** 2) 241 | 242 | return R 243 | 244 | 245 | def from_dict_sample(T, to_global=False, zero_origin=False, to_matrix=False): 246 | """ 247 | Create poses from a sample dictionary 248 | 249 | Parameters 250 | ---------- 251 | T : Dict 252 | Dictionary containing input poses [B,4,4] 253 | to_global : Bool 254 | Whether poses should be converted to global frame of reference 255 | zero_origin : Bool 256 | Whether the target camera should be the center of the frame of reference 257 | to_matrix : Bool 258 | Whether output poses should be classes or tensors 259 | 260 | Returns 261 | ------- 262 | pose : Dict 263 | Dictionary containing output poses 264 | """ 265 | pose = {key: Pose(val) for key, val in T.items()} 266 | if to_global: 267 | pose = to_global_pose(pose, zero_origin=zero_origin) 268 | if to_matrix: 269 | pose = {key: val.T for key, val in pose.items()} 270 | return pose 271 | 272 | 273 | def from_dict_batch(T, **kwargs): 274 | """Create poses from a batch dictionary""" 275 | pose_batch = [from_dict_sample({key: val[b] for key, val in T.items()}, **kwargs) 276 | for b in range(T[0].shape[0])] 277 | return {key: torch.stack([v[key] for v in pose_batch], 0) for key in pose_batch[0]} 278 | 279 | 280 | class Pose: 281 | """ 282 | Pose class for 3D operations 283 | 284 | Parameters 285 | ---------- 286 | T : torch.Tensor or Int 287 | Transformation matrix [B,4,4], or batch size (poses initialized as identity) 288 | """ 289 | def __init__(self, T=1): 290 | if is_int(T): 291 | T = torch.eye(4).repeat(T, 1, 1) 292 | self.T = T if T.dim() == 3 else T.unsqueeze(0) 293 | 294 | def __len__(self): 295 | """Return batch size""" 296 | return len(self.T) 297 | 298 | def __getitem__(self, i): 299 | """Return batch-wise pose""" 300 | return Pose(self.T[[i]]) 301 | 302 | def __mul__(self, data): 303 | """Transforms data (pose or 3D points)""" 304 | if isinstance(data, Pose): 305 | return Pose(self.T.bmm(data.T)) 306 | elif isinstance(data, torch.Tensor): 307 | return self.T[:, :3, :3].bmm(data) + self.T[:, :3, -1].unsqueeze(-1) 308 | else: 309 | raise NotImplementedError() 310 | 311 | def detach(self): 312 | """Return detached pose""" 313 | return Pose(self.T.detach()) 314 | 315 | @property 316 | def shape(self): 317 | """Return pose shape""" 318 | return self.T.shape 319 | 320 | @property 321 | def device(self): 322 | """Return pose device""" 323 | return self.T.device 324 | 325 | @property 326 | def dtype(self): 327 | """Return pose type""" 328 | return self.T.dtype 329 | 330 | @classmethod 331 | def identity(cls, N=1, device=None, dtype=torch.float): 332 | """Initializes as a [4,4] identity matrix""" 333 | return cls(torch.eye(4, device=device, dtype=dtype).repeat([N,1,1])) 334 | 335 | @staticmethod 336 | def from_dict(T, to_global=False, zero_origin=False, to_matrix=False): 337 | """Create poses from a dictionary""" 338 | if T[0].dim() == 3: 339 | return from_dict_sample(T, to_global=to_global, zero_origin=zero_origin, to_matrix=to_matrix) 340 | elif T[0].dim() == 4: 341 | return from_dict_batch(T, to_global=to_global, zero_origin=zero_origin, to_matrix=True) 342 | 343 | @classmethod 344 | def from_vec(cls, vec, mode): 345 | """Initializes from a [B,6] batch vector""" 346 | mat = pose_vec2mat(vec, mode) 347 | pose = torch.eye(4, device=vec.device, dtype=vec.dtype).repeat([len(vec), 1, 1]) 348 | pose[:, :3, :3] = mat[:, :3, :3] 349 | pose[:, :3, -1] = mat[:, :3, -1] 350 | return cls(pose) 351 | 352 | def repeat(self, *args, **kwargs): 353 | """Repeats the transformation matrix multiple times""" 354 | self.T = self.T.repeat(*args, **kwargs) 355 | return self 356 | 357 | def inverse(self): 358 | """Returns a new Pose that is the inverse of this one""" 359 | return Pose(invert_pose(self.T)) 360 | 361 | def to(self, *args, **kwargs): 362 | """Copy pose to device""" 363 | self.T = self.T.to(*args, **kwargs) 364 | return self 365 | 366 | def cuda(self, *args, **kwargs): 367 | """Copy pose to CUDA""" 368 | self.to('cuda') 369 | return self 370 | 371 | def translate(self, xyz): 372 | """Translate pose""" 373 | self.T[:, :3, -1] = self.T[:, :3, -1] + xyz.to(self.device) 374 | return self 375 | 376 | def rotate(self, rpw): 377 | """Rotate pose""" 378 | rot = euler2mat(rpw) 379 | T = invert_pose(self.T).clone() 380 | T[:, :3, :3] = T[:, :3, :3] @ rot.to(self.device) 381 | self.T = invert_pose(T) 382 | return self 383 | 384 | def rotateRoll(self, r): 385 | """Rotate pose in the roll axis""" 386 | return self.rotate(torch.tensor([[0, 0, r]])) 387 | 388 | def rotatePitch(self, p): 389 | """Rotate pose in the pitcv axis""" 390 | return self.rotate(torch.tensor([[p, 0, 0]])) 391 | 392 | def rotateYaw(self, w): 393 | """Rotate pose in the yaw axis""" 394 | return self.rotate(torch.tensor([[0, w, 0]])) 395 | 396 | def translateForward(self, t): 397 | """Translate pose forward""" 398 | return self.translate(torch.tensor([[0, 0, -t]])) 399 | 400 | def translateBackward(self, t): 401 | """Translate pose backward""" 402 | return self.translate(torch.tensor([[0, 0, +t]])) 403 | 404 | def translateLeft(self, t): 405 | """Translate pose left""" 406 | return self.translate(torch.tensor([[+t, 0, 0]])) 407 | 408 | def translateRight(self, t): 409 | """Translate pose right""" 410 | return self.translate(torch.tensor([[-t, 0, 0]])) 411 | 412 | def translateUp(self, t): 413 | """Translate pose up""" 414 | return self.translate(torch.tensor([[0, +t, 0]])) 415 | 416 | def translateDown(self, t): 417 | """Translate pose down""" 418 | return self.translate(torch.tensor([[0, -t, 0]])) 419 | 420 | 421 | #def inverse_warp2(img, depth, ref_depth, pose, intrinsics, padding_mode='zeros'): 422 | import torch.nn as nn 423 | 424 | from functools import lru_cache 425 | import torch 426 | import torch.nn as nn 427 | 428 | 429 | def pixel_grid(hw, b=None, with_ones=False, device=None, normalize=False): 430 | """ 431 | Creates a pixel grid for image operations 432 | Parameters 433 | ---------- 434 | hw : Tuple 435 | Height/width of the grid 436 | b : Int 437 | Batch size 438 | with_ones : Bool 439 | Stack an extra channel with 1s 440 | device : String 441 | Device where the grid will be created 442 | normalize : Bool 443 | Whether the grid is normalized between [-1,1] 444 | Returns 445 | ------- 446 | grid : torch.Tensor 447 | Output pixel grid [B,2,H,W] 448 | """ 449 | if is_tensor(hw): 450 | b, hw = hw.shape[0], hw.shape[-2:] 451 | if is_tensor(device): 452 | device = device.device 453 | hi, hf = 0, hw[0] - 1 454 | wi, wf = 0, hw[1] - 1 455 | yy, xx = torch.meshgrid([torch.linspace(hi, hf, hw[0], device=device), 456 | torch.linspace(wi, wf, hw[1], device=device)], indexing='ij') 457 | if with_ones: 458 | grid = torch.stack([xx, yy, torch.ones(hw, device=device)], 0) 459 | else: 460 | grid = torch.stack([xx, yy], 0) 461 | if b is not None: 462 | grid = grid.unsqueeze(0).repeat(b, 1, 1, 1) 463 | if normalize: 464 | grid = norm_pixel_grid(grid) 465 | return grid 466 | 467 | 468 | def norm_pixel_grid(grid, hw=None, in_place=False): 469 | """ 470 | Normalize a pixel grid to be between [0,1] 471 | Parameters 472 | ---------- 473 | grid : torch.Tensor 474 | Grid to be normalized [B,2,H,W] 475 | hw : Tuple 476 | Height/Width for normalization 477 | in_place : Bool 478 | Whether the operation is done in place or not 479 | Returns 480 | ------- 481 | grid : torch.Tensor 482 | Normalized grid [B,2,H,W] 483 | """ 484 | if hw is None: 485 | hw = grid.shape[-2:] 486 | if not in_place: 487 | grid = grid.clone() 488 | grid[:, 0] = 2.0 * grid[:, 0] / (hw[1] - 1) - 1.0 489 | grid[:, 1] = 2.0 * grid[:, 1] / (hw[0] - 1) - 1.0 490 | return grid 491 | 492 | 493 | 494 | class EUCMCamera(nn.Module): 495 | """ 496 | Differentiable camera class implementing reconstruction and projection 497 | functions for the extended unified camera model (EUCM). 498 | """ 499 | def __init__(self, I, Tcw=None): 500 | """ 501 | Initializes the Camera class 502 | 503 | Parameters 504 | ---------- 505 | I : torch.Tensor [6] 506 | Camera intrinsics parameter vector 507 | Tcw : Pose 508 | Camera -> World pose transformation 509 | """ 510 | super().__init__() 511 | self.I = I 512 | if Tcw is None: 513 | self.Tcw = Pose.identity(len(I)) 514 | elif isinstance(Tcw, Pose): 515 | self.Tcw = Tcw 516 | else: 517 | self.Tcw = Pose(Tcw) 518 | 519 | self.Tcw.to(self.I.device) 520 | 521 | def __len__(self): 522 | """Batch size of the camera intrinsics""" 523 | return len(self.I) 524 | 525 | def to(self, *args, **kwargs): 526 | """Moves object to a specific device""" 527 | self.I = self.I.to(*args, **kwargs) 528 | self.Tcw = self.Tcw.to(*args, **kwargs) 529 | return self 530 | 531 | @property 532 | def fx(self): 533 | """Focal length in x""" 534 | return self.I[:, 0].unsqueeze(1).unsqueeze(2) 535 | 536 | @property 537 | def fy(self): 538 | """Focal length in y""" 539 | return self.I[:, 1].unsqueeze(1).unsqueeze(2) 540 | 541 | @property 542 | def cx(self): 543 | """Principal point in x""" 544 | return self.I[:, 2].unsqueeze(1).unsqueeze(2) 545 | 546 | @property 547 | def cy(self): 548 | """Principal point in y""" 549 | return self.I[:, 3].unsqueeze(1).unsqueeze(2) 550 | 551 | @property 552 | def alpha(self): 553 | """alpha in EUCM model""" 554 | return self.I[:, 4].unsqueeze(1).unsqueeze(2) 555 | 556 | @property 557 | def beta(self): 558 | """beta in EUCM model""" 559 | return self.I[:, 5].unsqueeze(1).unsqueeze(2) 560 | 561 | @property 562 | @lru_cache() 563 | def Twc(self): 564 | """World -> Camera pose transformation (inverse of Tcw)""" 565 | return self.Tcw.inverse() 566 | 567 | def reconstruct(self, depth, frame='w'): 568 | """ 569 | Reconstructs pixel-wise 3D points from a depth map. 570 | 571 | Parameters 572 | ---------- 573 | depth : torch.Tensor [B,1,H,W] 574 | Depth map for the camera 575 | frame : 'w' 576 | Reference frame: 'c' for camera and 'w' for world 577 | 578 | Returns 579 | ------- 580 | points : torch.tensor [B,3,H,W] 581 | Pixel-wise 3D points 582 | """ 583 | 584 | if depth is None: 585 | return None 586 | b, c, h, w = depth.shape 587 | assert c == 1 588 | 589 | grid = pixel_grid(depth, with_ones=True, device=depth.device) 590 | 591 | # Estimate the outward rays in the camera frame 592 | fx, fy, cx, cy, alpha, beta = self.fx, self.fy, self.cx, self.cy, self.alpha, self.beta 593 | 594 | if torch.any(torch.isnan(alpha)): 595 | raise ValueError('alpha is nan') 596 | 597 | u = grid[:,0,:,:] 598 | v = grid[:,1,:,:] 599 | 600 | mx = (u - cx) / fx 601 | my = (v - cy) / fy 602 | r_square = mx ** 2 + my ** 2 603 | mz = (1 - beta * alpha ** 2 * r_square) / (alpha * torch.sqrt(1 - (2 * alpha - 1) * beta * r_square) + (1 - alpha)) 604 | coeff = 1 / torch.sqrt(mx ** 2 + my ** 2 + mz ** 2) 605 | 606 | x = coeff * mx 607 | y = coeff * my 608 | z = coeff * mz 609 | z = z.clamp(min=1e-7) 610 | 611 | x_norm = x / z 612 | y_norm = y / z 613 | z_norm = z / z 614 | xnorm = torch.stack(( x_norm, y_norm, z_norm ), dim=1) 615 | 616 | # Scale rays to metric depth 617 | Xc = xnorm * depth 618 | 619 | # If in camera frame of reference 620 | if frame == 'c': 621 | return Xc 622 | # If in world frame of reference 623 | elif frame == 'w': 624 | return (self.Twc * Xc.view(b, 3, -1)).view(b,3,h,w) 625 | # If none of the above 626 | else: 627 | raise ValueError('Unknown reference frame {}'.format(frame)) 628 | 629 | def project(self, X, frame='w'): 630 | """ 631 | Projects 3D points onto the image plane 632 | 633 | Parameters 634 | ---------- 635 | X : torch.Tensor [B,3,H,W] 636 | 3D points to be projected 637 | frame : 'w' 638 | Reference frame: 'c' for camera and 'w' for world 639 | 640 | Returns 641 | ------- 642 | points : torch.Tensor [B,H,W,2] 643 | 2D projected points that are within the image boundaries 644 | """ 645 | B, C, H, W = X.shape 646 | assert C == 3 647 | 648 | # Project 3D points onto the camera image plane 649 | if frame == 'c': 650 | X = X 651 | elif frame == 'w': 652 | X = (self.Tcw * X.view(B,3,-1)).view(B,3,H,W) 653 | else: 654 | raise ValueError('Unknown reference frame {}'.format(frame)) 655 | 656 | fx, fy, cx, cy, alpha, beta = self.fx, self.fy, self.cx, self.cy, self.alpha, self.beta 657 | x, y, z = X[:,0,:], X[:,1,:], X[:,2,:] 658 | d = torch.sqrt( beta * ( x ** 2 + y ** 2 ) + z ** 2 ) 659 | z = z.clamp(min=1e-7) 660 | 661 | Xnorm = fx * x / (alpha * d + (1 - alpha) * z + 1e-7) + cx 662 | Ynorm = fy * y / (alpha * d + (1 - alpha) * z + 1e-7) + cy 663 | Xnorm = 2 * Xnorm / (W-1) - 1 664 | Ynorm = 2 * Ynorm / (H-1) - 1 665 | 666 | coords = torch.stack([Xnorm, Ynorm], dim=-1).permute(0,3,1,2) 667 | z = z.unsqueeze(1) 668 | 669 | invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \ 670 | (coords[:, 1] < -1) | (coords[:, 1] > 1) | (z[:, 0] < 0) 671 | coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2 672 | 673 | # Return pixel coordinates 674 | return coords.permute(0, 2, 3, 1), z 675 | 676 | def reconstruct_depth_map(self, depth, to_world=True): 677 | if to_world: 678 | return self.reconstruct(depth, frame='w') 679 | else: 680 | return self.reconstruct(depth, frame='c') 681 | 682 | def project_points(self, points, from_world=True, normalize=True, return_z=False): 683 | if from_world: 684 | return self.project(points, frame='w') 685 | else: 686 | return self.project(points, frame='c') 687 | 688 | def coords_from_depth(self, depth, ref_cam=None): 689 | if ref_cam is None: 690 | return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True) 691 | else: 692 | return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True) 693 | 694 | 695 | class Camera(nn.Module, ABC): 696 | """ 697 | Camera class for 3D reconstruction 698 | 699 | Parameters 700 | ---------- 701 | K : torch.Tensor 702 | Camera intrinsics [B,3,3] 703 | hw : Tuple 704 | Camera height and width 705 | Twc : Pose or torch.Tensor 706 | Camera pose (world to camera) [B,4,4] 707 | Tcw : Pose or torch.Tensor 708 | Camera pose (camera to world) [B,4,4] 709 | """ 710 | def __init__(self, K, hw, Twc=None, Tcw=None): 711 | super().__init__() 712 | 713 | # Asserts 714 | 715 | assert Twc is None or Tcw is None 716 | 717 | # Fold if multi-batch 718 | 719 | if K.dim() == 4: 720 | K = rearrange(K, 'b n h w -> (b n) h w') 721 | if Twc is not None: 722 | Twc = rearrange(Twc, 'b n h w -> (b n) h w') 723 | if Tcw is not None: 724 | Tcw = rearrange(Tcw, 'b n h w -> (b n) h w') 725 | 726 | # Intrinsics 727 | 728 | if same_shape(K.shape[-2:], (3, 3)): 729 | self._K = torch.eye(4, dtype=K.dtype, device=K.device).repeat(K.shape[0], 1, 1) 730 | self._K[:, :3, :3] = K 731 | else: 732 | self._K = K 733 | 734 | # Pose 735 | 736 | if Twc is None and Tcw is None: 737 | self._Twc = torch.eye(4, dtype=K.dtype, device=K.device).unsqueeze(0).repeat(K.shape[0], 1, 1) 738 | else: 739 | self._Twc = invert_pose(Tcw) if Tcw is not None else Twc 740 | if is_tensor(self._Twc): 741 | self._Twc = Pose(self._Twc) 742 | 743 | # Resolution 744 | 745 | self._hw = hw 746 | if is_tensor(self._hw): 747 | self._hw = self._hw.shape[-2:] 748 | 749 | def __getitem__(self, idx): 750 | """Return batch-wise pose""" 751 | if is_seq(idx): 752 | return type(self).from_list([self.__getitem__(i) for i in idx]) 753 | else: 754 | return type(self)( 755 | K=self._K[[idx]], 756 | Twc=self._Twc[[idx]] if self._Twc is not None else None, 757 | hw=self._hw, 758 | ) 759 | 760 | def __len__(self): 761 | """Return length as intrinsics batch""" 762 | return self._K.shape[0] 763 | 764 | def __eq__(self, cam): 765 | """Check if two cameras are the same""" 766 | if not isinstance(cam, type(self)): 767 | return False 768 | if self._hw[0] != cam.hw[0] or self._hw[1] != cam.hw[1]: 769 | return False 770 | if not torch.allclose(self._K, cam.K): 771 | return False 772 | if not torch.allclose(self._Twc.T, cam.Twc.T): 773 | return False 774 | return True 775 | 776 | def clone(self): 777 | """Return a copy of this camera""" 778 | return deepcopy(self) 779 | 780 | @property 781 | def pose(self): 782 | """Return camera pose (world to camera)""" 783 | return self._Twc.T 784 | 785 | @property 786 | def K(self): 787 | """Return camera intrinsics""" 788 | return self._K 789 | 790 | @K.setter 791 | def K(self, K): 792 | """Set camera intrinsics""" 793 | self._K = K 794 | 795 | @property 796 | def invK(self): 797 | """Return inverse of camera intrinsics""" 798 | return invert_intrinsics(self._K) 799 | 800 | @property 801 | def batch_size(self): 802 | """Return batch size""" 803 | return self._Twc.T.shape[0] 804 | 805 | @property 806 | def hw(self): 807 | """Return camera height and width""" 808 | return self._hw 809 | 810 | @hw.setter 811 | def hw(self, hw): 812 | """Set camera height and width""" 813 | self._hw = hw 814 | 815 | @property 816 | def wh(self): 817 | """Get camera width and height""" 818 | return self._hw[::-1] 819 | 820 | @property 821 | def n_pixels(self): 822 | """Return number of pixels""" 823 | return self._hw[0] * self._hw[1] 824 | 825 | @property 826 | def fx(self): 827 | """Return horizontal focal length""" 828 | return self._K[:, 0, 0] 829 | 830 | @property 831 | def fy(self): 832 | """Return vertical focal length""" 833 | return self._K[:, 1, 1] 834 | 835 | @property 836 | def cx(self): 837 | """Return horizontal principal point""" 838 | return self._K[:, 0, 2] 839 | 840 | @property 841 | def cy(self): 842 | """Return vertical principal point""" 843 | return self._K[:, 1, 2] 844 | 845 | @property 846 | def fxy(self): 847 | """Return focal length""" 848 | return torch.tensor([self.fx, self.fy], dtype=self.dtype, device=self.device) 849 | 850 | @property 851 | def cxy(self): 852 | """Return principal points""" 853 | return self._K[:, :2, 2] 854 | # return torch.tensor([self.cx, self.cy], dtype=self.dtype, device=self.device) 855 | 856 | @property 857 | def Tcw(self): 858 | """Return camera pose (camera to world)""" 859 | return None if self._Twc is None else self._Twc.inverse() 860 | 861 | @Tcw.setter 862 | def Tcw(self, Tcw): 863 | """Set camera pose (camera to world)""" 864 | self._Twc = Tcw.inverse() 865 | 866 | @property 867 | def Twc(self): 868 | """Return camera pose (world to camera)""" 869 | return self._Twc 870 | 871 | @Twc.setter 872 | def Twc(self, Twc): 873 | """Set camera pose (world to camera)""" 874 | self._Twc = Twc 875 | 876 | @property 877 | def dtype(self): 878 | """Return tensor type""" 879 | return self._K.dtype 880 | 881 | @property 882 | def device(self): 883 | """Return device""" 884 | return self._K.device 885 | 886 | def detach_pose(self): 887 | """Detach pose from the graph""" 888 | return type(self)(K=self._K, hw=self._hw, 889 | Twc=self._Twc.detach() if self._Twc is not None else None) 890 | 891 | def detach_K(self): 892 | """Detach intrinsics from the graph""" 893 | return type(self)(K=self._K.detach(), hw=self._hw, Twc=self._Twc) 894 | 895 | def detach(self): 896 | """Detach camera from the graph""" 897 | return type(self)(K=self._K.detach(), hw=self._hw, 898 | Twc=self._Twc.detach() if self._Twc is not None else None) 899 | 900 | def inverted_pose(self): 901 | """Invert camera pose""" 902 | return type(self)(K=self._K, hw=self._hw, 903 | Twc=self._Twc.inverse() if self._Twc is not None else None) 904 | 905 | def no_translation(self): 906 | """Return new camera without translation""" 907 | Twc = self.pose.clone() 908 | Twc[:, :-1, -1] = 0 909 | return type(self)(K=self._K, hw=self._hw, Twc=Twc) 910 | 911 | @staticmethod 912 | def from_dict(K, hw, Twc=None): 913 | """Create cameras from a pose dictionary""" 914 | return {key: Camera(K=K[0], hw=hw[0], Twc=val) for key, val in Twc.items()} 915 | 916 | # @staticmethod 917 | # def from_dict(K, hw, Twc=None): 918 | # return {key: Camera(K=K[(0, 0)], hw=hw[(0, 0)], Twc=val) for key, val in Twc.items()} 919 | 920 | @staticmethod 921 | def from_list(cams): 922 | """Create cameras from a list""" 923 | K = torch.cat([cam.K for cam in cams], 0) 924 | Twc = torch.cat([cam.Twc.T for cam in cams], 0) 925 | return Camera(K=K, Twc=Twc, hw=cams[0].hw) 926 | 927 | def scaled(self, scale_factor): 928 | """Return a scaled camera""" 929 | if scale_factor is None or scale_factor == 1: 930 | return self 931 | if is_seq(scale_factor): 932 | if len(scale_factor) == 4: 933 | scale_factor = scale_factor[-2:] 934 | scale_factor = [float(scale_factor[i]) / float(self._hw[i]) for i in range(2)] 935 | else: 936 | scale_factor = [scale_factor] * 2 937 | return type(self)( 938 | K=scale_intrinsics(self._K, scale_factor), 939 | hw=[int(self._hw[i] * scale_factor[i]) for i in range(len(self._hw))], 940 | Twc=self._Twc 941 | ) 942 | 943 | def offset_start(self, start): 944 | """Offset camera intrinsics based on a crop""" 945 | new_cam = self.clone() 946 | start = start.to(self.device) 947 | new_cam.K[:, 0, 2] -= start[:, 1] 948 | new_cam.K[:, 1, 2] -= start[:, 0] 949 | return new_cam 950 | 951 | def interpolate(self, rgb): 952 | """Interpolate an image to fit the camera""" 953 | if rgb.dim() == 5: 954 | rgb = rearrange(rgb, 'b n c h w -> (b n) c h w') 955 | return interpolate(rgb, scale_factor=None, size=self.hw, mode='bilinear', align_corners=True) 956 | 957 | def interleave_K(self, b): 958 | """Interleave camera intrinsics to fit multiple batches""" 959 | return type(self)( 960 | K=interleave(self._K, b), 961 | Twc=self._Twc, 962 | hw=self._hw, 963 | ) 964 | 965 | def interleave_Twc(self, b): 966 | """Interleave camera pose to fit multiple batches""" 967 | return type(self)( 968 | K=self._K, 969 | Twc=interleave(self._Twc, b), 970 | hw=self._hw, 971 | ) 972 | 973 | def interleave(self, b): 974 | """Interleave camera to fit multiple batches""" 975 | return type(self)( 976 | K=interleave(self._K, b), 977 | Twc=interleave(self._Twc, b), 978 | hw=self._hw, 979 | ) 980 | 981 | def Pwc(self, from_world=True): 982 | """Return projection matrix""" 983 | return self._K[:, :3] if not from_world or self._Twc is None else \ 984 | torch.matmul(self._K, self._Twc.T)[:, :3] 985 | 986 | def to_world(self, points): 987 | """Transform points to world coordinates""" 988 | if points.dim() > 3: 989 | points = points.reshape(points.shape[0], 3, -1) 990 | return points if self.Tcw is None else self.Tcw * points 991 | 992 | def from_world(self, points): 993 | """Transform points back to camera coordinates""" 994 | if points.dim() > 3: 995 | points = points.reshape(points.shape[0], 3, -1) 996 | return points if self._Twc is None else \ 997 | torch.matmul(self._Twc.T, cat_channel_ones(points, 1))[:, :3] 998 | 999 | def to(self, *args, **kwargs): 1000 | """Copy camera to device""" 1001 | self._K = self._K.to(*args, **kwargs) 1002 | if self._Twc is not None: 1003 | self._Twc = self._Twc.to(*args, **kwargs) 1004 | return self 1005 | 1006 | def cuda(self, *args, **kwargs): 1007 | """Copy camera to CUDA""" 1008 | return self.to('cuda') 1009 | 1010 | def relative_to(self, cam): 1011 | """Create a new camera relative to another camera""" 1012 | return Camera(K=self._K, hw=self._hw, Twc=self._Twc * cam.Twc.inverse()) 1013 | 1014 | def global_from(self, cam): 1015 | """Create a new camera in global coordinates relative to another camera""" 1016 | return Camera(K=self._K, hw=self._hw, Twc=self._Twc * cam.Twc) 1017 | 1018 | def reconstruct_depth_map(self, depth, to_world=False): 1019 | """ 1020 | Reconstruct a depth map from the camera viewpoint 1021 | 1022 | Parameters 1023 | ---------- 1024 | depth : torch.Tensor 1025 | Input depth map [B,1,H,W] 1026 | to_world : Bool 1027 | Transform points to world coordinates 1028 | 1029 | Returns 1030 | ------- 1031 | points : torch.Tensor 1032 | Output 3D points [B,3,H,W] 1033 | """ 1034 | if depth is None: 1035 | return None 1036 | b, _, h, w = depth.shape 1037 | grid = pixel_grid(depth, with_ones=True, device=depth.device).view(b, 3, -1) 1038 | points = depth.view(b, 1, -1) * torch.matmul(self.invK[:, :3, :3], grid) 1039 | if to_world and self.Tcw is not None: 1040 | points = self.Tcw * points 1041 | return points.view(b, 3, h, w) 1042 | 1043 | def reconstruct_cost_volume(self, volume, to_world=True, flatten=True): 1044 | """ 1045 | Reconstruct a cost volume from the camera viewpoint 1046 | 1047 | Parameters 1048 | ---------- 1049 | volume : torch.Tensor 1050 | Input depth map [B,1,D,H,W] 1051 | to_world : Bool 1052 | Transform points to world coordinates 1053 | flatten: Bool 1054 | Flatten volume points 1055 | 1056 | Returns 1057 | ------- 1058 | points : torch.Tensor 1059 | Output 3D points [B,3,D,H,W] 1060 | """ 1061 | c, d, h, w = volume.shape 1062 | grid = pixel_grid((h, w), with_ones=True, device=volume.device).view(3, -1).repeat(1, d) 1063 | points = torch.stack([ 1064 | (volume.view(c, -1) * torch.matmul(invK[:3, :3].unsqueeze(0), grid)).view(3, d * h * w) 1065 | for invK in self.invK], 0) 1066 | if to_world and self.Tcw is not None: 1067 | points = self.Tcw * points 1068 | if flatten: 1069 | return points.view(-1, 3, d, h * w).permute(0, 2, 1, 3) 1070 | else: 1071 | return points.view(-1, 3, d, h, w) 1072 | 1073 | def project_points(self, points, from_world=True, normalize=True, return_z=False): 1074 | """ 1075 | Project points back to image plane 1076 | 1077 | Parameters 1078 | ---------- 1079 | points : torch.Tensor 1080 | Input 3D points [B,3,H,W] or [B,3,N] 1081 | from_world : Bool 1082 | Whether points are in the global frame 1083 | normalize : Bool 1084 | Whether projections should be normalized to [-1,1] 1085 | return_z : Bool 1086 | Whether projected depth is return as well 1087 | 1088 | Returns 1089 | ------- 1090 | coords : torch.Tensor 1091 | Projected 2D coordinates [B,2,H,W] 1092 | depth : torch.Tensor 1093 | Projected depth [B,1,H,W] 1094 | """ 1095 | is_depth_map = points.dim() == 4 1096 | hw = self._hw if not is_depth_map else points.shape[-2:] 1097 | 1098 | if is_depth_map: 1099 | points = points.reshape(points.shape[0], 3, -1) 1100 | b, _, n = points.shape 1101 | 1102 | points = torch.matmul(self.Pwc(from_world), cat_channel_ones(points, 1)) 1103 | 1104 | coords = points[:, :2] / (points[:, 2].unsqueeze(1) + 1e-7) 1105 | depth = points[:, 2] 1106 | 1107 | if not is_depth_map: 1108 | if normalize: 1109 | coords = norm_pixel_grid(coords, hw=self._hw, in_place=True) 1110 | invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \ 1111 | (coords[:, 1] < -1) | (coords[:, 1] > 1) | (depth < 0) 1112 | coords[invalid.unsqueeze(1).repeat(1, 2, 1)] = -2 1113 | if return_z: 1114 | return coords.permute(0, 2, 1), depth 1115 | else: 1116 | return coords.permute(0, 2, 1) 1117 | 1118 | coords = coords.view(b, 2, *hw) 1119 | depth = depth.view(b, 1, *hw) 1120 | 1121 | if normalize: 1122 | coords = norm_pixel_grid(coords, hw=self._hw, in_place=True) 1123 | invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \ 1124 | (coords[:, 1] < -1) | (coords[:, 1] > 1) | (depth[:, 0] < 0) 1125 | coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2 1126 | 1127 | if return_z: 1128 | return coords.permute(0, 2, 3, 1), depth 1129 | else: 1130 | return coords.permute(0, 2, 3, 1) 1131 | 1132 | def project_cost_volume(self, points, from_world=True, normalize=True): 1133 | """ 1134 | Project points back to image plane 1135 | 1136 | Parameters 1137 | ---------- 1138 | points : torch.Tensor 1139 | Input 3D points [B,3,H,W] or [B,3,N] 1140 | from_world : Bool 1141 | Whether points are in the global frame 1142 | normalize : Bool 1143 | Whether projections should be normalized to [-1,1] 1144 | 1145 | Returns 1146 | ------- 1147 | coords : torch.Tensor 1148 | Projected 2D coordinates [B,2,H,W] 1149 | """ 1150 | if points.dim() == 4: 1151 | points = points.permute(0, 2, 1, 3).reshape(points.shape[0], 3, -1) 1152 | b, _, n = points.shape 1153 | 1154 | points = torch.matmul(self.Pwc(from_world), cat_channel_ones(points, 1)) 1155 | 1156 | coords = points[:, :2] / (points[:, 2].unsqueeze(1) + 1e-7) 1157 | coords = coords.view(b, 2, -1, *self._hw).permute(0, 2, 3, 4, 1) 1158 | 1159 | if normalize: 1160 | coords[..., 0] /= self._hw[1] - 1 1161 | coords[..., 1] /= self._hw[0] - 1 1162 | return 2 * coords - 1 1163 | else: 1164 | return coords 1165 | 1166 | def coords_from_cost_volume(self, volume, ref_cam=None): 1167 | """ 1168 | Get warp coordinates from a cost volume 1169 | 1170 | Parameters 1171 | ---------- 1172 | volume : torch.Tensor 1173 | Input cost volume [B,1,D,H,W] 1174 | ref_cam : Camera 1175 | Optional to generate cross-camera coordinates 1176 | 1177 | Returns 1178 | ------- 1179 | coords : torch.Tensor 1180 | Projected 2D coordinates [B,2,H,W] 1181 | """ 1182 | if ref_cam is None: 1183 | return self.project_cost_volume(self.reconstruct_cost_volume(volume, to_world=False), from_world=True) 1184 | else: 1185 | return ref_cam.project_cost_volume(self.reconstruct_cost_volume(volume, to_world=True), from_world=True) 1186 | 1187 | def coords_from_depth(self, depth, ref_cam=None): 1188 | """ 1189 | Get warp coordinates from a depth map 1190 | 1191 | Parameters 1192 | ---------- 1193 | depth : torch.Tensor 1194 | Input cost volume [B,1,D,H,W] 1195 | ref_cam : Camera 1196 | Optional to generate cross-camera coordinates 1197 | 1198 | Returns 1199 | ------- 1200 | coords : torch.Tensor 1201 | Projected 2D coordinates [B,2,H,W] 1202 | """ 1203 | if ref_cam is None: 1204 | return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True, return_z=True) 1205 | else: 1206 | return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True, return_z=True) 1207 | import torch.nn.functional as tfn 1208 | 1209 | 1210 | def grid_sample(tensor, grid, padding_mode, mode, align_corners): 1211 | return tfn.grid_sample(tensor, grid, 1212 | padding_mode=padding_mode, mode=mode, align_corners=align_corners) 1213 | 1214 | 1215 | def invert_intrinsics(K): 1216 | """Invert camera intrinsics""" 1217 | Kinv = K.clone() 1218 | Kinv[:, 0, 0] = 1. / K[:, 0, 0] 1219 | Kinv[:, 1, 1] = 1. / K[:, 1, 1] 1220 | Kinv[:, 0, 2] = -1. * K[:, 0, 2] / K[:, 0, 0] 1221 | Kinv[:, 1, 2] = -1. * K[:, 1, 2] / K[:, 1, 1] 1222 | return Kinv 1223 | 1224 | def interpolate(tensor, size, scale_factor, mode, align_corners): 1225 | """ 1226 | Interpolate a tensor to a different resolution 1227 | 1228 | Parameters 1229 | ---------- 1230 | tensor : torch.Tensor 1231 | Input tensor [B,?,H,W] 1232 | size : Tuple 1233 | Interpolation size (H,W) 1234 | scale_factor : Float 1235 | Scale factor for interpolation 1236 | mode : String 1237 | Interpolation mode 1238 | align_corners : Bool 1239 | Corner alignment flag 1240 | 1241 | Returns 1242 | ------- 1243 | tensor : torch.Tensor 1244 | Interpolated tensor [B,?,h,w] 1245 | """ 1246 | if is_tensor(size): 1247 | size = size.shape[-2:] 1248 | return tfn.interpolate( 1249 | tensor, size=size, scale_factor=scale_factor, 1250 | mode=mode, align_corners=align_corners, recompute_scale_factor=False, 1251 | ) 1252 | 1253 | 1254 | class ViewSynthesis(nn.Module, ABC): 1255 | """ 1256 | Class for view synthesis calculation based on image warping 1257 | 1258 | Parameters 1259 | ---------- 1260 | cfg : Config 1261 | Configuration with parameters 1262 | """ 1263 | def __init__(self, cfg=None): 1264 | super().__init__() 1265 | self.grid_sample = partial( 1266 | grid_sample, mode='bilinear', padding_mode='border', align_corners=True) 1267 | self.interpolate = partial( 1268 | interpolate, mode='bilinear', scale_factor=None, align_corners=True) 1269 | self.grid_sample_zeros = partial( 1270 | grid_sample, mode='nearest', padding_mode='zeros', align_corners=True) 1271 | self.upsample_depth = False 1272 | @staticmethod 1273 | def get_num_scales(depths, optical_flow): 1274 | """Return number of scales based on input""" 1275 | if depths is not None: 1276 | return len(depths) 1277 | if optical_flow is not None: 1278 | return len(optical_flow) 1279 | else: 1280 | raise ValueError('Invalid inputs for view synthesis') 1281 | 1282 | @staticmethod 1283 | def get_tensor_ones(depths, optical_flow, scale): 1284 | """Return unitary tensor based on input""" 1285 | if depths is not None: 1286 | return torch.ones_like(depths[scale]) 1287 | elif optical_flow is not None: 1288 | b, _, h, w = optical_flow[scale].shape 1289 | return torch.ones((b, 1, h, w), device=optical_flow[scale].device) 1290 | else: 1291 | raise ValueError('Invalid inputs for view synthesis') 1292 | 1293 | def get_coords(self, rgbs, depths, cams, context, scale): 1294 | """ 1295 | Calculate projection coordinates for warping 1296 | 1297 | Parameters 1298 | ---------- 1299 | rgbs : list[torch.Tensor] 1300 | Input images (for dimensions) [B,3,H,W] 1301 | depths : list[torch.Tensor] 1302 | Target depth maps [B,1,H,W] 1303 | cams : list[Camera] 1304 | Input cameras 1305 | optical_flow : list[torch.Tensor] 1306 | Input optical flow for alternative warping 1307 | context : list[Int] 1308 | Context indices 1309 | scale : Int 1310 | Current scale 1311 | tgt : Int 1312 | Target index 1313 | 1314 | Returns 1315 | ------- 1316 | output : Dict 1317 | Dictionary containing warped images and masks 1318 | """ 1319 | if depths is not None and cams is not None: 1320 | cams_tgt = cams[0] if is_list(cams) else cams 1321 | cams_ctx = cams[1] if is_list(cams) else cams 1322 | depth = depths[scale] 1323 | return { 1324 | ctx: cams_tgt[0].coords_from_depth(depth, cams_ctx[ctx]) for ctx in context 1325 | } 1326 | else: 1327 | raise ValueError('Invalid input for view synthesis') 1328 | 1329 | def forward(self, rgbs, depths=None, 1330 | #tgt_depths=None, 1331 | cams=None, 1332 | return_masks=False, tgt=0): 1333 | 1334 | num_scales = 1 1335 | warps, warped_depths, masks = [], [], [] 1336 | scale = 0 1337 | 1338 | coords, warped_depths = self.get_coords(rgbs, depths, cams, [1], scale)[1] 1339 | src=0 1340 | print(coords.shape) 1341 | warps = self.grid_sample( 1342 | rgbs[1][src], coords.type(rgbs[1][src].dtype)) 1343 | #warped_depths = self.grid_sample( 1344 | # depths[src], coords[1].type(rgbs[1][src].dtype)) 1345 | #computed_depths = self.grid_sample( 1346 | # tgt_depths[tgt], coords[1].type(rgbs[1][src].dtype) 1347 | #) 1348 | if return_masks: 1349 | ones = self.get_tensor_ones(depths, None, scale) 1350 | masks = self.grid_sample_zeros( 1351 | ones, coords.type(ones.dtype)) 1352 | 1353 | return { 1354 | 'warps': warps, 1355 | 'warped_depths': warped_depths, 1356 | #'computed_depths': computed_depths, 1357 | 'masks': masks if return_masks else None 1358 | } 1359 | def same_shape(shape1, shape2): 1360 | """Checks if two shapes are the same""" 1361 | if len(shape1) != len(shape2): 1362 | return False 1363 | for i in range(len(shape1)): 1364 | if shape1[i] != shape2[i]: 1365 | return False 1366 | return True 1367 | 1368 | def cat_channel_ones(tensor, n=1): 1369 | """ 1370 | Concatenate tensor with an extra channel of ones 1371 | 1372 | Parameters 1373 | ---------- 1374 | tensor : torch.Tensor 1375 | Tensor to be concatenated 1376 | n : Int 1377 | Which channel will be concatenated 1378 | 1379 | Returns 1380 | ------- 1381 | cat_tensor : torch.Tensor 1382 | Concatenated tensor 1383 | """ 1384 | # Get tensor shape with 1 channel 1385 | shape = list(tensor.shape) 1386 | shape[n] = 1 1387 | # Return concatenation of tensor with ones 1388 | return torch.cat([tensor, torch.ones(shape, 1389 | device=tensor.device, dtype=tensor.dtype)], n) 1390 | 1391 | 1392 | def rectify_eucm(img, mask, depth, intrinsic): 1393 | 1394 | with torch.no_grad(): 1395 | cam1 = EUCMCamera(intrinsic.unsqueeze(0)) 1396 | linear_intrinsic = torch.tensor([[ 1397 | [intrinsic[0], 0, intrinsic[2]], 1398 | [0, intrinsic[1], intrinsic[3]], 1399 | [0, 0, 1] 1400 | ]]) 1401 | cam0 = Camera(linear_intrinsic, hw=(img.shape[2], img.shape[1])) 1402 | 1403 | coords, warped_depths = cam0.coords_from_depth(depth, cam1) 1404 | projected_img = torch.nn.functional.grid_sample(img, coords, 1405 | padding_mode='zeros', mode='bilinear', align_corners=False) 1406 | projected_mask = torch.nn.functional.grid_sample(mask, coords, 1407 | padding_mode='zeros', mode='bilinear', align_corners=False) 1408 | projected_depth = torch.nn.functional.grid_sample(depth, coords, 1409 | padding_mode='zeros', mode='bilinear', align_corners=False) 1410 | return projected_img.squeeze().cpu().numpy(), (projected_mask==1).float().squeeze().cpu().numpy(), projected_depth.squeeze().cpu().numpy() --------------------------------------------------------------------------------