├── .gitignore
├── README.md
├── data
├── bed.txt
├── bed_00000001.png
├── bed_00000002.jpg
├── bed_00000003.jpg
├── chair.txt
├── chair_00000001.jpg
├── chair_00000002.jpg
├── chair_00000003.jpg
├── chair_00000004.jpg
├── sofa.txt
├── sofa_00000001.jpg
├── sofa_00000002.jpg
├── sofa_00000003.jpg
├── swivelchair.txt
├── swivelchair_00000001.png
├── swivelchair_00000002.png
├── swivelchair_00000003.jpg
└── swivelchair_00000004.jpg
├── download_models.sh
└── src
├── 3D
├── genSynData
│ ├── getStickFigure.m
│ └── sctheta2theta.m
└── tools
│ ├── alpha2X3D.m
│ ├── alpha2x_proj.m
│ ├── propval.m
│ ├── renderImage.m
│ ├── show3DLD.m
│ └── visualize3Dpara.m
├── evaluate.m
├── main.lua
├── nn
└── Scale.lua
└── pyramidLCN.lua
/.gitignore:
--------------------------------------------------------------------------------
1 | archive/
2 | gen/
3 | models/
4 | results/
5 | www/
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Single Image 3D Interpreter Network
2 |
3 | This repository contains pre-trained models and evaluation code for the project 'Single Image 3D Interpreter Network' (ECCV 2016).
4 |
5 | http://3dinterpreter.csail.mit.edu
6 |
7 | ## Prerequisites
8 | #### Torch
9 | We use Torch 7 (http://torch.ch) for our implementation.
10 |
11 | #### fb.mattorch and Matlab (optional)
12 | We use `.mat` file with [`fb.mattorch`](https://github.com/facebook/fblualib/tree/master/fblualib/mattorch) for saving results, and `Matlab` (R2015a or later, with Computer Vision System Toolbox) for visualization.
13 |
14 | ## Installation
15 | Our current release has been tested on Ubuntu 14.04.
16 |
17 | #### Clone the repository
18 | ```sh
19 | git clone git@github.com:jiajunwu/3dinn.git
20 | ```
21 | #### Download pretrained models (1.8GB)
22 | ```sh
23 | cd 3dinn
24 | ./download_models.sh
25 | ```
26 |
27 | ## Steps for evaluation
28 |
29 | #### I) List input images in `data/[classname].txt`
30 |
31 | #### II) Estimate 3D object structure
32 |
33 | The file (`src/main.lua`) has the following options.
34 | - `-gpuID`: specifies the gpu to run on (1-indexed)
35 | - `-class`: which model to use for evaluation. Our current release contains four models: `chair`, `swivelchair`, `bed`, and `sofa`.
36 | - `-batchSize`: the batch size to use
37 |
38 | Sample usages include
39 | - Estimate chair structure for images listed in `data/class.txt`
40 | ```sh
41 | cd src
42 | th main.lua -gpuID 1 -class chair
43 | ```
44 |
45 | #### III) Check visualization in `www`, and estimated parameters in `results`
46 |
47 | ## Sample input & output
48 |
49 |
50 |
51 |  |
52 |  |
53 |  |
54 |  |
55 |  |
56 |  |
57 |
58 |
59 |  |
60 |  |
61 |  |
62 |  |
63 |  |
64 |  |
65 |
66 |
67 |  |
68 |  |
69 |  |
70 |  |
71 |  |
72 |  |
73 |
74 |
75 |  |
76 |  |
77 |  |
78 |  |
79 |  |
80 |  |
81 |
82 |
83 |
84 |
85 | ## Datasets we used
86 |
87 | - Keypoint-5 dataset [(zip, 208MB)](http://3dinterpreter.csail.mit.edu/data/keypoint-5.zip)
88 |
89 | - Extended IKEA dataset with additional 3D keypoint labels [(zip, 171MB)](http://3dinterpreter.csail.mit.edu/data/ikea_3DINN.zip)
90 |
91 | - [SUN Database](http://groups.csail.mit.edu/vision/SUN/)
92 |
93 | ## Reference
94 |
95 | @inproceedings{3dinterpreter,
96 | title={{Single Image 3D Interpreter Network}},
97 | author={Wu, Jiajun and Xue, Tianfan and Lim, Joseph J and Tian, Yuandong and Tenenbaum, Joshua B and Torralba, Antonio and Freeman, William T},
98 | booktitle={European Conference on Computer Vision},
99 | pages={365--382},
100 | year={2016}
101 | }
102 |
103 | For any questions, please contact Jiajun Wu (jiajunwu@mit.edu) and Tianfan Xue (tfxue@mit.edu).
104 |
--------------------------------------------------------------------------------
/data/bed.txt:
--------------------------------------------------------------------------------
1 | bed_00000001.png
2 | bed_00000002.jpg
3 | bed_00000003.jpg
4 |
--------------------------------------------------------------------------------
/data/bed_00000001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/bed_00000001.png
--------------------------------------------------------------------------------
/data/bed_00000002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/bed_00000002.jpg
--------------------------------------------------------------------------------
/data/bed_00000003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/bed_00000003.jpg
--------------------------------------------------------------------------------
/data/chair.txt:
--------------------------------------------------------------------------------
1 | chair_00000001.jpg
2 | chair_00000002.jpg
3 | chair_00000003.jpg
4 | chair_00000004.jpg
5 |
--------------------------------------------------------------------------------
/data/chair_00000001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/chair_00000001.jpg
--------------------------------------------------------------------------------
/data/chair_00000002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/chair_00000002.jpg
--------------------------------------------------------------------------------
/data/chair_00000003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/chair_00000003.jpg
--------------------------------------------------------------------------------
/data/chair_00000004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/chair_00000004.jpg
--------------------------------------------------------------------------------
/data/sofa.txt:
--------------------------------------------------------------------------------
1 | sofa_00000001.jpg
2 | sofa_00000002.jpg
3 | sofa_00000003.jpg
4 |
--------------------------------------------------------------------------------
/data/sofa_00000001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/sofa_00000001.jpg
--------------------------------------------------------------------------------
/data/sofa_00000002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/sofa_00000002.jpg
--------------------------------------------------------------------------------
/data/sofa_00000003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/sofa_00000003.jpg
--------------------------------------------------------------------------------
/data/swivelchair.txt:
--------------------------------------------------------------------------------
1 | swivelchair_00000001.png
2 | swivelchair_00000002.png
3 | swivelchair_00000003.jpg
4 | swivelchair_00000004.jpg
5 |
--------------------------------------------------------------------------------
/data/swivelchair_00000001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/swivelchair_00000001.png
--------------------------------------------------------------------------------
/data/swivelchair_00000002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/swivelchair_00000002.png
--------------------------------------------------------------------------------
/data/swivelchair_00000003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/swivelchair_00000003.jpg
--------------------------------------------------------------------------------
/data/swivelchair_00000004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiajunwu/3dinn/7d09607e211e75dd2a717e92edf4f3282f4e401e/data/swivelchair_00000004.jpg
--------------------------------------------------------------------------------
/download_models.sh:
--------------------------------------------------------------------------------
1 | mkdir -p models
2 |
3 | wget http://3dinterpreter.csail.mit.edu/repo/models/chair_model.torch -O models/chair_model.torch
4 | wget http://3dinterpreter.csail.mit.edu/repo/models/swivelchair_model.torch -O models/swivelchair_model.torch
5 | wget http://3dinterpreter.csail.mit.edu/repo/models/bed_model.torch -O models/bed_model.torch
6 | wget http://3dinterpreter.csail.mit.edu/repo/models/sofa_model.torch -O models/sofa_model.torch
7 |
--------------------------------------------------------------------------------
/src/3D/genSynData/getStickFigure.m:
--------------------------------------------------------------------------------
1 | function stickStruct = getStickFigure(varargin)
2 |
3 | % function stickStruct = getStickFigure(varargin)
4 | %
5 | % Parameters:
6 | % scale, class, indexPermute
7 | %
8 | % Fields of stickStruct
9 | % - nclass: ignore this field, this is always 1
10 | % - edgeAdj: M x 2 adjacent matrices, where M is the number of edges
11 | % - baseShape: 3 x N x B, where N is number of nodes, and B is number of
12 | % bases. baseShape(:,:,1) is the mean base shape, and baseShape(:, :, i)
13 | % are the additional deformation. See the demo code below for more
14 | % details.
15 | % - np: number of nodes
16 | % - nbasis: number of bases
17 | % - alphaGSif: ignore this field
18 | % - scaleRange: Range of scaling in camera parameters.
19 | % - thetaRange: Range of rotation angles in camera parameters.
20 | % - tranRange: Range of translation in camera parameters.
21 | % - fRange: Range of focal length
22 | % - h: the height of the input image
23 | % - w: the with of the input image
24 | % - indices: normally, it should be 1:N. For some models, it is not to
25 | % deal with some indices inconsistency.
26 | %
27 | % Demo code:
28 | %
29 | % cd [gitroot]/src
30 | % addpath(genpath('3D'));
31 | % addpath(genpath('nn'));
32 | % stickStruct = getStrickFigure('class', 'chair');
33 | % subplot(2,3,1); show3DLD(stickStruct.baseShape{1}(:,:,1), stickStruct.edgeAdj{1});
34 | % title('Mean shape');
35 | % for i = 2:5
36 | % subplot(2,3,i); show3DLD(stickStruct.baseShape{1}(:,:,1) + ...
37 | % stickStruct.baseShape{1}(:,:,i), stickStruct.edgeAdj{1});
38 | % title(sprintf('The %d-th deformation', i-1));
39 | % end
40 |
41 | para.scale = [];
42 | para.class = 'chair';
43 | para.indexPermute = true;
44 | para = propval(varargin, para);
45 |
46 | stickStruct.nclass = 1;
47 | stickStruct.edgeAdj = cell(1,1);
48 | stickStruct.baseShape = cell(1,1);
49 | stickStruct.alphaRange = cell(1,1);
50 | if ~isempty(para.scale) && (para.scale == 64)
51 | stickStruct.scaleRange = [9,11]/2;
52 | elseif ~isempty(para.scale) && (para.scale == 128)
53 | stickStruct.scaleRange = [9,11]/2;
54 | end
55 |
56 | if strcmp(para.class, 'chair')
57 | stickStruct.np = 10;
58 | stickStruct.nbasis = 5;
59 | stickStruct.edgeAdj{1} = [1,5;2,6;3,7;4,8;5,6;6,7;7,8;5,8;8,9;7,10;9,10];
60 | stickStruct.baseShape{1} = zeros(3,stickStruct.np,stickStruct.nbasis);
61 | stickStruct.baseShape{1}(:,:,1) = [-1,-2,1; 1,-2,1; 1,-2,-1; -1,-2,-1; -1,0,1; 1,0,1; 1,0,-1; -1,0,-1; -1,2,-1; 1,2,-1]';
62 | stickStruct.baseShape{1}(:,:,2) = [0,-1,0;0,-1,0;0,-1,0;0,-1,0; 0,0,0;0,0,0;0,0,0;0,0,0; 0,0,0;0,0,0]';
63 | stickStruct.baseShape{1}(:,:,3) = [0,0,0;0,0,0;0,0,0;0,0,0; 0,0,0;0,0,0;0,0,0;0,0,0; 0,1,0;0,1,0]';
64 | stickStruct.baseShape{1}(:,:,4) = [-1,0,0;1,0,0;1,0,0;-1,0,0; -1,0,0;1,0,0;1,0,0;-1,0,0; -1,0,0;1,0,0]';
65 | stickStruct.baseShape{1}(:,:,5) = [-1,0,1;1,0,1;1,0,-1;-1,0,-1; 0,0,0;0,0,0;0,0,0;0,0,0; 0,0,0;0,0,0]';
66 | stickStruct.alphaRange{1} = [-1.4,2;-1.4,2;-0.3,1;0,0.3];
67 | stickStruct.alphaGSif = false;
68 | stickStruct.scaleRange{1} = [30,50];
69 | stickStruct.thetaRange{1} = [1.0*pi,1.15*pi; 0,2*pi; 0,0];
70 | stickStruct.tranRange{1} = [-60,60;-80,80];
71 | stickStruct.fRange{1} = [120,450];
72 | stickStruct.h = 320;
73 | stickStruct.w = 240;
74 | stickStruct.indices{1} = 1:10;
75 | stickStruct.shapecheckFunc = [];
76 | elseif strcmp(para.class, 'swivelchair')
77 | stickStruct.np = 13;
78 | stickStruct.nbasis = 7;
79 | stickStruct.edgeAdj{1} = [1,6;2,6;3,6;4,6;5,6; 6,7;8,9;9,10;8,11;10,11; 11,12;12,13;10,13];
80 | stickStruct.baseShape{1} = zeros(3,stickStruct.np,stickStruct.nbasis);
81 | tmpScale = 1.2;
82 | stickStruct.baseShape{1}(:,:,1) = [0.3090*tmpScale,-1.5,0.9511*tmpScale; -0.8090*tmpScale,-1.5,0.5878*tmpScale; -0.8090*tmpScale,-1.5,-0.5878*tmpScale;
83 | 0.3090*tmpScale,-1.5,-0.9511*tmpScale; 1.0000*tmpScale,-1.5,0.0000*tmpScale;
84 | 0,-1.5,0; 0,0,0; -1,0,1; 1,0,1; 1,0,-1; -1,0,-1; -1,1.5,-1; 1,1.5,-1]';
85 | stickStruct.baseShape{1}(:,:,2) = [0,-1,0; 0,-1,0; 0,-1,0; 0,-1,0; 0,-1,0; 0,-1,0; 0,0,0;
86 | 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;]';
87 | stickStruct.baseShape{1}(:,:,3) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;
88 | 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,1,0; 0,1,0;]';
89 | stickStruct.baseShape{1}(:,:,4) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;
90 | -1,0,0; 1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0;]';
91 | stickStruct.baseShape{1}(:,:,5) = [0.3090,0,0.9511; -0.8090,0,0.5878; -0.8090,0,-0.5878; 0.3090,0,-0.9511; 1.0000,0,0.0000;
92 | 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;]'; % wheels
93 | stickStruct.baseShape{1}(:,:,6) = [-0.9511,0,0.3090; -0.5878,0,-0.8090; 0.5878,0,-0.8090; 0.9511,0,0.3090; 0.0000,0,1.0000;
94 | 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;]'; % wheels
95 | stickStruct.baseShape{1}(:,:,7) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;
96 | 0,1,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0;]'; % wheels
97 | stickStruct.alphaRange{1} = [-0.5,0.5; -1,2.5; -0.5,1; -0.5,0.8; -0.3,0.6; 0,0.6];
98 | stickStruct.alphaGSif = false;
99 | stickStruct.scaleRange{1} = [35,85];
100 | stickStruct.tranRange{1} = [-60,60;-80,80];
101 | stickStruct.thetaRange{1} = [1.0*pi,1.15*pi; 0,2*pi; 0,0];
102 | stickStruct.fRange{1} = [120,450];
103 | stickStruct.h = 320;
104 | stickStruct.w = 240;
105 | stickStruct.indices{1} = 1:13;
106 | stickStruct.shapecheckFunc = [];
107 | elseif strcmp(para.class, 'table')
108 | stickStruct.np = 8;
109 | stickStruct.nbasis = 6;
110 | stickStruct.edgeAdj{1} = [1,2;2,4;1,3;3,4;1,5;2,6;3,7;4,8];
111 | stickStruct.baseShape{1} = zeros(3,stickStruct.np,stickStruct.nbasis);
112 | stickStruct.baseShape{1}(:,:,1) = [-1,1,-1; 1,1,-1; -1,1,1; 1,1,1; -1,-1,-1; 1,-1,-1; -1,-1,1; 1,-1,1;]';
113 | stickStruct.baseShape{1}(:,:,2) = [0,1,0;0,1,0;0,1,0;0,1,0; 0,-1,0;0,-1,0;0,-1,0;0,-1,0;]';
114 | stickStruct.baseShape{1}(:,:,3) = [-1,0,0;1,0,0;-1,0,0;1,0,0; -1,0,0;1,0,0;-1,0,0;1,0,0;]';
115 | stickStruct.baseShape{1}(:,:,4) = [0,0,-1;0,0,-1;0,0,1;0,0,1; 0,0,-1;0,0,-1;0,0,1;0,0,1;]';
116 | stickStruct.baseShape{1}(:,:,5) = [0,0,0;0,0,0;0,0,0;0,0,0; 1,0,0;-1,0,0;1,0,0;-1,0,0;]';
117 | stickStruct.baseShape{1}(:,:,6) = [0,0,0;0,0,0;0,0,0;0,0,0; 0,0,-1;0,0,-1;0,0,1;0,0,1;]';
118 | stickStruct.alphaRange{1} = [-0.3,0.3;-0.6,2.5;-0.3,1; 0,1;0,0.2];
119 | stickStruct.alphaGSif = false;
120 | stickStruct.scaleRange{1} = [30,42];
121 | stickStruct.tranRange{1} = [-80,80; -60,60];
122 | % stickStruct.tranRange{1} = [-10,10; -120,120];
123 | stickStruct.thetaRange{1} = [1.0*pi,1.3*pi; -0.23*pi,0.23*pi; 0,0];
124 | stickStruct.fRange{1} = [120,450];
125 | stickStruct.h = 320;
126 | stickStruct.w = 240;
127 | stickStruct.indices{1} = 1:8;
128 | stickStruct.shapecheckFunc = [];
129 | elseif strcmp(para.class, 'bed')
130 | stickStruct.np = 10;
131 | stickStruct.nbasis = 5;
132 | stickStruct.edgeAdj{1} = [1,5;2,6;3,7;4,8;3,4;5,6;6,7;7,8;5,8;8,9;7,10;9,10];
133 | stickStruct.baseShape{1} = zeros(3,stickStruct.np,stickStruct.nbasis);
134 | stickStruct.baseShape{1}(:,:,1) = [-1.00,-1.00,2.00; 1.00,-1.00,2.00; 1.00,-1.00,-2.00; -1.00,-1.00,-2.00; -1.00,0.00,2.00; 1.00,0.00,2.00;
135 | 1.00,0.00,-2.00; -1.00, 0.00,-2.00; -1.00,1.00,-2.00; 1.00,1.00,-2.00]';
136 | stickStruct.baseShape{1}(:,:,2) = [0,0,1; 0,0,1; 0,0,-1; 0,0,-1; 0,0,1; 0,0,1; 0,0,-1; 0,0,-1; 0,0,-1; 0,0,-1]'; % length
137 | stickStruct.baseShape{1}(:,:,3) = [-1,0,0; 1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0]'; % width
138 | stickStruct.baseShape{1}(:,:,4) = [0,-1,0; 0,-1,0; 0,-1,0; 0,-1,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0]'; % height of leg
139 | stickStruct.baseShape{1}(:,:,5) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,1,0; 0,1,0]'; % height of back
140 | stickStruct.alphaRange{1} = [-1.2,0.5; -0.5,1.5; -0.5,0.5; -0.5,1];
141 | stickStruct.alphaGSif = false;
142 | stickStruct.scaleRange{1} = [30,42];
143 | stickStruct.tranRange{1} = [-80,80; -60,60];
144 | % stickStruct.tranRange{1} = [-10,10; -120,120];
145 | stickStruct.thetaRange{1} = [1.0*pi,1.25*pi; -0.75*pi,0.75*pi; 0,0];
146 | stickStruct.fRange{1} = [200,450];
147 | stickStruct.h = 320;
148 | stickStruct.w = 240;
149 | % stickStruct.indices{1} = [2,7,9,4,1,6,8,3,5,10];
150 | stickStruct.indices{1} = [2,7,6,1,4,9,8,3,5,10];
151 | stickStruct.shapecheckFunc = [];
152 | elseif strcmp(para.class, 'sofa')
153 | stickStruct.np = 14;
154 | stickStruct.nbasis = 7;
155 | stickStruct.edgeAdj{1} = [1,5;2,6;3,7;4,8;5,6;6,7;7,8;5,8;8,9;7,10;9,10;11,12;12,5;13,14;14,6];
156 | stickStruct.baseShape{1} = zeros(3,stickStruct.np,stickStruct.nbasis);
157 | stickStruct.baseShape{1}(:,:,1) = [-2.00,-2.00,1.00; 2.00,-2.00,1.00; 2.00,-2.00,-1.00; -2.00,-2.00,-1.00; -2.00,0.00,1.00;
158 | 2.00,0.00,1.00; 2.00,0.00,-1.00; -2.00,0.00,-1.00; -2.00,2.00,-1.00; 2.00,2.00,-1.00; -2.00,1.00,-1.00; -2.00,1.00,1.00; 2.00,1.00,-1.00; 2.00,1.00,1.00]';
159 | stickStruct.baseShape{1}(:,:,2) = [0,0,1; 0,0,1; 0,0,-1; 0,0,-1; 0,0,1; 0,0,1; 0,0,-1; 0,0,-1; 0,0,-1; 0,0,-1; 0,0,-1; 0,0,1; 0,0,-1; 0,0,1]'; % length
160 | stickStruct.baseShape{1}(:,:,3) = [-1,0,0; 1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0; -1,0,0; -1,0,0; 1,0,0; 1,0,0]'; % width
161 | stickStruct.baseShape{1}(:,:,4) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,1,0; 0,1,0; 0,0.5,0; 0,0.5,0; 0,0.5,0; 0,0.5,0]'; % height of back
162 | stickStruct.baseShape{1}(:,:,5) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,1,0; 0,1,0; 0,1,0; 0,1,0]'; % height of arm rest
163 | stickStruct.baseShape{1}(:,:,6) = [0,-1,0; 0,-1,0; 0,-1,0; 0,-1,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0]'; % height of leg
164 | stickStruct.baseShape{1}(:,:,7) = [0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,0; 0,0,-1; 0,0,0; 0,0,-1]'; % loc of arm rest
165 | stickStruct.alphaRange{1} = [-0.5,0.5; 0,2; -0.5,1; -1,2; -1,1; 0,2];
166 | stickStruct.alphaGSif = false;
167 | stickStruct.scaleRange{1} = [30,42];
168 | stickStruct.tranRange{1} = [-80,80; -60,60];
169 | stickStruct.thetaRange{1} = [1.0*pi,1.15*pi; -0.5*pi,0.5*pi; 0,0];
170 | stickStruct.fRange{1} = [200,450];
171 | stickStruct.h = 320;
172 | stickStruct.w = 240;
173 | stickStruct.indices{1} = [2,9,8,1,4,11,10,3,7,14,5,6,12,13];
174 | stickStruct.shapecheckFunc = [];
175 | end
176 |
177 | if para.indexPermute
178 | forwardidx = stickStruct.indices{1};
179 | revertidx = zeros(size(forwardidx));
180 | revertidx(stickStruct.indices{1}) = 1:length(stickStruct.indices{1});
181 | stickStruct.baseShape{1} = stickStruct.baseShape{1}(:,revertidx,:);
182 | stickStruct.edgeAdj{1} = forwardidx(stickStruct.edgeAdj{1});
183 | end
184 |
185 |
186 |
187 | if ~isempty(para.scale) && (para.scale == 64)
188 | stickStruct.h = 64;
189 | stickStruct.w = 64;
190 | elseif ~isempty(para.scale) && (para.scale == 128)
191 | stickStruct.h = 128;
192 | stickStruct.w = 128;
193 | end
194 |
195 | end
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
--------------------------------------------------------------------------------
/src/3D/genSynData/sctheta2theta.m:
--------------------------------------------------------------------------------
1 | function theta = sctheta2theta(sincosTheta)
2 |
3 | % function theta = sctheta2theta(sincosTheta)
4 |
5 | sincosTheta = sincosTheta ./ repmat(sqrt(sincosTheta(1:3,:).^2 + sincosTheta(4:6,:).^2),[2,1]);
6 | theta = asin(sincosTheta(1:3,:));
7 | mask = sincosTheta(4:6,:) < 0;
8 | theta(mask) = pi - theta(mask);
9 |
10 | end
11 |
12 |
--------------------------------------------------------------------------------
/src/3D/tools/alpha2X3D.m:
--------------------------------------------------------------------------------
1 | function X = alpha2X3D(alpha, baseShape)
2 |
3 | X = sum(bsxfun(@times,baseShape,shiftdim(alpha,-2)),3);
4 |
5 | end
--------------------------------------------------------------------------------
/src/3D/tools/alpha2x_proj.m:
--------------------------------------------------------------------------------
1 | function x = alpha2x_proj(tran,alpha,theta,f, baseShape, varargin)
2 |
3 | % function x = alpha2x_proj(tran,alpha,theta,f, baseShape, varargin)
4 | %
5 | % Paramters:h, w
6 |
7 | para.h = 64;
8 | para.w = 64;
9 | para = propval(varargin, para);
10 | h = para.h; h2 = h/2;
11 | w = para.w; w2 = w/2;
12 |
13 | np = size(baseShape,2);
14 | nShape = size(alpha,2);
15 | x = zeros([2,np,nShape]);
16 | for i=1:nShape
17 | Rx = [1,0,0; 0,cos(theta(1,i)),-sin(theta(1,i));0,sin(theta(1,i)),cos(theta(1,i))];
18 | Ry = [cos(theta(2,i)),0,sin(theta(2,i)); 0,1,0;-sin(theta(2,i)),0,cos(theta(2,i))];
19 | Rz = [cos(theta(3,i)),-sin(theta(3,i)),0; sin(theta(3)),cos(theta(3,i)),0;0,0,1];
20 | R = Rx*Ry*Rz;
21 |
22 | coor3 = sum(bsxfun(@times,baseShape,shiftdim(alpha(:,i),-2)),3);
23 | coor3RT = bsxfun(@plus, R * coor3, [tran(:,i);0]);
24 | x(:,:,i) = bsxfun(@plus, bsxfun(@times, coor3RT(1:2,:), 1./(1+coor3RT(3,:)/f(i))), [w2;h2]);
25 | end
26 |
27 | end
28 |
29 |
--------------------------------------------------------------------------------
/src/3D/tools/propval.m:
--------------------------------------------------------------------------------
1 | function [merged unused] = propval(propvals, defaults, varargin)
2 |
3 | % Create a structure combining property-value pairs with default values.
4 | %
5 | % [MERGED UNUSED] = PROPVAL(PROPVALS, DEFAULTS, ...)
6 | %
7 | % Given a cell array or structure of property-value pairs
8 | % (i.e. from VARARGIN or a structure of parameters), PROPVAL will
9 | % merge the user specified values with those specified in the
10 | % DEFAULTS structure and return the result in the structure
11 | % MERGED. Any user specified values that were not listed in
12 | % DEFAULTS are output as property-value arguments in the cell array
13 | % UNUSED. STRICT is disabled in this mode.
14 | %
15 | % ALTERNATIVE USAGE:
16 | %
17 | % [ ARGS ] = PROPVAL(PROPVALS, DEFAULTS, ...)
18 | %
19 | % In this case, propval will assume that no user specified
20 | % properties are meant to be "picked up" and STRICT mode will be enforced.
21 | %
22 | % ARGUMENTS:
23 | %
24 | % PROPVALS - Either a cell array of property-value pairs
25 | % (i.e. {'Property', Value, ...}) or a structure of equivalent form
26 | % (i.e. struct.Property = Value), to be merged with the values in
27 | % DEFAULTS.
28 | %
29 | % DEFAULTS - A structure where field names correspond to the
30 | % default value for any properties in PROPVALS.
31 | %
32 | % OPTIONAL ARGUMENTS:
33 | %
34 | % STRICT (default = true) - Use strict guidelines when processing
35 | % the property value pairs. This will warn the user if an empty
36 | % DEFAULTS structure is passed in or if there are properties in
37 | % PROPVALS for which no default was provided.
38 | %
39 | % EXAMPLES:
40 | %
41 | % Simple function with two optional numerical parameters:
42 | %
43 | % function [result] = myfunc(data, varargin)
44 | %
45 | % defaults.X = 5;
46 | % defaults.Y = 10;
47 | %
48 | % args = propvals(varargin, defaults)
49 | %
50 | % data = data * Y / X;
51 | %
52 | % >> myfunc(data)
53 | % This will run myfunc with X=5, Y=10 on the variable 'data'.
54 | %
55 | % >> myfunc(data, 'X', 0)
56 | % This will run myfunc with X=0, Y=10 (thus giving a
57 | % divide-by-zero error)
58 | %
59 | % >> myfunc(data, 'foo', 'bar') will run myfunc with X=5, Y=10, and
60 | % PROPVAL will give a warning that 'foo' has no default value,
61 | % since STRICT is true by default.
62 | %
63 |
64 | % License:
65 | %=====================================================================
66 | %
67 | % This is part of the Princeton MVPA toolbox, released under
68 | % the GPL. See http://www.csbmb.princeton.edu/mvpa for more
69 | % information.
70 | %
71 | % The Princeton MVPA toolbox is available free and
72 | % unsupported to those who might find it useful. We do not
73 | % take any responsibility whatsoever for any problems that
74 | % you have related to the use of the MVPA toolbox.
75 | %
76 | % ======================================================================
77 |
78 | % Backwards compatibility
79 | pvdef.ignore_missing_default = false;
80 | pvdef.ignore_empty_defaults = false;
81 |
82 | % check for the number of outputs
83 | if nargout == 2
84 | pvdef.strict = false;
85 | else
86 | pvdef.strict = true;
87 | end
88 |
89 | pvargs = pvdef;
90 |
91 | % Recursively process the propval optional arguments (possible
92 | % because we only recurse if optional parameters are given)
93 | if ~isempty(varargin)
94 | pvargs = propval(varargin, pvdef);
95 | end
96 |
97 | % NOTE: Backwards compatibility with previous version of propval
98 | if pvargs.ignore_missing_default | pvargs.ignore_empty_defaults
99 | pvargs.strict = false;
100 | end
101 |
102 | % check for a single cell argument; assume propvals is that argument
103 | if iscell(propvals) && numel(propvals) == 1
104 | propvals = propvals{1};
105 | end
106 |
107 | % check for valid inputs
108 | if ~iscell(propvals) & ~isstruct(propvals)
109 | error('Property-value pairs must be a cell array or a structure.');
110 | end
111 |
112 | if ~isstruct(defaults) & ~isempty(defaults)
113 | error('Defaults struct must be a structure.');
114 | end
115 |
116 | % check for empty defaults structure
117 | if isempty(defaults)
118 | if pvargs.strict & ~pvargs.ignore_missing_default
119 | error('Empty defaults structure passed to propval.');
120 | end
121 | defaults = struct();
122 | end
123 |
124 | defaultnames = fieldnames(defaults);
125 | defaultvalues = struct2cell(defaults);
126 |
127 | % prepare the defaults structure, but also prepare casechecking
128 | % structure with all case stripped
129 | defaults = struct();
130 | casecheck = struct();
131 |
132 | for i = 1:numel(defaultnames)
133 | defaults.(defaultnames{i}) = defaultvalues{i};
134 | casecheck.(lower(defaultnames{i})) = defaultvalues{i};
135 | end
136 |
137 | % merged starts with the default values
138 | merged = defaults;
139 | unused = {};
140 | used = struct();
141 |
142 | properties = [];
143 | values = [];
144 |
145 | % To extract property value pairs, we use different methods
146 | % depending on how they were passed in
147 | if isstruct(propvals)
148 | properties = fieldnames(propvals);
149 | values = struct2cell(propvals);
150 | else
151 | properties = { propvals{1:2:end} };
152 | values = { propvals{2:2:end} };
153 | end
154 |
155 | if numel(properties) ~= numel(values)
156 | error(sprintf('Found %g properties but only %g values.', numel(properties), ...
157 | numel(values)));
158 | end
159 |
160 | % merge new properties with defaults
161 | for i = 1:numel(properties)
162 |
163 | if ~ischar(properties{i})
164 | error(sprintf('Property %g is not a string.', i));
165 | end
166 |
167 | % convert property names to lower case
168 | properties{i} = properties{i};
169 |
170 | % check for multiple usage
171 | if isfield(used, properties{i})
172 | error(sprintf('Property %s is defined more than once.\n', ...
173 | properties{i}));
174 | end
175 |
176 | % Check for case errors
177 | if isfield(casecheck, lower(properties{i})) & ...
178 | ~isfield(merged, properties{i})
179 | error(['Property ''%s'' is equal to a default property except ' ...
180 | 'for case.'], properties{i});
181 | end
182 |
183 | % Merge with defaults
184 | if isfield(merged, properties{i})
185 | merged.(properties{i}) = values{i};
186 | else
187 | % add to unused property value pairs
188 | unused{end+1} = properties{i};
189 | unused{end+1} = values{i};
190 |
191 | % add to defaults, just in case, if the user isn't picking up "unused"
192 | if (nargout == 1 & ~pvargs.strict)
193 | merged.(properties{i}) = values{i};
194 | end
195 |
196 | if pvargs.strict
197 | error('Property ''%s'' has no default value.', properties{i});
198 | end
199 |
200 | end
201 |
202 | % mark as used
203 | used.(properties{i}) = true;
204 | end
205 |
206 |
--------------------------------------------------------------------------------
/src/3D/tools/renderImage.m:
--------------------------------------------------------------------------------
1 | function imlist = renderImage(x, edgeAdj, varargin)
2 |
3 | % function imlist = renderImage(x, edgeAdj, varargin)
4 | %
5 | % Parameter: addNoise, addNode, outputStep, nRandLine, nRandTri,
6 | % noiseLevel, h, w, scale, lineWidth, circSize
7 |
8 | para.addNoise = true;
9 | para.addNode = false;
10 | para.outputStep = 1000;
11 | para.nRandLine = 15;
12 | para.nRandTri = 3;
13 | para.noiseLevel = 0.1;
14 | para.h = 64;
15 | para.w = 64;
16 | para.scale = 2;
17 | para.lineWidth = 3;
18 | para.circSize = 4;
19 | para.parallel = true;
20 | para = propval(varargin, para);
21 |
22 | h=para.h; w=para.w;
23 | hb = h*para.scale; wb = w*para.scale;
24 | nShape = size(x,3);
25 | circSize = para.circSize;
26 | imlist = zeros(h,w,1,nShape);
27 | edgeColor = [0,0,0];
28 | np = size(x,2);
29 | if para.parallel
30 | parfor i=1:nShape
31 | % for i=1:nShape
32 | im = ones(hb,wb,3);
33 | coor = x(1:2,:,i) * para.scale;
34 | if para.addNode
35 | im = insertShape(im, 'FilledCircle', [coor',circSize*ones(np,1)], 'Color',edgeColor);
36 | end
37 | im = insertShape(im, 'Line', [coor(1,edgeAdj(:,1))',coor(2,edgeAdj(:,1))',...
38 | coor(1,edgeAdj(:,2))',coor(2,edgeAdj(:,2))'], 'Color', edgeColor, 'LineWidth', para.lineWidth);
39 |
40 | if para.addNoise
41 | xyStart = bsxfun(@times,rand(para.nRandLine,2),[wb,hb]);
42 | xyLen = rand(para.nRandLine,1) * (hb/8) + hb/8; angle = rand(para.nRandLine,1)*(2*pi);
43 | xyEnd = bsxfun(@plus, xyStart, bsxfun(@times,xyLen,[cos(angle),sin(angle)]));
44 | im = insertShape(im, 'Line', [xyStart,xyEnd], 'LineWidth', 1, 'Color', edgeColor);
45 | xyCenter = bsxfun(@plus, bsxfun(@times,rand(para.nRandTri,2),[wb/2,hb/2]), [wb/4,hb/4]);
46 | xyLen = rand(para.nRandTri,1) * (hb/8) + hb/8; angle = rand([para.nRandTri,1,3])*(2*pi);
47 | xyEnd = reshape(bsxfun(@plus, xyCenter, bsxfun(@times,xyLen,[cos(angle),sin(angle)])),[para.nRandTri,6]);
48 | im = insertShape(im, 'FilledPolygon', xyEnd, 'Opacity', 1, 'Color', repmat(rand(para.nRandTri,1)*0.5,[1,3]));
49 | im = im + para.noiseLevel * randn(size(im));
50 | end
51 |
52 | imlist(:,:,:,i) = imresize(im(:,:,1),[h,w]);
53 | if mod(i,para.outputStep) == 0
54 | fprintf(1,'shape %d\n', i);
55 | end
56 | end
57 | else
58 | for i=1:nShape
59 | im = ones(hb,wb,3);
60 | coor = x(1:2,:,i) * para.scale;
61 | if para.addNode
62 | im = insertShape(im, 'FilledCircle', [coor',circSize*ones(np,1)], 'Color',edgeColor);
63 | end
64 | im = insertShape(im, 'Line', [coor(1,edgeAdj(:,1))',coor(2,edgeAdj(:,1))',...
65 | coor(1,edgeAdj(:,2))',coor(2,edgeAdj(:,2))'], 'Color', edgeColor, 'LineWidth', para.lineWidth);
66 |
67 | if para.addNoise
68 | xyStart = bsxfun(@times,rand(para.nRandLine,2),[wb,hb]);
69 | xyLen = rand(para.nRandLine,1) * (hb/8) + hb/8; angle = rand(para.nRandLine,1)*(2*pi);
70 | xyEnd = bsxfun(@plus, xyStart, bsxfun(@times,xyLen,[cos(angle),sin(angle)]));
71 | im = insertShape(im, 'Line', [xyStart,xyEnd], 'LineWidth', 1, 'Color', edgeColor);
72 | xyCenter = bsxfun(@plus, bsxfun(@times,rand(para.nRandTri,2),[wb/2,hb/2]), [wb/4,hb/4]);
73 | xyLen = rand(para.nRandTri,1) * (hb/8) + hb/8; angle = rand([para.nRandTri,1,3])*(2*pi);
74 | xyEnd = reshape(bsxfun(@plus, xyCenter, bsxfun(@times,xyLen,[cos(angle),sin(angle)])),[para.nRandTri,6]);
75 | im = insertShape(im, 'FilledPolygon', xyEnd, 'Opacity', 1, 'Color', repmat(rand(para.nRandTri,1)*0.5,[1,3]));
76 | im = im + para.noiseLevel * randn(size(im));
77 | end
78 |
79 | imlist(:,:,:,i) = imresize(im(:,:,1),[h,w]);
80 | if mod(i,para.outputStep) == 0
81 | fprintf(1,'shape %d\n', i);
82 | end
83 | end
84 |
85 | end
86 |
87 | end
88 |
89 |
--------------------------------------------------------------------------------
/src/3D/tools/show3DLD.m:
--------------------------------------------------------------------------------
1 | function show3DLD(X, adj, color, varargin)
2 |
3 | % function show3DLD(X, adj, color, varargin)
4 | %
5 | % This file is created by Tianfan Xue (tfxue@mit.edu).
6 |
7 | para.showNumber = false;
8 | para.showAxis = true;
9 | para.fontsize = 20;
10 | para = propval(varargin, para);
11 |
12 | if (nargin == 2) || isempty(color)
13 | color = [1,0,0];
14 | end
15 |
16 | X1 = X(1,:); X2 = X(2,:); X3 = X(3,:);
17 | plot3(X1,X2,X3,'.','Color',color,'MarkerSize',10);
18 | hold on;
19 | line(X1(adj)', X2(adj)', X3(adj)', 'Color', color,'LineWidth',3);
20 | if para.showNumber
21 | for i=1:size(X,2)
22 | text(X1(i),X2(i),X3(i),sprintf('%d',i),'FontSize',para.fontsize);
23 | end
24 | end
25 | if para.showAxis
26 | xlabel('x');
27 | ylabel('y');
28 | zlabel('z');
29 | end
30 | axis equal
31 |
32 | end
33 |
34 |
--------------------------------------------------------------------------------
/src/3D/tools/visualize3Dpara.m:
--------------------------------------------------------------------------------
1 | function im = visualize3Dpara(alpha, sincosTheta, tran, f, baseShape, edgeAdj, varargin)
2 |
3 | % function im = visualize3Dpara(alpha, sincosTheta, tran, f, baseShape, varargin)
4 | %
5 | % Parameters:
6 | % alpha: nbasis x 1
7 | % sincosTheta: 6 x 1
8 | % tran: 2x1
9 | % f: 1
10 | % baseShape: 3 x np x nbasis
11 | % edgeAdj: nedge x 2
12 | %
13 | % Addition parameters:
14 | % h, w, lineWidth, addNode, circSize
15 | %
16 | % Example:
17 | % stickStruct = getStickFigure('class', 'chair');
18 | % im = visualize3Dpara(alpha, sincosTheta, tran, f, stickStruct.baseShape{1}, stickStruct.edgeAdj{1});
19 |
20 | para.h = 320;
21 | para.w = 240;
22 | para.lineWidth = 6;
23 | para.addNode = true;
24 | para.circSize = 8;
25 | para = propval(varargin, para);
26 |
27 | theta = sctheta2theta(sincosTheta);
28 | x = alpha2x_proj(tran,alpha,theta,f,baseShape, 'w', para.w, 'h',para.h);
29 | im = renderImage(x,edgeAdj,'addNoise',false,'noiseLevel',0,'h',para.h,'w',para.w,'lineWidth',6,'parallel',false,...
30 | 'addNode',para.addNode, 'circSize', para.circSize);
31 |
32 | end
33 |
34 |
35 |
--------------------------------------------------------------------------------
/src/evaluate.m:
--------------------------------------------------------------------------------
1 | function evaluate(resPath, wwwPath, class, width, height)
2 |
3 | if strcmp(class, 'chair') || strcmp(class, 'bed')
4 | param_length = [5, 6, 12, 14];
5 | elseif strcmp(class, 'swivelchair') || strcmp(class, 'sofa')
6 | param_length = [7, 8, 14, 16];
7 | end
8 |
9 | load(resPath, 'outputs');
10 | numInst = size(outputs, 1);
11 | alphaPred = outputs(:, 1 : param_length(1));
12 | finvPred = outputs(:, (param_length(1) + 1) : param_length(2));
13 | sincosThetaPred = outputs(:, (param_length(2) + 1) : param_length(3));
14 | sincosThetaPred(:, 3) = 0;
15 | sincosThetaPred(:, 6) = 1;
16 | tranPred = outputs(:, (param_length(3) + 1) : param_length(4));
17 |
18 | addpath(fullfile('3D', 'genSynData'));
19 | addpath(fullfile('3D', 'tools'));
20 | stickStruct = getStickFigure('class', class);
21 |
22 | for i = 1 : numInst
23 | im = visualize3Dpara(alphaPred(i, :)', sincosThetaPred(i, :)', ...
24 | tranPred(i, :)', 1 ./ finvPred(i, :), stickStruct.baseShape{1}, ...
25 | stickStruct.edgeAdj{1}, 'h', height, 'w', width, 'lineWidth', 6);
26 | imwrite(im, fullfile(wwwPath, sprintf('%08d.jpg', i)));
27 | end
28 |
29 | end
30 |
31 |
--------------------------------------------------------------------------------
/src/main.lua:
--------------------------------------------------------------------------------
1 | -------------------------------------------------------------------------------
2 | print '==> Initializing...'
3 | package.path = '?.lua;' .. package.path
4 |
5 | require 'torch'
6 | require 'image'
7 |
8 | require 'nn'
9 | require 'cunn'
10 | require 'cutorch'
11 |
12 | require 'nn.Scale'
13 |
14 | dofile 'pyramidLCN.lua'
15 |
16 | -------------------------------------------------------------------------------
17 |
18 | local cmd = torch.CmdLine()
19 |
20 | cmd:text()
21 | cmd:text('3D-INN evaluation script')
22 | cmd:text()
23 | cmd:text('Options:')
24 | cmd:option('-seed', 0, 'Random seed')
25 | cmd:option('-gpuID', 1, 'ID of GPUs to use')
26 | cmd:option('-batchSize', 16, 'Batch Size')
27 | cmd:option('-class', 'chair', 'Model class for evaluation')
28 | cmd:text()
29 |
30 | local opt = cmd:parse(arg or {})
31 |
32 | cutorch.setDevice(opt.gpuID)
33 | torch.manualSeed(opt.seed)
34 | math.randomseed(opt.seed)
35 | cutorch.manualSeedAll(opt.seed)
36 |
37 | --------------------------------------------------------------------------------
38 | -- constants
39 | opt.chnNum = 3 -- RGB
40 | opt.inputWidth = 240
41 | opt.inputHeight = 320
42 | opt.heatmapWidth = 30
43 | opt.heatmapHeight = 40
44 |
45 | opt.bankSize = {320, 160, 80} -- input sizes (long edge) for each bank
46 | opt.pyramidSize = {} -- pyramid sizes for each scale
47 | for i = 1, #opt.bankSize do
48 | local scale = opt.bankSize[i] / math.max(opt.inputWidth, opt.inputHeight)
49 | opt.pyramidSize[i] = {math.floor(scale * opt.inputHeight),
50 | math.floor(scale * opt.inputWidth)}
51 | end
52 |
53 | if opt.class == 'chair' then
54 | opt.paramNum = 14
55 | opt.paramMin = {30, -70, -70, -15, 0, 1/450, -1, -1, -1, -1, -1, -1, -60, -80}
56 | opt.paramMax = {50, 100, 100, 100, 150, 1/120, 1, 1, 1, 1, 1, 1, 60, 80}
57 | opt.paramWeight = torch.ones(opt.paramNum)
58 | opt.paramWeight[6] = 10
59 | elseif opt.class == 'swivelchair' then
60 | opt.paramNum = 16
61 | opt.paramMin = {35, -42.5, -85, -42.5, -42.5, -25.5, 0, 1/450, -1, -1, -1, -1, -1, -1, -60, -80}
62 | opt.paramMax = {85, 42.5, 212.5, 85, 65, 51, 51, 1/120, 1, 1, 1, 1, 1, 1, 60, 80}
63 | opt.paramWeight = torch.ones(opt.paramNum)
64 | opt.paramWeight[8] = 10
65 | elseif opt.class == 'bed' then
66 | opt.paramNum = 14
67 | opt.paramMin = {30, -50.4, -21, -21, -21, 1/450, -1, -1, -1, -1, -1, -1, -80, -60}
68 | opt.paramMax = {42, 21, 63, 21, 42, 1/200, 1, 1, 1, 1, 1, 1, 80, 60}
69 | opt.paramWeight = torch.ones(opt.paramNum)
70 | opt.paramWeight[6] = 10
71 | elseif opt.class == 'sofa' then
72 | opt.paramNum = 16
73 | opt.paramMin = {30, -21, 0, -21, -42, -42, 0, 1/450, -1, -1, -1, -1, -1, -1, -80, -60}
74 | opt.paramMax = {42, 21, 84, 42, 84, 42, 84, 1/200, 1, 1, 1, 1, 1, 1, 80, 60}
75 | opt.paramWeight = torch.ones(opt.paramNum)
76 | opt.paramWeight[8] = 10
77 | end
78 |
79 | opt.mean = torch.Tensor({129.67, 114.43, 107.26})
80 | opt.mean:mul(1 / 255)
81 |
82 | --------------------------------------------------------------------------------
83 | -- paths
84 | opt.dataPath = paths.concat('..', 'data')
85 | opt.modelPath = paths.concat('..', 'models', opt.class .. '_model.torch')
86 |
87 | opt.wwwPath = paths.concat('..', 'www', opt.class)
88 | opt.resPath = paths.concat('..', 'results')
89 | opt.genPath = paths.concat('..', 'gen')
90 |
91 | os.execute('mkdir -p ' .. opt.wwwPath)
92 | os.execute('mkdir -p ' .. opt.resPath)
93 | os.execute('mkdir -p ' .. opt.genPath)
94 |
95 | opt.resPath = paths.concat(opt.resPath, opt.class .. '.mat')
96 | opt.scalePath = paths.concat(opt.genPath, opt.class .. '_scale.torch')
97 | opt.inputPath = paths.concat(opt.genPath, opt.class .. '_inputs.torch')
98 |
99 | -------------------------------------------------------------------------------
100 | function unnormOutput(params, opt)
101 | for i = 1, opt.paramNum do
102 | if opt.paramMax[i] ~= opt.paramMin[i] then
103 | params[i] = opt.paramMin[i] + params[i] /
104 | opt.paramWeight[i] * (opt.paramMax[i] - opt.paramMin[i])
105 | else
106 | params[i] = opt.paramMax[i]
107 | end
108 | end
109 |
110 | return params
111 | end
112 |
113 | -- load data and build lcn
114 | function loadData(opt)
115 | local imageList = paths.concat(opt.dataPath, opt.class .. '.txt')
116 | local imageNum = 0
117 | local imageNames = {}
118 | for line in io.lines(imageList) do
119 | imageNum = imageNum + 1
120 | imageNames[imageNum] = line
121 | end
122 |
123 | print('generate scales')
124 | local oriCoords = torch.zeros(imageNum, 2, 2)
125 | local images = torch.Tensor(imageNum, opt.chnNum, opt.inputHeight, opt.inputWidth)
126 |
127 | for i = 1, imageNum do
128 | local oriIm = image.load(paths.concat(opt.dataPath, imageNames[i]))
129 | if oriIm:size(1) == 1 and opt.chnNum == 3 then
130 | oriIm = torch.repeatTensor(oriIm, 3, 1, 1)
131 | end
132 | for k = 1, opt.chnNum do
133 | local maxValue = oriIm[k]:max()
134 | local minValue = oriIm[k]:min()
135 | local imScaled = oriIm[k]:clone():add(-minValue):mul(1.0 / (maxValue - minValue))
136 | images[i][k] = image.scale(imScaled, opt.inputWidth, opt.inputHeight)
137 | images[i][k]:add(-opt.mean[k])
138 | end
139 |
140 | -- get new im trans from im
141 | local scaleX = opt.heatmapWidth / oriIm:size(3)
142 | local scaleY = opt.heatmapHeight / oriIm:size(2)
143 | local scale = math.min(scaleX, scaleY)
144 | local x1, x2, y1, y2
145 | local newTrans
146 | if scaleX < scaleY then
147 | x1 = 1
148 | x2 = opt.heatmapWidth
149 | y1 = math.floor((opt.heatmapHeight - oriIm:size(2) * scale) / 2 + 1 + 0.5)
150 | y2 = math.floor(y1 + scale * oriIm:size(2) - 1 + 0.5)
151 | else
152 | x1 = math.floor((opt.heatmapWidth - oriIm:size(3) * scale) / 2 + 1 + 0.5)
153 | x2 = math.floor(x1 + scale * oriIm:size(3) - 1 + 0.5)
154 | y1 = 1
155 | y2 = opt.heatmapHeight
156 | end
157 |
158 | x1 = math.min(math.max(x1, 1), opt.heatmapWidth)
159 | x2 = math.min(math.max(x2, 1), opt.heatmapWidth)
160 | y1 = math.min(math.max(y1, 1), opt.heatmapHeight)
161 | y2 = math.min(math.max(y2, 1), opt.heatmapHeight)
162 |
163 | local newTrans = torch.Tensor({{scale, 0, x1 - scale}, {0, scale, y1 - scale}})
164 |
165 | -- use trans to get the new coords
166 | oriCoords[i][1] = newTrans * torch.Tensor({1, 1, 1})
167 | oriCoords[i][2] = newTrans * torch.Tensor({oriIm:size(3), oriIm:size(2), 1})
168 | oriCoords[i][1][1] = math.floor(oriCoords[i][1][1] + 0.5)
169 | oriCoords[i][1][2] = math.floor(oriCoords[i][1][2] + 0.5)
170 | oriCoords[i][2][1] = math.floor(oriCoords[i][2][1] + 0.5)
171 | oriCoords[i][2][2] = math.floor(oriCoords[i][2][2] + 0.5)
172 | oriCoords[i][1][1] = math.min(math.max(oriCoords[i][1][1], 1), opt.heatmapWidth)
173 | oriCoords[i][1][2] = math.min(math.max(oriCoords[i][1][2], 1), opt.heatmapHeight)
174 | oriCoords[i][2][1] = math.min(math.max(oriCoords[i][2][1], 1), opt.heatmapWidth)
175 | oriCoords[i][2][2] = math.min(math.max(oriCoords[i][2][2], 1), opt.heatmapHeight)
176 | end
177 |
178 | torch.save(opt.scalePath, oriCoords)
179 |
180 | -- build pyramid and do lcn
181 | print('generate normalized input')
182 | local input = buildPyramidsAndLCN(images, opt)
183 | torch.save(opt.inputPath, input)
184 | end
185 |
186 | -------------------------------------------------------------------------------
187 | -- measure function, save all predictions for evaluation
188 | function test(opt)
189 | local time = sys.clock()
190 |
191 | local model = torch.load(opt.modelPath)
192 | model:evaluate()
193 |
194 | -- load test input/output
195 | local inputs = torch.load(opt.inputPath)
196 | local scales = torch.load(opt.scalePath)
197 | local outputs = torch.zeros(inputs:size(1), opt.paramNum)
198 |
199 | -- test over given dataset
200 | for t = 1, inputs:size(1), opt.batchSize do
201 | -- disp progress
202 | xlua.progress(t, inputs:size(1))
203 |
204 | -- create mini batch
205 | local batchStart = t
206 | local batchEnd = math.min(t + opt.batchSize - 1, inputs:size(1))
207 | local batchIndices = {batchStart, batchEnd}
208 | local batchInputs = inputs[{batchIndices}]:cuda()
209 | local batchScales = scales[{batchIndices}]:cuda()
210 |
211 | local batchOutputs = model:forward({batchInputs, batchScales}):clone():double()
212 | for i = batchStart, batchEnd do
213 | outputs[i] = unnormOutput(batchOutputs[i - batchStart + 1], opt)
214 | end
215 | end
216 |
217 | -- timing
218 | time = sys.clock() - time
219 | time = time / inputs:size(1)
220 | print(" time to test 1 sample = " .. (time * 1000) .. 'ms')
221 |
222 | -- save .mat file
223 | require('fb.mattorch').save(opt.resPath, {outputs = outputs})
224 | end
225 |
226 | -------------------------------------------------------------------------------
227 | print '==> prepare data'
228 | loadData(opt)
229 |
230 | -------------------------------------------------------------------------------
231 | print '==> evaluate'
232 | test(opt)
233 |
234 | local cmd = 'matlab -nodisplay -nodesktop -nojvm -r ' ..
235 | '"evaluate(\'' .. opt.resPath .. '\', \'' .. opt.wwwPath ..
236 | '\', \'' .. opt.class .. '\', ' .. opt.inputWidth ..
237 | ', ' .. opt.inputHeight .. '); exit;"'
238 | print('Testing: ' .. cmd)
239 | os.execute(cmd)
240 |
241 |
--------------------------------------------------------------------------------
/src/nn/Scale.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Input: A length-2 Table
3 | Input[1]: Responses, batchSize x nChannels x iH x iW
4 | Input[2]: Pairs (minW, minH, maxW, maxH), batchSize x 2 x 2
5 |
6 | Output: Scaled responses with zero paddings
7 | batchSize x nChannels x iH x iW
8 | --]]
9 | require 'image'
10 |
11 | local Scale, parent = torch.class('nn.Scale', 'nn.Module')
12 |
13 | function Scale:__init(sum)
14 | parent.__init(self)
15 |
16 | self.sum = sum
17 | self.gradInput = {}
18 | end
19 |
20 |
21 | function Scale:updateOutput(input)
22 | assert(#input == 2)
23 | assert(input[1]:size(1) == input[2]:size(1))
24 |
25 | local batchSize = input[1]:size(1)
26 | local nChannels = input[1]:size(2)
27 | local iH = input[1]:size(3)
28 | local iW = input[1]:size(4)
29 | self.output:resize(batchSize, nChannels, iH, iW):zero()
30 |
31 | self.buffer = self.buffer or input[1].new()
32 | self.buffer:resize(batchSize, nChannels)
33 |
34 | for i = 1, batchSize do
35 | local minW = input[2][i][1][1]
36 | local minH = input[2][i][1][2]
37 | local maxW = input[2][i][2][1]
38 | local maxH = input[2][i][2][2]
39 | local ratio = (maxW - minW + 1) * (maxH - minH + 1) / (iW * iH)
40 |
41 | self.output[{i, {}, {minH, maxH}, {minW, maxW}}] =
42 | image.scale(input[1][i]:double(), maxW - minW + 1, maxH - minH + 1):cuda():mul(1/ratio)
43 |
44 | for j = 1, nChannels do
45 | if self.output[i][j]:sum() == 0 then
46 | self.buffer[i][j] = 0
47 | else
48 | self.buffer[i][j] = math.min(self.sum / self.output[i][j]:sum(), 100)
49 | self.output[i][j]:mul(self.buffer[i][j])
50 | end
51 | if self.output[i][j]:ne(self.output[i][j]):sum() > 0 then
52 | print(self.buffer[i][j])
53 | print(i..' '..j)
54 | print('!!!!!!!!!!!!!!!!!!!!!!! Found NaN in output !!!!!!!!!!!!!!!!!!!!!!!')
55 | end
56 | end
57 | end
58 |
59 | return self.output
60 | end
61 |
62 | function Scale:updateGradInput(input, gradOutput)
63 | for i = 1, #input do
64 | if self.gradInput[i] == nil then
65 | self.gradInput[i] = input[i].new()
66 | end
67 | self.gradInput[i]:resizeAs(input[i]):zero()
68 | end
69 |
70 | local batchSize = input[1]:size(1)
71 | local nChannels = input[1]:size(2)
72 | local iH = input[1]:size(3)
73 | local iW = input[1]:size(4)
74 |
75 | for i = 1, batchSize do
76 | local minW = input[2][i][1][1]
77 | local minH = input[2][i][1][2]
78 | local maxW = input[2][i][2][1]
79 | local maxH = input[2][i][2][2]
80 | local ratio = (maxW - minW + 1) * (maxH - minH + 1) / (iW * iH)
81 |
82 | self.gradInput[1][i] =
83 | image.scale(gradOutput[{i, {}, {minH, maxH}, {minW, maxW}}]:double(), iW, iH):cuda()
84 | for j = 1, nChannels do
85 | self.gradInput[1][i][j]:mul(self.buffer[i][j])
86 | end
87 | end
88 |
89 | return self.gradInput
90 | end
91 |
92 | function Scale:__tostring__()
93 | return string.format('%s()', torch.type(self))
94 | end
95 |
--------------------------------------------------------------------------------
/src/pyramidLCN.lua:
--------------------------------------------------------------------------------
1 | function gaussianPyramid(...)
2 | local dst, src, scales
3 | local args = {...}
4 | if select('#',...) == 3 then
5 | dst = args[1]
6 | src = args[2]
7 | scales = args[3]
8 | elseif select('#',...) == 2 then
9 | dst = {}
10 | src = args[1]
11 | scales = args[2]
12 | end
13 | if src:nDimension() == 2 then
14 | for i = 1, #scales do
15 | dst[i] = dst[i] or torch.Tensor()
16 | dst[i]:resize(src:size(1) * scales[i], src:size(2) * scales[i])
17 | end
18 | elseif src:nDimension() == 3 then
19 | for i = 1, #scales do
20 | dst[i] = dst[i] or torch.Tensor()
21 | dst[i]:resize(src:size(1), src:size(2) * scales[i], src:size(3) * scales[i])
22 | end
23 | end
24 | local kernel = image.gaussian1D{width = 3, normalize = true}
25 |
26 | local padH = math.floor(kernel:size(1) / 2)
27 | local padW = padH
28 |
29 | local gaussianFilter = nn.Sequential()
30 | gaussianFilter:add(nn.SpatialZeroPadding(padW, padW, padH, padH))
31 | gaussianFilter:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(1), kernel:size(1), 1))
32 | gaussianFilter:add(nn.SpatialConvolution(1, 1, 1, kernel:size(1), 1))
33 |
34 | for i = 1, 1 do
35 | gaussianFilter.modules[2].weight[i]:copy(kernel)
36 | gaussianFilter.modules[3].weight[1][i]:copy(kernel)
37 | end
38 |
39 | gaussianFilter.modules[2].bias:zero()
40 | gaussianFilter.modules[3].bias:zero()
41 |
42 | local tmp = src
43 | for i = 1, #scales do
44 | if scales[i] == 1 then
45 | dst[i][{}] = tmp
46 | else
47 | image.scale(dst[i], tmp, 'simple')
48 | end
49 |
50 | local reshaped = dst[i]:reshape(dst[i]:size(1), 1, dst[i]:size(2), dst[i]:size(3))
51 |
52 | local coef = gaussianFilter:updateOutput(reshaped.new():resizeAs(reshaped):fill(1))
53 | coef = coef:clone()
54 |
55 | local filtered = gaussianFilter:updateOutput(reshaped)
56 | local filteredNormalized = nn.CDivTable():updateOutput{filtered, coef}
57 |
58 | tmp = filteredNormalized:reshape(dst[i]:size(1), dst[i]:size(2), dst[i]:size(3))
59 | end
60 | return dst
61 | end
62 |
63 | function customLCN(inputs, kernel, threshold, thresval)
64 | assert (inputs:dim() == 4, "Input should be of the form nSamples x nChannels x width x height")
65 |
66 | local padH = math.floor(kernel:size(1)/2)
67 | local padW = padH
68 |
69 | -- normalize the kernel
70 | kernel:div(kernel:sum())
71 |
72 | local meanestimator = nn.Sequential()
73 | meanestimator:add(nn.SpatialZeroPadding(padW, padW, padH, padH))
74 | meanestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(1), kernel:size(1), 1))
75 | meanestimator:add(nn.SpatialConvolution(1, 1, 1, kernel:size(1), 1))
76 |
77 | local stdestimator = nn.Sequential()
78 | stdestimator:add(nn.Square())
79 | stdestimator:add(nn.SpatialZeroPadding(padW, padW, padH, padH))
80 | stdestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(1), kernel:size(1), 1))
81 | stdestimator:add(nn.SpatialConvolution(1, 1, 1, kernel:size(1)))
82 | stdestimator:add(nn.Sqrt())
83 |
84 | for i = 1,1 do
85 | meanestimator.modules[2].weight[i]:copy(kernel)
86 | meanestimator.modules[3].weight[1][i]:copy(kernel)
87 | stdestimator.modules[3].weight[i]:copy(kernel)
88 | stdestimator.modules[4].weight[1][i]:copy(kernel)
89 | end
90 | meanestimator.modules[2].bias:zero()
91 | meanestimator.modules[3].bias:zero()
92 | stdestimator.modules[3].bias:zero()
93 | stdestimator.modules[4].bias:zero()
94 |
95 | local coef = meanestimator:updateOutput(inputs.new():resizeAs(inputs):fill(1))
96 | coef = coef:clone()
97 |
98 | local localSums = meanestimator:updateOutput(inputs)
99 | local adjustedSums = nn.CDivTable():updateOutput{localSums, coef}
100 | local meanSubtracted = nn.CSubTable():updateOutput{inputs, adjustedSums}
101 |
102 | local localStds = stdestimator:updateOutput(meanSubtracted)
103 | local adjustedStds = nn.CDivTable():updateOutput{localStds, coef}
104 | local thresholdedStds = nn.Threshold(threshold, thresval):updateOutput(adjustedStds)
105 | local outputs = nn.CDivTable():updateOutput{meanSubtracted, thresholdedStds}
106 |
107 | return outputs
108 | end
109 |
110 | -- Build a pyramid and run LCN on each of the inputs
111 | function buildPyramidsAndLCN(convInputs, opt)
112 | local bankNum = #opt.bankSize
113 |
114 | -- opt.bankSize[1] is the biggest size.
115 | -- From one image, we get multiple banks of different sizes (pyramid)
116 | local convInputsNormalized = torch.Tensor(convInputs:size(1),
117 | bankNum, opt.chnNum, opt.pyramidSize[1][1] * opt.pyramidSize[1][2]):zero()
118 |
119 | local pyramid = {}
120 | for j = 1, bankNum do
121 | pyramid[j] = torch.Tensor(convInputs:size(1),
122 | opt.chnNum, opt.pyramidSize[j][1], opt.pyramidSize[j][2])
123 | end
124 |
125 | local timer = torch.Timer()
126 | print('Building '..convInputs:size(1)..' pyramids...')
127 |
128 | -- Build pyramids
129 | local imageSize = math.max(convInputs:size(3), convInputs:size(4))
130 | local pyScales = {}
131 | for i = 1, bankNum do
132 | pyScales[i] = opt.bankSize[i] / imageSize
133 | end
134 |
135 | for i = 1, opt.chnNum do
136 | chnPyramid = {}
137 | gaussianPyramid(chnPyramid, convInputs[{{}, i, {}, {}}], pyScales)
138 |
139 | for j = 1, bankNum do
140 | pyramid[j][{{}, i, {}, {}}] = chnPyramid[j]
141 | end
142 | end
143 | print('Built pyramids : ' .. timer:time().real .. ' seconds')
144 |
145 | -- Define the normalization neighborhood:
146 | local neighborhood = image.gaussian1D(7)
147 |
148 | -- Use spatial contrastive normalization (LCN) to filter the input
149 | print('Applying LCN to '..convInputs:size(1)..' pyramids...')
150 | timer:reset()
151 | for scale = 1, bankNum do
152 | for chn = 1, opt.chnNum do
153 | convInputsNormalized[{{}, scale, chn,
154 | {1, pyramid[scale]:size(3) * pyramid[scale]:size(4)}}] =
155 | customLCN(pyramid[scale][{{}, chn, {}, {}}]:
156 | reshape(pyramid[scale]:size(1), 1, pyramid[scale]:size(3),
157 | pyramid[scale]:size(4)), neighborhood, 1, 1)
158 | end
159 | end
160 | print('Applied LCN : ' .. timer:time().real .. ' seconds')
161 |
162 | return convInputsNormalized
163 | end
164 |
165 |
--------------------------------------------------------------------------------