├── .github └── workflows │ └── pytest.yml ├── .gitignore ├── README.md ├── assets ├── 18.png ├── 19.png ├── 19_mae.png ├── 19_masked.jpg └── 2.png ├── config ├── config.json ├── docaligner.gin ├── doctr.json ├── doctr_plus.json └── finetune.json ├── demo ├── background_segmentation.ipynb └── doc3d_dataloader.ipynb ├── docmae ├── __init__.py ├── data │ ├── __init__.py │ ├── augmentation │ │ ├── __init__.py │ │ ├── random_resized_crop.py │ │ └── replace_background.py │ ├── doc3d.py │ ├── doc3d_minio.py │ ├── docaligner.py │ └── list_dataset.py ├── datamodule │ ├── __init__.py │ ├── docaligner_module.py │ ├── mixed_module.py │ └── utils.py ├── fine_tune.py ├── inference.py ├── models │ ├── docmae.py │ ├── doctr.py │ ├── doctr_custom.py │ ├── doctr_plus.py │ ├── mae.py │ ├── rectification.py │ ├── transformer.py │ └── upscale.py ├── pretrain.py ├── pretrain_pl.py └── train.py ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── architecture.py ├── datasets.py ├── doc3d ├── LICENSE ├── bm │ └── 9_6-vc_Page_002-YT40001.mat ├── img │ └── 9_6-vc_Page_002-YT40001.png ├── tiny.txt └── uv │ └── 9_6-vc_Page_002-YT40001.exr └── dtd ├── images └── lacelike_0037.jpg └── labels └── tiny.txt /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | # .github/workflows/app.yaml 2 | name: PyTest 3 | on: [push] 4 | 5 | 6 | jobs: 7 | test: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - uses: actions/setup-python@v4 12 | with: 13 | python-version: "3.10" 14 | - uses: Gr1N/setup-poetry@v8 15 | - run: poetry --version 16 | - run: poetry install 17 | - run: poetry run pytest tests -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DocMAE 2 | 3 | Unofficial implementation of **DocMAE: Document Image Rectification via Self-supervised Representation Learning** 4 | 5 | https://arxiv.org/abs/2304.10341 6 | 7 | ## TODO 8 | 9 | - [x] Document background segmentation network using U2 net 10 | - [x] Synthetic data generation for self-supervised pre-training 11 | - [x] Pre-training 12 | - [ ] Fine-tuning for document rectification (In progress) 13 | - [ ] Evaluation 14 | - [ ] Code clean up and documentation 15 | - [ ] Model release 16 | 17 | ## Demo 18 | 19 | Find a jupyter notebook at [demo/background_segmentation.ipynb](demo/background_segmentation.ipynb) 20 | 21 | ## Data 22 | 23 | ### Pre-training 24 | 25 | - 3411482 pages from ~1M documents from Docile dataset (https://github.com/rossumai/docile) 26 | - Rendered with Doc3D https://github.com/Dawars/doc3D-renderer 27 | - 558 HDR env lighting from https://hdri-haven.com/ 28 | 29 | Pretraining on 200k documents: 30 | 31 | ![MAE](assets/19_mae.png) 32 | 33 | #### Run training via: 34 | `python pretrain.py -c config/config.json` 35 | Visualize trained model using https://github.com/NielsRogge/Transformers-Tutorials/blob/master/ViTMAE/ViT_MAE_visualization_demo.ipynb 36 | 37 | # Acknowledgement 38 | 39 | Test documents come from DIR300 dataset https://github.com/fh2019ustc/DocGeoNet -------------------------------------------------------------------------------- /assets/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/assets/18.png -------------------------------------------------------------------------------- /assets/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/assets/19.png -------------------------------------------------------------------------------- /assets/19_mae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/assets/19_mae.png -------------------------------------------------------------------------------- /assets/19_masked.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/assets/19_masked.jpg -------------------------------------------------------------------------------- /assets/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/assets/2.png -------------------------------------------------------------------------------- /config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_dir": "./train_masked", 3 | "validation_dir": "./val_masked", 4 | "remove_unused_columns": false, 5 | "label_names": ["pixel_values"], 6 | "dataset_name": "doc3d", 7 | "do_train": true, 8 | "do_eval": true, 9 | "auto_find_batch_size": false, 10 | "tf32": null, 11 | "fp16": true, 12 | "norm_pix_loss": false, 13 | "base_learning_rate": 1.5e-4, 14 | "lr_scheduler_type": "cosine", 15 | "weight_decay": 0.05, 16 | "num_train_epochs": 200, 17 | "warmup_ratio": 0.05, 18 | "per_device_train_batch_size": 32, 19 | "per_device_eval_batch_size": 32, 20 | "dataloader_num_workers": 8, 21 | "logging_strategy": "steps", 22 | "logging_steps": 10, 23 | "logging_first_step": true, 24 | "evaluation_strategy": "epoch", 25 | "save_strategy": "epoch", 26 | "load_best_model_at_end": true, 27 | "save_total_limit": 3, 28 | "seed": 1337, 29 | "max_train_samples": 2000, 30 | "max_eval_samples": 1000, 31 | "config_overrides":"image_size=288,num_hidden_layers=6,num_attention_heads=8,hidden_size=512,decoder_num_hidden_layers=4,decoder_num_attention_heads=8,decoder_hidden_size=512,intermediate_size=2048,hidden_dropout_prob=0.1,attention_probs_dropout_prob=0.1" 32 | } 33 | -------------------------------------------------------------------------------- /config/docaligner.gin: -------------------------------------------------------------------------------- 1 | import docmae.datamodule.docaligner_module 2 | train.datamodule=@DocAlignerDataModule() 3 | 4 | DocAlignerDataModule.data_dir="/home/dawars/datasets/DocAligner_result/" 5 | DocAlignerDataModule.batch_size=4 6 | DocAlignerDataModule.num_workers=4 7 | DocAlignerDataModule.crop=False 8 | 9 | import torchvision.transforms 10 | import kornia.augmentation as ka 11 | 12 | import docmae.datamodule.utils 13 | 14 | 15 | get_image_transforms.transform_list = [@float/ConvertImageDtype(), 16 | @RandomChoice(), 17 | @int/transforms.ConvertImageDtype() 18 | ] 19 | float/ConvertImageDtype.dtype = %torch.float32 20 | int/ConvertImageDtype.dtype = %torch.uint8 21 | 22 | RandomChoice.transforms = [@kornia.augmentation.RandomPlanckianJitter(), 23 | @kornia.augmentation.RandomPlasmaShadow(), 24 | @kornia.augmentation.RandomPlasmaBrightness(), 25 | @kornia.augmentation.RandomInvert(), 26 | @kornia.augmentation.RandomPosterize(), 27 | # @RandomSharpness(), 28 | @kornia.augmentation.RandomAutoContrast(), 29 | @kornia.augmentation.RandomEqualize(), 30 | @kornia.augmentation.RandomGaussianBlur(), 31 | @kornia.augmentation.RandomMotionBlur(), 32 | ] 33 | RandomChoice.p = [0.5, 0.25, 0.2, 0.05, 0.1, 0.05, 0.05, 0.1, 0.1,] 34 | 35 | kornia.augmentation.RandomGaussianBlur.kernel_size=(3, 5) 36 | kornia.augmentation.RandomGaussianBlur.sigma=(0.1, 1) 37 | kornia.augmentation.RandomGaussianBlur.p=1.0 38 | kornia.augmentation.RandomGaussianBlur.keepdim=True 39 | 40 | kornia.augmentation.RandomMotionBlur.kernel_size=3 41 | kornia.augmentation.RandomMotionBlur.angle=35.0 42 | kornia.augmentation.RandomMotionBlur.direction=0.5 43 | kornia.augmentation.RandomMotionBlur.p=1.0 44 | kornia.augmentation.RandomMotionBlur.keepdim=True 45 | 46 | kornia.augmentation.RandomPlanckianJitter.keepdim=True 47 | kornia.augmentation.RandomPlasmaShadow.keepdim=True 48 | kornia.augmentation.RandomPlasmaBrightness.keepdim=True 49 | kornia.augmentation.RandomInvert.keepdim=True 50 | kornia.augmentation.RandomPosterize.keepdim=True 51 | kornia.augmentation.RandomAutoContrast.keepdim=True 52 | kornia.augmentation.RandomEqualize.keepdim=True 53 | 54 | kornia.augmentation.RandomPlanckianJitter.p=1.0 55 | kornia.augmentation.RandomPlasmaShadow.p=1.0 56 | kornia.augmentation.RandomPlasmaBrightness.p=1.0 57 | kornia.augmentation.RandomInvert.p=1.0 58 | kornia.augmentation.RandomPosterize.p=1.0 59 | kornia.augmentation.RandomAutoContrast.p=1.0 60 | kornia.augmentation.RandomEqualize.p=1.0 61 | -------------------------------------------------------------------------------- /config/doctr.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_path": "/home/dawars/datasets/doc3d/", 3 | "segmenter_ckpt": "u2netp.pth", 4 | "progress_bar": true, 5 | "training": { 6 | "seed": 1337, 7 | "num_devices": 1, 8 | "steps": 500000, 9 | "batch_size": 8, 10 | "num_workers": 8, 11 | "crop": true 12 | }, 13 | "model": { 14 | "segment_background": true, 15 | "hidden_dim": 256, 16 | "num_attn_layers": 6, 17 | "upscale_type": "raft" 18 | } 19 | } -------------------------------------------------------------------------------- /config/doctr_plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_path": "/home/dawars/datasets/doc3d/", 3 | "background_path": "/home/dawars/datasets/dtd/", 4 | "progress_bar": true, 5 | "training": { 6 | "seed": 1337, 7 | "num_devices": -1, 8 | "steps": 1000000, 9 | "batch_size": 12, 10 | "num_workers": 12, 11 | "crop": true, 12 | "mask_loss": false, 13 | "ocr_loss": false, 14 | "line_loss": false 15 | }, 16 | "model": { 17 | "segment_background": false, 18 | "hidden_dim": 256, 19 | "num_attn_layers": 6, 20 | "upscale_type": "raft", 21 | "extra_attention": true, 22 | "extra_skip": true, 23 | "add_pe_every_block": true, 24 | "no_pe_for_value": true 25 | } 26 | } -------------------------------------------------------------------------------- /config/finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "mae_path": "/home/dawars/projects/models/docmae_pretrain", 3 | "dataset_path": "/home/dawars/datasets/doc3d/", 4 | "epochs": 65, 5 | "batch_size": 64, 6 | "num_workers": 32, 7 | "use_minio": false, 8 | "hidden_dim": 512, 9 | "upscale_type": "interpolate", 10 | "freeze_backbone": true 11 | } -------------------------------------------------------------------------------- /docmae/__init__.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import sys 5 | 6 | 7 | def setup_logging(log_level, log_dir): 8 | """ 9 | To set up logging 10 | :param log_level: 11 | :param log_dir: 12 | :return: 13 | """ 14 | 15 | log_level = { 16 | "CRITICAL": logging.CRITICAL, 17 | "ERROR": logging.ERROR, 18 | "WARNING": logging.WARNING, 19 | "INFO": logging.INFO, 20 | "DEBUG": logging.DEBUG, 21 | }[log_level.upper()] 22 | root_logger = logging.getLogger() 23 | handler = logging.StreamHandler(sys.stdout) 24 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s") 25 | handler.setFormatter(formatter) 26 | root_logger.setLevel(log_level) 27 | root_logger.addHandler(handler) 28 | if log_dir: 29 | os.makedirs(log_dir, exist_ok=True) 30 | log_filename = datetime.datetime.now().strftime("%Y-%m-%d") + ".log" 31 | filehandler = logging.FileHandler(filename=os.path.join(log_dir, log_filename)) 32 | filehandler.setFormatter(formatter) 33 | root_logger.addHandler(filehandler) 34 | -------------------------------------------------------------------------------- /docmae/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/docmae/data/__init__.py -------------------------------------------------------------------------------- /docmae/data/augmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/docmae/data/augmentation/__init__.py -------------------------------------------------------------------------------- /docmae/data/augmentation/random_resized_crop.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import warnings 4 | from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union 5 | 6 | import torch 7 | from torchvision import datapoints 8 | from torchvision.transforms import InterpolationMode, functional 9 | from torchvision.transforms.v2 import functional as TF 10 | from torchvision.transforms.v2._utils import _setup_size 11 | from torchvision.transforms.v2.functional._geometry import _check_interpolation 12 | from torchvision.transforms.v2.utils import query_spatial_size 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class RandomResizedCropWithUV(object): 18 | """Crop a random portion of the input and resize it to a given size. 19 | This version correctly handles forward and backward mapping. Originally from torchvision.transforms.v2 20 | 21 | If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, 22 | :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) 23 | it can have arbitrary number of leading batch dimensions. For example, 24 | the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. 25 | 26 | A crop of the original input is made: the crop has a random area (H * W) 27 | and a random aspect ratio. This crop is finally resized to the given 28 | size. This is popularly used to train the Inception networks. 29 | 30 | Args: 31 | size (int or sequence): expected output size of the crop, for each edge. If size is an 32 | int instead of sequence like (h, w), a square output size ``(size, size)`` is 33 | made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). 34 | 35 | .. note:: 36 | In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. 37 | scale (tuple of float, optional): Specifies the lower and upper bounds for the random area of the crop, 38 | before resizing. The scale is defined with respect to the area of the original image. 39 | ratio (tuple of float, optional): lower and upper bounds for the random aspect ratio of the crop, before 40 | resizing. 41 | interpolation (InterpolationMode, optional): Desired interpolation enum defined by 42 | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. 43 | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, 44 | ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. 45 | The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. 46 | antialias (bool, optional): Whether to apply antialiasing. 47 | It only affects **tensors** with bilinear or bicubic modes and it is 48 | ignored otherwise: on PIL images, antialiasing is always applied on 49 | bilinear or bicubic modes; on other modes (for PIL images and 50 | tensors), antialiasing makes no sense and this parameter is ignored. 51 | Possible values are: 52 | 53 | - ``True``: will apply antialiasing for bilinear or bicubic modes. 54 | Other mode aren't affected. This is probably what you want to use. 55 | - ``False``: will not apply antialiasing for tensors on any mode. PIL 56 | images are still antialiased on bilinear or bicubic modes, because 57 | PIL doesn't support no antialias. 58 | - ``None``: equivalent to ``False`` for tensors and ``True`` for 59 | PIL images. This value exists for legacy reasons and you probably 60 | don't want to use it unless you really know what you are doing. 61 | 62 | The current default is ``None`` **but will change to** ``True`` **in 63 | v0.17** for the PIL and Tensor backends to be consistent. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | size: Union[int, Sequence[int]], 69 | scale: Tuple[float, float] = (0.08, 1.0), 70 | ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), 71 | interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, 72 | antialias: Optional[Union[str, bool]] = "warn", 73 | ) -> None: 74 | super().__init__() 75 | self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") 76 | 77 | if not isinstance(scale, Sequence): 78 | raise TypeError("Scale should be a sequence") 79 | scale = cast(Tuple[float, float], scale) 80 | if not isinstance(ratio, Sequence): 81 | raise TypeError("Ratio should be a sequence") 82 | ratio = cast(Tuple[float, float], ratio) 83 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 84 | warnings.warn("Scale and ratio should be of kind (min, max)") 85 | 86 | self.scale = scale 87 | self.ratio = ratio 88 | self.interpolation = _check_interpolation(interpolation) 89 | self.antialias = antialias 90 | 91 | self._log_ratio = torch.log(torch.tensor(self.ratio)) 92 | 93 | def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: 94 | height, width = query_spatial_size(flat_inputs) 95 | area = height * width 96 | 97 | log_ratio = self._log_ratio 98 | for _ in range(10): 99 | target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() 100 | aspect_ratio = torch.exp( 101 | torch.empty(1).uniform_( 102 | log_ratio[0], # type: ignore[arg-type] 103 | log_ratio[1], # type: ignore[arg-type] 104 | ) 105 | ).item() 106 | 107 | w = int(round(math.sqrt(target_area * aspect_ratio))) 108 | h = int(round(math.sqrt(target_area / aspect_ratio))) 109 | 110 | if 0 < w <= width and 0 < h <= height: 111 | i = torch.randint(0, height - h + 1, size=(1,)).item() 112 | j = torch.randint(0, width - w + 1, size=(1,)).item() 113 | break 114 | else: 115 | # Fallback to central crop 116 | in_ratio = float(width) / float(height) 117 | if in_ratio < min(self.ratio): 118 | w = width 119 | h = int(round(w / min(self.ratio))) 120 | elif in_ratio > max(self.ratio): 121 | h = height 122 | w = int(round(h * max(self.ratio))) 123 | else: # whole image 124 | w = width 125 | h = height 126 | i = (height - h) // 2 127 | j = (width - w) // 2 128 | 129 | return dict(top=i, left=j, height=h, width=w) 130 | 131 | def __call__(self, sample) -> Any: 132 | image, bm, uv, mask = sample 133 | orig_size = image.shape[1:] 134 | 135 | if uv is None: 136 | params = {"top": 0, "left": 0, "height": orig_size[0], "width": orig_size[1]} 137 | mask_crop = TF.resized_crop( 138 | mask, **params, size=self.size, interpolation=InterpolationMode.NEAREST_EXACT, antialias=False 139 | ).squeeze() 140 | 141 | image_crop = TF.resized_crop( 142 | image[None], **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias 143 | )[0].clip(0, 255) 144 | bm_crop = TF.resized_crop( 145 | bm[None], **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias 146 | )[0] 147 | return ( 148 | datapoints.Image(image_crop), 149 | datapoints.Image((((bm_crop.permute(1, 2, 0) - 0.5) * 2).float().permute(2, 0, 1) + 1) / 2), 150 | uv, # None 151 | datapoints.Mask(mask_crop[None]), 152 | ) 153 | params = self._get_params([image, bm, uv, mask]) 154 | crop = True 155 | while crop: 156 | params = self._get_params([image, bm, uv, mask]) 157 | uv_crop = TF.resized_crop( 158 | uv, **params, size=self.size, interpolation=InterpolationMode.NEAREST_EXACT, antialias=False 159 | ) 160 | mask_crop = TF.resized_crop( 161 | mask, **params, size=self.size, interpolation=InterpolationMode.NEAREST_EXACT, antialias=False 162 | ).squeeze() 163 | # more than half of the image is filled 164 | if mask_crop.sum() > 10: 165 | crop = False 166 | else: 167 | logging.warning("Crop contains little content, recropping") 168 | # test values 169 | # params = {"top": 14, "left": 39, "height": 419, "width": 331} # full page crop 170 | # params = {"top": 200, "left": 0, "height": 201, "width": 446} # bottom half crop 171 | # params = {"top": 0, "left": 113, "height": 326, "width": 309} # top right corner 172 | # params = {"top": 200, "left": 200, "height": 248, "width": 248} # bottom right corner 173 | # params = {"top": 2, "left": 2, "height": 366, "width": 366} # top left corner 174 | # params = {"top": 27, "left": 67, "height": 100, "width": 100} # test 175 | 176 | image_crop = TF.resized_crop( 177 | image[None], **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias 178 | )[0].clip(0, 255) 179 | 180 | # flip uv Y 181 | uv_crop[1, mask_crop.bool()] = 1 - uv_crop[1, mask_crop.bool()] 182 | min_uv_w, min_uv_h = uv_crop[0, mask_crop.bool()].min(), uv_crop[1, mask_crop.bool()].min() 183 | max_uv_w, max_uv_h = uv_crop[0, mask_crop.bool()].max(), uv_crop[1, mask_crop.bool()].max() 184 | 185 | min_uv_h = min_uv_h * orig_size[0] 186 | max_uv_h = max_uv_h * orig_size[0] 187 | min_uv_w = min_uv_w * orig_size[1] 188 | max_uv_w = max_uv_w * orig_size[1] 189 | 190 | bm_crop = bm[:, min_uv_h.long() : max_uv_h.long() + 1, min_uv_w.long() : max_uv_w.long() + 1] 191 | bm_crop = functional.resize(bm_crop[None], self.size, interpolation=self.interpolation, antialias=self.antialias)[0] 192 | 193 | # normalized relative displacement for sampling 194 | bm_crop_norm = (bm_crop.permute(1, 2, 0) - 0.5) * 2 195 | # extend crop to include background 196 | min_crop_w = params["left"] 197 | min_crop_h = params["top"] 198 | max_crop_w = params["left"] + params["width"] 199 | max_crop_h = params["top"] + params["height"] 200 | 201 | # get center of crop in normalized coords [-1, 1] 202 | center_x = min_crop_w + (max_crop_w - min_crop_w) / 2 203 | center_y = min_crop_h + (max_crop_h - min_crop_h) / 2 204 | center_x_norm = 2 * center_x / orig_size[1] - 1 205 | center_y_norm = 2 * center_y / orig_size[0] - 1 206 | 207 | bm_crop_norm[..., 1] = bm_crop_norm[..., 1] - center_y_norm # h 208 | bm_crop_norm[..., 0] = bm_crop_norm[..., 0] - center_x_norm # w 209 | 210 | # rescale to [-1, 1] for crop 211 | bm_crop_norm[..., 1] = (bm_crop_norm[..., 1]) * orig_size[1] / (max_crop_h - min_crop_h) 212 | bm_crop_norm[..., 0] = (bm_crop_norm[..., 0]) * orig_size[0] / (max_crop_w - min_crop_w) 213 | 214 | """ 215 | import torch.nn.functional as F 216 | from matplotlib import pyplot as plt 217 | import matplotlib.patches as patches 218 | from copy import copy 219 | 220 | align_corners = False 221 | bm_crop_norm = bm_crop_norm.float()[None] 222 | 223 | image_crop_manual = image[:, min_crop_h : max_crop_h + 1, min_crop_w : max_crop_w + 1] 224 | image_crop_manual = functional.resize( 225 | image_crop_manual[None], self.size, interpolation=self.interpolation, antialias=self.antialias 226 | )[0] 227 | 228 | mask_crop_manual = mask[0, min_crop_h : max_crop_h + 1, min_crop_w : max_crop_w + 1] 229 | mask_crop_manual = functional.resize( 230 | mask_crop_manual[None], self.size, interpolation=InterpolationMode.NEAREST_EXACT, antialias=False 231 | )[0] 232 | 233 | zeros = torch.ones((448, 448, 1)) 234 | 235 | f, axrr = plt.subplots(3, 5) 236 | for ax in axrr: 237 | for a in ax: 238 | a.set_xticks([]) 239 | a.set_yticks([]) 240 | 241 | # scale bm to -1.0 to 1.0 242 | bm_norm = (bm - 0.5) * 2 243 | bm_norm = bm_norm.permute(1, 2, 0)[None].float() 244 | 245 | axrr[0][0].imshow(image.permute(1, 2, 0) / 255) 246 | axrr[0][0].title.set_text("full image") 247 | axrr[0][0].scatter((center_x), (center_y), c="b", s=1) 248 | axrr[0][0].scatter((((center_x_norm + 1) / 2) * orig_size[1]), (((center_y_norm + 1) / 2) * orig_size[0]), c="r", s=1) 249 | axrr[0][1].imshow(mask[0], cmap="gray") 250 | axrr[0][1].title.set_text("mask") 251 | axrr[0][2].imshow(torch.cat([uv.permute(1, 2, 0), zeros], dim=-1)) 252 | axrr[0][2].title.set_text("uv") 253 | axrr[0][3].imshow(torch.cat([bm_norm[0] * 0.5 + 0.5, zeros], dim=-1), cmap="gray") 254 | axrr[0][3].title.set_text("bm") 255 | axrr[0][4].imshow(F.grid_sample(image[None] / 255, bm_norm, align_corners=align_corners)[0].permute(1, 2, 0)) 256 | axrr[0][4].title.set_text("unwarped full doc") 257 | 258 | rect_patch_crop = patches.Rectangle( 259 | (min_crop_w, min_crop_h), 260 | max_crop_w - min_crop_w, 261 | max_crop_h - min_crop_h, 262 | linewidth=1, 263 | edgecolor="b", 264 | facecolor="none", 265 | ) 266 | axrr[0][0].add_patch(copy(rect_patch_crop)) 267 | axrr[0][1].add_patch(copy(rect_patch_crop)) 268 | axrr[0][2].add_patch(copy(rect_patch_crop)) 269 | rect_patch_uv = patches.Rectangle( 270 | (min_uv_w, min_uv_h), (max_uv_w - min_uv_w), (max_uv_h - min_uv_h), linewidth=1, edgecolor="g", facecolor="none" 271 | ) 272 | axrr[0][4].add_patch(copy(rect_patch_uv)) 273 | axrr[0][3].add_patch(copy(rect_patch_uv)) 274 | 275 | zeros = torch.ones_like(uv_crop.permute(1, 2, 0)) 276 | 277 | axrr[1][0].imshow(image_crop.permute(1, 2, 0) / 255) 278 | axrr[1][0].title.set_text("image crop") 279 | axrr[1][1].imshow(mask_crop, cmap="gray") 280 | axrr[1][1].title.set_text("mask crop") 281 | axrr[1][2].imshow(torch.cat([uv_crop.permute(1, 2, 0), zeros], dim=-1)) 282 | axrr[1][2].title.set_text("uv crop") 283 | axrr[1][3].title.set_text("bm crop manual") 284 | axrr[1][4].imshow( 285 | F.grid_sample(image[None] / 255, (bm_crop.permute(1, 2, 0).float() - 0.5)[None] * 2, align_corners=align_corners,)[ 286 | 0 287 | ].permute(1, 2, 0) 288 | ) 289 | axrr[1][4].title.set_text("unwarped crop from orig") 290 | 291 | axrr[2][0].imshow(image_crop_manual.permute(1, 2, 0) / 255) 292 | axrr[2][0].title.set_text("image crop manual") 293 | axrr[2][1].imshow(mask_crop_manual, cmap="gray") 294 | axrr[2][1].title.set_text("crop mask manual") 295 | axrr[2][2].imshow( 296 | F.grid_sample(mask_crop_manual[None][None], bm_crop_norm, mode="nearest", align_corners=align_corners)[0].permute( 297 | 1, 2, 0 298 | ), 299 | cmap="gray", 300 | ) 301 | axrr[2][2].title.set_text("mask unwarped manual") 302 | axrr[2][3].imshow( 303 | torch.cat([bm_crop_norm[0] * 0.5 + 0.5, torch.ones_like(bm_crop_norm)[0, ..., 0:1]], dim=-1).clip(0, 1), cmap="gray" 304 | ) 305 | axrr[2][3].title.set_text("bm crop manual") 306 | axrr[2][4].imshow( 307 | F.grid_sample(image_crop_manual[None] / 255, bm_crop_norm, align_corners=align_corners)[0].permute(1, 2, 0) 308 | ) 309 | axrr[2][4].title.set_text("unwarped crop manual") 310 | 311 | plt.tight_layout() 312 | plt.show() 313 | """ 314 | return ( 315 | datapoints.Image(image_crop), 316 | datapoints.Image((bm_crop_norm.float().permute(2, 0, 1) + 1) / 2), 317 | # datapoints.Mask(uv_crop), 318 | None, 319 | datapoints.Mask(mask_crop[None]), 320 | ) 321 | -------------------------------------------------------------------------------- /docmae/data/augmentation/replace_background.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from pathlib import Path 4 | 5 | import cv2 6 | import kornia 7 | import torch 8 | from PIL import Image 9 | from torchvision import datapoints 10 | from torchvision.transforms.v2 import functional as F 11 | 12 | LOGGER = logging.getLogger(__name__) 13 | 14 | 15 | def match_brightness(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 16 | """ 17 | Match brightness of source to target 18 | Args: 19 | source: 20 | target: 21 | """ 22 | source_mean = torch.mean(source.float()) 23 | target_mean = torch.mean(target.float()) 24 | 25 | ratio = source_mean / target_mean 26 | return source / ratio 27 | 28 | 29 | class ReplaceBackground(object): 30 | def __init__(self, data_root: Path, split: str, match_brightness=True): 31 | """ 32 | Replace the background of the image (where mask is 0) by a random image 33 | Args: 34 | data_root: Directory where the dtd dataset is extracted 35 | split: split name of subset of images (train1-10, val1-10, test1-10) 36 | match_brightness: whether to match the brightness of the background to the image 37 | """ 38 | self.data_root = data_root 39 | self.filenames = (data_root / "labels" / f"{split}.txt").read_text().strip().split("\n") 40 | self.match_brightness = match_brightness 41 | 42 | def __call__(self, sample): 43 | image, bm, uv, mask = sample 44 | shape = image.shape 45 | 46 | filename = random.choice(self.filenames) 47 | background = Image.open(self.data_root / "images" / filename).convert("RGB").resize(shape[1:]) 48 | background = F.to_image_tensor(background) 49 | if self.match_brightness: 50 | background = match_brightness(background, image) 51 | 52 | smooth_mask = kornia.filters.box_blur(mask[None].float(), 3) 53 | image = ((1 - smooth_mask) * background + smooth_mask * image)[0] 54 | 55 | return datapoints.Image(image), bm, uv, mask 56 | -------------------------------------------------------------------------------- /docmae/data/doc3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | 5 | import h5py 6 | from PIL import Image 7 | import cv2 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torchvision import datapoints 12 | 13 | LOGGER = logging.getLogger(__name__) 14 | 15 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 16 | 17 | 18 | class Doc3D(Dataset): 19 | def __init__(self, data_root: Path, split: str, transforms=None): 20 | """ 21 | Args: 22 | data_root: Directory where the doc3d dataset is extracted 23 | split: split name of subset of images 24 | transforms: optional transforms for data augmentation 25 | image_transforms: optional transforms for data augmentation only applied to rgb image 26 | """ 27 | 28 | self.data_root = data_root 29 | self.filenames = (data_root / f"{split}.txt").read_text().strip().split("\n") 30 | self.prefix_img = "img/" 31 | self.prefix_bm = "bm/" 32 | self.prefix_uv = "uv/" 33 | 34 | self.transforms = transforms 35 | 36 | def __len__(self): 37 | return len(self.filenames) 38 | 39 | def __getitem__(self, idx): 40 | filename = self.filenames[idx] 41 | 42 | image = Image.open(self.data_root / self.prefix_img / f"{filename}.png").convert("RGB") 43 | image = datapoints.Image(image) 44 | 45 | # backwards mapping 46 | h5file = h5py.File(self.data_root / self.prefix_bm / f"{filename}.mat", "r") 47 | bm = np.array(h5file.get("bm")) 48 | bm = bm.transpose((2, 1, 0)) 49 | 50 | bm = datapoints.Image((bm / image.shape[1:]).transpose((2, 0, 1))) # absolute back mapping [0, 1] 51 | 52 | # mask from uv 53 | # Decode the EXR data using OpenCV 54 | uv_mask = cv2.imread(str(self.data_root / self.prefix_uv / f"{filename}.exr"), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 55 | uv_mask = cv2.cvtColor(uv_mask, cv2.COLOR_BGR2RGB).transpose(2, 0, 1) # forward mapping 56 | uv = datapoints.Mask(uv_mask[:2]) 57 | mask = datapoints.Mask(uv_mask[2:3].astype(bool)) 58 | 59 | if self.transforms: 60 | image, bm, uv, mask = self.transforms(image, bm, uv, mask) 61 | 62 | item = {"image": image.byte(), "bm": bm, "mask": mask} 63 | if uv is not None: 64 | item["uv"] = uv 65 | return item 66 | -------------------------------------------------------------------------------- /docmae/data/doc3d_minio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | from io import BytesIO 5 | 6 | import h5py 7 | import urllib3 8 | from urllib3.exceptions import InsecureRequestWarning 9 | from minio import Minio 10 | from PIL import Image 11 | import cv2 12 | import numpy as np 13 | import torchvision 14 | from torch.utils.data import Dataset 15 | from torchvision import datapoints 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | warnings.filterwarnings("ignore", category=InsecureRequestWarning) 20 | torchvision.disable_beta_transforms_warning() 21 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 22 | 23 | 24 | class Doc3D(Dataset): 25 | def __init__(self, filenames, transforms=None): 26 | self.client = self.create_gini_data_minio_client() 27 | 28 | self.prefix_img = "computer_vision/rectification/doc3d/img/" 29 | self.prefix_bm = "computer_vision/rectification/doc3d/bm/" 30 | self.prefix_uv = "computer_vision/rectification/doc3d/uv/" 31 | self.filenames = filenames 32 | 33 | self.transforms = transforms 34 | 35 | @staticmethod 36 | def create_gini_data_minio_client(minio_url: str = None, access_key: str = None, access_secret: str = None): 37 | """ 38 | To create a minio client with values from following ENV VARs. 39 | url: PROD_MINIO_URL 40 | key: CVIE_MINIO_USER 41 | secret: CVIE_MINIO_PASSWORD 42 | 43 | Note: ENV Vars override the given values. 44 | Client REF: https://min.io/docs/minio/linux/developers/python/API.html 45 | 46 | Returns: 47 | a configured minio client 48 | """ 49 | minio_url = os.getenv("PROD_MINIO_URL", minio_url) 50 | minio_key = os.getenv("CVIE_MINIO_USER", access_key) 51 | minio_secret = os.getenv("CVIE_MINIO_PASSWORD", access_secret) 52 | if not minio_key or not minio_secret: 53 | LOGGER.warning("No minio credential is set. ") 54 | return 55 | timeout = 60 56 | 57 | http_client = urllib3.PoolManager( 58 | timeout=urllib3.util.Timeout(connect=timeout, read=timeout), 59 | maxsize=10, 60 | cert_reqs="CERT_NONE", 61 | retries=urllib3.Retry(total=5, backoff_factor=0.2, status_forcelist=[500, 502, 503, 504]), 62 | ) 63 | # Note: we have to set up a client that uses ssl but no cert check involved. 64 | # with mc command it's mc --insecure, 65 | # but with python client, it's combination of secure=True (ensure https) and cert_req=CERT_NONE. 66 | return Minio(minio_url, access_key=minio_key, secret_key=minio_secret, secure=True, http_client=http_client) 67 | 68 | def __len__(self): 69 | return len(self.filenames) 70 | 71 | def __getitem__(self, idx): 72 | filename = self.filenames[idx] 73 | 74 | try: 75 | obj_img = self.client.get_object("cvie", self.prefix_img + filename + ".png") 76 | image = Image.open(obj_img).convert("RGB") 77 | finally: 78 | obj_img.close() 79 | obj_img.release_conn() 80 | image = datapoints.Image(image) 81 | 82 | # backwards mapping 83 | try: 84 | obj_bm = self.client.get_object("cvie", self.prefix_bm + filename + ".mat") 85 | h5file = h5py.File(BytesIO(obj_bm.data), "r") 86 | finally: 87 | obj_bm.close() 88 | obj_bm.release_conn() 89 | flow = np.array(h5file.get("bm")) 90 | flow = np.flip(flow, 0).copy() 91 | flow = datapoints.Image(flow) 92 | 93 | # mask from uv 94 | try: 95 | obj_uv = self.client.get_object("cvie", self.prefix_uv + filename + ".exr") 96 | exr_data = obj_uv.read() 97 | finally: 98 | obj_uv.close() 99 | obj_uv.release_conn() 100 | exr_array = np.asarray(bytearray(exr_data), dtype=np.uint8) 101 | 102 | # Decode the EXR data using OpenCV 103 | uv = cv2.imdecode(exr_array, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 104 | uv = cv2.cvtColor(uv, cv2.COLOR_BGR2RGB) 105 | mask = datapoints.Mask(uv[..., 2]) 106 | 107 | if self.transforms: 108 | image, flow, mask = self.transforms(image, flow, mask) 109 | 110 | return {"image": image, "bm": flow, "mask": mask} 111 | -------------------------------------------------------------------------------- /docmae/data/docaligner.py: -------------------------------------------------------------------------------- 1 | """ 2 | DocAligner https://github.com/ZZZHANG-jx/DocAligner 3 | @article{zhang2023docaligner, 4 | title={DocAligner: Annotating Real-world Photographic Document Images by Simply Taking Pictures}, 5 | author={Zhang, Jiaxin and Chen, Bangdong and Cheng, Hiuyi and Guo, Fengjun and Ding, Kai and Jin, Lianwen}, 6 | journal={arXiv preprint arXiv:2306.05749}, 7 | year={2023}} 8 | """ 9 | import logging 10 | from pathlib import Path 11 | 12 | from PIL import Image 13 | import numpy as np 14 | import torch 15 | from torch.utils.data import Dataset 16 | from torchvision import datapoints 17 | 18 | LOGGER = logging.getLogger(__name__) 19 | 20 | 21 | class DocAligner(Dataset): 22 | def __init__(self, data_root: Path, split: str, transforms=None): 23 | """ 24 | Args: 25 | data_root: Directory where the docaligner dataset is extracted 26 | split: split name of subset of images 27 | transforms: optional transforms for data augmentation 28 | """ 29 | 30 | self.data_root = data_root 31 | self.filenames = (data_root / f"{split}.txt").read_text().strip().split("\n") 32 | 33 | self.transforms = transforms 34 | 35 | def __len__(self): 36 | return len(self.filenames) 37 | 38 | def __getitem__(self, idx): 39 | filename = Path(self.filenames[idx]) 40 | 41 | image = Image.open(self.data_root / filename).convert("RGB").resize((1024, 1024)) 42 | image = datapoints.Image(image) 43 | 44 | # backwards mapping 45 | bm_raw = np.load(str(self.data_root / filename.with_suffix(".npy")).replace("origin", "grid3")).astype(float) 46 | bm = (bm_raw + 1) / 2 47 | bm = datapoints.Image(bm.transpose((2, 0, 1))) # absolute back mapping [0, 1] 48 | 49 | uv = None 50 | 51 | mask = Image.open(str(self.data_root / filename).replace("origin", "mask_new")).convert("1").resize((1024, 1024)) 52 | mask = datapoints.Mask(mask) 53 | 54 | if self.transforms: 55 | image, bm, uv, mask = self.transforms(image, bm, uv, mask) 56 | 57 | item = {"image": image.byte(), "bm": bm, "mask": mask} 58 | if uv is not None: 59 | item["uv"] = uv 60 | return item 61 | -------------------------------------------------------------------------------- /docmae/data/list_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from PIL import Image, ImageOps 4 | from torch.utils.data import Dataset 5 | import re 6 | 7 | 8 | def natural_sort(l): 9 | convert = lambda text: int(text) if text.isdigit() else text.lower() 10 | alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", key)] 11 | return sorted(l, key=alphanum_key) 12 | 13 | 14 | class ListDataset(Dataset): 15 | def __init__(self, data_root: Path, split: str, transforms=None): 16 | """ 17 | Args: 18 | data_root: Directory where the split files are located 19 | split: split name of subset of images 20 | transforms: optional transforms for data augmentation 21 | """ 22 | 23 | self.data_root = data_root 24 | 25 | self.filenames = natural_sort((data_root / f"{split}.txt").read_text().split()) 26 | 27 | self.transforms = transforms 28 | 29 | def __len__(self): 30 | return len(self.filenames) 31 | 32 | def __getitem__(self, idx): 33 | filename = self.filenames[idx] 34 | 35 | image = Image.open(self.data_root / filename).convert("RGB") 36 | image = ImageOps.exif_transpose(image) 37 | 38 | if self.transforms: 39 | image = self.transforms(image) 40 | 41 | return {"image": image} 42 | -------------------------------------------------------------------------------- /docmae/datamodule/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/docmae/datamodule/__init__.py -------------------------------------------------------------------------------- /docmae/datamodule/docaligner_module.py: -------------------------------------------------------------------------------- 1 | """This datamodule is responsible for loading and setting up dataloaders for DocAligner dataset""" 2 | from pathlib import Path 3 | 4 | import gin 5 | import lightning as L 6 | import torch 7 | # We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that 8 | # some APIs may slightly change in the future 9 | import torchvision 10 | from torch.utils.data import DataLoader 11 | 12 | from docmae.datamodule.utils import get_image_transforms 13 | 14 | torchvision.disable_beta_transforms_warning() 15 | 16 | import torchvision.transforms.v2 as transforms 17 | 18 | from docmae.data.docaligner import DocAligner 19 | from docmae.data.augmentation.random_resized_crop import RandomResizedCropWithUV 20 | 21 | 22 | @gin.configurable 23 | class DocAlignerDataModule(L.LightningDataModule): 24 | train_dataset: torch.utils.data.Dataset 25 | val_dataset: torch.utils.data.Dataset 26 | 27 | def __init__(self, data_dir: str, batch_size: int, num_workers: int, crop: bool): 28 | """ 29 | Datamodule to set up DocAligner dataset 30 | Args: 31 | data_dir: DocAligner path 32 | batch_size: batch size 33 | num_workers: number of workers in dataloader 34 | crop: whether to crop document images using UV 35 | """ 36 | super().__init__() 37 | self.data_dir = data_dir 38 | self.batch_size = max(batch_size, 4) # use 4 when searching batch size 39 | self.num_workers = num_workers 40 | self.crop = crop 41 | 42 | self.train_transform = transforms.Compose( 43 | [ 44 | RandomResizedCropWithUV((288, 288), scale=(0.08, 1.0) if self.crop else (1.0, 1.0), antialias=True), 45 | # ReplaceBackground(Path(config["background_path"]), "train1"), 46 | transforms.ToImageTensor(), 47 | transforms.ToDtype(torch.float32), 48 | ] 49 | ) 50 | self.image_transforms = get_image_transforms(gin.REQUIRED) 51 | self.val_transform = transforms.Compose( 52 | [ 53 | RandomResizedCropWithUV((288, 288), scale=(0.08, 1.0) if self.crop else (1.0, 1.0), antialias=True), 54 | # ReplaceBackground(Path(config["background_path"]), "val1"), 55 | transforms.ToImageTensor(), 56 | transforms.ToDtype(torch.float32), 57 | ] 58 | ) 59 | 60 | def setup(self, stage: str): 61 | self.train_dataset = DocAligner(Path(self.data_dir), "train", self.train_transform) 62 | self.val_dataset = DocAligner(Path(self.data_dir), "val", self.val_transform) 63 | 64 | def on_before_batch_transfer(self, batch, dataloader_idx: int): 65 | if self.trainer.training and isinstance(batch, dict): # not example tensor 66 | with torch.no_grad(): 67 | batch["image"] = self.image_transforms(batch["image"]) 68 | return batch 69 | 70 | def train_dataloader(self): 71 | return DataLoader( 72 | self.train_dataset, 73 | self.batch_size, 74 | shuffle=True, 75 | num_workers=min(self.batch_size, self.num_workers), 76 | pin_memory=True, 77 | ) 78 | 79 | def val_dataloader(self): 80 | return DataLoader( 81 | self.val_dataset, 82 | self.batch_size, 83 | shuffle=False, 84 | num_workers=min(self.batch_size, self.num_workers), 85 | pin_memory=True, 86 | ) 87 | -------------------------------------------------------------------------------- /docmae/datamodule/mixed_module.py: -------------------------------------------------------------------------------- 1 | """This datamodule is responsible for loading and setting up dataloaders for all available dataset""" 2 | from pathlib import Path 3 | 4 | import gin 5 | import lightning as L 6 | import torch 7 | from torch.utils.data import DataLoader, ConcatDataset 8 | # We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that 9 | # some APIs may slightly change in the future 10 | import torchvision 11 | import torchvision.transforms as T 12 | 13 | torchvision.disable_beta_transforms_warning() 14 | 15 | import torchvision.transforms.v2 as transforms 16 | 17 | from docmae.data.augmentation.replace_background import ReplaceBackground 18 | from docmae.data.doc3d import Doc3D 19 | from docmae.datamodule.utils import get_image_transforms 20 | from docmae.data.docaligner import DocAligner 21 | from docmae.data.augmentation.random_resized_crop import RandomResizedCropWithUV 22 | 23 | 24 | @gin.configurable 25 | class MixedDataModule(L.LightningDataModule): 26 | train_docaligner: torch.utils.data.Dataset 27 | val_docaligner: torch.utils.data.Dataset 28 | train_doc3d: torch.utils.data.Dataset 29 | val_doc3d: torch.utils.data.Dataset 30 | 31 | def __init__(self, docaligner_dir: str, doc3d_dir: str, background_dir: str, batch_size: int, num_workers: int, crop: bool): 32 | """ 33 | Datamodule to set up DocAligner dataset 34 | Args: 35 | docaligner_dir: DocAligner path 36 | doc3d_dir: Doc3D path 37 | background_dir: Path for background images 38 | batch_size: batch size 39 | num_workers: number of workers in dataloader 40 | crop: whether to crop document images using UV 41 | """ 42 | super().__init__() 43 | self.docaligner_dir = Path(docaligner_dir) 44 | self.doc3d_dir = Path(doc3d_dir) 45 | self.background_dir = Path(background_dir) 46 | self.batch_size = max(batch_size, 4) # use 4 when searching batch size 47 | self.num_workers = num_workers 48 | self.crop = crop 49 | 50 | self.train_transform = transforms.Compose( 51 | [ 52 | RandomResizedCropWithUV((288, 288), scale=(0.08, 1.0) if self.crop else (1.0, 1.0), antialias=True), 53 | ReplaceBackground(self.background_dir, "train1"), 54 | transforms.ToImageTensor(), 55 | transforms.ToDtype(torch.float32), 56 | ] 57 | ) 58 | self.val_transform = transforms.Compose( 59 | [ 60 | RandomResizedCropWithUV((288, 288), scale=(0.08, 1.0) if self.crop else (1.0, 1.0), antialias=True), 61 | transforms.ToImageTensor(), 62 | transforms.ToDtype(torch.float32), 63 | ] 64 | ) 65 | self.train_transform_nocrop = transforms.Compose( 66 | [ 67 | RandomResizedCropWithUV((288, 288), scale=(1.0, 1.0), antialias=True), 68 | transforms.ToImageTensor(), 69 | transforms.ToDtype(torch.float32), 70 | ] 71 | ) 72 | self.val_transform_nocrop = transforms.Compose( 73 | [ 74 | RandomResizedCropWithUV((288, 288), scale=(1.0, 1.0), antialias=True), 75 | transforms.ToImageTensor(), 76 | transforms.ToDtype(torch.float32), 77 | ] 78 | ) 79 | self.image_transforms = get_image_transforms(gin.REQUIRED) 80 | 81 | def setup(self, stage: str): 82 | self.train_docaligner = DocAligner(Path(self.docaligner_dir), "train", self.train_transform_nocrop) 83 | self.val_docaligner = DocAligner(Path(self.docaligner_dir), "val", self.val_transform_nocrop) 84 | self.train_doc3d = Doc3D(Path(self.doc3d_dir), "train", self.train_transform) 85 | self.val_doc3d = Doc3D(Path(self.doc3d_dir), "val", self.val_transform) 86 | 87 | def on_before_batch_transfer(self, batch, dataloader_idx: int): 88 | if self.trainer.training and isinstance(batch, dict): # not example tensor 89 | with torch.no_grad(): 90 | batch["image"] = self.image_transforms(batch["image"]) 91 | return batch 92 | 93 | def train_dataloader(self): 94 | return DataLoader( 95 | ConcatDataset([self.train_doc3d, self.train_docaligner]), 96 | self.batch_size, 97 | shuffle=True, 98 | num_workers=min(self.batch_size, self.num_workers), 99 | pin_memory=True, 100 | ) 101 | 102 | def val_dataloader(self): 103 | return DataLoader( 104 | ConcatDataset([self.val_doc3d, self.val_docaligner]), 105 | self.batch_size, 106 | shuffle=False, 107 | num_workers=min(self.batch_size, self.num_workers), 108 | pin_memory=True, 109 | ) 110 | -------------------------------------------------------------------------------- /docmae/datamodule/utils.py: -------------------------------------------------------------------------------- 1 | from enum import EnumMeta 2 | 3 | import gin 4 | import torchvision.transforms as T 5 | from kornia import augmentation 6 | from torch import nn 7 | from torchvision import transforms 8 | 9 | 10 | def init_external_gin_configurables(): 11 | # Set torchvision transforms as gin configurable 12 | for name in transforms.transforms.__all__: 13 | attribute = getattr(transforms, name) 14 | if isinstance(attribute, EnumMeta): 15 | # enums like InterpolationMode don't work because external registration of enums is not possible 16 | continue 17 | gin.external_configurable(attribute, module="torchvision.transforms") 18 | 19 | # Set torchvision transforms as gin configurable 20 | for name in augmentation.__all__: 21 | attribute = getattr(augmentation, name) 22 | if isinstance(attribute, EnumMeta): 23 | # enums like InterpolationMode don't work because external registration of enums is not possible 24 | continue 25 | if not hasattr(attribute, "__call__"): 26 | continue 27 | gin.external_configurable(attribute, module="kornia.augmentation") 28 | 29 | 30 | @gin.configurable 31 | def get_image_transforms(transform_list: list[nn.Module]): 32 | return T.Compose(transform_list) 33 | -------------------------------------------------------------------------------- /docmae/fine_tune.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | import argparse 5 | import shutil 6 | from pathlib import Path 7 | 8 | import torch 9 | import lightning as L 10 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 11 | from lightning.pytorch.loggers import TensorBoardLogger 12 | from torch.utils.data import DataLoader 13 | import torchvision.transforms.v2 as transforms 14 | 15 | from transformers import ViTMAEConfig 16 | from transformers.models.vit_mae.modeling_vit_mae import ViTMAEDecoder, ViTMAEModel 17 | 18 | from docmae.models.docmae import DocMAE 19 | 20 | from docmae import setup_logging 21 | from docmae.data.augmentation.random_resized_crop import RandomResizedCropWithUV 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def parse_arguments(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("-c", "--config", type=str, help="config file for training parameters") 29 | parser.add_argument( 30 | "-ll", 31 | "--log-level", 32 | type=str, 33 | default="INFO", 34 | choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], 35 | help="config file for training parameters", 36 | ) 37 | parser.add_argument("-l", "--log-dir", type=str, default="", help="folder to store log files") 38 | parser.add_argument("-t", "--tensorboard-dir", type=str, default="", help="folder to store tensorboard logs") 39 | parser.add_argument("-m", "--model-output-dir", type=str, default="model", help="folder to store trained models") 40 | return parser.parse_args() 41 | 42 | 43 | def train(args, config: dict): 44 | train_transform = transforms.Compose( 45 | [ 46 | # transforms.RandomRotation((-10, 10)), 47 | RandomResizedCropWithUV((288, 288), scale=(0.08, 1.0), antialias=True), 48 | transforms.ToImageTensor(), 49 | transforms.ToDtype(torch.float32), 50 | ] 51 | ) 52 | 53 | if config["use_minio"]: 54 | from docmae.data.doc3d_minio import Doc3D 55 | 56 | train_files = (Path(config["dataset_path"]) / "train.txt").read_text().split() 57 | val_files = (Path(config["dataset_path"]) / "val.txt").read_text().split() 58 | train_dataset = Doc3D(train_files, train_transform) 59 | val_dataset = Doc3D(val_files, train_transform) 60 | else: 61 | from docmae.data.doc3d import Doc3D 62 | 63 | train_dataset = Doc3D(Path(config["dataset_path"]), "train", train_transform) 64 | val_dataset = Doc3D(Path(config["dataset_path"]), "val", train_transform) 65 | train_loader = DataLoader(train_dataset, config["batch_size"], shuffle=True, num_workers=config["num_workers"], pin_memory=True) 66 | val_loader = DataLoader(val_dataset, config["batch_size"], shuffle=False, num_workers=config["num_workers"], pin_memory=True) 67 | 68 | callback_list = [ 69 | LearningRateMonitor(logging_interval="step"), 70 | ModelCheckpoint( 71 | dirpath=args.model_output_dir, 72 | filename="epoch_{epoch:d}", 73 | monitor="val/loss", 74 | mode="min", 75 | save_top_k=2, 76 | ), 77 | ] 78 | 79 | logger = TensorBoardLogger(save_dir=args.tensorboard_dir, log_graph=False, default_hp_metric=False) 80 | 81 | trainer = L.Trainer( 82 | logger=logger, 83 | callbacks=callback_list, 84 | accelerator="gpu", 85 | devices=1, 86 | max_epochs=config["epochs"], 87 | num_sanity_val_steps=1, 88 | enable_progress_bar=False, 89 | ) 90 | 91 | pretrained_config = ViTMAEConfig.from_pretrained(config["mae_path"]) 92 | mae_encoder = ViTMAEModel.from_pretrained(config["mae_path"], mask_ratio=0) 93 | mae_decoder = ViTMAEDecoder(pretrained_config, mae_encoder.embeddings.num_patches) 94 | 95 | model = DocMAE(mae_encoder, mae_decoder, config) 96 | 97 | trainer.fit(model, train_loader, val_loader) 98 | 99 | 100 | def main(): 101 | args = parse_arguments() 102 | setup_logging(log_level=args.log_level, log_dir=args.log_dir) 103 | 104 | assert args.config.endswith(".json") 105 | 106 | # Save config for training traceability and load config parameters 107 | config_file = Path(args.model_output_dir) / "fine_tune_config.json" 108 | config = json.loads(Path(args.config).read_text()) 109 | shutil.copyfile(args.config, config_file) 110 | train(args, config) 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /docmae/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | from pathlib import Path 5 | 6 | from matplotlib import pyplot as plt 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torchvision.utils import flow_to_image 10 | import torchvision.transforms.v2 as transforms 11 | from lightning import Trainer 12 | from lightning.pytorch.callbacks import BasePredictionWriter 13 | 14 | from docmae import setup_logging 15 | from docmae.data.list_dataset import ListDataset 16 | from docmae.models.doctr_custom import DocTrOrig 17 | from docmae.models.doctr_plus import DocTrPlus 18 | from docmae.models.rectification import Rectification 19 | from docmae.models.upscale import UpscaleRAFT, UpscaleTransposeConv, UpscaleInterpolate 20 | from extractor import BasicEncoder 21 | 22 | 23 | class CustomWriter(BasePredictionWriter): 24 | def __init__(self, output_dir: str | Path, save_bm: bool, save_mask: bool, write_interval="batch"): 25 | super().__init__(write_interval) 26 | self.output_dir = Path(output_dir) 27 | self.output_dir.mkdir(exist_ok=True) 28 | self.save_bm = save_bm 29 | self.save_mask = save_mask 30 | 31 | def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx): 32 | rectified, bm, mask = prediction 33 | for i, idx in enumerate(batch_indices): 34 | rectified_ = rectified[i] 35 | 36 | plt.imsave(self.output_dir / f"{idx}_rect.jpg", rectified_.permute(1, 2, 0).clip(0, 255).cpu().numpy() / 255) 37 | if self.save_bm: 38 | bm_ = bm[i] 39 | plt.imsave(self.output_dir / f"{idx}_bm.jpg", flow_to_image(bm_).permute(1, 2, 0).numpy()) 40 | if self.save_mask and mask is not None: 41 | mask_ = mask[i] 42 | plt.imsave(self.output_dir / f"{idx}_mask.png", mask_[0].cpu().numpy(), cmap="gray") 43 | 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | def parse_arguments(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("-c", "--config", type=Path, help="Model config file") 51 | parser.add_argument( 52 | "-ll", 53 | "--log-level", 54 | type=str, 55 | default="INFO", 56 | choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], 57 | help="config file for training parameters", 58 | ) 59 | parser.add_argument("-l", "--log-dir", type=Path, default="", help="folder to store log files") 60 | parser.add_argument("-d", "--data-path", type=Path, default="", help="Dataset directory") 61 | parser.add_argument("-s", "--split", type=str, default="test", help="Dataset split") 62 | parser.add_argument("-o", "--output-path", type=Path, default="", help="Directory to save inference results") 63 | parser.add_argument("-m", "--ckpt-path", type=Path, default="", help="Checkpoint path") 64 | return parser.parse_args() 65 | 66 | 67 | def inference(args, config): 68 | model = DocTrPlus(config["model"]) 69 | 70 | hidden_dim = config["model"]["hidden_dim"] 71 | backbone = BasicEncoder(output_dim=hidden_dim, norm_fn="instance") 72 | upscale_type = config["model"]["upscale_type"] 73 | if upscale_type == "raft": 74 | upscale_module = UpscaleRAFT(8, hidden_dim) 75 | elif upscale_type == "transpose_conv": 76 | upscale_module = UpscaleTransposeConv(hidden_dim, hidden_dim // 2) 77 | elif upscale_type == "interpolate": 78 | upscale_module = UpscaleInterpolate(hidden_dim, hidden_dim // 2) 79 | else: 80 | raise NotImplementedError 81 | model = Rectification.load_from_checkpoint(args.ckpt_path, "cuda", model=model, backbone=backbone, upscale=upscale_module, config=config) 82 | 83 | inference_transform = transforms.Compose( 84 | [ 85 | transforms.ToImageTensor(), 86 | transforms.ToDtype(torch.float32), 87 | ] 88 | ) 89 | dataset = ListDataset(args.data_path, args.split, inference_transform) 90 | dataloader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False) # todo implement collate for arbitrary images 91 | 92 | writer_callback = CustomWriter(args.output_path, save_bm=True, save_mask=True) 93 | 94 | trainer = Trainer( 95 | callbacks=[writer_callback], 96 | accelerator="cuda", 97 | # limit_predict_batches=160, 98 | ) 99 | trainer.predict(model, dataloader, return_predictions=False) 100 | 101 | 102 | def main(): 103 | args = parse_arguments() 104 | setup_logging(log_level=args.log_level, log_dir=args.log_dir) 105 | 106 | assert args.config.suffix == ".json" 107 | config = json.loads(args.config.read_text()) 108 | inference(args, config) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /docmae/models/docmae.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.nn import L1Loss 6 | from torch.optim.lr_scheduler import OneCycleLR 7 | from torch.utils.tensorboard import SummaryWriter 8 | from torchvision.transforms import transforms 9 | from transformers.models.vit_mae.modeling_vit_mae import ViTMAEDecoder, ViTMAEModel 10 | 11 | from docmae.models.upscale import UpscaleRAFT, UpscaleTransposeConv, UpscaleInterpolate, coords_grid 12 | 13 | PATCH_SIZE = 16 14 | 15 | 16 | class DocMAE(L.LightningModule): 17 | tb_log: SummaryWriter 18 | 19 | def __init__( 20 | self, 21 | encoder: ViTMAEModel, 22 | decoder: ViTMAEDecoder, 23 | hparams, 24 | ): 25 | super().__init__() 26 | self.example_input_array = torch.rand(1, 3, 288, 288) 27 | self.coodslar = self.initialize_flow(self.example_input_array) 28 | 29 | self.encoder = encoder 30 | self.decoder = decoder 31 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 32 | 33 | self.P = PATCH_SIZE 34 | self.hidden_dim = hparams["hidden_dim"] 35 | self.upscale_type = hparams["upscale_type"] 36 | self.freeze_backbone = hparams["freeze_backbone"] 37 | 38 | if self.upscale_type == "raft": 39 | self.upscale_module = UpscaleRAFT(self.hidden_dim, hidden_dim=256) 40 | elif self.upscale_type == "transpose_conv": 41 | self.upscale_module = UpscaleTransposeConv(self.hidden_dim) 42 | elif self.upscale_type == "interpolate": 43 | self.upscale_module = UpscaleInterpolate() 44 | else: 45 | raise NotImplementedError 46 | self.loss = L1Loss() 47 | self.save_hyperparameters(hparams) 48 | if self.freeze_backbone: 49 | for p in self.encoder.parameters(): 50 | p.requires_grad = False 51 | for p in self.decoder.parameters(): 52 | p.requires_grad = False 53 | 54 | def on_fit_start(self): 55 | self.coodslar = self.coodslar.to(self.device) 56 | 57 | self.tb_log = self.logger.experiment 58 | additional_metrics = ["val/loss"] 59 | self.logger.log_hyperparams(self.hparams, {**{key: 0 for key in additional_metrics}}) 60 | 61 | def configure_optimizers(self): 62 | """ 63 | calling setup_optimizers for which all missing parameters must be registered with gin 64 | 65 | Returns: 66 | dictionary defining optimizer, learning rate scheduler and value to monitor as expected by pytorch lightning 67 | """ 68 | 69 | optimizer = torch.optim.Adam(self.parameters()) 70 | scheduler = OneCycleLR( 71 | optimizer, 1e-4, epochs=self.hparams["epochs"], steps_per_epoch=200_000 // self.hparams["batch_size"] 72 | ) 73 | return { 74 | "optimizer": optimizer, 75 | "lr_scheduler": scheduler, 76 | "monitor": "train/loss", 77 | } 78 | 79 | def forward(self, inputs): 80 | """ 81 | Runs inference: image_processing, encoder, decoder, layer norm, flow head 82 | Args: 83 | inputs: image tensor of shape [B, C, H, W] 84 | Returns: flow displacement 85 | """ 86 | inputs = self.normalize(inputs) 87 | bottleneck = self.encoder.forward(inputs) 88 | fmap = self.decoder( 89 | bottleneck.last_hidden_state, 90 | bottleneck.ids_restore, 91 | output_hidden_states=True, 92 | return_dict=True, 93 | ) 94 | 95 | last_hidden_state = fmap.hidden_states[-1][:, 1:, :] # remove CLS token 96 | fmap = last_hidden_state # layer norm 97 | # B x 18*18 x 512 98 | # -> B x 512 x 18 x 18 (B x 256 x 36 x 36) 99 | fmap = fmap.permute(0, 2, 1) 100 | fmap = fmap.reshape(-1, self.hidden_dim, 18, 18) 101 | upflow = self.flow(fmap) 102 | return upflow 103 | 104 | def training_step(self, batch): 105 | image = batch["image"] * batch["mask"].unsqueeze(1) / 255 106 | flow = batch["bm"] 107 | batch_size = len(image) 108 | 109 | # training image sanity check 110 | if self.global_step == 0: 111 | zeros = torch.zeros((batch_size, 1, 288, 288), device=self.device) 112 | def viz_flow(img): return (img / 448 - 0.5) * 2 113 | self.tb_log.add_images("train/image", image, global_step=self.global_step) 114 | self.tb_log.add_images("val/flow", torch.cat((viz_flow(flow), zeros), dim=1), global_step=self.global_step) 115 | 116 | dflow = self.forward(image) 117 | flow_pred = self.coodslar + dflow 118 | 119 | # log metrics 120 | loss = self.loss(flow, flow_pred) 121 | 122 | self.log("train/loss", loss, on_step=True, on_epoch=True, batch_size=batch_size) 123 | 124 | return loss 125 | 126 | def on_after_backward(self): 127 | global_step = self.global_step 128 | # if self.global_step % 100 == 0: 129 | # for name, param in self.model.named_parameters(): 130 | # self.tb_log.add_histogram(name, param, global_step) 131 | # if param.requires_grad: 132 | # self.tb_log.add_histogram(f"{name}_grad", param.grad, global_step) 133 | 134 | def on_validation_start(self): 135 | self.coodslar = self.coodslar.to(self.device) 136 | self.tb_log = self.logger.experiment 137 | 138 | def validation_step(self, val_batch, batch_idx): 139 | image = val_batch["image"] * val_batch["mask"].unsqueeze(1) / 255 140 | flow_target = val_batch["bm"] 141 | batch_size = len(image) 142 | 143 | # training image sanity check 144 | if self.global_step == 0: 145 | self.tb_log.add_images("train/image", image, global_step=self.global_step) 146 | dflow = self.forward(image) 147 | flow_pred = self.coodslar + dflow 148 | 149 | # log metrics 150 | loss = self.loss(flow_target, flow_pred) 151 | 152 | self.log("val/loss", loss, on_epoch=True, batch_size=batch_size) 153 | 154 | zeros = torch.zeros((batch_size, 1, 288, 288), device=self.device) 155 | def viz_flow(img): return (img / 448 - 0.5) * 2 156 | if batch_idx == 0 and self.global_step == 0: 157 | self.tb_log.add_images("val/image", image, global_step=self.global_step) 158 | self.tb_log.add_images("val/flow", torch.cat((viz_flow(flow_target), zeros), dim=1), global_step=self.global_step) 159 | 160 | if batch_idx == 0: 161 | self.tb_log.add_images("val/flow_pred", torch.cat((viz_flow(flow_pred), zeros), dim=1), global_step=self.global_step) 162 | 163 | bm_ = viz_flow(flow_pred) 164 | bm_ = bm_.permute((0, 2, 3, 1)) 165 | img_ = image 166 | uw = F.grid_sample(img_, bm_, align_corners=False) 167 | 168 | self.tb_log.add_images("val/unwarped", uw, global_step=self.global_step) 169 | 170 | def on_test_start(self): 171 | self.tb_log = self.logger.experiment 172 | 173 | def initialize_flow(self, img): 174 | N, C, H, W = img.shape 175 | coodslar = coords_grid(N, H, W).to(img.device) 176 | # coords0 = coords_grid(N, H // self.P, W // self.P).to(img.device) 177 | # coords1 = coords_grid(N, H // self.P, W // self.P).to(img.device) 178 | 179 | return coodslar # , coords0, coords1 180 | 181 | def flow(self, fmap): 182 | # convex upsample based on fmap 183 | flow_up = self.upscale_module(fmap) 184 | return flow_up 185 | -------------------------------------------------------------------------------- /docmae/models/doctr.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT Licence https://github.com/fh2019ustc/DocTr 3 | """ 4 | import copy 5 | from typing import Optional 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torchvision.transforms import transforms 11 | 12 | from docmae.models.transformer import build_position_encoding 13 | 14 | 15 | class AttnLayer(nn.Module): 16 | def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): 17 | super().__init__() 18 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 19 | self.multihead_attn_list = nn.ModuleList( 20 | [copy.deepcopy(nn.MultiheadAttention(d_model, nhead, dropout=dropout)) for i in range(2)] 21 | ) 22 | # Implementation of Feedforward model 23 | self.linear1 = nn.Linear(d_model, dim_feedforward) 24 | self.dropout = nn.Dropout(dropout) 25 | self.linear2 = nn.Linear(dim_feedforward, d_model) 26 | 27 | self.norm1 = nn.LayerNorm(d_model) 28 | self.norm2_list = nn.ModuleList([copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]) 29 | 30 | self.norm3 = nn.LayerNorm(d_model) 31 | self.dropout1 = nn.Dropout(dropout) 32 | self.dropout2_list = nn.ModuleList([copy.deepcopy(nn.Dropout(dropout)) for i in range(2)]) 33 | self.dropout3 = nn.Dropout(dropout) 34 | 35 | self.activation = _get_activation_fn(activation) 36 | self.normalize_before = normalize_before 37 | 38 | def with_pos_embed(self, tensor, pos: Optional[torch.Tensor]): 39 | return tensor if pos is None else tensor + pos 40 | 41 | def forward_post( 42 | self, 43 | tgt, # query embed 44 | memory_list, # imgf 45 | tgt_mask=None, 46 | memory_mask=None, 47 | tgt_key_padding_mask=None, 48 | memory_key_padding_mask=None, 49 | pos=None, 50 | memory_pos=None, 51 | ): 52 | q = k = self.with_pos_embed(tgt, pos) 53 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] 54 | tgt = tgt + self.dropout1(tgt2) 55 | tgt = self.norm1(tgt) 56 | for memory, multihead_attn, norm2, dropout2, m_pos in zip( 57 | memory_list, self.multihead_attn_list, self.norm2_list, self.dropout2_list, memory_pos 58 | ): 59 | tgt2 = multihead_attn( 60 | query=self.with_pos_embed(tgt, pos), 61 | key=self.with_pos_embed(memory, m_pos), 62 | value=memory, 63 | attn_mask=memory_mask, 64 | key_padding_mask=memory_key_padding_mask, 65 | )[0] 66 | tgt = tgt + dropout2(tgt2) 67 | tgt = norm2(tgt) 68 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 69 | tgt = tgt + self.dropout3(tgt2) 70 | tgt = self.norm3(tgt) 71 | return tgt 72 | 73 | def forward_pre( 74 | self, 75 | tgt, 76 | memory, 77 | tgt_mask=None, 78 | memory_mask=None, 79 | tgt_key_padding_mask=None, 80 | memory_key_padding_mask=None, 81 | pos=None, 82 | memory_pos=None, 83 | ): 84 | tgt2 = self.norm1(tgt) 85 | q = k = self.with_pos_embed(tgt2, pos) 86 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] 87 | tgt = tgt + self.dropout1(tgt2) 88 | tgt2 = self.norm2(tgt) 89 | tgt2 = self.multihead_attn( 90 | query=self.with_pos_embed(tgt2, pos), 91 | key=self.with_pos_embed(memory, memory_pos), 92 | value=memory, 93 | attn_mask=memory_mask, 94 | key_padding_mask=memory_key_padding_mask, 95 | )[0] 96 | tgt = tgt + self.dropout2(tgt2) 97 | tgt2 = self.norm3(tgt) 98 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 99 | tgt = tgt + self.dropout3(tgt2) 100 | return tgt 101 | 102 | def forward( 103 | self, 104 | tgt, 105 | memory_list, 106 | tgt_mask=None, 107 | memory_mask=None, 108 | tgt_key_padding_mask=None, 109 | memory_key_padding_mask=None, 110 | pos=None, 111 | memory_pos=None, 112 | ): 113 | if self.normalize_before: 114 | return self.forward_pre( 115 | tgt, memory_list, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, memory_pos 116 | ) 117 | return self.forward_post( 118 | tgt, memory_list, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, memory_pos 119 | ) 120 | 121 | 122 | def _get_clones(module, N): 123 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 124 | 125 | 126 | def _get_activation_fn(activation): 127 | """Return an activation function given a string""" 128 | if activation == "relu": 129 | return F.relu 130 | if activation == "gelu": 131 | return F.gelu 132 | if activation == "glu": 133 | return F.glu 134 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 135 | 136 | 137 | class TransDecoder(nn.Module): 138 | def __init__(self, num_attn_layers, hidden_dim=128): 139 | super(TransDecoder, self).__init__() 140 | attn_layer = AttnLayer(hidden_dim) 141 | self.layers = _get_clones(attn_layer, num_attn_layers) 142 | self.position_embedding = build_position_encoding(hidden_dim) 143 | 144 | def forward(self, imgf, query_embed): 145 | pos = self.position_embedding( 146 | torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool().cuda() 147 | ) # torch.Size([1, 128, 36, 36]) 148 | 149 | bs, c, h, w = imgf.shape 150 | imgf = imgf.flatten(2).permute(2, 0, 1) 151 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 152 | pos = pos.flatten(2).permute(2, 0, 1) 153 | 154 | for layer in self.layers: 155 | query_embed = layer(query_embed, [imgf], pos=pos, memory_pos=[pos]) 156 | query_embed = query_embed.permute(1, 2, 0).reshape(bs, c, h, w) 157 | 158 | return query_embed 159 | 160 | 161 | class TransEncoder(nn.Module): 162 | def __init__(self, num_attn_layers, hidden_dim=128): 163 | super(TransEncoder, self).__init__() 164 | attn_layer = AttnLayer(hidden_dim) 165 | self.layers = _get_clones(attn_layer, num_attn_layers) 166 | self.position_embedding = build_position_encoding(hidden_dim) 167 | 168 | def forward(self, imgf): 169 | pos = self.position_embedding( 170 | torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool().cuda() 171 | ) # torch.Size([1, 128, 36, 36]) 172 | bs, c, h, w = imgf.shape 173 | imgf = imgf.flatten(2).permute(2, 0, 1) 174 | pos = pos.flatten(2).permute(2, 0, 1) 175 | 176 | for layer in self.layers: 177 | imgf = layer(imgf, [imgf], pos=pos, memory_pos=[pos, pos]) 178 | imgf = imgf.permute(1, 2, 0).reshape(bs, c, h, w) 179 | 180 | return imgf 181 | 182 | 183 | class FlowHead(nn.Module): 184 | def __init__(self, input_dim=128, hidden_dim=256): 185 | super(FlowHead, self).__init__() 186 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 187 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 188 | self.relu = nn.ReLU(inplace=True) 189 | 190 | def forward(self, x): 191 | return self.conv2(self.relu(self.conv1(x))) 192 | 193 | 194 | class DocTr(nn.Module): 195 | def __init__(self, config): 196 | super(DocTr, self).__init__() 197 | self.num_attn_layers = config["num_attn_layers"] 198 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 199 | 200 | hdim = config["hidden_dim"] 201 | 202 | self.trans_encoder = TransEncoder(self.num_attn_layers, hidden_dim=hdim) 203 | self.trans_decoder = TransDecoder(self.num_attn_layers, hidden_dim=hdim) 204 | self.query_embed = nn.Embedding(1296, hdim) 205 | 206 | self.flow_head = FlowHead(hdim, hidden_dim=hdim) 207 | 208 | def forward(self, backbone_features): 209 | """ 210 | image: segmented image 211 | """ 212 | fmap = torch.relu(backbone_features) 213 | 214 | fmap = self.trans_encoder(fmap) 215 | fmap = self.trans_decoder(fmap, self.query_embed.weight) 216 | 217 | dflow = self.flow_head(fmap) 218 | 219 | return {"flow": dflow, "feature_map": fmap} 220 | -------------------------------------------------------------------------------- /docmae/models/doctr_custom.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is a modified version of DocTr where the transformer module is replaced by the original transformer model. 3 | MIT Licence https://github.com/fh2019ustc/DocTr 4 | """ 5 | import copy 6 | from typing import Optional 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from torchvision.transforms import transforms 12 | 13 | from docmae.models.transformer import build_position_encoding 14 | 15 | 16 | class SelfAttnLayer(nn.Module): 17 | def __init__( 18 | self, 19 | d_model, 20 | nhead=8, 21 | dim_feedforward=2048, 22 | dropout=0.1, 23 | activation="relu", 24 | extra_attention=False, 25 | ): 26 | super().__init__() 27 | self.extra_attention = extra_attention 28 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 29 | self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 30 | # Implementation of Feedforward model 31 | self.linear1 = nn.Linear(d_model, dim_feedforward) 32 | self.dropout = nn.Dropout(dropout) 33 | self.linear2 = nn.Linear(dim_feedforward, d_model) 34 | 35 | self.norm1 = nn.LayerNorm(d_model) 36 | self.norm2 = nn.LayerNorm(d_model) 37 | self.norm3 = nn.LayerNorm(d_model) 38 | self.dropout1 = nn.Dropout(dropout) 39 | self.dropout2 = nn.Dropout(dropout) 40 | self.dropout3 = nn.Dropout(dropout) 41 | 42 | self.activation = _get_activation_fn(activation) 43 | 44 | def forward_post(self, query, key, value, cross_kv=None, pos_list=[None, None, None]): 45 | q = self.with_pos_embed(query, pos_list[0]) 46 | k = self.with_pos_embed(key, pos_list[1]) 47 | v = self.with_pos_embed(value, pos_list[2]) 48 | 49 | tgt2 = self.self_attn(q, k, v, need_weights=False)[0] 50 | tgt = query + self.dropout1(tgt2) # query and key should be equal, using it for skip connection 51 | tgt = self.norm1(tgt) 52 | 53 | if self.extra_attention: 54 | if cross_kv is not None: 55 | q = tgt 56 | k = v = cross_kv 57 | else: 58 | q = k = v = tgt 59 | q = self.with_pos_embed(q, pos_list[0]) 60 | k = self.with_pos_embed(k, pos_list[1]) 61 | v = self.with_pos_embed(v, pos_list[2]) 62 | 63 | tgt2 = self.cross_attn(q, k, v, need_weights=False)[0] 64 | tgt = tgt + self.dropout2(tgt2) 65 | tgt = self.norm2(tgt) 66 | 67 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 68 | tgt = tgt + self.dropout3(tgt2) 69 | tgt = self.norm3(tgt) 70 | return tgt 71 | 72 | def forward(self, query, key, value, cross_kv, pos_list): 73 | return self.forward_post(query, key, value, cross_kv=cross_kv, pos_list=pos_list) 74 | 75 | def with_pos_embed(self, tensor, pos: Optional[torch.Tensor]): 76 | return tensor if pos is None else tensor + pos 77 | 78 | 79 | class CrossAttnLayer(nn.Module): 80 | def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu"): 81 | super().__init__() 82 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 83 | self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 84 | 85 | # Implementation of Feedforward model 86 | self.linear1 = nn.Linear(d_model, dim_feedforward) 87 | self.dropout = nn.Dropout(dropout) 88 | self.linear2 = nn.Linear(dim_feedforward, d_model) 89 | 90 | self.norm1 = nn.LayerNorm(d_model) 91 | self.norm2 = nn.LayerNorm(d_model) 92 | self.norm3 = nn.LayerNorm(d_model) 93 | self.dropout1 = nn.Dropout(dropout) 94 | self.dropout2 = nn.Dropout(dropout) 95 | self.dropout3 = nn.Dropout(dropout) 96 | 97 | self.activation = _get_activation_fn(activation) 98 | 99 | def forward_post( 100 | self, query, key, value, cross_kv, pos_list=[None, None, None] # query embedding # encoder features # positional encoding 101 | ): 102 | q = self.with_pos_embed(query, pos_list[0]) 103 | k = self.with_pos_embed(key, pos_list[1]) 104 | v = self.with_pos_embed(value, pos_list[2]) 105 | 106 | tgt2 = self.self_attn(q, k, v, need_weights=False)[0] 107 | tgt = q + self.dropout1(tgt2) 108 | tgt = self.norm1(tgt) 109 | 110 | q = tgt 111 | k, v = cross_kv 112 | q = self.with_pos_embed(q, pos_list[0]) 113 | k = self.with_pos_embed(k, pos_list[1]) 114 | v = self.with_pos_embed(v, pos_list[2]) 115 | 116 | tgt2 = self.cross_attn(q, k, v, need_weights=False)[0] 117 | tgt = tgt + self.dropout2(tgt2) 118 | tgt = self.norm2(tgt) 119 | 120 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 121 | tgt = tgt + self.dropout3(tgt2) 122 | tgt = self.norm3(tgt) 123 | return tgt 124 | 125 | def forward(self, query, key, value, cross_kv, pos_list): 126 | return self.forward_post(query, key, value, cross_kv, pos_list=pos_list) 127 | 128 | def with_pos_embed(self, tensor, pos: Optional[torch.Tensor]): 129 | return tensor if pos is None else tensor + pos 130 | 131 | 132 | def _get_clones(module, N): 133 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 134 | 135 | 136 | def _get_activation_fn(activation): 137 | """Return an activation function given a string""" 138 | if activation == "relu": 139 | return F.relu 140 | if activation == "gelu": 141 | return F.gelu 142 | if activation == "glu": 143 | return F.glu 144 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 145 | 146 | 147 | class TransDecoder(nn.Module): 148 | def __init__(self, num_attn_layers, hidden_dim=128, pos_encoding_before=True, pos_encoding_value=True): 149 | super(TransDecoder, self).__init__() 150 | self.pos_encoding_before = pos_encoding_before 151 | self.pos_encoding_value = pos_encoding_value 152 | attn_layer = CrossAttnLayer(hidden_dim) 153 | self.layers = _get_clones(attn_layer, num_attn_layers) 154 | position_embedding = build_position_encoding(hidden_dim) 155 | self.pos = position_embedding(torch.ones((1, 36, 36), dtype=torch.bool).cuda()) # torch.Size([1, 128, 36, 36]) 156 | self.pos = self.pos.flatten(2).permute(2, 0, 1) 157 | 158 | def forward(self, imgf, query_embed): 159 | bs, c, h, w = imgf.shape 160 | imgf = imgf.flatten(2).permute(2, 0, 1) 161 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 162 | 163 | cross_k = cross_v = imgf 164 | for i, layer in enumerate(self.layers): 165 | query = key = value = query_embed 166 | 167 | if self.pos_encoding_before: 168 | pos_list = [None] * 3 169 | if i == 0: 170 | query = query + self.pos 171 | key = key + self.pos 172 | # cross_k = cross_k + self.pos # in orig transformer PE is not added here at all 173 | if self.pos_encoding_value: 174 | value = value + self.pos 175 | # cross_v = cross_v + self.pos 176 | else: # add PE every block, also added to cross k,v 177 | if self.pos_encoding_value: 178 | pos_list = [self.pos] * 3 179 | else: 180 | pos_list = [self.pos, self.pos, None] 181 | query_embed = layer(query=query, key=key, value=value, cross_kv=[cross_k, cross_v], pos_list=pos_list) 182 | query_embed = query_embed.permute(1, 2, 0).reshape(bs, c, h, w) 183 | 184 | return query_embed 185 | 186 | 187 | class TransEncoder(nn.Module): 188 | def __init__( 189 | self, 190 | num_attn_layers, 191 | hidden_dim=128, 192 | extra_attention=False, 193 | extra_skip=False, 194 | pos_encoding_before=True, 195 | pos_encoding_value=True, 196 | ): 197 | super(TransEncoder, self).__init__() 198 | if not extra_attention: 199 | assert not extra_skip 200 | self.extra_skip = extra_skip 201 | self.pos_encoding_before = pos_encoding_before 202 | self.pos_encoding_value = pos_encoding_value 203 | attn_layer = SelfAttnLayer(hidden_dim, extra_attention=extra_attention) 204 | self.layers = _get_clones(attn_layer, num_attn_layers) 205 | position_embedding = build_position_encoding(hidden_dim) 206 | # run here because sin PE is not learned 207 | self.pos = position_embedding(torch.ones((1, 36, 36), dtype=torch.bool).cuda()) # torch.Size([1, 128, 36, 36]) 208 | self.pos = self.pos.flatten(2).permute(2, 0, 1) 209 | 210 | def forward(self, imgf): 211 | bs, c, h, w = imgf.shape 212 | imgf = imgf.flatten(2).permute(2, 0, 1) 213 | 214 | for i, layer in enumerate(self.layers): 215 | query = key = value = imgf 216 | 217 | if self.pos_encoding_before: 218 | pos_list = [None] * 3 219 | if i == 0: 220 | query = query + self.pos 221 | key = key + self.pos 222 | if self.pos_encoding_value: 223 | value = value + self.pos 224 | else: # add PE every block 225 | if self.pos_encoding_value: 226 | pos_list = [self.pos] * 3 227 | else: 228 | pos_list = [self.pos, self.pos, None] 229 | cross_kv = query if self.extra_skip else None # extra skip connection from block input before blockwise PE 230 | imgf = layer(query=query, key=key, value=value, cross_kv=cross_kv, pos_list=pos_list) 231 | imgf = imgf.permute(1, 2, 0).reshape(bs, c, h, w) 232 | 233 | return imgf 234 | 235 | 236 | class FlowHead(nn.Module): 237 | def __init__(self, input_dim=128, hidden_dim=256): 238 | super(FlowHead, self).__init__() 239 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 240 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 241 | self.relu = nn.ReLU(inplace=True) 242 | 243 | def forward(self, x): 244 | return self.conv2(self.relu(self.conv1(x))) 245 | 246 | 247 | class DocTrOrig(nn.Module): 248 | def __init__(self, config): 249 | super(DocTrOrig, self).__init__() 250 | self.num_attn_layers = config["num_attn_layers"] 251 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 252 | 253 | hdim = config["hidden_dim"] 254 | 255 | self.trans_encoder = TransEncoder( 256 | self.num_attn_layers, 257 | hidden_dim=hdim, 258 | extra_attention=config["extra_attention"], # corresponds to cross attention block in decoder 259 | extra_skip=config["extra_skip"], # k,v comes from block input (q) 260 | pos_encoding_before=not config["add_pe_every_block"], # only add PE once before encoder blocks 261 | pos_encoding_value=not config["no_pe_for_value"], 262 | ) 263 | self.trans_decoder = TransDecoder( 264 | self.num_attn_layers, 265 | hidden_dim=hdim, 266 | pos_encoding_before=not config["add_pe_every_block"], 267 | pos_encoding_value=not config["no_pe_for_value"], 268 | ) 269 | self.query_embed = nn.Embedding(1296, hdim) 270 | 271 | self.flow_head = FlowHead(hdim, hidden_dim=hdim) 272 | 273 | def forward(self, backbone_features): 274 | """ 275 | image: segmented image 276 | """ 277 | fmap = torch.relu(backbone_features) 278 | 279 | fmap = self.trans_encoder(fmap) 280 | fmap = self.trans_decoder(fmap, self.query_embed.weight) 281 | 282 | dflow = self.flow_head(fmap) 283 | 284 | return {"flow": dflow, "feature_map": fmap} 285 | -------------------------------------------------------------------------------- /docmae/models/doctr_plus.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of DocTr++ based on DocTr code 3 | MIT Licence https://github.com/fh2019ustc/DocTr 4 | """ 5 | import copy 6 | from typing import Optional 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from torchvision.transforms import transforms 12 | 13 | from docmae.models.transformer import build_position_encoding 14 | 15 | 16 | class SelfAttnLayer(nn.Module): 17 | def __init__( 18 | self, 19 | d_model, 20 | nhead=8, 21 | dim_feedforward=2048, 22 | dropout=0.1, 23 | activation="relu", 24 | extra_attention=False, 25 | ): 26 | super().__init__() 27 | self.extra_attention = extra_attention 28 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 29 | self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 30 | # Implementation of Feedforward model 31 | self.linear1 = nn.Linear(d_model, dim_feedforward) 32 | self.dropout = nn.Dropout(dropout) 33 | self.linear2 = nn.Linear(dim_feedforward, d_model) 34 | 35 | self.norm1 = nn.LayerNorm(d_model) 36 | self.norm2 = nn.LayerNorm(d_model) 37 | self.norm3 = nn.LayerNorm(d_model) 38 | self.dropout1 = nn.Dropout(dropout) 39 | self.dropout2 = nn.Dropout(dropout) 40 | self.dropout3 = nn.Dropout(dropout) 41 | 42 | self.activation = _get_activation_fn(activation) 43 | 44 | def forward_post(self, query, key, value, cross_kv=None, pos_list=[None, None, None]): 45 | q = self.with_pos_embed(query, pos_list[0]) 46 | k = self.with_pos_embed(key, pos_list[1]) 47 | v = self.with_pos_embed(value, pos_list[2]) 48 | 49 | tgt2 = self.self_attn(q, k, v, need_weights=False)[0] 50 | tgt = query + self.dropout1(tgt2) # query and key should be equal, using it for skip connection 51 | tgt = self.norm1(tgt) 52 | 53 | if self.extra_attention: 54 | if cross_kv is not None: 55 | q = tgt 56 | k = v = cross_kv 57 | else: 58 | q = k = v = tgt 59 | q = self.with_pos_embed(q, pos_list[0]) 60 | k = self.with_pos_embed(k, pos_list[1]) 61 | v = self.with_pos_embed(v, pos_list[2]) 62 | 63 | tgt2 = self.cross_attn(q, k, v, need_weights=False)[0] 64 | tgt = tgt + self.dropout2(tgt2) 65 | tgt = self.norm2(tgt) 66 | 67 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 68 | tgt = tgt + self.dropout3(tgt2) 69 | tgt = self.norm3(tgt) 70 | return tgt 71 | 72 | def forward(self, query, key, value, cross_kv, pos_list): 73 | return self.forward_post(query, key, value, cross_kv=cross_kv, pos_list=pos_list) 74 | 75 | def with_pos_embed(self, tensor, pos: Optional[torch.Tensor]): 76 | return tensor if pos is None else tensor + pos 77 | 78 | 79 | class CrossAttnLayer(nn.Module): 80 | def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu"): 81 | super().__init__() 82 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 83 | self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 84 | 85 | # Implementation of Feedforward model 86 | self.linear1 = nn.Linear(d_model, dim_feedforward) 87 | self.dropout = nn.Dropout(dropout) 88 | self.linear2 = nn.Linear(dim_feedforward, d_model) 89 | 90 | self.norm1 = nn.LayerNorm(d_model) 91 | self.norm2 = nn.LayerNorm(d_model) 92 | self.norm3 = nn.LayerNorm(d_model) 93 | self.dropout1 = nn.Dropout(dropout) 94 | self.dropout2 = nn.Dropout(dropout) 95 | self.dropout3 = nn.Dropout(dropout) 96 | 97 | self.activation = _get_activation_fn(activation) 98 | 99 | def forward_post( 100 | self, 101 | query, 102 | key, 103 | value, 104 | cross_kv, 105 | pos_list=[None, None, None] 106 | # query embedding # encoder features # positional encoding 107 | ): 108 | q = self.with_pos_embed(query, pos_list[0]) 109 | k = self.with_pos_embed(key, pos_list[1]) 110 | v = self.with_pos_embed(value, pos_list[2]) 111 | 112 | tgt2 = self.self_attn(q, k, v, need_weights=False)[0] 113 | tgt = q + self.dropout1(tgt2) 114 | tgt = self.norm1(tgt) 115 | 116 | q = tgt 117 | k, v = cross_kv 118 | q = self.with_pos_embed(q, pos_list[0]) 119 | k = self.with_pos_embed(k, pos_list[1]) 120 | v = self.with_pos_embed(v, pos_list[2]) 121 | 122 | tgt2 = self.cross_attn(q, k, v, need_weights=False)[0] 123 | tgt = tgt + self.dropout2(tgt2) 124 | tgt = self.norm2(tgt) 125 | 126 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 127 | tgt = tgt + self.dropout3(tgt2) 128 | tgt = self.norm3(tgt) 129 | return tgt 130 | 131 | def forward(self, query, key, value, cross_kv, pos_list): 132 | return self.forward_post(query, key, value, cross_kv, pos_list=pos_list) 133 | 134 | def with_pos_embed(self, tensor, pos: Optional[torch.Tensor]): 135 | return tensor if pos is None else tensor + pos 136 | 137 | 138 | def _get_clones(module, N): 139 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 140 | 141 | 142 | def _get_activation_fn(activation): 143 | """Return an activation function given a string""" 144 | if activation == "relu": 145 | return F.relu 146 | if activation == "gelu": 147 | return F.gelu 148 | if activation == "glu": 149 | return F.glu 150 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 151 | 152 | 153 | class TransDecoder(nn.Module): 154 | def __init__(self, num_attn_layers, hidden_dim=128, pos_encoding_before=True, pos_encoding_value=True): 155 | super(TransDecoder, self).__init__() 156 | self.pos_encoding_before = pos_encoding_before 157 | self.pos_encoding_value = pos_encoding_value 158 | attn_layer = CrossAttnLayer(hidden_dim) 159 | self.layers = _get_clones(attn_layer, num_attn_layers) 160 | position_embedding = build_position_encoding(hidden_dim) 161 | self.pos = [ 162 | position_embedding(torch.ones((1, 36 // 2**i, 36 // 2**i), dtype=torch.bool, device="cuda")) 163 | .flatten(2) 164 | .permute(2, 0, 1) 165 | for i in reversed(range(3)) 166 | ] # torch.Size([1, 128, 36 / 2^i, 36]) 167 | 168 | def forward(self, imgf_list, query_embed): 169 | bs, c, h, w = imgf_list[0].shape 170 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 171 | 172 | for i, layer in enumerate(self.layers): 173 | if i in [0, 2, 4]: 174 | imgf = imgf_list[i // 2].flatten(2).permute(2, 0, 1) 175 | cross_k = cross_v = imgf 176 | 177 | query = key = value = query_embed 178 | 179 | pos = self.pos[i // 2] 180 | if self.pos_encoding_before: 181 | pos_list = [None] * 3 182 | if i == 0: 183 | query = query + pos 184 | key = key + pos 185 | # cross_k = cross_k + self.pos # in orig transformer PE is not added here at all 186 | if self.pos_encoding_value: 187 | value = value + pos 188 | # cross_v = cross_v + self.pos 189 | else: # add PE every block, also added to cross k,v 190 | if self.pos_encoding_value: 191 | pos_list = [pos] * 3 192 | else: 193 | pos_list = [pos, pos, None] 194 | query_embed = layer(query=query, key=key, value=value, cross_kv=[cross_k, cross_v], pos_list=pos_list) 195 | 196 | # the decoded embeddings of the first and second blocks are upsampled based on the bilinear interpolation 197 | if i in [1, 3]: 198 | query_embed = query_embed.permute(1, 2, 0).reshape(bs, c, h * 2 ** (i // 2), w * 2 ** (i // 2)) 199 | query_embed = nn.functional.interpolate(query_embed, scale_factor=2, mode="bilinear", align_corners=True) 200 | query_embed = query_embed.flatten(2).permute(2, 0, 1) 201 | 202 | scale_factor = 2 ** ((len(self.layers) - 1) // 2) # 4 203 | query_embed = query_embed.permute(1, 2, 0).reshape(bs, c, h * scale_factor, w * scale_factor) 204 | 205 | return query_embed 206 | 207 | 208 | class TransEncoder(nn.Module): 209 | def __init__( 210 | self, 211 | num_attn_layers, 212 | hidden_dim=128, 213 | extra_attention=False, 214 | extra_skip=False, 215 | pos_encoding_before=True, 216 | pos_encoding_value=True, 217 | ): 218 | super(TransEncoder, self).__init__() 219 | if not extra_attention: 220 | assert not extra_skip 221 | self.extra_skip = extra_skip 222 | self.pos_encoding_before = pos_encoding_before 223 | self.pos_encoding_value = pos_encoding_value 224 | attn_layer = SelfAttnLayer(hidden_dim, extra_attention=extra_attention) 225 | self.layers = _get_clones(attn_layer, num_attn_layers) 226 | self.conv_stride = nn.ModuleList( # todo what kernel size? depth wise separable? 227 | [ 228 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1, padding_mode="reflect") 229 | for _ in range(num_attn_layers // 2 - 1) 230 | ] 231 | ) 232 | position_embedding = build_position_encoding(hidden_dim) 233 | # run here because sin PE is not learned 234 | self.pos = [ 235 | position_embedding(torch.ones((1, 36 // 2**i, 36 // 2**i), dtype=torch.bool, device="cuda")) 236 | .flatten(2) 237 | .permute(2, 0, 1) 238 | for i in range(3) 239 | ] # torch.Size([1, 128, 36 / 2^i, 36]) 240 | 241 | def forward(self, imgf): 242 | bs, c, h, w = imgf.shape 243 | imgf = imgf.flatten(2).permute(2, 0, 1) 244 | 245 | outputs = [] 246 | 247 | for i, layer in enumerate(self.layers): 248 | query = key = value = imgf 249 | pos = self.pos[i // 2] 250 | if self.pos_encoding_before: 251 | pos_list = [None] * 3 252 | if i == 0: 253 | query = query + pos 254 | key = key + pos 255 | if self.pos_encoding_value: 256 | value = value + pos 257 | else: # add PE every block 258 | if self.pos_encoding_value: 259 | pos_list = [pos] * 3 260 | else: 261 | pos_list = [pos, pos, None] 262 | cross_kv = query if self.extra_skip else None # extra skip connection from block input before blockwise PE 263 | imgf = layer(query=query, key=key, value=value, cross_kv=cross_kv, pos_list=pos_list) 264 | if i in [1, 3, 5]: 265 | imgf = imgf.permute(1, 2, 0).reshape(bs, c, h // 2 ** (i // 2), w // 2 ** (i // 2)) 266 | outputs.append(imgf) # save output 267 | 268 | # downsampling after first and second block 269 | if i // 2 < len(self.conv_stride): 270 | imgf = self.conv_stride[i // 2](imgf) 271 | imgf = imgf.flatten(2).permute(2, 0, 1) 272 | 273 | return list(reversed(outputs)) 274 | 275 | 276 | class FlowHead(nn.Module): 277 | def __init__(self, input_dim=128, hidden_dim=256): 278 | super(FlowHead, self).__init__() 279 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 280 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 281 | self.relu = nn.ReLU(inplace=True) 282 | 283 | def forward(self, x): 284 | return self.conv2(self.relu(self.conv1(x))) 285 | 286 | 287 | class DocTrPlus(nn.Module): 288 | def __init__(self, config): 289 | super(DocTrPlus, self).__init__() 290 | self.num_attn_layers = config["num_attn_layers"] 291 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 292 | 293 | hdim = config["hidden_dim"] 294 | 295 | self.trans_encoder = TransEncoder( 296 | self.num_attn_layers, 297 | hidden_dim=hdim, 298 | extra_attention=config["extra_attention"], # corresponds to cross attention block in decoder 299 | extra_skip=config["extra_skip"], # k,v comes from block input (q) 300 | pos_encoding_before=not config["add_pe_every_block"], # only add PE once before encoder blocks 301 | pos_encoding_value=not config["no_pe_for_value"], 302 | ) 303 | self.trans_decoder = TransDecoder( 304 | self.num_attn_layers, 305 | hidden_dim=hdim, 306 | pos_encoding_before=not config["add_pe_every_block"], 307 | pos_encoding_value=not config["no_pe_for_value"], 308 | ) 309 | self.query_embed = nn.Embedding(9 * 9, hdim) # (288 / 32) ^2 310 | 311 | self.flow_head = FlowHead(hdim, hidden_dim=hdim) 312 | 313 | def forward(self, backbone_features): 314 | """ 315 | image: segmented image 316 | """ 317 | fmap = torch.relu(backbone_features) 318 | 319 | fmap = self.trans_encoder(fmap) 320 | fmap = self.trans_decoder(fmap, self.query_embed.weight) 321 | 322 | dflow = self.flow_head(fmap) 323 | 324 | return {"flow": dflow, "feature_map": fmap} 325 | -------------------------------------------------------------------------------- /docmae/models/mae.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrapper class for Masked Auto-encoder taken from the huggingface library 3 | """ 4 | from typing import Optional 5 | 6 | import torch 7 | import lightning as L 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR 9 | from torch.utils.tensorboard import SummaryWriter 10 | from torchvision import transforms 11 | 12 | 13 | class MAE(L.LightningModule): 14 | tb_log: SummaryWriter 15 | 16 | def __init__(self, encoder, decoder, hparams, training: bool): 17 | super().__init__() 18 | self.example_input_array = torch.rand(1, 3, 288, 288) 19 | 20 | self.segmenter = torch.jit.load(hparams["segmenter_ckpt"]) 21 | self.segmenter = torch.jit.freeze(self.segmenter) 22 | self.segmenter = torch.jit.optimize_for_inference(self.segmenter) 23 | self.encoder = encoder 24 | self.decoder = decoder 25 | self.is_training = training 26 | 27 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 28 | 29 | self.save_hyperparameters(hparams) 30 | 31 | def on_fit_start(self): 32 | self.tb_log = self.logger.experiment 33 | additional_metrics = ["val/loss"] 34 | self.logger.log_hyperparams(self.hparams, {**{key: 0 for key in additional_metrics}}) 35 | 36 | def configure_optimizers(self): 37 | """ 38 | calling setup_optimizers for which all missing parameters must be registered with gin 39 | 40 | Returns: 41 | dictionary defining optimizer, learning rate scheduler and value to monitor as expected by pytorch lightning 42 | """ 43 | parameters = self.segmenter.named_parameters() 44 | for name, param in parameters: 45 | param.requires_grad = False 46 | 47 | num_epochs = self.hparams["num_train_epochs"] 48 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams["base_learning_rate"], weight_decay=0) 49 | # scheduler = ReduceLROnPlateau(optimizer) 50 | scheduler = OneCycleLR( 51 | optimizer, self.hparams["base_learning_rate"], epochs=num_epochs, steps_per_epoch=200_000 // num_epochs 52 | ) 53 | # scheduler = CosineAnnealingLR(optimizer, self.hparams["base_learning_rate"]) 54 | return { 55 | "optimizer": optimizer, 56 | "lr_scheduler": scheduler, 57 | "monitor": "train/loss", 58 | } 59 | 60 | def forward( 61 | self, 62 | pixel_values: Optional[torch.FloatTensor] = None, 63 | noise: Optional[torch.FloatTensor] = None, 64 | head_mask: Optional[torch.FloatTensor] = None, 65 | output_attentions: Optional[bool] = None, 66 | output_hidden_states: Optional[bool] = None, 67 | ): 68 | 69 | outputs = self.encoder( 70 | pixel_values.float(), 71 | noise=noise, 72 | head_mask=head_mask, 73 | output_attentions=output_attentions, 74 | output_hidden_states=output_hidden_states, 75 | return_dict=False, 76 | ) 77 | 78 | latent, ids_restore, mask = outputs 79 | 80 | if not self.is_training: 81 | return outputs 82 | 83 | decoder_outputs = self.decoder(latent, ids_restore.long()) 84 | logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels) 85 | 86 | output = (logits, mask, ids_restore) 87 | return output 88 | 89 | """Taken from transformers TODO""" 90 | 91 | def patchify(self, pixel_values): 92 | """ 93 | Args: 94 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 95 | Pixel values. 96 | 97 | Returns: 98 | `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: 99 | Patchified pixel values. 100 | """ 101 | patch_size, num_channels = self.encoder.config.patch_size, self.encoder.config.num_channels 102 | # sanity checks 103 | if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0): 104 | raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size") 105 | if pixel_values.shape[1] != num_channels: 106 | raise ValueError( 107 | "Make sure the number of channels of the pixel values is equal to the one set in the configuration" 108 | ) 109 | 110 | # patchify 111 | batch_size = pixel_values.shape[0] 112 | num_patches_one_direction = pixel_values.shape[2] // patch_size 113 | patchified_pixel_values = pixel_values.reshape( 114 | batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size 115 | ) 116 | patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values) 117 | patchified_pixel_values = patchified_pixel_values.reshape( 118 | batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels 119 | ) 120 | return patchified_pixel_values 121 | 122 | def unpatchify(self, patchified_pixel_values): 123 | """ 124 | Args: 125 | patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: 126 | Patchified pixel values. 127 | 128 | Returns: 129 | `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: 130 | Pixel values. 131 | """ 132 | patch_size, num_channels = self.encoder.config.patch_size, self.encoder.config.num_channels 133 | num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5) 134 | # sanity check 135 | if num_patches_one_direction**2 != patchified_pixel_values.shape[1]: 136 | raise ValueError("Make sure that the number of patches can be squared") 137 | 138 | # unpatchify 139 | batch_size = patchified_pixel_values.shape[0] 140 | patchified_pixel_values = patchified_pixel_values.reshape( 141 | batch_size, 142 | num_patches_one_direction, 143 | num_patches_one_direction, 144 | patch_size, 145 | patch_size, 146 | num_channels, 147 | ) 148 | patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values) 149 | pixel_values = patchified_pixel_values.reshape( 150 | batch_size, 151 | num_channels, 152 | num_patches_one_direction * patch_size, 153 | num_patches_one_direction * patch_size, 154 | ) 155 | return pixel_values 156 | 157 | def forward_loss(self, pixel_values, pred, mask): 158 | """ 159 | Args: 160 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 161 | Pixel values. 162 | pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: 163 | Predicted pixel values. 164 | mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): 165 | Tensor indicating which patches are masked (1) and which are not (0). 166 | 167 | Returns: 168 | `torch.FloatTensor`: Pixel reconstruction loss. 169 | """ 170 | target = self.patchify(pixel_values) 171 | # if self.encoder.config.norm_pix_loss: 172 | # mean = target.mean(dim=-1, keepdim=True) 173 | # var = target.var(dim=-1, keepdim=True) 174 | # target = (target - mean) / (var + 1.0e-6) ** 0.5 175 | 176 | loss = (pred - target) ** 2 177 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 178 | 179 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 180 | return loss 181 | 182 | def training_step(self, batch): 183 | self.encoder.train() 184 | self.decoder.train() 185 | 186 | image = batch["image"] 187 | batch_size = len(image) 188 | 189 | with torch.no_grad(): 190 | seg_image = self.normalize(image) 191 | seg_mask = self.segmenter(seg_image) 192 | seg_mask = (seg_mask > 0.5).double() 193 | 194 | seg_image = image * seg_mask 195 | 196 | # training image sanity check 197 | if self.global_step == 0: 198 | self.tb_log.add_images("train/image", image, global_step=self.global_step) 199 | self.tb_log.add_images("train/seg_mask", seg_mask, global_step=self.global_step) 200 | 201 | output = self.forward(seg_image) 202 | (logits, ids_restore, mask) = output 203 | loss = self.forward_loss(seg_image, logits, mask) 204 | 205 | self.log("train/loss", loss, on_step=True, on_epoch=True, batch_size=batch_size) 206 | 207 | return loss 208 | 209 | def on_after_backward(self): 210 | global_step = self.global_step 211 | # if self.global_step % 100 == 0: 212 | # for name, param in self.model.named_parameters(): 213 | # self.tb_log.add_histogram(name, param, global_step) 214 | # if param.requires_grad: 215 | # self.tb_log.add_histogram(f"{name}_grad", param.grad, global_step) 216 | 217 | def on_validation_start(self): 218 | self.tb_log = self.logger.experiment 219 | 220 | def validation_step(self, val_batch, batch_idx, dataloader_idx=0): 221 | self.encoder.eval() 222 | self.decoder.eval() 223 | 224 | image = val_batch["image"] 225 | batch_size = len(image) 226 | 227 | with torch.no_grad(): 228 | seg_image = self.normalize(image) 229 | seg_mask = self.segmenter(seg_image) 230 | seg_mask = (seg_mask > 0.5).double() 231 | 232 | with torch.device(self.device): 233 | seg_image = image * seg_mask 234 | inputs = self.image_processor(images=seg_image, return_tensors="pt") 235 | 236 | # val image sanity check 237 | if self.global_step == 0 and batch_idx == 0: 238 | self.tb_log.add_images("val/image", image, global_step=self.global_step) 239 | self.tb_log.add_images("val/seg_mask", seg_mask, global_step=self.global_step) 240 | 241 | output = self.forward(**inputs) 242 | (logits, ids_restore, mask) = output 243 | 244 | loss = self.forward_loss(inputs["pixel_values"], logits, mask) 245 | 246 | self.log(f"val/loss", loss, batch_size=batch_size, prog_bar=True) 247 | 248 | if batch_idx == 0: 249 | y = self.unpatchify(logits) 250 | 251 | # visualize the mask 252 | mask = mask.detach() 253 | mask = mask.unsqueeze(-1).repeat(1, 1, self.encoder.config.patch_size**2 * 3) # (N, H*W, p*p*3) 254 | mask = self.unpatchify(mask) # 1 is removing, 0 is keeping 255 | 256 | # masked image 257 | im_masked = seg_image * (1 - mask) + mask * 0.5 258 | 259 | # MAE reconstruction pasted with visible patches 260 | im_paste = seg_image * (1 - mask) + y * mask 261 | 262 | self.tb_log.add_images(f"val/masked", im_masked, global_step=self.global_step) 263 | self.tb_log.add_images(f"val/prediction", y, global_step=self.global_step) 264 | self.tb_log.add_images(f"val/reconstruction", im_paste, global_step=self.global_step) 265 | 266 | def on_test_start(self): 267 | self.tb_log = self.logger.experiment 268 | -------------------------------------------------------------------------------- /docmae/models/rectification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import lightning as L 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch.nn import L1Loss 8 | from torch.optim.lr_scheduler import OneCycleLR 9 | from torch.utils.tensorboard import SummaryWriter 10 | from torchvision.transforms import transforms 11 | from torchvision.utils import flow_to_image 12 | 13 | from docmae.models.upscale import coords_grid 14 | 15 | 16 | class Rectification(L.LightningModule): 17 | tb_log: SummaryWriter 18 | 19 | def __init__( 20 | self, 21 | backbone: nn.Module, 22 | model: nn.Module, 23 | upscale: nn.Module, 24 | config, 25 | ): 26 | super().__init__() 27 | self.example_input_array = torch.rand(1, 3, 288, 288) 28 | 29 | self.backbone = backbone 30 | self.model = model 31 | self.upscale_module = upscale 32 | self.config = config 33 | hparams = config["model"] 34 | 35 | H, W = self.example_input_array.shape[2:] 36 | self.coodslar = coords_grid(1, H, W).to(self.example_input_array.device) 37 | 38 | self.segment_background = hparams["segment_background"] 39 | self.hidden_dim = hparams["hidden_dim"] 40 | 41 | self.loss = L1Loss() 42 | self.save_hyperparameters(hparams) 43 | 44 | def on_fit_start(self): 45 | self.tb_log = self.logger.experiment 46 | self.coodslar = self.coodslar.to(self.device) 47 | additional_metrics = ["val/loss"] 48 | self.logger.log_hyperparams(self.hparams, {**{key: 0 for key in additional_metrics}}) 49 | 50 | def configure_optimizers(self): 51 | """ 52 | calling setup_optimizers for which all missing parameters must be registered with gin 53 | 54 | Returns: 55 | dictionary defining optimizer, learning rate scheduler and value to monitor as expected by pytorch lightning 56 | """ 57 | 58 | optimizer = torch.optim.AdamW(self.parameters()) 59 | scheduler = { 60 | "scheduler": OneCycleLR(optimizer, max_lr=1e-4, pct_start=0.1, total_steps=self.config["training"]["steps"]), 61 | "interval": "step", 62 | } 63 | return [optimizer], [scheduler] 64 | 65 | def forward(self, inputs): 66 | """ 67 | Runs inference: image_processing, encoder, decoder, layer norm, flow head 68 | Args: 69 | inputs: image tensor of shape [B, C, H, W] 70 | Returns: flow displacement 71 | """ 72 | backbone_features = self.backbone(inputs) 73 | outputs = self.model(backbone_features) 74 | flow_up = self.upscale_module(**outputs) 75 | return flow_up 76 | 77 | def training_step(self, batch): 78 | image = batch["image"] / 255 79 | if self.segment_background: 80 | image = image * batch["mask"] 81 | 82 | bm_target = batch["bm"] * 287 83 | batch_size = len(image) 84 | 85 | flow_pred = self.forward(image) 86 | bm_pred = flow_pred + self.coodslar 87 | 88 | # training image sanity check 89 | if self.global_step == 0: 90 | ones = torch.ones((batch_size, 1, 288, 288)) 91 | self.tb_log.add_images("train/image", image.detach().cpu(), global_step=self.global_step) 92 | image_unwarped = F.grid_sample( 93 | image, ((bm_target / 287 - 0.5) * 2).permute((0, 2, 3, 1)), align_corners=False 94 | ) 95 | self.tb_log.add_images("train/unwarped", image_unwarped.detach().cpu(), global_step=self.global_step) 96 | 97 | # self.tb_log.add_images( 98 | # "train/bm_target", torch.cat((bm_target.detach().cpu() / 287, ones), dim=1), global_step=self.global_step 99 | # ) 100 | self.tb_log.add_images( 101 | "train/flow_target", 102 | flow_to_image(bm_target.detach().cpu() - self.coodslar.detach().cpu()), 103 | global_step=self.global_step, 104 | ) 105 | 106 | # log metrics 107 | loss = self.loss(bm_target, bm_pred) 108 | if self.hparams.get("mask_loss_mult", 0) > 0: # MataDoc unwarped mask loss 109 | mask_target = batch["mask"] 110 | mask_target_unwarped = F.grid_sample( 111 | mask_target, ((bm_target / 287 - 0.5) * 2).permute((0, 2, 3, 1)), align_corners=False 112 | ) 113 | mask_pred_unwarped = F.grid_sample( 114 | mask_target, ((bm_pred / 287 - 0.5) * 2).permute((0, 2, 3, 1)), align_corners=False 115 | ) 116 | if self.global_step == 0: 117 | self.tb_log.add_images("train/mask_target", mask_target_unwarped.detach().cpu(), global_step=self.global_step) 118 | self.tb_log.add_images("train/mask_pred", mask_pred_unwarped.detach().cpu(), global_step=self.global_step) 119 | 120 | mask_loss = self.loss(mask_target_unwarped, mask_pred_unwarped) 121 | self.log("train/mask_loss", mask_loss, on_step=True, on_epoch=True, batch_size=batch_size) 122 | 123 | loss += self.hparams["mask_loss_mult"] * mask_loss 124 | 125 | self.log("train/loss", loss, on_step=True, on_epoch=True, batch_size=batch_size) 126 | 127 | return loss 128 | 129 | def on_after_backward(self): 130 | global_step = self.global_step 131 | # if self.global_step % 100 == 0: 132 | # for name, param in self.model.named_parameters(): 133 | # self.tb_log.add_histogram(name, param, global_step) 134 | # if param.requires_grad: 135 | # self.tb_log.add_histogram(f"{name}_grad", param.grad, global_step) 136 | 137 | def on_validation_start(self): 138 | self.tb_log = self.logger.experiment 139 | self.coodslar = self.coodslar.to(self.device) 140 | 141 | def validation_step(self, val_batch, batch_idx): 142 | image = val_batch["image"] / 255 143 | if self.segment_background: 144 | image = image * val_batch["mask"] 145 | bm_target = val_batch["bm"] * 287 146 | batch_size = len(image) 147 | 148 | flow_pred = self.forward(image) 149 | bm_pred = self.coodslar + flow_pred 150 | 151 | # log metrics 152 | loss = self.loss(bm_target, bm_pred) 153 | 154 | self.log("val/loss", loss, on_epoch=True, batch_size=batch_size) 155 | 156 | ones = torch.ones((batch_size, 1, 288, 288)) 157 | 158 | if batch_idx == 0 and self.global_step == 0: 159 | self.tb_log.add_images("val/image", image.detach().cpu(), global_step=self.global_step) 160 | # self.tb_log.add_images( 161 | # "val/bm", torch.cat((bm_target.detach().cpu() / 287, ones), dim=1), global_step=self.global_step 162 | # ) 163 | self.tb_log.add_images( 164 | "val/flow", 165 | flow_to_image((bm_target.detach().cpu() - self.coodslar.detach().cpu())), 166 | global_step=self.global_step, 167 | ) 168 | 169 | if batch_idx == 0: 170 | # self.tb_log.add_images( 171 | # "val/bm_pred", torch.cat((bm_pred.detach().cpu() / 287, ones), dim=1), global_step=self.global_step 172 | # ) 173 | self.tb_log.add_images("val/flow_pred", flow_to_image(flow_pred.detach().cpu()), global_step=self.global_step) 174 | 175 | # self.tb_log.add_images("val/bm_diff", flow_to_image(bm_target.cpu() - bm_pred.cpu()), global_step=self.global_step) 176 | 177 | bm_ = (bm_pred / 287 - 0.5) * 2 178 | bm_ = bm_.permute((0, 2, 3, 1)) 179 | img_ = image 180 | uw = F.grid_sample(img_, bm_, align_corners=False).detach().cpu() 181 | 182 | self.tb_log.add_images("val/unwarped", uw, global_step=self.global_step) 183 | 184 | def on_predict_start(self): 185 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 186 | if self.segment_background: 187 | self.segmenter = torch.jit.load(self.config["segmenter_ckpt"], map_location=self.device) 188 | self.segmenter = torch.jit.freeze(self.segmenter) 189 | self.segmenter = torch.jit.optimize_for_inference(self.segmenter) 190 | self.resize = transforms.Resize((288, 288), antialias=True) 191 | self.coodslar = self.coodslar.to(self.device) 192 | 193 | def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0): 194 | image_orig = batch["image"] 195 | b, c, h, w = image_orig.shape 196 | # resize to 288 197 | image = self.resize(image_orig) 198 | image /= 255 199 | if self.segment_background: 200 | mask = (self.segmenter(self.normalize(image)) > 0.5).to(torch.bool) 201 | image = image * mask 202 | else: 203 | mask = None 204 | 205 | flow_pred = self.forward(image) 206 | bm_pred = self.coodslar + flow_pred # rescale to original 207 | 208 | bm = (2 * (bm_pred / 286.8) - 1) * 0.99 # https://github.com/fh2019ustc/DocTr/issues/6 209 | 210 | # pytorch reimplementation 211 | # import kornia 212 | # bm = torch.nn.functional.interpolate(bm, image_orig.shape[2:]) 213 | # bm = kornia.filters.box_blur(bm, 3) 214 | # rectified = F.grid_sample(image_orig, bm.permute((0, 2, 3, 1)), align_corners=False, mode="bilinear") 215 | 216 | # doctr implementation 217 | bm = bm.cpu() 218 | bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow 219 | bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow 220 | bm0 = cv2.blur(bm0, (3, 3)) 221 | bm1 = cv2.blur(bm1, (3, 3)) 222 | lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).to(self.device).unsqueeze(0) # h * w * 2 223 | 224 | rectified = F.grid_sample(image_orig, lbl, align_corners=True) 225 | return rectified, bm, mask 226 | 227 | def on_test_start(self): 228 | self.tb_log = self.logger.experiment 229 | self.coodslar = self.coodslar.to(self.device) 230 | -------------------------------------------------------------------------------- /docmae/models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | from typing import Optional 9 | from torch import Tensor 10 | 11 | 12 | class NestedTensor(object): 13 | def __init__(self, tensors, mask: Optional[Tensor]): 14 | self.tensors = tensors 15 | self.mask = mask 16 | 17 | def to(self, device): 18 | # type: (Device) -> NestedTensor # noqa 19 | cast_tensor = self.tensors.to(device) 20 | mask = self.mask 21 | if mask is not None: 22 | assert mask is not None 23 | cast_mask = mask.to(device) 24 | else: 25 | cast_mask = None 26 | return NestedTensor(cast_tensor, cast_mask) 27 | 28 | def decompose(self): 29 | return self.tensors, self.mask 30 | 31 | def __repr__(self): 32 | return str(self.tensors) 33 | 34 | 35 | class PositionEmbeddingSine(nn.Module): 36 | """ 37 | This is a more standard version of the position embedding, very similar to the one 38 | used by the Attention is all you need paper, generalized to work on images. 39 | """ 40 | 41 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 42 | super().__init__() 43 | self.num_pos_feats = num_pos_feats 44 | self.temperature = temperature 45 | self.normalize = normalize 46 | if scale is not None and normalize is False: 47 | raise ValueError("normalize should be True if scale is passed") 48 | if scale is None: 49 | scale = 2 * math.pi 50 | self.scale = scale 51 | 52 | def forward(self, mask): 53 | assert mask is not None 54 | y_embed = mask.cumsum(1, dtype=torch.float32) 55 | x_embed = mask.cumsum(2, dtype=torch.float32) 56 | if self.normalize: 57 | eps = 1e-6 58 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 59 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 60 | 61 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32).cuda() 62 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 63 | 64 | pos_x = x_embed[:, :, :, None] / dim_t 65 | pos_y = y_embed[:, :, :, None] / dim_t 66 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 67 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 68 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 69 | # print(pos.shape) 70 | return pos 71 | 72 | 73 | class PositionEmbeddingLearned(nn.Module): 74 | """ 75 | Absolute pos embedding, learned. 76 | """ 77 | 78 | def __init__(self, num_pos_feats=256): 79 | super().__init__() 80 | self.row_embed = nn.Embedding(50, num_pos_feats) 81 | self.col_embed = nn.Embedding(50, num_pos_feats) 82 | self.reset_parameters() 83 | 84 | def reset_parameters(self): 85 | nn.init.uniform_(self.row_embed.weight) 86 | nn.init.uniform_(self.col_embed.weight) 87 | 88 | def forward(self, tensor_list: NestedTensor): 89 | x = tensor_list.tensors 90 | h, w = x.shape[-2:] 91 | i = torch.arange(w, device=x.device) 92 | j = torch.arange(h, device=x.device) 93 | x_emb = self.col_embed(i) 94 | y_emb = self.row_embed(j) 95 | pos = ( 96 | torch.cat( 97 | [ 98 | x_emb.unsqueeze(0).repeat(h, 1, 1), 99 | y_emb.unsqueeze(1).repeat(1, w, 1), 100 | ], 101 | dim=-1, 102 | ) 103 | .permute(2, 0, 1) 104 | .unsqueeze(0) 105 | .repeat(x.shape[0], 1, 1, 1) 106 | ) 107 | return pos 108 | 109 | 110 | def build_position_encoding(hidden_dim=512, position_embedding="sine"): 111 | N_steps = hidden_dim // 2 112 | if position_embedding in ("v2", "sine"): 113 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 114 | elif position_embedding in ("v3", "learned"): 115 | position_embedding = PositionEmbeddingLearned(N_steps) 116 | else: 117 | raise ValueError(f"not supported {position_embedding}") 118 | 119 | return position_embedding 120 | 121 | 122 | class ResidualBlock(nn.Module): 123 | def __init__(self, in_planes, planes, norm_fn="group", stride=1): 124 | super(ResidualBlock, self).__init__() 125 | 126 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 127 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 128 | self.relu = nn.ReLU(inplace=True) 129 | 130 | num_groups = planes // 8 131 | 132 | if norm_fn == "group": 133 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 134 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 135 | if not stride == 1: 136 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 137 | 138 | elif norm_fn == "batch": 139 | self.norm1 = nn.BatchNorm2d(planes) 140 | self.norm2 = nn.BatchNorm2d(planes) 141 | if not stride == 1: 142 | self.norm3 = nn.BatchNorm2d(planes) 143 | 144 | elif norm_fn == "instance": 145 | self.norm1 = nn.InstanceNorm2d(planes) 146 | self.norm2 = nn.InstanceNorm2d(planes) 147 | if not stride == 1: 148 | self.norm3 = nn.InstanceNorm2d(planes) 149 | 150 | elif norm_fn == "none": 151 | self.norm1 = nn.Sequential() 152 | self.norm2 = nn.Sequential() 153 | if not stride == 1: 154 | self.norm3 = nn.Sequential() 155 | 156 | if stride == 1: 157 | self.downsample = None 158 | 159 | else: 160 | self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 161 | 162 | def forward(self, x): 163 | y = x 164 | y = self.relu(self.norm1(self.conv1(y))) 165 | y = self.relu(self.norm2(self.conv2(y))) 166 | 167 | if self.downsample is not None: 168 | x = self.downsample(x) 169 | 170 | return self.relu(x + y) 171 | 172 | 173 | class BasicEncoder(nn.Module): 174 | def __init__(self, output_dim=128, norm_fn="batch"): 175 | super(BasicEncoder, self).__init__() 176 | self.norm_fn = norm_fn 177 | 178 | if self.norm_fn == "group": 179 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 180 | 181 | elif self.norm_fn == "batch": 182 | self.norm1 = nn.BatchNorm2d(64) 183 | 184 | elif self.norm_fn == "instance": 185 | self.norm1 = nn.InstanceNorm2d(64) 186 | 187 | elif self.norm_fn == "none": 188 | self.norm1 = nn.Sequential() 189 | 190 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 191 | self.relu1 = nn.ReLU(inplace=True) 192 | 193 | self.in_planes = 64 194 | self.layer1 = self._make_layer(64, stride=1) 195 | self.layer2 = self._make_layer(128, stride=2) 196 | self.layer3 = self._make_layer(192, stride=2) 197 | 198 | # output convolution 199 | self.conv2 = nn.Conv2d(192, output_dim, kernel_size=1) 200 | 201 | for m in self.modules(): 202 | if isinstance(m, nn.Conv2d): 203 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 204 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 205 | if m.weight is not None: 206 | nn.init.constant_(m.weight, 1) 207 | if m.bias is not None: 208 | nn.init.constant_(m.bias, 0) 209 | 210 | def _make_layer(self, dim, stride=1): 211 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 212 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 213 | layers = (layer1, layer2) 214 | 215 | self.in_planes = dim 216 | return nn.Sequential(*layers) 217 | 218 | def forward(self, x): 219 | x = self.conv1(x) 220 | x = self.norm1(x) 221 | x = self.relu1(x) 222 | 223 | x = self.layer1(x) 224 | x = self.layer2(x) 225 | x = self.layer3(x) 226 | 227 | x = self.conv2(x) 228 | 229 | return x 230 | -------------------------------------------------------------------------------- /docmae/models/upscale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def coords_grid(batch, ht, wd): 7 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd), indexing="ij") 8 | coords = torch.stack(coords[::-1], dim=0).float() 9 | return coords[None].repeat(batch, 1, 1, 1) 10 | 11 | 12 | # Flow related code taken from https://github.com/fh2019ustc/DocTr/blob/main/GeoTr.py 13 | class UpscaleRAFT(nn.Module): 14 | """ 15 | Infers conv mask to upscale flow 16 | """ 17 | 18 | def __init__(self, patch_size: int, input_dim=512): 19 | super(UpscaleRAFT, self).__init__() 20 | self.P = patch_size 21 | 22 | self.mask = nn.Sequential( 23 | nn.Conv2d(input_dim, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, patch_size**2 * 9, 1, padding=0) 24 | ) 25 | 26 | def upsample_flow(self, flow, mask): 27 | N, _, H, W = flow.shape 28 | mask = mask.view(N, 1, 9, self.P, self.P, H, W) 29 | mask = torch.softmax(mask, dim=2) 30 | 31 | up_flow = F.unfold(self.P * flow, (3, 3), padding=1) 32 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 33 | 34 | up_flow = torch.sum(mask * up_flow, dim=2) 35 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 36 | 37 | return up_flow.reshape(N, 2, self.P * H, self.P * W) 38 | 39 | def forward(self, feature_map, flow): 40 | mask = 0.25 * self.mask(feature_map) # scale mask to balance gradients 41 | upflow = self.upsample_flow(flow, mask) 42 | return upflow 43 | 44 | 45 | class UpscaleTransposeConv(nn.Module): 46 | def __init__(self, input_dim=512, hidden_dim=256, mode="bilinear"): 47 | super().__init__() 48 | self.layers = [ 49 | expansion_block(input_dim, hidden_dim, hidden_dim // 2), 50 | expansion_block(hidden_dim // 2, hidden_dim // 4, 2, relu=False), 51 | nn.Upsample(scale_factor=2, mode=mode), 52 | ] 53 | 54 | self.layers = nn.Sequential(*self.layers) 55 | 56 | def forward(self, feature_map, **kwargs): 57 | return self.layers(feature_map) 58 | 59 | 60 | class UpscaleInterpolate(nn.Module): 61 | def __init__(self, input_dim=512, hidden_dim=256, mode="bilinear"): 62 | super().__init__() 63 | self.mode = mode 64 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 65 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 66 | self.relu = nn.ReLU(inplace=True) 67 | 68 | def forward(self, feature_map, **kwargs): 69 | flow = self.conv2(self.relu(self.conv1(feature_map))) 70 | 71 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) # to scale 8/16 72 | return 8 * F.interpolate(flow, size=new_size, mode=self.mode, align_corners=False) 73 | 74 | 75 | def expansion_block(in_channels, mid_channel, out_channels, relu=True): 76 | """Build block of two consecutive convolutions followed by upsampling. 77 | 78 | The following chain of layers is applied: 79 | [Conv2D -> ReLU -> BatchNorm -> Conv2D -> ReLU -> BatchNorm -> ConvTranspose2d -> ReLU] 80 | 81 | This block doubles the dimensions of input tensor, i.e. input is of size 82 | (rows, cols, in_channels) and output is of size (rows*2, cols*2, out_channels). 83 | 84 | Args: 85 | in_channels (int): Number of channels of input tensor. 86 | mid_channel (int): Number of channels of middle channel. 87 | out_channels (int): Number of channels of output tensor, i.e., number of filters. 88 | relu (bool): Indicates whether to apply ReLU after transposed convolution. 89 | 90 | Returns: 91 | block (nn.Sequential): Built expansive block. 92 | 93 | """ 94 | block = nn.Sequential( 95 | conv_relu_bn(in_channels, mid_channel), 96 | nn.ConvTranspose2d( 97 | in_channels=mid_channel, 98 | out_channels=out_channels, 99 | kernel_size=3, 100 | stride=2, 101 | padding=1, 102 | output_padding=1, 103 | ), 104 | ) 105 | if relu is True: 106 | block = nn.Sequential(block, nn.ReLU()) 107 | return block 108 | 109 | 110 | def conv_relu_bn(in_channels, out_channels, kernel_size=3, padding=1): 111 | """Build [Conv2D -> ReLu -> BatchNorm] block. 112 | 113 | Args: 114 | in_channels (int): Number of channels of input tensor. 115 | out_channels (int): Number of channels of output tensor, i.e., number of filters. 116 | kernel_size (int): Size of convolution filters, squared filters are assumed. 117 | padding (int): Amount of **zero** padding around input tensor. 118 | 119 | Returns: 120 | block (nn.Sequential): [Conv2D -> ReLu -> BatchNorm] block. 121 | 122 | """ 123 | block = nn.Sequential( 124 | nn.Conv2d( 125 | in_channels=in_channels, 126 | out_channels=out_channels, 127 | kernel_size=kernel_size, 128 | padding=padding, 129 | ), 130 | nn.ReLU(), 131 | nn.BatchNorm2d(out_channels), 132 | ) 133 | return block 134 | -------------------------------------------------------------------------------- /docmae/pretrain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2022 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | import glob 16 | import json 17 | import logging 18 | import os 19 | import sys 20 | from dataclasses import dataclass, field 21 | from pathlib import Path 22 | from typing import Optional 23 | 24 | import argparse 25 | import torch 26 | import datasets 27 | from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor 28 | from torchvision.transforms.functional import InterpolationMode 29 | 30 | import transformers 31 | from transformers import ( 32 | HfArgumentParser, 33 | Trainer, 34 | TrainingArguments, 35 | ViTImageProcessor, 36 | ViTMAEConfig, 37 | ViTMAEForPreTraining, 38 | ) 39 | from transformers.trainer_utils import get_last_checkpoint 40 | from transformers.utils import check_min_version, send_example_telemetry 41 | from transformers.utils.versions import require_version 42 | 43 | from docmae import setup_logging 44 | 45 | """ Pre-training a 🤗 ViT model as an MAE (masked autoencoder), as proposed in https://arxiv.org/abs/2111.06377.""" 46 | 47 | logger = logging.getLogger(__name__) 48 | 49 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 50 | check_min_version("4.29.0.dev0") 51 | 52 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") 53 | 54 | 55 | @dataclass 56 | class DataTrainingArguments: 57 | """ 58 | Arguments pertaining to what data we are going to input our model for training and eval. 59 | Using `HfArgumentParser` we can turn this class 60 | into argparse arguments to be able to specify them on 61 | the command line. 62 | """ 63 | 64 | dataset_name: Optional[str] = field(metadata={"help": "Name of a dataset from the datasets package"}) 65 | image_column_name: Optional[str] = field(default=None, metadata={"help": "The column name of the images in the files."}) 66 | train_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the training data."}) 67 | validation_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the validation data."}) 68 | max_train_samples: Optional[int] = field( 69 | default=None, 70 | metadata={ 71 | "help": ( 72 | "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." 73 | ) 74 | }, 75 | ) 76 | max_eval_samples: Optional[int] = field( 77 | default=None, 78 | metadata={ 79 | "help": ( 80 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 81 | "value if set." 82 | ) 83 | }, 84 | ) 85 | 86 | def __post_init__(self): 87 | data_files = {} 88 | if self.train_dir is not None: 89 | data_files["train"] = self.train_dir 90 | if self.validation_dir is not None: 91 | data_files["val"] = self.validation_dir 92 | self.data_files = data_files if data_files else None 93 | 94 | 95 | @dataclass 96 | class ModelArguments: 97 | """ 98 | Arguments pertaining to which model/config/image processor we are going to pre-train. 99 | """ 100 | 101 | model_name_or_path: str = field( 102 | default=None, 103 | metadata={ 104 | "help": ("The model checkpoint for weights initialization. " 105 | "Don't set if you want to train a model from scratch.") 106 | }, 107 | ) 108 | config_name: Optional[str] = field( 109 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name_or_path"} 110 | ) 111 | config_overrides: Optional[str] = field( 112 | default=None, 113 | metadata={ 114 | "help": ( 115 | "Override some existing default config settings when a model is trained from scratch. Example: " 116 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 117 | ) 118 | }, 119 | ) 120 | cache_dir: Optional[str] = field( 121 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 122 | ) 123 | model_revision: str = field( 124 | default="main", 125 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 126 | ) 127 | image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."}) 128 | use_auth_token: bool = field( 129 | default=False, 130 | metadata={ 131 | "help": ( 132 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 133 | "with private models)." 134 | ) 135 | }, 136 | ) 137 | mask_ratio: float = field( 138 | default=0.75, metadata={"help": "The ratio of the number of masked tokens in the input sequence."} 139 | ) 140 | norm_pix_loss: bool = field( 141 | default=True, metadata={"help": "Whether or not to train with normalized pixel values as target."} 142 | ) 143 | 144 | 145 | @dataclass 146 | class CustomTrainingArguments(TrainingArguments): 147 | base_learning_rate: float = field( 148 | default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."} 149 | ) 150 | 151 | 152 | def collate_fn(examples): 153 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 154 | return {"pixel_values": pixel_values} 155 | 156 | 157 | def parse_arguments(): 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument("-c", "--config", type=str, 160 | help="config file for training parameters") 161 | parser.add_argument("-ll", "--log-level", type=str, default="INFO", 162 | choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], 163 | help="config file for training parameters") 164 | parser.add_argument("-l", "--log-dir", type=str, default="", 165 | help="folder to store log files") 166 | parser.add_argument("-t", "--tensorboard-dir", type=str, default="", 167 | help="folder to store tensorboard logs") 168 | parser.add_argument("-m", "--model-output-dir", type=str, default="model", 169 | help="folder to store trained models") 170 | return parser.parse_args() 171 | 172 | 173 | def main(): 174 | args = parse_arguments() 175 | setup_logging(log_level=args.log_level, log_dir=args.log_dir) 176 | 177 | assert args.config.endswith(".json") 178 | 179 | # Save config for training traceability and load config parameters 180 | config_file = Path(args.model_output_dir) / "config.json" 181 | config = json.loads(Path(args.config).read_text()) 182 | 183 | config["logging_dir"] = args.tensorboard_dir 184 | config["output_dir"] = os.path.join(args.model_output_dir, "checkpoints") 185 | 186 | config_file.write_text(json.dumps(config)) 187 | train(str(config_file)) 188 | 189 | 190 | def train(config_file: str): 191 | # See all possible arguments in src/transformers/training_args.py 192 | # or by passing the --help flag to this script. 193 | # We now keep distinct sets of args, for a cleaner separation of concerns. 194 | 195 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments)) 196 | model_args, data_args, training_args = parser.parse_json_file(json_file=config_file) 197 | 198 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 199 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 200 | # send_example_telemetry("run_mae", model_args, data_args) 201 | 202 | # Setup logging 203 | logging.basicConfig( 204 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 205 | datefmt="%m/%d/%Y %H:%M:%S", 206 | handlers=[logging.StreamHandler(sys.stdout)], 207 | ) 208 | 209 | if training_args.should_log: 210 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 211 | transformers.utils.logging.set_verbosity_info() 212 | 213 | log_level = training_args.get_process_log_level() 214 | logger.setLevel(log_level) 215 | transformers.utils.logging.set_verbosity(log_level) 216 | transformers.utils.logging.enable_default_handler() 217 | transformers.utils.logging.enable_explicit_format() 218 | 219 | # Log on each process the small summary: 220 | logger.warning( 221 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 222 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 223 | ) 224 | logger.info(f"Training/evaluation parameters {training_args}") 225 | 226 | # Detecting last checkpoint. 227 | last_checkpoint = None 228 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 229 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 230 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 231 | raise ValueError( 232 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 233 | "Use --overwrite_output_dir to overcome." 234 | ) 235 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 236 | logger.info( 237 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 238 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 239 | ) 240 | 241 | # Initialize our dataset. 242 | ds_train = datasets.Dataset.from_dict( 243 | {"image": sorted(glob.glob(os.path.join(data_args.data_files["train"], "*.jpg")))}, 244 | split=datasets.Split.TRAIN, 245 | ).cast_column("image", datasets.Image(decode=True)) 246 | 247 | ds_val = datasets.Dataset.from_dict( 248 | {"image": sorted(glob.glob(os.path.join(data_args.data_files["val"], "*.jpg")))}, 249 | split=datasets.Split.VALIDATION, 250 | ).cast_column("image", datasets.Image(decode=True)) 251 | 252 | # combine dataset splits 253 | ds = datasets.DatasetDict() 254 | ds["train"] = ds_train 255 | ds["validation"] = ds_val 256 | 257 | # Load pretrained model and image processor 258 | # 259 | # Distributed training: 260 | # The .from_pretrained methods guarantee that only one local process can concurrently 261 | # download model & vocab. 262 | config_kwargs = { 263 | "cache_dir": model_args.cache_dir, 264 | "revision": model_args.model_revision, 265 | "use_auth_token": True if model_args.use_auth_token else None, 266 | "size": {"height": 288, "width": 288}, 267 | } 268 | if model_args.config_name: 269 | config = ViTMAEConfig.from_pretrained(model_args.config_name, **config_kwargs) 270 | elif model_args.model_name_or_path: 271 | config = ViTMAEConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 272 | else: 273 | config = ViTMAEConfig() 274 | logger.warning("You are instantiating a new config instance from scratch.") 275 | if model_args.config_overrides is not None: 276 | logger.info(f"Overriding config: {model_args.config_overrides}") 277 | config.update_from_string(model_args.config_overrides) 278 | logger.info(f"New config: {config}") 279 | 280 | # adapt config 281 | config.update( 282 | { 283 | "mask_ratio": model_args.mask_ratio, 284 | "norm_pix_loss": model_args.norm_pix_loss, 285 | } 286 | ) 287 | 288 | # create image processor 289 | if model_args.image_processor_name: 290 | image_processor = ViTImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs) 291 | elif model_args.model_name_or_path: 292 | image_processor = ViTImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs) 293 | else: 294 | image_processor = ViTImageProcessor() 295 | 296 | # create model 297 | if model_args.model_name_or_path: 298 | model = ViTMAEForPreTraining.from_pretrained( 299 | model_args.model_name_or_path, 300 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 301 | config=config, 302 | cache_dir=model_args.cache_dir, 303 | revision=model_args.model_revision, 304 | use_auth_token=True if model_args.use_auth_token else None, 305 | ) 306 | else: 307 | logger.info("Training new model from scratch") 308 | model = ViTMAEForPreTraining(config) 309 | 310 | if training_args.do_train: 311 | column_names = ds_train.column_names 312 | else: 313 | column_names = ds_val.column_names 314 | 315 | if data_args.image_column_name is not None: 316 | image_column_name = data_args.image_column_name 317 | elif "image" in column_names: 318 | image_column_name = "image" 319 | elif "img" in column_names: 320 | image_column_name = "img" 321 | else: 322 | image_column_name = column_names[0] 323 | 324 | # transformations as done in original MAE paper 325 | # source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py 326 | if "shortest_edge" in image_processor.size: 327 | size = image_processor.size["shortest_edge"] 328 | else: 329 | size = (image_processor.size["height"], image_processor.size["width"]) 330 | transforms = Compose( 331 | [ 332 | Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), 333 | # RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC), 334 | # RandomHorizontalFlip(), 335 | ToTensor(), 336 | Normalize(mean=image_processor.image_mean, std=image_processor.image_std), 337 | ] 338 | ) 339 | 340 | def preprocess_images(examples): 341 | """Preprocess a batch of images by applying transforms.""" 342 | 343 | examples["pixel_values"] = [transforms(image) for image in examples[image_column_name]] 344 | return examples 345 | 346 | if training_args.do_train: 347 | if "train" not in ds: 348 | raise ValueError("--do_train requires a train dataset") 349 | if data_args.max_train_samples is not None: 350 | ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) 351 | # Set the training transforms 352 | ds["train"].set_transform(preprocess_images) 353 | 354 | if training_args.do_eval: 355 | if "validation" not in ds: 356 | raise ValueError("--do_eval requires a validation dataset") 357 | if data_args.max_eval_samples is not None: 358 | ds["validation"] = ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) 359 | # Set the validation transforms 360 | ds["validation"].set_transform(preprocess_images) 361 | 362 | # Compute absolute learning rate 363 | total_train_batch_size = ( 364 | training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size 365 | ) 366 | if training_args.base_learning_rate is not None: 367 | training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256 368 | 369 | # Initialize our trainer 370 | trainer = Trainer( 371 | model=model, 372 | args=training_args, 373 | train_dataset=ds["train"] if training_args.do_train else None, 374 | eval_dataset=ds["validation"] if training_args.do_eval else None, 375 | tokenizer=image_processor, 376 | data_collator=collate_fn, 377 | ) 378 | 379 | # Training 380 | if training_args.do_train: 381 | checkpoint = None 382 | if training_args.resume_from_checkpoint is not None: 383 | checkpoint = training_args.resume_from_checkpoint 384 | elif last_checkpoint is not None: 385 | checkpoint = last_checkpoint 386 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 387 | trainer.save_model() 388 | trainer.log_metrics("train", train_result.metrics) 389 | trainer.save_metrics("train", train_result.metrics) 390 | trainer.save_state() 391 | 392 | # Evaluation 393 | if training_args.do_eval: 394 | metrics = trainer.evaluate() 395 | trainer.log_metrics("eval", metrics) 396 | trainer.save_metrics("eval", metrics) 397 | 398 | # Write model card and (optionally) push to hub 399 | kwargs = { 400 | "tasks": "masked-auto-encoding", 401 | "dataset": data_args.dataset_name, 402 | "tags": ["masked-auto-encoding"], 403 | } 404 | if training_args.push_to_hub: 405 | trainer.push_to_hub(**kwargs) 406 | else: 407 | trainer.create_model_card(**kwargs) 408 | 409 | 410 | def _mp_fn(index): 411 | # For xla_spawn (TPUs) 412 | main() 413 | 414 | 415 | if __name__ == "__main__": 416 | main() 417 | -------------------------------------------------------------------------------- /docmae/pretrain_pl.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from pathlib import Path 5 | 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms as T 8 | import lightning as L 9 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 10 | from lightning.pytorch.loggers import TensorBoardLogger 11 | from transformers import ViTMAEConfig, ViTMAEModel, ViTImageProcessor 12 | from transformers.models.vit_mae.modeling_vit_mae import ViTMAEDecoder 13 | 14 | from docmae import setup_logging 15 | from docmae.data.list_dataset import ListDataset 16 | from docmae.models.mae import MAE 17 | from docmae.pretrain import parse_arguments 18 | 19 | """ Pre-training a 🤗 ViT model as an MAE (masked autoencoder), as proposed in https://arxiv.org/abs/2111.06377.""" 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def train(args, config_file: str): 25 | config = json.loads(Path(config_file).read_text()) 26 | L.seed_everything(config["seed"]) 27 | 28 | pretrained_config = ViTMAEConfig.from_pretrained(config_file) 29 | pretrained_config.image_size = 288 30 | encoder = ViTMAEModel(pretrained_config) 31 | decoder = ViTMAEDecoder(pretrained_config, encoder.embeddings.num_patches) 32 | encoder.mask_ratio = 0.75 33 | decoder.mask_ratio = 0.75 34 | 35 | callback_list = [ 36 | LearningRateMonitor(logging_interval="step"), 37 | ModelCheckpoint( 38 | dirpath=args.model_output_dir, 39 | filename="epoch_{epoch:d}", 40 | monitor="val/loss", 41 | mode="min", 42 | save_top_k=config["save_total_limit"], 43 | ), 44 | ] 45 | 46 | tb_logger = TensorBoardLogger(save_dir=args.tensorboard_dir, log_graph=False, default_hp_metric=False) 47 | num_epochs = config["num_train_epochs"] 48 | trainer = L.Trainer( 49 | logger=tb_logger, 50 | callbacks=callback_list, 51 | accelerator="cuda", 52 | max_epochs=num_epochs, 53 | enable_progress_bar=not config["disable_tqdm"], 54 | 55 | limit_train_batches=200_000 // num_epochs, 56 | limit_val_batches=10_000 // num_epochs, 57 | val_check_interval=10_000 // num_epochs, 58 | ) 59 | 60 | model = MAE(encoder, decoder, config, training=True) 61 | transforms = T.Compose([T.Resize(size=(288, 288)), T.ToTensor()]) 62 | 63 | dataset_train = ListDataset(Path(config["train_dir"]), "train", transforms) 64 | dataset_val = ListDataset(Path(config["validation_dir"]), "val", transforms) 65 | 66 | loader_train = DataLoader( 67 | dataset_train, 68 | batch_size=config["per_device_train_batch_size"], 69 | shuffle=True, 70 | num_workers=config["dataloader_num_workers"], 71 | pin_memory=True, 72 | ) 73 | loader_val = DataLoader( 74 | dataset_val, 75 | batch_size=config["per_device_eval_batch_size"], 76 | shuffle=False, 77 | num_workers=config["dataloader_num_workers"], 78 | pin_memory=True, 79 | ) 80 | 81 | trainer.fit(model, loader_train, loader_val) 82 | 83 | 84 | def main(): 85 | args = parse_arguments() 86 | setup_logging(log_level=args.log_level, log_dir=args.log_dir) 87 | 88 | assert args.config.endswith(".json") 89 | 90 | # Save config for training traceability and load config parameters 91 | config_file = Path(args.model_output_dir) / "config.json" 92 | config = json.loads(Path(args.config).read_text()) 93 | 94 | config["logging_dir"] = args.tensorboard_dir 95 | config["output_dir"] = os.path.join(args.model_output_dir, "checkpoints") 96 | 97 | config_file.write_text(json.dumps(config)) 98 | train(args, str(config_file)) 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /docmae/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | import argparse 5 | import shutil 6 | from pathlib import Path 7 | 8 | import gin 9 | import gin.torch.external_configurables 10 | import torch 11 | import lightning as L 12 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 13 | from lightning.pytorch.loggers import TensorBoardLogger 14 | from lightning.pytorch.tuner import Tuner 15 | 16 | # We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that 17 | # some APIs may slightly change in the future 18 | import torchvision 19 | 20 | torchvision.disable_beta_transforms_warning() 21 | 22 | from docmae import setup_logging 23 | from docmae.models.transformer import BasicEncoder 24 | from docmae.models.upscale import UpscaleRAFT, UpscaleTransposeConv, UpscaleInterpolate 25 | from docmae.models.doctr import DocTr 26 | from docmae.models.doctr_custom import DocTrOrig 27 | from docmae.models.doctr_plus import DocTrPlus 28 | from docmae.models.rectification import Rectification 29 | from docmae.datamodule.utils import init_external_gin_configurables 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def parse_arguments(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("-c", "--config", required=True, type=Path, help="json config file for training parameters") 37 | parser.add_argument("-d", "--data-config", required=True, type=Path, help="gin config file for the dataloader") 38 | parser.add_argument( 39 | "-ll", 40 | "--log-level", 41 | type=str, 42 | default="INFO", 43 | choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], 44 | help="config file for training parameters", 45 | ) 46 | parser.add_argument("-l", "--log-dir", type=Path, default="", help="folder to store log files") 47 | parser.add_argument("-t", "--tensorboard-dir", type=Path, default="", help="folder to store tensorboard logs") 48 | parser.add_argument("-m", "--model-output-dir", type=Path, default="model", help="folder to store trained models") 49 | return parser.parse_args() 50 | 51 | 52 | @gin.configurable 53 | def train(args, config: dict, datamodule: L.LightningDataModule): 54 | L.seed_everything(config["training"]["seed"]) 55 | 56 | callback_list = [ 57 | LearningRateMonitor(logging_interval="step"), 58 | ModelCheckpoint( 59 | dirpath=args.model_output_dir, 60 | filename="epoch_{epoch:02d}", 61 | monitor="val/loss", 62 | mode="min", 63 | save_top_k=1, 64 | ), 65 | ] 66 | 67 | tb_logger = TensorBoardLogger(save_dir=args.tensorboard_dir, log_graph=False, default_hp_metric=False) 68 | 69 | trainer = L.Trainer( 70 | logger=tb_logger, 71 | callbacks=callback_list, 72 | accelerator="cuda", 73 | devices=max(torch.cuda.device_count(), config["training"]["num_devices"]), # use all gpus if config is -1 74 | max_epochs=config["training"].get("epochs", None), 75 | max_steps=config["training"].get("steps", -1), 76 | num_sanity_val_steps=1, 77 | enable_progress_bar=config["progress_bar"], 78 | ) 79 | 80 | hidden_dim = config["model"]["hidden_dim"] 81 | backbone = BasicEncoder(output_dim=hidden_dim, norm_fn="instance") 82 | model = DocTrPlus(config["model"]) 83 | upscale_type = config["model"]["upscale_type"] 84 | if upscale_type == "raft": 85 | upscale_module = UpscaleRAFT(8, hidden_dim) 86 | elif upscale_type == "transpose_conv": 87 | upscale_module = UpscaleTransposeConv(hidden_dim, hidden_dim // 2) 88 | elif upscale_type == "interpolate": 89 | upscale_module = UpscaleInterpolate(hidden_dim, hidden_dim // 2) 90 | else: 91 | raise NotImplementedError 92 | model = Rectification(backbone, model, upscale_module, config).cuda() 93 | 94 | # test export 95 | print(model.to_torchscript(method="trace")) 96 | 97 | tuner = Tuner(trainer) 98 | tuner.scale_batch_size(model, datamodule=datamodule, mode="power") 99 | 100 | trainer.fit(model, datamodule=datamodule) 101 | 102 | print(callback_list[1].best_model_path) 103 | 104 | 105 | def main(): 106 | args = parse_arguments() 107 | setup_logging(log_level=args.log_level, log_dir=args.log_dir) 108 | 109 | assert args.config.suffix == ".json" 110 | 111 | # Save config for training traceability and load config parameters 112 | config_file = args.model_output_dir / "config.json" 113 | config = json.loads(args.config.read_text()) 114 | shutil.copyfile(args.config, config_file) 115 | 116 | init_external_gin_configurables() 117 | gin.parse_config_file(args.data_config) 118 | Path(args.model_output_dir / "data_config.gin").write_text(gin.operative_config_str()) 119 | 120 | train(args, config, datamodule=gin.REQUIRED) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "docmae" 3 | version = "0.1" 4 | description = "Unofficial implementation of DocMAE paper" 5 | authors = ["David Komorowicz "] 6 | packages = [ 7 | { include = "docmae" } # dir name containing code if differs from name 8 | ] 9 | 10 | [tool.poetry.dependencies] 11 | python = "~3.10" 12 | opencv-python-headless = "^4.5" 13 | Pillow-SIMD = "^9.0" 14 | torch = [ 15 | {markers = "sys_platform == 'linux'", url="https://download.pytorch.org/whl/cu118/torch-2.0.1%2Bcu118-cp310-cp310-linux_x86_64.whl"}, 16 | {markers = "sys_platform == 'darwin' and platform_machine == 'x86_64'", url="https://download.pytorch.org/whl/cpu/torch-2.0.1-cp310-none-macosx_10_9_x86_64.whl"}, 17 | {markers = "sys_platform == 'darwin' and platform_machine == 'arm64'", url="https://download.pytorch.org/whl/cpu/torch-2.0.1-cp310-none-macosx_11_0_arm64.whl"}, 18 | {markers = "sys_platform == 'win32'", url="https://download.pytorch.org/whl/cu118/torch-2.0.1%2Bcu118-cp310-cp310-win_amd64.whl"}, 19 | ] 20 | torchvision = [ 21 | {markers = "sys_platform == 'linux'", url="https://download.pytorch.org/whl/cu118/torchvision-0.15.2%2Bcu118-cp310-cp310-linux_x86_64.whl"}, 22 | {markers = "sys_platform == 'darwin' and platform_machine == 'x86_64'", url="https://download.pytorch.org/whl/cpu/torchvision-0.15.2-cp310-cp310-macosx_10_9_x86_64.whl"}, 23 | {markers = "sys_platform == 'darwin' and platform_machine == 'arm64'", url="https://download.pytorch.org/whl/cpu/torchvision-0.15.2-cp310-cp310-macosx_10_9_x86_64.whl"}, 24 | {markers = "sys_platform == 'win32'", url="https://download.pytorch.org/whl/cu118/torchvision-0.15.2%2Bcu118-cp310-cp310-win_amd64.whl"}, 25 | ] 26 | lightning = "^2.0.4" 27 | pydantic = "<2.0.0" # https://github.com/Lightning-AI/lightning/issues/18027 28 | torchmetrics = "^0.11.1" 29 | torchsummary = "^1.5.1" 30 | datasets = "^2.13" 31 | transformers = "^4.30" 32 | accelerate = "^0.20.1" 33 | 34 | matplotlib = "^3.6" 35 | numpy = "^1.24" 36 | scikit-image = "^0.19" 37 | scikit-learn = ">=1.0" 38 | tensorboard = "^2.3" 39 | minio = "^7.1.15" 40 | h5py = "^3.9.0" 41 | kornia = "^0.7.0" 42 | gin-config = "^0.5.0" 43 | 44 | [tool.poetry.group.test.dependencies] 45 | pytest = "^7.4.0" 46 | 47 | [build-system] 48 | requires = ["poetry-core>=1.0.0"] 49 | build-backend = "poetry.core.masonry.api" 50 | 51 | # optional scripts 52 | [tool.poetry.scripts] 53 | docmae-train = "docmae.train:main" 54 | docmae-pretrain = "docmae.pretrain_pl:main" 55 | docmae-finetune = "docmae.fine_tune:main" 56 | 57 | [tool.black] 58 | line-length = 128 59 | include = '\.pyi?$' 60 | exclude = ''' 61 | ( 62 | /( 63 | \.eggs # exclude a few common directories in the 64 | | \.git # root of the project 65 | | \.hg 66 | | \.mypy_cache 67 | | \.pytest_cache 68 | | \.tox 69 | | \.venv 70 | | venv 71 | | _build 72 | | buck-out 73 | | build 74 | | dist 75 | )/ 76 | ) 77 | ''' -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/tests/__init__.py -------------------------------------------------------------------------------- /tests/architecture.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import torch 5 | import torchvision 6 | from transformers import ViTMAEConfig, AutoImageProcessor, ViTImageProcessor, ViTMAEModel 7 | from transformers.models.vit_mae.modeling_vit_mae import ViTMAEDecoder 8 | 9 | from docmae.models.docmae import DocMAE 10 | 11 | 12 | def test_upscaling(): 13 | config_file = "./config/finetune.json" 14 | config = json.loads(Path(config_file).read_text()) 15 | 16 | pretrained_config = ViTMAEConfig.from_pretrained(config["mae_path"]) 17 | pretrained_config.mask_ratio = 0 18 | 19 | with torch.device("cuda"): 20 | mae_encoder = ViTMAEModel(pretrained_config) 21 | mae_decoder = ViTMAEDecoder(pretrained_config, mae_encoder.embeddings.num_patches) 22 | 23 | model = DocMAE(mae_encoder, mae_decoder, config).cuda() 24 | model.eval() 25 | 26 | x = torch.rand([4, 3, 288, 288]) 27 | x.requires_grad = True 28 | out = model.forward(x) # fmap 29 | 30 | i = 2 31 | loss = out[i].sum() 32 | 33 | loss.backward() 34 | 35 | mask = torch.zeros([4]) 36 | mask[i] = 1 37 | mask = mask.bool() 38 | 39 | assert torch.count_nonzero(x.grad[~mask]) == 0 40 | assert torch.count_nonzero(x.grad[i]) > 0 41 | 42 | -------------------------------------------------------------------------------- /tests/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from torchvision.transforms import InterpolationMode 6 | import torchvision.transforms.v2 as transforms 7 | 8 | from docmae.data.doc3d import Doc3D 9 | from docmae.data.docaligner import DocAligner 10 | from docmae.data.augmentation.random_resized_crop import RandomResizedCropWithUV 11 | from docmae.data.augmentation.replace_background import ReplaceBackground 12 | 13 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 14 | 15 | 16 | def _test_dataset(sample, size): 17 | 18 | # size 19 | assert sample["image"].shape == (3, size, size) 20 | assert sample["bm"].shape == (2, size, size) 21 | assert sample["uv"].shape == (2, size, size) 22 | assert sample["mask"].shape == (1, size, size) 23 | assert sample["mask"].dtype == torch.bool 24 | 25 | # range 26 | assert 0 <= sample["image"].min() <= sample["image"].max() <= 255 27 | assert 0 <= sample["bm"].min() <= sample["bm"].max() < 1 # [0, 1] 28 | assert 0 <= sample["uv"].min() <= sample["uv"].max() <= 1 29 | assert set(sample["mask"].numpy().flatten()) == {1.0, 0.0} 30 | 31 | 32 | def test_doc3d_raw_output(): 33 | dataset = Doc3D(Path("./tests/doc3d/"), "tiny") 34 | 35 | sample = dataset[0] 36 | _test_dataset(sample, 448) 37 | 38 | 39 | def test_docaligner_raw_output(): 40 | dataset = DocAligner(Path("/home/dawars/datasets/DocAligner_result/"), "tiny") 41 | 42 | sample = dataset[0] 43 | _test_dataset(sample, 1024) 44 | 45 | 46 | def test_transforms(): 47 | transform = transforms.Compose( 48 | [ 49 | transforms.Resize((288, 288), antialias=True), 50 | transforms.ToImageTensor(), 51 | transforms.ToDtype(torch.float32), 52 | ] 53 | ) 54 | 55 | dataset = Doc3D(Path("./tests/doc3d/"), "tiny", transform) 56 | 57 | sample = dataset[0] 58 | 59 | assert sample["image"].shape == (3, 288, 288) 60 | assert sample["bm"].shape == (2, 288, 288) 61 | assert sample["uv"].shape == (2, 288, 288) 62 | assert sample["mask"].shape == (1, 288, 288) 63 | 64 | # range 65 | assert 0 <= sample["image"].min() <= sample["image"].max() <= 255 66 | assert 0 <= sample["bm"].min() <= sample["bm"].max() < 1 # [0, 1] 67 | assert 0 <= sample["uv"].min() <= sample["uv"].max() <= 1 68 | assert set(sample["mask"].numpy().flatten()) == {1.0, 0.0} 69 | 70 | 71 | def test_crop(): 72 | transform = RandomResizedCropWithUV((288, 288), interpolation=InterpolationMode.BICUBIC, antialias=True) 73 | 74 | image = torch.randint(0, 255, (3, 448, 448)) 75 | bm = torch.rand((2, 448, 448)) 76 | uv = torch.rand((2, 448, 448)) 77 | mask = torch.randint(0, 2, (1, 448, 448)) # 0 or 1 78 | 79 | image, bm, uv, mask = transform((image, bm, uv, mask)) 80 | 81 | # size 82 | assert image.shape == (3, 288, 288) 83 | assert bm.shape == (2, 288, 288) 84 | assert uv.shape == (2, 288, 288) 85 | assert mask.shape == (1, 288, 288) 86 | assert mask.dtype == torch.bool or mask.dtype == torch.long 87 | 88 | # range 89 | assert 0 <= image.min() <= image.max() <= 255 90 | assert -0.9 <= bm.min() <= bm.max() <= 1.9 # [0, 1] but crop reaches outside 91 | assert 0 <= uv.min() <= uv.max() <= 1 92 | assert set(mask.numpy().flatten()) == {1.0, 0.0} 93 | 94 | 95 | def test_replace_background(): 96 | transform = ReplaceBackground(Path("./tests/dtd"), "tiny") 97 | 98 | image = torch.randint(0, 255, (3, 288, 288)) 99 | bm = torch.rand((2, 288, 288)) 100 | uv = torch.rand((2, 288, 288)) 101 | mask = torch.randint(0, 2, (1, 288, 288)).bool() 102 | 103 | image, bm, uv, mask = transform((image, bm, uv, mask)) 104 | 105 | # size 106 | assert image.shape == (3, 288, 288) 107 | assert bm.shape == (2, 288, 288) 108 | assert uv.shape == (2, 288, 288) 109 | assert mask.shape == (1, 288, 288) 110 | assert mask.dtype == torch.bool 111 | # range 112 | assert 0 <= image.min() <= image.max() <= 255 113 | assert -0.9 <= bm.min() <= bm.max() <= 1.9 # [0, 1] but crop reaches outside 114 | assert 0 <= uv.min() <= uv.max() <= 1 115 | assert set(mask.numpy().flatten()) == {1.0, 0.0} 116 | -------------------------------------------------------------------------------- /tests/doc3d/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 CVLab@StonyBrook 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /tests/doc3d/bm/9_6-vc_Page_002-YT40001.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/tests/doc3d/bm/9_6-vc_Page_002-YT40001.mat -------------------------------------------------------------------------------- /tests/doc3d/img/9_6-vc_Page_002-YT40001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/tests/doc3d/img/9_6-vc_Page_002-YT40001.png -------------------------------------------------------------------------------- /tests/doc3d/tiny.txt: -------------------------------------------------------------------------------- 1 | 9_6-vc_Page_002-YT40001 -------------------------------------------------------------------------------- /tests/doc3d/uv/9_6-vc_Page_002-YT40001.exr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/tests/doc3d/uv/9_6-vc_Page_002-YT40001.exr -------------------------------------------------------------------------------- /tests/dtd/images/lacelike_0037.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dawars/DocMAE/de4cd087f6e82991d9a757bcd05a866b6b2fc95e/tests/dtd/images/lacelike_0037.jpg -------------------------------------------------------------------------------- /tests/dtd/labels/tiny.txt: -------------------------------------------------------------------------------- 1 | lacelike_0037.jpg --------------------------------------------------------------------------------