├── images └── teaser.png ├── code ├── utils.py ├── configs │ ├── pix3d.yaml │ ├── compcars.yaml │ └── stanfordcars.yaml ├── datasets │ ├── shape_datasets.py │ └── query_datasets.py ├── Models.py ├── binvox_rw.py ├── ColorTransfer.py ├── RetrievalNet.py └── RetrievalNet_test.py ├── LICENSE └── README.md /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IGLICT/IBSR_jittor/HEAD/images/teaser.png -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def read_json(mdir): 4 | with open(mdir, 'r') as f: 5 | tmp = json.loads(f.read()) 6 | return tmp 7 | -------------------------------------------------------------------------------- /code/configs/pix3d.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: pix3d 3 | num_workers: 8 4 | batch_size: 30 5 | root_dir: ./data 6 | mask_dir: mask_ocrnet 7 | images_dir: img 8 | render_path: rendering_pix3d.pkl 9 | pix_size: 224 10 | view_num: 12 11 | training_json: pix3d_train.json 12 | test_json: pix3d_test.json 13 | tau: 0.1 14 | 15 | 16 | trainer: 17 | epochs: 400 18 | seed: 3104 19 | 20 | 21 | models: 22 | z_dim: 128 23 | pre_trained_path: './pre_trained/pix3d.pt' 24 | pre_train_resnet_root: './pretrained_resnet' 25 | save_root: './' 26 | 27 | 28 | setting: 29 | is_training: True 30 | is_aug: True # for data augmentation, flip, random crop... 31 | is_color: True # for data augmentation: color transfer 32 | is_from_scratch: True -------------------------------------------------------------------------------- /code/configs/compcars.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: compcars 3 | num_workers: 8 4 | batch_size: 30 5 | root_dir: ./data 6 | mask_dir: mask 7 | images_dir: img 8 | render_path: rendering_compcars.pkl 9 | pix_size: 224 10 | view_num: 12 11 | training_json: CompCars_train.json 12 | test_json: CompCars_test.json 13 | tau: 0.1 14 | 15 | 16 | trainer: 17 | epochs: 400 18 | seed: 3104 19 | 20 | 21 | models: 22 | z_dim: 128 23 | pre_trained_path: './pre_trained/compcars.pt' 24 | pre_train_resnet_root: './pretrained_resnet' 25 | save_root: './' 26 | 27 | 28 | setting: 29 | is_training: True 30 | is_aug: True # for data augmentation, flip, random crop... 31 | is_color: True # for data augmentation: color transfer 32 | is_from_scratch: True -------------------------------------------------------------------------------- /code/configs/stanfordcars.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: stanfordcars 3 | num_workers: 8 4 | batch_size: 30 5 | root_dir: ./data 6 | mask_dir: mask 7 | images_dir: img 8 | render_path: rendering_stanfordcars.pkl 9 | pix_size: 224 10 | view_num: 12 11 | training_json: StanfordCars_train.json 12 | test_json: StanfordCars_test.json 13 | tau: 0.1 14 | 15 | 16 | trainer: 17 | epochs: 400 18 | seed: 3104 19 | 20 | 21 | models: 22 | z_dim: 128 23 | pre_trained_path: './pre_trained/stanfordcars.pt' 24 | pre_train_resnet_root: './pretrained_resnet' 25 | save_root: './' 26 | 27 | 28 | setting: 29 | is_training: True 30 | is_aug: True # for data augmentation, flip, random crop... 31 | is_color: True # for data augmentation: color transfer 32 | is_from_scratch: True -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 IGLICT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code/datasets/shape_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jittor.transform as transform 3 | from jittor.dataset.dataset import Dataset 4 | from PIL import Image 5 | import pickle 6 | import jittor as jt 7 | if jt.has_cuda: 8 | jt.flags.use_cuda = 1 9 | 10 | class ShapeDataset(Dataset): 11 | def __init__(self, cfg): 12 | super(ShapeDataset, self).__init__() 13 | 14 | render_path = os.path.join(cfg.data.root_dir, cfg.data.name, cfg.data.render_path) 15 | # self.dicts = np.load(render_path, allow_pickle=True).item() 16 | with open(render_path, 'rb') as f: 17 | self.dicts = pickle.load(f) 18 | self.labels = self.make_dataset(self.dicts) 19 | 20 | self.transform = self.get_transform(cfg.data.pix_size, cfg.setting.is_aug) 21 | self.view_num = cfg.data.view_num 22 | 23 | 24 | def __getitem__(self, index): 25 | labels = self.labels[index] 26 | cat = labels['cat'] 27 | idx = labels['instance'] 28 | renderings = self.dicts[cat][idx] # 12x224x224 29 | # debug = self.transform(Image.fromarray(renderings[0])) 30 | render_img = jt.concat([self.transform(renderings[vi]) for vi in range(self.view_num)], dim=0) 31 | 32 | return {'rendering_img': render_img, 'labels': labels} 33 | 34 | def __len__(self): 35 | return len(self.labels) 36 | 37 | @staticmethod 38 | def make_dataset(dicts): 39 | labels = [] 40 | for cat in dicts.keys(): 41 | for idx in dicts[cat].keys(): 42 | labels.append({'cat': cat, 'instance': idx}) 43 | return labels 44 | 45 | 46 | @staticmethod 47 | def get_transform(rsize=224, is_aug=False, method=Image.BICUBIC): 48 | transform_list = [] 49 | # transform_list.append(transforms.Resize(rsize, method)) 50 | if is_aug: 51 | transform_list.append(transform.RandomResizedCrop(rsize, scale=(0.65, 0.9))) 52 | transform_list.append(transform.RandomHorizontalFlip()) 53 | 54 | 55 | transform_list += [transform.ToTensor()] 56 | transform_list += [transform.ImageNormalize((0.5, ), (0.5, ))] 57 | return transform.Compose(transform_list) 58 | 59 | 60 | 61 | if __name__ =='__main__': 62 | import yaml 63 | import argparse 64 | with open('./configs/pix3d.yaml', 'r') as f: 65 | config = yaml.load(f) 66 | def dict2namespace(config): 67 | namespace = argparse.Namespace() 68 | for key, value in config.items(): 69 | if isinstance(value, dict): 70 | new_value = dict2namespace(value) 71 | else: 72 | new_value = value 73 | setattr(namespace, key, new_value) 74 | return namespace 75 | config = dict2namespace(config) 76 | 77 | dataset = ShapeDataset(cfg=config) 78 | retrieval_loader = ShapeDataset(cfg=config).set_attrs(batch_size=5, shuffle=True, num_workers=2, drop_last=True) 79 | 80 | # for batch in retrieval_loader: 81 | # debug = 10 82 | # debug = 20 83 | # continue 84 | # debug = 10 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CMIC-Retrieval 2 | Code for **Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning**. **ICCV 2021.** 3 | 4 | ![Overview](/images/teaser.png) 5 | 6 | 7 | 8 | ## Introduction 9 | 10 | In this work, we tackle the problem of single image-based 3D shape retrieval (IBSR), where we seek to find the most matched shape of a given single 2D image from a shape repository. Most of the existing works learn to embed 2D images and 3D shapes into a common feature space and perform metric learning using a triplet loss. Inspired by the great success in recent contrastive learning works on self-supervised representation learning, we propose a novel IBSR pipeline leveraging contrastive learning. We note that adopting such cross-modal contrastive learning between 2D images and 3D shapes into IBSR tasks is non-trivial and challenging: contrastive learning requires very strong data augmentation in constructed positive pairs to learn the feature invariance, whereas traditional metric learning works do not have this requirement. However, object shape and appearance are entangled in 2D query images, thus making the learning task more difficult than contrasting single-modal data. To mitigate the challenges, we propose to use multi-view grayscale rendered images from the 3D shapes as a shape representation. We then introduce a strong data augmentation technique based on color transfer, which can significantly but naturally change the appearance of the query image, effectively satisfying the need for contrastive learning. Finally, we propose to incorporate a novel category-level contrastive loss that helps distinguish similar objects from different categories, in addition to classic instance-level contrastive loss. Our experiments demonstrate that our approach achieves the best performance on all the three popular IBSR benchmarks, including Pix3D, Stanford Cars, and Comp Cars, outperforming the previous state-of-the-art from 4% - 15% on retrieval accuracy. 11 | 12 | 13 | 14 | ## About this repository 15 | 16 | This repository provides **data**, **pre-trained models** and **code**. 17 | 18 | 19 | 20 | ## Installation 21 | ```zsh 22 | # create anoconda environment 23 | ## please make sure that python version >= 3.7 (required by jittor) 24 | conda create -n ibsr_jittor python=3.7 25 | conda activate ibsr_jittor 26 | 27 | # jittor installation 28 | python3.7 -m pip install jittor 29 | ## testing jittor 30 | ### if errors appear, you can follow the instructions of jittor to fix them. 31 | python3.7 -m jittor.test.test_example 32 | # testing for cudnn 33 | python3.7 -m jittor.test.test_cudnn_op 34 | 35 | # other pickages 36 | pip install pyyaml 37 | pip install scikit-learn 38 | pip install matplotlib 39 | pip install scikit-image 40 | pip install argparse 41 | ``` 42 | 43 | 44 | 45 | ## How to use 46 | 47 | ```zsh 48 | # download pre-trained models, data and official ResNet pre-trained models from this links: 49 | https://1drv.ms/u/s!Ams-YJGtFnP7mTQOACYHco1s2gXE?e=c87UnV 50 | 51 | # put the unzip folder pre_trained, pretrained_resnet, data under IBSR_jittor/code 52 | cd IBSR_jittor/code 53 | 54 | # all codes are test under a single Nvidia RTX3090, Ubuntu 20.04 55 | # training 56 | python RetrievalNet.py --config ./configs/pix3d.yaml 57 | 58 | # testing 59 | python RetrievalNet_test.py --config ./configs/pix3d.yaml --mode simple 60 | # for full test 61 | python RetrievalNet_test.py --config ./configs/pix3d.yaml --mode full 62 | # for shapenet test 63 | python RetrievalNet_test.py --config ./configs/pix3d.yaml --mode shapenet 64 | 65 | # pay attention to: 66 | # model_std_bin128 and model_std_ptc10k_npy are not uploaded. 67 | # For model_std_ptc10k_npy, we randomly sample 10k points from the mesh by python igl package. 68 | # For model_std_bin128, please refer to https://www.patrickmin.com/viewvox/ for more information. 69 | ``` 70 | 71 | 72 | 73 | 74 | ## Citations 75 | ``` 76 | @inProceedings{lin2021cmic, 77 | title={Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning}, 78 | author={Lin, Ming-Xian and Yang, Jie and Wang, He and Lai, Yu-Kun and Jia, Rongfei and Zhao, Binqiang and Gao, Lin}, 79 | year={2021}, 80 | booktitle={International Conference on Computer Vision (ICCV)} 81 | } 82 | ``` 83 | 84 | 85 | 86 | ## Updates 87 | - [Apr 1, 2021] Pre-trained Models, Data and revised Code released. 88 | - [Oct 1, 2021] Preliminary version of Data and Code released. For more code and data, coming soon. Please follow our updates. 89 | 90 | -------------------------------------------------------------------------------- /code/Models.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn, models 3 | if jt.has_cuda: 4 | jt.flags.use_cuda = 1 # jt.flags.use_cuda 5 | 6 | class QueryEncoder(nn.Module): 7 | def __init__(self, out_dim=128): 8 | super(QueryEncoder, self).__init__() 9 | self.dim = out_dim 10 | self.resnet = models.resnet50(pretrained=False) 11 | self.resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) 12 | fc_features = self.resnet.fc.in_features 13 | self.resnet.fc = nn.Sequential( 14 | nn.BatchNorm1d(fc_features*1), 15 | nn.Linear(fc_features*1, self.dim), 16 | ) 17 | 18 | def execute(self, input): 19 | embeddings = self.resnet(input) 20 | embeddings = jt.normalize(embeddings, p=2, dim=1) 21 | return embeddings 22 | 23 | 24 | class RenderingEncoder(nn.Module): 25 | def __init__(self, out_dim=128): 26 | super(RenderingEncoder, self).__init__() 27 | self.dim = out_dim 28 | self.resnet = models.resnet18(pretrained=False) 29 | self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 30 | fc_features = self.resnet.fc.in_features 31 | self.resnet.fc = nn.Sequential( 32 | nn.BatchNorm1d(fc_features*1), 33 | nn.Linear(fc_features*1, self.dim), 34 | ) 35 | 36 | def execute(self, inputs): 37 | embeddings = self.resnet(inputs) 38 | embeddings = jt.normalize(embeddings, p=2, dim=1) 39 | return embeddings 40 | 41 | 42 | class Attention(nn.Module): 43 | ''' 44 | Revised from pytorch version: 45 | ''' 46 | 47 | """ Applies attention mechanism on the `context` using the `query`. 48 | 49 | **Thank you** to IBM for their initial implementation of :class:`Attention`. Here is 50 | their `License 51 | `__. 52 | 53 | Args: 54 | dimensions (int): Dimensionality of the query and context. 55 | attention_type (str, optional): How to compute the attention score: 56 | 57 | * dot: :math:`score(H_j,q) = H_j^T q` 58 | * general: :math:`score(H_j, q) = H_j^T W_a q` 59 | 60 | Example: 61 | 62 | >>> attention = Attention(256) 63 | >>> query = torch.randn(5, 1, 256) 64 | >>> context = torch.randn(5, 5, 256) 65 | >>> output, weights = attention(query, context) 66 | >>> output.size() 67 | torch.Size([5, 1, 256]) 68 | >>> weights.size() 69 | torch.Size([5, 1, 5]) 70 | """ 71 | 72 | def __init__(self, dimensions, attention_type='general'): 73 | super(Attention, self).__init__() 74 | 75 | if attention_type not in ['dot', 'general']: 76 | raise ValueError('Invalid attention type selected.') 77 | 78 | self.attention_type = attention_type 79 | if self.attention_type == 'general': 80 | self.linear_in = nn.Linear(dimensions, dimensions, bias=False) 81 | 82 | self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False) 83 | self.softmax = nn.Softmax(dim=-1) 84 | self.tanh = nn.Tanh() 85 | 86 | def execute(self, query, context): 87 | """ 88 | Args: 89 | query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of 90 | queries to query the context. 91 | context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data 92 | overwhich to apply the attention mechanism. 93 | 94 | Returns: 95 | :class:`tuple` with `output` and `weights`: 96 | * **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]): 97 | Tensor containing the attended features. 98 | * **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]): 99 | Tensor containing attention weights. 100 | """ 101 | batch_size, output_len, dimensions = query.size() 102 | query_len = context.size(1) 103 | 104 | if self.attention_type == "general": 105 | query = query.view(batch_size * output_len, dimensions) 106 | query = self.linear_in(query) 107 | query = query.view(batch_size, output_len, dimensions) 108 | 109 | # TODO: Include mask on PADDING_INDEX? 110 | 111 | # (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) -> 112 | # (batch_size, output_len, query_len) 113 | # attention_scores = nn.bmm(query, context.transpose(1, 2).contiguous()) 114 | attention_scores = nn.bmm(query, context.transpose(0, 2, 1)) 115 | 116 | # Compute weights across every context sequence 117 | attention_scores = attention_scores.view(batch_size * output_len, query_len) 118 | attention_weights = self.softmax(attention_scores) 119 | attention_weights = attention_weights.view(batch_size, output_len, query_len) 120 | 121 | # (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) -> 122 | # (batch_size, output_len, dimensions) 123 | mix = nn.bmm(attention_weights, context) 124 | 125 | # concat -> (batch_size * output_len, 2*dimensions) 126 | combined = jt.concat((mix, query), dim=2) 127 | combined = combined.view(batch_size * output_len, 2 * dimensions) 128 | 129 | # Apply linear_out on every 2nd dimension of concat 130 | # output -> (batch_size, output_len, dimensions) 131 | output = self.linear_out(combined).view(batch_size, output_len, dimensions) 132 | output = self.tanh(output) 133 | 134 | return output, attention_weights 135 | 136 | 137 | class RetrievalNet(nn.Module): 138 | ''' 139 | QueryEncoder 140 | RenderingEncoder 141 | Attention 142 | ''' 143 | def __init__(self, cfg): 144 | super(RetrievalNet, self).__init__() 145 | self.dim = cfg.models.z_dim 146 | self.size = cfg.data.pix_size 147 | self.view_num = cfg.data.view_num 148 | self.query_encoder = QueryEncoder(self.dim) 149 | self.rendering_encoder = RenderingEncoder(self.dim) 150 | self.attention = Attention(self.dim) 151 | 152 | 153 | def execute(self, query, rendering): 154 | query_ebd = self.get_query_ebd(query) 155 | bs = query_ebd.shape[0] 156 | rendering = rendering.view(-1, 1, self.size, self.size) 157 | rendering_ebds = self.get_rendering_ebd(rendering).view(-1, self.view_num, self.dim) 158 | 159 | #(shape, image, ebd) -> (bs, bs, 128) 160 | query_ebd = query_ebd.unsqueeze(0).repeat(bs, 1, 1) 161 | # query_ebd: bs, bs, dim 162 | # rendering_ebds: bs, 12, dim 163 | _, weights = self.attention_query(query_ebd, rendering_ebds) 164 | 165 | # weights: bxxbsx12 166 | # rendering_ebds: bsx12x128 167 | # queried_rendering_ebd: bsxbsx128 (shape, model, 128) 168 | # reference to https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/attention.html#Attentionl 169 | queried_rendering_ebd = nn.bmm(weights, rendering_ebds) 170 | return query_ebd, queried_rendering_ebd 171 | 172 | def get_query_ebd(self, inputs): 173 | return self.query_encoder(inputs) 174 | 175 | def get_rendering_ebd(self, inputs): 176 | return self.rendering_encoder(inputs) 177 | 178 | def attention_query(self, ebd, pool_ebd): 179 | return self.attention(ebd, pool_ebd) 180 | 181 | 182 | 183 | if __name__ == '__main__': 184 | import yaml 185 | import argparse 186 | 187 | with open('./configs/pix3d.yaml', 'r') as f: 188 | config = yaml.load(f) 189 | def dict2namespace(config): 190 | namespace = argparse.Namespace() 191 | for key, value in config.items(): 192 | if isinstance(value, dict): 193 | new_value = dict2namespace(value) 194 | else: 195 | new_value = value 196 | setattr(namespace, key, new_value) 197 | return namespace 198 | config = dict2namespace(config) 199 | 200 | 201 | 202 | 203 | models = RetrievalNet(config) 204 | img = jt.random([2,4,224,224]).stop_grad() 205 | mask = jt.random([2,12,224,224]).stop_grad() 206 | 207 | # mm = models.resnet50(pretrained=False) 208 | # # print(mm) 209 | # a = mm(img) 210 | 211 | outputs = models(img, mask) -------------------------------------------------------------------------------- /code/datasets/query_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jittor.misc import set_global_seed 3 | import jittor.transform as transform 4 | from jittor.dataset.dataset import Dataset 5 | from PIL import Image 6 | import pickle 7 | import json 8 | import time 9 | import jittor as jt 10 | if jt.has_cuda: 11 | jt.flags.use_cuda = 1 12 | 13 | 14 | class QueryDataset(Dataset): 15 | def __init__(self, cfg): 16 | super(QueryDataset, self).__init__() 17 | self.is_training = cfg.setting.is_training 18 | if cfg.setting.is_training: 19 | self.json_path = os.path.join(cfg.data.root_dir, cfg.data.name, cfg.data.training_json) 20 | else: 21 | self.json_path = os.path.join(cfg.data.root_dir, cfg.data.name, cfg.data.test_json) 22 | 23 | self.json_dict = self.read_json(self.json_path) 24 | self.data_dir = os.path.join(cfg.data.root_dir, cfg.data.name) 25 | 26 | crop_scale = (0.85, 0.95) 27 | self.aug = cfg.setting.is_aug 28 | self.query_transform = self.get_query_transform(cfg.data.pix_size, crop_scale, self.aug) 29 | self.mask_transform = self.get_mask_transform(cfg.data.pix_size, crop_scale, self.aug) 30 | self.rendering_transform = self.get_rendering_transform(cfg.data.pix_size, self.aug) 31 | self.view_num = cfg.data.view_num 32 | self.mask_dir = cfg.data.mask_dir 33 | 34 | render_path = os.path.join(cfg.data.root_dir, cfg.data.name, cfg.data.render_path) 35 | with open(render_path, 'rb') as f: 36 | self.dicts = pickle.load(f) 37 | 38 | def __getitem__(self, index): 39 | info = self.json_dict[index] 40 | cat = info['category'] 41 | instance = info['model'].split('/')[-2] 42 | renderings = self.dicts[cat][instance] 43 | 44 | tmp_seed = int(time.time()) % 100000 45 | jt.set_global_seed(tmp_seed) 46 | query_img = self.query_transform(Image.open(os.path.join(self.data_dir, info['img'])).convert("RGB")) 47 | 48 | jt.set_global_seed(tmp_seed) 49 | mask_path_list = info['mask'].split('/') 50 | # if self.is_training: 51 | # mask_path_list[0] = 'mask' 52 | # else: 53 | mask_path_list[0] = self.mask_dir 54 | mask_path = '/'.join(mask_path_list) 55 | mask_img = self.mask_transform(Image.open(os.path.join(self.data_dir, mask_path))) 56 | 57 | render_img = jt.concat([self.rendering_transform(renderings[vi]) for vi in range(self.view_num)], dim=0) 58 | 59 | # debug = 10 60 | # embeddings_name = info['category'] + '-' + info['model'].split('/')[-2]+'.npy' 61 | # embeddings = np.load(os.path.join(self.embeddings_dir, embeddings_name)) 62 | return {'query_img': query_img, 'mask_img':mask_img, \ 63 | 'rendering_img':render_img, 'cat':cat, 'instance':instance } 64 | 65 | def __len__(self): 66 | return len(self.json_dict) 67 | 68 | @staticmethod 69 | def get_query_transform(rsize=(224, 224), crop_scale=(0.85, 0.95), is_aug=False): 70 | transform_list = [] 71 | if is_aug: 72 | transform_list.append(transform.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=None, shear=None, resample=False, fillcolor=0)) 73 | transform_list.append(transform.RandomResizedCrop(rsize, scale=crop_scale)) 74 | transform_list.append(transform.RandomHorizontalFlip()) 75 | 76 | transform_list += [transform.ToTensor()] 77 | # if not is_training: 78 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 79 | # we have add this 'Normalize' in train_retrieval.py 80 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 81 | return transform.Compose(transform_list) 82 | 83 | @staticmethod 84 | def get_mask_transform(rsize=(224, 224), crop_scale=(0.85, 0.95), is_aug=False): 85 | transform_list = [] 86 | if is_aug: 87 | transform_list.append(transform.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=None, shear=None, resample=False, fillcolor=0)) 88 | transform_list.append(transform.RandomResizedCrop(rsize, scale=crop_scale)) 89 | transform_list.append(transform.RandomHorizontalFlip()) 90 | 91 | transform_list += [transform.ToTensor()] 92 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 93 | return transform.Compose(transform_list) 94 | 95 | @staticmethod 96 | def get_rendering_transform(rsize=224, is_aug=False): 97 | transform_list = [] 98 | # transform_list.append(transforms.Resize(rsize, method)) 99 | if is_aug: 100 | transform_list.append(transform.RandomResizedCrop(rsize, scale=(0.65, 0.9))) 101 | transform_list.append(transform.RandomHorizontalFlip()) 102 | 103 | transform_list += [transform.ToTensor()] 104 | transform_list += [transform.ImageNormalize((0.5, ), (0.5, ))] 105 | return transform.Compose(transform_list) 106 | 107 | @staticmethod 108 | def read_json(mdir): 109 | with open(mdir, 'r') as f: 110 | tmp = json.loads(f.read()) 111 | return tmp 112 | 113 | 114 | class ImageDataset(Dataset): 115 | def __init__(self, cfg): 116 | super(ImageDataset, self).__init__() 117 | self.json_path = os.path.join(cfg.data.root_dir, cfg.data.name, cfg.data.test_json) 118 | self.json_dict = self.read_json(self.json_path) 119 | self.data_dir = os.path.join(cfg.data.root_dir, cfg.data.name) 120 | 121 | self.query_transform = transform.Compose([transform.ToTensor(), transform.ImageNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 122 | self.mask_transform = transform.Compose([transform.ToTensor(), transform.ImageNormalize((0.5, ), (0.5, ))]) 123 | 124 | 125 | def __getitem__(self, index): 126 | info = self.json_dict[index] 127 | cat = info['category'] 128 | instance = info['model'].split('/')[-2] 129 | renderings = self.dicts[cat][instance] 130 | 131 | query_img = self.query_transform(Image.open(os.path.join(self.data_dir, info['img']))) 132 | mask_img = self.mask_transform(Image.open(os.path.join(self.data_dir, info['mask']))) 133 | 134 | return {'query_img': query_img, 'mask_img':mask_img, \ 135 | 'cat':cat, 'instance':instance } 136 | 137 | def __len__(self): 138 | return len(self.json_dict) 139 | 140 | @staticmethod 141 | def read_json(mdir): 142 | with open(mdir, 'r') as f: 143 | tmp = json.loads(f.read()) 144 | return tmp 145 | 146 | 147 | if __name__ =='__main__': 148 | import yaml 149 | import argparse 150 | with open('./configs/pix3d.yaml', 'r') as f: 151 | config = yaml.load(f) 152 | def dict2namespace(config): 153 | namespace = argparse.Namespace() 154 | for key, value in config.items(): 155 | if isinstance(value, dict): 156 | new_value = dict2namespace(value) 157 | else: 158 | new_value = value 159 | setattr(namespace, key, new_value) 160 | return namespace 161 | config = dict2namespace(config) 162 | 163 | # dataset = QueryDataset(cfg=config) 164 | # retrieval_loader = torch.utils.data.DataLoader(dataset=dataset, \ 165 | # batch_size=1, shuffle=True, \ 166 | # drop_last=True, num_workers=1) 167 | retrieval_loader = QueryDataset(cfg=config).set_attrs(batch_size=1, shuffle=True, num_workers=1, drop_last=True) 168 | 169 | for meta in retrieval_loader: 170 | 171 | ##### dataset debug ###### 172 | with jt.no_grad(): 173 | mask_img = meta['mask_img'] 174 | embeddings = meta['rendering_img'] 175 | cats = meta['cat'] 176 | instances = meta['instance'] 177 | query_img = meta['query_img'] 178 | 179 | 180 | topil = transform.ToPILImage() 181 | masked_img = query_img*mask_img 182 | 183 | 184 | q_img = topil(jt.transpose(query_img[0], [1,2,0])) 185 | # s_img = topil(style_img[0]) 186 | # tf_img = topil(transfer_img[0]) 187 | 188 | mask_img_ = mask_img[0] 189 | mask_img_[mask_img_>0] = 255 190 | m_img = topil(jt.transpose(mask_img_, [1,2,0])).convert('L') 191 | md_img = topil(jt.transpose(masked_img[0], [1,2,0])) 192 | # md_tf_img = topil(masked_tf_img[0]) 193 | 194 | q_img.save('./debug/q_img.png') 195 | # s_img.save('./debug/s_img.png') 196 | # tf_img.save('./debug/tf_img.png') 197 | 198 | m_img.save('./debug/m_img.png') 199 | md_img.save('./debug/md_img.png') 200 | # md_tf_img.save('./debug/md_tf_img.png') 201 | ##### dataset debug ###### 202 | 203 | debug = 10 204 | debug = 20 205 | continue -------------------------------------------------------------------------------- /code/binvox_rw.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2012 Daniel Maturana 2 | # This file is part of binvox-rw-py. 3 | # 4 | # binvox-rw-py is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # binvox-rw-py is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with binvox-rw-py. If not, see . 16 | # 17 | 18 | """ 19 | Binvox to Numpy and back. 20 | 21 | 22 | >>> import numpy as np 23 | >>> import binvox_rw 24 | >>> with open('chair.binvox', 'rb') as f: 25 | ... m1 = binvox_rw.read_as_3d_array(f) 26 | ... 27 | >>> m1.dims 28 | [32, 32, 32] 29 | >>> m1.scale 30 | 41.133000000000003 31 | >>> m1.translate 32 | [0.0, 0.0, 0.0] 33 | >>> with open('chair_out.binvox', 'wb') as f: 34 | ... m1.write(f) 35 | ... 36 | >>> with open('chair_out.binvox', 'rb') as f: 37 | ... m2 = binvox_rw.read_as_3d_array(f) 38 | ... 39 | >>> m1.dims==m2.dims 40 | True 41 | >>> m1.scale==m2.scale 42 | True 43 | >>> m1.translate==m2.translate 44 | True 45 | >>> np.all(m1.data==m2.data) 46 | True 47 | 48 | >>> with open('chair.binvox', 'rb') as f: 49 | ... md = binvox_rw.read_as_3d_array(f) 50 | ... 51 | >>> with open('chair.binvox', 'rb') as f: 52 | ... ms = binvox_rw.read_as_coord_array(f) 53 | ... 54 | >>> data_ds = binvox_rw.dense_to_sparse(md.data) 55 | >>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32) 56 | >>> np.all(data_sd==md.data) 57 | True 58 | >>> # the ordering of elements returned by numpy.nonzero changes with axis 59 | >>> # ordering, so to compare for equality we first lexically sort the voxels. 60 | >>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)]) 61 | True 62 | """ 63 | 64 | import numpy as np 65 | 66 | class Voxels(object): 67 | """ Holds a binvox model. 68 | data is either a three-dimensional numpy boolean array (dense representation) 69 | or a two-dimensional numpy float array (coordinate representation). 70 | 71 | dims, translate and scale are the model metadata. 72 | 73 | dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. 74 | 75 | scale and translate relate the voxels to the original model coordinates. 76 | 77 | To translate voxel coordinates i, j, k to original coordinates x, y, z: 78 | 79 | x_n = (i+.5)/dims[0] 80 | y_n = (j+.5)/dims[1] 81 | z_n = (k+.5)/dims[2] 82 | x = scale*x_n + translate[0] 83 | y = scale*y_n + translate[1] 84 | z = scale*z_n + translate[2] 85 | 86 | """ 87 | 88 | def __init__(self, data, dims, translate, scale, axis_order): 89 | self.data = data 90 | self.dims = dims 91 | self.translate = translate 92 | self.scale = scale 93 | assert (axis_order in ('xzy', 'xyz')) 94 | self.axis_order = axis_order 95 | 96 | def clone(self): 97 | data = self.data.copy() 98 | dims = self.dims[:] 99 | translate = self.translate[:] 100 | return Voxels(data, dims, translate, self.scale, self.axis_order) 101 | 102 | def write(self, fp): 103 | write(self, fp) 104 | 105 | def read_header(fp): 106 | """ Read binvox header. Mostly meant for internal use. 107 | """ 108 | line = fp.readline().strip() 109 | if not line.startswith(b'#binvox'): 110 | raise IOError('Not a binvox file') 111 | dims = list(map(int, fp.readline().strip().split(b' ')[1:])) 112 | translate = list(map(float, fp.readline().strip().split(b' ')[1:])) 113 | scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0] 114 | line = fp.readline() 115 | return dims, translate, scale 116 | 117 | def read_as_3d_array(fp, fix_coords=True): 118 | """ Read binary binvox format as array. 119 | 120 | Returns the model with accompanying metadata. 121 | 122 | Voxels are stored in a three-dimensional numpy array, which is simple and 123 | direct, but may use a lot of memory for large models. (Storage requirements 124 | are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy 125 | boolean arrays use a byte per element). 126 | 127 | Doesn't do any checks on input except for the '#binvox' line. 128 | """ 129 | dims, translate, scale = read_header(fp) 130 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 131 | # if just using reshape() on the raw data: 132 | # indexing the array as array[i,j,k], the indices map into the 133 | # coords as: 134 | # i -> x 135 | # j -> z 136 | # k -> y 137 | # if fix_coords is true, then data is rearranged so that 138 | # mapping is 139 | # i -> x 140 | # j -> y 141 | # k -> z 142 | values, counts = raw_data[::2], raw_data[1::2] 143 | data = np.repeat(values, counts).astype(np.bool) 144 | data = data.reshape(dims) 145 | if fix_coords: 146 | # xzy to xyz TODO the right thing 147 | data = np.transpose(data, (0, 2, 1)) 148 | axis_order = 'xyz' 149 | else: 150 | axis_order = 'xzy' 151 | return Voxels(data, dims, translate, scale, axis_order) 152 | 153 | def read_as_coord_array(fp, fix_coords=True): 154 | """ Read binary binvox format as coordinates. 155 | 156 | Returns binvox model with voxels in a "coordinate" representation, i.e. an 157 | 3 x N array where N is the number of nonzero voxels. Each column 158 | corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates 159 | of the voxel. (The odd ordering is due to the way binvox format lays out 160 | data). Note that coordinates refer to the binvox voxels, without any 161 | scaling or translation. 162 | 163 | Use this to save memory if your model is very sparse (mostly empty). 164 | 165 | Doesn't do any checks on input except for the '#binvox' line. 166 | """ 167 | dims, translate, scale = read_header(fp) 168 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 169 | 170 | values, counts = raw_data[::2], raw_data[1::2] 171 | 172 | sz = np.prod(dims) 173 | index, end_index = 0, 0 174 | end_indices = np.cumsum(counts) 175 | indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) 176 | 177 | values = values.astype(np.bool) 178 | indices = indices[values] 179 | end_indices = end_indices[values] 180 | 181 | nz_voxels = [] 182 | for index, end_index in zip(indices, end_indices): 183 | nz_voxels.extend(range(index, end_index)) 184 | nz_voxels = np.array(nz_voxels) 185 | # TODO are these dims correct? 186 | # according to docs, 187 | # index = x * wxh + z * width + y; // wxh = width * height = d * d 188 | 189 | x = nz_voxels / (dims[0]*dims[1]) 190 | zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y 191 | z = zwpy / dims[0] 192 | y = zwpy % dims[0] 193 | if fix_coords: 194 | data = np.vstack((x, y, z)) 195 | axis_order = 'xyz' 196 | else: 197 | data = np.vstack((x, z, y)) 198 | axis_order = 'xzy' 199 | 200 | #return Voxels(data, dims, translate, scale, axis_order) 201 | return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) 202 | 203 | def dense_to_sparse(voxel_data, dtype=np.int): 204 | """ From dense representation to sparse (coordinate) representation. 205 | No coordinate reordering. 206 | """ 207 | if voxel_data.ndim!=3: 208 | raise ValueError('voxel_data is wrong shape; should be 3D array.') 209 | return np.asarray(np.nonzero(voxel_data), dtype) 210 | 211 | def sparse_to_dense(voxel_data, dims, dtype=np.bool): 212 | if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: 213 | raise ValueError('voxel_data is wrong shape; should be 3xN array.') 214 | if np.isscalar(dims): 215 | dims = [dims]*3 216 | dims = np.atleast_2d(dims).T 217 | # truncate to integers 218 | xyz = voxel_data.astype(np.int) 219 | # discard voxels that fall outside dims 220 | valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) 221 | xyz = xyz[:,valid_ix] 222 | out = np.zeros(dims.flatten(), dtype=dtype) 223 | out[tuple(xyz)] = True 224 | return out 225 | 226 | #def get_linear_index(x, y, z, dims): 227 | #""" Assuming xzy order. (y increasing fastest. 228 | #TODO ensure this is right when dims are not all same 229 | #""" 230 | #return x*(dims[1]*dims[2]) + z*dims[1] + y 231 | 232 | def write(voxel_model, fp): 233 | """ Write binary binvox format. 234 | 235 | Note that when saving a model in sparse (coordinate) format, it is first 236 | converted to dense format. 237 | 238 | Doesn't check if the model is 'sane'. 239 | 240 | """ 241 | if voxel_model.data.ndim==2: 242 | # TODO avoid conversion to dense 243 | dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) 244 | else: 245 | dense_voxel_data = voxel_model.data 246 | 247 | fp.write('#binvox 1\n') 248 | fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n') 249 | fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n') 250 | fp.write('scale '+str(voxel_model.scale)+'\n') 251 | fp.write('data\n') 252 | if not voxel_model.axis_order in ('xzy', 'xyz'): 253 | raise ValueError('Unsupported voxel model axis order') 254 | 255 | if voxel_model.axis_order=='xzy': 256 | voxels_flat = dense_voxel_data.flatten() 257 | elif voxel_model.axis_order=='xyz': 258 | voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() 259 | 260 | # keep a sort of state machine for writing run length encoding 261 | state = voxels_flat[0] 262 | ctr = 0 263 | for c in voxels_flat: 264 | if c==state: 265 | ctr += 1 266 | # if ctr hits max, dump 267 | if ctr==255: 268 | fp.write(chr(state)) 269 | fp.write(chr(ctr)) 270 | ctr = 0 271 | else: 272 | # if switch state, dump 273 | fp.write(chr(state)) 274 | fp.write(chr(ctr)) 275 | state = c 276 | ctr = 1 277 | # flush out remainders 278 | if ctr > 0: 279 | fp.write(chr(state)) 280 | fp.write(chr(ctr)) 281 | 282 | if __name__ == '__main__': 283 | import doctest 284 | doctest.testmod() 285 | -------------------------------------------------------------------------------- /code/ColorTransfer.py: -------------------------------------------------------------------------------- 1 | # import the necessary packages 2 | import jittor as jt 3 | if jt.has_cuda: 4 | jt.flags.use_cuda = 1 # jt.flags.use_cuda 表示是否使用 gpu 训练。 5 | 6 | # Color conversion code 7 | def rgb2xyz(rgb): # rgb from [0,1] 8 | # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], 9 | # [0.212671, 0.715160, 0.072169], 10 | # [0.019334, 0.119193, 0.950227]]) 11 | 12 | # mask = (rgb > .04045).type(torch.float) 13 | mask = jt.float(rgb > .04045) 14 | 15 | rgb = (((rgb+.055)/1.055)**2.4)*mask + rgb/12.92*(1-mask) 16 | 17 | x = .412453*rgb[:,0,:,:]+.357580*rgb[:,1,:,:]+.180423*rgb[:,2,:,:] 18 | y = .212671*rgb[:,0,:,:]+.715160*rgb[:,1,:,:]+.072169*rgb[:,2,:,:] 19 | z = .019334*rgb[:,0,:,:]+.119193*rgb[:,1,:,:]+.950227*rgb[:,2,:,:] 20 | out = jt.concat((x[:,None,:,:],y[:,None,:,:],z[:,None,:,:]),dim=1) 21 | 22 | # if(torch.sum(torch.isnan(out))>0): 23 | # print('rgb2xyz') 24 | # embed() 25 | return out 26 | 27 | def xyz2rgb(xyz): 28 | # array([[ 3.24048134, -1.53715152, -0.49853633], 29 | # [-0.96925495, 1.87599 , 0.04155593], 30 | # [ 0.05564664, -0.20404134, 1.05731107]]) 31 | 32 | r = 3.24048134*xyz[:,0,:,:]-1.53715152*xyz[:,1,:,:]-0.49853633*xyz[:,2,:,:] 33 | g = -0.96925495*xyz[:,0,:,:]+1.87599*xyz[:,1,:,:]+.04155593*xyz[:,2,:,:] 34 | b = .05564664*xyz[:,0,:,:]-.20404134*xyz[:,1,:,:]+1.05731107*xyz[:,2,:,:] 35 | 36 | rgb = jt.concat((r[:,None,:,:],g[:,None,:,:],b[:,None,:,:]),dim=1) 37 | rgb = jt.maximum(rgb,jt.zeros_like(rgb)) # sometimes reaches a small negative number, which causes NaNs 38 | 39 | # mask = (rgb > .0031308).type(torch.float) 40 | mask = jt.float(rgb > .0031308) 41 | # if(rgb.is_cuda): 42 | # mask = mask.cuda() 43 | # mask = mask.to(device) 44 | 45 | rgb = (1.055*(rgb**(1./2.4)) - 0.055)*mask + 12.92*rgb*(1-mask) 46 | 47 | # if(torch.sum(torch.isnan(rgb))>0): 48 | # print('xyz2rgb') 49 | # embed() 50 | return rgb 51 | 52 | def xyz2lab(xyz): 53 | # 0.95047, 1., 1.08883 # white 54 | sc = jt.array((0.95047, 1., 1.08883))[None,:,None,None] 55 | # sc = jt.array((0.95047, 1., 1.08883)) 56 | # if(xyz.is_cuda): 57 | # sc = sc.cuda() 58 | # sc = sc.to(device) 59 | 60 | xyz_scale = xyz/sc 61 | 62 | # mask = (xyz_scale > .008856).type(torch.float) 63 | mask = jt.float(xyz_scale > .008856) 64 | 65 | # if(xyz_scale.is_cuda): 66 | # mask = mask.cuda() 67 | # mask = mask.to(device) 68 | 69 | xyz_int = xyz_scale**(1/3.)*mask + (7.787*xyz_scale + 16./116.)*(1-mask) 70 | 71 | L = 116.*xyz_int[:,1,:,:]-16. 72 | a = 500.*(xyz_int[:,0,:,:]-xyz_int[:,1,:,:]) 73 | b = 200.*(xyz_int[:,1,:,:]-xyz_int[:,2,:,:]) 74 | out = jt.concat((L[:,None,:,:],a[:,None,:,:],b[:,None,:,:]),dim=1) 75 | 76 | # if(torch.sum(torch.isnan(out))>0): 77 | # print('xyz2lab') 78 | # embed() 79 | 80 | return out 81 | 82 | def lab2xyz(lab): 83 | # device = lab.device 84 | y_int = (lab[:,0,:,:]+16.)/116. 85 | x_int = (lab[:,1,:,:]/500.) + y_int 86 | z_int = y_int - (lab[:,2,:,:]/200.) 87 | 88 | # z_int = torch.max(torch.Tensor((0,)).to(device), z_int) 89 | z_int = jt.maximum(jt.Var((0,)), z_int) 90 | 91 | out = jt.concat((x_int[:,None,:,:],y_int[:,None,:,:],z_int[:,None,:,:]),dim=1) 92 | # mask = (out > .2068966).type(torch.float) 93 | mask = jt.float(out > .2068966) 94 | # if(out.is_cuda): 95 | # mask = mask.cuda() 96 | # mask = mask.to(device) 97 | 98 | 99 | out = (out**3.)*mask + (out - 16./116.)/7.787*(1-mask) 100 | 101 | # sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None] 102 | sc = jt.array((0.95047, 1., 1.08883))[None,:,None,None] 103 | # sc = sc.to(out.device) 104 | 105 | out = out*sc 106 | 107 | # if(torch.sum(torch.isnan(out))>0): 108 | # print('lab2xyz') 109 | # embed() 110 | 111 | return out 112 | 113 | def rgb2lab(rgb): 114 | lab = xyz2lab(rgb2xyz(rgb)) 115 | l_rs = (lab[:,[0],:,:]-50)/100 116 | ab_rs = lab[:,1:,:,:]/100 117 | out = jt.concat((l_rs,ab_rs),dim=1) 118 | # if(torch.sum(torch.isnan(out))>0): 119 | # print('rgb2lab') 120 | # embed() 121 | return out 122 | 123 | def lab2rgb(lab_rs): 124 | l = lab_rs[:,[0],:,:]*100 + 50 125 | ab = lab_rs[:,1:,:,:]*100 126 | lab = jt.concat((l,ab),dim=1) 127 | out = xyz2rgb(lab2xyz(lab)) 128 | # if(torch.sum(torch.isnan(out))>0): 129 | # print('lab2rgb') 130 | # embed() 131 | return out 132 | 133 | 134 | # the testing function in main show that 135 | # clip = true and preserve_paper = False will get better results 136 | def color_tranfer(source, target, clip=True, preserve_paper=False): 137 | """ 138 | Transfers the color distribution from the source to the target 139 | image using the mean and standard deviations of the L*a*b* 140 | color space. 141 | 142 | This implementation is (loosely) based on to the "Color Transfer 143 | between Images" paper by Reinhard et al., 2001. 144 | 145 | Parameters: 146 | ------- 147 | source: NumPy array 148 | OpenCV image in BGR color space (the source image) 149 | target: NumPy array 150 | OpenCV image in BGR color space (the target image) 151 | clip: Should components of L*a*b* image be scaled by np.clip before 152 | converting back to BGR codlor space? 153 | If False then components will be min-max scaled appropriately. 154 | Clipping will keep target image brightness truer to the input. 155 | Scaling will adjust image brightness to avoid washed out portions 156 | in the resulting color transfer that can be caused by clipping. 157 | preserve_paper: Should color transfer strictly follow methodology 158 | layed out in original paper? The method does not always produce 159 | aesthetically pleasing results. 160 | If False then L*a*b* components will scaled using the reciprocal of 161 | the scaling factor proposed in the paper. This method seems to produce 162 | more consistently aesthetically pleasing results 163 | 164 | Returns: 165 | ------- 166 | transfer: NumPy array 167 | OpenCV image (w, h, 3) NumPy array (uint8) 168 | 169 | -------- 170 | original source: https://github.com/jrosebr1/color_transfer/ 171 | converted into pytorch version 172 | """ 173 | source = source * 255 174 | target = target * 255 175 | 176 | source_ = source 177 | source = rgb2lab(source) 178 | target = rgb2lab(target) 179 | 180 | 181 | MeanSrc = source.mean(dims=(2, 3)) 182 | src_size=1 183 | for i in source.shape[2:]: 184 | src_size *= i 185 | # jt.unsqueeze(jt.unsqueeze(source, -1), -1) 186 | # src_out = (jt.unsqueeze(jt.unsqueeze(MeanSrc, -1), -1) - source).sqr().sum(dims=(2, 3)) 187 | src_out = (MeanSrc.unsqueeze(-1).unsqueeze(-1) - source).sqr().sum(dims=(2, 3)) 188 | src_out = src_out/(src_size-1) 189 | StdSrc = src_out.maximum(1e-6).sqrt() 190 | 191 | # StdSrc = source.std(dims=(2, 3)) 192 | 193 | MeanTar = target.mean(dims=(2, 3)) 194 | tar_size=1 195 | for i in target.shape[2:]: 196 | tar_size *= i 197 | tar_out = (MeanTar.unsqueeze(-1).unsqueeze(-1) - target).sqr().sum(dims=(2, 3)) 198 | tar_out = tar_out/(tar_size-1) 199 | StdTar = tar_out.maximum(1e-6).sqrt() 200 | 201 | # StdTar = target.std(dims=(2, 3)) 202 | target -= MeanTar.unsqueeze(-1).unsqueeze(-1) 203 | 204 | if preserve_paper: 205 | target = (StdTar/StdSrc).unsqueeze(-1).unsqueeze(-1) * target 206 | else: 207 | target = (StdSrc/StdTar).unsqueeze(-1).unsqueeze(-1) * target 208 | 209 | target += MeanSrc.unsqueeze(-1).unsqueeze(-1) 210 | target = lab2rgb(target) 211 | 212 | if clip: 213 | transfers = jt.clamp(target, 0, 255) 214 | transfers = transfers/255 215 | else: 216 | bmin = target.min(dim=-1)[0].min(dim=-1)[0].unsqueeze(-1).unsqueeze(-1) 217 | bmax = target.max(dim=-1)[0].max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1) 218 | transfers = (target - bmin) / (bmax - bmin) 219 | return transfers 220 | 221 | 222 | 223 | if __name__ == '__main__': 224 | import yaml 225 | import argparse 226 | from datasets.query_datasets import QueryDataset 227 | import jittor.transform as transform 228 | 229 | with open('./configs/pix3d.yaml', 'r') as f: 230 | config = yaml.load(f) 231 | def dict2namespace(config): 232 | namespace = argparse.Namespace() 233 | for key, value in config.items(): 234 | if isinstance(value, dict): 235 | new_value = dict2namespace(value) 236 | else: 237 | new_value = value 238 | setattr(namespace, key, new_value) 239 | return namespace 240 | cfg = dict2namespace(config) 241 | 242 | 243 | # query_dataset = QueryDataset(cfg=cfg) 244 | # query_loader = torch.utils.data.DataLoader(dataset=query_dataset, \ 245 | # batch_size=cfg.data.batch_size, shuffle=False, \ 246 | # drop_last=True, num_workers=cfg.data.num_workers) 247 | query_loader = QueryDataset(cfg=cfg).set_attrs(batch_size=cfg.data.batch_size, shuffle=False, num_workers=cfg.data.num_workers, drop_last=True) 248 | 249 | # device = torch.device('cuda:3') 250 | # cter = ColorTransfer(device) 251 | 252 | topil = transform.ToPILImage() 253 | with jt.no_grad(): 254 | for meta in query_loader: 255 | mask_img = meta['mask_img'] 256 | # rendering_img = meta['rendering_img'] 257 | # cats = meta['cat'] 258 | # instances = meta['instance'] 259 | query_img = meta['query_img'] 260 | 261 | # seq = torch.randperm(query_img.shape[0]) 262 | seq = [i for i in range(mask_img.shape[0])][::-1] 263 | style_img = query_img[seq] 264 | style_mask_img = mask_img[seq] 265 | transfer_img = color_tranfer(style_img, query_img) 266 | 267 | 268 | # ##### dataset debug ###### 269 | bs = query_img.shape[0] 270 | for ii in range(bs): 271 | 272 | jt.transpose(transfer_img[ii], [1,2,0]) 273 | q_img = topil(jt.transpose(query_img[ii], [1,2,0])) 274 | s_img = topil(jt.transpose(style_img[ii], [1,2,0])) 275 | tf_img = topil(jt.transpose(transfer_img[ii], [1,2,0])) 276 | 277 | 278 | q_img.save('./debug/%d-q_img.png' %(ii, )) 279 | s_img.save('./debug/%d-s_img.png' %(ii, )) 280 | tf_img.save('./debug/%d-tf_img.png' %(ii, )) 281 | 282 | debug = 258 283 | break 284 | 285 | -------------------------------------------------------------------------------- /code/RetrievalNet.py: -------------------------------------------------------------------------------- 1 | from utils import read_json 2 | from PIL import Image 3 | from Models import RetrievalNet 4 | import yaml 5 | import argparse 6 | from ColorTransfer import color_tranfer 7 | import tqdm 8 | from datasets.shape_datasets import ShapeDataset 9 | from datasets.query_datasets import QueryDataset 10 | import jittor.transform as transform 11 | import jittor as jt 12 | from jittor import nn 13 | 14 | 15 | 16 | # from tensorboardX import SummaryWriter 17 | 18 | import warnings 19 | warnings.filterwarnings('ignore') 20 | 21 | import os 22 | 23 | 24 | class Retrieval(object): 25 | ''' 26 | training 27 | testing 28 | loading 29 | saving 30 | ''' 31 | def __init__(self, config): 32 | self.cfg =config 33 | self.retrieval_net = RetrievalNet(self.cfg) 34 | lr = 0.00005 35 | beta1 = 0.5 36 | beta2 = 0.999 37 | 38 | self.opt = nn.Adam(self.retrieval_net.parameters(), lr=lr, betas=(beta1, beta2)) 39 | 40 | self.normal_tf = transform.ImageNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 41 | self.tau = self.cfg.data.tau 42 | self.size = self.cfg.data.pix_size 43 | self.dim = self.cfg.models.z_dim 44 | self.view_num = self.cfg.data.view_num 45 | 46 | self.epoch = 0 47 | self.it = 0 48 | self.loading(self.cfg.models.pre_trained_path) 49 | 50 | 51 | self.logs_root = './logs/'+self.cfg.data.name 52 | self.best_acc = 0 53 | # writer 54 | # writer = SummaryWriter(os.path.join(self.logs_root, 'curves')) 55 | 56 | 57 | def training(self): 58 | # device = self.device 59 | cfg = self.cfg 60 | query_loader = QueryDataset(cfg=cfg).set_attrs(batch_size=cfg.data.batch_size, shuffle=True, num_workers=cfg.data.num_workers, drop_last=True) 61 | shape_loader = ShapeDataset(cfg=cfg).set_attrs(batch_size=cfg.data.batch_size, shuffle=True, num_workers=cfg.data.num_workers, drop_last=True) 62 | 63 | total_epoch = cfg.trainer.epochs 64 | epoch = self.epoch 65 | it = self.it 66 | while epoch < total_epoch: 67 | # self.testing() 68 | self.retrieval_net.train() 69 | 70 | pbar = tqdm.tqdm(query_loader) 71 | # print(len(pbar)) 72 | for meta in pbar: 73 | mask_img = meta['mask_img'] 74 | rendering_img = meta['rendering_img'] 75 | cats = meta['cat'] 76 | instances = meta['instance'] 77 | query_img = meta['query_img'] 78 | 79 | # doing color transfer 80 | seq = jt.misc.randperm(query_img.shape[0]) 81 | style_img = query_img[seq] 82 | transfer_img = color_tranfer(style_img, query_img) 83 | query_img = self.normal_tf(transfer_img).detach() 84 | 85 | 86 | # get Instance and Category level label 87 | inst_list = [] # record the unique idx for each model 88 | inst_index = [] # instance level idx 89 | idx_list = [] # using for torch.cat 90 | bs = len(cats) 91 | 92 | for ii in range(len(cats)): 93 | tmp_cat = cats[ii] 94 | tmp_inst = instances[ii] 95 | try: 96 | # model already existed 97 | idx = inst_list.index((tmp_cat, tmp_inst)) 98 | inst_index.append(idx) 99 | except ValueError: 100 | inst_index.append(len(inst_list)) 101 | inst_list.append((tmp_cat, tmp_inst)) 102 | idx_list.append(ii) 103 | 104 | rendering_img = jt.concat([rendering_img[idx:idx+1] for idx in idx_list], dim=0) 105 | 106 | 107 | while not len(inst_list) == bs: 108 | try: 109 | shape_meta = next(iter_shape) 110 | except: 111 | iter_shape = iter(shape_loader) 112 | shape_meta = next(iter_shape) 113 | 114 | tmp_cats = shape_meta['labels']['cat'] 115 | tmp_insts = shape_meta['labels']['instance'] 116 | tmp_reinderings = shape_meta['rendering_img'] 117 | # tmp_rendering_idx_list = [] # using this list to cat ebds in new datasets_iter 118 | tmp_rendering_list = [] 119 | 120 | for ii in range(len(tmp_cats)): 121 | tmp_cat = tmp_cats[ii] 122 | tmp_inst = tmp_insts[ii] 123 | try: 124 | # model already existed 125 | idx = inst_list.index((tmp_cat, tmp_inst)) 126 | except ValueError: 127 | inst_list.append((tmp_cat, tmp_inst)) 128 | tmp_rendering_list.append(tmp_reinderings[ii:ii+1]) 129 | 130 | if len(inst_list) == bs: 131 | break 132 | if not len(tmp_rendering_list) == 0: 133 | tmp_reindering = jt.concat(tmp_rendering_list, dim=0) 134 | rendering_img = jt.concat([rendering_img, tmp_reindering], dim=0) 135 | # bsx12x224x224 136 | 137 | cats_list = [] 138 | shape_cats_index = [] 139 | image_cats_index = [] 140 | # cats_index = [] # category level idx 141 | for ii, items in enumerate(inst_list): 142 | tmp_cat, tmp_inst = items 143 | try: 144 | idx = cats_list.index(tmp_cat) 145 | shape_cats_index.append(idx) 146 | except ValueError: 147 | shape_cats_index.append(len(cats_list)) 148 | cats_list.append(tmp_cat) 149 | tmp_cat, tmp_inst = inst_list[inst_index[ii]] 150 | idx = cats_list.index(tmp_cat) 151 | image_cats_index.append(idx) 152 | 153 | 154 | inst_label = jt.int(jt.array(inst_index).view(-1,1)) 155 | InstsMat = jt.zeros((inst_label.shape[0], bs)).scatter_(1, inst_label, jt.ones((inst_label.shape[0], bs))).t() 156 | 157 | shape_cats_label = jt.int(jt.array(shape_cats_index)) 158 | image_cats_label = jt.int(jt.array(image_cats_index)) 159 | 160 | shape_cats_labels = shape_cats_label.unsqueeze(1).repeat(1, bs) 161 | image_cats_labels = image_cats_label.unsqueeze(0).repeat(bs, 1) 162 | CatsMat = jt.float(shape_cats_labels==image_cats_labels) 163 | 164 | # ######## [ end ] ######## [create no repeat rendering batch data ] ######## 165 | # InstsMat bs, bs (shape, image) 166 | # CatsMat bs, bs (shape, image) 167 | 168 | # bsx4x224x224 169 | mquery = jt.concat((query_img, mask_img), dim=1) 170 | query_image_ebd, queried_rendering_ebd = self.retrieval_net(mquery, rendering_img) 171 | qi_ebd = query_image_ebd # bs, bs, 128 172 | qr_ebd = queried_rendering_ebd # bs, bs, 128 (shape, image, 128) 173 | 174 | 175 | prod_mat = (qi_ebd * qr_ebd).sum(dim=2) 176 | ProdMat = jt.exp(prod_mat * (1/self.tau)) # bs, bs (shape, image) 177 | 178 | ProdMat_sum = ProdMat.sum(dim=0) 179 | # Instance Loss 180 | loss_inst_ = (ProdMat * InstsMat).sum(dim=0) / ProdMat_sum 181 | loss_inst = -jt.log(loss_inst_) 182 | loss_inst = loss_inst.mean() 183 | 184 | # Category Loss 185 | if not (len(cats_list) == 1): 186 | CatsMat_exc = CatsMat 187 | pos_num = CatsMat_exc.sum(dim=0) 188 | 189 | pos_num[pos_num==0]=1 # In some cases, pos_num = 0 -->> nan 190 | SumMat = ProdMat.sum(dim=0).view(1, -1).repeat(bs, 1) # SumMat excluding InstMat 191 | ExcProdMat = ProdMat * CatsMat_exc / SumMat 192 | 193 | ExcProdMat[ExcProdMat==0] = 1 194 | # ExcProdMat[ExcProdMat<1e-5] = 1e-5 195 | loss_cats_ = -jt.log(ExcProdMat).sum(dim=0)/pos_num 196 | loss_cats = loss_cats_[loss_cats_ != 0] 197 | loss_cats = loss_cats.mean() 198 | 199 | loss = loss_inst + 0.2*loss_cats 200 | loss_cats_item = loss_cats.item() 201 | else: 202 | loss = loss_inst 203 | loss_cats_item = 0 204 | 205 | loss_item = loss.item() 206 | loss_inst_item = loss_inst.item() 207 | 208 | self.opt.step(loss) 209 | 210 | it += 1 211 | 212 | info_dict = {'loss': '%.3f' %(loss_item, ), 'loss_inst': '%.3f' %(loss_inst_item, ), 'loss_cats': '%.3f' %(loss_cats_item, ) } 213 | pbar.set_postfix(info_dict) 214 | pbar.set_description('Epoch: %d, Iter: %d ' % (epoch, it)) 215 | 216 | 217 | epoch += 1 218 | self.it = it 219 | self.epoch = epoch 220 | # self.testing() 221 | if epoch % 10 == 0 or epoch > 50: 222 | self.testing() 223 | 224 | 225 | def testing(self): 226 | # device = self.device 227 | cfg = self.cfg 228 | # writer = self.writer 229 | self.retrieval_net.eval() 230 | 231 | is_aug = cfg.setting.is_aug 232 | cfg.setting.is_aug = False 233 | 234 | shape_loader = ShapeDataset(cfg=cfg).set_attrs(batch_size=cfg.data.batch_size, shuffle=False, num_workers=cfg.data.num_workers, drop_last=False) 235 | cfg.setting.is_aug = is_aug 236 | 237 | 238 | shape_cats_list = [] 239 | shape_inst_list = [] 240 | shape_ebd_list = [] 241 | pbar = tqdm.tqdm(shape_loader) 242 | for meta in pbar: 243 | with jt.no_grad(): 244 | rendering_img = meta['rendering_img'] 245 | cats = meta['labels']['cat'] 246 | instances = meta['labels']['instance'] 247 | 248 | rendering = rendering_img.view(-1, 1, self.size, self.size) 249 | rendering_ebds = self.retrieval_net.get_rendering_ebd(rendering).view(-1, self.view_num, self.dim) 250 | shape_cats_list += cats 251 | shape_inst_list += instances 252 | shape_ebd_list.append(rendering_ebds) 253 | 254 | shape_ebd = jt.concat(shape_ebd_list, dim=0) # num, 12, dim 255 | 256 | # test image 257 | json_dict = read_json(os.path.join(cfg.data.root_dir, cfg.data.name, cfg.data.test_json)) 258 | json_lenth = len(json_dict) 259 | query_transformer = transform.Compose([transform.ToTensor(), transform.ImageNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 260 | mask_transformer = transform.Compose([transform.ToTensor(), transform.ImageNormalize((0.5, ), (0.5, ))]) 261 | bs = shape_ebd.shape[0] 262 | 263 | total_dict = {} 264 | acc_cats_dict = {} 265 | acc_inst_dict = {} 266 | with jt.no_grad(): 267 | pbar = tqdm.tqdm(range(json_lenth)) 268 | for i in pbar: 269 | info = json_dict[i] 270 | query_img = query_transformer(Image.open(os.path.join(cfg.data.root_dir, cfg.data.name, info['img']))) 271 | mask_img = mask_transformer(Image.open(os.path.join(cfg.data.root_dir, cfg.data.name, info['mask']))) 272 | 273 | query = jt.concat((query_img, mask_img), dim=0) 274 | # query = query.unsqueeze(dim=0) 275 | query = jt.unsqueeze(query, dim=0) 276 | query_ebd = self.retrieval_net.get_query_ebd(query) 277 | 278 | 279 | query_ebd = query_ebd.repeat(bs, 1, 1) 280 | _, weights = self.retrieval_net.attention_query(query_ebd, shape_ebd) 281 | queried_rendering_ebd = jt.nn.bmm(weights, shape_ebd) 282 | qr_ebd = queried_rendering_ebd 283 | qi_ebd = query_ebd 284 | prod_mat = (qi_ebd * qr_ebd).sum(dim=2) 285 | max_idx = prod_mat.argmax(dim=0) 286 | 287 | # print(max_idx[0].data[0]) 288 | pr_cats = shape_cats_list[max_idx[0].data[0]] 289 | pr_inst = shape_inst_list[max_idx[0].data[0]] 290 | 291 | gt_cats = info['category'] 292 | gt_inst = info['model'].split('/')[-2] 293 | 294 | try: 295 | total_dict[gt_cats] = total_dict[gt_cats] + 1 296 | except: 297 | total_dict[gt_cats] = 1 298 | 299 | if gt_cats == pr_cats: 300 | try: 301 | acc_cats_dict[gt_cats] = acc_cats_dict[gt_cats] + 1 302 | except: 303 | acc_cats_dict[gt_cats] = 1 304 | if gt_cats == pr_cats and gt_inst == pr_inst: 305 | try: 306 | acc_inst_dict[gt_cats] = acc_inst_dict[gt_cats] + 1 307 | except: 308 | acc_inst_dict[gt_cats] = 1 309 | 310 | total_num = 0 311 | total_acc = 0 312 | out_info = [] 313 | for keys in total_dict.keys(): 314 | num = total_dict[keys] 315 | try: 316 | inst_num = acc_inst_dict[keys] 317 | except: 318 | inst_num = 0 319 | 320 | try: 321 | cats_num = acc_cats_dict[keys] 322 | except: 323 | cats_num = 0 324 | 325 | total_num += num 326 | total_acc += inst_num 327 | out_info.append('%s: inst: %d, cats: %d, total: %d\n' %(keys, inst_num, cats_num, num)) 328 | 329 | out_infos = ''.join(out_info) 330 | print(out_infos) 331 | 332 | final_acc = total_acc/total_num 333 | print(final_acc) 334 | 335 | if final_acc > self.best_acc: 336 | self.best_acc = final_acc 337 | paths_list=self.cfg.models.pre_trained_path.split('/')[:-1] 338 | paths_list.append(self.cfg.data.name+'.pt') 339 | paths = '/'.join(paths_list) 340 | self.saving(paths=paths) 341 | 342 | print('best acc: %.3f' %(self.best_acc, )) 343 | 344 | 345 | def loading(self, paths=None): 346 | # print(self.retrieval_net.state_dict().keys()) 347 | cfg = self.cfg 348 | if paths == None or not os.path.exists(paths): 349 | # init 350 | model_dict = self.retrieval_net.state_dict() 351 | res18_pre_path = os.path.join(cfg.models.pre_train_resnet_root, 'resnet18.pkl') 352 | res50_pre_path = os.path.join(cfg.models.pre_train_resnet_root, 'resnet50.pkl') 353 | save_model18 = jt.load(res18_pre_path) 354 | save_model50 = jt.load(res50_pre_path) 355 | 356 | # query encoder 357 | # conv1 and fc 358 | prefix = 'query_encoder.resnet' 359 | for keys in save_model50.keys(): 360 | key_prefix = keys.split('.')[0] 361 | if key_prefix == 'conv1' or key_prefix == 'fc': 362 | continue 363 | model_key = '%s.%s' %(prefix, keys) 364 | model_dict[model_key] = save_model50[keys] 365 | 366 | # rendering encoder 367 | # conv1 and fc 368 | prefix = 'rendering_encoder.resnet' 369 | for keys in save_model18.keys(): 370 | key_prefix = keys.split('.')[0] 371 | if key_prefix == 'conv1' or key_prefix == 'fc': 372 | continue 373 | model_key = '%s.%s' %(prefix, keys) 374 | model_dict[model_key] = save_model18[keys] 375 | 376 | self.retrieval_net.load_state_dict(model_dict) 377 | print('No ckpt! Init from ResNet18 for RenderingEncoder and ResNet50 for QueryEncoder') 378 | 379 | else: 380 | # loading 381 | ckpt = jt.load(paths) 382 | self.retrieval_net.load_state_dict(ckpt) 383 | print('loading %s successfully' %(paths)) 384 | 385 | 386 | def saving(self, paths=None): 387 | cfg = self.cfg 388 | 389 | if paths == None: 390 | save_name = "epoch_{}_iter_{}.pt".format(self.epoch, self.it) 391 | save_path = os.path.join(cfg.save_dir, save_name) 392 | print('models %s saved!\n' %(save_name, )) 393 | else: 394 | save_path = paths 395 | print('model paths %s saved!\n' %(paths, )) 396 | 397 | jt.save(self.retrieval_net.state_dict(), save_path) 398 | 399 | 400 | if __name__ == '__main__': 401 | parser = argparse.ArgumentParser() 402 | parser.add_argument( 403 | "--config", type=str, default="./configs/compcars.yaml", help="Path to (.yaml) config file." 404 | ) 405 | 406 | configargs = parser.parse_args() 407 | 408 | with open(configargs.config, 'r', encoding="utf-8") as f: 409 | config = yaml.safe_load(f) 410 | def dict2namespace(config): 411 | namespace = argparse.Namespace() 412 | for key, value in config.items(): 413 | if isinstance(value, dict): 414 | new_value = dict2namespace(value) 415 | else: 416 | new_value = value 417 | setattr(namespace, key, new_value) 418 | return namespace 419 | config = dict2namespace(config) 420 | 421 | Task = Retrieval(config) 422 | # Task.testing() 423 | Task.training() 424 | 425 | 426 | 427 | -------------------------------------------------------------------------------- /code/RetrievalNet_test.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from numpy.core.records import record 3 | import jittor as jt 4 | from datasets.query_datasets import QueryDataset 5 | from datasets.shape_datasets import ShapeDataset 6 | import tqdm 7 | import jittor.transform as transform 8 | import os 9 | # from tensorboardX import SummaryWriter 10 | import warnings 11 | from utils import read_json 12 | from PIL import Image 13 | from RetrievalNet import RetrievalNet 14 | warnings.filterwarnings('ignore') 15 | from Models import RetrievalNet 16 | 17 | import numpy as np 18 | import binvox_rw 19 | 20 | import yaml 21 | import argparse 22 | # os.environ["CUDA_VISIBLE_DEVICES"]="3" 23 | 24 | 25 | def modified_averaged_hausdorff_distance(x, y): 26 | ''' 27 | Input: x is a Nxd Tensor 28 | y is a Mxd Tensor 29 | Output: dist is a NxM matrix where dist[i,j] is the norm 30 | between x[i,:] and y[j,:] 31 | i.e. dist[i,j] = ||x[i,:]-y[j,:]|| 32 | ''' 33 | with jt.no_grad(): 34 | xt = jt.float32(x.astype(np.float32)).unsqueeze(1) 35 | yt = jt.float32(y.astype(np.float32)).unsqueeze(0) 36 | differences = xt -yt 37 | # differences = x.unsqueeze(1).cuda() - y.unsqueeze(0).cuda() 38 | distances = jt.sum(differences**2, -1).sqrt() 39 | 40 | num = distances.shape[0] + distances.shape[1] 41 | res = jt.min(distances, dim=0).sum() + jt.min(distances, dim=1).sum() 42 | res = float(res) 43 | return res/num 44 | 45 | 46 | def cal_IoU_and_Haus(cfg, record_dict): 47 | iou_dict = {} 48 | haus_dict = {} 49 | 50 | name_src = cfg.data.name 51 | roots = cfg.data.root_dir 52 | 53 | if cfg.mode == 'shapenet': 54 | if cfg.data.name == 'pix3d': 55 | name_tar = 'shapenet4' 56 | else: 57 | name_tar = 'shapenetcars' 58 | else: 59 | name_tar = cfg.data.name 60 | 61 | 62 | cnt = 0 63 | for cat in record_dict.keys(): 64 | # nums = [i for i in range(len(record_dict[cat]))] 65 | records = [record_dict[cat][i] for i in range(len(record_dict[cat]))] 66 | iou_dict[cat] = [] 67 | haus_dict[cat] = [] 68 | 69 | for record in records: 70 | pr_cats, pr_inst, gt_cats, gt_inst = record 71 | 72 | roots_src = os.path.join(roots, name_src) 73 | roots_tar = os.path.join(roots, name_tar) 74 | src_binvox_path = os.path.join(roots_src, 'model_std_bin128', gt_cats, '%s.binvox'%(gt_inst,)) 75 | tar_binvox_path = os.path.join(roots_tar, 'model_std_bin128', pr_cats, '%s.binvox'%(pr_inst,)) 76 | 77 | with open(src_binvox_path, 'rb') as f: 78 | src_bin = binvox_rw.read_as_3d_array(f).data 79 | 80 | with open(tar_binvox_path, 'rb') as f: 81 | tar_bin = binvox_rw.read_as_3d_array(f).data 82 | 83 | # IoU 84 | Iou_st = np.sum(src_bin & tar_bin) / np.sum((src_bin | tar_bin) + 1e-8) 85 | # Haus 86 | src_ptc_path = os.path.join(roots_src, 'model_std_ptc10k_npy', gt_cats, '%s.npy'%(gt_inst,)) 87 | tar_ptc_path = os.path.join(roots_tar, 'model_std_ptc10k_npy', pr_cats, '%s.npy'%(pr_inst,)) 88 | 89 | src_ptc = np.load(src_ptc_path)[:10000]/2 90 | tar_ptc = np.load(tar_ptc_path)[:10000]/2 91 | 92 | cnt += 1 93 | if cnt % 500 == 0: 94 | print(cnt) 95 | Haus_st = modified_averaged_hausdorff_distance(src_ptc, tar_ptc) 96 | 97 | iou_dict[cat].append(Iou_st) 98 | haus_dict[cat].append(Haus_st) 99 | 100 | out_info = [] 101 | total_info = {} 102 | for cat in record_dict.keys(): 103 | length = len(iou_dict[cat]) 104 | out_info.append('%s: haus: %.4f, iou: %.4f,\ 105 | ' %(cat, sum(haus_dict[cat])/length, sum(iou_dict[cat])/length)) 106 | 107 | total_info[cat] = [] 108 | total_info[cat].append(sum(haus_dict[cat])) 109 | total_info[cat].append(sum(iou_dict[cat])) 110 | total_info[cat].append(length) 111 | for msg in out_info: 112 | print(msg) 113 | 114 | total_haus = sum([total_info[cat][0] for cat in record_dict.keys()]) 115 | total_iou = sum([total_info[cat][1] for cat in record_dict.keys()]) 116 | total_length = sum([total_info[cat][2] for cat in record_dict.keys()]) 117 | print('%s: haus: %.4f, iou: %.4f || %d\n\ 118 | ' %('[Total]', total_haus/total_length, 119 | total_iou/total_length, total_length)) 120 | 121 | 122 | 123 | class Retrieval(object): 124 | ''' 125 | ColorTransfer 126 | load 127 | save 128 | ''' 129 | def __init__(self, config): 130 | self.cfg =config 131 | self.retrieval_net = RetrievalNet(self.cfg) 132 | 133 | self.normal_tf = transform.ImageNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 134 | self.size = self.cfg.data.pix_size 135 | self.dim = self.cfg.models.z_dim 136 | self.view_num = self.cfg.data.view_num 137 | 138 | self.loading(self.cfg.models.pre_trained_path) 139 | 140 | 141 | def loading(self, paths=None): 142 | if paths == None or not os.path.exists(paths): 143 | print('No ckpt!') 144 | exit(-1) 145 | else: 146 | # loading 147 | ckpt = jt.load(paths) 148 | self.retrieval_net.load_state_dict(ckpt) 149 | print('loading %s successfully' %(paths)) 150 | 151 | 152 | def test_simple(self): 153 | # device = self.device 154 | cfg = self.cfg 155 | self.retrieval_net.eval() 156 | 157 | # datasets for no model embeddings repeating training 158 | cfg.setting.is_aug = False 159 | shape_loader = ShapeDataset(cfg=cfg).set_attrs(batch_size=cfg.data.batch_size, shuffle=False, num_workers=cfg.data.num_workers, drop_last=False) 160 | 161 | 162 | shape_cats_list = [] 163 | shape_inst_list = [] 164 | shape_ebd_list = [] 165 | pbar = tqdm.tqdm(shape_loader) 166 | for meta in pbar: 167 | with jt.no_grad(): 168 | rendering_img = meta['rendering_img'] 169 | cats = meta['labels']['cat'] 170 | instances = meta['labels']['instance'] 171 | 172 | rendering = rendering_img.view(-1, 1, self.size, self.size) 173 | rendering_ebds = self.retrieval_net.get_rendering_ebd(rendering).view(-1, self.view_num, self.dim) 174 | shape_cats_list += cats 175 | shape_inst_list += instances 176 | shape_ebd_list.append(rendering_ebds) 177 | 178 | shape_ebd = jt.concat(shape_ebd_list, dim=0) # num, 12, dim 179 | 180 | # test image 181 | json_dict = read_json(os.path.join(cfg.data.root_dir, cfg.data.name, cfg.data.test_json)) 182 | json_lenth = len(json_dict) 183 | query_transformer = transform.Compose([transform.ToTensor(), transform.ImageNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 184 | mask_transformer = transform.Compose([transform.ToTensor(), transform.ImageNormalize((0.5, ), (0.5, ))]) 185 | bs = shape_ebd.shape[0] 186 | 187 | total_dict = {} 188 | acc_cats_dict = {} 189 | acc_inst_dict = {} 190 | with jt.no_grad(): 191 | pbar = tqdm.tqdm(range(json_lenth)) 192 | for i in pbar: 193 | info = json_dict[i] 194 | query_img = query_transformer(Image.open(os.path.join(cfg.data.root_dir, cfg.data.name, info['img']))) 195 | mask_img = mask_transformer(Image.open(os.path.join(cfg.data.root_dir, cfg.data.name, info['mask']))) 196 | 197 | query = jt.concat((query_img, mask_img), dim=0) 198 | query = query.unsqueeze(dim=0) 199 | query_ebd = self.retrieval_net.get_query_ebd(query) 200 | 201 | 202 | query_ebd = query_ebd.repeat(bs, 1, 1) 203 | _, weights = self.retrieval_net.attention_query(query_ebd, shape_ebd) 204 | queried_rendering_ebd = jt.nn.bmm(weights, shape_ebd) 205 | qr_ebd = queried_rendering_ebd 206 | qi_ebd = query_ebd 207 | prod_mat = (qi_ebd * qr_ebd).sum(dim=2) 208 | max_idx = prod_mat.argmax(dim=0) 209 | 210 | 211 | pr_cats = shape_cats_list[int(max_idx[0])] 212 | pr_inst = shape_inst_list[int(max_idx[0])] 213 | 214 | gt_cats = info['category'] 215 | gt_inst = info['model'].split('/')[-2] 216 | 217 | try: 218 | total_dict[gt_cats] = total_dict[gt_cats] + 1 219 | except: 220 | total_dict[gt_cats] = 1 221 | 222 | if gt_cats == pr_cats: 223 | try: 224 | acc_cats_dict[gt_cats] = acc_cats_dict[gt_cats] + 1 225 | except: 226 | acc_cats_dict[gt_cats] = 1 227 | if gt_cats == pr_cats and gt_inst == pr_inst: 228 | try: 229 | acc_inst_dict[gt_cats] = acc_inst_dict[gt_cats] + 1 230 | except: 231 | acc_inst_dict[gt_cats] = 1 232 | 233 | total_num = 0 234 | total_acc = 0 235 | total_cat = 0 236 | out_info = [] 237 | for keys in total_dict.keys(): 238 | num = total_dict[keys] 239 | try: 240 | inst_num = acc_inst_dict[keys] 241 | except: 242 | inst_num = 0 243 | 244 | try: 245 | cats_num = acc_cats_dict[keys] 246 | except: 247 | cats_num = 0 248 | 249 | total_num += num 250 | total_acc += inst_num 251 | total_cat += cats_num 252 | out_info.append('%s: inst: %.3f, cats: %.3f, total: %d\n' %(keys, inst_num/num, cats_num/num, num)) 253 | 254 | out_infos = ''.join(out_info) 255 | print(out_infos) 256 | 257 | print(total_acc/total_num) 258 | print(total_cat/total_num) 259 | 260 | 261 | def test_full(self): 262 | cfg = self.cfg 263 | self.retrieval_net.eval() 264 | 265 | # datasets for no model embeddings repeating training 266 | cfg.setting.is_aug = False 267 | shape_loader = ShapeDataset(cfg=cfg).set_attrs(batch_size=cfg.data.batch_size, shuffle=False, num_workers=cfg.data.num_workers, drop_last=False) 268 | 269 | shape_cats_list = [] 270 | shape_inst_list = [] 271 | shape_ebd_list = [] 272 | pbar = tqdm.tqdm(shape_loader) 273 | for meta in pbar: 274 | with jt.no_grad(): 275 | rendering_img = meta['rendering_img'] 276 | cats = meta['labels']['cat'] 277 | instances = meta['labels']['instance'] 278 | 279 | rendering = rendering_img.view(-1, 1, self.size, self.size) 280 | rendering_ebds = self.retrieval_net.get_rendering_ebd(rendering).view(-1, self.view_num, self.dim) 281 | shape_cats_list += cats 282 | shape_inst_list += instances 283 | shape_ebd_list.append(rendering_ebds) 284 | 285 | shape_ebd = jt.concat(shape_ebd_list, dim=0) # num, 12, dim 286 | 287 | # test image 288 | json_dict = read_json(os.path.join(cfg.data.root_dir, cfg.data.name, cfg.data.test_json)) 289 | # json_dict = json_dict[:10] 290 | # json_lenth = len(json_dict) 291 | query_transformer = transform.Compose([transform.ToTensor(), transform.ImageNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 292 | mask_transformer = transform.Compose([transform.ToTensor(), transform.ImageNormalize((0.5, ), (0.5, ))]) 293 | bs = shape_ebd.shape[0] 294 | 295 | # sorted as category 296 | cats_dict = {} 297 | for items in json_dict: 298 | cat = items['category'] 299 | try: 300 | cats_dict[cat].append(items) 301 | except: 302 | cats_dict[cat] = [] 303 | cats_dict[cat].append(items) 304 | 305 | 306 | top1_dict = {} 307 | topk_dict = {} 308 | category_dict = {} 309 | record_dict = {} 310 | 311 | for cat in cats_dict.keys(): 312 | cats_list = cats_dict[cat] 313 | top1_dict[cat] = [0 for i in range(len(cats_list))] 314 | topk_dict[cat] = [0 for i in range(len(cats_list))] 315 | category_dict[cat] = [0 for i in range(len(cats_list))] 316 | record_dict[cat] = [] 317 | 318 | with jt.no_grad(): 319 | pbar = tqdm.tqdm(cats_list) 320 | for i, info in enumerate(pbar): 321 | # info = json_dict[i] 322 | query_img = query_transformer(Image.open(os.path.join(cfg.data.root_dir, cfg.data.name, info['img']))) 323 | mask_img = mask_transformer(Image.open(os.path.join(cfg.data.root_dir, cfg.data.name, info['mask']))) 324 | 325 | query = jt.concat((query_img, mask_img), dim=0) 326 | query = query.unsqueeze(dim=0) 327 | query_ebd = self.retrieval_net.get_query_ebd(query) 328 | 329 | 330 | query_ebd = query_ebd.repeat(bs, 1, 1) 331 | _, weights = self.retrieval_net.attention_query(query_ebd, shape_ebd) 332 | queried_rendering_ebd = jt.nn.bmm(weights, shape_ebd) 333 | qr_ebd = queried_rendering_ebd 334 | qi_ebd = query_ebd 335 | prod_mat = (qi_ebd * qr_ebd).sum(dim=2) 336 | max_idx = prod_mat.argmax(dim=0) 337 | 338 | 339 | pr_cats = shape_cats_list[int(max_idx[0])] 340 | pr_inst = shape_inst_list[int(max_idx[0])] 341 | 342 | gt_cats = info['category'] 343 | gt_inst = info['model'].split('/')[-2] 344 | 345 | if gt_cats == pr_cats: 346 | category_dict[cat][i] = 1 347 | 348 | if gt_inst == pr_inst: 349 | top1_dict[cat][i] = 1 350 | 351 | record_dict[cat].append((pr_cats, pr_inst, gt_cats, gt_inst)) 352 | 353 | 354 | 355 | max_idx = prod_mat.view(-1).topk(dim=0, k=10)[1] 356 | for kk in range(10): 357 | pr_cats = shape_cats_list[int(max_idx[kk])] 358 | pr_inst = shape_inst_list[int(max_idx[kk])] 359 | if gt_cats == pr_cats and gt_inst == pr_inst: 360 | topk_dict[cat][i] = 1 361 | break 362 | 363 | 364 | # basic output: top1, top10, cats, total number 365 | out_info = [] 366 | total_info = {} 367 | for cat in cats_dict.keys(): 368 | length = len(top1_dict[cat]) 369 | out_info.append('%s: top1: %d, top10: %d, cats: %d, || top1_rt: %.3f, top10_rt: %.3f, cats_rt: %.3f, || total num: %d\ 370 | ' %(cat, sum(top1_dict[cat]), sum(topk_dict[cat]), sum(category_dict[cat]), 371 | sum(top1_dict[cat])/length, sum(topk_dict[cat])/length, 372 | sum(category_dict[cat])/length, length)) 373 | 374 | total_info[cat] = [] 375 | total_info[cat].append(sum(top1_dict[cat])) 376 | total_info[cat].append(sum(topk_dict[cat])) 377 | total_info[cat].append(sum(category_dict[cat])) 378 | total_info[cat].append(length) 379 | for msg in out_info: 380 | print(msg) 381 | 382 | total_top1 = sum([total_info[cat][0] for cat in cats_dict.keys()]) 383 | total_top10 = sum([total_info[cat][1] for cat in cats_dict.keys()]) 384 | total_cats = sum([total_info[cat][2] for cat in cats_dict.keys()]) 385 | total_length = sum([total_info[cat][3] for cat in cats_dict.keys()]) 386 | print('%s: top1: %d, top10: %d, cats: %d, || top1_rt: %.3f, top10_rt: %.3f, cats_rt: %.3f, || total num: %d\n\ 387 | ' %('[Total]', total_top1, total_top10, total_cats, 388 | total_top1/total_length, total_top10/total_length, 389 | total_cats/total_length, total_length)) 390 | 391 | # self.cal_IoU_and_Haus(record_dict) 392 | return record_dict 393 | 394 | 395 | def _test_shapenet(self, ccat): 396 | cfg = self.cfg 397 | self.retrieval_net.eval() 398 | 399 | # datasets for no model embeddings repeating training 400 | cfg.setting.is_aug = False 401 | cats_ = ccat 402 | 403 | data_name = cfg.data.name 404 | data_render_path = cfg.data.render_path 405 | if data_name == 'pix3d': 406 | cfg.data.name = 'shapenet4' 407 | else: 408 | cfg.data.name = 'shapenetcars' 409 | cfg.data.render_path = 'rendering_shapenet%ss.pkl' %(cats_, ) 410 | 411 | shape_loader = ShapeDataset(cfg=cfg).set_attrs(batch_size=cfg.data.batch_size, shuffle=False, num_workers=cfg.data.num_workers, drop_last=False) 412 | 413 | cfg.data.name = data_name 414 | cfg.data.render_path = data_render_path 415 | 416 | shape_cats_list = [] 417 | shape_inst_list = [] 418 | shape_ebd_list = [] 419 | pbar = tqdm.tqdm(shape_loader) 420 | 421 | s_Flag = True 422 | for meta in pbar: 423 | with jt.no_grad(): 424 | rendering_img = meta['rendering_img'] 425 | cats = meta['labels']['cat'] 426 | instances = meta['labels']['instance'] 427 | 428 | rendering = rendering_img.view(-1, 1, self.size, self.size) 429 | rendering_ebds = self.retrieval_net.get_rendering_ebd(rendering).view(-1, self.view_num, self.dim) 430 | shape_cats_list += cats 431 | shape_inst_list += instances 432 | if s_Flag: 433 | shape_ebd = rendering_ebds 434 | s_Flag = False 435 | else: 436 | shape_ebd = jt.concat([shape_ebd, rendering_ebds], dim=0) # num, 12, dim 437 | # shape_ebd_list.append(rendering_ebds) 438 | del rendering 439 | del rendering_img 440 | del rendering_ebds 441 | del meta 442 | jt.jittor_core.cleanup() 443 | jt.sync_all() 444 | jt.gc() 445 | 446 | 447 | # shape_ebd = jt.concat(shape_ebd_list, dim=0) # num, 12, dim 448 | 449 | # test image 450 | json_dict = read_json(os.path.join(cfg.data.root_dir, cfg.data.name, cfg.data.test_json)) 451 | query_transformer = transform.Compose([transform.ToTensor(), transform.ImageNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 452 | mask_transformer = transform.Compose([transform.ToTensor(), transform.ImageNormalize((0.5, ), (0.5, ))]) 453 | bs = shape_ebd.shape[0] 454 | 455 | # sorted as category 456 | cats_dict = {} 457 | for items in json_dict: 458 | cat = items['category'] 459 | if not cat == ccat: 460 | continue 461 | try: 462 | cats_dict[cat].append(items) 463 | except: 464 | cats_dict[cat] = [] 465 | cats_dict[cat].append(items) 466 | 467 | record_dict = {} 468 | 469 | for cat in cats_dict.keys(): 470 | cats_list = cats_dict[cat] 471 | record_dict[cat] = [] 472 | 473 | with jt.no_grad(): 474 | pbar = tqdm.tqdm(cats_list) 475 | for i, info in enumerate(pbar): 476 | # info = json_dict[i] 477 | query_img = query_transformer(Image.open(os.path.join(cfg.data.root_dir, cfg.data.name, info['img']))) 478 | mask_img = mask_transformer(Image.open(os.path.join(cfg.data.root_dir, cfg.data.name, info['mask']))) 479 | 480 | query = jt.concat((query_img, mask_img), dim=0) 481 | query = query.unsqueeze(dim=0) 482 | query_ebd = self.retrieval_net.get_query_ebd(query) 483 | 484 | 485 | query_ebd = query_ebd.repeat(bs, 1, 1) 486 | _, weights = self.retrieval_net.attention_query(query_ebd, shape_ebd) 487 | queried_rendering_ebd = jt.bmm(weights, shape_ebd) 488 | qr_ebd = queried_rendering_ebd 489 | qi_ebd = query_ebd 490 | prod_mat = (qi_ebd * qr_ebd).sum(dim=2) 491 | max_idx = prod_mat.argmax(dim=0) 492 | 493 | 494 | pr_cats = shape_cats_list[int(max_idx[0])] 495 | pr_inst = shape_inst_list[int(max_idx[0])] 496 | 497 | gt_cats = info['category'] 498 | gt_inst = info['model'].split('/')[-2] 499 | 500 | 501 | record_dict[cat].append((pr_cats, pr_inst, gt_cats, gt_inst)) 502 | 503 | return record_dict[ccat] 504 | 505 | 506 | def test_shapenet(self): 507 | record_dict = {} 508 | if self.cfg.data.name == 'pix3d': 509 | cat_list = ['bed', 'chair', 'sofa', 'table'] 510 | else: 511 | cat_list = ['car', ] 512 | for cat in cat_list: 513 | record_dict[cat] = self._test_shapenet(cat) 514 | 515 | return record_dict 516 | 517 | 518 | if __name__ == '__main__': 519 | parser = argparse.ArgumentParser() 520 | parser.add_argument( 521 | "--config", type=str, default="./configs/stanfordcars.yaml", help="Path to (.yaml) config file." 522 | ) 523 | parser.add_argument( 524 | "--mode", type=str, default="shapenet", help="testing mode: simple | full | shapenet." 525 | ) 526 | 527 | configargs = parser.parse_args() 528 | 529 | with open(configargs.config, 'r', encoding="utf-8") as f: 530 | config = yaml.safe_load(f) 531 | def dict2namespace(config): 532 | namespace = argparse.Namespace() 533 | for key, value in config.items(): 534 | if isinstance(value, dict): 535 | new_value = dict2namespace(value) 536 | else: 537 | new_value = value 538 | setattr(namespace, key, new_value) 539 | return namespace 540 | config = dict2namespace(config) 541 | setattr(config, 'mode', configargs.mode) 542 | Task = Retrieval(config) 543 | 544 | if config.mode == 'simple': 545 | Task.test_simple() 546 | elif config.mode == 'full': 547 | records = Task.test_full() 548 | cal_IoU_and_Haus(config, records) 549 | elif config.mode == 'shapenet': 550 | records = Task.test_shapenet() 551 | cal_IoU_and_Haus(config, records) 552 | else: 553 | pass 554 | 555 | 556 | 557 | 558 | 559 | --------------------------------------------------------------------------------