├── .github └── workflows │ ├── executor-cd.yml │ └── executor-ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile.gpu ├── README.md ├── config.yml ├── datasets ├── __init__.py └── modelnet40.py ├── executor ├── __init__.py ├── encoder.py └── models │ ├── __init__.py │ ├── classifier_pl.py │ ├── curvenet │ ├── __init__.py │ ├── curvenet.py │ ├── curvenet_utils.py │ └── walk.py │ ├── encoder_pl.py │ ├── modeling.py │ ├── pointconv │ ├── __init__.py │ ├── pointconv.py │ └── pointconv_utils.py │ ├── pointmlp │ ├── __init__.py │ └── pointmlp.py │ ├── pointnet │ ├── __init__.py │ └── pointnet.py │ ├── pointnet2 │ ├── __init__.py │ └── pointnet2.py │ ├── pooling.py │ └── repsurf │ ├── __init__.py │ ├── polar_utils.py │ ├── recons_utils.py │ ├── repsurf.py │ └── repsurf_utils.py ├── finetune.py ├── finetune_pl.py ├── gpu_requirements.txt ├── manifest.yml ├── preprocess ├── __init__.py └── utils.py ├── pretrain_pl.py ├── requirements.txt └── tests ├── __init__.py ├── conftest.py ├── integration ├── __init__.py └── test_encoder.py ├── requirements.txt └── unit ├── __init__.py └── test_exec.py /.github/workflows/executor-cd.yml: -------------------------------------------------------------------------------- 1 | name: CD 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | release: 8 | types: 9 | - created 10 | workflow_dispatch: 11 | # pull_request: 12 | # uncomment the above to test CD in a PR 13 | 14 | jobs: 15 | call-external: 16 | uses: jina-ai/workflows-executors/.github/workflows/cd.yml@master 17 | with: 18 | event_name: ${{ github.event_name }} 19 | secrets: 20 | secret: ${{ secrets.EXECUTOR_SECRET }} 21 | -------------------------------------------------------------------------------- /.github/workflows/executor-ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | call-external: 7 | uses: jina-ai/workflows-executors/.github/workflows/ci.yml@master 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ 131 | workspace 132 | Dockerfile 133 | __jina__.Dockerfile 134 | nohup.out 135 | checkpoints 136 | finetune_logs/ 137 | lightning_logs/ 138 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | #- repo: https://github.com/terrencepreilly/darglint 3 | # rev: v1.5.8 4 | # hooks: 5 | # - id: darglint 6 | # files: pqlite/ 7 | # exclude: docs/ 8 | # args: 9 | # - --message-template={path}:{line} {msg_id} {msg} 10 | # - -s=sphinx 11 | # - -z=full 12 | # - -v=2 13 | #- repo: https://github.com/pycqa/pydocstyle 14 | # rev: 5.1.1 # pick a git hash / tag to point to 15 | # hooks: 16 | # - id: pydocstyle 17 | # files: pqlite/ 18 | # exclude: docs/ 19 | # args: 20 | # - --select=D101,D102,D103 21 | - repo: https://github.com/timothycrosley/isort 22 | rev: 5.8.0 23 | hooks: 24 | - id: isort 25 | args: ["--profile", "black"] 26 | - repo: https://github.com/ambv/black 27 | # rev: 20.8b1 28 | rev: 22.3.0 29 | hooks: 30 | - id: black 31 | types: [python] 32 | exclude: docs/ 33 | args: 34 | - -S 35 | - repo: https://github.com/pre-commit/pre-commit-hooks 36 | rev: v4.0.1 37 | hooks: 38 | - id: trailing-whitespace 39 | - id: check-yaml 40 | - id: end-of-file-fixer 41 | - id: requirements-txt-fixer 42 | - id: double-quote-string-fixer 43 | - id: check-merge-conflict 44 | - id: fix-encoding-pragma 45 | args: ["--remove"] 46 | - id: mixed-line-ending 47 | args: ["--fix=lf"] 48 | -------------------------------------------------------------------------------- /Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM jinaai/jina:2-py37-perf 2 | 3 | COPY gpu_requirements.txt gpu_requirements.txt 4 | RUN pip install --no-cache-dir -r gpu_requirements.txt 5 | 6 | COPY . /workdir/ 7 | WORKDIR /workdir 8 | 9 | ENTRYPOINT ["jina", "executor", "--uses", "config.yml"] 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D Mesh Encoder 2 | 3 | An Executor that receives Documents containing point sets data in its `tensor` attribute, with shape `(N, 3)` and encodes it to embeddings of shape `(D,)`. 4 | Now, the following pretrained models are ready to be used to create embeddings: 5 | 6 | - **PointConv-Shapenet-d512**: A **PointConv** model resulted in **512** dimension of embeddings, which is finetuned based on ShapeNet dataset. 7 | - **PointConv-Shapenet-d1024**: A **PointConv** model resulted in **1024** dimension of embeddings, which is finetuned based on ShapeNet dataset. 8 | 9 | 10 | 11 | ## Usage 12 | 13 | #### via Docker image (recommended) 14 | 15 | ```python 16 | from jina import Flow 17 | 18 | f = Flow().add(uses='jinahub+docker://MeshDataEncoder', \ 19 | uses_with={'pretrained_model': 'PointConv-Shapenet-d512'}) 20 | ``` 21 | 22 | #### via source code 23 | 24 | ```python 25 | from jina import Flow 26 | 27 | f = Flow().add(uses='jinahub://MeshDataEncoder', \ 28 | uses_with={'pretrained_model': 'PointConv-Shapenet-d512'}) 29 | ``` 30 | 31 | This Executor offers a GPU tag to speed up encoding. For more information on how to run the executor on GPU, check out the documentation. 32 | 33 | 34 | ## How to finetune pretrained-model? 35 | 36 | ### Finetune pretrained-model with finetuner 37 | #### install finetuner 38 | 39 | ```bash 40 | $ pip install finetuner 41 | ``` 42 | 43 | #### prepare dataset 44 | 45 | TBD... 46 | 47 | #### finetuning model with labeled dataset 48 | 49 | ```bash 50 | $ python finetune.py --help 51 | 52 | $ python finetune.py --model_name pointconv \ 53 | --train_dataset /path/to/train.bin \ 54 | --eval_dataset /path/to/eval.bin \ 55 | --batch_size 128 \ 56 | --epochs 50 57 | ``` 58 | 59 | #### finetuning model with unlabeled dataset 60 | 61 | ```bash 62 | $ python finetune.py --model_name pointconv \ 63 | --train_dataset /path/to/unlabeled_data.bin \ 64 | --interactive 65 | ``` 66 | 67 | ### Finetune pretrained-model with Pytorch Lightning 68 | #### prepare dataset 69 | 70 | To use your customized dataset, you should design your own dataset code, like those in `datasets/` directory. Here `datasets/modelnet40.py` is an example, you must at least implement `__len__` and `__getitem__` functions according to your logics. 71 | 72 | 73 | ```python 74 | class ModelNet40(torch.utils.data.Dataset): 75 | def __init__(self, data_path, sample_points=1024, seed=10) -> None: 76 | super().__init__() 77 | # extract point data and labels from your file, e.g. npz, h5, etc. 78 | data = np.load(data_path) 79 | self.points = data['tensor'] 80 | self.labels = data['labels'] 81 | self.sample_points = sample_points 82 | 83 | def __len__(self): 84 | # return the total length of your data 85 | return len(self.labels) 86 | 87 | def __getitem__(self, index): 88 | return ( 89 | # process on the fly, if needed 90 | preprocess(self.points[index], num_points=self.sample_points), 91 | self.labels[index], 92 | ) 93 | ``` 94 | 95 | #### finetuning model with labeled dataset 96 | 97 | Now we support PointNet, PointConv, PointNet++, PointMLP, RepSurf and Curvenet. To know more details about the arguments, please run `python finetune_pl.py --help` in cmd. 98 | ```bash 99 | $ python finetune_pl.py --help 100 | 101 | $ python finetune_pl.py --model_name pointconv \ 102 | --train_dataset /path/to/train.bin \ 103 | --eval_dataset /path/to/eval.bin \ 104 | --split_ratio 0.8 \ 105 | --checkpoint_path /path/to/checkpoint/ \ 106 | --embed_dim 512 \ 107 | --hidden_dim 1024 \ 108 | --batch_size 128 \ 109 | --epochs 50 110 | ``` 111 | 112 | ## Benchmark 113 | 114 | Below is our pretrained models' performance of 3D point cloud classification on ModelNet40 official test dataset. 115 | 116 | | dataset | model name | batch size | embedding dims | test loss | test overall accuracy | 117 | |------------|------------|------------|----------------|-----------|-----------------------| 118 | | modelnet40 | PointNet | 32 | 256 | 0.63 | 0.8225 | 119 | | modelnet40 | PointNet | 32 | 512 | 0.63 | 0.8254 | 120 | | modelnet40 | PointNet | 32 | 1024 | 0.65 | 0.8148 | 121 | | modelnet40 | PointNet++ | 32 | 256 | 0.48 | 0.863 | 122 | | modelnet40 | PointNet++ | 32 | 512 | 0.44 | 0.8712 | 123 | | modelnet40 | PointNet++ | 32 | 1024 | 0.47 | 0.8655 | 124 | | modelnet40 | PointConv | 32 | 128 | 0.55 | 0.8452 | 125 | | modelnet40 | PointConv | 32 | 256 | 0.53 | 0.8517 | 126 | | modelnet40 | PointConv | 32 | 512 | 0.54 | 0.8505 | 127 | | modelnet40 | PointConv | 32 | 1024 | 0.58 | 0.8533 | 128 | | modelnet40 | PointMLP | 32 | 64 | 0.46 | 0.8728 | 129 | | modelnet40 | RepSurf | 32 | 256 | 0.44 | 0.8776 | 130 | | modelnet40 | RepSurf | 32 | 512 | 0.45 | 0.8655 | 131 | | modelnet40 | RepSurf | 32 | 1024 | 0.43 | 0.8724 | 132 | | modelnet40 | CurveNet | 32 | 128 | 0.45 | 0.8651 | 133 | | modelnet40 | CurveNet | 32 | 256 | 0.45 | 0.8647 | 134 | | modelnet40 | CurveNet | 32 | 512 | 0.47 | 0.8687 | 135 | | modelnet40 | CurveNet | 32 | 1024 | 0.48 | 0.857 | 136 | 137 | ## References 138 | 139 | - [PointNet](https://arxiv.org/abs/1612.00593): Deep Learning on Point Sets for 3D Classification and Segmentation 140 | - [PointConv](https://arxiv.org/abs/1811.07246): Deep Convolutional Networks on 3D Point Clouds 141 | - [PointNet++](http://arxiv.org/abs/1706.02413): PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space 142 | - [PointMLP](http://arxiv.org/abs/2202.07123): Rethinking Network Design and Local Geometry in Point Cloud 143 | - [RepSurf](https://arxiv.org/abs/2205.05740): Surface Representation for Point Clouds 144 | - [CurveNet](https://arxiv.org/abs/2105.01288): Walk in the Cloud: Learning Curves for Point Clouds Shape Analysis 145 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | jtype: MeshDataEncoder 2 | metas: 3 | py_modules: 4 | - executor/__init__.py 5 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .modelnet40 import ModelNet40 2 | -------------------------------------------------------------------------------- /datasets/modelnet40.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | 4 | from preprocess import preprocess 5 | 6 | 7 | class ModelNet40(Dataset): 8 | def __init__(self, data_path, sample_points=1024, seed=10) -> None: 9 | super().__init__() 10 | data = np.load(data_path) 11 | self.points = data['tensor'] 12 | self.labels = data['labels'] 13 | self.sample_points = sample_points 14 | 15 | def __len__(self): 16 | return len(self.labels) 17 | 18 | def __getitem__(self, index): 19 | return ( 20 | preprocess(self.points[index], num_points=self.sample_points), 21 | self.labels[index], 22 | ) 23 | -------------------------------------------------------------------------------- /executor/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import MeshDataEncoder 2 | from .models.classifier_pl import MeshDataClassifierPL 3 | from .models.encoder_pl import MeshDataEncoderPL 4 | -------------------------------------------------------------------------------- /executor/encoder.py: -------------------------------------------------------------------------------- 1 | __copyright__ = 'Copyright (c) 2022 Jina AI Limited. All rights reserved.' 2 | __license__ = 'Apache-2.0' 3 | 4 | from typing import Dict, Optional 5 | 6 | import numpy as np 7 | import torch 8 | from jina import Document, DocumentArray, Executor, requests 9 | 10 | from .models import MeshDataModel 11 | 12 | AVAILABLE_MODELS = { 13 | 'PointNet-Shapenet-d1024': { 14 | 'model_name': 'pointnet', 15 | 'hidden_dim': 1024, 16 | 'embed_dim': 1024, 17 | 'model_path': '', 18 | }, 19 | 'PointConv-Shapenet-d1024': { 20 | 'model_name': 'pointconv', 21 | 'hidden_dim': 1024, 22 | 'embed_dim': 1024, 23 | 'model_path': 'https://jina-pretrained-models.s3.us-west-1.amazonaws.com/mesh_models/pointconv-shapenet-d1024.pth', 24 | }, 25 | 'PointNet-Shapenet-d512': { 26 | 'model_name': 'pointnet', 27 | 'hidden_dim': 1024, 28 | 'embed_dim': 512, 29 | 'model_path': '', 30 | }, 31 | 'PointConv-Shapenet-d512': { 32 | 'model_name': 'pointconv', 33 | 'hidden_dim': 1024, 34 | 'embed_dim': 512, 35 | 'model_path': 'https://jina-pretrained-models.s3.us-west-1.amazonaws.com/mesh_models/pointconv-shapenet-d512.pth', 36 | }, 37 | } 38 | 39 | 40 | def normalize(doc: 'Document'): 41 | points = doc.tensor 42 | points = points - np.expand_dims(np.mean(points, axis=0), 0) # center 43 | dist = np.max(np.sqrt(np.sum(points ** 2, axis=1)), 0) 44 | points = points / dist # scale 45 | doc.tensor = points.astype(np.float32) 46 | return doc 47 | 48 | 49 | class MeshDataEncoder(Executor): 50 | """ 51 | An executor that encodes 3D mesh data document. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | pretrained_model: str = 'PointConv-Shapenet-d512', 57 | default_model_name: str = 'pointconv', 58 | model_path: Optional[str] = None, 59 | hidden_dim: int = 1024, 60 | embed_dim: int = 1024, 61 | input_shape: str = 'bnc', 62 | device: str = 'cpu', 63 | batch_size: int = 64, 64 | filters: Optional[dict] = None, 65 | **kwargs, 66 | ) -> None: 67 | """ 68 | :param pretrained_model: The pretrained model path. 69 | :param default_model_name: The name of the default model. Models listed on: 70 | https://github.com/jina-ai/executor-3d-encoder 71 | :param model_path: The path of the trained models checkpoint. 72 | :param emb_dims: The dimension of embeddings. 73 | :param input_shape: The shape of Input Point Cloud (b: batch, n: no of points, c: channels) 74 | :param device: The device to use. 75 | :param batch_size: The batch size to use. 76 | :param filters: The filter condition that the documents need to fulfill before reaching the Executor. 77 | The condition can be defined in the form of a `DocArray query condition ` 78 | """ 79 | super().__init__(**kwargs) 80 | 81 | model_path = None 82 | if pretrained_model in AVAILABLE_MODELS: 83 | config = AVAILABLE_MODELS[pretrained_model] 84 | model_name = config.pop('model_name') 85 | model_path = config.pop('model_path') 86 | embed_dim = config.pop('embed_dim') 87 | hidden_dim = config.pop('hidden_dim') 88 | else: 89 | model_name = default_model_name 90 | 91 | self._model = MeshDataModel( 92 | model_name=model_name, 93 | hidden_dim=hidden_dim, 94 | embed_dim=embed_dim, 95 | pretrained=False if model_path else True, 96 | ) 97 | self._model.eval() 98 | 99 | if model_path: 100 | if model_path.startswith('http'): 101 | import os 102 | import urllib.request 103 | from pathlib import Path 104 | 105 | cache_dir = Path.home() / '.cache' / 'jina-models' 106 | cache_dir.mkdir(parents=True, exist_ok=True) 107 | 108 | file_url = model_path 109 | file_name = os.path.basename(model_path) 110 | model_path = cache_dir / file_name 111 | 112 | if not model_path.exists(): 113 | print(f'=> download {file_url} to {model_path}') 114 | urllib.request.urlretrieve(file_url, model_path) 115 | 116 | checkpoint = torch.load(model_path, map_location='cpu') 117 | self._model.load_state_dict(checkpoint) 118 | 119 | self._device = device 120 | self._batch_size = batch_size 121 | self._filters = filters 122 | 123 | @requests 124 | def encode(self, docs: 'DocumentArray', **_): 125 | """Encode docs.""" 126 | if docs is None: 127 | return 128 | 129 | if self._filters: 130 | filtered_docs = docs.find(self._filters) 131 | else: 132 | filtered_docs = docs 133 | 134 | filtered_docs.apply(normalize) 135 | filtered_docs.embed( 136 | self._model, 137 | device=self._device, 138 | batch_size=self._batch_size, 139 | to_numpy=True, 140 | ) 141 | -------------------------------------------------------------------------------- /executor/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling import MeshDataModel 2 | -------------------------------------------------------------------------------- /executor/models/classifier_pl.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch.nn import functional as F 4 | from torchmetrics.functional import accuracy 5 | 6 | from .modeling import get_model 7 | 8 | DEFAULT_MODEL_NAME = 'pointconv' 9 | 10 | 11 | class MeshDataClassifierPL(pl.LightningModule): 12 | def __init__( 13 | self, 14 | model_name: str = DEFAULT_MODEL_NAME, 15 | hidden_dim: int = 1024, 16 | input_shape: str = 'bnc', 17 | device: str = 'cpu', 18 | batch_size: int = 64, 19 | **kwargs, 20 | ) -> None: 21 | super().__init__(**kwargs) 22 | 23 | self.save_hyperparameters() 24 | 25 | self._model = get_model( 26 | model_name=model_name, 27 | hidden_dim=hidden_dim, 28 | input_shape=input_shape, 29 | classifier=True, 30 | ) 31 | 32 | self._device = device 33 | self._batch_size = batch_size 34 | # bnc 35 | self.example_input_array = torch.zeros((batch_size, 1024, 3)) 36 | self._model_name = model_name 37 | 38 | def forward(self, x): 39 | return self._model(x) 40 | 41 | def configure_optimizers(self): 42 | # optimizer and scheduler adapted from upstream 43 | if self._model_name == 'pointmlp': 44 | # 300 epochs 45 | optimizer = torch.optim.SGD( 46 | self.parameters(), lr=0.1, momentum=0.9, weight_decay=2e-4 47 | ) 48 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 49 | optimizer, 300, eta_min=0.005, last_epoch=-1 50 | ) 51 | elif self._model_name == 'repsurf': 52 | # 500 epochs 53 | optimizer = torch.optim.Adam( 54 | self.parameters(), 55 | lr=0.001, 56 | betas=(0.9, 0.999), 57 | eps=1e-08, 58 | weight_decay=1e-4, 59 | ) 60 | scheduler = torch.optim.lr_scheduler.StepLR( 61 | optimizer, step_size=20, gamma=0.7 62 | ) 63 | else: 64 | optimizer = torch.optim.Adam(self.parameters(), lr=5e-4) 65 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 66 | optimizer, milestones=[30, 60], gamma=0.5 67 | ) 68 | 69 | return {'optimizer': optimizer, 'lr_scheduler': scheduler} 70 | 71 | def training_step(self, train_batch, _batch_idx): 72 | x, y = train_batch 73 | logits = self._model(x) 74 | loss = F.nll_loss(F.log_softmax(logits, dim=1), y) 75 | self.log('train_loss', loss) 76 | return loss 77 | 78 | def evaluate(self, batch, stage): 79 | x, y = batch 80 | logits = self._model(x) 81 | loss = F.nll_loss(F.log_softmax(logits, dim=1), y) 82 | 83 | preds = torch.argmax(logits, dim=1) 84 | acc = accuracy(preds, y) 85 | self.log(f'{stage}_loss', loss, prog_bar=True) 86 | self.log(f'{stage}_acc', acc, prog_bar=True) 87 | 88 | def validation_step(self, val_batch, _batch_idx): 89 | self.evaluate(val_batch, 'val') 90 | 91 | def test_step(self, test_batch, _batch_idx): 92 | self.evaluate(test_batch, 'test') 93 | -------------------------------------------------------------------------------- /executor/models/curvenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .curvenet import CurveNet 2 | -------------------------------------------------------------------------------- /executor/models/curvenet/curvenet.py: -------------------------------------------------------------------------------- 1 | """Adapted from: https://github.com/tiangexiang/CurveNet/blob/main/core/models/curvenet_cls.py 2 | """ 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .curvenet_utils import CIC, LPFA 9 | 10 | curve_config = { 11 | 'default': [[100, 5], [100, 5], None, None], 12 | 'long': [[10, 30], None, None, None], 13 | } 14 | 15 | 16 | class CurveNet(nn.Module): 17 | def __init__( 18 | self, 19 | emb_dims=1024, 20 | num_classes=40, 21 | k=20, 22 | setting='default', 23 | classifier=False, 24 | input_shape='bnc', 25 | ): 26 | """The code is adapted from https://github.com/tiangexiang/CurveNet.git 27 | 28 | Args: 29 | emb_dims (int, optional): The dimensions for output embeddings. Defaults to 1024. 30 | num_classes (int, optional): the number of classes in classification tasks. Defaults to 40. 31 | k (int, optional): top k for clustering. Defaults to 20. 32 | setting (str, optional): choices for curvenet's settings. Defaults to 'default'. 33 | classifier (bool, optional): use curvenet as a classifier or a representative model. Defaults to False. 34 | input_shape (str, optional): the shape of the input data, which can be 'bnc' or 'bcn'. Defaults to 'bnc'. 35 | """ 36 | super(CurveNet, self).__init__() 37 | 38 | self.use_classifier = classifier 39 | self.input_shape = input_shape 40 | 41 | assert setting in curve_config 42 | 43 | additional_channel = 32 44 | self.lpfa = LPFA(9, additional_channel, k=k, mlp_num=1, initial=True) 45 | 46 | # encoder 47 | self.cic11 = CIC( 48 | npoint=1024, 49 | radius=0.05, 50 | k=k, 51 | in_channels=additional_channel, 52 | output_channels=64, 53 | bottleneck_ratio=2, 54 | mlp_num=1, 55 | curve_config=curve_config[setting][0], 56 | ) 57 | self.cic12 = CIC( 58 | npoint=1024, 59 | radius=0.05, 60 | k=k, 61 | in_channels=64, 62 | output_channels=64, 63 | bottleneck_ratio=4, 64 | mlp_num=1, 65 | curve_config=curve_config[setting][0], 66 | ) 67 | 68 | self.cic21 = CIC( 69 | npoint=1024, 70 | radius=0.05, 71 | k=k, 72 | in_channels=64, 73 | output_channels=128, 74 | bottleneck_ratio=2, 75 | mlp_num=1, 76 | curve_config=curve_config[setting][1], 77 | ) 78 | self.cic22 = CIC( 79 | npoint=1024, 80 | radius=0.1, 81 | k=k, 82 | in_channels=128, 83 | output_channels=128, 84 | bottleneck_ratio=4, 85 | mlp_num=1, 86 | curve_config=curve_config[setting][1], 87 | ) 88 | 89 | self.cic31 = CIC( 90 | npoint=256, 91 | radius=0.1, 92 | k=k, 93 | in_channels=128, 94 | output_channels=256, 95 | bottleneck_ratio=2, 96 | mlp_num=1, 97 | curve_config=curve_config[setting][2], 98 | ) 99 | self.cic32 = CIC( 100 | npoint=256, 101 | radius=0.2, 102 | k=k, 103 | in_channels=256, 104 | output_channels=256, 105 | bottleneck_ratio=4, 106 | mlp_num=1, 107 | curve_config=curve_config[setting][2], 108 | ) 109 | 110 | self.cic41 = CIC( 111 | npoint=64, 112 | radius=0.2, 113 | k=k, 114 | in_channels=256, 115 | output_channels=512, 116 | bottleneck_ratio=2, 117 | mlp_num=1, 118 | curve_config=curve_config[setting][3], 119 | ) 120 | self.cic42 = CIC( 121 | npoint=64, 122 | radius=0.4, 123 | k=k, 124 | in_channels=512, 125 | output_channels=512, 126 | bottleneck_ratio=4, 127 | mlp_num=1, 128 | curve_config=curve_config[setting][3], 129 | ) 130 | 131 | self.conv0 = nn.Sequential( 132 | nn.Conv1d(512, 1024, kernel_size=1, bias=False), 133 | nn.BatchNorm1d(1024), 134 | nn.ReLU(inplace=True), 135 | ) 136 | self.conv1 = nn.Linear(1024 * 2, emb_dims, bias=False) 137 | self.conv2 = nn.Linear(emb_dims, num_classes) 138 | self.bn1 = nn.BatchNorm1d(emb_dims) 139 | self.dp1 = nn.Dropout(p=0.5) 140 | 141 | def forward(self, xyz): 142 | if self.input_shape == 'bnc': 143 | xyz = xyz.permute(0, 2, 1) 144 | l0_points = self.lpfa(xyz, xyz) 145 | 146 | l1_xyz, l1_points = self.cic11(xyz, l0_points) 147 | l1_xyz, l1_points = self.cic12(l1_xyz, l1_points) 148 | 149 | l2_xyz, l2_points = self.cic21(l1_xyz, l1_points) 150 | l2_xyz, l2_points = self.cic22(l2_xyz, l2_points) 151 | 152 | l3_xyz, l3_points = self.cic31(l2_xyz, l2_points) 153 | l3_xyz, l3_points = self.cic32(l3_xyz, l3_points) 154 | 155 | l4_xyz, l4_points = self.cic41(l3_xyz, l3_points) 156 | l4_xyz, l4_points = self.cic42(l4_xyz, l4_points) 157 | 158 | x = self.conv0(l4_points) 159 | x_max = F.adaptive_max_pool1d(x, 1) 160 | x_avg = F.adaptive_avg_pool1d(x, 1) 161 | 162 | x = torch.cat((x_max, x_avg), dim=1).squeeze(-1) 163 | x = F.relu(self.bn1(self.conv1(x).unsqueeze(-1)), inplace=True).squeeze(-1) 164 | if self.use_classifier: 165 | x = self.dp1(x) 166 | x = self.conv2(x) 167 | return x 168 | -------------------------------------------------------------------------------- /executor/models/curvenet/curvenet_utils.py: -------------------------------------------------------------------------------- 1 | """Adapted from: https://github.com/tiangexiang/CurveNet/blob/main/core/models/curvenet_util.py 2 | """ 3 | from time import time 4 | from tracemalloc import start 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .walk import Walk 12 | 13 | 14 | def knn(x, k): 15 | k = k + 1 16 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 17 | xx = torch.sum(x**2, dim=1, keepdim=True) 18 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 19 | 20 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 21 | return idx 22 | 23 | 24 | def normal_knn(x, k): 25 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 26 | xx = torch.sum(x**2, dim=1, keepdim=True) 27 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 28 | 29 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 30 | return idx 31 | 32 | 33 | def pc_normalize(pc): 34 | l = pc.shape[0] 35 | centroid = np.mean(pc, axis=0) 36 | pc = pc - centroid 37 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 38 | pc = pc / m 39 | return pc 40 | 41 | 42 | def square_distance(src, dst): 43 | """ 44 | Calculate Euclid distance between each two points. 45 | """ 46 | B, N, _ = src.shape 47 | _, M, _ = dst.shape 48 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 49 | dist += torch.sum(src**2, -1).view(B, N, 1) 50 | dist += torch.sum(dst**2, -1).view(B, 1, M) 51 | return dist 52 | 53 | 54 | def index_points(points, idx): 55 | """ 56 | Input: 57 | points: input points data, [B, N, C] 58 | idx: sample index data, [B, S] 59 | Return: 60 | new_points:, indexed points data, [B, S, C] 61 | """ 62 | device = points.device 63 | B = points.shape[0] 64 | view_shape = list(idx.shape) 65 | view_shape[1:] = [1] * (len(view_shape) - 1) 66 | repeat_shape = list(idx.shape) 67 | repeat_shape[0] = 1 68 | batch_indices = ( 69 | torch.arange(B, dtype=torch.long) 70 | .to(device) 71 | .view(view_shape) 72 | .repeat(repeat_shape) 73 | ) 74 | new_points = points[batch_indices, idx, :] 75 | return new_points 76 | 77 | 78 | def farthest_point_sample(xyz, npoint): 79 | """ 80 | Input: 81 | xyz: pointcloud data, [B, N, 3] 82 | npoint: number of samples 83 | Return: 84 | centroids: sampled pointcloud index, [B, npoint] 85 | """ 86 | device = xyz.device 87 | B, N, C = xyz.shape 88 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 89 | distance = torch.ones(B, N).to(device) * 1e10 90 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) * 0 91 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 92 | for i in range(npoint): 93 | centroids[:, i] = farthest 94 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 95 | dist = torch.sum((xyz - centroid) ** 2, -1) 96 | mask = dist < distance 97 | distance[mask] = dist[mask] 98 | farthest = torch.max(distance, -1)[1] 99 | return centroids 100 | 101 | 102 | def query_ball_point(radius, nsample, xyz, new_xyz): 103 | """ 104 | Input: 105 | radius: local region radius 106 | nsample: max sample number in local region 107 | xyz: all points, [B, N, 3] 108 | new_xyz: query points, [B, S, 3] 109 | Return: 110 | group_idx: grouped points index, [B, S, nsample] 111 | """ 112 | device = xyz.device 113 | B, N, C = xyz.shape 114 | _, S, _ = new_xyz.shape 115 | group_idx = ( 116 | torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 117 | ) 118 | sqrdists = square_distance(new_xyz, xyz) 119 | group_idx[sqrdists > radius**2] = N 120 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 121 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 122 | mask = group_idx == N 123 | group_idx[mask] = group_first[mask] 124 | return group_idx 125 | 126 | 127 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 128 | """ 129 | Input: 130 | npoint: 131 | radius: 132 | nsample: 133 | xyz: input points position data, [B, N, 3] 134 | points: input points data, [B, N, D] 135 | Return: 136 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 137 | new_points: sampled points data, [B, npoint, nsample, 3+D] 138 | """ 139 | new_xyz = index_points(xyz, farthest_point_sample(xyz, npoint)) 140 | torch.cuda.empty_cache() 141 | 142 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 143 | torch.cuda.empty_cache() 144 | 145 | new_points = index_points(points, idx) 146 | torch.cuda.empty_cache() 147 | 148 | if returnfps: 149 | return new_xyz, new_points, idx 150 | else: 151 | return new_xyz, new_points 152 | 153 | 154 | class Attention_block(nn.Module): 155 | ''' 156 | Used in attention U-Net. 157 | ''' 158 | 159 | def __init__(self, F_g, F_l, F_int): 160 | super(Attention_block, self).__init__() 161 | self.W_g = nn.Sequential( 162 | nn.Conv1d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), 163 | nn.BatchNorm1d(F_int), 164 | ) 165 | 166 | self.W_x = nn.Sequential( 167 | nn.Conv1d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), 168 | nn.BatchNorm1d(F_int), 169 | ) 170 | 171 | self.psi = nn.Sequential( 172 | nn.Conv1d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), 173 | nn.BatchNorm1d(1), 174 | nn.Sigmoid(), 175 | ) 176 | 177 | def forward(self, g, x): 178 | g1 = self.W_g(g) 179 | x1 = self.W_x(x) 180 | psi = F.leaky_relu(g1 + x1, negative_slope=0.2) 181 | psi = self.psi(psi) 182 | 183 | return psi, 1.0 - psi 184 | 185 | 186 | class LPFA(nn.Module): 187 | def __init__( 188 | self, in_channel, out_channel, k, mlp_num=2, initial=False, shape='bnc' 189 | ): 190 | super(LPFA, self).__init__() 191 | self.k = k 192 | self.initial = initial 193 | self.shape = shape 194 | 195 | if not initial: 196 | self.xyz2feature = nn.Sequential( 197 | nn.Conv2d(9, in_channel, kernel_size=1, bias=False), 198 | nn.BatchNorm2d(in_channel), 199 | ) 200 | 201 | self.mlp = [] 202 | for _ in range(mlp_num): 203 | self.mlp.append( 204 | nn.Sequential( 205 | nn.Conv2d(in_channel, out_channel, 1, bias=False), 206 | nn.BatchNorm2d(out_channel), 207 | nn.LeakyReLU(0.2), 208 | ) 209 | ) 210 | in_channel = out_channel 211 | self.mlp = nn.Sequential(*self.mlp) 212 | 213 | def forward(self, x, xyz, idx=None): 214 | x = self.group_feature(x, xyz, idx) 215 | x = self.mlp(x) 216 | 217 | if self.initial: 218 | x = x.max(dim=-1, keepdim=False)[0] 219 | else: 220 | x = x.mean(dim=-1, keepdim=False) 221 | 222 | return x 223 | 224 | def group_feature(self, x, xyz, idx): 225 | batch_size, num_dims, num_points = x.size() 226 | 227 | if idx is None: 228 | idx = knn(xyz, k=self.k)[:, :, : self.k] # (batch_size, num_points, k) 229 | 230 | idx_base = ( 231 | torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points 232 | ) 233 | idx = idx + idx_base 234 | idx = idx.view(-1) 235 | 236 | xyz = xyz.transpose(2, 1).contiguous() # bs, n, 3 237 | point_feature = xyz.view(batch_size * num_points, -1)[idx, :] 238 | point_feature = point_feature.view( 239 | batch_size, num_points, self.k, -1 240 | ) # bs, n, k, 3 241 | points = xyz.view(batch_size, num_points, 1, 3).expand( 242 | -1, -1, self.k, -1 243 | ) # bs, n, k, 3 244 | 245 | point_feature = ( 246 | torch.cat((points, point_feature, point_feature - points), dim=3) 247 | .permute(0, 3, 1, 2) 248 | .contiguous() 249 | ) 250 | 251 | if self.initial: 252 | return point_feature 253 | 254 | x = x.transpose(2, 1).contiguous() # bs, n, c 255 | feature = x.view(batch_size * num_points, -1)[idx, :] 256 | feature = feature.view(batch_size, num_points, self.k, num_dims) # bs, n, k, c 257 | x = x.view(batch_size, num_points, 1, num_dims) 258 | feature = feature - x 259 | 260 | feature = feature.permute(0, 3, 1, 2).contiguous() 261 | point_feature = self.xyz2feature(point_feature) # bs, c, n, k 262 | feature = F.leaky_relu(feature + point_feature, 0.2) 263 | return feature # bs, c, n, k 264 | 265 | 266 | class PointNetFeaturePropagation(nn.Module): 267 | def __init__(self, in_channel, mlp, att=None): 268 | super(PointNetFeaturePropagation, self).__init__() 269 | self.mlp_convs = nn.ModuleList() 270 | self.mlp_bns = nn.ModuleList() 271 | last_channel = in_channel 272 | self.att = None 273 | if att is not None: 274 | self.att = Attention_block(F_g=att[0], F_l=att[1], F_int=att[2]) 275 | 276 | for out_channel in mlp: 277 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 278 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 279 | last_channel = out_channel 280 | 281 | def forward(self, xyz1, xyz2, points1, points2): 282 | """ 283 | Input: 284 | xyz1: input points position data, [B, C, N] 285 | xyz2: sampled input points position data, [B, C, S], skipped xyz 286 | points1: input points data, [B, D, N] 287 | points2: input points data, [B, D, S], skipped features 288 | Return: 289 | new_points: upsampled points data, [B, D', N] 290 | """ 291 | xyz1 = xyz1.permute(0, 2, 1) 292 | xyz2 = xyz2.permute(0, 2, 1) 293 | 294 | points2 = points2.permute(0, 2, 1) 295 | B, N, C = xyz1.shape 296 | _, S, _ = xyz2.shape 297 | 298 | if S == 1: 299 | interpolated_points = points2.repeat(1, N, 1) 300 | else: 301 | dists = square_distance(xyz1, xyz2) 302 | dists, idx = dists.sort(dim=-1) 303 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 304 | 305 | dist_recip = 1.0 / (dists + 1e-8) 306 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 307 | weight = dist_recip / norm 308 | interpolated_points = torch.sum( 309 | index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2 310 | ) 311 | 312 | # skip attention 313 | if self.att is not None: 314 | psix, psig = self.att(interpolated_points.permute(0, 2, 1), points1) 315 | points1 = points1 * psix 316 | 317 | if points1 is not None: 318 | points1 = points1.permute(0, 2, 1) 319 | new_points = torch.cat([points1, interpolated_points], dim=-1) 320 | else: 321 | new_points = interpolated_points 322 | 323 | new_points = new_points.permute(0, 2, 1) 324 | 325 | for i, conv in enumerate(self.mlp_convs): 326 | bn = self.mlp_bns[i] 327 | new_points = F.leaky_relu(bn(conv(new_points)), 0.2) 328 | 329 | return new_points 330 | 331 | 332 | class CIC(nn.Module): 333 | def __init__( 334 | self, 335 | npoint, 336 | radius, 337 | k, 338 | in_channels, 339 | output_channels, 340 | bottleneck_ratio=2, 341 | mlp_num=2, 342 | curve_config=None, 343 | ): 344 | super(CIC, self).__init__() 345 | self.in_channels = in_channels 346 | self.output_channels = output_channels 347 | self.bottleneck_ratio = bottleneck_ratio 348 | self.radius = radius 349 | self.k = k 350 | self.npoint = npoint 351 | 352 | planes = in_channels // bottleneck_ratio 353 | 354 | self.use_curve = curve_config is not None 355 | if self.use_curve: 356 | self.curveaggregation = CurveAggregation(planes) 357 | self.curvegrouping = CurveGrouping( 358 | planes, k, curve_config[0], curve_config[1] 359 | ) 360 | 361 | self.conv1 = nn.Sequential( 362 | nn.Conv1d(in_channels, planes, kernel_size=1, bias=False), 363 | nn.BatchNorm1d(in_channels // bottleneck_ratio), 364 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 365 | ) 366 | 367 | self.conv2 = nn.Sequential( 368 | nn.Conv1d(planes, output_channels, kernel_size=1, bias=False), 369 | nn.BatchNorm1d(output_channels), 370 | ) 371 | 372 | if in_channels != output_channels: 373 | self.shortcut = nn.Sequential( 374 | nn.Conv1d(in_channels, output_channels, kernel_size=1, bias=False), 375 | nn.BatchNorm1d(output_channels), 376 | ) 377 | 378 | self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 379 | 380 | self.maxpool = MaskedMaxPool(npoint, radius, k) 381 | 382 | self.lpfa = LPFA(planes, planes, k, mlp_num=mlp_num, initial=False) 383 | 384 | def forward(self, xyz, x): 385 | 386 | # max pool 387 | if xyz.size(-1) != self.npoint: 388 | xyz, x = self.maxpool(xyz.transpose(1, 2).contiguous(), x) 389 | xyz = xyz.transpose(1, 2) 390 | 391 | shortcut = x 392 | x = self.conv1(x) # bs, c', n 393 | 394 | idx = knn(xyz, self.k) 395 | 396 | if self.use_curve: 397 | # curve grouping 398 | curves = self.curvegrouping(x, xyz, idx[:, :, 1:]) # avoid self-loop 399 | 400 | # curve aggregation 401 | x = self.curveaggregation(x, curves) 402 | 403 | x = self.lpfa(x, xyz, idx=idx[:, :, : self.k]) # bs, c', n, k 404 | 405 | x = self.conv2(x) # bs, c, n 406 | 407 | if self.in_channels != self.output_channels: 408 | shortcut = self.shortcut(shortcut) 409 | 410 | x = self.relu(x + shortcut) 411 | 412 | return xyz, x 413 | 414 | 415 | class CurveAggregation(nn.Module): 416 | def __init__(self, in_channel): 417 | super(CurveAggregation, self).__init__() 418 | self.in_channel = in_channel 419 | mid_feature = in_channel // 2 420 | self.conva = nn.Conv1d(in_channel, mid_feature, kernel_size=1, bias=False) 421 | self.convb = nn.Conv1d(in_channel, mid_feature, kernel_size=1, bias=False) 422 | self.convc = nn.Conv1d(in_channel, mid_feature, kernel_size=1, bias=False) 423 | self.convn = nn.Conv1d(mid_feature, mid_feature, kernel_size=1, bias=False) 424 | self.convl = nn.Conv1d(mid_feature, mid_feature, kernel_size=1, bias=False) 425 | self.convd = nn.Sequential( 426 | nn.Conv1d(mid_feature * 2, in_channel, kernel_size=1, bias=False), 427 | nn.BatchNorm1d(in_channel), 428 | ) 429 | self.line_conv_att = nn.Conv2d(in_channel, 1, kernel_size=1, bias=False) 430 | 431 | def forward(self, x, curves): 432 | curves_att = self.line_conv_att(curves) # bs, 1, c_n, c_l 433 | 434 | curver_inter = torch.sum( 435 | curves * F.softmax(curves_att, dim=-1), dim=-1 436 | ) # bs, c, c_n 437 | curves_intra = torch.sum( 438 | curves * F.softmax(curves_att, dim=-2), dim=-2 439 | ) # bs, c, c_l 440 | 441 | curver_inter = self.conva(curver_inter) # bs, mid, n 442 | curves_intra = self.convb(curves_intra) # bs, mid ,n 443 | 444 | x_logits = self.convc(x).transpose(1, 2).contiguous() 445 | x_inter = F.softmax(torch.bmm(x_logits, curver_inter), dim=-1) # bs, n, c_n 446 | x_intra = F.softmax(torch.bmm(x_logits, curves_intra), dim=-1) # bs, l, c_l 447 | 448 | curver_inter = self.convn(curver_inter).transpose(1, 2).contiguous() 449 | curves_intra = self.convl(curves_intra).transpose(1, 2).contiguous() 450 | 451 | x_inter = torch.bmm(x_inter, curver_inter) 452 | x_intra = torch.bmm(x_intra, curves_intra) 453 | 454 | curve_features = ( 455 | torch.cat((x_inter, x_intra), dim=-1).transpose(1, 2).contiguous() 456 | ) 457 | x = x + self.convd(curve_features) 458 | 459 | return F.leaky_relu(x, negative_slope=0.2) 460 | 461 | 462 | class CurveGrouping(nn.Module): 463 | def __init__(self, in_channel, k, curve_num, curve_length): 464 | super(CurveGrouping, self).__init__() 465 | self.curve_num = curve_num 466 | self.curve_length = curve_length 467 | self.in_channel = in_channel 468 | self.k = k 469 | 470 | self.att = nn.Conv1d(in_channel, 1, kernel_size=1, bias=False) 471 | 472 | self.walk = Walk(in_channel, k, curve_num, curve_length) 473 | 474 | def forward(self, x, xyz, idx): 475 | # starting point selection in self attention style 476 | x_att = torch.sigmoid(self.att(x)) 477 | x = x * x_att 478 | 479 | _, start_index = torch.topk(x_att, self.curve_num, dim=2, sorted=False) 480 | 481 | # print(start_index.shape) 482 | # torch.Size([2, 1, 100]) 483 | start_index = start_index.squeeze(1).unsqueeze(2) 484 | 485 | curves = self.walk(xyz, x, idx, start_index) # bs, c, c_n, c_l 486 | 487 | return curves 488 | 489 | 490 | class MaskedMaxPool(nn.Module): 491 | def __init__(self, npoint, radius, k): 492 | super(MaskedMaxPool, self).__init__() 493 | self.npoint = npoint 494 | self.radius = radius 495 | self.k = k 496 | 497 | def forward(self, xyz, features): 498 | sub_xyz, neighborhood_features = sample_and_group( 499 | self.npoint, self.radius, self.k, xyz, features.transpose(1, 2) 500 | ) 501 | 502 | neighborhood_features = neighborhood_features.permute(0, 3, 1, 2).contiguous() 503 | sub_features = F.max_pool2d( 504 | neighborhood_features, kernel_size=[1, neighborhood_features.shape[3]] 505 | ) # bs, c, n, 1 506 | sub_features = torch.squeeze(sub_features, -1) # bs, c, n 507 | return sub_xyz, sub_features 508 | -------------------------------------------------------------------------------- /executor/models/curvenet/walk.py: -------------------------------------------------------------------------------- 1 | """Adapted from: https://github.com/tiangexiang/CurveNet/blob/main/core/models/walk.py 2 | """ 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def batched_index_select(input, dim, index): 10 | views = [input.shape[0]] + [ 11 | 1 if i != dim else -1 for i in range(1, len(input.shape)) 12 | ] 13 | expanse = list(input.shape) 14 | expanse[0] = -1 15 | expanse[dim] = -1 16 | index = index.view(views).expand(expanse) 17 | return torch.gather(input, dim, index) 18 | 19 | 20 | def gumbel_softmax(logits, dim, temperature=1): 21 | """ 22 | ST-gumple-softmax w/o random gumbel samplings 23 | input: [*, n_class] 24 | return: flatten --> [*, n_class] an one-hot vector 25 | """ 26 | y = F.softmax(logits / temperature, dim=dim) 27 | 28 | shape = y.size() 29 | _, ind = y.max(dim=-1) 30 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 31 | y_hard.scatter_(1, ind.view(-1, 1), 1) 32 | y_hard = y_hard.view(*shape) 33 | 34 | y_hard = (y_hard - y).detach() + y 35 | return y_hard 36 | 37 | 38 | class Walk(nn.Module): 39 | ''' 40 | Walk in the cloud 41 | ''' 42 | 43 | def __init__(self, in_channel, k, curve_num, curve_length): 44 | super(Walk, self).__init__() 45 | self.curve_num = curve_num 46 | self.curve_length = curve_length 47 | self.k = k 48 | 49 | self.agent_mlp = nn.Sequential( 50 | nn.Conv2d(in_channel * 2, 1, kernel_size=1, bias=False), nn.BatchNorm2d(1) 51 | ) 52 | self.momentum_mlp = nn.Sequential( 53 | nn.Conv1d(in_channel * 2, 2, kernel_size=1, bias=False), nn.BatchNorm1d(2) 54 | ) 55 | 56 | def crossover_suppression(self, cur, neighbor, bn, n, k): 57 | # cur: bs*n, 3 58 | # neighbor: bs*n, 3, k 59 | neighbor = neighbor.detach() 60 | cur = cur.unsqueeze(-1).detach() 61 | dot = torch.bmm(cur.transpose(1, 2), neighbor) # bs*n, 1, k 62 | norm1 = torch.norm(cur, dim=1, keepdim=True) 63 | norm2 = torch.norm(neighbor, dim=1, keepdim=True) 64 | divider = torch.clamp(norm1 * norm2, min=1e-8) 65 | ans = torch.div(dot, divider).squeeze(1) # bs*n, k 66 | 67 | # normalize to [0, 1] 68 | ans = 1.0 + ans 69 | ans = torch.clamp(ans, 0.0, 1.0) 70 | 71 | return ans.detach() 72 | 73 | def forward(self, xyz, x, adj, cur): 74 | bn, c, tot_points = x.size() 75 | 76 | # raw point coordinates 77 | xyz = xyz.transpose(1, 2).contiguous # bs, n, 3 78 | 79 | # point features 80 | x = x.transpose(1, 2).contiguous() # bs, n, c 81 | 82 | flatten_x = x.view(bn * tot_points, -1) 83 | batch_offset = torch.arange(0, bn, device=x.device).detach() * tot_points 84 | 85 | # indices of neighbors for the starting points 86 | tmp_adj = (adj + batch_offset.view(-1, 1, 1)).view( 87 | adj.size(0) * adj.size(1), -1 88 | ) # bs, n, k 89 | 90 | # batch flattened indices for teh starting points 91 | flatten_cur = (cur + batch_offset.view(-1, 1, 1)).view(-1) 92 | 93 | curves = [] 94 | 95 | # one step at a time 96 | for step in range(self.curve_length): 97 | 98 | if step == 0: 99 | # get starting point features using flattend indices 100 | starting_points = flatten_x[flatten_cur, :].contiguous() 101 | pre_feature = starting_points.view(bn, self.curve_num, -1, 1).transpose( 102 | 1, 2 103 | ) # bs * n, c 104 | else: 105 | # dynamic momentum 106 | cat_feature = torch.cat( 107 | (cur_feature.squeeze(3), pre_feature.squeeze(3)), dim=1 108 | ) 109 | att_feature = F.softmax(self.momentum_mlp(cat_feature), dim=1).view( 110 | bn, 1, self.curve_num, 2 111 | ) # bs, 1, n, 2 112 | cat_feature = torch.cat( 113 | (cur_feature, pre_feature), dim=-1 114 | ) # bs, c, n, 2 115 | 116 | # update curve descriptor 117 | pre_feature = torch.sum( 118 | cat_feature * att_feature, dim=-1, keepdim=True 119 | ) # bs, c, n 120 | pre_feature_cos = ( 121 | pre_feature.transpose(1, 2) 122 | .contiguous() 123 | .view(bn * self.curve_num, -1) 124 | ) 125 | 126 | pick_idx = tmp_adj[flatten_cur] # bs*n, k 127 | 128 | # get the neighbors of current points 129 | pick_values = flatten_x[pick_idx.view(-1), :] 130 | 131 | # reshape to fit crossover suppresion below 132 | pick_values_cos = pick_values.view(bn * self.curve_num, self.k, c) 133 | pick_values = pick_values_cos.view(bn, self.curve_num, self.k, c) 134 | pick_values_cos = pick_values_cos.transpose(1, 2).contiguous() 135 | 136 | pick_values = pick_values.permute(0, 3, 1, 2) # bs, c, n, k 137 | 138 | pre_feature_expand = pre_feature.expand_as(pick_values) 139 | 140 | # concat current point features with curve descriptors 141 | pre_feature_expand = torch.cat((pick_values, pre_feature_expand), dim=1) 142 | 143 | # which node to pick next? 144 | pre_feature_expand = self.agent_mlp(pre_feature_expand) # bs, 1, n, k 145 | 146 | if step != 0: 147 | # cross over supression 148 | d = self.crossover_suppression( 149 | cur_feature_cos - pre_feature_cos, 150 | pick_values_cos - cur_feature_cos.unsqueeze(-1), 151 | bn, 152 | self.curve_num, 153 | self.k, 154 | ) 155 | d = d.view(bn, self.curve_num, self.k).unsqueeze(1) # bs, 1, n, k 156 | pre_feature_expand = torch.mul(pre_feature_expand, d) 157 | 158 | pre_feature_expand = gumbel_softmax(pre_feature_expand, -1) # bs, 1, n, k 159 | 160 | cur_feature = torch.sum( 161 | pick_values * pre_feature_expand, dim=-1, keepdim=True 162 | ) # bs, c, n, 1 163 | 164 | cur_feature_cos = ( 165 | cur_feature.transpose(1, 2).contiguous().view(bn * self.curve_num, c) 166 | ) 167 | 168 | cur = torch.argmax(pre_feature_expand, dim=-1).view(-1, 1) # bs * n, 1 169 | 170 | flatten_cur = batched_index_select(pick_idx, 1, cur).squeeze(1) # bs * n 171 | 172 | # collect curve progress 173 | curves.append(cur_feature) 174 | 175 | return torch.cat(curves, dim=-1) 176 | -------------------------------------------------------------------------------- /executor/models/encoder_pl.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from finetuner.tuner.pytorch.losses import TripletLoss 6 | from finetuner.tuner.pytorch.miner import TripletEasyHardMiner 7 | from torch.nn import functional as F 8 | from torchmetrics.functional import accuracy 9 | 10 | from .modeling import MeshDataModel 11 | 12 | AVAILABLE_MODELS = { 13 | 'PointNet-Shapenet-d1024': { 14 | 'model_name': 'pointnet', 15 | 'hidden_dim': 1024, 16 | 'embed_dim': 1024, 17 | 'model_path': '', 18 | }, 19 | 'PointConv-Shapenet-d1024': { 20 | 'model_name': 'pointconv', 21 | 'hidden_dim': 1024, 22 | 'embed_dim': 1024, 23 | 'model_path': 'https://jina-pretrained-models.s3.us-west-1.amazonaws.com/mesh_models/pointconv-shapenet-d1024.pth', 24 | }, 25 | 'PointNet-Shapenet-d512': { 26 | 'model_name': 'pointnet', 27 | 'hidden_dim': 1024, 28 | 'embed_dim': 512, 29 | 'model_path': '', 30 | }, 31 | 'PointConv-Shapenet-d512': { 32 | 'model_name': 'pointconv', 33 | 'hidden_dim': 1024, 34 | 'embed_dim': 512, 35 | 'model_path': 'https://jina-pretrained-models.s3.us-west-1.amazonaws.com/mesh_models/pointconv-shapenet-d512.pth', 36 | }, 37 | } 38 | 39 | DEFAULT_MODEL_NAME = 'pointconv' 40 | 41 | 42 | class MeshDataEncoderPL(pl.LightningModule): 43 | def __init__( 44 | self, 45 | pretrained_model: str = None, 46 | default_model_name=DEFAULT_MODEL_NAME, 47 | model_path: Optional[str] = None, 48 | hidden_dim: int = 1024, 49 | embed_dim: int = 1024, 50 | input_shape: str = 'bnc', 51 | device: str = 'cpu', 52 | batch_size: int = 64, 53 | filters: Optional[dict] = None, 54 | **kwargs, 55 | ) -> None: 56 | super().__init__(**kwargs) 57 | 58 | self.save_hyperparameters() 59 | 60 | model_path = None 61 | if pretrained_model in AVAILABLE_MODELS: 62 | config = AVAILABLE_MODELS[pretrained_model] 63 | model_name = config.pop('model_name') 64 | model_path = config.pop('model_path') 65 | embed_dim = config.pop('embed_dim') 66 | hidden_dim = config.pop('hidden_dim') 67 | else: 68 | model_name = default_model_name 69 | self._model = MeshDataModel( 70 | model_name=model_name, 71 | hidden_dim=hidden_dim, 72 | embed_dim=embed_dim, 73 | pretrained=True if model_path else False, 74 | input_shape=input_shape, 75 | ) 76 | 77 | if model_path: 78 | if model_path.startswith('http'): 79 | import os 80 | import urllib.request 81 | from pathlib import Path 82 | 83 | cache_dir = Path.home() / '.cache' / 'jina-models' 84 | cache_dir.mkdir(parents=True, exist_ok=True) 85 | 86 | file_url = model_path 87 | file_name = os.path.basename(model_path) 88 | model_path = cache_dir / file_name 89 | 90 | if not model_path.exists(): 91 | print(f'=> download {file_url} to {model_path}') 92 | urllib.request.urlretrieve(file_url, model_path) 93 | 94 | checkpoint = torch.load(model_path, map_location='cpu') 95 | self._model.load_state_dict(checkpoint) 96 | 97 | self._device = device 98 | self._batch_size = batch_size 99 | self._filters = filters 100 | # bnc 101 | self.example_input_array = torch.zeros((batch_size, 1024, 3)) 102 | 103 | def forward(self, x): 104 | embedding = self._model(x) 105 | return embedding 106 | 107 | def configure_optimizers(self): 108 | optimizer = torch.optim.Adam(self.parameters(), lr=5e-4) 109 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 110 | optimizer, milestones=[30, 60], gamma=0.5 111 | ) 112 | 113 | return {'optimizer': optimizer, 'lr_scheduler': scheduler} 114 | 115 | def training_step(self, train_batch, _batch_idx): 116 | x, y = train_batch 117 | loss_fn = TripletLoss( 118 | miner=TripletEasyHardMiner(pos_strategy='easy', neg_strategy='semihard') 119 | ) 120 | embeddings = self._model(x) 121 | loss = loss_fn(embeddings, y) 122 | self.log('train_loss', loss) 123 | return loss 124 | 125 | def evaluate(self, batch, stage): 126 | x, y = batch 127 | loss_fn = TripletLoss( 128 | miner=TripletEasyHardMiner(pos_strategy='easy', neg_strategy='semihard') 129 | ) 130 | embeddings = self._model(x) 131 | loss = loss_fn(embeddings, y) 132 | self.log(f'{stage}_loss', loss, prog_bar=True) 133 | 134 | def validation_step(self, val_batch, _batch_idx): 135 | self.evaluate(val_batch, 'val') 136 | 137 | def test_step(self, test_batch, _batch_idx): 138 | self.evaluate(test_batch, 'test') 139 | -------------------------------------------------------------------------------- /executor/models/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .curvenet import CurveNet 5 | from .pointconv import MLP, PointConv 6 | from .pointmlp import pointMLP, pointMLPElite 7 | from .pointnet import PointNet 8 | from .pointnet2 import PointNet2 9 | from .repsurf import RepSurf 10 | 11 | PRETRAINED_MODELS = { 12 | 'pointnet': { 13 | 'model_path': '', 14 | 'hidden_dim': 1024, 15 | }, 16 | 'pointconv': { 17 | 'model_path': 'https://jina-pretrained-models.s3.us-west-1.amazonaws.com/mesh_models/pointconv_class_encoder.pth', 18 | 'hidden_dim': 1024, 19 | }, 20 | } 21 | 22 | 23 | def get_model(model_name: str, hidden_dim: int, input_shape: str, classifier: bool): 24 | if model_name == 'pointnet': 25 | # classifier ignored 26 | return PointNet( 27 | emb_dims=hidden_dim, 28 | input_shape=input_shape, 29 | use_bn=True, 30 | global_feat=True, 31 | classifier=classifier, 32 | ) 33 | elif model_name == 'pointconv': 34 | return PointConv( 35 | emb_dims=hidden_dim, 36 | input_channel_dim=3, 37 | input_shape=input_shape, 38 | classifier=classifier, 39 | ) 40 | elif model_name == 'pointnet2': 41 | return PointNet2( 42 | emb_dims=hidden_dim, 43 | normal_channel=False, 44 | input_shape=input_shape, 45 | classifier=classifier, 46 | density_adaptive_type='ssg', 47 | ) 48 | elif model_name == 'pointnet2msg': 49 | return PointNet2( 50 | emb_dims=hidden_dim, 51 | normal_channel=False, 52 | input_shape=input_shape, 53 | classifier=classifier, 54 | density_adaptive_type='msg', 55 | ) 56 | elif model_name == 'repsurf': 57 | return RepSurf( 58 | num_points=1024, 59 | emb_dims=hidden_dim, 60 | input_shape=input_shape, 61 | classifier=classifier, 62 | ) 63 | elif model_name == 'pointmlp': 64 | return pointMLP(classifier=classifier, embed_dim=hidden_dim) 65 | elif model_name == 'pointmlp-elite': 66 | return pointMLPElite(classifier=classifier, embed_dim=hidden_dim) 67 | elif model_name == 'curvenet': 68 | return CurveNet( 69 | emb_dims=hidden_dim, 70 | input_shape=input_shape, 71 | classifier=classifier, 72 | ) 73 | else: 74 | raise NotImplementedError('The model has not been implemented yet!') 75 | 76 | 77 | class MeshDataModel(nn.Module): 78 | def __init__( 79 | self, 80 | model_name: str = 'pointnet', 81 | hidden_dim: int = 1024, 82 | embed_dim: int = 512, 83 | input_shape: str = 'bnc', 84 | dropout_rate: float = 0.1, 85 | pretrained: bool = True, 86 | ): 87 | super().__init__() 88 | 89 | model_path = None 90 | if pretrained and model_name in PRETRAINED_MODELS: 91 | config = PRETRAINED_MODELS[model_name] 92 | model_path = config['model_path'] 93 | hidden_dim = config['hidden_dim'] 94 | 95 | self._point_encoder = get_model(model_name, hidden_dim, input_shape, False) 96 | 97 | if model_path: 98 | if model_path.startswith('http'): 99 | import os 100 | import urllib.request 101 | from pathlib import Path 102 | 103 | cache_dir = Path.home() / '.cache' / 'jina-models' 104 | cache_dir.mkdir(parents=True, exist_ok=True) 105 | 106 | file_url = model_path 107 | file_name = os.path.basename(model_path) 108 | model_path = cache_dir / file_name 109 | 110 | if not model_path.exists(): 111 | print(f'=> download {file_url} to {model_path}') 112 | urllib.request.urlretrieve(file_url, model_path) 113 | 114 | print(f'==> restore {model_name} from: {model_path}') 115 | checkpoint = torch.load(model_path, map_location='cpu') 116 | self._point_encoder.load_state_dict(checkpoint) 117 | 118 | self._dropout = nn.Dropout(dropout_rate) 119 | 120 | # Projector 121 | self._projector = MLP(hidden_dim, hidden_dim * 4, embed_dim) 122 | 123 | @property 124 | def encoder(self): 125 | return self._point_encoder 126 | 127 | def forward(self, points): 128 | feats = self._point_encoder(points) 129 | feats = self._dropout(feats) 130 | return self._projector(feats) 131 | -------------------------------------------------------------------------------- /executor/models/pointconv/__init__.py: -------------------------------------------------------------------------------- 1 | from .pointconv import PointConv 2 | from .pointconv_utils import ( 3 | MLP, 4 | PointNetSetAbstraction, 5 | PointNetSetAbstractionMsg, 6 | farthest_point_sample, 7 | index_points, 8 | knn_point, 9 | query_ball_point, 10 | ) 11 | -------------------------------------------------------------------------------- /executor/models/pointconv/pointconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .pointconv_utils import PointConvDensityNet 6 | 7 | 8 | class PointConv(torch.nn.Module): 9 | def __init__( 10 | self, 11 | emb_dims=1024, 12 | input_shape='bnc', 13 | input_channel_dim=3, 14 | classifier=False, 15 | num_classes=40, 16 | pretrained=None, 17 | ): 18 | super(PointConv, self).__init__() 19 | if input_shape not in ['bnc', 'bcn']: 20 | raise ValueError( 21 | "Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' " 22 | ) 23 | self.input_shape = input_shape 24 | self.emb_dims = emb_dims 25 | self.classifier = classifier 26 | self.input_channel_dim = input_channel_dim 27 | self.create_structure() 28 | if self.classifier: 29 | self.create_classifier(num_classes) 30 | 31 | def create_structure(self): 32 | # Arguments to define PointConv network using PointConvDensityNet class. 33 | # npoint: number of points sampled from input. 34 | # nsample: number of neighbours chosen for each point in sampled point cloud. 35 | # in_channel: number of channels in input. 36 | # mlp: sizes of multi-layer perceptrons. 37 | # bandwidth: used to compute gaussian density. 38 | # group_all: group all points from input to a single point if set to True. 39 | self.sa1 = PointConvDensityNet( 40 | npoint=512, 41 | nsample=32, 42 | in_channel=self.input_channel_dim, 43 | mlp=[64, 64, 128], 44 | bandwidth=0.1, 45 | group_all=False, 46 | ) 47 | self.sa2 = PointConvDensityNet( 48 | npoint=128, 49 | nsample=64, 50 | in_channel=128 + 3, 51 | mlp=[128, 128, 256], 52 | bandwidth=0.2, 53 | group_all=False, 54 | ) 55 | self.sa3 = PointConvDensityNet( 56 | npoint=1, 57 | nsample=None, 58 | in_channel=256 + 3, 59 | mlp=[256, 512, self.emb_dims], 60 | bandwidth=0.4, 61 | group_all=True, 62 | ) 63 | 64 | def create_classifier(self, num_classes): 65 | # These are simple fully-connected layers with batch-norm and dropouts. 66 | # This architecture is given by PointConv paper. Hence, I used it here as a default version. 67 | # This can be easily modified by overwriting this function or by using classifier.py class. 68 | self.fc1 = nn.Linear(self.emb_dims, 512) 69 | self.bn1 = nn.BatchNorm1d(512) 70 | self.drop1 = nn.Dropout(0.7) 71 | self.fc2 = nn.Linear(512, 256) 72 | self.bn2 = nn.BatchNorm1d(256) 73 | self.drop2 = nn.Dropout(0.7) 74 | self.fc3 = nn.Linear(256, num_classes) 75 | 76 | def forward(self, input_data): 77 | if self.input_shape == 'bnc': 78 | input_data = input_data.permute(0, 2, 1) 79 | batch_size = input_data.shape[0] 80 | 81 | # Convert point clouds to latent features using PointConv network. 82 | l1_points, l1_features = self.sa1(input_data[:, :3, :], input_data[:, 3:, :]) 83 | l2_points, l2_features = self.sa2(l1_points, l1_features) 84 | l3_points, l3_features = self.sa3(l2_points, l2_features) 85 | features = l3_features.view(batch_size, self.emb_dims) 86 | 87 | if self.classifier: 88 | # Use these features to classify the input point cloud. 89 | features = self.drop1(F.relu(self.bn1(self.fc1(features)))) 90 | features = self.drop2(F.relu(self.bn2(self.fc2(features)))) 91 | features = self.fc3(features) 92 | output = F.log_softmax(features, -1) 93 | else: 94 | # Return the PointConv features for the use of other higher level tasks. 95 | output = features 96 | 97 | return output 98 | -------------------------------------------------------------------------------- /executor/models/pointconv/pointconv_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility function for PointConv 3 | Originally from : https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/utils.py 4 | Modify by Wenxuan Wu 5 | Date: September 2019 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def square_distance(src, dst): 13 | """ 14 | Calculate Euclid distance between each two points. 15 | 16 | src^T * dst = xn * xm + yn * ym + zn * zm; 17 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 18 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 19 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 20 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 21 | 22 | Input: 23 | src: source points, [B, N, C] 24 | dst: target points, [B, M, C] 25 | Output: 26 | dist: per-point square distance, [B, N, M] 27 | """ 28 | B, N, _ = src.shape 29 | _, M, _ = dst.shape 30 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 31 | dist += torch.sum(src**2, -1).view(B, N, 1) 32 | dist += torch.sum(dst**2, -1).view(B, 1, M) 33 | return dist 34 | 35 | 36 | def index_points(points, idx): 37 | """ 38 | 39 | Input: 40 | points: input points data, [B, N, C] 41 | idx: sample index data, [B, S] 42 | Return: 43 | new_points: coordinates of the centroids, indexed points data, [B, S, C] 44 | """ 45 | device = points.device 46 | B = points.shape[0] 47 | view_shape = list(idx.shape) 48 | view_shape[1:] = [1] * (len(view_shape) - 1) 49 | repeat_shape = list(idx.shape) 50 | repeat_shape[0] = 1 51 | batch_indices = ( 52 | torch.arange(B, dtype=torch.long) 53 | .to(device) 54 | .view(view_shape) 55 | .repeat(repeat_shape) 56 | ) 57 | new_points = points[batch_indices, idx, :] 58 | return new_points 59 | 60 | 61 | def farthest_point_sample(xyz, npoint): 62 | """ 63 | Input: 64 | xyz: pointcloud data, which is xyz coordinates and corresponding 65 | features (if exist), [B, N, C] (B: batch size, N: number of points C: number of channels) 66 | npoint: number of selected samples 67 | Return: 68 | centroids: sampled pointcloud index, [B, npoint] 69 | """ 70 | # import ipdb; ipdb.set_trace() 71 | device = xyz.device 72 | B, N, C = xyz.shape 73 | # centroids : the indices of sampled points of each sample in batch 74 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 75 | # distance : the minimum distance of point clouds to the last selected point, 1e10 : infinity value 76 | distance = torch.ones(B, N).to(device) * 1e10 77 | # farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 78 | # farthest : indices of farthest points of each batch sample 79 | farthest = torch.zeros(B, dtype=torch.long).to(device) 80 | # batch_indices : batch indices of batch samples, [0, 1, 2, ..., B - 1] 81 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 82 | # n iterations 83 | for i in range(npoint): 84 | centroids[:, i] = farthest 85 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 86 | # xyz - centroid is the difference of points and last selected centroid 87 | dist = torch.sum((xyz - centroid) ** 2, -1) 88 | mask = dist < distance 89 | # update corresponding distance item if dist < distance 90 | distance[mask] = dist[mask] 91 | # indices of farthest ones to the last selected centroids of each batch sample 92 | farthest = torch.max(distance, -1)[1] 93 | return centroids 94 | 95 | 96 | def query_ball_point(radius, nsample, xyz, new_xyz): 97 | """ 98 | Input: 99 | radius: local region radius 100 | nsample: max sample number in local region 101 | xyz: all points, [B, N, C] 102 | new_xyz: query points, [B, S, C] 103 | Return: 104 | group_idx: grouped points index, [B, S, nsample] 105 | """ 106 | device = xyz.device 107 | B, N, C = xyz.shape 108 | _, S, _ = new_xyz.shape 109 | group_idx = ( 110 | torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 111 | ) 112 | sqrdists = square_distance(new_xyz, xyz) 113 | group_idx[sqrdists > radius**2] = N 114 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 115 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 116 | mask = group_idx == N 117 | group_idx[mask] = group_first[mask] 118 | return group_idx 119 | 120 | 121 | def knn_point(nsample, xyz, new_xyz): 122 | """ 123 | Input: 124 | nsample: max sample number in local region 125 | xyz: all points, [B, N, C] 126 | new_xyz: query points, [B, S, C] 127 | Return: 128 | group_idx: grouped points index, [B, S, nsample] 129 | """ 130 | sqrdists = square_distance(new_xyz, xyz) 131 | _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) 132 | return group_idx 133 | 134 | 135 | def sample_and_group(npoint, nsample, xyz, points, density_scale=None): 136 | """ 137 | Input: 138 | npoint: 139 | nsample: 140 | xyz: input points position data, [B, N, C] 141 | points: input points data, [B, N, D] 142 | density_scale: 143 | Return: 144 | new_xyz: sampled points position data, [B, 1, C] 145 | new_points: sampled points data, [B, 1, N, C+D] 146 | """ 147 | B, N, C = xyz.shape 148 | S = npoint 149 | # fps_idx : indices of centroids 150 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 151 | # new_xyz : xyz of centroids 152 | new_xyz = index_points(xyz, fps_idx) 153 | # idx : indices of members in each knn group 154 | idx = knn_point(nsample, xyz, new_xyz) 155 | # grouped_xyz : xyz of members in each knn group 156 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 157 | # grouped_xyz_norm : the relative xyz to centroids in each knn group 158 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 159 | if points is not None: 160 | # grouped_points : features learned from previous layers 161 | grouped_points = index_points(points, idx) 162 | new_points = torch.cat( 163 | # relative xyz, features learned from previous layers 164 | [grouped_xyz_norm, grouped_points], 165 | dim=-1, 166 | ) # [B, npoint, nsample, C+D] 167 | else: 168 | new_points = grouped_xyz_norm 169 | 170 | if density_scale is None: 171 | return new_xyz, new_points, grouped_xyz_norm, idx 172 | else: 173 | grouped_density = index_points(density_scale, idx) 174 | return new_xyz, new_points, grouped_xyz_norm, idx, grouped_density 175 | 176 | 177 | def sample_and_group_all(xyz, points, density_scale=None): 178 | """ 179 | Input: 180 | xyz: input points position data, [B, N, C] 181 | points: input points data, [B, N, D] 182 | Return: 183 | new_xyz: sampled points position data, [B, 1, C] 184 | new_points: sampled points data, [B, 1, N, C+D] 185 | """ 186 | device = xyz.device 187 | B, N, C = xyz.shape 188 | # new_xyz = torch.zeros(B, 1, C).to(device) 189 | new_xyz = xyz.mean(dim=1, keepdim=True) 190 | grouped_xyz = xyz.view(B, 1, N, C) - new_xyz.view(B, 1, 1, C) 191 | if points is not None: 192 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 193 | else: 194 | new_points = grouped_xyz 195 | if density_scale is None: 196 | return new_xyz, new_points, grouped_xyz 197 | else: 198 | grouped_density = density_scale.view(B, 1, N, 1) 199 | return new_xyz, new_points, grouped_xyz, grouped_density 200 | 201 | 202 | def group(nsample, xyz, points): 203 | """ 204 | Input: 205 | npoint: 206 | nsample: 207 | xyz: input points position data, [B, N, C] 208 | points: input points data, [B, N, D] 209 | Return: 210 | new_xyz: sampled points position data, [B, 1, C] 211 | new_points: sampled points data, [B, 1, N, C+D] 212 | """ 213 | B, N, C = xyz.shape 214 | S = N 215 | new_xyz = xyz 216 | idx = knn_point(nsample, xyz, new_xyz) 217 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 218 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 219 | if points is not None: 220 | grouped_points = index_points(points, idx) 221 | new_points = torch.cat( 222 | [grouped_xyz_norm, grouped_points], dim=-1 223 | ) # [B, npoint, nsample, C+D] 224 | else: 225 | new_points = grouped_xyz_norm 226 | 227 | return new_points, grouped_xyz_norm 228 | 229 | 230 | def compute_density(xyz, bandwidth): 231 | """_summary_ 232 | 233 | Args: 234 | xyz (_type_): input points position data, [B, N, C] 235 | bandwidth (_type_): _description_ 236 | 237 | Returns: 238 | _type_: _description_ 239 | """ 240 | # import ipdb; ipdb.set_trace() 241 | B, N, C = xyz.shape 242 | sqrdists = square_distance(xyz, xyz) 243 | gaussion_density = torch.exp(-sqrdists / (2.0 * bandwidth * bandwidth)) / ( 244 | 2.5 * bandwidth 245 | ) 246 | xyz_density = gaussion_density.mean(dim=-1) 247 | 248 | return xyz_density 249 | 250 | 251 | class DensityNet(nn.Module): 252 | def __init__(self, hidden_unit=[16, 8]): 253 | super(DensityNet, self).__init__() 254 | self.mlp_convs = nn.ModuleList() 255 | self.mlp_bns = nn.ModuleList() 256 | 257 | self.mlp_convs.append(nn.Conv2d(1, hidden_unit[0], 1)) 258 | self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[0])) 259 | for i in range(1, len(hidden_unit)): 260 | self.mlp_convs.append(nn.Conv2d(hidden_unit[i - 1], hidden_unit[i], 1)) 261 | self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[i])) 262 | self.mlp_convs.append(nn.Conv2d(hidden_unit[-1], 1, 1)) 263 | self.mlp_bns.append(nn.BatchNorm2d(1)) 264 | 265 | def forward(self, density_scale): 266 | for i, conv in enumerate(self.mlp_convs): 267 | bn = self.mlp_bns[i] 268 | density_scale = bn(conv(density_scale)) 269 | if i == len(self.mlp_convs): 270 | density_scale = F.sigmoid(density_scale) 271 | else: 272 | density_scale = F.relu(density_scale) 273 | 274 | return density_scale 275 | 276 | 277 | class WeightNet(nn.Module): 278 | def __init__(self, in_channel, out_channel, hidden_unit=[8, 8]): 279 | super(WeightNet, self).__init__() 280 | 281 | self.mlp_convs = nn.ModuleList() 282 | self.mlp_bns = nn.ModuleList() 283 | if hidden_unit is None or len(hidden_unit) == 0: 284 | self.mlp_convs.append(nn.Conv2d(in_channel, out_channel, 1)) 285 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 286 | else: 287 | self.mlp_convs.append(nn.Conv2d(in_channel, hidden_unit[0], 1)) 288 | self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[0])) 289 | for i in range(1, len(hidden_unit)): 290 | self.mlp_convs.append(nn.Conv2d(hidden_unit[i - 1], hidden_unit[i], 1)) 291 | self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[i])) 292 | self.mlp_convs.append(nn.Conv2d(hidden_unit[-1], out_channel, 1)) 293 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 294 | 295 | def forward(self, localized_xyz): 296 | # xyz : BxCxKxN 297 | 298 | weights = localized_xyz 299 | for i, conv in enumerate(self.mlp_convs): 300 | bn = self.mlp_bns[i] 301 | weights = F.relu(bn(conv(weights))) 302 | 303 | return weights 304 | 305 | 306 | class PointConvDensityNet(nn.Module): 307 | def __init__(self, npoint, nsample, in_channel, mlp, bandwidth, group_all): 308 | super(PointConvDensityNet, self).__init__() 309 | self.npoint = npoint 310 | self.nsample = nsample 311 | self.mlp_convs = nn.ModuleList() 312 | self.mlp_bns = nn.ModuleList() 313 | last_channel = in_channel 314 | for out_channel in mlp: 315 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 316 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 317 | last_channel = out_channel 318 | 319 | self.weightnet = WeightNet(3, 16) 320 | self.linear = nn.Linear(16 * mlp[-1], mlp[-1]) 321 | self.bn_linear = nn.BatchNorm1d(mlp[-1]) 322 | self.densitynet = DensityNet() 323 | self.group_all = group_all 324 | self.bandwidth = bandwidth 325 | 326 | def forward(self, xyz, points): 327 | """ 328 | Input: 329 | xyz: input points position data, [B, C, N] 330 | points: input points data, [B, D, N] 331 | Return: 332 | new_xyz: sampled points position data, [B, C, S] 333 | new_points_concat: sample points feature data, [B, D', S] 334 | """ 335 | B = xyz.shape[0] 336 | N = xyz.shape[2] 337 | xyz = xyz.permute(0, 2, 1) 338 | if points is not None: 339 | points = points.permute(0, 2, 1) 340 | 341 | xyz_density = compute_density(xyz, self.bandwidth) 342 | inverse_density = 1.0 / xyz_density 343 | 344 | if self.group_all: 345 | ( 346 | new_xyz, 347 | new_points, 348 | grouped_xyz_norm, 349 | grouped_density, 350 | ) = sample_and_group_all(xyz, points, inverse_density.view(B, N, 1)) 351 | else: 352 | ( 353 | new_xyz, 354 | new_points, 355 | grouped_xyz_norm, 356 | _, 357 | grouped_density, 358 | ) = sample_and_group( 359 | self.npoint, self.nsample, xyz, points, inverse_density.view(B, N, 1) 360 | ) 361 | # new_xyz: sampled points position data, [B, npoint, C] 362 | # new_points: sampled points data, [B, npoint, nsample, C+D] 363 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 364 | for i, conv in enumerate(self.mlp_convs): 365 | bn = self.mlp_bns[i] 366 | new_points = F.relu(bn(conv(new_points))) 367 | 368 | inverse_max_density = grouped_density.max(dim=2, keepdim=True)[0] 369 | density_scale = grouped_density / inverse_max_density 370 | density_scale = self.densitynet(density_scale.permute(0, 3, 2, 1)) 371 | new_points = new_points * density_scale 372 | 373 | grouped_xyz = grouped_xyz_norm.permute(0, 3, 2, 1) 374 | weights = self.weightnet(grouped_xyz) 375 | new_points = torch.matmul( 376 | input=new_points.permute(0, 3, 1, 2), other=weights.permute(0, 3, 2, 1) 377 | ).view(B, self.npoint, -1) 378 | new_points = self.linear(new_points) 379 | new_points = self.bn_linear(new_points.permute(0, 2, 1)) 380 | new_points = F.relu(new_points) 381 | new_xyz = new_xyz.permute(0, 2, 1) 382 | 383 | return new_xyz, new_points 384 | 385 | 386 | class MLP(nn.Module): 387 | def __init__( 388 | self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256 389 | ): 390 | super().__init__() 391 | self.output_dim = output_dim 392 | self.input_dim = input_dim 393 | self.model = nn.Sequential( 394 | nn.Linear(input_dim, hidden_size, bias=False), 395 | nn.BatchNorm1d(hidden_size), 396 | nn.ReLU(inplace=True), 397 | nn.Linear(hidden_size, output_dim, bias=True), 398 | ) 399 | 400 | def forward(self, x): 401 | x = self.model(x) 402 | return x 403 | 404 | 405 | class PointNetSetAbstraction(nn.Module): 406 | def __init__(self, npoint, nsample, radius, in_channel, mlp, group_all): 407 | """_summary_ 408 | 409 | Args: 410 | npoint (_type_): _description_ 411 | nsample (_type_): _description_ 412 | radius (_type_): _description_ 413 | in_channel (_type_): _description_ 414 | mlp (_type_): _description_ 415 | group_all (_type_): _description_ 416 | """ 417 | super(PointNetSetAbstraction, self).__init__() 418 | self.npoint = npoint 419 | self.nsample = nsample 420 | self.radius = radius 421 | self.mlp_convs = nn.ModuleList() 422 | self.mlp_bns = nn.ModuleList() 423 | last_channel = in_channel 424 | for out_channel in mlp: 425 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 426 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 427 | last_channel = out_channel 428 | self.group_all = group_all 429 | 430 | def forward(self, xyz, points): 431 | xyz = xyz.permute(0, 2, 1) 432 | if points is not None: 433 | points = points.permute(0, 2, 1) 434 | if self.group_all: 435 | new_xyz, new_points, _ = sample_and_group_all(xyz, points) 436 | else: 437 | # sample layer and group layer 438 | new_xyz, new_points, _, _ = sample_and_group( 439 | self.npoint, self.nsample, xyz, points 440 | ) 441 | 442 | # new_xyz : sampled points position data [B, npoint, C] 443 | # new_points: sampled points data [B, npoint, nsample, C + D] 444 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 445 | 446 | # PointNet layer 447 | for i, conv in enumerate(self.mlp_convs): 448 | bn = self.mlp_bns[i] 449 | new_points = F.relu(bn(conv(new_points))) 450 | new_points = torch.max(new_points, 2)[0] 451 | new_xyz = new_xyz.permute(0, 2, 1) 452 | return new_xyz, new_points 453 | 454 | 455 | class PointNetSetAbstractionMsg(nn.Module): 456 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 457 | super(PointNetSetAbstractionMsg, self).__init__() 458 | self.npoint = npoint 459 | self.radius_list = radius_list 460 | self.nsample_list = nsample_list 461 | self.conv_blocks = nn.ModuleList() 462 | self.bn_blocks = nn.ModuleList() 463 | 464 | # assert len(mlp_list) == len(radius_list) 465 | for i in range(len(mlp_list)): 466 | convs = nn.ModuleList() 467 | bns = nn.ModuleList() 468 | last_channel = in_channel + 3 469 | for out_channel in mlp_list[i]: 470 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 471 | bns.append(nn.BatchNorm2d(out_channel)) 472 | last_channel = out_channel 473 | self.conv_blocks.append(convs) 474 | self.bn_blocks.append(bns) 475 | 476 | def forward(self, xyz, points): 477 | xyz = xyz.permute(0, 2, 1) 478 | 479 | if points is not None: 480 | points = points.permute(0, 2, 1) 481 | 482 | B, N, C = xyz.shape 483 | S = self.npoint 484 | 485 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 486 | new_points_list = [] 487 | 488 | for i, radius in enumerate(self.radius_list): 489 | K = self.nsample_list[i] 490 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 491 | grouped_xyz = index_points(xyz, group_idx) 492 | grouped_xyz -= new_xyz.view(B, S, 1, C) 493 | if points is not None: 494 | grouped_points = index_points(points, group_idx) 495 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 496 | else: 497 | grouped_points = grouped_xyz 498 | 499 | grouped_points = grouped_points.permute(0, 3, 2, 1) 500 | for j in range(len(self.conv_blocks[i])): 501 | conv = self.conv_blocks[i][j] 502 | bn = self.bn_blocks[i][j] 503 | grouped_points = F.relu(bn(conv(grouped_points))) 504 | new_points = torch.max(grouped_points, 2)[0] 505 | new_points_list.append(new_points) 506 | 507 | new_xyz = new_xyz.permute(0, 2, 1) 508 | new_points_concat = torch.cat(new_points_list, dim=1) 509 | return new_xyz, new_points_concat 510 | -------------------------------------------------------------------------------- /executor/models/pointmlp/__init__.py: -------------------------------------------------------------------------------- 1 | from .pointmlp import pointMLP, pointMLPElite 2 | -------------------------------------------------------------------------------- /executor/models/pointmlp/pointmlp.py: -------------------------------------------------------------------------------- 1 | """ref : https://github.com/ma-xu/pointMLP-pytorch/blob/main/classification_ModelNet40/models/pointmlp.py 2 | """ 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ..pointconv import farthest_point_sample, index_points, knn_point, query_ball_point 9 | 10 | 11 | def get_activation(activation): 12 | if activation.lower() == 'gelu': 13 | return nn.GELU() 14 | elif activation.lower() == 'rrelu': 15 | return nn.RReLU(inplace=True) 16 | elif activation.lower() == 'selu': 17 | return nn.SELU(inplace=True) 18 | elif activation.lower() == 'silu': 19 | return nn.SiLU(inplace=True) 20 | elif activation.lower() == 'hardswish': 21 | return nn.Hardswish(inplace=True) 22 | elif activation.lower() == 'leakyrelu': 23 | return nn.LeakyReLU(inplace=True) 24 | else: 25 | return nn.ReLU(inplace=True) 26 | 27 | 28 | class LocalGrouper(nn.Module): 29 | def __init__( 30 | self, channel, groups, kneighbors, use_xyz=True, normalize='center', **kwargs 31 | ): 32 | """ 33 | Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d] 34 | :param groups: groups number 35 | :param kneighbors: k-nerighbors 36 | :param kwargs: others 37 | """ 38 | super(LocalGrouper, self).__init__() 39 | self.groups = groups 40 | self.kneighbors = kneighbors 41 | self.use_xyz = use_xyz 42 | if normalize is not None: 43 | self.normalize = normalize.lower() 44 | else: 45 | self.normalize = None 46 | if self.normalize not in ['center', 'anchor']: 47 | print( 48 | f'Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].' 49 | ) 50 | self.normalize = None 51 | if self.normalize is not None: 52 | add_channel = 3 if self.use_xyz else 0 53 | self.affine_alpha = nn.Parameter( 54 | torch.ones([1, 1, 1, channel + add_channel]) 55 | ) 56 | self.affine_beta = nn.Parameter( 57 | torch.zeros([1, 1, 1, channel + add_channel]) 58 | ) 59 | 60 | def forward(self, xyz, points): 61 | B, N, C = xyz.shape 62 | S = self.groups 63 | xyz = xyz.contiguous() # xyz [btach, points, xyz] 64 | 65 | # fps_idx = torch.multinomial(torch.linspace(0, N - 1, steps=N).repeat(B, 1).to(xyz.device), num_samples=self.groups, replacement=False).long() 66 | fps_idx = farthest_point_sample(xyz, self.groups).long() 67 | # fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.groups).long() # [B, npoint] 68 | new_xyz = index_points(xyz, fps_idx) # [B, npoint, 3] 69 | new_points = index_points(points, fps_idx) # [B, npoint, d] 70 | 71 | idx = knn_point(self.kneighbors, xyz, new_xyz) 72 | # idx = query_ball_point(radius, nsample, xyz, new_xyz) 73 | grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3] 74 | grouped_points = index_points(points, idx) # [B, npoint, k, d] 75 | if self.use_xyz: 76 | grouped_points = torch.cat( 77 | [grouped_points, grouped_xyz], dim=-1 78 | ) # [B, npoint, k, d+3] 79 | if self.normalize is not None: 80 | if self.normalize == 'center': 81 | mean = torch.mean(grouped_points, dim=2, keepdim=True) 82 | if self.normalize == 'anchor': 83 | mean = ( 84 | torch.cat([new_points, new_xyz], dim=-1) 85 | if self.use_xyz 86 | else new_points 87 | ) 88 | mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] 89 | diff = (grouped_points - mean).reshape(B, -1).clone() 90 | std_old = torch.std(diff, dim=-1, keepdim=True) 91 | std = std_old.unsqueeze(dim=-1).unsqueeze(dim=-1) 92 | grouped_points = (grouped_points - mean) / (std + 1e-5) 93 | grouped_points = self.affine_alpha * grouped_points + self.affine_beta 94 | 95 | new_points = torch.cat( 96 | [ 97 | grouped_points, 98 | new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1), 99 | ], 100 | dim=-1, 101 | ) 102 | return new_xyz, new_points 103 | 104 | 105 | class ConvBNReLU1D(nn.Module): 106 | def __init__( 107 | self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu' 108 | ): 109 | super(ConvBNReLU1D, self).__init__() 110 | self.act = get_activation(activation) 111 | self.net = nn.Sequential( 112 | nn.Conv1d( 113 | in_channels=in_channels, 114 | out_channels=out_channels, 115 | kernel_size=kernel_size, 116 | bias=bias, 117 | ), 118 | nn.BatchNorm1d(out_channels), 119 | self.act, 120 | ) 121 | 122 | def forward(self, x): 123 | return self.net(x) 124 | 125 | 126 | class ConvBNReLURes1D(nn.Module): 127 | def __init__( 128 | self, 129 | channel, 130 | kernel_size=1, 131 | groups=1, 132 | res_expansion=1.0, 133 | bias=True, 134 | activation='relu', 135 | ): 136 | super(ConvBNReLURes1D, self).__init__() 137 | self.act = get_activation(activation) 138 | self.net1 = nn.Sequential( 139 | nn.Conv1d( 140 | in_channels=channel, 141 | out_channels=int(channel * res_expansion), 142 | kernel_size=kernel_size, 143 | groups=groups, 144 | bias=bias, 145 | ), 146 | nn.BatchNorm1d(int(channel * res_expansion)), 147 | self.act, 148 | ) 149 | if groups > 1: 150 | self.net2 = nn.Sequential( 151 | nn.Conv1d( 152 | in_channels=int(channel * res_expansion), 153 | out_channels=channel, 154 | kernel_size=kernel_size, 155 | groups=groups, 156 | bias=bias, 157 | ), 158 | nn.BatchNorm1d(channel), 159 | self.act, 160 | nn.Conv1d( 161 | in_channels=channel, 162 | out_channels=channel, 163 | kernel_size=kernel_size, 164 | bias=bias, 165 | ), 166 | nn.BatchNorm1d(channel), 167 | ) 168 | else: 169 | self.net2 = nn.Sequential( 170 | nn.Conv1d( 171 | in_channels=int(channel * res_expansion), 172 | out_channels=channel, 173 | kernel_size=kernel_size, 174 | bias=bias, 175 | ), 176 | nn.BatchNorm1d(channel), 177 | ) 178 | 179 | def forward(self, x): 180 | return self.act(self.net2(self.net1(x)) + x) 181 | 182 | 183 | class PreExtraction(nn.Module): 184 | def __init__( 185 | self, 186 | channels, 187 | out_channels, 188 | blocks=1, 189 | groups=1, 190 | res_expansion=1, 191 | bias=True, 192 | activation='relu', 193 | use_xyz=True, 194 | ): 195 | """ 196 | input: [b,g,k,d]: output:[b,d,g] 197 | :param channels: 198 | :param blocks: 199 | """ 200 | super(PreExtraction, self).__init__() 201 | in_channels = 3 + 2 * channels if use_xyz else 2 * channels 202 | self.transfer = ConvBNReLU1D( 203 | in_channels, out_channels, bias=bias, activation=activation 204 | ) 205 | operation = [] 206 | for _ in range(blocks): 207 | operation.append( 208 | ConvBNReLURes1D( 209 | out_channels, 210 | groups=groups, 211 | res_expansion=res_expansion, 212 | bias=bias, 213 | activation=activation, 214 | ) 215 | ) 216 | self.operation = nn.Sequential(*operation) 217 | 218 | def forward(self, x): 219 | b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) 220 | x = x.permute(0, 1, 3, 2) 221 | x = x.reshape(-1, d, s) 222 | x = self.transfer(x) 223 | batch_size, _, _ = x.size() 224 | x = self.operation(x) # [b, d, k] 225 | x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 226 | x = x.reshape(b, n, -1).permute(0, 2, 1) 227 | return x 228 | 229 | 230 | class PosExtraction(nn.Module): 231 | def __init__( 232 | self, 233 | channels, 234 | blocks=1, 235 | groups=1, 236 | res_expansion=1, 237 | bias=True, 238 | activation='relu', 239 | ): 240 | """ 241 | input[b,d,g]; output[b,d,g] 242 | :param channels: 243 | :param blocks: 244 | """ 245 | super(PosExtraction, self).__init__() 246 | operation = [] 247 | for _ in range(blocks): 248 | operation.append( 249 | ConvBNReLURes1D( 250 | channels, 251 | groups=groups, 252 | res_expansion=res_expansion, 253 | bias=bias, 254 | activation=activation, 255 | ) 256 | ) 257 | self.operation = nn.Sequential(*operation) 258 | 259 | def forward(self, x): # [b, d, g] 260 | return self.operation(x) 261 | 262 | 263 | class Model(nn.Module): 264 | def __init__( 265 | self, 266 | points=1024, 267 | class_num=40, 268 | embed_dim=64, 269 | groups=1, 270 | res_expansion=1.0, 271 | activation='relu', 272 | bias=True, 273 | use_xyz=True, 274 | normalize='center', 275 | dim_expansion=[2, 2, 2, 2], 276 | pre_blocks=[2, 2, 2, 2], 277 | pos_blocks=[2, 2, 2, 2], 278 | k_neighbors=[32, 32, 32, 32], 279 | reducers=[2, 2, 2, 2], 280 | input_shape='bnc', 281 | classifier=False, 282 | **kwargs, 283 | ): 284 | super(Model, self).__init__() 285 | self.stages = len(pre_blocks) 286 | self.class_num = class_num 287 | self.points = points 288 | self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation) 289 | assert ( 290 | len(pre_blocks) 291 | == len(k_neighbors) 292 | == len(reducers) 293 | == len(pos_blocks) 294 | == len(dim_expansion) 295 | ), 'Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers.' 296 | self.local_grouper_list = nn.ModuleList() 297 | self.pre_blocks_list = nn.ModuleList() 298 | self.pos_blocks_list = nn.ModuleList() 299 | self.embed_dim = embed_dim 300 | last_channel = embed_dim 301 | anchor_points = self.points 302 | for i in range(len(pre_blocks)): 303 | out_channel = last_channel * dim_expansion[i] 304 | pre_block_num = pre_blocks[i] 305 | pos_block_num = pos_blocks[i] 306 | kneighbor = k_neighbors[i] 307 | reduce = reducers[i] 308 | anchor_points = anchor_points // reduce 309 | # append local_grouper_list 310 | local_grouper = LocalGrouper( 311 | last_channel, anchor_points, kneighbor, use_xyz, normalize 312 | ) # [b,g,k,d] 313 | self.local_grouper_list.append(local_grouper) 314 | # append pre_block_list 315 | pre_block_module = PreExtraction( 316 | last_channel, 317 | out_channel, 318 | pre_block_num, 319 | groups=groups, 320 | res_expansion=res_expansion, 321 | bias=bias, 322 | activation=activation, 323 | use_xyz=use_xyz, 324 | ) 325 | self.pre_blocks_list.append(pre_block_module) 326 | # append pos_block_list 327 | pos_block_module = PosExtraction( 328 | out_channel, 329 | pos_block_num, 330 | groups=groups, 331 | res_expansion=res_expansion, 332 | bias=bias, 333 | activation=activation, 334 | ) 335 | self.pos_blocks_list.append(pos_block_module) 336 | 337 | last_channel = out_channel 338 | 339 | self.act = get_activation(activation) 340 | self.classifier = nn.Sequential( 341 | nn.Linear(last_channel, 512), 342 | nn.BatchNorm1d(512), 343 | self.act, 344 | nn.Dropout(0.5), 345 | nn.Linear(512, 256), 346 | nn.BatchNorm1d(256), 347 | self.act, 348 | nn.Dropout(0.5), 349 | nn.Linear(256, self.class_num), 350 | ) 351 | self.get_embedding = nn.Linear(last_channel, self.embed_dim) 352 | self.use_classifier = classifier 353 | self.input_shape = input_shape 354 | 355 | def forward(self, x): 356 | if self.input_shape == 'bnc': 357 | x = x.permute(0, 2, 1) 358 | xyz = x.permute(0, 2, 1) 359 | batch_size, _, _ = x.size() 360 | x = self.embedding(x) # B,D,N 361 | for i in range(self.stages): 362 | # Give xyz[b, p, 3] and fea[b, p, d], return new_xyz[b, g, 3] and new_fea[b, g, k, d] 363 | xyz, x = self.local_grouper_list[i]( 364 | xyz, x.permute(0, 2, 1) 365 | ) # [b,g,3] [b,g,k,d] 366 | x = self.pre_blocks_list[i](x) # [b,d,g] 367 | x = self.pos_blocks_list[i](x) # [b,d,g] 368 | 369 | x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1) 370 | 371 | if self.use_classifier: 372 | x = self.classifier(x) 373 | else: 374 | x = self.get_embedding(x) 375 | return x 376 | 377 | 378 | def pointMLP(num_classes=40, embed_dim=64, classifier=False, **kwargs) -> Model: 379 | return Model( 380 | points=1024, 381 | class_num=num_classes, 382 | embed_dim=embed_dim, 383 | groups=1, 384 | res_expansion=1.0, 385 | activation='relu', 386 | bias=False, 387 | use_xyz=False, 388 | normalize='anchor', 389 | dim_expansion=[2, 2, 2, 2], 390 | pre_blocks=[2, 2, 2, 2], 391 | pos_blocks=[2, 2, 2, 2], 392 | k_neighbors=[24, 24, 24, 24], 393 | reducers=[2, 2, 2, 2], 394 | classifier=classifier, 395 | **kwargs, 396 | ) 397 | 398 | 399 | def pointMLPElite(num_classes=40, embed_dim=32, classifier=False, **kwargs) -> Model: 400 | return Model( 401 | points=1024, 402 | class_num=num_classes, 403 | embed_dim=embed_dim, 404 | groups=1, 405 | res_expansion=0.25, 406 | activation='relu', 407 | bias=False, 408 | use_xyz=False, 409 | normalize='anchor', 410 | dim_expansion=[2, 2, 2, 1], 411 | pre_blocks=[1, 1, 2, 1], 412 | pos_blocks=[1, 1, 2, 1], 413 | k_neighbors=[24, 24, 24, 24], 414 | reducers=[2, 2, 2, 2], 415 | classifier=classifier, 416 | **kwargs, 417 | ) 418 | -------------------------------------------------------------------------------- /executor/models/pointnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .pointnet import PointNet 2 | -------------------------------------------------------------------------------- /executor/models/pointnet/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..pooling import Pooling 5 | 6 | 7 | class PointNet(nn.Module): 8 | def __init__( 9 | self, 10 | emb_dims=1024, 11 | input_shape='bnc', 12 | use_bn=True, 13 | global_feat=True, 14 | num_classes=40, 15 | classifier=False, 16 | ): 17 | # emb_dims: Embedding Dimensions for PointNet. 18 | # input_shape: Shape of Input Point Cloud (b: batch, n: no of points, c: channels) 19 | super(PointNet, self).__init__() 20 | if input_shape not in ['bcn', 'bnc']: 21 | raise ValueError( 22 | "Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' " 23 | ) 24 | self.input_shape = input_shape 25 | self.emb_dims = emb_dims 26 | self.use_bn = use_bn 27 | self.global_feat = global_feat 28 | if self.global_feat: 29 | self.pooling = Pooling('max') 30 | 31 | self.layers = self.create_model() 32 | self.classifier = nn.Sequential( 33 | nn.Linear(self.emb_dims, 512), 34 | nn.BatchNorm1d(512), 35 | nn.ReLU(), 36 | nn.Linear(512, 256), 37 | nn.Dropout(0.4), 38 | nn.BatchNorm1d(256), 39 | nn.ReLU(), 40 | nn.Linear(256, num_classes), 41 | ) 42 | self.use_classifier = classifier 43 | 44 | def create_model(self): 45 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 46 | self.conv2 = torch.nn.Conv1d(64, 64, 1) 47 | self.conv3 = torch.nn.Conv1d(64, 64, 1) 48 | self.conv4 = torch.nn.Conv1d(64, 128, 1) 49 | self.conv5 = torch.nn.Conv1d(128, self.emb_dims, 1) 50 | self.relu = torch.nn.ReLU() 51 | 52 | if self.use_bn: 53 | self.bn1 = torch.nn.BatchNorm1d(64) 54 | self.bn2 = torch.nn.BatchNorm1d(64) 55 | self.bn3 = torch.nn.BatchNorm1d(64) 56 | self.bn4 = torch.nn.BatchNorm1d(128) 57 | self.bn5 = torch.nn.BatchNorm1d(self.emb_dims) 58 | 59 | if self.use_bn: 60 | layers = [ 61 | self.conv1, 62 | self.bn1, 63 | self.relu, 64 | self.conv2, 65 | self.bn2, 66 | self.relu, 67 | self.conv3, 68 | self.bn3, 69 | self.relu, 70 | self.conv4, 71 | self.bn4, 72 | self.relu, 73 | self.conv5, 74 | self.bn5, 75 | self.relu, 76 | ] 77 | else: 78 | layers = [ 79 | self.conv1, 80 | self.relu, 81 | self.conv2, 82 | self.relu, 83 | self.conv3, 84 | self.relu, 85 | self.conv4, 86 | self.relu, 87 | self.conv5, 88 | self.relu, 89 | ] 90 | return layers 91 | 92 | def forward(self, input_data): 93 | # input_data: Point Cloud having shape input_shape. 94 | # output: PointNet features (Batch x emb_dims) 95 | if self.input_shape == 'bnc': 96 | num_points = input_data.shape[1] 97 | input_data = input_data.permute(0, 2, 1) 98 | else: 99 | num_points = input_data.shape[2] 100 | if input_data.shape[1] != 3: 101 | raise RuntimeError('shape of x must be of [Batch x 3 x NumInPoints]') 102 | 103 | output = input_data 104 | for idx, layer in enumerate(self.layers): 105 | output = layer(output) 106 | if idx == 1 and not self.global_feat: 107 | point_feature = output 108 | 109 | # output = torch.max(output, 2, keepdim=True)[0] 110 | # output = output.view(-1, self.emb_dims) 111 | 112 | if self.global_feat: 113 | output = self.pooling(output) 114 | embedding = output 115 | else: 116 | # output = output.view(-1, self.emb_dims, 1).repeat(1, 1, num_points) 117 | # output = self.pooling(output) 118 | output = output.view(-1, self.emb_dims, 1).repeat(1, 1, num_points) 119 | embedding = torch.cat([output, point_feature], 1) 120 | 121 | if self.use_classifier: 122 | return self.classifier(embedding) 123 | else: 124 | return embedding 125 | -------------------------------------------------------------------------------- /executor/models/pointnet2/__init__.py: -------------------------------------------------------------------------------- 1 | from .pointnet2 import PointNet2 2 | -------------------------------------------------------------------------------- /executor/models/pointnet2/pointnet2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from ..pointconv import PointNetSetAbstraction, PointNetSetAbstractionMsg 5 | 6 | 7 | class PointNet2(nn.Module): 8 | def __init__( 9 | self, 10 | emb_dims=1024, 11 | input_shape='bnc', 12 | normal_channel=True, 13 | classifier=False, 14 | num_classes=40, 15 | density_adaptive_type='ssg', 16 | pretrained=None, 17 | ): 18 | super(PointNet2, self).__init__() 19 | 20 | if input_shape not in ['bnc', 'bcn']: 21 | raise ValueError( 22 | "Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' " 23 | ) 24 | 25 | self.emb_dims = emb_dims 26 | self.input_shape = input_shape 27 | self.classifier = classifier 28 | self.normal_channel = normal_channel 29 | 30 | if density_adaptive_type == 'ssg': 31 | if normal_channel: 32 | in_channel = 6 33 | else: 34 | in_channel = 3 35 | self.sa1 = PointNetSetAbstraction( 36 | npoint=512, 37 | radius=0.2, 38 | nsample=32, 39 | in_channel=in_channel, 40 | mlp=[64, 64, 128], 41 | group_all=False, 42 | ) 43 | self.sa2 = PointNetSetAbstraction( 44 | npoint=128, 45 | radius=0.4, 46 | nsample=64, 47 | in_channel=128 + 3, 48 | mlp=[128, 128, 256], 49 | group_all=False, 50 | ) 51 | self.sa3 = PointNetSetAbstraction( 52 | npoint=None, 53 | radius=None, 54 | nsample=None, 55 | in_channel=256 + 3, 56 | mlp=[256, 512, self.emb_dims], 57 | group_all=True, 58 | ) 59 | else: 60 | if normal_channel: 61 | in_channel = 3 62 | else: 63 | in_channel = 0 64 | self.sa1 = PointNetSetAbstractionMsg( 65 | npoint=512, 66 | radius_list=[0.1, 0.2, 0.4], 67 | nsample_list=[16, 32, 128], 68 | in_channel=in_channel, 69 | mlp_list=[[32, 32, 64], [64, 64, 128], [64, 96, 128]], 70 | ) 71 | self.sa2 = PointNetSetAbstractionMsg( 72 | npoint=128, 73 | radius_list=[0.2, 0.4, 0.8], 74 | nsample_list=[32, 64, 128], 75 | in_channel=320, 76 | mlp_list=[[64, 64, 128], [128, 128, 256], [128, 128, 256]], 77 | ) 78 | self.sa3 = PointNetSetAbstraction( 79 | npoint=None, 80 | radius=None, 81 | nsample=None, 82 | in_channel=640 + 3, 83 | mlp=[256, 512, self.emb_dims], 84 | group_all=True, 85 | ) 86 | 87 | self.fc1 = nn.Linear(self.emb_dims, 512) 88 | self.bn1 = nn.BatchNorm1d(512) 89 | self.drop1 = nn.Dropout(0.4) 90 | self.fc2 = nn.Linear(512, 256) 91 | self.bn2 = nn.BatchNorm1d(256) 92 | self.drop2 = nn.Dropout(0.4) 93 | self.fc3 = nn.Linear(256, num_classes) 94 | 95 | def forward(self, xyz): 96 | if self.input_shape == 'bnc': 97 | xyz = xyz.permute(0, 2, 1) 98 | batch_size = xyz.shape[0] 99 | if self.normal_channel: 100 | norm = xyz[:, 3:, :] 101 | xyz = xyz[:, :3, :] 102 | else: 103 | norm = None 104 | 105 | l1_xyz, l1_points = self.sa1(xyz, norm) 106 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 107 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 108 | x = l3_points.view(batch_size, self.emb_dims) 109 | 110 | if self.classifier: 111 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 112 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 113 | x = self.fc3(x) 114 | output = F.log_softmax(x, -1) 115 | else: 116 | output = x 117 | 118 | return output 119 | -------------------------------------------------------------------------------- /executor/models/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Pooling(torch.nn.Module): 5 | def __init__(self, pool_type='max'): 6 | super(Pooling, self).__init__() 7 | 8 | self.pool_type = pool_type 9 | 10 | def forward(self, input): 11 | if self.pool_type == 'max': 12 | return torch.max(input, 2)[0].contiguous() 13 | elif self.pool_type == 'avg' or self.pool_type == 'average': 14 | return torch.mean(input, 2).contiguous() 15 | -------------------------------------------------------------------------------- /executor/models/repsurf/__init__.py: -------------------------------------------------------------------------------- 1 | from .repsurf import RepSurf 2 | -------------------------------------------------------------------------------- /executor/models/repsurf/polar_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | ref: https://github.com/hancyran/RepSurf/blob/44b8da1a40/modules/polar_utils.py 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def xyz2sphere(xyz, normalize=True): 10 | """ 11 | Convert XYZ to Spherical Coordinate 12 | reference: https://en.wikipedia.org/wiki/Spherical_coordinate_system 13 | :param xyz: [B, N, 3] / [B, N, G, 3] 14 | :return: (rho, theta, phi) [B, N, 3] / [B, N, G, 3] 15 | """ 16 | rho = torch.sqrt(torch.sum(torch.pow(xyz, 2), dim=-1, keepdim=True)) 17 | rho = torch.clamp(rho, min=0) # range: [0, inf] 18 | theta = torch.acos(xyz[..., 2, None] / rho) # range: [0, pi] 19 | phi = torch.atan2(xyz[..., 1, None], xyz[..., 0, None]) # range: [-pi, pi] 20 | # check nan 21 | idx = rho == 0 22 | theta[idx] = 0 23 | 24 | if normalize: 25 | theta = theta / np.pi # [0, 1] 26 | phi = phi / (2 * np.pi) + 0.5 # [0, 1] 27 | out = torch.cat([rho, theta, phi], dim=-1) 28 | return out 29 | 30 | 31 | def xyz2cylind(xyz, normalize=True): 32 | """ 33 | Convert XYZ to Cylindrical Coordinate 34 | reference: https://en.wikipedia.org/wiki/Cylindrical_coordinate_system 35 | :param normalize: Normalize phi & z 36 | :param xyz: [B, N, 3] / [B, N, G, 3] 37 | :return: (rho, phi, z) [B, N, 3] 38 | """ 39 | rho = torch.sqrt(torch.sum(torch.pow(xyz[..., :2], 2), dim=-1, keepdim=True)) 40 | rho = torch.clamp(rho, 0, 1) # range: [0, 1] 41 | phi = torch.atan2(xyz[..., 1, None], xyz[..., 0, None]) # range: [-pi, pi] 42 | z = xyz[..., 2, None] 43 | z = torch.clamp(z, -1, 1) # range: [-1, 1] 44 | 45 | if normalize: 46 | phi = phi / (2 * np.pi) + 0.5 47 | z = (z + 1.0) / 2.0 48 | out = torch.cat([rho, phi, z], dim=-1) 49 | return 50 | -------------------------------------------------------------------------------- /executor/models/repsurf/recons_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | ref: https://github.com/hancyran/RepSurf/blob/44b8da1a40/modules/recons_utils.py 3 | """ 4 | 5 | """ 6 | Author: Haoxi Ran 7 | Date: 05/10/2022 8 | """ 9 | 10 | import torch 11 | from torch import nn 12 | 13 | from ..pointconv import index_points, knn_point 14 | 15 | 16 | def _recons_factory(type, cuda=False): 17 | if type == 'knn': 18 | return knn_recons 19 | # elif type == 'lknn': 20 | # return limited_knn_recons 21 | else: 22 | raise Exception('Not Implemented Reconstruction Type') 23 | 24 | 25 | def knn_recons(k, center, context, cuda=False): 26 | idx = knn_point(k, context, center) 27 | torch.cuda.empty_cache() 28 | 29 | group_xyz = index_points(context, idx) # [B, N, K, C] 30 | torch.cuda.empty_cache() 31 | return group_xyz 32 | 33 | 34 | def cal_normal(group_xyz, random_inv=False, is_group=False): 35 | """ 36 | Calculate Normal Vector (Unit Form + First Term Positive) 37 | :param group_xyz: [B, N, K=3, 3] / [B, N, G, K=3, 3] 38 | :param random_inv: 39 | :param return_intersect: 40 | :param return_const: 41 | :return: [B, N, 3] 42 | """ 43 | edge_vec1 = group_xyz[..., 1, :] - group_xyz[..., 0, :] # [B, N, 3] 44 | edge_vec2 = group_xyz[..., 2, :] - group_xyz[..., 0, :] # [B, N, 3] 45 | 46 | nor = torch.cross(edge_vec1, edge_vec2, dim=-1) 47 | unit_nor = nor / torch.norm(nor, dim=-1, keepdim=True) # [B, N, 3] / [B, N, G, 3] 48 | if not is_group: 49 | pos_mask = (unit_nor[..., 0] > 0).float() * 2.0 - 1.0 # keep x_n positive 50 | else: 51 | pos_mask = (unit_nor[..., 0:1, 0] > 0).float() * 2.0 - 1.0 52 | unit_nor = unit_nor * pos_mask.unsqueeze(-1) 53 | 54 | # batch-wise random inverse normal vector (prob: 0.5) 55 | if random_inv: 56 | random_mask = torch.randint(0, 2, (group_xyz.size(0), 1, 1)).float() * 2.0 - 1.0 57 | random_mask = random_mask.to(unit_nor.device) 58 | if not is_group: 59 | unit_nor = unit_nor * random_mask 60 | else: 61 | unit_nor = unit_nor * random_mask.unsqueeze(-1) 62 | 63 | return unit_nor 64 | 65 | 66 | def pca(X, k, center=True): 67 | """ 68 | Principal Components Analysis impl. with SVD function 69 | :param X: 70 | :param k: 71 | :param center: 72 | :return: 73 | """ 74 | 75 | n = X.size()[0] 76 | ones = torch.ones(n).view([n, 1]) 77 | h = ( 78 | ((1 / n) * torch.mm(ones, ones.t())) 79 | if center 80 | else torch.zeros(n * n).view([n, n]) 81 | ) 82 | H = torch.eye(n) - h 83 | X_center = torch.mm(H.double(), X.double()) 84 | u, s, v = torch.svd(X_center) 85 | components = v[:k].t() 86 | explained_variance = torch.mul(s[:k], s[:k]) / (n - 1) 87 | return { 88 | 'X': X, 89 | 'k': k, 90 | 'components': components, 91 | 'explained_variance': explained_variance, 92 | } 93 | 94 | 95 | def cal_center(group_xyz): 96 | """ 97 | Calculate Global Coordinates of the Center of Triangle 98 | :param group_xyz: [B, N, K, 3] / [B, N, G, K, 3]; K >= 3 99 | :return: [B, N, 3] / [B, N, G, 3] 100 | """ 101 | center = torch.mean(group_xyz, dim=-2) 102 | return center 103 | 104 | 105 | def cal_area(group_xyz): 106 | """ 107 | Calculate Area of Triangle 108 | :param group_xyz: [B, N, K, 3] / [B, N, G, K, 3]; K = 3 109 | :return: [B, N, 1] / [B, N, G, 1] 110 | """ 111 | pad_shape = group_xyz[..., 0, None].shape 112 | det_xy = torch.det( 113 | torch.cat( 114 | [group_xyz[..., 0, None], group_xyz[..., 1, None], torch.ones(pad_shape)], 115 | dim=-1, 116 | ) 117 | ) 118 | det_yz = torch.det( 119 | torch.cat( 120 | [group_xyz[..., 1, None], group_xyz[..., 2, None], torch.ones(pad_shape)], 121 | dim=-1, 122 | ) 123 | ) 124 | det_zx = torch.det( 125 | torch.cat( 126 | [group_xyz[..., 2, None], group_xyz[..., 0, None], torch.ones(pad_shape)], 127 | dim=-1, 128 | ) 129 | ) 130 | area = torch.sqrt(det_xy**2 + det_yz**2 + det_zx**2).unsqueeze(-1) 131 | return area 132 | 133 | 134 | def cal_const(normal, center, is_normalize=True): 135 | """ 136 | Calculate Constant Term (Standard Version, with x_normal to be 1) 137 | math:: 138 | const = x_nor * x_0 + y_nor * y_0 + z_nor * z_0 139 | :param is_normalize: 140 | :param normal: [B, N, 3] / [B, N, G, 3] 141 | :param center: [B, N, 3] / [B, N, G, 3] 142 | :return: [B, N, 1] / [B, N, G, 1] 143 | """ 144 | const = torch.sum(normal * center, dim=-1, keepdim=True) 145 | factor = torch.sqrt(torch.Tensor([3])).to(normal.device) 146 | const = const / factor if is_normalize else const 147 | 148 | return const 149 | 150 | 151 | def check_nan(normal, center, pos=None): 152 | """ 153 | Check & Remove NaN in normal tensor 154 | :param pos: [B, N, 1] 155 | :param center: [B, N, 3] 156 | :param normal: [B, N, 3] 157 | :return: 158 | """ 159 | B, N, _ = normal.shape 160 | mask = torch.sum(torch.isnan(normal), dim=-1) > 0 161 | mask_first = torch.argmax((~mask).int(), dim=-1) 162 | 163 | normal_first = normal[torch.arange(B), None, mask_first].repeat([1, N, 1]) 164 | normal[mask] = normal_first[mask] 165 | center_first = center[torch.arange(B), None, mask_first].repeat([1, N, 1]) 166 | center[mask] = center_first[mask] 167 | 168 | if pos is not None: 169 | pos_first = pos[torch.arange(B), None, mask_first].repeat([1, N, 1]) 170 | pos[mask] = pos_first[mask] 171 | return normal, center, pos 172 | return normal, center 173 | 174 | 175 | def check_nan_umb(normal, center, pos=None): 176 | """ 177 | Check & Remove NaN in normal tensor 178 | :param pos: [B, N, G, 1] 179 | :param center: [B, N, G, 3] 180 | :param normal: [B, N, G, 3] 181 | :return: 182 | """ 183 | B, N, G, _ = normal.shape 184 | mask = torch.sum(torch.isnan(normal), dim=-1) > 0 185 | mask_first = torch.argmax((~mask).int(), dim=-1) 186 | b_idx = torch.arange(B).unsqueeze(1).repeat([1, N]) 187 | n_idx = torch.arange(N).unsqueeze(0).repeat([B, 1]) 188 | 189 | normal_first = normal[b_idx, n_idx, None, mask_first].repeat([1, 1, G, 1]) 190 | normal[mask] = normal_first[mask] 191 | center_first = center[b_idx, n_idx, None, mask_first].repeat([1, 1, G, 1]) 192 | center[mask] = center_first[mask] 193 | 194 | if pos is not None: 195 | pos_first = pos[b_idx, n_idx, None, mask_first].repeat([1, 1, G, 1]) 196 | pos[mask] = pos_first[mask] 197 | return normal, center, pos 198 | return normal, center 199 | 200 | 201 | class SurfaceConstructor(nn.Module): 202 | """ 203 | Surface Constructor for Point Clouds 204 | Formulation of A Surface: 205 | A * (x - x_0) + B * (y - y_0) + C * (z - z_0) = 0, 206 | where A^2 + B^2 + C^2 = 1 & A > 0 207 | """ 208 | 209 | def __init__( 210 | self, 211 | r=None, 212 | k=3, 213 | recons_type='knn', 214 | return_dist=False, 215 | random_inv=True, 216 | cuda=False, 217 | ): 218 | super(SurfaceConstructor, self).__init__() 219 | self.K = k 220 | self.R = r 221 | self.recons = _recons_factory(recons_type) 222 | self.cuda = cuda 223 | 224 | self.return_dist = return_dist 225 | self.random_inv = random_inv 226 | 227 | def forward(self, center, context): 228 | """ 229 | Input: 230 | center: input points position as centroid points, [B, 3, N] 231 | context: input points position as context points, [B, 3, N'] 232 | Output: 233 | normal: normals of constructed triangles, [B, 3, N] 234 | center: centroids of constructed triangles, [B, 3, N] 235 | pos: position info of constructed triangles, [B, 1, N] 236 | """ 237 | center = center.permute(0, 2, 1) 238 | context = context.permute(0, 2, 1) 239 | 240 | group_xyz = self.recons(self.K, center, context, cuda=self.cuda) 241 | normal = cal_normal(group_xyz, random_inv=self.random_inv) 242 | center = cal_center(group_xyz) 243 | 244 | if self.return_dist: 245 | pos = cal_const(normal, center) 246 | normal, center, pos = check_nan(normal, center, pos) 247 | normal = normal.permute(0, 2, 1) 248 | center = center.permute(0, 2, 1) 249 | pos = pos.permute(0, 2, 1) 250 | return normal, center, pos 251 | 252 | normal, center = check_nan(normal, center) 253 | normal = normal.permute(0, 2, 1) 254 | center = center.permute(0, 2, 1) 255 | 256 | return normal, center 257 | -------------------------------------------------------------------------------- /executor/models/repsurf/repsurf.py: -------------------------------------------------------------------------------- 1 | """ref: https://github.com/hancyran/RepSurf/blob/main/models/repsurf/scanobjectnn/repsurf_ssg_umb.py 2 | """ 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .repsurf_utils import SurfaceAbstractionCD, UmbrellaSurfaceConstructor 9 | 10 | 11 | class RepSurf(torch.nn.Module): 12 | def __init__( 13 | self, 14 | num_points, 15 | return_center=True, 16 | return_polar=True, 17 | return_dist=True, 18 | group_size=8, 19 | umb_pool_type='sum', 20 | num_classes=40, 21 | input_shape='bnc', 22 | emb_dims=1024, 23 | classifier=False, 24 | ) -> None: 25 | super(RepSurf, self).__init__() 26 | 27 | if input_shape not in ['bnc', 'bcn']: 28 | raise ValueError( 29 | "Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' " 30 | ) 31 | self.input_shape = input_shape 32 | 33 | center_channel = 0 if not return_center else (6 if return_polar else 3) 34 | repsurf_channel = 10 35 | 36 | self.init_nsample = num_points 37 | self.return_dist = return_dist 38 | self.emb_dims = emb_dims 39 | 40 | self.surface_constructor = UmbrellaSurfaceConstructor( 41 | group_size + 1, 42 | repsurf_channel, 43 | return_dist=return_dist, 44 | aggr_type=umb_pool_type, 45 | cuda=False, 46 | ) 47 | 48 | self.sa1 = SurfaceAbstractionCD( 49 | npoint=512, 50 | radius=0.2, 51 | nsample=32, 52 | feat_channel=repsurf_channel, 53 | pos_channel=center_channel, 54 | mlp=[64, 64, 128], 55 | group_all=False, 56 | return_polar=return_polar, 57 | cuda=False, 58 | ) 59 | 60 | self.sa2 = SurfaceAbstractionCD( 61 | npoint=128, 62 | radius=0.4, 63 | nsample=64, 64 | feat_channel=128 + repsurf_channel, 65 | pos_channel=center_channel, 66 | mlp=[128, 128, 256], 67 | group_all=False, 68 | return_polar=return_polar, 69 | cuda=False, 70 | ) 71 | 72 | self.sa3 = SurfaceAbstractionCD( 73 | npoint=None, 74 | radius=None, 75 | nsample=None, 76 | feat_channel=256 + repsurf_channel, 77 | pos_channel=center_channel, 78 | mlp=[256, 512, self.emb_dims], 79 | group_all=True, 80 | return_polar=return_polar, 81 | cuda=False, 82 | ) 83 | 84 | # modelnet40 85 | self.use_classifier = classifier 86 | self.classfier = nn.Sequential( 87 | nn.Linear(self.emb_dims, 512), 88 | nn.BatchNorm1d(512), 89 | nn.ReLU(True), 90 | nn.Dropout(0.4), 91 | nn.Linear(512, 256), 92 | nn.BatchNorm1d(256), 93 | nn.ReLU(True), 94 | nn.Dropout(0.4), 95 | nn.Linear(256, num_classes), 96 | ) 97 | 98 | def forward(self, points): 99 | if self.input_shape == 'bnc': 100 | points = points.permute(0, 2, 1) 101 | 102 | center = points[:, :3, :] 103 | 104 | normal = self.surface_constructor(center) 105 | 106 | center, normal, feature = self.sa1(center, normal, None) 107 | center, normal, feature = self.sa2(center, normal, feature) 108 | center, normal, feature = self.sa3(center, normal, feature) 109 | 110 | feature = feature.view(-1, self.emb_dims) 111 | 112 | if self.use_classifier: 113 | feature = self.classfier(feature) 114 | feature = F.log_softmax(feature, -1) 115 | 116 | return feature 117 | -------------------------------------------------------------------------------- /executor/models/repsurf/repsurf_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from ..pointconv import farthest_point_sample, index_points, knn_point, query_ball_point 10 | from .polar_utils import xyz2sphere 11 | from .recons_utils import cal_center, cal_const, cal_normal, check_nan_umb 12 | 13 | 14 | def sample_and_group( 15 | npoint, 16 | radius, 17 | nsample, 18 | center, 19 | normal, 20 | feature, 21 | return_normal=True, 22 | return_polar=False, 23 | cuda=False, 24 | ): 25 | """ 26 | Input: 27 | center: input points position data 28 | normal: input points normal data 29 | feature: input points feature 30 | Return: 31 | new_center: sampled points position data 32 | new_normal: sampled points normal data 33 | new_feature: sampled points feature 34 | """ 35 | # sample 36 | fps_idx = farthest_point_sample(center, npoint) # [B, npoint, A] 37 | torch.cuda.empty_cache() 38 | # sample center 39 | new_center = index_points(center, fps_idx) 40 | torch.cuda.empty_cache() 41 | # sample normal 42 | new_normal = index_points(normal, fps_idx) 43 | torch.cuda.empty_cache() 44 | 45 | # group 46 | idx = query_ball_point(radius, nsample, center, new_center) 47 | torch.cuda.empty_cache() 48 | # group normal 49 | group_normal = index_points(normal, idx) # [B, npoint, nsample, B] 50 | torch.cuda.empty_cache() 51 | # group center 52 | group_center = index_points(center, idx) # [B, npoint, nsample, A] 53 | torch.cuda.empty_cache() 54 | group_center_norm = group_center - new_center.unsqueeze(2) 55 | torch.cuda.empty_cache() 56 | 57 | # group polar 58 | if return_polar: 59 | group_polar = xyz2sphere(group_center_norm) 60 | group_center_norm = torch.cat([group_center_norm, group_polar], dim=-1) 61 | if feature is not None: 62 | group_feature = index_points(feature, idx) 63 | new_feature = ( 64 | torch.cat([group_center_norm, group_normal, group_feature], dim=-1) 65 | if return_normal 66 | else torch.cat([group_center_norm, group_feature], dim=-1) 67 | ) 68 | else: 69 | new_feature = torch.cat([group_center_norm, group_normal], dim=-1) 70 | 71 | return new_center, new_normal, new_feature 72 | 73 | 74 | def sample_and_group_all( 75 | center, normal, feature, return_normal=True, return_polar=False 76 | ): 77 | """ 78 | Input: 79 | center: input centroid position data 80 | normal: input normal data 81 | feature: input feature data 82 | Return: 83 | new_center: sampled points position data 84 | new_normal: sampled points position data 85 | new_feature: sampled points data 86 | """ 87 | device = center.device 88 | B, N, C = normal.shape 89 | 90 | new_center = torch.zeros(B, 1, 3).to(device) 91 | new_normal = new_center 92 | 93 | group_normal = normal.view(B, 1, N, C) 94 | group_center = center.view(B, 1, N, 3) 95 | if return_polar: 96 | group_polar = xyz2sphere(group_center) 97 | group_center = torch.cat([group_center, group_polar], dim=-1) 98 | 99 | new_feature = ( 100 | torch.cat([group_center, group_normal, feature.view(B, 1, N, -1)], dim=-1) 101 | if return_normal 102 | else torch.cat([group_center, feature.view(B, 1, N, -1)], dim=-1) 103 | ) 104 | 105 | return new_center, new_normal, new_feature 106 | 107 | 108 | def resort_points(points, idx): 109 | """ 110 | Resort Set of points along G dim 111 | """ 112 | device = points.device 113 | B, N, G, _ = points.shape 114 | 115 | view_shape = [B, 1, 1] 116 | repeat_shape = [1, N, G] 117 | b_indices = ( 118 | torch.arange(B, dtype=torch.long) 119 | .to(device) 120 | .view(view_shape) 121 | .repeat(repeat_shape) 122 | ) 123 | 124 | view_shape = [1, N, 1] 125 | repeat_shape = [B, 1, G] 126 | n_indices = ( 127 | torch.arange(N, dtype=torch.long) 128 | .to(device) 129 | .view(view_shape) 130 | .repeat(repeat_shape) 131 | ) 132 | 133 | new_points = points[b_indices, n_indices, idx, :] 134 | 135 | return new_points 136 | 137 | 138 | def group_by_umbrella(xyz, new_xyz, k=9, cuda=False): 139 | """ 140 | Group a set of points into umbrella surfaces 141 | """ 142 | idx = knn_point(k, xyz, new_xyz) 143 | torch.cuda.empty_cache() 144 | group_xyz = index_points(xyz, idx)[:, :, 1:] # [B, N', K-1, 3] 145 | torch.cuda.empty_cache() 146 | 147 | group_xyz_norm = group_xyz - new_xyz.unsqueeze(-2) 148 | group_phi = xyz2sphere(group_xyz_norm)[..., 2] # [B, N', K-1] 149 | sort_idx = group_phi.argsort(dim=-1) # [B, N', K-1] 150 | 151 | # [B, N', K-1, 1, 3] 152 | sorted_group_xyz = resort_points(group_xyz_norm, sort_idx).unsqueeze(-2) 153 | sorted_group_xyz_roll = torch.roll(sorted_group_xyz, -1, dims=-3) 154 | group_centriod = torch.zeros_like(sorted_group_xyz) 155 | umbrella_group_xyz = torch.cat( 156 | [group_centriod, sorted_group_xyz, sorted_group_xyz_roll], dim=-2 157 | ) 158 | 159 | return umbrella_group_xyz 160 | 161 | 162 | class SurfaceAbstraction(nn.Module): 163 | """ 164 | Surface Abstraction Module 165 | """ 166 | 167 | def __init__( 168 | self, 169 | npoint, 170 | radius, 171 | nsample, 172 | in_channel, 173 | mlp, 174 | group_all, 175 | return_polar=True, 176 | return_normal=True, 177 | cuda=False, 178 | ): 179 | super(SurfaceAbstraction, self).__init__() 180 | self.npoint = npoint 181 | self.radius = radius 182 | self.nsample = nsample 183 | self.return_normal = return_normal 184 | self.return_polar = return_polar 185 | self.cuda = cuda 186 | self.group_all = group_all 187 | self.mlp_convs = nn.ModuleList() 188 | self.mlp_bns = nn.ModuleList() 189 | 190 | last_channel = in_channel 191 | for out_channel in mlp: 192 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 193 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 194 | last_channel = out_channel 195 | 196 | def forward(self, center, normal, feature): 197 | normal = normal.permute(0, 2, 1) 198 | center = center.permute(0, 2, 1) 199 | if feature is not None: 200 | feature = feature.permute(0, 2, 1) 201 | 202 | if self.group_all: 203 | new_center, new_normal, new_feature = sample_and_group_all( 204 | center, 205 | normal, 206 | feature, 207 | return_polar=self.return_polar, 208 | return_normal=self.return_normal, 209 | ) 210 | else: 211 | new_center, new_normal, new_feature = sample_and_group( 212 | self.npoint, 213 | self.radius, 214 | self.nsample, 215 | center, 216 | normal, 217 | feature, 218 | return_polar=self.return_polar, 219 | return_normal=self.return_normal, 220 | cuda=self.cuda, 221 | ) 222 | 223 | new_feature = new_feature.permute(0, 3, 2, 1) 224 | for i, conv in enumerate(self.mlp_convs): 225 | bn = self.mlp_bns[i] 226 | new_feature = F.relu(bn(conv(new_feature))) 227 | new_feature = torch.max(new_feature, 2)[0] 228 | 229 | new_center = new_center.permute(0, 2, 1) 230 | new_normal = new_normal.permute(0, 2, 1) 231 | 232 | return new_center, new_normal, new_feature 233 | 234 | 235 | class SurfaceAbstractionCD(nn.Module): 236 | """ 237 | Surface Abstraction Module 238 | """ 239 | 240 | def __init__( 241 | self, 242 | npoint, 243 | radius, 244 | nsample, 245 | feat_channel, 246 | pos_channel, 247 | mlp, 248 | group_all, 249 | return_normal=True, 250 | return_polar=False, 251 | cuda=False, 252 | ): 253 | super(SurfaceAbstractionCD, self).__init__() 254 | self.npoint = npoint 255 | self.radius = radius 256 | self.nsample = nsample 257 | self.return_normal = return_normal 258 | self.return_polar = return_polar 259 | self.cuda = cuda 260 | self.mlp_convs = nn.ModuleList() 261 | self.mlp_bns = nn.ModuleList() 262 | self.pos_channel = pos_channel 263 | self.group_all = group_all 264 | 265 | self.mlp_l0 = nn.Conv2d(self.pos_channel, mlp[0], 1) 266 | self.mlp_f0 = nn.Conv2d(feat_channel, mlp[0], 1) 267 | self.bn_l0 = nn.BatchNorm2d(mlp[0]) 268 | self.bn_f0 = nn.BatchNorm2d(mlp[0]) 269 | 270 | # mlp_l0+mlp_f0 can be considered as the first layer of mlp_convs 271 | last_channel = mlp[0] 272 | for out_channel in mlp[1:]: 273 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 274 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 275 | last_channel = out_channel 276 | 277 | def forward(self, center, normal, feature): 278 | normal = normal.permute(0, 2, 1) 279 | center = center.permute(0, 2, 1) 280 | if feature is not None: 281 | feature = feature.permute(0, 2, 1) 282 | 283 | if self.group_all: 284 | new_center, new_normal, new_feature = sample_and_group_all( 285 | center, 286 | normal, 287 | feature, 288 | return_normal=self.return_normal, 289 | return_polar=self.return_polar, 290 | ) 291 | else: 292 | new_center, new_normal, new_feature = sample_and_group( 293 | self.npoint, 294 | self.radius, 295 | self.nsample, 296 | center, 297 | normal, 298 | feature, 299 | return_normal=self.return_normal, 300 | return_polar=self.return_polar, 301 | cuda=self.cuda, 302 | ) 303 | 304 | new_feature = new_feature.permute(0, 3, 2, 1) 305 | 306 | # init layer 307 | loc = self.bn_l0(self.mlp_l0(new_feature[:, : self.pos_channel])) 308 | feat = self.bn_f0(self.mlp_f0(new_feature[:, self.pos_channel :])) 309 | new_feature = loc + feat 310 | new_feature = F.relu(new_feature) 311 | 312 | for i, conv in enumerate(self.mlp_convs): 313 | bn = self.mlp_bns[i] 314 | new_feature = F.relu(bn(conv(new_feature))) 315 | new_feature = torch.max(new_feature, 2)[0] 316 | 317 | new_center = new_center.permute(0, 2, 1) 318 | new_normal = new_normal.permute(0, 2, 1) 319 | 320 | return new_center, new_normal, new_feature 321 | 322 | 323 | class UmbrellaSurfaceConstructor(nn.Module): 324 | """ 325 | Umbrella-based Surface Abstraction Module 326 | """ 327 | 328 | def __init__( 329 | self, 330 | k, 331 | in_channel, 332 | aggr_type='sum', 333 | return_dist=False, 334 | random_inv=True, 335 | cuda=False, 336 | ): 337 | super(UmbrellaSurfaceConstructor, self).__init__() 338 | self.k = k 339 | self.return_dist = return_dist 340 | self.random_inv = random_inv 341 | self.aggr_type = aggr_type 342 | self.cuda = cuda 343 | 344 | self.mlps = nn.Sequential( 345 | nn.Conv2d(in_channel, in_channel, 1, bias=False), 346 | nn.BatchNorm2d(in_channel), 347 | nn.ReLU(True), 348 | nn.Conv2d(in_channel, in_channel, 1, bias=True), 349 | nn.BatchNorm2d(in_channel), 350 | nn.ReLU(True), 351 | nn.Conv2d(in_channel, in_channel, 1, bias=True), 352 | ) 353 | 354 | def forward(self, center): 355 | center = center.permute(0, 2, 1) 356 | # surface construction 357 | group_xyz = group_by_umbrella( 358 | center, center, k=self.k, cuda=self.cuda 359 | ) # [B, N, K-1, 3 (points), 3 (coord.)] 360 | 361 | # normal 362 | group_normal = cal_normal(group_xyz, random_inv=self.random_inv, is_group=True) 363 | # coordinate 364 | group_center = cal_center(group_xyz) 365 | # polar 366 | group_polar = xyz2sphere(group_center) 367 | if self.return_dist: 368 | group_pos = cal_const(group_normal, group_center) 369 | group_normal, group_center, group_pos = check_nan_umb( 370 | group_normal, group_center, group_pos 371 | ) 372 | # new_feature = torch.cat([group_normal], dim=-1) # N: 3 373 | # new_feature = torch.cat([group_normal, group_pos], dim=-1) # N+P: 4 374 | # new_feature = torch.cat([group_center, group_normal], dim=-1) # N+C: 6 375 | # new_feature = torch.cat([group_center, group_normal, group_pos], dim=-1) # N+P+C: 7 376 | new_feature = torch.cat( 377 | [group_center, group_polar, group_normal, group_pos], dim=-1 378 | ) # N+P+CP: 10 379 | else: 380 | group_normal, group_center = check_nan_umb(group_normal, group_center) 381 | new_feature = torch.cat([group_center, group_polar, group_normal], dim=-1) 382 | new_feature = new_feature.permute(0, 3, 2, 1) # [B, C, G, N] 383 | 384 | # mapping 385 | new_feature = self.mlps(new_feature) 386 | 387 | # aggregation 388 | if self.aggr_type == 'max': 389 | new_feature = torch.max(new_feature, 2)[0] 390 | elif self.aggr_type == 'avg': 391 | new_feature = torch.mean(new_feature, 2) 392 | else: 393 | new_feature = torch.sum(new_feature, 2) 394 | 395 | return new_feature 396 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | __copyright__ = 'Copyright (c) 2022 Jina AI Limited. All rights reserved.' 2 | __license__ = 'Apache-2.0' 3 | 4 | import pathlib 5 | from functools import partial 6 | 7 | import click 8 | import finetuner 9 | import numpy as np 10 | import torch 11 | from docarray import Document, DocumentArray 12 | from finetuner.tuner.callback import BestModelCheckpoint 13 | from finetuner.tuner.pytorch.losses import TripletLoss 14 | from finetuner.tuner.pytorch.miner import TripletEasyHardMiner 15 | 16 | from executor.models import MeshDataModel 17 | 18 | 19 | def random_sample(pc, num): 20 | permutation = np.arange(len(pc)) 21 | np.random.shuffle(permutation) 22 | pc = np.array(pc).astype('float32') 23 | pc = pc[permutation[:num]] 24 | return pc 25 | 26 | 27 | def preprocess(doc: 'Document', num_points: int = 1024, data_aug: bool = True): 28 | points = random_sample(doc.tensor, num_points) 29 | # points = np.transpose(points) 30 | 31 | points = points - np.expand_dims(np.mean(points, axis=0), 0) # center 32 | dist = np.max(np.sqrt(np.sum(points**2, axis=1)), 0) 33 | points = points / dist # scale 34 | 35 | if data_aug: 36 | theta = np.random.uniform(0, np.pi * 2) 37 | rotation_matrix = np.array( 38 | [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] 39 | ) 40 | points[:, [0, 2]] = points[:, [0, 2]].dot(rotation_matrix) # random rotation 41 | points += np.random.normal(0, 0.02, size=points.shape) # random jitter 42 | doc.tensor = points 43 | return doc 44 | 45 | 46 | @click.command() 47 | @click.option('--train_dataset', help='The training dataset file path') 48 | @click.option('--eval_dataset', help='The evaluation dataset file path') 49 | @click.option('--embed_dim', default=512, help='The embedding dimension') 50 | @click.option('--restore_from', help='The restore checkpoitn path of pretrained model') 51 | @click.option( 52 | '--checkpoint_dir', 53 | default='checkpoints', 54 | type=click.Path(file_okay=False, path_type=pathlib.Path), 55 | help='The directory of checkpoints', 56 | ) 57 | @click.option('--model_name', default='pointnet', help='The model name') 58 | @click.option('--batch_size', default=128, help='The pretrained clip model path') 59 | @click.option('--epochs', default=50, help='The pretrained clip model path') 60 | @click.option('--use-gpu/--no-use-gpu', default=True, help='If True to use gpu') 61 | @click.option( 62 | '--interactive', default=False, help='set to True if you have unlabeled data' 63 | ) 64 | def main( 65 | train_dataset, 66 | eval_dataset, 67 | model_name, 68 | embed_dim, 69 | batch_size, 70 | epochs, 71 | use_gpu, 72 | restore_from, 73 | checkpoint_dir, 74 | interactive, 75 | ): 76 | model = MeshDataModel(model_name=model_name, embed_dim=embed_dim) 77 | if restore_from: 78 | print(f'==> restore from: {restore_from}') 79 | ckpt = torch.load(checkpoint_dir, map_location='cpu') 80 | model.load_state_dict(ckpt) 81 | 82 | train_da = DocumentArray.load_binary(train_dataset) 83 | eval_da = DocumentArray.load_binary(eval_dataset) if eval_dataset else None 84 | 85 | def configure_optimizer(model): 86 | from torch.optim import Adam 87 | from torch.optim.lr_scheduler import MultiStepLR 88 | 89 | optimizer = Adam(model.parameters(), lr=5e-4) 90 | scheduler = MultiStepLR(optimizer, milestones=[30, 60], gamma=0.5) 91 | 92 | return optimizer, scheduler 93 | 94 | checkpoint_dir.mkdir(parents=True, exist_ok=True) 95 | ckpt_callback = BestModelCheckpoint(str(checkpoint_dir)) 96 | 97 | tuned_model = finetuner.fit( 98 | model, 99 | train_da, 100 | eval_data=eval_da, 101 | preprocess_fn=partial(preprocess, num_points=1024, data_aug=True), 102 | epochs=epochs, 103 | batch_size=batch_size, 104 | loss=TripletLoss( 105 | miner=TripletEasyHardMiner(pos_strategy='easy', neg_strategy='semihard') 106 | ), 107 | configure_optimizer=configure_optimizer, 108 | num_items_per_class=8, 109 | learning_rate=5e-4, 110 | device='cuda' if use_gpu else 'cpu', 111 | callbacks=[ckpt_callback], 112 | interactive=interactive, 113 | ) 114 | 115 | torch.save( 116 | tuned_model.state_dict(), 117 | str(checkpoint_dir / f'finetuned-{model_name}-d{embed_dim}.pth'), 118 | ) 119 | 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /finetune_pl.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import time 3 | 4 | import click 5 | import torch 6 | from pytorch_lightning import Trainer 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning.loggers import TensorBoardLogger 9 | from torch.utils.data import DataLoader, random_split 10 | 11 | from datasets import ModelNet40 12 | from executor import MeshDataEncoderPL 13 | 14 | 15 | @click.command() 16 | @click.option('--train_dataset', help='The training dataset file path') 17 | @click.option( 18 | '--split_ratio', 19 | default=0.8, 20 | help='The proportion of training samples out of the whole training dataset', 21 | ) 22 | @click.option('--eval_dataset', help='The evaluation dataset file path') 23 | @click.option( 24 | '--embed_dim', default=512, help='The embedding dimension of the final outputs' 25 | ) 26 | @click.option('--hidden_dim', default=1024, help='The dimension of the used models') 27 | @click.option( 28 | '--checkpoint_path', 29 | type=click.Path(file_okay=True, path_type=pathlib.Path), 30 | help='The path of checkpoint', 31 | ) 32 | @click.option( 33 | '--output_path', 34 | type=click.Path(file_okay=True, path_type=pathlib.Path), 35 | help='The path of output files', 36 | ) 37 | @click.option( 38 | '--model_name', 39 | default='pointnet', 40 | type=click.Choice( 41 | ['pointnet', 'pointnet2', 'curvenet', 'pointmlp', 'pointconv', 'repsurf'] 42 | ), 43 | help='The model name', 44 | ) 45 | @click.option('--batch_size', default=128, help='The size of each batch') 46 | @click.option('--epochs', default=50, help='The epochs of training process') 47 | @click.option('--use-gpu/--no-use-gpu', default=True, help='If True to use gpu') 48 | @click.option( 49 | '--interactive', default=False, help='set to True if you have unlabeled data' 50 | ) 51 | @click.option( 52 | '--devices', default=7, help='The number of gpus/tpus you can use for training' 53 | ) 54 | @click.option('--seed', default=10, help='The random seed for reproducing results') 55 | def main( 56 | train_dataset, 57 | split_ratio, 58 | eval_dataset, 59 | model_name, 60 | embed_dim, 61 | hidden_dim, 62 | batch_size, 63 | epochs, 64 | use_gpu, 65 | checkpoint_path, 66 | output_path, 67 | interactive, 68 | devices, 69 | seed, 70 | ): 71 | seed = int(time.time()) 72 | 73 | torch.manual_seed(seed) 74 | torch.cuda.manual_seed_all(seed) 75 | 76 | if use_gpu: 77 | device = 'cuda' 78 | else: 79 | device = 'cpu' 80 | 81 | if checkpoint_path: 82 | model = MeshDataEncoderPL.load_from_checkpoint( 83 | checkpoint_path, map_location=device 84 | ) 85 | else: 86 | model = MeshDataEncoderPL( 87 | default_model_name=model_name, 88 | embed_dim=embed_dim, 89 | device=device, 90 | hidden_dim=hidden_dim, 91 | batch_size=batch_size, 92 | ) 93 | 94 | train_and_val_data = ModelNet40(train_dataset, seed=seed) 95 | tot_len = len(train_and_val_data) 96 | train_len = int(tot_len * split_ratio) 97 | validate_len = tot_len - train_len 98 | train_data, validate_data = random_split( 99 | train_and_val_data, [train_len, validate_len] 100 | ) 101 | test_data = ModelNet40(eval_dataset, seed=seed) 102 | 103 | # drop_last=True, avoid batch=1 error from BatchNorm 104 | train_loader = DataLoader( 105 | train_data, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True 106 | ) 107 | validate_loader = DataLoader( 108 | validate_data, 109 | batch_size=batch_size, 110 | shuffle=False, 111 | num_workers=8, 112 | drop_last=True, 113 | ) 114 | 115 | test_loader = DataLoader( 116 | test_data, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=True 117 | ) 118 | 119 | logger = TensorBoardLogger( 120 | save_dir='./finetune_logs' if output_path is None else output_path, 121 | log_graph=True, 122 | name='{}_hidden_{}_embed_{}_batch_{}_epochs_{}_seed_{}'.format( 123 | model_name, hidden_dim, embed_dim, batch_size, epochs, seed 124 | ), 125 | ) 126 | 127 | checkpoint_callback = ModelCheckpoint( 128 | save_top_k=5, 129 | monitor='val_loss', 130 | mode='min', 131 | filename='{epoch:02d}-{val_loss:.2f}', 132 | ) 133 | 134 | trainer = Trainer( 135 | accelerator='gpu' if use_gpu else 'cpu', 136 | devices=devices, 137 | max_epochs=epochs, 138 | check_val_every_n_epoch=1, 139 | enable_checkpointing=True, 140 | logger=logger, 141 | callbacks=[checkpoint_callback], 142 | ) 143 | model.train() 144 | trainer.fit(model, train_loader, validate_loader) 145 | print(checkpoint_callback.best_model_path) 146 | 147 | model.eval() 148 | print('Validation set:') 149 | trainer.test(model, dataloaders=validate_loader) 150 | print('Testing set:') 151 | trainer.test(model, dataloaders=test_loader) 152 | 153 | 154 | if __name__ == '__main__': 155 | main() 156 | -------------------------------------------------------------------------------- /gpu_requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf>=3.20.0 2 | pytest 3 | torch==1.12.0 4 | -------------------------------------------------------------------------------- /manifest.yml: -------------------------------------------------------------------------------- 1 | manifest_version: 1 2 | name: MeshDataEncoder 3 | description: An executor that loads 3D mesh models and embeds documents. 4 | url: https://github.com/jina-ai/executor-3d-encoder 5 | keywords: [encoder, 3D-Mesh, pytorch, 3.0-exclusive] 6 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import preprocess, random_sample 2 | -------------------------------------------------------------------------------- /preprocess/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def random_sample(pc, num): 5 | permutation = np.arange(len(pc)) 6 | np.random.shuffle(permutation) 7 | pc = np.array(pc).astype('float32') 8 | pc = pc[permutation[:num]] 9 | return pc 10 | 11 | 12 | def preprocess(points, num_points: int = 1024, data_aug: bool = True): 13 | points = random_sample(points, num_points) 14 | # points = np.transpose(points) 15 | 16 | points = points - np.expand_dims(np.mean(points, axis=0), 0) # center 17 | dist = np.max(np.sqrt(np.sum(points**2, axis=1)), 0) 18 | points = points / dist # scale 19 | 20 | if data_aug: 21 | theta = np.random.uniform(0, np.pi * 2) 22 | rotation_matrix = np.array( 23 | [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] 24 | ) 25 | points[:, [0, 2]] = points[:, [0, 2]].dot(rotation_matrix) # random rotation 26 | points += np.random.normal(0, 0.02, size=points.shape) # random jitter 27 | return points 28 | -------------------------------------------------------------------------------- /pretrain_pl.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import time 3 | 4 | import click 5 | import torch 6 | from pytorch_lightning import Trainer 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning.loggers import TensorBoardLogger 9 | from torch.utils.data import DataLoader, random_split 10 | 11 | from datasets import ModelNet40 12 | from executor import MeshDataClassifierPL 13 | 14 | 15 | @click.command() 16 | @click.option('--train_dataset', help='The training dataset file path') 17 | @click.option( 18 | '--split_ratio', 19 | default=0.8, 20 | help='The proportion of training samples out of the whole training dataset', 21 | ) 22 | @click.option('--eval_dataset', help='The evaluation dataset file path') 23 | @click.option('--hidden_dim', default=1024, help='The dimension of the used models') 24 | @click.option( 25 | '--checkpoint_path', 26 | type=click.Path(file_okay=True, path_type=pathlib.Path), 27 | help='The path of checkpoint', 28 | ) 29 | @click.option( 30 | '--output_path', 31 | type=click.Path(file_okay=True, path_type=pathlib.Path), 32 | help='The path of output files', 33 | ) 34 | @click.option('--model_name', default='pointnet', help='The model name') 35 | @click.option('--batch_size', default=128, help='The size of each batch') 36 | @click.option('--epochs', default=50, help='The epochs of training process') 37 | @click.option('--use-gpu/--no-use-gpu', default=False, help='If True to use gpu') 38 | @click.option( 39 | '--devices', default=7, help='The number of gpus/tpus you can use for training' 40 | ) 41 | @click.option('--seed', default=10, help='The random seed for reproducing results') 42 | def main( 43 | train_dataset, 44 | split_ratio, 45 | eval_dataset, 46 | model_name, 47 | hidden_dim, 48 | batch_size, 49 | epochs, 50 | use_gpu, 51 | checkpoint_path, 52 | output_path, 53 | devices, 54 | seed, 55 | ): 56 | seed = int(time.time()) 57 | 58 | torch.manual_seed(seed) 59 | torch.cuda.manual_seed_all(seed) 60 | 61 | if use_gpu: 62 | device = 'cuda' 63 | else: 64 | device = 'cpu' 65 | 66 | if checkpoint_path: 67 | model = MeshDataClassifierPL.load_from_checkpoint( 68 | checkpoint_path, map_location=device 69 | ) 70 | else: 71 | model = MeshDataClassifierPL( 72 | model_name=model_name, 73 | device=device, 74 | hidden_dim=hidden_dim, 75 | batch_size=batch_size, 76 | ) 77 | 78 | train_and_val_data = ModelNet40(train_dataset, seed=seed) 79 | tot_len = len(train_and_val_data) 80 | train_len = int(tot_len * split_ratio) 81 | validate_len = tot_len - train_len 82 | train_data, validate_data = random_split( 83 | train_and_val_data, [train_len, validate_len] 84 | ) 85 | test_data = ModelNet40(eval_dataset, seed=seed) 86 | 87 | # drop_last=True, avoid batch=1 error from BatchNorm 88 | train_loader = DataLoader( 89 | train_data, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True 90 | ) 91 | validate_loader = DataLoader( 92 | validate_data, 93 | batch_size=batch_size, 94 | shuffle=False, 95 | num_workers=8, 96 | drop_last=True, 97 | ) 98 | 99 | test_loader = DataLoader( 100 | test_data, batch_size=batch_size, shuffle=False, num_workers=8 101 | ) 102 | 103 | logger = TensorBoardLogger( 104 | save_dir='./logs' if output_path is None else output_path, 105 | log_graph=True, 106 | name='{}_dim_{}_batch_{}_epochs_{}_seed_{}'.format( 107 | model_name, hidden_dim, batch_size, epochs, seed 108 | ), 109 | ) 110 | 111 | checkpoint_callback = ModelCheckpoint( 112 | save_top_k=5, 113 | monitor='val_loss', 114 | mode='min', 115 | filename='{epoch:02d}-{val_loss:.2f}-{val_acc:.4f}', 116 | ) 117 | 118 | trainer = Trainer( 119 | accelerator='gpu' if use_gpu else 'cpu', 120 | devices=devices, 121 | max_epochs=epochs, 122 | check_val_every_n_epoch=1, 123 | enable_checkpointing=True, 124 | logger=logger, 125 | callbacks=[checkpoint_callback], 126 | gradient_clip_val=1.0, 127 | ) 128 | model.train() 129 | trainer.fit(model, train_loader, validate_loader) 130 | print(checkpoint_callback.best_model_path) 131 | 132 | model.eval() 133 | print('Validation set:') 134 | trainer.test(model, dataloaders=validate_loader) 135 | print('Testing set:') 136 | trainer.test(model, dataloaders=test_loader) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/torch_stable.html 2 | finetuner==0.4.1 3 | protobuf>=3.20.0 4 | pytest 5 | pytorch_lightning==1.6.5 6 | torch==1.12.0+cpu 7 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jina-ai/executor-3d-encoder/d9a39ddcbf6e2efd6f7b894f07289af68d0d9714/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class RandomDataset(Dataset): 7 | def __init__(self, n_samples, n_points=1024, n_classes=40) -> None: 8 | super().__init__() 9 | self.points = np.random.random((n_samples, n_points, 3)) 10 | self.labels = np.random.randint(n_classes, size=(n_samples)) 11 | 12 | def __len__(self): 13 | return len(self.labels) 14 | 15 | def __getitem__(self, index): 16 | return ( 17 | self.points[index, :, :], 18 | self.labels[index], 19 | ) 20 | 21 | 22 | def create_torch_dataset(n_samples, n_points, n_classes=40): 23 | return RandomDataset(n_samples, n_points, n_classes) 24 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jina-ai/executor-3d-encoder/d9a39ddcbf6e2efd6f7b894f07289af68d0d9714/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/test_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from jina import DocumentArray, Flow 4 | from pytorch_lightning import Trainer 5 | from torch.utils.data import DataLoader, random_split 6 | 7 | from executor import MeshDataEncoder, MeshDataEncoderPL 8 | from tests.conftest import create_torch_dataset 9 | 10 | 11 | @pytest.mark.parametrize( 12 | 'model_name, hidden_dim, embed_dim', 13 | [ 14 | ('pointconv', 1024, 1024), 15 | ('pointnet', 1024, 1024), 16 | ('pointnet2', 1024, 1024), 17 | ('pointmlp', 64, 32), 18 | ('repsurf', 1024, 1024), 19 | ('curvenet', 1024, 1024), 20 | ], 21 | ) 22 | def test_integration(model_name: str, hidden_dim: int, embed_dim: int): 23 | docs = DocumentArray.empty(5) 24 | docs.tensors = np.random.random((5, 1024, 3)) 25 | with Flow(return_results=True).add( 26 | uses=MeshDataEncoder, 27 | uses_with={ 28 | 'pretrained_model': None, 29 | 'default_model_name': model_name, 30 | 'hidden_dim': hidden_dim, 31 | 'embed_dim': embed_dim, 32 | }, 33 | ) as flow: 34 | resp = flow.post( 35 | on='/encoder', 36 | inputs=docs, 37 | return_results=True, 38 | ) 39 | 40 | for doc in resp: 41 | assert doc.embedding is not None 42 | assert doc.embedding.shape == (embed_dim,) 43 | 44 | 45 | @pytest.fixture(name='create_torch_dataset') 46 | @pytest.mark.parametrize( 47 | 'model_name, hidden_dim, embed_dim', 48 | [ 49 | ('pointconv', 1024, 1024), 50 | ('pointnet', 1024, 1024), 51 | ('pointnet2', 1024, 1024), 52 | ('pointmlp', 64, 32), 53 | ('repsurf', 1024, 1024), 54 | ('curvenet', 1024, 1024), 55 | ], 56 | ) 57 | def test_integration_pytorch_lightning( 58 | model_name: str, hidden_dim: int, embed_dim: int, create_torch_dataset 59 | ): 60 | encoder = MeshDataEncoderPL( 61 | default_model_name=model_name, hidden_dim=hidden_dim, embed_dim=embed_dim 62 | ) 63 | 64 | train_and_val_data = create_torch_dataset(200, 1024) 65 | test_data = create_torch_dataset(100, 1024) 66 | 67 | train_data, validate_data = random_split(train_and_val_data, [4, 1]) 68 | 69 | train_loader = DataLoader(train_data, batch_size=2, shuffle=True) 70 | validate_loader = DataLoader(validate_data, batch_size=1, shuffle=True) 71 | test_loader = DataLoader(test_data, batch_size=2, shuffle=True) 72 | 73 | trainer = Trainer( 74 | accelerator='cpu', 75 | max_epochs=5, 76 | check_val_every_n_epoch=1, 77 | enable_checkpointing=True, 78 | ) 79 | 80 | encoder.train() 81 | trainer.fit(encoder, train_loader, validate_loader) 82 | 83 | encoder.eval() 84 | trainer.test(encoder, dataloaders=test_loader) 85 | 86 | data = np.random.random((5, 1024, 3)) 87 | embedding = encoder.forward(data) 88 | 89 | assert embedding is not None 90 | assert embedding.shape == ( 91 | 5, 92 | embed_dim, 93 | ) 94 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jina-ai/executor-3d-encoder/d9a39ddcbf6e2efd6f7b894f07289af68d0d9714/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/test_exec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from jina import Document, DocumentArray 4 | 5 | from executor import MeshDataEncoder 6 | 7 | 8 | @pytest.mark.parametrize( 9 | 'model_name, hidden_dim, embed_dim', 10 | [ 11 | ('pointconv', 1024, 1024), 12 | ('pointnet', 1024, 1024), 13 | ('pointnet2', 1024, 1024), 14 | ('pointmlp', 64, 32), 15 | ('repsurf', 1024, 1024), 16 | ('curvenet', 1024, 1024), 17 | ], 18 | ) 19 | def test_encoder(model_name, hidden_dim, embed_dim): 20 | encoder = MeshDataEncoder( 21 | pretrained_model=None, 22 | default_model_name=model_name, 23 | hidden_dim=hidden_dim, 24 | embed_dim=embed_dim, 25 | ) 26 | 27 | docs = DocumentArray(Document(tensor=np.random.random((1024, 3)))) 28 | 29 | encoder.encode(docs) 30 | 31 | assert docs[0].embedding is not None 32 | assert docs[0].embedding.shape == (embed_dim,) 33 | 34 | 35 | def test_filter(): 36 | encoder = MeshDataEncoder( 37 | pretrained_model=None, 38 | default_model_name='pointconv', 39 | filters={'embedding': {'$exists': False}}, 40 | ) 41 | 42 | docs = DocumentArray(Document(tensor=np.random.random((1024, 3)))) 43 | 44 | embedding = np.random.random((512,)) 45 | docs.append(Document(tensor=np.random.random((1024, 3)), embedding=embedding)) 46 | 47 | encoder.encode(docs) 48 | 49 | assert docs[0].embedding.shape == (1024,) 50 | 51 | assert docs[1].embedding is not None 52 | assert docs[1].embedding.shape == (512,) 53 | --------------------------------------------------------------------------------