├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── chainer ├── commands │ ├── __init__.py │ └── interactive_train.py ├── datasets │ ├── __init__.py │ ├── concatenated_dataset.py │ ├── file_dataset.py │ └── sub_dataset.py ├── evaluate.py ├── evaluation │ ├── __init__.py │ └── evaluator.py ├── fsns_demo.py ├── functions │ ├── __init__.py │ ├── disable_shearing.py │ ├── disable_translation.py │ └── rotation_droput.py ├── insights │ ├── __init__.py │ ├── bbox_plotter.py │ ├── fsns_bbox_plotter.py │ ├── lstm_per_step_plotter.py │ ├── svhn_bbox_plotter.py │ ├── text_rec_bbox_plotter.py │ ├── textrec_bbox_plotter.py │ └── visual_backprop.py ├── metrics │ ├── __init__.py │ ├── ctc_metrics.py │ ├── loss_metrics.py │ ├── lstm_per_step_metrics.py │ ├── softmax_metrics.py │ ├── svhn_ctc_metrics.py │ ├── svhn_softmax_metrics.py │ └── textrec_metrics.py ├── models │ ├── __init__.py │ ├── fsns.py │ ├── fsns_resnet.py │ ├── ic_stn.py │ ├── svhn.py │ └── text_recognition.py ├── optimizers │ ├── __init__.py │ └── multi_net_optimizer.py ├── text_recognition_demo.py ├── train_fsns.py ├── train_mnist.py ├── train_svhn.py ├── train_text_recognition.py └── utils │ ├── README.md │ ├── __init__.py │ ├── baby_step_curriculum.py │ ├── create_gif.py │ ├── crop_images.py │ ├── datatypes.py │ ├── dict_eval.py │ ├── intelligent_attribute_shifter.py │ ├── logger.py │ ├── multi_accuracy_classifier.py │ ├── plotting.py │ └── train_utils.py ├── datasets ├── fsns │ ├── __init__.py │ ├── change_file_names.py │ ├── download_fsns.py │ ├── extract_words.py │ ├── fonts │ │ ├── DejaVuSansMono-Bold.ttf │ │ └── DejaVuSansMono.ttf │ ├── fsns_char_map.json │ ├── render_text_on_signs.py │ ├── slice_fsns_dataset.py │ ├── swap_classes.py │ ├── tfrecord_to_image.py │ ├── transform_back_to_single_line.py │ └── transform_gt.py ├── svhn │ ├── create_svhn_csv_gt.py │ ├── create_svhn_dataset.py │ ├── create_svhn_dataset_4_images.py │ ├── filter_large_images.py │ ├── prepare_svhn_crops.py │ ├── svhn_char_map.json │ └── svhn_dataextract_tojson.py └── textrec │ └── ctc_char_map.json ├── requirements.txt └── utils ├── create_video.py └── show_progress.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | logs 104 | .idea 105 | 106 | nccl* 107 | *.csv 108 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | ARG FROM_IMAGE=nvidia/cuda:9.1-cudnn7-devel-ubuntu16.04 3 | FROM ${FROM_IMAGE} 4 | 5 | # install opencv for python 3 6 | RUN apt-get update && \ 7 | apt-get install -y \ 8 | build-essential \ 9 | git \ 10 | libasound2-dev \ 11 | libavformat-dev \ 12 | libcanberra-gtk3-module \ 13 | libgtk-3-dev \ 14 | libjasper-dev \ 15 | libjpeg-dev \ 16 | libpng-dev \ 17 | libpq-dev \ 18 | libswscale-dev \ 19 | libtbb-dev \ 20 | libtbb2 \ 21 | libtiff-dev \ 22 | pkg-config \ 23 | python3 \ 24 | python3-numpy \ 25 | python3-pip \ 26 | unzip \ 27 | wget \ 28 | yasm 29 | 30 | # Set the working directory to /app 31 | WORKDIR /app 32 | 33 | ARG NCCL_NAME=nccl-repo-ubuntu1604-2.1.15-ga-cuda9.1_1-1_amd64.deb 34 | COPY ${NCCL_NAME} /app 35 | RUN dpkg -i ${NCCL_NAME} 36 | RUN apt-get update && apt-get install -y libnccl2 libnccl-dev 37 | 38 | COPY requirements.txt /app/ 39 | RUN pip3 install -v --trusted-host pypi.python.org -r requirements.txt 40 | 41 | # Copy the current directory contents into the container at /app 42 | COPY . /app 43 | 44 | # Define environment variable 45 | ENV NAME See 46 | -------------------------------------------------------------------------------- /chainer/commands/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/chainer/commands/__init__.py -------------------------------------------------------------------------------- /chainer/commands/interactive_train.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | from cmd import Cmd 4 | 5 | 6 | class InteractiveTrain(Cmd): 7 | 8 | prompt = '' 9 | 10 | def __init__(self, *args, **kwargs): 11 | self.bbox_plotter = kwargs.pop('bbox_plotter') 12 | self.curriculum = kwargs.pop('curriculum', None) 13 | self.lr_shifter = kwargs.pop('lr_shifter', None) 14 | 15 | super().__init__(*args, **kwargs) 16 | 17 | def do_enablebboxvis(self, arg): 18 | """Enable sending of bboxes to remote host""" 19 | self.bbox_plotter.send_bboxes = True 20 | 21 | def do_increasedifficulty(self, arg): 22 | """Increase dfficulty of learning curriculum""" 23 | if self.curriculum is not None: 24 | self.curriculum.force_enlarge_dataset = True 25 | 26 | def do_shiftlr(self, arg): 27 | if self.lr_shifter is not None: 28 | self.lr_shifter.force_shift = True 29 | 30 | def do_quit(self, arg): 31 | return True 32 | 33 | def do_echo(self, arg): 34 | print(arg) 35 | 36 | 37 | def open_interactive_prompt(*args, **kwargs): 38 | cmd_interface = InteractiveTrain(*args, **kwargs) 39 | 40 | thread = threading.Thread(target=lambda: cmd_interface.cmdloop()) 41 | thread.daemon = True 42 | thread.start() 43 | -------------------------------------------------------------------------------- /chainer/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/chainer/datasets/__init__.py -------------------------------------------------------------------------------- /chainer/datasets/concatenated_dataset.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | from chainer.dataset import dataset_mixin 5 | 6 | 7 | class ConcatenatedDataset(dataset_mixin.DatasetMixin): 8 | 9 | """Dataset which concatenates some base datasets. 10 | This dataset wraps some base datasets and works as a concatenated dataset. 11 | For example, if a base dataset with 10 samples and 12 | another base dataset with 20 samples are given, this dataset works as 13 | a dataset which has 30 samples. 14 | Args: 15 | datasets: The underlying datasets. Each dataset has to support 16 | :meth:`__len__` and :meth:`__getitem__`. 17 | """ 18 | 19 | def __init__(self, *datasets): 20 | self._datasets = datasets 21 | 22 | def __len__(self): 23 | return sum(len(dataset) for dataset in self._datasets) 24 | 25 | def pad_labels(self, new_label_length, pad_value): 26 | for dataset in self._datasets: 27 | dataset.pad_labels(new_label_length, pad_value) 28 | 29 | def get_example(self, i): 30 | if i < 0: 31 | raise IndexError 32 | for dataset in self._datasets: 33 | if i < len(dataset): 34 | return dataset[i] 35 | i -= len(dataset) 36 | raise IndexError -------------------------------------------------------------------------------- /chainer/datasets/file_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | 5 | try: 6 | import cv2 7 | except ImportError: 8 | pass 9 | 10 | import numpy as np 11 | import random 12 | from PIL import Image 13 | 14 | from chainer.dataset import dataset_mixin 15 | 16 | 17 | class FileBasedDataset(dataset_mixin.DatasetMixin): 18 | 19 | def __init__(self, dataset_file, file_contains_metadata=True, resize_size=None): 20 | self.file_names = [] 21 | self.labels = [] 22 | self.base_dir = os.path.dirname(dataset_file) 23 | self.num_timesteps = None 24 | self.num_labels = None 25 | self.resize_size = resize_size 26 | 27 | with open(dataset_file, 'r') as f: 28 | reader = csv.reader(f, delimiter='\t') 29 | if file_contains_metadata: 30 | # first read metadata about file 31 | self.num_timesteps, self.num_labels = (int(i) for i in next(reader)) 32 | # then read all data 33 | for line in reader: 34 | file_name = line[0] 35 | labels = np.array(line[1:], dtype=np.int32) 36 | self.file_names.append(file_name) 37 | self.labels.append(labels) 38 | 39 | assert len(self.file_names) == len(self.labels) 40 | label_length = len(self.labels[0]) 41 | for i, label in enumerate(self.labels): 42 | if len(label) != label_length: 43 | print("Label of file {} is not as long as all others ({} vs {})".format(self.file_names[i], len(label), label_length)) 44 | 45 | def __len__(self): 46 | return len(self.file_names) 47 | 48 | def pad_labels(self, new_num_timesteps, pad_value): 49 | padded_labels = [ 50 | np.concatenate( 51 | (label, np.array([pad_value] * (new_num_timesteps - self.num_timesteps) * self.num_labels, dtype=np.int32)), 52 | axis=0 53 | ) for label in self.labels 54 | ] 55 | self.num_timesteps = new_num_timesteps 56 | self.labels = padded_labels 57 | 58 | def get_label_length(self, num_timesteps, check_length=True): 59 | label_length, rest = divmod(len(self.labels[0]), num_timesteps) 60 | if check_length: 61 | assert rest == 0, "Number of labels does not evenly divide by number of timesteps! (Rest: {})".format(rest) 62 | return label_length 63 | 64 | def load_image(self, file_name): 65 | with Image.open(os.path.join(self.base_dir, file_name)) as the_image: 66 | the_image = the_image.convert("RGB") 67 | if self.resize_size is not None: 68 | the_image = the_image.resize((self.resize_size.width, self.resize_size.height), Image.LANCZOS) 69 | assert the_image.width == self.resize_size.width 70 | assert the_image.height == self.resize_size.height 71 | assert the_image.mode == 'RGB' 72 | 73 | image = np.asarray(the_image, dtype=np.float32) 74 | image /= 255 75 | 76 | # put color channels to the front, as expected by Chainer 77 | image = image.transpose(2, 0, 1) 78 | num_channels, height, width = image.shape 79 | assert num_channels == 3 80 | if self.resize_size is not None: 81 | assert height == self.resize_size.height 82 | assert width == self.resize_size.width 83 | 84 | return image 85 | 86 | def get_example(self, i): 87 | while True: 88 | try: 89 | image = self.load_image(self.file_names[i]) 90 | break 91 | except Exception as e: 92 | print("could not load image: {}".format(self.file_names[i])) 93 | i = random.randint(0, len(self)) 94 | 95 | label = self.labels[i] 96 | return image, label 97 | 98 | 99 | class TextRecFileDataset(dataset_mixin.DatasetMixin): 100 | 101 | def __init__(self, dataset_file, char_map=None, file_contains_metadata=True, resize_size=None, blank_label=0): 102 | self.file_names = [] 103 | self.labels = [] 104 | self.base_dir = os.path.dirname(dataset_file) 105 | self.num_timesteps = None 106 | self.num_labels = None 107 | self.resize_size = resize_size 108 | self.blank_label = blank_label 109 | 110 | with open(char_map) as the_map: 111 | self.char_map = json.load(the_map) 112 | self.reverse_char_map = {v: k for k, v in self.char_map.items()} 113 | 114 | with open(dataset_file, 'r') as f: 115 | reader = csv.reader(f, delimiter='\t') 116 | if file_contains_metadata: 117 | # first read metadata about file 118 | self.num_timesteps, self.num_labels = (int(i) for i in next(reader)) 119 | # then read all data 120 | for line in reader: 121 | file_name = line[0] 122 | labels = line[1] 123 | self.file_names.append(file_name) 124 | self.labels.append(labels) 125 | 126 | assert len(self.file_names) == len(self.labels) 127 | 128 | def __len__(self): 129 | return len(self.file_names) 130 | 131 | def get_label_length(self, num_timesteps, check_length=True): 132 | return self.num_labels 133 | 134 | def get_example(self, i): 135 | try: 136 | image = self.load_image(self.file_names[i]) 137 | except Exception as e: 138 | print("can not load image: {}".format(self.file_names[i])) 139 | i = random.randint(0, len(self)) 140 | image = self.load_image(self.file_names[i]) 141 | 142 | labels = self.get_labels(self.labels[i]) 143 | return image, labels 144 | 145 | def load_image(self, file_name): 146 | with Image.open(os.path.join(self.base_dir, file_name)) as the_image: 147 | the_image = the_image.convert('RGB') 148 | if self.resize_size is not None: 149 | the_image = the_image.resize((self.resize_size.width, self.resize_size.height), Image.LANCZOS) 150 | image = np.asarray(the_image, dtype=np.float32) 151 | image /= 255 152 | del the_image 153 | 154 | image = image.transpose(2, 0, 1) 155 | return image 156 | 157 | def get_labels(self, word): 158 | labels = [int(self.reverse_char_map[ord(character)]) for character in word] 159 | labels += [self.blank_label] * (self.num_timesteps - len(labels)) 160 | return np.array(labels, dtype=np.int32) 161 | 162 | 163 | class OpencvTextRecFileDataset(TextRecFileDataset): 164 | 165 | def load_image(self, file_name): 166 | the_image = cv2.imread(file_name, cv2.IMREAD_COLOR) 167 | the_image = cv2.cvtColor(the_image, cv2.COLOR_BGR2RGB) 168 | if self.resize_size is not None: 169 | the_image = cv2.resize(the_image, self.resize_size) 170 | 171 | the_image = the_image.astype(np.float32) / 255 172 | image = np.transpose(the_image, (2, 0, 1)) 173 | del the_image 174 | 175 | return image 176 | 177 | -------------------------------------------------------------------------------- /chainer/datasets/sub_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from chainer.datasets.sub_dataset import SubDataset 4 | 5 | 6 | class PaddableSubDataset(SubDataset): 7 | 8 | def pad_labels(self, new_label_length, pad_value): 9 | self._dataset.pad_labels(new_label_length, pad_value) 10 | 11 | 12 | def split_dataset(dataset, split_at, order=None): 13 | """Splits a dataset into two subsets. 14 | This function creates two instances of :class:`SubDataset`. These instances 15 | do not share any examples, and they together cover all examples of the 16 | original dataset. 17 | Args: 18 | dataset: Dataset to split. 19 | split_at (int): Position at which the base dataset is split. 20 | order (sequence of ints): Permutation of indexes in the base dataset. 21 | See the document of :class:`SubDataset` for details. 22 | Returns: 23 | tuple: Two :class:`SubDataset` objects. The first subset represents the 24 | examples of indexes ``order[:split_at]`` while the second subset 25 | represents the examples of indexes ``order[split_at:]``. 26 | """ 27 | n_examples = len(dataset) 28 | if split_at < 0: 29 | raise ValueError('split_at must be non-negative') 30 | if split_at >= n_examples: 31 | raise ValueError('split_at exceeds the dataset size') 32 | subset1 = PaddableSubDataset(dataset, 0, split_at, order) 33 | subset2 = PaddableSubDataset(dataset, split_at, n_examples, order) 34 | return subset1, subset2 35 | 36 | 37 | def split_dataset_random(dataset, first_size, seed=None): 38 | """Splits a dataset into two subsets randomly. 39 | This function creates two instances of :class:`SubDataset`. These instances 40 | do not share any examples, and they together cover all examples of the 41 | original dataset. The split is automatically done randomly. 42 | Args: 43 | dataset: Dataset to split. 44 | first_size (int): Size of the first subset. 45 | seed (int): Seed the generator used for the permutation of indexes. 46 | If an integer being convertible to 32 bit unsigned integers is 47 | specified, it is guaranteed that each sample 48 | in the given dataset always belongs to a specific subset. 49 | If ``None``, the permutation is changed randomly. 50 | Returns: 51 | tuple: Two :class:`SubDataset` objects. The first subset contains 52 | ``first_size`` examples randomly chosen from the dataset without 53 | replacement, and the second subset contains the rest of the 54 | dataset. 55 | """ 56 | order = numpy.random.RandomState(seed).permutation(len(dataset)) 57 | return split_dataset(dataset, first_size, order) 58 | 59 | def split_dataset_n(dataset, n, order=None): 60 | """Splits a dataset into ``n`` subsets. 61 | Args: 62 | dataset: Dataset to split. 63 | n(int): The number of subsets. 64 | order (sequence of ints): Permutation of indexes in the base dataset. 65 | See the document of :class:`SubDataset` for details. 66 | Returns: 67 | list: List of ``n`` :class:`SubDataset` objects. 68 | Each subset contains the examples of indexes 69 | ``order[i * (len(dataset) // n):(i + 1) * (len(dataset) // n)]`` 70 | . 71 | """ 72 | n_examples = len(dataset) 73 | sub_size = n_examples // n 74 | return [PaddableSubDataset(dataset, sub_size * i, sub_size * (i + 1), order) 75 | for i in range(n)] 76 | 77 | 78 | def split_dataset_n_random(dataset, n, seed=None): 79 | """Splits a dataset into ``n`` subsets randomly. 80 | Args: 81 | dataset: Dataset to split. 82 | n(int): The number of subsets. 83 | seed (int): Seed the generator used for the permutation of indexes. 84 | If an integer being convertible to 32 bit unsigned integers is 85 | specified, it is guaranteed that each sample 86 | in the given dataset always belongs to a specific subset. 87 | If ``None``, the permutation is changed randomly. 88 | Returns: 89 | list: List of ``n`` :class:`SubDataset` objects. 90 | Each subset contains ``len(dataset) // n`` examples randomly chosen 91 | from the dataset without replacement. 92 | """ 93 | n_examples = len(dataset) 94 | sub_size = n_examples // n 95 | order = numpy.random.RandomState(seed).permutation(len(dataset)) 96 | return [PaddableSubDataset(dataset, sub_size * i, sub_size * (i + 1), order) 97 | for i in range(n)] 98 | -------------------------------------------------------------------------------- /chainer/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from evaluation.evaluator import FSNSEvaluator, SVHNEvaluator, TextRecognitionEvaluator 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser(description="Tool that evaluates a trained model on a chosen test set (either FSNS or SVHN)") 7 | subparsers = parser.add_subparsers(help="choice of evaluation type") 8 | parser.add_argument("model_dir", help='path to model dir') 9 | parser.add_argument("snapshot_name", help="name of snapshot in model dir") 10 | parser.add_argument("eval_gt", help="path to evaluation groundtruth file") 11 | parser.add_argument("char_map", help="Path to char map") 12 | parser.add_argument("num_labels", help="number of labels per sample", type=int) 13 | parser.add_argument("--dropout-ratio", type=float, default=0.5, help="dropout ratio") 14 | parser.add_argument("--target-shape", default="75,100", help="input shape for recognition network in form: height,width [default: 75,100]") 15 | parser.add_argument("--timesteps", default=3, type=int, help="number of timesteps localization net shall perform [default: 3]") 16 | parser.add_argument("--blank-symbol", type=int, default=0, help="blank symbol used for padding [default: 0]") 17 | parser.add_argument("--gpu", type=int, default=-1, help="gpu to use [default: use cpu]") 18 | parser.add_argument("--save-rois", action='store_true', default=False, help="save rois of each image for further inspection") 19 | parser.add_argument("--num-rois", type=int, default=1000, help="number of rois to save [default: 1000]") 20 | parser.add_argument('--log-name', default='log', help='name of the log file [default: log]') 21 | 22 | fsns_parser = subparsers.add_parser("fsns", help="evaluate fsns model") 23 | fsns_parser.set_defaults(evaluator=FSNSEvaluator) 24 | 25 | svhn_parser = subparsers.add_parser("svhn", help="evaluate svhn model") 26 | svhn_parser.set_defaults(evaluator=SVHNEvaluator) 27 | 28 | text_recognition_parser = subparsers.add_parser("textrec", help="evaluate text recognition model") 29 | text_recognition_parser.set_defaults(evaluator=TextRecognitionEvaluator) 30 | 31 | args = parser.parse_args() 32 | args.is_original_fsns = True 33 | args.refinement_steps = 0 34 | args.refinement = False 35 | args.render_all_bboxes = False 36 | 37 | evaluator = args.evaluator(args) 38 | evaluator.evaluate() 39 | -------------------------------------------------------------------------------- /chainer/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/chainer/evaluation/__init__.py -------------------------------------------------------------------------------- /chainer/fsns_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | 4 | import os 5 | 6 | import json 7 | from collections import OrderedDict 8 | 9 | import chainer 10 | from pprint import pprint 11 | 12 | import chainer.functions as F 13 | import numpy as np 14 | 15 | from PIL import Image 16 | from chainer import configuration 17 | 18 | from utils.datatypes import Size 19 | 20 | 21 | def get_class_and_module(log_data): 22 | if not isinstance(log_data, list): 23 | module_name = 'fsns.py' 24 | klass_name = log_data 25 | else: 26 | klass_name, module_name = log_data 27 | return klass_name, module_name 28 | 29 | 30 | def load_module(module_file): 31 | module_spec = importlib.util.spec_from_file_location("models.model", module_file) 32 | module = importlib.util.module_from_spec(module_spec) 33 | module_spec.loader.exec_module(module) 34 | return module 35 | 36 | 37 | def build_recognition_net(recognition_net_class, target_shape, args): 38 | return recognition_net_class( 39 | target_shape, 40 | args.num_labels, 41 | args.timesteps, 42 | uses_original_data=args.is_original_fsns, 43 | use_blstm=True 44 | ) 45 | 46 | 47 | def build_localization_net(localization_net_class, args): 48 | return localization_net_class(args.dropout_ratio, args.timesteps) 49 | 50 | 51 | def build_fusion_net(fusion_net_class, localization_net, recognition_net): 52 | return fusion_net_class(localization_net, recognition_net, uses_original_data=args.is_original_fsns) 53 | 54 | 55 | def create_network(args, log_data): 56 | # Step 1: build network 57 | localization_net_class_name, localization_module_name = get_class_and_module(log_data['localization_net']) 58 | module = load_module(os.path.abspath(os.path.join(args.model_dir, localization_module_name))) 59 | localization_net_class = eval('module.{}'.format(localization_net_class_name)) 60 | localization_net = build_localization_net(localization_net_class, args) 61 | 62 | recognition_net_class_name, recognition_module_name = get_class_and_module(log_data['recognition_net']) 63 | module = load_module(os.path.abspath(os.path.join(args.model_dir, recognition_module_name))) 64 | recognition_net_class = eval('module.{}'.format(recognition_net_class_name)) 65 | recognition_net = build_recognition_net(recognition_net_class, target_shape, args) 66 | 67 | fusion_net_class_name, fusion_module_name = get_class_and_module(log_data['fusion_net']) 68 | module = load_module(os.path.abspath(os.path.join(args.model_dir, fusion_module_name))) 69 | fusion_net_class = eval('module.{}'.format(fusion_net_class_name)) 70 | net = build_fusion_net(fusion_net_class, localization_net, recognition_net) 71 | 72 | if args.gpu >= 0: 73 | net.to_gpu(args.gpu) 74 | 75 | return net 76 | 77 | 78 | def load_image(image_file, xp): 79 | with Image.open(image_file) as the_image: 80 | image = xp.asarray(the_image.convert('RGB'), dtype=np.float32) 81 | image /= 255 82 | image = image.transpose(2, 0, 1) 83 | 84 | return image 85 | 86 | 87 | def strip_prediction(predictions, xp, blank_symbol): 88 | words = [] 89 | for prediction in predictions: 90 | stripped_prediction = xp.empty((0,), dtype=xp.int32) 91 | for char in prediction: 92 | if char == blank_symbol: 93 | continue 94 | stripped_prediction = xp.hstack((stripped_prediction, char.reshape(1, ))) 95 | words.append(stripped_prediction) 96 | return words 97 | 98 | 99 | def extract_bbox(bbox, image_size, target_shape, xp): 100 | bbox.data[...] = (bbox.data[...] + 1) / 2 101 | bbox.data[0, :] *= image_size.width 102 | bbox.data[1, :] *= image_size.height 103 | 104 | x = xp.clip(bbox.data[0, :].reshape(target_shape), 0, image_size.width) 105 | y = xp.clip(bbox.data[1, :].reshape(target_shape), 0, image_size.height) 106 | 107 | top_left = (float(x[0, 0]), float(y[0, 0])) 108 | bottom_right = (float(x[-1, -1]), float(y[-1, -1])) 109 | 110 | return top_left, bottom_right 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser(description="Tool that loads model and predicts on a given image") 115 | parser.add_argument("model_dir", help="path to directory where model is saved") 116 | parser.add_argument("snapshot_name", help="name of the snapshot to load") 117 | parser.add_argument("image_path", help="path to the image that shall be evaluated") 118 | parser.add_argument("char_map", help="path to char map, that maps class id to character") 119 | parser.add_argument("--gpu", type=int, default=-1, help="id of gpu to use [default: use cpu]") 120 | 121 | args = parser.parse_args() 122 | # set standard args that should always hold true if using the supplied model 123 | args.is_original_fsns = True 124 | args.log_name = 'log' 125 | args.dropout_ratio = 0.5 126 | args.blank_symbol = 0 127 | # max number of text regions in the image 128 | args.timesteps = 4 129 | # max number of characters per word 130 | args.num_labels = 21 131 | 132 | # open log and extract meta information 133 | with open(os.path.join(args.model_dir, args.log_name)) as the_log: 134 | log_data = json.load(the_log)[0] 135 | 136 | target_shape = Size._make(log_data['target_size']) 137 | image_size = Size._make(log_data['image_size']) 138 | 139 | xp = chainer.cuda.cupy if args.gpu >= 0 else np 140 | network = create_network(args, log_data) 141 | 142 | # load weights 143 | with np.load(os.path.join(args.model_dir, args.snapshot_name)) as f: 144 | chainer.serializers.NpzDeserializer(f).load(network) 145 | 146 | # load char map 147 | with open(args.char_map) as the_map: 148 | char_map = json.load(the_map) 149 | 150 | # load image 151 | image = load_image(args.image_path, xp) 152 | with configuration.using_config('train', False): 153 | predictions, crops, grids = network(image[xp.newaxis, ...]) 154 | 155 | # extract class scores for each word 156 | words = OrderedDict({}) 157 | for prediction, bbox in zip(predictions, grids): 158 | classification = F.softmax(prediction, axis=2) 159 | classification = classification.data 160 | classification = xp.argmax(classification, axis=2) 161 | classification = xp.transpose(classification, (1, 0)) 162 | 163 | word = strip_prediction(classification, xp, args.blank_symbol)[0] 164 | 165 | word = "".join(map(lambda x: chr(char_map[str(x)]), word)) 166 | 167 | bbox = extract_bbox(bbox, image_size, target_shape, xp) 168 | words[word] = OrderedDict({ 169 | 'top_left': bbox[0], 170 | 'bottom_right': bbox[1] 171 | }) 172 | 173 | pprint(words) 174 | 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /chainer/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/chainer/functions/__init__.py -------------------------------------------------------------------------------- /chainer/functions/disable_shearing.py: -------------------------------------------------------------------------------- 1 | from chainer import Function 2 | from chainer.utils import type_check, force_array 3 | 4 | 5 | class DisableShearing(Function): 6 | 7 | def check_type_forward(self, in_types): 8 | type_check.expect(in_types.size() == 1) 9 | grid_type, = in_types 10 | 11 | type_check.expect( 12 | grid_type.dtype.kind == 'f', 13 | grid_type.ndim == 3, 14 | grid_type.shape[1] == 2, 15 | grid_type.shape[2] == 3, 16 | ) 17 | 18 | def forward(self, inputs): 19 | grid, = inputs 20 | 21 | grid[:, 0, 1] = 0 22 | grid[:, 1, 0] = 0 23 | 24 | return force_array(grid, dtype=grid.dtype), 25 | 26 | def backward(self, inputs, grad_outputs): 27 | grad_outputs = grad_outputs[0] 28 | grad_outputs[:, 0, 1] = 0 29 | grad_outputs[:, 1, 0] = 0 30 | return force_array(grad_outputs, dtype=inputs[0].dtype), 31 | 32 | 33 | def disable_shearing(grid): 34 | return DisableShearing()(grid) 35 | -------------------------------------------------------------------------------- /chainer/functions/disable_translation.py: -------------------------------------------------------------------------------- 1 | from chainer import Function 2 | from chainer.utils import type_check, force_array 3 | 4 | 5 | class DisableTranslation(Function): 6 | 7 | def check_type_forward(self, in_types): 8 | type_check.expect(in_types.size() == 1) 9 | grid_type, = in_types 10 | 11 | type_check.expect( 12 | grid_type.dtype.kind == 'f', 13 | grid_type.ndim == 3, 14 | grid_type.shape[1] == 2, 15 | grid_type.shape[2] == 3, 16 | ) 17 | 18 | def forward(self, inputs): 19 | self.retain_inputs(()) 20 | grid, = inputs 21 | grid[:, :, 2] = 0 22 | 23 | return grid, 24 | 25 | def backward(self, inputs, grad_outputs): 26 | grad_outputs = grad_outputs[0] 27 | grad_outputs[:, :, 2] = 0 28 | return grad_outputs, 29 | 30 | 31 | def disable_translation(grid): 32 | return DisableTranslation()(grid) 33 | -------------------------------------------------------------------------------- /chainer/functions/rotation_droput.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from chainer import configuration 4 | from chainer import cuda 5 | from chainer import function 6 | from chainer.utils import type_check 7 | 8 | 9 | class RotationDropout(function.Function): 10 | 11 | """Dropout regularsation for training rotation of spatial transformer""" 12 | 13 | def __init__(self, dropout_ratio): 14 | self.dropout_ratio = dropout_ratio 15 | 16 | def check_type_forward(self, in_types): 17 | type_check.expect(in_types.size() == 1) 18 | x_type = in_types[0] 19 | type_check.expect( 20 | x_type.dtype.kind == 'f', 21 | x_type.ndim == 3, 22 | x_type.shape[1] == 2, 23 | x_type.shape[2] == 3, 24 | ) 25 | 26 | def forward(self, x): 27 | self.retain_inputs(()) 28 | xp = cuda.get_array_module(*x) 29 | 30 | if not configuration.config.train: 31 | # scale affected weights if we are testing 32 | mask = xp.ones_like(x[0]) 33 | mask[:, 0, 1] = self.dropout_ratio 34 | mask[:, 1, 0] = self.dropout_ratio 35 | 36 | return x[0] * mask, 37 | 38 | if not hasattr(self, 'mask'): 39 | self.mask = xp.ones_like(x[0]) 40 | 41 | flag_data = xp.random.rand(1) < self.dropout_ratio 42 | self.mask[:, 0, 1] = flag_data 43 | self.mask[:, 1, 0] = flag_data 44 | 45 | return x[0] * self.mask, 46 | 47 | def backward(self, x, gy): 48 | return gy[0] * self.mask, 49 | 50 | 51 | def rotation_dropout(x, ratio=.5, **kwargs): 52 | return RotationDropout(ratio)(x) 53 | -------------------------------------------------------------------------------- /chainer/insights/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/chainer/insights/__init__.py -------------------------------------------------------------------------------- /chainer/insights/bbox_plotter.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os 4 | import socket 5 | 6 | from io import BytesIO 7 | 8 | import chainer.functions as F 9 | import numpy as np 10 | from PIL import Image, ImageFont 11 | 12 | from PIL import ImageDraw 13 | 14 | import chainer 15 | from chainer import cuda 16 | from chainer.training import Extension 17 | 18 | from insights.visual_backprop import VisualBackprop 19 | from utils.datatypes import Size 20 | 21 | 22 | COLOR_MAP = [ 23 | "#00B3FF", # Vivid Yellow 24 | "#753E80", # Strong Purple 25 | "#0068FF", # Vivid Orange 26 | "#D7BDA6", # Very Light Blue 27 | "#2000C1", # Vivid Red 28 | "#62A2CE", # Grayish Yellow 29 | "#667081", # Medium Gray 30 | 31 | # The following don't work well for people with defective color vision 32 | "#347D00", # Vivid Green 33 | "#8E76F6", # Strong Purplish Pink 34 | "#8A5300", # Strong Blue 35 | "#5C7AFF", # Strong Yellowish Pink 36 | "#7A3753", # Strong Violet 37 | "#008EFF", # Vivid Orange Yellow 38 | "#5128B3", # Strong Purplish Red 39 | "#00C8F4", # Vivid Greenish Yellow 40 | "#0D187F", # Strong Reddish Brown 41 | "#00AA93", # Vivid Yellowish Green 42 | "#153359", # Deep Yellowish Brown 43 | "#133AF1", # Vivid Reddish Orange 44 | "#162C23", # Dark Olive Green 45 | 46 | # extend colour map 47 | "#00B3FF", # Vivid Yellow 48 | "#753E80", # Strong Purple 49 | "#0068FF", # Vivid Orange 50 | "#D7BDA6", # Very Light Blue 51 | "#2000C1", # Vivid Red 52 | "#62A2CE", # Grayish Yellow 53 | "#667081", # Medium Gray 54 | ] 55 | 56 | 57 | class BBOXPlotter(Extension): 58 | 59 | def __init__(self, image, out_dir, out_size, loss_metrics, **kwargs): 60 | super(BBOXPlotter, self).__init__() 61 | self.image = image 62 | self.render_extracted_rois = kwargs.pop("render_extracted_rois", True) 63 | self.image_size = Size(height=image.shape[1], width=image.shape[2]) 64 | self.out_dir = out_dir 65 | os.makedirs(self.out_dir, exist_ok=True) 66 | self.out_size = out_size 67 | self.colours = COLOR_MAP 68 | self.send_bboxes = kwargs.pop("send_bboxes", False) 69 | self.upstream_ip = kwargs.pop("upstream_ip", '127.0.0.1') 70 | self.upstream_port = kwargs.pop("upstream_port", 1337) 71 | self.loss_metrics = loss_metrics 72 | self.font = ImageFont.truetype("utils/DejaVuSans.ttf", 20) 73 | self.visualization_anchors = kwargs.pop("visualization_anchors", []) 74 | self.visual_backprop = VisualBackprop() 75 | self.xp = np 76 | 77 | def send_image(self, data): 78 | height = data.height 79 | width = data.width 80 | channels = len(data.getbands()) 81 | 82 | # convert image to png in order to save network bandwidth 83 | png_stream = BytesIO() 84 | data.save(png_stream, format="PNG") 85 | png_stream = png_stream.getvalue() 86 | 87 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 88 | try: 89 | sock.connect((self.upstream_ip, self.upstream_port)) 90 | except Exception as e: 91 | print(e) 92 | print("could not connect to display server, disabling image rendering") 93 | self.send_bboxes = False 94 | return 95 | data = { 96 | 'width': width, 97 | 'height': height, 98 | 'channels': channels, 99 | 'image': base64.b64encode(png_stream).decode('utf-8'), 100 | } 101 | sock.send(bytes(json.dumps(data), 'utf-8')) 102 | 103 | def array_to_image(self, array): 104 | if array.shape[0] == 1: 105 | # image is black and white, we need to trick the system into thinking, that we are having an RGB image 106 | array = self.xp.tile(array, (3, 1, 1)) 107 | return Image.fromarray(cuda.to_cpu(array.transpose(1, 2, 0) * 255).astype(np.uint8), "RGB").convert("RGBA") 108 | 109 | def variable_to_image(self, variable): 110 | return self.array_to_image(variable.data) 111 | 112 | def __call__(self, trainer): 113 | iteration = trainer.updater.iteration 114 | 115 | with cuda.get_device_from_id(trainer.updater.get_optimizer('main').target._device_id), chainer.using_config('train', False): 116 | self.xp = np if trainer.updater.get_optimizer('main').target._device_id < 0 else cuda.cupy 117 | image = self.xp.asarray(self.image) 118 | predictor = trainer.updater.get_optimizer('main').target.predictor 119 | predictions, rois, bboxes = predictor(image[self.xp.newaxis, ...]) 120 | 121 | backprop_visualizations = [] 122 | for visanchor in self.visualization_anchors: 123 | vis_target = predictor 124 | for target in visanchor: 125 | vis_target = getattr(vis_target, target) 126 | backprop_visualizations.append(self.visual_backprop.perform_visual_backprop(vis_target)) 127 | 128 | self.render_rois(predictions, rois, bboxes, iteration, self.image.copy(), backprop_vis=backprop_visualizations) 129 | 130 | @property 131 | def original_image_paste_location(self): 132 | return 0, 0 133 | 134 | def render_rois(self, predictions, rois, bboxes, iteration, image, backprop_vis=()): 135 | # get the predicted text 136 | text = self.decode_predictions(predictions) 137 | 138 | image = self.array_to_image(image) 139 | 140 | num_timesteps = self.get_num_timesteps(bboxes) 141 | bboxes, dest_image = self.set_output_sizes(backprop_vis, bboxes, image, num_timesteps) 142 | if self.render_extracted_rois: 143 | self.render_extracted_regions(dest_image, image, rois, num_timesteps) 144 | 145 | if len(backprop_vis) != 0: 146 | # if we have a backprop visualization we can show it now 147 | self.show_backprop_vis(backprop_vis, dest_image, image, num_timesteps) 148 | 149 | self.draw_bboxes(bboxes, image) 150 | dest_image.paste(image, self.original_image_paste_location) 151 | if len(text) > 0: 152 | dest_image = self.render_text(dest_image, text) 153 | dest_image.save("{}.png".format(os.path.join(self.out_dir, str(iteration))), 'png') 154 | if self.send_bboxes: 155 | self.send_image(dest_image) 156 | 157 | def get_num_timesteps(self, bboxes): 158 | return bboxes.shape[0] 159 | 160 | def set_output_sizes(self, backprop_vis, bboxes, image, num_timesteps): 161 | _, num_channels, height, width = bboxes.shape 162 | 163 | image_height = image.height if len(backprop_vis) == 0 else image.height + self.image_size.height 164 | image_width = image.width + image.width * num_timesteps if self.render_extracted_rois else image.width 165 | 166 | dest_image = Image.new("RGBA", (image_width, image_height), color='black') 167 | bboxes = F.reshape(bboxes, (num_timesteps, 1, num_channels, height, width)) 168 | 169 | return bboxes, dest_image 170 | 171 | def show_backprop_vis(self, backprop_vis, dest_image, image, num_timesteps): 172 | count = 0 173 | for visualization in backprop_vis: 174 | for vis in visualization: 175 | backprop_image = self.array_to_image(self.xp.tile(vis[0], (3, 1, 1))).resize( 176 | (self.image_size.width, self.image_size.height)) 177 | dest_image.paste(backprop_image, (count * backprop_image.width, image.height)) 178 | count += 1 179 | 180 | def decode_predictions(self, predictions): 181 | words = [] 182 | for prediction in predictions: 183 | if isinstance(prediction, list): 184 | prediction = F.concat([F.expand_dims(p, axis=0) for p in prediction], axis=0) 185 | 186 | prediction = self.xp.transpose(prediction.data, (1, 0, 2)) 187 | prediction = self.xp.squeeze(prediction, axis=0) 188 | prediction = self.xp.argmax(prediction, axis=1) 189 | word = self.loss_metrics.strip_prediction(prediction[self.xp.newaxis, ...])[0] 190 | if len(word) == 1 and word[0] == 0: 191 | continue 192 | word = "".join(map(self.loss_metrics.label_to_char, word)) 193 | word = word.replace(chr(self.loss_metrics.char_map[str(self.loss_metrics.blank_symbol)]), '') 194 | if len(word) > 0: 195 | words.append(word) 196 | text = " ".join(words) 197 | return text 198 | 199 | def render_extracted_regions(self, dest_image, image, rois, num_timesteps): 200 | _, num_channels, height, width = rois.shape 201 | rois = self.xp.reshape(rois, (num_timesteps, -1, num_channels, height, width)) 202 | 203 | for i, roi in enumerate(rois, start=1): 204 | roi_image = self.variable_to_image(roi[0]) 205 | paste_location = i * image.width, 0 206 | dest_image.paste(roi_image.resize((self.image_size.width, self.image_size.height)), paste_location) 207 | 208 | def render_text(self, dest_image, text): 209 | label_image = Image.new(dest_image.mode, dest_image.size) 210 | # only keep ascii characters 211 | # labels = ''.join(filter(lambda x: len(x) == len(x.encode()), labels)) 212 | draw = ImageDraw.Draw(label_image) 213 | text_width, text_height = draw.textsize(text, font=self.font) 214 | draw.rectangle([dest_image.width - text_width - 1, 0, dest_image.width, text_height], 215 | fill=(255, 255, 255, 160)) 216 | draw.text((dest_image.width - text_width - 1, 0), text, fill='green', font=self.font) 217 | dest_image = Image.alpha_composite(dest_image, label_image) 218 | return dest_image 219 | 220 | def draw_bboxes(self, bboxes, image): 221 | draw = ImageDraw.Draw(image) 222 | for i, sub_box in enumerate(F.separate(bboxes, axis=1)): 223 | for bbox, colour in zip(F.separate(sub_box, axis=0), self.colours): 224 | bbox.data[...] = (bbox.data[...] + 1) / 2 225 | bbox.data[0, :] *= self.image_size.width 226 | bbox.data[1, :] *= self.image_size.height 227 | 228 | x = self.xp.clip(bbox.data[0, :].reshape(self.out_size), 0, self.image_size.width) + i * self.image_size.width 229 | y = self.xp.clip(bbox.data[1, :].reshape(self.out_size), 0, self.image_size.height) 230 | 231 | top_left = (x[0, 0], y[0, 0]) 232 | top_right = (x[0, -1], y[0, -1]) 233 | bottom_left = (x[-1, 0], y[-1, 0]) 234 | bottom_right = (x[-1, -1], y[-1, -1]) 235 | 236 | corners = [top_left, top_right, bottom_right, bottom_left] 237 | next_corners = corners[1:] + [corners[0]] 238 | 239 | for first_corner, next_corner in zip(corners, next_corners): 240 | draw.line([first_corner, next_corner], fill=colour, width=3) 241 | -------------------------------------------------------------------------------- /chainer/insights/fsns_bbox_plotter.py: -------------------------------------------------------------------------------- 1 | import chainer.functions as F 2 | from PIL import Image 3 | 4 | from insights.bbox_plotter import BBOXPlotter 5 | from utils.datatypes import Size 6 | 7 | 8 | class FSNSBBOXPlotter(BBOXPlotter): 9 | 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.image_size = Size(height=self.image.shape[1], width=self.image.shape[2] // 4) 13 | 14 | @property 15 | def original_image_paste_location(self): 16 | return self.image_size.width, 0 17 | 18 | def get_num_timesteps(self, bboxes): 19 | return bboxes.shape[0] // 4 20 | 21 | def set_output_sizes(self, backprop_vis, bboxes, image, num_timesteps): 22 | _, num_channels, height, width = bboxes.shape 23 | image_height = image.height 24 | if len(backprop_vis) > 0: 25 | image_height = image.height + self.image_size.height 26 | if self.render_extracted_rois: 27 | image_height = image.height + self.image_size.height * (1 + num_timesteps) 28 | 29 | dest_image = Image.new("RGBA", (image.width + self.image_size.width, image_height), color='black') 30 | bboxes = F.reshape(bboxes, (num_timesteps, 4, num_channels, height, self.out_size.width)) 31 | 32 | return bboxes, dest_image 33 | 34 | def show_backprop_vis(self, backprop_vis, dest_image, image, num_timesteps): 35 | # first render localization visualization 36 | for j, vis in enumerate(backprop_vis[0]): 37 | backprop_image = self.array_to_image(self.xp.tile(vis, (3, 1, 1))) 38 | dest_image.paste(backprop_image, ((j + 1) * self.image_size.width, image.height)) 39 | # second render recognition visualization 40 | _, num_channels, height, width = backprop_vis[1].shape 41 | recognition_vis = self.xp.reshape(backprop_vis[1], (num_timesteps, -1, num_channels, height, width)) 42 | for i in range(len(recognition_vis)): 43 | for j, vis in enumerate(recognition_vis[i]): 44 | backprop_image = self.array_to_image(self.xp.tile(vis, (3, 1, 1)))\ 45 | .resize((self.image_size.width, self.image_size.height)) 46 | dest_image.paste(backprop_image, ((j + 1) * self.image_size.width, (i + 2) * image.height)) 47 | 48 | def render_extracted_regions(self, dest_image, image, rois, num_timesteps): 49 | _, num_channels, height, width = rois.shape 50 | rois = self.xp.reshape(rois, (num_timesteps, -1, num_channels, height, width)) 51 | 52 | for i, roi in enumerate(rois, start=1): 53 | roi_image = self.variable_to_image(roi[0]) 54 | paste_location = 0, (i + 1) * image.height 55 | dest_image.paste(roi_image.resize((self.image_size.width, self.image_size.height)), paste_location) 56 | -------------------------------------------------------------------------------- /chainer/insights/lstm_per_step_plotter.py: -------------------------------------------------------------------------------- 1 | import chainer.functions as F 2 | 3 | from insights.bbox_plotter import BBOXPlotter 4 | 5 | 6 | class LSTMPerStepBBOXPlotter(BBOXPlotter): 7 | 8 | def decode_predictions(self, predictions, xp): 9 | predictions = F.concat(predictions, axis=0) 10 | predictions = xp.argmax(predictions.data, axis=1) 11 | word = self.loss_metrics.strip_prediction(predictions[xp.newaxis, ...])[0] 12 | if len(word) == 1 and word[0] == 0: 13 | return "" 14 | word = "".join(map(self.loss_metrics.label_to_char, word)) 15 | word = word.replace(chr(self.loss_metrics.char_map[str(self.loss_metrics.blank_symbol)]), '') 16 | return word 17 | -------------------------------------------------------------------------------- /chainer/insights/svhn_bbox_plotter.py: -------------------------------------------------------------------------------- 1 | from chainer import cuda 2 | 3 | import chainer.functions as F 4 | 5 | from insights.bbox_plotter import BBOXPlotter 6 | 7 | 8 | class SVHNBBoxPlotter(BBOXPlotter): 9 | 10 | def decode_predictions(self, predictions): 11 | # concat all individual predictions and slice for each time step 12 | predictions = F.concat([F.expand_dims(p, axis=0) for p in predictions], axis=0) 13 | 14 | words = [] 15 | with cuda.get_device_from_array(predictions.data): 16 | for prediction in F.separate(predictions, axis=0): 17 | prediction = F.squeeze(prediction, axis=0) 18 | prediction = F.softmax(prediction, axis=1) 19 | prediction = self.xp.argmax(prediction.data, axis=1) 20 | word = self.loss_metrics.strip_prediction(prediction[self.xp.newaxis, ...])[0] 21 | if len(word) == 1 and word[0] == 0: 22 | return '' 23 | 24 | word = "".join(map(self.loss_metrics.label_to_char, word)) 25 | word = word.replace(chr(self.loss_metrics.char_map[str(self.loss_metrics.blank_symbol)]), '') 26 | words.append(word) 27 | 28 | text = " ".join(words) 29 | return text 30 | -------------------------------------------------------------------------------- /chainer/insights/text_rec_bbox_plotter.py: -------------------------------------------------------------------------------- 1 | from chainer import cuda 2 | 3 | import chainer.functions as F 4 | from PIL import Image, ImageDraw 5 | 6 | from insights.bbox_plotter import BBOXPlotter 7 | 8 | 9 | class TextRecBBOXPlotter(BBOXPlotter): 10 | 11 | def __init__(self, *args, **kwargs): 12 | self.render_intermediate_bboxes = kwargs.pop('render_intermediate_bboxes', False) 13 | super().__init__(*args, **kwargs) 14 | 15 | def get_num_timesteps(self, bboxes): 16 | return bboxes[-1].shape[0] 17 | 18 | def set_output_sizes(self, backprop_vis, bboxes, image, num_timesteps): 19 | _, num_channels, height, width = bboxes[-1].shape 20 | 21 | image_height = image.height if len(backprop_vis) == 0 else image.height + self.image_size.height 22 | image_width = image.width + image.width * num_timesteps if self.render_extracted_rois else image.width 23 | 24 | dest_image = Image.new("RGBA", (image_width, image_height), color='black') 25 | bboxes = F.concat([F.reshape(bbox, (num_timesteps, 1, num_channels, height, width)) for bbox in bboxes], axis=1) 26 | 27 | return bboxes, dest_image 28 | 29 | def render_extracted_regions(self, dest_image, image, rois, num_timesteps): 30 | rois = rois[-1] 31 | _, num_channels, height, width = rois.shape 32 | rois = self.xp.reshape(rois, (num_timesteps, -1, num_channels, height, width)) 33 | 34 | for i, roi in enumerate(rois, start=1): 35 | roi_image = self.variable_to_image(roi[0]) 36 | paste_location = i * image.width, 0 37 | dest_image.paste(roi_image.resize((self.image_size.width, self.image_size.height)), paste_location) 38 | 39 | def decode_predictions(self, predictions): 40 | # concat all individual predictions and slice for each time step 41 | predictions = F.concat([F.expand_dims(prediction, axis=0) for prediction in predictions], axis=0) 42 | 43 | with cuda.get_device_from_array(predictions.data): 44 | prediction = F.squeeze(predictions, axis=1) 45 | classification = F.softmax(prediction, axis=1) 46 | classification = classification.data 47 | classification = self.xp.argmax(classification, axis=1) 48 | 49 | words = self.loss_metrics.strip_prediction(classification[self.xp.newaxis, ...])[0] 50 | word = "".join(map(self.loss_metrics.label_to_char, words)) 51 | 52 | return word 53 | 54 | def draw_bboxes(self, bboxes, image): 55 | draw = ImageDraw.Draw(image) 56 | for boxes, colour in zip(F.separate(bboxes, axis=0), self.colours): 57 | num_boxes = boxes.shape[0] 58 | 59 | for i, bbox in enumerate(F.separate(boxes, axis=0)): 60 | # render all intermediate results with lower alpha as the others 61 | fill_colour = colour 62 | if i < num_boxes - 1: 63 | if not self.render_intermediate_bboxes: 64 | continue 65 | fill_colour += '88' 66 | 67 | bbox.data[...] = (bbox.data[...] + 1) / 2 68 | bbox.data[0, :] *= self.image_size.width 69 | bbox.data[1, :] *= self.image_size.height 70 | 71 | x = self.xp.clip(bbox.data[0, :].reshape(self.out_size), 0, self.image_size.width) 72 | y = self.xp.clip(bbox.data[1, :].reshape(self.out_size), 0, self.image_size.height) 73 | 74 | top_left = (x[0, 0], y[0, 0]) 75 | top_right = (x[0, -1], y[0, -1]) 76 | bottom_left = (x[-1, 0], y[-1, 0]) 77 | bottom_right = (x[-1, -1], y[-1, -1]) 78 | 79 | corners = [top_left, top_right, bottom_right, bottom_left] 80 | next_corners = corners[1:] + [corners[0]] 81 | 82 | for first_corner, next_corner in zip(corners, next_corners): 83 | draw.line([first_corner, next_corner], fill=fill_colour, width=3) 84 | -------------------------------------------------------------------------------- /chainer/insights/textrec_bbox_plotter.py: -------------------------------------------------------------------------------- 1 | from chainer import cuda 2 | 3 | import chainer.functions as F 4 | 5 | from insights.bbox_plotter import BBOXPlotter 6 | 7 | 8 | class TextRectBBoxPlotter(BBOXPlotter): 9 | 10 | def decode_predictions(self, predictions): 11 | # concat all individual predictions and slice for each time step 12 | predictions = predictions[0] 13 | 14 | with cuda.get_device_from_array(predictions.data): 15 | prediction = F.squeeze(predictions, axis=1) 16 | classification = F.softmax(prediction, axis=1) 17 | classification = classification.data 18 | classification = self.xp.argmax(classification, axis=1) 19 | 20 | words = self.loss_metrics.strip_prediction(classification[self.xp.newaxis, ...])[0] 21 | word = "".join(map(self.loss_metrics.label_to_char, words)) 22 | 23 | return word 24 | -------------------------------------------------------------------------------- /chainer/insights/visual_backprop.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer import cuda 3 | from chainer.functions.connection.convolution_2d import Convolution2DFunction 4 | from chainer.functions.pooling.pooling_2d import Pooling2D 5 | 6 | import chainer.functions as F 7 | from chainer.training import Extension 8 | 9 | 10 | class VisualBackprop(Extension): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.xp = None 15 | 16 | def traverse_computational_graph(self, node, feature_map): 17 | if node.inputs[0].creator is None: 18 | return feature_map 19 | 20 | if isinstance(node, Convolution2DFunction) or isinstance(node, Pooling2D): 21 | feature_map = self.scale_layer(feature_map, node) 22 | 23 | return self.traverse_computational_graph(node.inputs[0].creator, feature_map) 24 | 25 | def scale_layer(self, feature_map, node): 26 | input_data = node.inputs[0].data 27 | _, _, in_height, in_width = input_data.shape 28 | _, _, feature_height, feature_width = feature_map.shape 29 | kernel_height = in_height + 2 * node.ph - node.sy * (feature_height - 1) 30 | kernel_width = in_width + 2 * node.pw - node.sx * (feature_width - 1) 31 | scaled_feature = F.deconvolution_2d( 32 | feature_map, 33 | self.xp.ones((1, 1, kernel_height, kernel_width)), 34 | stride=(node.sy, node.sx), 35 | pad=(node.ph, node.pw), 36 | outsize=(in_height, in_width), 37 | ) 38 | averaged_feature_map = F.average(input_data, axis=1, keepdims=True) 39 | feature_map = scaled_feature * averaged_feature_map 40 | return feature_map 41 | 42 | def perform_visual_backprop(self, variable): 43 | with chainer.no_backprop_mode(), chainer.cuda.get_device_from_array(variable.data): 44 | self.xp = cuda.get_array_module(variable) 45 | averaged_feature = F.average(variable, axis=1, keepdims=True) 46 | 47 | visualization = self.traverse_computational_graph(variable.creator, averaged_feature) 48 | visualization = visualization.data 49 | for i in range(len(visualization)): 50 | min_val = visualization[i].min() 51 | max_val = visualization[i].max() 52 | visualization[i] -= min_val 53 | visualization[i] *= 1.0 / (max_val - min_val) 54 | return visualization 55 | -------------------------------------------------------------------------------- /chainer/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/chainer/metrics/__init__.py -------------------------------------------------------------------------------- /chainer/metrics/ctc_metrics.py: -------------------------------------------------------------------------------- 1 | from chainer import cuda 2 | 3 | import chainer.functions as F 4 | 5 | from metrics.loss_metrics import LossMetrics 6 | 7 | 8 | class CTCMetrics(LossMetrics): 9 | 10 | def calc_actual_loss(self, predictions, grid, labels): 11 | loss = F.connectionist_temporal_classification(predictions, labels, self.blank_symbol) 12 | return loss 13 | 14 | def strip_prediction(self, predictions): 15 | # TODO Parallelize 16 | words = [] 17 | for prediction in predictions: 18 | blank_symbol_seen = False 19 | stripped_prediction = self.xp.full((1,), prediction[0], dtype=self.xp.int32) 20 | for char in prediction: 21 | if char == self.blank_symbol: 22 | blank_symbol_seen = True 23 | continue 24 | if char == stripped_prediction[-1] and not blank_symbol_seen: 25 | continue 26 | blank_symbol_seen = False 27 | stripped_prediction = self.xp.hstack((stripped_prediction, char.reshape(1,))) 28 | words.append(stripped_prediction) 29 | return words 30 | -------------------------------------------------------------------------------- /chainer/metrics/loss_metrics.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | 4 | import chainer 5 | from chainer import cuda 6 | import numpy 7 | 8 | import chainer.functions as F 9 | 10 | 11 | class LossMetrics: 12 | def __init__(self, blank_symbol, char_map, timesteps, image_size, area_loss_factor=0, aspect_ratio_loss_factor=0, uses_original_data=False, 13 | area_scaling_factor=2): 14 | self.aspect_ratio_loss_factor = aspect_ratio_loss_factor 15 | self.blank_symbol = blank_symbol 16 | self.xp = None 17 | with open(char_map, 'r') as the_char_map: 18 | self.char_map = json.load(the_char_map) 19 | self.image_size = image_size 20 | self.num_timesteps = timesteps 21 | self.base_area_loss_factor = area_loss_factor 22 | self.area_scaling_factor = area_scaling_factor 23 | self.uses_original_data = uses_original_data 24 | self.area_loss_factor = self.base_area_loss_factor 25 | 26 | def get_label_lengths(self, labels): 27 | if self.xp == numpy: 28 | label_lengths = self.xp.zeros(len(labels)) 29 | 30 | for i in range(len(labels)): 31 | for j in range(len(labels[i])): 32 | if labels.data[i][j] == self.blank_symbol: 33 | label_lengths[i] = j 34 | break 35 | else: 36 | import cupy 37 | label_length_kernel = cupy.ElementwiseKernel( 38 | 'raw T labels, int32 blank_symbol, int32 num_labels', 39 | 'T length', 40 | ''' 41 | for (int j = 0; j < num_labels; ++j) { 42 | T label_value = labels[i * num_labels + j]; 43 | if (label_value == blank_symbol) { 44 | length = j; 45 | break; 46 | } 47 | } 48 | ''', 49 | 'get_label_lengths' 50 | ) 51 | label_lengths = label_length_kernel(labels.data, self.blank_symbol, labels.shape[1], size=len(labels)) 52 | return label_lengths 53 | 54 | def strip_prediction(self, predictions): 55 | # TODO Parallelize 56 | words = [] 57 | for prediction in predictions: 58 | stripped_prediction = self.xp.empty((0,), dtype=self.xp.int32) 59 | for char in prediction: 60 | if char == self.blank_symbol: 61 | continue 62 | stripped_prediction = self.xp.hstack((stripped_prediction, char.reshape(1,))) 63 | words.append(stripped_prediction) 64 | return words 65 | 66 | def get_bbox_side_lengths(self, grids): 67 | x0, x1, x2, y0, y1, y2 = self.get_corners(grids) 68 | 69 | width = F.sqrt( 70 | F.square(x1 - x0) + F.square(y1 - y0) 71 | ) 72 | 73 | height = F.sqrt( 74 | F.square(x2 - x0) + F.square(y2 - y0) 75 | ) 76 | return width, height 77 | 78 | def get_corners(self, grids): 79 | _, _, height, width = grids.shape 80 | grids = (grids + 1) / 2 81 | x_points = grids[:, 0, ...] * self.image_size.width 82 | y_points = grids[:, 1, ...] * self.image_size.height 83 | top_left_x = F.get_item(x_points, [..., 0, 0]) 84 | top_left_y = F.get_item(y_points, [..., 0, 0]) 85 | top_right_x = F.get_item(x_points, [..., 0, width - 1]) 86 | top_right_y = F.get_item(y_points, [..., 0, width - 1]) 87 | bottom_left_x = F.get_item(x_points, [..., height - 1, 0]) 88 | bottom_left_y = F.get_item(y_points, [..., height - 1, 0]) 89 | return top_left_x, top_right_x, bottom_left_x, top_left_y, top_right_y, bottom_left_y 90 | 91 | def calc_direction_loss(self, grids): 92 | top_left_x, top_right_x, _, top_left_y, _, bottom_left_y = self.get_corners(grids) 93 | 94 | # penalize upside down images 95 | distance = top_left_y - bottom_left_y 96 | loss_values = F.maximum(distance, self.xp.zeros_like(distance)) 97 | up_down_loss = F.average(loss_values) 98 | 99 | # penalize images that are vertically mirrored 100 | distance = top_left_x - top_right_x 101 | loss_values = F.maximum(distance, self.xp.zeros_like(distance)) 102 | left_right_loss = F.average(loss_values) 103 | 104 | return up_down_loss + left_right_loss 105 | 106 | def calc_height_loss(self, height): 107 | # penalize bboxes that are not high enough to contain text (10 pixels) 108 | shifted_height = height - 10 109 | thresholded_height = F.minimum(shifted_height, self.xp.zeros_like(shifted_height)) 110 | thresholded_height *= -1 111 | 112 | return F.average(thresholded_height) 113 | 114 | def calc_area_loss(self, width, height): 115 | loc_area = width * height 116 | loc_ratio = loc_area / (self.image_size.width * self.image_size.height) 117 | return sum(loc_ratio) / max(len(loc_ratio), 1) 118 | 119 | def calc_overlap(self, left_1, width_1, left_2, width_2): 120 | radius_1 = width_1 / 2 121 | center_1 = left_1 + radius_1 122 | radius_2 = width_2 / 2 123 | center_2 = left_2 + radius_2 124 | 125 | center_distance = center_2 - center_1 126 | center_distance = F.maximum(center_distance, center_distance * -1) 127 | min_distance_for_no_overlap = radius_1 + radius_2 128 | return min_distance_for_no_overlap - center_distance 129 | 130 | def calc_intersection(self, top_left_x_1, width_1, top_left_x_2, width_2, top_left_y_1, height_1, top_left_y_2, height_2): 131 | width_overlap = self.calc_overlap( 132 | top_left_x_1, 133 | width_1, 134 | top_left_x_2, 135 | width_2 136 | ) 137 | 138 | height_overlap = self.calc_overlap( 139 | top_left_y_1, 140 | height_1, 141 | top_left_y_2, 142 | height_2 143 | ) 144 | 145 | width_overlap = F.maximum(width_overlap, self.xp.zeros_like(width_overlap)) 146 | height_overlap = F.maximum(height_overlap, self.xp.zeros_like(height_overlap)) 147 | 148 | return width_overlap * height_overlap 149 | 150 | def calc_iou_loss(self, grids1, grids2): 151 | top_left_x_1, top_right_x_1, _, top_left_y_1, _, bottom_left_y_1 = self.get_corners(grids1) 152 | top_left_x_2, top_right_x_2, _, top_left_y_2, _, bottom_left_y_2 = self.get_corners(grids2) 153 | 154 | width_1 = top_right_x_1 - top_left_x_1 155 | width_2 = top_right_x_2 - top_left_x_2 156 | height_1 = bottom_left_y_1 - top_left_y_1 157 | height_2 = bottom_left_y_2 - top_left_y_2 158 | intersection = self.calc_intersection(top_left_x_1, width_1, top_left_x_2, width_2, top_left_y_1, height_1, top_left_y_2, height_2) 159 | union = width_1 * height_1 + width_2 * height_2 - intersection 160 | iou = intersection / F.maximum(union, self.xp.ones_like(union)) 161 | 162 | return sum(iou) / len(iou) 163 | 164 | def calc_aspect_ratio_loss(self, width, height, label_lengths=None): 165 | # penalize aspect ratios that are higher than wide, and penalize aspect ratios that are tooo wide 166 | aspect_ratio = height / F.maximum(width, self.xp.ones_like(width)) 167 | # do not give an incentive to bboxes with a width that is 2x the height of the box 168 | aspect_loss = F.maximum(aspect_ratio - 0.5, self.xp.zeros_like(aspect_ratio)) 169 | 170 | # penalize very long bboxes (based on the underlying word), by assuming that a single letter 171 | # has a max width of its height, if the width of the bbox is too large it will be penalized 172 | if label_lengths is not None: 173 | max_width = label_lengths * height 174 | width_ratio = width - max_width 175 | width_threshold = F.maximum(width_ratio, self.xp.zeros_like(width_ratio)) 176 | aspect_loss = aspect_ratio + width_threshold 177 | 178 | return sum(aspect_loss) / len(aspect_loss) 179 | 180 | def label_to_char(self, label): 181 | return chr(self.char_map[str(label)]) 182 | 183 | def calc_loss(self, x, t): 184 | batch_predictions, _, grids = x 185 | self.xp = cuda.get_array_module(batch_predictions[0], t) 186 | 187 | # reshape labels 188 | batch_size = t.shape[0] 189 | t = F.reshape(t, (batch_size, self.num_timesteps, -1)) 190 | 191 | # reshape grids 192 | grid_shape = grids.shape 193 | if self.uses_original_data: 194 | grids = F.reshape(grids, (self.num_timesteps, batch_size, 4,) + grid_shape[1:]) 195 | else: 196 | grids = F.reshape(grids, (self.num_timesteps, batch_size, 1,) + grid_shape[1:]) 197 | losses = [] 198 | 199 | # with cuda.get_device_from_array(grids.data): 200 | # grid_list = F.separate(F.reshape(grids, (self.num_timesteps, -1,) + grids.shape[3:]), axis=0) 201 | # overlap_losses = [] 202 | # for grid_1, grid_2 in itertools.combinations(grid_list, 2): 203 | # overlap_losses.append(self.calc_iou_loss(grid_1, grid_2)) 204 | # losses.append(sum(overlap_losses) / max(len(overlap_losses), 1)) 205 | 206 | loss_weights = [1, 1.25, 2, 1.25] 207 | for i, (predictions, grid, labels) in enumerate(zip(batch_predictions, F.separate(grids, axis=0), F.separate(t, axis=1)), start=1): 208 | with cuda.get_device_from_array(getattr(predictions, 'data', predictions[0].data)): 209 | # adapt ctc weight depending on current prediction position and labels 210 | # if all labels are blank, we want this weight to be full weight! 211 | overall_loss_weight = loss_weights[i - 1] 212 | loss = self.calc_actual_loss(predictions, grid, labels) 213 | # label_lengths = self.get_label_lengths(labels) 214 | 215 | for sub_grid in F.separate(grid, axis=1): 216 | width, height = self.get_bbox_side_lengths(sub_grid) 217 | loss += self.area_loss_factor * self.calc_area_loss(width, height) 218 | loss += self.aspect_ratio_loss_factor * self.calc_aspect_ratio_loss(width, height) 219 | loss += self.calc_direction_loss(sub_grid) 220 | loss += self.calc_height_loss(height) 221 | loss *= overall_loss_weight 222 | losses.append(loss) 223 | 224 | return sum(losses) / len(losses) 225 | 226 | def calc_actual_loss(self, predictions, grid, labels): 227 | raise NotImplementedError 228 | 229 | def scale_area_loss_factor(self, accuracy): 230 | self.area_loss_factor = self.base_area_loss_factor + self.area_scaling_factor * accuracy 231 | 232 | def calc_accuracy(self, x, t): 233 | batch_predictions, _, _ = x 234 | self.xp = cuda.get_array_module(batch_predictions[0], t) 235 | batch_size = t.shape[0] 236 | t = F.reshape(t, (batch_size, self.num_timesteps, -1)) 237 | accuracies = [] 238 | 239 | for predictions, labels in zip(batch_predictions, F.separate(t, axis=1)): 240 | if isinstance(predictions, list): 241 | predictions = F.concat([F.expand_dims(p, axis=0) for p in predictions], axis=0) 242 | with cuda.get_device_from_array(predictions.data): 243 | 244 | classification = F.softmax(predictions, axis=2) 245 | classification = classification.data 246 | classification = self.xp.argmax(classification, axis=2) 247 | classification = self.xp.transpose(classification, (1, 0)) 248 | 249 | words = self.strip_prediction(classification) 250 | labels = self.strip_prediction(labels.data) 251 | 252 | num_correct_words = 0 253 | for word, label in zip(words, labels): 254 | word = "".join(map(self.label_to_char, word)) 255 | label = "".join(map(self.label_to_char, label)) 256 | if word == label: 257 | num_correct_words += 1 258 | 259 | accuracy = num_correct_words / len(labels) 260 | accuracies.append(accuracy) 261 | 262 | overall_accuracy = sum(accuracies) / max(len(accuracies), 1) 263 | self.scale_area_loss_factor(overall_accuracy) 264 | return overall_accuracy 265 | -------------------------------------------------------------------------------- /chainer/metrics/lstm_per_step_metrics.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import chainer.functions as F 4 | import chainer.links as L 5 | 6 | from chainer import cuda 7 | 8 | from metrics.loss_metrics import LossMetrics 9 | 10 | 11 | class PerStepLSTMMetric(LossMetrics): 12 | 13 | def calc_actual_loss(self, predictions, grid, labels): 14 | pass 15 | 16 | def calc_loss(self, x, t): 17 | batch_predictions, _, grids = x 18 | self.xp = cuda.get_array_module(batch_predictions[0], t) 19 | 20 | # reshape labels 21 | batch_size = t.shape[0] 22 | 23 | # reshape grids 24 | grid_shape = grids.shape 25 | if self.uses_original_data: 26 | grids = F.reshape(grids, (self.num_timesteps, batch_size, 4,) + grid_shape[1:]) 27 | else: 28 | grids = F.reshape(grids, (self.num_timesteps, batch_size, 1,) + grid_shape[1:]) 29 | recognition_losses = [] 30 | 31 | for prediction, label in zip(batch_predictions, F.separate(t, axis=1)): 32 | recognition_loss = F.softmax_cross_entropy(prediction, label) 33 | recognition_losses.append(recognition_loss) 34 | 35 | losses = [sum(recognition_losses) / len(recognition_losses)] 36 | 37 | # with cuda.get_device_from_array(grids.data): 38 | # grid_list = F.separate(F.reshape(grids, (self.timesteps, -1,) + grids.shape[3:]), axis=0) 39 | # overlap_losses = [] 40 | # for grid_1, grid_2 in itertools.combinations(grid_list, 2): 41 | # overlap_losses.append(self.calc_iou_loss(grid_1, grid_2)) 42 | # losses.append(sum(overlap_losses) / len(overlap_losses)) 43 | 44 | for i, grid in enumerate(F.separate(grids, axis=0), start=1): 45 | with cuda.get_device_from_array(grid.data): 46 | grid_losses = [] 47 | for sub_grid in F.separate(grid, axis=1): 48 | width, height = self.get_bbox_side_lengths(sub_grid) 49 | grid_losses.append(self.area_loss_factor * self.calc_area_loss(width, height)) 50 | grid_losses.append(self.aspect_ratio_loss_factor * self.calc_aspect_ratio_loss(width, height)) 51 | grid_losses.append(self.calc_direction_loss(sub_grid)) 52 | grid_losses.append(self.calc_height_loss(height)) 53 | losses.append(sum(grid_losses)) 54 | 55 | return sum(losses) / len(losses) 56 | 57 | def calc_accuracy(self, x, t): 58 | batch_predictions, _, _ = x 59 | self.xp = cuda.get_array_module(batch_predictions[0], t) 60 | accuracies = [] 61 | for prediction, label in zip(batch_predictions, F.separate(t, axis=1)): 62 | recognition_accuracy = F.accuracy(prediction, label) 63 | accuracies.append(recognition_accuracy) 64 | return sum(accuracies) / len(accuracies) 65 | -------------------------------------------------------------------------------- /chainer/metrics/softmax_metrics.py: -------------------------------------------------------------------------------- 1 | import chainer.functions as F 2 | 3 | from metrics.loss_metrics import LossMetrics 4 | 5 | 6 | class SoftmaxMetrics(LossMetrics): 7 | 8 | def calc_actual_loss(self, predictions, grid, labels): 9 | losses = [] 10 | for char_prediction, char_gt in zip(F.separate(predictions, axis=0), F.separate(labels, axis=1)): 11 | losses.append(F.softmax_cross_entropy(char_prediction, char_gt)) 12 | return sum(losses) 13 | -------------------------------------------------------------------------------- /chainer/metrics/svhn_ctc_metrics.py: -------------------------------------------------------------------------------- 1 | from chainer import cuda 2 | import chainer.functions as F 3 | from metrics.loss_metrics import LossMetrics 4 | 5 | 6 | class SVHNCTCMetrics(LossMetrics): 7 | 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.pred_num_timesteps = 2 * self.num_timesteps + 1 11 | 12 | def calc_actual_loss(self, predictions, grid, labels): 13 | predictions = F.separate(predictions, axis=0) 14 | return F.connectionist_temporal_classification(predictions, labels, blank_symbol=self.blank_symbol) 15 | 16 | def strip_prediction(self, predictions): 17 | # TODO Parallelize 18 | words = [] 19 | for prediction in predictions: 20 | blank_symbol_seen = False 21 | stripped_prediction = self.xp.full((1,), prediction[0], dtype=self.xp.int32) 22 | for char in prediction: 23 | if char == self.blank_symbol: 24 | blank_symbol_seen = True 25 | continue 26 | if char == stripped_prediction[-1] and not blank_symbol_seen: 27 | continue 28 | blank_symbol_seen = False 29 | stripped_prediction = self.xp.hstack((stripped_prediction, char.reshape(1, ))) 30 | words.append(stripped_prediction) 31 | return words 32 | 33 | 34 | -------------------------------------------------------------------------------- /chainer/metrics/svhn_softmax_metrics.py: -------------------------------------------------------------------------------- 1 | from chainer import cuda 2 | import chainer.functions as F 3 | from metrics.softmax_metrics import SoftmaxMetrics 4 | 5 | 6 | class SVHNSoftmaxMetrics(SoftmaxMetrics): 7 | 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | def calc_loss(self, x, t): 12 | batch_predictions, _, _ = x 13 | 14 | # concat all individual predictions and slice for each time step 15 | batch_predictions = F.concat([F.expand_dims(p, axis=0) for p in batch_predictions], axis=0) 16 | 17 | self.xp = cuda.get_array_module(batch_predictions[0], t) 18 | batch_size = t.shape[0] 19 | t = F.reshape(t, (batch_size, self.num_timesteps, -1)) 20 | 21 | losses = [] 22 | for predictions, labels in zip(F.separate(batch_predictions, axis=0), F.separate(t, axis=1)): 23 | batch_size, num_chars, num_classes = predictions.shape 24 | predictions = F.reshape(predictions, (batch_size * num_chars, num_classes)) 25 | labels = F.reshape(labels, (-1,)) 26 | losses.append(F.softmax_cross_entropy(predictions, labels)) 27 | 28 | return sum(losses) 29 | 30 | def calc_accuracy(self, x, t): 31 | batch_predictions, _, _ = x 32 | 33 | # concat all individual predictions and slice for each time step 34 | batch_predictions = F.concat([F.expand_dims(p, axis=0) for p in batch_predictions], axis=0) 35 | 36 | self.xp = cuda.get_array_module(batch_predictions[0], t) 37 | batch_size = t.shape[0] 38 | t = F.reshape(t, (batch_size, self.num_timesteps, -1)) 39 | 40 | accuracies = [] 41 | with cuda.get_device_from_array(batch_predictions.data): 42 | for prediction, label in zip(F.separate(batch_predictions, axis=0), F.separate(t, axis=1)): 43 | classification = F.softmax(prediction, axis=2) 44 | classification = classification.data 45 | classification = self.xp.argmax(classification, axis=2) 46 | # classification = self.xp.transpose(classification, (1, 0)) 47 | 48 | words = self.strip_prediction(classification) 49 | labels = self.strip_prediction(label.data) 50 | 51 | num_correct_words = 0 52 | for word, label in zip(words, labels): 53 | word = "".join(map(self.label_to_char, word)) 54 | label = "".join(map(self.label_to_char, label)) 55 | if word == label: 56 | num_correct_words += 1 57 | 58 | accuracy = num_correct_words / len(labels) 59 | accuracies.append(accuracy) 60 | 61 | overall_accuracy = sum(accuracies) / max(len(accuracies), 1) 62 | self.scale_area_loss_factor(overall_accuracy) 63 | return overall_accuracy 64 | -------------------------------------------------------------------------------- /chainer/metrics/textrec_metrics.py: -------------------------------------------------------------------------------- 1 | import chainer.functions as F 2 | 3 | from chainer import cuda 4 | 5 | from metrics.loss_metrics import LossMetrics 6 | 7 | 8 | class TextRectMetrics(LossMetrics): 9 | 10 | def calc_loss(self, x, t): 11 | batch_predictions, _, grids = x 12 | self.xp = cuda.get_array_module(batch_predictions, t) 13 | 14 | loss = self.calc_actual_loss(batch_predictions, None, t) 15 | 16 | # reshape grids 17 | batch_size = t.shape[0] 18 | grids = grids[-1] 19 | grid_shape = grids.shape 20 | grids = F.reshape(grids, (-1, batch_size) + grid_shape[1:]) 21 | 22 | grid_losses = [] 23 | for grid in F.separate(grids, axis=0): 24 | with cuda.get_device_from_array(getattr(grid, 'data', grid[0].data)): 25 | grid_losses.append(self.calc_direction_loss(grid)) 26 | 27 | return loss + (sum(grid_losses) / len(grid_losses)) 28 | 29 | def calc_accuracy(self, x, t): 30 | batch_predictions, _, _ = x 31 | batch_predictions = F.concat([F.expand_dims(prediction, axis=0) for prediction in batch_predictions], axis=0) 32 | 33 | self.xp = cuda.get_array_module(batch_predictions[0], t) 34 | accuracies = [] 35 | 36 | with cuda.get_device_from_array(batch_predictions.data): 37 | classification = F.softmax(batch_predictions, axis=2) 38 | classification = classification.data 39 | classification = self.xp.argmax(classification, axis=2) 40 | classification = self.xp.transpose(classification, (1, 0)) 41 | 42 | words = self.strip_prediction(classification) 43 | labels = self.strip_prediction(t) 44 | 45 | num_correct_words = 0 46 | for word, label in zip(words, labels): 47 | word = "".join(map(self.label_to_char, word)) 48 | label = "".join(map(self.label_to_char, label)) 49 | if word == label: 50 | num_correct_words += 1 51 | 52 | accuracy = num_correct_words / len(labels) 53 | accuracies.append(accuracy) 54 | 55 | overall_accuracy = sum(accuracies) / max(len(accuracies), 1) 56 | self.scale_area_loss_factor(overall_accuracy) 57 | return overall_accuracy 58 | 59 | 60 | class TextRecCTCMetrics(TextRectMetrics): 61 | 62 | def __init__(self, *args, **kwargs): 63 | super().__init__(*args, **kwargs) 64 | self.pred_num_timesteps = 2 * self.num_timesteps + 1 65 | 66 | def calc_actual_loss(self, predictions, grid, labels): 67 | predictions = F.separate(predictions, axis=0) 68 | return F.connectionist_temporal_classification(predictions, labels, blank_symbol=self.blank_symbol) 69 | 70 | def strip_prediction(self, predictions): 71 | # TODO Parallelize 72 | words = [] 73 | for prediction in predictions: 74 | blank_symbol_seen = False 75 | stripped_prediction = self.xp.full((1,), prediction[0], dtype=self.xp.int32) 76 | for char in prediction: 77 | if char == self.blank_symbol: 78 | blank_symbol_seen = True 79 | continue 80 | if char == stripped_prediction[-1] and not blank_symbol_seen: 81 | continue 82 | blank_symbol_seen = False 83 | stripped_prediction = self.xp.hstack((stripped_prediction, char.reshape(1, ))) 84 | words.append(stripped_prediction) 85 | return words 86 | 87 | 88 | class TextRecSoftmaxMetrics(TextRectMetrics): 89 | 90 | def calc_actual_loss(self, predictions, grid, labels): 91 | batch_size = labels.shape[0] 92 | labels = F.reshape(labels, (-1,)) 93 | 94 | predictions = F.concat([F.expand_dims(prediction, axis=1) for prediction in predictions], axis=1) 95 | predictions = F.reshape(predictions, (batch_size * self.num_timesteps, -1)) 96 | return F.softmax_cross_entropy(predictions, labels) 97 | -------------------------------------------------------------------------------- /chainer/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/chainer/models/__init__.py -------------------------------------------------------------------------------- /chainer/models/fsns_resnet.py: -------------------------------------------------------------------------------- 1 | import chainer.functions as F 2 | import chainer.links as L 3 | 4 | from chainer import Chain 5 | from chainer.links.model.vision.resnet import ResNetLayers 6 | 7 | 8 | class FSNSResNetLayers(ResNetLayers): 9 | 10 | @property 11 | def functions(self): 12 | functions = super().functions 13 | del functions['fc6'] 14 | del functions['prob'] 15 | return functions 16 | 17 | 18 | class FSNSRecognitionResnet(Chain): 19 | 20 | def __init__(self, target_shape, num_labels, num_timesteps, uses_original_data=False, dropout_ratio=0.5, use_dropout=False, use_blstm=False, use_attention=False): 21 | super().__init__() 22 | with self.init_scope(): 23 | self.resnet = FSNSResNetLayers('', 152) 24 | self.fc1 = L.Linear(None, 512) 25 | self.lstm = L.LSTM(None, 512) 26 | if use_blstm: 27 | self.blstm = L.LSTM(None, 512) 28 | self.classifier = L.Linear(None, 134) 29 | 30 | self.target_shape = target_shape 31 | self.num_labels = num_labels 32 | self.num_timesteps = num_timesteps 33 | self.uses_original_data = uses_original_data 34 | self.vis_anchor = None 35 | self.use_dropout = use_dropout 36 | self.dropout_ratio = dropout_ratio 37 | self.use_blstm = use_blstm 38 | self.use_attention = use_attention 39 | 40 | def __call__(self, images, localizations): 41 | points = F.spatial_transformer_grid(localizations, self.target_shape) 42 | rois = F.spatial_transformer_sampler(images, points) 43 | 44 | h = self.resnet(rois, layers=['res5', 'pool5']) 45 | 46 | self.vis_anchor = h['res5'] 47 | h = h['pool5'] 48 | 49 | if self.uses_original_data: 50 | # merge data of all 4 individual images in channel dimension 51 | batch_size, num_channels = h.shape 52 | h = F.reshape(h, (batch_size // 4, 4 * num_channels)) 53 | 54 | h = F.relu(self.fc1(h)) 55 | 56 | # for each timestep of the localization net do the 'classification' 57 | h = F.reshape(h, (self.num_timesteps, -1, self.fc1.out_size)) 58 | overall_predictions = [] 59 | for timestep in F.separate(h, axis=0): 60 | # go 2x num_labels plus 1 timesteps because of ctc loss 61 | lstm_predictions = [] 62 | self.lstm.reset_state() 63 | if self.use_blstm: 64 | self.blstm.reset_state() 65 | 66 | for _ in range(self.num_labels): 67 | lstm_prediction = self.lstm(timestep) 68 | lstm_predictions.append(lstm_prediction) 69 | 70 | if self.use_blstm: 71 | blstm_predictions = [] 72 | for lstm_prediction in reversed(lstm_predictions): 73 | blstm_prediction = self.blstm(lstm_prediction) 74 | blstm_predictions.append(blstm_prediction) 75 | 76 | lstm_predictions = list(reversed(blstm_predictions)) 77 | 78 | final_lstm_predictions = [] 79 | for lstm_prediction in lstm_predictions: 80 | classified = self.classifier(lstm_prediction) 81 | final_lstm_predictions.append(F.expand_dims(classified, axis=0)) 82 | 83 | final_lstm_predictions = F.concat(final_lstm_predictions, axis=0) 84 | overall_predictions.append(final_lstm_predictions) 85 | 86 | return overall_predictions, rois, points 87 | -------------------------------------------------------------------------------- /chainer/models/ic_stn.py: -------------------------------------------------------------------------------- 1 | import chainer.links as L 2 | import chainer.functions as F 3 | 4 | from chainer import Chain, cuda, Deserializer 5 | 6 | from functions.rotation_droput import rotation_dropout 7 | from models.fsns import ResnetBlock 8 | 9 | 10 | class InverseCompositionalLocalizationNet(Chain): 11 | 12 | def __init__(self, dropout_ratio, num_timesteps, num_refinement_steps, target_shape, zoom=0.9, do_parameter_refinement=False): 13 | super().__init__() 14 | with self.init_scope(): 15 | self.conv0 = L.Convolution2D(None, 32, 3, pad=1) 16 | self.bn0 = L.BatchNormalization(32) 17 | self.conv0_1 = L.Convolution2D(None, 32, 3, pad=1) 18 | self.bn0_1 = L.BatchNormalization(32) 19 | self.rs1 = ResnetBlock(32) 20 | self.rs2 = ResnetBlock(48, filter_increase=True) 21 | self.rs3 = ResnetBlock(48) 22 | self.rs4 = ResnetBlock(32) 23 | self.rs5 = ResnetBlock(48, filter_increase=True) 24 | self.lstm = L.LSTM(None, 256) 25 | self.transform_2 = L.Linear(256, 6) 26 | self.refinement_transform = L.Linear(2352, 6) 27 | 28 | for transform_param_layer in [self.transform_2, self.refinement_transform]: 29 | # initialize transform 30 | transform_param_layer.W.data[...] = 0 31 | 32 | transform_bias = transform_param_layer.b.data 33 | transform_bias[[0, 4]] = zoom 34 | transform_bias[[2, 5]] = 0 35 | 36 | self.refinement_transform.b.data[[0, 4]] = 0.1 37 | 38 | self.dropout_ratio = dropout_ratio 39 | self.num_timesteps = num_timesteps 40 | self.num_refinement_steps = num_refinement_steps 41 | self.target_shape = target_shape 42 | self.do_parameter_refinement = do_parameter_refinement 43 | self.vis_anchor = None 44 | 45 | def do_transformation_param_refinement_step(self, images, transformation_params): 46 | transformation_params = self.remove_homogeneous_coordinates(transformation_params) 47 | points = F.spatial_transformer_grid(transformation_params, self.target_shape) 48 | rois = F.spatial_transformer_sampler(images, points) 49 | 50 | # rerun parts of the feature extraction for producing a refined version of the transformation params 51 | h = self.bn0_1(self.conv0_1(rois)) 52 | h = F.average_pooling_2d(F.relu(h), 2, stride=2) 53 | 54 | h = self.rs4(h) 55 | h = F.max_pooling_2d(h, 2, stride=2) 56 | 57 | h = self.rs5(h) 58 | h = F.max_pooling_2d(h, 2, stride=2) 59 | 60 | transformation_params = self.refinement_transform(h) 61 | transformation_params = F.reshape(transformation_params, (-1, 2, 3)) 62 | transformation_params = rotation_dropout(transformation_params, ratio=self.dropout_ratio) 63 | return transformation_params 64 | 65 | def to_homogeneous_coordinates(self, transformation_params): 66 | batch_size = transformation_params.shape[0] 67 | transformation_fill = self.xp.zeros((batch_size, 1, 3), dtype=transformation_params.dtype) 68 | transformation_fill[:, 0, 2] = 1 69 | transformation_params = F.concat((transformation_params, transformation_fill), axis=1) 70 | return transformation_params 71 | 72 | def remove_homogeneous_coordinates(self, transformation_params): 73 | # we can remove homogeneous axis, as it is still 0 0 1 74 | axes = F.split_axis(transformation_params, 3, axis=1, force_tuple=True) 75 | return F.concat(axes[:-1], axis=1) 76 | 77 | def __call__(self, images): 78 | self.lstm.reset_state() 79 | 80 | h = self.bn0(self.conv0(images)) 81 | h = F.average_pooling_2d(F.relu(h), 2, stride=2) 82 | 83 | h = self.rs1(h) 84 | h = F.max_pooling_2d(h, 2, stride=2) 85 | 86 | h = self.rs2(h) 87 | h = F.max_pooling_2d(h, 2, stride=2) 88 | 89 | h = self.rs3(h) 90 | # h = self.rs4(h) 91 | self.vis_anchor = h 92 | h = F.average_pooling_2d(h, 5) 93 | 94 | localizations = [] 95 | 96 | with cuda.get_device_from_array(h.data): 97 | 98 | for _ in range(self.num_timesteps): 99 | timestep_localizations = [] 100 | in_feature = h 101 | lstm_prediction = F.relu(self.lstm(in_feature)) 102 | transformed = self.transform_2(lstm_prediction) 103 | transformed = F.reshape(transformed, (-1, 2, 3)) 104 | transformation_params = rotation_dropout(transformed, ratio=self.dropout_ratio) 105 | timestep_localizations.append(transformation_params) 106 | 107 | # self.transform_2.disable_update() 108 | 109 | if self.do_parameter_refinement: 110 | transformation_params = self.to_homogeneous_coordinates(transformation_params) 111 | # refine the transformation parameters 112 | for _ in range(self.num_refinement_steps): 113 | transformation_deltas = self.do_transformation_param_refinement_step(images, transformation_params) 114 | transformation_deltas = self.to_homogeneous_coordinates(transformation_deltas) 115 | 116 | transformation_params = F.batch_matmul(transformation_params, transformation_deltas) 117 | # transformation_params = F.batch_matmul(transformation_deltas, transformation_params) 118 | timestep_localizations.append(transformation_params[:, :-1, :]) 119 | 120 | localizations.append(timestep_localizations) 121 | 122 | return [F.concat(loc, axis=0) for loc in zip(*localizations)] 123 | 124 | def serialize(self, serializer): 125 | super().serialize(serializer) 126 | # only run rest of method if we are deserializing 127 | if not issubclass(serializer.__class__, Deserializer): 128 | return 129 | 130 | # if extra transform params are uninitialized we initialize them with the pretrained version of the previous 131 | # iteration (if there is any) 132 | # first check if we are loading a pre-trained model 133 | if not any('conv0' in file for file in serializer.npz.files): 134 | # nothing to do if we are not loading a pre-trained model 135 | return 136 | # no need to do anything if we already trained a model with extra refinement 137 | if any('bn0_1' in file for file in serializer.npz.files): 138 | return 139 | 140 | # copy trained params 141 | params_to_copy = [(self.conv0_1, self.conv0), (self.bn0_1, self.bn0), (self.rs4, self.rs1), (self.rs5, self.rs2)] 142 | for target, source in params_to_copy: 143 | target.copyparams(source) 144 | 145 | 146 | -------------------------------------------------------------------------------- /chainer/models/svhn.py: -------------------------------------------------------------------------------- 1 | from chainer import Chain, cuda 2 | import chainer.functions as F 3 | import chainer.links as L 4 | 5 | from functions.rotation_droput import rotation_dropout 6 | from insights.visual_backprop import VisualBackprop 7 | from models.fsns import ResnetBlock 8 | 9 | 10 | class SVHNLocalizationNet(Chain): 11 | 12 | def __init__(self, dropout_ratio, num_timesteps, zoom=0.9): 13 | super(SVHNLocalizationNet, self).__init__() 14 | with self.init_scope(): 15 | self.conv0 = L.Convolution2D(None, 32, 3, pad=1) 16 | self.bn0 = L.BatchNormalization(32) 17 | self.rs1 = ResnetBlock(32) 18 | self.rs2 = ResnetBlock(48, filter_increase=True) 19 | self.rs3 = ResnetBlock(48) 20 | self.lstm = L.LSTM(None, 256) 21 | self.transform_2 = L.Linear(256, 6) 22 | 23 | # initialize transform 24 | self.transform_2.W.data[...] = 0 25 | 26 | transform_bias = self.transform_2.b.data 27 | transform_bias[[0, 4]] = zoom 28 | transform_bias[[2, 5]] = 0 29 | 30 | self.dropout_ratio = dropout_ratio 31 | self._train = True 32 | self.num_timesteps = num_timesteps 33 | self.vis_anchor = None 34 | 35 | self.width_encoding = None 36 | self.height_encoding = None 37 | 38 | def __call__(self, images): 39 | self.lstm.reset_state() 40 | 41 | h = self.bn0(self.conv0(images)) 42 | h = F.average_pooling_2d(F.relu(h), 2, stride=2) 43 | 44 | h = self.rs1(h) 45 | h = F.max_pooling_2d(h, 2, stride=2) 46 | 47 | h = self.rs2(h) 48 | h = F.max_pooling_2d(h, 2, stride=2) 49 | 50 | h = self.rs3(h) 51 | # h = self.rs4(h) 52 | self.vis_anchor = h 53 | h = F.average_pooling_2d(h, 5, stride=2) 54 | 55 | localizations = [] 56 | 57 | with cuda.get_device_from_array(h.data): 58 | for _ in range(self.num_timesteps): 59 | in_feature = h 60 | lstm_prediction = F.relu(self.lstm(in_feature)) 61 | transformed = self.transform_2(lstm_prediction) 62 | transformed = F.reshape(transformed, (-1, 2, 3)) 63 | localizations.append(rotation_dropout(transformed, ratio=self.dropout_ratio)) 64 | 65 | return F.concat(localizations, axis=0) 66 | 67 | 68 | class SVHNRecognitionNet(Chain): 69 | 70 | def __init__(self, target_shape, num_labels, num_timesteps, use_blstm=False): 71 | super(SVHNRecognitionNet, self).__init__() 72 | with self.init_scope(): 73 | self.data_bn = L.BatchNormalization(3) 74 | self.conv0 = L.Convolution2D(None, 32, 3, pad=1, stride=2) 75 | self.bn0 = L.BatchNormalization(32) 76 | self.conv1 = L.Convolution2D(32, 32, 3, pad=1) 77 | self.bn1 = L.BatchNormalization(32) 78 | self.rs1 = ResnetBlock(32) 79 | self.rs2 = ResnetBlock(64, filter_increase=True) 80 | self.rs3 = ResnetBlock(128, filter_increase=True) 81 | self.fc1 = L.Linear(None, 256) 82 | self.lstm = L.LSTM(None, 256) 83 | if use_blstm: 84 | self.blstm = L.LSTM(None, 256) 85 | self.classifier = L.Linear(None, 11) 86 | 87 | self._train = True 88 | self.target_shape = target_shape 89 | self.num_labels = num_labels 90 | self.num_timesteps = num_timesteps 91 | self.vis_anchor = None 92 | self.use_blstm = use_blstm 93 | 94 | def __call__(self, images, localizations): 95 | points = F.spatial_transformer_grid(localizations, self.target_shape) 96 | rois = F.spatial_transformer_sampler(images, points) 97 | 98 | h = self.data_bn(rois) 99 | h = F.relu(self.bn0(self.conv0(h))) 100 | h = F.average_pooling_2d(h, 2, stride=2) 101 | 102 | h = self.rs1(h) 103 | h = self.rs2(h) 104 | h = F.max_pooling_2d(h, 2, stride=2) 105 | h = self.rs3(h) 106 | self.vis_anchor = h 107 | 108 | h = F.average_pooling_2d(h, 5, stride=1) 109 | 110 | h = F.relu(self.fc1(h)) 111 | 112 | # for each timestep of the localization net do the 'classification' 113 | h = F.reshape(h, (self.num_timesteps, -1, self.fc1.out_size)) 114 | overall_predictions = [] 115 | for timestep in F.separate(h, axis=0): 116 | lstm_predictions = [] 117 | self.lstm.reset_state() 118 | if self.use_blstm: 119 | self.blstm.reset_state() 120 | 121 | for _ in range(self.num_labels): 122 | lstm_prediction = self.lstm(timestep) 123 | lstm_predictions.append(lstm_prediction) 124 | 125 | if self.use_blstm: 126 | blstm_predictions = [] 127 | for lstm_prediction in reversed(lstm_predictions): 128 | blstm_prediction = self.blstm(lstm_prediction) 129 | blstm_predictions.append(blstm_prediction) 130 | 131 | lstm_predictions = reversed(blstm_predictions) 132 | 133 | final_lstm_predictions = [] 134 | for lstm_prediction in lstm_predictions: 135 | classified = self.classifier(lstm_prediction) 136 | final_lstm_predictions.append(F.expand_dims(classified, axis=1)) 137 | 138 | final_lstm_predictions = F.concat(final_lstm_predictions, axis=1) 139 | overall_predictions.append(final_lstm_predictions) 140 | 141 | return overall_predictions, rois, points 142 | 143 | 144 | class SVHNCTCRecognitionNet(Chain): 145 | 146 | def __init__(self, target_shape, num_labels, num_timesteps): 147 | super(SVHNCTCRecognitionNet, self).__init__() 148 | with self.init_scope(): 149 | self.data_bn = L.BatchNormalization(3) 150 | self.conv0 = L.Convolution2D(None, 32, 3, pad=1) 151 | self.bn0 = L.BatchNormalization(32) 152 | self.rs1 = ResnetBlock(32) 153 | self.rs2 = ResnetBlock(64, filter_increase=True) 154 | self.rs3 = ResnetBlock(128, filter_increase=True) 155 | self.fc1 = L.Linear(None, 256) 156 | self.lstm = L.LSTM(None, 256) 157 | self.classifier = L.Linear(None, 11) 158 | 159 | self._train = True 160 | self.target_shape = target_shape 161 | self.num_labels = num_labels 162 | self.num_timesteps = num_timesteps 163 | self.vis_anchor = None 164 | 165 | def __call__(self, images, localizations): 166 | points = F.spatial_transformer_grid(localizations, self.target_shape) 167 | rois = F.spatial_transformer_sampler(images, points) 168 | 169 | # h = self.data_bn(rois) 170 | h = F.relu(self.bn0(self.conv0(rois))) 171 | h = F.average_pooling_2d(h, 2, stride=2) 172 | 173 | h = self.rs1(h) 174 | h = self.rs2(h) 175 | h = F.max_pooling_2d(h, 2, stride=2) 176 | h = self.rs3(h) 177 | self.vis_anchor = h 178 | 179 | h = F.average_pooling_2d(h, 5, stride=1) 180 | 181 | h = F.relu(self.fc1(h)) 182 | 183 | # for each timestep of the localization net do the 'classification' 184 | h = F.reshape(h, (self.num_timesteps * 2 + 1, -1, self.fc1.out_size)) 185 | overall_predictions = [] 186 | for timestep in F.separate(h, axis=0): 187 | # go 2x num_labels plus 1 timesteps because of ctc loss 188 | lstm_predictions = [] 189 | self.lstm.reset_state() 190 | for _ in range(self.num_labels): 191 | lstm_prediction = self.lstm(timestep) 192 | classified = self.classifier(lstm_prediction) 193 | lstm_predictions.append(classified) 194 | overall_predictions.append(lstm_predictions) 195 | 196 | return overall_predictions, rois, points 197 | 198 | 199 | class SVHNNet(Chain): 200 | 201 | def __init__(self, localization_net, recognition_net): 202 | super(SVHNNet, self).__init__() 203 | with self.init_scope(): 204 | self.localization_net = localization_net 205 | self.recognition_net = recognition_net 206 | 207 | def __call__(self, images): 208 | batch_size = images.shape[0] 209 | h = self.localization_net(images) 210 | new_batch_size = h.shape[0] 211 | batch_size_increase_factor = new_batch_size // batch_size 212 | images = F.concat([images for _ in range(batch_size_increase_factor)], axis=0) 213 | 214 | return self.recognition_net(images, h) 215 | -------------------------------------------------------------------------------- /chainer/models/text_recognition.py: -------------------------------------------------------------------------------- 1 | import chainer.functions as F 2 | import chainer.links as L 3 | 4 | from chainer import Chain 5 | 6 | from models.fsns import ResnetBlock 7 | 8 | 9 | class TextRecognitionNet(Chain): 10 | 11 | def __init__(self, target_shape, num_rois, label_size, use_blstm=False): 12 | super().__init__() 13 | with self.init_scope(): 14 | self.conv0 = L.Convolution2D(None, 32, 3, pad=1, stride=2) 15 | self.bn0 = L.BatchNormalization(32) 16 | self.conv1 = L.Convolution2D(32, 32, 3, pad=1) 17 | self.bn1 = L.BatchNormalization(32) 18 | self.rs1 = ResnetBlock(32) 19 | self.rs2 = ResnetBlock(64, filter_increase=True) 20 | self.rs3 = ResnetBlock(128, filter_increase=True) 21 | self.fc1 = L.Linear(None, 256) 22 | self.lstm = L.LSTM(None, 256) 23 | if use_blstm: 24 | self.blstm = L.LSTM(None, 256) 25 | self.classifier = L.Linear(None, label_size) 26 | 27 | self.use_blstm = use_blstm 28 | self.target_shape = target_shape 29 | self.num_rois = num_rois 30 | 31 | def __call__(self, images, localizations): 32 | self.lstm.reset_state() 33 | if self.use_blstm: 34 | self.blstm.reset_state() 35 | 36 | points = [F.spatial_transformer_grid(localization, self.target_shape) for localization in localizations] 37 | rois = [F.spatial_transformer_sampler(images, point) for point in points] 38 | 39 | h = F.relu(self.bn0(self.conv0(rois[-1]))) 40 | h = F.average_pooling_2d(h, 2, stride=2) 41 | 42 | h = self.rs1(h) 43 | h = self.rs2(h) 44 | h = F.max_pooling_2d(h, 2, stride=2) 45 | h = self.rs3(h) 46 | self.vis_anchor = h 47 | 48 | h = F.average_pooling_2d(h, 5, stride=1) 49 | 50 | h = F.relu(self.fc1(h)) 51 | 52 | # each timestep of the localization contains one character prediction, that needs to be classified 53 | overall_predictions = [] 54 | h = F.reshape(h, (self.num_rois, -1, self.fc1.out_size)) 55 | 56 | for timestep in F.separate(h, axis=0): 57 | lstm_state = self.lstm(timestep) 58 | 59 | prediction = self.classifier(lstm_state) 60 | overall_predictions.append(prediction) 61 | 62 | return overall_predictions, rois, points 63 | 64 | 65 | class TextRecNet(Chain): 66 | 67 | def __init__(self, localization_net, recognition_net): 68 | super(TextRecNet, self).__init__() 69 | with self.init_scope(): 70 | self.localization_net = localization_net 71 | self.recognition_net = recognition_net 72 | 73 | def __call__(self, images): 74 | batch_size = images.shape[0] 75 | h = self.localization_net(images) 76 | new_batch_size = h[-1].shape[0] 77 | batch_size_increase_factor = new_batch_size // batch_size 78 | images = F.concat([images for _ in range(batch_size_increase_factor)], axis=0) 79 | 80 | return self.recognition_net(images, h) 81 | -------------------------------------------------------------------------------- /chainer/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/chainer/optimizers/__init__.py -------------------------------------------------------------------------------- /chainer/optimizers/multi_net_optimizer.py: -------------------------------------------------------------------------------- 1 | from chainer.optimizer import GradientMethod 2 | 3 | 4 | class MultiNetOptimizer(GradientMethod): 5 | 6 | def __getattribute__(self, item): 7 | """Proxy all calls that can not be handled by this class to base optimizer""" 8 | try: 9 | v = object.__getattribute__(self, item) 10 | except AttributeError: 11 | v = getattr(object.__getattribute__(self, 'base_optimizer'), item) 12 | return v 13 | 14 | def __init__(self, base_optimizer, optimize_all_interval=10): 15 | self.base_optimizer = base_optimizer 16 | self.optimize_all_interval = optimize_all_interval 17 | super().__init__() 18 | self.extra_links = None 19 | self.hyperparam = self.base_optimizer.hyperparam 20 | 21 | def setup(self, link, extra_links=()): 22 | self.extra_links = extra_links 23 | super().setup(link) 24 | 25 | def update(self, lossfun=None, *args, **kwargs): 26 | """Updates parameters based on a loss function or computed gradients. 27 | This method runs in two ways. 28 | - If ``lossfun`` is given, then it is used as a loss function to 29 | compute gradients. 30 | - Otherwise, this method assumes that the gradients are already 31 | computed. 32 | In both cases, the computed gradients are used to update parameters. 33 | The actual update routines are defined by the update rule of each 34 | parameter. 35 | """ 36 | if lossfun is not None: 37 | use_cleargrads = getattr(self, '_use_cleargrads', True) 38 | loss = lossfun(*args, **kwargs) 39 | if use_cleargrads: 40 | self.target.cleargrads() 41 | else: 42 | self.target.zerograds() 43 | loss.backward() 44 | del loss 45 | 46 | self.reallocate_cleared_grads() 47 | 48 | self.call_hooks() 49 | 50 | self.t += 1 51 | self.base_optimizer.t = self.t 52 | 53 | if self.t % self.optimize_all_interval == 0: 54 | # update all params 55 | for param in self.target.params(): 56 | param.update() 57 | else: 58 | # only update params in extra links 59 | for link in self.extra_links: 60 | for param in link.params(): 61 | param.update() 62 | 63 | def create_update_rule(self): 64 | return self.base_optimizer.create_update_rule() 65 | -------------------------------------------------------------------------------- /chainer/text_recognition_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | 4 | import os 5 | 6 | import json 7 | from collections import OrderedDict 8 | 9 | import chainer 10 | from pprint import pprint 11 | 12 | import chainer.functions as F 13 | import numpy as np 14 | 15 | from PIL import Image 16 | from chainer import configuration 17 | 18 | from utils.datatypes import Size 19 | 20 | 21 | def get_class_and_module(log_data): 22 | if not isinstance(log_data, list): 23 | if 'InverseCompositional' in log_data: 24 | module_name = 'ic_stn.py' 25 | klass_name = log_data 26 | else: 27 | module_name = 'text_recognition.py' 28 | klass_name = log_data 29 | else: 30 | klass_name, module_name = log_data 31 | return klass_name, module_name 32 | 33 | 34 | def load_module(module_file): 35 | module_spec = importlib.util.spec_from_file_location("models.model", module_file) 36 | module = importlib.util.module_from_spec(module_spec) 37 | module_spec.loader.exec_module(module) 38 | return module 39 | 40 | 41 | def build_recognition_net(recognition_net_class, target_shape, args): 42 | return recognition_net_class( 43 | target_shape, 44 | num_rois=args.timesteps, 45 | label_size=52, 46 | ) 47 | 48 | 49 | def build_localization_net(localization_net_class, target_shape, args): 50 | return localization_net_class( 51 | args.dropout_ratio, 52 | args.timesteps, 53 | 0, 54 | target_shape, 55 | zoom=1.0, 56 | do_parameter_refinement=False 57 | ) 58 | 59 | 60 | def build_fusion_net(fusion_net_class, localization_net, recognition_net): 61 | return fusion_net_class(localization_net, recognition_net) 62 | 63 | 64 | def create_network(args, log_data): 65 | # Step 1: build network 66 | localization_net_class_name, localization_module_name = get_class_and_module(log_data['localization_net']) 67 | module = load_module(os.path.abspath(os.path.join(args.model_dir, localization_module_name))) 68 | localization_net_class = eval('module.{}'.format(localization_net_class_name)) 69 | localization_net = build_localization_net(localization_net_class, log_data['target_size'], args) 70 | 71 | recognition_net_class_name, recognition_module_name = get_class_and_module(log_data['recognition_net']) 72 | module = load_module(os.path.abspath(os.path.join(args.model_dir, recognition_module_name))) 73 | recognition_net_class = eval('module.{}'.format(recognition_net_class_name)) 74 | recognition_net = build_recognition_net(recognition_net_class, target_shape, args) 75 | 76 | fusion_net_class_name, fusion_module_name = get_class_and_module(log_data['fusion_net']) 77 | module = load_module(os.path.abspath(os.path.join(args.model_dir, fusion_module_name))) 78 | fusion_net_class = eval('module.{}'.format(fusion_net_class_name)) 79 | net = build_fusion_net(fusion_net_class, localization_net, recognition_net) 80 | 81 | if args.gpu >= 0: 82 | net.to_gpu(args.gpu) 83 | 84 | return net 85 | 86 | 87 | def load_image(image_file, xp, image_size): 88 | with Image.open(image_file) as the_image: 89 | the_image = the_image.convert('L') 90 | the_image = the_image.resize((image_size.width, image_size.height), Image.LANCZOS) 91 | image = xp.asarray(the_image, dtype=np.float32) 92 | image /= 255 93 | image = xp.broadcast_to(image, (3, image_size.height, image_size.width)) 94 | return image 95 | 96 | 97 | def strip_prediction(predictions, xp, blank_symbol): 98 | words = [] 99 | for prediction in predictions: 100 | blank_symbol_seen = False 101 | stripped_prediction = xp.full((1,), prediction[0], dtype=xp.int32) 102 | for char in prediction: 103 | if char == blank_symbol: 104 | blank_symbol_seen = True 105 | continue 106 | if char == stripped_prediction[-1] and not blank_symbol_seen: 107 | continue 108 | blank_symbol_seen = False 109 | stripped_prediction = xp.hstack((stripped_prediction, char.reshape(1, ))) 110 | words.append(stripped_prediction) 111 | return words 112 | 113 | 114 | def extract_bbox(bbox, image_size, target_shape, xp): 115 | bbox.data[...] = (bbox.data[...] + 1) / 2 116 | bbox.data[0, :] *= image_size.width 117 | bbox.data[1, :] *= image_size.height 118 | 119 | x = xp.clip(bbox.data[0, :].reshape(target_shape), 0, image_size.width) 120 | y = xp.clip(bbox.data[1, :].reshape(target_shape), 0, image_size.height) 121 | 122 | top_left = (float(x[0, 0]), float(y[0, 0])) 123 | bottom_right = (float(x[-1, -1]), float(y[-1, -1])) 124 | 125 | return top_left, bottom_right 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser(description="Tool that loads model and predicts on a given image") 130 | parser.add_argument("model_dir", help="path to directory where model is saved") 131 | parser.add_argument("snapshot_name", help="name of the snapshot to load") 132 | parser.add_argument("image_path", help="path to the image that shall be evaluated") 133 | parser.add_argument("char_map", help="path to char map, that maps class id to character") 134 | parser.add_argument("--gpu", type=int, default=-1, help="id of gpu to use [default: use cpu]") 135 | 136 | args = parser.parse_args() 137 | # set standard args that should always hold true if using the supplied model 138 | args.log_name = 'log' 139 | args.dropout_ratio = 0.5 140 | args.blank_symbol = 0 141 | # max number of text regions in the image 142 | args.timesteps = 23 143 | # max number of characters per word 144 | args.num_labels = 1 145 | 146 | # open log and extract meta information 147 | with open(os.path.join(args.model_dir, args.log_name)) as the_log: 148 | log_data = json.load(the_log)[0] 149 | 150 | target_shape = Size._make(log_data['target_size']) 151 | image_size = Size._make(log_data['image_size']) 152 | 153 | xp = chainer.cuda.cupy if args.gpu >= 0 else np 154 | network = create_network(args, log_data) 155 | 156 | # load weights 157 | with np.load(os.path.join(args.model_dir, args.snapshot_name)) as f: 158 | chainer.serializers.NpzDeserializer(f).load(network) 159 | 160 | # load char map 161 | with open(args.char_map) as the_map: 162 | char_map = json.load(the_map) 163 | 164 | # load image 165 | image = load_image(args.image_path, xp, image_size) 166 | with configuration.using_config('train', False): 167 | predictions, crops, grids = network(image[xp.newaxis, ...]) 168 | 169 | # extract class scores for each word 170 | words = OrderedDict({}) 171 | 172 | predictions = F.concat([F.expand_dims(prediction, axis=0) for prediction in predictions], axis=0) 173 | 174 | classification = F.softmax(predictions, axis=2) 175 | classification = classification.data 176 | classification = xp.argmax(classification, axis=2) 177 | classification = xp.transpose(classification, (1, 0)) 178 | 179 | word = strip_prediction(classification, xp, args.blank_symbol)[0] 180 | 181 | word = "".join(map(lambda x: chr(char_map[str(x)]), word)) 182 | 183 | bboxes = [] 184 | for bbox in grids[0]: 185 | bbox = extract_bbox(bbox, image_size, target_shape, xp) 186 | bboxes.append(OrderedDict({ 187 | 'top_left': bbox[0], 188 | 'bottom_right': bbox[1] 189 | })) 190 | words[word] = bboxes 191 | 192 | pprint(words) 193 | 194 | 195 | 196 | 197 | -------------------------------------------------------------------------------- /chainer/train_mnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import statistics 4 | from itertools import zip_longest 5 | 6 | import chainer 7 | import datetime 8 | 9 | import numpy as np 10 | from PIL import Image 11 | from chainer import cuda 12 | import chainer.functions as F 13 | from chainer.training import extensions 14 | 15 | from datasets.mnist_dataset import MNISTDataset, HDF5MnistDataset, FileBasedMNISTDataset 16 | from insights.bbox_plotter import BBOXPlotter 17 | from models.mnist import MNISTNet, MNISTLocalizationNet, MNISTRecognitionNet 18 | from models.text_detection_net import SmallLocalizationNet, TextDetectionNet 19 | from utils.datatypes import Size 20 | from utils.multi_accuracy_classifier import Classifier 21 | from utils.train_utils import add_default_arguments, get_fast_evaluator, AttributeUpdater, get_trainer, \ 22 | concat_and_pad_examples, TwoStateLearningRateShifter 23 | 24 | 25 | def mnist_loss(x, t): 26 | xp = cuda.get_array_module(x[0].data, t.data) 27 | batch_predictions, _, _ = x 28 | losses = [] 29 | 30 | for predictions, labels in zip(F.split_axis(batch_predictions, args.timesteps, axis=1), F.separate(t, axis=1)): 31 | batch_size, _, num_classes = predictions.data.shape 32 | predictions = F.reshape(F.flatten(predictions), (batch_size, num_classes)) 33 | losses.append(F.softmax_cross_entropy(predictions, labels)) 34 | 35 | return sum(losses) 36 | 37 | 38 | def mnist_accuracy(x, t): 39 | xp = cuda.get_array_module(x[0].data, t.data) 40 | batch_predictions, _, _ = x 41 | accuracies = [] 42 | 43 | for predictions, labels in zip(F.split_axis(batch_predictions, args.timesteps, axis=1), F.separate(t, axis=1)): 44 | batch_size, _, num_classes = predictions.data.shape 45 | predictions = F.reshape(F.flatten(predictions), (batch_size, num_classes)) 46 | accuracies.append(F.accuracy(predictions, labels)) 47 | 48 | return sum(accuracies) / max(len(accuracies), 1) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser(description="Tool to train a text detection network based on Spatial Transformers") 53 | parser.add_argument("train_data", help="path to train data file") 54 | parser.add_argument("val_data", help="path to validation data file") 55 | parser.add_argument("--timesteps", type=int, default=5, help="number of timesteps the GRU shall run [default: 5]") 56 | parser.add_argument("--alternative", action="store_true", default=False, help="use alternative implementation of spatial Transformers") 57 | parser.add_argument("-ds", dest='downsample_factor', type=int, default=2, help="downsample for image sampler") 58 | parser = add_default_arguments(parser) 59 | args = parser.parse_args() 60 | 61 | image_size = Size(width=200, height=200) 62 | 63 | localization_net = MNISTLocalizationNet(args.dropout_ratio, args.timesteps) 64 | recognition_net = MNISTRecognitionNet(image_size, args.dropout_ratio, downsample_factor=args.downsample_factor, use_alternative=args.alternative) 65 | net = MNISTNet(localization_net, recognition_net) 66 | 67 | model = Classifier(net, ('accuracy', ), lossfun=mnist_loss, accfun=mnist_accuracy) 68 | if args.gpu >= 0: 69 | cuda.get_device(args.gpu).use() 70 | model.to_gpu() 71 | 72 | # optimizer = chainer.optimizers.MomentumSGD(lr=args.learning_rate) 73 | optimizer = chainer.optimizers.Adam(alpha=1e-4) 74 | # lr_shifter = AttributeUpdater(0.1, trigger=(5, 'epoch')) 75 | # optimizer = chainer.optimizers.RMSprop(lr=args.learning_rate) 76 | # optimizer = chainer.optimizers.AdaDelta(rho=0.9) 77 | optimizer.setup(model) 78 | optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005)) 79 | 80 | train_dataset = FileBasedMNISTDataset(args.train_data) 81 | validation_dataset = FileBasedMNISTDataset(args.val_data) 82 | 83 | # train_dataset = MNISTDataset(args.train_data, "train") 84 | # validation_dataset = MNISTDataset(args.train_data, "valid") 85 | 86 | train_iterator = chainer.iterators.MultiprocessIterator(train_dataset, args.batch_size) 87 | validation_iterator = chainer.iterators.MultiprocessIterator(validation_dataset, args.batch_size) 88 | 89 | updater = chainer.training.StandardUpdater(train_iterator, optimizer, device=args.gpu) 90 | 91 | log_dir = os.path.join(args.log_dir, "{}_{}".format(args.log_name, datetime.datetime.now().isoformat())) 92 | 93 | fields_to_print = [ 94 | 'epoch', 95 | 'iteration', 96 | 'main/loss', 97 | 'main/accuracy', 98 | 'validation/main/loss', 99 | 'validation/main/accuracy', 100 | ] 101 | 102 | FastEvaluator = get_fast_evaluator((args.test_interval, 'iteration')) 103 | evaluator = FastEvaluator(validation_iterator, model, device=args.gpu, eval_func=lambda *args: model(*args), 104 | num_iterations=args.test_iterations, converter=concat_and_pad_examples) 105 | 106 | # take snapshot of model every 5 epochs 107 | model_snapshotter = extensions.snapshot_object(net, 'model_{.updater.iteration}.npz', trigger=(5, 'epoch')) 108 | 109 | # bbox plotter test 110 | test_image = validation_dataset.get_example(0)[0] 111 | bbox_plotter = BBOXPlotter(test_image, os.path.join(log_dir, 'boxes'), args.downsample_factor, send_bboxes=True) 112 | 113 | learning_rate_schedule = [ 114 | { 115 | "state": TwoStateLearningRateShifter.INTERVAL_BASED_SHIFT_STATE, 116 | "target_lr": 5e-4, 117 | "update_trigger": (10, 'epoch'), 118 | "stop_trigger": (70, 'epoch'), 119 | }, 120 | { 121 | "state": TwoStateLearningRateShifter.CONTINUOS_SHIFT_STATE, 122 | "target_lr": 5e-10, 123 | "update_trigger": (15, 'epoch'), 124 | "stop_trigger": (90, 'epoch'), 125 | }, 126 | ] 127 | 128 | # num_epochs = sum([phase["stop_trigger"][0] for phase in learning_rate_schedule]) 129 | # 130 | # lr_shifter = TwoStateLearningRateShifter(args.learning_rate, learning_rate_schedule) 131 | 132 | trainer = get_trainer( 133 | net, 134 | updater, 135 | log_dir, 136 | fields_to_print, 137 | epochs=args.epochs, 138 | snapshot_interval=args.snapshot_interval, 139 | print_interval=args.log_interval, 140 | extra_extensions=( 141 | evaluator, 142 | model_snapshotter, 143 | bbox_plotter, 144 | # lr_shifter, 145 | ) 146 | ) 147 | 148 | if args.resume is not None: 149 | print("resuming training from {}".format(args.resume)) 150 | chainer.serializers.load_npz(args.resume, trainer) 151 | 152 | trainer.run() 153 | -------------------------------------------------------------------------------- /chainer/train_svhn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | 5 | 6 | import chainer 7 | import datetime 8 | 9 | import numpy as np 10 | import shutil 11 | from chainer.training.updaters import MultiprocessParallelUpdater 12 | from chainer.training import extensions 13 | 14 | from commands.interactive_train import open_interactive_prompt 15 | from datasets.file_dataset import FileBasedDataset 16 | from datasets.sub_dataset import split_dataset, split_dataset_n_random 17 | from insights.svhn_bbox_plotter import SVHNBBoxPlotter 18 | from metrics.svhn_softmax_metrics import SVHNSoftmaxMetrics 19 | from models.svhn import SVHNLocalizationNet, SVHNRecognitionNet, SVHNNet 20 | from utils.baby_step_curriculum import BabyStepCurriculum 21 | from utils.datatypes import Size 22 | from utils.multi_accuracy_classifier import Classifier 23 | from utils.train_utils import add_default_arguments, get_fast_evaluator, get_trainer, \ 24 | concat_and_pad_examples 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser(description="Tool to train a text detection network based on Spatial Transformers") 28 | parser.add_argument('dataset_specification', 29 | help='path to json file that contains all datasets to use in a list of dicts') 30 | parser.add_argument("--blank-label", type=int, help="blank label to use during training") 31 | parser.add_argument("--char-map", help="path to char map") 32 | parser.add_argument("--send-bboxes", action='store_true', default=False, 33 | help="send predicted bboxes for each iteration") 34 | parser.add_argument("--port", type=int, default=1337, help="port to connect to for sending bboxes") 35 | parser.add_argument("--area-factor", type=float, default=0, help="factor for incorporating area loss") 36 | parser.add_argument("--area-scale-factor", type=float, default=2, help="area scale factor for changing area loss over time") 37 | parser.add_argument("--aspect-factor", type=float, default=0, help="for for incorporating aspect ratio loss") 38 | parser.add_argument("--load-localization", action='store_true', default=False, help="only load localization net") 39 | parser.add_argument("--load-recognition", action='store_true', default=False, help="only load recognition net") 40 | parser.add_argument("--is-trainer-snapshot", action='store_true', default=False, 41 | help="indicate that snapshot to load has been saved by trainer itself") 42 | parser.add_argument("--no-log", action='store_false', default=True, help="disable logging") 43 | parser.add_argument("--freeze-localization", action='store_true', default=False, 44 | help='freeze weights of localization net') 45 | parser.add_argument("--zoom", type=float, default=0.9, help="Zoom for initial bias of spatial transformer") 46 | parser.add_argument("--optimize-all-interval", type=int, default=5, 47 | help="interval in which to optimize the whole network instead of only a part") 48 | parser.add_argument("--use-dropout", action='store_true', default=False, help='use dropout in network') 49 | parser.add_argument("--test-image", help='path to an image that should be used by BBoxPlotter') 50 | parser = add_default_arguments(parser) 51 | args = parser.parse_args() 52 | 53 | image_size = Size(width=200, height=200) 54 | target_shape = Size(width=50, height=50) 55 | 56 | # attributes that need to be adjusted, once the Curriculum decides to use 57 | # a more difficult dataset 58 | # this is a 'map' of attribute name to path in trainer object 59 | attributes_to_adjust = [ 60 | ('num_timesteps', ['predictor', 'localization_net']), 61 | ('num_timesteps', ['predictor', 'recognition_net']), 62 | ('num_timesteps', ['lossfun', '__self__']), 63 | ('num_labels', ['predictor', 'recognition_net']), 64 | ] 65 | 66 | curriculum = BabyStepCurriculum( 67 | args.dataset_specification, 68 | FileBasedDataset, 69 | args.blank_label, 70 | args.gpus, 71 | attributes_to_adjust=attributes_to_adjust, 72 | trigger=(args.test_interval, 'iteration'), 73 | min_delta=0.1, 74 | ) 75 | 76 | train_dataset, validation_dataset = curriculum.load_dataset(0) 77 | train_dataset.resize_size = image_size 78 | validation_dataset.resize_size = image_size 79 | 80 | metrics = SVHNSoftmaxMetrics( 81 | args.blank_label, 82 | args.char_map, 83 | train_dataset.num_timesteps, 84 | image_size, 85 | area_loss_factor=args.area_factor, 86 | aspect_ratio_loss_factor=args.aspect_factor, 87 | area_scaling_factor=args.area_scale_factor, 88 | ) 89 | 90 | localization_net = SVHNLocalizationNet( 91 | args.dropout_ratio, 92 | train_dataset.num_timesteps, 93 | zoom=args.zoom, 94 | ) 95 | recognition_net = SVHNRecognitionNet( 96 | target_shape, 97 | train_dataset.get_label_length(train_dataset.num_timesteps, check_length=False), 98 | train_dataset.num_timesteps, 99 | ) 100 | net = SVHNNet(localization_net, recognition_net) 101 | 102 | model = Classifier(net, ('accuracy',), lossfun=metrics.calc_loss, accfun=metrics.calc_accuracy, 103 | provide_label_during_forward=False) 104 | 105 | if args.resume is not None: 106 | with np.load(args.resume) as f: 107 | if args.load_localization: 108 | if args.is_trainer_snapshot: 109 | chainer.serializers.NpzDeserializer(f)['/updater/model:main/predictor/localization_net'].load( 110 | localization_net) 111 | else: 112 | chainer.serializers.NpzDeserializer(f, strict=False)['localization_net'].load(localization_net) 113 | elif args.load_recognition: 114 | if args.is_trainer_snapshot: 115 | chainer.serializers.NpzDeserializer(f)['/updater/model:main/predictor/recognition_net'].load( 116 | recognition_net 117 | ) 118 | else: 119 | chainer.serializers.NpzDeserializer(f)['recognition_net'].load(recognition_net) 120 | else: 121 | if args.is_trainer_snapshot: 122 | chainer.serializers.NpzDeserializer(f)['/updater/model:main/predictor'].load(net) 123 | else: 124 | chainer.serializers.NpzDeserializer(f).load(net) 125 | 126 | optimizer = chainer.optimizers.Adam(alpha=args.learning_rate) 127 | optimizer.setup(model) 128 | optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005)) 129 | optimizer.add_hook(chainer.optimizer.GradientClipping(2)) 130 | 131 | # freeze localization net if user wants to do that 132 | if args.freeze_localization: 133 | localization_net.disable_update() 134 | 135 | if len(args.gpus) > 1: 136 | gpu_datasets = split_dataset_n_random(train_dataset, len(args.gpus)) 137 | if not len(gpu_datasets[0]) == len(gpu_datasets[-1]): 138 | adapted_second_split = split_dataset(gpu_datasets[-1], len(gpu_datasets[0]))[0] 139 | gpu_datasets[-1] = adapted_second_split 140 | else: 141 | gpu_datasets = [train_dataset] 142 | 143 | train_iterators = [chainer.iterators.MultiprocessIterator(dataset, args.batch_size) for dataset in gpu_datasets] 144 | validation_iterator = chainer.iterators.MultiprocessIterator(validation_dataset, args.batch_size, repeat=False) 145 | 146 | updater = MultiprocessParallelUpdater(train_iterators, optimizer, devices=args.gpus) 147 | 148 | log_dir = os.path.join(args.log_dir, "{}_{}".format(datetime.datetime.now().isoformat(), args.log_name)) 149 | args.log_dir = log_dir 150 | 151 | # backup current file 152 | if not os.path.exists(log_dir): 153 | os.makedirs(log_dir, exist_ok=True) 154 | shutil.copy(__file__, log_dir) 155 | 156 | # log all necessary configuration params 157 | report = { 158 | 'log_dir': log_dir, 159 | 'image_size': image_size, 160 | 'target_size': target_shape, 161 | 'localization_net': localization_net.__class__.__name__, 162 | 'recognition_net': recognition_net.__class__.__name__, 163 | 'fusion_net': net.__class__.__name__, 164 | } 165 | 166 | for argument in filter(lambda x: not x.startswith('_'), dir(args)): 167 | report[argument] = getattr(args, argument) 168 | 169 | # callback that logs report 170 | def log_postprocess(stats_cpu): 171 | # only log further information once and not every time we log our progress 172 | if stats_cpu['epoch'] == 0 and stats_cpu['iteration'] == args.log_interval: 173 | stats_cpu.update(report) 174 | 175 | 176 | fields_to_print = [ 177 | 'epoch', 178 | 'iteration', 179 | 'main/loss', 180 | 'main/accuracy', 181 | 'lr', 182 | 'fast_validation/main/loss', 183 | 'fast_validation/main/accuracy', 184 | 'validation/main/loss', 185 | 'validation/main/accuracy', 186 | ] 187 | 188 | FastEvaluator = get_fast_evaluator((args.test_interval, 'iteration')) 189 | evaluator = ( 190 | FastEvaluator( 191 | validation_iterator, 192 | model, 193 | device=updater._devices[0], 194 | eval_func=lambda *args: model(*args), 195 | num_iterations=args.test_iterations, 196 | converter=concat_and_pad_examples 197 | ), 198 | (args.test_interval, 'iteration') 199 | ) 200 | epoch_validation_iterator = copy.copy(validation_iterator) 201 | epoch_validation_iterator._repeat = False 202 | epoch_evaluator = ( 203 | chainer.training.extensions.Evaluator( 204 | epoch_validation_iterator, 205 | model, 206 | device=updater._devices[0], 207 | converter=concat_and_pad_examples, 208 | ), 209 | (1, 'epoch') 210 | ) 211 | 212 | model_snapshotter = ( 213 | extensions.snapshot_object(net, 'model_{.updater.iteration}.npz'), (args.snapshot_interval, 'iteration')) 214 | 215 | # bbox plotter test 216 | if not args.test_image: 217 | test_image = validation_dataset.get_example(0)[0] 218 | else: 219 | test_image = train_dataset.load_image(args.test_image) 220 | 221 | bbox_plotter = (SVHNBBoxPlotter( 222 | test_image, 223 | os.path.join(log_dir, 'boxes'), 224 | target_shape, 225 | metrics, 226 | send_bboxes=args.send_bboxes, 227 | upstream_port=args.port, 228 | visualization_anchors=[["localization_net", "vis_anchor"], ["recognition_net", "vis_anchor"]] 229 | ), (1, 'iteration')) 230 | 231 | trainer = get_trainer( 232 | net, 233 | updater, 234 | log_dir, 235 | fields_to_print, 236 | curriculum=curriculum, 237 | epochs=args.epochs, 238 | snapshot_interval=args.snapshot_interval, 239 | print_interval=args.log_interval, 240 | extra_extensions=( 241 | evaluator, 242 | # epoch_evaluator, 243 | model_snapshotter, 244 | bbox_plotter, 245 | (curriculum, (args.test_interval, 'iteration')), 246 | # lr_shifter, 247 | ), 248 | postprocess=log_postprocess, 249 | do_logging=args.no_log, 250 | ) 251 | 252 | open_interactive_prompt( 253 | bbox_plotter=bbox_plotter[0], 254 | curriculum=curriculum, 255 | # lr_shifter=lr_shifter[0], 256 | ) 257 | 258 | trainer.run() 259 | -------------------------------------------------------------------------------- /chainer/utils/README.md: -------------------------------------------------------------------------------- 1 | # Training Utils for chainer 2 | 3 | This repository contains all files that can be reused during the training of a network with chainer. -------------------------------------------------------------------------------- /chainer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/chainer/utils/__init__.py -------------------------------------------------------------------------------- /chainer/utils/baby_step_curriculum.py: -------------------------------------------------------------------------------- 1 | import json 2 | import statistics 3 | from collections import deque 4 | 5 | import copy 6 | from chainer import cuda, Variable 7 | from chainer.training import extension 8 | import chainer.training.trigger as trigger_module 9 | 10 | from datasets.concatenated_dataset import ConcatenatedDataset 11 | from datasets.sub_dataset import split_dataset_random, split_dataset, split_dataset_n_random 12 | 13 | 14 | class BabyStepCurriculum(extension.Extension): 15 | 16 | def __init__(self, dataset_specification, dataset_class, blank_label, gpus, trigger=(1, 'epoch'), min_delta=0.01, attributes_to_adjust=(), maxlen=5, dataset_args=None): 17 | if dataset_args is None: 18 | dataset_args = {} 19 | self.dataset_class = dataset_class 20 | self.dataset_args = dataset_args 21 | self.trigger = trigger_module.get_trigger(trigger) 22 | self.maxlen = maxlen 23 | self.queue = deque(maxlen=self.maxlen) 24 | self.min_delta = min_delta 25 | self.attributes_to_adjust = attributes_to_adjust 26 | self.blank_label = blank_label 27 | self.gpus = gpus 28 | self.force_enlarge_dataset = False 29 | self.training_finished = False 30 | 31 | with open(dataset_specification) as specification: 32 | specification = json.load(specification) 33 | 34 | self.train_curriculum = {i: s['train'] for i, s in enumerate(specification)} 35 | self.validation_curriculum = {i: s['validation'] for i, s in enumerate(specification)} 36 | 37 | self.current_level = 0 38 | 39 | def load_dataset(self, level): 40 | train_dataset = self.dataset_class(self.train_curriculum[level], **self.dataset_args) 41 | validation_dataset = self.dataset_class(self.validation_curriculum[level], **self.dataset_args) 42 | 43 | return train_dataset, validation_dataset 44 | 45 | def training_converged(self): 46 | # check whether system already settled and we can enlarge train set 47 | reference_value = self.queue[self.maxlen-1] 48 | deltas = [] 49 | for value in self.queue: 50 | deltas.append(abs(value - reference_value)) 51 | 52 | mean = statistics.mean(deltas) 53 | return mean <= self.min_delta 54 | 55 | def adjust_attributes(self, model, dataset): 56 | for attribute_name, attribute_path in self.attributes_to_adjust: 57 | chain = model 58 | # find the correct chain/link in our model as provided by attribute path 59 | for path in attribute_path: 60 | chain = getattr(chain, path) 61 | # set the corresponding attribute of our chain/link, with the attribute provided by the given dataset 62 | setattr(chain, attribute_name, getattr(dataset, attribute_name)) 63 | 64 | def __call__(self, trainer): 65 | if self.force_enlarge_dataset: 66 | self.force_enlarge_dataset = False 67 | self.enlarge_dataset(trainer) 68 | 69 | if self.trigger(trainer): 70 | with cuda.get_device_from_id(trainer.updater.get_optimizer('main').target._device_id): 71 | loss = trainer.observation.get('fast_validation/main/loss', None) 72 | if loss is None: 73 | return 74 | queue_data = loss.data if isinstance(loss, Variable) else loss 75 | self.queue.append(float(queue_data)) 76 | if len(self.queue) >= self.maxlen: 77 | if not self.training_converged(): 78 | return 79 | 80 | self.enlarge_dataset(trainer) 81 | 82 | def enlarge_dataset(self, trainer): 83 | print("enlarging datasets") 84 | # we can add new samples to the train dataset 85 | self.current_level += 1 86 | if self.current_level < len(self.train_curriculum): 87 | train_dataset, validation_dataset = self.load_dataset(self.current_level) 88 | else: 89 | # we have exhausted our train curriculum we need to stop! 90 | self.training_finished = True 91 | print("Training curriculum has finished. Terminating the training process.\n") 92 | return 93 | self.update_iterators(trainer, train_dataset, validation_dataset) 94 | self.adjust_attributes(trainer.updater.get_optimizer('main').target, train_dataset) 95 | self.queue.clear() 96 | 97 | def split_dataset(self, dataset): 98 | gpu_datasets = split_dataset_n_random(dataset, len(self.gpus)) 99 | if not len(gpu_datasets[0]) == len(gpu_datasets[-1]): 100 | adapted_second_split = split_dataset(gpu_datasets[-1], len(gpu_datasets[0]))[0] 101 | gpu_datasets[-1] = adapted_second_split 102 | return gpu_datasets 103 | 104 | def pad_dataset(self, old_dataset, new_dataset): 105 | old_dataset.pad_labels(new_dataset.num_timesteps, self.blank_label) 106 | return old_dataset 107 | 108 | def update_iterators(self, trainer, train_dataset, validation_dataset): 109 | train_iterators = getattr(trainer.updater, '_mpu_iterators', None) 110 | if train_iterators is None: 111 | train_iterators = [trainer.updater.get_iterator('main')] 112 | 113 | validation_iterator = trainer.get_extension('fast_validation').get_iterator('main') 114 | 115 | # pad old dataset 116 | for train_iterator in train_iterators: 117 | train_iterator.dataset = self.pad_dataset(train_iterator.dataset, train_dataset) 118 | validation_iterator.dataset = self.pad_dataset(validation_iterator.dataset, validation_dataset) 119 | 120 | # concatenate new dataset with old dataset 121 | new_train_datasets = [] 122 | if len(train_iterators) > 1: 123 | for iterator, dataset in zip(train_iterators, self.split_dataset(train_dataset)): 124 | new_train_datasets.append(ConcatenatedDataset(dataset, iterator.dataset)) 125 | else: 126 | new_train_datasets.append(ConcatenatedDataset(train_dataset, train_iterators[0].dataset)) 127 | new_validation_dataset = ConcatenatedDataset(validation_dataset, validation_iterator.dataset) 128 | 129 | # create new iterator 130 | new_train_iterators = [iterator.__class__( 131 | dataset, 132 | iterator.batch_size, 133 | ) for iterator, dataset in zip(train_iterators, new_train_datasets)] 134 | 135 | new_validation_iterator = validation_iterator.__class__( 136 | new_validation_dataset, 137 | validation_iterator.batch_size, 138 | ) 139 | 140 | # exchange iterators in trainer 141 | if hasattr(trainer.updater, '_mpu_iterators'): 142 | # for iterator, worker in zip(new_train_iterators, trainer.updater._workers): 143 | # worker.iterator = iterator 144 | trainer.updater._mpu_iterators = new_train_iterators 145 | trainer.updater._iterators['main'] = new_train_iterators[0] 146 | trainer.get_extension('fast_validation')._iterators['main'] = new_validation_iterator 147 | 148 | # in case we have a real validation extension, not just our fast evaluator we also need to change the iterator there 149 | try: 150 | validator = trainer.get_extension('validation') 151 | copy_of_new_validation_iterator = copy.copy(new_validation_iterator) 152 | copy_of_new_validation_iterator._repeat = False 153 | validator._iterators['main'] = copy_of_new_validation_iterator 154 | except KeyError: 155 | pass 156 | 157 | for iterator in [*train_iterators, validation_iterator]: 158 | iterator.finalize() 159 | 160 | -------------------------------------------------------------------------------- /chainer/utils/create_gif.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | 5 | from collections import namedtuple 6 | 7 | from PIL import Image 8 | from PIL import ImageChops 9 | from PIL.GifImagePlugin import getheader, getdata 10 | 11 | SUPPORTED_IMAGETYPES = [".png", ".jpg", ".jpeg"] 12 | ImageData = namedtuple("ImageData", ["file_name", "image"]) 13 | 14 | 15 | def intToBin(i): 16 | """ Integer to two bytes """ 17 | # devide in two parts (bytes) 18 | i1 = i % 256 19 | i2 = int( i/256) 20 | # make string (little endian) 21 | return chr(i1) + chr(i2) 22 | 23 | 24 | def create_loop_header(loops=0): 25 | if loops == 0 or loops == float('inf'): 26 | loops = 2 ** 16 - 1 27 | 28 | bb = "\x21\xFF\x0B" # application extension 29 | bb += "NETSCAPE2.0" 30 | bb += "\x03\x01" 31 | bb += intToBin(loops) 32 | bb += '\x00' # end 33 | return [bb.encode('utf-8')] 34 | 35 | 36 | def makedelta(fp, sequence): 37 | """Convert list of image frames to a GIF animation file""" 38 | 39 | frames = 0 40 | 41 | previous = None 42 | 43 | for im in sequence: 44 | 45 | # To specify duration, add the time in milliseconds to getdata(), 46 | # e.g. getdata(im, duration=1000) 47 | 48 | if not previous: 49 | 50 | # global header 51 | loops = 2 ** 16 - 1 52 | for s in getheader(im, info={"loop": loops})[0] + getdata(im, duration=10, loop=2 ** 16 - 1): 53 | fp.write(s) 54 | 55 | else: 56 | 57 | # delta frame 58 | delta = ImageChops.subtract_modulo(im, previous) 59 | 60 | bbox = delta.getbbox() 61 | 62 | if bbox: 63 | 64 | # compress difference 65 | for s in getdata(im.crop(bbox), offset=bbox[:2], duration=10): 66 | fp.write(s) 67 | 68 | else: 69 | # FIXME: what should we do in this case? 70 | pass 71 | 72 | previous = im.copy() 73 | 74 | frames += 1 75 | 76 | fp.write(b";") 77 | 78 | return frames 79 | 80 | 81 | def make_gif(image_dir, dest_file, pattern="(\d+)"): 82 | sort_pattern = re.compile(pattern) 83 | 84 | image_files = filter(lambda x: os.path.splitext(x)[-1] in SUPPORTED_IMAGETYPES, os.listdir(image_dir)) 85 | images = [] 86 | 87 | try: 88 | print("loading images") 89 | for file_name in image_files: 90 | path = os.path.join(image_dir, file_name) 91 | images.append(ImageData._make((file_name, Image.open(path).convert('P')))) 92 | 93 | print("sorting images") 94 | images_sorted = sorted(images, key=lambda x: int(re.search(sort_pattern, x.file_name).group(1))) 95 | 96 | print("writing gif") 97 | with open(dest_file, "wb") as out_file: 98 | makedelta(out_file, [image.image for image in images_sorted]) 99 | 100 | finally: 101 | for image in images: 102 | image.image.close() 103 | 104 | 105 | if __name__ == "__main__": 106 | parser = argparse.ArgumentParser(description='Tool that creates a gif out of a number of given input images') 107 | parser.add_argument("image_dir", help="path to directory that contains all images that shall be converted to a gif") 108 | parser.add_argument("dest_file", help="path to destination gif file") 109 | parser.add_argument("--pattern", default="(\d+)", help="naming pattern to extract the ordering of the images") 110 | 111 | args = parser.parse_args() 112 | 113 | make_gif(args.image_dir, args.dest_file, args.pattern) -------------------------------------------------------------------------------- /chainer/utils/crop_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import tqdm 5 | 6 | from PIL import Image 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("source_dir", help="path to directory with source files") 12 | parser.add_argument("dest_dir", help="path to directory with cropped files") 13 | parser.add_argument("crop_window", type=int, nargs="*", default=[150, 0, 750, 300], help="crop window in (x, y, x, y) starting at top left and then bottom right") 14 | 15 | args = parser.parse_args() 16 | 17 | os.makedirs(args.dest_dir, exist_ok=True) 18 | for image_name in tqdm.tqdm(os.listdir(args.source_dir)): 19 | image_path = os.path.join(args.source_dir, image_name) 20 | with Image.open(image_path) as the_image: 21 | the_image = the_image.crop(args.crop_window) 22 | the_image.save(os.path.join(args.dest_dir, image_name)) 23 | -------------------------------------------------------------------------------- /chainer/utils/datatypes.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Size = namedtuple('Size', ['height', 'width']) 4 | -------------------------------------------------------------------------------- /chainer/utils/dict_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | 4 | import hunspell 5 | from collections import Counter 6 | from itertools import zip_longest 7 | 8 | from tqdm import tqdm 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser(description="tool that does extra dictionary correction on top of given predictions, including gt") 12 | parser.add_argument("evaluation_result", help="path to eval result file") 13 | parser.add_argument("dictionary", help="path to dictionary file") 14 | parser.add_argument("affix", help="path to affix file") 15 | parser.add_argument("--extra-dict", help="extra file containing words to add to dict") 16 | 17 | args = parser.parse_args() 18 | 19 | with open(args.evaluation_result) as eval_result_file: 20 | reader = csv.reader(eval_result_file) 21 | lines = [l for l in reader] 22 | 23 | hobj = hunspell.HunSpell(args.dictionary, args.affix) 24 | 25 | if args.extra_dict: 26 | with open(args.extra_dict) as extra_dict: 27 | for line in extra_dict: 28 | hobj.add(line) 29 | 30 | num_word_x_correct = Counter() 31 | num_word_x = Counter() 32 | correct_words = 0 33 | num_words = 0 34 | num_correct_lines = 0 35 | num_lines = 0 36 | 37 | for prediction, gt in tqdm(lines): 38 | corrected_words = [] 39 | for word in prediction.split(): 40 | corrected_word = word 41 | if not hobj.spell(word): 42 | suggestions = hobj.suggest(word) 43 | if len(suggestions) > 0: 44 | corrected_word = suggestions[0] 45 | 46 | corrected_words.append(corrected_word) 47 | 48 | line_correct = True 49 | for i, (corrected_word, gt_word) in enumerate(zip_longest(corrected_words, gt.split(), fillvalue=''), start=1): 50 | if corrected_word.lower() == gt_word.lower(): 51 | correct_words += 1 52 | num_word_x_correct[i] += 1 53 | else: 54 | line_correct = False 55 | 56 | num_words += 1 57 | num_word_x[i] += 1 58 | if line_correct: 59 | num_correct_lines += 1 60 | num_lines += 1 61 | 62 | print("Sequence Accuracy: {}".format(num_correct_lines / num_lines)) 63 | print("Word Accuracy: {}".format(correct_words / num_words)) 64 | print("Single word accuracies:") 65 | for i in range(1, len(num_word_x) + 1): 66 | print("Accuracy for Word {}: {}".format(i, num_word_x_correct[i] / num_word_x[i])) 67 | -------------------------------------------------------------------------------- /chainer/utils/intelligent_attribute_shifter.py: -------------------------------------------------------------------------------- 1 | import statistics 2 | from collections import deque 3 | 4 | import math 5 | 6 | from chainer import cuda, Variable 7 | from chainer.training import extension 8 | import chainer.training.trigger as trigger_module 9 | 10 | 11 | class IntelligentAttributeShifter(extension.Extension): 12 | 13 | def __init__(self, shift, attr='lr', trigger=(1, 'iteration'), min_delta=0.1): 14 | self.shift = shift 15 | self.attr = attr 16 | self.trigger = trigger_module.get_trigger(trigger) 17 | self.queue = deque(maxlen=5) 18 | self.min_delta = min_delta 19 | self.force_shift = False 20 | 21 | def __call__(self, trainer): 22 | if self.force_shift: 23 | self.shift_attribute(trainer) 24 | self.force_shift = False 25 | return 26 | 27 | if self.trigger(trainer): 28 | with cuda.get_device_from_id(trainer.updater.get_optimizer('main').target._device_id): 29 | loss = trainer.observation.get('validation/main/loss', None) 30 | if loss is None: 31 | return 32 | queue_data = loss.data if isinstance(loss, Variable) else loss 33 | self.queue.append(float(queue_data)) 34 | if len(self.queue) == self.queue.maxlen: 35 | # check whether we need to shift attribute 36 | deltas = [] 37 | rotated_queue = self.queue.copy() 38 | rotated_queue.rotate(-1) 39 | rotated_queue.pop() 40 | for element_1, element_2 in zip(self.queue, rotated_queue): 41 | deltas.append(abs(element_1 - element_2)) 42 | 43 | delta = sum(deltas) / len(deltas) 44 | # if change over last 5 validations was lower than min change shift attribute 45 | if delta < self.min_delta: 46 | self.shift_attribute(trainer) 47 | 48 | def shift_attribute(self, trainer): 49 | print("Shifting attribute {}".format(self.attr)) 50 | optimizer = trainer.updater.get_optimizer('main') 51 | current_value = getattr(optimizer, self.attr) 52 | shifted_value = current_value * self.shift 53 | setattr(optimizer, self.attr, shifted_value) 54 | self.queue.clear() 55 | -------------------------------------------------------------------------------- /chainer/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from chainer.training.extensions import LogReport 5 | 6 | 7 | class Logger(LogReport): 8 | 9 | def __init__(self, model_files, log_dir, keys=None, trigger=(1, 'epoch'), postprocess=None, log_name='log'): 10 | super(Logger, self).__init__(keys=keys, trigger=trigger, postprocess=postprocess, log_name=log_name) 11 | self.backup_model(model_files, log_dir) 12 | 13 | def backup_model(self, model_files, log_dir): 14 | if not os.path.exists(log_dir): 15 | os.makedirs(log_dir, exist_ok=True) 16 | for model_file in model_files: 17 | shutil.copy(model_file, log_dir) 18 | -------------------------------------------------------------------------------- /chainer/utils/multi_accuracy_classifier.py: -------------------------------------------------------------------------------- 1 | from chainer import reporter 2 | from chainer.functions import accuracy 3 | from chainer.functions.loss import softmax_cross_entropy 4 | from chainer.links import Classifier as OriginalClassifier 5 | 6 | 7 | class Classifier(OriginalClassifier): 8 | """ 9 | Classifier that is able to log two different accuracies 10 | """ 11 | 12 | def __init__(self, predictor, accuracy_types, 13 | lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy, provide_label_during_forward=False): 14 | super(Classifier, self).__init__(predictor, lossfun=lossfun, accfun=accfun) 15 | assert type(accuracy_types) is tuple, "accuracy_types must be a tuple of strings" 16 | self.accuracy_types = accuracy_types 17 | self.provide_label_during_forward = provide_label_during_forward 18 | 19 | def __call__(self, *args): 20 | """Computes the loss value for an input and label pair. 21 | 22 | It also computes accuracy and stores it to the attribute. 23 | 24 | Args: 25 | args (list of ~chainer.Variable): Input minibatch. 26 | 27 | The all elements of ``args`` but last one are features and 28 | the last element corresponds to ground truth labels. 29 | It feeds features to the predictor and compare the result 30 | with ground truth labels. 31 | 32 | Returns: 33 | ~chainer.Variable: Loss value. 34 | 35 | """ 36 | assert len(args) >= 2 37 | x = args[:-1] 38 | t = args[-1] 39 | self.y = None 40 | self.loss = None 41 | if self.provide_label_during_forward: 42 | self.y = self.predictor(*x, t) 43 | else: 44 | self.y = self.predictor(*x) 45 | self.loss = self.lossfun(self.y, t) 46 | reporter.report({'loss': self.loss}, self) 47 | if self.compute_accuracy: 48 | reported_accuracies = self.accfun(self.y, t) 49 | if len(self.accuracy_types) == 1: 50 | reported_accuracies = reported_accuracies, 51 | report = {accuracy_type: reported_accuracy 52 | for accuracy_type, reported_accuracy in zip(self.accuracy_types, reported_accuracies)} 53 | reporter.report(report, self) 54 | return self.loss 55 | -------------------------------------------------------------------------------- /chainer/utils/plotting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | class LogPlotter(object): 11 | 12 | def __init__(self, log_file): 13 | self.log_file = log_file 14 | self.train_iterations = {} 15 | self.test_iterations = {} 16 | 17 | def parse_log_file(self, start=0, end=None): 18 | with open(self.log_file) as log_file: 19 | log_data = json.load(log_file) 20 | if end is None: 21 | end = log_data[-1]['iteration'] 22 | 23 | events_to_plot = filter(lambda x: start <= x['iteration'] <= end, log_data) 24 | for event in events_to_plot: 25 | iteration = event.pop('iteration') 26 | self.train_iterations[iteration] = { 27 | key.rsplit('/')[-1]: event[key] for key in 28 | filter(lambda x: ('accuracy' in x or 'loss' in x) and 'validation' not in x, event) 29 | } 30 | 31 | test_keys = list(filter(lambda x: 'validation' in x, event)) 32 | if len(test_keys) > 0: 33 | self.test_iterations[iteration] = { 34 | key.rsplit('/')[-1]: event[key] for key in 35 | filter(lambda x: ('accuracy' in x or 'loss' in x), test_keys) 36 | } 37 | 38 | def plot(self, start=0, end=None): 39 | self.parse_log_file(start=start, end=end) 40 | 41 | metrics_to_plot = sorted(next(iter(self.train_iterations.values())).keys(), key=lambda x: x.rsplit('_')) 42 | fig, axes = plt.subplots(len(metrics_to_plot), sharex=True) 43 | 44 | x_train = list(sorted(self.train_iterations.keys())) 45 | x_test = list(sorted(self.test_iterations.keys())) 46 | 47 | for metric, axe in zip(metrics_to_plot, axes): 48 | axe.plot(x_train, [self.train_iterations[iteration][metric] for iteration in x_train], 'r.-', label='train') 49 | axe.plot(x_test, [self.test_iterations[iteration][metric] for iteration in x_test], 'g.-', label='test') 50 | 51 | axe.set_title(metric) 52 | 53 | box = axe.get_position() 54 | axe.set_position([box.x0, box.y0, box.width * 0.9, box.height]) 55 | 56 | axe.legend(bbox_to_anchor=(1, 0.5), loc='center left', fancybox=True, shadow=True) 57 | 58 | return fig 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser(description='tool to create plots of training') 63 | parser.add_argument("log_file", help="path to log file") 64 | parser.add_argument("-d", "--destination", dest='destination', help='where to save the resulting plot') 65 | parser.add_argument("-f", "--from", dest='start', default=0, type=int, help="start index from which you want to plot") 66 | parser.add_argument("-t", "--to", dest='end', type=int, help="index until which you want to plot (default: end)") 67 | 68 | args = parser.parse_args() 69 | 70 | plotter = LogPlotter(args.log_file) 71 | fig = plotter.plot(start=args.start, end=args.end) 72 | if args.destination is None: 73 | plt.show() 74 | else: 75 | fig.savefig(args.destination) -------------------------------------------------------------------------------- /chainer/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import copy 3 | import os 4 | import sys 5 | import time 6 | import datetime 7 | 8 | import six 9 | from chainer import reporter as reporter_module 10 | from chainer import variable 11 | from chainer.dataset import convert, concat_examples 12 | from chainer.training import extension 13 | from chainer.training import extensions 14 | import chainer.training.trigger as trigger_module 15 | from chainer.training.triggers import IntervalTrigger 16 | from chainer.training.extensions import Evaluator 17 | from chainer.training.extensions import util 18 | 19 | from .logger import Logger 20 | 21 | 22 | class AttributeUpdater(extension.Extension): 23 | 24 | def __init__(self, shift, attr='lr', trigger=(1, 'epoch')): 25 | self.shift = shift 26 | self.attr = attr 27 | self.trigger = trigger_module.get_trigger(trigger) 28 | 29 | def __call__(self, trainer): 30 | if self.trigger(trainer): 31 | optimizer = trainer.updater.get_optimizer('main') 32 | current_value = getattr(optimizer, self.attr) 33 | shifted_value = current_value * self.shift 34 | setattr(optimizer, self.attr, shifted_value) 35 | 36 | 37 | class TwoStateLearningRateShifter(extension.Extension): 38 | 39 | CONTINUOS_SHIFT_STATE = 0 40 | INTERVAL_BASED_SHIFT_STATE = 1 41 | 42 | def __init__(self, start_lr, states): 43 | self.start_lr = start_lr 44 | self.lr = start_lr 45 | self.states = states 46 | self.current_state = self.states.pop(0) 47 | self.start_epoch = 0 48 | self.start_iteration = 0 49 | self.set_triggers() 50 | 51 | def set_triggers(self): 52 | self.target_lr = self.current_state['target_lr'] 53 | self.update_trigger = trigger_module.get_trigger(self.current_state['update_trigger']) 54 | self.stop_trigger = trigger_module.get_trigger(self.current_state['stop_trigger']) 55 | self.phase_length, self.unit = self.current_state['stop_trigger'] 56 | 57 | def switch_state_if_necessary(self, trainer): 58 | if self.stop_trigger(trainer): 59 | if len(self.states) > 1: 60 | self.current_state = self.states.pop(0) 61 | self.set_triggers() 62 | self.start_lr = self.target_lr 63 | self.start_epoch = trainer.updater.epoch 64 | self.start_iteration = self.update_trigger.iteration 65 | 66 | def __call__(self, trainer): 67 | updater = trainer.updater 68 | optimizer = trainer.updater.get_optimizer('main') 69 | 70 | if self.update_trigger(trainer): 71 | if self.current_state['state'] == self.CONTINUOS_SHIFT_STATE: 72 | epoch = updater.epoch_detail 73 | 74 | if self.unit == 'iteration': 75 | interpolation_factor = (updater.iteration - self.start_iteration) / self.phase_length 76 | else: 77 | interpolation_factor = (epoch - self.start_epoch) / self.phase_length 78 | 79 | new_lr = (1 - interpolation_factor) * self.start_lr + interpolation_factor * self.target_lr 80 | self.lr = new_lr 81 | optimizer.lr = new_lr 82 | 83 | else: 84 | optimizer.lr = self.target_lr 85 | self.lr = optimizer.lr 86 | 87 | self.switch_state_if_necessary(trainer) 88 | 89 | 90 | class FastEvaluatorBase(Evaluator): 91 | 92 | def __init__(self, iterator, target, converter=convert.concat_examples, 93 | device=None, eval_hook=None, eval_func=None, num_iterations=200): 94 | super(FastEvaluatorBase, self).__init__( 95 | iterator, 96 | target, 97 | converter=converter, 98 | device=device, 99 | eval_hook=eval_hook, 100 | eval_func=eval_func 101 | ) 102 | self.num_iterations = num_iterations 103 | 104 | def evaluate(self): 105 | iterator = self._iterators['main'] 106 | target = self._targets['main'] 107 | eval_func = self.eval_func or target 108 | 109 | if self.eval_hook: 110 | self.eval_hook(self) 111 | it = copy.copy(iterator) 112 | summary = reporter_module.DictSummary() 113 | 114 | for _ in range(min(len(iterator.dataset) // iterator.batch_size, self.num_iterations)): 115 | batch = next(it, None) 116 | if batch is None: 117 | break 118 | 119 | observation = {} 120 | with reporter_module.report_scope(observation), chainer.using_config('train', False), chainer.using_config('enable_backprop', False): 121 | in_arrays = self.converter(batch, self.device) 122 | if isinstance(in_arrays, tuple): 123 | eval_func(*in_arrays) 124 | elif isinstance(in_arrays, dict): 125 | eval_func(**in_arrays) 126 | else: 127 | eval_func(in_arrays) 128 | 129 | summary.add(observation) 130 | 131 | return summary.compute_mean() 132 | 133 | 134 | def get_fast_evaluator(trigger_interval): 135 | return type('FastEvaluator', (FastEvaluatorBase,), dict(trigger=trigger_interval, name='fast_validation')) 136 | 137 | 138 | class EarlyStopIntervalTrigger(IntervalTrigger): 139 | 140 | def __init__(self, period, unit, curriculum): 141 | super().__init__(period, unit) 142 | self.curriculum = curriculum 143 | 144 | def __call__(self, trainer): 145 | fire = super().__call__(trainer) 146 | if self.curriculum.training_finished is True: 147 | fire = True 148 | return fire 149 | 150 | 151 | def get_trainer(net, updater, log_dir, print_fields, curriculum=None, extra_extensions=(), epochs=10, snapshot_interval=20000, print_interval=100, postprocess=None, do_logging=True, model_files=()): 152 | if curriculum is None: 153 | trainer = chainer.training.Trainer( 154 | updater, 155 | (epochs, 'epoch'), 156 | out=log_dir, 157 | ) 158 | else: 159 | trainer = chainer.training.Trainer( 160 | updater, 161 | EarlyStopIntervalTrigger(epochs, 'epoch', curriculum), 162 | out=log_dir, 163 | ) 164 | 165 | # dump computational graph 166 | trainer.extend(extensions.dump_graph('main/loss')) 167 | 168 | # also observe learning rate 169 | observe_lr_extension = chainer.training.extensions.observe_lr() 170 | observe_lr_extension.trigger = (print_interval, 'iteration') 171 | trainer.extend(observe_lr_extension) 172 | 173 | # Take snapshots 174 | trainer.extend( 175 | extensions.snapshot(filename="trainer_snapshot"), 176 | trigger=lambda trainer: 177 | trainer.updater.is_new_epoch or 178 | (trainer.updater.iteration > 0 and trainer.updater.iteration % snapshot_interval == 0) 179 | ) 180 | 181 | if do_logging: 182 | # write all statistics to a file 183 | trainer.extend(Logger(model_files, log_dir, keys=print_fields, trigger=(print_interval, 'iteration'), postprocess=postprocess)) 184 | 185 | # print some interesting statistics 186 | trainer.extend(extensions.PrintReport( 187 | print_fields, 188 | log_report='Logger', 189 | )) 190 | 191 | # Progressbar!! 192 | trainer.extend(extensions.ProgressBar(update_interval=1)) 193 | 194 | for extra_extension, trigger in extra_extensions: 195 | trainer.extend(extra_extension, trigger=trigger) 196 | 197 | return trainer 198 | 199 | 200 | def add_default_arguments(parser): 201 | parser.add_argument("log_dir", help='directory where generated models and logs shall be stored') 202 | parser.add_argument('-b', '--batch-size', dest='batch_size', type=int, required=True, 203 | help="Number of images per training batch") 204 | parser.add_argument('-g', '--gpus', type=int, nargs="*", default=[], help="Ids of GPU to use [default: (use cpu)]") 205 | parser.add_argument('-e', '--epochs', type=int, default=20, help="Number of epochs to train [default: 20]") 206 | parser.add_argument('-r', '--resume', help="path to previously saved state of trained model from which training shall resume") 207 | parser.add_argument('-si', '--snapshot-interval', dest='snapshot_interval', type=int, default=20000, 208 | help="number of iterations after which a snapshot shall be taken [default: 20000]") 209 | parser.add_argument('-ln', '--log-name', dest='log_name', default='training', help="name of the log folder") 210 | parser.add_argument('-lr', '--learning-rate', dest='learning_rate', type=float, default=0.01, 211 | help="initial learning rate [default: 0.01]") 212 | parser.add_argument('-li', '--log-interval', dest='log_interval', type=int, default=100, 213 | help="number of iterations after which an update shall be logged [default: 100]") 214 | parser.add_argument('--lr-step', dest='learning_rate_step_size', type=float, default=0.1, 215 | help="Step size for decreasing learning rate [default: 0.1]") 216 | parser.add_argument('-t', '--test-interval', dest='test_interval', type=int, default=1000, 217 | help="number of iterations after which testing should be performed [default: 1000]") 218 | parser.add_argument('--test-iterations', dest='test_iterations', type=int, default=200, 219 | help="number of test iterations [default: 200]") 220 | parser.add_argument("-dr", "--dropout-ratio", dest='dropout_ratio', default=0.5, type=float, 221 | help="ratio for dropout layers") 222 | 223 | return parser 224 | 225 | 226 | def get_concat_and_pad_examples(padding=-10000): 227 | def concat_and_pad_examples(batch, device=None): 228 | return concat_examples(batch, device=device, padding=padding) 229 | 230 | return concat_and_pad_examples 231 | 232 | 233 | def concat_and_pad_examples(batch, device=None, padding=-10000): 234 | return concat_examples(batch, device=device, padding=padding) 235 | 236 | 237 | def get_definition_filepath(obj): 238 | return __import__(obj.__module__, fromlist=obj.__module__.split('.')[:1]).__file__ 239 | 240 | 241 | def get_definition_filename(obj): 242 | return os.path.basename(get_definition_filepath(obj)) 243 | 244 | -------------------------------------------------------------------------------- /datasets/fsns/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/datasets/fsns/__init__.py -------------------------------------------------------------------------------- /datasets/fsns/change_file_names.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | 4 | from tqdm import tqdm 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser(description='tool that removes a given prefix from the file names in the given csv file') 8 | parser.add_argument("csv_file", help="path to file where paths shall be changed") 9 | parser.add_argument("prefix", help='prefix to remove') 10 | 11 | args = parser.parse_args() 12 | 13 | with open(args.csv_file) as csv_file: 14 | reader = csv.reader(csv_file, delimiter='\t') 15 | lines = [l for l in reader] 16 | 17 | for line in tqdm(lines[1:]): 18 | path = line[0] 19 | path = path[len(args.prefix):] 20 | line[0] = path 21 | 22 | with open(args.csv_file, 'w') as csv_file: 23 | writer = csv.writer(csv_file, delimiter='\t') 24 | writer.writerows(lines) 25 | -------------------------------------------------------------------------------- /datasets/fsns/download_fsns.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import urllib.request 4 | 5 | 6 | BASE_URL = "http://download.tensorflow.org/data/fsns-20160927/" 7 | 8 | SETS = [ 9 | ('test', 0, 64), 10 | ('train', 0, 512), 11 | ('validation', 0, 64), 12 | ] 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser(description='tool that downloads fsns dataset') 16 | parser.add_argument("destination_dir", help='path to destination directory') 17 | 18 | args = parser.parse_args() 19 | 20 | for set_name, start_part, num_parts in SETS: 21 | dest_dir = os.path.join(args.destination_dir, set_name) 22 | os.makedirs(dest_dir, exist_ok=True) 23 | 24 | for part in range(start_part, num_parts): 25 | file_name = "{set_name}-{part:0>5}-of-{num_parts:0>5}".format( 26 | set_name=set_name, 27 | part=part, 28 | num_parts=num_parts, 29 | ) 30 | 31 | url = "{base}{set_name}/{file_name}".format( 32 | base=BASE_URL, 33 | set_name=set_name, 34 | file_name=file_name, 35 | ) 36 | 37 | file_size = int(urllib.request.urlopen(url).info()['Content-Length']) 38 | 39 | if (not os.path.exists(os.path.join(dest_dir, file_name)) or 40 | os.stat(os.path.join(dest_dir, file_name)).st_size != file_size): 41 | print("downloading {}".format(file_name)) 42 | with urllib.request.urlopen(url) as url_data, open(os.path.join(dest_dir, file_name), 'wb') as f: 43 | file_size = int(url_data.info()['Content-Length']) 44 | 45 | downloaded = 0 46 | block_size = 8192 47 | while True: 48 | buffer = url_data.read(block_size) 49 | if not buffer: 50 | break 51 | 52 | downloaded += len(buffer) 53 | f.write(buffer) 54 | print("Got: {:>10} of {:>10} bytes".format(downloaded, file_size), end='\r') 55 | 56 | print("{}".format(" " * 100), end='\r') 57 | else: 58 | print('File already found at:{location}, Continuing...'.format( 59 | location=os.path.join(dest_dir, file_name))) 60 | continue 61 | -------------------------------------------------------------------------------- /datasets/fsns/extract_words.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import os 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def extract_words_from_gt(gt_file, extracted_words): 10 | with open(gt_file) as gt: 11 | reader = csv.reader(gt, delimiter='\t') 12 | lines = [l for l in reader] 13 | 14 | for line in tqdm(lines): 15 | labels = line[1:] 16 | text_line = ''.join([chr(char_map[x]) for x in labels]) 17 | words = text_line.strip(chr(char_map[args.blank_label])).split() 18 | 19 | for word in words: 20 | extracted_words.add(word) 21 | 22 | return extracted_words 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser(description="extract all words used in fsns dataset") 27 | parser.add_argument("gt", help='path to fsns gt dir') 28 | parser.add_argument("dest", help='destination dict') 29 | parser.add_argument("char_map", help='path to char map') 30 | parser.add_argument('--blank-label', default='133', help='class number of blank label') 31 | 32 | args = parser.parse_args() 33 | 34 | with open(args.char_map) as the_char_map: 35 | char_map = json.load(the_char_map) 36 | reverse_char_map = {v: k for k, v in char_map.items()} 37 | 38 | gt_files = filter(lambda x: os.path.splitext(x)[-1] == '.csv', os.listdir(args.gt)) 39 | 40 | words = set() 41 | for gt_file in gt_files: 42 | words = extract_words_from_gt(os.path.join(args.gt, gt_file), words) 43 | 44 | with open(args.dest, 'w') as destination: 45 | print(len(words), file=destination) 46 | for word in words: 47 | print(word, file=destination) 48 | -------------------------------------------------------------------------------- /datasets/fsns/fonts/DejaVuSansMono-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/datasets/fsns/fonts/DejaVuSansMono-Bold.ttf -------------------------------------------------------------------------------- /datasets/fsns/fonts/DejaVuSansMono.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/see/2ad5159027759a9f77dfc3b75686a5b8266c5282/datasets/fsns/fonts/DejaVuSansMono.ttf -------------------------------------------------------------------------------- /datasets/fsns/fsns_char_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": 9250, 3 | "1": 108, 4 | "2": 8216, 5 | "3": 233, 6 | "4": 116, 7 | "5": 101, 8 | "6": 105, 9 | "7": 110, 10 | "8": 115, 11 | "9": 120, 12 | "10": 103, 13 | "11": 117, 14 | "12": 111, 15 | "13": 49, 16 | "14": 56, 17 | "15": 55, 18 | "16": 48, 19 | "17": 8212, 20 | "18": 46, 21 | "19": 112, 22 | "20": 97, 23 | "21": 114, 24 | "22": 232, 25 | "23": 100, 26 | "24": 99, 27 | "25": 86, 28 | "26": 118, 29 | "27": 98, 30 | "28": 109, 31 | "29": 41, 32 | "30": 67, 33 | "31": 122, 34 | "32": 83, 35 | "33": 121, 36 | "34": 44, 37 | "35": 107, 38 | "36": 201, 39 | "37": 65, 40 | "38": 104, 41 | "39": 69, 42 | "40": 187, 43 | "41": 68, 44 | "42": 47, 45 | "43": 72, 46 | "44": 77, 47 | "45": 40, 48 | "46": 71, 49 | "47": 80, 50 | "48": 231, 51 | "49": 82, 52 | "50": 102, 53 | "51": 8221, 54 | "52": 50, 55 | "53": 106, 56 | "54": 124, 57 | "55": 78, 58 | "56": 54, 59 | "57": 176, 60 | "58": 53, 61 | "59": 84, 62 | "60": 79, 63 | "61": 85, 64 | "62": 51, 65 | "63": 37, 66 | "64": 57, 67 | "65": 113, 68 | "66": 90, 69 | "67": 66, 70 | "68": 75, 71 | "69": 119, 72 | "70": 87, 73 | "71": 58, 74 | "72": 52, 75 | "73": 76, 76 | "74": 70, 77 | "75": 93, 78 | "76": 239, 79 | "77": 73, 80 | "78": 74, 81 | "79": 228, 82 | "80": 238, 83 | "81": 59, 84 | "82": 224, 85 | "83": 234, 86 | "84": 88, 87 | "85": 252, 88 | "86": 89, 89 | "87": 244, 90 | "88": 61, 91 | "89": 43, 92 | "90": 92, 93 | "91": 123, 94 | "92": 125, 95 | "93": 95, 96 | "94": 81, 97 | "95": 339, 98 | "96": 241, 99 | "97": 42, 100 | "98": 33, 101 | "99": 220, 102 | "100": 226, 103 | "101": 199, 104 | "102": 338, 105 | "103": 251, 106 | "104": 63, 107 | "105": 36, 108 | "106": 235, 109 | "107": 171, 110 | "108": 8364, 111 | "109": 38, 112 | "110": 60, 113 | "111": 230, 114 | "112": 35, 115 | "113": 174, 116 | "114": 194, 117 | "115": 200, 118 | "116": 62, 119 | "117": 91, 120 | "118": 198, 121 | "119": 249, 122 | "120": 206, 123 | "121": 212, 124 | "122": 255, 125 | "123": 192, 126 | "124": 202, 127 | "125": 64, 128 | "126": 207, 129 | "127": 169, 130 | "128": 203, 131 | "129": 217, 132 | "130": 163, 133 | "131": 376, 134 | "132": 219, 135 | "133": 32 136 | } 137 | -------------------------------------------------------------------------------- /datasets/fsns/render_text_on_signs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | 5 | import random 6 | 7 | import os 8 | import time 9 | from PIL import Image 10 | from PIL import ImageDraw 11 | from PIL import ImageFont 12 | from functools import lru_cache 13 | 14 | 15 | SUPPORTED_IMAGES = ['.jpg', '.png', '.jpeg'] 16 | FONT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), 'fonts')) 17 | FONTS = [os.path.abspath(os.path.join(FONT_DIR, font)) for font in os.listdir(FONT_DIR)] 18 | MAX_IMAGES_PER_DIR = 1000 19 | 20 | 21 | def get_image_paths(dir): 22 | all_images = [] 23 | for root, _, files in os.walk(dir): 24 | files = [os.path.join(root, f) for f in filter(lambda x: os.path.splitext(x)[-1].lower() in SUPPORTED_IMAGES, files)] 25 | all_images.extend(files) 26 | 27 | return all_images 28 | 29 | 30 | @lru_cache(maxsize=1024) 31 | def get_image(image_path): 32 | return Image.open(image_path) 33 | 34 | 35 | @lru_cache(maxsize=1024) 36 | def find_font_size(draw, text_lines, max_width, max_height, spacing): 37 | # start with a default font size that should be large enough to be too large 38 | font_size = 35 39 | 40 | # reload the font until the word fits or the font size would be too small 41 | while True: 42 | font = ImageFont.truetype(random.choice(FONTS), size=font_size, encoding='unic') 43 | text_width, text_height = draw.multiline_textsize(text_lines, font, spacing=spacing) 44 | 45 | if text_width <= max_width and text_height <= max_height: 46 | return font, (text_width, text_height) 47 | 48 | font_size -= 1 49 | 50 | if font_size <= 1: 51 | raise ValueError('Can not draw Text on given image') 52 | 53 | 54 | def random_crop(image, crop_width, crop_height): 55 | left = min(random.choice(range(image.width)), image.width - crop_width) 56 | top = min(random.choice(range(image.height)), image.height - crop_height) 57 | 58 | return image.crop(( 59 | left, 60 | top, 61 | left + crop_width, 62 | top + crop_height, 63 | )) 64 | 65 | 66 | def save_image(index, image, base_dest_dir, split_ratio): 67 | 68 | def get_subdir(): 69 | return os.path.join( 70 | "{:04d}".format(index // (MAX_IMAGES_PER_DIR * MAX_IMAGES_PER_DIR)), 71 | "{:04d}".format(index // MAX_IMAGES_PER_DIR) 72 | ) 73 | 74 | split_dir = "validation" if random.randint(1, 100) < split_ratio else "train" 75 | dest_dir = os.path.join(base_dest_dir, split_dir, get_subdir()) 76 | os.makedirs(dest_dir, exist_ok=True) 77 | 78 | image_path = "{}.png".format(os.path.join(dest_dir, str(index))) 79 | image.save(image_path) 80 | 81 | return image_path, split_dir == 'train' 82 | 83 | 84 | def get_labels(words, char_map, blank_label, max_length, max_textlines): 85 | all_labels = [] 86 | 87 | for word in words: 88 | labels = [char_map[ord(char)] for char in word] 89 | labels.extend([blank_label] * (max_length - len(labels))) 90 | all_labels.append(labels) 91 | 92 | if len(all_labels) < max_textlines: 93 | all_labels.extend([[blank_label] * max_length for _ in range(max_textlines - len(all_labels))]) 94 | 95 | return [label for labels in all_labels for label in labels] 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser(description='Tool that renders text on street sign images and puts them into a nice context') 100 | parser.add_argument('wordlist', help='path to wordlist file') 101 | parser.add_argument('sign_dir', help='path to a directory containing empty signs to render text on') 102 | parser.add_argument('background_dir', help='path to a directory holding dirs/images of background images where signs will be placed on') 103 | parser.add_argument('destination', help='where to put the generated samples') 104 | parser.add_argument('char_map', help='path to json file that contains a char map') 105 | parser.add_argument('num_samples', type=int, help='number of samples to create') 106 | parser.add_argument('--blank-label', type=int, default=133, help='index of blank label [default: 133]') 107 | parser.add_argument('--max-length', type=int, default=10, help='max length of word on signs [default: 10]') 108 | parser.add_argument('--label-length', type=int, default=37, help='label length for each rendered text line [default: 37]') 109 | parser.add_argument('--image-size', type=int, default=150, help='size of resulting images [default: 150 x 150]') 110 | parser.add_argument('--split-ratio', type=int, default=15, help='percentage of samples to use for validation [default: 15]') 111 | parser.add_argument('--max-textlines', type=int, default=3, help='maximum number of text lines per rendered image [default: 3]') 112 | parser.add_argument('--min-textlines', type=int, default=1, help='minimum number of text lines per rendered image [default: 1]') 113 | 114 | args = parser.parse_args() 115 | 116 | print("loading sign images") 117 | sign_image_paths = get_image_paths(args.sign_dir) 118 | sign_images = [get_image(path) for path in sign_image_paths] 119 | 120 | print("getting background images") 121 | background_image_paths = get_image_paths(args.background_dir) 122 | 123 | print("loading wordlist") 124 | with open(args.wordlist) as the_wordlist: 125 | wordlist = [word.strip() for word in filter(lambda x: len(x) <= args.max_length and "'" not in x, the_wordlist)] 126 | 127 | print("opening char map") 128 | with open(args.char_map) as the_map: 129 | char_map = json.load(the_map) 130 | reverse_char_map = {v: k for k, v in char_map.items()} 131 | 132 | print("starting generation of samples") 133 | os.makedirs(args.destination, exist_ok=True) 134 | with open(os.path.join(args.destination, 'train.csv'), 'w') as train_labels, open(os.path.join(args.destination, 'val.csv'), 'w') as val_labels: 135 | train_writer = csv.writer(train_labels, delimiter='\t') 136 | val_writer = csv.writer(val_labels, delimiter='\t') 137 | 138 | i = 0 139 | start_time = time.time() 140 | while i <= args.num_samples: 141 | num_textlines = random.randint(args.min_textlines, args.max_textlines) 142 | words = [random.choice(wordlist) for _ in range(num_textlines)] 143 | 144 | sign_image = random.choice(sign_images).copy() 145 | draw = ImageDraw.Draw(sign_image) 146 | 147 | background_image = get_image(random.choice(background_image_paths)) 148 | background_image = random_crop(background_image, args.image_size, args.image_size) 149 | background_image = background_image.convert('RGBA') 150 | 151 | width, height = sign_image.size 152 | max_width = width * 0.7 153 | max_height = height * 0.7 154 | spacing = 5 155 | 156 | text_lines = '\n'.join(words) 157 | try: 158 | font, text_size = find_font_size(draw, text_lines, max_width, max_height, spacing) 159 | except ValueError: 160 | continue 161 | 162 | draw.multiline_text( 163 | (width // 2 - text_size[0] // 2, height // 2 - text_size[1] // 2), 164 | text_lines, 165 | fill='white', 166 | font=font, 167 | spacing=spacing, 168 | align='center', 169 | ) 170 | 171 | paste_box = ( 172 | args.image_size // 2 - sign_image.width // 2, 173 | args.image_size // 2 - sign_image.height // 2, 174 | ) 175 | 176 | overlay_image = Image.new("RGBA", (args.image_size, args.image_size), (255, 255, 255, 0)) 177 | overlay_image.paste(sign_image, box=paste_box) 178 | 179 | background_image = Image.alpha_composite(background_image, overlay_image) 180 | 181 | sample_path, is_train = save_image(i, background_image, args.destination, args.split_ratio) 182 | labels = get_labels(words, reverse_char_map, args.blank_label, args.label_length, args.max_textlines) 183 | 184 | if is_train: 185 | train_writer.writerow((sample_path, *labels)) 186 | else: 187 | val_writer.writerow((sample_path, *labels)) 188 | 189 | i += 1 190 | if i % 1000 == 0: 191 | print("Generated {} samples, took {:4.5} seconds".format(i, time.time() - start_time)) 192 | start_time = time.time() 193 | -------------------------------------------------------------------------------- /datasets/fsns/slice_fsns_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | 4 | import os 5 | from PIL import Image 6 | 7 | 8 | def find_way_to_common_dir(common_dir, dir): 9 | dirs = [] 10 | base_dir = dir 11 | while True: 12 | base_dir, dirname = os.path.split(base_dir) 13 | dirs.append(dirname) 14 | if base_dir in common_dir: 15 | break 16 | 17 | return reversed(dirs) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser(description='Tool that slice one FSNS image to 4 slices') 22 | parser.add_argument('label_file', help='path to label file that has path to image + labels') 23 | parser.add_argument('destination_dir', help='path to dir where resulting images shall be put') 24 | parser.add_argument('-b', '--base-dir', help='path to base dir of every file in label file') 25 | 26 | args = parser.parse_args() 27 | 28 | label_file_name = os.path.basename(args.label_file) 29 | label_dir = os.path.dirname(args.label_file) 30 | 31 | with open(args.label_file) as label_file, open(os.path.join(args.destination_dir, label_file_name), 'w') as dest_file: 32 | reader = csv.reader(label_file, delimiter='\t') 33 | writer = csv.writer(dest_file, delimiter='\t') 34 | 35 | for idx, info in enumerate(reader): 36 | image_path = info[0] 37 | labels = info[1:] 38 | image = Image.open(image_path) 39 | image = image.convert('RGB') 40 | 41 | save_dir = os.path.join(args.destination_dir, *find_way_to_common_dir(label_dir, os.path.dirname(image_path))) 42 | os.makedirs(save_dir, exist_ok=True) 43 | 44 | image_name = os.path.splitext(os.path.basename(image_path))[0] 45 | 46 | for i in range(4): 47 | part_image = image.crop((i * 150, 0, (i + 1) * 150, 150)) 48 | file_name = "{}_{}.png".format(os.path.join(save_dir, image_name), i) 49 | part_image.save(file_name) 50 | writer.writerow((file_name, *labels)) 51 | 52 | print("done with {:6} files".format(idx), end='\r') 53 | -------------------------------------------------------------------------------- /datasets/fsns/swap_classes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | 4 | from tqdm import tqdm 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser(description='swap to classes in gt') 8 | parser.add_argument('gt_file', help='path to gt file where class labels shall be swapped') 9 | parser.add_argument('destination', help='path to file with swapped labels') 10 | parser.add_argument('class1', help='first label to swap') 11 | parser.add_argument('class2', help='second label to swap') 12 | 13 | args = parser.parse_args() 14 | 15 | lines = [] 16 | with open(args.gt_file) as gt_file: 17 | reader = csv.reader(gt_file, delimiter='\t') 18 | 19 | for line in tqdm(reader): 20 | new_line = [] 21 | for label in line: 22 | if label == args.class1: 23 | new_line.append(args.class2) 24 | elif label == args.class2: 25 | new_line.append(args.class1) 26 | else: 27 | new_line.append(label) 28 | lines.append(new_line) 29 | 30 | with open(args.destination, 'w') as destination: 31 | writer = csv.writer(destination, delimiter='\t') 32 | writer.writerows(lines) 33 | -------------------------------------------------------------------------------- /datasets/fsns/tfrecord_to_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | 4 | import os 5 | import re 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from PIL import Image 11 | 12 | 13 | FILENAME_PATTERN = re.compile(r'.+-(\d+)-of-(\d+)') 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser(description='tool that takes tfrecord files and extracts all images + labels from it') 18 | parser.add_argument('tfrecord_dir', help='path to directory containing tfrecord files') 19 | parser.add_argument('destination_dir', help='path to dir where resulting images shall be saved') 20 | parser.add_argument('stage', help='stage of training these files are for [e.g. train]') 21 | 22 | args = parser.parse_args() 23 | 24 | os.makedirs(args.destination_dir, exist_ok=True) 25 | 26 | tfrecord_files = os.listdir(args.tfrecord_dir) 27 | tfrecord_files = sorted(tfrecord_files, key=lambda x: int(FILENAME_PATTERN.match(x).group(1))) 28 | 29 | with open(os.path.join(args.destination_dir, '{}.csv'.format(args.stage)), 'w') as label_file: 30 | writer = csv.writer(label_file, delimiter='\t') 31 | 32 | for tfrecord_file in tfrecord_files: 33 | tfrecord_filename = os.path.join(args.tfrecord_dir, tfrecord_file) 34 | 35 | file_id = FILENAME_PATTERN.match(tfrecord_file).group(1) 36 | dest_dir = os.path.join(args.destination_dir, args.stage, file_id) 37 | os.makedirs(dest_dir, exist_ok=True) 38 | 39 | record_iterator = tf.python_io.tf_record_iterator(path=tfrecord_filename) 40 | 41 | for idx, string_record in enumerate(record_iterator): 42 | example = tf.train.Example() 43 | example.ParseFromString(string_record) 44 | 45 | labels = example.features.feature['image/class'].int64_list.value 46 | img_string = example.features.feature['image/encoded'].bytes_list.value[0] 47 | 48 | file_name = os.path.join(dest_dir, '{}.png'.format(idx)) 49 | with open(file_name, 'wb') as f: 50 | f.write(img_string) 51 | 52 | label_file_data = [file_name] 53 | label_file_data.extend(labels) 54 | writer.writerow(label_file_data) 55 | print("recovered {:0>6} files".format(idx), end='\r') 56 | -------------------------------------------------------------------------------- /datasets/fsns/transform_back_to_single_line.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import json 4 | 5 | import csv 6 | 7 | import itertools 8 | import tqdm as tqdm 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser(description="Tool that converts fsns gt to per word gt") 12 | parser.add_argument('fsns_gt', help='apth to original fsns gt file') 13 | parser.add_argument('char_map', help='path to fsns char map') 14 | parser.add_argument('destination', help='path to destination gt file') 15 | parser.add_argument('--max-words', type=int, default=6, help='max words per image') 16 | parser.add_argument('--max-chars', type=int, default=21, help='max characters per word') 17 | parser.add_argument('--blank-label', default='133', help='class number of blank label') 18 | parser.add_argument('--convert-to-single-line', action='store_true', default=False, help='convert gt to a single line gt') 19 | 20 | args = parser.parse_args() 21 | 22 | with open(args.char_map) as c_map: 23 | char_map = json.load(c_map) 24 | reverse_char_map = {v: k for k, v in char_map.items()} 25 | 26 | with open(args.fsns_gt) as fsns_gt: 27 | reader = csv.reader(fsns_gt, delimiter='\t') 28 | lines = [l for l in reader] 29 | 30 | text_lines = [] 31 | for line in tqdm.tqdm(lines): 32 | text = ''.join(map(lambda x: chr(char_map[x]), line[1:])) 33 | words = text.split(chr(char_map[args.blank_label])) 34 | words = filter(lambda x: len(x) > 0, words) 35 | 36 | text_line = ' '.join(words) 37 | 38 | # pad resulting data with blank label 39 | text_line += ''.join(chr(char_map[args.blank_label]) * (37 - len(text_line))) 40 | text_line = [reverse_char_map[ord(character)] for character in text_line] 41 | 42 | text_lines.append([line[0], *text_line]) 43 | 44 | with open(args.destination, 'w') as dest: 45 | writer = csv.writer(dest, delimiter='\t') 46 | writer.writerows(text_lines) 47 | -------------------------------------------------------------------------------- /datasets/fsns/transform_gt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | 5 | import itertools 6 | from tqdm import tqdm 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser(description="Tool that converts fsns gt to per word gt") 11 | parser.add_argument('fsns_gt', help='path to original fsns gt file') 12 | parser.add_argument('char_map', help='path to fsns char map') 13 | parser.add_argument('destination', help='path to destination gt file') 14 | parser.add_argument('--max-words', type=int, default=6, help='max words per image') 15 | parser.add_argument('--min-words', type=int, default=1, help='min words per image') 16 | parser.add_argument('--max-chars', type=int, default=21, help='max characters per word') 17 | parser.add_argument('--blank-label', default='133', help='class number of blank label') 18 | parser.add_argument('--word-gt', action='store_true', default=False, help='input gt is word level gt') 19 | 20 | args = parser.parse_args() 21 | 22 | with open(args.char_map) as c_map: 23 | char_map = json.load(c_map) 24 | reverse_char_map = {v: k for k, v in char_map.items()} 25 | 26 | with open(args.fsns_gt) as fsns_gt: 27 | reader = csv.reader(fsns_gt, delimiter='\t') 28 | lines = [l for l in reader] 29 | 30 | text_lines = [] 31 | for line in tqdm(lines): 32 | text = ''.join(map(lambda x: chr(char_map[x]), line[1:])) 33 | if args.word_gt: 34 | text = text.split(chr(char_map[args.blank_label])) 35 | text = filter(lambda x: len(x) > 0, text) 36 | else: 37 | text = text.strip(chr(char_map[args.blank_label])) 38 | text = text.split() 39 | 40 | words = [] 41 | for t in text: 42 | t = list(map(lambda x: reverse_char_map[ord(x)], t)) 43 | t.extend([args.blank_label] * (args.max_chars - len(t))) 44 | words.append(t) 45 | 46 | if len(words) > args.max_words or len(words) < args.min_words: 47 | continue 48 | elif any([len(word) > args.max_chars for word in words]): 49 | continue 50 | 51 | words.extend([[args.blank_label] * args.max_chars for _ in range(args.max_words - len(words))]) 52 | 53 | text_lines.append([line[0]] + list(itertools.chain(*words))) 54 | 55 | with open(args.destination, 'w') as dest: 56 | writer = csv.writer(dest, delimiter='\t') 57 | writer.writerow([args.max_words, args.max_chars]) 58 | writer.writerows(text_lines) 59 | -------------------------------------------------------------------------------- /datasets/svhn/create_svhn_csv_gt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import os 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser(description="take svhn gt json and create csv file") 8 | parser.add_argument('json_gt', help='path to svhn json gt') 9 | parser.add_argument('destination', help='path to resulting gt file in csv format') 10 | parser.add_argument('--max-length', type=int, help='maximum length of labels [default: take max length found]') 11 | 12 | args = parser.parse_args() 13 | 14 | with open(args.json_gt) as gt_file: 15 | gt = json.load(gt_file) 16 | 17 | # read information for all files 18 | base_dir = os.path.abspath(os.path.dirname(args.json_gt)) 19 | file_info = [] 20 | for image_data in gt: 21 | filename = os.path.join(base_dir, image_data['filename']) 22 | labels = [int(b['label']) if int(b['label']) != 10 else 0 for b in image_data['boxes']] 23 | file_info.append((filename, labels)) 24 | 25 | # determine max length of labels 26 | if args.max_length is None: 27 | max_length = max(map(lambda x: len(x[1]), file_info)) 28 | else: 29 | max_length = args.max_length 30 | 31 | # pad and filter labels 32 | filtered_infos = [] 33 | for file_path, labels in file_info: 34 | if len(labels) > max_length: 35 | continue 36 | elif len(labels) < max_length: 37 | values_to_pad = [10] * (max_length - len(labels)) 38 | labels.extend(values_to_pad) 39 | file_info = [file_path] + labels 40 | filtered_infos.append(file_info) 41 | 42 | # write to csv file 43 | with open(args.destination, 'w') as destination: 44 | writer = csv.writer(destination, delimiter='\t') 45 | writer.writerows(filtered_infos) 46 | -------------------------------------------------------------------------------- /datasets/svhn/create_svhn_dataset_4_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | 5 | import itertools 6 | import os 7 | import random 8 | 9 | import numpy as np 10 | from PIL import Image 11 | 12 | SUPPORTED_IMAGE_TYPES = ['.png'] 13 | 14 | 15 | class SVHNDatasetCreator: 16 | 17 | def __init__(self, image_dir, image_size, image_columns, image_rows, destination_dir, dataset_name, max_label_length, label_pad_value=10): 18 | self.image_dir = image_dir 19 | self.image_size = image_size 20 | self.image_columns = image_columns 21 | self.images_per_sample = image_columns * image_rows 22 | self.destination_dir = destination_dir 23 | self.dataset_name = dataset_name 24 | self.max_label_length = max_label_length 25 | 26 | self.destination_image_dir = os.path.join(destination_dir, dataset_name) 27 | os.makedirs(self.destination_image_dir, exist_ok=True) 28 | self.label_pad_value = label_pad_value 29 | 30 | self.interpolation_area = 0.15 31 | self.image_rows = image_rows 32 | self.dest_image_size = ( 33 | int(image_rows * self.image_size - (image_rows - 1) * self.image_size * self.interpolation_area), 34 | int(image_columns * self.image_size - (image_columns - 1) * self.image_size * self.interpolation_area), 35 | ) 36 | self.overlap_map = np.zeros(self.dest_image_size, dtype=np.uint32) 37 | 38 | def blend_at_intersection(self, images): 39 | pre_blend_images = [] 40 | for idx, image in enumerate(images): 41 | x_start, y_start = self.get_start_indices(idx) 42 | image_data = np.asarray(image) 43 | overlap_data = self.overlap_map[y_start:y_start + self.image_size, x_start:x_start + self.image_size] 44 | alpha_image = image_data.copy() 45 | alpha_image[:, :, 3] = image_data[:, :, 3] / overlap_data 46 | pre_blend_image = np.zeros(self.dest_image_size + (4,), dtype=np.uint8) 47 | pre_blend_image[y_start:y_start + self.image_size, x_start:x_start + self.image_size] = alpha_image 48 | pre_blend_images.append(Image.fromarray(pre_blend_image, 'RGBA')) 49 | 50 | dest_image = pre_blend_images[0] 51 | for blend_image in pre_blend_images[1:]: 52 | dest_image = Image.alpha_composite(dest_image, blend_image) 53 | 54 | return dest_image 55 | 56 | def interpolate_at_intersection(self, images): 57 | 58 | def interpolate_width(data): 59 | interpolation_start = data.shape[1] - num_interpolation_pixels 60 | for i in range(num_interpolation_pixels): 61 | data[:, interpolation_start + i, 3] *= (num_interpolation_pixels - i) / num_interpolation_pixels 62 | return data 63 | 64 | def interpolate_height(data): 65 | interpolation_start = data.shape[0] - num_interpolation_pixels 66 | for i in range(num_interpolation_pixels): 67 | data[interpolation_start + i, :, 3] *= (num_interpolation_pixels - i) / num_interpolation_pixels 68 | return data 69 | 70 | pre_blend_images = [] 71 | num_interpolation_pixels = int(self.image_size * self.interpolation_area) 72 | for y_idx in range(self.image_rows): 73 | for x_idx in range(self.image_columns): 74 | image_idx = y_idx * self.image_columns + x_idx 75 | image = images[image_idx] 76 | x_start, y_start = self.get_start_indices(y_idx, x_idx) 77 | image_data = np.asarray(image).copy().astype(np.float64) 78 | 79 | # create horizontal alpha mask 80 | if x_idx < self.image_columns - 1: 81 | image_data = interpolate_width(image_data) 82 | 83 | if x_idx > 0: 84 | image_data = np.fliplr(image_data) 85 | image_data = interpolate_width(image_data) 86 | image_data = np.fliplr(image_data) 87 | 88 | # create vertical alpha mask 89 | if y_idx < self.image_rows - 1: 90 | image_data = interpolate_height(image_data) 91 | 92 | if y_idx > 0: 93 | image_data = np.flipud(image_data) 94 | image_data = interpolate_height(image_data) 95 | image_data = np.flipud(image_data) 96 | 97 | pre_blend_image = np.zeros(self.dest_image_size + (4,), dtype=np.uint8) 98 | pre_blend_image[y_start:y_start + self.image_size, x_start:x_start + self.image_size] = image_data.astype(np.uint8) 99 | pre_blend_images.append(Image.fromarray(pre_blend_image, 'RGBA')) 100 | 101 | dest_image = pre_blend_images[0] 102 | for blend_image in pre_blend_images[1:]: 103 | dest_image = Image.alpha_composite(dest_image, blend_image) 104 | 105 | dest_image = dest_image.convert('RGB') 106 | return dest_image 107 | 108 | def get_start_indices(self, y_idx, x_idx): 109 | x_start = x_idx * self.image_size 110 | x_start = x_start - x_idx * self.image_size * self.interpolation_area if x_start > 0 else x_start 111 | y_start = y_idx * self.image_size 112 | y_start = y_start - y_idx * self.image_size * self.interpolation_area if y_start > 0 else y_start 113 | 114 | return int(x_start), int(y_start) 115 | 116 | def create_sample(self, image_information): 117 | images = [] 118 | all_labels = [] 119 | self.overlap_map[:] = 0 120 | 121 | for y_idx in range(self.image_rows): 122 | x_idx = 0 123 | while x_idx < self.image_columns: 124 | image_info = random.choice(image_information) 125 | file_name = image_info['filename'] 126 | bboxes = image_info['boxes'] 127 | labels = [int(box['label']) - 1 for box in bboxes] 128 | if len(labels) > self.max_label_length: 129 | continue 130 | 131 | all_labels.append(labels + [self.label_pad_value] * (self.max_label_length - len(labels))) 132 | 133 | with Image.open(os.path.join(self.image_dir, file_name)) as image: 134 | image = image.resize((self.image_size, self.image_size)) 135 | image = image.convert('RGBA') 136 | images.append(image) 137 | 138 | x_start, y_start = self.get_start_indices(y_idx, x_idx) 139 | self.overlap_map[y_start:y_start + self.image_size, x_start:x_start + self.image_size] += 1 140 | x_idx += 1 141 | 142 | assert len(images) == self.images_per_sample 143 | return self.interpolate_at_intersection(images), all_labels 144 | 145 | def pad_labels(self, labels): 146 | longest_label_length = max(map(len, labels)) 147 | padded_labels = [] 148 | for label in labels: 149 | padded_label = label + [self.label_pad_value] * (longest_label_length - len(label)) 150 | padded_labels.append(padded_label) 151 | 152 | return padded_labels 153 | 154 | def create_dataset(self, num_samples, image_infos): 155 | all_labels = [] 156 | for i in range(num_samples): 157 | if (i + 1) % 1000 == 0: 158 | print(i + 1) 159 | 160 | sample, labels = self.create_sample(image_infos) 161 | all_labels.extend(labels) 162 | 163 | sample.save(os.path.join(self.destination_image_dir, '{}.png'.format(i))) 164 | 165 | with open(os.path.join(self.destination_dir, '{}.csv'.format(self.dataset_name)), 'w') as gt_file: 166 | writer = csv.writer(gt_file, delimiter='\t') 167 | 168 | # merge labels per image 169 | all_labels_concatenated = [] 170 | for idx in range(0, len(all_labels), self.images_per_sample): 171 | concatenated = list(itertools.chain(*all_labels[idx:idx + self.images_per_sample])) 172 | all_labels_concatenated.append(concatenated) 173 | assert len(all_labels_concatenated) == num_samples, "number of labels should be as large as number of generated samples" 174 | 175 | for idx, labels in enumerate(all_labels_concatenated): 176 | writer.writerow([os.path.join(os.path.abspath(self.destination_image_dir), '{}.png'.format(idx))] + labels) 177 | 178 | 179 | if __name__ == "__main__": 180 | parser = argparse.ArgumentParser(description='Tool that creates a four image svhn dataset for training') 181 | parser.add_argument('svhn_image_dir', help='path to directory containing svhn images') 182 | parser.add_argument('svhn_gt', help='path to JSON containing svhn GT') 183 | parser.add_argument('destination_dir', help='directory where generated samples shall be saved') 184 | parser.add_argument('num_samples', type=int, help='number of samples to create') 185 | parser.add_argument('max_label_length', type=int, help='maximum length of labels') 186 | parser.add_argument('--dataset-name', default='train', help='name of the data set [e.g. train]') 187 | parser.add_argument('--image-size', type=int, default=100, help='size that each source image shall be resized to') 188 | parser.add_argument('--image-columns', type=int, default=2, help='number of image columns per generated sample') 189 | parser.add_argument('--image-rows', type=int, default=2, help='number of image rows per generated sample') 190 | 191 | args = parser.parse_args() 192 | 193 | with open(args.svhn_gt) as gt_json: 194 | gt_data = json.load(gt_json) 195 | 196 | dataset_creator = SVHNDatasetCreator( 197 | args.svhn_image_dir, 198 | args.image_size, 199 | args.image_columns, 200 | args.image_rows, 201 | args.destination_dir, 202 | args.dataset_name, 203 | args.max_label_length, 204 | ) 205 | dataset_creator.create_dataset(args.num_samples, gt_data) 206 | -------------------------------------------------------------------------------- /datasets/svhn/filter_large_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from PIL import Image 5 | 6 | from create_svhn_dataset_4_images import SUPPORTED_IMAGE_TYPES 7 | 8 | 9 | def get_images(image_dir, images, min_width, min_height): 10 | for an_image in images: 11 | with Image.open(os.path.join(image_dir, an_image)) as image: 12 | width, height = image.size 13 | 14 | if width >= min_width and height >= min_height: 15 | yield an_image 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser(description='tool that filters all images larger than a given size from a directory') 19 | parser.add_argument('image_dir', help='path to dir containing images to filter') 20 | parser.add_argument('destination_dir', help='path to dir where filtered images shall be put') 21 | parser.add_argument('min_width', type=int, help='minimum width of images that shall be filtered') 22 | parser.add_argument('min_height', type=int, help='minimum height of images that shall be filtered') 23 | 24 | args = parser.parse_args() 25 | 26 | images = filter(lambda x: os.path.splitext(x)[-1] in SUPPORTED_IMAGE_TYPES, os.listdir(args.image_dir)) 27 | 28 | for image in get_images(args.image_dir, images, args.min_width, args.min_height): 29 | shutil.copy(os.path.join(args.image_dir, image), os.path.join(args.destination_dir, image)) 30 | -------------------------------------------------------------------------------- /datasets/svhn/prepare_svhn_crops.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import os 5 | import tqdm as tqdm 6 | from PIL import Image 7 | 8 | from create_svhn_dataset import BBox 9 | 10 | 11 | def merge_bboxes(bboxes): 12 | resulting_bbox = None 13 | for bbox in bboxes: 14 | if resulting_bbox is None: 15 | resulting_bbox = bbox 16 | continue 17 | 18 | max_right = max(resulting_bbox.left + resulting_bbox.width, bbox.left + bbox.width) 19 | max_bottom = max(resulting_bbox.top + resulting_bbox.height, bbox.top + bbox.height) 20 | 21 | resulting_bbox.top = min(resulting_bbox.top, bbox.top) 22 | resulting_bbox.left = min(resulting_bbox.left, bbox.left) 23 | resulting_bbox.width = max_right - resulting_bbox.left 24 | resulting_bbox.height = max_bottom - resulting_bbox.top 25 | resulting_bbox.label.extend(bbox.label) 26 | 27 | return resulting_bbox 28 | 29 | 30 | def enlarge_bbox(bbox, percentage=0.3): 31 | enlarge_width = bbox.width * percentage * 2 32 | enlarge_height = bbox.height * percentage * 2 33 | 34 | return BBox( 35 | label=bbox.label, 36 | left=bbox.left - enlarge_width // 2, 37 | width=bbox.width + enlarge_width, 38 | top=bbox.top - enlarge_height // 2, 39 | height=bbox.height + enlarge_height, 40 | ) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser(description='tool that takes user defined crops around image regions of svhn data') 45 | parser.add_argument('svhn_json', help='path to svhn gt file') 46 | parser.add_argument('crop_size', type=int, help='crop size in pixels for each text region') 47 | parser.add_argument('dest_dir', help='destination dir for cropped images') 48 | parser.add_argument('stage_name', help='name of stage (e.g. train, or test)') 49 | parser.add_argument('--max-length', type=int, default=5, help='max length of labels [default: 5]') 50 | 51 | args = parser.parse_args() 52 | 53 | with open(args.svhn_json) as gt_file: 54 | gt = json.load(gt_file) 55 | 56 | # create dest dir if it does not exist 57 | os.makedirs(args.dest_dir, exist_ok=True) 58 | 59 | # read information for all files 60 | base_dir = os.path.abspath(os.path.dirname(args.svhn_json)) 61 | file_info = [] 62 | for image_data in tqdm.tqdm(gt): 63 | filename = os.path.join(base_dir, image_data['filename']) 64 | bboxes = [BBox( 65 | label=int(b['label']), 66 | top=b['top'], 67 | height=b['height'], 68 | left=b['left'], 69 | width=b['width'], 70 | ) for b in image_data['boxes']] 71 | 72 | merged_bbox = merge_bboxes(bboxes) 73 | new_bbox = enlarge_bbox(merged_bbox) 74 | 75 | with Image.open(filename) as image: 76 | cropped_image = image.crop(( 77 | new_bbox.left, 78 | new_bbox.top, 79 | new_bbox.left + new_bbox.width, 80 | new_bbox.top + new_bbox.height, 81 | )) 82 | cropped_image = cropped_image.resize((args.crop_size, args.crop_size)) 83 | new_filename = os.path.join(args.dest_dir, image_data['filename']) 84 | cropped_image.save(new_filename) 85 | file_info.append((os.path.abspath(new_filename), new_bbox.label)) 86 | 87 | # pad and filter labels 88 | filtered_infos = [] 89 | for file_path, labels in file_info: 90 | if len(labels) > args.max_length: 91 | continue 92 | elif len(labels) < args.max_length: 93 | values_to_pad = [0] * (args.max_length - len(labels)) 94 | labels.extend(values_to_pad) 95 | file_info = [file_path] + labels 96 | filtered_infos.append(file_info) 97 | 98 | # write to csv file 99 | with open(os.path.join(os.path.dirname(args.dest_dir), "{}.csv".format(args.stage_name)), 'w') as destination: 100 | writer = csv.writer(destination, delimiter='\t') 101 | writer.writerows(filtered_infos) -------------------------------------------------------------------------------- /datasets/svhn/svhn_char_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": 9250, 3 | "1": 49, 4 | "2": 50, 5 | "3": 51, 6 | "4": 52, 7 | "5": 53, 8 | "6": 54, 9 | "7": 55, 10 | "8": 56, 11 | "9": 57, 12 | "10": 48 13 | } 14 | -------------------------------------------------------------------------------- /datasets/svhn/svhn_dataextract_tojson.py: -------------------------------------------------------------------------------- 1 | # SVHN extracts data from the digitStruct.mat full numbers files. The data can be downloaded 2 | # the Street View House Number (SVHN) web site: http://ufldl.stanford.edu/housenumbers. 3 | # 4 | # This is an A2iA tweak (YG -9 Jan 2014) of the script found here : 5 | # http://blog.grimwisdom.com/python/street-view-house-numbers-svhn-and-octave 6 | # 7 | # The digitStruct.mat files in the full numbers tars (train.tar.gz, test.tar.gz, and extra.tar.gz) 8 | # are only compatible with matlab. This Python program can be run at the command line and will generate 9 | # a json version of the dataset. 10 | # 11 | # Command line usage: 12 | # SVHN_dataextract.py [-f input] [-o output_without_extension] 13 | # > python SVHN_dataextract.py -f digitStruct.mat -o digitStruct 14 | # 15 | # Issues: 16 | # The alibility to split in several files has been removed from the original 17 | # script. 18 | # 19 | 20 | import h5py 21 | import optparse 22 | from json import JSONEncoder 23 | 24 | parser = optparse.OptionParser() 25 | parser.add_option("-f", dest="fin", help="Matlab full number SVHN input file", default="digitStruct.mat") 26 | parser.add_option("-o", dest="filePrefix", help="name for the json output file", default="digitStruct") 27 | (options,args)= parser.parse_args() 28 | 29 | fin = options.fin 30 | 31 | # The DigitStructFile is just a wrapper around the h5py data. It basically references 32 | # inf: The input h5 matlab file 33 | # digitStructName The h5 ref to all the file names 34 | # digitStructBbox The h5 ref to all struc data 35 | class DigitStructFile: 36 | def __init__(self, inf): 37 | self.inf = h5py.File(inf, 'r') 38 | self.digitStructName = self.inf['digitStruct']['name'] 39 | self.digitStructBbox = self.inf['digitStruct']['bbox'] 40 | 41 | # getName returns the 'name' string for for the n(th) digitStruct. 42 | def getName(self,n): 43 | return ''.join([chr(c[0]) for c in self.inf[self.digitStructName[n][0]].value]) 44 | 45 | # bboxHelper handles the coding difference when there is exactly one bbox or an array of bbox. 46 | def bboxHelper(self,attr): 47 | if (len(attr) > 1): 48 | attr = [self.inf[attr.value[j].item()].value[0][0] for j in range(len(attr))] 49 | else: 50 | attr = [attr.value[0][0]] 51 | return attr 52 | 53 | # getBbox returns a dict of data for the n(th) bbox. 54 | def getBbox(self,n): 55 | bbox = {} 56 | bb = self.digitStructBbox[n].item() 57 | bbox['height'] = self.bboxHelper(self.inf[bb]["height"]) 58 | bbox['label'] = self.bboxHelper(self.inf[bb]["label"]) 59 | bbox['left'] = self.bboxHelper(self.inf[bb]["left"]) 60 | bbox['top'] = self.bboxHelper(self.inf[bb]["top"]) 61 | bbox['width'] = self.bboxHelper(self.inf[bb]["width"]) 62 | return bbox 63 | 64 | def getDigitStructure(self,n): 65 | s = self.getBbox(n) 66 | s['name']=self.getName(n) 67 | return s 68 | 69 | # getAllDigitStructure returns all the digitStruct from the input file. 70 | def getAllDigitStructure(self): 71 | return [self.getDigitStructure(i) for i in range(len(self.digitStructName))] 72 | 73 | # Return a restructured version of the dataset (one structure by boxed digit). 74 | # 75 | # Return a list of such dicts : 76 | # 'filename' : filename of the samples 77 | # 'boxes' : list of such dicts (one by digit) : 78 | # 'label' : 1 to 9 corresponding digits. 10 for digit '0' in image. 79 | # 'left', 'top' : position of bounding box 80 | # 'width', 'height' : dimension of bounding box 81 | # 82 | # Note: We may turn this to a generator, if memory issues arise. 83 | def getAllDigitStructure_ByDigit(self): 84 | pictDat = self.getAllDigitStructure() 85 | result = [] 86 | structCnt = 1 87 | for i in range(len(pictDat)): 88 | item = { 'filename' : pictDat[i]["name"] } 89 | figures = [] 90 | for j in range(len(pictDat[i]['height'])): 91 | figure = {} 92 | figure['height'] = pictDat[i]['height'][j] 93 | figure['label'] = pictDat[i]['label'][j] 94 | figure['left'] = pictDat[i]['left'][j] 95 | figure['top'] = pictDat[i]['top'][j] 96 | figure['width'] = pictDat[i]['width'][j] 97 | figures.append(figure) 98 | structCnt = structCnt + 1 99 | item['boxes'] = figures 100 | result.append(item) 101 | return result 102 | 103 | dsf = DigitStructFile(fin) 104 | dataset = dsf.getAllDigitStructure_ByDigit() 105 | fout = open(options.filePrefix + ".json",'w') 106 | fout.write(JSONEncoder(indent = True).encode(dataset)) 107 | fout.close() 108 | 109 | -------------------------------------------------------------------------------- /datasets/textrec/ctc_char_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": 9250, 3 | "1": 48, 4 | "2": 49, 5 | "3": 50, 6 | "4": 51, 7 | "5": 52, 8 | "6": 53, 9 | "7": 54, 10 | "8": 55, 11 | "9": 56, 12 | "10": 57, 13 | "11": 97, 14 | "12": 98, 15 | "13": 99, 16 | "14": 100, 17 | "15": 101, 18 | "16": 102, 19 | "17": 103, 20 | "18": 104, 21 | "19": 105, 22 | "20": 106, 23 | "21": 107, 24 | "22": 108, 25 | "23": 109, 26 | "24": 110, 27 | "25": 111, 28 | "26": 112, 29 | "27": 113, 30 | "28": 114, 31 | "29": 115, 32 | "30": 116, 33 | "31": 117, 34 | "32": 118, 35 | "33": 119, 36 | "34": 120, 37 | "35": 121, 38 | "36": 122, 39 | "37": 95, 40 | "38": 32, 41 | "39": 91, 42 | "40": 93, 43 | "41": 46, 44 | "42": 40, 45 | "43": 41, 46 | "44": 38, 47 | "45": 39, 48 | "46": 33, 49 | "47": 44, 50 | "48": 45, 51 | "49": 58, 52 | "50": 64, 53 | "51": 63 54 | } 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bleach==1.5.0 2 | chainer==3.2.0 3 | cupy==2.2.0 4 | cycler==0.10.0 5 | enum34==1.1.6 6 | fastrlock==0.3 7 | filelock==2.0.13 8 | h5py==2.7.1 9 | html5lib==0.9999999 10 | Markdown==2.6.10 11 | matplotlib==2.1.1 12 | numpy==1.13.3 13 | olefile==0.44 14 | Pillow==4.3.0 15 | protobuf==3.5.0.post1 16 | pyparsing==2.2.0 17 | python-dateutil==2.6.1 18 | pytz==2017.3 19 | six==1.11.0 20 | tensorflow==1.4.1 21 | tensorflow-tensorboard==0.4.0rc3 22 | tqdm==4.19.5 23 | Werkzeug==0.13 24 | chainerui==0.3.0 25 | -------------------------------------------------------------------------------- /utils/create_video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | 5 | from collections import namedtuple 6 | 7 | import subprocess 8 | 9 | import shutil 10 | import tempfile 11 | 12 | 13 | SUPPORTED_IMAGETYPES = [".png", ".jpg", ".jpeg"] 14 | ImageData = namedtuple('ImageData', ['file_name', 'path']) 15 | 16 | 17 | def get_filter(pattern): 18 | def filter_function(x): 19 | return int(re.search(pattern, x.file_name).group(1)) 20 | return filter_function 21 | 22 | 23 | def make_video(image_dir, dest_file, batch_size=1000, start=None, end=None, pattern=r"(\d+)"): 24 | sort_pattern = re.compile(pattern) 25 | 26 | image_files = os.listdir(image_dir) 27 | 28 | image_files = filter(lambda x: os.path.splitext(x)[-1] in SUPPORTED_IMAGETYPES, image_files) 29 | images = [] 30 | 31 | print("loading images") 32 | for file_name in image_files: 33 | path = os.path.join(image_dir, file_name) 34 | images.append(ImageData(file_name=file_name, path=path)) 35 | 36 | extract_number = get_filter(sort_pattern) 37 | if end is None: 38 | end = extract_number(max(images, key=extract_number)) 39 | if start is None: 40 | start = 0 41 | 42 | print("sort and cut images") 43 | images_sorted = list(filter( 44 | lambda x: start <= extract_number(x) < end, 45 | sorted(images, key=extract_number))) 46 | 47 | print("creating temp file") 48 | temp_file = tempfile.NamedTemporaryFile(mode="w") 49 | video_dir = tempfile.mkdtemp() 50 | i = 1 51 | try: 52 | # create a bunch of videos and merge them later (saves memory) 53 | while i < len(images_sorted): 54 | image = images_sorted[i] 55 | if i % batch_size == 0: 56 | temp_file = create_video(i, temp_file, video_dir) 57 | else: 58 | print(image.path, file=temp_file) 59 | i += 1 60 | 61 | if i % batch_size != 0: 62 | print("creating last video") 63 | temp_file = create_video(i - 1, temp_file, video_dir) 64 | temp_file.close() 65 | 66 | # merge created videos 67 | process_args = [ 68 | 'ffmpeg', 69 | '-i concat:"{}"'.format( 70 | '|'.join(sorted( 71 | os.listdir(video_dir), 72 | key=lambda x: int(os.path.splitext(x.rsplit('/', 1)[-1])[0])) 73 | ) 74 | ), 75 | '-c copy {}'.format(os.path.abspath(dest_file)) 76 | ] 77 | print(' '.join(process_args)) 78 | subprocess.run(' '.join(process_args), shell=True, check=True, cwd=video_dir) 79 | finally: 80 | shutil.rmtree(video_dir) 81 | 82 | 83 | def create_video(i, temp_file, video_dir): 84 | process_args = [ 85 | 'convert', 86 | '-quality 100', 87 | '@{}'.format(temp_file.name), 88 | os.path.join(video_dir, "{}.mpeg".format(i)) 89 | ] 90 | print(' '.join(process_args)) 91 | temp_file.flush() 92 | subprocess.run(' '.join(process_args), shell=True, check=True) 93 | temp_file.close() 94 | temp_file = tempfile.NamedTemporaryFile(mode="w") 95 | return temp_file 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser(description='Tool that creates a gif out of a number of given input images') 100 | parser.add_argument("image_dir", help="path to directory that contains all images that shall be converted to a gif") 101 | parser.add_argument("dest_file", help="path to destination gif file") 102 | parser.add_argument("--pattern", default=r"(\d+)", help="naming pattern to extract the ordering of the images") 103 | parser.add_argument("--batch-size", "-b", default=1000, type=int, help="batch size for processing, [default=1000]") 104 | parser.add_argument("-e", "--end", type=int, help="maximum number of images to put in gif") 105 | parser.add_argument("-s", "--start", type=int, help="frame to start") 106 | 107 | args = parser.parse_args() 108 | 109 | make_video(args.image_dir, args.dest_file, batch_size=args.batch_size, start=args.start, end=args.end, pattern=args.pattern) 110 | -------------------------------------------------------------------------------- /utils/show_progress.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import threading 4 | import tkinter 5 | from socketserver import ThreadingMixIn, StreamRequestHandler, TCPServer 6 | 7 | import base64 8 | import numpy as np 9 | from PIL import Image, ImageTk 10 | from io import BytesIO 11 | 12 | 13 | class ProgressWindow: 14 | 15 | def __init__(self, root): 16 | self.frame = tkinter.Frame(root) 17 | self.frame.pack(fill=tkinter.BOTH, expand=tkinter.YES) 18 | 19 | self._image = None 20 | self._sprite = None 21 | self.canvas = tkinter.Canvas( 22 | self.frame, 23 | width=850, 24 | height=400 25 | ) 26 | self.canvas.pack(fill=tkinter.BOTH, expand=tkinter.YES) 27 | 28 | @property 29 | def image(self): 30 | return self._image 31 | 32 | @image.setter 33 | def image(self, value): 34 | window_width = self.frame.winfo_width() 35 | window_height = self.frame.winfo_height() 36 | value = value.resize((window_width, window_height), Image.LANCZOS) 37 | image = ImageTk.PhotoImage(value) 38 | self._image = image 39 | self._sprite = self.canvas.create_image(value.width // 2, value.height // 2, image=self._image) 40 | self.canvas.config(width=value.width, height=value.height) 41 | 42 | 43 | class ImageDataHandler(StreamRequestHandler): 44 | 45 | def __init__(self, *args, **kwargs): 46 | self.window = kwargs.pop('window') 47 | super(ImageDataHandler, self).__init__(*args, **kwargs) 48 | 49 | def handle(self): 50 | data = self.rfile.read() 51 | data = json.loads(data.decode('utf-8')) 52 | data = BytesIO(base64.b64decode(data['image'])) 53 | image = Image.open(data) 54 | self.window.image = image 55 | 56 | 57 | class ImageServer(ThreadingMixIn, TCPServer): 58 | 59 | def __init__(self, *args, **kwargs): 60 | self.window = kwargs.pop('window') 61 | super(ImageServer, self).__init__(*args, **kwargs) 62 | 63 | def finish_request(self, request, client_address): 64 | self.RequestHandlerClass(request, client_address, self, window=self.window) 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser(description='Tool that shows current pictures of a training') 69 | parser.add_argument('--host', default='0.0.0.0', help='address to listen on') 70 | parser.add_argument('--port', type=int, default=1337, help='port to listen on') 71 | 72 | args = parser.parse_args() 73 | 74 | root = tkinter.Tk() 75 | window = ProgressWindow(root) 76 | 77 | print("starting server") 78 | server = ImageServer((args.host, args.port), ImageDataHandler, window=window) 79 | server_thread = threading.Thread(target=server.serve_forever) 80 | server_thread.deamon = True 81 | server_thread.start() 82 | 83 | print("starting window") 84 | root.mainloop() 85 | server.shutdown() 86 | server.server_close() 87 | 88 | 89 | 90 | 91 | --------------------------------------------------------------------------------