├── .gitignore
├── Merge test field.ipynb
├── Merge_predict.ipynb
├── README.md
├── Split_predict.ipynb
├── data_utils
├── Merge_data.py
├── Split_data.py
└── __init__.py
├── dataset
├── __init__.py
└── dataset.py
├── images
├── merge_example.jpg
├── split_input.jpg
└── split_output.jpg
├── loss
├── __init__.py
└── loss.py
├── merge
├── __init__.py
├── test.py
└── train.py
├── modules
├── __init__.py
├── merge_modules.py
└── split_modules.py
├── requirements.txt
└── split
├── __init__.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | *.pyc
3 | /__pycache__
--------------------------------------------------------------------------------
/Merge test field.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Predict script for Merge model, predict the D and R matrices, and visualize the result."
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import os"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "from modules.merge_modules import MergeModel\n",
35 | "import torch.backends.cudnn as cudnn\n",
36 | "import torch\n",
37 | "import json\n",
38 | "from dataset.dataset import ImageDataset\n",
39 | "from PIL import Image\n",
40 | "import cv2\n",
41 | "import numpy as np"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "# init the Merge model\n",
51 | "net = MergeModel(3).cuda()\n",
52 | "cudnn.benchmark = True\n",
53 | "cudnn.deterministic = True\n",
54 | "net = torch.nn.DataParallel(net).cuda()"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "# load saved checkpoint \n",
64 | "net.load_state_dict(torch.load('Merge_model.pth'))"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {
71 | "scrolled": true
72 | },
73 | "outputs": [],
74 | "source": [
75 | "# change the model to eval mode\n",
76 | "net.eval()"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "# init dataset\n",
86 | "folder = 'validation'\n",
87 | "with open('D:/dataset/table/table_line/Split1/'+ folder+'_merge_dict.json', 'r') as f:\n",
88 | " labels = json.load(f)\n",
89 | "dataset = ImageDataset('D:/dataset/table/table_line/Split1/'+ folder+'_input', labels, 8, scale=0.25,mode='merge')"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "index = 0\n",
99 | "img, label, arc = dataset[index]\n",
100 | "index += 1"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {
107 | "scrolled": true
108 | },
109 | "outputs": [],
110 | "source": [
111 | "# predict \n",
112 | "input_img = img.unsqueeze(0)\n",
113 | "arc_c = [[torch.Tensor([y]) for y in x] for x in arc]\n",
114 | "pred = net(input_img,arc_c)\n",
115 | "u,d,l,r = pred # up, down, left, right\n",
116 | "# calculate D and R matrice, \n",
117 | "D = 0.5 * u[:, :-1, :] * d[:, 1:, :] + 0.25 * (u[:, :-1, :] + d[:, 1:, :])\n",
118 | "R = 0.5 * r[:, :, :-1] * l[:, :, 1:] + 0.25 * (r[:, :, :-1] + l[:, :, 1:])\n",
119 | "D = D[0].detach().cpu().numpy()\n",
120 | "R = R[0].detach().cpu().numpy()\n",
121 | "D[D>0.5] = 1\n",
122 | "D[D<=0.5] = 0\n",
123 | "R[R>0.5] = 1\n",
124 | "R[R<=0.5] = 0\n",
125 | "\n",
126 | "rows, columns = arc\n",
127 | "h,w = img[2].shape\n",
128 | "rows = [round(h*x) for x in rows]\n",
129 | "columns = [round(w*x) for x in columns]\n",
130 | "rows = [0] + rows + [h]\n",
131 | "columns = [0] + columns +[w]\n",
132 | "\n",
133 | "# draw lines on the original image\n",
134 | "draw_img = img[2].numpy()*255.\n",
135 | "draw_img = cv2.cvtColor(draw_img, cv2.COLOR_GRAY2RGB)\n",
136 | "for i in range(R.shape[0]):\n",
137 | " for j in range(R.shape[1]):\n",
138 | " if R[i,j] == 0:\n",
139 | " pts1 = (columns[j+1],rows[i])\n",
140 | " pts2 = (columns[j+1],rows[i+1])\n",
141 | " draw_img = cv2.line(draw_img, pts1,pts2,(255.,0,0),2)\n",
142 | "for i in range(D.shape[0]):\n",
143 | " for j in range(D.shape[1]):\n",
144 | " if D[i,j] == 0:\n",
145 | " pts1 = (columns[j],rows[i+1])\n",
146 | " pts2 = (columns[j+1],rows[i+1])\n",
147 | " draw_img = cv2.line(draw_img, pts1,pts2,(255.,0,0),2)"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {
154 | "scrolled": true
155 | },
156 | "outputs": [],
157 | "source": [
158 | "# visualize original image\n",
159 | "Image.fromarray(img[2].numpy()*255.).convert('L')"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {
166 | "scrolled": false
167 | },
168 | "outputs": [],
169 | "source": [
170 | "# visualize merged image\n",
171 | "Image.fromarray(np.array(draw_img,dtype=np.uint8))"
172 | ]
173 | }
174 | ],
175 | "metadata": {
176 | "kernelspec": {
177 | "display_name": "Python 3",
178 | "language": "python",
179 | "name": "python3"
180 | },
181 | "language_info": {
182 | "codemirror_mode": {
183 | "name": "ipython",
184 | "version": 3
185 | },
186 | "file_extension": ".py",
187 | "mimetype": "text/x-python",
188 | "name": "python",
189 | "nbconvert_exporter": "python",
190 | "pygments_lexer": "ipython3",
191 | "version": "3.6.5"
192 | }
193 | },
194 | "nbformat": 4,
195 | "nbformat_minor": 2
196 | }
197 |
--------------------------------------------------------------------------------
/Merge_predict.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Predict script for Merge model, predict the D and R matrice, and visualize resule."
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import os"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "from modules.merge_modules import MergeModel\n",
35 | "import torch.backends.cudnn as cudnn\n",
36 | "import torch\n",
37 | "import json\n",
38 | "from dataset.dataset import ImageDataset\n",
39 | "from PIL import Image\n",
40 | "import cv2\n",
41 | "import numpy as np"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "# init the Merge model\n",
51 | "net = MergeModel(3).cuda()\n",
52 | "cudnn.benchmark = True\n",
53 | "cudnn.deterministic = True\n",
54 | "net = torch.nn.DataParallel(net).cuda()"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "# load saved checkpoint \n",
64 | "net.load_state_dict(torch.load('Merge_model.pth'))"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {
71 | "scrolled": true
72 | },
73 | "outputs": [],
74 | "source": [
75 | "# change the model to eval mode\n",
76 | "net.eval()"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "# init dataset\n",
86 | "folder = 'validation'\n",
87 | "with open('D:/dataset/table/table_line/Split1/'+ folder+'_merge_dict.json', 'r') as f:\n",
88 | " labels = json.load(f)\n",
89 | "dataset = ImageDataset('D:/dataset/table/table_line/Split1/'+ folder+'_input', labels, 8, scale=0.25,mode='merge')"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "index = 0\n",
99 | "img, label, arc = dataset[index]\n",
100 | "index += 1"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {
107 | "scrolled": true
108 | },
109 | "outputs": [],
110 | "source": [
111 | "# predict \n",
112 | "input_img = img.unsqueeze(0)\n",
113 | "arc_c = [[torch.Tensor([y]) for y in x] for x in arc]\n",
114 | "pred = net(input_img,arc_c)\n",
115 | "u,d,l,r = pred # up, down, left, right\n",
116 | "# calculate D and R matrice, \n",
117 | "D = 0.5 * u[:, :-1, :] * d[:, 1:, :] + 0.25 * (u[:, :-1, :] + d[:, 1:, :])\n",
118 | "R = 0.5 * r[:, :, :-1] * l[:, :, 1:] + 0.25 * (r[:, :, :-1] + l[:, :, 1:])\n",
119 | "D = D[0].detach().cpu().numpy()\n",
120 | "R = R[0].detach().cpu().numpy()\n",
121 | "D[D>0.5] = 1\n",
122 | "D[D<=0.5] = 0\n",
123 | "R[R>0.5] = 1\n",
124 | "R[R<=0.5] = 0\n",
125 | "\n",
126 | "rows, columns = arc\n",
127 | "h,w = img[2].shape\n",
128 | "rows = [round(h*x) for x in rows]\n",
129 | "columns = [round(w*x) for x in columns]\n",
130 | "rows = [0] + rows + [h]\n",
131 | "columns = [0] + columns +[w]\n",
132 | "\n",
133 | "# draw lines on the original image\n",
134 | "draw_img = img[2].numpy()*255.\n",
135 | "draw_img = cv2.cvtColor(draw_img, cv2.COLOR_GRAY2RGB)\n",
136 | "for i in range(R.shape[0]):\n",
137 | " for j in range(R.shape[1]):\n",
138 | " if R[i,j] == 0:\n",
139 | " pts1 = (columns[j+1],rows[i])\n",
140 | " pts2 = (columns[j+1],rows[i+1])\n",
141 | " draw_img = cv2.line(draw_img, pts1,pts2,(255.,0,0),2)\n",
142 | "for i in range(D.shape[0]):\n",
143 | " for j in range(D.shape[1]):\n",
144 | " if D[i,j] == 0:\n",
145 | " pts1 = (columns[j],rows[i+1])\n",
146 | " pts2 = (columns[j+1],rows[i+1])\n",
147 | " draw_img = cv2.line(draw_img, pts1,pts2,(255.,0,0),2)"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {
154 | "scrolled": true
155 | },
156 | "outputs": [],
157 | "source": [
158 | "# visualize original image\n",
159 | "Image.fromarray(img[2].numpy()*255.).convert('L')"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {
166 | "scrolled": false
167 | },
168 | "outputs": [],
169 | "source": [
170 | "# visualize merged image\n",
171 | "Image.fromarray(np.array(draw_img,dtype=np.uint8))"
172 | ]
173 | }
174 | ],
175 | "metadata": {
176 | "kernelspec": {
177 | "display_name": "Python 3",
178 | "language": "python",
179 | "name": "python3"
180 | },
181 | "language_info": {
182 | "codemirror_mode": {
183 | "name": "ipython",
184 | "version": 3
185 | },
186 | "file_extension": ".py",
187 | "mimetype": "text/x-python",
188 | "name": "python",
189 | "nbconvert_exporter": "python",
190 | "pygments_lexer": "ipython3",
191 | "version": "3.6.5"
192 | }
193 | },
194 | "nbformat": 4,
195 | "nbformat_minor": 2
196 | }
197 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### Split&Merge: Table Recognition with Pytorch
2 |
3 | An implementation of Table Recognition Model Split&Merge in Pytorch. Split&Merge is an efficient convolutional neural network architecture for recognizing table structure from images. For more detail, please check the paper from ICDAR 2019: Deep Splitting and Merging for Table Structure Decomposition
4 |
5 | ## Usage
6 |
7 | **Clone the repo:**
8 |
9 | ```
10 | git clone https://github.com/solitaire2015/Split_Merge_table_recognition.git
11 | pip install -r requirements.txt
12 | ```
13 |
14 | **Prepare the training data:**
15 |
16 | The input of the Split model and Merge model should be an image which has one channels or three channels, I used three channels image, ones channel is gray scale image, the other two are segmentation mask in vertical and horizontal direction. You can use gray scale image only by setting number of channels to 1 and dataset suffix to `.jpg`.
17 |
18 | Split model
19 |
20 | The ground truth of Split model is loaded from a `.json` file, the structure of the file is like:
21 |
22 | ``````
23 | {
24 | “img_1”:{"rows":[0,0,1,0,1,1],"columns":[1,0,0,1,1,1,1]},
25 | “img_2”:{"rows":[0,0,1,0,1,1],"columns":[1,0,0,1,1,1,1]}
26 | }
27 | ``````
28 |
29 | Where `row` indicates if it's a line in corresponding row of the image, the length of `row` is height of the image. `columns` indicates corresponding column the length is width of the image.
30 |
31 | Merge model
32 |
33 | The ground truth of Merge is loaded from a `.json` file, here is the structure:
34 |
35 | ``````
36 | {
37 | “img_1”:{"rows":[0,0,1,0,1,1],
38 | "columns":[1,0,0,1,1,1,1],
39 | "h_matrix":[[0,0,0],
40 | [1,0,1]]},
41 | "v_matrix":[[0,1,0],
42 | [0,0,1]]}
43 | }
44 | ``````
45 |
46 | where the `h_matrix` indicates where the cells should be merged with it's right neighborhood, `v_matrix` indicates where the cells should be merged with it's bottom neighborhood, check more detail from the original paper, `h_matrix` and `v_matrix` are `R` and `D` in that paper.
47 |
48 | There are two scripts in `data_utils` folder to generate training data.
49 |
50 | **Training**
51 |
52 | Train Split model :
53 |
54 | ``````
55 | python split/train.py --img_dir your image dir -- json_dir your json file -- saved_dir where you want to save model --val_img_dir your validation image dir -- val_json your validation json file
56 | ``````
57 |
58 | Train Merge model:
59 |
60 | ``````
61 | python merge/train.py --img_dir your image dir -- json_dir your json file -- saved_dir where you want to save model --val_img_dir your validation image dir -- val_json your validation json file
62 | ``````
63 | Pre-trained models:
64 | * Merge: https://pan.baidu.com/s/112ElDK-MSMQ50CTKpkPCZQ code:rto4
65 |
66 | * Split: https://pan.baidu.com/s/1rSO6o23WbKV6jXo2DRrQcw code:lu23
67 |
68 | Run the predict script:
69 |
70 | ``````
71 | jupyter notebook
72 | ``````
73 |
74 | Open `Split_predict.ipynb` and `Merge_predict.ipynb` in your browser.
75 |
76 | ## Result
77 |
78 | I didn't test the model on public ICDAR table recognition competition dataset, I test the model on my private dataset and got 97% F-score, you can test on ICDAR dataset by yourself.
79 |
80 | Here are some example:
81 |
82 | ## Images
83 |
84 | 
85 |
86 | Fig1. Original image
87 |
88 | 
89 |
90 | Fig2. Split result
91 |
92 | 
93 |
94 | Fig3. one Merge result.
95 |
96 |
--------------------------------------------------------------------------------
/Split_predict.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Predict script for Split model, predict the D and R matrices, and visualize the result."
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import os\n",
17 | "from dataset.dataset import ImageDataset\n",
18 | "from modules.split_modules import SplitModel\n",
19 | "import json\n",
20 | "from PIL import Image\n",
21 | "import torch\n",
22 | "from torchsummary import summary\n",
23 | "import numpy as np\n",
24 | "import cv2\n",
25 | "import json"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 9,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "# load dataset\n",
44 | "folder = 'train'\n",
45 | "with open('D:/dataset/table/table_line/Split1/'+ folder+'_labels.json', 'r') as f:\n",
46 | " labels = json.load(f)\n",
47 | "dataset = ImageDataset('D:/dataset/table/table_line/Split1/'+ folder+'_input', labels, 8, scale=0.25)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 4,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "# init model\n",
57 | "net = SplitModel(3)\n",
58 | "net = torch.nn.DataParallel(net).cuda()"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 5,
64 | "metadata": {},
65 | "outputs": [
66 | {
67 | "data": {
68 | "text/plain": [
69 | "IncompatibleKeys(missing_keys=[], unexpected_keys=[])"
70 | ]
71 | },
72 | "execution_count": 5,
73 | "metadata": {},
74 | "output_type": "execute_result"
75 | }
76 | ],
77 | "source": [
78 | "# load saved checkpoint\n",
79 | "net.load_state_dict(torch.load('split_model.pth'))"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 6,
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "data": {
89 | "text/plain": [
90 | "DataParallel(\n",
91 | " (module): SplitModel(\n",
92 | " (sfcn): SFCN(\n",
93 | " (conv1): Sequential(\n",
94 | " (0): Conv2d(3, 18, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)\n",
95 | " (1): ReLU(inplace)\n",
96 | " )\n",
97 | " (conv2): Sequential(\n",
98 | " (0): Conv2d(18, 18, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)\n",
99 | " (1): ReLU(inplace)\n",
100 | " )\n",
101 | " (conv3): Sequential(\n",
102 | " (0): Conv2d(18, 18, kernel_size=(7, 7), stride=(1, 1), padding=(6, 6), dilation=(2, 2), bias=False)\n",
103 | " (1): ReLU(inplace)\n",
104 | " )\n",
105 | " )\n",
106 | " (rpn1): ProjectionNet(\n",
107 | " (conv_branch1): Sequential(\n",
108 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n",
109 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
110 | " (2): ReLU(inplace)\n",
111 | " )\n",
112 | " (conv_branch2): Sequential(\n",
113 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n",
114 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
115 | " (2): ReLU(inplace)\n",
116 | " )\n",
117 | " (conv_branch3): Sequential(\n",
118 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n",
119 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
120 | " (2): ReLU(inplace)\n",
121 | " )\n",
122 | " (project_module): ProjectionModule(\n",
123 | " (max_pool): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)\n",
124 | " (feature_conv): Sequential(\n",
125 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
126 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n",
127 | " (2): ReLU(inplace)\n",
128 | " )\n",
129 | " (prediction_conv): Sequential(\n",
130 | " (0): Dropout2d(p=0)\n",
131 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
132 | " )\n",
133 | " (feature_project): ProjectPooling()\n",
134 | " (prediction_project): Sequential(\n",
135 | " (0): ProjectPooling()\n",
136 | " (1): Sigmoid()\n",
137 | " )\n",
138 | " )\n",
139 | " )\n",
140 | " (rpn2): ProjectionNet(\n",
141 | " (conv_branch1): Sequential(\n",
142 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n",
143 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
144 | " (2): ReLU(inplace)\n",
145 | " )\n",
146 | " (conv_branch2): Sequential(\n",
147 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n",
148 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
149 | " (2): ReLU(inplace)\n",
150 | " )\n",
151 | " (conv_branch3): Sequential(\n",
152 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n",
153 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
154 | " (2): ReLU(inplace)\n",
155 | " )\n",
156 | " (project_module): ProjectionModule(\n",
157 | " (max_pool): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)\n",
158 | " (feature_conv): Sequential(\n",
159 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
160 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n",
161 | " (2): ReLU(inplace)\n",
162 | " )\n",
163 | " (prediction_conv): Sequential(\n",
164 | " (0): Dropout2d(p=0)\n",
165 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
166 | " )\n",
167 | " (feature_project): ProjectPooling()\n",
168 | " (prediction_project): Sequential(\n",
169 | " (0): ProjectPooling()\n",
170 | " (1): Sigmoid()\n",
171 | " )\n",
172 | " )\n",
173 | " )\n",
174 | " (rpn3): ProjectionNet(\n",
175 | " (conv_branch1): Sequential(\n",
176 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n",
177 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
178 | " (2): ReLU(inplace)\n",
179 | " )\n",
180 | " (conv_branch2): Sequential(\n",
181 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n",
182 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
183 | " (2): ReLU(inplace)\n",
184 | " )\n",
185 | " (conv_branch3): Sequential(\n",
186 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n",
187 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
188 | " (2): ReLU(inplace)\n",
189 | " )\n",
190 | " (project_module): ProjectionModule(\n",
191 | " (max_pool): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)\n",
192 | " (feature_conv): Sequential(\n",
193 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
194 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n",
195 | " (2): ReLU(inplace)\n",
196 | " )\n",
197 | " (prediction_conv): Sequential(\n",
198 | " (0): Dropout2d(p=0.3)\n",
199 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
200 | " )\n",
201 | " (feature_project): ProjectPooling()\n",
202 | " (prediction_project): Sequential(\n",
203 | " (0): ProjectPooling()\n",
204 | " (1): Sigmoid()\n",
205 | " )\n",
206 | " )\n",
207 | " )\n",
208 | " (rpn4): ProjectionNet(\n",
209 | " (conv_branch1): Sequential(\n",
210 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n",
211 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
212 | " (2): ReLU(inplace)\n",
213 | " )\n",
214 | " (conv_branch2): Sequential(\n",
215 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n",
216 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
217 | " (2): ReLU(inplace)\n",
218 | " )\n",
219 | " (conv_branch3): Sequential(\n",
220 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n",
221 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
222 | " (2): ReLU(inplace)\n",
223 | " )\n",
224 | " (project_module): ProjectionModule(\n",
225 | " (max_pool): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)\n",
226 | " (feature_conv): Sequential(\n",
227 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
228 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n",
229 | " (2): ReLU(inplace)\n",
230 | " )\n",
231 | " (prediction_conv): Sequential(\n",
232 | " (0): Dropout2d(p=0)\n",
233 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
234 | " )\n",
235 | " (feature_project): ProjectPooling()\n",
236 | " (prediction_project): Sequential(\n",
237 | " (0): ProjectPooling()\n",
238 | " (1): Sigmoid()\n",
239 | " )\n",
240 | " )\n",
241 | " )\n",
242 | " (cpn1): ProjectionNet(\n",
243 | " (conv_branch1): Sequential(\n",
244 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n",
245 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
246 | " (2): ReLU(inplace)\n",
247 | " )\n",
248 | " (conv_branch2): Sequential(\n",
249 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n",
250 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
251 | " (2): ReLU(inplace)\n",
252 | " )\n",
253 | " (conv_branch3): Sequential(\n",
254 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n",
255 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
256 | " (2): ReLU(inplace)\n",
257 | " )\n",
258 | " (project_module): ProjectionModule(\n",
259 | " (max_pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
260 | " (feature_conv): Sequential(\n",
261 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
262 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n",
263 | " (2): ReLU(inplace)\n",
264 | " )\n",
265 | " (prediction_conv): Sequential(\n",
266 | " (0): Dropout2d(p=0)\n",
267 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
268 | " )\n",
269 | " (feature_project): ProjectPooling()\n",
270 | " (prediction_project): Sequential(\n",
271 | " (0): ProjectPooling()\n",
272 | " (1): Sigmoid()\n",
273 | " )\n",
274 | " )\n",
275 | " )\n",
276 | " (cpn2): ProjectionNet(\n",
277 | " (conv_branch1): Sequential(\n",
278 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n",
279 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
280 | " (2): ReLU(inplace)\n",
281 | " )\n",
282 | " (conv_branch2): Sequential(\n",
283 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n",
284 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
285 | " (2): ReLU(inplace)\n",
286 | " )\n",
287 | " (conv_branch3): Sequential(\n",
288 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n",
289 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
290 | " (2): ReLU(inplace)\n",
291 | " )\n",
292 | " (project_module): ProjectionModule(\n",
293 | " (max_pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
294 | " (feature_conv): Sequential(\n",
295 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
296 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n",
297 | " (2): ReLU(inplace)\n",
298 | " )\n",
299 | " (prediction_conv): Sequential(\n",
300 | " (0): Dropout2d(p=0)\n",
301 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
302 | " )\n",
303 | " (feature_project): ProjectPooling()\n",
304 | " (prediction_project): Sequential(\n",
305 | " (0): ProjectPooling()\n",
306 | " (1): Sigmoid()\n",
307 | " )\n",
308 | " )\n",
309 | " )\n",
310 | " (cpn3): ProjectionNet(\n",
311 | " (conv_branch1): Sequential(\n",
312 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n",
313 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
314 | " (2): ReLU(inplace)\n",
315 | " )\n",
316 | " (conv_branch2): Sequential(\n",
317 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n",
318 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
319 | " (2): ReLU(inplace)\n",
320 | " )\n",
321 | " (conv_branch3): Sequential(\n",
322 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n",
323 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
324 | " (2): ReLU(inplace)\n",
325 | " )\n",
326 | " (project_module): ProjectionModule(\n",
327 | " (max_pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
328 | " (feature_conv): Sequential(\n",
329 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
330 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n",
331 | " (2): ReLU(inplace)\n",
332 | " )\n",
333 | " (prediction_conv): Sequential(\n",
334 | " (0): Dropout2d(p=0.3)\n",
335 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
336 | " )\n",
337 | " (feature_project): ProjectPooling()\n",
338 | " (prediction_project): Sequential(\n",
339 | " (0): ProjectPooling()\n",
340 | " (1): Sigmoid()\n",
341 | " )\n",
342 | " )\n",
343 | " )\n",
344 | " (cpn4): ProjectionNet(\n",
345 | " (conv_branch1): Sequential(\n",
346 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n",
347 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
348 | " (2): ReLU(inplace)\n",
349 | " )\n",
350 | " (conv_branch2): Sequential(\n",
351 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n",
352 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
353 | " (2): ReLU(inplace)\n",
354 | " )\n",
355 | " (conv_branch3): Sequential(\n",
356 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n",
357 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n",
358 | " (2): ReLU(inplace)\n",
359 | " )\n",
360 | " (project_module): ProjectionModule(\n",
361 | " (max_pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
362 | " (feature_conv): Sequential(\n",
363 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
364 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n",
365 | " (2): ReLU(inplace)\n",
366 | " )\n",
367 | " (prediction_conv): Sequential(\n",
368 | " (0): Dropout2d(p=0)\n",
369 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
370 | " )\n",
371 | " (feature_project): ProjectPooling()\n",
372 | " (prediction_project): Sequential(\n",
373 | " (0): ProjectPooling()\n",
374 | " (1): Sigmoid()\n",
375 | " )\n",
376 | " )\n",
377 | " )\n",
378 | " )\n",
379 | ")"
380 | ]
381 | },
382 | "execution_count": 6,
383 | "metadata": {},
384 | "output_type": "execute_result"
385 | }
386 | ],
387 | "source": [
388 | "# change to eval mode\n",
389 | "net.eval()"
390 | ]
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": 24,
395 | "metadata": {},
396 | "outputs": [],
397 | "source": [
398 | "img, label = dataset[50]\n",
399 | "r,c = net(img.unsqueeze(0))\n",
400 | "r = r[-1]>0.5\n",
401 | "c = c[-1]>0.5\n",
402 | "c = c.cpu().detach().numpy()\n",
403 | "r = r.cpu().detach().numpy()\n",
404 | "r_im = r.reshape((-1,1))*np.ones((r.shape[0],c.shape[0]))\n",
405 | "c_im = c.reshape((1,-1))*np.ones((r.shape[0],c.shape[0]))\n",
406 | "im = cv2.bitwise_or(r_im,c_im)"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": 25,
412 | "metadata": {},
413 | "outputs": [],
414 | "source": [
415 | "Image.fromarray(img.numpy()[2]*255.).convert('L')"
416 | ]
417 | },
418 | {
419 | "cell_type": "code",
420 | "execution_count": 26,
421 | "metadata": {
422 | "scrolled": true
423 | },
424 | "outputs": [],
425 | "source": [
426 | "Image.fromarray(im*255.).convert('L')"
427 | ]
428 | }
429 | ],
430 | "metadata": {
431 | "kernelspec": {
432 | "display_name": "Python 3",
433 | "language": "python",
434 | "name": "python3"
435 | },
436 | "language_info": {
437 | "codemirror_mode": {
438 | "name": "ipython",
439 | "version": 3
440 | },
441 | "file_extension": ".py",
442 | "mimetype": "text/x-python",
443 | "name": "python",
444 | "nbconvert_exporter": "python",
445 | "pygments_lexer": "ipython3",
446 | "version": "3.6.5"
447 | }
448 | },
449 | "nbformat": 4,
450 | "nbformat_minor": 2
451 | }
452 |
--------------------------------------------------------------------------------
/data_utils/Merge_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: craig.li(solitaire10@163.com)
3 | # Script for generating data of Merge model..
4 |
5 |
6 | import cv2
7 | import json
8 | import numpy as np
9 | import os
10 |
11 | from PIL import Image
12 |
13 |
14 | def make_merge_data(label, mask_img, threshold=5):
15 | """
16 | Generates merge labels
17 |
18 | Args:
19 | label(list): Table row and column vector, same as labels of Split model.
20 | mask_img(ndarray): Mask drawed with labeled line.
21 | threshold(int): Threshold .
22 | Returns:
23 | h_matrix(ndarray): Label of Merge model data in horizontal direction.
24 | v_matrix(ndarray): Label of Merge model data in vertical direction.
25 | columns(list): Position of vertical lines.
26 | rows(list): Position of horizontal lines.
27 | """
28 | h_line, v_line = label
29 | rows = find_connected_line(h_line)
30 | columns = find_connected_line(v_line)
31 |
32 | h = len(h_line)
33 | w = len(v_line)
34 | mask_img = cv2.resize(mask_img, (w, h), interpolation=cv2.INTER_AREA)
35 | append_columns = [0] + columns + [w]
36 | append_rows = [0] + rows + [h]
37 |
38 | h_matrix = np.zeros((len(rows), len(columns) + 1))
39 | v_matrix = np.zeros((len(rows) + 1, len(columns)))
40 |
41 | for i in range(len(rows)):
42 | for j in range(len(append_columns) - 1):
43 | if np.count_nonzero(mask_img[rows[i], append_columns[j]:append_columns[j + 1]] < 10) > threshold:
44 | h_matrix[i, j] = 1
45 | else:
46 | h_matrix[i, j] = 0
47 | for i in range(len(append_rows) - 1):
48 | for j in range(len(columns)):
49 | if np.count_nonzero(mask_img[append_rows[i]:append_rows[i + 1], columns[j]] < 10) > threshold:
50 | v_matrix[i, j] = 1
51 | else:
52 | v_matrix[i, j] = 0
53 | rows = [x / h for x in rows]
54 | columns = [x / w for x in columns]
55 | return h_matrix, v_matrix, rows, columns
56 |
57 |
58 | def find_connected_line(lines, threshold=5):
59 | """
60 | Gets center of lines.
61 |
62 | Args:
63 | lines(list): A vector indicates position of lines.
64 | threshold(int): Threshold for filtering lines that too close to the border.
65 | """
66 | length = len(lines)
67 | i = 0
68 | blocks = []
69 |
70 | def find_end(start):
71 | end = length - 1
72 | for j in range(start, length - 1):
73 | if lines[j + 1] == 0:
74 | end = j
75 | break
76 | return end
77 |
78 | while i < length:
79 | if lines[i] == 0:
80 | i += 1
81 | else:
82 | end = find_end(i)
83 | blocks.append((i, end))
84 | i = end + 1
85 | if len(blocks) > 0:
86 | if blocks[0][0] <= threshold:
87 | blocks.pop(0)
88 | if len(blocks) > 0:
89 | if (length - blocks[-1][1]) <= threshold:
90 | blocks.pop(-1)
91 | lines_position = [int((x[0] + x[1]) / 2) for x in blocks]
92 | return lines_position
93 |
94 |
95 | if __name__ == "__main__":
96 | mask_img_dir = 'test_out_mask_dir'
97 | with open('test_labels.json', 'r') as f:
98 | dataset = json.load(f)
99 | merge_dict = {}
100 | for id, label in dataset.items():
101 | labels = [np.array(label['row']), np.array(label['column'])]
102 | mask_img = np.array(Image.open(os.path.join(mask_img_dir, id + '.jpg')))
103 | h_matrix, v_matrix, rows, columns = make_merge_data(labels, mask_img)
104 | merge_dict[id] = {'h_matrix': [list(x) for x in list(h_matrix)],
105 | 'v_matrix': [list(x) for x in list(v_matrix)], 'rows': rows, 'columns': columns}
106 | with open('merge_dict.json', 'w') as f:
107 | f.write(json.dumps(merge_dict))
108 |
--------------------------------------------------------------------------------
/data_utils/Split_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: craig.li(solitaire10@163.com)
3 | # Script for generating data of Split model..
4 |
5 |
6 | import cv2
7 | import json
8 | import numpy as np
9 | import os
10 | import xml.etree.ElementTree as ET
11 |
12 | from functools import reduce
13 | from PIL import Image, ImageDraw
14 |
15 |
16 | def make_split_data(file, lines, img_dir, out_img_dir, out_mask_dir, length_threshold=20):
17 | """
18 | Generate data for training Split model, crops and saves image blocks of table from original image,
19 | returns labels of image blocks
20 |
21 | Args:
22 | file(str): File name of original image file.
23 | img_dir(str): The directory contains image files.
24 | out_img_dir(str): A directory to save croped image blocks.
25 | out_mask_dir(str): A directory to save masks of croped image blocks.
26 | length_threshold(int):Threshold of filtering blocks with short width or height.
27 | Returns:
28 | label_dict(dict): A directory, for each item, key is the file id and value contains 'rows' and 'columns',
29 | which are two vectors, 1 indicates there is a line in corresponding row or column.
30 | """
31 | # loads original image
32 | im_path = os.path.join(img_dir, file)
33 | im = Image.open(im_path)
34 | im = im.convert('L')
35 | im_array = np.array(im)
36 | # init blank images for drawing masks.
37 | masks = [Image.new('L', im.size, (0,)) for i in range(4)]
38 | draws = [ImageDraw.Draw(x) for x in masks]
39 | # Calculates angle of the image according to labeled boxes.
40 | thetas = []
41 | for line in lines:
42 | draws[int(line['type'])].polygon(line['coordinates'], fill=1)
43 | if line['type'] == 'h': # line['type'] should be 'v' or 'h' for vertical and horizontal lines
44 | up_left = line['coordinates'][:2]
45 | up_right = line['coordinates'][2:4]
46 | theta = np.arctan((up_right[1] - up_left[1]) / (up_right[0] - up_left[0]))
47 | thetas.append(theta)
48 | theta = np.average(thetas)
49 | matrix = np.array([[np.cos(-theta), -np.sin(-theta), 0],
50 | [np.sin(-theta), np.cos(-theta), 0]])
51 | # rotates mask images and original images.
52 | masks = [cv2.warpAffine(np.array(x), matrix, np.array(x).shape[::-1]) for x in masks]
53 | mask_data = np.array([np.array(x) for x in masks])
54 | im_array = np.array(cv2.warpAffine(im_array, matrix, im_array.shape[::-1]))
55 | mask_img = cv2.bitwise_or(cv2.bitwise_or(mask_data[0], mask_data[2]),
56 | cv2.bitwise_or(mask_data[1], mask_data[3])) * 255.
57 | mask_img = np.array(mask_img, dtype=np.uint8)
58 |
59 | # Splits different tables by connected components.
60 | num, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_img, 4, cv2.CV_32S)
61 | label_dict = {}
62 | for i in range(1, num):
63 | image, contours = cv2.findContours(np.array((labels == i) * 255., dtype=np.uint8), cv2.RETR_EXTERNAL,
64 | cv2.CHAIN_APPROX_SIMPLE)
65 | x, y, w, h = cv2.boundingRect(image[0])
66 | if w < length_threshold or h < length_threshold:
67 | continue
68 | row_start = int(y)
69 | row_end = int(y + h)
70 | column_start = int(x)
71 | column_end = int(x + w)
72 | # filters blocks
73 | if (row_end - row_start) <= 0 or (column_end - column_start) <= 0:
74 | continue
75 | row_table = cv2.bitwise_or(mask_data[0][row_start:row_end, column_start:column_end],
76 | mask_data[2][row_start:row_end, column_start:column_end])
77 | column_table = cv2.bitwise_or(mask_data[1][row_start:row_end, column_start:column_end],
78 | mask_data[3][row_start:row_end, column_start:column_end])
79 | # makes column and row labels.
80 | row_label = []
81 | for row in range(row_table.shape[0]):
82 | row_label.append(1 if np.sum(row_table[row, :]) > 0 else 0)
83 | column_lable = []
84 | for column in range(column_table.shape[1]):
85 | column_lable.append(1 if np.sum(column_table[:, column]) > 0 else 0)
86 | key = file[:-4] + '_' + str(i)
87 | label_dict[key] = {'row': row_label, 'column': column_lable}
88 | # crops image blocks of tables.
89 | mask_table = mask_img[row_start:row_end, column_start:column_end]
90 | original_img = im_array[row_start:row_end, column_start:column_end]
91 | mask_table = Image.fromarray(np.array(mask_table, dtype=np.uint8))
92 | original_img = Image.fromarray(np.array(original_img, dtype=np.uint8))
93 | # saves cropped blocks.
94 | original_img.save(os.path.join(out_img_dir, key + '.jpg'))
95 | mask_table.save(os.path.join(out_mask_dir, key + '.jpg'))
96 | return label_dict
97 |
98 |
99 | def extract_lines_from_xml(root_dir, file):
100 | """
101 | Extracts labels of lines from original xml file.
102 | You can rewrite this function to extract lines from your original label.
103 |
104 | Args:
105 | root_dir(str): The directory to the folder which contains the file.
106 | file(str): The file name of the xml file.
107 | Returns:
108 | lines(list): A list of lines' coordinates, the coordinates should be a list of 8 integers, indicates x, y
109 | coordinates of a corner in the flowing order up_left->up_right->down_right->down_left. Type should be
110 | 'v' or 'h' for vertical and horizontal lines.
111 | """
112 | # extracts data from xml file
113 | tree = ET.parse(os.path.join(root_dir, file))
114 | root = tree.getroot()
115 | elements = root.findall('object')
116 | # Constructs a list contains lines.
117 | lines = []
118 | for element in elements:
119 | category = element.find('name').text
120 | bbox = element.find('bndbox')
121 | coordinates = [] + bbox.find('leftup').text.split(',') + bbox.find('rightup').text.split(',') + bbox.find(
122 | 'rightdown').text.split(',') + bbox.find('leftdown').text.split(',')
123 | coordinates = [int(x) for x in coordinates]
124 | lines.append({'type': category, 'coordinates': coordinates})
125 | return lines
126 |
127 |
128 | def merge_dicts(d0, d1):
129 | """
130 | Merges two directories.
131 | """
132 | for k, v in d1.items():
133 | d0[k] = v
134 | return d0
135 |
136 |
137 | if __name__ == "__main__":
138 | root_dir = 'root_dir'
139 | img_dir = 'img_dir'
140 | out_img_dir = 'test_out_img_dir'
141 | out_mask_dir = 'test_out_mask_dir'
142 | if not os.path.exists(out_img_dir):
143 | os.mkdir(out_img_dir)
144 | if not os.path.exists(out_mask_dir):
145 | os.mkdir(out_mask_dir)
146 | files = os.listdir(img_dir)
147 | json_path = 'test_labels.json'
148 | ids = [x[:-4] for x in files]
149 | label_dicts = []
150 | for id in ids:
151 | file = id + '.xml'
152 | lines = extract_lines_from_xml(root_dir, file)
153 | img_name = id + '.jpg'
154 | label_dict = make_split_data(img_name, lines, img_dir, out_img_dir, out_mask_dir)
155 | label_dicts.append(label_dict)
156 | labels = reduce(merge_dicts, label_dicts)
157 | with open(json_path, 'w') as f:
158 | f.write(json.dumps(labels))
159 |
--------------------------------------------------------------------------------
/data_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/data_utils/__init__.py
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: craig.li(solitaire10@163.com)
3 | # Image dataset for training and testing Split and Merge models(ICDAR 2019 :Deep Splitting and Merging for Table Structure Decomposition).
4 |
5 | import cv2
6 | import numpy as np
7 | import os
8 | import torch
9 |
10 | from PIL import Image
11 | from torch.utils.data import Dataset
12 |
13 |
14 | class ImageDataset(Dataset):
15 | """Image Dataset"""
16 |
17 | def __init__(self, img_dir, labels_dict, output_width, scale=0.5,
18 | min_width=40, mode='split', suffix='.npy'):
19 | """
20 | Initialization of the dataset
21 |
22 | Args:
23 | img_dir(str): The directory of images
24 | labels_dict(dict): A dictionary stores ids of images and
25 | corresponding ground truth, which are two vectors with
26 | w and h elements, where w is the width of image and h
27 | is the height of image, indicates the probability of
28 | there is a line in that row or column.
29 | output_width(int): Defines the width of the output tensor,
30 | scale(float): The scale of resizing image.
31 | min_width(int): Specifies minimal width of resizing image.
32 | mode(str): The model should be one of 'split' and 'merge'
33 | """
34 | self.labels_dict = labels_dict
35 | self.ids = list(labels_dict.keys())
36 | self.nSamples = len(self.ids)
37 | self.img_dir = img_dir
38 | self.output_width = output_width
39 | self.scale = scale
40 | self.min_width = min_width
41 | self.mode = mode
42 | self.suffix = suffix
43 |
44 | def __len__(self):
45 | return self.nSamples
46 |
47 | def __getitem__(self, index):
48 | assert index <= len(self), 'index range error'
49 | id = self.ids[index]
50 | if self.suffix == '.npy':
51 | img = np.load(os.path.join(self.img_dir, id + self.suffix))
52 | elif self.suffix == '.jpg':
53 | img = Image.open(os.path.join(self.img_dir, id + self.suffix))
54 | img = np.array(img)
55 | c, h, w = img.shape
56 | new_h = int(self.scale * h) if int(
57 | self.scale * h) > self.min_width else self.min_width
58 | new_w = int(self.scale * w) if int(
59 | self.scale * w) > self.min_width else self.min_width
60 | img = np.array(
61 | [cv2.resize(img[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in
62 | range(c)])
63 | img_array = np.array(img) / 255.
64 |
65 | if self.mode == 'merge':
66 | labels = self.labels_dict[id]
67 | rows = labels['rows']
68 | columns = labels['columns']
69 | h_matrix = labels['h_matrix']
70 | v_matrix = labels['v_matrix']
71 |
72 | img_tensor = torch.from_numpy(img_array).type(torch.FloatTensor)
73 | row_label = torch.from_numpy(np.array(h_matrix)).type(torch.FloatTensor)
74 | column_label = torch.from_numpy(np.array(v_matrix)).type(
75 | torch.FloatTensor)
76 |
77 | return img_tensor, (row_label, column_label), (rows, columns)
78 | else:
79 | labels = self.labels_dict[id]
80 | row_label = labels['row']
81 | column_label = labels['column']
82 |
83 | # resize ground truth to proper size
84 | width = int(np.floor(new_w / self.output_width))
85 | height = int(np.floor(new_h / self.output_width))
86 | row_label = np.array([row_label]).T * np.ones((len(row_label), width))
87 | column_label = np.array(column_label) * np.ones(
88 | (height, len(column_label)))
89 |
90 | row_label = np.array(row_label, dtype=np.uint8)
91 | column_label = np.array(column_label, dtype=np.uint8)
92 |
93 | row_label = cv2.resize(row_label, (width, new_h))
94 | column_label = cv2.resize(column_label, (new_w, height))
95 |
96 | row_label = row_label[:, 0]
97 | column_label = column_label[0, :]
98 |
99 | img_tensor = torch.from_numpy(img_array).type(torch.FloatTensor)
100 | row_label = torch.from_numpy(row_label).type(torch.FloatTensor)
101 | column_label = torch.from_numpy(column_label).type(torch.FloatTensor)
102 |
103 | return img_tensor, (row_label, column_label)
104 |
--------------------------------------------------------------------------------
/images/merge_example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/images/merge_example.jpg
--------------------------------------------------------------------------------
/images/split_input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/images/split_input.jpg
--------------------------------------------------------------------------------
/images/split_output.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/images/split_output.jpg
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/loss/__init__.py
--------------------------------------------------------------------------------
/loss/loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: craig.li(solitaire10@163.com)
3 | # Defines two loss functions which are used in split model and merge model.
4 |
5 | import torch
6 |
7 | from functools import reduce
8 |
9 |
10 | def bce_loss(pred, label):
11 | """
12 | Loss function for Split model
13 | Args:
14 | pred(torch.tensor): Prediction
15 | label(torch.tensor): Ground truth
16 | Return:
17 | loss(torch.tensor): Loss of the input image
18 | """
19 | row_pred, column_pred = pred
20 | row_label, column_label = label
21 |
22 | criterion = torch.nn.BCELoss(torch.tensor([10.])).cuda()
23 |
24 | lr3 = criterion(row_pred[0].view(-1), row_label.view(-1))
25 |
26 | lc3 = criterion(column_pred[0].view(-1), column_label.view(-1))
27 |
28 | loss = lr3 + lc3
29 |
30 | return loss
31 |
32 |
33 | def merge_loss(pred, label, weight):
34 | """
35 | Loss function for training Merge model which is Binary Cross-Entropy loss.
36 |
37 | Args:
38 | pred(torch.tensor): Prediction of the input image
39 | label(torch.tensor): Ground truth of corresponding image
40 | weight(float): Weight to balance positive and negative samples
41 | Return:
42 | loss(torch.tensor): Loss of the input image
43 | D(torch.tensor): A matrix with size (M - 1) x N, indicates the probability
44 | of two neighbor cells should be merged in vertical direction.
45 | R(torch.tensor): A matrix with size M * (N - 1), indicates the probability
46 | of two neighbor cells should be merged in horizontal direction.
47 | Where M is the height of the label, and N is the width of the label
48 | """
49 | pu, pd, pl, pr = pred
50 | D = 0.5 * pu[:, :-1, :] * pd[:, 1:, :] + 0.25 * (pu[:, :-1, :] + pd[:, 1:, :])
51 | R = 0.5 * pr[:, :, :-1] * pl[:, :, 1:] + 0.25 * (pr[:, :, :-1] + pl[:, :, 1:])
52 |
53 | DT, RT = label
54 |
55 | criterion = torch.nn.BCELoss(torch.tensor([weight])).cuda()
56 |
57 | ld = criterion(D.view(-1), DT.view(-1))
58 | lr = criterion(R.view(-1), RT.view(-1))
59 | losses = []
60 | if D.view(-1).shape[0] != 0:
61 | losses.append(ld)
62 | if R.view(-1).shape[0] != 0:
63 | losses.append(lr)
64 | if len(losses) == 0:
65 | loss = torch.tensor(0).cuda()
66 | else:
67 | loss = reduce(lambda x, y: x + y, losses)
68 | return loss, D, R
69 |
--------------------------------------------------------------------------------
/merge/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/merge/__init__.py
--------------------------------------------------------------------------------
/merge/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: craig.li(solitaire10@163.com)
3 | # Script for testing Merge model.
4 |
5 |
6 | import argparse
7 | import json
8 | import numpy as np
9 | import os
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 |
13 | from dataset.dataset import ImageDataset
14 | from modules.merge_modules import MergeModel
15 | from loss.loss import merge_loss
16 | from torch.utils.data import DataLoader
17 |
18 |
19 | def test(opt, net, data=None):
20 | """
21 | Test script for Merge model
22 | Args:
23 | opt(dic): Options
24 | net(torch.model): Merge model instance
25 | data(dataloader): Dataloader or None, if load data with configuration in opt.
26 | Return:
27 | total_loss(torch.tensor): The total loss of the dataset
28 | precision(torch.tensor): Precision (TP / TP + FP)
29 | recall(torch.tensor): Recall (TP / TP + FN)
30 | f1(torch.tensor): f1 score (2 * precision * recall / (precision + recall))
31 | """
32 | if not data:
33 | with open(opt.json_dir, 'r') as f:
34 | labels = json.load(f)
35 | dir_img = opt.img_dir
36 |
37 | test_set = ImageDataset(dir_img, labels, opt.featureW, scale=opt.scale,
38 | mode='merge')
39 | test_loader = DataLoader(test_set, batch_size=opt.batch_size, shuffle=False)
40 | else:
41 | test_loader = data
42 |
43 | loss_func = merge_loss
44 |
45 | for epoch in range(1):
46 | net.eval()
47 | epoch_loss = 0
48 | number_batchs = 0
49 | total_tp = 0
50 | total_tn = 0
51 | total_fp = 0
52 | total_fn = 0
53 | for i, b in enumerate(test_loader):
54 | with torch.no_grad():
55 | img, label, arc = b
56 | if opt.gpu:
57 | img = img.cuda()
58 | label = [x.cuda() for x in label]
59 | pred_label = net(img, arc)
60 | loss, D, R = loss_func(pred_label, label, 10.)
61 | epoch_loss += loss
62 |
63 | tp = torch.sum(
64 | ((D.view(-1)[
65 | (label[0].view(-1) > 0.5).type(torch.ByteTensor)] > 0.5).type(
66 | torch.IntTensor) ==
67 | label[0].view(-1)[
68 | (label[0].view(-1) > 0.5).type(torch.ByteTensor)].type(
69 | torch.IntTensor))).item() + torch.sum(
70 | ((R.view(-1)[
71 | (label[1].view(-1) > 0.5).type(torch.ByteTensor)] > 0.5).type(
72 | torch.IntTensor) ==
73 | label[1].view(-1)[
74 | (label[1].view(-1) > 0.5).type(torch.ByteTensor)].type(
75 | torch.IntTensor))).item()
76 | tn = torch.sum(
77 | ((D.view(-1)[
78 | (label[0].view(-1) <= 0.5).type(torch.ByteTensor)] > 0.5).type(
79 | torch.IntTensor) ==
80 | label[0].view(-1)[
81 | (label[0].view(-1) <= 0.5).type(torch.ByteTensor)].type(
82 | torch.IntTensor))).item() + torch.sum(
83 | ((R.view(-1)[
84 | (label[1].view(-1) <= 0.5).type(torch.ByteTensor)] > 0.5).type(
85 | torch.IntTensor) ==
86 | label[1].view(-1)[
87 | (label[1].view(-1) <= 0.5).type(torch.ByteTensor)].type(
88 | torch.IntTensor))).item()
89 | fn = torch.sum(
90 | (label[0].view(-1) > 0.5).type(torch.ByteTensor)).item() + torch.sum(
91 | (label[1].view(-1) > 0.5).type(torch.ByteTensor)).item() - tp
92 | fp = torch.sum(
93 | (label[0].view(-1) < 0.5).type(torch.ByteTensor)).item() + torch.sum(
94 | (label[1].view(-1) < 0.5).type(torch.ByteTensor)).item() - tn
95 |
96 | total_fn += fn
97 | total_fp += fp
98 | total_tn += tn
99 | total_tp += tp
100 | number_batchs += 1
101 | total_loss = epoch_loss / number_batchs
102 | precision = total_tp / (total_tp + total_fp)
103 | recall = total_tp / (total_tp + total_fn)
104 | f1 = 2 * precision * recall / (precision + recall)
105 | print(
106 | 'Validation finished ! Loss: {0} ; Precision: {1} ; Recall: {2} ; F1 Score: {3}'.format(
107 | total_loss,
108 | precision,
109 | recall,
110 | f1))
111 | return total_loss, precision, recall, f1
112 |
113 |
114 | def model_select(opt, net):
115 | """
116 | Select best model with highest f1 score
117 | Args:
118 | opt(dict): Options
119 | net(torch.model): Merge model instance
120 | """
121 | model_dir = opt.model_dir
122 | models = os.listdir(model_dir)
123 | losses = []
124 | f1s = []
125 | for model in models:
126 | print(model)
127 | net.load_state_dict(torch.load(os.path.join(model_dir, model)))
128 | loss, precision, recall, f1 = test(opt, net)
129 | losses.append(loss)
130 | f1s.append(f1)
131 | min_loss_index = np.argmin(np.array(losses))
132 | max_f1_index = np.argmax(np.array(f1s))
133 | print('losses', min_loss_index, losses[min_loss_index],
134 | models[min_loss_index])
135 | print('f1 score', max_f1_index, f1s[max_f1_index], models[max_f1_index])
136 |
137 |
138 | if __name__ == '__main__':
139 | parser = argparse.ArgumentParser()
140 | parser.add_argument('--batch_size', type=int, default=32,
141 | help='batch size of the training set')
142 | parser.add_argument('--gpu', type=bool, default=True, help='if use gpu')
143 | parser.add_argument('--gpu_list', type=str, default='0',
144 | help='which gpu could use')
145 | parser.add_argument('--model_dir', type=str, required=True,
146 | help='saved directory for output models')
147 | parser.add_argument('--json_dir', type=str, required=True,
148 | help='labels of the data')
149 | parser.add_argument('--img_dir', type=str, required=True,
150 | help='image directory for input data')
151 | parser.add_argument('--featureW', type=int, default=8, help='width of output')
152 | parser.add_argument('--scale', type=float, default=0.5,
153 | help='scale of the image')
154 |
155 | opt = parser.parse_args()
156 |
157 | net = MergeModel(3)
158 | if opt.gpu:
159 | cudnn.benchmark = True
160 | cudnn.deterministic = True
161 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_list
162 | net = torch.nn.DataParallel(net).cuda()
163 |
164 | net.load_state_dict(torch.load('merge_models_1225/CP90.pth'))
165 | print(test(opt, net))
166 | # model_select(opt, net)
167 |
--------------------------------------------------------------------------------
/merge/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: craig.li(solitaire10@163.com)
3 | # Script for training merge model
4 |
5 |
6 | import argparse
7 | import json
8 | import os
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 |
12 | from torch import optim
13 | from torch.utils.data import DataLoader
14 | from dataset.dataset import ImageDataset
15 | from modules.merge_modules import MergeModel
16 | from merge.test import test
17 | from loss.loss import merge_loss
18 |
19 |
20 | def train(opt, net):
21 | """
22 | Train the merge model
23 | Args:
24 | opt(dic): Options
25 | net(torch.model): Merge model instance
26 | """
27 | # load labels
28 | with open(opt.json_dir, 'r') as f:
29 | labels = json.load(f)
30 | dir_img = opt.img_dir
31 |
32 | with open(opt.val_json, 'r') as f:
33 | val_labels = json.load(f)
34 | val_img_dir = opt.val_img_dir
35 |
36 | train_set = ImageDataset(dir_img, labels, opt.featureW, scale=opt.scale,
37 | mode='merge')
38 | train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True)
39 |
40 | val_set = ImageDataset(val_img_dir, val_labels, opt.featureW, scale=opt.scale,
41 | mode='merge')
42 | val_loader = DataLoader(val_set, batch_size=opt.batch_size, shuffle=False)
43 |
44 | print('Data loaded!')
45 |
46 | # defines loss function
47 | loss_func = merge_loss
48 | optimizer = optim.Adam(net.parameters(),
49 | lr=opt.lr,
50 | weight_decay=0.001)
51 | best_f1 = 0
52 | for epoch in range(opt.epochs):
53 | print('epoch:{}'.format(epoch + 1))
54 | net.train()
55 | epoch_loss = 0
56 | number_batchs = 0
57 | for i, b in enumerate(train_loader):
58 | img, label, arc = b
59 | if opt.gpu:
60 | img = img.cuda()
61 | label = [x.cuda() for x in label]
62 | pred_label = net(img, arc)
63 | loss, _, _ = loss_func(pred_label, label, 10.)
64 | if loss.requires_grad:
65 | epoch_loss += loss
66 | optimizer.zero_grad()
67 | loss.backward()
68 | optimizer.step()
69 | number_batchs += 1
70 |
71 | print('Epoch finished ! Loss: {0} '.format(epoch_loss / number_batchs))
72 | val_loss, precision, recall, f1 = test(opt, net, val_loader)
73 | # save model if best f1 score less than current f1 score
74 | if f1 > best_f1:
75 | best_f1 = f1
76 | torch.save(net.state_dict(),
77 | opt.saved_dir + 'CP{}.pth'.format(epoch + 1))
78 | # write training information of current epoch to the log file
79 | with open(os.path.join(opt.saved_dir, 'log.txt'), 'a') as f:
80 | f.write(
81 | 'Epoch {0}, val loss : {1}, precision : {2}, recall : {3}, f1 score : {4} \n tra loss : {5} \n\n'.format(
82 | epoch + 1, val_loss, precision, recall, f1,
83 | epoch_loss / number_batchs))
84 |
85 |
86 | if __name__ == '__main__':
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument('--batch_size', type=int, default=1,
89 | help='batch size of the training set')
90 | parser.add_argument('--epochs', type=int, default=50, help='epochs')
91 | parser.add_argument('--gpu', type=bool, default=True, help='if use gpu')
92 | parser.add_argument('--gpu_list', type=str, default='0',
93 | help='which gpu could use')
94 | parser.add_argument('--lr', type=float, default=0.00075,
95 | help='learning rate, default=0.00075 for Adam')
96 | parser.add_argument('--saved_dir', type=str, required=True,
97 | help='saved directory for output models')
98 | parser.add_argument('--json_dir', type=str, required=True,
99 | help='labels of the data')
100 | parser.add_argument('--img_dir', type=str, required=True,
101 | help='image directory for input data')
102 | parser.add_argument('--val_json', type=str, required=True,
103 | help='labels of the validation data')
104 | parser.add_argument('--val_img_dir', type=str, required=True,
105 | help='image directory for validation data')
106 | parser.add_argument('--featureW', type=int, default=8, help='width of output')
107 | parser.add_argument('--scale', type=float, default=0.5,
108 | help='scale of the image')
109 |
110 | opt = parser.parse_args()
111 |
112 | net = MergeModel(3)
113 | if opt.gpu:
114 | cudnn.benchmark = True
115 | cudnn.deterministic = True
116 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_list
117 | net = torch.nn.DataParallel(net).cuda()
118 |
119 | if not os.path.exists(opt.saved_dir):
120 | os.mkdir(opt.saved_dir)
121 |
122 | with open(os.path.join(opt.saved_dir, 'log.txt'), 'w') as f:
123 | configuration = '--batch_size: {0} \n' \
124 | '--epochs: {1} \n' \
125 | '--gpu: {2} \n' \
126 | '--gpu_list: {3} \n' \
127 | '--lr: {4} \n' \
128 | '--saved_dir: {5} \n' \
129 | '--json_dir: {6} \n' \
130 | '--img_dir: {7} \n' \
131 | '--val_json: {8} \n' \
132 | '--val_img_dir: {9} \n' \
133 | '--featureW: {10} \n' \
134 | '--scale: {11} \n'.format(
135 | opt.batch_size, opt.epochs, opt.gpu, opt.gpu_list, opt.lr, opt.saved_dir,
136 | opt.json_dir, opt.img_dir,
137 | opt.val_json, opt.val_img_dir, opt.featureW, opt.scale)
138 | f.write(configuration + '\n\n Logs: \n')
139 |
140 | train(opt, net)
141 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/modules/__init__.py
--------------------------------------------------------------------------------
/modules/merge_modules.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: craig.li(solitaire10@163.com)
3 | # Defines several modules to build up Merge model
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class MergeModel(nn.Module):
10 | """
11 | Merge model refer to ICDAR 2019 Deep Splitting and Merging for Table Structure Decomposition
12 | """
13 |
14 | def __init__(self, input_channels):
15 | """
16 | Initialization of merge model
17 | Args:
18 | input_channels(int): The number of input data channel
19 | """
20 | super(MergeModel, self).__init__()
21 | # shared full conv net
22 | self.sfcn = SharedFCN(input_channels)
23 | # four branches: up, down, left, right
24 | self.rpn1 = ProjectionNet(18, True, 0.3)
25 | self.rpn2 = ProjectionNet(36, False, 0)
26 | self.rpn3 = ProjectionNet(36, True, 0.3)
27 |
28 | self.dpn1 = ProjectionNet(18, True, 0.3)
29 | self.dpn2 = ProjectionNet(36, False, 0)
30 | self.dpn3 = ProjectionNet(36, True, 0.3)
31 |
32 | self.upn1 = ProjectionNet(18, True, 0.3)
33 | self.upn2 = ProjectionNet(36, False, 0)
34 | self.upn3 = ProjectionNet(36, True, 0.3)
35 |
36 | self.lpn1 = ProjectionNet(18, True, 0.3)
37 | self.lpn2 = ProjectionNet(36, False, 0)
38 | self.lpn3 = ProjectionNet(36, True, 0.3)
39 |
40 | self._init_weights()
41 |
42 | def _init_weights(self):
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
45 | torch.nn.init.kaiming_normal_(m.weight)
46 | if m.bias is not None:
47 | m.bias.data.fill_(0.01)
48 | elif isinstance(m, nn.BatchNorm2d):
49 | m.weight.data.fill_(1)
50 | m.bias.data.zero_()
51 | elif isinstance(m, nn.Linear):
52 | m.weight.data.normal_(0, 0.01)
53 | m.bias.data.zero_()
54 |
55 | def forward(self, x, arc):
56 | """
57 | Forward pass of the merge model
58 | Args:
59 | x(torch.tensor): Input tensor with shape (b,c,h,w)
60 | arc(list(torch.tensor)): Grid architecture for grid project pooling,
61 | contains two lists of tensor indicates coordinates of horizontal and
62 | vertical lines
63 | Return:
64 | output(list(torch.tensor)): output of the merge model, a list contains four matrices
65 | indicates four direction
66 | """
67 | feature = self.sfcn(x)
68 |
69 | right_feature, r3 = self.rpn1(feature,
70 | arc) # self.rpn3(self.rpn2(self.rpn1(feature, arc), arc), arc)
71 |
72 | left_feature, l3 = self.lpn1(feature,
73 | arc) # self.lpn3(self.lpn2(self.lpn1(feature, arc), arc), arc)
74 |
75 | up_feature, u3 = self.upn1(feature,
76 | arc) # self.upn3(self.upn2(self.upn1(feature, arc), arc), arc)
77 |
78 | down_feature, d3 = self.dpn1(feature,
79 | arc) # self.dpn3(self.dpn2(self.dpn1(feature, arc), arc), arc)
80 |
81 | output = [u3.squeeze(1), d3.squeeze(1), l3.squeeze(1), r3.squeeze(1)]
82 | return output
83 |
84 |
85 | class GridProjectPooling(nn.Module):
86 | """
87 | Grid project pooling, every pixel location replaces its value with the average of all
88 | pixels within its grid element:
89 | $$ \hat F_{ij} = \frac {1}{\lvert\Omega(i,j)\rvert} \sum_{i',j' \in \Omega (i,j)} F_i',j' $$
90 | """
91 |
92 | def __init__(self):
93 | """
94 | Initialization of grid project pooling module
95 | """
96 | super(GridProjectPooling, self).__init__()
97 |
98 | def forward(self, x, architecture):
99 | """
100 | Forward pass of this module
101 | Args:
102 | x(torch.tensor): Input tensor with shape (b, c, h, w)
103 | architecture(list(torch.tensor)): Grid architecture for grid project pooling,
104 | contains two lists of tensor indicates coordinates of horizontal and
105 | vertical lines
106 | Return:
107 | output(torch.tensor): Output tensor of this module, the shape is same with
108 | input tensor
109 | matrix(torch.tensor): A M x N matrix, where M and N indicates the number of
110 | lines in horizontal and vertical directions.
111 | """
112 | b, c, h, w = x.size()
113 | h_line, v_line = architecture
114 | self.h_line = [torch.Tensor([0]).type(
115 | torch.DoubleTensor).cuda()] + h_line + [
116 | torch.Tensor([1]).type(torch.DoubleTensor).cuda()]
117 | self.v_line = [torch.Tensor([0]).type(
118 | torch.DoubleTensor).cuda()] + v_line + [
119 | torch.Tensor([1]).type(torch.DoubleTensor).cuda()]
120 | self.h_line = [(h * x).round().type(torch.IntTensor) for x in self.h_line]
121 | self.v_line = [(w * x).round().type(torch.IntTensor) for x in self.v_line]
122 |
123 | rows = [self.h_line[i + 1] - self.h_line[i] for i in
124 | range(len(self.h_line) - 1)]
125 | columns = [self.v_line[i + 1] - self.v_line[i] for i in
126 | range(len(self.v_line) - 1)]
127 |
128 | slices = torch.split(x, rows, 2)
129 | means = [torch.mean(y, 2).unsqueeze(2) for y in slices]
130 | matrix = torch.cat(means, 2)
131 | blocks = [means[i].repeat(1, 1, rows[i], 1) for i in range(len(means))]
132 | block = torch.cat(blocks, 2)
133 |
134 | means = [torch.mean(y, 3).unsqueeze(3) for y in
135 | torch.split(matrix, columns, 3)]
136 | matrix = torch.cat(means, 3)
137 |
138 | block_mean = [torch.mean(y, 3).unsqueeze(3) for y in
139 | torch.split(block, columns, 3)]
140 |
141 | blocks = [block_mean[i].repeat(1, 1, 1, columns[i]) for i in
142 | range(len(block_mean))]
143 | output = torch.cat(blocks, 3)
144 | """
145 | Old version Grid pooling
146 | v_blocks = []
147 | matrix = torch.from_numpy(np.ones([b, c, len(self.h_line) - 1, len(self.v_line) - 1])).type(
148 | torch.FloatTensor).cuda()
149 | for i in range(len(self.h_line) - 1):
150 | h_blocks = []
151 | for j in range(len(self.v_line) - 1):
152 | output_block = torch.from_numpy(
153 | np.ones([b, c, self.h_line[i + 1] - self.h_line[i], self.v_line[j + 1] - self.v_line[j]])).type(
154 | torch.FloatTensor).cuda()
155 | mean = torch.mean(
156 | torch.mean(x[:, :, self.h_line[i]:self.h_line[i + 1], self.v_line[j]:self.v_line[j + 1]], 2),
157 | 2).cuda()
158 | matrix[:, :, i, j] = mean
159 | h_blocks.append(mean.unsqueeze(0).transpose(0, 2) * output_block)
160 | h_block = torch.cat(h_blocks, 3)
161 | v_blocks.append(h_block)
162 | output = torch.cat(v_blocks, 2)
163 | """
164 | return output, matrix
165 |
166 |
167 | class ProjectionNet(nn.Module):
168 | """
169 | Projection Module contains three parallel conv layers with dilation factor 1,2,3, followed by
170 | a grid project pooling module
171 | """
172 |
173 | def __init__(self, input_channels, sigmoid=False, dropout=0.5):
174 | """
175 | Initialization of Project module
176 | Args:
177 | input_channels(int): The number of input channels of the module
178 | sigmoid(bool): If need to ge the output matrix
179 | dropout(float): Drop out ratio
180 | """
181 | super(ProjectionNet, self).__init__()
182 | self.conv_branch1 = nn.Sequential(
183 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=1, dilation=1),
184 | nn.GroupNorm(3, 6), nn.ReLU(True))
185 | self.conv_branch2 = nn.Sequential(
186 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=2, dilation=2),
187 | nn.GroupNorm(3, 6), nn.ReLU(True))
188 | self.conv_branch3 = nn.Sequential(
189 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=3, dilation=3),
190 | nn.GroupNorm(3, 6), nn.ReLU(True))
191 | self.sigmoid = sigmoid
192 | self.project_module = ProjectionModule(18, sigmoid, dropout=dropout)
193 |
194 | def forward(self, x, arc):
195 | """
196 | Forward pass of Project module
197 | Args:
198 | x(torch.tensor): Input tensor with shape (b,c,h,w)
199 | arc(list(torch.tensor)): Grid architecture for grid project pooling,
200 | contains two lists of tensor indicates coordinates of horizontal and
201 | vertical lines
202 | Return:
203 | output(torch.tensor): Output tensor of this module, the shape is same with
204 | input tensor
205 | matrix(torch.tensor): A M x N matrix, where M and N indicates the number of
206 | lines in horizontal and vertical directions.
207 | """
208 | conv_out = torch.cat(
209 | [m(x) for m in [self.conv_branch1, self.conv_branch2, self.conv_branch3]],
210 | 1)
211 | output, matrix = self.project_module(conv_out, arc)
212 | if self.sigmoid:
213 | return output, matrix
214 | else:
215 | return output
216 |
217 |
218 | class ProjectionModule(nn.Module):
219 | """
220 | Projection block
221 | """
222 |
223 | def __init__(self, input_channels, sigmoid=False, dropout=0.5):
224 | """
225 | Initialization of Project module
226 | Args:
227 | input_channels(int): The number of input channels of the module
228 | sigmoid(bool): If need to ge the output matrix
229 | dropout(float): Drop out ratio
230 | """
231 | super(ProjectionModule, self).__init__()
232 | self.sigmoid = sigmoid
233 |
234 | self.feature_conv = nn.Sequential(
235 | nn.Conv2d(input_channels, input_channels, 1, bias=False)
236 | , nn.GroupNorm(6, input_channels), nn.ReLU(True))
237 | self.prediction_conv = nn.Sequential(nn.Dropout2d(p=dropout),
238 | nn.Conv2d(input_channels, 1, 1,
239 | bias=False))
240 |
241 | self.feature_project = GridProjectPooling()
242 | self.prediction_project = GridProjectPooling()
243 | self.sigmoid_layer = nn.Sigmoid()
244 |
245 | def forward(self, x, arch):
246 | """
247 | Forward pass of Project module
248 | Args:
249 | x(torch.tensor): Input tensor with shape (b,c,h,w)
250 | arch(list(torch.tensor)): Grid architecture for grid project pooling,
251 | contains two lists of tensor indicates coordinates of horizontal and
252 | vertical lines
253 | Return:
254 | output(torch.tensor): Output tensor of this module, the shape is same with
255 | input tensor
256 | matrix(torch.tensor): A M x N matrix, where M and N indicates the number of
257 | lines in horizontal and vertical directions.
258 | """
259 | base_input = x
260 | feature = self.feature_conv(base_input)
261 | feature, _ = self.feature_project(feature, arch)
262 | tensors = [base_input, feature]
263 | if self.sigmoid:
264 | prediction = self.prediction_conv(base_input)
265 | prediction, matrix = self.prediction_project(prediction, arch)
266 | prediction = self.sigmoid_layer(prediction)
267 | matrix = self.sigmoid_layer(matrix)
268 | tensors.append(prediction)
269 | output = torch.cat(tensors, 1)
270 | return output, matrix
271 | else:
272 | output = torch.cat(tensors, 1)
273 | return output, None
274 |
275 |
276 | class SharedFCN(nn.Module):
277 | """Shared fully convolution module"""
278 |
279 | def __init__(self, input_channels):
280 | """
281 | Initialization of SFCN instance
282 | Args:
283 | input_channels(int): The number of input channels of the module
284 | """
285 | super(SharedFCN, self).__init__()
286 | self.conv = nn.Sequential(
287 | nn.Sequential(
288 | nn.Conv2d(input_channels, 18, 7, stride=1, padding=3, bias=False),
289 | nn.ReLU(True)),
290 | nn.Sequential(nn.Conv2d(18, 18, 7, stride=1, padding=3, bias=False),
291 | nn.ReLU(True)),
292 | nn.MaxPool2d((2, 2)),
293 | nn.Sequential(nn.Conv2d(18, 18, 7, stride=1, padding=3, bias=False),
294 | nn.ReLU(True)),
295 | nn.Sequential(nn.Conv2d(18, 18, 7, stride=1, padding=3, bias=False),
296 | nn.ReLU(True)),
297 | nn.MaxPool2d((2, 2))
298 | )
299 |
300 | def forward(self, x):
301 | x = self.conv(x)
302 | return x
303 |
--------------------------------------------------------------------------------
/modules/split_modules.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: craig.li(solitaire10@163.com)
3 | # Defines several modules to build up Split model
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class SplitModel(nn.Module):
11 | """
12 | Split model refer to ICDAR 2019 Deep Splitting and Merging for Table Structure Decomposition
13 | """
14 |
15 | def __init__(self, input_channels):
16 | """
17 | Initialization of split model
18 | Args:
19 | input_channels(int): The number of input data channel
20 | """
21 | super(SplitModel, self).__init__()
22 | self.sfcn = SFCN(input_channels)
23 | self.rpn1 = ProjectionNet(18, 0, True, False, 0)
24 | self.rpn2 = ProjectionNet(36, 0, True, False, 0)
25 | self.rpn3 = ProjectionNet(36, 0, False, True, 0.3)
26 | self.rpn4 = ProjectionNet(37, 0, False, True, 0)
27 | # self.rpn5 = ProjectionNet(37, 0, False, True, 0)
28 |
29 | self.cpn1 = ProjectionNet(18, 1, True, False, 0)
30 | self.cpn2 = ProjectionNet(36, 1, True, False, 0)
31 | self.cpn3 = ProjectionNet(36, 1, False, True, 0.3)
32 | self.cpn4 = ProjectionNet(37, 1, False, True, 0)
33 | # self.cpn5 = ProjectionNet(37, 1, False, True, 0)
34 |
35 | self._init_weights()
36 |
37 | def _init_weights(self):
38 | for m in self.modules():
39 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
40 | torch.nn.init.kaiming_normal_(m.weight)
41 | if m.bias is not None:
42 | m.bias.data.fill_(0.01)
43 | elif isinstance(m, nn.BatchNorm2d):
44 | m.weight.data.fill_(1)
45 | m.bias.data.zero_()
46 | elif isinstance(m, nn.Linear):
47 | m.weight.data.normal_(0, 0.01)
48 | m.bias.data.zero_()
49 |
50 | def forward(self, x):
51 | """
52 | Forward pass of the split model
53 | Args:
54 | x(torch.tensor): Input tensor with shape (b,c,h,w)
55 | Return:
56 | output(list(torch.tensor)): output of the split model, two vectors indicates if there
57 | is a line in horizontal and vertical direction
58 | """
59 | feature = self.sfcn(x)
60 |
61 | row_feature = self.rpn3(self.rpn2(self.rpn1(feature)))
62 | r3 = row_feature[:, -1, :, :]
63 | # row_feature = self.rpn4(row_feature)
64 | # r4 = row_feature[:, -1, :, :]
65 | # row_feature = self.rpn5(row_feature)
66 | # r5 = row_feature[:, -1, :, :]
67 |
68 | cow_feature = self.cpn3(self.cpn2(self.cpn1(feature)))
69 | c3 = cow_feature[:, -1, :, :]
70 | # cow_feature = self.cpn4(cow_feature)
71 | # c4 = cow_feature[:, -1, :, :]
72 | # cow_feature = self.cpn5(cow_feature)
73 | # c5 = cow_feature[:, -1, :, :]
74 |
75 | # r = self.rpn1(feature)[:,-1,:,:]
76 | # c = self.cpn1(feature)[:,-1,:,:]
77 | # return (
78 | # torch.cat([r3[:, :, 0], r4[:, :, 0]], 0),#, r5[:, :, 0]
79 | # torch.cat([c3[:, 0, :], c4[:, 0, :]], 0))#, c5[:, 0, :]
80 | return (r3[:, :, 0], c3[:, 0, :])
81 |
82 |
83 | class ProjectPooling(nn.Module):
84 | """
85 | Project pooling, replace each value in the input with its row(column) average,
86 | Row project pooling:
87 | $$ \hat F_{ij} = \frac {1}{W} \sum_{j'=1}^{W} F_i',j' $$
88 | Column project pooling:
89 | $$ \hat F_{ij} = \frac {1}{H} \sum_{i'=1}^{H} F_i',j' $$
90 | """
91 |
92 | def __init__(self, direction):
93 | """
94 | Initialization of Project pooling layer
95 | Args:
96 | direction(int): Specifies the direction of this layer, 0 for row and 1 for column
97 | """
98 | super(ProjectPooling, self).__init__()
99 | self.direction = direction
100 |
101 | def forward(self, x):
102 | """
103 | Forward pass of project pooling layer
104 | Args:
105 | x(torch.tensor): Input tensor with shape (b,c,h,w)
106 | Return:
107 | output: Output of Project pooling layer with the same shape with input tensor
108 | """
109 | b, c, h, w = x.size()
110 | output_slice = torch.from_numpy(np.ones([b, c, h, w])).type(
111 | torch.FloatTensor).cuda()
112 | if self.direction == 0:
113 | return torch.mean(x, 3).unsqueeze(3) * output_slice
114 | elif self.direction == 1:
115 | return torch.mean(x, 2).unsqueeze(2) * output_slice
116 | else:
117 | raise Exception(
118 | 'Wrong direction, the direction should be 0 for horizontal and 1 for vertical')
119 |
120 |
121 | class ProjectionNet(nn.Module):
122 | """
123 | Projection Module contains three parallel conv layers with dilation factor 2,3,4, followed by
124 | a project pooling module
125 | """
126 |
127 | def __init__(self, input_channels, direction, max_pooling=False,
128 | sigmoid=False, dropout=0.5):
129 | super(ProjectionNet, self).__init__()
130 | self.conv_branch1 = nn.Sequential(
131 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=2, dilation=2),
132 | nn.GroupNorm(3, 6), nn.ReLU(True))
133 | self.conv_branch2 = nn.Sequential(
134 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=3, dilation=3),
135 | nn.GroupNorm(3, 6), nn.ReLU(True))
136 | self.conv_branch3 = nn.Sequential(
137 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=4, dilation=4),
138 | nn.GroupNorm(3, 6), nn.ReLU(True))
139 |
140 | self.project_module = ProjectionModule(18, direction, max_pooling, sigmoid,
141 | dropout=dropout)
142 |
143 | def forward(self, x):
144 | """
145 | Forward pass of Project module
146 | Args:
147 | x(torch.tensor): Input tensor with shape (b,c,h,w)
148 | Return:
149 | output(torch.tensor): Output tensor of this module, the shape is same with
150 | input tensor
151 | """
152 | conv_out = torch.cat(
153 | [m(x) for m in [self.conv_branch1, self.conv_branch2, self.conv_branch3]],
154 | 1)
155 | output = self.project_module(conv_out)
156 | return output
157 |
158 |
159 | class ProjectionModule(nn.Module):
160 | """
161 | Projection block
162 | """
163 |
164 | def __init__(self, input_channels, direction, max_pooling=False,
165 | sigmoid=False, dropout=0.5):
166 | """
167 | Initialization of Project module
168 | Args:
169 | input_channels(int): The number of input channels of the module
170 | direction(int): Direction of project pooling module, 0 for row, 1 for column
171 | max_pooling(bool): If there is a max pooling layer in the module, if it's a
172 | row project pooling layer, a (1,2) max pooling layer would be applied,
173 | (2,1) max pooling would be applied if it's a column project pooling layer
174 | sigmoid(bool): If need to ge the output matrix
175 | dropout(float): Drop out ratio
176 | """
177 | super(ProjectionModule, self).__init__()
178 | self.direction = direction
179 | self.max_pooling = max_pooling
180 | self.sigmoid = sigmoid
181 | self.max_pool = nn.MaxPool2d((1, 2)) if direction == 0 else nn.MaxPool2d(
182 | (2, 1))
183 | self.feature_conv = nn.Sequential(
184 | nn.Conv2d(input_channels, input_channels, 1, bias=False)
185 | , nn.GroupNorm(6, input_channels), nn.ReLU(True))
186 | self.prediction_conv = nn.Sequential(nn.Dropout2d(p=dropout),
187 | nn.Conv2d(input_channels, 1, 1,
188 | bias=False))
189 |
190 | self.feature_project = ProjectPooling(direction)
191 | self.prediction_project = nn.Sequential(ProjectPooling(direction),
192 | nn.Sigmoid())
193 |
194 | def forward(self, x):
195 | """
196 | Forward pass of Project module
197 | Args:
198 | x(torch.tensor): Input tensor with shape (b,c,h,w)
199 | Return:
200 | output(torch.tensor): Output tensor of this module, if a max pooling layer is
201 | applied, the output shape would be decreased to half of the original shape
202 | in opposite direction
203 | """
204 | base_input = x
205 | if self.max_pooling:
206 | base_input = self.max_pool(x)
207 | feature = self.feature_conv(base_input)
208 | feature = self.feature_project(feature)
209 | tensors = [base_input, feature]
210 | if self.sigmoid:
211 | prediction = self.prediction_conv(base_input)
212 | prediction = self.prediction_project(prediction)
213 | tensors.append(prediction)
214 | output = torch.cat(tensors, 1)
215 | return output
216 |
217 |
218 | class SFCN(nn.Module):
219 | """
220 | Shared fully convolution module composed of three conv layers, and the last conv layer is
221 | a dilation conv layer with the factor 2
222 | """
223 |
224 | def __init__(self, input_channels):
225 | """
226 | Initialization of SFCN instance
227 | Args:
228 | input_channels(int): The number of input channels of the module
229 | """
230 | super(SFCN, self).__init__()
231 | self.conv1 = nn.Sequential(
232 | nn.Conv2d(input_channels, 18, 7, stride=1, padding=3, bias=False),
233 | nn.ReLU(True))
234 | self.conv2 = nn.Sequential(
235 | nn.Conv2d(18, 18, 7, stride=1, padding=3, bias=False),
236 | nn.ReLU(True))
237 | self.conv3 = nn.Sequential(
238 | nn.Conv2d(18, 18, 7, stride=1, padding=6, dilation=2, bias=False),
239 | nn.ReLU(True))
240 |
241 | def forward(self, x):
242 | x = self.conv1(x)
243 | x = self.conv2(x)
244 | x = self.conv3(x)
245 | return x
246 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=0.4.0
2 | numpy
3 | cv2
4 | PIL
5 | argparse
6 |
--------------------------------------------------------------------------------
/split/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/split/__init__.py
--------------------------------------------------------------------------------
/split/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: craig.li(solitaire10@163.com)
3 | # Script for testing Split model.
4 |
5 |
6 | import argparse
7 | import json
8 | import numpy as np
9 | import os
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 |
13 | from torch.utils.data import DataLoader
14 | from dataset.dataset import ImageDataset
15 | from modules.split_modules import SplitModel
16 | from loss.loss import bce_loss
17 |
18 |
19 | def test(opt, net, data=None):
20 | """
21 | Test script for Split model
22 | Args:
23 | opt(dic): Options
24 | net(torch.model): Split model instance
25 | data(dataloader): Dataloader or None, if load data with configuration in opt.
26 | Return:
27 | total_loss(torch.tensor): The total loss of the dataset
28 | accuracy(torch.tensor): the accuracy of the dataset
29 | """
30 | if not data:
31 | with open(opt.json_dir, 'r') as f:
32 | labels = json.load(f)
33 | dir_img = opt.img_dir
34 |
35 | test_set = ImageDataset(dir_img, labels, opt.featureW, scale=opt.scale)
36 | test_loader = DataLoader(test_set, batch_size=opt.batch_size, shuffle=True)
37 | else:
38 | test_loader = data
39 |
40 | loss_func = bce_loss
41 |
42 | for epoch in range(1):
43 | net.eval()
44 | epoch_loss = 0
45 | correct_count = 0
46 | count = 0
47 | times = 1
48 | for i, b in enumerate(test_loader):
49 | with torch.no_grad():
50 | img, label = b
51 | if opt.gpu:
52 | img = img.cuda()
53 | label = [x.cuda() for x in label]
54 | pred_label = net(img)
55 | loss = loss_func(pred_label, label, [0.1, 0.25, 1])
56 | epoch_loss += loss
57 | correct_count += (torch.sum(
58 | (pred_label[0] > 0.5).type(torch.IntTensor) == label[0][0].repeat(
59 | times, 1).type(
60 | torch.IntTensor)).item() + torch.sum(
61 | (pred_label[1] > 0.5).type(torch.IntTensor) == label[1][0].repeat(
62 | times, 1).type(
63 | torch.IntTensor)).item())
64 | count += label[0].view(-1).size()[0] * times + label[1].view(-1).size()[
65 | 0] * times
66 | accuracy = correct_count / (count)
67 | total_loss = epoch_loss / (i + 1)
68 | print('Validation finished ! Loss: {0} , Accuracy: {1}'.format(
69 | epoch_loss / (i + 1), accuracy))
70 | return total_loss, accuracy
71 |
72 |
73 | def model_select(opt, net):
74 | model_dir = opt.model_dir
75 | models = os.listdir(model_dir)
76 | losses = []
77 | accuracies = []
78 | for model in models:
79 | net.load_state_dict(torch.load(os.path.join(model_dir, model)))
80 | loss, accuracy = test(opt, net)
81 | losses.append(loss)
82 | accuracies.append(accuracy)
83 | min_loss_index = np.argmin(np.array(losses))
84 | max_accuracy_index = np.argmax(np.array(accuracies))
85 | print('accuracy:', max_accuracy_index, accuracies[max_accuracy_index],
86 | models[max_accuracy_index])
87 | print('losses', min_loss_index, losses[min_loss_index],
88 | models[min_loss_index])
89 |
90 | print(losses)
91 | print(accuracies)
92 |
93 |
94 | if __name__ == '__main__':
95 | parser = argparse.ArgumentParser()
96 | parser.add_argument('--batch_size', type=int, default=32,
97 | help='batch size of the training set')
98 | parser.add_argument('--gpu', type=bool, default=True, help='if use gpu')
99 | parser.add_argument('--gpu_list', type=str, default='0',
100 | help='which gpu could use')
101 | parser.add_argument('--model_dir', type=str, required=True,
102 | help='saved directory for output models')
103 | parser.add_argument('--json_dir', type=str, required=True,
104 | help='labels of the data')
105 | parser.add_argument('--img_dir', type=str, required=True,
106 | help='image directory for input data')
107 | parser.add_argument('--featureW', type=int, default=8, help='width of output')
108 | parser.add_argument('--scale', type=float, default=0.5,
109 | help='scale of the image')
110 |
111 | opt = parser.parse_args()
112 |
113 | net = SplitModel(3)
114 | if opt.gpu:
115 | cudnn.benchmark = True
116 | cudnn.deterministic = True
117 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_list
118 | net = torch.nn.DataParallel(net).cuda()
119 |
120 | net.load_state_dict(torch.load('saved_models/CP53.pth'))
121 | print(test(opt, net))
122 | # model_select(opt, net)
123 |
--------------------------------------------------------------------------------
/split/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: craig.li(solitaire10@163.com)
3 | # Script for training merge model
4 |
5 |
6 | import argparse
7 | import json
8 | import os
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 |
12 | from torch import optim
13 | from torch.utils.data import DataLoader
14 | from dataset.dataset import ImageDataset
15 | from loss.loss import bce_loss
16 | from modules.split_modules import SplitModel
17 | from split.test import test
18 |
19 |
20 | def train(opt, net):
21 | """
22 | Train the split model
23 | Args:
24 | opt(dic): Options
25 | net(torch.model): Split model instance
26 | """
27 | with open(opt.json_dir, 'r') as f:
28 | labels = json.load(f)
29 | dir_img = opt.img_dir
30 |
31 | with open(opt.val_json, 'r') as f:
32 | val_labels = json.load(f)
33 | val_img_dir = opt.val_img_dir
34 |
35 | train_set = ImageDataset(dir_img, labels, opt.featureW, scale=opt.scale)
36 | train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True)
37 |
38 | val_set = ImageDataset(val_img_dir, val_labels, opt.featureW, scale=opt.scale)
39 | val_loader = DataLoader(val_set, batch_size=opt.batch_size, shuffle=False)
40 |
41 | print('Data loaded!')
42 |
43 | loss_func = bce_loss
44 | optimizer = optim.Adam(net.parameters(),
45 | lr=opt.lr,
46 | weight_decay=0.001)
47 | best_accuracy = 0
48 | for epoch in range(opt.epochs):
49 | print('epoch:{}'.format(epoch + 1))
50 | net.train()
51 | epoch_loss = 0
52 | correct_count = 0
53 | count = 0
54 | for i, b in enumerate(train_loader):
55 | img, label = b
56 | if opt.gpu:
57 | img = img.cuda()
58 | label = [x.cuda() for x in label]
59 | pred_label = net(img)
60 | loss = loss_func(pred_label, label, [0.1, 0.25, 1])
61 | epoch_loss += loss
62 | optimizer.zero_grad()
63 | loss.backward()
64 | optimizer.step()
65 | times = 1
66 | correct_count += (torch.sum(
67 | (pred_label[0] > 0.5).type(torch.IntTensor) == label[0][0].repeat(times,
68 | 1).type(
69 | torch.IntTensor)).item() + torch.sum(
70 | (pred_label[1] > 0.5).type(torch.IntTensor) == label[1][0].repeat(times,
71 | 1).type(
72 | torch.IntTensor)).item())
73 | count += label[0].view(-1).size()[0] * times + label[1].view(-1).size()[
74 | 0] * times
75 | accuracy = correct_count / (count)
76 | print(
77 | 'Epoch finished ! Loss: {0} , Accuracy: {1}'.format(epoch_loss / (i + 1),
78 | accuracy))
79 | val_loss, val_acc = test(opt, net, val_loader)
80 | if val_acc > best_accuracy:
81 | best_accuracy = val_acc
82 | torch.save(net.state_dict(),
83 | opt.saved_dir + 'CP{}.pth'.format(epoch + 1))
84 |
85 |
86 | if __name__ == '__main__':
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument('--batch_size', type=int, default=1,
89 | help='batch size of the training set')
90 | parser.add_argument('--epochs', type=int, default=50, help='epochs')
91 | parser.add_argument('--gpu', type=bool, default=True, help='if use gpu')
92 | parser.add_argument('--gpu_list', type=str, default='0',
93 | help='which gpu could use')
94 | parser.add_argument('--lr', type=float, default=0.00075,
95 | help='learning rate, default=0.00075 for Adam')
96 | parser.add_argument('--saved_dir', type=str, required=True,
97 | help='saved directory for output models')
98 | parser.add_argument('--json_dir', type=str, required=True,
99 | help='labels of the data')
100 | parser.add_argument('--img_dir', type=str, required=True,
101 | help='image directory for input data')
102 | parser.add_argument('--val_json', type=str, required=True,
103 | help='labels of the validation data')
104 | parser.add_argument('--val_img_dir', type=str, required=True,
105 | help='image directory for validation data')
106 | parser.add_argument('--featureW', type=int, default=8, help='width of output')
107 | parser.add_argument('--scale', type=float, default=0.5,
108 | help='scale of the image')
109 |
110 | opt = parser.parse_args()
111 |
112 | net = SplitModel(3)
113 | if opt.gpu:
114 | cudnn.benchmark = True
115 | cudnn.deterministic = True
116 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_list
117 | net = torch.nn.DataParallel(net).cuda()
118 |
119 | if not os.path.exists(opt.saved_dir):
120 | os.mkdir(opt.saved_dir)
121 |
122 | train(opt, net)
123 |
--------------------------------------------------------------------------------