├── .gitattributes ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── poetry.lock ├── pyproject.toml ├── requirements.txt └── src ├── ambientcg_download.py ├── blender_render_samples_3d.py ├── bounding_box.py ├── bounding_box_math.py ├── calculate_bounding_boxes.py ├── font_rendering.py ├── generate_samples_2d.py ├── main.py ├── polyhaven_download.py ├── postprocess.py ├── prepare_data.py ├── shuffle_iter.py └── tests ├── test_assets ├── SilentReaction.ttf ├── coordinates0001.png ├── image0001.png ├── test_alpha_font_rendering_generate.png └── test_img_font_rendering_generate.png └── test_font_rendering.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | assets/ 2 | output/ 3 | fonts_filter.ipynb 4 | dev.ipynb 5 | tests.ipynb 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 111 | __pypackages__/ 112 | 113 | # Celery stuff 114 | celerybeat-schedule 115 | celerybeat.pid 116 | 117 | # SageMath parsed files 118 | *.sage.py 119 | 120 | # Environments 121 | .env 122 | .venv 123 | env/ 124 | venv/ 125 | ENV/ 126 | env.bak/ 127 | venv.bak/ 128 | 129 | # Spyder project settings 130 | .spyderproject 131 | .spyproject 132 | 133 | # Rope project settings 134 | .ropeproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | .dmypy.json 142 | dmypy.json 143 | 144 | # Pyre type checker 145 | .pyre/ 146 | 147 | # pytype static type analyzer 148 | .pytype/ 149 | 150 | # Cython debug symbols 151 | cython_debug/ 152 | 153 | # PyCharm 154 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 155 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 156 | # and can be added to the global gitignore or merged into this file. For a more nuclear 157 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 158 | #.idea/ 159 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/Blender_3D_document_rendering_pipeline"] 2 | path = src/Blender_3D_document_rendering_pipeline 3 | url = https://github.com/GbotHQ/Blender-3D-document-rendering-pipeline.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 GbotHQ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Handwritten OCR dataset rendering 2 | 3 | Generate and render annotated images with text on a virtual paper with folds and other types of damage using PIL and Blender 3D. 4 | 5 | ## Installation 6 | 7 | 1. Clone the repository: 8 | 9 | ``` 10 | git clone --recursive https://github.com/GbotHQ/ocr-dataset-rendering.git 11 | ``` 12 | 13 | 2. Install the required packages: 14 | 15 | ``` 16 | cd ocr-dataset-rendering 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | 3. Download Blender: 21 | 22 | ``` 23 | cd src/Blender_3D_document_rendering_pipeline 24 | bash download_blender_binary.sh 25 | cd ../ 26 | ``` 27 | 28 | ## Usage 29 | 30 | example usage: 31 | ``` 32 | rm -r ../output 33 | python main.py --n_samples 2 --blender_path "Blender_3D_document_rendering_pipeline/blender-3.4.0-linux-x64/blender" --output_dir ../output --device cpu --resolution_x 512 --resolution_y 512 --compression_level 9 34 | ``` 35 | ``` 36 | python main.py --help 37 | ``` 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "ocr-dataset-rendering" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Your Name "] 6 | readme = "README.md" 7 | packages = [{include = "ocr-dataset-rendering"}] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.9" 11 | pillow = "^9.5.0" 12 | opencv-contrib-python = "^4.7.0.72" 13 | datasets = "^2.11.0" 14 | psutil = "^5.9.5" 15 | ipywidgets = "^8.0.6" 16 | gdown = "^4.7.1" 17 | fonttools = "^4.39.3" 18 | scikit-image = "^0.20.0" 19 | fire = "^0.5.0" 20 | pytest = "^7.3.1" 21 | p-tqdm = "^1.4.0" 22 | 23 | 24 | [tool.poetry.group.dev.dependencies] 25 | black = "^23.3.0" 26 | ipykernel = "^6.22.0" 27 | 28 | [build-system] 29 | requires = ["poetry-core"] 30 | build-backend = "poetry.core.masonry.api" 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.4 ; python_version >= "3.9" and python_version < "4.0" 2 | aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0" 3 | appnope==0.1.3 ; python_version >= "3.9" and python_version < "4.0" and platform_system == "Darwin" or python_version >= "3.9" and python_version < "4.0" and sys_platform == "darwin" 4 | asttokens==2.2.1 ; python_version >= "3.9" and python_version < "4.0" 5 | async-timeout==4.0.2 ; python_version >= "3.9" and python_version < "4.0" 6 | attrs==23.1.0 ; python_version >= "3.9" and python_version < "4.0" 7 | backcall==0.2.0 ; python_version >= "3.9" and python_version < "4.0" 8 | beautifulsoup4==4.12.2 ; python_version >= "3.9" and python_version < "4.0" 9 | certifi==2022.12.7 ; python_version >= "3.9" and python_version < "4" 10 | cffi==1.15.1 ; python_version >= "3.9" and python_version < "4.0" and implementation_name == "pypy" 11 | charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0" 12 | colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows" 13 | comm==0.1.3 ; python_version >= "3.9" and python_version < "4.0" 14 | datasets==2.11.0 ; python_version >= "3.9" and python_version < "4.0" 15 | debugpy==1.6.7 ; python_version >= "3.9" and python_version < "4.0" 16 | decorator==5.1.1 ; python_version >= "3.9" and python_version < "4.0" 17 | dill==0.3.6 ; python_version >= "3.9" and python_version < "4.0" 18 | exceptiongroup==1.1.1 ; python_version >= "3.9" and python_version < "3.11" 19 | executing==1.2.0 ; python_version >= "3.9" and python_version < "4.0" 20 | filelock==3.12.0 ; python_version >= "3.9" and python_version < "4.0" 21 | fire==0.5.0 ; python_version >= "3.9" and python_version < "4.0" 22 | fonttools==4.39.3 ; python_version >= "3.9" and python_version < "4.0" 23 | frozenlist==1.3.3 ; python_version >= "3.9" and python_version < "4.0" 24 | fsspec==2023.4.0 ; python_version >= "3.9" and python_version < "4.0" 25 | fsspec[http]==2023.4.0 ; python_version >= "3.9" and python_version < "4.0" 26 | gdown==4.7.1 ; python_version >= "3.9" and python_version < "4.0" 27 | huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0" 28 | idna==3.4 ; python_version >= "3.9" and python_version < "4" 29 | imageio==2.28.1 ; python_version >= "3.9" and python_version < "4.0" 30 | importlib-metadata==6.6.0 ; python_version >= "3.9" and python_version < "3.10" 31 | iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "4.0" 32 | ipykernel==6.22.0 ; python_version >= "3.9" and python_version < "4.0" 33 | ipython==8.12.0 ; python_version >= "3.9" and python_version < "4.0" 34 | ipywidgets==8.0.6 ; python_version >= "3.9" and python_version < "4.0" 35 | jedi==0.18.2 ; python_version >= "3.9" and python_version < "4.0" 36 | jupyter-client==8.2.0 ; python_version >= "3.9" and python_version < "4.0" 37 | jupyter-core==5.3.0 ; python_version >= "3.9" and python_version < "4.0" 38 | jupyterlab-widgets==3.0.7 ; python_version >= "3.9" and python_version < "4.0" 39 | lazy-loader==0.2 ; python_version >= "3.9" and python_version < "4.0" 40 | matplotlib-inline==0.1.6 ; python_version >= "3.9" and python_version < "4.0" 41 | mega.py==1.0.8 ; python_version >= "3.9" and python_version < "4.0" 42 | multidict==6.0.4 ; python_version >= "3.9" and python_version < "4.0" 43 | multiprocess==0.70.14 ; python_version >= "3.9" and python_version < "4.0" 44 | nest-asyncio==1.5.6 ; python_version >= "3.9" and python_version < "4.0" 45 | networkx==3.1 ; python_version >= "3.9" and python_version < "4.0" 46 | numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0" 47 | opencv-contrib-python==4.7.0.72 ; python_version >= "3.9" and python_version < "4.0" 48 | p-tqdm==1.4.0 ; python_version >= "3.9" and python_version < "4.0" 49 | packaging==23.1 ; python_version >= "3.9" and python_version < "4.0" 50 | pandas==2.0.1 ; python_version >= "3.9" and python_version < "4.0" 51 | parso==0.8.3 ; python_version >= "3.9" and python_version < "4.0" 52 | pathos==0.3.0 ; python_version >= "3.9" and python_version < "4.0" 53 | pexpect==4.8.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform != "win32" 54 | pickleshare==0.7.5 ; python_version >= "3.9" and python_version < "4.0" 55 | pillow==9.5.0 ; python_version >= "3.9" and python_version < "4.0" 56 | platformdirs==3.3.0 ; python_version >= "3.9" and python_version < "4.0" 57 | pluggy==1.0.0 ; python_version >= "3.9" and python_version < "4.0" 58 | pox==0.3.2 ; python_version >= "3.9" and python_version < "4.0" 59 | ppft==1.7.6.6 ; python_version >= "3.9" and python_version < "4.0" 60 | prompt-toolkit==3.0.38 ; python_version >= "3.9" and python_version < "4.0" 61 | psutil==5.9.5 ; python_version >= "3.9" and python_version < "4.0" 62 | ptyprocess==0.7.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform != "win32" 63 | pure-eval==0.2.2 ; python_version >= "3.9" and python_version < "4.0" 64 | pyarrow==11.0.0 ; python_version >= "3.9" and python_version < "4.0" 65 | pycparser==2.21 ; python_version >= "3.9" and python_version < "4.0" and implementation_name == "pypy" 66 | pygments==2.15.1 ; python_version >= "3.9" and python_version < "4.0" 67 | pysocks==1.7.1 ; python_version >= "3.9" and python_version < "4" 68 | pytest==7.3.1 ; python_version >= "3.9" and python_version < "4.0" 69 | python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0" 70 | pytz==2023.3 ; python_version >= "3.9" and python_version < "4.0" 71 | pywavelets==1.4.1 ; python_version >= "3.9" and python_version < "4.0" 72 | pywin32==306 ; sys_platform == "win32" and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "4.0" 73 | pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0" 74 | pyzmq==25.0.2 ; python_version >= "3.9" and python_version < "4.0" 75 | requests==2.28.2 ; python_version >= "3.9" and python_version < "4" 76 | requests[socks]==2.28.2 ; python_version >= "3.9" and python_version < "4" 77 | responses==0.18.0 ; python_version >= "3.9" and python_version < "4.0" 78 | scikit-image==0.20.0 ; python_version >= "3.9" and python_version < "4.0" 79 | scipy==1.9.1 ; python_version == "3.9" 80 | scipy==1.9.3 ; python_version > "3.9" and python_version < "4.0" 81 | six==1.16.0 ; python_version >= "3.9" and python_version < "4.0" 82 | soupsieve==2.4.1 ; python_version >= "3.9" and python_version < "4.0" 83 | stack-data==0.6.2 ; python_version >= "3.9" and python_version < "4.0" 84 | termcolor==2.3.0 ; python_version >= "3.9" and python_version < "4.0" 85 | tifffile==2023.4.12 ; python_version >= "3.9" and python_version < "4.0" 86 | tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" 87 | tornado==6.3.1 ; python_version >= "3.9" and python_version < "4.0" 88 | tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" 89 | traitlets==5.9.0 ; python_version >= "3.9" and python_version < "4.0" 90 | typing-extensions==4.5.0 ; python_version >= "3.9" and python_version < "4.0" 91 | tzdata==2023.3 ; python_version >= "3.9" and python_version < "4.0" 92 | urllib3==1.26.15 ; python_version >= "3.9" and python_version < "4" 93 | wcwidth==0.2.6 ; python_version >= "3.9" and python_version < "4.0" 94 | widgetsnbextension==4.0.7 ; python_version >= "3.9" and python_version < "4.0" 95 | xxhash==3.2.0 ; python_version >= "3.9" and python_version < "4.0" 96 | yarl==1.9.2 ; python_version >= "3.9" and python_version < "4.0" 97 | zipp==3.15.0 ; python_version >= "3.9" and python_version < "3.10" 98 | -------------------------------------------------------------------------------- /src/ambientcg_download.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any 2 | import requests 3 | from pathlib import Path as pth 4 | import zipfile 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def find_string_containing_substring_in_list(strings: List[str], substring: str) -> str: 10 | return next((s for s in strings if substring in s), None) 11 | 12 | 13 | def is_list_of_substrings_in_list_of_strings( 14 | strings: List[str], substrings: List[str] 15 | ) -> bool: 16 | return all( 17 | bool(find_string_containing_substring_in_list(strings, k)) for k in substrings 18 | ) 19 | 20 | 21 | def download_url(url: str, out_path: pth) -> None: 22 | with requests.Session() as session: 23 | response = session.get(url, stream=True) 24 | response.raise_for_status() 25 | with open(out_path, "wb") as f: 26 | for chunk in response.iter_content(chunk_size=8192): 27 | f.write(chunk) 28 | 29 | 30 | class AmbientCGDownloader: 31 | def __init__(self, extension: str, resolution: str, texture_types: List[str]): 32 | self.extension: str = extension 33 | self.resolution: str = resolution 34 | self.texture_types: List[str] = texture_types 35 | self.texture_type_to_filename_map: Dict[str, str] = { 36 | "Color": "albedo", 37 | "Roughness": "roughness", 38 | "Displacement": "displacement", 39 | } 40 | self.api_endpoint: str = "https://ambientcg.com/api/v2/full_json" 41 | 42 | def _get_material_assets(self, params: Dict[str, Any]) -> Dict[str, Any]: 43 | with requests.Session() as session: 44 | response = session.get(self.api_endpoint, params=params) 45 | 46 | response.raise_for_status() 47 | data = response.json() 48 | return data 49 | 50 | def _get_material_assets_with_pagination( 51 | self, params: Dict[str, Any] 52 | ) -> List[Dict[str, Any]]: 53 | data_list: List[Dict[str, Any]] = [] 54 | offset = 0 55 | while offset < params["total"]: 56 | params["offset"] = offset 57 | data_list.append(self._get_material_assets(params)) 58 | offset += params["limit"] 59 | return data_list 60 | 61 | def _has_required_texture_types(self, file_info: Dict[str, Any]) -> bool: 62 | texture_type_substrings = [ 63 | f"{filename}.{self.extension}" for filename in self.texture_types 64 | ] 65 | return is_list_of_substrings_in_list_of_strings( 66 | file_info["zipContent"], texture_type_substrings 67 | ) 68 | 69 | def _is_matching_resolution_and_type(self, file_info: Dict[str, Any]) -> bool: 70 | return ( 71 | file_info["attribute"] 72 | == f"{self.resolution.upper()}-{self.extension.upper()}" 73 | ) 74 | 75 | def _get_matching_files(self, asset: Dict[str, Any]) -> List[Dict[str, str]]: 76 | return [ 77 | file_info 78 | for file_info in asset["downloadFolders"]["default"][ 79 | "downloadFiletypeCategories" 80 | ]["zip"]["downloads"] 81 | if self._is_matching_resolution_and_type(file_info) 82 | and self._has_required_texture_types(file_info) 83 | ] 84 | 85 | def _map_asset_to_dict( 86 | self, asset: Dict[str, Any], file_info: Dict[str, Any] 87 | ) -> Dict[str, str]: 88 | return { 89 | "filename": file_info["fileName"], 90 | "id": asset["assetId"], 91 | "url": file_info["downloadLink"], 92 | } 93 | 94 | def _get_matching_assets( 95 | self, assets: List[Dict[str, Any]] 96 | ) -> List[Dict[str, str]]: 97 | file_dicts = [] 98 | for asset in assets: 99 | matching_files = self._get_matching_files(asset) 100 | file_dicts.extend( 101 | self._map_asset_to_dict(asset, file_info) 102 | for file_info in matching_files 103 | ) 104 | return file_dicts 105 | 106 | def _download_material_zip(self, asset: Dict[str, str], out_dir: str) -> pth: 107 | out_dir = pth(out_dir) 108 | zip_path = out_dir / asset["filename"] 109 | download_url(asset["url"], zip_path) 110 | return zip_path 111 | 112 | def _extract_texture_from_zip( 113 | self, zip_file: zipfile.ZipFile, texture_type: str, out_dir: pth 114 | ) -> pth: 115 | try: 116 | filename = find_string_containing_substring_in_list( 117 | zip_file.namelist(), f"{texture_type}.{self.extension}" 118 | ) 119 | if not filename: 120 | return 121 | 122 | out_path = out_dir / filename 123 | zip_file.extract(filename, out_dir) 124 | return out_path 125 | except Exception as e: 126 | print(f"Error extracting texture from zip file: {e}") 127 | return 128 | 129 | def _extract_textures_from_zip(self, zip_path: pth, out_dir: pth) -> Dict[str, pth]: 130 | out_dir = pth(out_dir) 131 | textures: Dict[str, pth] = {} 132 | try: 133 | with zipfile.ZipFile(zip_path, "r") as zip_file: 134 | for texture_type in self.texture_types: 135 | mapped_texture_type = self.texture_type_to_filename_map[ 136 | texture_type 137 | ] 138 | filename = f"{zip_path.stem}_{mapped_texture_type}_{self.resolution}.{self.extension}" 139 | out_path = out_dir / filename 140 | if original_out_path := self._extract_texture_from_zip( 141 | zip_file, texture_type, out_dir 142 | ): 143 | original_out_path.rename(out_path) 144 | textures[mapped_texture_type] = out_path 145 | except Exception as e: 146 | print(f"Error extracting textures from zip file: {e}") 147 | 148 | return textures 149 | 150 | def _get_and_filter_material_assets(self) -> List[Dict[str, str]]: 151 | params: Dict[str, Any] = { 152 | "method": "", 153 | "type": "Material", 154 | "sort": "Alphabet", 155 | "include": "downloadData", 156 | "limit": 250, 157 | } 158 | 159 | response = self._get_material_assets(params) 160 | total_count = response["numberOfResults"] 161 | params["total"] = total_count 162 | 163 | data_list = self._get_material_assets_with_pagination(params) 164 | return self._get_matching_assets( 165 | [asset for data in data_list for asset in data["foundAssets"]] 166 | ) 167 | 168 | def download_and_extract_materials(self, out_dir: str) -> List[Dict[str, Any]]: 169 | out_dir = pth(out_dir) 170 | out_dir.mkdir(parents=True, exist_ok=True) 171 | assets = self._get_and_filter_material_assets() 172 | 173 | output_assets: List[Dict[str, Any]] = [] 174 | for asset in tqdm(assets): 175 | try: 176 | zip_path = self._download_material_zip(asset, out_dir) 177 | textures = self._extract_textures_from_zip(zip_path, out_dir) 178 | asset["textures"] = textures 179 | output_assets.append(asset) 180 | zip_path.unlink() 181 | except Exception as e: 182 | print(f"Error downloading or extracting material: {e}") 183 | 184 | return output_assets 185 | -------------------------------------------------------------------------------- /src/blender_render_samples_3d.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import subprocess 3 | from pathlib import Path as pth 4 | from shutil import rmtree 5 | 6 | 7 | def mkdir(path: pth): 8 | if path.is_dir(): 9 | rmtree(path) 10 | path.mkdir(parents=True) 11 | 12 | 13 | def run_blender_command( 14 | blender_binary_path: Union[str, pth], 15 | config_dir: Union[str, pth], 16 | output_dir: Union[str, pth], 17 | device: str, 18 | ): 19 | if not pth(output_dir).is_dir(): 20 | raise FileNotFoundError( 21 | f"Output directory {output_dir} does not exist, please create it first" 22 | ) 23 | 24 | out = subprocess.run( 25 | [ 26 | str(blender_binary_path), 27 | pth("./Blender_3D_document_rendering_pipeline/blender/scene.blend").resolve(), 28 | "--background", 29 | "--factory-startup", 30 | "--threads", 31 | "0", 32 | "--engine", 33 | "CYCLES", 34 | "--enable-autoexec", 35 | "--python", 36 | pth("./Blender_3D_document_rendering_pipeline/src/main.py").resolve(), 37 | "--", 38 | "--cycles-device", 39 | device, 40 | "--config_path", 41 | str(config_dir), 42 | ], 43 | capture_output=True, 44 | ) 45 | print(out.stdout.decode("utf-8")) 46 | if out.returncode != 0: 47 | raise Exception(out.stderr.decode("utf-8")) 48 | -------------------------------------------------------------------------------- /src/bounding_box.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import copy 3 | 4 | import numpy as np 5 | import cv2 as cv 6 | 7 | 8 | def remap_point(point, coords): 9 | point = np.array(point) 10 | 11 | # distance from point 12 | distance = np.amax(np.abs(point - coords), axis=-1) 13 | 14 | # take the pixel with the lowest distance 15 | return np.unravel_index(np.argmin(distance), distance.shape) 16 | 17 | 18 | class BaseBoundingBox: 19 | def __init__(self, dtype=np.int32): 20 | self.dtype = dtype 21 | self.points = np.array([], self.dtype) 22 | 23 | def __getitem__(self, index: int): 24 | return self.points[index] 25 | 26 | def to_simple(self): 27 | return self 28 | 29 | def to_quad(self): 30 | return self 31 | 32 | def astype(self, dtype): 33 | bbox = self.copy() 34 | bbox.dtype = dtype 35 | bbox.points = self.points.astype(dtype) 36 | return bbox 37 | 38 | def relative(self, size, dtype=np.float32): 39 | bbox = self.copy() 40 | bbox.dtype = dtype 41 | bbox.points = bbox.points.astype(bbox.dtype) / np.array(size, dtype)[None] 42 | return bbox 43 | 44 | def absolute(self, size, dtype=np.int32): 45 | bbox = self.copy() 46 | bbox.dtype = dtype 47 | bbox.points = bbox.points.astype(bbox.dtype) * np.array(size, dtype)[None] 48 | return bbox 49 | 50 | def get_size(self): 51 | simple = self.to_simple() 52 | return simple.get_size() 53 | 54 | def draw(self, img: np.ndarray, col: Tuple[int, int, int] = (255, 0, 0)) -> np.ndarray: 55 | return self.to_quad().draw(img, col) 56 | 57 | def copy(self): 58 | return copy.deepcopy(self) 59 | 60 | 61 | class SimpleBoundingBox(BaseBoundingBox): 62 | def __init__(self, p0: Tuple[int, int], p1: Tuple[int, int], dtype=np.int32): 63 | super().__init__(dtype) 64 | self.points = np.array((p0, p1), self.dtype) 65 | 66 | def xxyy(self): 67 | x0, y0, x1, y1 = self.points.ravel() 68 | return np.array((x0, x1, y0, y1)).tolist() 69 | 70 | def to_quad(self): 71 | x0, y0, x1, y1 = self.points.ravel() 72 | return QuadBoundingBox((x0, y0), (x1, y0), (x1, y1), (x0, y1), self.dtype) 73 | 74 | def get_size(self): 75 | return np.array(self.points[1] - self.points[0]) 76 | 77 | 78 | class QuadBoundingBox(BaseBoundingBox): 79 | def __init__(self, p0, p1, p2, p3, dtype=np.int32): 80 | super().__init__(dtype) 81 | self.points = np.array((p0, p1, p2, p3), self.dtype) 82 | 83 | def to_simple(self): 84 | return SimpleBoundingBox( 85 | np.amin(self.points, axis=0), np.amax(self.points, axis=0), self.dtype 86 | ) 87 | 88 | def remap(self, coords): 89 | bbox = self.copy() 90 | for i in range(4): 91 | bbox.points[i] = remap_point(self.points[i][::-1], coords) 92 | return bbox 93 | 94 | def draw(self, img: np.ndarray, col: Tuple[int, int, int] = (255, 0, 0)) -> np.ndarray: 95 | points = self.points[:, ::-1] 96 | img = img.copy() 97 | img = cv.polylines(img, (points,), True, col, 1, cv.LINE_AA) 98 | for i in range(4): 99 | img = cv.circle(img, points[i], 2, col, -1, cv.LINE_AA) 100 | return img 101 | 102 | -------------------------------------------------------------------------------- /src/bounding_box_math.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import cv2 as cv 5 | from PIL import Image, ImageDraw 6 | 7 | from bounding_box import SimpleBoundingBox, QuadBoundingBox 8 | 9 | 10 | def bbox_from_mask( 11 | img: np.ndarray, background_color: int, tolerance: int = 12 12 | ) -> SimpleBoundingBox: 13 | mask = np.abs(img.astype(np.int32) - background_color) > tolerance 14 | mask = (mask.any(0), mask.any(1)) 15 | 16 | bbox = [(np.argmax(k), k.size - np.argmax(k[::-1])) for k in mask] 17 | bbox = np.swapaxes(np.array(bbox, np.int32), 0, 1) 18 | bbox = SimpleBoundingBox(bbox[0], bbox[1], bbox.dtype) 19 | 20 | bbox.points = bbox.points[:, ::-1] 21 | return bbox 22 | 23 | 24 | def remap_value(value, low1, high1, low2, high2): 25 | return low2 + (value - low1) * (high2 - low2) / (high1 - low1) 26 | 27 | 28 | def perspective_transform_points(array: np.ndarray, matrix: np.ndarray) -> np.ndarray: 29 | return cv.perspectiveTransform(array.reshape(-1, 1, 2), matrix).reshape(-1, 2) 30 | 31 | 32 | def transform_quad_to_fit_points(quad, points): 33 | reference_rect = np.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=np.float32) 34 | 35 | quad = quad.astype(np.float32) 36 | points = points.astype(np.float32) 37 | 38 | # Compute the perspective transform from the quad to a rectangle 39 | transform_matrix = cv.getPerspectiveTransform(quad, reference_rect) 40 | transform_matrix_inverse = cv.getPerspectiveTransform(reference_rect, quad) 41 | 42 | # Transform using the matrix 43 | quad = perspective_transform_points(quad, transform_matrix) 44 | points = perspective_transform_points(points, transform_matrix) 45 | 46 | quad = quad.astype(np.float32) 47 | points = points.astype(np.float32) 48 | 49 | quad_transformed_min = np.amin(quad, axis=0) 50 | quad_transformed_max = np.amax(quad, axis=0) 51 | points_transformed_min = np.amin(points, axis=0) 52 | points_transformed_max = np.amax(points, axis=0) 53 | 54 | quad = remap_value( 55 | quad, 56 | quad_transformed_min[None], 57 | quad_transformed_max[None], 58 | points_transformed_min[None], 59 | points_transformed_max[None], 60 | ) 61 | 62 | return perspective_transform_points(quad, transform_matrix_inverse) 63 | 64 | 65 | def get_tight_character_bbox(char, font): 66 | mask = font.getmask(char) 67 | mask = np.array(mask, dtype=np.uint8).reshape(mask.size[::-1]) 68 | 69 | return mask, bbox_from_mask(mask, 0, 0) 70 | 71 | 72 | def get_img_rotation_matrix(shape, angle): 73 | res = np.array(shape[:2], np.float32)[::-1] 74 | half = (res / 2).astype(np.int32) 75 | 76 | rotation_matrix = cv.getRotationMatrix2D(half.tolist(), angle, 1.0) 77 | 78 | cos, sin = np.abs(rotation_matrix[0, :2]) 79 | new_res = (res * cos + res[::-1] * sin).astype(np.int32) 80 | rotation_matrix[:2, 2] += (new_res / 2) - half 81 | 82 | return rotation_matrix, new_res 83 | 84 | 85 | def img_bbox_rotate(bbox, shape, angle): 86 | bbox = bbox.copy() 87 | 88 | rotation_matrix, _ = get_img_rotation_matrix(shape, angle) 89 | 90 | points = bbox.points[:, ::-1] 91 | points = np.hstack((points, np.ones((points.shape[0], 1), dtype=points.dtype))) 92 | points = np.dot(points, rotation_matrix.T) 93 | 94 | bbox.points = points.astype(np.int32)[:, ::-1] 95 | 96 | return bbox 97 | 98 | 99 | def bbox_union(bbox1, bbox2): 100 | bbox1.points[0] = np.minimum(bbox1.points[0], bbox2.points[0]) 101 | bbox1.points[1] = np.maximum(bbox1.points[1], bbox2.points[1]) 102 | 103 | return bbox1 104 | 105 | 106 | def calculate_overall_bbox(line_bboxes): 107 | overall_bbox = line_bboxes[0].copy() 108 | for bbox in line_bboxes: 109 | overall_bbox = bbox_union(overall_bbox, bbox) 110 | return overall_bbox 111 | 112 | 113 | def calculate_line_bboxes(char_bboxes): 114 | line_bboxes = [] 115 | for k in char_bboxes: 116 | if len(line_bboxes) < k["line_index"] + 1: 117 | line_bboxes.append(k["bbox"].copy()) 118 | 119 | line_bbox = line_bboxes[k["line_index"]] 120 | 121 | line_bbox = bbox_union(line_bbox, k["bbox"]) 122 | 123 | line_bboxes[k["line_index"]] = line_bbox 124 | 125 | return line_bboxes 126 | 127 | 128 | def calculate_char_bboxes(xy, text, font): 129 | draw = ImageDraw.Draw(Image.new("RGB", (0, 0))) 130 | 131 | char_bboxes = [] 132 | for i, char in enumerate(text): 133 | char_bbox = draw.textbbox(xy, text[i], font=font) 134 | width = char_bbox[2] - char_bbox[0] 135 | height = char_bbox[3] - char_bbox[1] 136 | 137 | if width == 0 or height == 0: 138 | continue 139 | 140 | mask, offset_bbox = get_tight_character_bbox(char, font) 141 | offset_bbox_size = offset_bbox.get_size() 142 | 143 | line_index = text[: i + 1].count("\n") 144 | 145 | char_bbox = draw.textbbox(xy, "\n" * line_index + text[i], font=font) 146 | bottom = char_bbox[3] 147 | 148 | char_bbox = draw.textbbox(xy, text[: i + 1].split("\n")[-1], font=font) 149 | right = char_bbox[2] 150 | 151 | bottom -= mask.shape[0] - offset_bbox.points[1][0] 152 | right -= mask.shape[1] - offset_bbox.points[1][1] 153 | top = bottom - offset_bbox_size[0] 154 | left = right - offset_bbox_size[1] 155 | 156 | char_bboxes.append( 157 | { 158 | "bbox": SimpleBoundingBox((top, left), (bottom, right)), 159 | "mask": mask[ 160 | offset_bbox.points[0, 0] : offset_bbox.points[1, 0], 161 | offset_bbox.points[0, 1] : offset_bbox.points[1, 1], 162 | ], 163 | "char_index": i, 164 | "line_index": line_index, 165 | } 166 | ) 167 | return char_bboxes 168 | 169 | 170 | def calculate_precise_bbox( 171 | document_bbox: QuadBoundingBox, 172 | src_shape: Tuple[int, int], 173 | mask_warped: np.ndarray, 174 | coords_relative: np.ndarray, 175 | ) -> Tuple[QuadBoundingBox, QuadBoundingBox]: 176 | bbox = document_bbox 177 | bbox = bbox.relative(src_shape[:2]) 178 | bbox = bbox.remap(coords_relative[..., ::-1]) 179 | bbox.points = bbox.points.astype(np.int32) 180 | points = np.stack(np.where(mask_warped), -1) 181 | 182 | bbox.points = transform_quad_to_fit_points(bbox.points, points) 183 | bbox.points = np.rint(bbox.points).astype(np.int32) 184 | 185 | bbox_precise = bbox 186 | bbox_precise_relative = bbox.relative(coords_relative.shape[:2][::-1]) 187 | 188 | return bbox_precise, bbox_precise_relative 189 | 190 | 191 | def bbox_from_binary_mask(mask: np.ndarray) -> SimpleBoundingBox: 192 | mask = (mask.any(0), mask.any(1)) 193 | 194 | bbox = [(np.argmax(k), k.size - np.argmax(k[::-1])) for k in mask] 195 | bbox = np.swapaxes(np.array(bbox, np.int32), 0, 1) 196 | bbox = SimpleBoundingBox(bbox[0], bbox[1], bbox.dtype) 197 | 198 | bbox.points = bbox.points[:, ::-1] 199 | return bbox 200 | 201 | 202 | def perspective_transform_bboxes(bboxes, transform_matrix): 203 | points = [k.to_quad().points for k in bboxes] 204 | points = np.concatenate(points).astype(np.float32) 205 | 206 | points = points[:, ::-1] 207 | points = perspective_transform_points(points, transform_matrix) 208 | points = np.rint(points[:, ::-1]).astype(np.int32).reshape(-1, 4, 2) 209 | 210 | return [QuadBoundingBox(k[0], k[1], k[2], k[3]) for k in points] 211 | -------------------------------------------------------------------------------- /src/calculate_bounding_boxes.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | from pathlib import Path as pth 3 | 4 | import numpy as np 5 | import cv2 as cv 6 | from scipy.ndimage import map_coordinates 7 | from PIL import ImageFont 8 | 9 | from bounding_box import SimpleBoundingBox 10 | from bounding_box_math import ( 11 | calculate_char_bboxes, 12 | calculate_line_bboxes, 13 | calculate_overall_bbox, 14 | img_bbox_rotate, 15 | get_tight_character_bbox, 16 | bbox_from_binary_mask, 17 | perspective_transform_bboxes, 18 | calculate_precise_bbox, 19 | ) 20 | 21 | 22 | def to_uint(img: np.ndarray, dtype=np.uint8): 23 | return (np.clip(img, 0, 1) * np.iinfo(dtype).max).astype(dtype) 24 | 25 | 26 | def to_float(img: np.ndarray, fdtype=np.float32): 27 | return img.astype(fdtype) / np.iinfo(img.dtype).max 28 | 29 | 30 | def imread_coords(path: Union[str, pth], src_shape: Tuple[int, int]): 31 | # unchanged to read as uint16 32 | coords_relative = to_float(cv.imread(path, cv.IMREAD_UNCHANGED)) 33 | alpha = coords_relative[..., 0, None] 34 | # flip y to match opencv coordinates 35 | coords_relative[..., 1] = 1 - coords_relative[..., 1] 36 | coords_relative = np.where(alpha < 1, -1, coords_relative[..., 1:]) 37 | 38 | coords_absolute = np.moveaxis(coords_relative.copy(), -1, 0) 39 | coords_absolute *= np.array( 40 | (src_shape[0] - 1, src_shape[1] - 1), coords_absolute.dtype 41 | )[:, None, None] 42 | 43 | return coords_relative, coords_absolute, alpha 44 | 45 | 46 | def calculate_document_bboxes(sample, font): 47 | char_bboxes = calculate_char_bboxes(sample.anchor, sample.text, font) 48 | 49 | for k in char_bboxes: 50 | left_offset = sample.line_offsets[k["line_index"]] 51 | points = k["bbox"].points.astype(np.float32) 52 | points[:, 1] += left_offset 53 | k["bbox"].points = points.astype(np.int32) 54 | 55 | overall_bbox = calculate_overall_bbox(calculate_line_bboxes(char_bboxes)) 56 | 57 | shape = sample.resolution_before_rotation[::-1] 58 | 59 | for k in char_bboxes: 60 | k["bbox"] = img_bbox_rotate( 61 | k["bbox"].to_quad(), shape, sample.text_rotation_angle 62 | ) 63 | 64 | overall_bbox = img_bbox_rotate( 65 | overall_bbox.to_quad(), shape, sample.text_rotation_angle 66 | ) 67 | 68 | padding_array = np.array(sample.padding[:2], np.int32)[None] 69 | for k in char_bboxes: 70 | k["bbox"].points += padding_array 71 | 72 | overall_bbox.points += padding_array 73 | 74 | return char_bboxes, overall_bbox 75 | 76 | 77 | def calculate_char_labels(sample, document_img_shape, char_bboxes, font): 78 | mask_combined = np.zeros(document_img_shape[:2], dtype=np.int32) 79 | 80 | for i, k in enumerate(char_bboxes): 81 | bbox = k["bbox"] 82 | char_index = k["char_index"] 83 | 84 | mask, mask_bbox = get_tight_character_bbox(sample.text[char_index], font) 85 | mask = mask[ 86 | mask_bbox.points[0, 0] : mask_bbox.points[1, 0], 87 | mask_bbox.points[0, 1] : mask_bbox.points[1, 1], 88 | ] 89 | 90 | h, w = np.array(mask.shape[:2]) - 1 91 | reference_rect = np.array([[0, 0], [h, 0], [h, w], [0, w]], dtype=np.float32) 92 | quad = bbox.points.astype(np.float32) 93 | 94 | transform_matrix = cv.getPerspectiveTransform( 95 | reference_rect[:, ::-1], quad[:, ::-1] 96 | ) 97 | mask = cv.warpPerspective(mask, transform_matrix, document_img_shape[:2][::-1]) 98 | 99 | mask_combined[mask > 64] = i + 1 100 | 101 | return mask_combined 102 | 103 | 104 | def calculate_render_bboxes( 105 | document_img_shape, 106 | labels, 107 | document_char_bboxes, 108 | document_overall_bbox, 109 | coords_relative, 110 | coords_absolute, 111 | resolution, 112 | ): 113 | labels_warped = map_coordinates( 114 | labels, np.rint(coords_absolute).astype(np.int32), cval=0 115 | ) 116 | 117 | if np.all(labels_warped == 0): 118 | return None, None, None 119 | 120 | overall_bbox_remapped, _ = calculate_precise_bbox( 121 | document_overall_bbox, document_img_shape, labels_warped != 0, coords_relative 122 | ) 123 | 124 | w, h = resolution 125 | reference_rect = np.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=np.float32) * ( 126 | h - 1, 127 | w - 1, 128 | ) 129 | reference_rect = reference_rect.astype(np.float32)[:, ::-1] 130 | quad = overall_bbox_remapped.points.astype(np.float32)[:, ::-1] 131 | 132 | transform_matrix = cv.getPerspectiveTransform(quad, reference_rect) 133 | transform_matrix_inverse = cv.getPerspectiveTransform(reference_rect, quad) 134 | labels_warped = cv.warpPerspective( 135 | labels_warped.astype(np.float32), 136 | transform_matrix, 137 | (w, h), 138 | flags=cv.INTER_NEAREST, 139 | ) 140 | 141 | new_char_bboxes = [] 142 | for i, k in enumerate(document_char_bboxes): 143 | mask = labels_warped == i + 1 144 | bbox = bbox_from_binary_mask(mask) 145 | bbox.points[1] -= 1 146 | k = k.copy() 147 | k["bbox"] = bbox 148 | new_char_bboxes.append(k) 149 | 150 | new_line_bboxes = calculate_line_bboxes(new_char_bboxes) 151 | 152 | # calculate overall bbox 153 | new_overall_bbox = np.concatenate([k["bbox"].points for k in new_char_bboxes], 0) 154 | new_overall_bbox = np.amin(new_overall_bbox, 0), np.amax(new_overall_bbox, 0) 155 | new_overall_bbox = SimpleBoundingBox(*new_overall_bbox).to_quad() 156 | 157 | # transform bboxes back 158 | new_overall_bbox = perspective_transform_bboxes( 159 | [new_overall_bbox], transform_matrix_inverse 160 | )[0] 161 | new_char_bboxes = perspective_transform_bboxes( 162 | [k["bbox"] for k in new_char_bboxes], transform_matrix_inverse 163 | ) 164 | new_line_bboxes = perspective_transform_bboxes( 165 | new_line_bboxes, transform_matrix_inverse 166 | ) 167 | 168 | return new_char_bboxes, new_line_bboxes, new_overall_bbox 169 | 170 | 171 | def bboxes_to_dict( 172 | text, 173 | output_img_shape, 174 | new_overall_bbox, 175 | new_line_bboxes, 176 | new_char_bboxes, 177 | original_char_bboxes, 178 | ): 179 | # copy data 180 | new_overall_bbox = new_overall_bbox.copy() 181 | new_line_bboxes = [k.copy() for k in new_line_bboxes] 182 | new_char_bboxes = [k.copy() for k in new_char_bboxes] 183 | original_char_bboxes = original_char_bboxes.copy() 184 | for k in original_char_bboxes: 185 | k["bbox"] = k["bbox"].copy() 186 | 187 | res = np.array(output_img_shape[:2], np.int32)[None] 188 | for bbox in new_char_bboxes: 189 | bbox.points = (bbox.points / res)[:, ::-1] 190 | for bbox in new_line_bboxes: 191 | bbox.points = (bbox.points / res)[:, ::-1] 192 | new_overall_bbox.points = (new_overall_bbox.points / res)[:, ::-1] 193 | 194 | def make_tl_tr_br_bl(points): 195 | return np.stack((points[0], points[3], points[2], points[1]), 0) 196 | 197 | for bbox in new_char_bboxes: 198 | bbox.points = make_tl_tr_br_bl(bbox.points) 199 | for bbox in new_line_bboxes: 200 | bbox.points = make_tl_tr_br_bl(bbox.points) 201 | new_overall_bbox.points = make_tl_tr_br_bl(new_overall_bbox.points) 202 | 203 | axis_aligned_overall_bbox = np.stack( 204 | (np.amin(new_overall_bbox.points, 0), np.amax(new_overall_bbox.points, 0)), 0 205 | ) 206 | axis_aligned_overall_bbox_xxyy = [ 207 | axis_aligned_overall_bbox[0, 0], 208 | axis_aligned_overall_bbox[1, 0], 209 | axis_aligned_overall_bbox[0, 1], 210 | axis_aligned_overall_bbox[1, 1], 211 | ] 212 | 213 | lines = text.split("\n") 214 | bounding_boxes = { 215 | "overall_bbox": new_overall_bbox.points.tolist(), 216 | "axis_aligned_overall_bbox": axis_aligned_overall_bbox.tolist(), 217 | "axis_aligned_overall_bbox_xxyy": axis_aligned_overall_bbox_xxyy, 218 | "lines": lines, 219 | "line_bboxes": [], 220 | "chars": [], 221 | "char_idx": [], 222 | "char_bboxes": [], 223 | } 224 | 225 | for i in {k["line_index"] for k in original_char_bboxes}: 226 | bounding_boxes["line_bboxes"].append(new_line_bboxes[i].points.tolist()) 227 | 228 | for i in {k["line_index"] for k in original_char_bboxes}: 229 | char_bboxes = [ 230 | (k[0]["char_index"], k[1].points.tolist()) 231 | for k in zip(original_char_bboxes, new_char_bboxes) 232 | if k[0]["line_index"] == i 233 | ] 234 | for k in char_bboxes: 235 | bounding_boxes["chars"].append(text[k[0]]) 236 | bounding_boxes["char_idx"].append(k[0]) 237 | bounding_boxes["char_bboxes"].append(k[1]) 238 | 239 | return bounding_boxes 240 | 241 | 242 | def calculate_bounding_boxes(sample): 243 | document_img = cv.imread(str(sample.image_path)) 244 | document_res = document_img.shape[:2] 245 | font = ImageFont.truetype(str(sample.font_path), sample.font_size) 246 | 247 | coords_relative, coords_absolute, _ = imread_coords( 248 | str(sample.output_coordinates_path), document_res 249 | ) 250 | 251 | document_char_bboxes, document_overall_bbox = calculate_document_bboxes( 252 | sample, font 253 | ) 254 | mask_combined = calculate_char_labels( 255 | sample, document_res, document_char_bboxes, font 256 | ) 257 | new_char_bboxes, new_line_bboxes, new_overall_bbox = calculate_render_bboxes( 258 | document_res, 259 | mask_combined, 260 | document_char_bboxes, 261 | document_overall_bbox, 262 | coords_relative, 263 | coords_absolute, 264 | sample.output_image_resolution, 265 | ) 266 | 267 | if new_overall_bbox is None: 268 | return None, None, None, None 269 | 270 | bounding_boxes = bboxes_to_dict( 271 | sample.text, 272 | coords_absolute.shape[1:], 273 | new_overall_bbox, 274 | new_line_bboxes, 275 | new_char_bboxes, 276 | document_char_bboxes, 277 | ) 278 | 279 | return bounding_boxes, new_char_bboxes, new_line_bboxes, new_overall_bbox 280 | -------------------------------------------------------------------------------- /src/font_rendering.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import concurrent.futures 3 | from random import uniform, randint, random, choice 4 | import textwrap 5 | from colorsys import hls_to_rgb 6 | 7 | from fontTools.ttLib import TTFont 8 | import numpy as np 9 | import cv2 as cv 10 | from PIL import Image, ImageFont, ImageDraw 11 | from PIL.Image import Resampling 12 | 13 | from bounding_box import SimpleBoundingBox 14 | 15 | 16 | def lerp(a, b, fac): 17 | return a + (b - a) * fac 18 | 19 | 20 | def pad_image( 21 | img: np.ndarray, padding: Tuple[int, int, int, int] = (0, 0, 0, 0), color: int = 255 22 | ): 23 | return np.pad( 24 | img, 25 | ((padding[0], padding[2]), (padding[1], padding[3]), (0, 0)), 26 | mode="constant", 27 | constant_values=color, 28 | ) 29 | 30 | 31 | def wrap_text_to_match_aspect_ratio(text: str, font: ImageFont, aspect_ratio: float): 32 | # calculate line character width based on aspect ratio 33 | n_characters = len(text) 34 | 35 | text_bbox = font.getbbox(text) 36 | text_bbox = SimpleBoundingBox(text_bbox[:2], text_bbox[2:], np.int32) 37 | character_size = text_bbox.get_size() 38 | character_size[0] /= n_characters 39 | 40 | text_area = character_size[0] * character_size[1] * n_characters 41 | 42 | size = np.ceil( 43 | np.array( 44 | (text_area**0.5 * aspect_ratio, text_area**0.5 / aspect_ratio), 45 | np.float32, 46 | ) 47 | / character_size 48 | ).astype(np.int32) 49 | 50 | # wrap text based on calculated width 51 | return textwrap.fill( 52 | text=text, width=size[0], break_long_words=False, break_on_hyphens=False 53 | ) 54 | 55 | 56 | def calculate_font_scale(text: str, font: ImageFont, resolution: int, draw: ImageDraw): 57 | # find font scale that roughly matches resolution 58 | text_bbox = draw.multiline_textbbox((0, 0), text, font=font) 59 | text_bbox = SimpleBoundingBox(text_bbox[:2], text_bbox[2:], np.int32) 60 | 61 | return resolution / np.amax(text_bbox.get_size()) 62 | 63 | 64 | def apply_colors(color0, color1, alpha): 65 | # to float 66 | alpha = alpha.astype(np.float32) / 255 67 | img = lerp(color0, color1, alpha) 68 | # to uint 69 | img = np.rint(img).astype(np.uint8) 70 | return cv.cvtColor(img, cv.COLOR_RGB2BGR) 71 | 72 | 73 | def draw_text_aligned(draw, xy, text, font, align, **kwargs): 74 | text_bbox_with_padding = draw.textbbox((0, 0), text, font=font) 75 | line_offsets = [] 76 | for i, line in enumerate(text.split("\n")): 77 | line_bbox_with_padding = draw.textbbox((0, 0), line, font=font) 78 | bottom = draw.textbbox((0, 0), "\n" * i + line, font=font)[3] 79 | 80 | if align == "left": 81 | left_offset = 0 82 | elif align == "center": 83 | left_offset = round( 84 | ((text_bbox_with_padding[2]) - (line_bbox_with_padding[2])) / 2 85 | ) 86 | elif align == "right": 87 | left_offset = text_bbox_with_padding[2] - line_bbox_with_padding[2] 88 | else: 89 | raise ValueError("Invalid align value") 90 | 91 | line_offsets.append(left_offset) 92 | 93 | top = bottom - line_bbox_with_padding[3] 94 | anchor = (xy[0] + left_offset, xy[1] + top) 95 | draw.text(anchor, line, font=font, align="left", **kwargs) 96 | 97 | return line_offsets 98 | 99 | 100 | def render_text_mask( 101 | text: str, 102 | font: ImageFont, 103 | resolution: int, 104 | text_aspect_ratio: float, 105 | align: str = "left", 106 | ) -> Tuple[str, np.ndarray]: 107 | text = wrap_text_to_match_aspect_ratio(text, font, text_aspect_ratio) 108 | 109 | # temp draw for getting font bbox 110 | draw = ImageDraw.Draw(Image.new("RGB", (0, 0))) 111 | font_scale = calculate_font_scale(text, font, resolution, draw) 112 | font_size = int(font.size * font_scale) 113 | # scale font 114 | font = font.font_variant(size=font_size) 115 | 116 | # calculate image resolution 117 | text_bbox = draw.multiline_textbbox((0, 0), text, font=font) 118 | text_bbox = SimpleBoundingBox(text_bbox[:2], text_bbox[2:], np.int32) 119 | # render text 120 | alpha = Image.new("RGB", text_bbox.get_size().tolist(), color=(255, 255, 255)) 121 | draw = ImageDraw.Draw(alpha) 122 | xy = (-text_bbox[0][:]).tolist() 123 | 124 | # draw.text(xy, text, fill=(0, 0, 0), font=font, align=align) 125 | line_offsets = draw_text_aligned(draw, xy, text, font, align, fill=(0, 0, 0)) 126 | 127 | return text, alpha, font_size, xy, line_offsets 128 | 129 | 130 | def is_char_supported_by_font(font: TTFont, char: str) -> bool: 131 | return any( 132 | ord(char) in k.cmap.keys() for k in font["cmap"].tables if hasattr(k, "cmap") 133 | ) 134 | 135 | 136 | def is_text_supported_by_font(text: str, font: TTFont): 137 | with concurrent.futures.ThreadPoolExecutor() as executor: 138 | return all(executor.map(lambda c: is_char_supported_by_font(font, c), text)) 139 | 140 | 141 | def hls_to_int_rgb(hue: float, lightness: float, saturation: float): 142 | col = hls_to_rgb(hue, lightness, saturation) 143 | return np.floor(np.array(col, np.float32) * 255).astype(np.int32) 144 | 145 | 146 | def generate(text: str, font: ImageFont, resolution: int): 147 | hue, lightness, saturation = [random() for _ in range(3)] 148 | # make saturation curve steeper 149 | saturation = (saturation**2) * (3 - 2 * saturation) 150 | font_color = hls_to_int_rgb(hue, lightness * 0.6, saturation * 0.86) 151 | 152 | text_rotation_angle = randint(-45, 45) 153 | 154 | text_aspect_ratio = uniform(0.5, 2) 155 | padding = [randint(0, 64) for _ in range(4)] 156 | alignment = choice(["left", "center", "right"]) 157 | 158 | background_color = (255, 255, 255) 159 | 160 | white = (255, 255, 255) 161 | 162 | text, alpha, font_size, xy, line_offsets = render_text_mask( 163 | text, font, resolution, text_aspect_ratio, alignment 164 | ) 165 | resolution_before_rotation = alpha.size 166 | 167 | # rotate image 168 | alpha = np.array( 169 | alpha.rotate( 170 | text_rotation_angle, 171 | resample=Resampling.BICUBIC, 172 | expand=True, 173 | fillcolor=white, 174 | ) 175 | ) 176 | 177 | # add padding 178 | alpha = pad_image(alpha, padding, white[0]) 179 | # apply colors 180 | img = apply_colors(font_color, background_color, alpha) 181 | 182 | return ( 183 | text, 184 | img, 185 | font_size, 186 | xy, 187 | line_offsets, 188 | padding, 189 | font_color, 190 | text_rotation_angle, 191 | resolution_before_rotation, 192 | ) 193 | -------------------------------------------------------------------------------- /src/generate_samples_2d.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | from random import random, choice 3 | from pathlib import Path as pth 4 | from shutil import rmtree 5 | from io import BytesIO 6 | import threading 7 | import itertools 8 | import tempfile 9 | import re 10 | from functools import partial 11 | 12 | from p_tqdm import t_map, p_umap 13 | import numpy as np 14 | import cv2 as cv 15 | from fontTools.ttLib import TTFont 16 | from PIL import ImageFont 17 | 18 | from font_rendering import generate, is_text_supported_by_font 19 | from Blender_3D_document_rendering_pipeline.src import config 20 | 21 | 22 | class SampleInfo: 23 | def __init__( 24 | self, 25 | text: str, 26 | config: config.Config, 27 | anchor: Tuple[int, int], 28 | line_offsets, 29 | padding, 30 | image_path: pth, 31 | font_path: str, 32 | font_color: Tuple[int, int, int], 33 | font_size: int, 34 | text_rotation_angle: int, 35 | text_image_resolution: Tuple[int, int], 36 | output_image_resolution: Tuple[int, int], 37 | output_image_path: pth, 38 | output_coordinates_path: pth, 39 | compression_level: int, 40 | resolution_before_rotation: Tuple[int, int], 41 | ): 42 | self.text = text 43 | self.config = config 44 | self.anchor = anchor 45 | self.line_offsets = line_offsets 46 | self.padding = padding 47 | self.image_path = image_path 48 | self.font_path = font_path 49 | self.font_color = font_color 50 | self.font_size = font_size 51 | self.text_rotation_angle = text_rotation_angle 52 | self.text_image_resolution = text_image_resolution 53 | self.output_image_resolution = output_image_resolution 54 | self.output_image_path = output_image_path 55 | self.output_coordinates_path = output_coordinates_path 56 | self.compression_level = compression_level 57 | self.resolution_before_rotation = resolution_before_rotation 58 | 59 | 60 | def mkdir(path: pth): 61 | if path.is_dir(): 62 | rmtree(path) 63 | path.mkdir(parents=True) 64 | 65 | 66 | def assign_material_to_conf(material, conf): 67 | material = {k: str(material[k].resolve()) for k in material} 68 | 69 | conf.ground.albedo_tex = material["albedo"] 70 | conf.ground.roughness_tex = material["roughness"] 71 | conf.ground.displacement_tex = material["displacement"] 72 | 73 | 74 | def break_up_sample(text_sample): 75 | # make sure that punctuation has the correct spacing 76 | text_sample = re.sub(r"\s*([.,?!])\s*", r"\1 ", text_sample) 77 | 78 | text_sample = [k.strip() for k in text_sample.split(", ")] 79 | text_sample = [k for k in text_sample if k] 80 | 81 | # randomly merge back some adjecent strings 82 | i = 0 83 | while (i + 1) < len(text_sample): 84 | if random() < 0.15: 85 | text_sample[i] += f", {text_sample.pop(i + 1)}" 86 | i += 1 87 | 88 | return text_sample 89 | 90 | 91 | def break_up_samples(text_samples): 92 | text_samples = map(break_up_sample, text_samples) 93 | return list(itertools.chain.from_iterable(text_samples)) 94 | 95 | 96 | def get_text_and_font(shuffled_text_dataset_iter, random_font_iter): 97 | while True: 98 | # find a font that supports all characters in text 99 | text = next(shuffled_text_dataset_iter)["sentences"] 100 | text = choice(break_up_sample(text)) 101 | for _ in range(20): 102 | font_path = next(random_font_iter) 103 | with open(font_path, "rb") as f: 104 | font_file = f.read() 105 | # few fonts are be broken and will raise an exception 106 | try: 107 | if is_text_supported_by_font(text, TTFont(BytesIO(font_file))): 108 | break 109 | except Exception as e: 110 | print(e) 111 | continue 112 | else: 113 | # give up and try a different piece of text 114 | print("No font found that supports all characters in text") 115 | continue 116 | break 117 | 118 | return text, font_path 119 | 120 | 121 | def write_config( 122 | img: np.ndarray, 123 | device: str, 124 | root_dir: pth, 125 | output_dir: pth, 126 | hdri_path, 127 | material, 128 | output_image_resolution, 129 | compression_level: int, 130 | image_path, 131 | config_path: pth, 132 | ) -> config.Config: 133 | resolution = np.array(img.shape[:2], np.float32) 134 | paper_size = resolution[::-1] / np.mean(resolution) * 25 135 | 136 | conf = config.Config(device, project_root=root_dir) 137 | 138 | conf.hdri.texture_path = str(pth(hdri_path).resolve()) 139 | 140 | assign_material_to_conf(material, conf) 141 | 142 | conf.render.output_dir = str(output_dir) 143 | conf.render.resolution = output_image_resolution 144 | conf.render.cycles_samples = 2 145 | conf.render.compression_ratio = round(compression_level / 9 * 100) 146 | conf.paper.document_image_path = str(image_path.resolve(True)) 147 | conf.paper.size = paper_size.tolist() 148 | conf.ground.visible = random() < 0.6 149 | 150 | config.write_config(config_path, conf) 151 | 152 | return conf 153 | 154 | 155 | def save_images( 156 | img: np.ndarray, 157 | image_dir_path: pth, 158 | sample_id: str, 159 | compression_level: int, 160 | ): 161 | image_path = image_dir_path / f"{sample_id}.png" 162 | cv.imwrite(str(image_path), img, [cv.IMWRITE_PNG_COMPRESSION, compression_level]) 163 | 164 | return image_path 165 | 166 | 167 | def generate_sample( 168 | index: int, 169 | text: str, 170 | font_path: str, 171 | material, 172 | hdri_path, 173 | root_dir, 174 | output_dir, 175 | image_dir, 176 | config_dir, 177 | text_render_resolution, 178 | output_image_resolution, 179 | device, 180 | compression_level, 181 | ) -> Union[SampleInfo, None]: 182 | try: 183 | with open(font_path, "rb") as f: 184 | font_file = f.read() 185 | font = ImageFont.truetype(BytesIO(font_file), 42) 186 | except Exception as e: 187 | print(f"Error while generating sample {index}: {e}") 188 | return 189 | 190 | # few fonts are be broken and will raise an exception 191 | try: 192 | ( 193 | text, 194 | img, 195 | font_size, 196 | anchor, 197 | line_offsets, 198 | padding, 199 | font_color, 200 | text_rotation_angle, 201 | resolution_before_rotation, 202 | ) = generate(text, font, text_render_resolution) 203 | except Exception as e: 204 | print(f"Error while generating sample {index}: {e}") 205 | return 206 | 207 | sample_id = f"sample_{index:08d}" 208 | 209 | image_path = save_images(img, image_dir, sample_id, compression_level) 210 | 211 | config_path = config_dir / f"{sample_id}.json" 212 | 213 | conf = write_config( 214 | img, 215 | device, 216 | root_dir, 217 | output_dir, 218 | hdri_path, 219 | material, 220 | output_image_resolution, 221 | compression_level, 222 | image_path, 223 | config_path, 224 | ) 225 | 226 | out_dir = output_dir / f"{sample_id}" 227 | out_image_path = out_dir / "image0001.png" 228 | out_coordinates_path = out_dir / "coordinates0001.png" 229 | 230 | resolution = img.shape[:2] 231 | 232 | return SampleInfo( 233 | text, 234 | conf, 235 | anchor, 236 | line_offsets, 237 | padding, 238 | image_path, 239 | font_path, 240 | font_color, 241 | font_size, 242 | text_rotation_angle, 243 | resolution, 244 | output_image_resolution, 245 | out_image_path, 246 | out_coordinates_path, 247 | compression_level, 248 | resolution_before_rotation, 249 | ) 250 | 251 | 252 | def generate_samples( 253 | n_samples: int, 254 | device: str, 255 | output_image_resolution: Tuple[int, int], 256 | compression_level: int, 257 | root_dir: pth, 258 | output_dir: pth, 259 | config_dir: pth, 260 | image_dir: pth, 261 | random_font_iter, 262 | random_hdri_iter, 263 | random_material_iter, 264 | shuffled_dataset_iter, 265 | multiprocessing: bool = True, 266 | ): 267 | root_dir = pth(root_dir) 268 | output_dir = pth(output_dir) 269 | config_dir = pth(config_dir) 270 | image_dir = pth(image_dir) 271 | 272 | text_render_resolution = int(max(output_image_resolution)) 273 | 274 | mkdir(config_dir) 275 | 276 | texts = [] 277 | font_paths = [] 278 | materials = [] 279 | hdris = [] 280 | while len(texts) < n_samples: 281 | try: 282 | text, font_path = get_text_and_font(shuffled_dataset_iter, random_font_iter) 283 | 284 | # use textures from a single material 285 | # or combine textures from different materials for more diversity 286 | material = ( 287 | next(random_material_iter) 288 | # if random() < 0.2 289 | # else { 290 | # k: next(random_material_iter)[k] 291 | # for k in ["albedo", "roughness", "displacement"] 292 | # } 293 | ) 294 | 295 | hdri_path = next(random_hdri_iter) 296 | 297 | texts.append(text) 298 | font_paths.append(font_path) 299 | materials.append(material) 300 | hdris.append(hdri_path) 301 | except Exception as e: 302 | print(e) 303 | 304 | generate_sample_func = partial( 305 | generate_sample, 306 | root_dir=root_dir, 307 | output_dir=output_dir, 308 | image_dir=image_dir, 309 | config_dir=config_dir, 310 | text_render_resolution=text_render_resolution, 311 | output_image_resolution=output_image_resolution, 312 | device=device, 313 | compression_level=compression_level, 314 | ) 315 | 316 | if multiprocessing: 317 | generated_samples = p_umap( 318 | generate_sample_func, 319 | range(len(texts)), 320 | texts, 321 | font_paths, 322 | materials, 323 | hdris, 324 | desc="Generating 2D samples" 325 | ) 326 | else: 327 | generated_samples = t_map( 328 | generate_sample_func, 329 | range(len(texts)), 330 | texts, 331 | font_paths, 332 | materials, 333 | hdris, 334 | desc="Generating 2D samples" 335 | ) 336 | 337 | # remove failed samples 338 | generated_samples = [k for k in generated_samples if k] 339 | 340 | return generated_samples 341 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path as pth 2 | import json 3 | import tempfile 4 | from shutil import rmtree 5 | from typing import Union, Optional 6 | 7 | import fire 8 | 9 | from prepare_data import ( 10 | get_text_dataset, 11 | gdrive_download_and_extract, 12 | mega_download_and_extract, 13 | get_hdris, 14 | get_materials, 15 | get_fonts, 16 | DownloadError 17 | ) 18 | from generate_samples_2d import generate_samples 19 | from blender_render_samples_3d import run_blender_command 20 | from postprocess import postprocess_samples 21 | 22 | 23 | def download_from_grdrive_or_mega( 24 | path: Union[str, pth], 25 | gdrive_file_id: str, 26 | mega_file_id_and_key: str, 27 | desc: Optional[str]=None, 28 | ) -> pth: 29 | path = pth(path) 30 | if desc is None: 31 | desc = str(path.stem) 32 | 33 | print(f"Downloading {desc}...") 34 | 35 | # download from google drive, if failed, download from mega 36 | try: 37 | download_path = gdrive_download_and_extract(path, gdrive_file_id) 38 | if not download_path.is_dir(): 39 | raise DownloadError(f"{desc} from Google Drive") 40 | except Exception as e: 41 | print(e) 42 | print(f"Attempting to download {desc} from Mega.nz") 43 | 44 | download_path = mega_download_and_extract(path, mega_file_id_and_key) 45 | if not download_path.is_dir(): 46 | raise DownloadError(f"{desc} from Mega.nz") 47 | 48 | return download_path 49 | 50 | 51 | def download_and_prepare_data(): 52 | download_path = pth("../assets/").resolve() 53 | if not download_path.is_dir(): 54 | download_path.mkdir(parents=True, exist_ok=True) 55 | 56 | print("Preparing text dataset...") 57 | _, texts_iter = get_text_dataset() 58 | 59 | fonts_path = download_from_grdrive_or_mega( 60 | download_path / "fonts", 61 | "1-K8EE0QsXfxaAV-5uOE6lhTLGicRZbW2", 62 | "itg2TKoJ#UCOEQqX7pPwAf9pguWVUuksX7orWtzK5n4SdI6CQqGc" 63 | ) 64 | 65 | hdris_path = download_from_grdrive_or_mega( 66 | download_path / "hdris", 67 | "1BNCTqw5fenCK-D48-a7VQ234Aq3k45hu", 68 | "D5YWQKgR#fFRHe-HpbCc7-yAm3h5zxbR905o1hkrxKJDvQOsGOKk" 69 | ) 70 | 71 | materials_path = download_from_grdrive_or_mega( 72 | download_path / "materials", 73 | "1-5dz5DMce-braCrhVIsqB58PvcyB6qyy", 74 | "W0hGyCzB#-Gldvyt6uGt9D6iT8kDL-T4CKGNwIKD0Yc4jR2MAgxo" 75 | ) 76 | 77 | fonts_iter = get_fonts(fonts_path) 78 | hdris_iter = get_hdris(hdris_path) 79 | materials_iter = get_materials(materials_path) 80 | 81 | return texts_iter, fonts_iter, hdris_iter, materials_iter 82 | 83 | 84 | def main( 85 | n_samples: int, 86 | blender_path: str, 87 | output_dir: str, 88 | device: str, 89 | resolution_x: int = 512, 90 | resolution_y: int = 512, 91 | compression_level: int = 9, 92 | ): 93 | """ 94 | Runs the main pipeline for rendering handwritten text on a virtual piece of paper using Blender. 95 | It downloads and prepares data for rendering 3D documents using Blender. 96 | It generates samples, renders them using Blender, and post-processes them. 97 | The generated samples are saved in the specified output directory. 98 | 99 | Args: 100 | n_samples (int): The number of sample images to generate. 101 | blender_path (str): The path to the Blender executable. 102 | output_dir (str): The path to the directory where the rendered images will be saved. 103 | device (str): The device to use for rendering ('cpu', 'cuda' or 'optix'). 104 | output_image_resolution (Tuple[int, int], optional): The resolution of the output images. Defaults to (512, 512). 105 | compression_level (int, optional): The png compression level to use when saving the output images. 106 | Must be between 0 and 9, with 0 being no compression and 9 being maximum compression. 107 | Defaults to 9. 108 | 109 | Raises: 110 | ValueError: If the output directory is not empty. 111 | 112 | Returns: 113 | None 114 | """ 115 | 116 | device = device.upper() 117 | 118 | blender_path: pth = pth(blender_path) 119 | output_dir: pth = pth(output_dir) 120 | 121 | if device not in ["CPU", "CUDA", "OPTIX"]: 122 | raise ValueError(f"Invalid device: {device}") 123 | 124 | if not blender_path.is_file(): 125 | raise ValueError(f"Blender path {blender_path} is not a valid file") 126 | 127 | if not output_dir.is_dir(): 128 | output_dir.mkdir(parents=True, exist_ok=True) 129 | 130 | blender_path = blender_path.resolve() 131 | output_dir = output_dir.resolve() 132 | 133 | # Check if the output directory is empty 134 | if output_dir.is_dir() and list(output_dir.iterdir()): 135 | raise ValueError(f"Output directory {output_dir} is not empty") 136 | 137 | resolution = resolution_x, resolution_y 138 | 139 | texts, fonts, hdris, materials = download_and_prepare_data() 140 | 141 | root_dir = pth.cwd() / "Blender_3D_document_rendering_pipeline" 142 | 143 | temp_dir = pth(tempfile.mkdtemp()) 144 | print(f"Saving temporary files to: {temp_dir}") 145 | 146 | config_dir = temp_dir / "configs" 147 | image_dir = temp_dir / "images" 148 | config_dir.mkdir() 149 | image_dir.mkdir() 150 | 151 | print("Generating samples...") 152 | generated_samples = generate_samples( 153 | n_samples=n_samples, 154 | device=device, 155 | output_image_resolution=resolution, 156 | compression_level=compression_level, 157 | root_dir=root_dir, 158 | output_dir=output_dir, 159 | config_dir=config_dir, 160 | image_dir=image_dir, 161 | random_font_iter=fonts, 162 | random_hdri_iter=hdris, 163 | random_material_iter=materials, 164 | shuffled_dataset_iter=texts, 165 | ) 166 | 167 | print("Rendering samples using Blender...") 168 | output_dir.mkdir(parents=True, exist_ok=True) 169 | run_blender_command(blender_path, config_dir, output_dir, device) 170 | 171 | postprocess_samples(generated_samples) 172 | 173 | for k in generated_samples: 174 | output_dict = { 175 | "text": k.text, 176 | "bboxes": k.bounding_boxes, 177 | "font_path": str(k.font_path), 178 | "font_color": k.font_color.tolist(), 179 | "text_rotation_angle": k.text_rotation_angle, 180 | "resolution": k.output_image_resolution, 181 | } 182 | 183 | if k.bounding_boxes is None: 184 | print(f"Failed to calculate bounding boxes for {k.text}! Skipping") 185 | print(json.dumps(output_dict, indent=4)) 186 | sample_output_dir = k.output_image_path.parent 187 | rmtree(sample_output_dir) 188 | continue 189 | 190 | with open(k.output_image_path.with_suffix(".json"), "w") as f: 191 | json.dump(output_dict, f, indent=4) 192 | 193 | print("Cleaning up temporary files...") 194 | rmtree(temp_dir) 195 | 196 | 197 | if __name__ == "__main__": 198 | fire.Fire(main) 199 | -------------------------------------------------------------------------------- /src/polyhaven_download.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | import requests 3 | from pathlib import Path as pth 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | polyhaven_api_endpoint = "https://api.polyhaven.com" 9 | 10 | polyhaven_texture_type_map = { 11 | "Diffuse": "albedo", 12 | "Rough": "roughness", 13 | "Displacement": "displacement", 14 | } 15 | 16 | 17 | def download_texture( 18 | name: str, 19 | texture_type: str, 20 | save_directory: Union[str, pth], 21 | resolution: str, 22 | file_extension: str, 23 | ) -> pth: 24 | """ 25 | Downloads a single texture from Polyhaven. 26 | """ 27 | save_directory = pth(save_directory) 28 | save_directory.mkdir(parents=True, exist_ok=True) 29 | 30 | mapped_texture_type = polyhaven_texture_type_map.get(texture_type, texture_type) 31 | 32 | file_path = ( 33 | save_directory / f"{name}_{mapped_texture_type}_{resolution}.{file_extension}" 34 | ) 35 | if file_path.is_file(): 36 | return file_path 37 | 38 | file_url = requests.get(f"{polyhaven_api_endpoint}/files/{name}").json() 39 | file_url = file_url[texture_type][resolution][file_extension]["url"] 40 | 41 | response = requests.get(file_url, allow_redirects=True) 42 | 43 | file_path.write_bytes(response.content) 44 | 45 | return file_path 46 | 47 | 48 | def download_all_hdris( 49 | save_directory: Union[str, pth] = "hdris", 50 | resolution: str = "2k", 51 | file_extension: str = "exr", 52 | ) -> List[pth]: 53 | """ 54 | Downloads all HDRIs from Polyhaven. 55 | """ 56 | names = requests.get("https://api.polyhaven.com/assets?t=hdris").json().keys() 57 | 58 | file_paths = [] 59 | for name in tqdm(names): 60 | try: 61 | file_paths.append( 62 | download_texture( 63 | name, "hdri", save_directory, resolution, file_extension 64 | ) 65 | ) 66 | except Exception as e: 67 | tqdm.write(f"Failed to download {name}: {e}") 68 | 69 | return file_paths 70 | 71 | 72 | def download_all_materials( 73 | save_directory: Union[str, pth] = "materials", 74 | resolution: str = "1k", 75 | file_extension: str = "jpg", 76 | texture_types: List[str] = None, 77 | ) -> List[pth]: 78 | """ 79 | Downloads all materials from Polyhaven. 80 | """ 81 | if texture_types is None: 82 | texture_types = ["Diffuse", "Rough", "Displacement"] 83 | 84 | names = requests.get("https://api.polyhaven.com/assets?t=textures").json().keys() 85 | 86 | file_paths = [] 87 | for name in tqdm(names): 88 | newly_downloaded = [] 89 | try: 90 | newly_downloaded.append( 91 | { 92 | k: download_texture( 93 | name, k, save_directory, resolution, file_extension 94 | ) 95 | for k in texture_types 96 | } 97 | ) 98 | except Exception as e: 99 | tqdm.write(f"Failed to download {name}: {e}") 100 | print(newly_downloaded) 101 | for item in newly_downloaded: 102 | for k in item.values(): 103 | print(f"deleting {k}") 104 | # k.close() 105 | 106 | file_paths += newly_downloaded 107 | 108 | return file_paths 109 | -------------------------------------------------------------------------------- /src/postprocess.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Tuple, Union, List 2 | import math 3 | from random import random 4 | import concurrent.futures 5 | import multiprocessing 6 | from pathlib import Path as pth 7 | 8 | import numpy as np 9 | import cv2 as cv 10 | from p_tqdm import p_map 11 | 12 | from bounding_box import ( 13 | SimpleBoundingBox, 14 | ) 15 | from bounding_box_math import bbox_from_mask 16 | from calculate_bounding_boxes import calculate_bounding_boxes 17 | from generate_samples_2d import SampleInfo 18 | 19 | 20 | def random_gaussian_blur(img: np.ndarray) -> Tuple[np.ndarray, int]: 21 | radius = math.floor((random() ** 3) * 2) 22 | if radius > 0: 23 | img = cv.GaussianBlur(img, (radius * 2 + 1,) * 2, 0) 24 | return img, radius 25 | 26 | 27 | def make_grid(images: Iterable[np.ndarray]) -> np.ndarray: 28 | resolution = images[0].shape[:2] 29 | n_images = len(images) 30 | grid_size = n_images**0.5 31 | grid_size = math.ceil(grid_size) 32 | 33 | img = np.full( 34 | (resolution[0] * grid_size, resolution[1] * grid_size, 3), 0, dtype=np.uint8 35 | ) 36 | 37 | for x in range(grid_size): 38 | for y in range(grid_size): 39 | i = x * grid_size + y 40 | if i >= n_images: 41 | break 42 | img[ 43 | x * resolution[0] : (x + 1) * resolution[0], 44 | y * resolution[1] : (y + 1) * resolution[1], 45 | ] = images[i] 46 | else: 47 | continue 48 | break 49 | 50 | return img 51 | 52 | 53 | def calculate_mask(document_img: np.ndarray): 54 | mask = 1 - document_img.astype(np.float32) / 255 55 | mask = 1 - mask / np.max(mask, axis=(0, 1))[None, None] 56 | mask = np.clip(mask[:, :, 0], 0, 1) 57 | return (mask * 255).astype(np.uint8) 58 | 59 | 60 | def apply_random_gaussian_blur_to_sample( 61 | image_path: Union[str, pth], compression_level: int 62 | ): 63 | img = cv.imread(str(image_path)) 64 | img, gaussian_blur_radius = random_gaussian_blur(img) 65 | cv.imwrite( 66 | str(image_path), 67 | img, 68 | [cv.IMWRITE_PNG_COMPRESSION, compression_level], 69 | ) 70 | return gaussian_blur_radius 71 | 72 | 73 | def calculate_simple_bbox( 74 | mask_warped: np.ndarray, coords_relative: np.ndarray 75 | ) -> Tuple[SimpleBoundingBox, SimpleBoundingBox]: 76 | bbox = bbox_from_mask(mask_warped, 255) 77 | bbox_relative_xxyy = bbox.relative(coords_relative.shape[:2][::-1]).xxyy() 78 | 79 | return bbox, bbox_relative_xxyy 80 | 81 | 82 | def postprocess_samples(samples: List[SampleInfo]): 83 | bounding_boxes_list = p_map( 84 | calculate_bounding_boxes, samples, desc="Calculating bounding boxes" 85 | ) 86 | for sample, bounding_boxes in zip(samples, bounding_boxes_list): 87 | sample.bounding_boxes = bounding_boxes[0] 88 | -------------------------------------------------------------------------------- /src/prepare_data.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from random import randint 3 | from pathlib import Path as pth 4 | import zipfile 5 | import itertools 6 | 7 | import gdown 8 | from mega import Mega 9 | 10 | from datasets import load_dataset 11 | from shuffle_iter import ShuffleIterator, DatasetShuffleIterator 12 | 13 | 14 | zip_path = pth("../assets/") 15 | mega = Mega() 16 | m = mega.login() 17 | 18 | 19 | class DownloadError(IOError): 20 | def __init__(self, *args): 21 | super().__init__(*args) 22 | 23 | def __str__(self): 24 | return f"Failed to download {super().__str__()}" 25 | 26 | 27 | class MultipleDatasetShuffleIterator: 28 | def __init__(self, datasets: list): 29 | self.dataset_shuffle_iterators = [DatasetShuffleIterator(k) for k in datasets] 30 | 31 | def __next__(self): 32 | return next( 33 | self.dataset_shuffle_iterators[ 34 | randint(0, len(self.dataset_shuffle_iterators) - 1) 35 | ] 36 | ) 37 | 38 | 39 | def get_text_dataset(): 40 | datasets = [ 41 | load_dataset(k, split="train", streaming=True) 42 | for k in ( 43 | "ChristophSchuhmann/wikipedia-en-nov22-1-sentence-level", 44 | "ChristophSchuhmann/1-sentence-level-gutenberg-en_arxiv_pubmed_soda", 45 | ) 46 | ] 47 | shuffled_dataset_iters = MultipleDatasetShuffleIterator(datasets) 48 | 49 | return datasets, shuffled_dataset_iters 50 | 51 | 52 | def _download_and_extract(download_fn, path: Union[str, pth]) -> pth: 53 | path = pth(path) 54 | zip_path = path.with_suffix(".zip") 55 | 56 | if not zip_path.is_file(): 57 | download_fn(zip_path) 58 | 59 | if not path.is_dir(): 60 | with zipfile.ZipFile(zip_path, "r") as zip_ref: 61 | zip_ref.extractall(path) 62 | 63 | return path 64 | 65 | def gdrive_download_and_extract(path: Union[str, pth], file_id: str) -> pth: 66 | def download_fn(zip_path: pth) -> None: 67 | gdown.download(id=file_id, output=str(zip_path.resolve()), quiet=False) 68 | 69 | return _download_and_extract(download_fn, path) 70 | 71 | def mega_download_and_extract(path: Union[str, pth], file_id_and_key: str) -> pth: 72 | def download_fn(zip_path: pth) -> None: 73 | m.download_url(f"https://mega.nz/file/{file_id_and_key}", str(zip_path.parent.resolve()), str(zip_path.name)) 74 | 75 | return _download_and_extract(download_fn, path) 76 | 77 | 78 | def get_hdris(hdris_dir: Union[str, pth]) -> ShuffleIterator: 79 | hdris_dir = pth(hdris_dir) 80 | 81 | hdris = list(hdris_dir.glob("*.exr")) 82 | return ShuffleIterator(hdris) 83 | 84 | 85 | def get_materials(materials_dir: Union[str, pth]) -> ShuffleIterator: 86 | materials_dir = pth(materials_dir) 87 | 88 | texture_types = ["albedo", "roughness", "displacement"] 89 | 90 | materials = {} 91 | for file in materials_dir.glob("*.jpg"): 92 | for k in texture_types: 93 | if k not in file.name: 94 | continue 95 | name = file.name.split(k, 1)[0] 96 | material = materials.get(name, dict.fromkeys(texture_types)) 97 | material[k] = file 98 | materials[name] = material 99 | 100 | # remove materials with missing textures 101 | materials = {k: v for k, v in materials.items() if None not in v.values()} 102 | 103 | return ShuffleIterator(list(materials.values())) 104 | 105 | 106 | def get_fonts(fonts_dir: Union[str, pth]) -> ShuffleIterator: 107 | fonts_dir = pth(fonts_dir) 108 | 109 | fonts = list( 110 | itertools.chain.from_iterable( 111 | (fonts_dir / "fontcollection").glob(f"*.{k}") for k in ("ttf", "otf") 112 | ) 113 | ) 114 | return ShuffleIterator(fonts) 115 | -------------------------------------------------------------------------------- /src/shuffle_iter.py: -------------------------------------------------------------------------------- 1 | from random import shuffle 2 | 3 | 4 | class DatasetShuffleIterator: 5 | def __init__(self, dataset, **kwargs): 6 | self.dataset = dataset 7 | self.shuffled_dataset = dataset 8 | self.iterator = iter([]) 9 | self.kwargs = kwargs 10 | 11 | self.shuffle() 12 | 13 | def shuffle(self, **kwargs): 14 | self.shuffled_dataset = self.shuffled_dataset.shuffle(**kwargs) 15 | self.iterator = iter(self.shuffled_dataset) 16 | 17 | def __next__(self): 18 | try: 19 | return next(self.iterator) 20 | except StopIteration: 21 | self.shuffle(**self.kwargs) 22 | return next(self.iterator) 23 | 24 | 25 | class ShuffleIterator: 26 | def __init__(self, lst: list): 27 | self.index = 0 28 | self.lst = lst.copy() 29 | shuffle(self.lst) 30 | 31 | def __next__(self): 32 | if self.index >= len(self.lst): 33 | shuffle(self.lst) 34 | self.index = 0 35 | item = self.lst[self.index] 36 | self.index += 1 37 | return item 38 | -------------------------------------------------------------------------------- /src/tests/test_assets/SilentReaction.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GbotHQ/ocr-dataset-rendering/f38c3698daff026a697cf47bcc087a0401eb6dfc/src/tests/test_assets/SilentReaction.ttf -------------------------------------------------------------------------------- /src/tests/test_assets/coordinates0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GbotHQ/ocr-dataset-rendering/f38c3698daff026a697cf47bcc087a0401eb6dfc/src/tests/test_assets/coordinates0001.png -------------------------------------------------------------------------------- /src/tests/test_assets/image0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GbotHQ/ocr-dataset-rendering/f38c3698daff026a697cf47bcc087a0401eb6dfc/src/tests/test_assets/image0001.png -------------------------------------------------------------------------------- /src/tests/test_assets/test_alpha_font_rendering_generate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GbotHQ/ocr-dataset-rendering/f38c3698daff026a697cf47bcc087a0401eb6dfc/src/tests/test_assets/test_alpha_font_rendering_generate.png -------------------------------------------------------------------------------- /src/tests/test_assets/test_img_font_rendering_generate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GbotHQ/ocr-dataset-rendering/f38c3698daff026a697cf47bcc087a0401eb6dfc/src/tests/test_assets/test_img_font_rendering_generate.png -------------------------------------------------------------------------------- /src/tests/test_font_rendering.py: -------------------------------------------------------------------------------- 1 | from random import seed 2 | 3 | import numpy as np 4 | import cv2 as cv 5 | from PIL import ImageFont 6 | 7 | from font_rendering import generate 8 | from bounding_box_math import calculate_char_bboxes 9 | 10 | 11 | class TestFontRenderer: 12 | def setup_class(self): 13 | self.font = ImageFont.truetype("tests/test_assets/SilentReaction.ttf", 42) 14 | self.input_text = "The quick brown fox" 15 | 16 | self.img = cv.imread("tests/test_assets/test_img_font_rendering_generate.png") 17 | self.alpha = cv.imread( 18 | "tests/test_assets/test_alpha_font_rendering_generate.png" 19 | ) 20 | 21 | def generate_render_text_test_data(self): 22 | seed(0) 23 | 24 | ( 25 | text, 26 | img, 27 | font_size, 28 | xy, 29 | line_offsets, 30 | padding, 31 | font_color, 32 | text_rotation_angle, 33 | resolution_before_rotation, 34 | ) = generate(self.input_text, self.font) 35 | 36 | compression = [cv.IMWRITE_PNG_COMPRESSION, 7] 37 | cv.imwrite( 38 | "tests/test_assets/test_img_font_rendering_generate.png", img, compression 39 | ) 40 | test_dict = { 41 | "text": text, 42 | "font_size": font_size, 43 | "xy": xy, 44 | "line_offsets": line_offsets, 45 | "padding": padding, 46 | "font_color": font_color, 47 | "text_rotation_angle": text_rotation_angle, 48 | "resolution_before_rotation": resolution_before_rotation, 49 | } 50 | 51 | print(test_dict) 52 | 53 | def test_character_bounding_boxes(self): 54 | xy = [6, 21] 55 | bboxes = calculate_char_bboxes(xy, self.input_text, self.font) 56 | assert list(bboxes[0].keys()) == ["bbox", "mask", "char_index", "line_index"] 57 | 58 | def test_render_text(self): 59 | """Test if text rendering produces expected results""" 60 | seed(0) 61 | 62 | ( 63 | text, 64 | img, 65 | font_size, 66 | xy, 67 | line_offsets, 68 | padding, 69 | font_color, 70 | text_rotation_angle, 71 | resolution_before_rotation, 72 | ) = generate(self.input_text, self.font, 512) 73 | 74 | np.testing.assert_array_equal(img, self.img) 75 | 76 | assert text == "The quick brown\nfox" 77 | assert font_size == 82 78 | assert xy == [6, 21] 79 | assert line_offsets == [0, 208] 80 | assert padding == [62, 51, 38, 61] 81 | assert font_color.tolist() == [154, 77, 148] 82 | assert text_rotation_angle == -12 83 | assert resolution_before_rotation == (508, 192) 84 | --------------------------------------------------------------------------------