├── utils ├── __init__.py ├── add_size_to_annotations.py ├── fast_inverse.py ├── augmentation.py ├── string_utils.py ├── util.py └── transformation_utils.py ├── model ├── modules │ └── __init__.py ├── __init__.py ├── csrc │ ├── cpu │ │ ├── vision.h │ │ └── ROIAlign_cpu.cpp │ ├── vision.cpp │ ├── cuda │ │ ├── vision.h │ │ └── ROIAlign_cuda.cu │ └── ROIAlign.h ├── model.py ├── binary_pair_net.py ├── loss.py ├── roi_align.py ├── simpleNN.py ├── metric.py ├── coordconv.py ├── binary_pair_real.py ├── vgg.py ├── optimize.py └── yolo_box_detector.py ├── logger ├── __init__.py └── logger.py ├── data_loader ├── __init__.py └── data_loaders.py ├── .gitignore ├── base ├── __init__.py ├── base_model.py └── base_data_loader.py ├── .flake8 ├── evaluators ├── __init__.py └── draw_graph.py ├── trainer ├── __init__.py ├── trainer.py └── feature_pair_trainer.py ├── config.json ├── test_anchors.json ├── cf_no_vis_pairing.json ├── anchors_noRot_new_25.json ├── setup.py ├── cf_baseline_detector.json ├── notes.txt ├── cf_detector.json ├── cf_test_no_vis_pairing.json ├── cf_pairing.json ├── datasets ├── printforms_box_detect.py ├── testforms_graph_pair.py ├── testforms_feat_pair.py ├── testforms_box.py └── graph_pair.py ├── pruneClusters.py ├── graph.py ├── train.py └── run.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loaders import * 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | saved/ 4 | data/ 5 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = F401, F403 3 | max-line-length = 120 4 | exclude = 5 | .git, 6 | __pycache__, 7 | -------------------------------------------------------------------------------- /evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from evaluators.formsboxdetect_printer import FormsBoxDetect_printer 3 | from evaluators.formsgraphpair_printer import FormsGraphPair_printer 4 | from evaluators.formsfeaturepair_printer import FormsFeaturePair_printer 5 | 6 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.yolo_box_detector import YoloBoxDetector 2 | from model.pairing_graph import PairingGraph 3 | from model.binary_pair_net import BinaryPairNet 4 | from model.binary_pair_real import BinaryPairReal 5 | from model.simpleNN import SimpleNN 6 | #from .roi_align import ROIAlign 7 | #from .roi_align import roi_align 8 | -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class Logger: 5 | """ 6 | Training process logger 7 | 8 | Note: 9 | Used by BaseTrainer to save training history. 10 | """ 11 | def __init__(self): 12 | self.entries = {} 13 | 14 | def add_entry(self, entry): 15 | self.entries[len(self.entries) + 1] = entry 16 | 17 | def __str__(self): 18 | return json.dumps(self.entries, sort_keys=True, indent=4) 19 | -------------------------------------------------------------------------------- /model/csrc/cpu/vision.h: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | #pragma once 4 | #include 5 | 6 | 7 | at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, 8 | const at::Tensor& rois, 9 | const float spatial_scale, 10 | const int pooled_height, 11 | const int pooled_width, 12 | const int sampling_ratio); 13 | -------------------------------------------------------------------------------- /model/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | //#include "nms.h" 3 | #include "ROIAlign.h" 4 | //#include "ROIPool.h" 5 | 6 | 7 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 8 | //m.def("nms", &nms, "non-maximum suppression"); 9 | m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward"); 10 | m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward"); 11 | //m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward"); 12 | //m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward"); 13 | } 14 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | def __init__(self, config): 11 | super(BaseModel, self).__init__() 12 | self.config = config 13 | self.logger = logging.getLogger(self.__class__.__name__) 14 | 15 | def forward(self, *input): 16 | """ 17 | Forward pass logic 18 | 19 | :return: Model output 20 | """ 21 | raise NotImplementedError 22 | 23 | def summary(self): 24 | """ 25 | Model summary 26 | """ 27 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 28 | params = sum([np.prod(p.size()) for p in model_parameters]) 29 | self.logger.info('Trainable parameters: {}'.format(params)) 30 | self.logger.info(self) 31 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | from .trainer import * 17 | from .box_detect_trainer import BoxDetectTrainer 18 | from .graph_pair_trainer import GraphPairTrainer 19 | from .feature_pair_trainer import FeaturePairTrainer 20 | -------------------------------------------------------------------------------- /model/csrc/cuda/vision.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | #include 4 | 5 | 6 | at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, 7 | const at::Tensor& rois, 8 | const float spatial_scale, 9 | const int pooled_height, 10 | const int pooled_width, 11 | const int sampling_ratio); 12 | 13 | at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, 14 | const at::Tensor& rois, 15 | const float spatial_scale, 16 | const int pooled_height, 17 | const int pooled_width, 18 | const int batch_size, 19 | const int channels, 20 | const int height, 21 | const int width, 22 | const int sampling_ratio); 23 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "AI2D_UNet_wn", 3 | "cuda": true, 4 | "gpu": 0, 5 | "override": true, 6 | "data_loader": { 7 | "data_set_name": "AI2D", 8 | "data_dir": "/home/ubuntu/brian/data/ai2d", 9 | "batch_size": 16, 10 | "shuffle": true, 11 | "num_workers": 2, 12 | "patch_size": 300, 13 | "center_jitter": 0.1, 14 | "size_jitter": 0.2 15 | }, 16 | "validation": { 17 | "validation_split": 0.01, 18 | "shuffle": true 19 | }, 20 | 21 | 22 | "lr_scheduler_type": "none", 23 | 24 | "optimizer_type": "Adam", 25 | "optimizer": { 26 | "lr": 0.001, 27 | "weight_decay": 0 28 | }, 29 | "loss": "sigmoid_BCE_loss", 30 | "metrics": ["meanIOU"], 31 | "trainer": { 32 | "iterations": 1000000, 33 | "save_dir": "saved/", 34 | "val_step": 5000, 35 | "save_step": 10000, 36 | "log_step": 500, 37 | "verbosity": 2, 38 | "monitor": "val_meanIOU", 39 | "monitor_mode": "none" 40 | }, 41 | "arch": "UNet", 42 | "model": { 43 | "skip_last_sigmoid": true, 44 | "n_channels": 4, 45 | "norm_type": "weightNorm" 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /test_anchors.json: -------------------------------------------------------------------------------- 1 | [{"height": 8.341853141784668, "width": 67.14251708984375, "rot": -0.0011615825351327658}, {"height": 7.695361137390137, "width": 288.52301025390625, "rot": -0.0012055743718519807}, {"height": 7.149867057800293, "width": 44.82994079589844, "rot": -0.0010948360431939363}, {"height": 10.486498832702637, "width": 74.67229461669922, "rot": 1.6216387748718262}, {"height": 6.005581855773926, "width": 21.706567764282227, "rot": 0.00011035850184271112}, {"height": 32.79598617553711, "width": 6.153600215911865, "rot": -0.01817847602069378}, {"height": 5.257901191711426, "width": 15.17779541015625, "rot": 0.0006144480430521071}, {"height": 8.575309753417969, "width": 168.4725799560547, "rot": 0.00022563549282494932}, {"height": 7.376450538635254, "width": 9.10978889465332, "rot": -0.0004971762537024915}, {"height": 3.924494504928589, "width": 28.789209365844727, "rot": -1.5683108568191528}, {"height": 4.687230587005615, "width": 5.672689437866211, "rot": 0.0020368173718452454}, {"height": 3.1410233974456787, "width": 3.3848800659179688, "rot": -2.101607242366299e-05}, {"height": 9.570945739746094, "width": 103.58953094482422, "rot": -0.0010029035620391369}, {"height": 3.543055295944214, "width": 10.173822402954102, "rot": 0.0011912824120372534}, {"height": 6.470541477203369, "width": 31.112993240356445, "rot": -0.00018544778868090361}, {"height": 134.04791259765625, "width": 65.5813980102539, "rot": 0.0379830040037632}] -------------------------------------------------------------------------------- /utils/add_size_to_annotations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import sys 17 | import json 18 | import os 19 | from skimage import io 20 | 21 | 22 | dirPath = sys.argv[1] 23 | 24 | with open(os.path.join(dirPath,'categories.json')) as f: 25 | imageToCategories = json.loads(f.read()) 26 | 27 | for imageName in imageToCategories: 28 | image = io.imread(os.path.join(dirPath,'images',imageName)) 29 | with open(os.path.join(dirPath,'annotations',imageName+'.json')) as f: 30 | annotations = json.loads(f.read()) 31 | annotations['imageConsts']['height']=image.shape[0] 32 | annotations['imageConsts']['width']=image.shape[1] 33 | with open(os.path.join(dirPath,'annotationsMod',imageName+'.json'),'w') as f: 34 | f.write(json.dumps(annotations, sort_keys=True, indent=4, separators=(',', ': '))) 35 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | from base import BaseModel 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | 21 | class MnistModel(BaseModel): 22 | def __init__(self, config): 23 | super(MnistModel, self).__init__(config) 24 | self.config = config 25 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 26 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 27 | self.conv2_drop = nn.Dropout2d() 28 | self.fc1 = nn.Linear(320, 50) 29 | self.fc2 = nn.Linear(50, 10) 30 | 31 | def forward(self, x): 32 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 33 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 34 | x = x.view(-1, 320) 35 | x = F.relu(self.fc1(x)) 36 | x = F.dropout(x, training=self.training) 37 | x = self.fc2(x) 38 | return F.log_softmax(x, dim=1) 39 | -------------------------------------------------------------------------------- /model/binary_pair_net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import torch 17 | from torch import nn 18 | #from base import BaseModel 19 | import torch.nn.functional as F 20 | #from torch.nn.utils.weight_norm import weight_norm 21 | import math 22 | import json 23 | #from .net_builder import make_layers 24 | 25 | #This assumes the classification of edges was done by the pairing_graph modules featurizer 26 | 27 | class BinaryPairNet(nn.Module): 28 | def __init__(self, config): # predCount, base_0, base_1): 29 | super(BinaryPairNet, self).__init__() 30 | raise NotImplemented('Changes have broken this class, use BinaryPairReal') 31 | 32 | 33 | def forward(self, node_features, adjacencyMatrix, numBBs): 34 | #expects edge_features as batch currently 35 | 36 | #adj = torch.spmm(self.weight,edge_features) + self.bias 37 | 38 | 39 | 40 | #return 41 | return None,node_features 42 | 43 | 44 | -------------------------------------------------------------------------------- /cf_no_vis_pairing.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "no_vis_pairing", 3 | "cuda": false, 4 | "gpu": 0, 5 | "save_mode": "state_dict", 6 | "override": true, 7 | "super_computer":false, 8 | "data_loader": { 9 | "data_set_name": "FormsFeaturePair", 10 | "simple_dataset": true, 11 | "alternate_json_dir": "out/detection_data/", 12 | "data_dir": "../data/NAF_dataset", 13 | "batch_size": 512, 14 | "shuffle": true, 15 | "num_workers": 0, 16 | "no_blanks": true, 17 | "swap_circle":true, 18 | "no_graphics":true, 19 | "cache_resized_images": true, 20 | "rotation": false, 21 | "balance": true, 22 | "only_opposite_pairs": true, 23 | "corners":true 24 | 25 | 26 | }, 27 | "validation": { 28 | "shuffle": false, 29 | "balance": false 30 | }, 31 | 32 | 33 | "lr_scheduler_type": "none", 34 | 35 | "optimizer_type": "Adam", 36 | "optimizer": { 37 | "lr": 0.001, 38 | "weight_decay": 0 39 | }, 40 | "loss": "sigmoid_BCE_loss", 41 | "loss_params": 42 | { 43 | }, 44 | "metrics": [], 45 | "trainer": { 46 | "class": "FeaturePairTrainer", 47 | "iterations": 10000, 48 | "save_dir": "saved/", 49 | "val_step": 2000, 50 | "save_step": 2000, 51 | "save_step_minor": 250, 52 | "log_step": 50, 53 | "verbosity": 1, 54 | "monitor": "loss", 55 | "monitor_mode": "none" 56 | }, 57 | "arch": "SimpleNN", 58 | "model": { 59 | "feat_size":18, 60 | "num_layers": 2, 61 | "hidden_size": 256, 62 | "out_size": 1 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import torch.nn.functional as F 17 | import utils 18 | #import torch.nn as nn 19 | from model.alignment_loss import alignment_loss, box_alignment_loss, iou_alignment_loss 20 | from model.yolo_loss import YoloLoss, YoloDistLoss, LineLoss 21 | 22 | def my_loss(y_input, y_target): 23 | return F.nll_loss(y_input, y_target) 24 | 25 | def sigmoid_BCE_loss(y_input, y_target): 26 | return F.binary_cross_entropy_with_logits(y_input, y_target) 27 | def MSE(y_input, y_target): 28 | return F.mse_loss(y_input, y_target.float()) 29 | 30 | 31 | 32 | def detect_alignment_loss(predictions, target,label_sizes,alpha_alignment, alpha_backprop): 33 | return alignment_loss(predictions, target, label_sizes, alpha_alignment, alpha_backprop) 34 | def detect_alignment_loss_points(predictions, target,label_sizes,alpha_alignment, alpha_backprop): 35 | return alignment_loss(predictions, target, label_sizes, alpha_alignment, alpha_backprop,points=True) 36 | 37 | -------------------------------------------------------------------------------- /model/csrc/ROIAlign.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | #include 4 | 5 | #include "cpu/vision.h" 6 | 7 | #ifdef WITH_CUDA 8 | #include "cuda/vision.h" 9 | #endif 10 | 11 | // Interface for Python 12 | at::Tensor ROIAlign_forward(const at::Tensor& input, 13 | const at::Tensor& rois, 14 | const float spatial_scale, 15 | const int pooled_height, 16 | const int pooled_width, 17 | const int sampling_ratio) { 18 | if (input.type().is_cuda()) { 19 | #ifdef WITH_CUDA 20 | return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); 21 | #else 22 | AT_ERROR("Not compiled with GPU support"); 23 | #endif 24 | } 25 | return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); 26 | } 27 | 28 | at::Tensor ROIAlign_backward(const at::Tensor& grad, 29 | const at::Tensor& rois, 30 | const float spatial_scale, 31 | const int pooled_height, 32 | const int pooled_width, 33 | const int batch_size, 34 | const int channels, 35 | const int height, 36 | const int width, 37 | const int sampling_ratio) { 38 | if (grad.type().is_cuda()) { 39 | #ifdef WITH_CUDA 40 | return ROIAlign_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio); 41 | #else 42 | AT_ERROR("Not compiled with GPU support"); 43 | #endif 44 | } 45 | AT_ERROR("Not implemented on the CPU"); 46 | } 47 | 48 | -------------------------------------------------------------------------------- /anchors_noRot_new_25.json: -------------------------------------------------------------------------------- 1 | [{"height": 6.478144645690918, "width": 7.405168533325195, "rot": 0.0, "popularity": 3210}, {"height": 10.60506534576416, "width": 12.923971176147461, "rot": 0.0, "popularity": 3631}, {"height": 61.28946304321289, "width": 8.811919212341309, "rot": 0.0, "popularity": 1502}, {"height": 457.6505126953125, "width": 22.403642654418945, "rot": 0.0, "popularity": 72}, {"height": 10.292879104614258, "width": 19.579254150390625, "rot": 0.0, "popularity": 4647}, {"height": 12.632172584533691, "width": 40.45283126831055, "rot": 0.0, "popularity": 4418}, {"height": 19.433130264282227, "width": 223.44007873535156, "rot": 0.0, "popularity": 1895}, {"height": 17.903671264648438, "width": 407.82684326171875, "rot": 0.0, "popularity": 877}, {"height": 10.031716346740723, "width": 28.747276306152344, "rot": 0.0, "popularity": 4427}, {"height": 13.08789348602295, "width": 56.0013542175293, "rot": 0.0, "popularity": 4230}, {"height": 16.375200271606445, "width": 136.93319702148438, "rot": 0.0, "popularity": 2950}, {"height": 21.589508056640625, "width": 277.5206298828125, "rot": 0.0, "popularity": 1368}, {"height": 28.561887741088867, "width": 11.362449645996094, "rot": 0.0, "popularity": 852}, {"height": 13.71330738067627, "width": 76.47283935546875, "rot": 0.0, "popularity": 3949}, {"height": 20.79901695251465, "width": 175.9959259033203, "rot": 0.0, "popularity": 2048}, {"height": 18.283203125, "width": 509.3600158691406, "rot": 0.0, "popularity": 484}, {"height": 153.66111755371094, "width": 19.17852210998535, "rot": 0.0, "popularity": 213}, {"height": 92.11408996582031, "width": 111.45216369628906, "rot": 0.0, "popularity": 157}, {"height": 15.930002212524414, "width": 102.31188201904297, "rot": 0.0, "popularity": 3237}, {"height": 16.083959579467773, "width": 339.891845703125, "rot": 0.0, "popularity": 1167}, {"height": 18.060001373291016, "width": 594.9429321289062, "rot": 0.0, "popularity": 893}, {"height": 24.569995880126953, "width": 933.2150268554688, "rot": 0.0, "popularity": 95}, {"height": 19.704673767089844, "width": 743.9151000976562, "rot": 0.0, "popularity": 215}, {"height": 27.12500762939453, "width": 1344.597900390625, "rot": 0.0, "popularity": 18}, {"height": 27.45748519897461, "width": 1060.3948974609375, "rot": 0.0, "popularity": 5}] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #!/usr/bin/env python 3 | 4 | import glob 5 | import os 6 | 7 | import torch 8 | from setuptools import find_packages 9 | from setuptools import setup 10 | from torch.utils.cpp_extension import CUDA_HOME 11 | from torch.utils.cpp_extension import CppExtension 12 | from torch.utils.cpp_extension import CUDAExtension 13 | 14 | requirements = ["torch", "torchvision"] 15 | 16 | 17 | def get_extensions(): 18 | this_dir = os.path.dirname(os.path.abspath(__file__)) 19 | extensions_dir = os.path.join(this_dir, "model", "csrc") 20 | 21 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 22 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 23 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 24 | 25 | sources = main_file + source_cpu 26 | extension = CppExtension 27 | 28 | extra_compile_args = {"cxx": []} 29 | define_macros = [] 30 | 31 | if torch.cuda.is_available() and CUDA_HOME is not None: 32 | extension = CUDAExtension 33 | sources += source_cuda 34 | define_macros += [("WITH_CUDA", None)] 35 | extra_compile_args["nvcc"] = [ 36 | "-DCUDA_HAS_FP16=1", 37 | "-D__CUDA_NO_HALF_OPERATORS__", 38 | "-D__CUDA_NO_HALF_CONVERSIONS__", 39 | "-D__CUDA_NO_HALF2_OPERATORS__", 40 | ] 41 | print('Using CUDA') 42 | 43 | sources = [os.path.join(extensions_dir, s) for s in sources] 44 | 45 | include_dirs = [extensions_dir] 46 | 47 | ext_modules = [ 48 | extension( 49 | "model._C", 50 | sources, 51 | include_dirs=include_dirs, 52 | define_macros=define_macros, 53 | extra_compile_args=extra_compile_args, 54 | ) 55 | ] 56 | 57 | return ext_modules 58 | 59 | 60 | setup( 61 | name="pairingNet", 62 | version="0.1", 63 | author="brian davis", 64 | url="https://github.com/herobd/pairing", 65 | description="detection and pairing of form-image elements", 66 | packages=find_packages(exclude=("saved", "out","log_slurm")), 67 | # install_requires=requirements, 68 | ext_modules=get_extensions(), 69 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 70 | ) 71 | -------------------------------------------------------------------------------- /utils/fast_inverse.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 3 | it under the terms of the GNU General Public License as published by 4 | the Free Software Foundation, either version 3 of the License, or 5 | (at your option) any later version. 6 | 7 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 8 | but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | GNU General Public License for more details. 11 | 12 | You should have received a copy of the GNU General Public License 13 | along with Visual-Template-free-Form-Parsting. If not, see . 14 | """ 15 | import numpy as np 16 | import torch 17 | 18 | def adjoint(A): 19 | """compute inverse without division by det; ...xv3xc3 input, or array of matrices assumed""" 20 | AI = np.empty_like(A) 21 | for i in xrange(3): 22 | AI[...,i,:] = np.cross(A[...,i-2,:], A[...,i-1,:]) 23 | return AI 24 | 25 | def inverse_transpose(A): 26 | """ 27 | efficiently compute the inverse-transpose for stack of 3x3 matrices 28 | """ 29 | I = adjoint(A) 30 | det = dot(I, A).mean(axis=-1) 31 | return I / det[...,None,None] 32 | 33 | def inverse(A): 34 | """inverse of a stack of 3x3 matrices""" 35 | return np.swapaxes( inverse_transpose(A), -1,-2) 36 | def dot(A, B): 37 | """dot arrays of vecs; contract over last indices""" 38 | return np.einsum('...i,...i->...', A, B) 39 | 40 | def adjoint_torch(A): 41 | AI = A.clone() 42 | for i in xrange(3): 43 | AI[...,i,:] = torch.cross(A[...,i-2,:], A[...,i-1,:]) 44 | return AI 45 | 46 | def inverse_transpose_torch(A): 47 | I = adjoint_torch(A) 48 | det = dot_torch(I, A).mean(dim=-1) 49 | return I / det[:,None,None] 50 | 51 | def inverse_torch(A): 52 | return inverse_transpose_torch(A).transpose(1, 2) 53 | 54 | def dot_torch(A, B): 55 | A_view = A.view(-1,1,3) 56 | B_view = B.contiguous().view(-1,3,1) 57 | out = torch.bmm(A_view, B_view) 58 | out_view = out.view(A.size()[:-1]) 59 | return out_view 60 | 61 | 62 | if __name__ == "__main__": 63 | A = np.random.rand(2,3,3) 64 | I = inverse(A) 65 | 66 | A_torch = torch.from_numpy(A) 67 | 68 | I_torch = inverse_torch(A_torch) 69 | print(I) 70 | print(I_torch) 71 | -------------------------------------------------------------------------------- /model/roi_align.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # https://github.com/facebookresearch/maskrcnn-benchmark 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.nn.modules.utils import _pair 8 | 9 | from model import _C 10 | 11 | 12 | class _ROIAlign(Function): 13 | @staticmethod 14 | def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): 15 | ctx.save_for_backward(roi) 16 | ctx.output_size = _pair(output_size) 17 | ctx.spatial_scale = spatial_scale 18 | ctx.sampling_ratio = sampling_ratio 19 | ctx.input_shape = input.size() 20 | output = _C.roi_align_forward( 21 | input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio 22 | ) 23 | return output 24 | 25 | @staticmethod 26 | @once_differentiable 27 | def backward(ctx, grad_output): 28 | rois, = ctx.saved_tensors 29 | output_size = ctx.output_size 30 | spatial_scale = ctx.spatial_scale 31 | sampling_ratio = ctx.sampling_ratio 32 | bs, ch, h, w = ctx.input_shape 33 | grad_input = _C.roi_align_backward( 34 | grad_output, 35 | rois, 36 | spatial_scale, 37 | output_size[0], 38 | output_size[1], 39 | bs, 40 | ch, 41 | h, 42 | w, 43 | sampling_ratio, 44 | ) 45 | return grad_input, None, None, None, None 46 | 47 | 48 | roi_align = _ROIAlign.apply 49 | 50 | 51 | class ROIAlign(nn.Module): 52 | def __init__(self, output_H, output_W, spatial_scale, sampling_ratio=-1): 53 | super(ROIAlign, self).__init__() 54 | self.output_size = (output_H,output_W) 55 | self.spatial_scale = spatial_scale 56 | self.sampling_ratio = sampling_ratio 57 | 58 | def forward(self, input, rois): 59 | return roi_align( 60 | input, rois, self.output_size, self.spatial_scale, self.sampling_ratio 61 | ) 62 | 63 | def __repr__(self): 64 | tmpstr = self.__class__.__name__ + "(" 65 | tmpstr += "output_size=" + str(self.output_size) 66 | tmpstr += ", spatial_scale=" + str(self.spatial_scale) 67 | tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) 68 | tmpstr += ")" 69 | return tmpstr 70 | -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 3 | it under the terms of the GNU General Public License as published by 4 | the Free Software Foundation, either version 3 of the License, or 5 | (at your option) any later version. 6 | 7 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 8 | but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | GNU General Public License for more details. 11 | 12 | You should have received a copy of the GNU General Public License 13 | along with Visual-Template-free-Form-Parsting. If not, see . 14 | """ 15 | import cv2 16 | import numpy as np 17 | 18 | def tensmeyer_brightness(img, foreground=0, background=0): 19 | if img.shape[2]==3: 20 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 21 | else: 22 | gray = img 23 | ret,th = cv2.threshold(gray ,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) 24 | 25 | th = (th.astype(np.float32) / 255)[...,None] 26 | 27 | img = img.astype(np.float32) 28 | img = img + (1.0 - th) * foreground 29 | img = img + th * background 30 | 31 | img[img>255] = 255 32 | img[img<0] = 0 33 | 34 | return img.astype(np.uint8) 35 | 36 | def apply_tensmeyer_brightness(img, sigma=20, **kwargs): 37 | random_state = np.random.RandomState(kwargs.get("random_seed", None)) 38 | foreground = random_state.normal(0,sigma) 39 | background = random_state.normal(0,sigma) 40 | #print('fore {}, back {}'.format(foreground,background)) 41 | 42 | img = tensmeyer_brightness(img, foreground, background) 43 | 44 | return img 45 | 46 | 47 | def increase_brightness(img, brightness=0, contrast=1): 48 | img = img.astype(np.float32) 49 | img = img * contrast + brightness 50 | img[img>255] = 255 51 | img[img<0] = 0 52 | 53 | return img.astype(np.uint8) 54 | 55 | def apply_random_brightness(img, b_range=[-50,51], **kwargs): 56 | random_state = np.random.RandomState(kwargs.get("random_seed", None)) 57 | brightness = random_state.randint(b_range[0], b_range[1]) 58 | 59 | img = increase_brightness(img, brightness) 60 | 61 | return input_data 62 | 63 | def apply_random_color_rotation(img, **kwargs): 64 | random_state = np.random.RandomState(kwargs.get("random_seed", None)) 65 | shift = random_state.randint(0,255) 66 | 67 | hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 68 | hsv[...,0] = hsv[...,0] + shift 69 | img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 70 | 71 | return img 72 | -------------------------------------------------------------------------------- /utils/string_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 3 | it under the terms of the GNU General Public License as published by 4 | the Free Software Foundation, either version 3 of the License, or 5 | (at your option) any later version. 6 | 7 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 8 | but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | GNU General Public License for more details. 11 | 12 | You should have received a copy of the GNU General Public License 13 | along with Visual-Template-free-Form-Parsting. If not, see . 14 | """ 15 | import numpy as np 16 | import sys 17 | def str2label_single(value, characterToIndex={}, unknown_index=None): 18 | if unknown_index is None: 19 | unknown_index = len(characterToIndex) 20 | 21 | label = [] 22 | for v in value: 23 | if v not in characterToIndex: 24 | continue 25 | # raise "Unknown Charactor to Label conversion" 26 | label.append(characterToIndex[v]) 27 | return np.array(label, np.uint32) 28 | 29 | def label2input_single(value, num_of_inputs, char_break_interval): 30 | idx1 = len(value) * (char_break_interval + 1) + char_break_interval 31 | idx2 = num_of_inputs + 1 32 | input_data = [[0 for i in range(idx2)] for j in range(idx1)] 33 | 34 | cnt = 0 35 | for i in range(char_break_interval): 36 | input_data[cnt][idx2-1] = 1 37 | cnt += 1 38 | 39 | for i in range(len(value)): 40 | if value[i] == 0: 41 | input_data[cnt][idx2-1] = 1 42 | else: 43 | input_data[cnt][value[i]-1] = 1 44 | cnt += 1 45 | 46 | for i in range(char_break_interval): 47 | input_data[cnt][idx2-1] = 1 48 | cnt += 1 49 | 50 | return np.array(input_data) 51 | 52 | def label2str_single(label, indexToCharacter, asRaw, spaceChar = "~"): 53 | string = u"" 54 | for i in range(len(label)): 55 | if label[i] == 0: 56 | if asRaw: 57 | string += spaceChar 58 | else: 59 | break 60 | else: 61 | val = label[i] 62 | string += indexToCharacter[val] 63 | return string 64 | 65 | def naive_decode(output): 66 | rawPredData = np.argmax(output, axis=1) 67 | predData = [] 68 | for i in range(len(output)): 69 | if rawPredData[i] != 0 and not ( i > 0 and rawPredData[i] == rawPredData[i-1] ): 70 | predData.append(rawPredData[i]) 71 | return predData, list(rawPredData) 72 | -------------------------------------------------------------------------------- /model/simpleNN.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | from base import BaseModel 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import math 21 | from torch.nn.utils.weight_norm import weight_norm 22 | 23 | class SimpleNN(BaseModel): 24 | def __init__(self, config): 25 | super(SimpleNN, self).__init__(config) 26 | 27 | featSize = config['feat_size'] if 'feat_size' in config else 10 28 | numLayers = config['num_layers'] if 'num_layers' in config else 2 29 | hiddenSize = config['hidden_size'] if 'hidden_size' in config else 1024 30 | outSize = config['out_size'] if 'out_size' in config else 1 31 | 32 | reverse = config['reverse_activation'] if 'reverse_activation' in config else False #for resnet stuff 33 | norm = config['norm'] if 'norm' in config else 'batch_norm' 34 | dropout = float(config['dropout']) if 'dropout' in config else 0.4 35 | 36 | if numLayers==0: 37 | assert(featSize==hiddenSize) 38 | 39 | layers= [] 40 | for i in range(numLayers): 41 | if i==0: 42 | inSize=featSize 43 | else: 44 | inSize=hiddenSize 45 | 46 | if not reverse: 47 | layers.append(nn.Linear(inSize,hiddenSize)) 48 | if norm=='batch_norm': 49 | layers.append(nn.BatchNorm1d(hiddenSize)) 50 | elif norm=='group_norm': 51 | layers.append(nn.GroupNorm(getGroupSize(hiddenSize),hiddenSize)) 52 | if not reverse or i!=0: 53 | layers += [ 54 | nn.ReLU(inplace=True), 55 | nn.Dropout(dropout) 56 | ] 57 | if reverse: 58 | layers.append(nn.Linear(inSize,hiddenSize)) 59 | if outSize>0: 60 | layers.append(nn.Linear(hiddenSize,outSize)) 61 | self.layers=nn.Sequential(*layers) 62 | 63 | def forward(self,input): 64 | return self.layers(input) 65 | -------------------------------------------------------------------------------- /cf_baseline_detector.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "baseline_detector", 3 | "cuda": true, 4 | "gpu": 0, 5 | "super_computer":false, 6 | "save_mode": "state_dict", 7 | "override": true, 8 | "data_loader": { 9 | "data_set_name": "FormsBoxDetect", 10 | "data_dir": "../data/NAF_dataset", 11 | "batch_size": 5, 12 | "shuffle": true, 13 | "num_workers": 2, 14 | "crop_to_page":false, 15 | "color":false, 16 | "rescale_range": [0.4,0.65], 17 | "crop_params": { 18 | "crop_size":[652,1608], 19 | "pad":0, 20 | "flip_horz": true, 21 | "rot_degree_std_dev": 0.7 22 | }, 23 | "no_blanks": true, 24 | "swap_circle":true, 25 | "no_graphics":true, 26 | "cache_resized_images": true, 27 | "only_types": { 28 | "boxes":true 29 | }, 30 | "rotation": false 31 | 32 | 33 | }, 34 | "validation": { 35 | "shuffle": false, 36 | "crop_to_page":false, 37 | "color":false, 38 | "rescale_range": [0.52,0.52], 39 | "only_types": { 40 | "boxes":true 41 | }, 42 | "no_blanks": true, 43 | "swap_circle":true, 44 | "no_graphics":true, 45 | "batch_size": 1, 46 | "rotation": false 47 | }, 48 | 49 | 50 | "lr_scheduler_type": "none", 51 | 52 | "optimizer_type": "Adam", 53 | "optimizer": { 54 | "lr": 0.01, 55 | "weight_decay": 0 56 | }, 57 | "loss": { 58 | "box":"YoloLoss" 59 | }, 60 | "loss_params": { 61 | "box": { 62 | "ignore_thresh": 0.5, 63 | "bad_conf_weight": 20.0 64 | } 65 | }, 66 | "loss_weights":{"box":1.0}, 67 | "metrics": [], 68 | "trainer": { 69 | "class": "BoxDetectTrainer", 70 | "iterations": 150000, 71 | "save_dir": "saved/", 72 | "val_step": 10000, 73 | "save_step": 50000, 74 | "save_step_minor": 500, 75 | "log_step": 500, 76 | "verbosity": 1, 77 | "monitor": "loss", 78 | "monitor_mode": "none", 79 | "warmup_steps": 1000, 80 | "thresh_conf":0.88, 81 | "thresh_intersect":0.4 82 | }, 83 | "arch": "YoloBoxDetector", 84 | "model": { 85 | "color":false, 86 | "number_of_box_types": 2, 87 | "number_of_point_types": 0, 88 | "number_of_pixel_types": 0, 89 | "norm_type": "group_norm", 90 | "dropout": true, 91 | "down_layers_cfg": [1,"k5-32", "M", 92 | 64, 64, "M", 93 | 128, 128, 128, "M", 94 | 128, 128, 128, "M", 95 | 256, 256, 256, 256], 96 | "up_layers_cfg":[], 97 | "anchors_file": "anchors_noRot_new_25.json", 98 | "rotation":false 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /notes.txt: -------------------------------------------------------------------------------- 1 | -a validation=eval=True,data_loader=num_workers=0,validation=batch_size=1,data_loader=eval=True,data_loader=batch_size=1,optimize=True 2 | 3 | python eval.py -c saved/Mix18_staggerLight_NN/checkpoint-iteration200000.pth.tar -g 0 -n 10000 -a save_json=out_json/Mix18_staggerLight_NN,data_loader=batch_size=1,data_loader=num_workers=0,data_loader=rescale_range=0.52,data_loader=crop_params=,validation=rescale_range=0.52,validation=crop_params= 4 | 5 | 6 | 7 | Toy maxpairs 8 | 1,fe,rcr,hal 0.69 9 | step at 10 0.76 10 | step 15 0.85 11 | step 20 @40 0.86 12 | step 30 @40 0.88 13 | 14 | Step sch, #40,000 15 | 1,fe,rcr,learnedQ 0.94 16 | 1,fe,rcr,half 0.80 17 | 1,fe,rcr 0.76 18 | 1,fe,rcr,learnedQ,avgE 0.92, 0.89 19 | 1,fe,rcr,learnedQ,avgE,rr 0.88, 0.91 20 | 1,fe,rcr,learnedQ,rr 0.86, 0.88 21 | 1,fe,rcr,lrn1,avgE,rr 0.94, 0.90 * 22 | 1,fe,rcr,lrn1,avgE,rr,none0 0.47, 0.68 23 | 1,fe,lrn1,avgE,rr 0.83, 0.82 24 | 1,fe,rcr,lrn1,avgE,rr,relu 0.91, 0.89 25 | 1,fe,rcr,lrn1 0.89, 0.85, 0.87 26 | 1,fe,lrn1 0.69, 0.64 27 | 1,fe,rcr,lrn1,avgE,rr,none1 0.66 28 | 1,fe,gru,lrn1,avgE,rr 0.90,0.88 29 | 1,fe,rcr,lrn1,avgE,rr,prune 0.84,0.82,0.86 30 | 1,fe,rcr,lrn4,avgE,rr,prune 0.90,0.86 31 | 1,fe,rcr,lrn4,avgE,rr 0.91,0.90,0.92,0.84 32 | 1,fe,rcr,tree,avgE,rr 0.93,1.0,0.93,0.99,0.92 33 | 1,fe,rcr,tree,avgE 0.97,0.97,0.96,0.97 +.02,-.01,-.03,-.01 using 20 reps 34 | 35 | step only 30000 36 | 1,fe,gru,lrn1,avgE,rr 0.94,0.88 37 | 38 | 39 | After running reproducibility instructions: 40 | bb_ap overall mean: 0.4245414744870879, std 0.05992214677924842 41 | bb_recall overall mean: [0.9085108 0.75969355], std [0.07431647 0.19411495] 42 | bb_prec overall mean: [0.78065136 0.71984299], std [0.20507893 0.196836 ] 43 | bb_Fm overall mean: -1.0, std 0.0 44 | nn_loss overall mean: 0.03530250337759131, std 0.0331923394725507 45 | rel_recall overall mean: 0.6944726554909622, std 0.21906173981880636 46 | rel_precision overall mean: 0.6514794508010785, std 0.15718074652845865 47 | rel_Fm overall mean: 0.6430445318270176, std 0.14279309841442422 48 | relMissedByHeur overall mean: 1.368421052631579, std 2.5175285774950313 49 | relMissedByDetect overall mean: 5.315789473684211, std 7.356004576133685 50 | heurRecall overall mean: 0.9638793891082221, std 0.06479530822436647 51 | detectRecall overall mean: 0.7740061464088924, std 0.19850782946862844 52 | prec@0.5 overall mean: 0.6514794508010785, std 0.15718074652845865 53 | recall@0.5 overall mean: 0.6944726554909622, std 0.21906173981880636 54 | F-M@0.5 overall mean: 0.6430445318270176, std 0.14279309841442422 55 | rel_AP@0.5 overall mean: 0.616308119655676, std 0.21664192844128685 56 | rel_AP overall mean: 0.616308119655676, std 0.21664192844128685 57 | no_targs overall mean: 0.0, std 0.0 58 | nn_loss_final overall mean: 0.33462982134599434, std 0.3748013931824183 59 | nn_loss_diff overall mean: 0.29932731796840306, std 0.3421111374927549 60 | nn_acc_final overall mean: 0.7141149257469313, std 0.16618617068601443 61 | nn_acc_detector overall mean: 0.7141038546485179, std 0.1379463483665521 62 | -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import numpy as np 17 | 18 | 19 | def my_metric(y_input, y_target): 20 | assert len(y_input) == len(y_target) 21 | correct = 0 22 | for y0, y1 in zip(y_input, y_target): 23 | if np.array_equal(y0, y1): 24 | correct += 1 25 | return correct / len(y_input) 26 | 27 | 28 | def meanIOU(y_output, y_target): 29 | assert len(y_output) == len(y_target) 30 | epsilon = 0.001 31 | iouSum = 0 32 | for out, targ in zip(y_output, y_target): 33 | binary = out>0 #torch.where(out>0,1,0) 34 | #binary = torch.round(y_input) #threshold at 0.5 35 | intersection = (binary * targ).sum() 36 | union = (binary + targ).sum() - intersection 37 | iouSum += (intersection+epsilon) / (union+epsilon) 38 | return iouSum / float(len(y_output)) 39 | 40 | def mean_xy(xyrs_output, xyrs_target): 41 | assert len(xyrs_output) == len(xyrs_target) 42 | dists=0 43 | for out, targ in zip(xyrs_output, xyrs_target): 44 | dists+=( (out[0:2]-targ[0:2]).linalg.norm() ) 45 | return dists/float(len(xyrs_output)) 46 | def std_xy(xyrs_output, xyrs_target): 47 | assert len(xyrs_output) == len(xyrs_target) 48 | dists=[] 49 | for out, targ in zip(xyrs_output, xyrs_target): 50 | dists.append( (out[0:2]-targ[0:2]).linalg.norm() ) 51 | return np.std(dists) 52 | def mean_rot(xyrs_output, xyrs_target): 53 | assert len(xyrs_output) == len(xyrs_target) 54 | rotDiffs=0 55 | for out, targ in zip(xyrs_output, xyrs_target): 56 | rotDiffs+=(targ[2]-out[2]) 57 | return rotDiffs/float(len(xyrs_output)) 58 | def std_rot(xyrs_output, xyrs_target): 59 | assert len(xyrs_output) == len(xyrs_target) 60 | rotDiffs=[] 61 | for out, targ in zip(xyrs_output, xyrs_target): 62 | rotDiffs.append(targ[2]-out[2]) 63 | return np.std(rotDiffs) 64 | def mean_scale(xyrs_output, xyrs_target): 65 | assert len(xyrs_output) == len(xyrs_target) 66 | scaleDiffs=0 67 | for out, targ in zip(xyrs_output, xyrs_target): 68 | scaleDiffs+=(targ[3]-out[3]) 69 | return scaleDiffs/float(len(xyrs_output)) 70 | def std_scale(xyrs_output, xyrs_target): 71 | assert len(xyrs_output) == len(xyrs_target) 72 | scaleDiffs=[] 73 | for out, targ in zip(xyrs_output, xyrs_target): 74 | scaleDiffs.append(targ[3]-out[3]) 75 | return np.std(scaleDiffs) 76 | -------------------------------------------------------------------------------- /cf_detector.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "detector", 3 | "cuda": true, 4 | "gpu": 0, 5 | "super_computer":false, 6 | "save_mode": "state_dict", 7 | "override": true, 8 | "data_loader": { 9 | "data_set_name": "FormsBoxDetect", 10 | "data_dir": "../data/NAF_dataset", 11 | "batch_size": 5, 12 | "shuffle": true, 13 | "num_workers": 2, 14 | "crop_to_page":false, 15 | "color":false, 16 | "rescale_range": [0.4,0.65], 17 | "crop_params": { 18 | "crop_size":[652,1608], 19 | "pad":0, 20 | "flip_horz": true, 21 | "rot_degree_std_dev": 0.7 22 | }, 23 | "no_blanks": true, 24 | "swap_circle":true, 25 | "no_graphics":true, 26 | "cache_resized_images": true, 27 | "only_types": { 28 | "boxes":true 29 | }, 30 | "rotation": false 31 | 32 | 33 | }, 34 | "validation": { 35 | "shuffle": false, 36 | "crop_to_page":false, 37 | "color":false, 38 | "rescale_range": [0.52,0.52], 39 | "crop_params": null, 40 | "only_types": { 41 | "boxes":true 42 | }, 43 | "no_blanks": true, 44 | "swap_circle":true, 45 | "no_graphics":true, 46 | "batch_size": 1, 47 | "rotation": false 48 | }, 49 | 50 | 51 | "lr_scheduler_type": "none", 52 | 53 | "optimizer_type": "Adam", 54 | "optimizer": { 55 | "lr": 0.02, 56 | "weight_decay": 0 57 | }, 58 | "loss": { 59 | "box":"YoloLoss" 60 | }, 61 | "loss_params": { 62 | "box": { 63 | "ignore_thresh": 0.5, 64 | "bad_conf_weight": 20.0 65 | } 66 | }, 67 | "loss_weights":{"box":1.0}, 68 | "metrics": [], 69 | "trainer": { 70 | "class": "BoxDetectTrainer", 71 | "iterations": 150000, 72 | "save_dir": "saved/", 73 | "val_step": 10000, 74 | "save_step": 50000, 75 | "save_step_minor": 500, 76 | "log_step": 500, 77 | "verbosity": 1, 78 | "monitor": "loss", 79 | "monitor_mode": "none", 80 | "warmup_steps": 1000, 81 | "thresh_conf":0.88, 82 | "thresh_intersect":0.4, 83 | 84 | "use_learning_schedule": "detector" 85 | }, 86 | "arch": "YoloBoxDetector", 87 | "model": { 88 | "color":false, 89 | "pred_num_neighbors": true, 90 | "number_of_box_types": 2, 91 | "number_of_point_types": 0, 92 | "number_of_pixel_types": 0, 93 | "norm_type": "group_norm", 94 | "dropout": true, 95 | "down_layers_cfg": [1,"k5-32", "M", 96 | 64, 64, "M", 97 | "hd2-128", "vd1-128", "hd4-128", 128, "M", 98 | "hd4-128", "vd1-128", "hd8-128", 128, "M", 99 | "hd8-256","vd1-256","hd16-256",256, 256], 100 | "up_layers_cfg":[], 101 | "anchors_file": "anchors_noRot_new_25.json", 102 | "rotation":false 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /cf_test_no_vis_pairing.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "test no_vis_pairing", 3 | "cuda": true, 4 | "gpu": 0, 5 | "save_mode": "state_dict", 6 | "override": true, 7 | "super_computer":false, 8 | "data_loader": { 9 | "data_set_name": "FormsGraphPair", 10 | "special_dataset": "simple", 11 | "data_dir": "../data/NAF_dataset", 12 | "batch_size": 1, 13 | "shuffle": true, 14 | "num_workers": 1, 15 | "crop_to_page":false, 16 | "color":false, 17 | "rescale_range": [0.4,0.65], 18 | "crop_params": { 19 | "crop_size":[652,1608], 20 | "pad":0 21 | }, 22 | "no_blanks": true, 23 | "swap_circle":true, 24 | "no_graphics":true, 25 | "cache_resized_images": true, 26 | "rotation": false, 27 | "only_opposite_pairs": true 28 | 29 | 30 | }, 31 | "validation": { 32 | "shuffle": false, 33 | "rescale_range": [0.52,0.52], 34 | "crop_params": null, 35 | "batch_size": 1 36 | }, 37 | 38 | 39 | "lr_scheduler_type": "none", 40 | 41 | "optimizer_type": "Adam", 42 | "optimizer": { 43 | "lr": 0.008, 44 | "weight_decay": 0 45 | }, 46 | "loss": { 47 | "box": "YoloLoss", 48 | "edge": "sigmoid_BCE_loss" 49 | }, 50 | "loss_weights": { 51 | "box": 1, 52 | "edge": 1 53 | }, 54 | "loss_params": 55 | { 56 | "box": {"ignore_thresh": 0.5, 57 | "bad_conf_weight": 20.0, 58 | "multiclass":true} 59 | }, 60 | "metrics": [], 61 | "trainer": { 62 | "class": "GraphPairTrainer", 63 | "iterations": 200000, 64 | "save_dir": "saved/", 65 | "val_step": 10000, 66 | "save_step": 50000, 67 | "save_step_minor": 250, 68 | "log_step": 250, 69 | "verbosity": 1, 70 | "monitor": "loss", 71 | "monitor_mode": "none", 72 | "warmup_steps": 1000, 73 | "conf_thresh_init": 0.9, 74 | "conf_thresh_change_iters": 5000, 75 | "retry_count":3, 76 | 77 | "unfreeze_detector": 99999000, 78 | "partial_from_gt": 8000, 79 | "stop_from_gt": 200000, 80 | 81 | "use_bad_bb_pred_for_rel_loss": true 82 | }, 83 | "arch": "PairingGraph", 84 | "model": { 85 | "detector_checkpoint": "saved/detector/checkpoint-iteration150000.pth", 86 | "conf_thresh": 0.92, 87 | "start_frozen": true, 88 | "use_rel_shape_feats": "corner", 89 | "expand_rel_context": 100, 90 | "use_detect_layer_feats": 16, 91 | "graph_config": { 92 | "arch": "BinaryPairReal", 93 | "in_channels": 256, 94 | "node_channels": 0, 95 | "edge_channels": 1, 96 | "layers": ["FC256","FC256"], 97 | "shape_layers": "saved/no_vis_pairing/checkpoint-iteration6000.pth", 98 | "weight_split": 0.0 99 | }, 100 | "featurizer_start_h": 10, 101 | "featurizer_start_w": 10, 102 | "featurizer_conv": [128,128,"M",256,240], 103 | "featurizer_fc": null 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | import numpy as np 3 | 4 | 5 | class BaseDataLoader: 6 | """ 7 | Base class for all data loaders 8 | """ 9 | def __init__(self, config): 10 | self.config = config 11 | self.batch_size = config['data_loader']['batch_size'] 12 | self.shuffle = config['data_loader']['shuffle'] 13 | self.batch_idx = 0 14 | 15 | def __iter__(self): 16 | """ 17 | :return: Iterator 18 | """ 19 | assert self.__len__() > 0 20 | self.batch_idx = 0 21 | if self.shuffle: 22 | self._shuffle_data() 23 | return self 24 | 25 | def __next__(self): 26 | """ 27 | :return: Next batch 28 | """ 29 | packed = self._pack_data() 30 | if self.batch_idx < self.__len__(): 31 | batch = packed[self.batch_idx * self.batch_size:(self.batch_idx + 1) * self.batch_size] 32 | self.batch_idx = self.batch_idx + 1 33 | return self._unpack_data(batch) 34 | else: 35 | raise StopIteration 36 | 37 | def __len__(self): 38 | """ 39 | :return: Total number of batches 40 | """ 41 | return self._n_samples() // self.batch_size 42 | 43 | def _n_samples(self): 44 | """ 45 | :return: Total number of samples 46 | """ 47 | return NotImplementedError 48 | 49 | def _pack_data(self): 50 | """ 51 | Pack all data into a list/tuple/ndarray/... 52 | 53 | :return: Packed data in the data loader 54 | """ 55 | return NotImplementedError 56 | 57 | def _unpack_data(self, packed): 58 | """ 59 | Unpack packed data (from _pack_data()) 60 | 61 | :param packed: Packed data 62 | :return: Unpacked data 63 | """ 64 | return NotImplementedError 65 | 66 | def _update_data(self, unpacked): 67 | """ 68 | Update data member in the data loader 69 | 70 | :param unpacked: Unpacked data (from _update_data()) 71 | """ 72 | return NotImplementedError 73 | 74 | def _shuffle_data(self): 75 | """ 76 | Shuffle data members in the data loader 77 | """ 78 | packed = self._pack_data() 79 | rand_idx = np.random.permutation(len(packed)) 80 | packed = [packed[i] for i in rand_idx] 81 | self._update_data(self._unpack_data(packed)) 82 | 83 | def split_validation(self): 84 | """ 85 | Split validation data from data loader based on self.config['validation'] 86 | """ 87 | validation_split = self.config['validation']['validation_split'] 88 | shuffle = self.config['validation']['shuffle'] 89 | if validation_split == 0.0: 90 | return None 91 | if shuffle: 92 | self._shuffle_data() 93 | valid_data_loader = copy(self) 94 | split = int(self._n_samples() * validation_split) 95 | packed = self._pack_data() 96 | train_data = self._unpack_data(packed[split:]) 97 | val_data = self._unpack_data(packed[:split]) 98 | valid_data_loader._update_data(val_data) 99 | self._update_data(train_data) 100 | return valid_data_loader 101 | -------------------------------------------------------------------------------- /model/coordconv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import torch 17 | from torch import nn 18 | import numpy as np 19 | 20 | class CoordConv(nn.Module): 21 | def __init__(self,in_ch,out_ch,kernel_size=3,padding=1,dilation=1,groups=1,features='wave'): 22 | super(CoordConv, self).__init__() 23 | self.features=features 24 | if 'wave' in features: 25 | if 'Big' in features: 26 | self.numChX=10 27 | self.numChY=7 28 | self.minCycle=8 29 | self.maxCycleX=2000 30 | self.maxCycleY=1400 31 | elif 'Med' in features: 32 | self.numChX=10 33 | self.numChY=7 34 | self.minCycle=4 35 | self.maxCycleX=1000 36 | self.maxCycleY=700 37 | elif 'Small' in features: 38 | self.numChX=10 39 | self.numChY=7 40 | self.minCycle=2 41 | self.maxCycleX=500 42 | self.maxCycleY=350 43 | else: 44 | self.numChX=5 45 | self.numChY=4 46 | self.minCycle=16 47 | self.maxCycleX=1000 48 | self.maxCycleY=700 49 | 50 | self.cycleStepX = (self.maxCycleX-self.minCycle)/((self.numChX-1)**2) 51 | self.cycleStepY = (self.maxCycleY-self.minCycle)/((self.numChY-1)**2) 52 | self.numExtra=self.numChX+self.numChY 53 | 54 | self.conv = nn.Conv2d(in_ch+self.numExtra,out_ch, kernel_size=kernel_size, padding=padding,dilation=dilation,groups=groups) 55 | 56 | def forward(self,input): 57 | batch_size = input.size(0) 58 | dimY=input.size(2) 59 | dimX=input.size(3) 60 | if 'wave' in self.features: 61 | if self.training: 62 | xOffset = np.random.randint(0,self.maxCycleX) 63 | yOffset = np.random.randint(0,self.maxCycleY) 64 | else: 65 | xOffset=0 66 | yOffset=0 67 | 68 | extraX = torch.FloatTensor(self.numChX,dimX) 69 | x_range = torch.arange(dimX, dtype=torch.float64) + xOffset 70 | for i in range(self.numChX): 71 | cycle = self.minCycle + self.cycleStepX*(i**2) 72 | extraX[i] = torch.sin(x_range*np.pi*2/cycle) 73 | extraX = extraX[:,None,:].expand(self.numChX,dimY,dimX) 74 | 75 | extraY = torch.FloatTensor(self.numChY,dimY) 76 | y_range = torch.arange(dimY, dtype=torch.float64) + yOffset 77 | for i in range(self.numChY): 78 | cycle = self.minCycle + self.cycleStepY*(i**2) 79 | extraY[i] = torch.sin(y_range*np.pi*2/cycle) 80 | extraY = extraY[:,:,None].expand(self.numChY,dimY,dimX) 81 | extra = torch.cat((extraY,extraX),dim=0) 82 | 83 | 84 | extra = extra[None,...].repeat(batch_size,1,1,1).to(input.device) 85 | data = torch.cat((input,extra),dim=1) 86 | 87 | return self.conv(data) 88 | -------------------------------------------------------------------------------- /cf_pairing.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "pairing", 3 | "cuda": true, 4 | "gpu": 0, 5 | "save_mode": "state_dict", 6 | "override": true, 7 | "super_computer":false, 8 | "data_loader": { 9 | "data_set_name": "FormsGraphPair", 10 | "special_dataset": "simple", 11 | "data_dir": "../data/NAF_dataset", 12 | "batch_size": 1, 13 | "shuffle": true, 14 | "num_workers": 1, 15 | "crop_to_page":false, 16 | "color":false, 17 | "rescale_range": [0.4,0.65], 18 | "crop_params": { 19 | "crop_size":[652,1608], 20 | "pad":0 21 | }, 22 | "no_blanks": true, 23 | "swap_circle":true, 24 | "no_graphics":true, 25 | "cache_resized_images": true, 26 | "rotation": false, 27 | "only_opposite_pairs": true 28 | 29 | 30 | }, 31 | "validation": { 32 | "shuffle": false, 33 | "rescale_range": [0.52,0.52], 34 | "crop_params": null, 35 | "batch_size": 1 36 | }, 37 | 38 | 39 | "lr_scheduler_type": "none", 40 | 41 | "optimizer_type": "Adam", 42 | "optimizer": { 43 | "lr": 0.001, 44 | "weight_decay": 0 45 | }, 46 | "loss": { 47 | "box": "YoloLoss", 48 | "edge": "sigmoid_BCE_loss", 49 | "nn": "MSE", 50 | "class": "sigmoid_BCE_loss" 51 | }, 52 | "loss_weights": { 53 | "box": 1.0, 54 | "edge": 0.5, 55 | "nn": 0.25, 56 | "class": 0.25 57 | }, 58 | "loss_params": 59 | { 60 | "box": {"ignore_thresh": 0.5, 61 | "bad_conf_weight": 20.0, 62 | "multiclass":true} 63 | }, 64 | "metrics": [], 65 | "trainer": { 66 | "class": "GraphPairTrainer", 67 | "iterations": 125000, 68 | "save_dir": "saved/", 69 | "val_step": 5000, 70 | "save_step": 25000, 71 | "save_step_minor": 250, 72 | "log_step": 250, 73 | "verbosity": 1, 74 | "monitor": "loss", 75 | "monitor_mode": "none", 76 | "warmup_steps": 1000, 77 | "conf_thresh_init": 0.5, 78 | "conf_thresh_change_iters": 0, 79 | "retry_count":1, 80 | 81 | "unfreeze_detector": 2000, 82 | "partial_from_gt": 0, 83 | "stop_from_gt": 20000, 84 | "max_use_pred": 0.5, 85 | "use_all_bb_pred_for_rel_loss": true, 86 | 87 | "use_learning_schedule": true, 88 | "adapt_lr": false 89 | }, 90 | "arch": "PairingGraph", 91 | "model": { 92 | "detector_checkpoint": "saved/detector/checkpoint-iteration150000.pth", 93 | "conf_thresh": 0.5, 94 | "start_frozen": true, 95 | "use_rel_shape_feats": "corner", 96 | "use_detect_layer_feats": 16, 97 | "use_2nd_detect_layer_feats": 0, 98 | "use_2nd_detect_scale_feats": 2, 99 | "use_2nd_detect_feats_size": 64, 100 | "use_fixed_masks": true, 101 | "no_grad_feats": true, 102 | 103 | "expand_rel_context": 150, 104 | "featurizer_start_h": 32, 105 | "featurizer_start_w": 32, 106 | "featurizer_conv": ["sep128","M","sep128","sep128","M","sep256","sep256","M",238], 107 | "featurizer_fc": null, 108 | 109 | "pred_nn": true, 110 | "pred_class": false, 111 | "expand_bb_context": 150, 112 | "featurizer_bb_start_h": 32, 113 | "featurizer_bb_start_w": 32, 114 | "bb_featurizer_conv": ["sep64","M","sep64","sep64","M","sep128","sep128","M",250], 115 | 116 | "graph_config": { 117 | "arch": "BinaryPairReal", 118 | "in_channels": 256, 119 | "bb_out": 1, 120 | "rel_out": 1, 121 | "layers": ["FC256","FC256"], 122 | "layers_bb": ["FC256"] 123 | } 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /datasets/printforms_box_detect.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | from datasets.forms_box_detect import FormsBoxDetect 17 | from datasets import forms_box_detect 18 | import math 19 | import sys 20 | import os, cv2 21 | import numpy as np 22 | import torch 23 | 24 | def saveBoxes(data,dest): 25 | batchSize = data['img'].size(0) 26 | for b in range(batchSize): 27 | #print (data['img'].size()) 28 | #img = (data['img'][0].permute(1,2,0)+1)/2.0 29 | img = 255*(1-data['img'][b].permute(1,2,0))/2.0 30 | #print(img.shape) 31 | #print(data['pixel_gt']['table_pixels'].shape) 32 | if 'pixel_gt' in data and data['pixel_gt'] is not None: 33 | img[:,:,1] = data['pixel_gt'][b,0,:,:] 34 | imgName=(data['imgName'][b]) 35 | 36 | 37 | img=img.numpy().astype(np.uint8) 38 | 39 | for i in range(data['bb_sizes'][b]): 40 | xc=data['bb_gt'][b,i,0] 41 | yc=data['bb_gt'][b,i,1] 42 | rot=data['bb_gt'][b,i,2] 43 | h=data['bb_gt'][b,i,3] 44 | w=data['bb_gt'][b,i,4] 45 | text=data['bb_gt'][b,i,13] 46 | field=data['bb_gt'][b,i,14] 47 | if text>0: 48 | sub = 'text' 49 | else: 50 | sub = 'field' 51 | tr = (math.cos(rot)*w-math.sin(rot)*h +xc, -math.sin(rot)*w-math.cos(rot)*h +yc) 52 | tl = (-math.cos(rot)*w-math.sin(rot)*h +xc, math.sin(rot)*w-math.cos(rot)*h +yc) 53 | br = (math.cos(rot)*w+math.sin(rot)*h +xc, -math.sin(rot)*w+math.cos(rot)*h +yc) 54 | bl = (-math.cos(rot)*w+math.sin(rot)*h +xc, math.sin(rot)*w+math.cos(rot)*h +yc) 55 | #print([tr,tl,br,bl]) 56 | assert(rot==0) 57 | crop = img[int(tl[1]):int(br[1])+1,int(tl[0]):int(br[0])+1] 58 | path = os.path.join(dest,sub,'{}_b{}.png'.format(imgName,i)) 59 | cv2.imwrite(path,crop) 60 | 61 | 62 | 63 | 64 | 65 | if __name__ == "__main__": 66 | dirPath = sys.argv[1] 67 | if len(sys.argv)>2: 68 | directory = sys.argv[2] 69 | else: 70 | print('need dest dir') 71 | exit() 72 | dirText = os.path.join(directory,'text') 73 | dirField = os.path.join(directory,'field') 74 | if not os.path.exists(dirText): 75 | os.makedirs(dirText) 76 | if not os.path.exists(dirField): 77 | os.makedirs(dirField) 78 | data=FormsBoxDetect(dirPath=dirPath,split='valid',config={'crop_to_page':False, 79 | #'rescale_range':[0.45,0.6], 80 | #'rescale_range':[0.52,0.52], 81 | 'rescale_range':[1,1], 82 | #'crop_params':{ "crop_size":[652,1608], 83 | # "pad":0, 84 | # "rot_degree_std_dev":1.5, 85 | # "flip_horz": True, 86 | # "flip_vert": True}, 87 | 'no_blanks':True, 88 | 'use_paired_class':True, 89 | "swap_circle":True, 90 | 'no_graphics':True, 91 | 'rotation':False, 92 | "only_types": {"boxes":True} 93 | }) 94 | #data.cluster(start,repeat,'anchors_rot_{}.json') 95 | 96 | dataLoader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False, num_workers=0, collate_fn=forms_box_detect.collate) 97 | dataLoaderIter = iter(dataLoader) 98 | 99 | try: 100 | while True: 101 | #print('?') 102 | saveBoxes(dataLoaderIter.next(),directory) 103 | except StopIteration: 104 | print('done') 105 | -------------------------------------------------------------------------------- /datasets/testforms_graph_pair.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | from datasets.forms_graph_pair import FormsGraphPair 17 | from datasets import forms_graph_pair 18 | import math 19 | import sys 20 | from matplotlib import pyplot as plt 21 | from matplotlib import gridspec 22 | from matplotlib.patches import Polygon 23 | import numpy as np 24 | import torch 25 | 26 | def display(data): 27 | b=0 28 | 29 | #print (data['img'].size()) 30 | #img = (data['img'][0].permute(1,2,0)+1)/2.0 31 | img = (data['img'][b].permute(1,2,0)+1)/2.0 32 | #print(img.shape) 33 | #print(data['pixel_gt']['table_pixels'].shape) 34 | print(data['imgName']) 35 | 36 | 37 | 38 | fig = plt.figure() 39 | #gs = gridspec.GridSpec(1, 3) 40 | 41 | ax_im = plt.subplot() 42 | ax_im.set_axis_off() 43 | ax_im.imshow(img[:,:,0]) 44 | 45 | colors = { 'text_start_gt':'g-', 46 | 'text_end_gt':'b-', 47 | 'field_start_gt':'r-', 48 | 'field_end_gt':'y-', 49 | 'table_points':'co' 50 | } 51 | #print('num bb:{}'.format(data['bb_sizes'][b])) 52 | for i in range(data['bb_gt'].size(1)): 53 | xc=data['bb_gt'][b,i,0] 54 | yc=data['bb_gt'][b,i,1] 55 | rot=data['bb_gt'][b,i,2] 56 | h=data['bb_gt'][b,i,3] 57 | w=data['bb_gt'][b,i,4] 58 | text=data['bb_gt'][b,i,13] 59 | field=data['bb_gt'][b,i,14] 60 | if text>0: 61 | color = 'b-' 62 | else: 63 | color = 'r-' 64 | tr = (math.cos(rot)*w-math.sin(rot)*h +xc, math.sin(rot)*w+math.cos(rot)*h +yc) 65 | tl = (math.cos(rot)*-w-math.sin(rot)*h +xc, math.sin(rot)*-w+math.cos(rot)*h +yc) 66 | br = (math.cos(rot)*w-math.sin(rot)*-h +xc, math.sin(rot)*w+math.cos(rot)*-h +yc) 67 | bl = (math.cos(rot)*-w-math.sin(rot)*-h +xc, math.sin(rot)*-w+math.cos(rot)*-h +yc) 68 | #print([tr,tl,br,bl]) 69 | 70 | ax_im.plot([tr[0],tl[0],bl[0],br[0],tr[0]],[tr[1],tl[1],bl[1],br[1],tr[1]],color) 71 | for ind1,ind2 in data['adj']: 72 | x1=data['bb_gt'][b,ind1,0] 73 | y1=data['bb_gt'][b,ind1,1] 74 | x2=data['bb_gt'][b,ind2,0] 75 | y2=data['bb_gt'][b,ind2,1] 76 | 77 | ax_im.plot([x1,x2],[y1,y2],'g-') 78 | #print('{} to {}, {} - {}'.format(ind1,ind2,(x1,y1),(x2,y2))) 79 | plt.show() 80 | 81 | 82 | if __name__ == "__main__": 83 | dirPath = sys.argv[1] 84 | if len(sys.argv)>2: 85 | start = int(sys.argv[2]) 86 | else: 87 | start=0 88 | if len(sys.argv)>3: 89 | repeat = int(sys.argv[3]) 90 | else: 91 | repeat=1 92 | data=FormsGraphPair(dirPath=dirPath,split='train',config={ 93 | 'color':False, 94 | 'crop_to_page':False, 95 | 'rescale_range':[1,1], 96 | 'Xrescale_range':[0.4,0.65], 97 | 'Xcrop_params':{"crop_size":[652,1608],"pad":0}, 98 | 'no_blanks':True, 99 | "swap_circle":True, 100 | 'no_graphics':True, 101 | 'rotation':False, 102 | 'only_opposite_pairs':True, 103 | #"only_types": ["text_start_gt"] 104 | }) 105 | 106 | dataLoader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False, num_workers=0, collate_fn=forms_graph_pair.collate) 107 | dataLoaderIter = iter(dataLoader) 108 | 109 | #if start==0: 110 | #display(data[0]) 111 | for i in range(0,start): 112 | print(i) 113 | dataLoaderIter.next() 114 | #display(data[i]) 115 | try: 116 | while True: 117 | #print('?') 118 | display(dataLoaderIter.next()) 119 | except StopIteration: 120 | print('done') 121 | -------------------------------------------------------------------------------- /pruneClusters.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import math 4 | import sys 5 | import cv2 6 | 7 | if len(sys.argv)<2: 8 | print('usage: '+sys.argv[0]+' in.json k out.json') 9 | exit() 10 | 11 | def makePointsAndRects(h,w): 12 | return np.array([-w/2.0,0,w/2.0,0,0,-h/2.0,0,h/2.0, 0,0, 0, h,w]) 13 | 14 | with open(sys.argv[1]) as file: 15 | anchors = json.loads(file.read()) 16 | goalK = int(sys.argv[2]) 17 | outPath = sys.argv[3] 18 | 19 | #remove very unpopular anchors 20 | toRemove=[] 21 | for i in range(len(anchors)): 22 | if anchors[i]['popularity']<5: 23 | toRemove.append(i) 24 | toRemove.sort(reverse=True) 25 | #print(toRemove) 26 | for idx in toRemove: 27 | del anchors[idx] 28 | 29 | points = np.zeros([len(anchors),13]) 30 | for i in range(len(anchors)): 31 | points[i]=makePointsAndRects(anchors[i]['height'],anchors[i]['width']) 32 | expanded_points1_points = points[:,None,0:8] 33 | expanded_points1_heights = points[:,None,11] 34 | expanded_points1_widths = points[:,None,12] 35 | 36 | expanded_points2_points = points[None,:,0:8] 37 | expanded_points2_heights = points[None,:,11] 38 | expanded_points2_widths = points[None,:,12] 39 | 40 | #expanded_all_points = expanded_all_points.expand(all_points.shape[0], all_points.shape[1], means_points.shape[1], all_points.shape[2]) 41 | expanded_points1_points = np.tile(expanded_points1_points,(1,points.shape[0],1)) 42 | expanded_points1_heights = np.tile(expanded_points1_heights,(1,points.shape[0])) 43 | expanded_points1_widths = np.tile(expanded_points1_widths,(1,points.shape[0])) 44 | #expanded_means_points = expanded_means_points.expand(means_points.shape[0], all_points.shape[0], means_points.shape[0], means_points.shape[2]) 45 | expanded_points2_points = np.tile(expanded_points2_points,(points.shape[0],1,1)) 46 | expanded_points2_heights = np.tile(expanded_points2_heights,(points.shape[0],1)) 47 | expanded_points2_widths = np.tile(expanded_points2_widths,(points.shape[0],1)) 48 | 49 | point_deltas = (expanded_points1_points - expanded_points2_points) 50 | #avg_heights = ((expanded_means_heights+expanded_all_heights)/2) 51 | #avg_widths = ((expanded_means_widths+expanded_all_widths)/2) 52 | avg_heights=avg_widths = (expanded_points1_heights+expanded_points1_widths)/2 53 | #print point_deltas 54 | 55 | normed_difference = ( 56 | np.linalg.norm(point_deltas[:,:,0:2],2,2)/avg_widths + 57 | np.linalg.norm(point_deltas[:,:,2:4],2,2)/avg_widths + 58 | np.linalg.norm(point_deltas[:,:,4:6],2,2)/avg_heights + 59 | np.linalg.norm(point_deltas[:,:,6:8],2,2)/avg_heights 60 | )**2 61 | np.fill_diagonal(normed_difference,float('inf')) 62 | toRemove=[] 63 | for i in range(len(anchors)-goalK): 64 | cord = np.argmin(normed_difference) 65 | a=cord//len(anchors) 66 | b=cord%len(anchors) 67 | #print('{} {} {}'.format(cord,a,b)) 68 | 69 | normed_difference[a,b] = float('inf') 70 | normed_difference[b,a] = float('inf') 71 | #toRemove.append(a) 72 | #normed_difference[a,:] = float('inf') 73 | #normed_difference[:,a] = float('inf') 74 | 75 | if anchors[a]['popularity'] > anchors[b]['popularity']: 76 | toRemove.append(b) 77 | normed_difference[b,:] = float('inf') 78 | normed_difference[:,b] = float('inf') 79 | else: 80 | toRemove.append(a) 81 | normed_difference[a,:] = float('inf') 82 | normed_difference[:,a] = float('inf') 83 | 84 | toRemove.sort(reverse=True) 85 | #print(toRemove) 86 | for idx in toRemove: 87 | del anchors[idx] 88 | 89 | with open(outPath,'w') as out: 90 | out.write(json.dumps(anchors)) 91 | 92 | drawH=1000 93 | drawW=4000 94 | draw = np.zeros([drawH,drawW,3],dtype=np.float) 95 | for anchor in anchors: 96 | color = np.random.uniform(0.2,1,3).tolist() 97 | h=anchor['height'] 98 | w=anchor['width'] 99 | rot=anchor['rot'] 100 | tr = ( int(math.cos(rot)*w-math.sin(rot)*h)+(drawW//2), int(math.sin(rot)*w+math.cos(rot)*h)+(drawH//2) ) 101 | tl = ( int(math.cos(rot)*-w-math.sin(rot)*h)+(drawW//2), int(math.sin(rot)*-w+math.cos(rot)*h)+(drawH//2) ) 102 | br = ( int(math.cos(rot)*w-math.sin(rot)*-h)+(drawW//2), int(math.sin(rot)*w+math.cos(rot)*-h)+(drawH//2) ) 103 | bl = ( int(math.cos(rot)*-w-math.sin(rot)*-h)+(drawW//2), int(math.sin(rot)*-w+math.cos(rot)*-h)+(drawH//2) ) 104 | 105 | cv2.line(draw,tl,tr,color) 106 | cv2.line(draw,tr,br,color) 107 | cv2.line(draw,br,bl,color) 108 | cv2.line(draw,bl,tl,color,2) 109 | cv2.imshow('pruned',draw) 110 | cv2.waitKey() 111 | cv2.waitKey() 112 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import os 17 | import sys 18 | import signal 19 | import json 20 | import logging 21 | import argparse 22 | import torch 23 | from collections import defaultdict 24 | import numpy as np 25 | 26 | logging.basicConfig(level=logging.INFO, format='') 27 | 28 | 29 | def graph(log,plot=True,prefix=None): 30 | graphs=defaultdict(lambda:{'iters':[], 'values':[]}) 31 | for index, entry in log.entries.items(): 32 | iteration = entry['iteration'] 33 | for metric, value in entry.items(): 34 | if metric!='iteration': 35 | graphs[metric]['iters'].append(iteration) 36 | graphs[metric]['values'].append(value) 37 | 38 | print('summed') 39 | skip=[] 40 | for metric, data in graphs.items(): 41 | #print('{} max: {}, min {}'.format(metric,max(data['values']),min(data['values']))) 42 | ndata = np.array(data['values']) 43 | if ndata.dtype is not np.dtype(object): 44 | maxV = ndata.max(axis=0) 45 | minV = ndata.min(axis=0) 46 | print('{} max: {}, min {}'.format(metric,maxV,minV)) 47 | else: 48 | skip.append(metric) 49 | 50 | if plot: 51 | import matplotlib.pyplot as plt 52 | i=1 53 | for metric, data in graphs.items(): 54 | if metric in skip: 55 | continue 56 | if (prefix is None and (metric[:3]=='avg' or metric[:3]=='val')) or (prefix is not None and metric[:len(prefix)]==prefix): 57 | #print('{} == {}? {}'.format(metric[:len(prefix)],prefix,metric[:len(prefix)]==prefix)) 58 | plt.figure(i) 59 | i+=1 60 | plt.plot(data['iters'], data['values'], '.-') 61 | plt.xlabel('iterations') 62 | plt.ylabel(metric) 63 | plt.title(metric) 64 | plt.show() 65 | else: 66 | i=1 67 | for metric, data in graphs.items(): 68 | if metric[:3]=='avg' or metric[:3]=='val': 69 | print(metric) 70 | print(data['values']) 71 | 72 | 73 | 74 | 75 | 76 | if __name__ == '__main__': 77 | logger = logging.getLogger() 78 | 79 | parser = argparse.ArgumentParser(description='PyTorch Template') 80 | parser.add_argument('-c', '--checkpoint', default=None, type=str, 81 | help='checkpoint file path (default: None)') 82 | parser.add_argument('-p', '--plot', default=1, type=int, 83 | help='plot (default: True)') 84 | parser.add_argument('-o', '--only', default=None, type=str, 85 | help='only stats with this prefix (default: None)') 86 | parser.add_argument('-e', '--extract', default=None, type=str, 87 | help='instead of ploting, save a new file with only the log (default: None)') 88 | parser.add_argument('-C', '--printconfig', default=False, type=bool, 89 | help='print config (defaut False') 90 | 91 | args = parser.parse_args() 92 | 93 | assert args.checkpoint is not None 94 | saved = torch.load(args.checkpoint,map_location=lambda storage, loc: storage) 95 | log = saved['logger'] 96 | iteration = saved['iteration'] 97 | print('loaded iteration {}'.format(iteration)) 98 | 99 | if args.printconfig: 100 | print(saved['config']) 101 | exit() 102 | 103 | saved=None 104 | 105 | if args.extract is None: 106 | graph(log,args.plot,args.only) 107 | else: 108 | new_save = { 109 | 'iteration': iteration, 110 | 'logger': log 111 | } 112 | new_file = args.extract #args.checkpoint+'.ex' 113 | torch.save(new_save,new_file) 114 | print('saved '+new_file) 115 | -------------------------------------------------------------------------------- /evaluators/draw_graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math, os, random 4 | import torch 5 | 6 | def getCorners(xyrhw): 7 | xc=xyrhw[0].item() 8 | yc=xyrhw[1].item() 9 | r=xyrhw[2].item() 10 | h=xyrhw[3].item() 11 | w=xyrhw[4].item() 12 | h = min(30000,h) 13 | w = min(30000,w) 14 | 15 | tlX = int(-w*math.cos(r) -h*math.sin(r) +xc) 16 | tlY = int(-h*math.cos(r) +w*math.sin(r) +yc) 17 | trX = int( w*math.cos(r) -h*math.sin(r) +xc) 18 | trY = int(-h*math.cos(r) -w*math.sin(r) +yc) 19 | brX = int( w*math.cos(r) +h*math.sin(r) +xc) 20 | brY = int( h*math.cos(r) -w*math.sin(r) +yc) 21 | blX = int(-w*math.cos(r) +h*math.sin(r) +xc) 22 | blY = int( h*math.cos(r) +w*math.sin(r) +yc) 23 | return [[tlX,tlY],[trX,trY],[brX,brY],[blX,blY]] 24 | 25 | def plotRect(img,color,xyrhw,lineWidth=1): 26 | tl,tr,br,bl=getCorners(xyrhw) 27 | 28 | cv2.line(img,tl,tr,np.array(color),lineWidth) 29 | cv2.line(img,tr,br,np.array(color),lineWidth) 30 | cv2.line(img,br,bl,np.array(color),lineWidth) 31 | cv2.line(img,bl,tl,np.array(color),lineWidth) 32 | 33 | 34 | def draw_graph(outputBoxes,edgePred,edgeIndexes,image,pair_threshold,verbosity=1): 35 | if outputBoxes is not None: 36 | outputBoxes = outputBoxes.data.numpy() 37 | #image = image.cpu().numpy() 38 | b=0 39 | #image = (1-((1+np.transpose(image[b][:,:,:],(1,2,0)))/2.0)) 40 | #if image.shape[2]==1: 41 | # image = cv2.gray2rgb(image) 42 | 43 | 44 | 45 | to_write_text=[] 46 | bbs = outputBoxes 47 | 48 | 49 | #Draw pred groups (based on bb pred) 50 | groupCenters=[] 51 | predGroups = [[i] for i in range(len(bbs))] 52 | 53 | for group in predGroups: 54 | maxX=maxY=0 55 | minY=minX=99999999 56 | idColor = [random.random()/2+0.5 for i in range(3)] 57 | for j in group: 58 | conf = bbs[j,0] 59 | maxIndex =np.argmax(bbs[j,7:]) #TODO is this the right index? 60 | shade = conf#(conf-bb_thresh)/(1-bb_thresh) 61 | if maxIndex==0: 62 | color=(0,0,shade) #header 63 | elif maxIndex==1: 64 | color=(0,shade,shade) #question 65 | elif maxIndex==2: 66 | color=(shade,shade,0) #answer 67 | elif maxIndex==3: 68 | color=(shade,0,shade) #other 69 | else: 70 | raise NotImplementedError('Only 4 colors/classes implemented for drawing') 71 | lineWidth=1 72 | 73 | if verbosity>1 or len(group)==1: 74 | plotRect(image,color,bbs[j,1:6],lineWidth) 75 | x=int(bbs[j,1]) 76 | y=int(bbs[j,2]) 77 | 78 | tr,tl,br,bl=getCorners(outputBoxes[j,1:6]) 79 | if verbosity>1: 80 | image[tl[1]:tl[1]+2,tl[0]:tl[0]+2]=idColor 81 | image[tr[1]:tr[1]+1,tr[0]:tr[0]+1]=idColor 82 | image[bl[1]:bl[1]+1,bl[0]:bl[0]+1]=idColor 83 | image[br[1]:br[1]+1,br[0]:br[0]+1]=idColor 84 | maxX=max(maxX,tr[0],tl[0],br[0],bl[0]) 85 | minX=min(minX,tr[0],tl[0],br[0],bl[0]) 86 | maxY=max(maxY,tr[1],tl[1],br[1],bl[1]) 87 | minY=min(minY,tr[1],tl[1],br[1],bl[1]) 88 | minX-=2 89 | minY-=2 90 | maxX+=2 91 | maxY+=2 92 | lineWidth=2 93 | #color=(0.5,0,1) 94 | if len(group)>1: 95 | cv2.line(image,(minX,minY),(maxX,minY),color,lineWidth) 96 | cv2.line(image,(maxX,minY),(maxX,maxY),color,lineWidth) 97 | cv2.line(image,(maxX,maxY),(minX,maxY),color,lineWidth) 98 | cv2.line(image,(minX,maxY),(minX,minY),color,lineWidth) 99 | if verbosity>1: 100 | image[minY:minY+3,minX:minX+3]=idColor 101 | if verbosity>1: 102 | image[maxY:maxY+1,minX:minX+1]=idColor 103 | image[maxY:maxY+1,maxX:maxX+1]=idColor 104 | image[minY:minY+1,maxX:maxX+1]=idColor 105 | groupCenters.append(((minX+maxX)//2,(minY+maxY)//2)) 106 | 107 | 108 | 109 | #Draw pred pairings 110 | #draw_rel_thresh = relPred.max() * draw_rel_thresh_over 111 | numrelpred=0 112 | #hits = [False]*len(adjacency) 113 | edgesToDraw=[] 114 | if edgeIndexes is not None: 115 | for i,(g1,g2) in enumerate(edgeIndexes): 116 | 117 | if edgePred[i]>pair_threshold: 118 | x1,y1 = groupCenters[g1] 119 | x2,y2 = groupCenters[g2] 120 | edgesToDraw.append((i,x1,y1,x2,y2)) 121 | 122 | 123 | lineColor = (0,0.8,0) 124 | for i,x1,y1,x2,y2 in edgesToDraw: 125 | cv2.line(image,(x1,y1),(x2,y2),lineColor,2) 126 | 127 | 128 | 129 | 130 | 131 | 132 | image*=255 133 | return image 134 | #cv2.imwrite(path,image) 135 | 136 | 137 | -------------------------------------------------------------------------------- /model/binary_pair_real.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import torch 17 | from torch import nn 18 | #from base import BaseModel 19 | import torch.nn.functional as F 20 | #from torch.nn.utils.weight_norm import weight_norm 21 | import math 22 | import json 23 | from .net_builder import make_layers 24 | from model.simpleNN import SimpleNN 25 | 26 | #This assumes the classification of edges was done by the pairing_graph modules featurizer 27 | 28 | class BinaryPairReal(nn.Module): 29 | def __init__(self, config): # predCount, base_0, base_1): 30 | super(BinaryPairReal, self).__init__() 31 | numBBOut = config['bb_out'] if 'bb_out' in config else 0 32 | numRelOut = config['rel_out'] if 'rel_out' in config else 1 33 | 34 | in_ch=config['in_channels'] 35 | 36 | norm = config['norm'] if 'norm' in config else 'group_norm' 37 | dropout = config['dropout'] if 'dropout' in config else True 38 | 39 | layer_desc = config['layers'] if 'layers' in config else ['FC256','FC256','FC256'] 40 | layer_desc = [in_ch]+layer_desc+['FCnR{}'.format(numRelOut)] 41 | layers, last_ch_relC = make_layers(layer_desc,norm=norm,dropout=dropout) 42 | self.layers = nn.Sequential(*layers) 43 | 44 | if numBBOut>0: 45 | layer_desc = config['layers_bb'] if 'layers_bb' in config else ['FC256','FC256','FC256'] 46 | layer_desc = [in_ch]+layer_desc+['FCnR{}'.format(numBBOut)] 47 | layers, last_ch_relC = make_layers(layer_desc,norm=norm,dropout=dropout) 48 | self.layersBB = nn.Sequential(*layers) 49 | 50 | #This is written to by the PairingGraph object (which holds this one) 51 | self.numShapeFeats = config['num_shape_feats'] if 'num_shape_feats' in config else 16 52 | 53 | 54 | 55 | if 'shape_layers' in config: 56 | if type(config['shape_layers']) is list: 57 | layer_desc = config['shape_layers'] 58 | layer_desc = [self.numShapeFeats]+layer_desc+['FCnR{}'.format(numRelOut)] 59 | layers, last_ch_relC = make_layers(layer_desc,norm=norm,dropout=dropout) 60 | self.shape_layers = nn.Sequential(*layers) 61 | self.frozen_shape_layers=False 62 | else: 63 | checkpoint = torch.load(config['shape_layers']) 64 | shape_config = checkpoint['config']['model'] 65 | if 'state_dict' in checkpoint: 66 | self.shape_layers = eval(checkpoint['config']['arch'])(shape_config) 67 | self.shape_layers.load_state_dict(checkpoint['state_dict']) 68 | else: 69 | self.shape_layers = checkpoint['model'] 70 | for param in self.shape_layers.parameters(): 71 | param.requires_grad=False 72 | self.frozen_shape_layers=True 73 | if 'weight_split' in config: 74 | if type(config['weight_split']) is float: 75 | init = config['weight_split'] 76 | else: 77 | init = 0.5 78 | self.split_weighting = nn.Parameter(torch.tensor(init, requires_grad=True)) 79 | else: 80 | self.split_weighting = None 81 | else: 82 | self.shape_layers=None 83 | 84 | 85 | 86 | def forward(self, node_features, adjacencyMatrix, numBBs): 87 | node_featuresR = node_features[numBBs:] 88 | res = self.layers(node_featuresR) 89 | if self.shape_layers is not None: 90 | if self.frozen_shape_layers: 91 | self.shape_layers.eval() 92 | res2 = self.shape_layers(node_featuresR[:,-self.numShapeFeats:]) 93 | if self.split_weighting is None: 94 | res = (res+res2)/2 95 | else: 96 | weight = self.split_weighting.clamp(0,1) 97 | res = weight*res + (1-weight)*res2 98 | if numBBs>0: 99 | node_featuresB = node_features[:numBBs] 100 | resB = self.layersBB(node_featuresB) 101 | else: 102 | resB=None 103 | #import pdb;pdb.set_trace() 104 | return resB,res 105 | 106 | 107 | -------------------------------------------------------------------------------- /datasets/testforms_feat_pair.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | from datasets.forms_feature_pair import FormsFeaturePair 17 | from datasets import forms_feature_pair 18 | import math 19 | import sys 20 | from matplotlib import pyplot as plt 21 | from matplotlib import gridspec 22 | from matplotlib.patches import Polygon 23 | import numpy as np 24 | import torch 25 | 26 | def display(data): 27 | b=0 28 | return data['data'].numpy() 29 | if False: 30 | #print (data['img'].size()) 31 | #img = (data['img'][0].permute(1,2,0)+1)/2.0 32 | img = cv2.imread(data['imgPath']) 33 | #print(img.shape) 34 | #print(data['pixel_gt']['table_pixels'].shape) 35 | print(data['imgName']) 36 | 37 | 38 | 39 | fig = plt.figure() 40 | #gs = gridspec.GridSpec(1, 3) 41 | 42 | ax_im = plt.subplot() 43 | ax_im.set_axis_off() 44 | ax_im.imshow(img[:,:,0]) 45 | 46 | colors = { 'text_start_gt':'g-', 47 | 'text_end_gt':'b-', 48 | 'field_start_gt':'r-', 49 | 'field_end_gt':'y-', 50 | 'table_points':'co' 51 | } 52 | #print('num bb:{}'.format(data['bb_sizes'][b])) 53 | for i in range(data['bb_gt'].size(1)): 54 | xc=data['bb_gt'][b,i,0] 55 | yc=data['bb_gt'][b,i,1] 56 | rot=data['bb_gt'][b,i,2] 57 | h=data['bb_gt'][b,i,3] 58 | w=data['bb_gt'][b,i,4] 59 | text=data['bb_gt'][b,i,13] 60 | field=data['bb_gt'][b,i,14] 61 | if text>0: 62 | color = 'b-' 63 | else: 64 | color = 'r-' 65 | tr = (math.cos(rot)*w-math.sin(rot)*h +xc, math.sin(rot)*w+math.cos(rot)*h +yc) 66 | tl = (math.cos(rot)*-w-math.sin(rot)*h +xc, math.sin(rot)*-w+math.cos(rot)*h +yc) 67 | br = (math.cos(rot)*w-math.sin(rot)*-h +xc, math.sin(rot)*w+math.cos(rot)*-h +yc) 68 | bl = (math.cos(rot)*-w-math.sin(rot)*-h +xc, math.sin(rot)*-w+math.cos(rot)*-h +yc) 69 | #print([tr,tl,br,bl]) 70 | 71 | ax_im.plot([tr[0],tl[0],bl[0],br[0],tr[0]],[tr[1],tl[1],bl[1],br[1],tr[1]],color) 72 | for ind1,ind2 in data['adj']: 73 | x1=data['bb_gt'][b,ind1,0] 74 | y1=data['bb_gt'][b,ind1,1] 75 | x2=data['bb_gt'][b,ind2,0] 76 | y2=data['bb_gt'][b,ind2,1] 77 | 78 | ax_im.plot([x1,x2],[y1,y2],'g-') 79 | #print('{} to {}, {} - {}'.format(ind1,ind2,(x1,y1),(x2,y2))) 80 | plt.show() 81 | 82 | 83 | if __name__ == "__main__": 84 | dirPath = sys.argv[1] 85 | if len(sys.argv)>2: 86 | start = int(sys.argv[2]) 87 | else: 88 | start=0 89 | if len(sys.argv)>3: 90 | repeat = int(sys.argv[3]) 91 | else: 92 | repeat=1 93 | data=FormsFeaturePair(dirPath=dirPath,split='train',config={ 94 | "data_set_name": "FormsFeaturePair", 95 | "simple_dataset": True, 96 | "alternate_json_dir": "out_json/Forms21_augRFh_staggerLighter_NN", 97 | "data_dir": "../data/forms", 98 | "batch_size": 1, 99 | "shuffle": False, 100 | "num_workers": 0, 101 | "no_blanks": True, 102 | "swap_circle":True, 103 | "no_graphics":True, 104 | "cache_resized_images": True, 105 | "rotation": False, 106 | "balance": True, 107 | "only_opposite_pairs": True, 108 | "corners":True 109 | #"only_types": ["text_start_gt"] 110 | }) 111 | 112 | dataLoader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False, num_workers=0, collate_fn=forms_feature_pair.collate) 113 | dataLoaderIter = iter(dataLoader) 114 | 115 | #if start==0: 116 | #display(data[0]) 117 | for i in range(0,start): 118 | print(i) 119 | dataLoaderIter.next() 120 | #display(data[i]) 121 | datas=[] 122 | try: 123 | while True: 124 | #print('?') 125 | data = display(dataLoaderIter.next()) 126 | datas.append(data) 127 | except StopIteration: 128 | print('done') 129 | 130 | data = np.concatenate(datas,axis=0) 131 | 132 | #print(data.mean(axis=0)) 133 | #print(data.std(axis=0)) 134 | #toprint = ['']*data.shape[1] 135 | print('feat:\tmean:\tstddev:') 136 | for i in range(data.shape[1]): 137 | print('{}:\t{:.3}\t{:.3}'.format(i,data[:,i].mean(),data[:,i].std())) 138 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import os 17 | import struct 18 | import torch 19 | from . import string_utils 20 | 21 | 22 | def ensure_dir(path): 23 | if not os.path.exists(path): 24 | os.makedirs(path) 25 | 26 | 27 | def pt_xyrs_2_xyxy(state): 28 | out = torch.ones(state.data.shape[0], 5).type(state.data.type()) 29 | 30 | x = state[:,:,1:2] 31 | y = state[:,:,2:3] 32 | r = state[:,:,3:4] 33 | s = state[:,:,4:5] 34 | 35 | x0 = -torch.sin(r) * s + x 36 | y0 = -torch.cos(r) * s + y 37 | x1 = torch.sin(r) * s + x 38 | y1 = torch.cos(r) * s + y 39 | 40 | return torch.cat([ 41 | state[:,:,0:1], 42 | x0, y0, x1, y1 43 | ], 2) 44 | def pt_xyxy_2_xyrs(state): 45 | out = torch.ones(state.data.shape[0], 5).type(state.data.type()) 46 | 47 | x0 = state[:,0:1] 48 | y0 = state[:,1:2] 49 | x1 = state[:,2:3] 50 | y1 = state[:,3:4] 51 | 52 | dx = x0-x1 53 | dy = y0-y1 54 | 55 | d = torch.sqrt(dx**2.0 + dy**2.0)/2.0 56 | 57 | mx = (x0+x1)/2.0 58 | my = (y0+y1)/2.0 59 | 60 | theta = -torch.atan2(dx, -dy) 61 | 62 | return torch.cat([ 63 | mx, my, theta, d, 64 | state[:,4:5] 65 | ], 1) 66 | #------------------------------------------------------------------------------- 67 | # Name: get_image_size 68 | # Purpose: extract image dimensions given a file path using just 69 | # core modules 70 | # 71 | # Author: Paulo Scardine (based on code from Emmanuel VAÏSSE) 72 | # 73 | # Created: 26/09/2013 74 | # Copyright: (c) Paulo Scardine 2013 75 | # Licence: MIT 76 | # From: https://stackoverflow.com/questions/15800704/get-image-size-without-loading-image-into-memory 77 | #------------------------------------------------------------------------------- 78 | class UnknownImageFormat(Exception): 79 | pass 80 | 81 | def get_image_size(file_path): 82 | """ 83 | Return (width, height) for a given img file content - no external 84 | dependencies except the os and struct modules from core 85 | """ 86 | size = os.path.getsize(file_path) 87 | 88 | with open(file_path) as input: 89 | height = -1 90 | width = -1 91 | data = input.read(25) 92 | 93 | if (size >= 10) and data[:6] in ('GIF87a', 'GIF89a'): 94 | # GIFs 95 | w, h = struct.unpack("= 24) and data.startswith('\211PNG\r\n\032\n') 99 | and (data[12:16] == 'IHDR')): 100 | # PNGs 101 | w, h = struct.unpack(">LL", data[16:24]) 102 | width = int(w) 103 | height = int(h) 104 | elif (size >= 16) and data.startswith('\211PNG\r\n\032\n'): 105 | # older PNGs? 106 | w, h = struct.unpack(">LL", data[8:16]) 107 | width = int(w) 108 | height = int(h) 109 | elif (size >= 2) and data.startswith('\377\330'): 110 | # JPEG 111 | msg = " raised while trying to decode as JPEG." 112 | input.seek(0) 113 | input.read(2) 114 | b = input.read(1) 115 | try: 116 | while (b and ord(b) != 0xDA): 117 | while (ord(b) != 0xFF): b = input.read(1) 118 | while (ord(b) == 0xFF): b = input.read(1) 119 | if (ord(b) >= 0xC0 and ord(b) <= 0xC3): 120 | input.read(3) 121 | h, w = struct.unpack(">HH", input.read(4)) 122 | break 123 | else: 124 | input.read(int(struct.unpack(">H", input.read(2))[0])-2) 125 | b = input.read(1) 126 | width = int(w) 127 | height = int(h) 128 | except struct.error: 129 | raise UnknownImageFormat("StructError" + msg) 130 | except ValueError: 131 | raise UnknownImageFormat("ValueError" + msg) 132 | except Exception as e: 133 | raise UnknownImageFormat(e.__class__.__name__ + msg) 134 | else: 135 | raise UnknownImageFormat( 136 | "Sorry, don't know how to get information from this file." 137 | ) 138 | 139 | return width, height 140 | 141 | def decode_handwriting(out, idx_to_char): 142 | hw_out = out#['hw'] 143 | list_of_pred = [] 144 | list_of_raw_pred = [] 145 | for i in range(hw_out.shape[0]): 146 | logits = hw_out[i,...] 147 | pred, raw_pred = string_utils.naive_decode(logits) 148 | pred_str = string_utils.label2str_single(pred, idx_to_char, False) 149 | raw_pred_str = string_utils.label2str_single(raw_pred, idx_to_char, True) 150 | list_of_pred.append(pred_str) 151 | list_of_raw_pred.append(raw_pred_str) 152 | 153 | return list_of_pred, list_of_raw_pred 154 | 155 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import os 17 | import sys 18 | import signal 19 | import json 20 | import logging 21 | import argparse 22 | import torch 23 | from model import * 24 | from model.loss import * 25 | from model.metric import * 26 | from data_loader import getDataLoader 27 | from trainer import * 28 | from logger import Logger 29 | 30 | logging.basicConfig(level=logging.INFO, format='') 31 | def set_procname(newname): 32 | from ctypes import cdll, byref, create_string_buffer 33 | newname=os.fsencode(newname) 34 | libc = cdll.LoadLibrary('libc.so.6') #Loading a 3rd party library C 35 | buff = create_string_buffer(len(newname)+1) #Note: One larger than the name (man prctl says that) 36 | buff.value = newname #Null terminated string as it should be 37 | libc.prctl(15, byref(buff), 0, 0, 0) #Refer to "#define" of "/usr/include/linux/prctl.h" for the misterious value 16 & arg[3..5] are zero as the man page says. 38 | 39 | def main(config, resume): 40 | set_procname(config['name']) 41 | #np.random.seed(1234) I don't have a way of restarting the DataLoader at the same place, so this makes it totaly random 42 | train_logger = Logger() 43 | 44 | split = config['split'] if 'split' in config else 'train' 45 | data_loader, valid_data_loader = getDataLoader(config,split) 46 | #valid_data_loader = data_loader.split_validation() 47 | 48 | model = eval(config['arch'])(config['model']) 49 | model.summary() 50 | if type(config['loss'])==dict: 51 | loss={}#[eval(l) for l in config['loss']] 52 | for name,l in config['loss'].items(): 53 | loss[name]=eval(l) 54 | else: 55 | loss = eval(config['loss']) 56 | if type(config['metrics'])==dict: 57 | metrics={} 58 | for name,m in config['metrics'].items(): 59 | metrics[name]=[eval(metric) for metric in m] 60 | else: 61 | metrics = [eval(metric) for metric in config['metrics']] 62 | 63 | if 'class' in config['trainer']: 64 | trainerClass = eval(config['trainer']['class']) 65 | else: 66 | trainerClass = Trainer 67 | trainer = trainerClass(model, loss, metrics, 68 | resume=resume, 69 | config=config, 70 | data_loader=data_loader, 71 | valid_data_loader=valid_data_loader, 72 | train_logger=train_logger) 73 | 74 | def handleSIGINT(sig, frame): 75 | trainer.save() 76 | sys.exit(0) 77 | signal.signal(signal.SIGINT, handleSIGINT) 78 | 79 | print("Begin training") 80 | trainer.train() 81 | 82 | 83 | if __name__ == '__main__': 84 | logger = logging.getLogger() 85 | 86 | parser = argparse.ArgumentParser(description='PyTorch Template') 87 | parser.add_argument('-c', '--config', default=None, type=str, 88 | help='config file path (default: None)') 89 | parser.add_argument('-r', '--resume', default=None, type=str, 90 | help='path to checkpoint (default: None)') 91 | parser.add_argument('-s', '--soft_resume', default=None, type=str, 92 | help='path to checkpoint that may or may not exist (default: None)') 93 | parser.add_argument('-g', '--gpu', default=None, type=int, 94 | help='gpu to use (overrides config) (default: None)') 95 | #parser.add_argument('-m', '--merged', default=False, action='store_const', const=True, 96 | # help='Use combine train and valid sets.') 97 | 98 | args = parser.parse_args() 99 | 100 | config = None 101 | if args.config is not None: 102 | config = json.load(open(args.config)) 103 | if args.resume is None and args.soft_resume is not None: 104 | if not os.path.exists(args.soft_resume): 105 | print('WARNING: resume path ({}) was not found, starting from scratch'.format(args.soft_resume)) 106 | else: 107 | args.resume = args.soft_resume 108 | if args.resume is not None and (config is None or 'override' not in config or not config['override']): 109 | if args.config is not None: 110 | logger.warning('Warning: --config overridden by --resume') 111 | config = torch.load(args.resume)['config'] 112 | elif args.config is not None and args.resume is None: 113 | path = os.path.join(config['trainer']['save_dir'], config['name']) 114 | if os.path.exists(path): 115 | directory = os.fsencode(path) 116 | for file in os.listdir(directory): 117 | filename = os.fsdecode(file) 118 | if filename!='config.json': 119 | assert False, "Path {} already used!".format(path) 120 | 121 | assert config is not None 122 | 123 | if args.gpu is not None: 124 | config['gpu']=args.gpu 125 | print('override gpu to '+str(config['gpu'])) 126 | 127 | if config['cuda']: 128 | with torch.cuda.device(config['gpu']): 129 | main(config, args.resume) 130 | else: 131 | main(config, args.resume) 132 | -------------------------------------------------------------------------------- /datasets/testforms_box.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | from datasets.forms_box_detect import FormsBoxDetect 17 | from datasets import forms_box_detect 18 | import math 19 | import sys 20 | from matplotlib import pyplot as plt 21 | from matplotlib import gridspec 22 | from matplotlib.patches import Polygon 23 | import numpy as np 24 | import torch 25 | 26 | def display(data): 27 | batchSize = data['img'].size(0) 28 | for b in range(batchSize): 29 | #print (data['img'].size()) 30 | #img = (data['img'][0].permute(1,2,0)+1)/2.0 31 | img = (data['img'][b].permute(1,2,0)+1)/2.0 32 | #print(img.shape) 33 | #print(data['pixel_gt']['table_pixels'].shape) 34 | if 'pixel_gt' in data and data['pixel_gt'] is not None: 35 | img[:,:,1] = data['pixel_gt'][b,0,:,:] 36 | print(data['imgName'][b]) 37 | 38 | 39 | 40 | fig = plt.figure() 41 | #gs = gridspec.GridSpec(1, 3) 42 | 43 | ax_im = plt.subplot() 44 | ax_im.set_axis_off() 45 | if img.shape[2]==1: 46 | ax_im.imshow(img[0]) 47 | else: 48 | ax_im.imshow(img) 49 | 50 | colors = { 'text_start_gt':'g-', 51 | 'text_end_gt':'b-', 52 | 'field_start_gt':'r-', 53 | 'field_end_gt':'y-', 54 | 'table_points':'co', 55 | 'start_of_line':'y-', 56 | 'end_of_line':'c-', 57 | } 58 | print('num bb:{}'.format(data['bb_sizes'][b])) 59 | for i in range(data['bb_sizes'][b]): 60 | xc=data['bb_gt'][b,i,0] 61 | yc=data['bb_gt'][b,i,1] 62 | rot=data['bb_gt'][b,i,2] 63 | h=data['bb_gt'][b,i,3] 64 | w=data['bb_gt'][b,i,4] 65 | text=data['bb_gt'][b,i,13] 66 | field=data['bb_gt'][b,i,14] 67 | if text>0: 68 | color = 'b-' 69 | else: 70 | color = 'r-' 71 | tr = (math.cos(rot)*w-math.sin(rot)*h +xc, -math.sin(rot)*w-math.cos(rot)*h +yc) 72 | tl = (-math.cos(rot)*w-math.sin(rot)*h +xc, math.sin(rot)*w-math.cos(rot)*h +yc) 73 | br = (math.cos(rot)*w+math.sin(rot)*h +xc, -math.sin(rot)*w+math.cos(rot)*h +yc) 74 | bl = (-math.cos(rot)*w+math.sin(rot)*h +xc, math.sin(rot)*w+math.cos(rot)*h +yc) 75 | #print([tr,tl,br,bl]) 76 | 77 | ax_im.plot([tr[0],tl[0],bl[0],br[0],tr[0]],[tr[1],tl[1],bl[1],br[1],tr[1]],color) 78 | 79 | if data['bb_gt'].shape[2]>15: 80 | blank = data['bb_gt'][b,i,15] 81 | if blank>0: 82 | ax_im.plot(tr[0],tr[1],'mo') 83 | paired = data['bb_gt'][b,i,16] 84 | if paired>0: 85 | ax_im.plot(br[0],br[1],'go') 86 | 87 | 88 | if 'line_gt' in data and data['line_gt'] is not None: 89 | for name, gt in data['line_gt'].items(): 90 | if gt is not None: 91 | #print (gt.size()) 92 | for i in range(data['line_label_sizes'][name][b]): 93 | x0=gt[b,i,0] 94 | y0=gt[b,i,1] 95 | x1=gt[b,i,2] 96 | y1=gt[b,i,3] 97 | #print(1,'{},{} {},{}'.format(x0,y0,x1,y1)) 98 | 99 | ax_im.plot([x0,x1],[y0,y1],colors[name]) 100 | 101 | 102 | if 'point_gt' in data and data['point_gt'] is not None: 103 | for name, gt in data['point_gt'].items(): 104 | if gt is not None: 105 | #print (gt.size()) 106 | #print(data) 107 | for i in range(data['point_label_sizes'][name][b]): 108 | x0=gt[b,i,0] 109 | y0=gt[b,i,1] 110 | 111 | ax_im.plot([x0],[y0],colors[name]) 112 | plt.show() 113 | print('batch complete') 114 | 115 | 116 | if __name__ == "__main__": 117 | dirPath = sys.argv[1] 118 | if len(sys.argv)>2: 119 | start = int(sys.argv[2]) 120 | else: 121 | start=0 122 | if len(sys.argv)>3: 123 | repeat = int(sys.argv[3]) 124 | else: 125 | repeat=1 126 | data=FormsBoxDetect(dirPath=dirPath,split='train',config={'crop_to_page':False,'rescale_range':[0.4,0.65], 127 | 'crop_params':{ "crop_size":[652,1608], 128 | "pad":0, 129 | "rot_degree_std_dev":1.5, 130 | "flip_horz": True, 131 | "flip_vert": True}, 132 | 'no_blanks':False, 133 | 'use_paired_class':True, 134 | "swap_circle":True, 135 | 'no_graphics':True, 136 | 'rotation':False, 137 | "only_types": {"boxes":True} 138 | }) 139 | #data.cluster(start,repeat,'anchors_rot_{}.json') 140 | 141 | dataLoader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False, num_workers=0, collate_fn=forms_box_detect.collate) 142 | dataLoaderIter = iter(dataLoader) 143 | 144 | #if start==0: 145 | #display(data[0]) 146 | for i in range(0,start): 147 | print(i) 148 | dataLoaderIter.next() 149 | #display(data[i]) 150 | try: 151 | while True: 152 | #print('?') 153 | display(dataLoaderIter.next()) 154 | except StopIteration: 155 | print('done') 156 | -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | from torch.nn.utils.weight_norm import weight_norm 5 | 6 | 7 | __all__ = [ 8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 9 | 'vgg19_bn', 'vgg19', 10 | ] 11 | 12 | 13 | class VGG(nn.Module): 14 | 15 | def __init__(self, features, num_classes=1000): 16 | super(VGG, self).__init__() 17 | self.features = features 18 | 19 | def forward(self, x): 20 | x = self.features(x) 21 | return x 22 | 23 | def _initialize_weights(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 27 | m.weight.data.normal_(0, math.sqrt(2. / n)) 28 | if m.bias is not None: 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | elif isinstance(m, nn.Linear): 34 | m.weight.data.normal_(0, 0.01) 35 | m.bias.data.zero_() 36 | 37 | 38 | def make_layers(cfg, batch_norm=False, weight_norm=False): 39 | layers = [] 40 | in_channels = 3 41 | for i,v in enumerate(cfg): 42 | if v == 'M': 43 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 44 | else: 45 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 46 | if i == len(cfg)-1: 47 | layers += [conv2d] 48 | break 49 | if batch_norm: 50 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 51 | elif weight_norm: 52 | layers += [weight_norm(conv2d), nn.ReLU(inplace=True)] 53 | else: 54 | layers += [conv2d, nn.ReLU(inplace=True)] 55 | in_channels = v 56 | return nn.Sequential(*layers) 57 | 58 | OUTPUT_FEATURES = 5 59 | cfg = { 60 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, OUTPUT_FEATURES], 61 | "A'": [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 62 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, OUTPUT_FEATURES], 63 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, OUTPUT_FEATURES], 64 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, OUTPUT_FEATURES], 65 | } 66 | 67 | 68 | def vgg11(pretrained=False, **kwargs): 69 | """VGG 11-layer model (configuration "A") 70 | Args: 71 | pretrained (bool): If True, returns a model pre-trained on ImageNet 72 | """ 73 | model = VGG(make_layers(cfg['A'], **kwargs)) 74 | if pretrained: 75 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 76 | return model 77 | 78 | 79 | def vgg11_custOut(numOut,pretrained=False, **kwargs): 80 | """VGG 11-layer model (configuration "A'") 81 | Args: 82 | pretrained (bool): If True, returns a model pre-trained on ImageNet 83 | """ 84 | model = VGG(make_layers(cfg["A'"]+[numOut], **kwargs)) 85 | if pretrained: 86 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 87 | scale=1 88 | for a in cfg["A'"]: 89 | if a=='M': 90 | scale*=2 91 | return model, scale 92 | 93 | 94 | 95 | 96 | def vgg11_bn(pretrained=False, **kwargs): 97 | """VGG 11-layer model (configuration "A") with batch normalization 98 | Args: 99 | pretrained (bool): If True, returns a model pre-trained on ImageNet 100 | """ 101 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 102 | if pretrained: 103 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 104 | return model 105 | 106 | 107 | def vgg13(pretrained=False, **kwargs): 108 | """VGG 13-layer model (configuration "B") 109 | Args: 110 | pretrained (bool): If True, returns a model pre-trained on ImageNet 111 | """ 112 | model = VGG(make_layers(cfg['B']), **kwargs) 113 | if pretrained: 114 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 115 | return model 116 | 117 | 118 | def vgg13_bn(pretrained=False, **kwargs): 119 | """VGG 13-layer model (configuration "B") with batch normalization 120 | Args: 121 | pretrained (bool): If True, returns a model pre-trained on ImageNet 122 | """ 123 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 124 | if pretrained: 125 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 126 | return model 127 | 128 | 129 | def vgg16(pretrained=False, **kwargs): 130 | """VGG 16-layer model (configuration "D") 131 | Args: 132 | pretrained (bool): If True, returns a model pre-trained on ImageNet 133 | """ 134 | model = VGG(make_layers(cfg['D']), **kwargs) 135 | if pretrained: 136 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 137 | return model 138 | 139 | 140 | def vgg16_bn(pretrained=False, **kwargs): 141 | """VGG 16-layer model (configuration "D") with batch normalization 142 | Args: 143 | pretrained (bool): If True, returns a model pre-trained on ImageNet 144 | """ 145 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 146 | if pretrained: 147 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 148 | return model 149 | 150 | 151 | def vgg19(pretrained=False, **kwargs): 152 | """VGG 19-layer model (configuration "E") 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | """ 156 | model = VGG(make_layers(cfg['E']), **kwargs) 157 | if pretrained: 158 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 159 | return model 160 | 161 | 162 | def vgg19_bn(pretrained=False, **kwargs): 163 | """VGG 19-layer model (configuration 'E') with batch normalization 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 168 | if pretrained: 169 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 170 | return model 171 | -------------------------------------------------------------------------------- /model/optimize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import sys 17 | 18 | def optimizeRelationships(relPred,relNodes,gtNodeNeighbors,penalty=490): 19 | #if 'cvxpy' not in sys.modules: 20 | import cvxpy 21 | useRel = cvxpy.Variable(relPred.size(0),boolean=True) 22 | 23 | obj =0 24 | huh=0 25 | for i in range(relPred.size(0)): 26 | obj += relPred[i].item()*useRel[i] 27 | huh +=useRel[i] 28 | 29 | 30 | constraint = [0]*len(gtNodeNeighbors) 31 | for i in range(len(gtNodeNeighbors)): 32 | relI=0 33 | for a,b in relNodes: 34 | j=None 35 | if a==i: 36 | j=b 37 | elif b==i: 38 | j=a 39 | if j is not None: 40 | constraint[i] += useRel[relI] 41 | relI+=1 42 | constraint[i] -= gtNodeNeighbors[i] 43 | #obj -= cvxpy.power(penalty,(cvxpy.abs(constraint[i]))) #this causes it to not miss on the same node more than once 44 | constraint[i] = cvxpy.abs(constraint[i]) 45 | obj -= penalty*constraint[i] 46 | 47 | 48 | cs=[] 49 | for i in range(len(gtNodeNeighbors)): 50 | cs.append(constraint[i]<=1) 51 | problem = cvxpy.Problem(cvxpy.Maximize(obj),cs) 52 | #problem.solve(solver=cvxpy.GLPK_MI) 53 | problem.solve(solver=cvxpy.ECOS_BB) 54 | assert(useRel.value is not None) 55 | return useRel.value 56 | def optimizeRelationshipsSoft(relPred,relNodes,predNodeNeighbors,penalty=1.2,threshold=0.5): 57 | #if 'cvxpy' not in sys.modules: 58 | import cvxpy 59 | useRel = cvxpy.Variable(relPred.size(0),boolean=True) 60 | 61 | obj =0 62 | huh=0 63 | for i in range(relPred.size(0)): 64 | obj += (relPred[i].item()-threshold)*useRel[i] 65 | huh +=useRel[i] 66 | 67 | 68 | difference = [0]*len(predNodeNeighbors) 69 | for i in range(len(predNodeNeighbors)): 70 | relI=0 71 | for a,b in relNodes: 72 | j=None 73 | if a==i: 74 | j=b 75 | elif b==i: 76 | j=a 77 | if j is not None: 78 | difference[i] += useRel[relI] 79 | relI+=1 80 | difference[i] -= predNodeNeighbors[i] 81 | #difference[i] = cvxpy.abs(difference[i]) 82 | #obj -= cvxpy.power(penalty,difference[i]) #this causes it to not miss on the same node more than once 83 | obj -= penalty*cvxpy.power(difference[i],2) 84 | #obj -= penalty*cvxpy.maximum(1,difference[i]) - penalty #double penalty if difference>1 85 | #obj -= penalty*cvxpy.maximum(2,difference[i]) - 2*penalty #triple penalty if difference>2 86 | 87 | 88 | cs=[] 89 | #for i in range(len(predNodeNeighbors)): 90 | # cs.append(difference[i]<=4) 91 | problem = cvxpy.Problem(cvxpy.Maximize(obj),cs) 92 | #problem.solve(solver=cvxpy.GLPK_MI) 93 | problem.solve(solver=cvxpy.ECOS_BB) 94 | return useRel.value 95 | 96 | def optimizeRelationshipsBlind(relPred,relNodes,penalty=0.5): 97 | #if 'cvxpy' not in sys.modules: 98 | import cvxpy 99 | useRel = cvxpy.Variable(relPred.size(0),boolean=True) 100 | 101 | obj =0 102 | huh=0 103 | for i in range(relPred.size(0)): 104 | obj += relPred[i].item()*useRel[i] 105 | huh +=useRel[i] 106 | 107 | maxId=0 108 | for a,b in relNodes: 109 | maxId=max(maxId,a,b) 110 | numNodes=maxId+1 111 | 112 | constraint = [0]*numNodes 113 | for i in range(numNodes): 114 | relI=0 115 | for a,b in relNodes: 116 | j=None 117 | if a==i: 118 | j=b 119 | elif b==i: 120 | j=a 121 | if j is not None: 122 | constraint[i] += useRel[relI] 123 | relI+=1 124 | #constraint[i] -= gtNodeNeighbors[i] 125 | #obj -= cvxpy.power(penalty,(cvxpy.abs(constraint[i]))) #this causes it to not miss on the same node more than once 126 | #constraint[i] = cvxpy.abs(constraint[i]) 127 | 128 | obj -= penalty*(cvxpy.maximum(constraint[i],1)-1) 129 | 130 | 131 | cs=[] 132 | for i in range(numNodes): 133 | cs.append(constraint[i]<=2) 134 | problem = cvxpy.Problem(cvxpy.Maximize(obj),cs) 135 | problem.solve(solver=cvxpy.GLPK_MI) 136 | return useRel.value 137 | #from gurobipy import * 138 | # 139 | # 140 | #def optimizeRelationshipsGUROBI(relPred,relNodes,gtNodeNeighbors): 141 | # m = Model("mip1") 142 | # 143 | # x = m.addVar(vtype=GRB.BINARY, name="x") 144 | # 145 | # useRel=[] 146 | # for i in range(relPred.size(0)): 147 | # useRel.append( m.addVar(vtype=GRB.BINARY, name="e{}".format(i)) ) 148 | # 149 | # obj = LinExpr() 150 | # for i in range(relPred.size(0)): 151 | # obj += relPred[i].item()*useRel[i] 152 | # 153 | # for i in range(numNodes): 154 | # constraint = LinExpr() 155 | # relI=0 156 | # for a,b in relNodes: 157 | # j=None 158 | # if a==i: 159 | # j=b 160 | # elif b==i: 161 | # j=a 162 | # if j is not None: 163 | # constraint += useRel[relI] 164 | # constraint -= gtNodeNeighbors[i] 165 | # obj -= penalty**(abs(constraint)) #this causes it to not miss on the same node more than once 166 | # 167 | # m.setObjective(obj, GRB.MAXIMIZE) 168 | # 169 | # m.optimize() 170 | # 171 | # #for v in m.getVars(): 172 | # # print(v.varName, v.x) 173 | # ret = [0]*relPred.size(0) 174 | # for i in range(relPred.size(0)): 175 | # ret[i]=useRel[i].x 176 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import numpy as np 17 | import torch 18 | from base import BaseTrainer 19 | import timeit 20 | 21 | 22 | class Trainer(BaseTrainer): 23 | """ 24 | Trainer class 25 | 26 | Note: 27 | Inherited from BaseTrainer. 28 | self.optimizer is by default handled by BaseTrainer based on config. 29 | """ 30 | def __init__(self, model, loss, metrics, resume, config, 31 | data_loader, valid_data_loader=None, train_logger=None): 32 | super(Trainer, self).__init__(model, loss, metrics, resume, config, train_logger) 33 | #self.config = config #uggh, why is this getting overwritten everywhere? We'll let super handle it 34 | self.batch_size = data_loader.batch_size 35 | self.data_loader = data_loader 36 | self.data_loader_iter = iter(data_loader) 37 | #for i in range(self.start_iteration, 38 | self.valid_data_loader = valid_data_loader 39 | self.valid = True if self.valid_data_loader is not None else False 40 | #self.log_step = int(np.sqrt(self.batch_size)) 41 | 42 | #def _to_tensor(self, data, target): 43 | # return self._to_tensor_individual(data), _to_tensor_individual(target) 44 | def _to_tensor(self, *datas): 45 | ret=(self._to_tensor_individual(datas[0]),) 46 | for i in range(1,len(datas)): 47 | ret+=(self._to_tensor_individual(datas[i]),) 48 | return ret 49 | def _to_tensor_individual(self, data): 50 | if type(data)==str: 51 | return data 52 | if type(data)==list or type(data)==tuple: 53 | return [self._to_tensor_individual(d) for d in data] 54 | if (len(data.size())==1 and data.size(0)==1): 55 | return data[0] 56 | if type(data) is np.ndarray: 57 | data = torch.FloatTensor(data.astype(np.float32)) 58 | elif type(data) is torch.Tensor: 59 | data = data.type(torch.FloatTensor) 60 | if self.with_cuda: 61 | data = data.to(self.gpu) 62 | return data 63 | 64 | def _eval_metrics(self, output, target): 65 | acc_metrics = np.zeros(len(self.metrics)) 66 | if len(self.metrics)>0: 67 | output = output.cpu().data.numpy() 68 | target = target.cpu().data.numpy() 69 | for i, metric in enumerate(self.metrics): 70 | acc_metrics[i] += metric(output, target) 71 | return acc_metrics 72 | 73 | def _train_iteration(self, iteration): 74 | """ 75 | Training logic for an iteration 76 | 77 | :param iteration: Current training iteration. 78 | :return: A log that contains all information you want to save. 79 | 80 | Note: 81 | If you have additional information to record, for example: 82 | > additional_log = {"x": x, "y": y} 83 | merge it with log before return. i.e. 84 | > log = {**log, **additional_log} 85 | > return log 86 | 87 | The metrics in log must have the key 'metrics'. 88 | """ 89 | self.model.train() 90 | 91 | #tic=timeit.default_timer() 92 | batch_idx = (iteration-1) % len(self.data_loader) 93 | try: 94 | data, target = self._to_tensor(*self.data_loader_iter.next()) 95 | except StopIteration: 96 | self.data_loader_iter = iter(self.data_loader) 97 | data, target = self._to_tensor(*self.data_loader_iter.next()) 98 | #toc=timeit.default_timer() 99 | #print('data: '+str(toc-tic)) 100 | 101 | #tic=timeit.default_timer() 102 | 103 | self.optimizer.zero_grad() 104 | output = self.model(data) 105 | loss = self.loss(output, target) 106 | loss.backward() 107 | self.optimizer.step() 108 | 109 | #toc=timeit.default_timer() 110 | #print('for/bac: '+str(toc-tic)) 111 | 112 | #tic=timeit.default_timer() 113 | metrics = self._eval_metrics(output, target) 114 | #toc=timeit.default_timer() 115 | #print('metric: '+str(toc-tic)) 116 | 117 | #tic=timeit.default_timer() 118 | loss = loss.item() 119 | #toc=timeit.default_timer() 120 | #print('item: '+str(toc-tic)) 121 | 122 | 123 | log = { 124 | 'loss': loss, 125 | 'metrics': metrics 126 | } 127 | 128 | 129 | return log 130 | 131 | def _minor_log(self, log): 132 | ls='' 133 | for key,val in log.items(): 134 | ls += key 135 | if type(val) is float: 136 | ls +=': {:.6f}, '.format(val) 137 | else: 138 | ls +=': {}, '.format(val) 139 | self.logger.info('Train '+ls) 140 | 141 | def _valid_epoch(self): 142 | """ 143 | Validate after training an epoch 144 | 145 | :return: A log that contains information about validation 146 | 147 | Note: 148 | The validation metrics in log must have the key 'val_metrics'. 149 | """ 150 | self.model.eval() 151 | total_val_loss = 0 152 | total_val_metrics = np.zeros(len(self.metrics)) 153 | with torch.no_grad(): 154 | for batch_idx, (data, target) in enumerate(self.valid_data_loader): 155 | data, target = self._to_tensor(data, target) 156 | 157 | output = self.model(data) 158 | loss = self.loss(output, target) 159 | 160 | total_val_loss += loss.item() 161 | total_val_metrics += self._eval_metrics(output, target) 162 | 163 | return { 164 | 'val_loss': total_val_loss / len(self.valid_data_loader), 165 | 'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist() 166 | } 167 | -------------------------------------------------------------------------------- /data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import torch 17 | import torch.utils.data 18 | import numpy as np 19 | from datasets import forms_box_detect 20 | from datasets.forms_box_detect import FormsBoxDetect 21 | from datasets import forms_graph_pair 22 | from datasets.forms_feature_pair import FormsFeaturePair 23 | from datasets import forms_feature_pair 24 | #from torchvision import datasets, transforms 25 | from base import BaseDataLoader 26 | 27 | 28 | 29 | #class MnistDataLoader(BaseDataLoader): 30 | # """ 31 | # MNIST data loading demo using BaseDataLoader 32 | # """ 33 | # def __init__(self, config): 34 | # super(MnistDataLoader, self).__init__(config) 35 | # self.data_dir = config['data_loader']['data_dir'] 36 | # self.data_loader = torch.utils.data.DataLoader( 37 | # datasets.MNIST('../data', train=True, download=True, 38 | # transform=transforms.Compose([ 39 | # transforms.ToTensor(), 40 | # transforms.Normalize((0.1307,), (0.3081,)) 41 | # ])), batch_size=256, shuffle=False) 42 | # self.x = [] 43 | # self.y = [] 44 | # for data, target in self.data_loader: 45 | # self.x += [i for i in data.numpy()] 46 | # self.y += [i for i in target.numpy()] 47 | # self.x = np.array(self.x) 48 | # self.y = np.array(self.y) 49 | # 50 | # def __next__(self): 51 | # batch = super(MnistDataLoader, self).__next__() 52 | # batch = [np.array(sample) for sample in batch] 53 | # return batch 54 | # 55 | # def _pack_data(self): 56 | # packed = list(zip(self.x, self.y)) 57 | # return packed 58 | # 59 | # def _unpack_data(self, packed): 60 | # unpacked = list(zip(*packed)) 61 | # unpacked = [list(item) for item in unpacked] 62 | # return unpacked 63 | # 64 | # def _update_data(self, unpacked): 65 | # self.x, self.y = unpacked 66 | # 67 | # def _n_samples(self): 68 | # return len(self.x) 69 | 70 | def getDataLoader(config,split): 71 | data_set_name = config['data_loader']['data_set_name'] 72 | data_dir = config['data_loader']['data_dir'] 73 | batch_size = config['data_loader']['batch_size'] 74 | valid_batch_size = config['validation']['batch_size'] if 'batch_size' in config['validation'] else batch_size 75 | 76 | #copy info from main dataloader to validation (but don't overwrite) 77 | #helps insure same data 78 | for k,v in config['data_loader'].items(): 79 | if k not in config['validation']: 80 | config['validation'][k]=v 81 | 82 | if 'augmentation_params' in config['data_loader']: 83 | aug_param = config['data_loader']['augmentation_params'] 84 | else: 85 | aug_param = None 86 | shuffle = config['data_loader']['shuffle'] 87 | if 'num_workers' in config['data_loader']: 88 | numDataWorkers = config['data_loader']['num_workers'] 89 | else: 90 | numDataWorkers = 1 91 | shuffleValid = config['validation']['shuffle'] 92 | 93 | if data_set_name=='FormsBoxDetect': 94 | return withCollate(FormsBoxDetect,forms_box_detect.collate,batch_size,valid_batch_size,shuffle,shuffleValid,numDataWorkers,split,data_dir,config) 95 | elif data_set_name=='FormsGraphPair': 96 | return withCollate(forms_graph_pair.FormsGraphPair,forms_graph_pair.collate,batch_size,valid_batch_size,shuffle,shuffleValid,numDataWorkers,split,data_dir,config) 97 | elif data_set_name=='FormsFeaturePair': 98 | return withCollate(FormsFeaturePair,forms_feature_pair.collate,batch_size,valid_batch_size,shuffle,shuffleValid,numDataWorkers,split,data_dir,config) 99 | else: 100 | print('Error, no dataloader has no set for {}'.format(data_set_name)) 101 | exit() 102 | 103 | 104 | 105 | def basic(setObj,batch_size,valid_batch_size,shuffle,shuffleValid,numDataWorkers,split,data_dir,config): 106 | if split=='train': 107 | trainData = setObj(dirPath=data_dir, split='train', config=config['data_loader']) 108 | trainLoader = torch.utils.data.DataLoader(trainData, batch_size=batch_size, shuffle=shuffle, num_workers=numDataWorkers) 109 | validData = setObj(dirPath=data_dir, split='valid', config=config['validation']) 110 | validLoader = torch.utils.data.DataLoader(validData, batch_size=valid_batch_size, shuffle=shuffleValid, num_workers=numDataWorkers) 111 | return trainLoader, validLoader 112 | elif split=='test': 113 | testData = setObj(dirPath=data_dir, split='test', config=config['validation']) 114 | testLoader = torch.utils.data.DataLoader(testData, batch_size=valid_batch_size, shuffle=False, num_workers=numDataWorkers) 115 | elif split=='merge' or split=='merged' or split=='train-valid' or split=='train+valid': 116 | trainData = setObj(dirPath=data_dir, split=['train','valid'], config=config['data_loader']) 117 | trainLoader = torch.utils.data.DataLoader(trainData, batch_size=batch_size, shuffle=shuffle, num_workers=numDataWorkers) 118 | validData = setObj(dirPath=data_dir, split=['train','valid'], config=config['validation']) 119 | validLoader = torch.utils.data.DataLoader(validData, batch_size=valid_batch_size, shuffle=shuffleValid, num_workers=numDataWorkers) 120 | return trainLoader, validLoader 121 | def withCollate(setObj,collateFunc,batch_size,valid_batch_size,shuffle,shuffleValid,numDataWorkers,split,data_dir,config): 122 | if split=='train': 123 | trainData = setObj(dirPath=data_dir, split='train', config=config['data_loader']) 124 | trainLoader = torch.utils.data.DataLoader(trainData, batch_size=batch_size, shuffle=shuffle, num_workers=numDataWorkers, collate_fn=collateFunc) 125 | validData = setObj(dirPath=data_dir, split='valid', config=config['validation']) 126 | validLoader = torch.utils.data.DataLoader(validData, batch_size=valid_batch_size, shuffle=shuffleValid, num_workers=numDataWorkers, collate_fn=collateFunc) 127 | return trainLoader, validLoader 128 | elif split=='test': 129 | testData = setObj(dirPath=data_dir, split='test', config=config['validation']) 130 | testLoader = torch.utils.data.DataLoader(testData, batch_size=valid_batch_size, shuffle=False, num_workers=numDataWorkers, collate_fn=collateFunc) 131 | return testLoader, None 132 | elif split=='merge' or split=='merged' or split=='train-valid' or split=='train+valid': 133 | trainData = setObj(dirPath=data_dir, split=['train','valid'], config=config['data_loader']) 134 | trainLoader = torch.utils.data.DataLoader(trainData, batch_size=batch_size, shuffle=shuffle, num_workers=numDataWorkers, collate_fn=collateFunc) 135 | validData = setObj(dirPath=data_dir, split=['train','valid'], config=config['validation']) 136 | validLoader = torch.utils.data.DataLoader(validData, batch_size=valid_batch_size, shuffle=shuffleValid, num_workers=numDataWorkers, collate_fn=collateFunc) 137 | return trainLoader, validLoader 138 | 139 | 140 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import torch 4 | import cv2 5 | import numpy as np 6 | #from tqdm import tqdm 7 | from skimage import color, io 8 | from model import * 9 | from evaluators.draw_graph import draw_graph 10 | 11 | DETECTOR_TRAINED_MODEL = "saved/detector/checkpoint-iteration150000.pth" 12 | TRAINED_MODEL = "saved/pairing/checkpoint-iteration125000.pth" 13 | SCALE_IMAGE_DEFAULT = 0.52 # percent of original size 14 | INCLUDE_THRESHOLD_DEFAULT = 0.9 # threshold for using the bounding box (0 to 1) 15 | PAIR_THRESHOLD_DEFAULT = 0.7 # threshold for using the bounding box (0 to 1) 16 | 17 | 18 | def getCorners(xyrhw): 19 | xc=xyrhw[0] 20 | yc=xyrhw[1] 21 | rot=xyrhw[2] 22 | h=xyrhw[3] 23 | w=xyrhw[4] 24 | h = min(30000,h) 25 | w = min(30000,w) 26 | tr = ( int(w*math.cos(rot)-h*math.sin(rot) + xc), int(w*math.sin(rot)+h*math.cos(rot) + yc) ) 27 | tl = ( int(-w*math.cos(rot)-h*math.sin(rot) + xc), int(-w*math.sin(rot)+h*math.cos(rot) + yc) ) 28 | br = ( int(w*math.cos(rot)+h*math.sin(rot) + xc), int(w*math.sin(rot)-h*math.cos(rot) + yc) ) 29 | bl = ( int(-w*math.cos(rot)+h*math.sin(rot) + xc), int(-w*math.sin(rot)-h*math.cos(rot) + yc) ) 30 | return tl,tr,br,bl 31 | 32 | def plotRect(img,color,xyrhw,lineW=1): 33 | tl,tr,br,bl = getCorners(xyrhw) 34 | 35 | cv2.line(img,tl,tr,color,lineW) 36 | cv2.line(img,tr,br,color,lineW) 37 | cv2.line(img,br,bl,color,lineW) 38 | cv2.line(img,bl,tl,color,lineW) 39 | 40 | def detect_boxes(run_img,np_img, include_threshold=INCLUDE_THRESHOLD_DEFAULT, output_image=None,model_checkpoint=DETECTOR_TRAINED_MODEL,use_gpu=None): 41 | 42 | if gpu is not None: 43 | device="cuda" 44 | else: 45 | device="cpu" 46 | 47 | # device= "cuda" if use_gpu else "cpu" 48 | print(f"Using {device} device") 49 | 50 | # fetch the model 51 | checkpoint = torch.load(model_checkpoint, map_location=lambda storage, location: storage) 52 | print(f"Using {checkpoint['arch']}") 53 | model = eval(checkpoint['arch'])(checkpoint['config']['model']) 54 | model.load_state_dict(checkpoint['state_dict']) 55 | 56 | 57 | # run the image through the model 58 | print(f"Run image through model: {imagePath}") 59 | result = model(run_img) 60 | 61 | # produce the output 62 | boundingboxes = result[0].tolist() 63 | output = [] 64 | 65 | print(f"Process bounding boxes: {imagePath}") 66 | #for i in tqdm(boundingboxes[0]): 67 | for i in boundingboxes[0]: 68 | if i[0] < include_threshold: 69 | continue 70 | print(i) 71 | tl,tr,br,bl = getCorners(i[1:]) 72 | scale=1 73 | bb = { 74 | 'poly_points': [ [float(tl[0]/scale),float(tl[1]/scale)], 75 | [float(tr[0]/scale),float(tr[1]/scale)], 76 | [float(br[0]/scale),float(br[1]/scale)], 77 | [float(bl[0]/scale),float(bl[1]/scale)] ], 78 | 'type':'detectorPrediction', 79 | 'textPred': float(i[7]), 80 | 'fieldPred': float(i[8]) 81 | } 82 | colour = (255,0,0) # red 83 | if bb['textPred'] > bb['fieldPred']: 84 | colour = (0,0,255) # blue 85 | output.append(bb) 86 | if output_image: 87 | plotRect(np_img, colour, i[1:6]) 88 | 89 | if output_image: 90 | print(f"Saving output: {output_image}") 91 | io.imsave(output_image, np_img) 92 | return output 93 | 94 | 95 | def detect_boxes_and_pairs(run_img,np_img, output_image=None,model_checkpoint=TRAINED_MODEL,pair_threshold=PAIR_THRESHOLD_DEFAULT,use_gpu=None): 96 | 97 | if gpu is not None: 98 | device="cuda" 99 | else: 100 | device="cpu" 101 | 102 | # device= "cuda" if use_gpu else "cpu" 103 | print(f"Using {device} device") 104 | 105 | # fetch the model 106 | checkpoint = torch.load(model_checkpoint, map_location=lambda storage, location: storage) 107 | print(f"Using {checkpoint['arch']}") 108 | model = eval(checkpoint['arch'])(checkpoint['config']['model']) 109 | model.load_state_dict(checkpoint['state_dict']) 110 | model.to(device) 111 | # run the image through the model 112 | print(f"Run image through model: {imagePath}") 113 | run_img=run_img.to(device) 114 | result = model(run_img) 115 | outputBoxes, outputOffsets, relPred, relIndexes, bbPred = result 116 | relPred = torch.sigmoid(relPred) 117 | np_img = draw_graph(outputBoxes,relPred,relIndexes,np_img,pair_threshold) 118 | 119 | if output_image: 120 | print(f"Saving output: {output_image}") 121 | io.imsave(output_image, np_img) 122 | return result 123 | 124 | def main(imagePath,scale_image,detection,checkpoint,detect_threshold,output_image,pair_threshold,gpu=None): 125 | 126 | print(f"Loading image: {imagePath}") 127 | np_img = cv2.imread(imagePath, cv2.IMREAD_COLOR) 128 | 129 | print(f"Transforming image: {imagePath}") 130 | width = int(np_img.shape[1] * scale_image) 131 | height = int(np_img.shape[0] * scale_image) 132 | new_size = (width, height) 133 | np_img = cv2.resize(np_img,new_size) 134 | img = cv2.cvtColor(np_img, cv2.COLOR_BGR2GRAY) 135 | img = img[None,None,:,:] 136 | img = img.astype(np.float32) 137 | img = torch.from_numpy(img) 138 | img = 1.0 - img / 128.0 139 | 140 | if detection: 141 | if checkpoint is None: 142 | checkpoint = DETECTOR_TRAINED_MODEL 143 | result = detect_boxes( 144 | img, 145 | np_img, 146 | include_threshold=args.detect_threshold, 147 | output_image=output_image, 148 | model_checkpoint = checkpoint, 149 | use_gpu=gpu 150 | ) 151 | else: 152 | if checkpoint is None: 153 | checkpoint = TRAINED_MODEL 154 | np_img=np_img.astype(np.float32)/255 155 | result = detect_boxes_and_pairs( 156 | img, 157 | np_img, 158 | output_image=output_image, 159 | pair_threshold=args.pair_threshold, 160 | model_checkpoint = checkpoint, 161 | use_gpu=gpu 162 | ) 163 | 164 | 165 | if __name__ == "__main__": 166 | parser = argparse.ArgumentParser(description='Run on a single image') 167 | parser.add_argument('image', type=str, help='Path to the image to convert') 168 | parser.add_argument('output_image', type=str, help="A path to save a version of the original image with form boxes overlaid") 169 | parser.add_argument('--scale-image', type=float, default=SCALE_IMAGE_DEFAULT, 170 | help='Scale the image by this proportion (between 0 and 1). 0.52 for pretrained model on NAF images') 171 | parser.add_argument('--detect-threshold', type=float, default=INCLUDE_THRESHOLD_DEFAULT, 172 | help='Include boxes where the confidence is above this threshold (between 0 and 1)') 173 | parser.add_argument('--pair-threshold', type=float, default=INCLUDE_THRESHOLD_DEFAULT, 174 | help='Include relationships where the confidence is above this threshold (between 0 and 1) default: 0.7') 175 | parser.add_argument('-c', '--checkpoint', default=None, type=str, 176 | help='path to checkpoint (default: pretrained model)') 177 | parser.add_argument('-d', '--detection', default=False, action='store_const', const=True, 178 | help='Run detection model. Default is full (pairing) model') 179 | parser.add_argument('-g', '--gpu', default=None, type=int, 180 | help='gpu number (default: cpu only)') 181 | args = parser.parse_args() 182 | 183 | imagePath = args.image 184 | output_image = args.output_image 185 | scale_image = args.scale_image 186 | checkpoint = args.checkpoint 187 | detection=args.detection 188 | detect_threshold=args.detect_threshold 189 | pair_threshold=args.pair_threshold 190 | gpu=args.gpu 191 | 192 | if gpu is not None: 193 | with torch.cuda.device(gpu): 194 | main(imagePath,scale_image,detection,checkpoint,detect_threshold,output_image,pair_threshold,gpu) 195 | else: 196 | main(imagePath,scale_image,detection,checkpoint,detect_threshold,output_image,pair_threshold,gpu) 197 | -------------------------------------------------------------------------------- /utils/transformation_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 3 | it under the terms of the GNU General Public License as published by 4 | the Free Software Foundation, either version 3 of the License, or 5 | (at your option) any later version. 6 | 7 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 8 | but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | GNU General Public License for more details. 11 | 12 | You should have received a copy of the GNU General Public License 13 | along with Visual-Template-free-Form-Parsting. If not, see . 14 | """ 15 | import torch 16 | import torch.nn as nn 17 | #from torch.autograd import Variable 18 | from torch.nn.modules.module import Module 19 | 20 | from utils.fast_inverse import inverse_torch 21 | import numpy as np 22 | 23 | def compute_renorm_matrix(img): 24 | inv_c = np.array([ 25 | [1.0/img.size(2), 0, 0], 26 | [0, 1.0/img.size(3), 0], 27 | [0,0,1] 28 | ], dtype=np.float32) 29 | 30 | inv_b = np.array([ 31 | [2,0,-1], 32 | [0,2,-1], 33 | [0,0, 1] 34 | ], dtype=np.float32) 35 | 36 | inv_c = torch.from_numpy(inv_c).type(img.data.type()) 37 | inv_b = torch.from_numpy(inv_b).type(img.data.type()) 38 | 39 | return inv_b.mm(inv_c) 40 | 41 | def compute_next_state(delta, state): 42 | out = torch.zeros(*state.data.shape).type(state.data.type()) 43 | for i in xrange(0,3): 44 | out[:,i+2] = delta[:,i] + state[:,i+2] 45 | #r*cos(theta) + x = x' 46 | out[:,0] = out[:,3] * torch.cos(out[:,2]) + state[:,0] 47 | #r*sin(theta) + y = y' 48 | out[:,1] = out[:,3] * torch.sin(out[:,2]) + state[:,1] 49 | return out 50 | 51 | def compute_points(state): 52 | out = torch.zeros(state.data.shape[0],2,2).type(state.data.type()) 53 | out[:,0,0] = state[:,4] * torch.sin(state[:,2]) 54 | out[:,0,1] = state[:,4] * torch.cos(state[:,2]) 55 | 56 | out[:,1] = -out[:,0] 57 | 58 | out[:,:,0] = out[:,:,0] + state[:,0] 59 | out[:,:,1] = out[:,:,1] + state[:,1] 60 | 61 | return out 62 | 63 | import time 64 | def compute_basis(pts): 65 | #start = time.time() 66 | A = pts[:,:3,:3] 67 | b = pts[:,:3,3:4] 68 | #A_inv = A.clone() 69 | #for i in xrange(A.data.shape[0]): 70 | # A_inv[i,:,:] = torch.inverse(A[i,:,:]) 71 | 72 | #A_inv = [t.inverse() for t in torch.functional.unbind(A)] 73 | #A_inv = torch.functional.stack(A_inv) 74 | A_inv = inverse_torch(A) 75 | 76 | 77 | #print "s", time.time() - start 78 | x = A_inv.bmm(b) 79 | 80 | B = A.clone() 81 | for i in xrange(3): 82 | B[:,:,i] = A[:,:,i] * x[:,i] 83 | return B 84 | 85 | DEFAULT_TARGET = np.array([[ 86 | [-1.0,-1, 1, 1], 87 | [ 1.0,-1, 1,-1], 88 | [ 1.0, 1, 1, 1] 89 | ]]) 90 | BASIS = None 91 | def compute_perspective(pts, target=None): 92 | global BASIS 93 | if target is None: 94 | target = torch.from_numpy(DEFAULT_TARGET).type(pts.data.type()) 95 | if BASIS is None: 96 | B = compute_basis(target) 97 | BASIS = inverse_torch(B) 98 | 99 | basis = BASIS.expand(pts.size(0), BASIS.size(1), BASIS.size(2)) 100 | 101 | A = compute_basis(pts) 102 | return A.bmm(basis) 103 | 104 | def pt_ori_sca_2_pts(state): 105 | # Input: b x [x, y, theta, scale] 106 | out = torch.ones(state.data.shape[0], 3, 2).type(state.data.type()) 107 | out[:,0,0] = torch.sin(state[:,2]) * state[:,3] + state[:,0] 108 | out[:,1,0] = torch.cos(state[:,2]) * state[:,3] + state[:,1] 109 | out[:,0,1] = -torch.sin(state[:,2]) * state[:,3] + state[:,0] 110 | out[:,1,1] = -torch.cos(state[:,2]) * state[:,3] + state[:,1] 111 | 112 | return out 113 | 114 | def get_init_matrix(input): 115 | output = torch.zeros((input.size(0), 3, 3)).type(input.data.type()) 116 | output[:,0,0] = 1 117 | output[:,1,1] = 1 118 | output[:,2,2] = 1 119 | 120 | x = input[:,0:1] 121 | y = input[:,1:2] 122 | angles = input[:,2:3] 123 | scaler = input[:,3:4] 124 | 125 | cosines = torch.cos(angles) 126 | sinuses = torch.sin(angles) 127 | output[:,0,0] = cosines * scaler 128 | output[:,1,1] = cosines * scaler 129 | output[:,1,0] = -sinuses * scaler 130 | output[:,0,1] = sinuses * scaler 131 | 132 | output[:,0,2] = x 133 | output[:,1,2] = y 134 | 135 | return output 136 | 137 | #the input is a delta, I allow either x,y,theta or just theta 138 | def get_step_matrix(input,no_xy,scale_index): 139 | output = torch.zeros((input.size(0), 3, 3)).type(input.data.type()) 140 | output[:,0,0] = 1 141 | output[:,1,1] = 1 142 | output[:,2,2] = 1 143 | 144 | if scale_index is None: 145 | scale=torch.ones_like(input[:,0]) 146 | else: 147 | scale=input[:,scale_index] 148 | if no_xy: 149 | x = y = 0 150 | angles = input[:,0:1] 151 | #if use_scale: 152 | # scale = input[:,1:2] 153 | else: 154 | x = input[:,0:1] 155 | y = input[:,1:2] 156 | angles = input[:,2:3] 157 | #if use_scale: 158 | # scale = input[:,3:4] 159 | 160 | cosines = torch.cos(angles) 161 | sinuses = torch.sin(angles) 162 | output[:,0,0] = cosines*scale 163 | output[:,1,1] = cosines*scale 164 | output[:,1,0] = -sinuses*scale 165 | output[:,0,1] = sinuses*scale 166 | 167 | output[:,0,2] = x 168 | output[:,1,2] = y 169 | 170 | return output 171 | 172 | class ScaleRotateMatrixGenerator(Module): 173 | def __init__(self): 174 | super(ScaleRotateMatrixGenerator, self).__init__() 175 | 176 | def forward(self, input): 177 | output = torch.zeros((input.size(0), 3, 2)).type(input.data.type()) 178 | output[:,0,0] = 1 179 | output[:,1,1] = 1 180 | 181 | angles = input[:,0] 182 | scaler = input[:,1] 183 | 184 | cosines = torch.cos(angles) 185 | sinuses = torch.sin(angles) 186 | output[:,0,0] = cosines * scaler 187 | output[:,1,1] = cosines * scaler 188 | output[:,1,0] = -sinuses * scaler 189 | output[:,0,1] = sinuses * scaler 190 | 191 | return output 192 | 193 | #from https://gist.github.com/ncullen93/425ca642955f73452ebc097b3b46c493 194 | def transform_matrix_offset_center(matrix, x, y): 195 | """Apply offset to a transform matrix so that the image is 196 | transformed about the center of the image. 197 | NOTE: This is a fairly simple operaion, so can easily be 198 | moved to full torch. 199 | Arguments 200 | --------- 201 | matrix : 3x3 matrix/array 202 | x : integer 203 | height dimension of image to be transformed 204 | y : integer 205 | width dimension of image to be transformed 206 | """ 207 | o_x = float(x) / 2 + 0.5 208 | o_y = float(y) / 2 + 0.5 209 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 210 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 211 | transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) 212 | return transform_matrix 213 | 214 | def apply_transform(x, transform, fill_mode='nearest', fill_value=0., out_shape=None): 215 | """Applies an affine transform to a 2D array, or to each channel of a 3D array. 216 | NOTE: this can and certainly should be moved to full torch operations. 217 | Arguments 218 | --------- 219 | x : np.ndarray 220 | array to transform. NOTE: array should be ordered CHW 221 | 222 | transform : 3x3 affine transform matrix 223 | matrix to apply 224 | """ 225 | x = x.astype('float32') 226 | transform = transform_matrix_offset_center(transform, x.shape[1], x.shape[2]) 227 | final_affine_matrix = transform[:2, :2] 228 | final_offset = transform[:2, 2] 229 | channel_images = [ndi.interpolation.affine_transform(x_channel, final_affine_matrix, 230 | final_offset,output_shape=out_shape, order=0, mode=fill_mode, cval=fill_value) for x_channel in x] 231 | x = np.stack(channel_images, axis=0) 232 | return x 233 | 234 | def rotate(input, rotation, crop_to): 235 | #if rotation==0: 236 | # return input 237 | theta = rotation #math.pi / 180 * degree 238 | rotation_matrix = np.array([[math.cos(theta), -math.sin(theta), 0], 239 | [math.sin(theta), math.cos(theta), 0], 240 | [0, 0, 1]]) 241 | x_transformed = torch.from_numpy(apply_transform(input.numpy(), rotation_matrix, 242 | fill_mode=self.fill_mode, fill_value=self.fill_value, out_shape=crop_to)) 243 | return x_transformed 244 | -------------------------------------------------------------------------------- /model/csrc/cpu/ROIAlign_cpu.cpp: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | #include "cpu/vision.h" 4 | 5 | // implementation taken from Caffe2 6 | template 7 | struct PreCalc { 8 | int pos1; 9 | int pos2; 10 | int pos3; 11 | int pos4; 12 | T w1; 13 | T w2; 14 | T w3; 15 | T w4; 16 | }; 17 | 18 | template 19 | void pre_calc_for_bilinear_interpolate( 20 | const int height, 21 | const int width, 22 | const int pooled_height, 23 | const int pooled_width, 24 | const int iy_upper, 25 | const int ix_upper, 26 | T roi_start_h, 27 | T roi_start_w, 28 | T bin_size_h, 29 | T bin_size_w, 30 | int roi_bin_grid_h, 31 | int roi_bin_grid_w, 32 | std::vector>& pre_calc) { 33 | int pre_calc_index = 0; 34 | for (int ph = 0; ph < pooled_height; ph++) { 35 | for (int pw = 0; pw < pooled_width; pw++) { 36 | for (int iy = 0; iy < iy_upper; iy++) { 37 | const T yy = roi_start_h + ph * bin_size_h + 38 | static_cast(iy + .5f) * bin_size_h / 39 | static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 40 | for (int ix = 0; ix < ix_upper; ix++) { 41 | const T xx = roi_start_w + pw * bin_size_w + 42 | static_cast(ix + .5f) * bin_size_w / 43 | static_cast(roi_bin_grid_w); 44 | 45 | T x = xx; 46 | T y = yy; 47 | // deal with: inverse elements are out of feature map boundary 48 | if (y < -1.0 || y > height || x < -1.0 || x > width) { 49 | // empty 50 | PreCalc pc; 51 | pc.pos1 = 0; 52 | pc.pos2 = 0; 53 | pc.pos3 = 0; 54 | pc.pos4 = 0; 55 | pc.w1 = 0; 56 | pc.w2 = 0; 57 | pc.w3 = 0; 58 | pc.w4 = 0; 59 | pre_calc[pre_calc_index] = pc; 60 | pre_calc_index += 1; 61 | continue; 62 | } 63 | 64 | if (y <= 0) { 65 | y = 0; 66 | } 67 | if (x <= 0) { 68 | x = 0; 69 | } 70 | 71 | int y_low = (int)y; 72 | int x_low = (int)x; 73 | int y_high; 74 | int x_high; 75 | 76 | if (y_low >= height - 1) { 77 | y_high = y_low = height - 1; 78 | y = (T)y_low; 79 | } else { 80 | y_high = y_low + 1; 81 | } 82 | 83 | if (x_low >= width - 1) { 84 | x_high = x_low = width - 1; 85 | x = (T)x_low; 86 | } else { 87 | x_high = x_low + 1; 88 | } 89 | 90 | T ly = y - y_low; 91 | T lx = x - x_low; 92 | T hy = 1. - ly, hx = 1. - lx; 93 | T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; 94 | 95 | // save weights and indeces 96 | PreCalc pc; 97 | pc.pos1 = y_low * width + x_low; 98 | pc.pos2 = y_low * width + x_high; 99 | pc.pos3 = y_high * width + x_low; 100 | pc.pos4 = y_high * width + x_high; 101 | pc.w1 = w1; 102 | pc.w2 = w2; 103 | pc.w3 = w3; 104 | pc.w4 = w4; 105 | pre_calc[pre_calc_index] = pc; 106 | 107 | pre_calc_index += 1; 108 | } 109 | } 110 | } 111 | } 112 | } 113 | 114 | template 115 | void ROIAlignForward_cpu_kernel( 116 | const int nthreads, 117 | const T* bottom_data, 118 | const T& spatial_scale, 119 | const int channels, 120 | const int height, 121 | const int width, 122 | const int pooled_height, 123 | const int pooled_width, 124 | const int sampling_ratio, 125 | const T* bottom_rois, 126 | //int roi_cols, 127 | T* top_data) { 128 | //AT_ASSERT(roi_cols == 4 || roi_cols == 5); 129 | int roi_cols = 5; 130 | 131 | int n_rois = nthreads / channels / pooled_width / pooled_height; 132 | // (n, c, ph, pw) is an element in the pooled output 133 | // can be parallelized using omp 134 | // #pragma omp parallel for num_threads(32) 135 | for (int n = 0; n < n_rois; n++) { 136 | int index_n = n * channels * pooled_width * pooled_height; 137 | 138 | // roi could have 4 or 5 columns 139 | const T* offset_bottom_rois = bottom_rois + n * roi_cols; 140 | int roi_batch_ind = 0; 141 | if (roi_cols == 5) { 142 | roi_batch_ind = offset_bottom_rois[0]; 143 | offset_bottom_rois++; 144 | } 145 | 146 | // Do not using rounding; this implementation detail is critical 147 | T roi_start_w = offset_bottom_rois[0] * spatial_scale; 148 | T roi_start_h = offset_bottom_rois[1] * spatial_scale; 149 | T roi_end_w = offset_bottom_rois[2] * spatial_scale; 150 | T roi_end_h = offset_bottom_rois[3] * spatial_scale; 151 | // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale); 152 | // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale); 153 | // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale); 154 | // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale); 155 | 156 | // Force malformed ROIs to be 1x1 157 | T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); 158 | T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); 159 | T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); 160 | T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); 161 | 162 | // We use roi_bin_grid to sample the grid and mimic integral 163 | int roi_bin_grid_h = (sampling_ratio > 0) 164 | ? sampling_ratio 165 | : ceil(roi_height / pooled_height); // e.g., = 2 166 | int roi_bin_grid_w = 167 | (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); 168 | 169 | // We do average (integral) pooling inside a bin 170 | const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 171 | 172 | // we want to precalculate indeces and weights shared by all chanels, 173 | // this is the key point of optimiation 174 | std::vector> pre_calc( 175 | roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); 176 | pre_calc_for_bilinear_interpolate( 177 | height, 178 | width, 179 | pooled_height, 180 | pooled_width, 181 | roi_bin_grid_h, 182 | roi_bin_grid_w, 183 | roi_start_h, 184 | roi_start_w, 185 | bin_size_h, 186 | bin_size_w, 187 | roi_bin_grid_h, 188 | roi_bin_grid_w, 189 | pre_calc); 190 | 191 | for (int c = 0; c < channels; c++) { 192 | int index_n_c = index_n + c * pooled_width * pooled_height; 193 | const T* offset_bottom_data = 194 | bottom_data + (roi_batch_ind * channels + c) * height * width; 195 | int pre_calc_index = 0; 196 | 197 | for (int ph = 0; ph < pooled_height; ph++) { 198 | for (int pw = 0; pw < pooled_width; pw++) { 199 | int index = index_n_c + ph * pooled_width + pw; 200 | 201 | T output_val = 0.; 202 | for (int iy = 0; iy < roi_bin_grid_h; iy++) { 203 | for (int ix = 0; ix < roi_bin_grid_w; ix++) { 204 | PreCalc pc = pre_calc[pre_calc_index]; 205 | output_val += pc.w1 * offset_bottom_data[pc.pos1] + 206 | pc.w2 * offset_bottom_data[pc.pos2] + 207 | pc.w3 * offset_bottom_data[pc.pos3] + 208 | pc.w4 * offset_bottom_data[pc.pos4]; 209 | 210 | pre_calc_index += 1; 211 | } 212 | } 213 | output_val /= count; 214 | 215 | top_data[index] = output_val; 216 | } // for pw 217 | } // for ph 218 | } // for c 219 | } // for n 220 | } 221 | 222 | at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, 223 | const at::Tensor& rois, 224 | const float spatial_scale, 225 | const int pooled_height, 226 | const int pooled_width, 227 | const int sampling_ratio) { 228 | AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor"); 229 | AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor"); 230 | 231 | auto num_rois = rois.size(0); 232 | auto channels = input.size(1); 233 | auto height = input.size(2); 234 | auto width = input.size(3); 235 | 236 | auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); 237 | auto output_size = num_rois * pooled_height * pooled_width * channels; 238 | 239 | if (output.numel() == 0) { 240 | return output; 241 | } 242 | 243 | AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { 244 | ROIAlignForward_cpu_kernel( 245 | output_size, 246 | input.data(), 247 | spatial_scale, 248 | channels, 249 | height, 250 | width, 251 | pooled_height, 252 | pooled_width, 253 | sampling_ratio, 254 | rois.data(), 255 | output.data()); 256 | }); 257 | return output; 258 | } 259 | -------------------------------------------------------------------------------- /datasets/graph_pair.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import torch.utils.data 17 | import numpy as np 18 | import json 19 | #from skimage import io 20 | #from skimage import draw 21 | #import skimage.transform as sktransform 22 | import os 23 | import math 24 | from utils.crop_transform import CropBoxTransform 25 | from utils import augmentation 26 | from collections import defaultdict, OrderedDict 27 | from utils.forms_annotations import fixAnnotations, convertBBs, getBBWithPoints, getStartEndGT 28 | import timeit 29 | 30 | import cv2 31 | 32 | 33 | def collate(batch): 34 | assert(len(batch)==1) 35 | return batch[0] 36 | 37 | 38 | class GraphPairDataset(torch.utils.data.Dataset): 39 | """ 40 | Class for reading dataset and creating starting and ending gt 41 | """ 42 | 43 | 44 | def __init__(self, dirPath=None, split=None, config=None, images=None): 45 | #if 'augmentation_params' in config['data_loader']: 46 | # self.augmentation_params=config['augmentation_params'] 47 | #else: 48 | # self.augmentation_params=None 49 | self.color = config['color'] if 'color' in config else True 50 | self.rotate = config['rotation'] if 'rotation' in config else False 51 | #patchSize=config['patch_size'] 52 | if 'crop_params' in config and config['crop_params'] is not None: 53 | self.transform = CropBoxTransform(config['crop_params'],self.rotate) 54 | else: 55 | self.transform = None 56 | self.rescale_range = config['rescale_range'] 57 | if type(self.rescale_range) is float: 58 | self.rescale_range = [self.rescale_range,self.rescale_range] 59 | if self.rescale_range[0]==450: 60 | self.rescale_range[0]=0.2 61 | elif self.rescale_range[0]>1.0: 62 | self.rescale_range[0]=0.27 63 | if self.rescale_range[1]==800: 64 | self.rescale_range[1]=0.33 65 | elif self.rescale_range[1]>1.0: 66 | self.rescale_range[1]=0.27 67 | if 'cache_resized_images' in config: 68 | self.cache_resized = config['cache_resized_images'] 69 | if self.cache_resized: 70 | self.cache_path = os.path.join(dirPath,'cache_'+str(self.rescale_range[1])) 71 | if not os.path.exists(self.cache_path): 72 | os.mkdir(self.cache_path) 73 | else: 74 | self.cache_resized = False 75 | self.pixel_count_thresh = config['pixel_count_thresh'] if 'pixel_count_thresh' in config else 10000000 76 | self.max_dim_thresh = config['max_dim_thresh'] if 'max_dim_thresh' in config else 2700 77 | 78 | 79 | 80 | 81 | 82 | 83 | def __len__(self): 84 | return len(self.images) 85 | 86 | def __getitem__(self,index): 87 | return self.getitem(index) 88 | def getitem(self,index,scaleP=None,cropPoint=None): 89 | ##ticFull=timeit.default_timer() 90 | imagePath = self.images[index]['imagePath'] 91 | imageName = self.images[index]['imageName'] 92 | annotationPath = self.images[index]['annotationPath'] 93 | #print(annotationPath) 94 | rescaled = self.images[index]['rescaled'] 95 | with open(annotationPath) as annFile: 96 | annotations = json.loads(annFile.read()) 97 | 98 | ##tic=timeit.default_timer() 99 | np_img = cv2.imread(imagePath, 1 if self.color else 0)#/255.0 100 | if np_img is None or np_img.shape[0]==0: 101 | print("ERROR, could not open "+imagePath) 102 | return self.__getitem__((index+1)%self.__len__()) 103 | if scaleP is None: 104 | s = np.random.uniform(self.rescale_range[0], self.rescale_range[1]) 105 | else: 106 | s = scaleP 107 | partial_rescale = s/rescaled 108 | if self.transform is None: #we're doing the whole image 109 | #this is a check to be sure we don't send too big images through 110 | pixel_count = partial_rescale*partial_rescale*np_img.shape[0]*np_img.shape[1] 111 | if pixel_count > self.pixel_count_thresh: 112 | partial_rescale = math.sqrt(partial_rescale*partial_rescale*self.pixel_count_thresh/pixel_count) 113 | print('{} exceed thresh: {}: {}, new {}: {}'.format(imageName,s,pixel_count,rescaled*partial_rescale,partial_rescale*partial_rescale*np_img.shape[0]*np_img.shape[1])) 114 | s = rescaled*partial_rescale 115 | 116 | 117 | max_dim = partial_rescale*max(np_img.shape[0],np_img.shape[1]) 118 | if max_dim > self.max_dim_thresh: 119 | partial_rescale = partial_rescale*(self.max_dim_thresh/max_dim) 120 | print('{} exceed thresh: {}: {}, new {}: {}'.format(imageName,s,max_dim,rescaled*partial_rescale,partial_rescale*max(np_img.shape[0],np_img.shape[1]))) 121 | s = rescaled*partial_rescale 122 | 123 | 124 | 125 | ##tic=timeit.default_timer() 126 | #np_img = cv2.resize(np_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC) 127 | np_img = cv2.resize(np_img,(0,0), 128 | fx=partial_rescale, 129 | fy=partial_rescale, 130 | interpolation = cv2.INTER_CUBIC) 131 | if not self.color: 132 | np_img=np_img[...,None] #add 'color' channel 133 | ##print('resize: {} [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1])) 134 | 135 | ##tic=timeit.default_timer() 136 | 137 | bbs,ids,numClasses,trans = self.parseAnn(annotations,s) 138 | 139 | #start_of_line, end_of_line = getStartEndGT(annotations['byId'].values(),s) 140 | #Try: 141 | # table_points, table_pixels = self.getTables( 142 | # fieldBBs, 143 | # s, 144 | # np_img.shape[0], 145 | # np_img.shape[1], 146 | # annotations['samePairs']) 147 | #Except Exception as inst: 148 | # if imageName not in self.errors: 149 | # table_points=None 150 | # table_pixels=None 151 | # print(inst) 152 | # print('Table error on: '+imagePath) 153 | # self.errors.append(imageName) 154 | 155 | 156 | #pixel_gt = table_pixels 157 | 158 | ##ticTr=timeit.default_timer() 159 | if self.transform is not None: 160 | out, cropPoint = self.transform({ 161 | "img": np_img, 162 | "bb_gt": bbs, 163 | 'bb_auxs':ids, 164 | #"line_gt": { 165 | # "start_of_line": start_of_line, 166 | # "end_of_line": end_of_line 167 | # }, 168 | #"point_gt": { 169 | # "table_points": table_points 170 | # }, 171 | #"pixel_gt": pixel_gt, 172 | 173 | }, cropPoint) 174 | np_img = out['img'] 175 | bbs = out['bb_gt'] 176 | ids= out['bb_auxs'] 177 | 178 | 179 | ##tic=timeit.default_timer() 180 | if np_img.shape[2]==3: 181 | np_img = augmentation.apply_random_color_rotation(np_img) 182 | np_img = augmentation.apply_tensmeyer_brightness(np_img) 183 | else: 184 | np_img = augmentation.apply_tensmeyer_brightness(np_img) 185 | ##print('augmentation: {}'.format(timeit.default_timer()-tic)) 186 | ##print('transfrm: {} [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1])) 187 | pairs=set() 188 | #import pdb;pdb.set_trace() 189 | numNeighbors=[0]*len(ids) 190 | for index1,id in enumerate(ids): #updated 191 | responseBBIdList = self.getResponseBBIdList(id,annotations) 192 | for bbId in responseBBIdList: 193 | try: 194 | index2 = ids.index(bbId) 195 | #adjMatrix[min(index1,index2),max(index1,index2)]=1 196 | pairs.add((min(index1,index2),max(index1,index2))) 197 | numNeighbors[index1]+=1 198 | except ValueError: 199 | pass 200 | #ones = torch.ones(len(pairs)) 201 | #if len(pairs)>0: 202 | # pairs = torch.LongTensor(list(pairs)).t() 203 | #else: 204 | # pairs = torch.LongTensor(pairs) 205 | #adjMatrix = torch.sparse.FloatTensor(pairs,ones,(len(ids),len(ids))) # This is an upper diagonal matrix as pairings are bi-directional 206 | 207 | #if len(np_img.shape)==2: 208 | # img=np_img[None,None,:,:] #add "color" channel and batch 209 | #else: 210 | img = np_img.transpose([2,0,1])[None,...] #from [row,col,color] to [batch,color,row,col] 211 | img = img.astype(np.float32) 212 | img = torch.from_numpy(img) 213 | img = 1.0 - img / 128.0 #ideally the median value would be 0 214 | #if pixel_gt is not None: 215 | # pixel_gt = pixel_gt.transpose([2,0,1])[None,...] 216 | # pixel_gt = torch.from_numpy(pixel_gt) 217 | 218 | #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line) 219 | #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line) 220 | 221 | bbs = convertBBs(bbs,self.rotate,numClasses) 222 | if len(numNeighbors)>0: 223 | numNeighbors = torch.tensor(numNeighbors)[None,:] #add batch dim 224 | else: 225 | numNeighbors=None 226 | #if table_points is not None: 227 | # table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points) 228 | 229 | return { 230 | "img": img, 231 | "bb_gt": bbs, 232 | "num_neighbors": numNeighbors, 233 | "adj": pairs,#adjMatrix, 234 | "imgName": imageName, 235 | "scale": s, 236 | "cropPoint": cropPoint, 237 | "transcription": [trans[id] for id in ids if id in trans] 238 | } 239 | 240 | 241 | -------------------------------------------------------------------------------- /trainer/feature_pair_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import numpy as np 17 | import torch 18 | from torch.nn import functional as F 19 | #from base import BaseTrainer 20 | from .trainer import Trainer 21 | import timeit 22 | from utils import util 23 | from collections import defaultdict 24 | from evaluators import FormsBoxDetect_printer 25 | from utils.yolo_tools import non_max_sup_iou, AP_iou, computeAP 26 | 27 | 28 | class FeaturePairTrainer(Trainer): 29 | """ 30 | Trainer class 31 | 32 | Note: 33 | Inherited from BaseTrainer. 34 | self.optimizer is by default handled by BaseTrainer based on config. 35 | """ 36 | def __init__(self, model, loss, metrics, resume, config, 37 | data_loader, valid_data_loader=None, train_logger=None): 38 | super(FeaturePairTrainer, self).__init__(model, loss, metrics, resume, config, 39 | data_loader, valid_data_loader, train_logger) 40 | #self.config = config 41 | #self.batch_size = data_loader.batch_size 42 | #self.data_loader = data_loader 43 | #self.data_loader_iter = iter(data_loader) 44 | #self.valid_data_loader = valid_data_loader 45 | #self.valid = True if self.valid_data_loader is not None else False 46 | 47 | def _eval_metrics(self, typ,name,output, target): 48 | if len(self.metrics[typ])>0: 49 | #acc_metrics = np.zeros(len(self.metrics[typ])) 50 | met={} 51 | cpu_output=[] 52 | for pred in output: 53 | cpu_output.append(output.cpu().data.numpy()) 54 | target = target.cpu().data.numpy() 55 | for i, metric in enumerate(self.metrics[typ]): 56 | met[name+metric.__name__] = metric(cpu_output, target) 57 | return acc_metrics 58 | else: 59 | #return np.zeros(0) 60 | return {} 61 | 62 | def _train_iteration(self, iteration): 63 | """ 64 | Training logic for an iteration 65 | 66 | :param iteration: Current training iteration. 67 | :return: A log that contains all information you want to save. 68 | 69 | Note: 70 | If you have additional information to record, for example: 71 | > additional_log = {"x": x, "y": y} 72 | merge it with log before return. i.e. 73 | > log = {**log, **additional_log} 74 | > return log 75 | 76 | The metrics in log must have the key 'metrics'. 77 | """ 78 | self.model.train() 79 | #self.lr_schedule.step() 80 | 81 | ##tic=timeit.default_timer() 82 | batch_idx = (iteration-1) % len(self.data_loader) 83 | try: 84 | thisInstance = self.data_loader_iter.next() 85 | except StopIteration: 86 | self.data_loader_iter = iter(self.data_loader) 87 | thisInstance = self.data_loader_iter.next() 88 | ##toc=timeit.default_timer() 89 | ##print('data: '+str(toc-tic)) 90 | 91 | ##tic=timeit.default_timer() 92 | 93 | self.optimizer.zero_grad() 94 | 95 | ##toc=timeit.default_timer() 96 | ##print('for: '+str(toc-tic)) 97 | #loss = self.loss(output, target) 98 | index=0 99 | losses={} 100 | ##tic=timeit.default_timer() 101 | 102 | #if self.iteration % self.save_step == 0: 103 | # targetPoints={} 104 | # targetPixels=None 105 | # _,lossC=FormsBoxPair_printer(None,thisInstance,self.model,self.gpu,self._eval_metrics,self.checkpoint_dir,self.iteration,self.loss['box']) 106 | # loss, position_loss, conf_loss, class_loss, recall, precision = lossC 107 | #else: 108 | data,label = self._to_tensor(thisInstance['data'],thisInstance['label']) 109 | output = self.model(data) 110 | outputRel = output[:,0] 111 | if output.size(1)==3: 112 | outputNN = output[:,1:] 113 | gtNN = self._to_tensor(thisInstance['numNeighbors']) 114 | lossNN = F.mse_loss(outputNN,gtNN[0]) 115 | else: 116 | lossNN=0 117 | #import pdb;pdb.set_trace() 118 | lossRel = self.loss(outputRel,label) 119 | scoreTrue = (outputRel*label).sum()/label.sum() 120 | scoreFalse = (outputRel*(1-label)).sum()/(1-label).sum() 121 | 122 | loss = lossRel+lossNN 123 | 124 | ##toc=timeit.default_timer() 125 | ##print('loss: '+str(toc-tic)) 126 | ##tic=timeit.default_timer() 127 | loss.backward() 128 | #what is grads? 129 | #minGrad=9999999999 130 | #maxGrad=-9999999999 131 | #for p in filter(lambda p: p.grad is not None, self.model.parameters()): 132 | # minGrad = min(minGrad,p.min()) 133 | # maxGrad = max(maxGrad,p.max()) 134 | #import pdb; pdb.set_trace() 135 | torch.nn.utils.clip_grad_value_(self.model.parameters(),1) 136 | self.optimizer.step() 137 | 138 | ##toc=timeit.default_timer() 139 | ##print('bac: '+str(toc-tic)) 140 | 141 | #tic=timeit.default_timer() 142 | metrics={} 143 | #index=0 144 | #for name, target in targetBoxes.items(): 145 | # metrics = {**metrics, **self._eval_metrics('box',name,output, target)} 146 | #for name, target in targetPoints.items(): 147 | # metrics = {**metrics, **self._eval_metrics('point',name,output, target)} 148 | # metrics = self._eval_metrics(name,output, target) 149 | #toc=timeit.default_timer() 150 | #print('metric: '+str(toc-tic)) 151 | 152 | ##tic=timeit.default_timer() 153 | loss = loss.item() 154 | lossRel=lossRel.item() 155 | if type(lossNN) is not int: 156 | lossNN=lossNN.item() 157 | ##toc=timeit.default_timer() 158 | ##print('item: '+str(toc-tic)) 159 | #perAnchor={} 160 | #for i in range(avg_conf_per_anchor.size(0)): 161 | # perAnchor['anchor{}'.format(i)]=avg_conf_per_anchor[i] 162 | 163 | log = { 164 | 'loss': loss, 165 | 'lossRel':lossRel, 166 | 'lossNN':lossNN, 167 | 'scoreTrue': scoreTrue, 168 | 'scoreFalse': scoreFalse, 169 | 170 | **metrics, 171 | **losses 172 | } 173 | 174 | #if iteration%10==0: 175 | #image=None 176 | #queryMask=None 177 | #targetBoxes=None 178 | #outputBoxes=None 179 | #outputOffsets=None 180 | #loss=None 181 | #torch.cuda.empty_cache() 182 | 183 | 184 | return log# 185 | 186 | 187 | def _valid_epoch(self): 188 | """ 189 | Validate after training an epoch 190 | 191 | :return: A log that contains information about validation 192 | 193 | Note: 194 | The validation metrics in log must have the key 'val_metrics'. 195 | """ 196 | self.model.eval() 197 | total_val_loss = 0 198 | total_val_lossRel = 0 199 | total_val_lossNN = 0 200 | total_val_metrics = np.zeros(len(self.metrics)) 201 | 202 | tp_image=defaultdict(lambda:0) 203 | fp_image=defaultdict(lambda:0) 204 | tn_image=defaultdict(lambda:0) 205 | fn_image=defaultdict(lambda:0) 206 | images=set() 207 | scores=defaultdict(list) 208 | 209 | with torch.no_grad(): 210 | losses = defaultdict(lambda: 0) 211 | for batch_idx, instance in enumerate(self.valid_data_loader): 212 | if not self.logged: 213 | print('iter:{} valid batch: {}/{}'.format(self.iteration,batch_idx,len(self.valid_data_loader)), end='\r') 214 | 215 | data,label = self._to_tensor(instance['data'],instance['label']) 216 | output = self.model(data) 217 | outputRel = output[:,0] 218 | if output.size(1)==3: 219 | outputNN = output[:,1:] 220 | gtNN = self._to_tensor(instance['numNeighbors']) 221 | lossNN = F.mse_loss(outputNN,gtNN[0]) 222 | else: 223 | lossNN=0 224 | lossRel = self.loss(outputRel,label) 225 | 226 | loss = lossRel+lossNN 227 | 228 | 229 | for b in range(len(output)): 230 | image = instance['imgName'][b] 231 | images.add(image) 232 | scores[image].append( (outputRel[b],label[b]) ) 233 | if outputRel[b]<0.5: 234 | if label[b]==0: 235 | tn_image[image]+=1 236 | else: 237 | fn_image[image]+=1 238 | else: 239 | if label[b]==0: 240 | fp_image[image]+=1 241 | else: 242 | tp_image[image]+=1 243 | 244 | total_val_loss += loss.item() 245 | total_val_lossRel += lossRel.item() 246 | if type(lossNN) is not int: 247 | lossNN=lossNN.item() 248 | total_val_lossNN += lossNN 249 | 250 | mRecall=0 251 | mPrecision=0 252 | mAP=0 253 | mAP_count=0 254 | 255 | for image in images: 256 | ap = computeAP(scores[image]) 257 | if ap is not None: 258 | mAP+=ap 259 | mAP_count+=1 260 | if tp_image[image]+fn_image[image]>0: 261 | mRecall += tp_image[image]/(tp_image[image]+fn_image[image]) 262 | else: 263 | mRecall += 1 264 | if tp_image[image]+fp_image[image]>0: 265 | mPrecision += tp_image[image]/(tp_image[image]+fp_image[image]) 266 | else: 267 | mPrecision += 1 268 | mRecall /= len(images) 269 | mPrecision /= len(images) 270 | if mAP_count>0: 271 | mAP /= mAP_count 272 | 273 | return { 274 | 'val_loss': total_val_loss / len(self.valid_data_loader), 275 | 'val_lossRel': total_val_lossRel / len(self.valid_data_loader), 276 | 'val_lossNN': total_val_lossNN / len(self.valid_data_loader), 277 | 'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist(), 278 | 'val_recall*':mRecall, 279 | 'val_precision*':mPrecision, 280 | 'val_mAP*': mAP 281 | } 282 | -------------------------------------------------------------------------------- /model/csrc/cuda/ROIAlign_cuda.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | // TODO make it in a common file 10 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 11 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 12 | i += blockDim.x * gridDim.x) 13 | 14 | 15 | template 16 | __device__ T bilinear_interpolate(const T* bottom_data, 17 | const int height, const int width, 18 | T y, T x, 19 | const int index /* index for debug only*/) { 20 | 21 | // deal with cases that inverse elements are out of feature map boundary 22 | if (y < -1.0 || y > height || x < -1.0 || x > width) { 23 | //empty 24 | return 0; 25 | } 26 | 27 | if (y <= 0) y = 0; 28 | if (x <= 0) x = 0; 29 | 30 | int y_low = (int) y; 31 | int x_low = (int) x; 32 | int y_high; 33 | int x_high; 34 | 35 | if (y_low >= height - 1) { 36 | y_high = y_low = height - 1; 37 | y = (T) y_low; 38 | } else { 39 | y_high = y_low + 1; 40 | } 41 | 42 | if (x_low >= width - 1) { 43 | x_high = x_low = width - 1; 44 | x = (T) x_low; 45 | } else { 46 | x_high = x_low + 1; 47 | } 48 | 49 | T ly = y - y_low; 50 | T lx = x - x_low; 51 | T hy = 1. - ly, hx = 1. - lx; 52 | // do bilinear interpolation 53 | T v1 = bottom_data[y_low * width + x_low]; 54 | T v2 = bottom_data[y_low * width + x_high]; 55 | T v3 = bottom_data[y_high * width + x_low]; 56 | T v4 = bottom_data[y_high * width + x_high]; 57 | T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; 58 | 59 | T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 60 | 61 | return val; 62 | } 63 | 64 | template 65 | __global__ void RoIAlignForward(const int nthreads, const T* bottom_data, 66 | const T spatial_scale, const int channels, 67 | const int height, const int width, 68 | const int pooled_height, const int pooled_width, 69 | const int sampling_ratio, 70 | const T* bottom_rois, T* top_data) { 71 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 72 | // (n, c, ph, pw) is an element in the pooled output 73 | int pw = index % pooled_width; 74 | int ph = (index / pooled_width) % pooled_height; 75 | int c = (index / pooled_width / pooled_height) % channels; 76 | int n = index / pooled_width / pooled_height / channels; 77 | 78 | const T* offset_bottom_rois = bottom_rois + n * 5; 79 | int roi_batch_ind = offset_bottom_rois[0]; 80 | 81 | // Do not using rounding; this implementation detail is critical 82 | T roi_start_w = offset_bottom_rois[1] * spatial_scale; 83 | T roi_start_h = offset_bottom_rois[2] * spatial_scale; 84 | T roi_end_w = offset_bottom_rois[3] * spatial_scale; 85 | T roi_end_h = offset_bottom_rois[4] * spatial_scale; 86 | // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); 87 | // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); 88 | // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); 89 | // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); 90 | 91 | // Force malformed ROIs to be 1x1 92 | T roi_width = max(roi_end_w - roi_start_w, (T)1.); 93 | T roi_height = max(roi_end_h - roi_start_h, (T)1.); 94 | T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); 95 | T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); 96 | 97 | const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; 98 | 99 | // We use roi_bin_grid to sample the grid and mimic integral 100 | int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 101 | int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); 102 | 103 | // We do average (integral) pooling inside a bin 104 | const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 105 | 106 | T output_val = 0.; 107 | for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1 108 | { 109 | const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 110 | for (int ix = 0; ix < roi_bin_grid_w; ix ++) 111 | { 112 | const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); 113 | 114 | T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index); 115 | output_val += val; 116 | } 117 | } 118 | output_val /= count; 119 | 120 | top_data[index] = output_val; 121 | } 122 | } 123 | 124 | 125 | template 126 | __device__ void bilinear_interpolate_gradient( 127 | const int height, const int width, 128 | T y, T x, 129 | T & w1, T & w2, T & w3, T & w4, 130 | int & x_low, int & x_high, int & y_low, int & y_high, 131 | const int index /* index for debug only*/) { 132 | 133 | // deal with cases that inverse elements are out of feature map boundary 134 | if (y < -1.0 || y > height || x < -1.0 || x > width) { 135 | //empty 136 | w1 = w2 = w3 = w4 = 0.; 137 | x_low = x_high = y_low = y_high = -1; 138 | return; 139 | } 140 | 141 | if (y <= 0) y = 0; 142 | if (x <= 0) x = 0; 143 | 144 | y_low = (int) y; 145 | x_low = (int) x; 146 | 147 | if (y_low >= height - 1) { 148 | y_high = y_low = height - 1; 149 | y = (T) y_low; 150 | } else { 151 | y_high = y_low + 1; 152 | } 153 | 154 | if (x_low >= width - 1) { 155 | x_high = x_low = width - 1; 156 | x = (T) x_low; 157 | } else { 158 | x_high = x_low + 1; 159 | } 160 | 161 | T ly = y - y_low; 162 | T lx = x - x_low; 163 | T hy = 1. - ly, hx = 1. - lx; 164 | 165 | // reference in forward 166 | // T v1 = bottom_data[y_low * width + x_low]; 167 | // T v2 = bottom_data[y_low * width + x_high]; 168 | // T v3 = bottom_data[y_high * width + x_low]; 169 | // T v4 = bottom_data[y_high * width + x_high]; 170 | // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 171 | 172 | w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; 173 | 174 | return; 175 | } 176 | 177 | template 178 | __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, 179 | const int num_rois, const T spatial_scale, 180 | const int channels, const int height, const int width, 181 | const int pooled_height, const int pooled_width, 182 | const int sampling_ratio, 183 | T* bottom_diff, 184 | const T* bottom_rois) { 185 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 186 | // (n, c, ph, pw) is an element in the pooled output 187 | int pw = index % pooled_width; 188 | int ph = (index / pooled_width) % pooled_height; 189 | int c = (index / pooled_width / pooled_height) % channels; 190 | int n = index / pooled_width / pooled_height / channels; 191 | 192 | const T* offset_bottom_rois = bottom_rois + n * 5; 193 | int roi_batch_ind = offset_bottom_rois[0]; 194 | 195 | // Do not using rounding; this implementation detail is critical 196 | T roi_start_w = offset_bottom_rois[1] * spatial_scale; 197 | T roi_start_h = offset_bottom_rois[2] * spatial_scale; 198 | T roi_end_w = offset_bottom_rois[3] * spatial_scale; 199 | T roi_end_h = offset_bottom_rois[4] * spatial_scale; 200 | // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); 201 | // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); 202 | // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); 203 | // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); 204 | 205 | // Force malformed ROIs to be 1x1 206 | T roi_width = max(roi_end_w - roi_start_w, (T)1.); 207 | T roi_height = max(roi_end_h - roi_start_h, (T)1.); 208 | T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); 209 | T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); 210 | 211 | T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; 212 | 213 | int top_offset = (n * channels + c) * pooled_height * pooled_width; 214 | const T* offset_top_diff = top_diff + top_offset; 215 | const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; 216 | 217 | // We use roi_bin_grid to sample the grid and mimic integral 218 | int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 219 | int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); 220 | 221 | // We do average (integral) pooling inside a bin 222 | const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 223 | 224 | for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1 225 | { 226 | const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 227 | for (int ix = 0; ix < roi_bin_grid_w; ix ++) 228 | { 229 | const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); 230 | 231 | T w1, w2, w3, w4; 232 | int x_low, x_high, y_low, y_high; 233 | 234 | bilinear_interpolate_gradient(height, width, y, x, 235 | w1, w2, w3, w4, 236 | x_low, x_high, y_low, y_high, 237 | index); 238 | 239 | T g1 = top_diff_this_bin * w1 / count; 240 | T g2 = top_diff_this_bin * w2 / count; 241 | T g3 = top_diff_this_bin * w3 / count; 242 | T g4 = top_diff_this_bin * w4 / count; 243 | 244 | if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) 245 | { 246 | atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast(g1)); 247 | atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast(g2)); 248 | atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast(g3)); 249 | atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast(g4)); 250 | } // if 251 | } // ix 252 | } // iy 253 | } // CUDA_1D_KERNEL_LOOP 254 | } // RoIAlignBackward 255 | 256 | 257 | at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, 258 | const at::Tensor& rois, 259 | const float spatial_scale, 260 | const int pooled_height, 261 | const int pooled_width, 262 | const int sampling_ratio) { 263 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); 264 | AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); 265 | 266 | auto num_rois = rois.size(0); 267 | auto channels = input.size(1); 268 | auto height = input.size(2); 269 | auto width = input.size(3); 270 | 271 | auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); 272 | auto output_size = num_rois * pooled_height * pooled_width * channels; 273 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 274 | 275 | dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); 276 | dim3 block(512); 277 | 278 | if (output.numel() == 0) { 279 | THCudaCheck(cudaGetLastError()); 280 | return output; 281 | } 282 | 283 | AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { 284 | RoIAlignForward<<>>( 285 | output_size, 286 | input.contiguous().data(), 287 | spatial_scale, 288 | channels, 289 | height, 290 | width, 291 | pooled_height, 292 | pooled_width, 293 | sampling_ratio, 294 | rois.contiguous().data(), 295 | output.data()); 296 | }); 297 | THCudaCheck(cudaGetLastError()); 298 | return output; 299 | } 300 | 301 | // TODO remove the dependency on input and use instead its sizes -> save memory 302 | at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, 303 | const at::Tensor& rois, 304 | const float spatial_scale, 305 | const int pooled_height, 306 | const int pooled_width, 307 | const int batch_size, 308 | const int channels, 309 | const int height, 310 | const int width, 311 | const int sampling_ratio) { 312 | AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); 313 | AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); 314 | 315 | auto num_rois = rois.size(0); 316 | auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); 317 | 318 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 319 | 320 | dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); 321 | dim3 block(512); 322 | 323 | // handle possibly empty gradients 324 | if (grad.numel() == 0) { 325 | THCudaCheck(cudaGetLastError()); 326 | return grad_input; 327 | } 328 | 329 | AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { 330 | RoIAlignBackwardFeature<<>>( 331 | grad.numel(), 332 | grad.contiguous().data(), 333 | num_rois, 334 | spatial_scale, 335 | channels, 336 | height, 337 | width, 338 | pooled_height, 339 | pooled_width, 340 | sampling_ratio, 341 | grad_input.data(), 342 | rois.contiguous().data()); 343 | }); 344 | THCudaCheck(cudaGetLastError()); 345 | return grad_input; 346 | } 347 | -------------------------------------------------------------------------------- /model/yolo_box_detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Brian Davis 3 | Visual-Template-free-Form-Parsting is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | Visual-Template-free-Form-Parsting is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with Visual-Template-free-Form-Parsting. If not, see . 15 | """ 16 | import torch 17 | from torch import nn 18 | from base import BaseModel 19 | import math 20 | import json 21 | import numpy as np 22 | from .net_builder import make_layers 23 | 24 | 25 | 26 | 27 | 28 | class YoloBoxDetector(nn.Module): #BaseModel 29 | def __init__(self, config): # predCount, base_0, base_1): 30 | #super(YoloBoxDetector, self).__init__(config) 31 | super(YoloBoxDetector, self).__init__() 32 | self.forPairing=False 33 | self.config = config 34 | self.rotation = config['rotation'] if 'rotation' in config else True 35 | self.numBBTypes = config['number_of_box_types'] 36 | self.numBBParams = 6 #conf,x-off,y-off,h-scale,w-scale,rot-off 37 | self.numLineParams = 5 #conf,x-off,y-off,h-scale,rot 38 | if 'pred_num_neighbors' in config and config['pred_num_neighbors']: 39 | self.predNumNeighbors=True 40 | self.numBBParams+=1 41 | print("Detecting number of neighbors!") 42 | else: 43 | self.predNumNeighbors=False 44 | 45 | self.predPointCount = config['number_of_point_types'] if 'number_of_point_types' in config else 0 46 | self.predPixelCount = config['number_of_pixel_types'] if 'number_of_pixel_types' in config else 0 47 | self.predLineCount = config['number_of_line_types'] if 'number_of_line_types' in config else 0 48 | 49 | with open(config['anchors_file']) as f: 50 | self.anchors = json.loads(f.read()) #array of objects {rot,height,width} 51 | if self.rotation: 52 | self.meanH=48.0046359128/2 53 | else: 54 | self.meanH=62.1242376857/2 55 | self.numAnchors = len(self.anchors) 56 | if self.predLineCount>0: 57 | print('Warning, using hardcoded mean H (yolo_box_detector)') 58 | 59 | in_ch = 3 if 'color' not in config or config['color'] else 1 60 | norm = config['norm_type'] if "norm_type" in config else None 61 | if norm is None: 62 | print('Warning: YoloBoxDetector has no normalization!') 63 | dilation = config['dilation'] if 'dilation' in config else 1 64 | dropout = config['dropout'] if 'dropout' in config else None 65 | #self.cnn, self.scale = vgg.vgg11_custOut(self.predLineCount*5+self.predPointCount*3,batch_norm=batch_norm, weight_norm=weight_norm) 66 | self.numOutBB = (self.numBBTypes+self.numBBParams)*self.numAnchors 67 | self.numOutLine = (self.numBBTypes+self.numLineParams)*self.predLineCount 68 | self.numOutPoint = self.predPointCount*3 69 | 70 | if 'down_layers_cfg' in config: 71 | layers_cfg = config['down_layers_cfg'] 72 | else: 73 | layers_cfg=[in_ch,64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512] 74 | 75 | self.net_down_modules, down_last_channels = make_layers(layers_cfg, dilation,norm,dropout=dropout) 76 | self.final_features=None 77 | self.last_channels=down_last_channels 78 | self.net_down_modules.append(nn.Conv2d(down_last_channels, self.numOutBB+self.numOutLine+self.numOutPoint, kernel_size=1)) 79 | self._hack_down = nn.Sequential(*self.net_down_modules) 80 | scaleX=1 81 | scaleY=1 82 | for a in layers_cfg: 83 | if a=='M' or (type(a) is str and a[0]=='D'): 84 | scaleX*=2 85 | scaleY*=2 86 | elif type(a) is str and a[0]=='U': 87 | scaleX/=2 88 | scaleY/=2 89 | elif type(a) is str and a[0:4]=='long': #long pool 90 | scaleX*=3 91 | scaleY*=2 92 | self.scale=(scaleX,scaleY) 93 | 94 | if self.predPixelCount>0: 95 | if 'up_layers_cfg' in config: 96 | up_layers_cfg = config['up_layers_cfg'] 97 | else: 98 | up_layers_cfg=[512, 'U+512', 256, 'U+256', 128, 'U+128', 64, 'U+64'] 99 | self.net_up_modules, up_last_channels = make_layers(up_layers_cfg, 1, norm,dropout=dropout) 100 | self.net_up_modules.append(nn.Conv2d(up_last_channels, self.predPixelCount, kernel_size=1)) 101 | self._hack_up = nn.Sequential(*self.net_up_modules) 102 | 103 | #self.base_0 = config['base_0'] 104 | #self.base_1 = config['base_1'] 105 | if 'DEBUG' in config: 106 | self.setDEBUG() 107 | 108 | def forward(self, img): 109 | #import pdb; pdb.set_trace() 110 | y = self._hack_down(img) 111 | if self.forPairing: 112 | return y[:,:(self.numBBParams+self.numBBTypes)*self.numAnchors,:,:] 113 | #levels=[img] 114 | #for module in self.net_down_modules: 115 | # levels.append(module(levels[-1])) 116 | #y=levels[-1] 117 | 118 | 119 | #priors_0 = Variable(torch.arange(0,y.size(2)).type_as(img.data), requires_grad=False)[None,:,None] 120 | priors_0 = torch.arange(0,y.size(2)).type_as(img.data)[None,:,None] 121 | priors_0 = (priors_0 + 0.5) * self.scale[1] #self.base_0 122 | priors_0 = priors_0.expand(y.size(0), priors_0.size(1), y.size(3)) 123 | priors_0 = priors_0[:,None,:,:].to(img.device) 124 | 125 | #priors_1 = Variable(torch.arange(0,y.size(3)).type_as(img.data), requires_grad=False)[None,None,:] 126 | priors_1 = torch.arange(0,y.size(3)).type_as(img.data)[None,None,:] 127 | priors_1 = (priors_1 + 0.5) * self.scale[0] #elf.base_1 128 | priors_1 = priors_1.expand(y.size(0), y.size(2), priors_1.size(2)) 129 | priors_1 = priors_1[:,None,:,:].to(img.device) 130 | 131 | anchor = self.anchors 132 | pred_boxes=[] 133 | pred_offsets=[] #we seperate anchor predictions here. And compute actual bounding boxes 134 | for i in range(self.numAnchors): 135 | 136 | offset = i*(self.numBBParams+self.numBBTypes) 137 | if self.rotation: 138 | rot_dif = (math.pi/2)*torch.tanh(y[:,3+offset:4+offset,:,:]) 139 | else: 140 | rot_dif = torch.zeros_like(y[:,3+offset:4+offset,:,:]) 141 | 142 | stackedPred = [ 143 | torch.sigmoid(y[:,0+offset:1+offset,:,:]), #0. confidence 144 | torch.tanh(y[:,1+offset:2+offset,:,:])*self.scale[0] + priors_1, #1. x-center 145 | torch.tanh(y[:,2+offset:3+offset,:,:])*self.scale[1] + priors_0, #2. y-center 146 | rot_dif + anchor[i]['rot'], #3. rotation (radians) 147 | torch.exp(y[:,4+offset:5+offset,:,:]) * anchor[i]['height'], #4. height (half), I don't think this needs scaled 148 | torch.exp(y[:,5+offset:6+offset,:,:]) * anchor[i]['width'], #5. width (half) as we scale the anchors in training 149 | ] 150 | 151 | 152 | if self.predNumNeighbors: 153 | stackedPred.append(1+y[:,6+offset:7+offset,:,:]) #+1 so predicted -1 is 0 neighbors 154 | extra=1 155 | else: 156 | extra=0 157 | for j in range(self.numBBTypes): 158 | stackedPred.append(torch.sigmoid(y[:,6+j+extra+offset:7+j+extra+offset,:,:])) #x. class prediction 159 | #stackedOffsets.append(y[:,6+j+offset:7+j+offset,:,:]) #x. class prediction 160 | pred_boxes.append(torch.cat(stackedPred, dim=1)) 161 | #pred_offsets.append(torch.cat(stackedOffsets, dim=1)) 162 | pred_offsets.append(y[:,offset:offset+self.numBBParams+self.numBBTypes,:,:]) 163 | 164 | if len(pred_boxes)>0: 165 | bbPredictions = torch.stack(pred_boxes, dim=1) 166 | offsetPredictions = torch.stack(pred_offsets, dim=1) 167 | 168 | bbPredictions = bbPredictions.transpose(2,4).contiguous()#from [batch, anchors, channel, rows, cols] to [batch, anchros, cols, rows, channels] 169 | bbPredictions = bbPredictions.view(bbPredictions.size(0),bbPredictions.size(1),-1,bbPredictions.size(4))#flatten to [batch, anchors, instances, channel] 170 | #avg_conf_per_anchor = bbPredictions[:,:,:,0].mean(dim=0).mean(dim=1) 171 | bbPredictions = bbPredictions.view(bbPredictions.size(0),-1,bbPredictions.size(3)) #[batch, instances+anchors, channel] 172 | 173 | offsetPredictions = offsetPredictions.permute(0,1,3,4,2).contiguous() 174 | else: 175 | bbPredictions=None 176 | offsetPredictions=None 177 | 178 | linePreds=[] 179 | offsetLinePreds=[] 180 | for i in range(self.predLineCount): 181 | offset = i*(self.numLineParams+self.numBBTypes) + self.numAnchors*(self.numBBParams+self.numBBTypes) 182 | stackedPred=[ 183 | torch.sigmoid(y[:,0+offset:1+offset,:,:]), #confidence 184 | torch.tanh(y[:,1+offset:2+offset,:,:])*self.scale[0] + priors_1, #x-center 185 | torch.tanh(y[:,2+offset:3+offset,:,:])*self.scale[1] + priors_0, #y-center 186 | (math.pi)*torch.tanh(y[:,3+offset:4+offset,:,:]), #rotation (radians) 187 | torch.exp(y[:,4+offset:5+offset,:,:])*self.meanH #scale (half-height), 188 | 189 | ] 190 | for j in range(self.numBBTypes): 191 | stackedPred.append(y[:,5+j+offset:6+j+offset,:,:]) #x. class prediction 192 | 193 | predictions = torch.cat(stackedPred, dim=1) 194 | predictions = predictions.transpose(1,3).contiguous()#from [batch, channel, rows, cols] to [batch, cols, rows, channels] 195 | predictions = predictions.view(predictions.size(0),-1,predictions.size(3))#flatten to [batch, instances, channel] 196 | linePreds.append(predictions) 197 | 198 | offsets = y[:,offset:offset+self.numLineParams+self.numBBTypes,:,:] 199 | offsets = offsets.permute(0,2,3,1).contiguous() 200 | offsetLinePreds.append(offsets) 201 | 202 | pointPreds=[] 203 | for i in range(self.predPointCount): 204 | offset = i*3 + self.numAnchors*(self.numBBParams+self.numBBTypes) 205 | predictions = torch.cat([ 206 | torch.sigmoid(y[:,0+offset:1+offset,:,:]), #confidence 207 | y[:,1+offset:2+offset,:,:] + priors_1, #x 208 | y[:,2+offset:3+offset,:,:] + priors_0 #y 209 | ], dim=1) 210 | 211 | predictions = predictions.transpose(1,3).contiguous()#from [batch, channel, rows, cols] to [batch, cols, rows, channels] 212 | predictions = predictions.view(predictions.size(0),-1,3)#flatten to [batch, instances, channel] 213 | pointPreds.append(predictions) 214 | 215 | pixelPreds=None 216 | if self.predPixelCount>0: 217 | y2=levels[-2] 218 | p=-3 219 | for module in self.net_up_modules[:-1]: 220 | #print('uping {} , {}'.format(y2.size(), levels[p].size())) 221 | y2 = module(y2,levels[p]) 222 | p-=1 223 | pixelPreds = self.net_up_modules[-1](y2) 224 | 225 | 226 | 227 | 228 | return bbPredictions, offsetPredictions, linePreds, offsetLinePreds, pointPreds, pixelPreds #, avg_conf_per_anchor 229 | 230 | def summary(self): 231 | """ 232 | Model summary 233 | """ 234 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 235 | params = sum([np.prod(p.size()) for p in model_parameters]) 236 | print('Trainable parameters: {}'.format(params)) 237 | print(self) 238 | 239 | def setForPairing(self): 240 | self.forPairing=True 241 | def save_final(module,input,output): 242 | self.final_features=output 243 | self.net_down_modules[-2].register_forward_hook(save_final) 244 | def setForGraphPairing(self,beginningOfLast=False,featuresFromHere=-1,featuresFromScale=-2,f2Here=None,f2Scale=None): 245 | def save_feats(module,input,output): 246 | self.saved_features=output 247 | if beginningOfLast: 248 | self.net_down_modules[-2][0].register_forward_hook(save_final) #after max pool 249 | self.last_channels= self.last_channels//2 #HACK 250 | else: 251 | typ = type( self.net_down_modules[featuresFromScale][featuresFromHere]) 252 | if typ == torch.nn.modules.activation.ReLU or typ == torch.nn.modules.MaxPool2d: 253 | self.net_down_modules[featuresFromScale][featuresFromHere].register_forward_hook(save_feats) 254 | if featuresFromScale<0: 255 | featuresFromScale = len(self.net_down_modules)+featuresFromScale 256 | self.save_scale = 2**featuresFromScale 257 | else: 258 | print('Layer {},{} of the final conv block was specified, but it is not a ReLU layer. Did you choose the right layer?'.format(featuresFromScale,featuresFromHere)) 259 | exit() 260 | if f2Here is not None: 261 | def save_feats2(module,input,output): 262 | self.saved_features2=output 263 | typ = type( self.net_down_modules[f2Scale][f2Here]) 264 | if typ == torch.nn.modules.activation.ReLU or typ==torch.nn.modules.MaxPool2d: 265 | self.net_down_modules[f2Scale][f2Here].register_forward_hook(save_feats2) 266 | if f2Scale<0: 267 | f2Scale = len(self.net_down_modules)+f2Scale 268 | self.save2_scale = 2**f2Scale 269 | else: 270 | print('Layer {},{} of the final conv block was specified, but it is not a ReLU layer. Did you choose the right layer?'.format(f2Scale,f2Here)) 271 | def setDEBUG(self): 272 | #self.debug=[None]*5 273 | #for i in range(0,1): 274 | # def save_layer(module,input,output): 275 | # self.debug[i]=output.cpu() 276 | # self.net_down_modules[i].register_forward_hook(save_layer) 277 | 278 | def save_layer0(module,input,output): 279 | self.debug0=output.cpu() 280 | self.net_down_modules[0].register_forward_hook(save_layer0) 281 | def save_layer1(module,input,output): 282 | self.debug1=output.cpu() 283 | self.net_down_modules[1].register_forward_hook(save_layer1) 284 | def save_layer2(module,input,output): 285 | self.debug2=output.cpu() 286 | self.net_down_modules[2].register_forward_hook(save_layer2) 287 | def save_layer3(module,input,output): 288 | self.debug3=output.cpu() 289 | self.net_down_modules[3].register_forward_hook(save_layer3) 290 | def save_layer4(module,input,output): 291 | self.debug4=output.cpu() 292 | self.net_down_modules[4].register_forward_hook(save_layer4) 293 | --------------------------------------------------------------------------------