├── .gitignore ├── .gitmodules ├── Dockerfile ├── LICENSE ├── README.md ├── flow4D.py ├── lightning_logs └── version_0 │ ├── hparams.yaml │ └── metrics.csv └── network_4D.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.json 3 | logs/* 4 | *.pyc 5 | 6 | tmp_sbatch.sh 7 | *.log 8 | 9 | *.hydra 10 | 11 | *.h5 12 | 13 | *.zip 14 | 15 | cancel.sh 16 | 17 | # slurm log files 18 | *.err 19 | *.out 20 | 21 | # figure 22 | *.png 23 | *.pdf -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "OpenSceneFlow"] 2 | path = OpenSceneFlow 3 | url = https://github.com/KTH-RPL/OpenSceneFlow 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # check more: https://hub.docker.com/r/nvidia/cuda 2 | FROM nvidia/cuda:11.7.1-devel-ubuntu20.04 3 | ENV DEBIAN_FRONTEND noninteractive 4 | 5 | RUN apt update && apt install -y --no-install-recommends \ 6 | git curl vim rsync htop 7 | 8 | RUN curl -o ~/miniconda.sh -LO https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 9 | chmod +x ~/miniconda.sh && \ 10 | ~/miniconda.sh -b -p /opt/conda && \ 11 | rm ~/miniconda.sh && \ 12 | /opt/conda/bin/conda clean -ya && /opt/conda/bin/conda init bash 13 | 14 | RUN curl -o ~/mamba.sh -LO https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh && \ 15 | chmod +x ~/mamba.sh && \ 16 | ~/mamba.sh -b -p /opt/mambaforge && \ 17 | rm ~/mamba.sh && /opt/mambaforge/bin/mamba init bash 18 | 19 | # install zsh and oh-my-zsh 20 | RUN apt install -y wget git zsh tmux vim g++ 21 | RUN sh -c "$(wget -O- https://github.com/deluan/zsh-in-docker/releases/download/v1.1.5/zsh-in-docker.sh)" -- \ 22 | -t robbyrussell -p git \ 23 | -p https://github.com/agkozak/zsh-z \ 24 | -p https://github.com/zsh-users/zsh-autosuggestions \ 25 | -p https://github.com/zsh-users/zsh-completions \ 26 | -p https://github.com/zsh-users/zsh-syntax-highlighting 27 | 28 | RUN printf "y\ny\ny\n\n" | bash -c "$(curl -fsSL https://raw.githubusercontent.com/Kin-Zhang/Kin-Zhang/main/scripts/setup_ohmyzsh.sh)" 29 | RUN /opt/conda/bin/conda init zsh && /opt/mambaforge/bin/mamba init zsh 30 | 31 | # change to conda env 32 | ENV PATH /opt/conda/bin:$PATH 33 | ENV PATH /opt/mambaforge/bin:$PATH 34 | 35 | RUN mkdir -p /home/kin/workspace && cd /home/kin/workspace && git clone --recursive https://github.com/KTH-RPL/DeFlow.git 36 | WORKDIR /home/kin/workspace/DeFlow 37 | RUN apt-get update && apt-get install libgl1 -y 38 | # need read the gpu device info to compile the cuda extension 39 | RUN cd /home/kin/workspace/DeFlow && /opt/mambaforge/bin/mamba env create -f environment.yaml 40 | RUN cd /home/kin/workspace/DeFlow/mmcv && export MMCV_WITH_OPS=1 && export FORCE_CUDA=1 && /opt/mambaforge/envs/deflow/bin/pip install -e . 41 | 42 | 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Robotics, Perception and Learning @KTH 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flow4D: Leveraging 4D Voxel Network for LiDAR Scene Flow Estimation (RA-L 2025) 2 | 3 | This repository contains the code for the [Flow4D paper (RA-L 2025)](https://ieeexplore.ieee.org/document/10887254) 4 | 5 | 6 | ## Notice 7 | 8 | **Flow4D has been integrated into [OpenSceneFlow](https://github.com/KTH-RPL/OpenSceneFlow).** 9 | Please visit the [OpenSceneFlow](https://github.com/KTH-RPL/OpenSceneFlow) repository for the latest updates and developments. 10 | 11 | This repo saved README, and quick core file in Flow4D for a quick reference. 12 | The old source code branch is also [available here](https://github.com/dgist-cvlab/Flow4D/tree/source). 13 | 14 | ## Requirements 15 | 16 | This code is based on DeFlow.
17 | Please follow the installation instructions from the [DeFlow repository](https://github.com/KTH-RPL/DeFlow). 18 | 19 | Additionally, you need to install `spconv 2.3.6`.
20 | You can find the installation instructions here: [spconv](https://github.com/traveller59/spconv). 21 | 22 | 23 | ## Training 24 | 25 | To train the model, use the following command: 26 | 27 | ```bash 28 | python train.py model=flow4d lr=1e-3 epochs=15 batch_size=8 num_frames=5 loss_fn=deflowLoss "voxel_size=[0.2, 0.2, 0.2]" "point_cloud_range=[-51.2, -51.2, -3.2, 51.2, 51.2, 3.2]" 29 | ``` 30 | 31 | 32 | ## Inference 33 | 34 | To perform inference, use the following command: 35 | 36 | ```bash 37 | python eval.py checkpoint=path_to_checkpoint av2_mode=(val, test) 38 | ``` 39 | 40 | Replace `path_to_checkpoint` with the actual path to your checkpoint file and choose either `val` or `test`. 41 | 42 | 43 | ## Gratitude 44 | This code is based on the [DeFlow code](https://github.com/KTH-RPL/DeFlow) by Qingwen Zhang. 45 | We extend our deepest gratitude to her.
46 | Additionally, we would like to express our sincere thanks to Kyle Vedder et al. for hosting and providing extensive support for [Argoverse2 2024 Scene Flow Challenge](https://www.argoverse.org/sceneflow.html) 47 | 48 | -------------------------------------------------------------------------------- /flow4D.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | # Created: 2023-07-18 15:08 4 | # Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology 5 | # Author: Kin ZHANG (https://kin-zhang.github.io/) 6 | # 7 | # This work is licensed under the terms of the MIT license. 8 | # For a copy, see . 9 | """ 10 | 11 | import torch.nn as nn 12 | import dztimer, torch 13 | 14 | from .basic.embedder_model_flow4D import DynamicEmbedder_4D 15 | from .basic import cal_pose0to1 16 | 17 | from .basic.network_4D import Network_4D, Seperate_to_3D, Point_head 18 | 19 | 20 | def replace_feature(out, new_features): 21 | if "replace_feature" in out.__dir__(): 22 | # spconv 2.x behaviour 23 | return out.replace_feature(new_features) 24 | else: 25 | out.features = new_features 26 | return out 27 | 28 | 29 | class Flow4D(nn.Module): 30 | def __init__(self, voxel_size = [0.2, 0.2, 0.2], 31 | point_cloud_range = [-51.2, -51.2, -2.2, 51.2, 51.2, 4.2], 32 | grid_feature_size = [512, 512, 32], 33 | num_frames = 5): 34 | super().__init__() 35 | 36 | point_output_ch = 16 37 | voxel_output_ch = 16 38 | 39 | self.num_frames = num_frames 40 | print('voxel_size = {}, pseudo_dims = {}, input_num_frames = {}'.format(voxel_size, grid_feature_size, self.num_frames)) 41 | 42 | self.embedder_4D = DynamicEmbedder_4D(voxel_size=voxel_size, 43 | pseudo_image_dims=[grid_feature_size[0], grid_feature_size[1], grid_feature_size[2], num_frames], 44 | point_cloud_range=point_cloud_range, 45 | feat_channels=point_output_ch) 46 | 47 | self.network_4D = Network_4D(in_channel=point_output_ch, out_channel=voxel_output_ch) 48 | 49 | self.seperate_feat = Seperate_to_3D(num_frames) 50 | 51 | self.pointhead_3D = Point_head(voxel_feat_dim=voxel_output_ch, point_feat_dim=point_output_ch) 52 | 53 | self.timer = dztimer.Timing() 54 | self.timer.start("Total") 55 | 56 | def load_from_checkpoint(self, ckpt_path): 57 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] 58 | state_dict = { 59 | k[len("model.") :]: v for k, v in ckpt.items() if k.startswith("model.") 60 | } 61 | print("\nLoading... model weight from: ", ckpt_path, "\n") 62 | return self.load_state_dict(state_dict=state_dict, strict=False) 63 | 64 | def forward(self, batch): 65 | #t_deflow_start = time.time() 66 | """ 67 | input: using the batch from dataloader, which is a dict 68 | Detail: [pc0, pc1, pose0, pose1] 69 | output: the predicted flow, pose_flow, and the valid point index of pc0 70 | """ 71 | 72 | self.timer[0].start("Data Preprocess") 73 | batch_sizes = len(batch["pose0"]) 74 | 75 | pose_flows = [] 76 | transform_pc0s = [] 77 | transform_pc_m_frames = [[] for _ in range(self.num_frames - 2)] 78 | 79 | 80 | for batch_id in range(batch_sizes): 81 | selected_pc0 = batch["pc0"][batch_id] 82 | self.timer[0][0].start("pose") 83 | with torch.no_grad(): 84 | if 'ego_motion' in batch: 85 | pose_0to1 = batch['ego_motion'][batch_id] 86 | else: 87 | pose_0to1 = cal_pose0to1(batch["pose0"][batch_id], batch["pose1"][batch_id]) 88 | 89 | if self.num_frames > 2: 90 | past_poses = [] 91 | for i in range(1, self.num_frames - 1): 92 | past_pose = cal_pose0to1(batch[f"pose_m{i}"][batch_id], batch["pose1"][batch_id]) 93 | past_poses.append(past_pose) 94 | self.timer[0][0].stop() 95 | 96 | self.timer[0][1].start("transform") 97 | transform_pc0 = selected_pc0 @ pose_0to1[:3, :3].T + pose_0to1[:3, 3] #t -> t+1 warping 98 | self.timer[0][1].stop() 99 | pose_flows.append(transform_pc0 - selected_pc0) 100 | transform_pc0s.append(transform_pc0) 101 | 102 | for i in range(1, self.num_frames - 1): 103 | selected_pc_m = batch[f"pc_m{i}"][batch_id] 104 | transform_pc_m = selected_pc_m @ past_poses[i-1][:3, :3].T + past_poses[i-1][:3, 3] 105 | transform_pc_m_frames[i-1].append(transform_pc_m) 106 | 107 | pc_m_frames = [torch.stack(transform_pc_m_frames[i], dim=0) for i in range(self.num_frames - 2)] 108 | 109 | pc0s = torch.stack(transform_pc0s, dim=0) 110 | pc1s = batch["pc1"] 111 | self.timer[0].stop() 112 | 113 | 114 | pcs_dict = { 115 | 'pc0s': pc0s, 116 | 'pc1s': pc1s, 117 | } 118 | for i in range(1, self.num_frames - 1): 119 | pcs_dict[f'pc_m{i}s'] = pc_m_frames[i-1] 120 | 121 | 122 | self.timer[1].start("4D_voxelization") 123 | dict_4d = self.embedder_4D(pcs_dict) 124 | pc01_tesnor_4d = dict_4d['4d_tensor'] 125 | pc0_3dvoxel_infos_lst =dict_4d['pc0_3dvoxel_infos_lst'] 126 | pc0_point_feats_lst =dict_4d['pc0_point_feats_lst'] 127 | pc0_num_voxels = dict_4d['pc0_mum_voxels'] 128 | self.timer[1].stop() 129 | 130 | self.timer[2].start("4D_backbone") 131 | pc_all_output_4d = self.network_4D(pc01_tesnor_4d) #all = past, current, next 다 합친것 132 | self.timer[2].stop() 133 | 134 | self.timer[3].start("4D pc01 to 3D pc0") 135 | pc0_last = self.seperate_feat(pc_all_output_4d) 136 | assert pc0_last.features.shape[0] == pc0_num_voxels, 'voxel number mismatch' 137 | self.timer[3].stop() 138 | 139 | self.timer[4].start("3D_sparsetensor_to_point and head") 140 | flows = self.pointhead_3D(pc0_last, pc0_3dvoxel_infos_lst, pc0_point_feats_lst) 141 | self.timer[4].stop() 142 | 143 | pc0_points_lst = [e["points"] for e in pc0_3dvoxel_infos_lst] 144 | pc0_valid_point_idxes = [e["point_idxes"] for e in pc0_3dvoxel_infos_lst] 145 | 146 | model_res = { 147 | "flow": flows, 148 | 'pose_flow': pose_flows, 149 | 150 | "pc0_valid_point_idxes": pc0_valid_point_idxes, 151 | "pc0_points_lst": pc0_points_lst, 152 | 153 | } 154 | return model_res -------------------------------------------------------------------------------- /lightning_logs/version_0/hparams.yaml: -------------------------------------------------------------------------------- 1 | cfg: 2 | model: 3 | name: flow4D 4 | target: 5 | _target_: scripts.network.models.flow4D.Flow4D 6 | voxel_size: 7 | - 0.2 8 | - 0.2 9 | - 0.2 10 | point_cloud_range: 11 | - -51.2 12 | - -51.2 13 | - -2.2 14 | - 51.2 15 | - 51.2 16 | - 4.2 17 | num_frames: 5 18 | grid_feature_size: 19 | - 512 20 | - 512 21 | - 32 22 | val_monitor: val/Dynamic/Mean 23 | slurm_id: 0 24 | wandb_mode: offline 25 | dataset_path: /data/datasets/argoverse2/argoverse2/preprocess/sensor 26 | output: 0709_223159_04__0709_223427 27 | checkpoint: logs/wandb/0709_223159/checkpoints/04__0.194_0.011_0.025_0.055_0.006_0.014_0.77.ckpt 28 | av2_mode: val 29 | save_res: false 30 | submit_version: '2024' 31 | num_frames: 5 32 | loss_fn: deflowLoss 33 | gpus: 1 34 | seed: 42069 35 | eval: true 36 | -------------------------------------------------------------------------------- /lightning_logs/version_0/metrics.csv: -------------------------------------------------------------------------------- 1 | epoch,step,val/Dynamic/BACKGROUND,val/Dynamic/CAR,val/Dynamic/Mean,val/Dynamic/OTHER_VEHICLES,val/Dynamic/PEDESTRIAN,val/Dynamic/WHEELED_VRU,val/EPE_BS,val/EPE_FD,val/EPE_FS,val/IoU,val/Static/BACKGROUND,val/Static/CAR,val/Static/Mean,val/Static/OTHER_VEHICLES,val/Static/PEDESTRIAN,val/Static/WHEELED_VRU,val/Three-way 2 | 0,0,nan,0.12572486698627472,0.18557094037532806,0.2090398073196411,0.2297278642654419,0.17779125273227692,0.00477928202599287,0.07636797428131104,0.01351962797343731,0.7429938912391663,0.004683776292949915,0.012337284162640572,0.011209582909941673,0.017054351046681404,0.011305715888738632,0.010666786693036556,0.031555626541376114 3 | -------------------------------------------------------------------------------- /network_4D.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pdb, json 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import spconv as spconv_core 7 | #from easydict import EasyDict 8 | import yaml 9 | spconv_core.constants.SPCONV_ALLOW_TF32 = True 10 | 11 | import spconv.pytorch as spconv 12 | import time 13 | from spconv.utils import Point2VoxelCPU3d as VoxelGenerator 14 | tv = None 15 | try: 16 | import cumm.tensorview as tv 17 | except: 18 | pass 19 | 20 | from torch.autograd import Function 21 | from torch.autograd.function import once_differentiable 22 | import torch.cuda.amp as amp 23 | _TORCH_CUSTOM_FWD = amp.custom_fwd(cast_inputs=torch.float16) 24 | _TORCH_CUSTOM_BWD = amp.custom_bwd 25 | 26 | #from typing import List, Tuple, Dict 27 | 28 | 29 | def conv1x1x1x3(in_planes, out_planes, stride=1, indice_key=None): 30 | return spconv.SubMConv4d(in_planes, out_planes, kernel_size=(1,1,1,3), stride=stride, 31 | padding=(0,0,0,1), bias=False, indice_key=indice_key) 32 | 33 | def conv3x3x3x1(in_planes, out_planes, stride=1, indice_key=None): 34 | return spconv.SubMConv4d(in_planes, out_planes, kernel_size=(3,3,3,1), stride=stride, 35 | padding=(1,1,1,0), bias=False, indice_key=indice_key) 36 | 37 | def conv1x1x1x1(in_planes, out_planes, stride=1, indice_key=None): 38 | return spconv.SubMConv4d(in_planes, out_planes, kernel_size=(1,1,1,1), stride=stride, 39 | padding=0, bias=False, indice_key=indice_key) 40 | 41 | 42 | def conv3x3x3x3(in_planes, out_planes, stride=1, indice_key=None): 43 | return spconv.SubMConv4d(in_planes, out_planes, kernel_size=(3,3,3,3), stride=stride, 44 | padding=(1,1,1,1), bias=False, indice_key=indice_key) 45 | 46 | norm_cfg = { 47 | # format: layer_type: (abbreviation, module) 48 | "BN": ("bn", nn.BatchNorm2d), 49 | "BN1d": ("bn1d", nn.BatchNorm1d), 50 | "GN": ("gn", nn.GroupNorm), 51 | } 52 | 53 | 54 | class Seperate_to_3D(nn.Module): 55 | def __init__(self, num_frames): 56 | super(Seperate_to_3D, self).__init__() 57 | self.num_frames = num_frames 58 | #self.return_pc1 = return_pc1 59 | 60 | def forward(self, sparse_4D_tensor): 61 | 62 | indices_4d = sparse_4D_tensor.indices 63 | features_4d = sparse_4D_tensor.features 64 | 65 | pc0_time_value = self.num_frames-2 66 | 67 | mask_pc0 = (indices_4d[:, -1] == pc0_time_value) 68 | 69 | pc0_indices = indices_4d[mask_pc0][:, :-1] 70 | pc0_features = features_4d[mask_pc0] 71 | 72 | pc0_sparse_3D = sparse_4D_tensor.replace_feature(pc0_features) 73 | pc0_sparse_3D.spatial_shape = sparse_4D_tensor.spatial_shape[:-1] 74 | pc0_sparse_3D.indices = pc0_indices 75 | 76 | return pc0_sparse_3D 77 | 78 | 79 | 80 | class SpatioTemporal_Decomposition_Block(nn.Module): 81 | def __init__(self, in_filters, mid_filters, out_filters, indice_key=None, down_key = None, pooling=False, z_pooling=True, interact=False): 82 | super(SpatioTemporal_Decomposition_Block, self).__init__() 83 | 84 | 85 | self.pooling = pooling 86 | 87 | self.act = nn.LeakyReLU() 88 | 89 | self.spatial_conv_1 = conv3x3x3x1(in_filters, mid_filters, indice_key=indice_key + "bef") 90 | self.bn_s_1 = nn.BatchNorm1d(mid_filters) 91 | 92 | self.temporal_conv_1 = conv1x1x1x3(in_filters, mid_filters) 93 | self.bn_t_1 = nn.BatchNorm1d(mid_filters) 94 | 95 | self.fusion_conv_1 = conv1x1x1x1(mid_filters*2+in_filters, mid_filters, indice_key=indice_key + "1D") 96 | self.bn_fusion_1 = nn.BatchNorm1d(mid_filters) 97 | 98 | 99 | self.spatial_conv_2 = conv3x3x3x1(mid_filters, mid_filters, indice_key=indice_key + "bef") 100 | self.bn_s_2 = nn.BatchNorm1d(mid_filters) 101 | 102 | self.temporal_conv_2 = conv1x1x1x3(mid_filters, mid_filters) 103 | self.bn_t_2 = nn.BatchNorm1d(mid_filters) 104 | 105 | self.fusion_conv_2 = conv1x1x1x1(mid_filters*3, out_filters, indice_key=indice_key + "1D") 106 | self.bn_fusion_2 = nn.BatchNorm1d(out_filters) 107 | 108 | 109 | if self.pooling: 110 | if z_pooling == True: 111 | self.pool = spconv.SparseConv4d(out_filters, out_filters, kernel_size=(2,2,2,1), stride=(2,2,2,1), indice_key=down_key, bias=False) 112 | else: 113 | self.pool = spconv.SparseConv4d(out_filters, out_filters, kernel_size=(2,2,1,1), stride=(2,2,1,1), indice_key=down_key, bias=False) 114 | 115 | self.weight_initialization() 116 | 117 | def weight_initialization(self): 118 | for m in self.modules(): 119 | if isinstance(m, nn.BatchNorm1d): 120 | nn.init.constant_(m.weight, 1) 121 | nn.init.constant_(m.bias, 0) 122 | 123 | def forward(self, x): 124 | 125 | #ST block 126 | S_feat_1 = self.spatial_conv_1(x) 127 | S_feat_1 = S_feat_1.replace_feature(self.bn_s_1(S_feat_1.features)) 128 | S_feat_1 = S_feat_1.replace_feature(self.act(S_feat_1.features)) 129 | 130 | T_feat_1 = self.temporal_conv_1(x) 131 | T_feat_1 = T_feat_1.replace_feature(self.bn_t_1(T_feat_1.features)) 132 | T_feat_1 = T_feat_1.replace_feature(self.act(T_feat_1.features)) 133 | 134 | ST_feat_1 = x.replace_feature(torch.cat([S_feat_1.features, T_feat_1.features, x.features], 1)) #residual까지 concate 135 | 136 | ST_feat_1 = self.fusion_conv_1(ST_feat_1) 137 | ST_feat_1 = ST_feat_1.replace_feature(self.bn_fusion_1(ST_feat_1.features)) 138 | ST_feat_1 = ST_feat_1.replace_feature(self.act(ST_feat_1.features)) 139 | 140 | #TS block 141 | S_feat_2 = self.spatial_conv_2(ST_feat_1) 142 | S_feat_2 = S_feat_2.replace_feature(self.bn_s_2(S_feat_2.features)) 143 | S_feat_2 = S_feat_2.replace_feature(self.act(S_feat_2.features)) 144 | 145 | T_feat_2 = self.temporal_conv_2(ST_feat_1) 146 | T_feat_2 = T_feat_2.replace_feature(self.bn_t_2(T_feat_2.features)) 147 | T_feat_2 = T_feat_2.replace_feature(self.act(T_feat_2.features)) 148 | 149 | ST_feat_2 = x.replace_feature(torch.cat([S_feat_2.features, T_feat_2.features, ST_feat_1.features], 1)) #residual까지 concate 150 | 151 | ST_feat_2 = self.fusion_conv_2(ST_feat_2) 152 | ST_feat_2 = ST_feat_2.replace_feature(self.bn_fusion_2(ST_feat_2.features)) 153 | ST_feat_2 = ST_feat_2.replace_feature(self.act(ST_feat_2.features)) 154 | 155 | if self.pooling: 156 | pooled = self.pool(ST_feat_2) 157 | return pooled, ST_feat_2 158 | else: 159 | return ST_feat_2 160 | 161 | 162 | 163 | class Network_4D(nn.Module): 164 | def __init__(self, in_channel=16, out_channel=16, model_size = 16): 165 | super().__init__() 166 | 167 | SpatioTemporal_Block = SpatioTemporal_Decomposition_Block 168 | 169 | self.model_size = model_size 170 | 171 | 172 | self.STDB_1_1_1 = SpatioTemporal_Block(in_channel, model_size, model_size, indice_key="st1_1", down_key='floor1') 173 | self.STDB_1_1_2 = SpatioTemporal_Block(model_size, model_size, model_size*2, indice_key="st1_1", down_key='floor1', pooling=True) #512 512 32 -> 256 256 16 174 | 175 | self.STDB_2_1_1 = SpatioTemporal_Block(model_size*2, model_size*2, model_size*2, indice_key="st2_1", down_key='floor2') 176 | self.STDB_2_1_2 = SpatioTemporal_Block(model_size*2, model_size*2, model_size*4, indice_key="st2_1", down_key='floor2', pooling=True) #256 256 16 -> 128 128 8 177 | 178 | self.STDB_3_1_1 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st3_1", down_key='floor3') 179 | self.STDB_3_1_2 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st3_1", down_key='floor3', pooling=True) #128 128 8 -> 64 64 4 180 | 181 | self.STDB_4_1_1 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st4_1", down_key='floor4') 182 | self.STDB_4_1_2 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st4_1", down_key='floor4', pooling=True, z_pooling=False) #64 64 4 -> 64 64 4 183 | 184 | self.STDB_5_1_1 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st5_1") 185 | self.STDB_5_1_2 = SpatioTemporal_Block(model_size*4, model_size*4, model_size*4, indice_key="st5_1") 186 | self.up_subm_5 = spconv.SparseInverseConv4d(model_size*4, model_size*4, kernel_size=(2,2,1,1), indice_key='floor4', bias=False) #zpooling false 187 | 188 | self.STDB_4_2_1 = SpatioTemporal_Block(model_size*8, model_size*8, model_size*4, indice_key="st4_2") 189 | self.up_subm_4 = spconv.SparseInverseConv4d(model_size*4, model_size*4, kernel_size=(2,2,2,1), indice_key='floor3', bias=False) 190 | 191 | self.STDB_3_2_1 = SpatioTemporal_Block(model_size*8, model_size*8, model_size*4, indice_key="st3_2") 192 | self.up_subm_3 = spconv.SparseInverseConv4d(model_size*4, model_size*4, kernel_size=(2,2,2,1), indice_key='floor2', bias=False) 193 | 194 | self.STDB_2_2_1 = SpatioTemporal_Block(model_size*8, model_size*4, model_size*4, indice_key="st_2_2") 195 | self.up_subm_2 = spconv.SparseInverseConv4d(model_size*4, model_size*2, kernel_size=(2,2,2,1), indice_key='floor1', bias=False) 196 | 197 | self.STDB_1_2_1 = SpatioTemporal_Block(model_size*4, model_size*2, out_channel, indice_key="st_1_2") 198 | 199 | 200 | def forward(self, sp_tensor): 201 | 202 | sp_tensor = self.STDB_1_1_1(sp_tensor) 203 | down_2, skip_1 = self.STDB_1_1_2(sp_tensor) 204 | 205 | down_2 = self.STDB_2_1_1(down_2) 206 | down_3, skip_2 = self.STDB_2_1_2(down_2) 207 | 208 | down_3 = self.STDB_3_1_1(down_3) 209 | down_4, skip_3 = self.STDB_3_1_2(down_3) 210 | 211 | down_4 = self.STDB_4_1_1(down_4) 212 | down_5, skip_4 = self.STDB_4_1_2(down_4) 213 | 214 | down_5 = self.STDB_5_1_1(down_5) 215 | down_5 = self.STDB_5_1_2(down_5) 216 | 217 | up_4 = self.up_subm_5(down_5) 218 | up_4 = up_4.replace_feature(torch.cat((up_4.features, skip_4.features), 1)) 219 | up_4 = self.STDB_4_2_1(up_4) 220 | 221 | up_3 = self.up_subm_4(up_4) 222 | up_3 = up_3.replace_feature(torch.cat((up_3.features, skip_3.features), 1)) 223 | up_3 = self.STDB_3_2_1(up_3) 224 | 225 | up_2 = self.up_subm_3(up_3) 226 | up_2 = up_2.replace_feature(torch.cat((up_2.features, skip_2.features), 1)) 227 | up_2 = self.STDB_2_2_1(up_2) 228 | 229 | up_1 = self.up_subm_2(up_2) 230 | up_1 = up_1.replace_feature(torch.cat((up_1.features, skip_1.features), 1)) 231 | up_1 = self.STDB_1_2_1(up_1) 232 | 233 | return up_1 234 | 235 | 236 | 237 | class Point_head(nn.Module): 238 | def __init__(self, voxel_feat_dim: int = 96, point_feat_dim: int = 32): 239 | super().__init__() 240 | 241 | self.input_dim = voxel_feat_dim + point_feat_dim 242 | 243 | self.PPmodel_flow = nn.Sequential( 244 | nn.Linear(self.input_dim, 32), 245 | nn.BatchNorm1d(32), 246 | nn.ReLU(), 247 | 248 | nn.Linear(32, 3) 249 | ) 250 | 251 | def forward_single(self, voxel_feat, voxel_coords, point_feat): 252 | 253 | voxel_to_point_feat = voxel_feat[:, voxel_coords[:,2], voxel_coords[:,1], voxel_coords[:,0]].T 254 | concated_point_feat = torch.cat([voxel_to_point_feat, point_feat],dim=-1) 255 | 256 | flow = self.PPmodel_flow(concated_point_feat) 257 | 258 | return flow 259 | 260 | def forward(self, sparse_tensor, voxelizer_infos, pc0_point_feats_lst): 261 | 262 | voxel_feats = sparse_tensor.dense() 263 | 264 | flow_outputs = [] 265 | batch_idx = 0 266 | for voxelizer_info in voxelizer_infos: 267 | voxel_coords = voxelizer_info["voxel_coords"] 268 | point_feat = pc0_point_feats_lst[batch_idx] 269 | voxel_feat = voxel_feats[batch_idx, :] 270 | flow = self.forward_single(voxel_feat, voxel_coords, point_feat) 271 | batch_idx += 1 272 | flow_outputs.append(flow) 273 | 274 | return flow_outputs 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | --------------------------------------------------------------------------------