├── .github ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── cache_data ├── .gitignore ├── README.md ├── aflw_from_mat.py ├── cache │ ├── demo-sbr.mp4 │ ├── demo-sbrs │ │ └── .gitignore │ ├── demo.gif │ ├── dir-layout.png │ └── self.jpeg ├── demo_list.py ├── extrct_300VW.py ├── generate_300VW.py ├── generate_300W.py └── init_path.py ├── configs ├── Detector.config ├── LK.SGD.config ├── SGD.config ├── lk.config └── mix.lk.config ├── exps ├── basic_main.py ├── eval.py ├── lk_main.py └── vis.py ├── lib ├── config_utils │ ├── __init__.py │ ├── basic_args.py │ ├── configure_utils.py │ └── lk_args.py ├── datasets │ ├── GeneralDataset.py │ ├── VideoDataset.py │ ├── __init__.py │ ├── dataset_utils.py │ ├── file_utils.py │ ├── parse_utils.py │ └── point_meta.py ├── lk │ ├── __init__.py │ ├── basic_lk.py │ ├── basic_lk_batch.py │ ├── basic_utils.py │ └── basic_utils_batch.py ├── log_utils │ ├── __init__.py │ ├── logger.py │ ├── meter.py │ └── time_utils.py ├── models │ ├── LK.py │ ├── __init__.py │ ├── basic.py │ ├── basic_batch.py │ ├── cpm_vgg16.py │ ├── initialization.py │ └── model_utils.py ├── optimizer │ ├── __init__.py │ └── opt_utils.py ├── procedure │ ├── __init__.py │ ├── basic_eval.py │ ├── basic_train.py │ ├── lk_loss.py │ ├── lk_train.py │ ├── losses.py │ ├── saver.py │ └── starts.py ├── pts_utils │ ├── __init__.py │ └── generation.py ├── utils │ ├── __init__.py │ └── file_utils.py └── xvision │ ├── __init__.py │ ├── common_eval.py │ ├── evaluation_util.py │ ├── transforms.py │ └── visualization.py └── scripts ├── 300W-DET.sh ├── AFLW-DET.sh ├── demo_pam.sh ├── demo_sbr.sh └── sbr_example.sh /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Supervision-by-Registration (SBR) 2 | We want to make contributions to this project as easy and transparent as possible. 3 | 4 | ## Our Development Process 5 | Preliminary Implementations. 6 | 7 | ## Pull Requests 8 | We actively welcome your pull requests. 9 | 10 | 1. Fork the repo and create your branch from `master`. 11 | 2. If you've added code that should be tested, add tests. 12 | 3. If you've changed APIs, update the documentation. 13 | 4. Ensure the test suite passes. 14 | 5. Make sure your code lints. 15 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 16 | 17 | ## Contributor License Agreement ("CLA") 18 | In order to accept your pull request, we need you to submit a CLA. You only need 19 | to do this once to work on any of Facebook's open source projects. 20 | 21 | Complete your CLA here: 22 | 23 | ## Issues 24 | We use GitHub issues to track public bugs. Please ensure your description is 25 | clear and has sufficient instructions to be able to reproduce the issue. 26 | 27 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 28 | disclosure of security bugs. In those cases, please go through the process 29 | outlined on that page and do not file a public issue. 30 | 31 | ## Coding Style 32 | * 2 spaces for indentation rather than tabs 33 | * ... 34 | 35 | ## License 36 | By contributing to SBR, you agree that your contributions will be licensed 37 | under the LICENSE file in the root directory of this source tree. 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | If you have a question or would like help and support, please ask at [Issues](https://github.com/facebookresearch/supervision-by-registration/issues) 2 | 3 | If you are submitting a feature request, please preface the title with [feature request]. 4 | If you are submitting a bug report, please fill in the following details. 5 | 6 | ## Issue description 7 | 8 | Provide a short description. 9 | 10 | ## Code example 11 | 12 | Please try to provide a minimal example to repro the bug. 13 | Error messages and stack traces are also helpful. 14 | 15 | ## System Info 16 | Please copy and paste the output from our 17 | [environment collection script](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py) 18 | (or fill out the checklist below manually). 19 | 20 | You can get the script and run it with: 21 | ``` 22 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py 23 | # For security purposes, please check the contents of collect_env.py before running it. 24 | python collect_env.py 25 | ``` 26 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/supervision-by-registration/4e7fcedfafa5b176a6d4e0c035af67430aad2982/.github/PULL_REQUEST_TEMPLATE.md -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | */__pycache__ 3 | */*/__pycache__ 4 | snapshots 5 | cache_data/cache 6 | cache_data/lists 7 | .DS_Store 8 | */.DS_Store 9 | */*/.DS_Store 10 | *.swp 11 | */*.swp 12 | */*/*.swp 13 | */AFLWinfo_release.mat 14 | AFLWinfo_release.mat 15 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | October 19 2018 2 | 3 | ## Add missing copyright 4 | 5 | June 22 2018 6 | 7 | ## Create the project 8 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Supervision-by-Registration: An Unsupervised Approach to Improve the Precision of Facial Landmark Detectors 2 | By Xuanyi Dong, Shoou-I Yu, Xinshuo Weng, Shih-En Wei, Yi Yang, Yaser Sheikh 3 | 4 | University of Technology Sydney, Facebook Reality Labs 5 | 6 | ## Introduction 7 | We propose a method to find facial landmarks (e.g. corner of eyes, corner of mouth, tip of nose, etc) more precisely. 8 | Our method utilizes the fact that objects move smoothly in a video sequence (i.e. optical flow registration) to improve an existing facial landmark detector. 9 | The key novelty is that no additional human annotations are necessary to improve the detector, hence it is an “unsupervised approach”. 10 | 11 | ![demo](https://github.com/facebookresearch/supervision-by-registration/blob/master/cache_data/cache/demo.gif) 12 | 13 | ## Citation 14 | If you find that Supervision-by-Registration helps your research, please cite the paper: 15 | ``` 16 | @inproceedings{dong2018sbr, 17 | title={{Supervision-by-Registration}: An Unsupervised Approach to Improve the Precision of Facial Landmark Detectors}, 18 | author={Dong, Xuanyi and Yu, Shoou-I and Weng, Xinshuo and Wei, Shih-En and Yang, Yi and Sheikh, Yaser}, 19 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 20 | pages={360--368}, 21 | year={2018} 22 | } 23 | ``` 24 | 25 | ## Requirements 26 | - PyTorch >= 0.4.0 27 | - Python3.6 28 | 29 | ## Data Preparation 30 | 31 | See the README in `cache_data`. 32 | 33 | ### Dataset Format 34 | Each dataset is saved as one file, in which each row indicates one specific face in one image or one video frame. 35 | The format of one line : 36 | ``` 37 | image_path annotation_path x1 y1 x2 y2 (face_size) 38 | ``` 39 | - *image_path*: the image (video frame) file path of that face. 40 | - *annotation_path*: the annotation file path of that face (annotation is the coordinates of all landmarks) 41 | - *x1, y1, x2, y2*: the coordinates of left-upper and right-lower points of the face bounding box. 42 | - *face_size*: an optional item. If set this value, we use the `face_size` to compute the NME; otherwise, we use the distance between two pre-defined points to compute the NME. 43 | 44 | ## Training 45 | 46 | See the `configs` directory for some example configurations. 47 | ### Basic Training 48 | ``` 49 | python ./exps/basic_main.py [] 50 | ``` 51 | The argument list is loaded by `./lib/config_utils/basic_args.py`. 52 | An examples script can is `./scripts/300W-DET.sh`, and you can simple run to train the base detector on the `300-W` dataset. 53 | ``` 54 | bash scripts/300W-DET.sh 55 | ``` 56 | 57 | ### Improving the Detector by SBR 58 | ``` 59 | python ./exps/lk_main.py [] 60 | ``` 61 | The argument list is loaded by `./lib/config_utils/lk_args.py`. 62 | 63 | 64 | #### An example to train SBR on the unlabeled sequences 65 | The `init_model` parameter is the path to the detector trained in the `Basic Training` section. 66 | ``` 67 | bash scripts/demo_sbr.sh 68 | ``` 69 | To see visualization results use the commands in `Visualization`. 70 | 71 | #### An example to train SBR on your own data 72 | See the script `./scripts/sbr_example.sh`, and some parameters should be replaced by your own data. 73 | 74 | 75 | ## Evaluation 76 | 77 | When using the `basic_main.py` or `lk_main.py`, we evaluate the testing datasets automatically. 78 | 79 | To evaluate a single image, you can use the following script to compute the coordinates of 68 facial landmarks of the target image: 80 | ``` 81 | python ./exps/eval.py --image ./cache_data/cache/self.jpeg --model ./snapshots/300W-CPM-DET/checkpoint/cpm_vgg16-epoch-049-050.pth --face 250 150 900 1100 --save ./cache_data/cache/test.jpeg 82 | ``` 83 | - image : the input image path 84 | - model : the snapshot path 85 | - face : the face bounding box 86 | - save : save the visualized results 87 | 88 | 89 | ## Visualization 90 | 91 | After training the SBR on the demo video or models on other datasets, you can use the `./exps/vis.py` code to generate the visualization results. 92 | ``` 93 | python ./exps/vis.py --meta snapshots/CPM-SBR/metas/eval-start-eval-00-01.pth --save cache_data/cache/demo-detsbr-vis 94 | ffmpeg -start_number 3 -i cache_data/cache/demo-detsbr-vis/image%04d.png -b:v 30000k -vf "fps=30" -pix_fmt yuv420p cache_data/cache/demo-detsbr-vis.mp4 95 | 96 | python ./exps/vis.py --meta snapshots/CPM-SBR/metas/eval-epoch-049-050-00-01.pth --save cache_data/cache/demo-sbr-vis 97 | ffmpeg -start_number 3 -i cache_data/cache/demo-sbr-vis/image%04d.png -b:v 30000k -vf "fps=30" -pix_fmt yuv420p cache_data/cache/demo-sbr-vis.mp4 98 | ``` 99 | - meta : the saved prediction files 100 | - save : the directory path to save the visualization results 101 | 102 | 103 | ## License 104 | supervision-by-registration is released under the [CC-BY-NC license](https://github.com/facebookresearch/supervision-by-registration/blob/master/LICENSE). 105 | 106 | 107 | ## Useful Information 108 | 109 | ### 1. train on your own video data 110 | You should look at the `./lib/datasets/VideoDataset.py` and `./lib/datasets/parse_utils.py`, and add how to find the neighbour frames when giving one image path. 111 | For more details, see the `parse_basic` function in `lib/datasets/parse_utils.py`. 112 | 113 | ### 2. warnings when training the AFLW datase 114 | It is ok to show the following warnings. Since some images in the AFLW dataset are in the wrong format, PIL will raise some warnings when loading these images. These warnings do not affect the training performance. 115 | ``` 116 | TiffImagePlugin.py:756: UserWarning: Corrupt EXIF data. Expecting to read 12 bytes but only got 6. 117 | ``` 118 | 119 | ### Contact 120 | To ask questions or report issues, please open an issue on [the issues tracker](https://github.com/facebookresearch/supervision-by-registration/issues). 121 | -------------------------------------------------------------------------------- /cache_data/.gitignore: -------------------------------------------------------------------------------- 1 | *.lst 2 | temp 3 | EX300.sh 4 | -------------------------------------------------------------------------------- /cache_data/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Preparation 2 | The raw dataset should be put into the `$HOME/datasets/landmark-datasets`. The layout should be organized as the following screen shot. 3 | 4 | ![layout](https://github.com/facebookresearch/supervision-by-registration/blob/master/cache_data/cache/dir-layout.png) 5 | 6 | ## [300-W](https://ibug.doc.ic.ac.uk/resources/300-W/) 7 | 8 | ### Download 9 | - 300-W consits of several different datasets 10 | - Create directory to save images and annotations: mkdir ~/datasets/landmark-datasets/300W 11 | - To download i-bug: https://ibug.doc.ic.ac.uk/download/annotations/ibug.zip 12 | - To download afw: https://ibug.doc.ic.ac.uk/download/annotations/afw.zip 13 | - To download helen: https://ibug.doc.ic.ac.uk/download/annotations/helen.zip 14 | - To download lfpw: https://ibug.doc.ic.ac.uk/download/annotations/lfpw.zip 15 | - To download the bounding box annotations: https://ibug.doc.ic.ac.uk/media/uploads/competitions/bounding_boxes.zip 16 | - In the folder of `~/datasets/landmark-datasets/300W`, there are four zip files ibug.zip, afw.zip, helen.zip, and lfpw.zip 17 | ``` 18 | unzip ibug.zip -d ibug 19 | mv ibug/image_092\ _01.jpg ibug/image_092_01.jpg 20 | mv ibug/image_092\ _01.pts ibug/image_092_01.pts 21 | 22 | unzip afw.zip -d afw 23 | unzip helen.zip -d helen 24 | unzip lfpw.zip -d lfpw 25 | unzip bounding_boxes.zip ; mv Bounding\ Boxes Bounding_Boxes 26 | ``` 27 | The 300W directory is in `$HOME/datasets/landmark-datasets/300W` and the sturecture is: 28 | ``` 29 | -- afw 30 | -- Bounding_boxes 31 | -- helen 32 | -- ibug 33 | -- lfpw 34 | ``` 35 | 36 | Then you use the script to generate the 300-W list files. 37 | ``` 38 | python generate_300W.py 39 | ``` 40 | All list files will be saved into `./lists/300W/`. The files `*.DET` use the face detecter results for face bounding box. `*.GTB` use the ground-truth results for face bounding box. 41 | 42 | #### can not find the `*.mat` files for 300-W. 43 | The download link is in the official [300-W website](https://ibug.doc.ic.ac.uk/resources/300-W). 44 | ``` 45 | https://ibug.doc.ic.ac.uk/media/uploads/competitions/bounding_boxes.zip 46 | ``` 47 | The zip file should be unzipped, and all extracted mat files should be put into `$HOME/datasets/landmark-datasets/300W/Bounding_Boxes`. 48 | 49 | ## [AFLW](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/aflw/) 50 | 51 | Download the aflw.tar.gz file in `$HOME/datasets/landmark-datasets` and extract it by `tar xzvf aflw.tar.gz`. 52 | ``` 53 | mkdir $HOME/datasets/landmark-datasets/AFLW 54 | cp -r aflw/data/flickr $HOME/datasets/landmark-datasets/AFLW/images 55 | ``` 56 | 57 | The structure of AFLW is: 58 | ``` 59 | --images 60 | --0 61 | --2 62 | --3 63 | ``` 64 | 65 | Download the [AFLWinfo_release.mat](http://mmlab.ie.cuhk.edu.hk/projects/compositional/AFLWinfo_release.mat) from [this website](http://mmlab.ie.cuhk.edu.hk/projects/compositional.html) into `./cache_data`. This is the revised annotation of the full AFLW dataset. 66 | 67 | Generate the AFLW dataset list file into `./lists/AFLW`. 68 | ``` 69 | python aflw_from_mat.py 70 | ``` 71 | 72 | ## [300VW](https://ibug.doc.ic.ac.uk/resources/300-VW/) 73 | Download `300VW_Dataset_2015_12_14.zip` into `$HOME/datasets/landmark-datasets` and unzip it into `$HOME/datasets/landmark-datasets/300VW_Dataset_2015_12_14`. 74 | 75 | Use the following command to extract the raw video into the image format. 76 | ``` 77 | python extrct_300VW.py 78 | sh ./cache/Extract300VW.sh 79 | ``` 80 | 81 | Generate the 300-VW dataset list file. 82 | ``` 83 | python generate_300VW.py 84 | ``` 85 | 86 | ## a short demo video sequence 87 | 88 | The raw video is `./cache_data/cache/demo-sbr.mp4`. 89 | - use `ffmpeg -i ./cache/demo-sbr.mp4 ./cache/demo-sbrs/image%04d.png` to extract the frames into `/cache/demo-sbrs/` 90 | Then use `python demo_list.py` to generate the list file for the demo video. 91 | 92 | # Citation 93 | If you use the 300-W dataset, please cite the following papers. 94 | ``` 95 | @article{sagonas2016300, 96 | title={300 faces in-the-wild challenge: Database and results}, 97 | author={Sagonas, Christos and Antonakos, Epameinondas and Tzimiropoulos, Georgios and Zafeiriou, Stefanos and Pantic, Maja}, 98 | journal={Image and Vision Computing}, 99 | volume={47}, 100 | pages={3--18}, 101 | year={2016}, 102 | publisher={Elsevier} 103 | } 104 | @inproceedings{sagonas2013300, 105 | title={300 faces in-the-wild challenge: The first facial landmark localization challenge}, 106 | author={Sagonas, Christos and Tzimiropoulos, Georgios and Zafeiriou, Stefanos and Pantic, Maja}, 107 | booktitle={Proceedings of the IEEE International Conference on Computer Vision Workshops}, 108 | pages={397--403}, 109 | year={2013}, 110 | organization={IEEE} 111 | } 112 | ``` 113 | If you use the 300-VW dataset, please cite the following papers. 114 | ``` 115 | @inproceedings{chrysos2015offline, 116 | title={Offline deformable face tracking in arbitrary videos}, 117 | author={Chrysos, Grigoris G and Antonakos, Epameinondas and Zafeiriou, Stefanos and Snape, Patrick}, 118 | booktitle={Proceedings of the IEEE International Conference on Computer Vision Workshops}, 119 | pages={1--9}, 120 | year={2015} 121 | } 122 | @inproceedings{shen2015first, 123 | title={The first facial landmark tracking in-the-wild challenge: Benchmark and results}, 124 | author={Shen, Jie and Zafeiriou, Stefanos and Chrysos, Grigoris G and Kossaifi, Jean and Tzimiropoulos, Georgios and Pantic, Maja}, 125 | booktitle={Proceedings of the IEEE International Conference on Computer Vision Workshops}, 126 | pages={50--58}, 127 | year={2015} 128 | } 129 | @inproceedings{tzimiropoulos2015project, 130 | title={Project-out cascaded regression with an application to face alignment}, 131 | author={Tzimiropoulos, Georgios}, 132 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 133 | pages={3659--3667}, 134 | year={2015} 135 | } 136 | ``` 137 | If you use the AFLW dataset, please cite the following papers. 138 | ``` 139 | @inproceedings{koestinger2011annotated, 140 | title={Annotated facial landmarks in the wild: A large-scale, real-world database for facial landmark localization}, 141 | author={Koestinger, Martin and Wohlhart, Paul and Roth, Peter M and Bischof, Horst}, 142 | booktitle={IEEE International Conference on Computer Vision Workshops}, 143 | pages={2144--2151}, 144 | year={2011}, 145 | organization={IEEE} 146 | } 147 | ``` 148 | -------------------------------------------------------------------------------- /cache_data/aflw_from_mat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import sys, sqlite3 8 | import os, math 9 | import os.path as osp 10 | from pathlib import Path 11 | import copy 12 | import numpy as np 13 | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() 14 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 15 | assert sys.version_info.major == 3, 'Please upgrade from {:} to Python 3.x'.format(sys.version_info) 16 | from scipy.io import loadmat 17 | import datasets 18 | 19 | #Change this paths according to your directories 20 | this_dir = osp.dirname(os.path.abspath(__file__)) 21 | SAVE_DIR = osp.join(this_dir, 'lists', 'AFLW') 22 | HOME_STR = 'DOME_HOME' 23 | if HOME_STR not in os.environ: HOME_STR = 'HOME' 24 | assert HOME_STR in os.environ, 'Doest not find the HOME dir : {}'.format(HOME_STR) 25 | print ('This dir : {}, HOME : [{}] : {}'.format(this_dir, HOME_STR, os.environ[HOME_STR])) 26 | if not osp.isdir(SAVE_DIR): os.makedirs(SAVE_DIR) 27 | image_dir = osp.join(os.environ[HOME_STR], 'datasets', 'landmark-datasets', 'AFLW', 'images') 28 | annot_dir = osp.join(os.environ[HOME_STR], 'datasets', 'landmark-datasets', 'AFLW', 'annotations') 29 | print ('AFLW image dir : {}'.format(image_dir)) 30 | print ('AFLW annotation dir : {}'.format(annot_dir)) 31 | assert osp.isdir(image_dir), 'The image dir : {} does not exist'.format(image_dir) 32 | #assert osp.isdir(image_dir), 'The image dir : {} does not exist'.format(image_dir) 33 | 34 | 35 | class AFLWFace(): 36 | def __init__(self, index, name, mask, landmark, box): 37 | self.image_path = name 38 | self.face_id = index 39 | self.face_box = [float(box[0]), float(box[2]), float(box[1]), float(box[3])] 40 | mask = np.expand_dims(mask, axis=1) 41 | landmark = landmark.copy() 42 | self.landmarks = np.concatenate((landmark, mask), axis=1) 43 | 44 | def get_face_size(self, use_box): 45 | box = [] 46 | if use_box == 'GTL': 47 | box = datasets.dataset_utils.PTSconvert2box(self.landmarks.copy().T) 48 | elif use_box == 'GTB': 49 | box = [self.face_box[0], self.face_box[1], self.face_box[2], self.face_box[3]] 50 | else: 51 | assert False, 'The box indicator not find : {}'.format(use_box) 52 | assert box[2] > box[0], 'The size of box is not right [{}] : {}'.format(self.face_id, box) 53 | assert box[3] > box[1], 'The size of box is not right [{}] : {}'.format(self.face_id, box) 54 | face_size = math.sqrt( float(box[3]-box[1]) * float(box[2]-box[0]) ) 55 | box_str = '{:.2f} {:.2f} {:.2f} {:.2f}'.format(box[0], box[1], box[2], box[3]) 56 | return box_str, face_size 57 | 58 | def check_front(self): 59 | oks = 0 60 | box = self.face_box 61 | for idx in range(self.landmarks.shape[0]): 62 | if bool(self.landmarks[idx,2]): 63 | x, y = self.landmarks[idx,0], self.landmarks[idx,1] 64 | if x > self.face_box[0] and x < self.face_box[2]: 65 | if y > self.face_box[1] and y < self.face_box[3]: 66 | oks = oks + 1 67 | return oks == 19 68 | 69 | def __repr__(self): 70 | return ('{name}(path={image_path}, face-id={face_id})'.format(name=self.__class__.__name__, **self.__dict__)) 71 | 72 | def save_to_list_file(allfaces, lst_file, image_style_dir, annotation_dir, face_indexes, use_front, use_box): 73 | save_faces = [] 74 | for index in face_indexes: 75 | face = allfaces[index] 76 | if use_front == False or face.check_front(): 77 | save_faces.append( face ) 78 | print ('Prepare to save {} face images into {}'.format(len(save_faces), lst_file)) 79 | 80 | lst_file = open(lst_file, 'w') 81 | all_face_sizes = [] 82 | for face in save_faces: 83 | image_path = face.image_path 84 | sub_dir, base_name = image_path.split('/') 85 | cannot_dir = osp.join(annotation_dir, sub_dir) 86 | cannot_path = osp.join(cannot_dir, base_name.split('.')[0] + '-{}.pts'.format(face.face_id)) 87 | if not osp.isdir(cannot_dir): os.makedirs(cannot_dir) 88 | image_path = osp.join(image_style_dir, image_path) 89 | assert osp.isfile(image_path), 'The image [{}/{}] {} does not exsit'.format(index, len(save_faces), image_path) 90 | 91 | if not osp.isfile(cannot_path): 92 | pts_str = datasets.PTSconvert2str( face.landmarks.T ) 93 | pts_file = open(cannot_path, 'w') 94 | pts_file.write('{}'.format(pts_str)) 95 | pts_file.close() 96 | else: pts_str = None 97 | 98 | box_str, face_size = face.get_face_size(use_box) 99 | 100 | lst_file.write('{} {} {} {}\n'.format(image_path, cannot_path, box_str, face_size)) 101 | all_face_sizes.append( face_size ) 102 | lst_file.close() 103 | 104 | all_faces = np.array( all_face_sizes ) 105 | print ('all faces : mean={}, std={}'.format(all_faces.mean(), all_faces.std())) 106 | 107 | if __name__ == "__main__": 108 | mat_path = osp.join(this_dir, 'AFLWinfo_release.mat') 109 | aflwinfo = dict() 110 | mat = loadmat(mat_path) 111 | total_image = 24386 112 | # load train/test splits 113 | ra = np.squeeze(mat['ra']-1).tolist() 114 | aflwinfo['train-index'] = ra[:20000] 115 | aflwinfo['test-index'] = ra[20000:] 116 | aflwinfo['name-list'] = [] 117 | # load name-list 118 | for i in range(total_image): 119 | name = mat['nameList'][i,0][0] 120 | #name = name[:-4] + '.jpg' 121 | aflwinfo['name-list'].append( name ) 122 | aflwinfo['mask'] = mat['mask_new'].copy() 123 | aflwinfo['landmark'] = mat['data'].reshape((total_image, 2, 19)) 124 | aflwinfo['landmark'] = np.transpose(aflwinfo['landmark'], (0,2,1)) 125 | aflwinfo['box'] = mat['bbox'].copy() 126 | allfaces = [] 127 | for i in range(total_image): 128 | face = AFLWFace(i, aflwinfo['name-list'][i], aflwinfo['mask'][i], aflwinfo['landmark'][i], aflwinfo['box'][i]) 129 | allfaces.append( face ) 130 | 131 | USE_BOXES = ['GTL', 'GTB'] 132 | for USE_BOX in USE_BOXES: 133 | save_to_list_file(allfaces, osp.join(SAVE_DIR, 'train.{}'.format(USE_BOX)), image_dir, annot_dir, aflwinfo['train-index'], False, USE_BOX) 134 | save_to_list_file(allfaces, osp.join(SAVE_DIR, 'test.{}'.format(USE_BOX)), image_dir, annot_dir, aflwinfo['test-index'], False, USE_BOX) 135 | save_to_list_file(allfaces, osp.join(SAVE_DIR, 'test.front.{}'.format(USE_BOX)), image_dir, annot_dir, aflwinfo['test-index'], True, USE_BOX) 136 | save_to_list_file(allfaces, osp.join(SAVE_DIR, 'all.{}'.format(USE_BOX)), image_dir, annot_dir, aflwinfo['train-index'] + aflwinfo['test-index'], False, USE_BOX) 137 | -------------------------------------------------------------------------------- /cache_data/cache/demo-sbr.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/supervision-by-registration/4e7fcedfafa5b176a6d4e0c035af67430aad2982/cache_data/cache/demo-sbr.mp4 -------------------------------------------------------------------------------- /cache_data/cache/demo-sbrs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /cache_data/cache/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/supervision-by-registration/4e7fcedfafa5b176a6d4e0c035af67430aad2982/cache_data/cache/demo.gif -------------------------------------------------------------------------------- /cache_data/cache/dir-layout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/supervision-by-registration/4e7fcedfafa5b176a6d4e0c035af67430aad2982/cache_data/cache/dir-layout.png -------------------------------------------------------------------------------- /cache_data/cache/self.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/supervision-by-registration/4e7fcedfafa5b176a6d4e0c035af67430aad2982/cache_data/cache/self.jpeg -------------------------------------------------------------------------------- /cache_data/demo_list.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os, pdb, sys, glob, cv2 8 | from os import path as osp 9 | from pathlib import Path 10 | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() 11 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 12 | assert sys.version_info.major == 3, 'Please upgrade from {:} to Python 3.x'.format(sys.version_info) 13 | print ('lib-dir : {:}'.format(lib_dir)) 14 | from datasets import pil_loader 15 | from utils.file_utils import load_list_from_folders, load_txt_file 16 | # ffmpeg -i shooui.mp4 -filter:v "crop=450:680:10:120" -c:a copy ~/Desktop/demo.mp4 17 | 18 | 19 | def generate(demo_dir, list_dir, savename, check): 20 | imagelist, num_image = load_list_from_folders(demo_dir, ext_filter=['png'], depth=1) 21 | assert num_image == check, 'The number of images is not right vs. {:}'.format(num_image) 22 | if not osp.isdir(list_dir): os.makedirs(list_dir) 23 | 24 | gap, x1, y1, x2, y2 = 5, 5, 5, 450, 680 25 | 26 | imagelist.sort() 27 | 28 | txtfile = open(osp.join(list_dir, savename), 'w') 29 | for idx, image in enumerate(imagelist): 30 | if idx < 2 or idx + 2 >= len(imagelist): continue 31 | box_str = '{:.1f} {:.1f} {:.1f} {:.1f}'.format(gap, gap, x2-x1-gap, y2-y1-gap) 32 | txtfile.write('{:} {:} {:}\n'.format(image, 'none', box_str)) 33 | txtfile.flush() 34 | txtfile.close() 35 | print('there are {:} images for the demo video sequence'.format(num_image)) 36 | 37 | if __name__ == '__main__': 38 | HOME_STR = 'DOME_HOME' 39 | if HOME_STR not in os.environ: HOME_STR = 'HOME' 40 | assert HOME_STR in os.environ, 'Doest not find the HOME dir : {}'.format(HOME_STR) 41 | 42 | this_dir = osp.dirname(os.path.abspath(__file__)) 43 | demo_dir = osp.join(this_dir, 'cache', 'demo-sbrs') 44 | list_dir = osp.join(this_dir, 'lists', 'demo') 45 | print ('This dir : {}, HOME : [{}] : {}'.format(this_dir, HOME_STR, os.environ[HOME_STR])) 46 | generate(demo_dir, list_dir, 'demo-sbr.lst', 275) 47 | 48 | #demo_dir = osp.join(this_dir, 'cache', 'demo-pams') 49 | #list_dir = osp.join(this_dir, 'lists', 'demo') 50 | #print ('This dir : {}, HOME : [{}] : {}'.format(this_dir, HOME_STR, os.environ[HOME_STR])) 51 | #generate(demo_dir, list_dir, 'demo-pam.lst', 253) 52 | -------------------------------------------------------------------------------- /cache_data/extrct_300VW.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os, pdb, sys, glob 8 | from os import path as osp 9 | 10 | def generate_extract_300vw(P300VW): 11 | allfiles = glob.glob(osp.join(P300VW, '*')) 12 | alldirs = [] 13 | for xfile in allfiles: 14 | if osp.isdir( xfile ): 15 | alldirs.append(xfile) 16 | assert len(alldirs) == 114, 'The directories of 300VW should be 114 not {}'.format(len(alldirs)) 17 | cmds = [] 18 | for xdir in alldirs: 19 | video = osp.join(xdir, 'vid.avi') 20 | exdir = osp.join(xdir, 'extraction') 21 | if not osp.isdir(exdir): os.makedirs(exdir) 22 | cmd = 'ffmpeg -i {:} {:}/%06d.png'.format(video, exdir) 23 | cmds.append( cmd ) 24 | 25 | if not osp.isdir('./cache'): 26 | os.makedirs('./cache') 27 | 28 | with open('./cache/Extract300VW.sh', 'w') as txtfile: 29 | for cmd in cmds: 30 | txtfile.write('{}\n'.format(cmd)) 31 | txtfile.close() 32 | 33 | if __name__ == '__main__': 34 | HOME = 'DOME_HOME' if 'DOME_HOME' in os.environ else 'HOME' 35 | P300VW = osp.join(os.environ[HOME], 'datasets', 'landmark-datasets', '300VW_Dataset_2015_12_14') 36 | generate_extract_300vw(P300VW) 37 | -------------------------------------------------------------------------------- /cache_data/generate_300VW.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import numpy as np 8 | import math, os, pdb, sys, glob 9 | from os import path as osp 10 | from pathlib import Path 11 | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() 12 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 13 | assert sys.version_info.major == 3, 'Please upgrade from {:} to Python 3.x'.format(sys.version_info) 14 | print ('lib-dir : {:}'.format(lib_dir)) 15 | import datasets 16 | 17 | EXPAND_RATIO = 0.0 18 | afterfix='.10' 19 | 20 | 21 | def str2size(box_str): 22 | splits = box_str.split(' ') 23 | x1, y1, x2, y2 = float(splits[0]), float(splits[1]), float(splits[2]), float(splits[3]) 24 | return math.sqrt( (x2-x1) * (y2-y1) ) 25 | 26 | 27 | def load_video_dir(root, dirs, save_dir, save_name): 28 | videos, sparse_videos = [], [] 29 | first_videos = [] 30 | for idx, cdir in enumerate(dirs): 31 | annot_path = osp.join(root, cdir, 'annot') 32 | frame_path = osp.join(root, cdir, 'extraction') 33 | all_frames = glob.glob( osp.join(frame_path, '*.png') ) 34 | all_annots = glob.glob( osp.join(annot_path, '*.pts') ) 35 | assert len(all_frames) == len(all_annots), 'The length is not right for {} : {} vs {}'.format(cdir, len(all_frames), len(all_annots)) 36 | all_frames = sorted(all_frames) 37 | all_annots = sorted(all_annots) 38 | current_video = [] 39 | txtfile = open(osp.join(save_dir, save_name + cdir), 'w') 40 | nonefile = open(osp.join(save_dir, save_name + cdir + '.none'), 'w') 41 | 42 | all_sizes = [] 43 | for frame, annot in zip(all_frames, all_annots): 44 | basename_f = osp.basename(frame) 45 | basename_a = osp.basename(annot) 46 | assert basename_a[:6] == basename_f[:6], 'The name of {} is not right with {}'.format(frame, annot) 47 | current_video.append( (frame, annot) ) 48 | box_str = datasets.dataset_utils.for_generate_box_str(annot, 68, EXPAND_RATIO) 49 | txtfile.write('{} {} {}\n'.format(frame, annot, box_str)) 50 | nonefile.write('{} None {}\n'.format(frame, box_str)) 51 | all_sizes.append( str2size(box_str) ) 52 | if len(current_video) == 1: 53 | first_videos.append( (frame, annot) ) 54 | txtfile.close() 55 | nonefile.close() 56 | videos.append( current_video ) 57 | all_sizes = np.array( all_sizes ) 58 | print ('--->>> {:} : [{:02d}/{:02d}] : {:} has {:} frames | face size : mean={:.2f}, std={:.2f}'.format(save_name, idx, len(dirs), cdir, len(all_frames), all_sizes.mean(), all_sizes.std())) 59 | 60 | for jxj, video in enumerate(current_video): 61 | if jxj <= 3 or jxj + 3 >= len(current_video): continue 62 | if jxj % 10 == 3: 63 | sparse_videos.append( video ) 64 | 65 | txtfile = open(osp.join(save_dir, save_name), 'w') 66 | nonefile = open(osp.join(save_dir, save_name + '.none'), 'w') 67 | for video in videos: 68 | for cpair in video: 69 | box_str = datasets.dataset_utils.for_generate_box_str(cpair[1], 68, EXPAND_RATIO) 70 | txtfile.write('{} {} {}\n'.format(cpair[0], cpair[1], box_str)) 71 | nonefile.write('{} {} {}\n'.format(cpair[0], 'None', box_str)) 72 | txtfile.flush() 73 | nonefile.flush() 74 | txtfile.close() 75 | nonefile.close() 76 | 77 | txtfile = open(osp.join(save_dir, save_name + '.sparse' + afterfix), 'w') 78 | nonefile = open(osp.join(save_dir, save_name + '.sparse.none' + afterfix), 'w') 79 | for cpair in sparse_videos: 80 | box_str = datasets.dataset_utils.for_generate_box_str(cpair[1], 68, EXPAND_RATIO) 81 | txtfile.write('{} {} {}\n'.format(cpair[0], cpair[1], box_str)) 82 | nonefile.write('{} {} {}\n'.format(cpair[0], 'None', box_str)) 83 | txtfile.close() 84 | nonefile.close() 85 | 86 | txtfile = open(osp.join(save_dir, save_name + '.first'), 'w') 87 | for cpair in first_videos: 88 | box_str = datasets.dataset_utils.for_generate_box_str(cpair[1], 68, EXPAND_RATIO) 89 | txtfile.write('{} {} {}\n'.format(cpair[0], cpair[1], box_str)) 90 | txtfile.close() 91 | 92 | print ('{} finish save into {}'.format(save_name, save_dir)) 93 | return videos 94 | 95 | def generate_300vw_list(root, save_dir): 96 | assert osp.isdir(root), '{} is not dir'.format(root) 97 | if not osp.isdir(save_dir): os.makedirs(save_dir) 98 | test_1_dirs = [114, 124, 125, 126, 150, 158, 401, 402, 505, 506, 507, 508, 509, 510, 511, 514, 515, 518, 519, 520, 521, 522, 524, 525, 537, 538, 540, 541, 546, 547, 548] 99 | test_2_dirs = [203, 208, 211, 212, 213, 214, 218, 224, 403, 404, 405, 406, 407, 408, 409, 412, 550, 551, 553] 100 | test_3_dirs = [410, 411, 516, 517, 526, 528, 529, 530, 531, 533, 557, 558, 559, 562] 101 | train_dirs = ['009', '059', '002', '033', '020', '035', '018', '119', '120', '025', '205', '047', '007', '013', '004', '143', 102 | '034', '028', '053', '225', '041', '010', '031', '046', '049', '011', '027', '003', '016', '160', '113', '001', '029', '043', 103 | '112', '138', '144', '204', '057', '015', '044', '048', '017', '115', '223', '037', '123', '019', '039', '022'] 104 | 105 | test_1_dirs, test_2_dirs, test_3_dirs = [ '{}'.format(x) for x in test_1_dirs], [ '{}'.format(x) for x in test_2_dirs], [ '{}'.format(x) for x in test_3_dirs] 106 | #all_dirs = os.listdir(root) 107 | #train_dirs = set(all_dirs) - set(test_1_dirs) - set(test_2_dirs) - set(test_3_dirs) - set(['ReadMe.txt', 'extra.zip']) 108 | #train_dirs = list( train_dirs ) 109 | assert len(train_dirs) == 50, 'The length of train_dirs is not right : {}'.format( len(train_dirs) ) 110 | assert len(test_3_dirs) == 14, 'The length of test_3_dirs is not right : {}'.format( len(test_3_dirs) ) 111 | 112 | load_video_dir(root, train_dirs, save_dir, '300VW.train.lst') 113 | load_video_dir(root, test_1_dirs, save_dir, '300VW.test-1.lst') 114 | load_video_dir(root, test_2_dirs, save_dir, '300VW.test-2.lst') 115 | load_video_dir(root, test_3_dirs, save_dir, '300VW.test-3.lst') 116 | 117 | if __name__ == '__main__': 118 | HOME_STR = 'DOME_HOME' 119 | if HOME_STR not in os.environ: HOME_STR = 'HOME' 120 | assert HOME_STR in os.environ, 'Doest not find the HOME dir : {}'.format(HOME_STR) 121 | 122 | this_dir = osp.dirname(os.path.abspath(__file__)) 123 | SAVE_DIR = osp.join(this_dir, 'lists', '300VW') 124 | print ('This dir : {}, HOME : [{}] : {}'.format(this_dir, HOME_STR, os.environ[HOME_STR])) 125 | path_300vw = osp.join(os.environ[HOME_STR], 'datasets', 'landmark-datasets', '300VW_Dataset_2015_12_14') 126 | generate_300vw_list(path_300vw, SAVE_DIR) 127 | -------------------------------------------------------------------------------- /cache_data/generate_300W.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os, pdb, sys, glob 8 | from os import path as osp 9 | from pathlib import Path 10 | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() 11 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 12 | assert sys.version_info.major == 3, 'Please upgrade from {:} to Python 3.x'.format(sys.version_info) 13 | print ('lib-dir : {:}'.format(lib_dir)) 14 | import datasets 15 | from scipy.io import loadmat 16 | from utils.file_utils import load_list_from_folders, load_txt_file 17 | 18 | def load_box(mat_path, cdir): 19 | mat = loadmat(mat_path) 20 | mat = mat['bounding_boxes'] 21 | mat = mat[0] 22 | assert len(mat) > 0, 'The length of this mat file should be greater than 0 vs {}'.format(len(mat)) 23 | all_object = [] 24 | for cobject in mat: 25 | name = cobject[0][0][0][0] 26 | bb_detector = cobject[0][0][1][0] 27 | bb_ground_t = cobject[0][0][2][0] 28 | image_path = osp.join(cdir, name) 29 | image_path = image_path[:-4] 30 | all_object.append( (image_path, bb_detector, bb_ground_t) ) 31 | return all_object 32 | 33 | def load_mats(lists): 34 | all_objects = [] 35 | for dataset in lists: 36 | cobjects = load_box(dataset[0], dataset[1]) 37 | all_objects = all_objects + cobjects 38 | return all_objects 39 | 40 | def return_box(image_path, pts_path, all_dict, USE_BOX): 41 | image_path = image_path[:-4] 42 | assert image_path in all_dict, '{} not find'.format(image_path) 43 | np_boxes = all_dict[ image_path ] 44 | if USE_BOX == 'GTL': 45 | box_str = datasets.dataset_utils.for_generate_box_str(pts_path, 68, 0) 46 | elif USE_BOX == 'GTB': 47 | box_str = '{:.4f} {:.4f} {:.4f} {:.4f}'.format(np_boxes[1][0], np_boxes[1][1], np_boxes[1][2], np_boxes[1][3]) 48 | elif USE_BOX == 'DET': 49 | box_str = '{:.4f} {:.4f} {:.4f} {:.4f}'.format(np_boxes[0][0], np_boxes[0][1], np_boxes[0][2], np_boxes[0][3]) 50 | else: 51 | assert False, 'The box indicator not find : {}'.format(USE_BOX) 52 | return box_str 53 | 54 | def load_all_300w(root_dir): 55 | print ('300W Root Dir : {}'.format(root_dir)) 56 | mat_dir = osp.join(root_dir, 'Bounding_Boxes') 57 | pairs = [(osp.join(mat_dir, 'bounding_boxes_lfpw_testset.mat'), osp.join(root_dir, 'lfpw', 'testset')), 58 | (osp.join(mat_dir, 'bounding_boxes_lfpw_trainset.mat'), osp.join(root_dir, 'lfpw', 'trainset')), 59 | (osp.join(mat_dir, 'bounding_boxes_ibug.mat'), osp.join(root_dir, 'ibug')), 60 | (osp.join(mat_dir, 'bounding_boxes_afw.mat'), osp.join(root_dir, 'afw')), 61 | (osp.join(mat_dir, 'bounding_boxes_helen_testset.mat'), osp.join(root_dir, 'helen', 'testset')), 62 | (osp.join(mat_dir, 'bounding_boxes_helen_trainset.mat'), osp.join(root_dir, 'helen', 'trainset')),] 63 | 64 | all_datas = load_mats(pairs) 65 | data_dict = {} 66 | for i, cpair in enumerate(all_datas): 67 | image_path = cpair[0].replace(' ', '') 68 | data_dict[ image_path ] = (cpair[1], cpair[2]) 69 | return data_dict 70 | 71 | def generate_300w_list(root, save_dir, box_data, SUFFIX): 72 | assert osp.isdir(root), '{} is not dir'.format(root) 73 | #assert osp.isdir(save_dir), '{} is not dir'.format(save_dir) 74 | if not osp.isdir(save_dir): os.makedirs(save_dir) 75 | train_length, common_length, challenge_length = 3148, 554, 135 76 | subsets = ['afw', 'helen', 'ibug', 'lfpw'] 77 | dir_lists = [osp.join(root, subset) for subset in subsets] 78 | imagelist, num_image = load_list_from_folders(dir_lists, ext_filter=['png', 'jpg', 'jpeg'], depth=3) 79 | 80 | train_set, common_set, challenge_set = [], [], [] 81 | for image_path in imagelist: 82 | name, ext = osp.splitext(image_path) 83 | anno_path = name + '.pts' 84 | assert osp.isfile(anno_path), 'annotation for : {} does not exist'.format(image_path) 85 | if name.find('ibug') > 0: 86 | challenge_set.append( (image_path, anno_path) ) 87 | elif name.find('afw') > 0: 88 | train_set.append( (image_path, anno_path) ) 89 | elif name.find('helen') > 0 or name.find('lfpw') > 0: 90 | if name.find('trainset') > 0: 91 | train_set.append( (image_path, anno_path) ) 92 | elif name.find('testset') > 0: 93 | common_set.append( (image_path, anno_path) ) 94 | else: 95 | raise Exception('Unknow name : {}'.format(name)) 96 | else: 97 | raise Exception('Unknow name : {}'.format(name)) 98 | assert len(train_set) == train_length, 'The length is not right for train : {} vs {}'.format(len(train_set), train_length) 99 | assert len(common_set) == common_length, 'The length is not right for common : {} vs {}'.format(len(common_set), common_length) 100 | assert len(challenge_set) == challenge_length, 'The length is not right for challeng : {} vs {}'.format(len(common_set), common_length) 101 | 102 | with open(osp.join(save_dir, '300w.train.' + SUFFIX), 'w') as txtfile: 103 | for cpair in train_set: 104 | #box_str = datasets.dataset_utils.for_generate_box_str(cpair[1], 68, EXPAND_RATIO) 105 | box_str = return_box(cpair[0], cpair[1], box_data, SUFFIX) 106 | txtfile.write('{} {} {}\n'.format(cpair[0], cpair[1], box_str)) 107 | txtfile.close() 108 | 109 | with open(osp.join(save_dir, '300w.test.common.' + SUFFIX), 'w') as txtfile: 110 | for cpair in common_set: 111 | #box_str = datasets.dataset_utils.for_generate_box_str(cpair[1], 68, EXPAND_RATIO) 112 | box_str = return_box(cpair[0], cpair[1], box_data, SUFFIX) 113 | txtfile.write('{} {} {}\n'.format(cpair[0], cpair[1], box_str)) 114 | txtfile.close() 115 | 116 | with open(osp.join(save_dir, '300w.test.challenge.' + SUFFIX), 'w') as txtfile: 117 | for cpair in challenge_set: 118 | #box_str = datasets.dataset_utils.for_generate_box_str(cpair[1], 68, EXPAND_RATIO) 119 | box_str = return_box(cpair[0], cpair[1], box_data, SUFFIX) 120 | txtfile.write('{} {} {}\n'.format(cpair[0], cpair[1], box_str)) 121 | txtfile.close() 122 | 123 | with open(osp.join(save_dir, '300w.test.full.' + SUFFIX), 'w') as txtfile: 124 | fullset = common_set + challenge_set 125 | for cpair in fullset: 126 | #box_str = datasets.dataset_utils.for_generate_box_str(cpair[1], 68, EXPAND_RATIO) 127 | box_str = return_box(cpair[0], cpair[1], box_data, SUFFIX) 128 | txtfile.write('{} {} {}\n'.format(cpair[0], cpair[1], box_str)) 129 | txtfile.close() 130 | 131 | if __name__ == '__main__': 132 | HOME_STR = 'DOME_HOME' 133 | if HOME_STR not in os.environ: HOME_STR = 'HOME' 134 | assert HOME_STR in os.environ, 'Doest not find the HOME dir : {}'.format(HOME_STR) 135 | this_dir = osp.dirname(os.path.abspath(__file__)) 136 | SAVE_DIR = osp.join(this_dir, 'lists', '300W') 137 | print ('This dir : {}, HOME : [{}] : {}'.format(this_dir, HOME_STR, os.environ[HOME_STR])) 138 | path_300w = osp.join( os.environ[HOME_STR], 'datasets', 'landmark-datasets', '300W') 139 | USE_BOXES = ['GTB', 'DET'] 140 | box_datas = load_all_300w(path_300w) 141 | 142 | for USE_BOX in USE_BOXES: 143 | generate_300w_list(path_300w, SAVE_DIR, box_datas, USE_BOX) 144 | -------------------------------------------------------------------------------- /cache_data/init_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """Set up paths.""" 8 | 9 | import sys, os 10 | from os import path as osp 11 | 12 | def add_path(path): 13 | if path not in sys.path: 14 | sys.path.insert(0, path) 15 | this_dir = osp.dirname(osp.abspath(__file__)) 16 | 17 | # Add lib to PYTHONPATH 18 | lib_path = osp.abspath(osp.join(this_dir, '..', 'lib')) 19 | add_path(lib_path) 20 | -------------------------------------------------------------------------------- /configs/Detector.config: -------------------------------------------------------------------------------- 1 | { 2 | "arch" : ["str", "cpm_vgg16"], 3 | "stages" : ["int", "3"], 4 | "dilation": ["int", [1]], 5 | "pooling" : ["bool", [1, 1, 1]], 6 | "downsample": ["int", 8], 7 | "argmax" : ["int", "4"], 8 | "pretrained" : ["bool", [1]] 9 | } 10 | -------------------------------------------------------------------------------- /configs/LK.SGD.config: -------------------------------------------------------------------------------- 1 | { 2 | "optimizer" : ["str", "sgd"], 3 | "LR" : ["float", "0.0001"], 4 | "momentum" : ["float", "0.9"], 5 | "Decay" : ["float", "0.0005"], 6 | "nesterov" : ["bool", "1"], 7 | "criterion" : ["str", "MSE-none"], 8 | "loss_norm" : ["bool", "1"], 9 | "lossnorm" : ["bool", "1"], 10 | "epochs" : ["int", 50], 11 | "schedule" : ["int", [30, 35, 40, 45]], 12 | "gamma" : ["float", 0.5] 13 | } 14 | -------------------------------------------------------------------------------- /configs/SGD.config: -------------------------------------------------------------------------------- 1 | { 2 | "optimizer" : ["str", "sgd"], 3 | "LR" : ["float", "0.00005"], 4 | "momentum" : ["float", "0.9"], 5 | "Decay" : ["float", "0.0005"], 6 | "nesterov" : ["bool", "1"], 7 | "criterion" : ["str", "MSE-none"], 8 | "loss_norm" : ["bool", "1"], 9 | "lossnorm" : ["bool", "1"], 10 | "epochs" : ["int", 50], 11 | "schedule" : ["int", [30, 40]], 12 | "gamma" : ["float", 0.5] 13 | } 14 | -------------------------------------------------------------------------------- /configs/lk.config: -------------------------------------------------------------------------------- 1 | { 2 | "start" : ["int", 0], 3 | "steps" : ["int", 20], 4 | "window" : ["int", 8], 5 | "weight" : ["float", 0.2], 6 | "stable" : ["int", 1], 7 | "conf_thresh" : ["float", 0.3], 8 | "forward_max" : ["float", 2], 9 | "fb_thresh" : ["float", 1], 10 | "eps" : ["float", 0.00001] 11 | } 12 | -------------------------------------------------------------------------------- /configs/mix.lk.config: -------------------------------------------------------------------------------- 1 | { 2 | "start" : ["int", 0], 3 | "steps" : ["int", 20], 4 | "window" : ["int", 8], 5 | "weight" : ["float", 0.1], 6 | "stable" : ["int", 1], 7 | "conf_thresh" : ["float", 0.3], 8 | "forward_max" : ["float", 2], 9 | "fb_thresh" : ["float", 1], 10 | "eps" : ["float", 0.00001] 11 | } 12 | -------------------------------------------------------------------------------- /exps/basic_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import division 8 | 9 | import sys, time, torch, random, argparse, PIL 10 | from PIL import ImageFile 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | from copy import deepcopy 13 | from pathlib import Path 14 | from shutil import copyfile 15 | import numbers, numpy as np 16 | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() 17 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 18 | assert sys.version_info.major == 3, 'Please upgrade from {:} to Python 3.x'.format(sys.version_info) 19 | from config_utils import obtain_basic_args 20 | from procedure import prepare_seed, save_checkpoint, basic_train as train, basic_eval_all as eval_all 21 | from datasets import GeneralDataset as Dataset 22 | from xvision import transforms 23 | from log_utils import Logger, AverageMeter, time_for_file, convert_secs2time, time_string 24 | from config_utils import load_configure 25 | from models import obtain_model 26 | from optimizer import obtain_optimizer 27 | 28 | def main(args): 29 | assert torch.cuda.is_available(), 'CUDA is not available.' 30 | torch.backends.cudnn.enabled = True 31 | torch.backends.cudnn.benchmark = True 32 | prepare_seed(args.rand_seed) 33 | 34 | logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file()) 35 | logger = Logger(args.save_path, logstr) 36 | logger.log('Main Function with logger : {:}'.format(logger)) 37 | logger.log('Arguments : -------------------------------') 38 | for name, value in args._get_kwargs(): 39 | logger.log('{:16} : {:}'.format(name, value)) 40 | logger.log("Python version : {}".format(sys.version.replace('\n', ' '))) 41 | logger.log("Pillow version : {}".format(PIL.__version__)) 42 | logger.log("PyTorch version : {}".format(torch.__version__)) 43 | logger.log("cuDNN version : {}".format(torch.backends.cudnn.version())) 44 | 45 | # General Data Argumentation 46 | mean_fill = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] ) 47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 48 | std=[0.229, 0.224, 0.225]) 49 | assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(args.arg_flip, args.rotate_max) 50 | train_transform = [transforms.PreCrop(args.pre_crop_expand)] 51 | train_transform += [transforms.TrainScale2WH((args.crop_width, args.crop_height))] 52 | train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)] 53 | #if args.arg_flip: 54 | # train_transform += [transforms.AugHorizontalFlip()] 55 | if args.rotate_max: 56 | train_transform += [transforms.AugRotate(args.rotate_max)] 57 | train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)] 58 | train_transform += [transforms.ToTensor(), normalize] 59 | train_transform = transforms.Compose( train_transform ) 60 | 61 | eval_transform = transforms.Compose([transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)), transforms.ToTensor(), normalize]) 62 | assert (args.scale_min+args.scale_max) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(args.scale_min, args.scale_max, args.scale_eval) 63 | 64 | # Model Configure Load 65 | model_config = load_configure(args.model_config, logger) 66 | args.sigma = args.sigma * args.scale_eval 67 | logger.log('Real Sigma : {:}'.format(args.sigma)) 68 | 69 | # Training Dataset 70 | train_data = Dataset(train_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator) 71 | train_data.load_list(args.train_lists, args.num_pts, True) 72 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 73 | 74 | 75 | # Evaluation Dataloader 76 | eval_loaders = [] 77 | if args.eval_vlists is not None: 78 | for eval_vlist in args.eval_vlists: 79 | eval_vdata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator) 80 | eval_vdata.load_list(eval_vlist, args.num_pts, True) 81 | eval_vloader = torch.utils.data.DataLoader(eval_vdata, batch_size=args.batch_size, shuffle=False, 82 | num_workers=args.workers, pin_memory=True) 83 | eval_loaders.append((eval_vloader, True)) 84 | 85 | if args.eval_ilists is not None: 86 | for eval_ilist in args.eval_ilists: 87 | eval_idata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator) 88 | eval_idata.load_list(eval_ilist, args.num_pts, True) 89 | eval_iloader = torch.utils.data.DataLoader(eval_idata, batch_size=args.batch_size, shuffle=False, 90 | num_workers=args.workers, pin_memory=True) 91 | eval_loaders.append((eval_iloader, False)) 92 | 93 | # Define network 94 | logger.log('configure : {:}'.format(model_config)) 95 | net = obtain_model(model_config, args.num_pts + 1) 96 | assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(model_config.downsample, net.downsample) 97 | logger.log("=> network :\n {}".format(net)) 98 | 99 | logger.log('Training-data : {:}'.format(train_data)) 100 | for i, eval_loader in enumerate(eval_loaders): 101 | eval_loader, is_video = eval_loader 102 | logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(i, len(eval_loaders), 'video' if is_video else 'image', eval_loader.dataset)) 103 | 104 | logger.log('arguments : {:}'.format(args)) 105 | 106 | opt_config = load_configure(args.opt_config, logger) 107 | 108 | if hasattr(net, 'specify_parameter'): 109 | net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay) 110 | else: 111 | net_param_dict = net.parameters() 112 | 113 | optimizer, scheduler, criterion = obtain_optimizer(net_param_dict, opt_config, logger) 114 | logger.log('criterion : {:}'.format(criterion)) 115 | net, criterion = net.cuda(), criterion.cuda() 116 | net = torch.nn.DataParallel(net) 117 | 118 | last_info = logger.last_info() 119 | if last_info.exists(): 120 | logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) 121 | last_info = torch.load(last_info) 122 | start_epoch = last_info['epoch'] + 1 123 | checkpoint = torch.load(last_info['last_checkpoint']) 124 | assert last_info['epoch'] == checkpoint['epoch'], 'Last-Info is not right {:} vs {:}'.format(last_info, checkpoint['epoch']) 125 | net.load_state_dict(checkpoint['state_dict']) 126 | optimizer.load_state_dict(checkpoint['optimizer']) 127 | scheduler.load_state_dict(checkpoint['scheduler']) 128 | logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done" .format(logger.last_info(), checkpoint['epoch'])) 129 | else: 130 | logger.log("=> do not find the last-info file : {:}".format(last_info)) 131 | start_epoch = 0 132 | 133 | 134 | if args.eval_once: 135 | logger.log("=> only evaluate the model once") 136 | eval_results = eval_all(args, eval_loaders, net, criterion, 'eval-once', logger, opt_config) 137 | logger.close() ; return 138 | 139 | 140 | # Main Training and Evaluation Loop 141 | start_time = time.time() 142 | epoch_time = AverageMeter() 143 | for epoch in range(start_epoch, opt_config.epochs): 144 | 145 | scheduler.step() 146 | need_time = convert_secs2time(epoch_time.avg * (opt_config.epochs-epoch), True) 147 | epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs) 148 | LRs = scheduler.get_lr() 149 | logger.log('\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), opt_config)) 150 | 151 | # train for one epoch 152 | train_loss, train_nme = train(args, train_loader, net, criterion, optimizer, epoch_str, logger, opt_config) 153 | # log the results 154 | logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}, NME = {:.2f}'.format(time_string(), epoch_str, train_loss, train_nme*100)) 155 | 156 | # remember best prec@1 and save checkpoint 157 | save_path = save_checkpoint({ 158 | 'epoch': epoch, 159 | 'args' : deepcopy(args), 160 | 'arch' : model_config.arch, 161 | 'state_dict': net.state_dict(), 162 | 'detector' : net.state_dict(), 163 | 'scheduler' : scheduler.state_dict(), 164 | 'optimizer' : optimizer.state_dict(), 165 | }, logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger) 166 | 167 | last_info = save_checkpoint({ 168 | 'epoch': epoch, 169 | 'last_checkpoint': save_path, 170 | }, logger.last_info(), logger) 171 | 172 | eval_results = eval_all(args, eval_loaders, net, criterion, epoch_str, logger, opt_config) 173 | logger.log('NME Results : {:}'.format( eval_results )) 174 | 175 | # measure elapsed time 176 | epoch_time.update(time.time() - start_time) 177 | start_time = time.time() 178 | 179 | logger.close() 180 | 181 | if __name__ == '__main__': 182 | args = obtain_basic_args() 183 | main(args) 184 | -------------------------------------------------------------------------------- /exps/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import division 8 | 9 | import sys, time, torch, random, argparse, PIL 10 | from PIL import ImageFile 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | from copy import deepcopy 13 | from pathlib import Path 14 | import numbers, numpy as np 15 | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() 16 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 17 | assert sys.version_info.major == 3, 'Please upgrade from {:} to Python 3.x'.format(sys.version_info) 18 | from datasets import GeneralDataset as Dataset 19 | from xvision import transforms, draw_image_by_points 20 | from models import obtain_model, remove_module_dict 21 | from config_utils import load_configure 22 | 23 | 24 | def evaluate(args): 25 | assert torch.cuda.is_available(), 'CUDA is not available.' 26 | torch.backends.cudnn.enabled = True 27 | torch.backends.cudnn.benchmark = True 28 | 29 | print ('The image is {:}'.format(args.image)) 30 | print ('The model is {:}'.format(args.model)) 31 | snapshot = Path(args.model) 32 | assert snapshot.exists(), 'The model path {:} does not exist' 33 | print ('The face bounding box is {:}'.format(args.face)) 34 | assert len(args.face) == 4, 'Invalid face input : {:}'.format(args.face) 35 | snapshot = torch.load(snapshot) 36 | 37 | # General Data Argumentation 38 | mean_fill = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] ) 39 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 40 | std=[0.229, 0.224, 0.225]) 41 | 42 | param = snapshot['args'] 43 | import pdb; pdb.set_trace() 44 | eval_transform = transforms.Compose([transforms.PreCrop(param.pre_crop_expand), transforms.TrainScale2WH((param.crop_width, param.crop_height)), transforms.ToTensor(), normalize]) 45 | model_config = load_configure(param.model_config, None) 46 | dataset = Dataset(eval_transform, param.sigma, model_config.downsample, param.heatmap_type, param.data_indicator) 47 | dataset.reset(param.num_pts) 48 | 49 | net = obtain_model(model_config, param.num_pts + 1) 50 | net = net.cuda() 51 | weights = remove_module_dict(snapshot['detector']) 52 | net.load_state_dict(weights) 53 | print ('Prepare input data') 54 | [image, _, _, _, _, _, cropped_size], meta = dataset.prepare_input(args.image, args.face) 55 | inputs = image.unsqueeze(0).cuda() 56 | # network forward 57 | with torch.no_grad(): 58 | batch_heatmaps, batch_locs, batch_scos = net(inputs) 59 | # obtain the locations on the image in the orignial size 60 | cpu = torch.device('cpu') 61 | np_batch_locs, np_batch_scos, cropped_size = batch_locs.to(cpu).numpy(), batch_scos.to(cpu).numpy(), cropped_size.numpy() 62 | locations, scores = np_batch_locs[0,:-1,:], np.expand_dims(np_batch_scos[0,:-1], -1) 63 | 64 | scale_h, scale_w = cropped_size[0] * 1. / inputs.size(-2) , cropped_size[1] * 1. / inputs.size(-1) 65 | 66 | locations[:, 0], locations[:, 1] = locations[:, 0] * scale_w + cropped_size[2], locations[:, 1] * scale_h + cropped_size[3] 67 | prediction = np.concatenate((locations, scores), axis=1).transpose(1,0) 68 | 69 | print ('the coordinates for {:} facial landmarks:'.format(param.num_pts)) 70 | for i in range(param.num_pts): 71 | point = prediction[:, i] 72 | print ('the {:02d}/{:02d}-th point : ({:.1f}, {:.1f}), score = {:.2f}'.format(i, param.num_pts, float(point[0]), float(point[1]), float(point[2]))) 73 | 74 | if args.save: 75 | resize = 512 76 | image = draw_image_by_points(args.image, prediction, 2, (255, 0, 0), args.face, resize) 77 | image.save(args.save) 78 | print ('save the visualization results into {:}'.format(args.save)) 79 | else: 80 | print ('ignore the visualization procedure') 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser(description='Evaluate a single image by the trained model', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 85 | parser.add_argument('--image', type=str, help='The evaluation image path.') 86 | parser.add_argument('--model', type=str, help='The snapshot to the saved detector.') 87 | parser.add_argument('--face', nargs='+', type=float, help='The coordinate [x1,y1,x2,y2] of a face') 88 | parser.add_argument('--save', type=str, help='The path to save the visualized results.') 89 | args = parser.parse_args() 90 | evaluate(args) 91 | -------------------------------------------------------------------------------- /exps/lk_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import division 8 | 9 | import sys, time, torch, random, argparse, PIL 10 | from PIL import ImageFile 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | from copy import deepcopy 13 | from pathlib import Path 14 | from shutil import copyfile 15 | import numbers, numpy as np 16 | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() 17 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 18 | assert sys.version_info.major == 3, 'Please upgrade from {:} to Python 3.x'.format(sys.version_info) 19 | from config_utils import obtain_lk_args as obtain_args 20 | from procedure import prepare_seed, save_checkpoint, lk_train as train, basic_eval_all as eval_all 21 | from datasets import VideoDataset as VDataset, GeneralDataset as IDataset 22 | from xvision import transforms 23 | from log_utils import Logger, AverageMeter, time_for_file, convert_secs2time, time_string 24 | from config_utils import load_configure 25 | from models import obtain_LK as obtain_model, remove_module_dict 26 | from optimizer import obtain_optimizer 27 | 28 | def main(args): 29 | assert torch.cuda.is_available(), 'CUDA is not available.' 30 | torch.backends.cudnn.enabled = True 31 | torch.backends.cudnn.benchmark = True 32 | prepare_seed(args.rand_seed) 33 | 34 | logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file()) 35 | logger = Logger(args.save_path, logstr) 36 | logger.log('Main Function with logger : {:}'.format(logger)) 37 | logger.log('Arguments : -------------------------------') 38 | for name, value in args._get_kwargs(): 39 | logger.log('{:16} : {:}'.format(name, value)) 40 | logger.log("Python version : {}".format(sys.version.replace('\n', ' '))) 41 | logger.log("Pillow version : {}".format(PIL.__version__)) 42 | logger.log("PyTorch version : {}".format(torch.__version__)) 43 | logger.log("cuDNN version : {}".format(torch.backends.cudnn.version())) 44 | 45 | # General Data Argumentation 46 | mean_fill = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] ) 47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 48 | std=[0.229, 0.224, 0.225]) 49 | assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(args.arg_flip, args.rotate_max) 50 | train_transform = [transforms.PreCrop(args.pre_crop_expand)] 51 | train_transform += [transforms.TrainScale2WH((args.crop_width, args.crop_height))] 52 | train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)] 53 | #if args.arg_flip: 54 | # train_transform += [transforms.AugHorizontalFlip()] 55 | if args.rotate_max: 56 | train_transform += [transforms.AugRotate(args.rotate_max)] 57 | train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)] 58 | train_transform += [transforms.ToTensor(), normalize] 59 | train_transform = transforms.Compose( train_transform ) 60 | 61 | eval_transform = transforms.Compose([transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)), transforms.ToTensor(), normalize]) 62 | assert (args.scale_min+args.scale_max) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(args.scale_min, args.scale_max, args.scale_eval) 63 | 64 | # Model Configure Load 65 | model_config = load_configure(args.model_config, logger) 66 | args.sigma = args.sigma * args.scale_eval 67 | logger.log('Real Sigma : {:}'.format(args.sigma)) 68 | 69 | # Training Dataset 70 | train_data = VDataset(train_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator, args.video_parser) 71 | train_data.load_list(args.train_lists, args.num_pts, True) 72 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 73 | 74 | 75 | # Evaluation Dataloader 76 | eval_loaders = [] 77 | if args.eval_vlists is not None: 78 | for eval_vlist in args.eval_vlists: 79 | eval_vdata = IDataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator) 80 | eval_vdata.load_list(eval_vlist, args.num_pts, True) 81 | eval_vloader = torch.utils.data.DataLoader(eval_vdata, batch_size=args.batch_size, shuffle=False, 82 | num_workers=args.workers, pin_memory=True) 83 | eval_loaders.append((eval_vloader, True)) 84 | 85 | if args.eval_ilists is not None: 86 | for eval_ilist in args.eval_ilists: 87 | eval_idata = IDataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator) 88 | eval_idata.load_list(eval_ilist, args.num_pts, True) 89 | eval_iloader = torch.utils.data.DataLoader(eval_idata, batch_size=args.batch_size, shuffle=False, 90 | num_workers=args.workers, pin_memory=True) 91 | eval_loaders.append((eval_iloader, False)) 92 | 93 | # Define network 94 | lk_config = load_configure(args.lk_config, logger) 95 | logger.log('model configure : {:}'.format(model_config)) 96 | logger.log('LK configure : {:}'.format(lk_config)) 97 | net = obtain_model(model_config, lk_config, args.num_pts + 1) 98 | assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(model_config.downsample, net.downsample) 99 | logger.log("=> network :\n {}".format(net)) 100 | 101 | logger.log('Training-data : {:}'.format(train_data)) 102 | for i, eval_loader in enumerate(eval_loaders): 103 | eval_loader, is_video = eval_loader 104 | logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(i, len(eval_loaders), 'video' if is_video else 'image', eval_loader.dataset)) 105 | 106 | logger.log('arguments : {:}'.format(args)) 107 | 108 | opt_config = load_configure(args.opt_config, logger) 109 | 110 | if hasattr(net, 'specify_parameter'): 111 | net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay) 112 | else: 113 | net_param_dict = net.parameters() 114 | 115 | optimizer, scheduler, criterion = obtain_optimizer(net_param_dict, opt_config, logger) 116 | logger.log('criterion : {:}'.format(criterion)) 117 | net, criterion = net.cuda(), criterion.cuda() 118 | net = torch.nn.DataParallel(net) 119 | 120 | last_info = logger.last_info() 121 | if last_info.exists(): 122 | logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) 123 | last_info = torch.load(last_info) 124 | start_epoch = last_info['epoch'] + 1 125 | checkpoint = torch.load(last_info['last_checkpoint']) 126 | assert last_info['epoch'] == checkpoint['epoch'], 'Last-Info is not right {:} vs {:}'.format(last_info, checkpoint['epoch']) 127 | net.load_state_dict(checkpoint['state_dict']) 128 | optimizer.load_state_dict(checkpoint['optimizer']) 129 | scheduler.load_state_dict(checkpoint['scheduler']) 130 | logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done" .format(logger.last_info(), checkpoint['epoch'])) 131 | elif args.init_model is not None: 132 | init_model = Path(args.init_model) 133 | assert init_model.exists(), 'init-model {:} does not exist'.format(init_model) 134 | checkpoint = torch.load(init_model) 135 | checkpoint = remove_module_dict(checkpoint['state_dict'], True) 136 | net.module.detector.load_state_dict( checkpoint ) 137 | logger.log("=> initialize the detector : {:}".format(init_model)) 138 | start_epoch = 0 139 | else: 140 | logger.log("=> do not find the last-info file : {:}".format(last_info)) 141 | start_epoch = 0 142 | 143 | detector = torch.nn.DataParallel(net.module.detector) 144 | 145 | eval_results = eval_all(args, eval_loaders, detector, criterion, 'start-eval', logger, opt_config) 146 | if args.eval_once: 147 | logger.log("=> only evaluate the model once") 148 | logger.close() ; return 149 | 150 | # Main Training and Evaluation Loop 151 | start_time = time.time() 152 | epoch_time = AverageMeter() 153 | for epoch in range(start_epoch, opt_config.epochs): 154 | 155 | scheduler.step() 156 | need_time = convert_secs2time(epoch_time.avg * (opt_config.epochs-epoch), True) 157 | epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs) 158 | LRs = scheduler.get_lr() 159 | logger.log('\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), opt_config)) 160 | 161 | # train for one epoch 162 | train_loss = train(args, train_loader, net, criterion, optimizer, epoch_str, logger, opt_config, lk_config, epoch>=lk_config.start) 163 | # log the results 164 | logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}'.format(time_string(), epoch_str, train_loss)) 165 | 166 | # remember best prec@1 and save checkpoint 167 | save_path = save_checkpoint({ 168 | 'epoch': epoch, 169 | 'args' : deepcopy(args), 170 | 'arch' : model_config.arch, 171 | 'state_dict': net.state_dict(), 172 | 'detector' : detector.state_dict(), 173 | 'scheduler' : scheduler.state_dict(), 174 | 'optimizer' : optimizer.state_dict(), 175 | }, logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger) 176 | 177 | last_info = save_checkpoint({ 178 | 'epoch': epoch, 179 | 'last_checkpoint': save_path, 180 | }, logger.last_info(), logger) 181 | 182 | eval_results = eval_all(args, eval_loaders, detector, criterion, epoch_str, logger, opt_config) 183 | 184 | # measure elapsed time 185 | epoch_time.update(time.time() - start_time) 186 | start_time = time.time() 187 | 188 | logger.close() 189 | 190 | if __name__ == '__main__': 191 | args = obtain_args() 192 | main(args) 193 | -------------------------------------------------------------------------------- /exps/vis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import division 8 | 9 | import os, sys, time, random, argparse, PIL 10 | from os import path as osp 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | from copy import deepcopy 14 | from pathlib import Path 15 | import numbers, numpy as np 16 | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() 17 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 18 | assert sys.version_info.major == 3, 'Please upgrade from {:} to Python 3.x'.format(sys.version_info) 19 | from xvision import draw_image_by_points 20 | from xvision import Eval_Meta 21 | 22 | def visualize(args): 23 | 24 | print ('The result file is {:}'.format(args.meta)) 25 | print ('The save path is {:}'.format(args.save)) 26 | meta = Path(args.meta) 27 | save = Path(args.save) 28 | assert meta.exists(), 'The model path {:} does not exist' 29 | xmeta = Eval_Meta() 30 | xmeta.load(meta) 31 | print ('this meta file has {:} predictions'.format(len(xmeta))) 32 | if not save.exists(): os.makedirs( args.save ) 33 | for i in range(len(xmeta)): 34 | image, prediction = xmeta.image_lists[i], xmeta.predictions[i] 35 | name = osp.basename(image) 36 | image = draw_image_by_points(image, prediction, 2, (255, 0, 0), False, False) 37 | path = save / name 38 | image.save(path) 39 | print ('{:03d}-th image is saved into {:}'.format(i, path)) 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = argparse.ArgumentParser(description='visualize the results on a single ', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 44 | parser.add_argument('--meta', type=str, help='The evaluation image path.') 45 | parser.add_argument('--save', type=str, help='The path to save the visualized results.') 46 | args = parser.parse_args() 47 | visualize(args) 48 | -------------------------------------------------------------------------------- /lib/config_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .configure_utils import load_configure 8 | from .basic_args import obtain_args as obtain_basic_args 9 | from .lk_args import obtain_args as obtain_lk_args 10 | -------------------------------------------------------------------------------- /lib/config_utils/basic_args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os, sys, time, random, argparse 8 | 9 | def obtain_args(): 10 | parser = argparse.ArgumentParser(description='Train facial landmark detectors on 300-W or AFLW', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 11 | parser.add_argument('--train_lists', type=str, nargs='+', help='The list file path to the video training dataset.') 12 | parser.add_argument('--eval_vlists', type=str, nargs='+', help='The list file path to the video testing dataset.') 13 | parser.add_argument('--eval_ilists', type=str, nargs='+', help='The list file path to the image testing dataset.') 14 | parser.add_argument('--num_pts', type=int, help='Number of point.') 15 | parser.add_argument('--model_config', type=str, help='The path to the model configuration') 16 | parser.add_argument('--opt_config', type=str, help='The path to the optimizer configuration') 17 | # Data Generation 18 | parser.add_argument('--heatmap_type', type=str, choices=['gaussian','laplacian'], help='The method for generating the heatmap.') 19 | parser.add_argument('--data_indicator', type=str, default='300W-68',help='The method for generating the heatmap.') 20 | # Data Transform 21 | parser.add_argument('--pre_crop_expand', type=float, help='parameters for pre-crop expand ratio') 22 | parser.add_argument('--sigma', type=float, help='sigma distance for CPM.') 23 | parser.add_argument('--scale_prob', type=float, help='argument scale probability.') 24 | parser.add_argument('--scale_min', type=float, help='argument scale : minimum scale factor.') 25 | parser.add_argument('--scale_max', type=float, help='argument scale : maximum scale factor.') 26 | parser.add_argument('--scale_eval', type=float, help='argument scale : maximum scale factor.') 27 | parser.add_argument('--rotate_max', type=int, help='argument rotate : maximum rotate degree.') 28 | parser.add_argument('--crop_height', type=int, default=256, help='argument crop : crop height.') 29 | parser.add_argument('--crop_width', type=int, default=256, help='argument crop : crop width.') 30 | parser.add_argument('--crop_perturb_max', type=int, help='argument crop : center of maximum perturb distance.') 31 | parser.add_argument('--arg_flip', action='store_true', help='Using flip data argumentation or not ') 32 | # Optimization options 33 | parser.add_argument('--eval_once', action='store_true', help='evaluation only once for evaluation ') 34 | parser.add_argument('--error_bar', type=float, help='For drawing the image with large distance error.') 35 | parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') 36 | # Checkpoints 37 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency (default: 200)') 38 | parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.') 39 | # Acceleration 40 | parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 2)') 41 | # Random Seed 42 | parser.add_argument('--rand_seed', type=int, help='manual seed') 43 | args = parser.parse_args() 44 | 45 | if args.rand_seed is None: 46 | args.rand_seed = random.randint(1, 100000) 47 | assert args.save_path is not None, 'save-path argument can not be None' 48 | 49 | #state = {k: v for k, v in args._get_kwargs()} 50 | #Arguments = namedtuple('Arguments', ' '.join(state.keys())) 51 | #arguments = Arguments(**state) 52 | return args 53 | -------------------------------------------------------------------------------- /lib/config_utils/configure_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os, sys, json 8 | from pathlib import Path 9 | from collections import namedtuple 10 | 11 | support_types = ('str', 'int', 'bool', 'float') 12 | 13 | def convert_param(original_lists): 14 | assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists) 15 | ctype, value = original_lists[0], original_lists[1] 16 | assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types) 17 | is_list = isinstance(value, list) 18 | if not is_list: value = [value] 19 | outs = [] 20 | for x in value: 21 | if ctype == 'int': 22 | x = int(x) 23 | elif ctype == 'str': 24 | x = str(x) 25 | elif ctype == 'bool': 26 | x = bool(int(x)) 27 | elif ctype == 'float': 28 | x = float(x) 29 | else: 30 | raise TypeError('Does not know this type : {:}'.format(ctype)) 31 | outs.append(x) 32 | if not is_list: outs = outs[0] 33 | return outs 34 | 35 | def load_configure(path, logger): 36 | path = str(path) 37 | if logger is not None: logger.log(path) 38 | assert os.path.exists(path), 'Can not find {:}'.format(path) 39 | # Reading data back 40 | with open(path, 'r') as f: 41 | data = json.load(f) 42 | f.close() 43 | content = { k: convert_param(v) for k,v in data.items()} 44 | Arguments = namedtuple('Configure', ' '.join(content.keys())) 45 | content = Arguments(**content) 46 | if logger is not None: logger.log('{:}'.format(content)) 47 | return content 48 | -------------------------------------------------------------------------------- /lib/config_utils/lk_args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os, sys, time, random, argparse 8 | 9 | def obtain_args(): 10 | parser = argparse.ArgumentParser(description='Train facial landmark detectors on 300-W, AFLW or Mugsy', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 11 | parser.add_argument('--train_lists', type=str, nargs='+', help='The list file path to the video training dataset.') 12 | parser.add_argument('--eval_vlists', type=str, nargs='+', help='The list file path to the video testing dataset.') 13 | parser.add_argument('--eval_ilists', type=str, nargs='+', help='The list file path to the image testing dataset.') 14 | parser.add_argument('--num_pts', type=int, help='Number of point.') 15 | parser.add_argument('--model_config', type=str, help='The path to the model configuration') 16 | parser.add_argument('--opt_config', type=str, help='The path to the optimizer configuration') 17 | parser.add_argument('--lk_config', type=str, help='The path to the LK configuration') 18 | # Data Generation 19 | parser.add_argument('--heatmap_type', type=str, choices=['gaussian','laplacian'], help='The method for generating the heatmap.') 20 | parser.add_argument('--data_indicator', type=str, default='300W-68',help='The dataset indicator.') 21 | parser.add_argument('--video_parser', type=str, help='The video-parser indicator.') 22 | # Data Transform 23 | parser.add_argument('--pre_crop_expand', type=float, help='parameters for pre-crop expand ratio') 24 | parser.add_argument('--sigma', type=float, help='sigma distance for CPM.') 25 | parser.add_argument('--scale_prob', type=float, help='argument scale probability.') 26 | parser.add_argument('--scale_min', type=float, help='argument scale : minimum scale factor.') 27 | parser.add_argument('--scale_max', type=float, help='argument scale : maximum scale factor.') 28 | parser.add_argument('--scale_eval', type=float, help='argument scale : maximum scale factor.') 29 | parser.add_argument('--rotate_max', type=int, help='argument rotate : maximum rotate degree.') 30 | parser.add_argument('--crop_height', type=int, default=256, help='argument crop : crop height.') 31 | parser.add_argument('--crop_width', type=int, default=256, help='argument crop : crop width.') 32 | parser.add_argument('--crop_perturb_max', type=int, help='argument crop : center of maximum perturb distance.') 33 | parser.add_argument('--arg_flip', action='store_true', help='Using flip data argumentation or not ') 34 | # Optimization options 35 | parser.add_argument('--eval_once', action='store_true', help='evaluation only once for evaluation ') 36 | parser.add_argument('--error_bar', type=float, help='For drawing the image with large distance error.') 37 | parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') 38 | # Checkpoints 39 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency (default: 200)') 40 | parser.add_argument('--init_model', type=str, help='The detector model to be initalized.') 41 | parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.') 42 | # Acceleration 43 | parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 2)') 44 | # Random Seed 45 | parser.add_argument('--rand_seed', type=int, help='manual seed') 46 | args = parser.parse_args() 47 | 48 | if args.rand_seed is None: 49 | args.rand_seed = random.randint(1, 100000) 50 | assert args.save_path is not None, 'save-path argument can not be None' 51 | 52 | #state = {k: v for k, v in args._get_kwargs()} 53 | #Arguments = namedtuple('Arguments', ' '.join(state.keys())) 54 | #arguments = Arguments(**state) 55 | return args 56 | -------------------------------------------------------------------------------- /lib/datasets/GeneralDataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import print_function 8 | from PIL import Image 9 | from os import path as osp 10 | import numpy as np 11 | import math 12 | 13 | from pts_utils import generate_label_map 14 | from .file_utils import load_file_lists 15 | from .dataset_utils import pil_loader 16 | from .dataset_utils import anno_parser 17 | from .point_meta import Point_Meta 18 | import torch 19 | import torch.utils.data as data 20 | 21 | class GeneralDataset(data.Dataset): 22 | 23 | def __init__(self, transform, sigma, downsample, heatmap_type, data_indicator): 24 | 25 | self.transform = transform 26 | self.sigma = sigma 27 | self.downsample = downsample 28 | self.heatmap_type = heatmap_type 29 | self.dataset_name = data_indicator 30 | 31 | self.reset() 32 | print ('The general dataset initialization done : {:}'.format(self)) 33 | 34 | def __repr__(self): 35 | return ('{name}(point-num={NUM_PTS}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, dataset={dataset_name})'.format(name=self.__class__.__name__, **self.__dict__)) 36 | 37 | def reset(self, num_pts=-1): 38 | self.length = 0 39 | self.NUM_PTS = num_pts 40 | self.datas = [] 41 | self.labels = [] 42 | self.face_sizes = [] 43 | assert self.dataset_name is not None, 'The dataset name is None' 44 | 45 | def __len__(self): 46 | assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length) 47 | return self.length 48 | 49 | def append(self, data, label, box, face_size): 50 | assert osp.isfile(data), 'The image path is not a file : {}'.format(data) 51 | self.datas.append( data ) 52 | if (label is not None) and (label.lower() != 'none'): 53 | if isinstance(label, str): 54 | assert osp.isfile(label), 'The annotation path is not a file : {}'.format(label) 55 | np_points, _ = anno_parser(label, self.NUM_PTS) 56 | meta = Point_Meta(self.NUM_PTS, np_points, box, data, self.dataset_name) 57 | elif isinstance(label, Point_Meta): 58 | meta = label.copy() 59 | else: 60 | raise NameError('Do not know this label : {}'.format(label)) 61 | else: 62 | meta = Point_Meta(self.NUM_PTS, None, box, data, self.dataset_name) 63 | self.labels.append( meta ) 64 | self.face_sizes.append( face_size ) 65 | self.length = self.length + 1 66 | 67 | def prepare_input(self, image, box): 68 | meta = Point_Meta(self.NUM_PTS, None, np.array(box), image, self.dataset_name) 69 | image = pil_loader( image ) 70 | return self._process_(image, meta, -1), meta 71 | 72 | def load_data(self, datas, labels, boxes, face_sizes, num_pts, reset): 73 | # each data is a png file name 74 | # each label is a Point_Meta class or the general pts format file (anno_parser_v1) 75 | assert isinstance(datas, list), 'The type of the datas is not correct : {}'.format( type(datas) ) 76 | assert isinstance(labels, list) and len(datas) == len(labels), 'The type of the labels is not correct : {}'.format( type(labels) ) 77 | assert isinstance(boxes, list) and len(datas) == len(boxes), 'The type of the boxes is not correct : {}'.format( type(boxes) ) 78 | assert isinstance(face_sizes, list) and len(datas) == len(face_sizes), 'The type of the face_sizes is not correct : {}'.format( type(face_sizes) ) 79 | if reset: self.reset(num_pts) 80 | else: assert self.NUM_PTS == num_pts, 'The number of point is inconsistance : {} vs {}'.format(self.NUM_PTS, num_pts) 81 | 82 | print ('[GeneralDataset] load-data {:} datas begin'.format(len(datas))) 83 | 84 | for idx, data in enumerate(datas): 85 | assert isinstance(data, str), 'The type of data is not correct : {}'.format(data) 86 | assert osp.isfile(datas[idx]), '{} is not a file'.format(datas[idx]) 87 | self.append(datas[idx], labels[idx], boxes[idx], face_sizes[idx]) 88 | 89 | assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas)) 90 | assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels)) 91 | assert len(self.face_sizes) == self.length, 'The length and the face_sizes is not right {} vs {}'.format(self.length, len(self.face_sizes)) 92 | print ('Load data done for the general dataset, which has {} images.'.format(self.length)) 93 | 94 | def load_list(self, file_lists, num_pts, reset): 95 | lists = load_file_lists(file_lists) 96 | print ('GeneralDataset : load-list : load {:} lines'.format(len(lists))) 97 | 98 | datas, labels, boxes, face_sizes = [], [], [], [] 99 | 100 | for idx, data in enumerate(lists): 101 | alls = [x for x in data.split(' ') if x != ''] 102 | 103 | assert len(alls) == 6 or len(alls) == 7, 'The {:04d}-th line in {:} is wrong : {:}'.format(idx, data) 104 | datas.append( alls[0] ) 105 | if alls[1] == 'None': 106 | labels.append( None ) 107 | else: 108 | labels.append( alls[1] ) 109 | box = np.array( [ float(alls[2]), float(alls[3]), float(alls[4]), float(alls[5]) ] ) 110 | boxes.append( box ) 111 | if len(alls) == 6: 112 | face_sizes.append( None ) 113 | else: 114 | face_sizes.append( float(alls[6]) ) 115 | self.load_data(datas, labels, boxes, face_sizes, num_pts, reset) 116 | 117 | def __getitem__(self, index): 118 | assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index) 119 | image = pil_loader( self.datas[index] ) 120 | target = self.labels[index].copy() 121 | return self._process_(image, target, index) 122 | 123 | def _process_(self, image, target, index): 124 | 125 | # transform the image and points 126 | if self.transform is not None: 127 | image, target = self.transform(image, target) 128 | 129 | # obtain the visiable indicator vector 130 | if target.is_none(): nopoints = True 131 | else : nopoints = False 132 | 133 | # If for evaluation not load label, keeps the original data 134 | temp_save_wh = target.temp_save_wh 135 | ori_size = torch.IntTensor( [temp_save_wh[1], temp_save_wh[0], temp_save_wh[2], temp_save_wh[3]] ) # H, W, Cropped_[x1,y1] 136 | 137 | if isinstance(image, Image.Image): 138 | height, width = image.size[1], image.size[0] 139 | elif isinstance(image, torch.FloatTensor): 140 | height, width = image.size(1), image.size(2) 141 | else: 142 | raise Exception('Unknown type of image : {}'.format( type(image) )) 143 | 144 | if target.is_none() == False: 145 | target.apply_bound(width, height) 146 | points = target.points.copy() 147 | points = torch.from_numpy(points.transpose((1,0))).type(torch.FloatTensor) 148 | Hpoint = target.points.copy() 149 | else: 150 | points = torch.from_numpy(np.zeros((self.NUM_PTS,3))).type(torch.FloatTensor) 151 | Hpoint = np.zeros((3, self.NUM_PTS)) 152 | 153 | heatmaps, mask = generate_label_map(Hpoint, height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C 154 | 155 | heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor) 156 | mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) 157 | 158 | torch_index = torch.IntTensor([index]) 159 | torch_nopoints = torch.ByteTensor( [ nopoints ] ) 160 | 161 | return image, heatmaps, mask, points, torch_index, torch_nopoints, ori_size 162 | -------------------------------------------------------------------------------- /lib/datasets/VideoDataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import print_function 8 | from PIL import Image 9 | from os import path as osp 10 | import numpy as np 11 | import math 12 | 13 | from pts_utils import generate_label_map 14 | from .file_utils import load_file_lists 15 | from .dataset_utils import pil_loader 16 | from .dataset_utils import anno_parser 17 | from .point_meta import Point_Meta 18 | from .parse_utils import parse_video_by_indicator 19 | import torch 20 | import torch.utils.data as data 21 | 22 | class VideoDataset(data.Dataset): 23 | 24 | def __init__(self, transform, sigma, downsample, heatmap_type, data_indicator, video_parser): 25 | 26 | self.transform = transform 27 | self.sigma = sigma 28 | self.downsample = downsample 29 | self.heatmap_type = heatmap_type 30 | self.dataset_name = data_indicator 31 | self.video_parser = video_parser 32 | L, R = parse_video_by_indicator(None, self.video_parser, True) 33 | self.video_length = L + R + 1 34 | self.center_idx = L 35 | 36 | self.reset() 37 | print ('The general dataset initialization done : {:}'.format(self)) 38 | 39 | def __repr__(self): 40 | return ('{name}(point-num={NUM_PTS}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, dataset={dataset_name}, parser={video_parser})'.format(name=self.__class__.__name__, **self.__dict__)) 41 | 42 | def reset(self, num_pts=-1): 43 | self.length = 0 44 | self.NUM_PTS = num_pts 45 | self.datas = [] 46 | self.labels = [] 47 | self.face_sizes = [] 48 | assert self.dataset_name is not None, 'The dataset name is None' 49 | 50 | def __len__(self): 51 | assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length) 52 | return self.length 53 | 54 | def append(self, data, label, box, face_size): 55 | assert osp.isfile(data), 'The image path is not a file : {}'.format(data) 56 | self.datas.append( data ) 57 | if (label is not None) and (label.lower() != 'none'): 58 | if isinstance(label, str): 59 | assert osp.isfile(label), 'The annotation path is not a file : {}'.format(label) 60 | np_points, _ = anno_parser(label, self.NUM_PTS) 61 | meta = Point_Meta(self.NUM_PTS, np_points, box, data, self.dataset_name) 62 | elif isinstance(label, Point_Meta): 63 | meta = label.copy() 64 | else: 65 | raise NameError('Do not know this label : {}'.format(label)) 66 | else: 67 | meta = Point_Meta(self.NUM_PTS, None, box, data, self.dataset_name) 68 | self.labels.append( meta ) 69 | self.face_sizes.append( face_size ) 70 | self.length = self.length + 1 71 | 72 | def load_data(self, datas, labels, boxes, face_sizes, num_pts, reset): 73 | # each data is a png file name 74 | # each label is a Point_Meta class or the general pts format file (anno_parser_v1) 75 | assert isinstance(datas, list), 'The type of the datas is not correct : {}'.format( type(datas) ) 76 | assert isinstance(labels, list) and len(datas) == len(labels), 'The type of the labels is not correct : {}'.format( type(labels) ) 77 | assert isinstance(boxes, list) and len(datas) == len(boxes), 'The type of the boxes is not correct : {}'.format( type(boxes) ) 78 | assert isinstance(face_sizes, list) and len(datas) == len(face_sizes), 'The type of the face_sizes is not correct : {}'.format( type(face_sizes) ) 79 | if reset: self.reset(num_pts) 80 | else: assert self.NUM_PTS == num_pts, 'The number of point is inconsistance : {} vs {}'.format(self.NUM_PTS, num_pts) 81 | 82 | print ('[GeneralDataset] load-data {:} datas begin'.format(len(datas))) 83 | 84 | for idx, data in enumerate(datas): 85 | assert isinstance(data, str), 'The type of data is not correct : {}'.format(data) 86 | assert osp.isfile(datas[idx]), '{} is not a file'.format(datas[idx]) 87 | self.append(datas[idx], labels[idx], boxes[idx], face_sizes[idx]) 88 | 89 | assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas)) 90 | assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels)) 91 | assert len(self.face_sizes) == self.length, 'The length and the face_sizes is not right {} vs {}'.format(self.length, len(self.face_sizes)) 92 | print ('Load data done for the general dataset, which has {} images.'.format(self.length)) 93 | 94 | def load_list(self, file_lists, num_pts, reset): 95 | lists = load_file_lists(file_lists) 96 | print ('GeneralDataset : load-list : load {:} lines'.format(len(lists))) 97 | 98 | datas, labels, boxes, face_sizes = [], [], [], [] 99 | 100 | for idx, data in enumerate(lists): 101 | alls = [x for x in data.split(' ') if x != ''] 102 | 103 | assert len(alls) == 6 or len(alls) == 7, 'The {:04d}-th line in {:} is wrong : {:}'.format(idx, data) 104 | datas.append( alls[0] ) 105 | if alls[1] == 'None': 106 | labels.append( None ) 107 | else: 108 | labels.append( alls[1] ) 109 | box = np.array( [ float(alls[2]), float(alls[3]), float(alls[4]), float(alls[5]) ] ) 110 | boxes.append( box ) 111 | if len(alls) == 6: 112 | face_sizes.append( None ) 113 | else: 114 | face_sizes.append( float(alls[6]) ) 115 | self.load_data(datas, labels, boxes, face_sizes, num_pts, reset) 116 | 117 | def __getitem__(self, index): 118 | assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index) 119 | images, is_video_or_not = parse_video_by_indicator(self.datas[index], self.video_parser, False) 120 | images = [pil_loader(image) for image in images] 121 | 122 | target = self.labels[index].copy() 123 | 124 | # transform the image and points 125 | if self.transform is not None: 126 | images, target = self.transform(images, target) 127 | 128 | # obtain the visiable indicator vector 129 | if target.is_none(): nopoints = True 130 | else : nopoints = False 131 | 132 | # If for evaluation not load label, keeps the original data 133 | temp_save_wh = target.temp_save_wh 134 | ori_size = torch.IntTensor( [temp_save_wh[1], temp_save_wh[0], temp_save_wh[2], temp_save_wh[3]] ) # H, W, Cropped_[x1,y1] 135 | 136 | if isinstance(images[0], Image.Image): 137 | height, width = images[0].size[1], images[0].size[0] 138 | elif isinstance(images[0], torch.FloatTensor): 139 | height, width = images[0].size(1), images[0].size(2) 140 | else: 141 | raise Exception('Unknown type of image : {}'.format( type(images[0]) )) 142 | 143 | if target.is_none() == False: 144 | target.apply_bound(width, height) 145 | points = target.points.copy() 146 | points = torch.from_numpy(points.transpose((1,0))).type(torch.FloatTensor) 147 | Hpoint = target.points.copy() 148 | else: 149 | points = torch.from_numpy(np.zeros((self.NUM_PTS,3))).type(torch.FloatTensor) 150 | Hpoint = np.zeros((3, self.NUM_PTS)) 151 | 152 | heatmaps, mask = generate_label_map(Hpoint, height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C 153 | 154 | heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor) 155 | mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) 156 | 157 | torch_index = torch.IntTensor([index]) 158 | torch_nopoints = torch.ByteTensor( [ nopoints ] ) 159 | video_indicator = torch.ByteTensor( [is_video_or_not] ) 160 | 161 | return torch.stack(images), heatmaps, mask, points, torch_index, torch_nopoints, video_indicator, ori_size 162 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .GeneralDataset import GeneralDataset 8 | from .VideoDataset import VideoDataset 9 | from .dataset_utils import pil_loader 10 | from .point_meta import Point_Meta 11 | from .dataset_utils import PTSconvert2str 12 | from .dataset_utils import PTSconvert2box 13 | from .dataset_utils import merge_lists_from_file 14 | -------------------------------------------------------------------------------- /lib/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from os import path as osp 8 | from PIL import Image 9 | from scipy.ndimage.interpolation import zoom 10 | from utils.file_utils import load_txt_file 11 | import numpy as np 12 | import copy, math, pdb 13 | 14 | def pil_loader(path): 15 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 16 | with open(path, 'rb') as f: 17 | with Image.open(f) as img: 18 | return img.convert('RGB') 19 | 20 | def remove_item_from_list(list_to_remove, item): 21 | ''' 22 | remove a single item from a list 23 | ''' 24 | assert isinstance(list_to_remove, list), 'input list is not a list' 25 | 26 | try: 27 | list_to_remove.remove(item) 28 | except ValueError: 29 | print('Warning!!!!!! Item to remove is not in the list. Remove operation is not done.') 30 | 31 | return list_to_remove 32 | 33 | def anno_parser(anno_path, num_pts): 34 | data, num_lines = load_txt_file(anno_path) 35 | if data[0].find('version: ') == 0: # 300-W 36 | return anno_parser_v0(anno_path, num_pts) 37 | else: 38 | return anno_parser_v1(anno_path, num_pts) 39 | 40 | def anno_parser_v0(anno_path, num_pts): 41 | ''' 42 | parse the annotation for 300W dataset, which has a fixed format for .pts file 43 | return: 44 | pts: 3 x num_pts (x, y, oculusion) 45 | ''' 46 | data, num_lines = load_txt_file(anno_path) 47 | assert data[0].find('version: ') == 0, 'version is not correct' 48 | assert data[1].find('n_points: ') == 0, 'number of points in second line is not correct' 49 | assert data[2] == '{' and data[-1] == '}', 'starting and end symbol is not correct' 50 | 51 | assert data[0] == 'version: 1' or data[0] == 'version: 1.0', 'The version is wrong : {}'.format(data[0]) 52 | n_points = int(data[1][len('n_points: '):]) 53 | 54 | assert num_lines == n_points + 4, 'number of lines is not correct' # 4 lines for general information: version, n_points, start and end symbol 55 | assert num_pts == n_points, 'number of points is not correct' 56 | 57 | # read points coordinate 58 | pts = np.zeros((3, n_points), dtype='float32') 59 | line_offset = 3 # first point starts at fourth line 60 | point_set = set() 61 | for point_index in range(n_points): 62 | try: 63 | pts_list = data[point_index + line_offset].split(' ') # x y format 64 | if len(pts_list) > 2: # handle edge case where additional whitespace exists after point coordinates 65 | pts_list = remove_item_from_list(pts_list, '') 66 | pts[0, point_index] = float(pts_list[0]) 67 | pts[1, point_index] = float(pts_list[1]) 68 | pts[2, point_index] = float(1) # oculusion flag, 0: oculuded, 1: visible. We use 1 for all points since no visibility is provided by 300-W 69 | point_set.add( point_index ) 70 | except ValueError: 71 | print('error in loading points in %s' % anno_path) 72 | return pts, point_set 73 | 74 | def anno_parser_v1(anno_path, NUM_PTS, one_base=True): 75 | ''' 76 | parse the annotation for MUGSY-Full-Face dataset, which has a fixed format for .pts file 77 | return: pts: 3 x num_pts (x, y, oculusion) 78 | ''' 79 | data, n_points = load_txt_file(anno_path) 80 | assert n_points <= NUM_PTS, '{} has {} points'.format(anno_path, n_points) 81 | # read points coordinate 82 | pts = np.zeros((3, NUM_PTS), dtype='float32') 83 | point_set = set() 84 | for line in data: 85 | try: 86 | idx, point_x, point_y, oculusion = line.split(' ') 87 | idx, point_x, point_y, oculusion = int(idx), float(point_x), float(point_y), oculusion == 'True' 88 | if one_base==False: idx = idx+1 89 | assert idx >= 1 and idx <= NUM_PTS, 'Wrong idx of points : {:02d}-th in {:s}'.format(idx, anno_path) 90 | pts[0, idx-1] = point_x 91 | pts[1, idx-1] = point_y 92 | pts[2, idx-1] = float( oculusion ) 93 | point_set.add(idx) 94 | except ValueError: 95 | raise Exception('error in loading points in {}'.format(anno_path)) 96 | return pts, point_set 97 | 98 | def PTSconvert2str(points): 99 | assert isinstance(points, np.ndarray) and len(points.shape) == 2, 'The points is not right : {}'.format(points) 100 | assert points.shape[0] == 2 or points.shape[0] == 3, 'The shape of points is not right : {}'.format(points.shape) 101 | string = '' 102 | num_pts = points.shape[1] 103 | for i in range(num_pts): 104 | ok = False 105 | if points.shape[0] == 3 and bool(points[2, i]) == True: 106 | ok = True 107 | elif points.shape[0] == 2: 108 | ok = True 109 | 110 | if ok: 111 | string = string + '{:02d} {:.2f} {:.2f} True\n'.format(i+1, points[0, i], points[1, i]) 112 | string = string[:-1] 113 | return string 114 | 115 | def PTSconvert2box(points, expand_ratio=None): 116 | assert isinstance(points, np.ndarray) and len(points.shape) == 2, 'The points is not right : {}'.format(points) 117 | assert points.shape[0] == 2 or points.shape[0] == 3, 'The shape of points is not right : {}'.format(points.shape) 118 | if points.shape[0] == 3: 119 | points = points[:2, points[-1,:].astype('bool') ] 120 | elif points.shape[0] == 2: 121 | points = points[:2, :] 122 | else: 123 | raise Exception('The shape of points is not right : {}'.format(points.shape)) 124 | assert points.shape[1] >= 2, 'To get the box of points, there should be at least 2 vs {}'.format(points.shape) 125 | box = np.array([ points[0,:].min(), points[1,:].min(), points[0,:].max(), points[1,:].max() ]) 126 | W = box[2] - box[0] 127 | H = box[3] - box[1] 128 | assert W > 0 and H > 0, 'The size of box should be greater than 0 vs {}'.format(box) 129 | if expand_ratio is not None: 130 | box[0] = int( math.floor(box[0] - W * expand_ratio) ) 131 | box[1] = int( math.floor(box[1] - H * expand_ratio) ) 132 | box[2] = int( math.ceil(box[2] + W * expand_ratio) ) 133 | box[3] = int( math.ceil(box[3] + H * expand_ratio) ) 134 | return box 135 | 136 | def for_generate_box_str(anno_path, num_pts, extend): 137 | if isinstance(anno_path, str): 138 | points, _ = anno_parser(anno_path, num_pts) 139 | else: 140 | points = anno_path.copy() 141 | box = PTSconvert2box(points, extend) 142 | return '{:.2f} {:.2f} {:.2f} {:.2f}'.format(box[0], box[1], box[2], box[3]) 143 | 144 | def resize_heatmap(maps, height, width, order=3): 145 | # maps = np.ndarray with shape [height, width, channels] 146 | # order = 0 Nearest 147 | # order = 1 Bilinear 148 | # order = 2 Cubic 149 | assert isinstance(maps, np.ndarray) and len(maps.shape) == 3, 'maps type : {}'.format(type(maps)) 150 | 151 | scale = tuple(np.array([height,width], dtype=float) / np.array(maps.shape[:2])) 152 | return zoom(maps, scale + (1,), order=order) 153 | 154 | def analysis_dataset(dataset): 155 | all_values = np.zeros((3,len(dataset.datas)), dtype=np.float64) 156 | hs = np.zeros((len(dataset.datas),), dtype=np.float64) 157 | ws = np.zeros((len(dataset.datas),), dtype=np.float64) 158 | 159 | for index, image_path in enumerate(dataset.datas): 160 | img = pil_loader(image_path) 161 | ws[index] = img.size[0] 162 | hs[index] = img.size[1] 163 | img = np.array(img) 164 | all_values[:, index] = np.mean(np.mean(img, axis=0), axis=0).astype('float64') 165 | mean = np.mean(all_values, axis=1) 166 | std = np.std (all_values, axis=1) 167 | return mean, std, ws, hs 168 | 169 | def split_datasets(dataset, point_ids): 170 | sub_dataset = copy.deepcopy(dataset) 171 | assert len(point_ids) > 0 172 | assert False, 'un finished' 173 | 174 | def convert68to49(points): 175 | points = points.copy() 176 | assert len(points.shape) == 2 and (points.shape[0] == 3 or points.shape[0] == 2) and points.shape[1] == 68, 'The shape of points is not right : {}'.format(points.shape) 177 | out = np.ones((68,)).astype('bool') 178 | out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,60,64]] = False 179 | cpoints = points[:, out] 180 | assert len(cpoints.shape) == 2 and cpoints.shape[1] == 49 181 | return cpoints 182 | 183 | def convert68to51(points): 184 | points = points.copy() 185 | assert len(points.shape) == 2 and (points.shape[0] == 3 or points.shape[0] == 2) and points.shape[1] == 68, 'The shape of points is not right : {}'.format(points.shape) 186 | out = np.ones((68,)).astype('bool') 187 | out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]] = False 188 | cpoints = points[:, out] 189 | assert len(cpoints.shape) == 2 and cpoints.shape[1] == 51 190 | return cpoints 191 | 192 | def merge_lists_from_file(file_paths, seed=None): 193 | assert file_paths is not None, 'The input can not be None' 194 | if isinstance(file_paths, str): 195 | file_paths = [ file_paths ] 196 | print ('merge lists from {} files with seed={} for random shuffle'.format(len(file_paths), seed)) 197 | # load the data 198 | all_data = [] 199 | for file_path in file_paths: 200 | assert osp.isfile(file_path), '{} does not exist'.format(file_path) 201 | listfile = open(file_path, 'r') 202 | listdata = listfile.read().splitlines() 203 | listfile.close() 204 | all_data = all_data + listdata 205 | total = len(all_data) 206 | print ('merge all the lists done, total : {}'.format(total)) 207 | # random shuffle 208 | if seed is not None: 209 | np.random.seed(seed) 210 | order = np.random.permutation(total).tolist() 211 | new_data = [ all_data[idx] for idx in order ] 212 | all_data = new_data 213 | return all_data 214 | -------------------------------------------------------------------------------- /lib/datasets/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from os import path as osp 8 | 9 | def load_file_lists(file_paths): 10 | if isinstance(file_paths, str): 11 | file_paths = [ file_paths ] 12 | print ('Function [load_lists] input {:} files'.format(len(file_paths))) 13 | all_strings = [] 14 | for file_idx, file_path in enumerate(file_paths): 15 | assert osp.isfile(file_path), 'The {:}-th path : {:} is not a file.'.format(file_idx, file_path) 16 | listfile = open(file_path, 'r') 17 | listdata = listfile.read().splitlines() 18 | listfile.close() 19 | print ('Load [{:d}/{:d}]-th list : {:} with {:} images'.format(file_idx, len(file_paths), file_path, len(listdata))) 20 | all_strings += listdata 21 | return all_strings 22 | -------------------------------------------------------------------------------- /lib/datasets/parse_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import warnings 9 | from os import path as osp 10 | 11 | def parse_basic(ori_filename, length_l, length_r): 12 | folder = osp.dirname(ori_filename) 13 | filename = osp.basename(ori_filename) 14 | # 300-VW 15 | if folder[-10:] == 'extraction': 16 | assert filename[-4:] == '.png', 'The filename is not right : {}'.format(filename) 17 | idx = int(filename[: filename.find('.png')]) 18 | assert idx >= 0, 'The index must be greater than 0' 19 | images = [] 20 | for i in range(idx-length_l, idx+length_r+1): 21 | path = osp.join(folder, '{:06d}.png'.format(i)) 22 | if not osp.isfile(path): 23 | xpath = osp.join(folder, '{:06d}.png'.format(idx)) 24 | warnings.warn('Path [{}] does not exist, maybe it reaches the start or end of the video, use {} instead.'.format(path, xpath), UserWarning) 25 | path = xpath 26 | assert osp.isfile(path), '!!WRONG file path : {}, the original frame is {}'.format(path, filename) 27 | images.append(path) 28 | return images, True 29 | # YouTube Cele.. 30 | elif folder.find('YouTube_Celebrities_Annotation') > 0: 31 | assert filename[-4:] == '.png', 'The filename is not right : {}'.format(filename) 32 | idx = int(filename[filename.find('_')+1: filename.find('.png')]) 33 | assert idx >= 0, 'The index must be greater than 0' 34 | images = [] 35 | for i in range(idx-length_l, idx+length_r+1): 36 | path = osp.join(folder, 'frame_{:05d}.png'.format(i)) 37 | if not osp.isfile(path): 38 | xpath = osp.join(folder, 'frame_{:05d}.png'.format(idx)) 39 | warnings.warn('Path [{}] does not exist, maybe it reaches the start or end of the video, use {} instead.'.format(path, xpath), UserWarning) 40 | path = xpath 41 | assert osp.isfile(path), '!!WRONG file path : {}, the original frame is {}'.format(path, filename) 42 | images.append(path) 43 | return images, True 44 | # Talking Face.. 45 | elif folder.find('talking_face') > 0: 46 | assert filename[-4:] == '.jpg', 'The filename is not right : {}'.format(filename) 47 | idx = int(filename[filename.find('_')+1: filename.find('.jpg')]) 48 | assert idx >= 0, 'The index must be greater than 0' 49 | images = [] 50 | for i in range(idx-length_l, idx+length_r+1): 51 | path = osp.join(folder, 'franck_{:05d}.png'.format(i)) 52 | if not osp.isfile(path): 53 | xpath = osp.join(folder, 'franck_{:05d}.png'.format(idx)) 54 | warnings.warn('Path [{}] does not exist, maybe it reaches the start or end of the video, use {} instead.'.format(path, xpath), UserWarning) 55 | path = xpath 56 | assert osp.isfile(path), '!!WRONG file path : {}, the original frame is {}'.format(path, filename) 57 | images.append(path) 58 | return images, True 59 | # YouTube Face.. 60 | elif folder.find('YouTube-Face') > 0: 61 | assert filename[-4:] == '.jpg', 'The filename is not right : {}'.format(filename) 62 | splits = filename.split('.') 63 | assert len(splits) == 3, 'The format is not right : {}'.format(filename) 64 | idx = int(splits[1]) 65 | images = [] 66 | for i in range(idx-length_l, idx+length_r+1): 67 | path = osp.join(folder, '{}.{}.{}'.format(splits[0], i, splits[2])) 68 | if not osp.isfile(path): 69 | xpath = osp.join(folder, '{}.{}.{}'.format(splits[0], idx, splits[2])) 70 | warnings.warn('Path [{}] does not exist, maybe it reaches the start or end of the video, use {} instead.'.format(path, xpath), UserWarning) 71 | path = xpath 72 | assert osp.isfile(path), '!!WRONG file path : {}, the original frame is {}'.format(path, filename) 73 | images.append(path) 74 | return images, True 75 | elif folder.find('demo-pams') > 0 or folder.find('demo-sbrs') > 0: 76 | assert filename[-4:] == '.png', 'The filename is not right : {}'.format(filename) 77 | assert filename[:5] == 'image', 'The filename is not right : {}'.format(filename) 78 | splits = filename.split('.') 79 | assert len(splits) == 2, 'The format is not right : {}'.format(filename) 80 | idx = int(splits[0][5:]) 81 | images = [] 82 | for i in range(idx-length_l, idx+length_r+1): 83 | path = osp.join(folder, 'image{:04d}.{:}'.format(i, splits[1])) 84 | if not osp.isfile(path): 85 | xpath = osp.join(folder, 'image{:04d}.{:}'.format(idx, splits[1])) 86 | warnings.warn('Path [{}] does not exist, maybe it reaches the start or end of the video, use {} instead.'.format(path, xpath), UserWarning) 87 | path = xpath 88 | assert osp.isfile(path), '!!WRONG file path : {}, the original frame is {}'.format(path, filename) 89 | images.append(path) 90 | return images, True 91 | else: 92 | return [ori_filename] * (length_l+length_r+1), False 93 | 94 | def parse_video_by_indicator(image_path, parser, return_info=False): 95 | if parser is None or parser.lower() == 'none': 96 | method, offset_l, offset_r = 'None', 0, 0 97 | else: 98 | parser = parser.split('-') 99 | assert len(parser) == 3, 'The video parser must be 3 elements : {:}'.format(parser) 100 | method, offset_l, offset_r = parser[0], int(parser[1]), int(parser[2]) 101 | if return_info: 102 | return offset_l, offset_r 103 | else: 104 | images, is_video_or_not = parse_basic(image_path, offset_l, offset_r) 105 | return images, is_video_or_not 106 | -------------------------------------------------------------------------------- /lib/datasets/point_meta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from PIL import Image 8 | from scipy.ndimage.interpolation import zoom 9 | from utils.file_utils import load_txt_file 10 | from .dataset_utils import convert68to49 as _convert68to49 11 | from .dataset_utils import convert68to51 as _convert68to51 12 | import numpy as np 13 | import copy, math 14 | 15 | class Point_Meta(): 16 | # points: 3 x num_pts (x, y, oculusion) 17 | # image_size: original [width, height] 18 | def __init__(self, num_point, points, box, image_path, dataset_name): 19 | 20 | self.num_point = num_point 21 | assert len(box.shape) == 1 and box.shape[0] == 4, 'The shape of box is not right : {}'.format( box ) 22 | self.box = box.copy() 23 | if points is None: 24 | self.points = points 25 | else: 26 | assert len(points.shape) == 2 and points.shape[0] == 3 and points.shape[1] == self.num_point, 'The shape of point is not right : {}'.format( points ) 27 | self.points = points.copy() 28 | self.update_center() 29 | self.image_path = image_path 30 | self.datasets = dataset_name 31 | self.temp_save_wh = None 32 | 33 | def __repr__(self): 34 | return ('{name}(number of point={num_point})'.format(name=self.__class__.__name__, **self.__dict__)) 35 | 36 | def convert68to49(self): 37 | if self.points is not None: 38 | self.points = _convert68to49(self.points) 39 | 40 | def convert68to51(self): 41 | if self.points is not None: 42 | self.points = _convert68to51(self.points) 43 | 44 | def update_center(self): 45 | if self.points is not None: 46 | self.center = np.mean(self.points[:2, self.points[2,:]>0], axis=1) 47 | else: 48 | self.center = np.array([ (self.box[0]+self.box[2])/2, (self.box[1]+self.box[3])/2 ]) 49 | 50 | def apply_bound(self, width, height): 51 | if self.points is not None: 52 | oks = np.vstack((self.points[0, :] >= 0, self.points[1, :] >=0, self.points[0, :] <= width, self.points[1, :] <= height, self.points[2, :].astype('bool'))) 53 | oks = oks.transpose((1,0)) 54 | self.points[2, :] = np.sum(oks, axis=1) == 5 55 | self.box[0], self.box[1] = np.max([self.box[0], 0]), np.max([self.box[1], 0]) 56 | self.box[2], self.box[3] = np.min([self.box[2], width]), np.min([self.box[3], height]) 57 | 58 | def apply_scale(self, scale): 59 | if len(scale) == 1: # scale the same size for both x and y 60 | if self.points is not None: 61 | self.points[:2, self.points[2,:]>0] = self.points[:2, self.points[2,:]>0] * scale[0] 62 | self.center = self.center * scale[0] 63 | self.box[0], self.box[1] = self.box[0] * scale[0], self.box[1] * scale[0] 64 | self.box[2], self.box[3] = self.box[2] * scale[0], self.box[3] * scale[0] 65 | elif len(scale) == 2: # scale the width and height 66 | if self.points is not None: 67 | self.points[0, self.points[2,:]>0] = self.points[0, self.points[2,:]>0] * scale[0] 68 | self.points[1, self.points[2,:]>0] = self.points[1, self.points[2,:]>0] * scale[1] 69 | self.center[0] = self.center[0] * scale[0] 70 | self.center[1] = self.center[1] * scale[1] 71 | self.box[0], self.box[1] = self.box[0] * scale[0], self.box[1] * scale[1] 72 | self.box[2], self.box[3] = self.box[2] * scale[0], self.box[3] * scale[1] 73 | else: 74 | assert False, 'Does not support this scale : {}'.format(scale) 75 | 76 | def apply_offset(self, ax=None, ay=None): 77 | if ax is not None: 78 | if self.points is not None: 79 | self.points[0, self.points[2,:]>0] = self.points[0, self.points[2,:]>0] + ax 80 | self.center[0] = self.center[0] + ax 81 | self.box[0], self.box[2] = self.box[0] + ax, self.box[2] + ax 82 | if ay is not None: 83 | if self.points is not None: 84 | self.points[1, self.points[2,:]>0] = self.points[1, self.points[2,:]>0] + ay 85 | self.center[1] = self.center[1] + ay 86 | self.box[1], self.box[3] = self.box[1] + ay, self.box[3] + ay 87 | 88 | def apply_rotate(self, center, degree): 89 | degree = math.radians(-degree) 90 | if self.points is not None: 91 | vis_xs = self.points[0, self.points[2,:]>0] 92 | vis_ys = self.points[1, self.points[2,:]>0] 93 | self.points[0, self.points[2,:]>0] = (vis_xs - center[0]) * np.cos(degree) - (vis_ys - center[1]) * np.sin(degree) + center[0] 94 | self.points[1, self.points[2,:]>0] = (vis_xs - center[0]) * np.sin(degree) + (vis_ys - center[1]) * np.cos(degree) + center[1] 95 | # rotate the box 96 | corners = np.zeros((4,2)) 97 | corners[0,0], corners[0,1] = self.box[0], self.box[1] 98 | corners[1,0], corners[1,1] = self.box[0], self.box[3] 99 | corners[2,0], corners[2,1] = self.box[2], self.box[1] 100 | corners[3,0], corners[3,1] = self.box[2], self.box[3] 101 | corners[:, 0] = (corners[:, 0] - center[0]) * np.cos(degree) - (corners[:, 1] - center[1]) * np.sin(degree) + center[0] 102 | corners[:, 1] = (corners[:, 0] - center[0]) * np.sin(degree) - (corners[:, 1] - center[1]) * np.cos(degree) + center[1] 103 | self.box[0], self.box[1] = corners[0,0], corners[0,1] 104 | self.box[2], self.box[3] = corners[3,0], corners[3,1] 105 | 106 | def apply_horizontal_flip(self, width): 107 | self.points[0, :] = width - self.points[0, :] - 1 108 | # Mugsy spefic or Synthetic 109 | if self.datasets == 'Mugsy.full_face_v1': 110 | ori = np.array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]) 111 | pos = np.array([ 3, 4, 1, 2, 9, 10, 11, 12, 5, 6, 7, 8, 14, 13, 15, 16, 17, 18, 19, 20]) 112 | self.points[:, pos-1] = self.points[:, ori-1] 113 | elif self.datasets == 'Synthetic.v1': 114 | ori = np.array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]) 115 | pos = np.array([ 3, 4, 1, 2, 9, 10, 11, 12, 5, 6, 7, 8, 14, 13, 15, 16, 17, 18, 19, 20]) 116 | self.points[:, pos-1] = self.points[:, ori-1] 117 | else: 118 | assert False, 'Does not support {}.{}'.format(self.datasets, self.subsets) 119 | 120 | # all points' range [0, w) [0, h) 121 | def check_nan(self): 122 | if math.isnan(self.center[0]) or math.isnan(self.center[1]): 123 | return True 124 | for i in range(self.num_point): 125 | if self.points[2, i] > 0: 126 | if math.isnan(self.points[0, i]) or math.isnan(self.points[1, i]): 127 | return True 128 | return False 129 | 130 | def visiable_pts_num(self): 131 | ans = self.points[2,:]>0 132 | return np.sum(ans) 133 | 134 | def set_precrop_wh(self, W, H, x1, y1, x2, y2): 135 | self.temp_save_wh = [W, H, x1, y1, x2, y2] 136 | 137 | def get_box(self): 138 | return self.box.copy() 139 | 140 | def get_points(self): 141 | if self.points is not None: 142 | return self.points.copy() 143 | else: 144 | return np.zeros((3, self.num_point), dtype='float32') 145 | 146 | def is_none(self): 147 | assert self.box is not None, 'The box should not be None' 148 | return self.points is None 149 | 150 | def copy(self): 151 | return copy.deepcopy(self) 152 | -------------------------------------------------------------------------------- /lib/lk/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .basic_lk import lk_tensor_track 8 | from .basic_lk_batch import lk_tensor_track_batch 9 | from .basic_lk_batch import lk_forward_backward_batch 10 | -------------------------------------------------------------------------------- /lib/lk/basic_lk.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numbers, math 11 | import numpy as np 12 | import models.model_utils as MU 13 | from .basic_utils import SobelConv 14 | from .basic_utils import ComputeGradient, Generate_Weight, warp_feature 15 | from .basic_utils import torch_inverse 16 | 17 | def lk_tensor_track(feature_old, feature_new, pts_locations, patch_size, max_step, threshold=0.0001): 18 | # feature[old,new] : 3-D tensor [C, H, W] 19 | # pts_locations is a 2-D point [ Y, X ] 20 | assert feature_old.dim() == 3 and feature_new.dim() == 3, 'The dimension is not right : {} and {}'.format(feature_old.dim(), feature_new.dim()) 21 | C, H, W = feature_old.size(0), feature_old.size(1), feature_old.size(2) 22 | assert C == feature_new.size(0) and H == feature_new.size(1) and W == feature_new.size(2), 'The size is not right : {}'.format(feature_new.size()) 23 | assert pts_locations.dim() == 1 and pts_locations.size(0) == 2, 'The location is not right : {}'.format(pts_locations) 24 | if isinstance(patch_size, int): patch_size = (patch_size, patch_size) 25 | assert isinstance(patch_size, tuple) and len(patch_size) == 2 and isinstance(max_step, int), 'The format of lk-parameters are not right : {}, {}'.format(patch_size, max_step) 26 | assert isinstance(patch_size[0], int) and isinstance(patch_size[1], int), 'The format of lk-parameters are not right : {}'.format(patch_size) 27 | 28 | def abserror(deltap): 29 | deltap = MU.variable2np(deltap) 30 | return float(np.sqrt(np.sum(deltap*deltap))) 31 | 32 | weight_map = Generate_Weight( [patch_size[0]*2+1, patch_size[1]*2+1] ) # [H, W] 33 | with torch.cuda.device_of(feature_old): 34 | weight_map = MU.np2variable(weight_map, feature_old.is_cuda, False).unsqueeze(0) 35 | 36 | feature_T = warp_feature(feature_old, pts_locations, patch_size) 37 | gradiant_x = ComputeGradient(feature_T, 'x') 38 | gradiant_y = ComputeGradient(feature_T, 'y') 39 | J = torch.stack([gradiant_x, gradiant_y]) 40 | weightedJ = J*weight_map 41 | H = torch.mm( weightedJ.view(2,-1), J.view(2, -1).transpose(1,0) ) 42 | inverseH = torch_inverse(H) 43 | 44 | for step in range(max_step): 45 | # Step-1 Warp I with W(x,p) to compute I(W(x;p)) 46 | feature_I = warp_feature(feature_new, pts_locations, patch_size) 47 | # Step-2 Compute the error feature 48 | r = feature_I - feature_T 49 | # Step-7 Compute sigma 50 | sigma = torch.mm(weightedJ.view(2,-1), r.view(-1, 1)) 51 | # Step-8 Compute delta-p 52 | deltap = torch.mm(inverseH, sigma).squeeze(1) 53 | pts_locations = pts_locations - deltap 54 | if abserror(deltap) < threshold: break 55 | 56 | return pts_locations 57 | -------------------------------------------------------------------------------- /lib/lk/basic_lk_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numbers, math 11 | import numpy as np 12 | import models.model_utils as MU 13 | from .basic_utils import SobelConv, Generate_Weight 14 | from .basic_utils_batch import torch_inverse_batch, warp_feature_batch 15 | 16 | """ 17 | peak_config = {} 18 | def obtain_config(heatmap, radius): 19 | identity_str = '{}-{}'.format(radius, heatmap.get_device() if heatmap.is_cuda else -1 ) 20 | if identity_str not in peak_config: 21 | if heatmap.is_cuda: 22 | with torch.cuda.device_of(heatmap): 23 | X = MU.np2variable(torch.arange(-radius, radius+1), heatmap.is_cuda, False).view(1, 1, radius*2+1) 24 | Y = MU.np2variable(torch.arange(-radius, radius+1), heatmap.is_cuda, False).view(1, radius*2+1, 1) 25 | else: 26 | X = MU.np2variable(torch.arange(-radius, radius+1), heatmap.is_cuda, False).view(1, 1, radius*2+1) 27 | Y = MU.np2variable(torch.arange(-radius, radius+1), heatmap.is_cuda, False).view(1, radius*2+1, 1) 28 | peak_config[ identity_str ] = [X, Y] 29 | return peak_config[ identity_str ] 30 | """ 31 | 32 | 33 | def lk_tensor_track_batch(feature_old, feature_new, pts_locations, patch_size, max_step, feature_template=None): 34 | # feature[old,new] : 4-D tensor [1, C, H, W] 35 | # pts_locations is a 2-D tensor [Num-Pts, (Y,X)] 36 | if feature_new.dim() == 3: 37 | feature_new = feature_new.unsqueeze(0) 38 | if feature_old is not None and feature_old.dim() == 3: 39 | feature_old = feature_old.unsqueeze(0) 40 | assert feature_new.dim() == 4, 'The dimension of feature-new is not right : {}.'.format(feature_new.dim()) 41 | BB, C, H, W = list(feature_new.size()) 42 | if feature_old is not None: 43 | assert 1 == feature_old.size(0) and 1 == BB, 'The first dimension of feature should be one not {}'.format(feature_old.size()) 44 | assert C == feature_old.size(1) and H == feature_old.size(2) and W == feature_old.size(3), 'The size is not right : {}'.format(feature_old.size()) 45 | assert isinstance(patch_size, int), 'The format of lk-parameters are not right : {}'.format(patch_size) 46 | num_pts = pts_locations.size(0) 47 | device = feature_new.device 48 | 49 | weight_map = Generate_Weight( [patch_size*2+1, patch_size*2+1] ) # [H, W] 50 | with torch.no_grad(): 51 | weight_map = torch.tensor(weight_map).view(1, 1, 1, patch_size*2+1, patch_size*2+1).to(device) 52 | 53 | sobelconvx = SobelConv('x', feature_new.dtype).to(device) 54 | sobelconvy = SobelConv('y', feature_new.dtype).to(device) 55 | 56 | # feature_T should be a [num_pts, C, patch, patch] tensor 57 | if feature_template is None: 58 | feature_T = warp_feature_batch(feature_old, pts_locations, patch_size) 59 | else: 60 | assert feature_old is None, 'When feature_template is not None. feature_old must be None' 61 | feature_T = feature_template 62 | assert feature_T.size(2) == patch_size * 2 + 1 and feature_T.size(3) == patch_size * 2 + 1, 'The size of feature-template is not ok : {}'.format(feature_T.size()) 63 | gradiant_x = sobelconvx(feature_T) 64 | gradiant_y = sobelconvy(feature_T) 65 | J = torch.stack([gradiant_x, gradiant_y], dim=1) 66 | weightedJ = J * weight_map 67 | H = torch.bmm( weightedJ.view(num_pts,2,-1), J.view(num_pts, 2, -1).transpose(2,1) ) 68 | inverseH = torch_inverse_batch(H) 69 | 70 | #print ('PTS : {}'.format(pts_locations)) 71 | for step in range(max_step): 72 | # Step-1 Warp I with W(x,p) to compute I(W(x;p)) 73 | feature_I = warp_feature_batch(feature_new, pts_locations, patch_size) 74 | # Step-2 Compute the error feature 75 | r = feature_I - feature_T 76 | # Step-7 Compute sigma 77 | sigma = torch.bmm(weightedJ.view(num_pts,2,-1), r.view(num_pts,-1, 1)) 78 | # Step-8 Compute delta-p 79 | deltap = torch.bmm(inverseH, sigma).squeeze(-1) 80 | pts_locations = pts_locations - deltap 81 | 82 | return pts_locations 83 | 84 | 85 | def lk_forward_backward_batch(features, locations, window, steps): 86 | sequence, C, H, W = list(features.size()) 87 | seq, num_pts, _ = list(locations.size()) 88 | assert seq == sequence, '{:} vs {:}'.format(features.size(), locations.size()) 89 | 90 | previous_pts = [ locations[0] ] 91 | for iseq in range(1, sequence): 92 | feature_old = features.narrow(0, iseq-1, 1) 93 | feature_new = features.narrow(0, iseq , 1) 94 | nextPts = lk_tensor_track_batch(feature_old, feature_new, previous_pts[iseq-1], window, steps, None) 95 | previous_pts.append(nextPts) 96 | 97 | fback_pts = [None] * (sequence-1) + [ previous_pts[-1] ] 98 | for iseq in range(sequence-2, -1, -1): 99 | feature_old = features.narrow(0, iseq+1, 1) 100 | feature_new = features.narrow(0, iseq , 1) 101 | backPts = lk_tensor_track_batch(feature_old, feature_new, fback_pts[iseq+1] , window, steps, None) 102 | fback_pts[iseq] = backPts 103 | 104 | back_pts = [None] * (sequence-1) + [ locations[-1] ] 105 | for iseq in range(sequence-2, -1, -1): 106 | feature_old = features.narrow(0, iseq+1, 1) 107 | feature_new = features.narrow(0, iseq , 1) 108 | backPts = lk_tensor_track_batch(feature_old, feature_new, back_pts[iseq+1] , window, steps, None) 109 | back_pts[iseq] = backPts 110 | 111 | return torch.stack(previous_pts), torch.stack(fback_pts), torch.stack(back_pts) 112 | -------------------------------------------------------------------------------- /lib/lk/basic_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numbers, math 11 | import numpy as np 12 | import models.model_utils as MU 13 | 14 | #### The utils for LK 15 | 16 | def torch_inverse(deltp): 17 | assert deltp.dim() == 2 and deltp.size(0) == 2 and deltp.size(1) == 2, 'The deltp format is not right : {}'.format( deltp.size() ) 18 | a, b, c, d = deltp[0,0], deltp[0,1], deltp[1,0], deltp[1,1] 19 | a = a + np.finfo(float).eps 20 | d = d + np.finfo(float).eps 21 | divide = a*d-b*c 22 | inverse = torch.cat([d, -b, -c, a]).view(2,2) 23 | return inverse / divide 24 | 25 | class SobelConv(nn.Module): 26 | def __init__(self, tag, dtype): 27 | super(SobelConv, self).__init__() 28 | if tag == 'x': 29 | Sobel = np.array([ [-1./8, 0, 1./8], [-2./8, 0, 2./8], [ -1./8, 0, 1./8] ]) 30 | #Sobel = np.array([ [ 0, 0, 0], [-0.5,0,0.5], [ 0, 0, 0] ]) 31 | elif tag == 'y': 32 | Sobel = np.array([ [ -1./8, -2./8, -1./8], [ 0, 0, 0], [ 1./8, 2./8, 1./8] ]) 33 | #Sobel = np.array([ [ 0,-0.5, 0], [ 0, 0, 0], [ 0, 0.5, 0] ]) 34 | else: 35 | raise NameError('Do not know this tag for Sobel Kernel : {}'.format(tag)) 36 | Sobel = torch.from_numpy(Sobel).type(dtype) 37 | Sobel = Sobel.view(1, 1, 3, 3) 38 | self.register_buffer('weight', Sobel) 39 | self.tag = tag 40 | 41 | def forward(self, input): 42 | weight = self.weight.expand(input.size(1), 1, 3, 3).contiguous() 43 | return F.conv2d(input, weight, groups=input.size(1), padding=1) 44 | 45 | def __repr__(self): 46 | return ('{name}(tag={tag})'.format(name=self.__class__.__name__, **self.__dict__)) 47 | 48 | def ComputeGradient(feature, tag): 49 | if feature.dim() == 3: 50 | feature = feature.unsqueeze(0) 51 | squeeze = True 52 | else: 53 | squeeze = False 54 | assert feature.dim() == 4, 'feature must be [batch x C x H x W] not {}'.format(feature.size()) 55 | sobel = SobelConv(tag) 56 | if feature.is_cuda: sobel.cuda() 57 | if squeeze: return sobel(feature).squeeze(0) 58 | else: return sobel(feature) 59 | 60 | def Generate_Weight(patch_size, sigma=None): 61 | assert isinstance(patch_size, list) or isinstance(patch_size, tuple) 62 | assert patch_size[0] > 0 and patch_size[1] > 0, 'the patch size must > 0 rather :{}'.format(patch_size) 63 | center = [(patch_size[0]-1.)/2, (patch_size[1]-1.)/2] 64 | maps = np.fromfunction( lambda x, y: (x-center[0])**2 + (y-center[1])**2, (patch_size[0], patch_size[1]), dtype=int) 65 | if sigma is None: sigma = min(patch_size[0], patch_size[1])/2. 66 | maps = np.exp(maps / -2.0 / sigma / sigma) 67 | maps[0, :] = maps[-1, :] = maps[:, 0] = maps[:, -1] = 0 68 | return maps.astype(np.float32) 69 | 70 | def warp_feature(feature, pts_location, patch_size): 71 | # pts_location is [X,Y], patch_size is [H,W] 72 | C, H, W = feature.size(0), feature.size(1), feature.size(2) 73 | def normalize(x, L): 74 | return -1. + 2. * x / (L-1) 75 | 76 | crop_box = [pts_location[0]-patch_size[1], pts_location[1]-patch_size[0], pts_location[0]+patch_size[1], pts_location[1]+patch_size[0]] 77 | crop_box[0] = normalize(crop_box[0], W) 78 | crop_box[1] = normalize(crop_box[1], H) 79 | crop_box[2] = normalize(crop_box[2], W) 80 | crop_box[3] = normalize(crop_box[3], H) 81 | affine_parameter = [(crop_box[2]-crop_box[0])/2, MU.np2variable(torch.zeros(1),feature.is_cuda,False), (crop_box[0]+crop_box[2])/2, 82 | MU.np2variable(torch.zeros(1),feature.is_cuda,False), (crop_box[3]-crop_box[1])/2, (crop_box[1]+crop_box[3])/2] 83 | 84 | affine_parameter = torch.cat(affine_parameter).view(2, 3) 85 | 86 | theta = affine_parameter.unsqueeze(0) 87 | feature = feature.unsqueeze(0) 88 | grid_size = torch.Size([1, 1, 2*patch_size[0]+1, 2*patch_size[1]+1]) 89 | grid = F.affine_grid(theta, grid_size) 90 | sub_feature = F.grid_sample(feature, grid).squeeze(0) 91 | return sub_feature 92 | -------------------------------------------------------------------------------- /lib/lk/basic_utils_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numbers, math 11 | import numpy as np 12 | import models.model_utils as MU 13 | 14 | #### The utils for LK 15 | def torch_inverse_batch(deltp): 16 | # deltp must be [K,2] 17 | assert deltp.dim() == 3 and deltp.size(1) == 2 and deltp.size(2) == 2, 'The deltp format is not right : {}'.format( deltp.size() ) 18 | a, b, c, d = deltp[:,0,0], deltp[:,0,1], deltp[:,1,0], deltp[:,1,1] 19 | a = a + np.finfo(float).eps 20 | d = d + np.finfo(float).eps 21 | divide = a*d-b*c+np.finfo(float).eps 22 | inverse = torch.stack([d, -b, -c, a], dim=1) / divide.unsqueeze(1) 23 | return inverse.view(-1,2,2) 24 | 25 | 26 | def warp_feature_batch(feature, pts_location, patch_size): 27 | # feature must be [1,C,H,W] and pts_location must be [Num-Pts, (x,y)] 28 | _, C, H, W = list(feature.size()) 29 | num_pts = pts_location.size(0) 30 | assert isinstance(patch_size, int) and feature.size(0) == 1 and pts_location.size(1) == 2, 'The shapes of feature or points are not right : {} vs {}'.format(feature.size(), pts_location.size()) 31 | assert W > 1 and H > 1, 'To guarantee normalization {}, {}'.format(W, H) 32 | 33 | def normalize(x, L): 34 | return -1. + 2. * x / (L-1) 35 | 36 | crop_box = torch.cat([pts_location-patch_size, pts_location+patch_size], 1) 37 | crop_box[:, [0,2]] = normalize(crop_box[:, [0,2]], W) 38 | crop_box[:, [1,3]] = normalize(crop_box[:, [1,3]], H) 39 | 40 | affine_parameter = [(crop_box[:,2]-crop_box[:,0])/2, crop_box[:,0]*0, (crop_box[:,2]+crop_box[:,0])/2, 41 | crop_box[:,0]*0, (crop_box[:,3]-crop_box[:,1])/2, (crop_box[:,3]+crop_box[:,1])/2] 42 | #affine_parameter = [(crop_box[:,2]-crop_box[:,0])/2, MU.np2variable(torch.zeros(num_pts),feature.is_cuda,False), (crop_box[:,2]+crop_box[:,0])/2, 43 | # MU.np2variable(torch.zeros(num_pts),feature.is_cuda,False), (crop_box[:,3]-crop_box[:,1])/2, (crop_box[:,3]+crop_box[:,1])/2] 44 | theta = torch.stack(affine_parameter, 1).view(num_pts, 2, 3) 45 | feature = feature.expand(num_pts,C, H, W) 46 | grid_size = torch.Size([num_pts, 1, 2*patch_size+1, 2*patch_size+1]) 47 | grid = F.affine_grid(theta, grid_size) 48 | sub_feature = F.grid_sample(feature, grid) 49 | return sub_feature 50 | -------------------------------------------------------------------------------- /lib/log_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .logger import Logger 8 | from .meter import AverageMeter 9 | from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_size2str, convert_secs2time, print_log 10 | -------------------------------------------------------------------------------- /lib/log_utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from pathlib import Path 8 | import importlib, warnings 9 | import os, sys, time, numpy as np 10 | import scipy.misc 11 | if sys.version_info.major == 2: # Python 2.x 12 | from StringIO import StringIO as BIO 13 | else: # Python 3.x 14 | from io import BytesIO as BIO 15 | 16 | class Logger(object): 17 | 18 | def __init__(self, log_dir, logstr): 19 | """Create a summary writer logging to log_dir.""" 20 | self.log_dir = Path(log_dir) 21 | self.model_dir = Path(log_dir) / 'checkpoint' 22 | self.meta_dir = Path(log_dir) / 'metas' 23 | self.log_dir.mkdir(mode=0o775, parents=True, exist_ok=True) 24 | self.model_dir.mkdir(mode=0o775, parents=True, exist_ok=True) 25 | self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True) 26 | 27 | self.logger_path = self.log_dir / '{:}.log'.format(logstr) 28 | self.logger_file = open(self.logger_path, 'w') 29 | 30 | 31 | def __repr__(self): 32 | return ('{name}(dir={log_dir})'.format(name=self.__class__.__name__, **self.__dict__)) 33 | 34 | def path(self, mode): 35 | if mode == 'meta' : return self.meta_dir 36 | elif mode == 'model': return self.model_dir 37 | elif mode == 'log' : return self.log_dir 38 | else: raise TypeError('Unknow mode = {:}'.format(mode)) 39 | 40 | def last_info(self): 41 | return self.log_dir / 'last-info.pth' 42 | 43 | def extract_log(self): 44 | return self.logger_file 45 | 46 | def close(self): 47 | self.logger_file.close() 48 | 49 | def log(self, string, save=True): 50 | print (string) 51 | if save: 52 | self.logger_file.write('{:}\n'.format(string)) 53 | self.logger_file.flush() 54 | -------------------------------------------------------------------------------- /lib/log_utils/meter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import time, sys 8 | import numpy as np 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | -------------------------------------------------------------------------------- /lib/log_utils/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import time, sys 8 | import numpy as np 9 | from .logger import Logger 10 | 11 | def time_for_file(): 12 | ISOTIMEFORMAT='%d-%h-at-%H-%M-%S' 13 | return '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 14 | 15 | def time_string(): 16 | ISOTIMEFORMAT='%Y-%m-%d %X' 17 | string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 18 | return string 19 | 20 | def time_string_short(): 21 | ISOTIMEFORMAT='%Y%m%d' 22 | string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 23 | return string 24 | 25 | def time_print(string, is_print=True): 26 | if (is_print): 27 | print('{} : {}'.format(time_string(), string)) 28 | 29 | def convert_size2str(torch_size): 30 | dims = len(torch_size) 31 | string = '[' 32 | for idim in range(dims): 33 | string = string + ' {}'.format(torch_size[idim]) 34 | return string + ']' 35 | 36 | def convert_secs2time(epoch_time, return_str=False): 37 | need_hour = int(epoch_time / 3600) 38 | need_mins = int((epoch_time - 3600*need_hour) / 60) 39 | need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) 40 | if return_str: 41 | str = '[Time Left: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) 42 | return str 43 | else: 44 | return need_hour, need_mins, need_secs 45 | 46 | def print_log(print_string, log): 47 | if isinstance(log, Logger): log.log('{:}'.format(print_string)) 48 | else: 49 | print("{:}".format(print_string)) 50 | if log is not None: 51 | log.write('{:}\n'.format(print_string)) 52 | log.flush() 53 | -------------------------------------------------------------------------------- /lib/models/LK.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch, copy 8 | import torch.nn as nn 9 | import lk 10 | 11 | class LK(nn.Module): 12 | def __init__(self, model, lkconfig, points): 13 | super(LK, self).__init__() 14 | self.detector = model 15 | self.downsample = self.detector.downsample 16 | self.config = copy.deepcopy(lkconfig) 17 | self.points = points 18 | 19 | def forward(self, inputs): 20 | assert inputs.dim() == 5, 'This model accepts 5 dimension input tensor: {}'.format(inputs.size()) 21 | batch_size, sequence, C, H, W = list( inputs.size() ) 22 | gathered_inputs = inputs.view(batch_size * sequence, C, H, W) 23 | heatmaps, batch_locs, batch_scos = self.detector(gathered_inputs) 24 | heatmaps = [x.view(batch_size, sequence, self.points, H//self.downsample, W//self.downsample) for x in heatmaps] 25 | batch_locs, batch_scos = batch_locs.view(batch_size, sequence, self.points, 2), batch_scos.view(batch_size, sequence, self.points) 26 | batch_next, batch_fback, batch_back = [], [], [] 27 | 28 | for ibatch in range(batch_size): 29 | feature_old = inputs[ibatch] 30 | nextPts, fbackPts, backPts = lk.lk_forward_backward_batch(inputs[ibatch], batch_locs[ibatch], self.config.window, self.config.steps) 31 | 32 | batch_next.append(nextPts) 33 | batch_fback.append(fbackPts) 34 | batch_back.append(backPts) 35 | batch_next, batch_fback, batch_back = torch.stack(batch_next), torch.stack(batch_fback), torch.stack(batch_back) 36 | return heatmaps, batch_locs, batch_scos, batch_next, batch_fback, batch_back 37 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .basic import obtain_model 8 | from .basic import obtain_LK 9 | from .model_utils import remove_module_dict 10 | -------------------------------------------------------------------------------- /lib/models/basic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .cpm_vgg16 import cpm_vgg16 8 | from .LK import LK 9 | 10 | def obtain_model(configure, points): 11 | if configure.arch == 'cpm_vgg16': 12 | net = cpm_vgg16(configure, points) 13 | else: 14 | raise TypeError('Unkonw type : {:}'.format(configure.arch)) 15 | return net 16 | 17 | def obtain_LK(configure, lkconfig, points): 18 | model = obtain_model(configure, points) 19 | lk_model = LK(model, lkconfig, points) 20 | return lk_model 21 | -------------------------------------------------------------------------------- /lib/models/basic_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numbers, math 11 | import numpy as np 12 | 13 | def find_tensor_peak_batch(heatmap, radius, downsample, threshold = 0.000001): 14 | assert heatmap.dim() == 3, 'The dimension of the heatmap is wrong : {}'.format(heatmap.size()) 15 | assert radius > 0 and isinstance(radius, numbers.Number), 'The radius is not ok : {}'.format(radius) 16 | num_pts, H, W = heatmap.size(0), heatmap.size(1), heatmap.size(2) 17 | assert W > 1 and H > 1, 'To avoid the normalization function divide zero' 18 | # find the approximate location: 19 | score, index = torch.max(heatmap.view(num_pts, -1), 1) 20 | index_w = (index % W).float() 21 | index_h = (index / W).float() 22 | 23 | def normalize(x, L): 24 | return -1. + 2. * x.data / (L-1) 25 | boxes = [index_w - radius, index_h - radius, index_w + radius, index_h + radius] 26 | boxes[0] = normalize(boxes[0], W) 27 | boxes[1] = normalize(boxes[1], H) 28 | boxes[2] = normalize(boxes[2], W) 29 | boxes[3] = normalize(boxes[3], H) 30 | #affine_parameter = [(boxes[2]-boxes[0])/2, boxes[0]*0, (boxes[2]+boxes[0])/2, 31 | # boxes[0]*0, (boxes[3]-boxes[1])/2, (boxes[3]+boxes[1])/2] 32 | #theta = torch.stack(affine_parameter, 1).view(num_pts, 2, 3) 33 | 34 | affine_parameter = torch.zeros((num_pts, 2, 3)) 35 | affine_parameter[:,0,0] = (boxes[2]-boxes[0])/2 36 | affine_parameter[:,0,2] = (boxes[2]+boxes[0])/2 37 | affine_parameter[:,1,1] = (boxes[3]-boxes[1])/2 38 | affine_parameter[:,1,2] = (boxes[3]+boxes[1])/2 39 | # extract the sub-region heatmap 40 | theta = affine_parameter.to(heatmap.device) 41 | grid_size = torch.Size([num_pts, 1, radius*2+1, radius*2+1]) 42 | grid = F.affine_grid(theta, grid_size) 43 | sub_feature = F.grid_sample(heatmap.unsqueeze(1), grid).squeeze(1) 44 | sub_feature = F.threshold(sub_feature, threshold, np.finfo(float).eps) 45 | 46 | X = torch.arange(-radius, radius+1).to(heatmap).view(1, 1, radius*2+1) 47 | Y = torch.arange(-radius, radius+1).to(heatmap).view(1, radius*2+1, 1) 48 | 49 | sum_region = torch.sum(sub_feature.view(num_pts,-1),1) 50 | x = torch.sum((sub_feature*X).view(num_pts,-1),1) / sum_region + index_w 51 | y = torch.sum((sub_feature*Y).view(num_pts,-1),1) / sum_region + index_h 52 | 53 | x = x * downsample + downsample / 2.0 - 0.5 54 | y = y * downsample + downsample / 2.0 - 0.5 55 | return torch.stack([x, y],1), score 56 | -------------------------------------------------------------------------------- /lib/models/cpm_vgg16.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import division 8 | import time, math 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.utils.model_zoo as model_zoo 13 | from collections import OrderedDict 14 | from copy import deepcopy 15 | from .model_utils import get_parameters 16 | from .basic_batch import find_tensor_peak_batch 17 | from .initialization import weights_init_cpm 18 | 19 | class VGG16_base(nn.Module): 20 | def __init__(self, config, pts_num): 21 | super(VGG16_base, self).__init__() 22 | 23 | self.config = deepcopy(config) 24 | self.downsample = 8 25 | self.pts_num = pts_num 26 | 27 | self.features = nn.Sequential( 28 | nn.Conv2d( 3, 64, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 29 | nn.Conv2d( 64, 64, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=2, stride=2), 31 | nn.Conv2d( 64, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 32 | nn.Conv2d(128, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 33 | nn.MaxPool2d(kernel_size=2, stride=2), 34 | nn.Conv2d(128, 256, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 35 | nn.Conv2d(256, 256, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 36 | nn.Conv2d(256, 256, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 37 | nn.MaxPool2d(kernel_size=2, stride=2), 38 | nn.Conv2d(256, 512, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 39 | nn.Conv2d(512, 512, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 40 | nn.Conv2d(512, 512, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True)) 41 | 42 | 43 | self.CPM_feature = nn.Sequential( 44 | nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), #CPM_1 45 | nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True)) #CPM_2 46 | 47 | assert self.config.stages >= 1, 'stages of cpm must >= 1 not : {:}'.format(self.config.stages) 48 | stage1 = nn.Sequential( 49 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), 50 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), 51 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), 52 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), 53 | nn.Conv2d(128, 512, kernel_size=1, padding=0), nn.ReLU(inplace=True), 54 | nn.Conv2d(512, pts_num, kernel_size=1, padding=0)) 55 | stages = [stage1] 56 | for i in range(1, self.config.stages): 57 | stagex = nn.Sequential( 58 | nn.Conv2d(128+pts_num, 128, kernel_size=7, dilation=1, padding=3), nn.ReLU(inplace=True), 59 | nn.Conv2d(128, 128, kernel_size=7, dilation=1, padding=3), nn.ReLU(inplace=True), 60 | nn.Conv2d(128, 128, kernel_size=7, dilation=1, padding=3), nn.ReLU(inplace=True), 61 | nn.Conv2d(128, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 62 | nn.Conv2d(128, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 63 | nn.Conv2d(128, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 64 | nn.Conv2d(128, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), 65 | nn.Conv2d(128, 128, kernel_size=1, padding=0), nn.ReLU(inplace=True), 66 | nn.Conv2d(128, pts_num, kernel_size=1, padding=0)) 67 | stages.append( stagex ) 68 | self.stages = nn.ModuleList(stages) 69 | 70 | def specify_parameter(self, base_lr, base_weight_decay): 71 | params_dict = [ {'params': get_parameters(self.features, bias=False), 'lr': base_lr , 'weight_decay': base_weight_decay}, 72 | {'params': get_parameters(self.features, bias=True ), 'lr': base_lr*2, 'weight_decay': 0}, 73 | {'params': get_parameters(self.CPM_feature, bias=False), 'lr': base_lr , 'weight_decay': base_weight_decay}, 74 | {'params': get_parameters(self.CPM_feature, bias=True ), 'lr': base_lr*2, 'weight_decay': 0}, 75 | ] 76 | for stage in self.stages: 77 | params_dict.append( {'params': get_parameters(stage, bias=False), 'lr': base_lr*4, 'weight_decay': base_weight_decay} ) 78 | params_dict.append( {'params': get_parameters(stage, bias=True ), 'lr': base_lr*8, 'weight_decay': 0} ) 79 | return params_dict 80 | 81 | # return : cpm-stages, locations 82 | def forward(self, inputs): 83 | assert inputs.dim() == 4, 'This model accepts 4 dimension input tensor: {}'.format(inputs.size()) 84 | batch_size, feature_dim = inputs.size(0), inputs.size(1) 85 | batch_cpms, batch_locs, batch_scos = [], [], [] 86 | 87 | feature = self.features(inputs) 88 | xfeature = self.CPM_feature(feature) 89 | for i in range(self.config.stages): 90 | if i == 0: cpm = self.stages[i]( xfeature ) 91 | else: cpm = self.stages[i]( torch.cat([xfeature, batch_cpms[i-1]], 1) ) 92 | batch_cpms.append( cpm ) 93 | 94 | # The location of the current batch 95 | for ibatch in range(batch_size): 96 | batch_location, batch_score = find_tensor_peak_batch(batch_cpms[-1][ibatch], self.config.argmax, self.downsample) 97 | batch_locs.append( batch_location ) 98 | batch_scos.append( batch_score ) 99 | batch_locs, batch_scos = torch.stack(batch_locs), torch.stack(batch_scos) 100 | 101 | return batch_cpms, batch_locs, batch_scos 102 | 103 | # use vgg16 conv1_1 to conv4_4 as feature extracation 104 | model_urls = 'https://download.pytorch.org/models/vgg16-397923af.pth' 105 | 106 | def cpm_vgg16(config, pts): 107 | 108 | print ('Initialize cpm-vgg16 with configure : {}'.format(config)) 109 | model = VGG16_base(config, pts) 110 | model.apply(weights_init_cpm) 111 | 112 | if config.pretrained: 113 | print ('vgg16_base use pre-trained model') 114 | weights = model_zoo.load_url(model_urls) 115 | model.load_state_dict(weights, strict=False) 116 | return model 117 | -------------------------------------------------------------------------------- /lib/models/initialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import init 10 | 11 | def weights_init_cpm(m): 12 | classname = m.__class__.__name__ 13 | # print(classname) 14 | if classname.find('Conv') != -1: 15 | m.weight.data.normal_(0, 0.01) 16 | if m.bias is not None: m.bias.data.zero_() 17 | elif classname.find('BatchNorm2d') != -1: 18 | m.weight.data.fill_(1) 19 | m.bias.data.zero_() 20 | -------------------------------------------------------------------------------- /lib/models/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from scipy.ndimage.interpolation import zoom 8 | from collections import OrderedDict 9 | import utils 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import copy, numbers, numpy as np 14 | 15 | def get_parameters(model, bias): 16 | for m in model.modules(): 17 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 18 | if bias: 19 | yield m.bias 20 | else: 21 | yield m.weight 22 | elif isinstance(m, nn.BatchNorm2d): 23 | if bias: 24 | yield m.bias 25 | else: 26 | yield m.weight 27 | 28 | def remove_module_dict(state_dict, is_print=False): 29 | new_state_dict = OrderedDict() 30 | for k, v in state_dict.items(): 31 | if k[:7] == 'module.': 32 | name = k[7:] # remove `module.` 33 | else: 34 | name = k 35 | new_state_dict[name] = v 36 | if is_print: print(new_state_dict.keys()) 37 | return new_state_dict 38 | -------------------------------------------------------------------------------- /lib/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .opt_utils import obtain_optimizer 8 | -------------------------------------------------------------------------------- /lib/optimizer/opt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | 9 | def obtain_optimizer(params, config, logger): 10 | assert hasattr(config, 'optimizer'), 'Must have the optimizer attribute' 11 | optimizer = config.optimizer.lower() 12 | if optimizer == 'sgd': 13 | opt = torch.optim.SGD(params, lr=config.LR, momentum=config.momentum, 14 | weight_decay=config.Decay, nesterov=config.nesterov) 15 | elif optimizer == 'rmsprop': 16 | opt = torch.optim.RMSprop(params, lr=config.LR, momentum=config.momentum, 17 | alpha = config.alpha, eps=config.epsilon, 18 | weight_decay = config.weight_decay) 19 | elif optimizer == 'adam': 20 | opt = torch.optim.Adam(params, lr=config.LR, amsgrad=config.amsgrad) 21 | else: 22 | raise TypeError('Does not know this optimizer : {:}'.format(config)) 23 | 24 | scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=config.schedule, gamma=config.gamma) 25 | 26 | strs = config.criterion.split('-') 27 | assert len(strs) == 2, 'illegal criterion : {:}'.format(config.criterion) 28 | if strs[0].lower() == 'mse': 29 | size_average = strs[1].lower() == 'avg' 30 | criterion = torch.nn.MSELoss(size_average) 31 | message = 'Optimizer : {:}, MSE Loss with size-average={:}'.format(opt, size_average) 32 | if logger is not None: logger.log(message) 33 | else : print(message) 34 | else: 35 | raise TypeError('Does not know this optimizer : {:}'.format(config.criterion)) 36 | 37 | return opt, scheduler, criterion 38 | -------------------------------------------------------------------------------- /lib/procedure/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .starts import prepare_seed 8 | from .basic_train import basic_train 9 | from .saver import save_checkpoint 10 | from .basic_eval import basic_eval_all 11 | # LK 12 | from .lk_train import lk_train 13 | -------------------------------------------------------------------------------- /lib/procedure/basic_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import time, os, sys, numpy as np 8 | import torch 9 | from copy import deepcopy 10 | from pathlib import Path 11 | from xvision import Eval_Meta 12 | from log_utils import AverageMeter, time_for_file, time_string, convert_secs2time 13 | from .losses import compute_stage_loss, show_stage_loss 14 | 15 | def basic_eval_all(args, loaders, net, criterion, epoch_str, logger, opt_config): 16 | args = deepcopy(args) 17 | logger.log('Basic-Eval-All evaluates {:} dataset'.format(len(loaders))) 18 | nmes = [] 19 | for i, (loader, is_video) in enumerate(loaders): 20 | logger.log('==>>{:}, [{:}], evaluate the {:}/{:}-th dataset [{:}] : {:}'.format(time_string(), epoch_str, i, len(loaders), 'video' if is_video else 'image', loader.dataset)) 21 | with torch.no_grad(): 22 | eval_loss, eval_meta = basic_eval(args, loader, net, criterion, epoch_str+"::{:}/{:}".format(i,len(loaders)), logger, opt_config) 23 | nme, _, _ = eval_meta.compute_mse(logger) 24 | meta_path = logger.path('meta') / 'eval-{:}-{:02d}-{:02d}.pth'.format(epoch_str, i, len(loaders)) 25 | eval_meta.save(meta_path) 26 | nmes.append(nme*100) 27 | return ', '.join(['{:.2f}'.format(x) for x in nmes]) 28 | 29 | 30 | def basic_eval(args, loader, net, criterion, epoch_str, logger, opt_config): 31 | batch_time, data_time, forward_time, eval_time = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 32 | visible_points, losses = AverageMeter(), AverageMeter() 33 | eval_meta = Eval_Meta() 34 | cpu = torch.device('cpu') 35 | 36 | # switch to train mode 37 | net.eval() 38 | criterion.eval() 39 | 40 | end = time.time() 41 | for i, (inputs, target, mask, points, image_index, nopoints, cropped_size) in enumerate(loader): 42 | # inputs : Batch, Channel, Height, Width 43 | 44 | target = target.cuda(non_blocking=True) 45 | 46 | image_index = image_index.numpy().squeeze(1).tolist() 47 | batch_size, num_pts = inputs.size(0), args.num_pts 48 | visible_point_num = float(np.sum(mask.numpy()[:,:-1,:,:])) / batch_size 49 | visible_points.update(visible_point_num, batch_size) 50 | nopoints = nopoints.numpy().squeeze(1).tolist() 51 | annotated_num = batch_size - sum(nopoints) 52 | 53 | # measure data loading time 54 | mask = mask.cuda(non_blocking=True) 55 | data_time.update(time.time() - end) 56 | 57 | # batch_heatmaps is a list for stage-predictions, each element should be [Batch, C, H, W] 58 | batch_heatmaps, batch_locs, batch_scos = net(inputs) 59 | forward_time.update(time.time() - end) 60 | 61 | if annotated_num > 0: 62 | loss, each_stage_loss_value = compute_stage_loss(criterion, target, batch_heatmaps, mask) 63 | if opt_config.lossnorm: 64 | loss, each_stage_loss_value = loss / annotated_num, [x/annotated_num for x in each_stage_loss_value] 65 | each_stage_loss_value = show_stage_loss(each_stage_loss_value) 66 | # measure accuracy and record loss 67 | losses.update(loss.item(), batch_size) 68 | else: 69 | loss, each_stage_loss_value = 0, 'no-det-loss' 70 | 71 | eval_time.update(time.time() - end) 72 | 73 | np_batch_locs, np_batch_scos = batch_locs.to(cpu).numpy(), batch_scos.to(cpu).numpy() 74 | cropped_size = cropped_size.numpy() 75 | # evaluate the training data 76 | for ibatch, (imgidx, nopoint) in enumerate(zip(image_index, nopoints)): 77 | #if nopoint == 1: continue 78 | locations, scores = np_batch_locs[ibatch,:-1,:], np.expand_dims(np_batch_scos[ibatch,:-1], -1) 79 | xpoints = loader.dataset.labels[imgidx].get_points() 80 | assert cropped_size[ibatch,0] > 0 and cropped_size[ibatch,1] > 0, 'The ibatch={:}, imgidx={:} is not right.'.format(ibatch, imgidx, cropped_size[ibatch]) 81 | scale_h, scale_w = cropped_size[ibatch,0] * 1. / inputs.size(-2) , cropped_size[ibatch,1] * 1. / inputs.size(-1) 82 | locations[:, 0], locations[:, 1] = locations[:, 0] * scale_w + cropped_size[ibatch,2], locations[:, 1] * scale_h + cropped_size[ibatch,3] 83 | assert xpoints.shape[1] == num_pts and locations.shape[0] == num_pts and scores.shape[0] == num_pts, 'The number of points is {} vs {} vs {} vs {}'.format(num_pts, xpoints.shape, locations.shape, scores.shape) 84 | # recover the original resolution 85 | prediction = np.concatenate((locations, scores), axis=1).transpose(1,0) 86 | image_path = loader.dataset.datas[imgidx] 87 | face_size = loader.dataset.face_sizes[imgidx] 88 | if nopoint == 1: 89 | eval_meta.append(prediction, None, image_path, face_size) 90 | else: 91 | eval_meta.append(prediction, xpoints, image_path, face_size) 92 | 93 | # measure elapsed time 94 | batch_time.update(time.time() - end) 95 | last_time = convert_secs2time(batch_time.avg * (len(loader)-i-1), True) 96 | end = time.time() 97 | 98 | if i % (args.print_freq) == 0 or i+1 == len(loader): 99 | logger.log(' -->>[Eval]: [{:}][{:03d}/{:03d}] ' 100 | 'Time {batch_time.val:4.2f} ({batch_time.avg:4.2f}) ' 101 | 'Data {data_time.val:4.2f} ({data_time.avg:4.2f}) ' 102 | 'Forward {forward_time.val:4.2f} ({forward_time.avg:4.2f}) ' 103 | 'Loss {loss.val:7.4f} ({loss.avg:7.4f}) '.format( 104 | epoch_str, i, len(loader), batch_time=batch_time, 105 | data_time=data_time, forward_time=forward_time, loss=losses) 106 | + last_time + each_stage_loss_value \ 107 | + ' In={:} Tar={:}'.format(list(inputs.size()), list(target.size())) \ 108 | + ' Vis-PTS : {:2d} ({:.1f})'.format(int(visible_points.val), visible_points.avg)) 109 | return losses.avg, eval_meta 110 | -------------------------------------------------------------------------------- /lib/procedure/basic_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import time, os 8 | import numpy as np 9 | import torch 10 | from copy import deepcopy 11 | from pathlib import Path 12 | from xvision import Eval_Meta 13 | from log_utils import AverageMeter, time_for_file, convert_secs2time 14 | from .losses import compute_stage_loss, show_stage_loss 15 | 16 | # train function (forward, backward, update) 17 | def basic_train(args, loader, net, criterion, optimizer, epoch_str, logger, opt_config): 18 | args = deepcopy(args) 19 | batch_time, data_time, forward_time, eval_time = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 20 | visible_points, losses = AverageMeter(), AverageMeter() 21 | eval_meta = Eval_Meta() 22 | cpu = torch.device('cpu') 23 | 24 | # switch to train mode 25 | net.train() 26 | criterion.train() 27 | 28 | end = time.time() 29 | for i, (inputs, target, mask, points, image_index, nopoints, cropped_size) in enumerate(loader): 30 | # inputs : Batch, Channel, Height, Width 31 | 32 | target = target.cuda(non_blocking=True) 33 | 34 | image_index = image_index.numpy().squeeze(1).tolist() 35 | batch_size, num_pts = inputs.size(0), args.num_pts 36 | visible_point_num = float(np.sum(mask.numpy()[:,:-1,:,:])) / batch_size 37 | visible_points.update(visible_point_num, batch_size) 38 | nopoints = nopoints.numpy().squeeze(1).tolist() 39 | annotated_num = batch_size - sum(nopoints) 40 | 41 | # measure data loading time 42 | mask = mask.cuda(non_blocking=True) 43 | data_time.update(time.time() - end) 44 | 45 | # batch_heatmaps is a list for stage-predictions, each element should be [Batch, C, H, W] 46 | batch_heatmaps, batch_locs, batch_scos = net(inputs) 47 | forward_time.update(time.time() - end) 48 | 49 | loss, each_stage_loss_value = compute_stage_loss(criterion, target, batch_heatmaps, mask) 50 | 51 | if opt_config.lossnorm: 52 | loss, each_stage_loss_value = loss / annotated_num / 2, [x/annotated_num/2 for x in each_stage_loss_value] 53 | 54 | # measure accuracy and record loss 55 | losses.update(loss.item(), batch_size) 56 | 57 | # compute gradient and do SGD step 58 | optimizer.zero_grad() 59 | loss.backward() 60 | optimizer.step() 61 | eval_time.update(time.time() - end) 62 | 63 | np_batch_locs, np_batch_scos = batch_locs.detach().to(cpu).numpy(), batch_scos.detach().to(cpu).numpy() 64 | cropped_size = cropped_size.numpy() 65 | # evaluate the training data 66 | for ibatch, (imgidx, nopoint) in enumerate(zip(image_index, nopoints)): 67 | if nopoint == 1: continue 68 | locations, scores = np_batch_locs[ibatch,:-1,:], np.expand_dims(np_batch_scos[ibatch,:-1], -1) 69 | xpoints = loader.dataset.labels[imgidx].get_points() 70 | assert cropped_size[ibatch,0] > 0 and cropped_size[ibatch,1] > 0, 'The ibatch={:}, imgidx={:} is not right.'.format(ibatch, imgidx, cropped_size[ibatch]) 71 | scale_h, scale_w = cropped_size[ibatch,0] * 1. / inputs.size(-2) , cropped_size[ibatch,1] * 1. / inputs.size(-1) 72 | locations[:, 0], locations[:, 1] = locations[:, 0] * scale_w + cropped_size[ibatch,2], locations[:, 1] * scale_h + cropped_size[ibatch,3] 73 | assert xpoints.shape[1] == num_pts and locations.shape[0] == num_pts and scores.shape[0] == num_pts, 'The number of points is {} vs {} vs {} vs {}'.format(num_pts, xpoints.shape, locations.shape, scores.shape) 74 | # recover the original resolution 75 | prediction = np.concatenate((locations, scores), axis=1).transpose(1,0) 76 | image_path = loader.dataset.datas[imgidx] 77 | face_size = loader.dataset.face_sizes[imgidx] 78 | eval_meta.append(prediction, xpoints, image_path, face_size) 79 | 80 | # measure elapsed time 81 | batch_time.update(time.time() - end) 82 | last_time = convert_secs2time(batch_time.avg * (len(loader)-i-1), True) 83 | end = time.time() 84 | 85 | if i % args.print_freq == 0 or i+1 == len(loader): 86 | logger.log(' -->>[Train]: [{:}][{:03d}/{:03d}] ' 87 | 'Time {batch_time.val:4.2f} ({batch_time.avg:4.2f}) ' 88 | 'Data {data_time.val:4.2f} ({data_time.avg:4.2f}) ' 89 | 'Forward {forward_time.val:4.2f} ({forward_time.avg:4.2f}) ' 90 | 'Loss {loss.val:7.4f} ({loss.avg:7.4f}) '.format( 91 | epoch_str, i, len(loader), batch_time=batch_time, 92 | data_time=data_time, forward_time=forward_time, loss=losses) 93 | + last_time + show_stage_loss(each_stage_loss_value) \ 94 | + ' In={:} Tar={:}'.format(list(inputs.size()), list(target.size())) \ 95 | + ' Vis-PTS : {:2d} ({:.1f})'.format(int(visible_points.val), visible_points.avg)) 96 | nme, _, _ = eval_meta.compute_mse(logger) 97 | return losses.avg, nme 98 | -------------------------------------------------------------------------------- /lib/procedure/lk_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import models 8 | import torch 9 | import numpy as np 10 | import pdb, math, numbers 11 | 12 | def lk_input_check(batch_locs, batch_scos, batch_next, batch_fback, batch_back): 13 | batch, sequence, num_pts, _ = list(batch_locs.size()) 14 | assert batch_locs.size() == batch_next.size() == batch_fback.size() == batch_back.size(), '{:} vs {:} vs {:} vs {:}'.format(batch_locs.size(), batch_next.size(), batch_fback.size(), batch_back.size()) 15 | assert _ == 2, '{:}'.format(batch_locs.size()) 16 | assert batch_scos.size(0) == batch and batch_scos.size(1) == sequence and batch_scos.size(2) == num_pts, '{:} vs {:}'.format(batch_locs.size(), batch_scos.size()) 17 | return batch, sequence, num_pts 18 | 19 | def p2string(point): 20 | if isinstance(point, numbers.Number): 21 | return '{:.1f}'.format(point*1.0) 22 | elif point.size == 2: 23 | return '{:.1f},{:.1f}'.format(point[0], point[1]) 24 | else: 25 | return '{}'.format(point) 26 | 27 | def lk_target_loss(batch_locs, batch_scos, batch_next, batch_fbak, batch_back, lk_config, video_or_not, mask, nopoints): 28 | # return the calculate target from the first frame to the whole sequence. 29 | batch, sequence, num_pts = lk_input_check(batch_locs, batch_scos, batch_next, batch_fbak, batch_back) 30 | 31 | # remove the background 32 | num_pts = num_pts - 1 33 | sequence_checks = np.ones((batch, num_pts), dtype='bool') 34 | 35 | # Check the confidence score for each point 36 | for ibatch in range(batch): 37 | if video_or_not[ibatch] == False: 38 | sequence_checks[ibatch, :] = False 39 | else: 40 | for iseq in range(sequence): 41 | for ipts in range(num_pts): 42 | score = batch_scos[ibatch, iseq, ipts] 43 | if mask[ibatch, ipts] == False and nopoints[ibatch] == 0: 44 | sequence_checks[ibatch, ipts] = False 45 | if score.item() < lk_config.conf_thresh: 46 | sequence_checks[ibatch, ipts] = False 47 | 48 | losses = [] 49 | for ibatch in range(batch): 50 | for ipts in range(num_pts): 51 | if not sequence_checks[ibatch, ipts]: continue 52 | loss = 0 53 | for iseq in range(sequence): 54 | 55 | targets = batch_locs[ibatch, iseq, ipts] 56 | nextPts = batch_next[ibatch, iseq, ipts] 57 | fbakPts = batch_fbak[ibatch, iseq, ipts] 58 | backPts = batch_back[ibatch, iseq, ipts] 59 | 60 | with torch.no_grad(): 61 | fbak_distance = torch.dist(nextPts, fbakPts) 62 | back_distance = torch.dist(targets, backPts) 63 | forw_distance = torch.dist(targets, nextPts) 64 | 65 | #print ('[{:02d},{:02d},{:02d}] : {:.2f}, {:.2f}, {:.2f}'.format(ibatch, ipts, iseq, fbak_distance.item(), back_distance.item(), forw_distance.item())) 66 | #loss += back_distance + forw_distance 67 | 68 | if fbak_distance.item() > lk_config.fb_thresh or fbak_distance.item() < lk_config.eps: # forward-backward-check 69 | if iseq+1 < sequence: sequence_checks[ibatch, ipts] = False 70 | if forw_distance.item() > lk_config.forward_max or forw_distance.item() < lk_config.eps: # to avoid the tracker point is too far 71 | if iseq > 0 : sequence_checks[ibatch, ipts] = False 72 | if back_distance.item() > lk_config.forward_max or back_distance.item() < lk_config.eps: # to avoid the tracker point is too far 73 | if iseq+1 < sequence: sequence_checks[ibatch, ipts] = False 74 | 75 | if iseq > 0: 76 | if lk_config.stable: loss += torch.dist(targets, backPts.detach()) 77 | else : loss += torch.dist(targets, backPts) 78 | if iseq + 1 < sequence: 79 | if lk_config.stable: loss += torch.dist(targets, nextPts.detach()) 80 | else : loss += torch.dist(targets, nextPts) 81 | 82 | if sequence_checks[ibatch, ipts]: 83 | losses.append(loss) 84 | 85 | avaliable = int(np.sum(sequence_checks)) 86 | if avaliable == 0: return None, avaliable 87 | else : return torch.mean(torch.stack(losses)), avaliable 88 | -------------------------------------------------------------------------------- /lib/procedure/lk_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import time, os, numpy as np 8 | import torch 9 | import numbers, warnings 10 | from copy import deepcopy 11 | from pathlib import Path 12 | from log_utils import AverageMeter, time_for_file, convert_secs2time 13 | from .losses import compute_stage_loss, show_stage_loss 14 | from .lk_loss import lk_target_loss 15 | 16 | # train function (forward, backward, update) 17 | def lk_train(args, loader, net, criterion, optimizer, epoch_str, logger, opt_config, lk_config, use_lk): 18 | args = deepcopy(args) 19 | batch_time, data_time, forward_time, eval_time = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 20 | visible_points, detlosses, lklosses = AverageMeter(), AverageMeter(), AverageMeter() 21 | alk_points, losses = AverageMeter(), AverageMeter() 22 | cpu = torch.device('cpu') 23 | 24 | annotate_index = loader.dataset.center_idx 25 | 26 | # switch to train mode 27 | net.train() 28 | criterion.train() 29 | 30 | end = time.time() 31 | for i, (inputs, target, mask, points, image_index, nopoints, video_or_not, cropped_size) in enumerate(loader): 32 | # inputs : Batch, Sequence Channel, Height, Width 33 | 34 | target = target.cuda(non_blocking=True) 35 | 36 | image_index = image_index.numpy().squeeze(1).tolist() 37 | batch_size, sequence, num_pts = inputs.size(0), inputs.size(1), args.num_pts 38 | mask_np = mask.numpy().squeeze(-1).squeeze(-1) 39 | visible_point_num = float(np.sum(mask.numpy()[:,:-1,:,:])) / batch_size 40 | visible_points.update(visible_point_num, batch_size) 41 | nopoints = nopoints.numpy().squeeze(1).tolist() 42 | video_or_not= video_or_not.numpy().squeeze(1).tolist() 43 | annotated_num = batch_size - sum(nopoints) 44 | 45 | # measure data loading time 46 | mask = mask.cuda(non_blocking=True) 47 | data_time.update(time.time() - end) 48 | 49 | # batch_heatmaps is a list for stage-predictions, each element should be [Batch, Sequence, PTS, H/Down, W/Down] 50 | batch_heatmaps, batch_locs, batch_scos, batch_next, batch_fback, batch_back = net(inputs) 51 | annot_heatmaps = [x[:, annotate_index] for x in batch_heatmaps] 52 | forward_time.update(time.time() - end) 53 | 54 | if annotated_num > 0: 55 | # have the detection loss 56 | detloss, each_stage_loss_value = compute_stage_loss(criterion, target, annot_heatmaps, mask) 57 | if opt_config.lossnorm: 58 | detloss, each_stage_loss_value = detloss / annotated_num / 2, [x/annotated_num/2 for x in each_stage_loss_value] 59 | # measure accuracy and record loss 60 | detlosses.update(detloss.item(), batch_size) 61 | each_stage_loss_value = show_stage_loss(each_stage_loss_value) 62 | else: 63 | detloss, each_stage_loss_value = 0, 'no-det-loss' 64 | 65 | if use_lk: 66 | lkloss, avaliable = lk_target_loss(batch_locs, batch_scos, batch_next, batch_fback, batch_back, lk_config, video_or_not, mask_np, nopoints) 67 | if lkloss is not None: 68 | lklosses.update(lkloss.item(), avaliable) 69 | else: lkloss = 0 70 | alk_points.update(float(avaliable)/batch_size, batch_size) 71 | else : lkloss = 0 72 | 73 | loss = detloss + lkloss * lk_config.weight 74 | 75 | if isinstance(loss, numbers.Number): 76 | warnings.warn('The {:}-th iteration has no detection loss and no lk loss'.format(i)) 77 | else: 78 | losses.update(loss.item(), batch_size) 79 | # compute gradient and do SGD step 80 | optimizer.zero_grad() 81 | loss.backward() 82 | optimizer.step() 83 | 84 | eval_time.update(time.time() - end) 85 | 86 | # measure elapsed time 87 | batch_time.update(time.time() - end) 88 | last_time = convert_secs2time(batch_time.avg * (len(loader)-i-1), True) 89 | end = time.time() 90 | 91 | if i % args.print_freq == 0 or i+1 == len(loader): 92 | logger.log(' -->>[Train]: [{:}][{:03d}/{:03d}] ' 93 | 'Time {batch_time.val:4.2f} ({batch_time.avg:4.2f}) ' 94 | 'Data {data_time.val:4.2f} ({data_time.avg:4.2f}) ' 95 | 'Forward {forward_time.val:4.2f} ({forward_time.avg:4.2f}) ' 96 | 'Loss {loss.val:7.4f} ({loss.avg:7.4f}) [LK={lk.val:7.4f} ({lk.avg:7.4f})] '.format( 97 | epoch_str, i, len(loader), batch_time=batch_time, 98 | data_time=data_time, forward_time=forward_time, loss=losses, lk=lklosses) 99 | + each_stage_loss_value + ' ' + last_time \ 100 | + ' Vis-PTS : {:2d} ({:.1f})'.format(int(visible_points.val), visible_points.avg) \ 101 | + ' Ava-PTS : {:.1f} ({:.1f})'.format(alk_points.val, alk_points.avg)) 102 | 103 | return losses.avg 104 | -------------------------------------------------------------------------------- /lib/procedure/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import numpy as np 8 | import numbers, torch 9 | import torch.nn.functional as F 10 | 11 | def compute_stage_loss(criterion, targets, outputs, masks): 12 | assert isinstance(outputs, list), 'The ouputs type is wrong : {:}'.format(type(outputs)) 13 | total_loss = 0 14 | each_stage_loss = [] 15 | 16 | for output in outputs: 17 | stage_loss = 0 18 | output = torch.masked_select(output , masks) 19 | target = torch.masked_select(targets, masks) 20 | 21 | stage_loss = criterion(output, target) 22 | total_loss = total_loss + stage_loss 23 | each_stage_loss.append(stage_loss.item()) 24 | return total_loss, each_stage_loss 25 | 26 | 27 | def show_stage_loss(each_stage_loss): 28 | if each_stage_loss is None: return 'None' 29 | elif isinstance(each_stage_loss, str): return each_stage_loss 30 | answer = '' 31 | for index, loss in enumerate(each_stage_loss): 32 | answer = answer + ' : L{:1d}={:7.4f}'.format(index+1, loss) 33 | return answer 34 | 35 | 36 | def sum_stage_loss(losses): 37 | total_loss = None 38 | each_stage_loss = [] 39 | for loss in losses: 40 | if total_loss is None: 41 | total_loss = loss 42 | else: 43 | total_loss = total_loss + loss 44 | each_stage_loss.append(loss.data[0]) 45 | return total_loss, each_stage_loss 46 | -------------------------------------------------------------------------------- /lib/procedure/saver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | 9 | def save_checkpoint(state, filename, logger): 10 | torch.save(state, filename) 11 | logger.log('save checkpoint into {}'.format(filename)) 12 | return filename 13 | -------------------------------------------------------------------------------- /lib/procedure/starts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os, sys, time 8 | import numpy as np 9 | import torch 10 | import random 11 | 12 | def prepare_seed(rand_seed): 13 | np.random.seed(rand_seed) 14 | random.seed(rand_seed) 15 | torch.manual_seed(rand_seed) 16 | torch.cuda.manual_seed_all(rand_seed) 17 | -------------------------------------------------------------------------------- /lib/pts_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .generation import generate_label_map 8 | -------------------------------------------------------------------------------- /lib/pts_utils/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from scipy.ndimage.interpolation import zoom 8 | import numbers, math 9 | import numpy as np 10 | 11 | ## pts = 3 * N numpy array; points location is based on the image with size (height*downsample, width*downsample) 12 | 13 | def generate_label_map(pts, height, width, sigma, downsample, nopoints, ctype): 14 | #if isinstance(pts, numbers.Number): 15 | # this image does not provide the annotation, pts is a int number representing the number of points 16 | #return np.zeros((height,width,pts+1), dtype='float32'), np.ones((1,1,1+pts), dtype='float32') 17 | # nopoints == True means this image does not provide the annotation, pts is a int number representing the number of points 18 | 19 | assert isinstance(pts, np.ndarray) and len(pts.shape) == 2 and pts.shape[0] == 3, 'The shape of points : {}'.format(pts.shape) 20 | if isinstance(sigma, numbers.Number): 21 | sigma = np.zeros((pts.shape[1])) + sigma 22 | assert isinstance(sigma, np.ndarray) and len(sigma.shape) == 1 and sigma.shape[0] == pts.shape[1], 'The shape of sigma : {}'.format(sigma.shape) 23 | 24 | offset = downsample / 2.0 - 0.5 25 | num_points, threshold = pts.shape[1], 0.01 26 | 27 | if nopoints == False: visiable = pts[2, :].astype('bool') 28 | else : visiable = (pts[2, :]*0).astype('bool') 29 | #assert visiable.shape[0] == num_points 30 | 31 | transformed_label = np.fromfunction( lambda y, x, pid : ((offset + x*downsample - pts[0,pid])**2 \ 32 | + (offset + y*downsample - pts[1,pid])**2) \ 33 | / -2.0 / sigma[pid] / sigma[pid], 34 | (height, width, num_points), dtype=int) 35 | 36 | mask_heatmap = np.ones((1, 1, num_points+1), dtype='float32') 37 | mask_heatmap[0, 0, :num_points] = visiable 38 | mask_heatmap[0, 0, num_points] = (nopoints==False) 39 | 40 | if ctype == 'laplacian': 41 | transformed_label = (1+transformed_label) * np.exp(transformed_label) 42 | elif ctype == 'gaussian': 43 | transformed_label = np.exp(transformed_label) 44 | else: 45 | raise TypeError('Does not know this type [{:}] for label generation'.format(ctype)) 46 | transformed_label[ transformed_label < threshold ] = 0 47 | transformed_label[ transformed_label > 1 ] = 1 48 | transformed_label = transformed_label * mask_heatmap[:, :, :num_points] 49 | 50 | background_label = 1 - np.amax(transformed_label, axis=2) 51 | background_label[ background_label < 0 ] = 0 52 | heatmap = np.concatenate((transformed_label, np.expand_dims(background_label, axis=2)), axis=2).astype('float32') 53 | 54 | return heatmap*mask_heatmap, mask_heatmap 55 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .file_utils import load_list_from_folders, load_txt_file 8 | -------------------------------------------------------------------------------- /lib/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os, sys, glob, numbers 8 | from os import path as osp 9 | 10 | def mkdir_if_missing(path): 11 | if not osp.isdir(path): 12 | os.makedirs(path) 13 | 14 | def is_path_exists(pathname): 15 | try: 16 | return isinstance(pathname, str) and pathname and os.path.exists(pathname) 17 | except OSError: 18 | return False 19 | 20 | def fileparts(pathname): 21 | ''' 22 | this function return a tuple, which contains (directory, filename, extension) 23 | if the file has multiple extension, only last one will be displayed 24 | ''' 25 | pathname = osp.normpath(pathname) 26 | if len(pathname) == 0: 27 | return ('', '', '') 28 | if pathname[-1] == '/': 29 | if len(pathname) > 1: 30 | return (pathname[:-1], '', '') # ignore the final '/' 31 | else: 32 | return (pathname, '', '') # ignore the final '/' 33 | directory = osp.dirname(osp.abspath(pathname)) 34 | filename = osp.splitext(osp.basename(pathname))[0] 35 | ext = osp.splitext(pathname)[1] 36 | return (directory, filename, ext) 37 | 38 | def load_txt_file(file_path): 39 | ''' 40 | load data or string from text file. 41 | ''' 42 | with open(file_path, 'r') as cfile: 43 | content = cfile.readlines() 44 | cfile.close() 45 | content = [x.strip() for x in content] 46 | num_lines = len(content) 47 | return content, num_lines 48 | 49 | def load_list_from_folder(folder_path, ext_filter=None, depth=1): 50 | ''' 51 | load a list of files or folders from a system path 52 | 53 | parameter: 54 | folder_path: root to search 55 | ext_filter: a string to represent the extension of files interested 56 | depth: maximum depth of folder to search, when it's None, all levels of folders will be searched 57 | ''' 58 | folder_path = osp.normpath(folder_path) 59 | assert isinstance(depth, int) , 'input depth is not correct {}'.format(depth) 60 | assert ext_filter is None or (isinstance(ext_filter, list) and all(isinstance(ext_tmp, str) for ext_tmp in ext_filter)) or isinstance(ext_filter, str), 'extension filter is not correct' 61 | if isinstance(ext_filter, str): # convert to a list 62 | ext_filter = [ext_filter] 63 | 64 | fulllist = list() 65 | wildcard_prefix = '*' 66 | for index in range(depth): 67 | if ext_filter is not None: 68 | for ext_tmp in ext_filter: 69 | curpath = osp.join(folder_path, wildcard_prefix + '.' + ext_tmp) 70 | fulllist += glob.glob(curpath) 71 | else: 72 | curpath = osp.join(folder_path, wildcard_prefix) 73 | fulllist += glob.glob(curpath) 74 | wildcard_prefix = osp.join(wildcard_prefix, '*') 75 | 76 | fulllist = [osp.normpath(path_tmp) for path_tmp in fulllist] 77 | num_elem = len(fulllist) 78 | 79 | return fulllist, num_elem 80 | 81 | def load_list_from_folders(folder_path_list, ext_filter=None, depth=1): 82 | ''' 83 | load a list of files or folders from a list of system path 84 | ''' 85 | assert isinstance(folder_path_list, list) or isinstance(folder_path_list, str), 'input path list is not correct' 86 | if isinstance(folder_path_list, str): 87 | folder_path_list = [folder_path_list] 88 | 89 | fulllist = list() 90 | num_elem = 0 91 | for folder_path_tmp in folder_path_list: 92 | fulllist_tmp, num_elem_tmp = load_list_from_folder(folder_path_tmp, ext_filter=ext_filter, depth=depth) 93 | fulllist += fulllist_tmp 94 | num_elem += num_elem_tmp 95 | 96 | return fulllist, num_elem 97 | -------------------------------------------------------------------------------- /lib/xvision/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .evaluation_util import Eval_Meta 8 | from .visualization import draw_image_by_points 9 | -------------------------------------------------------------------------------- /lib/xvision/common_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import numpy as np 8 | import pdb, os, time 9 | from log_utils import print_log 10 | from datasets.dataset_utils import convert68to49, convert68to51 11 | from sklearn.metrics import auc 12 | 13 | def evaluate_normalized_mean_error(predictions, groundtruth, log, extra_faces): 14 | ## compute total average normlized mean error 15 | assert len(predictions) == len(groundtruth), 'The lengths of predictions and ground-truth are not consistent : {} vs {}'.format( len(predictions), len(groundtruth) ) 16 | assert len(predictions) > 0, 'The length of predictions must be greater than 0 vs {}'.format( len(predictions) ) 17 | if extra_faces is not None: assert len(extra_faces) == len(predictions), 'The length of extra_faces is not right {} vs {}'.format( len(extra_faces), len(predictions) ) 18 | num_images = len(predictions) 19 | for i in range(num_images): 20 | c, g = predictions[i], groundtruth[i] 21 | assert isinstance(c, np.ndarray) and isinstance(g, np.ndarray), 'The type of predictions is not right : [{:}] :: {} vs {} '.format(i, type(c), type(g)) 22 | 23 | num_points = predictions[0].shape[1] 24 | error_per_image = np.zeros((num_images,1)) 25 | for i in range(num_images): 26 | detected_points = predictions[i] 27 | ground_truth_points = groundtruth[i] 28 | if num_points == 68: 29 | interocular_distance = np.linalg.norm(ground_truth_points[:2, 36] - ground_truth_points[:2, 45]) 30 | assert bool(ground_truth_points[2,36]) and bool(ground_truth_points[2,45]) 31 | elif num_points == 51 or num_points == 49: 32 | interocular_distance = np.linalg.norm(ground_truth_points[:2, 19] - ground_truth_points[:2, 28]) 33 | assert bool(ground_truth_points[2,19]) and bool(ground_truth_points[2,28]) 34 | elif num_points == 19: 35 | assert extra_faces is not None and extra_faces[i] is not None 36 | interocular_distance = extra_faces[i] 37 | else: 38 | raise Exception('----> Unknown number of points : {}'.format(num_points)) 39 | dis_sum, pts_sum = 0, 0 40 | for j in range(num_points): 41 | if bool(ground_truth_points[2, j]): 42 | dis_sum = dis_sum + np.linalg.norm(detected_points[:2, j] - ground_truth_points[:2, j]) 43 | pts_sum = pts_sum + 1 44 | error_per_image[i] = dis_sum / (pts_sum*interocular_distance) 45 | 46 | normalise_mean_error = error_per_image.mean() 47 | # calculate the auc for 0.07 48 | max_threshold = 0.07 49 | threshold = np.linspace(0, max_threshold, num=2000) 50 | accuracys = np.zeros(threshold.shape) 51 | for i in range(threshold.size): 52 | accuracys[i] = np.sum(error_per_image < threshold[i]) * 1.0 / error_per_image.size 53 | area_under_curve07 = auc(threshold, accuracys) / max_threshold 54 | # calculate the auc for 0.08 55 | max_threshold = 0.08 56 | threshold = np.linspace(0, max_threshold, num=2000) 57 | accuracys = np.zeros(threshold.shape) 58 | for i in range(threshold.size): 59 | accuracys[i] = np.sum(error_per_image < threshold[i]) * 1.0 / error_per_image.size 60 | area_under_curve08 = auc(threshold, accuracys) / max_threshold 61 | 62 | accuracy_under_007 = np.sum(error_per_image<0.07) * 100. / error_per_image.size 63 | accuracy_under_008 = np.sum(error_per_image<0.08) * 100. / error_per_image.size 64 | 65 | print_log('Compute NME and AUC for {:} images with {:} points :: [(NME): mean={:.3f}, std={:.3f}], auc@0.07={:.3f}, auc@0.08-{:.3f}, acc@0.07={:.3f}, acc@0.08={:.3f}'.format(num_images, num_points, normalise_mean_error*100, error_per_image.std()*100, area_under_curve07*100, area_under_curve08*100, accuracy_under_007, accuracy_under_008), log) 66 | 67 | for_pck_curve = [] 68 | for x in range(0, 3501, 1): 69 | error_bar = x * 0.0001 70 | accuracy = np.sum(error_per_image < error_bar) * 1.0 / error_per_image.size 71 | for_pck_curve.append((error_bar, accuracy)) 72 | 73 | return normalise_mean_error, accuracy_under_008, for_pck_curve 74 | -------------------------------------------------------------------------------- /lib/xvision/evaluation_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os, time 8 | import numpy as np 9 | import torch 10 | import json 11 | from log_utils import print_log 12 | from collections import OrderedDict 13 | from scipy import interpolate 14 | from mpl_toolkits.mplot3d import Axes3D 15 | from .common_eval import evaluate_normalized_mean_error 16 | 17 | class Eval_Meta(): 18 | 19 | def __init__(self): 20 | self.reset() 21 | 22 | def __repr__(self): 23 | return ('{name}'.format(name=self.__class__.__name__)+'(number of data = {:})'.format(len(self))) 24 | 25 | def reset(self): 26 | self.predictions = [] 27 | self.groundtruth = [] 28 | self.image_lists = [] 29 | self.face_sizes = [] 30 | 31 | def __len__(self): 32 | return len(self.image_lists) 33 | 34 | def append(self, _pred, _ground, image_path, face_size): 35 | assert _pred.shape[0] == 3 and len(_pred.shape) == 2, 'Prediction\'s shape is {:} vs [should be (3,pts) or (2,pts)]'.format(_pred.shape) 36 | if _ground is not None: 37 | assert _pred.shape == _ground.shape, 'shapes must be the same : {} vs {}'.format(_pred.shape, _ground.shape) 38 | if (not self.predictions) == False: 39 | assert _pred.shape == self.predictions[-1].shape, 'shapes must be the same : {} vs {}'.format(_pred.shape, self.predictions[-1].shape) 40 | self.predictions.append(_pred) 41 | self.groundtruth.append(_ground) 42 | self.image_lists.append(image_path) 43 | self.face_sizes.append(face_size) 44 | 45 | def save(self, filename): 46 | meta = {'predictions': self.predictions, 47 | 'groundtruth': self.groundtruth, 48 | 'image_lists': self.image_lists, 49 | 'face_sizes' : self.face_sizes} 50 | torch.save(meta, filename) 51 | print ('save eval-meta into {}'.format(filename)) 52 | 53 | def load(self, filename): 54 | assert os.path.isfile(filename), '{} is not a file'.format(filename) 55 | checkpoint = torch.load(filename) 56 | self.predictions = checkpoint['predictions'] 57 | self.groundtruth = checkpoint['groundtruth'] 58 | self.image_lists = checkpoint['image_lists'] 59 | self.face_sizes = checkpoint['face_sizes'] 60 | 61 | def compute_mse(self, log): 62 | predictions, groundtruth, face_sizes, num = [], [], [], 0 63 | for x, gt, face in zip(self.predictions, self.groundtruth, self.face_sizes): 64 | if gt is None: continue 65 | predictions.append(x) 66 | groundtruth.append(gt) 67 | face_sizes.append(face) 68 | num += 1 69 | print_log('Filter the unlabeled data from {:} into {:} data'.format(len(self), num), log) 70 | if num == 0: 71 | nme, auc, pck_curves = -1, None, None 72 | else: 73 | nme, auc, pck_curves = evaluate_normalized_mean_error(self.predictions, self.groundtruth, log, self.face_sizes) 74 | return nme, auc, pck_curves 75 | -------------------------------------------------------------------------------- /lib/xvision/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import division 8 | import torch 9 | import sys, math, random, PIL 10 | from PIL import Image, ImageOps 11 | import numpy as np 12 | import numbers 13 | import types 14 | import collections 15 | 16 | if sys.version_info.major == 2: 17 | import cPickle as pickle 18 | else: 19 | import pickle 20 | 21 | class Compose(object): 22 | def __init__(self, transforms): 23 | self.transforms = transforms 24 | 25 | def __call__(self, img, points): 26 | for t in self.transforms: 27 | img, points = t(img, points) 28 | return img, points 29 | 30 | class TrainScale2WH(object): 31 | """Rescale the input PIL.Image to the given size. 32 | Args: 33 | size (sequence or int): Desired output size. If size is a sequence like 34 | (w, h), output size will be matched to this. If size is an int, 35 | smaller edge of the image will be matched to this number. 36 | i.e, if height > width, then image will be rescaled to 37 | (size * height / width, size) 38 | interpolation (int, optional): Desired interpolation. Default is 39 | ``PIL.Image.BILINEAR`` 40 | """ 41 | 42 | def __init__(self, target_size, interpolation=Image.BILINEAR): 43 | assert isinstance(target_size, tuple) or isinstance(target_size, list), 'The type of target_size is not right : {}'.format(target_size) 44 | assert len(target_size) == 2, 'The length of target_size is not right : {}'.format(target_size) 45 | assert isinstance(target_size[0], int) and isinstance(target_size[1], int), 'The type of target_size is not right : {}'.format(target_size) 46 | self.target_size = target_size 47 | self.interpolation = interpolation 48 | 49 | def __call__(self, imgs, point_meta): 50 | """ 51 | Args: 52 | img (PIL.Image): Image to be scaled. 53 | points 3 * N numpy.ndarray [x, y, visiable] 54 | Returns: 55 | PIL.Image: Rescaled image. 56 | """ 57 | point_meta = point_meta.copy() 58 | 59 | if isinstance(imgs, list): is_list = True 60 | else: is_list, imgs = False, [imgs] 61 | 62 | w, h = imgs[0].size 63 | ow, oh = self.target_size[0], self.target_size[1] 64 | point_meta.apply_scale( [ow*1./w, oh*1./h] ) 65 | 66 | imgs = [ img.resize((ow, oh), self.interpolation) for img in imgs ] 67 | if is_list == False: imgs = imgs[0] 68 | 69 | return imgs, point_meta 70 | 71 | 72 | 73 | class ToPILImage(object): 74 | """Convert a tensor to PIL Image. 75 | Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape 76 | H x W x C to a PIL.Image while preserving the value range. 77 | """ 78 | 79 | def __call__(self, pic): 80 | """ 81 | Args: 82 | pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image. 83 | Returns: 84 | PIL.Image: Image converted to PIL.Image. 85 | """ 86 | npimg = pic 87 | mode = None 88 | if isinstance(pic, torch.FloatTensor): 89 | pic = pic.mul(255).byte() 90 | if torch.is_tensor(pic): 91 | npimg = np.transpose(pic.numpy(), (1, 2, 0)) 92 | assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray' 93 | if npimg.shape[2] == 1: 94 | npimg = npimg[:, :, 0] 95 | 96 | if npimg.dtype == np.uint8: 97 | mode = 'L' 98 | if npimg.dtype == np.int16: 99 | mode = 'I;16' 100 | if npimg.dtype == np.int32: 101 | mode = 'I' 102 | elif npimg.dtype == np.float32: 103 | mode = 'F' 104 | else: 105 | if npimg.dtype == np.uint8: 106 | mode = 'RGB' 107 | assert mode is not None, '{} is not supported'.format(npimg.dtype) 108 | return Image.fromarray(npimg, mode=mode) 109 | 110 | 111 | 112 | class ToTensor(object): 113 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. 114 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 115 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 116 | """ 117 | 118 | def __call__(self, pics, points): 119 | """ 120 | Args: 121 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. 122 | points 3 * N numpy.ndarray [x, y, visiable] or Point_Meta 123 | Returns: 124 | Tensor: Converted image. 125 | """ 126 | ## add to support list 127 | if isinstance(pics, list): is_list = True 128 | else: is_list, pics = False, [pics] 129 | 130 | returned = [] 131 | for pic in pics: 132 | if isinstance(pic, np.ndarray): 133 | # handle numpy array 134 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 135 | # backward compatibility 136 | returned.append( img.float().div(255) ) 137 | continue 138 | 139 | # handle PIL Image 140 | if pic.mode == 'I': 141 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 142 | elif pic.mode == 'I;16': 143 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 144 | else: 145 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 146 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 147 | if pic.mode == 'YCbCr': 148 | nchannel = 3 149 | elif pic.mode == 'I;16': 150 | nchannel = 1 151 | else: 152 | nchannel = len(pic.mode) 153 | img = img.view(pic.size[1], pic.size[0], nchannel) 154 | # put it from HWC to CHW format 155 | # yikes, this transpose takes 80% of the loading time/CPU 156 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 157 | if isinstance(img, torch.ByteTensor): 158 | img = img.float().div(255) 159 | returned.append(img) 160 | 161 | if is_list == False: 162 | assert len(returned) == 1, 'For non-list data, length of answer must be one not {}'.format(len(returned)) 163 | returned = returned[0] 164 | 165 | return returned, points 166 | 167 | 168 | class Normalize(object): 169 | """Normalize an tensor image with mean and standard deviation. 170 | Given mean: (R, G, B) and std: (R, G, B), 171 | will normalize each channel of the torch.*Tensor, i.e. 172 | channel = (channel - mean) / std 173 | Args: 174 | mean (sequence): Sequence of means for R, G, B channels respecitvely. 175 | std (sequence): Sequence of standard deviations for R, G, B channels 176 | respecitvely. 177 | """ 178 | 179 | def __init__(self, mean, std): 180 | self.mean = mean 181 | self.std = std 182 | 183 | def __call__(self, tensors, points): 184 | """ 185 | Args: 186 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 187 | Returns: 188 | Tensor: Normalized image. 189 | """ 190 | # TODO: make efficient 191 | if isinstance(tensors, list): is_list = True 192 | else: is_list, tensors = False, [tensors] 193 | 194 | for tensor in tensors: 195 | for t, m, s in zip(tensor, self.mean, self.std): 196 | t.sub_(m).div_(s) 197 | 198 | if is_list == False: tensors = tensors[0] 199 | 200 | return tensors, points 201 | 202 | 203 | class PreCrop(object): 204 | """Crops the given PIL.Image at the center. 205 | 206 | Args: 207 | size (sequence or int): Desired output size of the crop. If size is an 208 | int instead of sequence like (w, h), a square crop (size, size) is 209 | made. 210 | """ 211 | 212 | def __init__(self, expand_ratio): 213 | assert expand_ratio is None or isinstance(expand_ratio, numbers.Number), 'The expand_ratio should not be {}'.format(expand_ratio) 214 | if expand_ratio is None: 215 | self.expand_ratio = 0 216 | else: 217 | self.expand_ratio = expand_ratio 218 | assert self.expand_ratio >= 0, 'The expand_ratio should not be {}'.format(expand_ratio) 219 | 220 | def __call__(self, imgs, point_meta): 221 | ## AugCrop has something wrong... For unsupervised data 222 | 223 | if isinstance(imgs, list): is_list = True 224 | else: is_list, imgs = False, [imgs] 225 | 226 | w, h = imgs[0].size 227 | box = point_meta.get_box().tolist() 228 | face_ex_w, face_ex_h = (box[2] - box[0]) * self.expand_ratio, (box[3] - box[1]) * self.expand_ratio 229 | x1, y1 = int(max(math.floor(box[0]-face_ex_w), 0)), int(max(math.floor(box[1]-face_ex_h), 0)) 230 | x2, y2 = int(min(math.ceil(box[2]+face_ex_w), w)), int(min(math.ceil(box[3]+face_ex_h), h)) 231 | 232 | imgs = [ img.crop((x1, y1, x2, y2)) for img in imgs ] 233 | point_meta.set_precrop_wh( imgs[0].size[0], imgs[0].size[1], x1, y1, x2, y2) 234 | point_meta.apply_offset(-x1, -y1) 235 | point_meta.apply_bound(imgs[0].size[0], imgs[0].size[1]) 236 | 237 | if is_list == False: imgs = imgs[0] 238 | return imgs, point_meta 239 | 240 | 241 | class AugScale(object): 242 | """Rescale the input PIL.Image to the given size. 243 | 244 | Args: 245 | size (sequence or int): Desired output size. If size is a sequence like 246 | (w, h), output size will be matched to this. If size is an int, 247 | smaller edge of the image will be matched to this number. 248 | i.e, if height > width, then image will be rescaled to 249 | (size * height / width, size) 250 | interpolation (int, optional): Desired interpolation. Default is 251 | ``PIL.Image.BILINEAR`` 252 | """ 253 | 254 | def __init__(self, scale_prob, scale_min, scale_max, interpolation=Image.BILINEAR): 255 | assert isinstance(scale_prob, numbers.Number) and scale_prob >= 0, 'scale_prob : {:}'.format(scale_prob) 256 | assert isinstance(scale_min, numbers.Number) and isinstance(scale_max, numbers.Number), 'scales : {:}, {:}'.format(scale_min, scale_max) 257 | self.scale_prob = scale_prob 258 | self.scale_min = scale_min 259 | self.scale_max = scale_max 260 | self.interpolation = interpolation 261 | 262 | def __call__(self, imgs, point_meta): 263 | """ 264 | Args: 265 | img (PIL.Image): Image to be scaled. 266 | points 3 * N numpy.ndarray [x, y, visiable] 267 | Returns: 268 | PIL.Image: Rescaled image. 269 | """ 270 | point_meta = point_meta.copy() 271 | 272 | dice = random.random() 273 | if dice > self.scale_prob: 274 | return imgs, point_meta 275 | 276 | if isinstance(imgs, list): is_list = True 277 | else: is_list, imgs = False, [imgs] 278 | 279 | scale_multiplier = (self.scale_max - self.scale_min) * random.random() + self.scale_min 280 | 281 | w, h = imgs[0].size 282 | ow, oh = int(w * scale_multiplier), int(h * scale_multiplier) 283 | 284 | imgs = [ img.resize((ow, oh), self.interpolation) for img in imgs ] 285 | point_meta.apply_scale( [scale_multiplier] ) 286 | 287 | if is_list == False: imgs = imgs[0] 288 | 289 | return imgs, point_meta 290 | 291 | 292 | class AugCrop(object): 293 | 294 | def __init__(self, crop_x, crop_y, center_perterb_max, fill=0): 295 | assert isinstance(crop_x, int) and isinstance(crop_y, int) and isinstance(center_perterb_max, numbers.Number) 296 | self.crop_x = crop_x 297 | self.crop_y = crop_y 298 | self.center_perterb_max = center_perterb_max 299 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 300 | self.fill = fill 301 | 302 | def __call__(self, imgs, point_meta=None): 303 | ## AugCrop has something wrong... For unsupervised data 304 | 305 | point_meta = point_meta.copy() 306 | if isinstance(imgs, list): is_list = True 307 | else: is_list, imgs = False, [imgs] 308 | 309 | dice_x, dice_y = random.random(), random.random() 310 | x_offset = int( (dice_x-0.5) * 2 * self.center_perterb_max) 311 | y_offset = int( (dice_y-0.5) * 2 * self.center_perterb_max) 312 | 313 | x1 = int(round( point_meta.center[0] + x_offset - self.crop_x / 2. )) 314 | y1 = int(round( point_meta.center[1] + y_offset - self.crop_y / 2. )) 315 | x2 = x1 + self.crop_x 316 | y2 = y1 + self.crop_y 317 | 318 | w, h = imgs[0].size 319 | if x1 < 0 or y1 < 0 or x2 >= w or y2 >= h: 320 | pad = max(0-x1, 0-y1, x2-w+1, y2-h+1) 321 | assert pad > 0, 'padding operation in crop must be greater than 0' 322 | imgs = [ ImageOps.expand(img, border=pad, fill=self.fill) for img in imgs ] 323 | x1, x2, y1, y2 = x1 + pad, x2 + pad, y1 + pad, y2 + pad 324 | point_meta.apply_offset(pad, pad) 325 | point_meta.apply_bound(imgs[0].size[0], imgs[0].size[1]) 326 | 327 | point_meta.apply_offset(-x1, -y1) 328 | imgs = [ img.crop((x1, y1, x2, y2)) for img in imgs ] 329 | point_meta.apply_bound(imgs[0].size[0], imgs[0].size[1]) 330 | 331 | if is_list == False: imgs = imgs[0] 332 | return imgs, point_meta 333 | 334 | class AugRotate(object): 335 | """Rotate the given PIL.Image at the center. 336 | Args: 337 | size (sequence or int): Desired output size of the crop. If size is an 338 | int instead of sequence like (w, h), a square crop (size, size) is 339 | made. 340 | """ 341 | 342 | def __init__(self, max_rotate_degree): 343 | assert isinstance(max_rotate_degree, numbers.Number) 344 | self.max_rotate_degree = max_rotate_degree 345 | 346 | def __call__(self, imgs, point_meta): 347 | """ 348 | Args: 349 | img (PIL.Image): Image to be cropped. 350 | point_meta : Point_Meta 351 | Returns: 352 | PIL.Image: Rotated image. 353 | """ 354 | point_meta = point_meta.copy() 355 | if isinstance(imgs, list): is_list = True 356 | else: is_list, imgs = False, [imgs] 357 | 358 | degree = (random.random() - 0.5) * 2 * self.max_rotate_degree 359 | center = (imgs[0].size[0] / 2, imgs[0].size[1] / 2) 360 | if PIL.__version__[0] == '4': 361 | imgs = [ img.rotate(degree, center=center) for img in imgs ] 362 | else: 363 | imgs = [ img.rotate(degree) for img in imgs ] 364 | 365 | point_meta.apply_rotate(center, degree) 366 | point_meta.apply_bound(imgs[0].size[0], imgs[0].size[1]) 367 | 368 | if is_list == False: imgs = imgs[0] 369 | 370 | return imgs, point_meta 371 | -------------------------------------------------------------------------------- /lib/xvision/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from PIL import Image 8 | from PIL import ImageDraw 9 | from PIL import ImageFont 10 | import numpy as np 11 | import datasets 12 | 13 | def draw_image_by_points(_image, pts, radius, color, crop, resize): 14 | if isinstance(_image, str): 15 | _image = datasets.pil_loader(_image) 16 | assert isinstance(_image, Image.Image), 'image type is not PIL.Image.Image' 17 | assert isinstance(pts, np.ndarray) and (pts.shape[0] == 2 or pts.shape[0] == 3), 'input points are not correct' 18 | image, pts = _image.copy(), pts.copy() 19 | 20 | num_points = pts.shape[1] 21 | visiable_points = [] 22 | for idx in range(num_points): 23 | if pts.shape[0] == 2 or bool(pts[2,idx]): 24 | visiable_points.append( True ) 25 | else: 26 | visiable_points.append( False ) 27 | visiable_points = np.array( visiable_points ) 28 | #print ('visiable points : {}'.format( np.sum(visiable_points) )) 29 | 30 | if crop: 31 | if isinstance(crop, list): 32 | x1, y1, x2, y2 = int(crop[0]), int(crop[1]), int(crop[2]), int(crop[3]) 33 | else: 34 | x1, x2 = pts[0, visiable_points].min(), pts[0, visiable_points].max() 35 | y1, y2 = pts[1, visiable_points].min(), pts[1, visiable_points].max() 36 | face_h, face_w = (y2-y1)*0.1, (x2-x1)*0.1 37 | x1, x2 = int(x1 - face_w), int(x2 + face_w) 38 | y1, y2 = int(y1 - face_h), int(y2 + face_h) 39 | image = image.crop((x1, y1, x2, y2)) 40 | pts[0, visiable_points] = pts[0, visiable_points] - x1 41 | pts[1, visiable_points] = pts[1, visiable_points] - y1 42 | 43 | if resize: 44 | width, height = image.size 45 | image = image.resize((resize,resize), Image.BICUBIC) 46 | pts[0, visiable_points] = pts[0, visiable_points] * 1.0 / width * resize 47 | pts[1, visiable_points] = pts[1, visiable_points] * 1.0 / height * resize 48 | 49 | finegrain = True 50 | if finegrain: 51 | owidth, oheight = image.size 52 | image = image.resize((owidth*8,oheight*8), Image.BICUBIC) 53 | pts[0, visiable_points] = pts[0, visiable_points] * 8.0 54 | pts[1, visiable_points] = pts[1, visiable_points] * 8.0 55 | radius = radius * 8 56 | 57 | draw = ImageDraw.Draw(image) 58 | for idx in range(num_points): 59 | if visiable_points[ idx ]: 60 | # draw hollow circle 61 | point = (pts[0,idx]-radius, pts[1,idx]-radius, pts[0,idx]+radius, pts[1,idx]+radius) 62 | if radius > 0: 63 | draw.ellipse(point, fill=color, outline=color) 64 | 65 | if finegrain: 66 | image = image.resize((owidth,oheight), Image.BICUBIC) 67 | 68 | return image 69 | -------------------------------------------------------------------------------- /scripts/300W-DET.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | CUDA_VISIBLE_DEVICES=0,1 python ./exps/basic_main.py \ 3 | --train_lists ./cache_data/lists/300W/300w.train.DET \ 4 | --eval_ilists ./cache_data/lists/300W/300w.test.common.DET \ 5 | ./cache_data/lists/300W/300w.test.challenge.DET \ 6 | ./cache_data/lists/300W/300w.test.full.DET \ 7 | --num_pts 68 \ 8 | --model_config ./configs/Detector.config \ 9 | --opt_config ./configs/SGD.config \ 10 | --save_path ./snapshots/300W-CPM-DET \ 11 | --pre_crop_expand 0.2 --sigma 4 --batch_size 8 \ 12 | --crop_perturb_max 30 --rotate_max 20 \ 13 | --scale_prob 1.0 --scale_min 0.9 --scale_max 1.1 --scale_eval 1 \ 14 | --heatmap_type gaussian 15 | -------------------------------------------------------------------------------- /scripts/AFLW-DET.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | CUDA_VISIBLE_DEVICES=0,1 python ./exps/basic_main.py \ 3 | --train_lists ./cache_data/lists/AFLW/train.GTB \ 4 | --eval_ilists ./cache_data/lists/AFLW/test.GTB \ 5 | --num_pts 19 \ 6 | --model_config ./configs/Detector.config \ 7 | --opt_config ./configs/SGD.config \ 8 | --save_path ./snapshots/AFLW-CPM-DET \ 9 | --pre_crop_expand 0.2 --sigma 4 --batch_size 8 \ 10 | --crop_perturb_max 30 --rotate_max 20 \ 11 | --scale_prob 1.0 --scale_min 0.9 --scale_max 1.1 --scale_eval 1 \ 12 | --heatmap_type gaussian 13 | -------------------------------------------------------------------------------- /scripts/demo_pam.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | CUDA_VISIBLE_DEVICES=0,1 python ./exps/lk_main.py \ 3 | --train_lists ./cache_data/lists/demo/demo-pam.lst \ 4 | --eval_ilists ./cache_data/lists/demo/demo-pam.lst \ 5 | --num_pts 68 \ 6 | --model_config ./configs/Detector.config \ 7 | --opt_config ./configs/LK.SGD.config \ 8 | --lk_config ./configs/lk.config \ 9 | --video_parser x-1-1 --save_path ./snapshots/CPM-PAM \ 10 | --init_model ./snapshots/300W-CPM-DET/checkpoint/cpm_vgg16-epoch-049-050.pth \ 11 | --pre_crop_expand 0.2 --sigma 4 \ 12 | --batch_size 8 --crop_perturb_max 5 --scale_prob 1 --scale_min 1 --scale_max 1 --scale_eval 1 --heatmap_type gaussian \ 13 | --print_freq 10 14 | -------------------------------------------------------------------------------- /scripts/demo_sbr.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | CUDA_VISIBLE_DEVICES=2,3 python ./exps/lk_main.py \ 3 | --train_lists ./cache_data/lists/demo/demo-sbr.lst \ 4 | ./cache_data/lists/300W/300w.train.DET \ 5 | --eval_ilists ./cache_data/lists/demo/demo-sbr.lst \ 6 | --num_pts 68 \ 7 | --model_config ./configs/Detector.config \ 8 | --opt_config ./configs/LK.SGD.config \ 9 | --lk_config ./configs/mix.lk.config \ 10 | --video_parser x-1-1 --save_path ./snapshots/CPM-SBR \ 11 | --init_model ./snapshots/300W-CPM-DET/checkpoint/cpm_vgg16-epoch-049-050.pth \ 12 | --pre_crop_expand 0.2 --sigma 4 \ 13 | --batch_size 8 --crop_perturb_max 5 --scale_prob 1 --scale_min 1 --scale_max 1 --scale_eval 1 --heatmap_type gaussian \ 14 | --print_freq 10 15 | -------------------------------------------------------------------------------- /scripts/sbr_example.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | CUDA_VISIBLE_DEVICES=0,1 python ./exps/lk_main.py \ 3 | --train_lists ${dataset_with_annotation} \ 4 | ${dataset_without_annotation} \ 5 | --eval_ilists ${evaluation_dataset_lists} \ 6 | --num_pts 68 7 | --model_config ${detection.configuration} \ 8 | --opt_config ${optimization.configuration} \ 9 | --lk_config ${LK.configuration} \ 10 | --video_parser x-1-1 \ 11 | --save_path ${snapshot_path} \ 12 | --init_model ./snapshots/300W-CPM-DET/checkpoint/cpm_vgg16-epoch-049-050.pth \ 13 | --pre_crop_expand 0.2 --sigma 4 --batch_size 8 --crop_perturb_max 5 --scale_prob 1 --scale_min 1 --scale_max 1 --scale_eval 1 --heatmap_type gaussian --print_freq 10 14 | --------------------------------------------------------------------------------