├── LICENSE ├── README.md ├── compress.py ├── data ├── ScanNet │ └── scene0011_00_vh_clean_2.ply ├── ScanNet_10bit │ └── scene0011_00_vh_clean_2.ply ├── ScanNet_12bit │ └── scene0011_00_vh_clean_2.ply ├── SemanticKITTI │ └── 000000.bin.ply ├── SemanticKITTI_10bit │ └── 000000.bin.ply ├── SemanticKITTI_12bit │ └── 000000.bin.ply └── ShapeNet_10bit │ ├── 1.ply │ ├── 10.ply │ ├── 2.ply │ ├── 3.ply │ ├── 4.ply │ ├── 5.ply │ ├── 6.ply │ ├── 7.ply │ ├── 8.ply │ └── 9.ply ├── decompress.py ├── env_create.sh ├── environment.yml ├── figure ├── framework.png ├── humanbodies.png ├── lidar.png └── objects_and_scenes.png ├── kit.py ├── model └── ckpt.pt ├── net.py ├── reflectance_compress.py ├── reflectance_decompress.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 I2-Multimedia-Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Efficient and Generic Point Model for Lossless Point Cloud Attribute Compression 3 | 4 | [![arXiv](https://img.shields.io/badge/Arxiv-2404.06936-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2404.06936) 5 | [![GitHub issues](https://img.shields.io/github/issues/i2-multimedia-lab/polopcac?color=critical&label=Issues)](https://github.com/i2-multimedia-lab/polopcac/issues?q=is%3Aopen+is%3Aissue) 6 | [![GitHub closed issues](https://img.shields.io/github/issues-closed/i2-multimedia-lab/polopcac?color=success&label=Issues)](https://github.com/i2-multimedia-lab/polopcac/issues?q=is%3Aissue+is%3Aclosed) 7 | [![MIT License](https://img.shields.io/badge/License-MIT-green.svg)](https://choosealicense.com/licenses/mit/) 8 | 9 | 🔥 **High Performance**: Lower bitrate than G-PCCv23 predlift. \ 10 | 🚀 **High Efficiency**: Faster than G-PCCv23 on one RTX 2080Ti. \ 11 | 🌌 **Robust Generalizability**: Instantly applied to various samples once trained on a small-scale objects. \ 12 | 🌍 **Scale&Density Independent**: Directly executed on point clouds of arbitrary scale and density. \ 13 | 🌱 **Light Weight**: Only 676k parameters (about 2.6MB). 14 | 15 | ## Abstract 16 | 17 | > The past several years have witnessed the emergence of learned point cloud compression (PCC) techniques. However, current learning-based lossless point cloud attribute compression (PCAC) methods either suffer from high computational complexity or deteriorated compression performance. Moreover, the significant variations in point cloud scale and sparsity encountered in real-world applications make developing an all-in-one neural model a challenging task. In this paper, we propose PoLoPCAC, an efficient and generic lossless PCAC method that achieves high compression efficiency and strong generalizability simultaneously. We formulate lossless PCAC as the task of inferring explicit distributions of attributes from group-wise autoregressive priors. A progressive random grouping strategy is first devised to efficiently resolve the point cloud into groups, and then the attributes of each group are modeled sequentially from accumulated antecedents. A locality-aware attention mechanism is utilized to exploit prior knowledge from context windows in parallel. Since our method directly operates on points, it can naturally avoids distortion caused by voxelization, and can be executed on point clouds with arbitrary scale and density. Experiments show that our method can be instantly deployed once trained on a Synthetic 2k-ShapeNet dataset while enjoying continuous bit-rate reduction over the latest G-PCCv23 on various datasets (ShapeNet, ScanNet, MVUB, 8iVFB). Meanwhile, our method reports shorter coding time than G-PCCv23 on the majority of sequences with a lightweight model size (2.6MB), which is highly attractive for practical applications. 18 | 19 | ## Overview 20 | 21 | ![](./figure/framework.png) 22 | 23 | ## Environment 24 | 25 | The environment we use is as follows: 26 | 27 | - Python 3.10.10 28 | - Pytorch 2.0.0 with CUDA 11.7 29 | - Pytorch3d 0.7.3 30 | - Torchac 0.9.3 31 | 32 | For the convenience of reproduction, we provide three different ways to help create the environment: 33 | 34 | #### Option 1: Using yml 35 | 36 | ``` 37 | conda env create -f=environment.yml 38 | ``` 39 | 40 | #### Option 2: Using .sh 41 | 42 | ``` 43 | source ./env_create.sh 44 | ``` 45 | 46 | #### Option 3: CodeWithGPU (AutoDL image) 47 | 48 | 🤗 [PoLoPCAC Image](https://www.codewithgpu.com/i/I2-Multimedia-Lab/PoLoPCAC/PoLoPCAC) has been uploaded at [CodeWithGPU](https://www.codewithgpu.com/image) community. The required environment can be instantly built once you create an [AutoDL](https://www.autodl.com) container instance with our image `I2-Multimedia-Lab/PoLoPCAC/PoLoPCAC` being selected from the community image list. 49 | 50 | ## Data 51 | 52 | Example point clouds are saved in ``./data/`` with the following format: 53 | 54 | ``` 55 | data/ 56 | ├── ScanNet/ // A ScanNet point cloud on original scale 57 | ├── ScanNet_10bit/ // Rescale the point cloud to the scale of [0, 1023] (but not voxelized) 58 | ├── ScanNet_12bit/ // Rescale the point cloud to the scale of [0, 4095] (but not voxelized) 59 | | 60 | ├── SemanticKITTI/ // A SemanticKITTI point cloud with reflectance attribute on original scale 61 | ├── SemanticKITTI_10bit/ // Rescale the point cloud to the scale of [0, 1023] (but not voxelized) 62 | ├── SemanticKITTI_12bit/ // Rescale the point cloud to the scale of [0, 4095] (but not voxelized) 63 | | 64 | └── ShapeNet_10bit/ // Several sparse object point clouds (2k points per object) on the scale of [0, 1023] (not voxelized) 65 | ``` 66 | 67 | The Synthetic 2k-ShapeNet dataset used in the paper is uploaded to [Google Drive](https://drive.google.com/file/d/1mhUBx4_6joG0KxPkHfXol8fw0tSHxvq6/view?usp=sharing). 68 | 69 | ## Training 70 | 71 | 📢 Synthetic 2k-ShapeNet trained model is saved in ``./model/``. 72 | 73 | Otherwise you can train the model from from scratch: 74 | 75 | ``` 76 | python ./train.py \ 77 | --training_data='./data/Synthetic_2k_ShapeNet/train_64/*.ply' \ 78 | --model_save_folder='./retrained_model/' 79 | ``` 80 | 81 | (The PoLoPCAC might perform better when trained on non-voxelized samples, such as Synthetic 2k-ShapeNet) 82 | 83 | ## Compression 84 | 85 | ``` 86 | python ./compress.py \ 87 | --ckpt='./model/ckpt.pt' \ 88 | --input_glob='./data/ScanNet/*.ply' \ 89 | --compressed_path='./data/ScanNet_compressed' 90 | ``` 91 | 92 | 💡 You are expected to get the same Bpp value for one point cloud of different scale (e.g., same Bpp result for ``/data/ScanNet/scene0011_00_vh_clean_2.ply``, ``/data/ScanNet_10bit/scene0011_00_vh_clean_2.ply``, and ``/data/ScanNet_12bit/scene0011_00_vh_clean_2.ply``). 93 | 94 | ## Decompression 95 | 96 | ``` 97 | python ./decompress.py --ckpt='./model/ckpt.pt' \ 98 | --compressed_path='./data/ScanNet_compressed' \ 99 | --decompressed_path='./data/ScanNet_decompressed' 100 | ``` 101 | 102 | ## Compression For Reflectance 103 | 104 | ``` 105 | python ./reflectance_compress.py \ 106 | --ckpt='./model/ckpt.pt' \ 107 | --input_glob='./data/SemanticKITTI/*.ply' \ 108 | --compressed_path='./data/SemanticKITTI_compressed' 109 | ``` 110 | 111 | ## Decompression For Reflectance 112 | 113 | ``` 114 | python ./reflectance_decompress.py \ 115 | --ckpt='./model/ckpt.pt' \ 116 | --compressed_path='./data/SemanticKITTI_compressed' \ 117 | --decompressed_path='./data/SemanticKITTI_decompressed' 118 | ``` 119 | 120 | ## Examples 121 | 122 | ### Objects&Scenes 123 | 124 | ![](./figure/objects_and_scenes.png) 125 | 126 | ### Human Bodies 127 | 128 | ![](./figure/humanbodies.png) 129 | 130 | ### LiDAR Reflectance 131 | 132 | ![](./figure/lidar.png) 133 | 134 | ## Citation 135 | 136 | 😊 If you find this work useful, please consider citing our work: 137 | 138 | ``` 139 | @article{kang2024polopcac, 140 | title={Efficient and Generic Point Model for Lossless Point Cloud Attribute Compression}, 141 | author={Kang You, Pan Gao, Zhan Ma}, 142 | journal={arXiv preprint arXiv:2404.06936}, 143 | year={2024} 144 | } 145 | ``` 146 | -------------------------------------------------------------------------------- /compress.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | import numpy as np 6 | 7 | from glob import glob 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torchac 12 | from pytorch3d.ops.knn import _KNN, knn_gather, knn_points 13 | 14 | import kit 15 | from net import Network 16 | 17 | torch.cuda.manual_seed(1) 18 | torch.manual_seed(1) 19 | np.random.seed(1) 20 | 21 | 22 | parser = argparse.ArgumentParser( 23 | prog='compress.py', 24 | description='Compress Point Cloud Attributes.', 25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 26 | ) 27 | 28 | parser.add_argument('--ckpt', required=True, help='Trained ckeckpoint file.') 29 | parser.add_argument('--input_glob', required=True, help='Point clouds glob pattern to be compressed.') 30 | parser.add_argument('--compressed_path', required=True, help='Compressed file saving directory.') 31 | 32 | parser.add_argument('--local_region', type=int, help='Neighbooring scope for context windows (i.e., K).', default=8) 33 | parser.add_argument('--granularity', type=int, help='Upper limit for each group (i.e., s*).', default=2**14) 34 | parser.add_argument('--init_ratio', type=int, help='The ratio for size of the very first group (i.e., alpha).', default=128) 35 | parser.add_argument('--expand_ratio', type=int, help='Expand ratio (i.e., r)', default=2) 36 | parser.add_argument('--prg_seed', type=int, help='Pseudorandom seed for PRG.', default=2147483647) 37 | 38 | args = parser.parse_args() 39 | 40 | 41 | if not os.path.exists(args.compressed_path): 42 | os.makedirs(args.compressed_path) 43 | 44 | files = np.array(glob(args.input_glob, recursive=True)) 45 | np.random.shuffle(files) 46 | 47 | net = Network(local_region=args.local_region, granularity=args.granularity, init_ratio=args.init_ratio, expand_ratio=args.expand_ratio) 48 | net.load_state_dict(torch.load(args.ckpt)) 49 | net = torch.compile(net, mode='max-autotune') 50 | net.cuda().eval() 51 | 52 | # warm up our model 53 | # since the very first step of network is extremely slow... 54 | _ = net.mu_sigma_pred(net.pt(torch.rand((1, 32, 8, 3)).cuda(), torch.rand((1, 32, 8, 3)).cuda())) 55 | 56 | enc_times = [] 57 | fnames, bpps = [], [] 58 | with torch.no_grad(): 59 | for f in tqdm(files): 60 | fname = os.path.split(f)[-1] 61 | 62 | pc = kit.read_point_cloud_ycocg(f) 63 | batch_x = torch.tensor(pc).unsqueeze(0) 64 | batch_x = batch_x.cuda() 65 | 66 | B, N, _ = batch_x.shape 67 | 68 | torch.cuda.synchronize() 69 | TIME_STAMP = time.time() 70 | 71 | ##################################### 72 | # progressive random grouping 73 | 74 | g_cpu = torch.Generator() 75 | g_cpu.manual_seed(args.prg_seed) 76 | 77 | batch_x = batch_x[:, torch.randperm(batch_x.size()[1], generator=g_cpu), :] 78 | _, N, _ = batch_x.shape 79 | 80 | base_size = min(N//args.init_ratio, args.granularity) 81 | window_size = base_size 82 | 83 | context_ls, target_ls = [], [] 84 | cursor = base_size 85 | 86 | while cursor faster 124 | # byte_stream = torchac.encode_float_cdf(cdf.cpu(), target_feature.cpu(), check_input_bounds=True) 125 | byte_stream = torchac.encode_int16_normalized_cdf( 126 | kit._convert_to_int_and_normalize(cdf, True).cpu(), 127 | target_feature.cpu()) 128 | 129 | # save current group to a file 130 | # concat to a single bitstream is also practicable 131 | comp_f = os.path.join(args.compressed_path, fname+f'.{i}.bin') 132 | with open(comp_f, 'wb') as fout: 133 | fout.write(byte_stream) 134 | 135 | # record bitrate of current group 136 | total_bits += kit.get_file_size_in_bits(comp_f) 137 | 138 | # save the first group directly using 8bit code 139 | comp_base_f = os.path.join(args.compressed_path, fname+'.c.bin') 140 | context_base = context_ls[0][0, :, 3:].detach().cpu().numpy().astype(np.int16) 141 | context_base = kit.transformYCoCgToRGB(8, context_base) 142 | 143 | torch.cuda.synchronize() 144 | enc_times.append(time.time() - TIME_STAMP) 145 | 146 | context_base.astype(np.uint8).tofile(comp_base_f) 147 | total_bits += kit.get_file_size_in_bits(comp_base_f) 148 | 149 | # save geometry (for decompression only) 150 | geo_f = os.path.join(args.compressed_path, fname+'.geo.bin') 151 | batch_x[:, :, :3].detach().cpu().numpy().astype(np.float32).tofile(geo_f) 152 | 153 | # record 154 | fnames.append(fname) 155 | bpps.append(np.round(total_bits/N, 3)) 156 | 157 | print('Max GPU Memory:', round(torch.cuda.max_memory_allocated()/1024/1024, 3), 'MB') 158 | print(f'Done! Total {len(fnames)} \ 159 | | color bpp: {round(np.array(bpps).mean(), 3)}\ 160 | | ave enc time: {round(np.array(enc_times).mean(), 3)} s') 161 | print('Params:', sum(p.numel() for p in net.parameters()), 162 | 'Trainable params:', sum(p.numel() for p in net.parameters() if p.requires_grad)) 163 | -------------------------------------------------------------------------------- /data/ScanNet/scene0011_00_vh_clean_2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ScanNet/scene0011_00_vh_clean_2.ply -------------------------------------------------------------------------------- /data/ScanNet_10bit/scene0011_00_vh_clean_2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ScanNet_10bit/scene0011_00_vh_clean_2.ply -------------------------------------------------------------------------------- /data/ScanNet_12bit/scene0011_00_vh_clean_2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ScanNet_12bit/scene0011_00_vh_clean_2.ply -------------------------------------------------------------------------------- /data/SemanticKITTI/000000.bin.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/SemanticKITTI/000000.bin.ply -------------------------------------------------------------------------------- /data/SemanticKITTI_10bit/000000.bin.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/SemanticKITTI_10bit/000000.bin.ply -------------------------------------------------------------------------------- /data/SemanticKITTI_12bit/000000.bin.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/SemanticKITTI_12bit/000000.bin.ply -------------------------------------------------------------------------------- /data/ShapeNet_10bit/1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ShapeNet_10bit/1.ply -------------------------------------------------------------------------------- /data/ShapeNet_10bit/10.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ShapeNet_10bit/10.ply -------------------------------------------------------------------------------- /data/ShapeNet_10bit/2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ShapeNet_10bit/2.ply -------------------------------------------------------------------------------- /data/ShapeNet_10bit/3.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ShapeNet_10bit/3.ply -------------------------------------------------------------------------------- /data/ShapeNet_10bit/4.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ShapeNet_10bit/4.ply -------------------------------------------------------------------------------- /data/ShapeNet_10bit/5.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ShapeNet_10bit/5.ply -------------------------------------------------------------------------------- /data/ShapeNet_10bit/6.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ShapeNet_10bit/6.ply -------------------------------------------------------------------------------- /data/ShapeNet_10bit/7.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ShapeNet_10bit/7.ply -------------------------------------------------------------------------------- /data/ShapeNet_10bit/8.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ShapeNet_10bit/8.ply -------------------------------------------------------------------------------- /data/ShapeNet_10bit/9.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/data/ShapeNet_10bit/9.ply -------------------------------------------------------------------------------- /decompress.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | import numpy as np 6 | 7 | from glob import glob 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torchac 12 | from pytorch3d.ops.knn import _KNN, knn_gather, knn_points 13 | 14 | import kit 15 | from net import Network 16 | 17 | torch.cuda.manual_seed(1) 18 | torch.manual_seed(1) 19 | np.random.seed(1) 20 | 21 | 22 | parser = argparse.ArgumentParser( 23 | prog='decompress.py', 24 | description='Decompress Point Cloud Attributes.', 25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 26 | ) 27 | 28 | parser.add_argument('--ckpt', required=True, help='Trained ckeckpoint file.') 29 | parser.add_argument('--compressed_path', required=True, help='Compressed file saving directory.') 30 | parser.add_argument('--decompressed_path', required=True, help='Decompressed file saving directory.') 31 | 32 | parser.add_argument('--local_region', type=int, help='', default=8) 33 | parser.add_argument('--granularity', type=int, help='', default=2**14) 34 | parser.add_argument('--init_ratio', type=int, help='', default=128) 35 | parser.add_argument('--expand_ratio', type=int, help='', default=2) 36 | parser.add_argument('--prg_seed', type=int, help='', default=2147483647) 37 | 38 | args = parser.parse_args() 39 | 40 | 41 | if not os.path.exists(args.decompressed_path): 42 | os.makedirs(args.decompressed_path) 43 | 44 | comp_glob = os.path.join(args.compressed_path, '*.c.bin') 45 | files = np.array(glob(comp_glob, recursive=True)) 46 | 47 | net = Network(local_region=args.local_region, granularity=args.granularity, init_ratio=args.init_ratio, expand_ratio=args.expand_ratio) 48 | net.load_state_dict(torch.load(args.ckpt)) 49 | net = torch.compile(net, mode='max-autotune') 50 | net.cuda().eval() 51 | 52 | # warm up our model 53 | # since the very first step of network is extremely slow... 54 | _ = net.mu_sigma_pred(net.pt(torch.rand((1, 32, 8, 3)).cuda(), torch.rand((1, 32, 8, 3)).cuda())) 55 | 56 | dec_times = [] 57 | with torch.no_grad(): 58 | for comp_c_f in tqdm(files): 59 | fname = os.path.split(comp_c_f)[-1].split('.c.bin')[0] 60 | geo_f_path = os.path.join(args.compressed_path, fname+'.geo.bin') 61 | 62 | # read geometry 63 | batch_x_geo = torch.tensor(np.fromfile(geo_f_path, dtype=np.float32)).view(1, -1, 3) 64 | context_attr_base = np.array(np.fromfile(comp_c_f, dtype=np.uint8)).reshape(-1, 3) 65 | 66 | # convert base attr to ycocg 67 | torch.cuda.synchronize() 68 | TIME_STAMP = time.time() 69 | context_attr_base = context_attr_base.astype(np.int16) 70 | context_attr_base = kit.transformRGBToYCoCg(8, context_attr_base) 71 | context_attr_base = torch.tensor(context_attr_base.astype(float)).view(1, -1, 3) 72 | 73 | _, N, _ = batch_x_geo.shape 74 | base_size = min(N//args.init_ratio, args.granularity) 75 | window_size = base_size 76 | cursor = base_size 77 | i=0 78 | while cursor < N: 79 | window_size = min(window_size*args.expand_ratio, args.granularity) 80 | 81 | # get context info 82 | context_geo = batch_x_geo[:, :cursor, :].cuda() 83 | target_geo = batch_x_geo[:, cursor:cursor+window_size, :].cuda() 84 | cursor += window_size 85 | 86 | # rescale input attributes to [0, 1] in GPU 87 | context_attr = context_attr_base.clone().float().cuda() 88 | context_attr[:, :, 0] = context_attr[:, :, 0] / 255 89 | context_attr[:, :, 1:] = context_attr[:, :, 1:] / 511 90 | 91 | # context window gathering 92 | _, idx, context_grouped_geo = knn_points(target_geo, context_geo, K=net.local_region, return_nn=True) 93 | context_grouped_attr = knn_gather(context_attr, idx) 94 | 95 | # spatial normalization 96 | context_grouped_geo = context_grouped_geo - target_geo.view(1, -1, 1, 3) 97 | context_grouped_geo = kit.n_scale_ball(context_grouped_geo) 98 | 99 | # network 100 | feature = net.pt(context_grouped_geo, context_grouped_attr) 101 | mu_sigma = net.mu_sigma_pred(feature) 102 | mu, sigma = mu_sigma[:, :, :3]+0.5, torch.exp(mu_sigma[:, :, 3:]) 103 | 104 | cdf = kit.get_cdf_ycocg(mu[0]*255, sigma[0]*32) 105 | comp_f = os.path.join(args.compressed_path, fname+f'.{i}.bin') 106 | with open(comp_f, 'rb') as fin: 107 | byte_stream = fin.read() 108 | 109 | # put _convert_to_int_and_normalize in GPU -> faster 110 | # original version: decomp_attr = torchac.decode_float_cdf(cdf.cpu(), byte_stream) 111 | decomp_attr = torchac.decode_int16_normalized_cdf( 112 | kit._convert_to_int_and_normalize(cdf, True).cpu(), 113 | byte_stream) 114 | 115 | # concat current decoded group to context 116 | context_attr_base = torch.cat((context_attr_base, decomp_attr.unsqueeze(0)), dim=1) 117 | i+=1 118 | 119 | decompressed_pc = torch.cat((batch_x_geo, context_attr_base), dim=-1) 120 | torch.cuda.synchronize() 121 | dec_times.append(time.time()-TIME_STAMP) 122 | decompressed_path = os.path.join(args.decompressed_path, fname+'.bin.ply') 123 | kit.save_point_cloud_ycocg(decompressed_pc[0].detach().cpu().numpy(), path=decompressed_path) 124 | 125 | print('Max GPU Memory:', round(torch.cuda.max_memory_allocated(device=None)/1024/1024, 3), 'MB') 126 | print('ave dec time:', round(np.array(dec_times).mean(), 3), 's') 127 | -------------------------------------------------------------------------------- /env_create.sh: -------------------------------------------------------------------------------- 1 | # create an anaconda environment 2 | # with python3.10, pytorch2.0.1 (CUDA 11.7), pytorch3d, and other dependencies 3 | # tested on Ubuntu 20.04 and Debian GNU/Linux 10 in April 2024 4 | 5 | # create environment 6 | conda create -n polopcac python=3.10 7 | conda activate polopcac 8 | 9 | # install pytorch 10 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia 11 | 12 | # install pytorch3d 13 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 14 | conda install pytorch3d -c pytorch3d 15 | 16 | # install torchac and others 17 | pip install torchac 18 | pip install ninja 19 | 20 | pip install pandas matplotlib plyfile pyntcloud 21 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: polopcac 2 | channels: 3 | - pytorch3d 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - blas=1.0=mkl 12 | - brotli-python=1.0.9=py310h6a678d5_7 13 | - bzip2=1.0.8=h5eee18b_5 14 | - ca-certificates=2024.3.11=h06a4308_0 15 | - certifi=2024.2.2=py310h06a4308_0 16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 17 | - colorama=0.4.6=pyhd8ed1ab_0 18 | - cuda-cudart=11.7.99=0 19 | - cuda-cupti=11.7.101=0 20 | - cuda-libraries=11.7.1=0 21 | - cuda-nvrtc=11.7.99=0 22 | - cuda-nvtx=11.7.91=0 23 | - cuda-runtime=11.7.1=0 24 | - dataclasses=0.8=pyhc8e2a94_3 25 | - ffmpeg=4.3=hf484d3e_0 26 | - filelock=3.13.1=py310h06a4308_0 27 | - freetype=2.12.1=h4a9f257_0 28 | - fvcore=0.1.5.post20221221=pyhd8ed1ab_0 29 | - gmp=6.2.1=h295c915_3 30 | - gmpy2=2.1.2=py310heeb90bb_0 31 | - gnutls=3.6.15=he1e5248_0 32 | - idna=3.4=py310h06a4308_0 33 | - intel-openmp=2023.1.0=hdb19cb5_46306 34 | - iopath=0.1.9=pyhd8ed1ab_0 35 | - jinja2=3.1.3=py310h06a4308_0 36 | - jpeg=9e=h5eee18b_1 37 | - lame=3.100=h7b6447c_0 38 | - lcms2=2.12=h3be6417_0 39 | - ld_impl_linux-64=2.38=h1181459_1 40 | - lerc=3.0=h295c915_0 41 | - libcublas=11.10.3.66=0 42 | - libcufft=10.7.2.124=h4fbf590_0 43 | - libcufile=1.9.0.20=0 44 | - libcurand=10.3.5.119=0 45 | - libcusolver=11.4.0.1=0 46 | - libcusparse=11.7.4.91=0 47 | - libdeflate=1.17=h5eee18b_1 48 | - libffi=3.4.4=h6a678d5_0 49 | - libgcc-ng=11.2.0=h1234567_1 50 | - libgomp=11.2.0=h1234567_1 51 | - libiconv=1.16=h7f8727e_2 52 | - libidn2=2.3.4=h5eee18b_0 53 | - libnpp=11.7.4.75=0 54 | - libnvjpeg=11.8.0.2=0 55 | - libpng=1.6.39=h5eee18b_0 56 | - libstdcxx-ng=11.2.0=h1234567_1 57 | - libtasn1=4.19.0=h5eee18b_0 58 | - libtiff=4.5.1=h6a678d5_0 59 | - libunistring=0.9.10=h27cfd23_0 60 | - libuuid=1.41.5=h5eee18b_0 61 | - libwebp-base=1.3.2=h5eee18b_0 62 | - lz4-c=1.9.4=h6a678d5_0 63 | - markupsafe=2.1.3=py310h5eee18b_0 64 | - mkl=2023.1.0=h213fc3f_46344 65 | - mkl-service=2.4.0=py310h5eee18b_1 66 | - mkl_fft=1.3.8=py310h5eee18b_0 67 | - mkl_random=1.2.4=py310hdb19cb5_0 68 | - mpc=1.1.0=h10f8cd9_1 69 | - mpfr=4.0.2=hb69a4c5_1 70 | - mpmath=1.3.0=py310h06a4308_0 71 | - ncurses=6.4=h6a678d5_0 72 | - nettle=3.7.3=hbbd107a_1 73 | - networkx=3.1=py310h06a4308_0 74 | - numpy=1.26.4=py310h5f9d8c6_0 75 | - numpy-base=1.26.4=py310hb5e798b_0 76 | - openh264=2.1.1=h4ff587b_0 77 | - openjpeg=2.4.0=h3ad879b_0 78 | - openssl=3.0.13=h7f8727e_0 79 | - pillow=10.2.0=py310h5eee18b_0 80 | - pip=23.3.1=py310h06a4308_0 81 | - portalocker=2.8.2=py310hff52083_1 82 | - pysocks=1.7.1=py310h06a4308_0 83 | - python=3.10.14=h955ad1f_0 84 | - python_abi=3.10=2_cp310 85 | - pytorch=2.0.1=py3.10_cuda11.7_cudnn8.5.0_0 86 | - pytorch-cuda=11.7=h778d358_5 87 | - pytorch-mutex=1.0=cuda 88 | - pytorch3d=0.7.5=py310_cu117_pyt201 89 | - pyyaml=6.0=py310h5764c6d_4 90 | - readline=8.2=h5eee18b_0 91 | - requests=2.31.0=py310h06a4308_1 92 | - setuptools=68.2.2=py310h06a4308_0 93 | - sqlite=3.41.2=h5eee18b_0 94 | - sympy=1.12=py310h06a4308_0 95 | - tabulate=0.9.0=pyhd8ed1ab_1 96 | - tbb=2021.8.0=hdb19cb5_0 97 | - termcolor=2.4.0=pyhd8ed1ab_0 98 | - tk=8.6.12=h1ccaba5_0 99 | - torchaudio=2.0.2=py310_cu117 100 | - torchtriton=2.0.0=py310 101 | - torchvision=0.15.2=py310_cu117 102 | - tqdm=4.66.2=pyhd8ed1ab_0 103 | - typing_extensions=4.9.0=py310h06a4308_1 104 | - urllib3=2.1.0=py310h06a4308_1 105 | - wheel=0.41.2=py310h06a4308_0 106 | - xz=5.4.6=h5eee18b_0 107 | - yacs=0.1.8=pyhd8ed1ab_0 108 | - yaml=0.2.5=h7f98852_2 109 | - zlib=1.2.13=h5eee18b_0 110 | - zstd=1.5.5=hc292b87_0 111 | - pip: 112 | - contourpy==1.2.0 113 | - cycler==0.12.1 114 | - fonttools==4.50.0 115 | - kiwisolver==1.4.5 116 | - matplotlib==3.8.3 117 | - ninja==1.11.1.1 118 | - packaging==24.0 119 | - pandas==2.2.1 120 | - plyfile==1.0.3 121 | - pyntcloud==0.3.1 122 | - pyparsing==3.1.2 123 | - python-dateutil==2.9.0.post0 124 | - pytz==2024.1 125 | - scipy==1.12.0 126 | - six==1.16.0 127 | - torchac==0.9.3 128 | - tzdata==2024.1 -------------------------------------------------------------------------------- /figure/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/figure/framework.png -------------------------------------------------------------------------------- /figure/humanbodies.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/figure/humanbodies.png -------------------------------------------------------------------------------- /figure/lidar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/figure/lidar.png -------------------------------------------------------------------------------- /figure/objects_and_scenes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/figure/objects_and_scenes.png -------------------------------------------------------------------------------- /kit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import multiprocessing 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from tqdm import tqdm 13 | from plyfile import PlyElement, PlyData 14 | from pyntcloud import PyntCloud 15 | from pytorch3d.ops.knn import knn_gather, knn_points 16 | 17 | 18 | #core transformation function 19 | def transformRGBToYCoCg(bitdepth, rgb): 20 | r = rgb[:, 0] 21 | g = rgb[:, 1] 22 | b = rgb[:, 2] 23 | co = r - b 24 | t = b + (co >> 1) # co >>1 i.e. co // 2 25 | cg = g - t 26 | y = t + (cg >> 1) 27 | 28 | offset = 1 << bitdepth # 2^bitdepth 29 | 30 | # NB: YCgCoR needs extra 1-bit for chroma 31 | return np.column_stack((y, co + offset, cg + offset)) 32 | 33 | 34 | def transformYCoCgToRGB(bitDepth, ycocg): 35 | 36 | offset = 1 << bitDepth 37 | y0 = ycocg[:,0] 38 | co = ycocg[:,1] - offset 39 | cg = ycocg[:,2] - offset 40 | 41 | t = y0 - (cg >> 1) 42 | 43 | g = cg + t 44 | b = t - (co >> 1) 45 | r = co + b 46 | 47 | maxVal = (1 << bitDepth) - 1 48 | r = np.clip(r, 0, maxVal) 49 | g = np.clip(g, 0, maxVal) 50 | b = np.clip(b, 0, maxVal) 51 | 52 | return np.column_stack((r,g,b)) 53 | 54 | 55 | def read_point_cloud_ycocg(filepath): 56 | pc = PyntCloud.from_file(filepath) 57 | try: 58 | cols=['x', 'y', 'z','red', 'green', 'blue'] 59 | points=pc.points[cols].values 60 | except: 61 | cols = ['x', 'y', 'z', 'r', 'g', 'b'] 62 | points = pc.points[cols].values 63 | color = points[:, 3:].astype(np.int16) 64 | color = transformRGBToYCoCg(8, color) 65 | # color: int 66 | # y channel: 0~255 67 | # co channel: 0~511 (1~511 in our dataset) 68 | # cg channel: 0~511 (34~476 in our dataset) 69 | points[:, 3:] = color.astype(float) 70 | return points 71 | 72 | 73 | def save_point_cloud_ycocg(pc, path): 74 | color = pc[:, 3:] 75 | color = np.round(color).astype(np.int16) # 务必 round 后 再加 astype 76 | color = transformYCoCgToRGB(8, color) 77 | 78 | pc = pd.DataFrame(pc, columns=['x', 'y', 'z', 'red', 'green', 'blue']) 79 | pc[['red','green','blue']] = np.round(color).astype(np.uint8) 80 | cloud = PyntCloud(pc) 81 | cloud.to_file(path) 82 | 83 | def read_point_cloud_reflactance(filepath): 84 | plydata = PlyData.read(filepath) 85 | pc = np.array(np.transpose(np.stack((plydata['vertex']['x'],plydata['vertex']['y'],plydata['vertex']['z'], plydata['vertex']['reflectance'])))).astype(np.float32) 86 | return pc 87 | 88 | 89 | def save_point_cloud_reflactance(pc, path, to_rgb=False): 90 | 91 | if to_rgb: 92 | pc[:, 3:] = pc[:, 3:] / 100 93 | cmap = plt.get_cmap('jet') 94 | color = np.round(cmap(pc[:, 3])[:, :3] * 255) 95 | pc = np.hstack((pc[:, :3], color)) 96 | pc = pd.DataFrame(pc, columns=['x', 'y', 'z', 'red', 'green', 'blue']) 97 | pc[['red','green','blue']] = np.round(np.clip(pc[['red','green','blue']], 0, 255)).astype(np.uint8) 98 | cloud = PyntCloud(pc) 99 | cloud.to_file(path) 100 | else: 101 | scan = pc 102 | vertex = np.array( 103 | [(scan[i,0], scan[i,1], scan[i,2], scan[i,3]) for i in range(scan.shape[0])], 104 | dtype=[ 105 | ("x", np.dtype("float32")), 106 | ("y", np.dtype("float32")), 107 | ("z", np.dtype("float32")), 108 | ("reflectance", np.dtype("uint8")), 109 | ] 110 | ) 111 | PlyElement.describe(vertex, 'vertex', comments=['vertices']) 112 | output_pc = PlyElement.describe(vertex, "vertex") 113 | output_pc = PlyData([output_pc]) 114 | output_pc.write(path) 115 | 116 | 117 | def read_point_clouds_ycocg(file_path_list, bar=True): 118 | print('loading point clouds...') 119 | with multiprocessing.Pool() as p: 120 | if bar: 121 | pcs = list(tqdm(p.imap(read_point_cloud_ycocg, file_path_list, 32), total=len(file_path_list))) 122 | else: 123 | pcs = list(p.imap(read_point_cloud_ycocg, file_path_list, 32)) 124 | return pcs 125 | 126 | 127 | 128 | def n_scale_ball(grouped_xyz): 129 | B, N, K, _ = grouped_xyz.shape 130 | 131 | longest = (grouped_xyz**2).sum(dim=-1).sqrt().max(dim=-1)[0] 132 | scaling = (1) / longest 133 | 134 | grouped_xyz = grouped_xyz * scaling.view(B, N, 1, 1) 135 | 136 | return grouped_xyz 137 | 138 | 139 | class MLP(nn.Module): 140 | def __init__(self, in_channel, mlp, relu, bn): 141 | super(MLP, self).__init__() 142 | 143 | mlp.insert(0, in_channel) 144 | self.mlp_Modules = nn.ModuleList() 145 | for i in range(len(mlp) - 1): 146 | if relu[i]: 147 | if bn[i]: 148 | mlp_Module = nn.Sequential( 149 | nn.Conv2d(mlp[i], mlp[i+1], 1), 150 | nn.BatchNorm2d(mlp[i+1]), 151 | nn.ReLU(), 152 | ) 153 | else: 154 | mlp_Module = nn.Sequential( 155 | nn.Conv2d(mlp[i], mlp[i+1], 1), 156 | nn.ReLU(), 157 | ) 158 | else: 159 | mlp_Module = nn.Sequential( 160 | nn.Conv2d(mlp[i], mlp[i+1], 1), 161 | ) 162 | self.mlp_Modules.append(mlp_Module) 163 | 164 | 165 | def forward(self, points, squeeze=False): 166 | """ 167 | Input: 168 | points: input points position data, [B, C, N] 169 | Return: 170 | points: feature data, [B, D, N] 171 | """ 172 | if squeeze: 173 | points = points.unsqueeze(-1) # [B, C, N, 1] 174 | 175 | for m in self.mlp_Modules: 176 | points = m(points) 177 | # [B, D, N, 1] 178 | 179 | if squeeze: 180 | points = points.squeeze(-1) # [B, D, N] 181 | 182 | return points 183 | 184 | 185 | class QueryMaskedAttention(nn.Module): 186 | def __init__(self, channel): 187 | super(QueryMaskedAttention, self).__init__() 188 | self.channel = channel 189 | self.k_mlp = nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=1) 190 | self.v_mlp = nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=1) 191 | self.pe_multiplier, self.pe_bias = True, True 192 | if self.pe_multiplier: 193 | self.linear_p_multiplier = nn.Sequential( 194 | nn.Conv2d(in_channels=3, out_channels=channel, kernel_size=1), 195 | nn.ReLU(inplace=True), 196 | nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=1), 197 | ) 198 | if self.pe_bias: 199 | self.linear_p_bias = nn.Sequential( 200 | nn.Conv2d(in_channels=3, out_channels=channel, kernel_size=1), 201 | nn.ReLU(inplace=True), 202 | nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=1), 203 | ) 204 | self.weight_encoding = nn.Sequential( 205 | nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=1), 206 | nn.ReLU(inplace=True), 207 | nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=1), 208 | ) 209 | self.residual_emb = nn.Sequential( 210 | nn.ReLU(), 211 | nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=1), 212 | ) 213 | 214 | self.softmax = nn.Softmax(dim=2) 215 | 216 | def forward(self, grouped_xyz, grouped_feature): 217 | 218 | key = self.k_mlp(grouped_feature) # B, C, K, M 219 | value = self.v_mlp(grouped_feature) # B, C, K, M 220 | 221 | relation_qk = key # - query 222 | if self.pe_multiplier: 223 | pem = self.linear_p_multiplier(grouped_xyz) 224 | relation_qk = relation_qk * pem 225 | if self.pe_bias: 226 | peb = self.linear_p_bias(grouped_xyz) 227 | relation_qk = relation_qk + peb 228 | value = value + peb 229 | 230 | weight = self.weight_encoding(relation_qk) 231 | score = self.softmax(weight) # B, C, K, M 232 | 233 | feature = score*value # B, C, K, M 234 | feature = self.residual_emb(feature) # B, C, K, M 235 | 236 | return feature 237 | 238 | 239 | class PT(nn.Module): 240 | def __init__(self, in_channel, out_channel, n_layers): 241 | super(PT, self).__init__() 242 | self.in_channel = in_channel 243 | self.out_channel = out_channel 244 | self.n_layers = n_layers 245 | self.sa_ls, self.sa_emb_ls = nn.ModuleList(), nn.ModuleList() 246 | self.linear_in = nn.Conv2d(in_channel, out_channel, kernel_size=1) 247 | for i in range(n_layers): 248 | self.sa_emb_ls.append(nn.Sequential( 249 | nn.Conv2d(out_channel, out_channel, kernel_size=1), 250 | nn.ReLU(), 251 | )) 252 | self.sa_ls.append(QueryMaskedAttention(out_channel)) 253 | def forward(self, groped_geo, grouped_attr): 254 | """ 255 | Input: 256 | groped_geo: input points position data, [B, M, K, 3] 257 | groped_attr: input points feature data, [B, M, K, 3] 258 | Return: 259 | feature: output feature data, [B, M, C] 260 | """ 261 | groped_geo, grouped_attr = groped_geo.permute((0, 3, 2, 1)), grouped_attr.permute((0, 3, 2, 1)) # B, _, K, M 262 | feature = self.linear_in(grouped_attr) 263 | for i in range(self.n_layers): 264 | identity = feature 265 | feature = self.sa_emb_ls[i](feature) 266 | output = self.sa_ls[i](groped_geo, feature) 267 | feature = output + identity 268 | feature = feature.sum(dim=2).transpose(1, 2) 269 | return feature 270 | 271 | 272 | def get_cdf(mu, sigma): 273 | M, d = sigma.shape 274 | mu = mu.unsqueeze(-1).repeat(1, 1, 256) 275 | sigma = sigma.unsqueeze(-1).repeat(1, 1, 256).clamp(1e-10, 1e10) 276 | gaussian = torch.distributions.laplace.Laplace(mu, sigma) 277 | flag = torch.arange(0, 256).to(sigma.device).view(1, 1, 256).repeat((M, d, 1)) 278 | cdf = gaussian.cdf(flag + 0.5) 279 | 280 | spatial_dimensions = cdf.shape[:-1] + (1,) 281 | zeros = torch.zeros(spatial_dimensions, dtype=cdf.dtype, device=cdf.device) 282 | cdf_with_0 = torch.cat([zeros, cdf], dim=-1) 283 | return cdf_with_0 284 | 285 | 286 | def get_cdf_ycocg(mu, sigma): 287 | M, d = sigma.shape 288 | mu = mu.unsqueeze(-1).repeat(1, 1, 512) 289 | sigma = sigma.unsqueeze(-1).repeat(1, 1, 512).clamp(1e-10, 1e10) 290 | gaussian = torch.distributions.laplace.Laplace(mu, sigma) 291 | flag = torch.arange(0, 512).to(sigma.device).view(1, 1, 512).repeat((M, d, 1)) 292 | cdf = gaussian.cdf(flag + 0.5) 293 | 294 | spatial_dimensions = cdf.shape[:-1] + (1,) 295 | zeros = torch.zeros(spatial_dimensions, dtype=cdf.dtype, device=cdf.device) 296 | cdf_with_0 = torch.cat([zeros, cdf], dim=-1) 297 | return cdf_with_0 298 | 299 | 300 | def get_cdf_reflactance(mu, sigma): 301 | M, d = sigma.shape 302 | mu = mu.unsqueeze(-1).repeat(1, 1, 128) 303 | sigma = sigma.unsqueeze(-1).repeat(1, 1, 128).clamp(1e-10, 1e10) 304 | gaussian = torch.distributions.laplace.Laplace(mu, sigma) 305 | flag = torch.arange(0, 128).to(sigma.device).view(1, 1, 128).repeat((M, d, 1)) 306 | cdf = gaussian.cdf(flag + 0.5) 307 | 308 | spatial_dimensions = cdf.shape[:-1] + (1,) 309 | zeros = torch.zeros(spatial_dimensions, dtype=cdf.dtype, device=cdf.device) 310 | cdf_with_0 = torch.cat([zeros, cdf], dim=-1) 311 | return cdf_with_0 312 | 313 | 314 | def feature_probs_based_mu_sigma(feature, mu, sigma): 315 | sigma = sigma.clamp(1e-10, 1e10) 316 | gaussian = torch.distributions.laplace.Laplace(mu, sigma) 317 | probs = gaussian.cdf(feature + 0.5) - gaussian.cdf(feature - 0.5) 318 | total_bits = torch.sum(torch.clamp(-1.0 * torch.log(probs + 1e-10) / math.log(2.0), 0, 50)) 319 | return total_bits, probs 320 | 321 | 322 | def get_file_size_in_bits(f): 323 | return os.stat(f).st_size * 8 324 | 325 | 326 | def _convert_to_int_and_normalize(cdf_float, needs_normalization): 327 | """Convert floatingpoint CDF to integers. See README for more info. 328 | 329 | The idea is the following: 330 | When we get the cdf here, it is (assumed to be) between 0 and 1, i.e, 331 | cdf \in [0, 1) 332 | (note that 1 should not be included.) 333 | We now want to convert this to int16 but make sure we do not get 334 | the same value twice, as this would break the arithmetic coder 335 | (you need a strictly monotonically increasing function). 336 | So, if needs_normalization==True, we multiply the input CDF 337 | with 2**16 - (Lp - 1). This means that now, 338 | cdf \in [0, 2**16 - (Lp - 1)]. 339 | Then, in a final step, we add an arange(Lp), which is just a line with 340 | slope one. This ensure that for sure, we will get unique, strictly 341 | monotonically increasing CDFs, which are \in [0, 2**16) 342 | """ 343 | Lp = cdf_float.shape[-1] 344 | factor = torch.tensor( 345 | 2, dtype=torch.float32, device=cdf_float.device).pow_(16) 346 | new_max_value = factor 347 | if needs_normalization: 348 | new_max_value = new_max_value - (Lp - 1) 349 | cdf_float = cdf_float.mul(new_max_value) 350 | cdf_float = cdf_float.round() 351 | cdf = cdf_float.to(dtype=torch.int16, non_blocking=True) 352 | if needs_normalization: 353 | r = torch.arange(Lp, dtype=torch.int16, device=cdf.device) 354 | cdf.add_(r) 355 | return cdf 356 | -------------------------------------------------------------------------------- /model/ckpt.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/I2-Multimedia-Lab/PoLoPCAC/03ebbc1d08a750af516f2f2b87d6a34a509ac9e1/model/ckpt.pt -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from pytorch3d.ops.knn import knn_gather, knn_points 5 | 6 | import kit 7 | 8 | class Network(nn.Module): 9 | def __init__(self, local_region, granularity, init_ratio, expand_ratio): 10 | super(Network, self).__init__() 11 | 12 | self.local_region = local_region 13 | self.init_ratio = init_ratio 14 | self.expand_ratio = expand_ratio 15 | self.granularity = granularity 16 | 17 | self.pt = kit.PT(in_channel=3, out_channel=128, n_layers=5) 18 | self.mu_sigma_pred = nn.Sequential( 19 | nn.Linear(128, 64), 20 | nn.ReLU(), 21 | nn.Linear(64, 16), 22 | nn.ReLU(), 23 | nn.Linear(16, 3*2), 24 | ) 25 | 26 | def forward(self, batch_x): 27 | B, N, _ = batch_x.shape 28 | 29 | # random grouping 30 | base_size = min(N//self.init_ratio, self.granularity) 31 | window_size = base_size 32 | 33 | context_ls, target_ls = [], [] 34 | cursor = base_size 35 | 36 | while cursor= args.max_steps: 100 | break 101 | 102 | if global_step >= args.max_steps: 103 | break 104 | --------------------------------------------------------------------------------