├── .gitignore ├── Merge test field.ipynb ├── Merge_predict.ipynb ├── README.md ├── Split_predict.ipynb ├── data_utils ├── Merge_data.py ├── Split_data.py └── __init__.py ├── dataset ├── __init__.py └── dataset.py ├── images ├── merge_example.jpg ├── split_input.jpg └── split_output.jpg ├── loss ├── __init__.py └── loss.py ├── merge ├── __init__.py ├── test.py └── train.py ├── modules ├── __init__.py ├── merge_modules.py └── split_modules.py ├── requirements.txt └── split ├── __init__.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.pyc 3 | /__pycache__ -------------------------------------------------------------------------------- /Merge test field.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Predict script for Merge model, predict the D and R matrices, and visualize the result." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "from modules.merge_modules import MergeModel\n", 35 | "import torch.backends.cudnn as cudnn\n", 36 | "import torch\n", 37 | "import json\n", 38 | "from dataset.dataset import ImageDataset\n", 39 | "from PIL import Image\n", 40 | "import cv2\n", 41 | "import numpy as np" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# init the Merge model\n", 51 | "net = MergeModel(3).cuda()\n", 52 | "cudnn.benchmark = True\n", 53 | "cudnn.deterministic = True\n", 54 | "net = torch.nn.DataParallel(net).cuda()" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "# load saved checkpoint \n", 64 | "net.load_state_dict(torch.load('Merge_model.pth'))" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "scrolled": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "# change the model to eval mode\n", 76 | "net.eval()" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# init dataset\n", 86 | "folder = 'validation'\n", 87 | "with open('D:/dataset/table/table_line/Split1/'+ folder+'_merge_dict.json', 'r') as f:\n", 88 | " labels = json.load(f)\n", 89 | "dataset = ImageDataset('D:/dataset/table/table_line/Split1/'+ folder+'_input', labels, 8, scale=0.25,mode='merge')" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "index = 0\n", 99 | "img, label, arc = dataset[index]\n", 100 | "index += 1" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": { 107 | "scrolled": true 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "# predict \n", 112 | "input_img = img.unsqueeze(0)\n", 113 | "arc_c = [[torch.Tensor([y]) for y in x] for x in arc]\n", 114 | "pred = net(input_img,arc_c)\n", 115 | "u,d,l,r = pred # up, down, left, right\n", 116 | "# calculate D and R matrice, \n", 117 | "D = 0.5 * u[:, :-1, :] * d[:, 1:, :] + 0.25 * (u[:, :-1, :] + d[:, 1:, :])\n", 118 | "R = 0.5 * r[:, :, :-1] * l[:, :, 1:] + 0.25 * (r[:, :, :-1] + l[:, :, 1:])\n", 119 | "D = D[0].detach().cpu().numpy()\n", 120 | "R = R[0].detach().cpu().numpy()\n", 121 | "D[D>0.5] = 1\n", 122 | "D[D<=0.5] = 0\n", 123 | "R[R>0.5] = 1\n", 124 | "R[R<=0.5] = 0\n", 125 | "\n", 126 | "rows, columns = arc\n", 127 | "h,w = img[2].shape\n", 128 | "rows = [round(h*x) for x in rows]\n", 129 | "columns = [round(w*x) for x in columns]\n", 130 | "rows = [0] + rows + [h]\n", 131 | "columns = [0] + columns +[w]\n", 132 | "\n", 133 | "# draw lines on the original image\n", 134 | "draw_img = img[2].numpy()*255.\n", 135 | "draw_img = cv2.cvtColor(draw_img, cv2.COLOR_GRAY2RGB)\n", 136 | "for i in range(R.shape[0]):\n", 137 | " for j in range(R.shape[1]):\n", 138 | " if R[i,j] == 0:\n", 139 | " pts1 = (columns[j+1],rows[i])\n", 140 | " pts2 = (columns[j+1],rows[i+1])\n", 141 | " draw_img = cv2.line(draw_img, pts1,pts2,(255.,0,0),2)\n", 142 | "for i in range(D.shape[0]):\n", 143 | " for j in range(D.shape[1]):\n", 144 | " if D[i,j] == 0:\n", 145 | " pts1 = (columns[j],rows[i+1])\n", 146 | " pts2 = (columns[j+1],rows[i+1])\n", 147 | " draw_img = cv2.line(draw_img, pts1,pts2,(255.,0,0),2)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": { 154 | "scrolled": true 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "# visualize original image\n", 159 | "Image.fromarray(img[2].numpy()*255.).convert('L')" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": { 166 | "scrolled": false 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "# visualize merged image\n", 171 | "Image.fromarray(np.array(draw_img,dtype=np.uint8))" 172 | ] 173 | } 174 | ], 175 | "metadata": { 176 | "kernelspec": { 177 | "display_name": "Python 3", 178 | "language": "python", 179 | "name": "python3" 180 | }, 181 | "language_info": { 182 | "codemirror_mode": { 183 | "name": "ipython", 184 | "version": 3 185 | }, 186 | "file_extension": ".py", 187 | "mimetype": "text/x-python", 188 | "name": "python", 189 | "nbconvert_exporter": "python", 190 | "pygments_lexer": "ipython3", 191 | "version": "3.6.5" 192 | } 193 | }, 194 | "nbformat": 4, 195 | "nbformat_minor": 2 196 | } 197 | -------------------------------------------------------------------------------- /Merge_predict.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Predict script for Merge model, predict the D and R matrice, and visualize resule." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "from modules.merge_modules import MergeModel\n", 35 | "import torch.backends.cudnn as cudnn\n", 36 | "import torch\n", 37 | "import json\n", 38 | "from dataset.dataset import ImageDataset\n", 39 | "from PIL import Image\n", 40 | "import cv2\n", 41 | "import numpy as np" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# init the Merge model\n", 51 | "net = MergeModel(3).cuda()\n", 52 | "cudnn.benchmark = True\n", 53 | "cudnn.deterministic = True\n", 54 | "net = torch.nn.DataParallel(net).cuda()" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "# load saved checkpoint \n", 64 | "net.load_state_dict(torch.load('Merge_model.pth'))" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "scrolled": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "# change the model to eval mode\n", 76 | "net.eval()" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# init dataset\n", 86 | "folder = 'validation'\n", 87 | "with open('D:/dataset/table/table_line/Split1/'+ folder+'_merge_dict.json', 'r') as f:\n", 88 | " labels = json.load(f)\n", 89 | "dataset = ImageDataset('D:/dataset/table/table_line/Split1/'+ folder+'_input', labels, 8, scale=0.25,mode='merge')" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "index = 0\n", 99 | "img, label, arc = dataset[index]\n", 100 | "index += 1" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": { 107 | "scrolled": true 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "# predict \n", 112 | "input_img = img.unsqueeze(0)\n", 113 | "arc_c = [[torch.Tensor([y]) for y in x] for x in arc]\n", 114 | "pred = net(input_img,arc_c)\n", 115 | "u,d,l,r = pred # up, down, left, right\n", 116 | "# calculate D and R matrice, \n", 117 | "D = 0.5 * u[:, :-1, :] * d[:, 1:, :] + 0.25 * (u[:, :-1, :] + d[:, 1:, :])\n", 118 | "R = 0.5 * r[:, :, :-1] * l[:, :, 1:] + 0.25 * (r[:, :, :-1] + l[:, :, 1:])\n", 119 | "D = D[0].detach().cpu().numpy()\n", 120 | "R = R[0].detach().cpu().numpy()\n", 121 | "D[D>0.5] = 1\n", 122 | "D[D<=0.5] = 0\n", 123 | "R[R>0.5] = 1\n", 124 | "R[R<=0.5] = 0\n", 125 | "\n", 126 | "rows, columns = arc\n", 127 | "h,w = img[2].shape\n", 128 | "rows = [round(h*x) for x in rows]\n", 129 | "columns = [round(w*x) for x in columns]\n", 130 | "rows = [0] + rows + [h]\n", 131 | "columns = [0] + columns +[w]\n", 132 | "\n", 133 | "# draw lines on the original image\n", 134 | "draw_img = img[2].numpy()*255.\n", 135 | "draw_img = cv2.cvtColor(draw_img, cv2.COLOR_GRAY2RGB)\n", 136 | "for i in range(R.shape[0]):\n", 137 | " for j in range(R.shape[1]):\n", 138 | " if R[i,j] == 0:\n", 139 | " pts1 = (columns[j+1],rows[i])\n", 140 | " pts2 = (columns[j+1],rows[i+1])\n", 141 | " draw_img = cv2.line(draw_img, pts1,pts2,(255.,0,0),2)\n", 142 | "for i in range(D.shape[0]):\n", 143 | " for j in range(D.shape[1]):\n", 144 | " if D[i,j] == 0:\n", 145 | " pts1 = (columns[j],rows[i+1])\n", 146 | " pts2 = (columns[j+1],rows[i+1])\n", 147 | " draw_img = cv2.line(draw_img, pts1,pts2,(255.,0,0),2)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": { 154 | "scrolled": true 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "# visualize original image\n", 159 | "Image.fromarray(img[2].numpy()*255.).convert('L')" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": { 166 | "scrolled": false 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "# visualize merged image\n", 171 | "Image.fromarray(np.array(draw_img,dtype=np.uint8))" 172 | ] 173 | } 174 | ], 175 | "metadata": { 176 | "kernelspec": { 177 | "display_name": "Python 3", 178 | "language": "python", 179 | "name": "python3" 180 | }, 181 | "language_info": { 182 | "codemirror_mode": { 183 | "name": "ipython", 184 | "version": 3 185 | }, 186 | "file_extension": ".py", 187 | "mimetype": "text/x-python", 188 | "name": "python", 189 | "nbconvert_exporter": "python", 190 | "pygments_lexer": "ipython3", 191 | "version": "3.6.5" 192 | } 193 | }, 194 | "nbformat": 4, 195 | "nbformat_minor": 2 196 | } 197 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Split&Merge: Table Recognition with Pytorch 2 | 3 | An implementation of Table Recognition Model Split&Merge in Pytorch. Split&Merge is an efficient convolutional neural network architecture for recognizing table structure from images. For more detail, please check the paper from ICDAR 2019: Deep Splitting and Merging for Table Structure Decomposition 4 | 5 | ## Usage 6 | 7 | **Clone the repo:** 8 | 9 | ``` 10 | git clone https://github.com/solitaire2015/Split_Merge_table_recognition.git 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | **Prepare the training data:** 15 | 16 | The input of the Split model and Merge model should be an image which has one channels or three channels, I used three channels image, ones channel is gray scale image, the other two are segmentation mask in vertical and horizontal direction. You can use gray scale image only by setting number of channels to 1 and dataset suffix to `.jpg`. 17 | 18 | Split model 19 | 20 | The ground truth of Split model is loaded from a `.json` file, the structure of the file is like: 21 | 22 | `````` 23 | { 24 | “img_1”:{"rows":[0,0,1,0,1,1],"columns":[1,0,0,1,1,1,1]}, 25 | “img_2”:{"rows":[0,0,1,0,1,1],"columns":[1,0,0,1,1,1,1]} 26 | } 27 | `````` 28 | 29 | Where `row` indicates if it's a line in corresponding row of the image, the length of `row` is height of the image. `columns` indicates corresponding column the length is width of the image. 30 | 31 | Merge model 32 | 33 | The ground truth of Merge is loaded from a `.json` file, here is the structure: 34 | 35 | `````` 36 | { 37 | “img_1”:{"rows":[0,0,1,0,1,1], 38 | "columns":[1,0,0,1,1,1,1], 39 | "h_matrix":[[0,0,0], 40 | [1,0,1]]}, 41 | "v_matrix":[[0,1,0], 42 | [0,0,1]]} 43 | } 44 | `````` 45 | 46 | where the `h_matrix` indicates where the cells should be merged with it's right neighborhood, `v_matrix` indicates where the cells should be merged with it's bottom neighborhood, check more detail from the original paper, `h_matrix` and `v_matrix` are `R` and `D` in that paper. 47 | 48 | There are two scripts in `data_utils` folder to generate training data. 49 | 50 | **Training** 51 | 52 | Train Split model : 53 | 54 | `````` 55 | python split/train.py --img_dir your image dir -- json_dir your json file -- saved_dir where you want to save model --val_img_dir your validation image dir -- val_json your validation json file 56 | `````` 57 | 58 | Train Merge model: 59 | 60 | `````` 61 | python merge/train.py --img_dir your image dir -- json_dir your json file -- saved_dir where you want to save model --val_img_dir your validation image dir -- val_json your validation json file 62 | `````` 63 | Pre-trained models: 64 | * Merge: https://pan.baidu.com/s/112ElDK-MSMQ50CTKpkPCZQ code:rto4 65 | 66 | * Split: https://pan.baidu.com/s/1rSO6o23WbKV6jXo2DRrQcw code:lu23 67 | 68 | Run the predict script: 69 | 70 | `````` 71 | jupyter notebook 72 | `````` 73 | 74 | Open `Split_predict.ipynb` and `Merge_predict.ipynb` in your browser. 75 | 76 | ## Result 77 | 78 | I didn't test the model on public ICDAR table recognition competition dataset, I test the model on my private dataset and got 97% F-score, you can test on ICDAR dataset by yourself. 79 | 80 | Here are some example: 81 | 82 | ## Images 83 | 84 | ![](images/split_input.jpg) 85 | 86 | Fig1. Original image 87 | 88 | ![](images/split_output.jpg) 89 | 90 | Fig2. Split result 91 | 92 | ![](images/merge_example.jpg) 93 | 94 | Fig3. one Merge result. 95 | 96 | -------------------------------------------------------------------------------- /Split_predict.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Predict script for Split model, predict the D and R matrices, and visualize the result." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "from dataset.dataset import ImageDataset\n", 18 | "from modules.split_modules import SplitModel\n", 19 | "import json\n", 20 | "from PIL import Image\n", 21 | "import torch\n", 22 | "from torchsummary import summary\n", 23 | "import numpy as np\n", 24 | "import cv2\n", 25 | "import json" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 9, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# load dataset\n", 44 | "folder = 'train'\n", 45 | "with open('D:/dataset/table/table_line/Split1/'+ folder+'_labels.json', 'r') as f:\n", 46 | " labels = json.load(f)\n", 47 | "dataset = ImageDataset('D:/dataset/table/table_line/Split1/'+ folder+'_input', labels, 8, scale=0.25)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 4, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# init model\n", 57 | "net = SplitModel(3)\n", 58 | "net = torch.nn.DataParallel(net).cuda()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 5, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "data": { 68 | "text/plain": [ 69 | "IncompatibleKeys(missing_keys=[], unexpected_keys=[])" 70 | ] 71 | }, 72 | "execution_count": 5, 73 | "metadata": {}, 74 | "output_type": "execute_result" 75 | } 76 | ], 77 | "source": [ 78 | "# load saved checkpoint\n", 79 | "net.load_state_dict(torch.load('split_model.pth'))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 6, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "DataParallel(\n", 91 | " (module): SplitModel(\n", 92 | " (sfcn): SFCN(\n", 93 | " (conv1): Sequential(\n", 94 | " (0): Conv2d(3, 18, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)\n", 95 | " (1): ReLU(inplace)\n", 96 | " )\n", 97 | " (conv2): Sequential(\n", 98 | " (0): Conv2d(18, 18, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)\n", 99 | " (1): ReLU(inplace)\n", 100 | " )\n", 101 | " (conv3): Sequential(\n", 102 | " (0): Conv2d(18, 18, kernel_size=(7, 7), stride=(1, 1), padding=(6, 6), dilation=(2, 2), bias=False)\n", 103 | " (1): ReLU(inplace)\n", 104 | " )\n", 105 | " )\n", 106 | " (rpn1): ProjectionNet(\n", 107 | " (conv_branch1): Sequential(\n", 108 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n", 109 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 110 | " (2): ReLU(inplace)\n", 111 | " )\n", 112 | " (conv_branch2): Sequential(\n", 113 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n", 114 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 115 | " (2): ReLU(inplace)\n", 116 | " )\n", 117 | " (conv_branch3): Sequential(\n", 118 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n", 119 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 120 | " (2): ReLU(inplace)\n", 121 | " )\n", 122 | " (project_module): ProjectionModule(\n", 123 | " (max_pool): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)\n", 124 | " (feature_conv): Sequential(\n", 125 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 126 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n", 127 | " (2): ReLU(inplace)\n", 128 | " )\n", 129 | " (prediction_conv): Sequential(\n", 130 | " (0): Dropout2d(p=0)\n", 131 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 132 | " )\n", 133 | " (feature_project): ProjectPooling()\n", 134 | " (prediction_project): Sequential(\n", 135 | " (0): ProjectPooling()\n", 136 | " (1): Sigmoid()\n", 137 | " )\n", 138 | " )\n", 139 | " )\n", 140 | " (rpn2): ProjectionNet(\n", 141 | " (conv_branch1): Sequential(\n", 142 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n", 143 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 144 | " (2): ReLU(inplace)\n", 145 | " )\n", 146 | " (conv_branch2): Sequential(\n", 147 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n", 148 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 149 | " (2): ReLU(inplace)\n", 150 | " )\n", 151 | " (conv_branch3): Sequential(\n", 152 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n", 153 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 154 | " (2): ReLU(inplace)\n", 155 | " )\n", 156 | " (project_module): ProjectionModule(\n", 157 | " (max_pool): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)\n", 158 | " (feature_conv): Sequential(\n", 159 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 160 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n", 161 | " (2): ReLU(inplace)\n", 162 | " )\n", 163 | " (prediction_conv): Sequential(\n", 164 | " (0): Dropout2d(p=0)\n", 165 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 166 | " )\n", 167 | " (feature_project): ProjectPooling()\n", 168 | " (prediction_project): Sequential(\n", 169 | " (0): ProjectPooling()\n", 170 | " (1): Sigmoid()\n", 171 | " )\n", 172 | " )\n", 173 | " )\n", 174 | " (rpn3): ProjectionNet(\n", 175 | " (conv_branch1): Sequential(\n", 176 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n", 177 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 178 | " (2): ReLU(inplace)\n", 179 | " )\n", 180 | " (conv_branch2): Sequential(\n", 181 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n", 182 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 183 | " (2): ReLU(inplace)\n", 184 | " )\n", 185 | " (conv_branch3): Sequential(\n", 186 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n", 187 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 188 | " (2): ReLU(inplace)\n", 189 | " )\n", 190 | " (project_module): ProjectionModule(\n", 191 | " (max_pool): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)\n", 192 | " (feature_conv): Sequential(\n", 193 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 194 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n", 195 | " (2): ReLU(inplace)\n", 196 | " )\n", 197 | " (prediction_conv): Sequential(\n", 198 | " (0): Dropout2d(p=0.3)\n", 199 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 200 | " )\n", 201 | " (feature_project): ProjectPooling()\n", 202 | " (prediction_project): Sequential(\n", 203 | " (0): ProjectPooling()\n", 204 | " (1): Sigmoid()\n", 205 | " )\n", 206 | " )\n", 207 | " )\n", 208 | " (rpn4): ProjectionNet(\n", 209 | " (conv_branch1): Sequential(\n", 210 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n", 211 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 212 | " (2): ReLU(inplace)\n", 213 | " )\n", 214 | " (conv_branch2): Sequential(\n", 215 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n", 216 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 217 | " (2): ReLU(inplace)\n", 218 | " )\n", 219 | " (conv_branch3): Sequential(\n", 220 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n", 221 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 222 | " (2): ReLU(inplace)\n", 223 | " )\n", 224 | " (project_module): ProjectionModule(\n", 225 | " (max_pool): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)\n", 226 | " (feature_conv): Sequential(\n", 227 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 228 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n", 229 | " (2): ReLU(inplace)\n", 230 | " )\n", 231 | " (prediction_conv): Sequential(\n", 232 | " (0): Dropout2d(p=0)\n", 233 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 234 | " )\n", 235 | " (feature_project): ProjectPooling()\n", 236 | " (prediction_project): Sequential(\n", 237 | " (0): ProjectPooling()\n", 238 | " (1): Sigmoid()\n", 239 | " )\n", 240 | " )\n", 241 | " )\n", 242 | " (cpn1): ProjectionNet(\n", 243 | " (conv_branch1): Sequential(\n", 244 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n", 245 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 246 | " (2): ReLU(inplace)\n", 247 | " )\n", 248 | " (conv_branch2): Sequential(\n", 249 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n", 250 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 251 | " (2): ReLU(inplace)\n", 252 | " )\n", 253 | " (conv_branch3): Sequential(\n", 254 | " (0): Conv2d(18, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n", 255 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 256 | " (2): ReLU(inplace)\n", 257 | " )\n", 258 | " (project_module): ProjectionModule(\n", 259 | " (max_pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n", 260 | " (feature_conv): Sequential(\n", 261 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 262 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n", 263 | " (2): ReLU(inplace)\n", 264 | " )\n", 265 | " (prediction_conv): Sequential(\n", 266 | " (0): Dropout2d(p=0)\n", 267 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 268 | " )\n", 269 | " (feature_project): ProjectPooling()\n", 270 | " (prediction_project): Sequential(\n", 271 | " (0): ProjectPooling()\n", 272 | " (1): Sigmoid()\n", 273 | " )\n", 274 | " )\n", 275 | " )\n", 276 | " (cpn2): ProjectionNet(\n", 277 | " (conv_branch1): Sequential(\n", 278 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n", 279 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 280 | " (2): ReLU(inplace)\n", 281 | " )\n", 282 | " (conv_branch2): Sequential(\n", 283 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n", 284 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 285 | " (2): ReLU(inplace)\n", 286 | " )\n", 287 | " (conv_branch3): Sequential(\n", 288 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n", 289 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 290 | " (2): ReLU(inplace)\n", 291 | " )\n", 292 | " (project_module): ProjectionModule(\n", 293 | " (max_pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n", 294 | " (feature_conv): Sequential(\n", 295 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 296 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n", 297 | " (2): ReLU(inplace)\n", 298 | " )\n", 299 | " (prediction_conv): Sequential(\n", 300 | " (0): Dropout2d(p=0)\n", 301 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 302 | " )\n", 303 | " (feature_project): ProjectPooling()\n", 304 | " (prediction_project): Sequential(\n", 305 | " (0): ProjectPooling()\n", 306 | " (1): Sigmoid()\n", 307 | " )\n", 308 | " )\n", 309 | " )\n", 310 | " (cpn3): ProjectionNet(\n", 311 | " (conv_branch1): Sequential(\n", 312 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n", 313 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 314 | " (2): ReLU(inplace)\n", 315 | " )\n", 316 | " (conv_branch2): Sequential(\n", 317 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n", 318 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 319 | " (2): ReLU(inplace)\n", 320 | " )\n", 321 | " (conv_branch3): Sequential(\n", 322 | " (0): Conv2d(36, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n", 323 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 324 | " (2): ReLU(inplace)\n", 325 | " )\n", 326 | " (project_module): ProjectionModule(\n", 327 | " (max_pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n", 328 | " (feature_conv): Sequential(\n", 329 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 330 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n", 331 | " (2): ReLU(inplace)\n", 332 | " )\n", 333 | " (prediction_conv): Sequential(\n", 334 | " (0): Dropout2d(p=0.3)\n", 335 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 336 | " )\n", 337 | " (feature_project): ProjectPooling()\n", 338 | " (prediction_project): Sequential(\n", 339 | " (0): ProjectPooling()\n", 340 | " (1): Sigmoid()\n", 341 | " )\n", 342 | " )\n", 343 | " )\n", 344 | " (cpn4): ProjectionNet(\n", 345 | " (conv_branch1): Sequential(\n", 346 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n", 347 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 348 | " (2): ReLU(inplace)\n", 349 | " )\n", 350 | " (conv_branch2): Sequential(\n", 351 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n", 352 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 353 | " (2): ReLU(inplace)\n", 354 | " )\n", 355 | " (conv_branch3): Sequential(\n", 356 | " (0): Conv2d(37, 6, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n", 357 | " (1): GroupNorm(3, 6, eps=1e-05, affine=True)\n", 358 | " (2): ReLU(inplace)\n", 359 | " )\n", 360 | " (project_module): ProjectionModule(\n", 361 | " (max_pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n", 362 | " (feature_conv): Sequential(\n", 363 | " (0): Conv2d(18, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 364 | " (1): GroupNorm(6, 18, eps=1e-05, affine=True)\n", 365 | " (2): ReLU(inplace)\n", 366 | " )\n", 367 | " (prediction_conv): Sequential(\n", 368 | " (0): Dropout2d(p=0)\n", 369 | " (1): Conv2d(18, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 370 | " )\n", 371 | " (feature_project): ProjectPooling()\n", 372 | " (prediction_project): Sequential(\n", 373 | " (0): ProjectPooling()\n", 374 | " (1): Sigmoid()\n", 375 | " )\n", 376 | " )\n", 377 | " )\n", 378 | " )\n", 379 | ")" 380 | ] 381 | }, 382 | "execution_count": 6, 383 | "metadata": {}, 384 | "output_type": "execute_result" 385 | } 386 | ], 387 | "source": [ 388 | "# change to eval mode\n", 389 | "net.eval()" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 24, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "img, label = dataset[50]\n", 399 | "r,c = net(img.unsqueeze(0))\n", 400 | "r = r[-1]>0.5\n", 401 | "c = c[-1]>0.5\n", 402 | "c = c.cpu().detach().numpy()\n", 403 | "r = r.cpu().detach().numpy()\n", 404 | "r_im = r.reshape((-1,1))*np.ones((r.shape[0],c.shape[0]))\n", 405 | "c_im = c.reshape((1,-1))*np.ones((r.shape[0],c.shape[0]))\n", 406 | "im = cv2.bitwise_or(r_im,c_im)" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 25, 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "Image.fromarray(img.numpy()[2]*255.).convert('L')" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 26, 421 | "metadata": { 422 | "scrolled": true 423 | }, 424 | "outputs": [], 425 | "source": [ 426 | "Image.fromarray(im*255.).convert('L')" 427 | ] 428 | } 429 | ], 430 | "metadata": { 431 | "kernelspec": { 432 | "display_name": "Python 3", 433 | "language": "python", 434 | "name": "python3" 435 | }, 436 | "language_info": { 437 | "codemirror_mode": { 438 | "name": "ipython", 439 | "version": 3 440 | }, 441 | "file_extension": ".py", 442 | "mimetype": "text/x-python", 443 | "name": "python", 444 | "nbconvert_exporter": "python", 445 | "pygments_lexer": "ipython3", 446 | "version": "3.6.5" 447 | } 448 | }, 449 | "nbformat": 4, 450 | "nbformat_minor": 2 451 | } 452 | -------------------------------------------------------------------------------- /data_utils/Merge_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: craig.li(solitaire10@163.com) 3 | # Script for generating data of Merge model.. 4 | 5 | 6 | import cv2 7 | import json 8 | import numpy as np 9 | import os 10 | 11 | from PIL import Image 12 | 13 | 14 | def make_merge_data(label, mask_img, threshold=5): 15 | """ 16 | Generates merge labels 17 | 18 | Args: 19 | label(list): Table row and column vector, same as labels of Split model. 20 | mask_img(ndarray): Mask drawed with labeled line. 21 | threshold(int): Threshold . 22 | Returns: 23 | h_matrix(ndarray): Label of Merge model data in horizontal direction. 24 | v_matrix(ndarray): Label of Merge model data in vertical direction. 25 | columns(list): Position of vertical lines. 26 | rows(list): Position of horizontal lines. 27 | """ 28 | h_line, v_line = label 29 | rows = find_connected_line(h_line) 30 | columns = find_connected_line(v_line) 31 | 32 | h = len(h_line) 33 | w = len(v_line) 34 | mask_img = cv2.resize(mask_img, (w, h), interpolation=cv2.INTER_AREA) 35 | append_columns = [0] + columns + [w] 36 | append_rows = [0] + rows + [h] 37 | 38 | h_matrix = np.zeros((len(rows), len(columns) + 1)) 39 | v_matrix = np.zeros((len(rows) + 1, len(columns))) 40 | 41 | for i in range(len(rows)): 42 | for j in range(len(append_columns) - 1): 43 | if np.count_nonzero(mask_img[rows[i], append_columns[j]:append_columns[j + 1]] < 10) > threshold: 44 | h_matrix[i, j] = 1 45 | else: 46 | h_matrix[i, j] = 0 47 | for i in range(len(append_rows) - 1): 48 | for j in range(len(columns)): 49 | if np.count_nonzero(mask_img[append_rows[i]:append_rows[i + 1], columns[j]] < 10) > threshold: 50 | v_matrix[i, j] = 1 51 | else: 52 | v_matrix[i, j] = 0 53 | rows = [x / h for x in rows] 54 | columns = [x / w for x in columns] 55 | return h_matrix, v_matrix, rows, columns 56 | 57 | 58 | def find_connected_line(lines, threshold=5): 59 | """ 60 | Gets center of lines. 61 | 62 | Args: 63 | lines(list): A vector indicates position of lines. 64 | threshold(int): Threshold for filtering lines that too close to the border. 65 | """ 66 | length = len(lines) 67 | i = 0 68 | blocks = [] 69 | 70 | def find_end(start): 71 | end = length - 1 72 | for j in range(start, length - 1): 73 | if lines[j + 1] == 0: 74 | end = j 75 | break 76 | return end 77 | 78 | while i < length: 79 | if lines[i] == 0: 80 | i += 1 81 | else: 82 | end = find_end(i) 83 | blocks.append((i, end)) 84 | i = end + 1 85 | if len(blocks) > 0: 86 | if blocks[0][0] <= threshold: 87 | blocks.pop(0) 88 | if len(blocks) > 0: 89 | if (length - blocks[-1][1]) <= threshold: 90 | blocks.pop(-1) 91 | lines_position = [int((x[0] + x[1]) / 2) for x in blocks] 92 | return lines_position 93 | 94 | 95 | if __name__ == "__main__": 96 | mask_img_dir = 'test_out_mask_dir' 97 | with open('test_labels.json', 'r') as f: 98 | dataset = json.load(f) 99 | merge_dict = {} 100 | for id, label in dataset.items(): 101 | labels = [np.array(label['row']), np.array(label['column'])] 102 | mask_img = np.array(Image.open(os.path.join(mask_img_dir, id + '.jpg'))) 103 | h_matrix, v_matrix, rows, columns = make_merge_data(labels, mask_img) 104 | merge_dict[id] = {'h_matrix': [list(x) for x in list(h_matrix)], 105 | 'v_matrix': [list(x) for x in list(v_matrix)], 'rows': rows, 'columns': columns} 106 | with open('merge_dict.json', 'w') as f: 107 | f.write(json.dumps(merge_dict)) 108 | -------------------------------------------------------------------------------- /data_utils/Split_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: craig.li(solitaire10@163.com) 3 | # Script for generating data of Split model.. 4 | 5 | 6 | import cv2 7 | import json 8 | import numpy as np 9 | import os 10 | import xml.etree.ElementTree as ET 11 | 12 | from functools import reduce 13 | from PIL import Image, ImageDraw 14 | 15 | 16 | def make_split_data(file, lines, img_dir, out_img_dir, out_mask_dir, length_threshold=20): 17 | """ 18 | Generate data for training Split model, crops and saves image blocks of table from original image, 19 | returns labels of image blocks 20 | 21 | Args: 22 | file(str): File name of original image file. 23 | img_dir(str): The directory contains image files. 24 | out_img_dir(str): A directory to save croped image blocks. 25 | out_mask_dir(str): A directory to save masks of croped image blocks. 26 | length_threshold(int):Threshold of filtering blocks with short width or height. 27 | Returns: 28 | label_dict(dict): A directory, for each item, key is the file id and value contains 'rows' and 'columns', 29 | which are two vectors, 1 indicates there is a line in corresponding row or column. 30 | """ 31 | # loads original image 32 | im_path = os.path.join(img_dir, file) 33 | im = Image.open(im_path) 34 | im = im.convert('L') 35 | im_array = np.array(im) 36 | # init blank images for drawing masks. 37 | masks = [Image.new('L', im.size, (0,)) for i in range(4)] 38 | draws = [ImageDraw.Draw(x) for x in masks] 39 | # Calculates angle of the image according to labeled boxes. 40 | thetas = [] 41 | for line in lines: 42 | draws[int(line['type'])].polygon(line['coordinates'], fill=1) 43 | if line['type'] == 'h': # line['type'] should be 'v' or 'h' for vertical and horizontal lines 44 | up_left = line['coordinates'][:2] 45 | up_right = line['coordinates'][2:4] 46 | theta = np.arctan((up_right[1] - up_left[1]) / (up_right[0] - up_left[0])) 47 | thetas.append(theta) 48 | theta = np.average(thetas) 49 | matrix = np.array([[np.cos(-theta), -np.sin(-theta), 0], 50 | [np.sin(-theta), np.cos(-theta), 0]]) 51 | # rotates mask images and original images. 52 | masks = [cv2.warpAffine(np.array(x), matrix, np.array(x).shape[::-1]) for x in masks] 53 | mask_data = np.array([np.array(x) for x in masks]) 54 | im_array = np.array(cv2.warpAffine(im_array, matrix, im_array.shape[::-1])) 55 | mask_img = cv2.bitwise_or(cv2.bitwise_or(mask_data[0], mask_data[2]), 56 | cv2.bitwise_or(mask_data[1], mask_data[3])) * 255. 57 | mask_img = np.array(mask_img, dtype=np.uint8) 58 | 59 | # Splits different tables by connected components. 60 | num, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_img, 4, cv2.CV_32S) 61 | label_dict = {} 62 | for i in range(1, num): 63 | image, contours = cv2.findContours(np.array((labels == i) * 255., dtype=np.uint8), cv2.RETR_EXTERNAL, 64 | cv2.CHAIN_APPROX_SIMPLE) 65 | x, y, w, h = cv2.boundingRect(image[0]) 66 | if w < length_threshold or h < length_threshold: 67 | continue 68 | row_start = int(y) 69 | row_end = int(y + h) 70 | column_start = int(x) 71 | column_end = int(x + w) 72 | # filters blocks 73 | if (row_end - row_start) <= 0 or (column_end - column_start) <= 0: 74 | continue 75 | row_table = cv2.bitwise_or(mask_data[0][row_start:row_end, column_start:column_end], 76 | mask_data[2][row_start:row_end, column_start:column_end]) 77 | column_table = cv2.bitwise_or(mask_data[1][row_start:row_end, column_start:column_end], 78 | mask_data[3][row_start:row_end, column_start:column_end]) 79 | # makes column and row labels. 80 | row_label = [] 81 | for row in range(row_table.shape[0]): 82 | row_label.append(1 if np.sum(row_table[row, :]) > 0 else 0) 83 | column_lable = [] 84 | for column in range(column_table.shape[1]): 85 | column_lable.append(1 if np.sum(column_table[:, column]) > 0 else 0) 86 | key = file[:-4] + '_' + str(i) 87 | label_dict[key] = {'row': row_label, 'column': column_lable} 88 | # crops image blocks of tables. 89 | mask_table = mask_img[row_start:row_end, column_start:column_end] 90 | original_img = im_array[row_start:row_end, column_start:column_end] 91 | mask_table = Image.fromarray(np.array(mask_table, dtype=np.uint8)) 92 | original_img = Image.fromarray(np.array(original_img, dtype=np.uint8)) 93 | # saves cropped blocks. 94 | original_img.save(os.path.join(out_img_dir, key + '.jpg')) 95 | mask_table.save(os.path.join(out_mask_dir, key + '.jpg')) 96 | return label_dict 97 | 98 | 99 | def extract_lines_from_xml(root_dir, file): 100 | """ 101 | Extracts labels of lines from original xml file. 102 | You can rewrite this function to extract lines from your original label. 103 | 104 | Args: 105 | root_dir(str): The directory to the folder which contains the file. 106 | file(str): The file name of the xml file. 107 | Returns: 108 | lines(list): A list of lines' coordinates, the coordinates should be a list of 8 integers, indicates x, y 109 | coordinates of a corner in the flowing order up_left->up_right->down_right->down_left. Type should be 110 | 'v' or 'h' for vertical and horizontal lines. 111 | """ 112 | # extracts data from xml file 113 | tree = ET.parse(os.path.join(root_dir, file)) 114 | root = tree.getroot() 115 | elements = root.findall('object') 116 | # Constructs a list contains lines. 117 | lines = [] 118 | for element in elements: 119 | category = element.find('name').text 120 | bbox = element.find('bndbox') 121 | coordinates = [] + bbox.find('leftup').text.split(',') + bbox.find('rightup').text.split(',') + bbox.find( 122 | 'rightdown').text.split(',') + bbox.find('leftdown').text.split(',') 123 | coordinates = [int(x) for x in coordinates] 124 | lines.append({'type': category, 'coordinates': coordinates}) 125 | return lines 126 | 127 | 128 | def merge_dicts(d0, d1): 129 | """ 130 | Merges two directories. 131 | """ 132 | for k, v in d1.items(): 133 | d0[k] = v 134 | return d0 135 | 136 | 137 | if __name__ == "__main__": 138 | root_dir = 'root_dir' 139 | img_dir = 'img_dir' 140 | out_img_dir = 'test_out_img_dir' 141 | out_mask_dir = 'test_out_mask_dir' 142 | if not os.path.exists(out_img_dir): 143 | os.mkdir(out_img_dir) 144 | if not os.path.exists(out_mask_dir): 145 | os.mkdir(out_mask_dir) 146 | files = os.listdir(img_dir) 147 | json_path = 'test_labels.json' 148 | ids = [x[:-4] for x in files] 149 | label_dicts = [] 150 | for id in ids: 151 | file = id + '.xml' 152 | lines = extract_lines_from_xml(root_dir, file) 153 | img_name = id + '.jpg' 154 | label_dict = make_split_data(img_name, lines, img_dir, out_img_dir, out_mask_dir) 155 | label_dicts.append(label_dict) 156 | labels = reduce(merge_dicts, label_dicts) 157 | with open(json_path, 'w') as f: 158 | f.write(json.dumps(labels)) 159 | -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/data_utils/__init__.py -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: craig.li(solitaire10@163.com) 3 | # Image dataset for training and testing Split and Merge models(ICDAR 2019 :Deep Splitting and Merging for Table Structure Decomposition). 4 | 5 | import cv2 6 | import numpy as np 7 | import os 8 | import torch 9 | 10 | from PIL import Image 11 | from torch.utils.data import Dataset 12 | 13 | 14 | class ImageDataset(Dataset): 15 | """Image Dataset""" 16 | 17 | def __init__(self, img_dir, labels_dict, output_width, scale=0.5, 18 | min_width=40, mode='split', suffix='.npy'): 19 | """ 20 | Initialization of the dataset 21 | 22 | Args: 23 | img_dir(str): The directory of images 24 | labels_dict(dict): A dictionary stores ids of images and 25 | corresponding ground truth, which are two vectors with 26 | w and h elements, where w is the width of image and h 27 | is the height of image, indicates the probability of 28 | there is a line in that row or column. 29 | output_width(int): Defines the width of the output tensor, 30 | scale(float): The scale of resizing image. 31 | min_width(int): Specifies minimal width of resizing image. 32 | mode(str): The model should be one of 'split' and 'merge' 33 | """ 34 | self.labels_dict = labels_dict 35 | self.ids = list(labels_dict.keys()) 36 | self.nSamples = len(self.ids) 37 | self.img_dir = img_dir 38 | self.output_width = output_width 39 | self.scale = scale 40 | self.min_width = min_width 41 | self.mode = mode 42 | self.suffix = suffix 43 | 44 | def __len__(self): 45 | return self.nSamples 46 | 47 | def __getitem__(self, index): 48 | assert index <= len(self), 'index range error' 49 | id = self.ids[index] 50 | if self.suffix == '.npy': 51 | img = np.load(os.path.join(self.img_dir, id + self.suffix)) 52 | elif self.suffix == '.jpg': 53 | img = Image.open(os.path.join(self.img_dir, id + self.suffix)) 54 | img = np.array(img) 55 | c, h, w = img.shape 56 | new_h = int(self.scale * h) if int( 57 | self.scale * h) > self.min_width else self.min_width 58 | new_w = int(self.scale * w) if int( 59 | self.scale * w) > self.min_width else self.min_width 60 | img = np.array( 61 | [cv2.resize(img[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in 62 | range(c)]) 63 | img_array = np.array(img) / 255. 64 | 65 | if self.mode == 'merge': 66 | labels = self.labels_dict[id] 67 | rows = labels['rows'] 68 | columns = labels['columns'] 69 | h_matrix = labels['h_matrix'] 70 | v_matrix = labels['v_matrix'] 71 | 72 | img_tensor = torch.from_numpy(img_array).type(torch.FloatTensor) 73 | row_label = torch.from_numpy(np.array(h_matrix)).type(torch.FloatTensor) 74 | column_label = torch.from_numpy(np.array(v_matrix)).type( 75 | torch.FloatTensor) 76 | 77 | return img_tensor, (row_label, column_label), (rows, columns) 78 | else: 79 | labels = self.labels_dict[id] 80 | row_label = labels['row'] 81 | column_label = labels['column'] 82 | 83 | # resize ground truth to proper size 84 | width = int(np.floor(new_w / self.output_width)) 85 | height = int(np.floor(new_h / self.output_width)) 86 | row_label = np.array([row_label]).T * np.ones((len(row_label), width)) 87 | column_label = np.array(column_label) * np.ones( 88 | (height, len(column_label))) 89 | 90 | row_label = np.array(row_label, dtype=np.uint8) 91 | column_label = np.array(column_label, dtype=np.uint8) 92 | 93 | row_label = cv2.resize(row_label, (width, new_h)) 94 | column_label = cv2.resize(column_label, (new_w, height)) 95 | 96 | row_label = row_label[:, 0] 97 | column_label = column_label[0, :] 98 | 99 | img_tensor = torch.from_numpy(img_array).type(torch.FloatTensor) 100 | row_label = torch.from_numpy(row_label).type(torch.FloatTensor) 101 | column_label = torch.from_numpy(column_label).type(torch.FloatTensor) 102 | 103 | return img_tensor, (row_label, column_label) 104 | -------------------------------------------------------------------------------- /images/merge_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/images/merge_example.jpg -------------------------------------------------------------------------------- /images/split_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/images/split_input.jpg -------------------------------------------------------------------------------- /images/split_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/images/split_output.jpg -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/loss/__init__.py -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: craig.li(solitaire10@163.com) 3 | # Defines two loss functions which are used in split model and merge model. 4 | 5 | import torch 6 | 7 | from functools import reduce 8 | 9 | 10 | def bce_loss(pred, label): 11 | """ 12 | Loss function for Split model 13 | Args: 14 | pred(torch.tensor): Prediction 15 | label(torch.tensor): Ground truth 16 | Return: 17 | loss(torch.tensor): Loss of the input image 18 | """ 19 | row_pred, column_pred = pred 20 | row_label, column_label = label 21 | 22 | criterion = torch.nn.BCELoss(torch.tensor([10.])).cuda() 23 | 24 | lr3 = criterion(row_pred[0].view(-1), row_label.view(-1)) 25 | 26 | lc3 = criterion(column_pred[0].view(-1), column_label.view(-1)) 27 | 28 | loss = lr3 + lc3 29 | 30 | return loss 31 | 32 | 33 | def merge_loss(pred, label, weight): 34 | """ 35 | Loss function for training Merge model which is Binary Cross-Entropy loss. 36 | 37 | Args: 38 | pred(torch.tensor): Prediction of the input image 39 | label(torch.tensor): Ground truth of corresponding image 40 | weight(float): Weight to balance positive and negative samples 41 | Return: 42 | loss(torch.tensor): Loss of the input image 43 | D(torch.tensor): A matrix with size (M - 1) x N, indicates the probability 44 | of two neighbor cells should be merged in vertical direction. 45 | R(torch.tensor): A matrix with size M * (N - 1), indicates the probability 46 | of two neighbor cells should be merged in horizontal direction. 47 | Where M is the height of the label, and N is the width of the label 48 | """ 49 | pu, pd, pl, pr = pred 50 | D = 0.5 * pu[:, :-1, :] * pd[:, 1:, :] + 0.25 * (pu[:, :-1, :] + pd[:, 1:, :]) 51 | R = 0.5 * pr[:, :, :-1] * pl[:, :, 1:] + 0.25 * (pr[:, :, :-1] + pl[:, :, 1:]) 52 | 53 | DT, RT = label 54 | 55 | criterion = torch.nn.BCELoss(torch.tensor([weight])).cuda() 56 | 57 | ld = criterion(D.view(-1), DT.view(-1)) 58 | lr = criterion(R.view(-1), RT.view(-1)) 59 | losses = [] 60 | if D.view(-1).shape[0] != 0: 61 | losses.append(ld) 62 | if R.view(-1).shape[0] != 0: 63 | losses.append(lr) 64 | if len(losses) == 0: 65 | loss = torch.tensor(0).cuda() 66 | else: 67 | loss = reduce(lambda x, y: x + y, losses) 68 | return loss, D, R 69 | -------------------------------------------------------------------------------- /merge/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/merge/__init__.py -------------------------------------------------------------------------------- /merge/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: craig.li(solitaire10@163.com) 3 | # Script for testing Merge model. 4 | 5 | 6 | import argparse 7 | import json 8 | import numpy as np 9 | import os 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | 13 | from dataset.dataset import ImageDataset 14 | from modules.merge_modules import MergeModel 15 | from loss.loss import merge_loss 16 | from torch.utils.data import DataLoader 17 | 18 | 19 | def test(opt, net, data=None): 20 | """ 21 | Test script for Merge model 22 | Args: 23 | opt(dic): Options 24 | net(torch.model): Merge model instance 25 | data(dataloader): Dataloader or None, if load data with configuration in opt. 26 | Return: 27 | total_loss(torch.tensor): The total loss of the dataset 28 | precision(torch.tensor): Precision (TP / TP + FP) 29 | recall(torch.tensor): Recall (TP / TP + FN) 30 | f1(torch.tensor): f1 score (2 * precision * recall / (precision + recall)) 31 | """ 32 | if not data: 33 | with open(opt.json_dir, 'r') as f: 34 | labels = json.load(f) 35 | dir_img = opt.img_dir 36 | 37 | test_set = ImageDataset(dir_img, labels, opt.featureW, scale=opt.scale, 38 | mode='merge') 39 | test_loader = DataLoader(test_set, batch_size=opt.batch_size, shuffle=False) 40 | else: 41 | test_loader = data 42 | 43 | loss_func = merge_loss 44 | 45 | for epoch in range(1): 46 | net.eval() 47 | epoch_loss = 0 48 | number_batchs = 0 49 | total_tp = 0 50 | total_tn = 0 51 | total_fp = 0 52 | total_fn = 0 53 | for i, b in enumerate(test_loader): 54 | with torch.no_grad(): 55 | img, label, arc = b 56 | if opt.gpu: 57 | img = img.cuda() 58 | label = [x.cuda() for x in label] 59 | pred_label = net(img, arc) 60 | loss, D, R = loss_func(pred_label, label, 10.) 61 | epoch_loss += loss 62 | 63 | tp = torch.sum( 64 | ((D.view(-1)[ 65 | (label[0].view(-1) > 0.5).type(torch.ByteTensor)] > 0.5).type( 66 | torch.IntTensor) == 67 | label[0].view(-1)[ 68 | (label[0].view(-1) > 0.5).type(torch.ByteTensor)].type( 69 | torch.IntTensor))).item() + torch.sum( 70 | ((R.view(-1)[ 71 | (label[1].view(-1) > 0.5).type(torch.ByteTensor)] > 0.5).type( 72 | torch.IntTensor) == 73 | label[1].view(-1)[ 74 | (label[1].view(-1) > 0.5).type(torch.ByteTensor)].type( 75 | torch.IntTensor))).item() 76 | tn = torch.sum( 77 | ((D.view(-1)[ 78 | (label[0].view(-1) <= 0.5).type(torch.ByteTensor)] > 0.5).type( 79 | torch.IntTensor) == 80 | label[0].view(-1)[ 81 | (label[0].view(-1) <= 0.5).type(torch.ByteTensor)].type( 82 | torch.IntTensor))).item() + torch.sum( 83 | ((R.view(-1)[ 84 | (label[1].view(-1) <= 0.5).type(torch.ByteTensor)] > 0.5).type( 85 | torch.IntTensor) == 86 | label[1].view(-1)[ 87 | (label[1].view(-1) <= 0.5).type(torch.ByteTensor)].type( 88 | torch.IntTensor))).item() 89 | fn = torch.sum( 90 | (label[0].view(-1) > 0.5).type(torch.ByteTensor)).item() + torch.sum( 91 | (label[1].view(-1) > 0.5).type(torch.ByteTensor)).item() - tp 92 | fp = torch.sum( 93 | (label[0].view(-1) < 0.5).type(torch.ByteTensor)).item() + torch.sum( 94 | (label[1].view(-1) < 0.5).type(torch.ByteTensor)).item() - tn 95 | 96 | total_fn += fn 97 | total_fp += fp 98 | total_tn += tn 99 | total_tp += tp 100 | number_batchs += 1 101 | total_loss = epoch_loss / number_batchs 102 | precision = total_tp / (total_tp + total_fp) 103 | recall = total_tp / (total_tp + total_fn) 104 | f1 = 2 * precision * recall / (precision + recall) 105 | print( 106 | 'Validation finished ! Loss: {0} ; Precision: {1} ; Recall: {2} ; F1 Score: {3}'.format( 107 | total_loss, 108 | precision, 109 | recall, 110 | f1)) 111 | return total_loss, precision, recall, f1 112 | 113 | 114 | def model_select(opt, net): 115 | """ 116 | Select best model with highest f1 score 117 | Args: 118 | opt(dict): Options 119 | net(torch.model): Merge model instance 120 | """ 121 | model_dir = opt.model_dir 122 | models = os.listdir(model_dir) 123 | losses = [] 124 | f1s = [] 125 | for model in models: 126 | print(model) 127 | net.load_state_dict(torch.load(os.path.join(model_dir, model))) 128 | loss, precision, recall, f1 = test(opt, net) 129 | losses.append(loss) 130 | f1s.append(f1) 131 | min_loss_index = np.argmin(np.array(losses)) 132 | max_f1_index = np.argmax(np.array(f1s)) 133 | print('losses', min_loss_index, losses[min_loss_index], 134 | models[min_loss_index]) 135 | print('f1 score', max_f1_index, f1s[max_f1_index], models[max_f1_index]) 136 | 137 | 138 | if __name__ == '__main__': 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('--batch_size', type=int, default=32, 141 | help='batch size of the training set') 142 | parser.add_argument('--gpu', type=bool, default=True, help='if use gpu') 143 | parser.add_argument('--gpu_list', type=str, default='0', 144 | help='which gpu could use') 145 | parser.add_argument('--model_dir', type=str, required=True, 146 | help='saved directory for output models') 147 | parser.add_argument('--json_dir', type=str, required=True, 148 | help='labels of the data') 149 | parser.add_argument('--img_dir', type=str, required=True, 150 | help='image directory for input data') 151 | parser.add_argument('--featureW', type=int, default=8, help='width of output') 152 | parser.add_argument('--scale', type=float, default=0.5, 153 | help='scale of the image') 154 | 155 | opt = parser.parse_args() 156 | 157 | net = MergeModel(3) 158 | if opt.gpu: 159 | cudnn.benchmark = True 160 | cudnn.deterministic = True 161 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_list 162 | net = torch.nn.DataParallel(net).cuda() 163 | 164 | net.load_state_dict(torch.load('merge_models_1225/CP90.pth')) 165 | print(test(opt, net)) 166 | # model_select(opt, net) 167 | -------------------------------------------------------------------------------- /merge/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: craig.li(solitaire10@163.com) 3 | # Script for training merge model 4 | 5 | 6 | import argparse 7 | import json 8 | import os 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | 12 | from torch import optim 13 | from torch.utils.data import DataLoader 14 | from dataset.dataset import ImageDataset 15 | from modules.merge_modules import MergeModel 16 | from merge.test import test 17 | from loss.loss import merge_loss 18 | 19 | 20 | def train(opt, net): 21 | """ 22 | Train the merge model 23 | Args: 24 | opt(dic): Options 25 | net(torch.model): Merge model instance 26 | """ 27 | # load labels 28 | with open(opt.json_dir, 'r') as f: 29 | labels = json.load(f) 30 | dir_img = opt.img_dir 31 | 32 | with open(opt.val_json, 'r') as f: 33 | val_labels = json.load(f) 34 | val_img_dir = opt.val_img_dir 35 | 36 | train_set = ImageDataset(dir_img, labels, opt.featureW, scale=opt.scale, 37 | mode='merge') 38 | train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True) 39 | 40 | val_set = ImageDataset(val_img_dir, val_labels, opt.featureW, scale=opt.scale, 41 | mode='merge') 42 | val_loader = DataLoader(val_set, batch_size=opt.batch_size, shuffle=False) 43 | 44 | print('Data loaded!') 45 | 46 | # defines loss function 47 | loss_func = merge_loss 48 | optimizer = optim.Adam(net.parameters(), 49 | lr=opt.lr, 50 | weight_decay=0.001) 51 | best_f1 = 0 52 | for epoch in range(opt.epochs): 53 | print('epoch:{}'.format(epoch + 1)) 54 | net.train() 55 | epoch_loss = 0 56 | number_batchs = 0 57 | for i, b in enumerate(train_loader): 58 | img, label, arc = b 59 | if opt.gpu: 60 | img = img.cuda() 61 | label = [x.cuda() for x in label] 62 | pred_label = net(img, arc) 63 | loss, _, _ = loss_func(pred_label, label, 10.) 64 | if loss.requires_grad: 65 | epoch_loss += loss 66 | optimizer.zero_grad() 67 | loss.backward() 68 | optimizer.step() 69 | number_batchs += 1 70 | 71 | print('Epoch finished ! Loss: {0} '.format(epoch_loss / number_batchs)) 72 | val_loss, precision, recall, f1 = test(opt, net, val_loader) 73 | # save model if best f1 score less than current f1 score 74 | if f1 > best_f1: 75 | best_f1 = f1 76 | torch.save(net.state_dict(), 77 | opt.saved_dir + 'CP{}.pth'.format(epoch + 1)) 78 | # write training information of current epoch to the log file 79 | with open(os.path.join(opt.saved_dir, 'log.txt'), 'a') as f: 80 | f.write( 81 | 'Epoch {0}, val loss : {1}, precision : {2}, recall : {3}, f1 score : {4} \n tra loss : {5} \n\n'.format( 82 | epoch + 1, val_loss, precision, recall, f1, 83 | epoch_loss / number_batchs)) 84 | 85 | 86 | if __name__ == '__main__': 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--batch_size', type=int, default=1, 89 | help='batch size of the training set') 90 | parser.add_argument('--epochs', type=int, default=50, help='epochs') 91 | parser.add_argument('--gpu', type=bool, default=True, help='if use gpu') 92 | parser.add_argument('--gpu_list', type=str, default='0', 93 | help='which gpu could use') 94 | parser.add_argument('--lr', type=float, default=0.00075, 95 | help='learning rate, default=0.00075 for Adam') 96 | parser.add_argument('--saved_dir', type=str, required=True, 97 | help='saved directory for output models') 98 | parser.add_argument('--json_dir', type=str, required=True, 99 | help='labels of the data') 100 | parser.add_argument('--img_dir', type=str, required=True, 101 | help='image directory for input data') 102 | parser.add_argument('--val_json', type=str, required=True, 103 | help='labels of the validation data') 104 | parser.add_argument('--val_img_dir', type=str, required=True, 105 | help='image directory for validation data') 106 | parser.add_argument('--featureW', type=int, default=8, help='width of output') 107 | parser.add_argument('--scale', type=float, default=0.5, 108 | help='scale of the image') 109 | 110 | opt = parser.parse_args() 111 | 112 | net = MergeModel(3) 113 | if opt.gpu: 114 | cudnn.benchmark = True 115 | cudnn.deterministic = True 116 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_list 117 | net = torch.nn.DataParallel(net).cuda() 118 | 119 | if not os.path.exists(opt.saved_dir): 120 | os.mkdir(opt.saved_dir) 121 | 122 | with open(os.path.join(opt.saved_dir, 'log.txt'), 'w') as f: 123 | configuration = '--batch_size: {0} \n' \ 124 | '--epochs: {1} \n' \ 125 | '--gpu: {2} \n' \ 126 | '--gpu_list: {3} \n' \ 127 | '--lr: {4} \n' \ 128 | '--saved_dir: {5} \n' \ 129 | '--json_dir: {6} \n' \ 130 | '--img_dir: {7} \n' \ 131 | '--val_json: {8} \n' \ 132 | '--val_img_dir: {9} \n' \ 133 | '--featureW: {10} \n' \ 134 | '--scale: {11} \n'.format( 135 | opt.batch_size, opt.epochs, opt.gpu, opt.gpu_list, opt.lr, opt.saved_dir, 136 | opt.json_dir, opt.img_dir, 137 | opt.val_json, opt.val_img_dir, opt.featureW, opt.scale) 138 | f.write(configuration + '\n\n Logs: \n') 139 | 140 | train(opt, net) 141 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/modules/__init__.py -------------------------------------------------------------------------------- /modules/merge_modules.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: craig.li(solitaire10@163.com) 3 | # Defines several modules to build up Merge model 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class MergeModel(nn.Module): 10 | """ 11 | Merge model refer to ICDAR 2019 Deep Splitting and Merging for Table Structure Decomposition 12 | """ 13 | 14 | def __init__(self, input_channels): 15 | """ 16 | Initialization of merge model 17 | Args: 18 | input_channels(int): The number of input data channel 19 | """ 20 | super(MergeModel, self).__init__() 21 | # shared full conv net 22 | self.sfcn = SharedFCN(input_channels) 23 | # four branches: up, down, left, right 24 | self.rpn1 = ProjectionNet(18, True, 0.3) 25 | self.rpn2 = ProjectionNet(36, False, 0) 26 | self.rpn3 = ProjectionNet(36, True, 0.3) 27 | 28 | self.dpn1 = ProjectionNet(18, True, 0.3) 29 | self.dpn2 = ProjectionNet(36, False, 0) 30 | self.dpn3 = ProjectionNet(36, True, 0.3) 31 | 32 | self.upn1 = ProjectionNet(18, True, 0.3) 33 | self.upn2 = ProjectionNet(36, False, 0) 34 | self.upn3 = ProjectionNet(36, True, 0.3) 35 | 36 | self.lpn1 = ProjectionNet(18, True, 0.3) 37 | self.lpn2 = ProjectionNet(36, False, 0) 38 | self.lpn3 = ProjectionNet(36, True, 0.3) 39 | 40 | self._init_weights() 41 | 42 | def _init_weights(self): 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 45 | torch.nn.init.kaiming_normal_(m.weight) 46 | if m.bias is not None: 47 | m.bias.data.fill_(0.01) 48 | elif isinstance(m, nn.BatchNorm2d): 49 | m.weight.data.fill_(1) 50 | m.bias.data.zero_() 51 | elif isinstance(m, nn.Linear): 52 | m.weight.data.normal_(0, 0.01) 53 | m.bias.data.zero_() 54 | 55 | def forward(self, x, arc): 56 | """ 57 | Forward pass of the merge model 58 | Args: 59 | x(torch.tensor): Input tensor with shape (b,c,h,w) 60 | arc(list(torch.tensor)): Grid architecture for grid project pooling, 61 | contains two lists of tensor indicates coordinates of horizontal and 62 | vertical lines 63 | Return: 64 | output(list(torch.tensor)): output of the merge model, a list contains four matrices 65 | indicates four direction 66 | """ 67 | feature = self.sfcn(x) 68 | 69 | right_feature, r3 = self.rpn1(feature, 70 | arc) # self.rpn3(self.rpn2(self.rpn1(feature, arc), arc), arc) 71 | 72 | left_feature, l3 = self.lpn1(feature, 73 | arc) # self.lpn3(self.lpn2(self.lpn1(feature, arc), arc), arc) 74 | 75 | up_feature, u3 = self.upn1(feature, 76 | arc) # self.upn3(self.upn2(self.upn1(feature, arc), arc), arc) 77 | 78 | down_feature, d3 = self.dpn1(feature, 79 | arc) # self.dpn3(self.dpn2(self.dpn1(feature, arc), arc), arc) 80 | 81 | output = [u3.squeeze(1), d3.squeeze(1), l3.squeeze(1), r3.squeeze(1)] 82 | return output 83 | 84 | 85 | class GridProjectPooling(nn.Module): 86 | """ 87 | Grid project pooling, every pixel location replaces its value with the average of all 88 | pixels within its grid element: 89 | $$ \hat F_{ij} = \frac {1}{\lvert\Omega(i,j)\rvert} \sum_{i',j' \in \Omega (i,j)} F_i',j' $$ 90 | """ 91 | 92 | def __init__(self): 93 | """ 94 | Initialization of grid project pooling module 95 | """ 96 | super(GridProjectPooling, self).__init__() 97 | 98 | def forward(self, x, architecture): 99 | """ 100 | Forward pass of this module 101 | Args: 102 | x(torch.tensor): Input tensor with shape (b, c, h, w) 103 | architecture(list(torch.tensor)): Grid architecture for grid project pooling, 104 | contains two lists of tensor indicates coordinates of horizontal and 105 | vertical lines 106 | Return: 107 | output(torch.tensor): Output tensor of this module, the shape is same with 108 | input tensor 109 | matrix(torch.tensor): A M x N matrix, where M and N indicates the number of 110 | lines in horizontal and vertical directions. 111 | """ 112 | b, c, h, w = x.size() 113 | h_line, v_line = architecture 114 | self.h_line = [torch.Tensor([0]).type( 115 | torch.DoubleTensor).cuda()] + h_line + [ 116 | torch.Tensor([1]).type(torch.DoubleTensor).cuda()] 117 | self.v_line = [torch.Tensor([0]).type( 118 | torch.DoubleTensor).cuda()] + v_line + [ 119 | torch.Tensor([1]).type(torch.DoubleTensor).cuda()] 120 | self.h_line = [(h * x).round().type(torch.IntTensor) for x in self.h_line] 121 | self.v_line = [(w * x).round().type(torch.IntTensor) for x in self.v_line] 122 | 123 | rows = [self.h_line[i + 1] - self.h_line[i] for i in 124 | range(len(self.h_line) - 1)] 125 | columns = [self.v_line[i + 1] - self.v_line[i] for i in 126 | range(len(self.v_line) - 1)] 127 | 128 | slices = torch.split(x, rows, 2) 129 | means = [torch.mean(y, 2).unsqueeze(2) for y in slices] 130 | matrix = torch.cat(means, 2) 131 | blocks = [means[i].repeat(1, 1, rows[i], 1) for i in range(len(means))] 132 | block = torch.cat(blocks, 2) 133 | 134 | means = [torch.mean(y, 3).unsqueeze(3) for y in 135 | torch.split(matrix, columns, 3)] 136 | matrix = torch.cat(means, 3) 137 | 138 | block_mean = [torch.mean(y, 3).unsqueeze(3) for y in 139 | torch.split(block, columns, 3)] 140 | 141 | blocks = [block_mean[i].repeat(1, 1, 1, columns[i]) for i in 142 | range(len(block_mean))] 143 | output = torch.cat(blocks, 3) 144 | """ 145 | Old version Grid pooling 146 | v_blocks = [] 147 | matrix = torch.from_numpy(np.ones([b, c, len(self.h_line) - 1, len(self.v_line) - 1])).type( 148 | torch.FloatTensor).cuda() 149 | for i in range(len(self.h_line) - 1): 150 | h_blocks = [] 151 | for j in range(len(self.v_line) - 1): 152 | output_block = torch.from_numpy( 153 | np.ones([b, c, self.h_line[i + 1] - self.h_line[i], self.v_line[j + 1] - self.v_line[j]])).type( 154 | torch.FloatTensor).cuda() 155 | mean = torch.mean( 156 | torch.mean(x[:, :, self.h_line[i]:self.h_line[i + 1], self.v_line[j]:self.v_line[j + 1]], 2), 157 | 2).cuda() 158 | matrix[:, :, i, j] = mean 159 | h_blocks.append(mean.unsqueeze(0).transpose(0, 2) * output_block) 160 | h_block = torch.cat(h_blocks, 3) 161 | v_blocks.append(h_block) 162 | output = torch.cat(v_blocks, 2) 163 | """ 164 | return output, matrix 165 | 166 | 167 | class ProjectionNet(nn.Module): 168 | """ 169 | Projection Module contains three parallel conv layers with dilation factor 1,2,3, followed by 170 | a grid project pooling module 171 | """ 172 | 173 | def __init__(self, input_channels, sigmoid=False, dropout=0.5): 174 | """ 175 | Initialization of Project module 176 | Args: 177 | input_channels(int): The number of input channels of the module 178 | sigmoid(bool): If need to ge the output matrix 179 | dropout(float): Drop out ratio 180 | """ 181 | super(ProjectionNet, self).__init__() 182 | self.conv_branch1 = nn.Sequential( 183 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=1, dilation=1), 184 | nn.GroupNorm(3, 6), nn.ReLU(True)) 185 | self.conv_branch2 = nn.Sequential( 186 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=2, dilation=2), 187 | nn.GroupNorm(3, 6), nn.ReLU(True)) 188 | self.conv_branch3 = nn.Sequential( 189 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=3, dilation=3), 190 | nn.GroupNorm(3, 6), nn.ReLU(True)) 191 | self.sigmoid = sigmoid 192 | self.project_module = ProjectionModule(18, sigmoid, dropout=dropout) 193 | 194 | def forward(self, x, arc): 195 | """ 196 | Forward pass of Project module 197 | Args: 198 | x(torch.tensor): Input tensor with shape (b,c,h,w) 199 | arc(list(torch.tensor)): Grid architecture for grid project pooling, 200 | contains two lists of tensor indicates coordinates of horizontal and 201 | vertical lines 202 | Return: 203 | output(torch.tensor): Output tensor of this module, the shape is same with 204 | input tensor 205 | matrix(torch.tensor): A M x N matrix, where M and N indicates the number of 206 | lines in horizontal and vertical directions. 207 | """ 208 | conv_out = torch.cat( 209 | [m(x) for m in [self.conv_branch1, self.conv_branch2, self.conv_branch3]], 210 | 1) 211 | output, matrix = self.project_module(conv_out, arc) 212 | if self.sigmoid: 213 | return output, matrix 214 | else: 215 | return output 216 | 217 | 218 | class ProjectionModule(nn.Module): 219 | """ 220 | Projection block 221 | """ 222 | 223 | def __init__(self, input_channels, sigmoid=False, dropout=0.5): 224 | """ 225 | Initialization of Project module 226 | Args: 227 | input_channels(int): The number of input channels of the module 228 | sigmoid(bool): If need to ge the output matrix 229 | dropout(float): Drop out ratio 230 | """ 231 | super(ProjectionModule, self).__init__() 232 | self.sigmoid = sigmoid 233 | 234 | self.feature_conv = nn.Sequential( 235 | nn.Conv2d(input_channels, input_channels, 1, bias=False) 236 | , nn.GroupNorm(6, input_channels), nn.ReLU(True)) 237 | self.prediction_conv = nn.Sequential(nn.Dropout2d(p=dropout), 238 | nn.Conv2d(input_channels, 1, 1, 239 | bias=False)) 240 | 241 | self.feature_project = GridProjectPooling() 242 | self.prediction_project = GridProjectPooling() 243 | self.sigmoid_layer = nn.Sigmoid() 244 | 245 | def forward(self, x, arch): 246 | """ 247 | Forward pass of Project module 248 | Args: 249 | x(torch.tensor): Input tensor with shape (b,c,h,w) 250 | arch(list(torch.tensor)): Grid architecture for grid project pooling, 251 | contains two lists of tensor indicates coordinates of horizontal and 252 | vertical lines 253 | Return: 254 | output(torch.tensor): Output tensor of this module, the shape is same with 255 | input tensor 256 | matrix(torch.tensor): A M x N matrix, where M and N indicates the number of 257 | lines in horizontal and vertical directions. 258 | """ 259 | base_input = x 260 | feature = self.feature_conv(base_input) 261 | feature, _ = self.feature_project(feature, arch) 262 | tensors = [base_input, feature] 263 | if self.sigmoid: 264 | prediction = self.prediction_conv(base_input) 265 | prediction, matrix = self.prediction_project(prediction, arch) 266 | prediction = self.sigmoid_layer(prediction) 267 | matrix = self.sigmoid_layer(matrix) 268 | tensors.append(prediction) 269 | output = torch.cat(tensors, 1) 270 | return output, matrix 271 | else: 272 | output = torch.cat(tensors, 1) 273 | return output, None 274 | 275 | 276 | class SharedFCN(nn.Module): 277 | """Shared fully convolution module""" 278 | 279 | def __init__(self, input_channels): 280 | """ 281 | Initialization of SFCN instance 282 | Args: 283 | input_channels(int): The number of input channels of the module 284 | """ 285 | super(SharedFCN, self).__init__() 286 | self.conv = nn.Sequential( 287 | nn.Sequential( 288 | nn.Conv2d(input_channels, 18, 7, stride=1, padding=3, bias=False), 289 | nn.ReLU(True)), 290 | nn.Sequential(nn.Conv2d(18, 18, 7, stride=1, padding=3, bias=False), 291 | nn.ReLU(True)), 292 | nn.MaxPool2d((2, 2)), 293 | nn.Sequential(nn.Conv2d(18, 18, 7, stride=1, padding=3, bias=False), 294 | nn.ReLU(True)), 295 | nn.Sequential(nn.Conv2d(18, 18, 7, stride=1, padding=3, bias=False), 296 | nn.ReLU(True)), 297 | nn.MaxPool2d((2, 2)) 298 | ) 299 | 300 | def forward(self, x): 301 | x = self.conv(x) 302 | return x 303 | -------------------------------------------------------------------------------- /modules/split_modules.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: craig.li(solitaire10@163.com) 3 | # Defines several modules to build up Split model 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class SplitModel(nn.Module): 11 | """ 12 | Split model refer to ICDAR 2019 Deep Splitting and Merging for Table Structure Decomposition 13 | """ 14 | 15 | def __init__(self, input_channels): 16 | """ 17 | Initialization of split model 18 | Args: 19 | input_channels(int): The number of input data channel 20 | """ 21 | super(SplitModel, self).__init__() 22 | self.sfcn = SFCN(input_channels) 23 | self.rpn1 = ProjectionNet(18, 0, True, False, 0) 24 | self.rpn2 = ProjectionNet(36, 0, True, False, 0) 25 | self.rpn3 = ProjectionNet(36, 0, False, True, 0.3) 26 | self.rpn4 = ProjectionNet(37, 0, False, True, 0) 27 | # self.rpn5 = ProjectionNet(37, 0, False, True, 0) 28 | 29 | self.cpn1 = ProjectionNet(18, 1, True, False, 0) 30 | self.cpn2 = ProjectionNet(36, 1, True, False, 0) 31 | self.cpn3 = ProjectionNet(36, 1, False, True, 0.3) 32 | self.cpn4 = ProjectionNet(37, 1, False, True, 0) 33 | # self.cpn5 = ProjectionNet(37, 1, False, True, 0) 34 | 35 | self._init_weights() 36 | 37 | def _init_weights(self): 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 40 | torch.nn.init.kaiming_normal_(m.weight) 41 | if m.bias is not None: 42 | m.bias.data.fill_(0.01) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.weight.data.fill_(1) 45 | m.bias.data.zero_() 46 | elif isinstance(m, nn.Linear): 47 | m.weight.data.normal_(0, 0.01) 48 | m.bias.data.zero_() 49 | 50 | def forward(self, x): 51 | """ 52 | Forward pass of the split model 53 | Args: 54 | x(torch.tensor): Input tensor with shape (b,c,h,w) 55 | Return: 56 | output(list(torch.tensor)): output of the split model, two vectors indicates if there 57 | is a line in horizontal and vertical direction 58 | """ 59 | feature = self.sfcn(x) 60 | 61 | row_feature = self.rpn3(self.rpn2(self.rpn1(feature))) 62 | r3 = row_feature[:, -1, :, :] 63 | # row_feature = self.rpn4(row_feature) 64 | # r4 = row_feature[:, -1, :, :] 65 | # row_feature = self.rpn5(row_feature) 66 | # r5 = row_feature[:, -1, :, :] 67 | 68 | cow_feature = self.cpn3(self.cpn2(self.cpn1(feature))) 69 | c3 = cow_feature[:, -1, :, :] 70 | # cow_feature = self.cpn4(cow_feature) 71 | # c4 = cow_feature[:, -1, :, :] 72 | # cow_feature = self.cpn5(cow_feature) 73 | # c5 = cow_feature[:, -1, :, :] 74 | 75 | # r = self.rpn1(feature)[:,-1,:,:] 76 | # c = self.cpn1(feature)[:,-1,:,:] 77 | # return ( 78 | # torch.cat([r3[:, :, 0], r4[:, :, 0]], 0),#, r5[:, :, 0] 79 | # torch.cat([c3[:, 0, :], c4[:, 0, :]], 0))#, c5[:, 0, :] 80 | return (r3[:, :, 0], c3[:, 0, :]) 81 | 82 | 83 | class ProjectPooling(nn.Module): 84 | """ 85 | Project pooling, replace each value in the input with its row(column) average, 86 | Row project pooling: 87 | $$ \hat F_{ij} = \frac {1}{W} \sum_{j'=1}^{W} F_i',j' $$ 88 | Column project pooling: 89 | $$ \hat F_{ij} = \frac {1}{H} \sum_{i'=1}^{H} F_i',j' $$ 90 | """ 91 | 92 | def __init__(self, direction): 93 | """ 94 | Initialization of Project pooling layer 95 | Args: 96 | direction(int): Specifies the direction of this layer, 0 for row and 1 for column 97 | """ 98 | super(ProjectPooling, self).__init__() 99 | self.direction = direction 100 | 101 | def forward(self, x): 102 | """ 103 | Forward pass of project pooling layer 104 | Args: 105 | x(torch.tensor): Input tensor with shape (b,c,h,w) 106 | Return: 107 | output: Output of Project pooling layer with the same shape with input tensor 108 | """ 109 | b, c, h, w = x.size() 110 | output_slice = torch.from_numpy(np.ones([b, c, h, w])).type( 111 | torch.FloatTensor).cuda() 112 | if self.direction == 0: 113 | return torch.mean(x, 3).unsqueeze(3) * output_slice 114 | elif self.direction == 1: 115 | return torch.mean(x, 2).unsqueeze(2) * output_slice 116 | else: 117 | raise Exception( 118 | 'Wrong direction, the direction should be 0 for horizontal and 1 for vertical') 119 | 120 | 121 | class ProjectionNet(nn.Module): 122 | """ 123 | Projection Module contains three parallel conv layers with dilation factor 2,3,4, followed by 124 | a project pooling module 125 | """ 126 | 127 | def __init__(self, input_channels, direction, max_pooling=False, 128 | sigmoid=False, dropout=0.5): 129 | super(ProjectionNet, self).__init__() 130 | self.conv_branch1 = nn.Sequential( 131 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=2, dilation=2), 132 | nn.GroupNorm(3, 6), nn.ReLU(True)) 133 | self.conv_branch2 = nn.Sequential( 134 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=3, dilation=3), 135 | nn.GroupNorm(3, 6), nn.ReLU(True)) 136 | self.conv_branch3 = nn.Sequential( 137 | nn.Conv2d(input_channels, 6, 3, stride=1, padding=4, dilation=4), 138 | nn.GroupNorm(3, 6), nn.ReLU(True)) 139 | 140 | self.project_module = ProjectionModule(18, direction, max_pooling, sigmoid, 141 | dropout=dropout) 142 | 143 | def forward(self, x): 144 | """ 145 | Forward pass of Project module 146 | Args: 147 | x(torch.tensor): Input tensor with shape (b,c,h,w) 148 | Return: 149 | output(torch.tensor): Output tensor of this module, the shape is same with 150 | input tensor 151 | """ 152 | conv_out = torch.cat( 153 | [m(x) for m in [self.conv_branch1, self.conv_branch2, self.conv_branch3]], 154 | 1) 155 | output = self.project_module(conv_out) 156 | return output 157 | 158 | 159 | class ProjectionModule(nn.Module): 160 | """ 161 | Projection block 162 | """ 163 | 164 | def __init__(self, input_channels, direction, max_pooling=False, 165 | sigmoid=False, dropout=0.5): 166 | """ 167 | Initialization of Project module 168 | Args: 169 | input_channels(int): The number of input channels of the module 170 | direction(int): Direction of project pooling module, 0 for row, 1 for column 171 | max_pooling(bool): If there is a max pooling layer in the module, if it's a 172 | row project pooling layer, a (1,2) max pooling layer would be applied, 173 | (2,1) max pooling would be applied if it's a column project pooling layer 174 | sigmoid(bool): If need to ge the output matrix 175 | dropout(float): Drop out ratio 176 | """ 177 | super(ProjectionModule, self).__init__() 178 | self.direction = direction 179 | self.max_pooling = max_pooling 180 | self.sigmoid = sigmoid 181 | self.max_pool = nn.MaxPool2d((1, 2)) if direction == 0 else nn.MaxPool2d( 182 | (2, 1)) 183 | self.feature_conv = nn.Sequential( 184 | nn.Conv2d(input_channels, input_channels, 1, bias=False) 185 | , nn.GroupNorm(6, input_channels), nn.ReLU(True)) 186 | self.prediction_conv = nn.Sequential(nn.Dropout2d(p=dropout), 187 | nn.Conv2d(input_channels, 1, 1, 188 | bias=False)) 189 | 190 | self.feature_project = ProjectPooling(direction) 191 | self.prediction_project = nn.Sequential(ProjectPooling(direction), 192 | nn.Sigmoid()) 193 | 194 | def forward(self, x): 195 | """ 196 | Forward pass of Project module 197 | Args: 198 | x(torch.tensor): Input tensor with shape (b,c,h,w) 199 | Return: 200 | output(torch.tensor): Output tensor of this module, if a max pooling layer is 201 | applied, the output shape would be decreased to half of the original shape 202 | in opposite direction 203 | """ 204 | base_input = x 205 | if self.max_pooling: 206 | base_input = self.max_pool(x) 207 | feature = self.feature_conv(base_input) 208 | feature = self.feature_project(feature) 209 | tensors = [base_input, feature] 210 | if self.sigmoid: 211 | prediction = self.prediction_conv(base_input) 212 | prediction = self.prediction_project(prediction) 213 | tensors.append(prediction) 214 | output = torch.cat(tensors, 1) 215 | return output 216 | 217 | 218 | class SFCN(nn.Module): 219 | """ 220 | Shared fully convolution module composed of three conv layers, and the last conv layer is 221 | a dilation conv layer with the factor 2 222 | """ 223 | 224 | def __init__(self, input_channels): 225 | """ 226 | Initialization of SFCN instance 227 | Args: 228 | input_channels(int): The number of input channels of the module 229 | """ 230 | super(SFCN, self).__init__() 231 | self.conv1 = nn.Sequential( 232 | nn.Conv2d(input_channels, 18, 7, stride=1, padding=3, bias=False), 233 | nn.ReLU(True)) 234 | self.conv2 = nn.Sequential( 235 | nn.Conv2d(18, 18, 7, stride=1, padding=3, bias=False), 236 | nn.ReLU(True)) 237 | self.conv3 = nn.Sequential( 238 | nn.Conv2d(18, 18, 7, stride=1, padding=6, dilation=2, bias=False), 239 | nn.ReLU(True)) 240 | 241 | def forward(self, x): 242 | x = self.conv1(x) 243 | x = self.conv2(x) 244 | x = self.conv3(x) 245 | return x 246 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.0 2 | numpy 3 | cv2 4 | PIL 5 | argparse 6 | -------------------------------------------------------------------------------- /split/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fireae/Split_Merge_table_recognition/a5c270cd0c3dc5cc52a8069356219764e587be37/split/__init__.py -------------------------------------------------------------------------------- /split/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: craig.li(solitaire10@163.com) 3 | # Script for testing Split model. 4 | 5 | 6 | import argparse 7 | import json 8 | import numpy as np 9 | import os 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | 13 | from torch.utils.data import DataLoader 14 | from dataset.dataset import ImageDataset 15 | from modules.split_modules import SplitModel 16 | from loss.loss import bce_loss 17 | 18 | 19 | def test(opt, net, data=None): 20 | """ 21 | Test script for Split model 22 | Args: 23 | opt(dic): Options 24 | net(torch.model): Split model instance 25 | data(dataloader): Dataloader or None, if load data with configuration in opt. 26 | Return: 27 | total_loss(torch.tensor): The total loss of the dataset 28 | accuracy(torch.tensor): the accuracy of the dataset 29 | """ 30 | if not data: 31 | with open(opt.json_dir, 'r') as f: 32 | labels = json.load(f) 33 | dir_img = opt.img_dir 34 | 35 | test_set = ImageDataset(dir_img, labels, opt.featureW, scale=opt.scale) 36 | test_loader = DataLoader(test_set, batch_size=opt.batch_size, shuffle=True) 37 | else: 38 | test_loader = data 39 | 40 | loss_func = bce_loss 41 | 42 | for epoch in range(1): 43 | net.eval() 44 | epoch_loss = 0 45 | correct_count = 0 46 | count = 0 47 | times = 1 48 | for i, b in enumerate(test_loader): 49 | with torch.no_grad(): 50 | img, label = b 51 | if opt.gpu: 52 | img = img.cuda() 53 | label = [x.cuda() for x in label] 54 | pred_label = net(img) 55 | loss = loss_func(pred_label, label, [0.1, 0.25, 1]) 56 | epoch_loss += loss 57 | correct_count += (torch.sum( 58 | (pred_label[0] > 0.5).type(torch.IntTensor) == label[0][0].repeat( 59 | times, 1).type( 60 | torch.IntTensor)).item() + torch.sum( 61 | (pred_label[1] > 0.5).type(torch.IntTensor) == label[1][0].repeat( 62 | times, 1).type( 63 | torch.IntTensor)).item()) 64 | count += label[0].view(-1).size()[0] * times + label[1].view(-1).size()[ 65 | 0] * times 66 | accuracy = correct_count / (count) 67 | total_loss = epoch_loss / (i + 1) 68 | print('Validation finished ! Loss: {0} , Accuracy: {1}'.format( 69 | epoch_loss / (i + 1), accuracy)) 70 | return total_loss, accuracy 71 | 72 | 73 | def model_select(opt, net): 74 | model_dir = opt.model_dir 75 | models = os.listdir(model_dir) 76 | losses = [] 77 | accuracies = [] 78 | for model in models: 79 | net.load_state_dict(torch.load(os.path.join(model_dir, model))) 80 | loss, accuracy = test(opt, net) 81 | losses.append(loss) 82 | accuracies.append(accuracy) 83 | min_loss_index = np.argmin(np.array(losses)) 84 | max_accuracy_index = np.argmax(np.array(accuracies)) 85 | print('accuracy:', max_accuracy_index, accuracies[max_accuracy_index], 86 | models[max_accuracy_index]) 87 | print('losses', min_loss_index, losses[min_loss_index], 88 | models[min_loss_index]) 89 | 90 | print(losses) 91 | print(accuracies) 92 | 93 | 94 | if __name__ == '__main__': 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument('--batch_size', type=int, default=32, 97 | help='batch size of the training set') 98 | parser.add_argument('--gpu', type=bool, default=True, help='if use gpu') 99 | parser.add_argument('--gpu_list', type=str, default='0', 100 | help='which gpu could use') 101 | parser.add_argument('--model_dir', type=str, required=True, 102 | help='saved directory for output models') 103 | parser.add_argument('--json_dir', type=str, required=True, 104 | help='labels of the data') 105 | parser.add_argument('--img_dir', type=str, required=True, 106 | help='image directory for input data') 107 | parser.add_argument('--featureW', type=int, default=8, help='width of output') 108 | parser.add_argument('--scale', type=float, default=0.5, 109 | help='scale of the image') 110 | 111 | opt = parser.parse_args() 112 | 113 | net = SplitModel(3) 114 | if opt.gpu: 115 | cudnn.benchmark = True 116 | cudnn.deterministic = True 117 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_list 118 | net = torch.nn.DataParallel(net).cuda() 119 | 120 | net.load_state_dict(torch.load('saved_models/CP53.pth')) 121 | print(test(opt, net)) 122 | # model_select(opt, net) 123 | -------------------------------------------------------------------------------- /split/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: craig.li(solitaire10@163.com) 3 | # Script for training merge model 4 | 5 | 6 | import argparse 7 | import json 8 | import os 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | 12 | from torch import optim 13 | from torch.utils.data import DataLoader 14 | from dataset.dataset import ImageDataset 15 | from loss.loss import bce_loss 16 | from modules.split_modules import SplitModel 17 | from split.test import test 18 | 19 | 20 | def train(opt, net): 21 | """ 22 | Train the split model 23 | Args: 24 | opt(dic): Options 25 | net(torch.model): Split model instance 26 | """ 27 | with open(opt.json_dir, 'r') as f: 28 | labels = json.load(f) 29 | dir_img = opt.img_dir 30 | 31 | with open(opt.val_json, 'r') as f: 32 | val_labels = json.load(f) 33 | val_img_dir = opt.val_img_dir 34 | 35 | train_set = ImageDataset(dir_img, labels, opt.featureW, scale=opt.scale) 36 | train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True) 37 | 38 | val_set = ImageDataset(val_img_dir, val_labels, opt.featureW, scale=opt.scale) 39 | val_loader = DataLoader(val_set, batch_size=opt.batch_size, shuffle=False) 40 | 41 | print('Data loaded!') 42 | 43 | loss_func = bce_loss 44 | optimizer = optim.Adam(net.parameters(), 45 | lr=opt.lr, 46 | weight_decay=0.001) 47 | best_accuracy = 0 48 | for epoch in range(opt.epochs): 49 | print('epoch:{}'.format(epoch + 1)) 50 | net.train() 51 | epoch_loss = 0 52 | correct_count = 0 53 | count = 0 54 | for i, b in enumerate(train_loader): 55 | img, label = b 56 | if opt.gpu: 57 | img = img.cuda() 58 | label = [x.cuda() for x in label] 59 | pred_label = net(img) 60 | loss = loss_func(pred_label, label, [0.1, 0.25, 1]) 61 | epoch_loss += loss 62 | optimizer.zero_grad() 63 | loss.backward() 64 | optimizer.step() 65 | times = 1 66 | correct_count += (torch.sum( 67 | (pred_label[0] > 0.5).type(torch.IntTensor) == label[0][0].repeat(times, 68 | 1).type( 69 | torch.IntTensor)).item() + torch.sum( 70 | (pred_label[1] > 0.5).type(torch.IntTensor) == label[1][0].repeat(times, 71 | 1).type( 72 | torch.IntTensor)).item()) 73 | count += label[0].view(-1).size()[0] * times + label[1].view(-1).size()[ 74 | 0] * times 75 | accuracy = correct_count / (count) 76 | print( 77 | 'Epoch finished ! Loss: {0} , Accuracy: {1}'.format(epoch_loss / (i + 1), 78 | accuracy)) 79 | val_loss, val_acc = test(opt, net, val_loader) 80 | if val_acc > best_accuracy: 81 | best_accuracy = val_acc 82 | torch.save(net.state_dict(), 83 | opt.saved_dir + 'CP{}.pth'.format(epoch + 1)) 84 | 85 | 86 | if __name__ == '__main__': 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--batch_size', type=int, default=1, 89 | help='batch size of the training set') 90 | parser.add_argument('--epochs', type=int, default=50, help='epochs') 91 | parser.add_argument('--gpu', type=bool, default=True, help='if use gpu') 92 | parser.add_argument('--gpu_list', type=str, default='0', 93 | help='which gpu could use') 94 | parser.add_argument('--lr', type=float, default=0.00075, 95 | help='learning rate, default=0.00075 for Adam') 96 | parser.add_argument('--saved_dir', type=str, required=True, 97 | help='saved directory for output models') 98 | parser.add_argument('--json_dir', type=str, required=True, 99 | help='labels of the data') 100 | parser.add_argument('--img_dir', type=str, required=True, 101 | help='image directory for input data') 102 | parser.add_argument('--val_json', type=str, required=True, 103 | help='labels of the validation data') 104 | parser.add_argument('--val_img_dir', type=str, required=True, 105 | help='image directory for validation data') 106 | parser.add_argument('--featureW', type=int, default=8, help='width of output') 107 | parser.add_argument('--scale', type=float, default=0.5, 108 | help='scale of the image') 109 | 110 | opt = parser.parse_args() 111 | 112 | net = SplitModel(3) 113 | if opt.gpu: 114 | cudnn.benchmark = True 115 | cudnn.deterministic = True 116 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_list 117 | net = torch.nn.DataParallel(net).cuda() 118 | 119 | if not os.path.exists(opt.saved_dir): 120 | os.mkdir(opt.saved_dir) 121 | 122 | train(opt, net) 123 | --------------------------------------------------------------------------------