├── weights └── README.md ├── data └── WIDER │ └── README.md ├── tinyfaces ├── __init__.py ├── models │ ├── __init__.py │ ├── loss.py │ ├── model.py │ └── utils.py ├── clustering │ ├── __init__.py │ ├── k_medoids.py │ └── cluster.py ├── utils │ ├── __init__.py │ └── visualize.py ├── datasets │ ├── __init__.py │ ├── templates.json │ ├── dense_overlap.py │ ├── wider_face.py │ └── processor.py ├── trainer.py ├── evaluation.py └── metrics.py ├── .pylintrc ├── tests ├── test_metrics.py └── test_dense_overlap.py ├── pyproject.toml ├── LICENSE ├── Makefile ├── .gitignore ├── README.md ├── detect_image.py ├── evaluate_model.py └── main.py /weights/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/WIDER/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tinyfaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tinyfaces/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tinyfaces/clustering/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tinyfaces/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils module""" 2 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | pylint: disable=missing-docstring,invalid-name,arguments-differ,assignment-from-no-return -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import loadmat 3 | 4 | from tinyfaces.metrics import rect_dist 5 | 6 | 7 | def test_rect_dist(x, y, gt_dist): 8 | d = rect_dist(x, y) 9 | print("Is my rect_dist code correct?", np.array_equal(d, gt_dist)) 10 | 11 | 12 | def main(): 13 | #TODO Generate rect_dist.mat file from Matlab 14 | truth = loadmat('rect_dist.mat') 15 | gt_dist = truth['d'][:, 0] 16 | x = truth['labelRect'] 17 | y = truth['tLabelRect'] 18 | test_rect_dist(x, y, gt_dist) 19 | print(rect_dist(x[0, :], y[0, :])) 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "tinyfaces" 3 | version = "1.0.0" 4 | authors = ["Varun Agrawal "] 5 | description = "Finding Tiny Faces in PyTorch" 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | loguru = "^0.7.2" 11 | numpy = "^1.26.4" 12 | scipy = "^1.12.0" 13 | Pillow = "^10.3.0" 14 | pyclust = "^0.2.0" 15 | pyclustering = "^0.10.1.2" 16 | torch = "^2.3.0" 17 | torchvision = "^0.18.0" 18 | tqdm = "^4.66.2" 19 | treelib = "^1.7.0" 20 | 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | pytest = "^8.0.0" 24 | 25 | [project.urls] 26 | "Homepage" = "https://github.com/varunagrawal/tiny-faces-pytorch" 27 | "Bug Tracker" = "https://github.com/varunagrawal/tiny-faces-pytorch/issues" 28 | 29 | [build-system] 30 | requires = ["poetry-core"] 31 | build-backend = "poetry.core.masonry.api" 32 | 33 | [tool.pytest.ini_options] 34 | filterwarnings = [ 35 | "error", 36 | "ignore::UserWarning", 37 | "ignore:.*:DeprecationWarning", 38 | ] 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Varun Agrawal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .RECIPEPREFIX += 2 | 3 | PYTHON=python 4 | ROOT=data/WIDER 5 | TRAINDATA=$(ROOT)/wider_face_split/wider_face_train_bbx_gt.txt 6 | VALDATA=$(ROOT)/wider_face_split/wider_face_val_bbx_gt.txt 7 | TESTDATA=$(ROOT)/wider_face_split/wider_face_test_filelist.txt 8 | 9 | CHECKPOINT=weights/checkpoint_50.pth 10 | 11 | main: 12 | $(PYTHON) main.py $(TRAINDATA) $(VALDATA) --dataset-root $(ROOT) 13 | 14 | resume: 15 | $(PYTHON) main.py $(TRAINDATA) $(VALDATA) --dataset-root $(ROOT) --resume $(CHECKPOINT) --epochs $(EPOCH) 16 | 17 | evaluate: 18 | $(PYTHON) evaluate_model.py $(VALDATA) --dataset-root $(ROOT) --checkpoint $(CHECKPOINT) --split val 19 | 20 | evaluation: 21 | cd eval_tools/ && octave wider_eval.m 22 | 23 | test: 24 | $(PYTHON) evaluate.py $(TESTDATA) --dataset-root $(ROOT) --checkpoint $(CHECKPOINT) --split test 25 | 26 | cluster: 27 | cd utils; $(PYTHON) cluster.py $(TRAIN_INSTANCES) 28 | 29 | debug: 30 | $(PYTHON) main.py $(TRAINDATA) $(VALDATA) --dataset-root $(ROOT) --batch_size 1 --workers 0 --debug 31 | 32 | debug-evaluate: 33 | $(PYTHON) evaluate.py $(VALDATA) --dataset-root $(ROOT) --checkpoint $(CHECKPOINT) --split val --batch_size 1 --workers 0 --debug 34 | -------------------------------------------------------------------------------- /tests/test_dense_overlap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import loadmat 3 | 4 | from tinyfaces.datasets.dense_overlap import compute_dense_overlap 5 | 6 | #TODO Use Peiyun's Matlab code to generate a dense_overlap.mat file 7 | d = loadmat("dense_overlap.mat") 8 | 9 | ofx, ofy = d['ofx'][0, 0], d['ofy'][0, 0] 10 | stx, sty = d['stx'][0, 0], d['sty'][0, 0] 11 | vsx, vsy = d['vsx'][0, 0], d['vsy'][0, 0] 12 | dx1, dy1, dx2, dy2 = d['dx1'], d['dy1'], d['dx2'], d['dy2'] 13 | dx1 = dx1.reshape(dx1.shape[2]) 14 | dy1 = dy1.reshape(dy1.shape[2]) 15 | dx2 = dx2.reshape(dx2.shape[2]) 16 | dy2 = dy2.reshape(dy2.shape[2]) 17 | 18 | gx1, gy1, gx2, gy2 = d['gx1'], d['gy1'], d['gx2'], d['gy2'] 19 | gx1 = gx1.reshape(gx1.shape[0]) 20 | gy1 = gy1.reshape(gy1.shape[0]) 21 | gx2 = gx2.reshape(gx2.shape[0]) 22 | gy2 = gy2.reshape(gy2.shape[0]) 23 | 24 | correct_iou = d['iou'] 25 | 26 | iou = compute_dense_overlap(ofx, ofy, stx, sty, vsx, vsy, dx1, dy1, dx2, dy2, 27 | gx1, gy1, gx2, gy2, 1, 1) 28 | 29 | print("Computed IOU") 30 | print("iou shape", iou.shape) 31 | print("correct iou shape", correct_iou.shape) 32 | print("Tensors are close enough?", np.allclose(iou, correct_iou)) 33 | print("Tensors are equal?", np.array_equal(iou, correct_iou)) 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | *.c 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | weights/*.pth 105 | 106 | data/WIDER 107 | 108 | # JetBrains IDE 109 | .idea 110 | 111 | val_results 112 | -------------------------------------------------------------------------------- /tinyfaces/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from torch.utils import data 6 | 7 | from tinyfaces.clustering.cluster import compute_kmedoids 8 | from tinyfaces.datasets.wider_face import WIDERFace 9 | 10 | 11 | def get_dataloader(datapath, 12 | args, 13 | num_templates=25, 14 | template_file="templates.json", 15 | img_transforms=None, 16 | train=True, 17 | split="train"): 18 | template_file = Path(__file__).parent / template_file 19 | 20 | if template_file.exists(): 21 | templates = json.load(open(template_file)) 22 | 23 | else: 24 | # Cluster the bounding boxes to get the templates 25 | dataset = WIDERFace(Path(args.traindata).expanduser(), []) 26 | clustering = compute_kmedoids(dataset.get_all_bboxes(), 27 | 1, 28 | indices=num_templates, 29 | option='pyclustering', 30 | max_clusters=num_templates) 31 | 32 | print("Canonical bounding boxes computed") 33 | templates = clustering[num_templates]['medoids'].tolist() 34 | 35 | # record templates 36 | json.dump(templates, open(template_file, "w")) 37 | 38 | templates = np.round(np.array(templates), decimals=8) 39 | 40 | dataset = WIDERFace(Path(datapath).expanduser(), 41 | templates, 42 | split=split, 43 | img_transforms=img_transforms, 44 | dataset_root=Path(args.dataset_root).expanduser(), 45 | debug=args.debug) 46 | data_loader = data.DataLoader(dataset, 47 | batch_size=args.batch_size, 48 | shuffle=train, 49 | num_workers=args.workers, 50 | pin_memory=True) 51 | 52 | return data_loader, templates 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tiny-faces-pytorch 2 | 3 | This is a PyTorch implementation of Peiyun Hu's [awesome tiny face detector](https://github.com/peiyunh/tiny). 4 | 5 | We use (and recommend) **Python 3.6+** for minimal pain when using this codebase (plus Python 3.6 has really cool features). 6 | 7 | **NOTE** Be sure to cite Peiyun's CVPR paper and this repo if you use this code! 8 | 9 | This code gives the following mAP results on the WIDER Face dataset: 10 | 11 | | Setting | mAP | 12 | |---------|-------| 13 | | easy | 0.902 | 14 | | medium | 0.892 | 15 | | hard | 0.797 | 16 | 17 | ## Getting Started 18 | 19 | - Clone this repository. 20 | - Download the WIDER Face dataset and annotations files to `data/WIDER`. 21 | - Install dependencies with `pip install -r requirements.txt`. 22 | 23 | Your data directory should look like this for WIDERFace 24 | 25 | ``` 26 | - data 27 | - WIDER 28 | - README.md 29 | - wider_face_split 30 | - WIDER_train 31 | - WIDER_val 32 | - WIDER_test 33 | ``` 34 | 35 | ## Pretrained Weights 36 | 37 | You can find the pretrained weights which get the above mAP results [here](https://www.dropbox.com/scl/fi/md0lxok2uh2achx8r58mk/checkpoint_50.pth?rlkey=9y1acwj1k6c57tqck14t6as18&dl=0). 38 | 39 | ## Training 40 | 41 | Just type `make` at the repo root and you should be good to go! 42 | 43 | In case you wish to change some settings (such as data location), you can modify the `Makefile` which should be super easy to work with. 44 | 45 | ## Evaluation 46 | 47 | To run evaluation and generate the output files as per the WIDERFace specification, simply run `make evaluate`. The results will be stored in the `val_results` directory. 48 | 49 | You can then use the dataset's `eval_tools` to generate the mAP numbers (this needs Matlab/Octave). 50 | 51 | Similarly, to run the model on the test set, run `make test` to generate results in the `test_results` directory. 52 | 53 | ## Deployment 54 | 55 | To run the model on your own image, please use the `detect_image.py` script. 56 | You may have to adjust the probability and NMS thresholds to get the best results. 57 | -------------------------------------------------------------------------------- /tinyfaces/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import ImageDraw, ImageFont 3 | 4 | 5 | def draw_bounding_box(img, bbox, labels): 6 | draw = ImageDraw.Draw(img) 7 | font = ImageFont.load_default() 8 | color = tuple(np.random.choice(range(100, 256), size=3)) 9 | 10 | draw.rectangle((bbox[0], bbox[1], bbox[2], bbox[3]), outline=color) 11 | 12 | for i, k in enumerate(labels.keys()): 13 | w, h = font.getsize(labels[k]) 14 | # draw.rectangle((bbox[0], bbox[1] + i*h, bbox[0] + w, bbox[1] + (i+2)*h), fill=color) 15 | draw.text((bbox[0], bbox[1] + i * h), 16 | "{0}:{1:.3} ".format(k, labels[k]), 17 | fill=color) 18 | 19 | return img 20 | 21 | 22 | def draw_all_boxes(img, bboxes, categories): 23 | for bbox, c in zip(bboxes, categories): 24 | img = draw_bounding_box(img, bbox, c) 25 | 26 | img.show() 27 | 28 | 29 | def visualize_bboxes(image, bboxes): 30 | """ 31 | 32 | :param image: PIL image 33 | :param bboxes: 34 | :return: 35 | """ 36 | print("Number of GT bboxes", bboxes.shape[0]) 37 | for idx, bbox in enumerate(bboxes): 38 | bbox = np.round(np.array(bbox)) 39 | # print(bbox) 40 | image = draw_bounding_box(image, bbox, {"name": "{0}".format(idx)}) 41 | 42 | image.show(title="BBoxes") 43 | 44 | 45 | def render_and_save_bboxes(image, 46 | image_id, 47 | bboxes, 48 | scores, 49 | scales, 50 | directory="qualitative"): 51 | """ 52 | Render the bboxes on the image and save the image 53 | :param image: PIL image 54 | :param image_id: 55 | :param bboxes: 56 | :param scores: 57 | :param scales: 58 | :param directory: 59 | :return: 60 | """ 61 | for idx, bbox in enumerate(bboxes): 62 | bbox = np.round(np.array(bbox)) 63 | image = draw_bounding_box(image, bbox, { 64 | 'score': scores[idx], 65 | 'scale': scales[idx] 66 | }) 67 | 68 | image.save("{0}/{1}.jpg".format(directory, image_id)) 69 | -------------------------------------------------------------------------------- /tinyfaces/datasets/templates.json: -------------------------------------------------------------------------------- 1 | [ 2 | [-86.3382352941177, -113.444117647059, 86.3382352941177, 113.444117647059, 0.500000000000000], 3 | [-48.7500000000000, -65.2500000000000, 48.7500000000000, 65.2500000000000, 0.500000000000000], 4 | [-33.2500000000000, -43.7500000000000, 33.2500000000000, 43.7500000000000, 0.500000000000000], 5 | [-25.7500000000000, -33.7500000000000, 25.7500000000000, 33.7500000000000, 0.500000000000000], 6 | [-40.5000000000000, -54.5000000000000, 40.5000000000000, 54.5000000000000, 1], 7 | [-34.5000000000000, -43.5000000000000, 34.5000000000000, 43.5000000000000, 1], 8 | [-28.5000000000000, -38, 28.5000000000000, 38, 1], 9 | [-25.6589050000000, -31.3221500000000, 25.6589050000000, 31.3221500000000, 1], 10 | [-21.6137000000000, -27.5976700000000, 21.6137000000000, 27.5976700000000, 1], 11 | [-20, -22.5000000000000, 20, 22.5000000000000, 1], 12 | [-17.5000000000000, -25.5000000000000, 17.5000000000000, 25.5000000000000, 1], 13 | [-16.3279854000000, -20.8855500000000, 16.3279854000000, 20.8855500000000, 1], 14 | [-29.4755000000000, -34.4803000000000, 29.4755000000000, 34.4803000000000, 2], 15 | [-25.4202000000000, -37.1060900000000, 25.4202000000000, 37.1060900000000, 2], 16 | [-24.2118000000000, -30.2145600000000, 24.2118000000000, 30.2145600000000, 2], 17 | [-22.0129000000000, -24.7059000000000, 22.0129000000000, 24.7059000000000, 2], 18 | [-19.3142000000000, -28.0202000000000, 19.3142000000000, 28.0202000000000, 2], 19 | [-17.9677000000000, -22.7849000000000, 17.9677000000000, 22.7849000000000, 2], 20 | [-15.6123000000000, -19.5907000000000, 15.6123000000000, 19.5907000000000, 2], 21 | [-13.1842000000000, -18.3421000000000, 13.1842000000000, 18.3421000000000, 2], 22 | [-13.5027700000000, -15.0792000000000, 13.5027700000000, 15.0792000000000, 2], 23 | [-11.1091000000000, -16.4909000000000, 11.1091000000000, 16.4909000000000, 2], 24 | [-10.9768600000000, -14.0478000000000, 10.9768600000000, 14.0478000000000, 2], 25 | [-9.97219999999993, -12.4105000000000, 9.97219999999993, 12.4105000000000, 2], 26 | [-9.66129999999998, -10.6161000000000, 9.66129999999998, 10.6161000000000, 2] 27 | ] -------------------------------------------------------------------------------- /tinyfaces/clustering/k_medoids.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | 6 | def kMedoids(distances, k): 7 | """ 8 | https://github.com/salspaugh/machine_learning/blob/master/clustering/kmedoids.py 9 | 10 | :param distances: 11 | :param k: 12 | :return: 13 | """ 14 | n = distances.shape[0] 15 | medoid_idxs = np.random.choice(n, size=k, replace=False) 16 | old_medoids_idxs = np.zeros(k) 17 | 18 | while not np.all( 19 | medoid_idxs == old_medoids_idxs): # and n_iter_ < max_iter_ 20 | # retain a copy of the old assignments 21 | old_medoids_idxs = np.copy(medoid_idxs) 22 | 23 | cluster_idxs = get_cluster_indices(distances, medoid_idxs) 24 | 25 | medoid_idxs = update_medoids(distances, cluster_idxs, medoid_idxs) 26 | 27 | return medoid_idxs, cluster_idxs 28 | 29 | 30 | def get_cluster_indices(distances, medoid_idxs): 31 | cluster_idxs = np.argmin(distances[medoid_idxs, :], axis=0) 32 | return cluster_idxs 33 | 34 | 35 | def update_medoids(distances, cluster_idxs, medoid_idxs): 36 | for cluster_idx in range(medoid_idxs.shape[0]): 37 | if sum(cluster_idxs == cluster_idx) == 0: 38 | warnings.warn("Cluster {} is empty!".format(cluster_idx)) 39 | continue 40 | 41 | curr_cost = np.sum(distances[medoid_idxs[cluster_idx], 42 | cluster_idxs == cluster_idx]) 43 | 44 | # Extract the distance matrix between the data points 45 | # inside the cluster_idx 46 | D_in = distances[cluster_idxs == cluster_idx, :] 47 | D_in = D_in[:, cluster_idxs == cluster_idx] 48 | 49 | # Calculate all costs there exists between all 50 | # the data points in the cluster_idx 51 | all_costs = np.sum(D_in, axis=1) 52 | 53 | # Find the index for the smallest cost in cluster_idx 54 | min_cost_idx = np.argmin(all_costs) 55 | 56 | # find the value of the minimum cost in cluster_idx 57 | min_cost = all_costs[min_cost_idx] 58 | 59 | # If the minimum cost is smaller than that 60 | # exhibited by the currently used medoid, 61 | # we switch to using the new medoid in cluster_idx 62 | if min_cost < curr_cost: 63 | # Find data points that belong to cluster_idx, 64 | # and assign the newly found medoid as the medoid 65 | # for cluster c 66 | medoid_idxs[cluster_idx] = np.where( 67 | cluster_idxs == cluster_idx)[0][min_cost_idx] 68 | 69 | return medoid_idxs 70 | -------------------------------------------------------------------------------- /detect_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to evaluate model. 3 | Look at Makefile to see `evaluate` command. 4 | """ 5 | 6 | import argparse 7 | import json 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image, ImageDraw 12 | from torchvision import transforms 13 | 14 | from tinyfaces.evaluation import get_detections, get_model 15 | 16 | 17 | def arguments(): 18 | parser = argparse.ArgumentParser("Image Evaluator") 19 | parser.add_argument("image_path") 20 | parser.add_argument("--checkpoint", 21 | help="The path to the model checkpoint", 22 | default="") 23 | parser.add_argument("--prob_thresh", type=float, default=0.6) 24 | parser.add_argument("--nms_thresh", type=float, default=0.3) 25 | 26 | return parser.parse_args() 27 | 28 | 29 | def run(model, image, templates, prob_thresh, nms_thresh, device): 30 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225]) 32 | img_transforms = transforms.Compose([transforms.ToTensor(), normalize]) 33 | 34 | # Convert to tensor 35 | img = transforms.functional.to_tensor(image) 36 | 37 | rf = {'size': [859, 859], 'stride': [8, 8], 'offset': [-1, -1]} 38 | 39 | dets = get_detections(model, 40 | img, 41 | templates, 42 | rf, 43 | img_transforms, 44 | prob_thresh, 45 | nms_thresh, 46 | scales=(0, ), 47 | device=device) 48 | 49 | return dets 50 | 51 | 52 | def main(): 53 | args = arguments() 54 | 55 | if torch.cuda.is_available(): 56 | device = torch.device('cuda:0') 57 | else: 58 | device = torch.device('cpu') 59 | 60 | templates = json.load(open('tinyfaces/datasets/templates.json')) 61 | templates = np.round(np.array(templates), decimals=8) 62 | 63 | num_templates = templates.shape[0] 64 | 65 | model = get_model(args.checkpoint, num_templates=num_templates) 66 | print("Loaded model", args.checkpoint) 67 | 68 | image = Image.open(args.image_path).convert('RGB') 69 | 70 | with torch.no_grad(): 71 | # run model on image 72 | dets = run(model, image, templates, args.prob_thresh, args.nms_thresh, 73 | device) 74 | 75 | draw = ImageDraw.Draw(image) 76 | for det in dets: 77 | draw.rectangle(((det[0], det[1]), (det[2], det[3])), width=4) 78 | 79 | image.show() 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /tinyfaces/datasets/dense_overlap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_dense_overlap(ofx, 5 | ofy, 6 | stx, 7 | sty, 8 | vsx, 9 | vsy, 10 | dx1, 11 | dy1, 12 | dx2, 13 | dy2, 14 | gx1, 15 | gy1, 16 | gx2, 17 | gy2, 18 | zmx=1, 19 | zmy=1): 20 | """ 21 | Compute the dense IoU 22 | """ 23 | num_templates = dx1.shape[0] 24 | num_gt = gx1.shape[0] 25 | 26 | ty, tx = (vsy - 1) * zmy + 1, ( 27 | vsx - 1) * zmx + 1 # + 1 is by definition of receptive field 28 | overlap = np.zeros((ty, tx, num_templates, num_gt)) 29 | 30 | for i in range(num_gt): 31 | bbox_x1, bbox_y1, bbox_x2, bbox_y2 = gx1[i], gy1[i], gx2[i], gy2[i] 32 | bbox_w, bbox_h = bbox_x2 - bbox_x1 + 1, bbox_y2 - bbox_y1 + 1 33 | bbox_area = bbox_w * bbox_h 34 | 35 | for j in range(num_templates): 36 | delta_x1, delta_y1, delta_x2, delta_y2 = dx1[j], dy1[j], dx2[ 37 | j], dy2[j] 38 | filter_h = delta_y2 - delta_y1 + 1 39 | filter_w = delta_x2 - delta_x1 + 1 40 | 41 | filter_area = filter_w * filter_h 42 | 43 | xmax = tx 44 | ymax = ty 45 | 46 | # enumerate spatial locations 47 | for x in range(xmax): 48 | for y in range(ymax): 49 | cx = ofx + x * (stx / zmx) 50 | cy = ofy + y * (sty / zmy) 51 | 52 | x1 = delta_x1 + cx 53 | y1 = delta_y1 + cy 54 | x2 = delta_x2 + cx 55 | y2 = delta_y2 + cy 56 | 57 | xx1 = max(x1, bbox_x1) 58 | yy1 = max(y1, bbox_y1) 59 | xx2 = min(x2, bbox_x2) 60 | yy2 = min(y2, bbox_y2) 61 | 62 | int_w = xx2 - xx1 + 1 63 | int_h = yy2 - yy1 + 1 64 | 65 | if int_h > 0 and int_w > 0: 66 | int_area = int_w * int_h 67 | union_area = filter_area + bbox_area - int_area 68 | 69 | overlap[y, x, j, i] = int_area / union_area 70 | 71 | else: 72 | overlap[y, x, j, i] = 0 73 | 74 | # truncate the number of decimals to match MATLAB behavior 75 | return np.around(overlap, decimals=14) 76 | -------------------------------------------------------------------------------- /tinyfaces/trainer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from torch.nn import functional as nnfunc 6 | from torchvision import transforms 7 | 8 | 9 | def print_state(idx, epoch, size, loss_cls, loss_reg): 10 | if epoch >= 0: 11 | message = "Epoch: [{0}][{1}/{2}]\t".format(epoch, idx, size) 12 | else: 13 | message = "Val: [{0}/{1}]\t".format(idx, size) 14 | 15 | print(message + 16 | '\tloss_cls: {loss_cls:.6f}' \ 17 | '\tloss_reg: {loss_reg:.6f}'.format(loss_cls=loss_cls, loss_reg=loss_reg)) 18 | 19 | 20 | def save_checkpoint(state, filename="checkpoint.pth", save_path="weights"): 21 | # check if the save directory exists 22 | if not Path(save_path).exists(): 23 | Path(save_path).mkdir() 24 | 25 | save_path = Path(save_path, filename) 26 | torch.save(state, str(save_path)) 27 | 28 | 29 | def visualize_output(img, 30 | output, 31 | templates, 32 | proc, 33 | prob_thresh=0.55, 34 | nms_thresh=0.1): 35 | tensor_to_image = transforms.ToPILImage() 36 | 37 | mean = [0.485, 0.456, 0.406] 38 | std = [0.229, 0.224, 0.225] 39 | for t, m, s in zip(img[0], mean, std): 40 | t.mul_(s).add_(m) 41 | 42 | image = tensor_to_image(img[0]) # Index into the batch 43 | 44 | cls_map = nnfunc.sigmoid( 45 | output[:, 0:templates.shape[0], :, :]).data.cpu().numpy().transpose( 46 | (0, 2, 3, 1))[0, :, :, :] 47 | reg_map = output[:, 48 | templates.shape[0]:, :, :].data.cpu().numpy().transpose( 49 | (0, 2, 3, 1))[0, :, :, :] 50 | 51 | print(np.sort(np.unique(cls_map))[::-1]) 52 | proc.visualize_heatmaps(image, 53 | cls_map, 54 | reg_map, 55 | templates, 56 | prob_thresh=prob_thresh, 57 | nms_thresh=nms_thresh) 58 | 59 | p = input("Continue? [Yn]") 60 | if p.lower().strip() == 'n': 61 | exit(0) 62 | 63 | 64 | def draw_bboxes(image, img_id, bboxes, scores, scales, processor): 65 | processor.render_and_save_bboxes(image, img_id, bboxes, scores, scales) 66 | 67 | 68 | def train(model, loss_fn, optimizer, dataloader, epoch, device): 69 | model = model.to(device) 70 | model.train() 71 | 72 | for idx, (img, class_map, regression_map) in enumerate(dataloader): 73 | x = img.float().to(device) 74 | 75 | class_map_var = class_map.float().to(device) 76 | regression_map_var = regression_map.float().to(device) 77 | 78 | output = model(x) 79 | loss = loss_fn(output, class_map_var, regression_map_var) 80 | 81 | # visualize_output(img, output, dataloader.dataset.templates) 82 | 83 | optimizer.zero_grad() 84 | # Get the gradients 85 | # torch will automatically mask the gradients to 0 where applicable! 86 | loss.backward() 87 | optimizer.step() 88 | 89 | print_state(idx, epoch, len(dataloader), loss_fn.class_average.average, 90 | loss_fn.reg_average.average) 91 | -------------------------------------------------------------------------------- /evaluate_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to evaluate model. 3 | Look at Makefile to see `evaluate` command. 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from torchvision import transforms 10 | from tqdm import tqdm 11 | 12 | from tinyfaces.datasets import get_dataloader 13 | from tinyfaces.evaluation import get_detections, get_model, write_results 14 | 15 | 16 | def arguments(): 17 | parser = argparse.ArgumentParser("Model Evaluator") 18 | parser.add_argument("dataset") 19 | parser.add_argument("--split", default="val") 20 | parser.add_argument("--dataset-root") 21 | parser.add_argument("--checkpoint", 22 | help="The path to the model checkpoint", 23 | default="") 24 | parser.add_argument("--prob_thresh", type=float, default=0.03) 25 | parser.add_argument("--nms_thresh", type=float, default=0.3) 26 | parser.add_argument("--workers", default=8, type=int) 27 | parser.add_argument("--batch_size", default=1, type=int) 28 | parser.add_argument("--results_dir", default=None) 29 | parser.add_argument("--debug", action="store_true") 30 | 31 | return parser.parse_args() 32 | 33 | 34 | def dataloader(args): 35 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 36 | std=[0.229, 0.224, 0.225]) 37 | val_transforms = transforms.Compose([transforms.ToTensor(), normalize]) 38 | 39 | val_loader, templates = get_dataloader(args.dataset, 40 | args, 41 | train=False, 42 | split=args.split, 43 | img_transforms=val_transforms) 44 | return val_loader, templates 45 | 46 | 47 | def run(model, 48 | val_loader, 49 | templates, 50 | prob_thresh, 51 | nms_thresh, 52 | device, 53 | split, 54 | results_dir=None, 55 | debug=False): 56 | for _, (img, filename) in tqdm(enumerate(val_loader), 57 | total=len(val_loader)): 58 | dets = get_detections(model, 59 | img[0], 60 | templates, 61 | val_loader.dataset.rf, 62 | val_loader.dataset.transforms, 63 | prob_thresh, 64 | nms_thresh, 65 | device=device) 66 | 67 | write_results(dets, filename[0], split, results_dir) 68 | return dets 69 | 70 | 71 | def main(): 72 | args = arguments() 73 | 74 | if torch.cuda.is_available(): 75 | device = torch.device('cuda:0') 76 | else: 77 | device = torch.device('cpu') 78 | 79 | val_loader, templates = dataloader(args) 80 | num_templates = templates.shape[0] 81 | 82 | model = get_model(args.checkpoint, num_templates=num_templates) 83 | 84 | with torch.no_grad(): 85 | # run model on val/test set and generate results files 86 | run(model, 87 | val_loader, 88 | templates, 89 | args.prob_thresh, 90 | args.nms_thresh, 91 | device, 92 | args.split, 93 | results_dir=args.results_dir, 94 | debug=args.debug) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /tinyfaces/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from tinyfaces.models.utils import balance_sampling 5 | 6 | 7 | class AvgMeter: 8 | 9 | def __init__(self): 10 | self.average = 0 11 | self.num_averaged = 0 12 | 13 | def update(self, loss, size): 14 | n = self.num_averaged 15 | m = n + size 16 | self.average = ((n * self.average) + float(loss)) / m 17 | self.num_averaged = m 18 | 19 | def reset(self): 20 | self.average = 0 21 | self.num_averaged = 0 22 | 23 | 24 | class DetectionCriterion(nn.Module): 25 | """ 26 | The loss for the Tiny Faces detector 27 | """ 28 | 29 | def __init__(self, n_templates=25, reg_weight=1, pos_fraction=0.5): 30 | super().__init__() 31 | 32 | # We don't want per element averaging. 33 | # We want to normalize over the batch or positive samples. 34 | self.regression_criterion = nn.SmoothL1Loss(reduction='none') 35 | self.classification_criterion = nn.SoftMarginLoss(reduction='none') 36 | self.n_templates = n_templates 37 | self.reg_weight = reg_weight 38 | self.pos_fraction = pos_fraction 39 | 40 | self.class_average = AvgMeter() 41 | self.reg_average = AvgMeter() 42 | 43 | self.masked_class_loss = None 44 | self.masked_reg_loss = None 45 | self.total_loss = None 46 | 47 | def balance_sample(self, class_map): 48 | device = class_map.device 49 | label_class_np = class_map.cpu().numpy() 50 | # iterate through batch 51 | for idx in range(label_class_np.shape[0]): 52 | label_class_np[idx, ...] = balance_sampling( 53 | label_class_np[idx, ...], pos_fraction=self.pos_fraction) 54 | 55 | class_map = torch.from_numpy(label_class_np).to(device) 56 | 57 | return class_map 58 | 59 | def hard_negative_mining(self, classification, class_map): 60 | loss_class_map = nn.functional.soft_margin_loss( 61 | classification.detach(), class_map, reduction='none') 62 | class_map[loss_class_map < 0.03] = 0 63 | return class_map 64 | 65 | def forward(self, output, class_map, regression_map): 66 | classification = output[:, 0:self.n_templates, :, :] 67 | regression = output[:, self.n_templates:, :, :] 68 | 69 | # online hard negative mining 70 | class_map = self.hard_negative_mining(classification, class_map) 71 | # balance sampling 72 | class_map = self.balance_sample(class_map) 73 | 74 | class_loss = self.classification_criterion(classification, class_map) 75 | 76 | # weights used to mask out invalid regions i.e. where the label is 0 77 | class_mask = (class_map != 0).type(output.dtype) 78 | # Mask the classification loss 79 | self.masked_class_loss = class_mask * class_loss 80 | 81 | reg_loss = self.regression_criterion(regression, regression_map) 82 | # make same size as reg_map 83 | reg_mask = (class_map > 0).repeat(1, 4, 1, 1).type(output.dtype) 84 | 85 | self.masked_reg_loss = reg_mask * reg_loss # / reg_loss.size(0) 86 | 87 | self.total_loss = self.masked_class_loss.sum() + \ 88 | self.reg_weight * self.masked_reg_loss.sum() 89 | 90 | self.class_average.update(self.masked_class_loss.sum(), output.size(0)) 91 | self.reg_average.update(self.masked_reg_loss.sum(), output.size(0)) 92 | 93 | return self.total_loss 94 | 95 | def reset(self): 96 | self.class_average.reset() 97 | self.reg_average.reset() 98 | -------------------------------------------------------------------------------- /tinyfaces/evaluation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from torchvision import transforms 7 | 8 | from tinyfaces.models.model import DetectionModel 9 | from tinyfaces.models.utils import get_bboxes 10 | 11 | 12 | def get_model(checkpoint=None, num_templates=25): 13 | model = DetectionModel(num_templates=num_templates) 14 | if checkpoint: 15 | checkpoint = torch.load(checkpoint) 16 | model.load_state_dict(checkpoint["model"]) 17 | return model 18 | 19 | 20 | def get_detections(model, 21 | img, 22 | templates, 23 | rf, 24 | img_transforms, 25 | prob_thresh=0.65, 26 | nms_thresh=0.3, 27 | scales=(-2, -1, 0, 1), 28 | device=None): 29 | model = model.to(device) 30 | model.eval() 31 | 32 | dets = np.empty((0, 5)) # store bbox (x1, y1, x2, y2), score 33 | 34 | num_templates = templates.shape[0] 35 | 36 | # Evaluate over multiple scale 37 | scales_list = [2**x for x in scales] 38 | 39 | # convert tensor to PIL image so we can perform resizing 40 | image = transforms.functional.to_pil_image(img) 41 | 42 | min_side = np.min(image.size) 43 | 44 | for scale in scales_list: 45 | # scale the images 46 | scaled_image = transforms.functional.resize(image, 47 | int(min_side * scale)) 48 | 49 | # normalize the images 50 | img = img_transforms(scaled_image) 51 | 52 | # add batch dimension 53 | img.unsqueeze_(0) 54 | 55 | # now run the model 56 | x = img.float().to(device) 57 | 58 | output = model(x) 59 | 60 | # first `num_templates` channels are class maps 61 | score_cls = output[:, :num_templates, :, :] 62 | prob_cls = torch.sigmoid(score_cls) 63 | 64 | score_cls = score_cls.data.cpu().numpy().transpose((0, 2, 3, 1)) 65 | prob_cls = prob_cls.data.cpu().numpy().transpose((0, 2, 3, 1)) 66 | 67 | score_reg = output[:, num_templates:, :, :] 68 | score_reg = score_reg.data.cpu().numpy().transpose((0, 2, 3, 1)) 69 | 70 | t_bboxes, scores = get_bboxes(score_cls, score_reg, prob_cls, 71 | templates, prob_thresh, rf, scale) 72 | 73 | scales = np.ones((t_bboxes.shape[0], 1)) / scale 74 | 75 | # append scores at the end for NMS 76 | d = np.hstack((t_bboxes, scores)) 77 | 78 | dets = np.vstack((dets, d)) 79 | 80 | scores = torch.from_numpy(dets[:, 4]) 81 | dets = torch.from_numpy(dets[:, :4]) 82 | 83 | # Apply NMS 84 | keep = torchvision.ops.nms(dets, scores, nms_thresh) 85 | dets = dets[keep] 86 | 87 | return dets.numpy() 88 | 89 | 90 | def write_results(dets, img_path, split, results_dir=None): 91 | results_dir = results_dir or f"{split}_results" 92 | results_dir = Path(results_dir) 93 | 94 | if not results_dir.exists(): 95 | results_dir.mkdir(parents=True) 96 | 97 | filename = results_dir / img_path.replace('jpg', 'txt') 98 | 99 | file_dir = filename.parent 100 | if not file_dir.exists(): 101 | file_dir.mkdir(parents=True) 102 | 103 | with open(filename, 'w') as f: 104 | f.write(img_path.split('/')[-1] + "\n") 105 | f.write(str(dets.shape[0]) + "\n") 106 | 107 | for x in dets: 108 | left, top = np.round(x[0]), np.round(x[1]) 109 | width = np.round(x[2] - x[0] + 1) 110 | height = np.round(x[3] - x[1] + 1) 111 | score = x[4] 112 | d = f"{int(left)} {int(top)} {int(width)} {int(height)} {score}\n" 113 | 114 | f.write(d) 115 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main model training training script. 3 | See Makefile `main` to see usage. 4 | """ 5 | import argparse 6 | from pathlib import Path 7 | 8 | import torch 9 | from torch import optim 10 | from torchvision import transforms 11 | 12 | from tinyfaces import trainer 13 | from tinyfaces.datasets import get_dataloader 14 | from tinyfaces.models.loss import DetectionCriterion 15 | from tinyfaces.models.model import DetectionModel 16 | 17 | 18 | def arguments(): 19 | parser = argparse.ArgumentParser() 20 | 21 | parser.add_argument("traindata") 22 | parser.add_argument("valdata") 23 | parser.add_argument("--dataset-root", default="") 24 | parser.add_argument("--dataset", default="WIDERFace") 25 | parser.add_argument("--lr", default=1e-4, type=float) 26 | parser.add_argument("--weight-decay", default=0.0005, type=float) 27 | parser.add_argument("--momentum", default=0.9, type=float) 28 | parser.add_argument("--batch_size", default=12, type=int) 29 | parser.add_argument("--workers", default=8, type=int) 30 | parser.add_argument("--start-epoch", default=0, type=int) 31 | parser.add_argument("--epochs", default=50, type=int) 32 | parser.add_argument("--save-every", default=10, type=int) 33 | parser.add_argument("--resume", action="store_true") 34 | parser.add_argument("--debug", action="store_true") 35 | 36 | return parser.parse_args() 37 | 38 | 39 | def main(): 40 | args = arguments() 41 | 42 | num_templates = 25 # aka the number of clusters 43 | 44 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 45 | std=[0.229, 0.224, 0.225]) 46 | img_transforms = transforms.Compose([transforms.ToTensor(), normalize]) 47 | train_loader, _ = get_dataloader(args.traindata, 48 | args, 49 | num_templates, 50 | img_transforms=img_transforms) 51 | 52 | model = DetectionModel(num_objects=1, num_templates=num_templates) 53 | loss_fn = DetectionCriterion(num_templates) 54 | 55 | # directory where we'll store model weights 56 | weights_dir = Path("weights") 57 | if not weights_dir.exists(): 58 | weights_dir.mkdir() 59 | 60 | # check for CUDA 61 | if torch.cuda.is_available(): 62 | device = torch.device('cuda:0') 63 | else: 64 | device = torch.device('cpu') 65 | 66 | # As per Peiyun, SGD is more robust than Adam and works really well 67 | optimizer = optim.SGD(model.learnable_parameters(args.lr), 68 | lr=args.lr, 69 | momentum=args.momentum, 70 | weight_decay=args.weight_decay) 71 | # optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 72 | 73 | if args.resume: 74 | checkpoint = torch.load(args.resume) 75 | model.load_state_dict(checkpoint['model']) 76 | optimizer.load_state_dict(checkpoint['optimizer']) 77 | # Set the start epoch if it has not been 78 | if not args.start_epoch: 79 | args.start_epoch = checkpoint['epoch'] 80 | 81 | scheduler = optim.lr_scheduler.StepLR(optimizer, 82 | step_size=20, 83 | last_epoch=args.start_epoch - 1) 84 | 85 | # train and evalute for `epochs` 86 | for epoch in range(args.start_epoch, args.epochs): 87 | trainer.train(model, 88 | loss_fn, 89 | optimizer, 90 | train_loader, 91 | epoch, 92 | device=device) 93 | scheduler.step() 94 | 95 | if (epoch + 1) % args.save_every == 0: 96 | trainer.save_checkpoint( 97 | { 98 | 'epoch': epoch + 1, 99 | 'batch_size': train_loader.batch_size, 100 | 'model': model.state_dict(), 101 | 'optimizer': optimizer.state_dict() 102 | }, 103 | filename="checkpoint_{0}.pth".format(epoch + 1), 104 | save_path=weights_dir) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /tinyfaces/models/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torchvision.models import ResNet101_Weights, resnet101 5 | 6 | 7 | class DetectionModel(nn.Module): 8 | """ 9 | Hybrid Model from Tiny Faces paper 10 | """ 11 | 12 | def __init__(self, 13 | base_model=resnet101, 14 | pretrained_weights=ResNet101_Weights.IMAGENET1K_V1, 15 | num_templates=1, 16 | num_objects=1): 17 | super().__init__() 18 | # 4 is for the bounding box offsets 19 | output = (num_objects + 4) * num_templates 20 | self.model = base_model(weights=pretrained_weights) 21 | 22 | # delete unneeded layer 23 | del self.model.layer4 24 | 25 | self.score_res3 = nn.Conv2d(in_channels=512, 26 | out_channels=output, 27 | kernel_size=1, 28 | padding=0) 29 | self.score_res4 = nn.Conv2d(in_channels=1024, 30 | out_channels=output, 31 | kernel_size=1, 32 | padding=0) 33 | 34 | self.score4_upsample = nn.ConvTranspose2d(in_channels=output, 35 | out_channels=output, 36 | kernel_size=4, 37 | stride=2, 38 | padding=1, 39 | bias=False) 40 | self._init_bilinear() 41 | 42 | def _init_weights(self): 43 | pass 44 | 45 | def _init_bilinear(self): 46 | """ 47 | Initialize the ConvTranspose2d layer with a bilinear interpolation mapping 48 | :return: 49 | """ 50 | k = self.score4_upsample.kernel_size[0] 51 | factor = np.floor((k + 1) / 2) 52 | if k % 2 == 1: 53 | center = factor 54 | else: 55 | center = factor + 0.5 56 | C = np.arange(1, 5) 57 | 58 | f = np.zeros((self.score4_upsample.in_channels, 59 | self.score4_upsample.out_channels, k, k)) 60 | 61 | for i in range(self.score4_upsample.out_channels): 62 | f[i, i, :, :] = (np.ones((1, k)) - (np.abs(C-center)/factor)).T @ \ 63 | (np.ones((1, k)) - (np.abs(C-center)/factor)) 64 | 65 | self.score4_upsample.weight = torch.nn.Parameter(data=torch.Tensor(f)) 66 | 67 | def learnable_parameters(self, lr): 68 | parameters = [ 69 | # Be T'Challa. Don't freeze. 70 | { 71 | 'params': self.model.parameters(), 72 | 'lr': lr 73 | }, 74 | { 75 | 'params': self.score_res3.parameters(), 76 | 'lr': 0.1 * lr 77 | }, 78 | { 79 | 'params': self.score_res4.parameters(), 80 | 'lr': 1 * lr 81 | }, 82 | { 83 | 'params': self.score4_upsample.parameters(), 84 | 'lr': 0 85 | } # freeze UpConv layer 86 | ] 87 | return parameters 88 | 89 | def forward(self, x): 90 | x = self.model.conv1(x) 91 | x = self.model.bn1(x) 92 | x = self.model.relu(x) 93 | x = self.model.maxpool(x) 94 | 95 | x = self.model.layer1(x) 96 | # res2 = x 97 | 98 | x = self.model.layer2(x) 99 | res3 = x 100 | 101 | x = self.model.layer3(x) 102 | res4 = x 103 | 104 | score_res3 = self.score_res3(res3) 105 | 106 | score_res4 = self.score_res4(res4) 107 | score4 = self.score4_upsample(score_res4) 108 | 109 | # We need to do some fancy cropping to accomodate the difference in image sizes in eval 110 | if not self.training: 111 | # from vl_feats DagNN Crop 112 | cropv = score4.size(2) - score_res3.size(2) 113 | cropu = score4.size(3) - score_res3.size(3) 114 | # if the crop is 0 (both the input sizes are the same) 115 | # we do some arithmetic to allow python to index correctly 116 | if cropv == 0: 117 | cropv = -score4.size(2) 118 | if cropu == 0: 119 | cropu = -score4.size(3) 120 | 121 | score4 = score4[:, :, 0:-cropv, 0:-cropu] 122 | else: 123 | # match the dimensions arbitrarily 124 | score4 = score4[:, :, 0:score_res3.size(2), 0:score_res3.size(3)] 125 | 126 | score = score_res3 + score4 127 | 128 | return score 129 | -------------------------------------------------------------------------------- /tinyfaces/clustering/cluster.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | 4 | import joblib 5 | import numpy as np 6 | from pyclust import KMedoids 7 | from pyclustering.cluster.kmedoids import kmedoids 8 | from tqdm import tqdm 9 | 10 | from tinyfaces.clustering.k_medoids import kMedoids 11 | from tinyfaces.metrics import jaccard_index, rect_dist 12 | 13 | 14 | def centralize_bbox(bboxes): 15 | """ 16 | Convert the bounding boxes from (x, y, w, h) to (-w/2, -h/2, w/2, h/2). 17 | We perform clustering based on aspect ratio only. 18 | """ 19 | print("Centralize and vectorize") 20 | hs = bboxes[:, 3] - bboxes[:, 1] + 1 21 | ws = bboxes[:, 2] - bboxes[:, 0] + 1 22 | rects = np.vstack( 23 | [-(ws - 1) / 2, -(hs - 1) / 2, (ws - 1) / 2, (hs - 1) / 2]).T 24 | 25 | return rects 26 | 27 | 28 | def compute_distances(bboxes): 29 | print("Computing distances") 30 | distances = np.zeros((len(bboxes), len(bboxes))) 31 | for i in tqdm(range(len(bboxes)), total=len(bboxes)): 32 | for j in range(len(bboxes)): 33 | distances[i, j] = 1 - jaccard_index(bboxes[i, :], bboxes[j, :], 34 | (i, j)) 35 | 36 | return distances 37 | 38 | 39 | def compute_kmedoids(bboxes, 40 | cls, 41 | option='pyclustering', 42 | indices=15, 43 | max_clusters=35, 44 | max_limit=5000): 45 | print("Performing clustering using", option) 46 | clustering = [{} for _ in range(indices)] 47 | 48 | bboxes = centralize_bbox(bboxes) 49 | 50 | # subsample the number of bounding boxes so that it can fit in memory and is faster 51 | if bboxes.shape[0] > max_limit: 52 | sub_ind = np.random.choice(np.arange(bboxes.shape[0]), 53 | size=max_limit, 54 | replace=False) 55 | bboxes = bboxes[sub_ind] 56 | 57 | distances_cache = Path('distances_{0}.jbl'.format(cls)) 58 | if distances_cache.exists(): 59 | print("Loading distances") 60 | dist = joblib.load(distances_cache) 61 | else: 62 | dist = compute_distances(bboxes) 63 | joblib.dump(dist, distances_cache, compress=5) 64 | 65 | if option == 'pyclustering': 66 | for k in range(indices, max_clusters + 1): 67 | print(k, "clusters") 68 | 69 | initial_medoids = np.random.choice(bboxes.shape[0], 70 | size=k, 71 | replace=False) 72 | 73 | kmedoids_instance = kmedoids(dist, 74 | initial_medoids, 75 | ccore=True, 76 | data_type='distance_matrix') 77 | 78 | print("Running KMedoids") 79 | t1 = datetime.now() 80 | kmedoids_instance.process() 81 | dt = datetime.now() - t1 82 | print("Total time taken for clustering {k} medoids: {0}min:{1}s". 83 | format(dt.seconds // 60, dt.seconds % 60, k=k)) 84 | 85 | medoids_idx = kmedoids_instance.get_medoids() 86 | medoids = bboxes[medoids_idx] 87 | 88 | clustering.append({ 89 | 'n_clusters': k, 90 | 'medoids': medoids, 91 | 'class': cls 92 | }) 93 | 94 | elif option == 'pyclust': 95 | 96 | for k in range(indices, max_clusters + 1): 97 | print(k, "clusters") 98 | kmd = KMedoids(n_clusters=k, 99 | distance=rect_dist, 100 | n_trials=1, 101 | max_iter=2) 102 | t1 = datetime.now() 103 | kmd.fit(bboxes) 104 | dt = datetime.now() - t1 105 | print("Total time taken for clustering {k} medoids: {0}min:{1}s". 106 | format(dt.seconds // 60, dt.seconds % 60, k=k)) 107 | 108 | medoids = kmd.centers_ 109 | 110 | clustering.append({ 111 | 'n_clusters': k, 112 | 'medoids': medoids, 113 | 'class': cls 114 | }) 115 | 116 | elif option == 'local': 117 | 118 | for k in range(indices, max_clusters + 1): 119 | print(k, "clusters") 120 | curr_medoids, cluster_idxs = kMedoids(dist, k=k) 121 | medoids = [] 122 | for m in curr_medoids: 123 | medoids.append(bboxes[m, :]) 124 | clustering.append({ 125 | 'n_clusters': k, 126 | 'medoids': medoids, 127 | 'class': cls 128 | }) 129 | 130 | return clustering 131 | -------------------------------------------------------------------------------- /tinyfaces/models/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_bboxes(score_cls, 5 | score_reg, 6 | prob_cls, 7 | templates, 8 | prob_thresh, 9 | rf, 10 | scale=1, 11 | refine=True): 12 | """ 13 | Convert model output tensor to a set of bounding boxes and their corresponding scores 14 | """ 15 | num_templates = templates.shape[0] 16 | 17 | # template to evaluate at every scale (Type A templates) 18 | all_scale_template_ids = np.arange(4, 12) 19 | 20 | # templates to evaluate at a single scale aka small scale (Type B templates) 21 | one_scale_template_ids = np.arange(18, 25) 22 | 23 | ignored_template_ids = np.setdiff1d( 24 | np.arange(25), 25 | np.concatenate((all_scale_template_ids, one_scale_template_ids))) 26 | 27 | template_scales = templates[:, 4] 28 | 29 | # if we down-sample, then we only need large templates 30 | if scale < 1: 31 | invalid_one_scale_idx = np.where( 32 | template_scales[one_scale_template_ids] >= 1.0) 33 | elif scale == 1: 34 | invalid_one_scale_idx = np.where( 35 | template_scales[one_scale_template_ids] != 1.0) 36 | elif scale > 1: 37 | invalid_one_scale_idx = np.where( 38 | template_scales[one_scale_template_ids] != 1.0) 39 | 40 | invalid_template_id = np.concatenate( 41 | (ignored_template_ids, one_scale_template_ids[invalid_one_scale_idx])) 42 | 43 | # zero out prediction from templates that are invalid on this scale 44 | prob_cls[:, :, invalid_template_id] = 0.0 45 | 46 | indices = np.where(prob_cls > prob_thresh) 47 | fb, fy, fx, fc = indices 48 | 49 | scores = score_cls[fb, fy, fx, fc] 50 | scores = scores.reshape((scores.shape[0], 1)) 51 | 52 | stride, offset = rf['stride'], rf['offset'] 53 | cy, cx = fy * stride[0] + offset[0], fx * stride[1] + offset[1] 54 | cw = templates[fc, 2] - templates[fc, 0] + 1 55 | ch = templates[fc, 3] - templates[fc, 1] + 1 56 | 57 | # bounding box refinements 58 | tx = score_reg[:, :, :, 0:num_templates] 59 | ty = score_reg[:, :, :, 1 * num_templates:2 * num_templates] 60 | tw = score_reg[:, :, :, 2 * num_templates:3 * num_templates] 61 | th = score_reg[:, :, :, 3 * num_templates:4 * num_templates] 62 | 63 | if refine: 64 | bboxes = regression_refinement(tx, ty, tw, th, cx, cy, cw, ch, indices) 65 | 66 | else: 67 | bboxes = np.array([cx - cw / 2, cy - ch / 2, cx + cw / 2, cy + ch / 2]) 68 | 69 | # bboxes has a channel dim so we remove that 70 | bboxes = bboxes[0] 71 | 72 | # scale the bboxes 73 | factor = 1 / scale 74 | bboxes = bboxes * factor 75 | 76 | return bboxes, scores 77 | 78 | 79 | def regression_refinement(tx, ty, tw, th, cx, cy, cw, ch, indices): 80 | # refine the bounding boxes 81 | dcx = cw * tx[indices] 82 | dcy = ch * ty[indices] 83 | 84 | rcx = cx + dcx 85 | rcy = cy + dcy 86 | 87 | rcw = cw * np.exp(tw[indices]) 88 | rch = ch * np.exp(th[indices]) 89 | 90 | # create bbox array 91 | rcx = rcx.reshape((rcx.shape[0], 1)) 92 | rcy = rcy.reshape((rcy.shape[0], 1)) 93 | rcw = rcw.reshape((rcw.shape[0], 1)) 94 | rch = rch.reshape((rch.shape[0], 1)) 95 | 96 | # transpose so that it is (N, 4) 97 | bboxes = np.array( 98 | [rcx - rcw / 2, rcy - rch / 2, rcx + rcw / 2, rcy + rch / 2]).T 99 | 100 | return bboxes 101 | 102 | 103 | def balance_sampling(label_cls, pos_fraction, sample_size=256): 104 | """ 105 | Perform balance sampling by always sampling `pos_fraction` positive samples and 106 | `(1-pos_fraction)` negative samples from the input 107 | :param label_cls: Class labels as numpy.array. 108 | :param pos_fraction: The maximum fraction of positive samples to keep. 109 | :return: 110 | """ 111 | pos_maxnum = sample_size * pos_fraction # sample 128 positive points 112 | 113 | # Find all the points where we have objects and ravel the indices to get a 1D array. 114 | # This makes the subsequent operations easier to reason about 115 | pos_idx_unraveled = np.where(label_cls == 1) 116 | pos_idx = np.array(np.ravel_multi_index(pos_idx_unraveled, 117 | label_cls.shape)) 118 | 119 | if pos_idx.size > pos_maxnum: 120 | # Get all the indices of the locations to be zeroed out 121 | didx = shuffle_index(pos_idx.size, pos_idx.size - pos_maxnum) 122 | # Get the locations and unravel it so we can index 123 | pos_idx_unraveled = np.unravel_index(pos_idx[didx], label_cls.shape) 124 | label_cls[pos_idx_unraveled] = 0 125 | 126 | neg_maxnum = pos_maxnum * (1 - pos_fraction) / pos_fraction 127 | neg_idx_unraveled = np.where(label_cls == -1) 128 | neg_idx = np.array(np.ravel_multi_index(neg_idx_unraveled, 129 | label_cls.shape)) 130 | 131 | if neg_idx.size > neg_maxnum: 132 | # Get all the indices of the locations to be zeroed out 133 | ridx = shuffle_index(neg_idx.size, neg_maxnum) 134 | didx = np.arange(0, neg_idx.size) 135 | didx = np.delete(didx, ridx) 136 | neg_idx = np.unravel_index(neg_idx[didx], label_cls.shape) 137 | label_cls[neg_idx] = 0 138 | 139 | return label_cls 140 | 141 | 142 | def shuffle_index(n, n_out): 143 | """ 144 | Randomly shuffle the indices and return a subset of them 145 | :param n: The number of indices to shuffle. 146 | :param n_out: The number of output indices. 147 | :return: 148 | """ 149 | n = int(n) 150 | n_out = int(n_out) 151 | 152 | if n == 0 or n_out == 0: 153 | return np.empty(0) 154 | 155 | x = np.random.permutation(n) 156 | 157 | # the output should be at most the size of the input 158 | assert n_out <= n 159 | 160 | if n_out != n: 161 | x = x[:n_out] 162 | 163 | return x 164 | -------------------------------------------------------------------------------- /tinyfaces/metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import warnings 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | def jaccard_index(box_a, box_b, indices=[]): 9 | """ 10 | Compute the Jaccard Index (Intersection over Union) of 2 boxes. Each box is (x1, y1, x2, y2). 11 | :param box_a: 12 | :param box_b: 13 | :param indices: The indices of box_a and box_b as [box_a_idx, box_b_idx]. Helps in debugging DivideByZero errors 14 | :return: 15 | """ 16 | # area of bounding boxes 17 | area_A = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]) 18 | area_B = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) 19 | 20 | xA = max(box_a[0], box_b[0]) 21 | yA = max(box_a[1], box_b[1]) 22 | xB = min(box_a[2], box_b[2]) 23 | yB = min(box_a[3], box_b[3]) 24 | 25 | intersection = (xB - xA) * (yB - yA) 26 | union = area_A + area_B - intersection 27 | 28 | # return the intersection over union value 29 | try: 30 | if union <= 0: 31 | iou = 0 32 | else: 33 | iou = intersection / union 34 | except: 35 | print(indices) 36 | print(box_a) 37 | print(box_b) 38 | print(area_A, area_B, intersection) 39 | exit(1) 40 | 41 | return iou 42 | 43 | 44 | def rect_dist(I, J): 45 | if len(I.shape) == 1: 46 | I = I[np.newaxis, :] 47 | J = J[np.newaxis, :] 48 | 49 | # area of boxes 50 | aI = (I[:, 2] - I[:, 0] + 1) * (I[:, 3] - I[:, 1] + 1) 51 | aJ = (J[:, 2] - J[:, 0] + 1) * (J[:, 3] - J[:, 1] + 1) 52 | 53 | x1 = np.maximum(I[:, 0], J[:, 0]) 54 | y1 = np.maximum(I[:, 1], J[:, 1]) 55 | x2 = np.minimum(I[:, 2], J[:, 2]) 56 | y2 = np.minimum(I[:, 3], J[:, 3]) 57 | 58 | aIJ = (x2 - x1 + 1) * (y2 - y1 + 1) * (np.logical_and(x2 > x1, y2 > y1)) 59 | 60 | with warnings.catch_warnings(): 61 | warnings.filterwarnings('error') 62 | try: 63 | iou = aIJ / (aI + aJ - aIJ) 64 | except (RuntimeWarning, Exception): 65 | iou = np.zeros(aIJ.shape) 66 | 67 | # set NaN, inf, and -inf to 0 68 | iou[np.isnan(iou)] = 0 69 | iou[np.isinf(iou)] = 0 70 | 71 | dist = np.maximum(np.zeros(iou.shape), 72 | np.minimum(np.ones(iou.shape), 1 - iou)) 73 | 74 | return dist 75 | 76 | 77 | def voc_ap(rec, prec, use_07_metric=False): 78 | """ ap = voc_ap(rec, prec) 79 | Compute VOC AP given precision and recall. 80 | Always uses the newer metric (in contrast to the '07 metric) 81 | """ 82 | # correct AP calculation 83 | # first append sentinel values at the end 84 | mrec = np.concatenate(([0.], rec, [1.])) 85 | mpre = np.concatenate(([0.], prec, [0.])) 86 | 87 | # compute the precision envelope 88 | for i in range(mpre.size - 1, 0, -1): 89 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 90 | 91 | # to calculate area under PR curve, look for points 92 | # where X axis (recall) changes value 93 | i = np.where(mrec[1:] != mrec[:-1])[0] 94 | 95 | # and sum (\Delta recall) * prec 96 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 97 | return ap 98 | 99 | 100 | def average_precision(confidence, 101 | dets, 102 | image_ids, 103 | class_recs, 104 | npos, 105 | ovthresh=0.5): 106 | sorted_ind = np.argsort(-confidence) 107 | sorted_scores = np.sort(-confidence) 108 | BB = dets[sorted_ind, :] 109 | img_ids = [image_ids[x] for x in sorted_ind] 110 | 111 | nd = len(img_ids) # num of detections 112 | tp = np.zeros(nd) 113 | fp = np.zeros(nd) 114 | 115 | for d in tqdm(range(nd), total=nd): 116 | R = class_recs[img_ids[d]] 117 | bb = BB[d, :].astype(np.float) 118 | ovmax = -np.inf 119 | BBGT = R['bbox'].astype(np.float) 120 | BBGT[:, 2] = BBGT[:, 0] + BBGT[:, 2] - 1 121 | BBGT[:, 3] = BBGT[:, 1] + BBGT[:, 3] - 1 122 | 123 | if BBGT.size > 0: 124 | # compute overlaps 125 | # intersection 126 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 127 | iymin = np.maximum(BBGT[:, 1], bb[1]) 128 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 129 | iymax = np.minimum(BBGT[:, 3], bb[3]) 130 | iw = np.maximum(ixmax - ixmin, 0.0) 131 | ih = np.maximum(iymax - iymin, 0.0) 132 | inters = iw * ih 133 | 134 | # union 135 | uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) + 136 | (BBGT[:, 2] - BBGT[:, 0]) * (BBGT[:, 3] - BBGT[:, 1]) - 137 | inters) 138 | 139 | overlaps = inters / uni 140 | ovmax = np.max(overlaps) 141 | jmax = np.argmax(overlaps) 142 | 143 | if ovmax > ovthresh: 144 | if not R['det'][jmax]: 145 | tp[d] = 1. 146 | R['det'][jmax] = 1 147 | else: 148 | fp[d] = 1. 149 | else: 150 | fp[d] = 1. 151 | 152 | # compute precision recall 153 | fp = np.cumsum(fp) 154 | tp = np.cumsum(tp) 155 | rec = tp / float(npos) 156 | # avoid divide by zero in case the first detection matches a difficult 157 | # ground truth 158 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 159 | ap = voc_ap(rec, prec) 160 | 161 | return ap, prec, rec 162 | 163 | 164 | def compute_model_score(pred_file, gt_file, class_id=3): 165 | # load GT 166 | GT = json.load(open(gt_file)) 167 | recs = {} 168 | for g in GT: 169 | recs[g["image"]["id"]] = g["bboxes"] 170 | 171 | class_recs = {} 172 | npos = 0 173 | for img_id in recs.keys(): 174 | # get the list of all bboxes belonging to the particular class 175 | R = [obj for obj in recs[img_id] if obj["category_id"] == class_id] 176 | bboxes = np.array([x["bbox"] for x in R]) 177 | det = [False] * len( 178 | R) # to record if this object has already been recorded 179 | npos = npos + len(R) 180 | class_recs[img_id] = {'bbox': bboxes, 'det': det} 181 | 182 | print("Loaded GT") 183 | 184 | # Read the detections 185 | with open(pred_file) as f: 186 | preds = f.readlines() 187 | preds = [json.loads(x) for x in preds] 188 | 189 | confidence, BB, image_ids = [], [], [] 190 | for x in tqdm(preds, total=len(preds)): 191 | confidence.extend(x["confidences"]) 192 | BB.extend(x["bboxes"]) 193 | image_ids.extend([x["id"]] * len(x['confidences'])) 194 | 195 | print("Loaded detections") 196 | 197 | confidence = np.array(confidence) 198 | BB = np.array(BB) 199 | 200 | print(confidence.shape) 201 | print(BB.shape) 202 | 203 | ap, prec, rec = average_precision(confidence, BB, image_ids, class_recs, 204 | npos) 205 | return ap, prec, rec 206 | -------------------------------------------------------------------------------- /tinyfaces/datasets/wider_face.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import dataset 7 | from torchvision import transforms 8 | 9 | from tinyfaces.datasets.processor import DataProcessor 10 | from tinyfaces.utils import visualize 11 | 12 | 13 | class WIDERFace(dataset.Dataset): 14 | """The WIDERFace dataset is generated using MATLAB, 15 | so a lot of small housekeeping elements have been added 16 | to take care of the indexing discrepancies.""" 17 | 18 | def __init__(self, 19 | path, 20 | templates, 21 | img_transforms=None, 22 | dataset_root="", 23 | split="train", 24 | input_size=(500, 500), 25 | heatmap_size=(63, 63), 26 | pos_thresh=0.7, 27 | neg_thresh=0.3, 28 | pos_fraction=0.5, 29 | debug=False): 30 | super().__init__() 31 | 32 | self.data = [] 33 | self.split = split 34 | 35 | self.load(path) 36 | 37 | print("Dataset loaded") 38 | print("{0} samples in the {1} dataset".format(len(self.data), 39 | self.split)) 40 | 41 | # canonical object templates obtained via clustering 42 | # NOTE we directly use the values from Peiyun's repository stored in "templates.json" 43 | self.templates = templates 44 | 45 | self.transforms = img_transforms 46 | self.dataset_root = Path(dataset_root) 47 | self.input_size = input_size 48 | self.heatmap_size = heatmap_size 49 | self.pos_thresh = pos_thresh 50 | self.neg_thresh = neg_thresh 51 | self.pos_fraction = pos_fraction 52 | 53 | # receptive field computed using a combination of values from Matconvnet 54 | # plus derived equations. 55 | self.rf = {'size': [859, 859], 'stride': [8, 8], 'offset': [-1, -1]} 56 | 57 | self.processor = DataProcessor(input_size, 58 | heatmap_size, 59 | pos_thresh, 60 | neg_thresh, 61 | templates, 62 | rf=self.rf) 63 | self.debug = debug 64 | 65 | def load(self, path): 66 | """Load the dataset from the text file.""" 67 | 68 | if self.split in ("train", "val"): 69 | lines = open(path).readlines() 70 | self.data = [] 71 | idx = 0 72 | 73 | while idx < len(lines): 74 | img = lines[idx].strip() 75 | idx += 1 76 | n = int(lines[idx].strip()) 77 | idx += 1 78 | 79 | bboxes = np.empty((n, 10)) 80 | 81 | if n == 0: 82 | idx += 1 83 | else: 84 | for b in range(n): 85 | bboxes[b, :] = [ 86 | abs(float(x)) for x in lines[idx].strip().split() 87 | ] 88 | idx += 1 89 | 90 | # remove invalid bboxes where w or h are 0 91 | invalid = np.where( 92 | np.logical_or(bboxes[:, 2] == 0, bboxes[:, 3] == 0)) 93 | bboxes = np.delete(bboxes, invalid, 0) 94 | 95 | # bounding boxes are 1 indexed so we keep them like that 96 | # and treat them as abstract geometrical objects 97 | # We only need to worry about the box indexing when actually rendering them 98 | 99 | # convert from (x, y, w, h) to (x1, y1, x2, y2) 100 | # We work with the two point representation 101 | # since cropping becomes easier to deal with 102 | # -1 to ensure the same representation as in Matlab. 103 | bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] - 1 104 | bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] - 1 105 | 106 | datum = { 107 | "img_path": img, 108 | "bboxes": bboxes[:, 0:4], 109 | "blur": bboxes[:, 4], 110 | "expression": bboxes[:, 5], 111 | "illumination": bboxes[:, 6], 112 | "invalid": bboxes[:, 7], 113 | "occlusion": bboxes[:, 8], 114 | "pose": bboxes[:, 9] 115 | } 116 | 117 | self.data.append(datum) 118 | 119 | elif self.split == "test": 120 | data = open(path).readlines() 121 | self.data = [{'img_path': x.strip()} for x in data] 122 | 123 | def get_all_bboxes(self): 124 | bboxes = np.empty((0, 4)) 125 | for datum in self.data: 126 | bboxes = np.vstack((bboxes, datum['bboxes'])) 127 | 128 | return bboxes 129 | 130 | def __len__(self): 131 | return len(self.data) 132 | 133 | def process_inputs(self, image, bboxes): 134 | # Randomly resize the image 135 | rnd = np.random.rand() 136 | if rnd < 1 / 3: 137 | # resize by half 138 | scaled_shape = (int(0.5 * image.height), int(0.5 * image.width)) 139 | image = transforms.functional.resize(image, scaled_shape) 140 | bboxes = bboxes / 2 141 | 142 | elif rnd > 2 / 3: 143 | # double size 144 | scaled_shape = (int(2 * image.height), int(2 * image.width)) 145 | image = transforms.functional.resize(image, scaled_shape) 146 | bboxes = bboxes * 2 147 | 148 | # convert from PIL Image to ndarray 149 | img = np.array(image) 150 | 151 | # Get a random crop of the image and keep only relevant bboxes 152 | img, bboxes, paste_box = self.processor.crop_image(img, bboxes) 153 | pad_mask = self.processor.get_padding(paste_box) 154 | 155 | # Random Flip 156 | flip = np.random.rand() > 0.5 157 | if flip: 158 | img = np.fliplr(img).copy() # flip the image 159 | 160 | lx1, lx2 = np.array(bboxes[:, 0]), np.array(bboxes[:, 2]) 161 | # Flip the bounding box. +1 for correct indexing 162 | bboxes[:, 0] = self.input_size[1] - lx2 + 1 163 | bboxes[:, 2] = self.input_size[1] - lx1 + 1 164 | 165 | pad_mask = np.fliplr(pad_mask) 166 | 167 | # Get the ground truth class and regression maps 168 | class_maps, regress_maps, iou = self.processor.get_heatmaps( 169 | bboxes, pad_mask) 170 | 171 | if self.debug: 172 | # Visualize stuff 173 | visualize.visualize_bboxes( 174 | Image.fromarray(img.astype('uint8'), 'RGB'), bboxes) 175 | self.processor.visualize_heatmaps(Image.fromarray( 176 | img.astype('uint8'), 'RGB'), 177 | class_maps, 178 | regress_maps, 179 | self.templates, 180 | iou=iou) 181 | 182 | # and now we exit 183 | exit(0) 184 | 185 | # transpose so we get CxHxW 186 | class_maps = class_maps.transpose((2, 0, 1)) 187 | regress_maps = regress_maps.transpose((2, 0, 1)) 188 | 189 | # img is type float64. Convert it to uint8 so torch knows to treat it like an image 190 | img = img.astype(np.uint8) 191 | 192 | return img, class_maps, regress_maps, bboxes 193 | 194 | def __getitem__(self, index): 195 | datum = self.data[index] 196 | 197 | image_root = self.dataset_root / "WIDER_{0}".format(self.split) 198 | image_path = image_root / "images" / datum['img_path'] 199 | image = Image.open(image_path).convert('RGB') 200 | 201 | if self.split == 'train': 202 | bboxes = datum['bboxes'] 203 | 204 | if self.debug: 205 | if bboxes.shape[0] == 0: 206 | print(image_path) 207 | print("Dataset index: \t", index) 208 | print("image path:\t", image_path) 209 | 210 | img, class_map, reg_map, bboxes = self.process_inputs( 211 | image, bboxes) 212 | 213 | # convert everything to tensors 214 | if self.transforms is not None: 215 | # if img is a byte or uint8 array, it will convert from 0-255 to 0-1 216 | # this converts from (HxWxC) to (CxHxW) as well 217 | img = self.transforms(img) 218 | 219 | class_map = torch.from_numpy(class_map) 220 | reg_map = torch.from_numpy(reg_map) 221 | 222 | return img, class_map, reg_map 223 | 224 | elif self.split == 'val': 225 | # NOTE Return only the image and the image path. 226 | # Use the eval_tools to get the final results. 227 | if self.transforms is not None: 228 | # Only convert to tensor since we do normalization after rescaling 229 | img = transforms.functional.to_tensor(image) 230 | 231 | return img, datum['img_path'] 232 | 233 | elif self.split == 'test': 234 | filename = datum['img_path'] 235 | 236 | if self.transforms is not None: 237 | img = self.transforms(image) 238 | 239 | return img, filename 240 | -------------------------------------------------------------------------------- /tinyfaces/datasets/processor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from copy import deepcopy 3 | 4 | import numpy as np 5 | from torchvision.ops import nms 6 | 7 | from tinyfaces.datasets.dense_overlap import compute_dense_overlap 8 | from tinyfaces.metrics import rect_dist 9 | from tinyfaces.utils.visualize import draw_bounding_box 10 | 11 | logger = logging.getLogger("detector") 12 | 13 | 14 | class DataProcessor: 15 | """ 16 | This is a helper class to abstract out all the operation needed during the data-loading 17 | pipeline of the Tiny Faces object detector. 18 | 19 | The idea is that this can act as a mixin that enables torch dataloaders with the heatmap 20 | generation semantics. 21 | """ 22 | 23 | def __init__(self, 24 | input_size, 25 | heatmap_size, 26 | pos_thresh, 27 | neg_thresh, 28 | templates, 29 | img_means=None, 30 | rf=None): 31 | self.input_size = input_size 32 | self.heatmap_size = heatmap_size 33 | self.pos_thresh = pos_thresh 34 | self.neg_thresh = neg_thresh 35 | self.templates = templates 36 | self.rf = rf 37 | self.ofy, self.ofx = rf['offset'] 38 | self.sty, self.stx = rf['stride'] 39 | self.img_means = img_means or [0.485, 0.456, 0.406] 40 | 41 | def crop_image(self, img, bboxes): 42 | """ 43 | Crop a 500x500 patch from the image, taking care for smaller images. 44 | bboxes is the np.array of all bounding boxes [x1, y1, x2, y2] 45 | """ 46 | # randomly pick a cropping window for the image 47 | # We keep the second arg to randint at least 1 since randint is [low, high) 48 | max_crop_x = np.max([1, (img.shape[1] - self.input_size[1] + 1)]) 49 | max_crop_y = np.max([1, (img.shape[0] - self.input_size[0] + 1)]) 50 | crop_x1 = np.random.randint(0, max_crop_x) 51 | crop_y1 = np.random.randint(0, max_crop_y) 52 | crop_x2 = min(img.shape[1], crop_x1 + self.input_size[1]) 53 | crop_y2 = min(img.shape[0], crop_y1 + self.input_size[0]) 54 | crop_h = crop_y2 - crop_y1 55 | crop_w = crop_x2 - crop_x1 56 | 57 | # place the cropped image in a random location in a `input_size` image 58 | paste_box = [0, 0, 0, 0] # x1, y1, x2, y2 59 | paste_box[0] = np.random.randint(0, self.input_size[1] - crop_w + 1) 60 | paste_box[1] = np.random.randint(0, self.input_size[0] - crop_h + 1) 61 | paste_box[2] = paste_box[0] + crop_w 62 | paste_box[3] = paste_box[1] + crop_h 63 | 64 | # set this to average image colors 65 | # this will later be subtracted in mean image subtraction 66 | img_buf = np.zeros((self.input_size + (3, ))) 67 | 68 | # add the average image so it gets subtracted later. 69 | for i, c in enumerate(self.img_means): 70 | img_buf[:, :, i] += c 71 | # img is a int8 array, so we need to scale the values accordingly 72 | img_buf = (img_buf * 255).astype(np.int8) 73 | 74 | img_buf[paste_box[1]:paste_box[3], paste_box[0]:paste_box[2], :] = \ 75 | img[crop_y1:crop_y2, crop_x1:crop_x2, :] 76 | 77 | if bboxes.shape[0] > 0: 78 | # check if overlap is above negative threshold 79 | tbox = deepcopy(bboxes) 80 | tbox[:, 0] = np.maximum(tbox[:, 0], crop_x1) 81 | tbox[:, 1] = np.maximum(tbox[:, 1], crop_y1) 82 | tbox[:, 2] = np.minimum(tbox[:, 2], crop_x2) 83 | tbox[:, 3] = np.minimum(tbox[:, 3], crop_y2) 84 | 85 | overlap = 1 - rect_dist(tbox, bboxes) 86 | 87 | # adjust the bounding boxes - first for crop and then for random placement 88 | bboxes[:, 0] = bboxes[:, 0] - crop_x1 + paste_box[0] 89 | bboxes[:, 1] = bboxes[:, 1] - crop_y1 + paste_box[1] 90 | bboxes[:, 2] = bboxes[:, 2] - crop_x1 + paste_box[0] 91 | bboxes[:, 3] = bboxes[:, 3] - crop_y1 + paste_box[1] 92 | 93 | # correct for bbox to be within image border 94 | bboxes[:, 0] = np.minimum(self.input_size[1], 95 | np.maximum(0, bboxes[:, 0])) 96 | bboxes[:, 1] = np.minimum(self.input_size[0], 97 | np.maximum(0, bboxes[:, 1])) 98 | bboxes[:, 2] = np.minimum(self.input_size[1], 99 | np.maximum(1, bboxes[:, 2])) 100 | bboxes[:, 3] = np.minimum(self.input_size[0], 101 | np.maximum(1, bboxes[:, 3])) 102 | 103 | # check to see if the adjusted bounding box is invalid 104 | invalid = np.logical_or( 105 | np.logical_or(bboxes[:, 2] <= bboxes[:, 0], bboxes[:, 3] <= bboxes[:, 1]), \ 106 | overlap < self.neg_thresh) 107 | 108 | # remove invalid bounding boxes 109 | ind = np.where(invalid) 110 | bboxes = np.delete(bboxes, ind, 0) 111 | 112 | return img_buf, bboxes, paste_box 113 | 114 | def get_padding(self, paste_box): 115 | """ 116 | Get the padding of the image based on where the sampled image patch was placed. 117 | :param paste_box: [x1, y1, x2, y2] 118 | :return: 119 | """ 120 | ofy, ofx = self.rf['offset'] 121 | sty, stx = self.rf['stride'] 122 | vsy, vsx = self.heatmap_size 123 | coarse_x, coarse_y = np.meshgrid(ofx + np.array(range(vsx)) * stx, 124 | ofy + np.array(range(vsy)) * sty) 125 | 126 | # each cluster is [x1, y1, x2, y2] 127 | dx1 = self.templates[:, 0] 128 | dy1 = self.templates[:, 1] 129 | dx2 = self.templates[:, 2] 130 | dy2 = self.templates[:, 3] 131 | 132 | # yapf: disable 133 | # compute the bounds 134 | # We add new axes so that the arrays are numpy broadcasting compatible 135 | coarse_xx1 = coarse_x[:, :, np.newaxis] + dx1[np.newaxis, np.newaxis, :] # (vsy, vsx, nt) 136 | coarse_yy1 = coarse_y[:, :, np.newaxis] + dy1[np.newaxis, np.newaxis, :] # (vsy, vsx, nt) 137 | coarse_xx2 = coarse_x[:, :, np.newaxis] + dx2[np.newaxis, np.newaxis, :] # (vsy, vsx, nt) 138 | coarse_yy2 = coarse_y[:, :, np.newaxis] + dy2[np.newaxis, np.newaxis, :] # (vsy, vsx, nt) 139 | # yapf: enable 140 | 141 | # Matlab code indexes from 1 hence to check against it, we need to add +1 142 | # However, in python we don't need the +1 during actual training 143 | padx1 = coarse_xx1 < paste_box[0] + 1 144 | pady1 = coarse_yy1 < paste_box[1] + 1 145 | padx2 = coarse_xx2 > paste_box[2] 146 | pady2 = coarse_yy2 > paste_box[3] 147 | 148 | pad_mask = padx1 | pady1 | padx2 | pady2 149 | 150 | return pad_mask 151 | 152 | def get_regression(self, bboxes, cluster_boxes, iou): 153 | """ 154 | Compute the target bounding box regression values 155 | :param bboxes: 156 | :param cluster_boxes: 157 | :param iou: 158 | :return: 159 | """ 160 | ofy, ofx = self.rf['offset'] 161 | sty, stx = self.rf['stride'] 162 | vsy, vsx = self.heatmap_size 163 | 164 | coarse_xx, coarse_yy = np.meshgrid(ofx + np.array(range(vsx)) * stx, 165 | ofy + np.array(range(vsy)) * sty) 166 | 167 | dx1, dy1, dx2, dy2 = cluster_boxes 168 | 169 | # We reshape to take advantage of numpy broadcasting 170 | fxx1 = bboxes[:, 0].reshape(1, 1, 1, 171 | bboxes.shape[0]) # (1, 1, 1, bboxes) 172 | fyy1 = bboxes[:, 1].reshape(1, 1, 1, bboxes.shape[0]) 173 | fxx2 = bboxes[:, 2].reshape(1, 1, 1, bboxes.shape[0]) 174 | fyy2 = bboxes[:, 3].reshape(1, 1, 1, bboxes.shape[0]) 175 | 176 | h = dy2 - dy1 + 1 177 | w = dx2 - dx1 + 1 178 | dhh = h.reshape(1, 1, h.shape[0], 1) # (1, 1, N, 1) 179 | dww = w.reshape(1, 1, w.shape[0], 1) # (1, 1, N, 1) 180 | 181 | fcx = (fxx1 + fxx2) / 2 182 | fcy = (fyy1 + fyy2) / 2 183 | 184 | tx = np.divide((fcx - coarse_xx.reshape(vsy, vsx, 1, 1)), dww) 185 | ty = np.divide((fcy - coarse_yy.reshape(vsy, vsx, 1, 1)), dhh) 186 | 187 | fhh = fyy2 - fyy1 + 1 188 | fww = fxx2 - fxx1 + 1 189 | 190 | tw = np.log(np.divide(fww, dww)) # (1, 1, N, bboxes) 191 | th = np.log(np.divide(fhh, dhh)) 192 | 193 | # Randomly perturb the IOU so that if multiple candidates have the same IOU, 194 | # we don't pick the same one every time. This is useful when the template is smaller than the GT bbox 195 | iou = iou + (1e-6 * np.random.rand(*iou.shape)) 196 | 197 | best_obj_per_loc = iou.argmax(axis=3) 198 | idx0, idx1, idx2 = np.indices(iou.shape[:-1]) 199 | 200 | tx = tx[idx0, idx1, idx2, best_obj_per_loc] 201 | ty = ty[idx0, idx1, idx2, best_obj_per_loc] 202 | 203 | tw = np.repeat(tw, vsy, axis=0) # (vsy, 1, N, bboxes) 204 | tw = np.repeat(tw, vsx, axis=1) # (vsy, vsx, N, bboxes) 205 | tw = tw[idx0, idx1, idx2, best_obj_per_loc] 206 | 207 | th = np.repeat(th, vsy, axis=0) 208 | th = np.repeat(th, vsx, axis=1) 209 | th = th[idx0, idx1, idx2, best_obj_per_loc] 210 | 211 | return np.concatenate((tx, ty, tw, th), axis=2), iou 212 | 213 | def get_heatmaps(self, bboxes, pad_mask): 214 | ofy, ofx = self.rf['offset'] 215 | sty, stx = self.rf['stride'] 216 | vsy, vsx = self.heatmap_size 217 | 218 | nt = self.templates.shape[0] 219 | # Initiate heatmaps 220 | class_maps = -np.ones((vsy, vsx, nt)) 221 | regress_maps = np.zeros((vsy, vsx, nt * 4)) 222 | 223 | # each cluster is [-w/2, -h/2, w/2, h/2] 224 | dx1, dx2 = self.templates[:, 0], self.templates[:, 2] 225 | dy1, dy2 = self.templates[:, 1], self.templates[:, 3] 226 | 227 | # Filter out invalid bbox 228 | invalid_x = bboxes[:, 2] <= bboxes[:, 0] 229 | invalid_y = bboxes[:, 3] <= bboxes[:, 1] 230 | invalid = np.logical_or(invalid_x, invalid_y) 231 | ind = np.where(invalid) 232 | bboxes = np.delete(bboxes, ind, axis=0) 233 | 234 | ng = bboxes.shape[0] 235 | iou = np.zeros((vsy, vsx, self.templates.shape[0], bboxes.shape[0])) 236 | 237 | if ng > 0: 238 | gx1 = bboxes[:, 0] 239 | gy1 = bboxes[:, 1] 240 | gx2 = bboxes[:, 2] 241 | gy2 = bboxes[:, 3] 242 | 243 | iou = compute_dense_overlap(ofx, ofy, stx, sty, vsx, vsy, dx1, dy1, 244 | dx2, dy2, gx1, gy1, gx2, gy2, 1, 1) 245 | 246 | regress_maps, iou = self.get_regression(bboxes, 247 | [dx1, dy1, dx2, dy2], iou) 248 | 249 | best_iou = iou.max(axis=3) 250 | 251 | # Set max IoU values to 1 (even if they are < pos_thresh, as long as they are above neg_thresh) 252 | per_object_iou = np.reshape(iou, (-1, ng)) 253 | fbest_idx = np.argmax(per_object_iou, axis=0) 254 | iou_ = np.amax(per_object_iou, axis=0) 255 | fbest_idx = np.unravel_index(fbest_idx[iou_ > self.neg_thresh], 256 | iou.shape[:-1]) 257 | class_maps[fbest_idx] = 1 258 | 259 | # Assign positive labels 260 | class_maps = np.maximum(class_maps, 261 | (best_iou >= self.pos_thresh) * 2 - 1) 262 | 263 | # If between positive and negative, assign as gray area 264 | gray = -np.ones(class_maps.shape) 265 | gray_mask = np.bitwise_and(self.neg_thresh <= best_iou, best_iou 266 | < self.pos_thresh) 267 | gray[gray_mask] = 0 268 | # since we set the max IoU values to 1 269 | class_maps = np.maximum(class_maps, gray) 270 | 271 | # handle the boundary 272 | non_neg_border = np.bitwise_and(pad_mask, class_maps != -1) 273 | class_maps[non_neg_border] = 0 274 | regress_maps[:, :, :nt][non_neg_border] = 0 275 | 276 | # Return heatmaps 277 | return class_maps, regress_maps, iou 278 | 279 | def visualize_heatmaps(self, 280 | img, 281 | cls_map, 282 | reg_map, 283 | templates, 284 | prob_thresh=1, 285 | nms_thresh=1, 286 | iou=None): 287 | """ 288 | Expect cls_map and reg_map to be of the form HxWxC 289 | """ 290 | fy, fx, fc = np.where(cls_map >= prob_thresh) 291 | 292 | cy, cx = fy * self.sty + self.ofy, fx * self.stx + self.ofx 293 | cw = templates[fc, 2] - templates[fc, 0] 294 | ch = templates[fc, 3] - templates[fc, 1] 295 | 296 | # box_ovlp = best_iou[fc, fy, fx] 297 | num_templates = templates.shape[0] 298 | 299 | # refine bounding box 300 | tx = reg_map[:, :, 0 * num_templates:1 * num_templates] 301 | ty = reg_map[:, :, 1 * num_templates:2 * num_templates] 302 | tw = reg_map[:, :, 2 * num_templates:3 * num_templates] 303 | th = reg_map[:, :, 3 * num_templates:4 * num_templates] 304 | 305 | dcx = cw * tx[fy, fx, fc] 306 | dcy = ch * ty[fy, fx, fc] 307 | 308 | rx = cx + dcx 309 | ry = cy + dcy 310 | 311 | rw = cw * np.exp(tw[fy, fx, fc]) 312 | rh = ch * np.exp(th[fy, fx, fc]) 313 | 314 | bboxes = np.array([ 315 | np.abs(rx - rw / 2), 316 | np.abs(ry - rh / 2), rx + rw / 2, ry + rh / 2 317 | ]).T 318 | 319 | scores = cls_map[fy, fx, fc] 320 | 321 | keep = nms(bboxes, scores, nms_thresh) 322 | bboxes = bboxes[keep] 323 | # bbox_iou = best_iou[fy, fx, fc] 324 | 325 | # print("Best bounding box", bboxes) 326 | # print(bboxes.shape) 327 | 328 | print("Number of bboxes ", bboxes.shape[0]) 329 | for idx, bbox in enumerate(bboxes): 330 | bbox = np.round(np.array(bbox)) 331 | print(bbox) 332 | # img = draw_bounding_box(img, bbox, {"name": "{0}".format(np.around(bbox_iou[idx], decimals=2))}) 333 | img = draw_bounding_box(img, bbox, {"name": "{0}".format(idx)}) 334 | 335 | # if idx == 20: 336 | # break 337 | 338 | img.show(title="Heatmap visualized") 339 | --------------------------------------------------------------------------------