├── .gitignore ├── README.md ├── data ├── bed.txt ├── bed_00000001.png ├── bed_00000002.jpg ├── bed_00000003.jpg ├── chair.txt ├── chair_00000001.jpg ├── chair_00000002.jpg ├── chair_00000003.jpg ├── chair_00000004.jpg ├── sofa.txt ├── sofa_00000001.jpg ├── sofa_00000002.jpg ├── sofa_00000003.jpg ├── swivelchair.txt ├── swivelchair_00000001.png ├── swivelchair_00000002.png ├── swivelchair_00000003.jpg └── swivelchair_00000004.jpg ├── download_models.sh └── src ├── 3D ├── genSynData │ ├── getStickFigure.m │ └── sctheta2theta.m └── tools │ ├── alpha2X3D.m │ ├── alpha2x_proj.m │ ├── propval.m │ ├── renderImage.m │ ├── show3DLD.m │ └── visualize3Dpara.m ├── evaluate.m ├── main.lua ├── nn └── Scale.lua └── pyramidLCN.lua /.gitignore: -------------------------------------------------------------------------------- 1 | archive/ 2 | gen/ 3 | models/ 4 | results/ 5 | www/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Single Image 3D Interpreter Network 2 | 3 | This repository contains pre-trained models and evaluation code for the project 'Single Image 3D Interpreter Network' (ECCV 2016). 4 | 5 | http://3dinterpreter.csail.mit.edu 6 | 7 | ## Prerequisites 8 | #### Torch 9 | We use Torch 7 (http://torch.ch) for our implementation. 10 | 11 | #### fb.mattorch and Matlab (optional) 12 | We use `.mat` file with [`fb.mattorch`](https://github.com/facebook/fblualib/tree/master/fblualib/mattorch) for saving results, and `Matlab` (R2015a or later, with Computer Vision System Toolbox) for visualization. 13 | 14 | ## Installation 15 | Our current release has been tested on Ubuntu 14.04. 16 | 17 | #### Clone the repository 18 | ```sh 19 | git clone git@github.com:jiajunwu/3dinn.git 20 | ``` 21 | #### Download pretrained models (1.8GB) 22 | ```sh 23 | cd 3dinn 24 | ./download_models.sh 25 | ``` 26 | 27 | ## Steps for evaluation 28 | 29 | #### I) List input images in `data/[classname].txt` 30 | 31 | #### II) Estimate 3D object structure 32 | 33 | The file (`src/main.lua`) has the following options. 34 | - `-gpuID`: specifies the gpu to run on (1-indexed) 35 | - `-class`: which model to use for evaluation. Our current release contains four models: `chair`, `swivelchair`, `bed`, and `sofa`. 36 | - `-batchSize`: the batch size to use 37 | 38 | Sample usages include 39 | - Estimate chair structure for images listed in `data/class.txt` 40 | ```sh 41 | cd src 42 | th main.lua -gpuID 1 -class chair 43 | ``` 44 | 45 | #### III) Check visualization in `www`, and estimated parameters in `results` 46 | 47 | ## Sample input & output 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 |
83 | 84 | 85 | ## Datasets we used 86 | 87 | - Keypoint-5 dataset [(zip, 208MB)](http://3dinterpreter.csail.mit.edu/data/keypoint-5.zip) 88 | 89 | - Extended IKEA dataset with additional 3D keypoint labels [(zip, 171MB)](http://3dinterpreter.csail.mit.edu/data/ikea_3DINN.zip) 90 | 91 | - [SUN Database](http://groups.csail.mit.edu/vision/SUN/) 92 | 93 | ## Reference 94 | 95 | @inproceedings{3dinterpreter, 96 | title={{Single Image 3D Interpreter Network}}, 97 | author={Wu, Jiajun and Xue, Tianfan and Lim, Joseph J and Tian, Yuandong and Tenenbaum, Joshua B and Torralba, Antonio and Freeman, William T}, 98 | booktitle={European Conference on Computer Vision}, 99 | pages={365--382}, 100 | year={2016} 101 | } 102 | 103 | For any questions, please contact Jiajun Wu (jiajunwu@mit.edu) and Tianfan Xue (tfxue@mit.edu). 104 | -------------------------------------------------------------------------------- /data/bed.txt: -------------------------------------------------------------------------------- 1 | bed_00000001.png 2 | bed_00000002.jpg 3 | bed_00000003.jpg 4 | -------------------------------------------------------------------------------- /data/bed_00000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/bed_00000001.png -------------------------------------------------------------------------------- /data/bed_00000002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/bed_00000002.jpg -------------------------------------------------------------------------------- /data/bed_00000003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/bed_00000003.jpg -------------------------------------------------------------------------------- /data/chair.txt: -------------------------------------------------------------------------------- 1 | chair_00000001.jpg 2 | chair_00000002.jpg 3 | chair_00000003.jpg 4 | chair_00000004.jpg 5 | -------------------------------------------------------------------------------- /data/chair_00000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/chair_00000001.jpg -------------------------------------------------------------------------------- /data/chair_00000002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/chair_00000002.jpg -------------------------------------------------------------------------------- /data/chair_00000003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/chair_00000003.jpg -------------------------------------------------------------------------------- /data/chair_00000004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/chair_00000004.jpg -------------------------------------------------------------------------------- /data/sofa.txt: -------------------------------------------------------------------------------- 1 | sofa_00000001.jpg 2 | sofa_00000002.jpg 3 | sofa_00000003.jpg 4 | -------------------------------------------------------------------------------- /data/sofa_00000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/sofa_00000001.jpg -------------------------------------------------------------------------------- /data/sofa_00000002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/sofa_00000002.jpg -------------------------------------------------------------------------------- /data/sofa_00000003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/sofa_00000003.jpg -------------------------------------------------------------------------------- /data/swivelchair.txt: -------------------------------------------------------------------------------- 1 | swivelchair_00000001.png 2 | swivelchair_00000002.png 3 | swivelchair_00000003.jpg 4 | swivelchair_00000004.jpg 5 | -------------------------------------------------------------------------------- /data/swivelchair_00000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/swivelchair_00000001.png -------------------------------------------------------------------------------- /data/swivelchair_00000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/swivelchair_00000002.png -------------------------------------------------------------------------------- /data/swivelchair_00000003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/swivelchair_00000003.jpg -------------------------------------------------------------------------------- /data/swivelchair_00000004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/swivelchair_00000004.jpg -------------------------------------------------------------------------------- /download_models.sh: -------------------------------------------------------------------------------- 1 | mkdir -p models 2 | 3 | wget http://3dinterpreter.csail.mit.edu/repo/models/chair_model.torch -O models/chair_model.torch 4 | wget http://3dinterpreter.csail.mit.edu/repo/models/swivelchair_model.torch -O models/swivelchair_model.torch 5 | wget http://3dinterpreter.csail.mit.edu/repo/models/bed_model.torch -O models/bed_model.torch 6 | wget http://3dinterpreter.csail.mit.edu/repo/models/sofa_model.torch -O models/sofa_model.torch 7 | -------------------------------------------------------------------------------- /src/3D/genSynData/getStickFigure.m: -------------------------------------------------------------------------------- 1 | function stickStruct = getStickFigure(varargin) 2 | 3 | % function stickStruct = getStickFigure(varargin) 4 | % 5 | % Parameters: 6 | % scale, class, indexPermute 7 | % 8 | % Fields of stickStruct 9 | % - nclass: ignore this field, this is always 1 10 | % - edgeAdj: M x 2 adjacent matrices, where M is the number of edges 11 | % - baseShape: 3 x N x B, where N is number of nodes, and B is number of 12 | % bases. baseShape(:,:,1) is the mean base shape, and baseShape(:, :, i) 13 | % are the additional deformation. See the demo code below for more 14 | % details. 15 | % - np: number of nodes 16 | % - nbasis: number of bases 17 | % - alphaGSif: ignore this field 18 | % - scaleRange: Range of scaling in camera parameters. 19 | % - thetaRange: Range of rotation angles in camera parameters. 20 | % - tranRange: Range of translation in camera parameters. 21 | % - fRange: Range of focal length 22 | % - h: the height of the input image 23 | % - w: the with of the input image 24 | % - indices: normally, it should be 1:N. For some models, it is not to 25 | % deal with some indices inconsistency. 26 | % 27 | % Demo code: 28 | % 29 | % cd [gitroot]/src 30 | % addpath(genpath('3D')); 31 | % addpath(genpath('nn')); 32 | % stickStruct = getStrickFigure('class', 'chair'); 33 | % subplot(2,3,1); show3DLD(stickStruct.baseShape{1}(:,:,1), stickStruct.edgeAdj{1}); 34 | % title('Mean shape'); 35 | % for i = 2:5 36 | % subplot(2,3,i); show3DLD(stickStruct.baseShape{1}(:,:,1) + ... 37 | % stickStruct.baseShape{1}(:,:,i), stickStruct.edgeAdj{1}); 38 | % title(sprintf('The %d-th deformation', i-1)); 39 | % end 40 | 41 | para.scale = []; 42 | para.class = 'chair'; 43 | para.indexPermute = true; 44 | para = propval(varargin, para); 45 | 46 | stickStruct.nclass = 1; 47 | stickStruct.edgeAdj = cell(1,1); 48 | stickStruct.baseShape = cell(1,1); 49 | stickStruct.alphaRange = cell(1,1); 50 | if ~isempty(para.scale) && (para.scale == 64) 51 | stickStruct.scaleRange = [9,11]/2; 52 | elseif ~isempty(para.scale) && (para.scale == 128) 53 | stickStruct.scaleRange = [9,11]/2; 54 | end 55 | 56 | if strcmp(para.class, 'chair') 57 | stickStruct.np = 10; 58 | stickStruct.nbasis = 5; 59 | stickStruct.edgeAdj{1} = [1,5;2,6;3,7;4,8;5,6;6,7;7,8;5,8;8,9;7,10;9,10]; 60 | stickStruct.baseShape{1} = zeros(3,stickStruct.np,stickStruct.nbasis); 61 | stickStruct.baseShape{1}(:,:,1) = [-1,-2,1; 1,-2,1; 1,-2,-1; -1,-2,-1; -1,0,1; 1,0,1; 1,0,-1; -1,0,-1; -1,2,-1; 1,2,-1]'; 62 | stickStruct.baseShape{1}(:,:,2) = [0,-1,0;0,-1,0;0,-1,0;0,-1,0; 0,0,0;0,0,0;0,0,0;0,0,0; 0,0,0;0,0,0]'; 63 | stickStruct.baseShape{1}(:,:,3) = [0,0,0;0,0,0;0,0,0;0,0,0; 0,0,0;0,0,0;0,0,0;0,0,0; 0,1,0;0,1,0]'; 64 | stickStruct.baseShape{1}(:,:,4) = [-1,0,0;1,0,0;1,0,0;-1,0,0; -1,0,0;1,0,0;1,0,0;-1,0,0; -1,0,0;1,0,0]'; 65 | stickStruct.baseShape{1}(:,:,5) = [-1,0,1;1,0,1;1,0,-1;-1,0,-1; 0,0,0;0,0,0;0,0,0;0,0,0; 0,0,0;0,0,0]'; 66 | stickStruct.alphaRange{1} = [-1.4,2;-1.4,2;-0.3,1;0,0.3]; 67 | stickStruct.alphaGSif = false; 68 | stickStruct.scaleRange{1} = [30,50]; 69 | stickStruct.thetaRange{1} = [1.0*pi,1.15*pi; 0,2*pi; 0,0]; 70 | stickStruct.tranRange{1} = [-60,60;-80,80]; 71 | stickStruct.fRange{1} = [120,450]; 72 | stickStruct.h = 320; 73 | stickStruct.w = 240; 74 | stickStruct.indices{1} = 1:10; 75 | stickStruct.shapecheckFunc = []; 76 | elseif strcmp(para.class, 'swivelchair') 77 | stickStruct.np = 13; 78 | stickStruct.nbasis = 7; 79 | stickStruct.edgeAdj{1} = [1,6;2,6;3,6;4,6;5,6; 6,7;8,9;9,10;8,11;10,11; 11,12;12,13;10,13]; 80 | stickStruct.baseShape{1} = zeros(3,stickStruct.np,stickStruct.nbasis); 81 | tmpScale = 1.2; 82 | stickStruct.baseShape{1}(:,:,1) = [0.3090*tmpScale,-1.5,0.9511*tmpScale; -0.8090*tmpScale,-1.5,0.5878*tmpScale; -0.8090*tmpScale,-1.5,-0.5878*tmpScale; 83 | 0.3090*tmpScale,-1.5,-0.9511*tmpScale; 1.0000*tmpScale,-1.5,0.0000*tmpScale; 84 | 0,-1.5,0; 0,0,0; -1,0,1; 1,0,1; 1,0,-1; -1,0,-1; -1,1.5,-1; 1,1.5,-1]'; 85 | stickStruct.baseShape{1}(:,:,2) = [0,-1,0; 0,-1,0; 0,-1,0; 0,-1,0; 0,-1,0; 0,-1,0; 0,0,0; 86 | 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;]'; 87 | stickStruct.baseShape{1}(:,:,3) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 88 | 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,1,0; 0,1,0;]'; 89 | stickStruct.baseShape{1}(:,:,4) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 90 | -1,0,0; 1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0;]'; 91 | stickStruct.baseShape{1}(:,:,5) = [0.3090,0,0.9511; -0.8090,0,0.5878; -0.8090,0,-0.5878; 0.3090,0,-0.9511; 1.0000,0,0.0000; 92 | 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;]'; % wheels 93 | stickStruct.baseShape{1}(:,:,6) = [-0.9511,0,0.3090; -0.5878,0,-0.8090; 0.5878,0,-0.8090; 0.9511,0,0.3090; 0.0000,0,1.0000; 94 | 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;]'; % wheels 95 | stickStruct.baseShape{1}(:,:,7) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 96 | 0,1,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;]'; % wheels 97 | stickStruct.alphaRange{1} = [-0.5,0.5; -1,2.5; -0.5,1; -0.5,0.8; -0.3,0.6; 0,0.6]; 98 | stickStruct.alphaGSif = false; 99 | stickStruct.scaleRange{1} = [35,85]; 100 | stickStruct.tranRange{1} = [-60,60;-80,80]; 101 | stickStruct.thetaRange{1} = [1.0*pi,1.15*pi; 0,2*pi; 0,0]; 102 | stickStruct.fRange{1} = [120,450]; 103 | stickStruct.h = 320; 104 | stickStruct.w = 240; 105 | stickStruct.indices{1} = 1:13; 106 | stickStruct.shapecheckFunc = []; 107 | elseif strcmp(para.class, 'table') 108 | stickStruct.np = 8; 109 | stickStruct.nbasis = 6; 110 | stickStruct.edgeAdj{1} = [1,2;2,4;1,3;3,4;1,5;2,6;3,7;4,8]; 111 | stickStruct.baseShape{1} = zeros(3,stickStruct.np,stickStruct.nbasis); 112 | stickStruct.baseShape{1}(:,:,1) = [-1,1,-1; 1,1,-1; -1,1,1; 1,1,1; -1,-1,-1; 1,-1,-1; -1,-1,1; 1,-1,1;]'; 113 | stickStruct.baseShape{1}(:,:,2) = [0,1,0;0,1,0;0,1,0;0,1,0; 0,-1,0;0,-1,0;0,-1,0;0,-1,0;]'; 114 | stickStruct.baseShape{1}(:,:,3) = [-1,0,0;1,0,0;-1,0,0;1,0,0; -1,0,0;1,0,0;-1,0,0;1,0,0;]'; 115 | stickStruct.baseShape{1}(:,:,4) = [0,0,-1;0,0,-1;0,0,1;0,0,1; 0,0,-1;0,0,-1;0,0,1;0,0,1;]'; 116 | stickStruct.baseShape{1}(:,:,5) = [0,0,0;0,0,0;0,0,0;0,0,0; 1,0,0;-1,0,0;1,0,0;-1,0,0;]'; 117 | stickStruct.baseShape{1}(:,:,6) = [0,0,0;0,0,0;0,0,0;0,0,0; 0,0,-1;0,0,-1;0,0,1;0,0,1;]'; 118 | stickStruct.alphaRange{1} = [-0.3,0.3;-0.6,2.5;-0.3,1; 0,1;0,0.2]; 119 | stickStruct.alphaGSif = false; 120 | stickStruct.scaleRange{1} = [30,42]; 121 | stickStruct.tranRange{1} = [-80,80; -60,60]; 122 | % stickStruct.tranRange{1} = [-10,10; -120,120]; 123 | stickStruct.thetaRange{1} = [1.0*pi,1.3*pi; -0.23*pi,0.23*pi; 0,0]; 124 | stickStruct.fRange{1} = [120,450]; 125 | stickStruct.h = 320; 126 | stickStruct.w = 240; 127 | stickStruct.indices{1} = 1:8; 128 | stickStruct.shapecheckFunc = []; 129 | elseif strcmp(para.class, 'bed') 130 | stickStruct.np = 10; 131 | stickStruct.nbasis = 5; 132 | stickStruct.edgeAdj{1} = [1,5;2,6;3,7;4,8;3,4;5,6;6,7;7,8;5,8;8,9;7,10;9,10]; 133 | stickStruct.baseShape{1} = zeros(3,stickStruct.np,stickStruct.nbasis); 134 | stickStruct.baseShape{1}(:,:,1) = [-1.00,-1.00,2.00; 1.00,-1.00,2.00; 1.00,-1.00,-2.00; -1.00,-1.00,-2.00; -1.00,0.00,2.00; 1.00,0.00,2.00; 135 | 1.00,0.00,-2.00; -1.00, 0.00,-2.00; -1.00,1.00,-2.00; 1.00,1.00,-2.00]'; 136 | stickStruct.baseShape{1}(:,:,2) = [0,0,1; 0,0,1; 0,0,-1; 0,0,-1; 0,0,1; 0,0,1; 0,0,-1; 0,0,-1; 0,0,-1; 0,0,-1]'; % length 137 | stickStruct.baseShape{1}(:,:,3) = [-1,0,0; 1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0]'; % width 138 | stickStruct.baseShape{1}(:,:,4) = [0,-1,0; 0,-1,0; 0,-1,0; 0,-1,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0]'; % height of leg 139 | stickStruct.baseShape{1}(:,:,5) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,1,0; 0,1,0]'; % height of back 140 | stickStruct.alphaRange{1} = [-1.2,0.5; -0.5,1.5; -0.5,0.5; -0.5,1]; 141 | stickStruct.alphaGSif = false; 142 | stickStruct.scaleRange{1} = [30,42]; 143 | stickStruct.tranRange{1} = [-80,80; -60,60]; 144 | % stickStruct.tranRange{1} = [-10,10; -120,120]; 145 | stickStruct.thetaRange{1} = [1.0*pi,1.25*pi; -0.75*pi,0.75*pi; 0,0]; 146 | stickStruct.fRange{1} = [200,450]; 147 | stickStruct.h = 320; 148 | stickStruct.w = 240; 149 | % stickStruct.indices{1} = [2,7,9,4,1,6,8,3,5,10]; 150 | stickStruct.indices{1} = [2,7,6,1,4,9,8,3,5,10]; 151 | stickStruct.shapecheckFunc = []; 152 | elseif strcmp(para.class, 'sofa') 153 | stickStruct.np = 14; 154 | stickStruct.nbasis = 7; 155 | stickStruct.edgeAdj{1} = [1,5;2,6;3,7;4,8;5,6;6,7;7,8;5,8;8,9;7,10;9,10;11,12;12,5;13,14;14,6]; 156 | stickStruct.baseShape{1} = zeros(3,stickStruct.np,stickStruct.nbasis); 157 | stickStruct.baseShape{1}(:,:,1) = [-2.00,-2.00,1.00; 2.00,-2.00,1.00; 2.00,-2.00,-1.00; -2.00,-2.00,-1.00; -2.00,0.00,1.00; 158 | 2.00,0.00,1.00; 2.00,0.00,-1.00; -2.00,0.00,-1.00; -2.00,2.00,-1.00; 2.00,2.00,-1.00; -2.00,1.00,-1.00; -2.00,1.00,1.00; 2.00,1.00,-1.00; 2.00,1.00,1.00]'; 159 | stickStruct.baseShape{1}(:,:,2) = [0,0,1; 0,0,1; 0,0,-1; 0,0,-1; 0,0,1; 0,0,1; 0,0,-1; 0,0,-1; 0,0,-1; 0,0,-1; 0,0,-1; 0,0,1; 0,0,-1; 0,0,1]'; % length 160 | stickStruct.baseShape{1}(:,:,3) = [-1,0,0; 1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0; 1,0,0]'; % width 161 | stickStruct.baseShape{1}(:,:,4) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,1,0; 0,1,0; 0,0.5,0; 0,0.5,0; 0,0.5,0; 0,0.5,0]'; % height of back 162 | stickStruct.baseShape{1}(:,:,5) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,1,0; 0,1,0; 0,1,0; 0,1,0]'; % height of arm rest 163 | stickStruct.baseShape{1}(:,:,6) = [0,-1,0; 0,-1,0; 0,-1,0; 0,-1,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0]'; % height of leg 164 | stickStruct.baseShape{1}(:,:,7) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,-1; 0,0,0; 0,0,-1]'; % loc of arm rest 165 | stickStruct.alphaRange{1} = [-0.5,0.5; 0,2; -0.5,1; -1,2; -1,1; 0,2]; 166 | stickStruct.alphaGSif = false; 167 | stickStruct.scaleRange{1} = [30,42]; 168 | stickStruct.tranRange{1} = [-80,80; -60,60]; 169 | stickStruct.thetaRange{1} = [1.0*pi,1.15*pi; -0.5*pi,0.5*pi; 0,0]; 170 | stickStruct.fRange{1} = [200,450]; 171 | stickStruct.h = 320; 172 | stickStruct.w = 240; 173 | stickStruct.indices{1} = [2,9,8,1,4,11,10,3,7,14,5,6,12,13]; 174 | stickStruct.shapecheckFunc = []; 175 | end 176 | 177 | if para.indexPermute 178 | forwardidx = stickStruct.indices{1}; 179 | revertidx = zeros(size(forwardidx)); 180 | revertidx(stickStruct.indices{1}) = 1:length(stickStruct.indices{1}); 181 | stickStruct.baseShape{1} = stickStruct.baseShape{1}(:,revertidx,:); 182 | stickStruct.edgeAdj{1} = forwardidx(stickStruct.edgeAdj{1}); 183 | end 184 | 185 | 186 | 187 | if ~isempty(para.scale) && (para.scale == 64) 188 | stickStruct.h = 64; 189 | stickStruct.w = 64; 190 | elseif ~isempty(para.scale) && (para.scale == 128) 191 | stickStruct.h = 128; 192 | stickStruct.w = 128; 193 | end 194 | 195 | end 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /src/3D/genSynData/sctheta2theta.m: -------------------------------------------------------------------------------- 1 | function theta = sctheta2theta(sincosTheta) 2 | 3 | % function theta = sctheta2theta(sincosTheta) 4 | 5 | sincosTheta = sincosTheta ./ repmat(sqrt(sincosTheta(1:3,:).^2 + sincosTheta(4:6,:).^2),[2,1]); 6 | theta = asin(sincosTheta(1:3,:)); 7 | mask = sincosTheta(4:6,:) < 0; 8 | theta(mask) = pi - theta(mask); 9 | 10 | end 11 | 12 | -------------------------------------------------------------------------------- /src/3D/tools/alpha2X3D.m: -------------------------------------------------------------------------------- 1 | function X = alpha2X3D(alpha, baseShape) 2 | 3 | X = sum(bsxfun(@times,baseShape,shiftdim(alpha,-2)),3); 4 | 5 | end -------------------------------------------------------------------------------- /src/3D/tools/alpha2x_proj.m: -------------------------------------------------------------------------------- 1 | function x = alpha2x_proj(tran,alpha,theta,f, baseShape, varargin) 2 | 3 | % function x = alpha2x_proj(tran,alpha,theta,f, baseShape, varargin) 4 | % 5 | % Paramters:h, w 6 | 7 | para.h = 64; 8 | para.w = 64; 9 | para = propval(varargin, para); 10 | h = para.h; h2 = h/2; 11 | w = para.w; w2 = w/2; 12 | 13 | np = size(baseShape,2); 14 | nShape = size(alpha,2); 15 | x = zeros([2,np,nShape]); 16 | for i=1:nShape 17 | Rx = [1,0,0; 0,cos(theta(1,i)),-sin(theta(1,i));0,sin(theta(1,i)),cos(theta(1,i))]; 18 | Ry = [cos(theta(2,i)),0,sin(theta(2,i)); 0,1,0;-sin(theta(2,i)),0,cos(theta(2,i))]; 19 | Rz = [cos(theta(3,i)),-sin(theta(3,i)),0; sin(theta(3)),cos(theta(3,i)),0;0,0,1]; 20 | R = Rx*Ry*Rz; 21 | 22 | coor3 = sum(bsxfun(@times,baseShape,shiftdim(alpha(:,i),-2)),3); 23 | coor3RT = bsxfun(@plus, R * coor3, [tran(:,i);0]); 24 | x(:,:,i) = bsxfun(@plus, bsxfun(@times, coor3RT(1:2,:), 1./(1+coor3RT(3,:)/f(i))), [w2;h2]); 25 | end 26 | 27 | end 28 | 29 | -------------------------------------------------------------------------------- /src/3D/tools/propval.m: -------------------------------------------------------------------------------- 1 | function [merged unused] = propval(propvals, defaults, varargin) 2 | 3 | % Create a structure combining property-value pairs with default values. 4 | % 5 | % [MERGED UNUSED] = PROPVAL(PROPVALS, DEFAULTS, ...) 6 | % 7 | % Given a cell array or structure of property-value pairs 8 | % (i.e. from VARARGIN or a structure of parameters), PROPVAL will 9 | % merge the user specified values with those specified in the 10 | % DEFAULTS structure and return the result in the structure 11 | % MERGED. Any user specified values that were not listed in 12 | % DEFAULTS are output as property-value arguments in the cell array 13 | % UNUSED. STRICT is disabled in this mode. 14 | % 15 | % ALTERNATIVE USAGE: 16 | % 17 | % [ ARGS ] = PROPVAL(PROPVALS, DEFAULTS, ...) 18 | % 19 | % In this case, propval will assume that no user specified 20 | % properties are meant to be "picked up" and STRICT mode will be enforced. 21 | % 22 | % ARGUMENTS: 23 | % 24 | % PROPVALS - Either a cell array of property-value pairs 25 | % (i.e. {'Property', Value, ...}) or a structure of equivalent form 26 | % (i.e. struct.Property = Value), to be merged with the values in 27 | % DEFAULTS. 28 | % 29 | % DEFAULTS - A structure where field names correspond to the 30 | % default value for any properties in PROPVALS. 31 | % 32 | % OPTIONAL ARGUMENTS: 33 | % 34 | % STRICT (default = true) - Use strict guidelines when processing 35 | % the property value pairs. This will warn the user if an empty 36 | % DEFAULTS structure is passed in or if there are properties in 37 | % PROPVALS for which no default was provided. 38 | % 39 | % EXAMPLES: 40 | % 41 | % Simple function with two optional numerical parameters: 42 | % 43 | % function [result] = myfunc(data, varargin) 44 | % 45 | % defaults.X = 5; 46 | % defaults.Y = 10; 47 | % 48 | % args = propvals(varargin, defaults) 49 | % 50 | % data = data * Y / X; 51 | % 52 | % >> myfunc(data) 53 | % This will run myfunc with X=5, Y=10 on the variable 'data'. 54 | % 55 | % >> myfunc(data, 'X', 0) 56 | % This will run myfunc with X=0, Y=10 (thus giving a 57 | % divide-by-zero error) 58 | % 59 | % >> myfunc(data, 'foo', 'bar') will run myfunc with X=5, Y=10, and 60 | % PROPVAL will give a warning that 'foo' has no default value, 61 | % since STRICT is true by default. 62 | % 63 | 64 | % License: 65 | %===================================================================== 66 | % 67 | % This is part of the Princeton MVPA toolbox, released under 68 | % the GPL. See http://www.csbmb.princeton.edu/mvpa for more 69 | % information. 70 | % 71 | % The Princeton MVPA toolbox is available free and 72 | % unsupported to those who might find it useful. We do not 73 | % take any responsibility whatsoever for any problems that 74 | % you have related to the use of the MVPA toolbox. 75 | % 76 | % ====================================================================== 77 | 78 | % Backwards compatibility 79 | pvdef.ignore_missing_default = false; 80 | pvdef.ignore_empty_defaults = false; 81 | 82 | % check for the number of outputs 83 | if nargout == 2 84 | pvdef.strict = false; 85 | else 86 | pvdef.strict = true; 87 | end 88 | 89 | pvargs = pvdef; 90 | 91 | % Recursively process the propval optional arguments (possible 92 | % because we only recurse if optional parameters are given) 93 | if ~isempty(varargin) 94 | pvargs = propval(varargin, pvdef); 95 | end 96 | 97 | % NOTE: Backwards compatibility with previous version of propval 98 | if pvargs.ignore_missing_default | pvargs.ignore_empty_defaults 99 | pvargs.strict = false; 100 | end 101 | 102 | % check for a single cell argument; assume propvals is that argument 103 | if iscell(propvals) && numel(propvals) == 1 104 | propvals = propvals{1}; 105 | end 106 | 107 | % check for valid inputs 108 | if ~iscell(propvals) & ~isstruct(propvals) 109 | error('Property-value pairs must be a cell array or a structure.'); 110 | end 111 | 112 | if ~isstruct(defaults) & ~isempty(defaults) 113 | error('Defaults struct must be a structure.'); 114 | end 115 | 116 | % check for empty defaults structure 117 | if isempty(defaults) 118 | if pvargs.strict & ~pvargs.ignore_missing_default 119 | error('Empty defaults structure passed to propval.'); 120 | end 121 | defaults = struct(); 122 | end 123 | 124 | defaultnames = fieldnames(defaults); 125 | defaultvalues = struct2cell(defaults); 126 | 127 | % prepare the defaults structure, but also prepare casechecking 128 | % structure with all case stripped 129 | defaults = struct(); 130 | casecheck = struct(); 131 | 132 | for i = 1:numel(defaultnames) 133 | defaults.(defaultnames{i}) = defaultvalues{i}; 134 | casecheck.(lower(defaultnames{i})) = defaultvalues{i}; 135 | end 136 | 137 | % merged starts with the default values 138 | merged = defaults; 139 | unused = {}; 140 | used = struct(); 141 | 142 | properties = []; 143 | values = []; 144 | 145 | % To extract property value pairs, we use different methods 146 | % depending on how they were passed in 147 | if isstruct(propvals) 148 | properties = fieldnames(propvals); 149 | values = struct2cell(propvals); 150 | else 151 | properties = { propvals{1:2:end} }; 152 | values = { propvals{2:2:end} }; 153 | end 154 | 155 | if numel(properties) ~= numel(values) 156 | error(sprintf('Found %g properties but only %g values.', numel(properties), ... 157 | numel(values))); 158 | end 159 | 160 | % merge new properties with defaults 161 | for i = 1:numel(properties) 162 | 163 | if ~ischar(properties{i}) 164 | error(sprintf('Property %g is not a string.', i)); 165 | end 166 | 167 | % convert property names to lower case 168 | properties{i} = properties{i}; 169 | 170 | % check for multiple usage 171 | if isfield(used, properties{i}) 172 | error(sprintf('Property %s is defined more than once.\n', ... 173 | properties{i})); 174 | end 175 | 176 | % Check for case errors 177 | if isfield(casecheck, lower(properties{i})) & ... 178 | ~isfield(merged, properties{i}) 179 | error(['Property ''%s'' is equal to a default property except ' ... 180 | 'for case.'], properties{i}); 181 | end 182 | 183 | % Merge with defaults 184 | if isfield(merged, properties{i}) 185 | merged.(properties{i}) = values{i}; 186 | else 187 | % add to unused property value pairs 188 | unused{end+1} = properties{i}; 189 | unused{end+1} = values{i}; 190 | 191 | % add to defaults, just in case, if the user isn't picking up "unused" 192 | if (nargout == 1 & ~pvargs.strict) 193 | merged.(properties{i}) = values{i}; 194 | end 195 | 196 | if pvargs.strict 197 | error('Property ''%s'' has no default value.', properties{i}); 198 | end 199 | 200 | end 201 | 202 | % mark as used 203 | used.(properties{i}) = true; 204 | end 205 | 206 | -------------------------------------------------------------------------------- /src/3D/tools/renderImage.m: -------------------------------------------------------------------------------- 1 | function imlist = renderImage(x, edgeAdj, varargin) 2 | 3 | % function imlist = renderImage(x, edgeAdj, varargin) 4 | % 5 | % Parameter: addNoise, addNode, outputStep, nRandLine, nRandTri, 6 | % noiseLevel, h, w, scale, lineWidth, circSize 7 | 8 | para.addNoise = true; 9 | para.addNode = false; 10 | para.outputStep = 1000; 11 | para.nRandLine = 15; 12 | para.nRandTri = 3; 13 | para.noiseLevel = 0.1; 14 | para.h = 64; 15 | para.w = 64; 16 | para.scale = 2; 17 | para.lineWidth = 3; 18 | para.circSize = 4; 19 | para.parallel = true; 20 | para = propval(varargin, para); 21 | 22 | h=para.h; w=para.w; 23 | hb = h*para.scale; wb = w*para.scale; 24 | nShape = size(x,3); 25 | circSize = para.circSize; 26 | imlist = zeros(h,w,1,nShape); 27 | edgeColor = [0,0,0]; 28 | np = size(x,2); 29 | if para.parallel 30 | parfor i=1:nShape 31 | % for i=1:nShape 32 | im = ones(hb,wb,3); 33 | coor = x(1:2,:,i) * para.scale; 34 | if para.addNode 35 | im = insertShape(im, 'FilledCircle', [coor',circSize*ones(np,1)], 'Color',edgeColor); 36 | end 37 | im = insertShape(im, 'Line', [coor(1,edgeAdj(:,1))',coor(2,edgeAdj(:,1))',... 38 | coor(1,edgeAdj(:,2))',coor(2,edgeAdj(:,2))'], 'Color', edgeColor, 'LineWidth', para.lineWidth); 39 | 40 | if para.addNoise 41 | xyStart = bsxfun(@times,rand(para.nRandLine,2),[wb,hb]); 42 | xyLen = rand(para.nRandLine,1) * (hb/8) + hb/8; angle = rand(para.nRandLine,1)*(2*pi); 43 | xyEnd = bsxfun(@plus, xyStart, bsxfun(@times,xyLen,[cos(angle),sin(angle)])); 44 | im = insertShape(im, 'Line', [xyStart,xyEnd], 'LineWidth', 1, 'Color', edgeColor); 45 | xyCenter = bsxfun(@plus, bsxfun(@times,rand(para.nRandTri,2),[wb/2,hb/2]), [wb/4,hb/4]); 46 | xyLen = rand(para.nRandTri,1) * (hb/8) + hb/8; angle = rand([para.nRandTri,1,3])*(2*pi); 47 | xyEnd = reshape(bsxfun(@plus, xyCenter, bsxfun(@times,xyLen,[cos(angle),sin(angle)])),[para.nRandTri,6]); 48 | im = insertShape(im, 'FilledPolygon', xyEnd, 'Opacity', 1, 'Color', repmat(rand(para.nRandTri,1)*0.5,[1,3])); 49 | im = im + para.noiseLevel * randn(size(im)); 50 | end 51 | 52 | imlist(:,:,:,i) = imresize(im(:,:,1),[h,w]); 53 | if mod(i,para.outputStep) == 0 54 | fprintf(1,'shape %d\n', i); 55 | end 56 | end 57 | else 58 | for i=1:nShape 59 | im = ones(hb,wb,3); 60 | coor = x(1:2,:,i) * para.scale; 61 | if para.addNode 62 | im = insertShape(im, 'FilledCircle', [coor',circSize*ones(np,1)], 'Color',edgeColor); 63 | end 64 | im = insertShape(im, 'Line', [coor(1,edgeAdj(:,1))',coor(2,edgeAdj(:,1))',... 65 | coor(1,edgeAdj(:,2))',coor(2,edgeAdj(:,2))'], 'Color', edgeColor, 'LineWidth', para.lineWidth); 66 | 67 | if para.addNoise 68 | xyStart = bsxfun(@times,rand(para.nRandLine,2),[wb,hb]); 69 | xyLen = rand(para.nRandLine,1) * (hb/8) + hb/8; angle = rand(para.nRandLine,1)*(2*pi); 70 | xyEnd = bsxfun(@plus, xyStart, bsxfun(@times,xyLen,[cos(angle),sin(angle)])); 71 | im = insertShape(im, 'Line', [xyStart,xyEnd], 'LineWidth', 1, 'Color', edgeColor); 72 | xyCenter = bsxfun(@plus, bsxfun(@times,rand(para.nRandTri,2),[wb/2,hb/2]), [wb/4,hb/4]); 73 | xyLen = rand(para.nRandTri,1) * (hb/8) + hb/8; angle = rand([para.nRandTri,1,3])*(2*pi); 74 | xyEnd = reshape(bsxfun(@plus, xyCenter, bsxfun(@times,xyLen,[cos(angle),sin(angle)])),[para.nRandTri,6]); 75 | im = insertShape(im, 'FilledPolygon', xyEnd, 'Opacity', 1, 'Color', repmat(rand(para.nRandTri,1)*0.5,[1,3])); 76 | im = im + para.noiseLevel * randn(size(im)); 77 | end 78 | 79 | imlist(:,:,:,i) = imresize(im(:,:,1),[h,w]); 80 | if mod(i,para.outputStep) == 0 81 | fprintf(1,'shape %d\n', i); 82 | end 83 | end 84 | 85 | end 86 | 87 | end 88 | 89 | -------------------------------------------------------------------------------- /src/3D/tools/show3DLD.m: -------------------------------------------------------------------------------- 1 | function show3DLD(X, adj, color, varargin) 2 | 3 | % function show3DLD(X, adj, color, varargin) 4 | % 5 | % This file is created by Tianfan Xue (tfxue@mit.edu). 6 | 7 | para.showNumber = false; 8 | para.showAxis = true; 9 | para.fontsize = 20; 10 | para = propval(varargin, para); 11 | 12 | if (nargin == 2) || isempty(color) 13 | color = [1,0,0]; 14 | end 15 | 16 | X1 = X(1,:); X2 = X(2,:); X3 = X(3,:); 17 | plot3(X1,X2,X3,'.','Color',color,'MarkerSize',10); 18 | hold on; 19 | line(X1(adj)', X2(adj)', X3(adj)', 'Color', color,'LineWidth',3); 20 | if para.showNumber 21 | for i=1:size(X,2) 22 | text(X1(i),X2(i),X3(i),sprintf('%d',i),'FontSize',para.fontsize); 23 | end 24 | end 25 | if para.showAxis 26 | xlabel('x'); 27 | ylabel('y'); 28 | zlabel('z'); 29 | end 30 | axis equal 31 | 32 | end 33 | 34 | -------------------------------------------------------------------------------- /src/3D/tools/visualize3Dpara.m: -------------------------------------------------------------------------------- 1 | function im = visualize3Dpara(alpha, sincosTheta, tran, f, baseShape, edgeAdj, varargin) 2 | 3 | % function im = visualize3Dpara(alpha, sincosTheta, tran, f, baseShape, varargin) 4 | % 5 | % Parameters: 6 | % alpha: nbasis x 1 7 | % sincosTheta: 6 x 1 8 | % tran: 2x1 9 | % f: 1 10 | % baseShape: 3 x np x nbasis 11 | % edgeAdj: nedge x 2 12 | % 13 | % Addition parameters: 14 | % h, w, lineWidth, addNode, circSize 15 | % 16 | % Example: 17 | % stickStruct = getStickFigure('class', 'chair'); 18 | % im = visualize3Dpara(alpha, sincosTheta, tran, f, stickStruct.baseShape{1}, stickStruct.edgeAdj{1}); 19 | 20 | para.h = 320; 21 | para.w = 240; 22 | para.lineWidth = 6; 23 | para.addNode = true; 24 | para.circSize = 8; 25 | para = propval(varargin, para); 26 | 27 | theta = sctheta2theta(sincosTheta); 28 | x = alpha2x_proj(tran,alpha,theta,f,baseShape, 'w', para.w, 'h',para.h); 29 | im = renderImage(x,edgeAdj,'addNoise',false,'noiseLevel',0,'h',para.h,'w',para.w,'lineWidth',6,'parallel',false,... 30 | 'addNode',para.addNode, 'circSize', para.circSize); 31 | 32 | end 33 | 34 | 35 | -------------------------------------------------------------------------------- /src/evaluate.m: -------------------------------------------------------------------------------- 1 | function evaluate(resPath, wwwPath, class, width, height) 2 | 3 | if strcmp(class, 'chair') || strcmp(class, 'bed') 4 | param_length = [5, 6, 12, 14]; 5 | elseif strcmp(class, 'swivelchair') || strcmp(class, 'sofa') 6 | param_length = [7, 8, 14, 16]; 7 | end 8 | 9 | load(resPath, 'outputs'); 10 | numInst = size(outputs, 1); 11 | alphaPred = outputs(:, 1 : param_length(1)); 12 | finvPred = outputs(:, (param_length(1) + 1) : param_length(2)); 13 | sincosThetaPred = outputs(:, (param_length(2) + 1) : param_length(3)); 14 | sincosThetaPred(:, 3) = 0; 15 | sincosThetaPred(:, 6) = 1; 16 | tranPred = outputs(:, (param_length(3) + 1) : param_length(4)); 17 | 18 | addpath(fullfile('3D', 'genSynData')); 19 | addpath(fullfile('3D', 'tools')); 20 | stickStruct = getStickFigure('class', class); 21 | 22 | for i = 1 : numInst 23 | im = visualize3Dpara(alphaPred(i, :)', sincosThetaPred(i, :)', ... 24 | tranPred(i, :)', 1 ./ finvPred(i, :), stickStruct.baseShape{1}, ... 25 | stickStruct.edgeAdj{1}, 'h', height, 'w', width, 'lineWidth', 6); 26 | imwrite(im, fullfile(wwwPath, sprintf('%08d.jpg', i))); 27 | end 28 | 29 | end 30 | 31 | -------------------------------------------------------------------------------- /src/main.lua: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------------- 2 | print '==> Initializing...' 3 | package.path = '?.lua;' .. package.path 4 | 5 | require 'torch' 6 | require 'image' 7 | 8 | require 'nn' 9 | require 'cunn' 10 | require 'cutorch' 11 | 12 | require 'nn.Scale' 13 | 14 | dofile 'pyramidLCN.lua' 15 | 16 | ------------------------------------------------------------------------------- 17 | 18 | local cmd = torch.CmdLine() 19 | 20 | cmd:text() 21 | cmd:text('3D-INN evaluation script') 22 | cmd:text() 23 | cmd:text('Options:') 24 | cmd:option('-seed', 0, 'Random seed') 25 | cmd:option('-gpuID', 1, 'ID of GPUs to use') 26 | cmd:option('-batchSize', 16, 'Batch Size') 27 | cmd:option('-class', 'chair', 'Model class for evaluation') 28 | cmd:text() 29 | 30 | local opt = cmd:parse(arg or {}) 31 | 32 | cutorch.setDevice(opt.gpuID) 33 | torch.manualSeed(opt.seed) 34 | math.randomseed(opt.seed) 35 | cutorch.manualSeedAll(opt.seed) 36 | 37 | -------------------------------------------------------------------------------- 38 | -- constants 39 | opt.chnNum = 3 -- RGB 40 | opt.inputWidth = 240 41 | opt.inputHeight = 320 42 | opt.heatmapWidth = 30 43 | opt.heatmapHeight = 40 44 | 45 | opt.bankSize = {320, 160, 80} -- input sizes (long edge) for each bank 46 | opt.pyramidSize = {} -- pyramid sizes for each scale 47 | for i = 1, #opt.bankSize do 48 | local scale = opt.bankSize[i] / math.max(opt.inputWidth, opt.inputHeight) 49 | opt.pyramidSize[i] = {math.floor(scale * opt.inputHeight), 50 | math.floor(scale * opt.inputWidth)} 51 | end 52 | 53 | if opt.class == 'chair' then 54 | opt.paramNum = 14 55 | opt.paramMin = {30, -70, -70, -15, 0, 1/450, -1, -1, -1, -1, -1, -1, -60, -80} 56 | opt.paramMax = {50, 100, 100, 100, 150, 1/120, 1, 1, 1, 1, 1, 1, 60, 80} 57 | opt.paramWeight = torch.ones(opt.paramNum) 58 | opt.paramWeight[6] = 10 59 | elseif opt.class == 'swivelchair' then 60 | opt.paramNum = 16 61 | opt.paramMin = {35, -42.5, -85, -42.5, -42.5, -25.5, 0, 1/450, -1, -1, -1, -1, -1, -1, -60, -80} 62 | opt.paramMax = {85, 42.5, 212.5, 85, 65, 51, 51, 1/120, 1, 1, 1, 1, 1, 1, 60, 80} 63 | opt.paramWeight = torch.ones(opt.paramNum) 64 | opt.paramWeight[8] = 10 65 | elseif opt.class == 'bed' then 66 | opt.paramNum = 14 67 | opt.paramMin = {30, -50.4, -21, -21, -21, 1/450, -1, -1, -1, -1, -1, -1, -80, -60} 68 | opt.paramMax = {42, 21, 63, 21, 42, 1/200, 1, 1, 1, 1, 1, 1, 80, 60} 69 | opt.paramWeight = torch.ones(opt.paramNum) 70 | opt.paramWeight[6] = 10 71 | elseif opt.class == 'sofa' then 72 | opt.paramNum = 16 73 | opt.paramMin = {30, -21, 0, -21, -42, -42, 0, 1/450, -1, -1, -1, -1, -1, -1, -80, -60} 74 | opt.paramMax = {42, 21, 84, 42, 84, 42, 84, 1/200, 1, 1, 1, 1, 1, 1, 80, 60} 75 | opt.paramWeight = torch.ones(opt.paramNum) 76 | opt.paramWeight[8] = 10 77 | end 78 | 79 | opt.mean = torch.Tensor({129.67, 114.43, 107.26}) 80 | opt.mean:mul(1 / 255) 81 | 82 | -------------------------------------------------------------------------------- 83 | -- paths 84 | opt.dataPath = paths.concat('..', 'data') 85 | opt.modelPath = paths.concat('..', 'models', opt.class .. '_model.torch') 86 | 87 | opt.wwwPath = paths.concat('..', 'www', opt.class) 88 | opt.resPath = paths.concat('..', 'results') 89 | opt.genPath = paths.concat('..', 'gen') 90 | 91 | os.execute('mkdir -p ' .. opt.wwwPath) 92 | os.execute('mkdir -p ' .. opt.resPath) 93 | os.execute('mkdir -p ' .. opt.genPath) 94 | 95 | opt.resPath = paths.concat(opt.resPath, opt.class .. '.mat') 96 | opt.scalePath = paths.concat(opt.genPath, opt.class .. '_scale.torch') 97 | opt.inputPath = paths.concat(opt.genPath, opt.class .. '_inputs.torch') 98 | 99 | ------------------------------------------------------------------------------- 100 | function unnormOutput(params, opt) 101 | for i = 1, opt.paramNum do 102 | if opt.paramMax[i] ~= opt.paramMin[i] then 103 | params[i] = opt.paramMin[i] + params[i] / 104 | opt.paramWeight[i] * (opt.paramMax[i] - opt.paramMin[i]) 105 | else 106 | params[i] = opt.paramMax[i] 107 | end 108 | end 109 | 110 | return params 111 | end 112 | 113 | -- load data and build lcn 114 | function loadData(opt) 115 | local imageList = paths.concat(opt.dataPath, opt.class .. '.txt') 116 | local imageNum = 0 117 | local imageNames = {} 118 | for line in io.lines(imageList) do 119 | imageNum = imageNum + 1 120 | imageNames[imageNum] = line 121 | end 122 | 123 | print('generate scales') 124 | local oriCoords = torch.zeros(imageNum, 2, 2) 125 | local images = torch.Tensor(imageNum, opt.chnNum, opt.inputHeight, opt.inputWidth) 126 | 127 | for i = 1, imageNum do 128 | local oriIm = image.load(paths.concat(opt.dataPath, imageNames[i])) 129 | if oriIm:size(1) == 1 and opt.chnNum == 3 then 130 | oriIm = torch.repeatTensor(oriIm, 3, 1, 1) 131 | end 132 | for k = 1, opt.chnNum do 133 | local maxValue = oriIm[k]:max() 134 | local minValue = oriIm[k]:min() 135 | local imScaled = oriIm[k]:clone():add(-minValue):mul(1.0 / (maxValue - minValue)) 136 | images[i][k] = image.scale(imScaled, opt.inputWidth, opt.inputHeight) 137 | images[i][k]:add(-opt.mean[k]) 138 | end 139 | 140 | -- get new im trans from im 141 | local scaleX = opt.heatmapWidth / oriIm:size(3) 142 | local scaleY = opt.heatmapHeight / oriIm:size(2) 143 | local scale = math.min(scaleX, scaleY) 144 | local x1, x2, y1, y2 145 | local newTrans 146 | if scaleX < scaleY then 147 | x1 = 1 148 | x2 = opt.heatmapWidth 149 | y1 = math.floor((opt.heatmapHeight - oriIm:size(2) * scale) / 2 + 1 + 0.5) 150 | y2 = math.floor(y1 + scale * oriIm:size(2) - 1 + 0.5) 151 | else 152 | x1 = math.floor((opt.heatmapWidth - oriIm:size(3) * scale) / 2 + 1 + 0.5) 153 | x2 = math.floor(x1 + scale * oriIm:size(3) - 1 + 0.5) 154 | y1 = 1 155 | y2 = opt.heatmapHeight 156 | end 157 | 158 | x1 = math.min(math.max(x1, 1), opt.heatmapWidth) 159 | x2 = math.min(math.max(x2, 1), opt.heatmapWidth) 160 | y1 = math.min(math.max(y1, 1), opt.heatmapHeight) 161 | y2 = math.min(math.max(y2, 1), opt.heatmapHeight) 162 | 163 | local newTrans = torch.Tensor({{scale, 0, x1 - scale}, {0, scale, y1 - scale}}) 164 | 165 | -- use trans to get the new coords 166 | oriCoords[i][1] = newTrans * torch.Tensor({1, 1, 1}) 167 | oriCoords[i][2] = newTrans * torch.Tensor({oriIm:size(3), oriIm:size(2), 1}) 168 | oriCoords[i][1][1] = math.floor(oriCoords[i][1][1] + 0.5) 169 | oriCoords[i][1][2] = math.floor(oriCoords[i][1][2] + 0.5) 170 | oriCoords[i][2][1] = math.floor(oriCoords[i][2][1] + 0.5) 171 | oriCoords[i][2][2] = math.floor(oriCoords[i][2][2] + 0.5) 172 | oriCoords[i][1][1] = math.min(math.max(oriCoords[i][1][1], 1), opt.heatmapWidth) 173 | oriCoords[i][1][2] = math.min(math.max(oriCoords[i][1][2], 1), opt.heatmapHeight) 174 | oriCoords[i][2][1] = math.min(math.max(oriCoords[i][2][1], 1), opt.heatmapWidth) 175 | oriCoords[i][2][2] = math.min(math.max(oriCoords[i][2][2], 1), opt.heatmapHeight) 176 | end 177 | 178 | torch.save(opt.scalePath, oriCoords) 179 | 180 | -- build pyramid and do lcn 181 | print('generate normalized input') 182 | local input = buildPyramidsAndLCN(images, opt) 183 | torch.save(opt.inputPath, input) 184 | end 185 | 186 | ------------------------------------------------------------------------------- 187 | -- measure function, save all predictions for evaluation 188 | function test(opt) 189 | local time = sys.clock() 190 | 191 | local model = torch.load(opt.modelPath) 192 | model:evaluate() 193 | 194 | -- load test input/output 195 | local inputs = torch.load(opt.inputPath) 196 | local scales = torch.load(opt.scalePath) 197 | local outputs = torch.zeros(inputs:size(1), opt.paramNum) 198 | 199 | -- test over given dataset 200 | for t = 1, inputs:size(1), opt.batchSize do 201 | -- disp progress 202 | xlua.progress(t, inputs:size(1)) 203 | 204 | -- create mini batch 205 | local batchStart = t 206 | local batchEnd = math.min(t + opt.batchSize - 1, inputs:size(1)) 207 | local batchIndices = {batchStart, batchEnd} 208 | local batchInputs = inputs[{batchIndices}]:cuda() 209 | local batchScales = scales[{batchIndices}]:cuda() 210 | 211 | local batchOutputs = model:forward({batchInputs, batchScales}):clone():double() 212 | for i = batchStart, batchEnd do 213 | outputs[i] = unnormOutput(batchOutputs[i - batchStart + 1], opt) 214 | end 215 | end 216 | 217 | -- timing 218 | time = sys.clock() - time 219 | time = time / inputs:size(1) 220 | print(" time to test 1 sample = " .. (time * 1000) .. 'ms') 221 | 222 | -- save .mat file 223 | require('fb.mattorch').save(opt.resPath, {outputs = outputs}) 224 | end 225 | 226 | ------------------------------------------------------------------------------- 227 | print '==> prepare data' 228 | loadData(opt) 229 | 230 | ------------------------------------------------------------------------------- 231 | print '==> evaluate' 232 | test(opt) 233 | 234 | local cmd = 'matlab -nodisplay -nodesktop -nojvm -r ' .. 235 | '"evaluate(\'' .. opt.resPath .. '\', \'' .. opt.wwwPath .. 236 | '\', \'' .. opt.class .. '\', ' .. opt.inputWidth .. 237 | ', ' .. opt.inputHeight .. '); exit;"' 238 | print('Testing: ' .. cmd) 239 | os.execute(cmd) 240 | 241 | -------------------------------------------------------------------------------- /src/nn/Scale.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Input: A length-2 Table 3 | Input[1]: Responses, batchSize x nChannels x iH x iW 4 | Input[2]: Pairs (minW, minH, maxW, maxH), batchSize x 2 x 2 5 | 6 | Output: Scaled responses with zero paddings 7 | batchSize x nChannels x iH x iW 8 | --]] 9 | require 'image' 10 | 11 | local Scale, parent = torch.class('nn.Scale', 'nn.Module') 12 | 13 | function Scale:__init(sum) 14 | parent.__init(self) 15 | 16 | self.sum = sum 17 | self.gradInput = {} 18 | end 19 | 20 | 21 | function Scale:updateOutput(input) 22 | assert(#input == 2) 23 | assert(input[1]:size(1) == input[2]:size(1)) 24 | 25 | local batchSize = input[1]:size(1) 26 | local nChannels = input[1]:size(2) 27 | local iH = input[1]:size(3) 28 | local iW = input[1]:size(4) 29 | self.output:resize(batchSize, nChannels, iH, iW):zero() 30 | 31 | self.buffer = self.buffer or input[1].new() 32 | self.buffer:resize(batchSize, nChannels) 33 | 34 | for i = 1, batchSize do 35 | local minW = input[2][i][1][1] 36 | local minH = input[2][i][1][2] 37 | local maxW = input[2][i][2][1] 38 | local maxH = input[2][i][2][2] 39 | local ratio = (maxW - minW + 1) * (maxH - minH + 1) / (iW * iH) 40 | 41 | self.output[{i, {}, {minH, maxH}, {minW, maxW}}] = 42 | image.scale(input[1][i]:double(), maxW - minW + 1, maxH - minH + 1):cuda():mul(1/ratio) 43 | 44 | for j = 1, nChannels do 45 | if self.output[i][j]:sum() == 0 then 46 | self.buffer[i][j] = 0 47 | else 48 | self.buffer[i][j] = math.min(self.sum / self.output[i][j]:sum(), 100) 49 | self.output[i][j]:mul(self.buffer[i][j]) 50 | end 51 | if self.output[i][j]:ne(self.output[i][j]):sum() > 0 then 52 | print(self.buffer[i][j]) 53 | print(i..' '..j) 54 | print('!!!!!!!!!!!!!!!!!!!!!!! Found NaN in output !!!!!!!!!!!!!!!!!!!!!!!') 55 | end 56 | end 57 | end 58 | 59 | return self.output 60 | end 61 | 62 | function Scale:updateGradInput(input, gradOutput) 63 | for i = 1, #input do 64 | if self.gradInput[i] == nil then 65 | self.gradInput[i] = input[i].new() 66 | end 67 | self.gradInput[i]:resizeAs(input[i]):zero() 68 | end 69 | 70 | local batchSize = input[1]:size(1) 71 | local nChannels = input[1]:size(2) 72 | local iH = input[1]:size(3) 73 | local iW = input[1]:size(4) 74 | 75 | for i = 1, batchSize do 76 | local minW = input[2][i][1][1] 77 | local minH = input[2][i][1][2] 78 | local maxW = input[2][i][2][1] 79 | local maxH = input[2][i][2][2] 80 | local ratio = (maxW - minW + 1) * (maxH - minH + 1) / (iW * iH) 81 | 82 | self.gradInput[1][i] = 83 | image.scale(gradOutput[{i, {}, {minH, maxH}, {minW, maxW}}]:double(), iW, iH):cuda() 84 | for j = 1, nChannels do 85 | self.gradInput[1][i][j]:mul(self.buffer[i][j]) 86 | end 87 | end 88 | 89 | return self.gradInput 90 | end 91 | 92 | function Scale:__tostring__() 93 | return string.format('%s()', torch.type(self)) 94 | end 95 | -------------------------------------------------------------------------------- /src/pyramidLCN.lua: -------------------------------------------------------------------------------- 1 | function gaussianPyramid(...) 2 | local dst, src, scales 3 | local args = {...} 4 | if select('#',...) == 3 then 5 | dst = args[1] 6 | src = args[2] 7 | scales = args[3] 8 | elseif select('#',...) == 2 then 9 | dst = {} 10 | src = args[1] 11 | scales = args[2] 12 | end 13 | if src:nDimension() == 2 then 14 | for i = 1, #scales do 15 | dst[i] = dst[i] or torch.Tensor() 16 | dst[i]:resize(src:size(1) * scales[i], src:size(2) * scales[i]) 17 | end 18 | elseif src:nDimension() == 3 then 19 | for i = 1, #scales do 20 | dst[i] = dst[i] or torch.Tensor() 21 | dst[i]:resize(src:size(1), src:size(2) * scales[i], src:size(3) * scales[i]) 22 | end 23 | end 24 | local kernel = image.gaussian1D{width = 3, normalize = true} 25 | 26 | local padH = math.floor(kernel:size(1) / 2) 27 | local padW = padH 28 | 29 | local gaussianFilter = nn.Sequential() 30 | gaussianFilter:add(nn.SpatialZeroPadding(padW, padW, padH, padH)) 31 | gaussianFilter:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(1), kernel:size(1), 1)) 32 | gaussianFilter:add(nn.SpatialConvolution(1, 1, 1, kernel:size(1), 1)) 33 | 34 | for i = 1, 1 do 35 | gaussianFilter.modules[2].weight[i]:copy(kernel) 36 | gaussianFilter.modules[3].weight[1][i]:copy(kernel) 37 | end 38 | 39 | gaussianFilter.modules[2].bias:zero() 40 | gaussianFilter.modules[3].bias:zero() 41 | 42 | local tmp = src 43 | for i = 1, #scales do 44 | if scales[i] == 1 then 45 | dst[i][{}] = tmp 46 | else 47 | image.scale(dst[i], tmp, 'simple') 48 | end 49 | 50 | local reshaped = dst[i]:reshape(dst[i]:size(1), 1, dst[i]:size(2), dst[i]:size(3)) 51 | 52 | local coef = gaussianFilter:updateOutput(reshaped.new():resizeAs(reshaped):fill(1)) 53 | coef = coef:clone() 54 | 55 | local filtered = gaussianFilter:updateOutput(reshaped) 56 | local filteredNormalized = nn.CDivTable():updateOutput{filtered, coef} 57 | 58 | tmp = filteredNormalized:reshape(dst[i]:size(1), dst[i]:size(2), dst[i]:size(3)) 59 | end 60 | return dst 61 | end 62 | 63 | function customLCN(inputs, kernel, threshold, thresval) 64 | assert (inputs:dim() == 4, "Input should be of the form nSamples x nChannels x width x height") 65 | 66 | local padH = math.floor(kernel:size(1)/2) 67 | local padW = padH 68 | 69 | -- normalize the kernel 70 | kernel:div(kernel:sum()) 71 | 72 | local meanestimator = nn.Sequential() 73 | meanestimator:add(nn.SpatialZeroPadding(padW, padW, padH, padH)) 74 | meanestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(1), kernel:size(1), 1)) 75 | meanestimator:add(nn.SpatialConvolution(1, 1, 1, kernel:size(1), 1)) 76 | 77 | local stdestimator = nn.Sequential() 78 | stdestimator:add(nn.Square()) 79 | stdestimator:add(nn.SpatialZeroPadding(padW, padW, padH, padH)) 80 | stdestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(1), kernel:size(1), 1)) 81 | stdestimator:add(nn.SpatialConvolution(1, 1, 1, kernel:size(1))) 82 | stdestimator:add(nn.Sqrt()) 83 | 84 | for i = 1,1 do 85 | meanestimator.modules[2].weight[i]:copy(kernel) 86 | meanestimator.modules[3].weight[1][i]:copy(kernel) 87 | stdestimator.modules[3].weight[i]:copy(kernel) 88 | stdestimator.modules[4].weight[1][i]:copy(kernel) 89 | end 90 | meanestimator.modules[2].bias:zero() 91 | meanestimator.modules[3].bias:zero() 92 | stdestimator.modules[3].bias:zero() 93 | stdestimator.modules[4].bias:zero() 94 | 95 | local coef = meanestimator:updateOutput(inputs.new():resizeAs(inputs):fill(1)) 96 | coef = coef:clone() 97 | 98 | local localSums = meanestimator:updateOutput(inputs) 99 | local adjustedSums = nn.CDivTable():updateOutput{localSums, coef} 100 | local meanSubtracted = nn.CSubTable():updateOutput{inputs, adjustedSums} 101 | 102 | local localStds = stdestimator:updateOutput(meanSubtracted) 103 | local adjustedStds = nn.CDivTable():updateOutput{localStds, coef} 104 | local thresholdedStds = nn.Threshold(threshold, thresval):updateOutput(adjustedStds) 105 | local outputs = nn.CDivTable():updateOutput{meanSubtracted, thresholdedStds} 106 | 107 | return outputs 108 | end 109 | 110 | -- Build a pyramid and run LCN on each of the inputs 111 | function buildPyramidsAndLCN(convInputs, opt) 112 | local bankNum = #opt.bankSize 113 | 114 | -- opt.bankSize[1] is the biggest size. 115 | -- From one image, we get multiple banks of different sizes (pyramid) 116 | local convInputsNormalized = torch.Tensor(convInputs:size(1), 117 | bankNum, opt.chnNum, opt.pyramidSize[1][1] * opt.pyramidSize[1][2]):zero() 118 | 119 | local pyramid = {} 120 | for j = 1, bankNum do 121 | pyramid[j] = torch.Tensor(convInputs:size(1), 122 | opt.chnNum, opt.pyramidSize[j][1], opt.pyramidSize[j][2]) 123 | end 124 | 125 | local timer = torch.Timer() 126 | print('Building '..convInputs:size(1)..' pyramids...') 127 | 128 | -- Build pyramids 129 | local imageSize = math.max(convInputs:size(3), convInputs:size(4)) 130 | local pyScales = {} 131 | for i = 1, bankNum do 132 | pyScales[i] = opt.bankSize[i] / imageSize 133 | end 134 | 135 | for i = 1, opt.chnNum do 136 | chnPyramid = {} 137 | gaussianPyramid(chnPyramid, convInputs[{{}, i, {}, {}}], pyScales) 138 | 139 | for j = 1, bankNum do 140 | pyramid[j][{{}, i, {}, {}}] = chnPyramid[j] 141 | end 142 | end 143 | print('Built pyramids : ' .. timer:time().real .. ' seconds') 144 | 145 | -- Define the normalization neighborhood: 146 | local neighborhood = image.gaussian1D(7) 147 | 148 | -- Use spatial contrastive normalization (LCN) to filter the input 149 | print('Applying LCN to '..convInputs:size(1)..' pyramids...') 150 | timer:reset() 151 | for scale = 1, bankNum do 152 | for chn = 1, opt.chnNum do 153 | convInputsNormalized[{{}, scale, chn, 154 | {1, pyramid[scale]:size(3) * pyramid[scale]:size(4)}}] = 155 | customLCN(pyramid[scale][{{}, chn, {}, {}}]: 156 | reshape(pyramid[scale]:size(1), 1, pyramid[scale]:size(3), 157 | pyramid[scale]:size(4)), neighborhood, 1, 1) 158 | end 159 | end 160 | print('Applied LCN : ' .. timer:time().real .. ' seconds') 161 | 162 | return convInputsNormalized 163 | end 164 | 165 | --------------------------------------------------------------------------------