├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples ├── data │ ├── Adobe Scan 23 Jan 2024.pdf │ ├── G38J5.jpg │ ├── download.png │ └── sideways.jpg └── examples.py ├── py_reform ├── __init__.py ├── core.py ├── models │ ├── __init__.py │ ├── base.py │ ├── deskew_model.py │ ├── uvdoc_model.py │ └── weights │ │ └── best_model.pkl └── utils │ ├── __init__.py │ ├── comparison.py │ ├── image.py │ └── pdf.py ├── pyproject.toml ├── requirements.txt ├── setup.py └── tests └── test_basic.py /.gitignore: -------------------------------------------------------------------------------- 1 | # outputs 2 | examples/output/ 3 | examples/output/*.jpg 4 | examples/output/*.pdf 5 | examples/output/*.png 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Jupyter Notebook 55 | .ipynb_checkpoints 56 | 57 | # Environments 58 | .env 59 | .venv 60 | env/ 61 | venv/ 62 | ENV/ 63 | env.bak/ 64 | venv.bak/ 65 | 66 | # IDE specific files 67 | .idea/ 68 | .vscode/ 69 | *.swp 70 | *.swo 71 | 72 | # OS specific files 73 | .DS_Store 74 | Thumbs.db -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Jonathan Soma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt 4 | recursive-include py_reform/models/weights *.pkl 5 | recursive-include examples *.py *.jpg *.pdf 6 | recursive-exclude examples/output * 7 | recursive-include tests *.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # py-reform: PDF & Image Dewarping Library 2 | 3 | A Python library for dewarping/straightening/reformatting document images and PDFs. 4 | 5 | ![An example](examples/comparison.jpg) 6 | 7 | ## Features 8 | 9 | - Dewarp/straighten single images 10 | - Process entire PDFs or selected pages 11 | - Return PIL images for further processing 12 | - Save results as images or PDFs 13 | - Progress tracking with tqdm 14 | - Flexible error handling 15 | - Automatic EXIF orientation handling 16 | - Multiple dewarping models 17 | 18 | ## Installation 19 | 20 | ```bash 21 | pip install py-reform 22 | ``` 23 | 24 | ## Quick Start 25 | 26 | ### Process a Single Image 27 | 28 | ```python 29 | from py_reform import straighten 30 | 31 | # Process a single image 32 | straight_image = straighten("curved_page.jpg") 33 | straight_image.save("straight_page.jpg") 34 | ``` 35 | 36 | ### Process a PDF 37 | 38 | ```python 39 | from py_reform import straighten, save_pdf 40 | 41 | # Process a PDF (all pages) 42 | straight_pages = straighten("document.pdf") 43 | 44 | # Save processed pages as a new PDF 45 | save_pdf(straight_pages, "straight_document.pdf") 46 | ``` 47 | 48 | ### Process Specific PDF Pages 49 | 50 | ```python 51 | # Process specific PDF pages 52 | straight_pages = straighten("document.pdf", pages=[0, 2, 5]) 53 | ``` 54 | 55 | ### Choose a Different Dewarping Model 56 | 57 | By default we use [UVDoc](https://github.com/tanguymagne/UVDoc), which works for all sorts of problematic images. If you just need to rotate the image, though, use [deskew](https://github.com/sbrunner/deskew) instead. 58 | 59 | ```python 60 | # Use the rotation-based deskew model 61 | straight_image = straighten("document.jpg", model="deskew") 62 | 63 | # Use the UVDoc model with custom parameters 64 | straight_image = straighten("document.jpg", model="uvdoc", device="cpu") 65 | 66 | # Configure deskew model parameters 67 | straight_image = straighten("document.jpg", model="deskew", max_angle=15.0, num_peaks=30) 68 | ``` 69 | 70 | ### Create Before/After Comparisons 71 | 72 | ```python 73 | from py_reform.utils import create_comparison 74 | 75 | straight_image = straighten("curved_page.jpg") 76 | 77 | # Create a side-by-side comparison 78 | comparison = create_comparison(["curved_page.jpg", straight_image]) 79 | comparison.save("comparison.jpg") 80 | ``` 81 | 82 | ### Error Handling 83 | 84 | ```python 85 | # Default: stop on error 86 | result = straighten("document.pdf", errors="raise") 87 | # Skip errors, log warning 88 | result = straighten("document.pdf", errors="ignore") 89 | # Use original on error with warning 90 | result = straighten("document.pdf", errors="warn") 91 | ``` 92 | 93 | ### Working with Image Orientation 94 | 95 | The library automatically handles EXIF orientation data in JPEG files, ensuring that images are correctly oriented before processing. You can also use these utilities directly: 96 | 97 | ```python 98 | from py_reform.utils import open_image, auto_rotate_image 99 | import PIL.Image 100 | 101 | # Open an image with automatic orientation correction 102 | img = open_image("photo.jpg") 103 | 104 | # Or correct orientation of an already opened image 105 | img = PIL.Image.open("photo.jpg") 106 | img = auto_rotate_image(img) 107 | ``` 108 | 109 | ## Available Models 110 | 111 | - [UVDoc](https://github.com/tanguymagne/UVDoc/) 112 | - [deskew](https://github.com/sbrunner/deskew) 113 | 114 | ## Examples 115 | 116 | See [examples/examples.py](examples/examples.py) 117 | 118 | ## Citation 119 | 120 | The UVDoc model is based on original work by Floor Verhoeven, Tanguy Magne, and Olga Sorkine-Hornung. If you use py-reform with the UVDoc model, please consider citing their work: 121 | 122 | ```bibtex 123 | @inproceedings{UVDoc, 124 | title={{UVDoc}: Neural Grid-based Document Unwarping}, 125 | author={Floor Verhoeven and Tanguy Magne and Olga Sorkine-Hornung}, 126 | booktitle = {SIGGRAPH ASIA, Technical Papers}, 127 | year = {2023}, 128 | url={https://doi.org/10.1145/3610548.3618174} 129 | } 130 | ``` 131 | 132 | Original UVDoc repository: [https://github.com/tanguymagne/UVDoc/](https://github.com/tanguymagne/UVDoc/) 133 | 134 | ## Anything else?? 135 | 136 | I'm pretty sure I wrote about *two lines of code for this*, the rest was all [Cursor](https://www.cursor.com/en) and [Claude 3.7 Sonnet](https://claude.ai/). My job was mostly making demands around pathlib and ditching OpenCV. -------------------------------------------------------------------------------- /examples/data/Adobe Scan 23 Jan 2024.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsoma/py-reform/da6e0ac5a71ca4fad3008390111194010ed4585c/examples/data/Adobe Scan 23 Jan 2024.pdf -------------------------------------------------------------------------------- /examples/data/G38J5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsoma/py-reform/da6e0ac5a71ca4fad3008390111194010ed4585c/examples/data/G38J5.jpg -------------------------------------------------------------------------------- /examples/data/download.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsoma/py-reform/da6e0ac5a71ca4fad3008390111194010ed4585c/examples/data/download.png -------------------------------------------------------------------------------- /examples/data/sideways.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsoma/py-reform/da6e0ac5a71ca4fad3008390111194010ed4585c/examples/data/sideways.jpg -------------------------------------------------------------------------------- /examples/examples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple examples for using py-reform to dewarp document images and PDFs. 3 | """ 4 | 5 | import sys 6 | import shutil 7 | from pathlib import Path 8 | 9 | # Add the parent directory to the path so we can import py_reform 10 | sys.path.insert(0, str(Path(__file__).parent.parent)) 11 | 12 | from py_reform import straighten, save_pdf 13 | from py_reform.utils import create_comparison, pdf_to_images 14 | import PIL.Image 15 | 16 | # Clean up and recreate output directory 17 | output_dir = Path("examples/output") 18 | if output_dir.exists(): 19 | # Remove all files in the output directory 20 | for file_path in output_dir.glob("*"): 21 | if file_path.is_file(): 22 | file_path.unlink() 23 | elif file_path.is_dir(): 24 | shutil.rmtree(file_path) 25 | print(f"Cleaned up existing output directory: {output_dir}") 26 | else: 27 | # Create output directory if it doesn't exist 28 | output_dir.mkdir(parents=True) 29 | print(f"Created output directory: {output_dir}") 30 | 31 | # ===================================================================== 32 | # Process a single image with the default model 33 | # ===================================================================== 34 | 35 | # Process an image and save the result 36 | image = straighten("examples/data/G38J5.jpg") 37 | image.save("examples/output/straightened.jpg") 38 | 39 | # Create a side-by-side comparison 40 | comparison = create_comparison( 41 | ["examples/data/G38J5.jpg", image], 42 | spacing=10 # Add 10px spacing between images 43 | ) 44 | comparison.save("examples/output/comparison.jpg") 45 | 46 | # Process an image with rotated EXIF data 47 | image = straighten("examples/data/sideways.jpg") 48 | image.save("examples/output/sideways_straightened.jpg") 49 | 50 | # ===================================================================== 51 | # Process an image with the deskew model 52 | # ===================================================================== 53 | 54 | # Process an image with the Deskew model 55 | deskew_image = straighten("examples/data/G38J5.jpg", model="deskew") 56 | deskew_image.save("examples/output/deskew_straightened.jpg") 57 | 58 | # ===================================================================== 59 | # Process an entire PDF 60 | # ===================================================================== 61 | 62 | # Process all pages in a PDF 63 | pages = straighten("examples/data/Adobe Scan 23 Jan 2024.pdf") 64 | save_pdf(pages, "examples/output/straightened.pdf") 65 | 66 | # Create comparisons for the first two pages 67 | original_pages = pdf_to_images("examples/data/Adobe Scan 23 Jan 2024.pdf", pages=[0, 1]) 68 | 69 | # Save the comparison images 70 | for i, (page, original) in enumerate(zip(pages, original_pages)): 71 | comparison = create_comparison([original, page]) 72 | comparison.save(f"examples/output/pdf_comparison_{i}.jpg") 73 | 74 | # ===================================================================== 75 | # Process specific PDF pages 76 | # ===================================================================== 77 | 78 | # Process only pages 0 and 1 (first and second pages) 79 | pages = straighten("examples/data/Adobe Scan 23 Jan 2024.pdf", pages=[0, 1]) 80 | save_pdf(pages, "examples/output/pages_0_1.pdf") 81 | 82 | # ===================================================================== 83 | # Process PDF with the Deskew model 84 | # ===================================================================== 85 | 86 | # Process a PDF with the Deskew model 87 | deskew_pages = straighten("examples/data/Adobe Scan 23 Jan 2024.pdf", model="deskew") 88 | save_pdf(deskew_pages, "examples/output/deskew_straightened.pdf") 89 | 90 | # ===================================================================== 91 | # Save PDF pages as individual images 92 | # ===================================================================== 93 | 94 | # Process a PDF and save each page as an individual image 95 | pages = straighten("examples/data/Adobe Scan 23 Jan 2024.pdf", pages=[0, 1]) 96 | for i, page in enumerate(pages): 97 | page.save(f"examples/output/page_{i}.jpg", "JPEG", quality=95) 98 | 99 | # ===================================================================== 100 | # NEW: Multi-image comparison examples 101 | # ===================================================================== 102 | 103 | # Process an image with different models 104 | original = PIL.Image.open("examples/data/G38J5.jpg") 105 | uvdoc_result = straighten("examples/data/G38J5.jpg", model="uvdoc") 106 | deskew_result = straighten("examples/data/G38J5.jpg", model="deskew") 107 | 108 | # Create a comparison with multiple images in a row 109 | multi_comparison = create_comparison( 110 | images=[original, uvdoc_result, deskew_result], 111 | labels=["Original", "UVDoc Model", "Deskew Model"], 112 | spacing=15, 113 | ) 114 | multi_comparison.save("examples/output/multi_model_comparison.jpg") 115 | print(f"Created multi-model comparison: examples/output/multi_model_comparison.jpg") 116 | -------------------------------------------------------------------------------- /py_reform/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | py_reform: A Python library for dewarping/straightening document images and PDFs 3 | """ 4 | 5 | from py_reform.core import straighten 6 | from py_reform.utils.pdf import save_pdf 7 | from py_reform.utils.image import auto_rotate_image, open_image 8 | 9 | __version__ = "0.1.3" 10 | -------------------------------------------------------------------------------- /py_reform/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core functionality for the py-reform library. 3 | """ 4 | 5 | import logging 6 | from pathlib import Path 7 | from typing import List, Literal, Optional, Union 8 | 9 | import PIL.Image 10 | from tqdm import tqdm 11 | 12 | from py_reform.models import get_model 13 | from py_reform.utils.pdf import pdf_to_images 14 | from py_reform.utils.image import auto_rotate_image, open_image 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def straighten( 20 | source: Union[str, Path, PIL.Image.Image], 21 | pages: Optional[List[int]] = None, 22 | model: str = "uvdoc", 23 | errors: Literal["raise", "ignore", "warn"] = "raise", 24 | **model_params, 25 | ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: 26 | """ 27 | Dewarp/straighten document images or PDF pages. 28 | 29 | Args: 30 | source: Path to image or PDF file, or PIL Image object 31 | pages: List of page indices to process (for PDFs only) 32 | model: Name of the dewarping model to use 33 | errors: How to handle errors during processing 34 | **model_params: Additional parameters to pass to the model 35 | 36 | Returns: 37 | A single PIL Image or a list of PIL Images 38 | """ 39 | # Get the appropriate model 40 | dewarping_model = get_model(model, **model_params) 41 | 42 | # Handle different input types 43 | if isinstance(source, PIL.Image.Image): 44 | # Process a single PIL image 45 | # Auto-rotate the image if it has orientation data 46 | source = auto_rotate_image(source) 47 | try: 48 | return dewarping_model.process(source) 49 | except Exception as e: 50 | result = _handle_error(e, source, errors) 51 | if result is None: 52 | # If ignore mode returns None, we need to return an empty list to match the return type 53 | return [] 54 | return result 55 | 56 | # Convert string to Path if needed 57 | if isinstance(source, str): 58 | source_path = Path(source) 59 | else: 60 | source_path = source 61 | 62 | # Check if the source is a PDF 63 | if source_path.suffix.lower() == ".pdf": 64 | # Extract images from PDF 65 | images = pdf_to_images(source_path, pages=pages) 66 | 67 | # Process each image 68 | processed_images = [] 69 | for img in tqdm(images, desc="Processing pages"): 70 | # Auto-rotate the image if it has orientation data 71 | img = auto_rotate_image(img) 72 | try: 73 | processed = dewarping_model.process(img) 74 | processed_images.append(processed) 75 | except Exception as e: 76 | result = _handle_error(e, img, errors) 77 | if result is not None: 78 | processed_images.append(result) 79 | 80 | return processed_images 81 | 82 | # Process a single image file 83 | try: 84 | # Use our utility function to open and auto-rotate the image 85 | img = open_image(source_path) 86 | return dewarping_model.process(img) 87 | except Exception as e: 88 | # If there's an error, try to open the image again for error handling 89 | # This ensures we have a valid image for error handling 90 | try: 91 | original_img = open_image(source_path) 92 | except: 93 | # If we can't open it with our function, fall back to regular open 94 | original_img = PIL.Image.open(source_path) 95 | 96 | result = _handle_error(e, original_img, errors) 97 | if result is None: 98 | # If ignore mode returns None, we need to return an empty list to match the return type 99 | return [] 100 | return result 101 | 102 | 103 | def _handle_error( 104 | error: Exception, original_image: PIL.Image.Image, 105 | error_mode: Literal["raise", "ignore", "warn"] 106 | ) -> Optional[PIL.Image.Image]: 107 | """Handle errors based on the specified error mode.""" 108 | if error_mode == "raise": 109 | raise error 110 | elif error_mode == "warn": 111 | logger.warning(f"Error during processing: {error}. Using original image.") 112 | return original_image 113 | elif error_mode == "ignore": 114 | logger.warning(f"Error during processing: {error}. Skipping image.") 115 | return None 116 | else: 117 | raise ValueError(f"Invalid error mode: {error_mode}") 118 | -------------------------------------------------------------------------------- /py_reform/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models for document dewarping. 3 | """ 4 | 5 | from typing import Any, Dict 6 | 7 | from py_reform.models.base import DewarpingModel 8 | from py_reform.models.uvdoc_model import UVDocModel 9 | from py_reform.models.deskew_model import DeskewModel 10 | 11 | # Registry of available models 12 | MODEL_REGISTRY = { 13 | "uvdoc": UVDocModel, 14 | "deskew": DeskewModel, 15 | # Add more models here as they are implemented 16 | # "deep-learning": DeepLearningModel, 17 | # "opencv-contour": OpenCVContourModel, 18 | # "line-detection": LineDetectionModel, 19 | } 20 | 21 | 22 | def get_model(model_name: str, **model_params) -> DewarpingModel: 23 | """ 24 | Factory function to get the appropriate dewarping model. 25 | 26 | Args: 27 | model_name: Name of the model to use 28 | **model_params: Parameters to pass to the model 29 | 30 | Returns: 31 | An instance of the requested dewarping model 32 | 33 | Raises: 34 | ValueError: If the requested model is not available 35 | """ 36 | if model_name not in MODEL_REGISTRY: 37 | available_models = ", ".join(MODEL_REGISTRY.keys()) 38 | raise ValueError( 39 | f"Model '{model_name}' not available. Choose from: {available_models}" 40 | ) 41 | 42 | model_class = MODEL_REGISTRY[model_name] 43 | return model_class(**model_params) 44 | -------------------------------------------------------------------------------- /py_reform/models/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for dewarping models. 3 | """ 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import Optional 7 | 8 | import PIL.Image 9 | 10 | 11 | class DewarpingModel(ABC): 12 | """Base class for all dewarping models.""" 13 | 14 | @abstractmethod 15 | def process(self, image: PIL.Image.Image) -> PIL.Image.Image: 16 | """ 17 | Process an image to dewarp/straighten it. 18 | 19 | Args: 20 | image: The input image to process 21 | 22 | Returns: 23 | The processed (dewarped) image 24 | """ 25 | pass 26 | -------------------------------------------------------------------------------- /py_reform/models/deskew_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deskew model for document straightening. 3 | 4 | This implementation uses the deskew package to detect and correct skew in document images. 5 | """ 6 | 7 | import logging 8 | import numpy as np 9 | import PIL.Image 10 | 11 | from py_reform.models.base import DewarpingModel 12 | 13 | try: 14 | from deskew import determine_skew 15 | DESKEW_AVAILABLE = True 16 | except ImportError: 17 | DESKEW_AVAILABLE = False 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class DeskewModel(DewarpingModel): 23 | """ 24 | Deskew model for document straightening. 25 | 26 | This model uses the deskew package to detect and correct skew in document images. 27 | It's a simpler alternative to the UVDoc model for basic document straightening. 28 | """ 29 | 30 | def __init__(self, max_angle: float = 45.0, num_peaks: int = 20, **kwargs): 31 | """ 32 | Initialize the Deskew model. 33 | 34 | Args: 35 | max_angle: Maximum angle to consider for skew detection (in degrees) 36 | num_peaks: Number of peaks to consider in the Hough transform 37 | **kwargs: Additional parameters (ignored) 38 | """ 39 | if not DESKEW_AVAILABLE: 40 | raise ImportError( 41 | "The deskew package is required for the Deskew model. " 42 | "Install it with 'pip install deskew'." 43 | ) 44 | 45 | self.max_angle = max_angle 46 | self.num_peaks = num_peaks 47 | logger.info(f"Initialized Deskew model with max_angle={max_angle}, num_peaks={num_peaks}") 48 | 49 | def process(self, image: PIL.Image.Image) -> PIL.Image.Image: 50 | """ 51 | Process an image to straighten it using the Deskew algorithm. 52 | 53 | Args: 54 | image: The input image to straighten 55 | 56 | Returns: 57 | The straightened image 58 | """ 59 | # Convert PIL image to numpy array 60 | img_np = np.array(image) 61 | 62 | # Convert to grayscale if needed 63 | if len(img_np.shape) == 3 and img_np.shape[2] >= 3: 64 | # Use simple averaging for grayscale conversion 65 | grayscale = np.mean(img_np[:, :, :3], axis=2).astype(np.uint8) 66 | else: 67 | grayscale = img_np 68 | 69 | # Determine skew angle 70 | angle = determine_skew(grayscale, max_angle=self.max_angle, num_peaks=self.num_peaks) 71 | 72 | if angle is None or abs(angle) < 0.1: 73 | logger.info("No significant skew detected") 74 | return image 75 | 76 | logger.info(f"Detected skew angle: {angle:.2f} degrees") 77 | 78 | # Rotate the image to correct the skew 79 | # We use PIL for rotation to avoid scikit-image dependency 80 | rotated_image = image.rotate( 81 | angle, 82 | resample=PIL.Image.Resampling.BILINEAR, 83 | expand=True, 84 | fillcolor=(255, 255, 255) if image.mode == 'RGB' else 255 85 | ) 86 | 87 | return rotated_image -------------------------------------------------------------------------------- /py_reform/models/uvdoc_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | UVDoc model for document dewarping. 3 | 4 | This implementation is based on the UVDoc paper and code: 5 | https://github.com/tanguymagne/UVDoc/ (MIT License) 6 | 7 | Citation: 8 | @inproceedings{UVDoc, 9 | title={{UVDoc}: Neural Grid-based Document Unwarping}, 10 | author={Floor Verhoeven and Tanguy Magne and Olga Sorkine-Hornung}, 11 | booktitle = {SIGGRAPH ASIA, Technical Papers}, 12 | year = {2023}, 13 | url={https://doi.org/10.1145/3610548.3618174} 14 | } 15 | """ 16 | 17 | import logging 18 | from pathlib import Path 19 | from typing import Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import PIL.Image 23 | from PIL import ImageOps 24 | 25 | from py_reform.models.base import DewarpingModel 26 | 27 | try: 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | 32 | TORCH_AVAILABLE = True 33 | except ImportError: 34 | TORCH_AVAILABLE = False 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | # Default model parameters 39 | DEFAULT_IMG_SIZE = [488, 712] # Width, Height 40 | DEFAULT_GRID_SIZE = [45, 31] # Width, Height 41 | DEFAULT_MODEL_PATH = Path(__file__).parent / "weights" / "best_model.pkl" 42 | 43 | 44 | def conv3x3(in_channels, out_channels, kernel_size, stride=1): 45 | """3x3 convolution with padding""" 46 | return nn.Conv2d( 47 | in_channels, 48 | out_channels, 49 | kernel_size=kernel_size, 50 | stride=stride, 51 | padding=kernel_size // 2, 52 | ) 53 | 54 | 55 | def dilated_conv_bn_act(in_channels, out_channels, act_fn, BatchNorm, dilation): 56 | """Dilated convolution with batch normalization and activation""" 57 | model = nn.Sequential( 58 | nn.Conv2d( 59 | in_channels, 60 | out_channels, 61 | bias=False, 62 | kernel_size=3, 63 | stride=1, 64 | padding=dilation, 65 | dilation=dilation, 66 | ), 67 | BatchNorm(out_channels), 68 | act_fn, 69 | ) 70 | return model 71 | 72 | 73 | def dilated_conv(in_channels, out_channels, kernel_size, dilation, stride=1): 74 | """Dilated convolution""" 75 | model = nn.Sequential( 76 | nn.Conv2d( 77 | in_channels, 78 | out_channels, 79 | kernel_size=kernel_size, 80 | stride=stride, 81 | padding=dilation * (kernel_size // 2), 82 | dilation=dilation, 83 | ) 84 | ) 85 | return model 86 | 87 | 88 | class ResidualBlockWithDilation(nn.Module): 89 | """Residual block with dilation""" 90 | 91 | def __init__( 92 | self, 93 | in_channels, 94 | out_channels, 95 | BatchNorm, 96 | kernel_size, 97 | stride=1, 98 | downsample=None, 99 | is_activation=True, 100 | is_top=False, 101 | ): 102 | super(ResidualBlockWithDilation, self).__init__() 103 | self.stride = stride 104 | self.downsample = downsample 105 | self.is_activation = is_activation 106 | self.is_top = is_top 107 | if self.stride != 1 or self.is_top: 108 | self.conv1 = conv3x3(in_channels, out_channels, kernel_size, self.stride) 109 | self.conv2 = conv3x3(out_channels, out_channels, kernel_size) 110 | else: 111 | self.conv1 = dilated_conv( 112 | in_channels, out_channels, kernel_size, dilation=3 113 | ) 114 | self.conv2 = dilated_conv( 115 | out_channels, out_channels, kernel_size, dilation=3 116 | ) 117 | 118 | self.bn1 = BatchNorm(out_channels) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.bn2 = BatchNorm(out_channels) 121 | 122 | def forward(self, x): 123 | residual = x 124 | if self.downsample is not None: 125 | residual = self.downsample(x) 126 | 127 | out1 = self.relu(self.bn1(self.conv1(x))) 128 | out2 = self.bn2(self.conv2(out1)) 129 | 130 | out2 += residual 131 | out = self.relu(out2) 132 | return out 133 | 134 | 135 | class ResnetStraight(nn.Module): 136 | """ResNet backbone for UVDoc""" 137 | 138 | def __init__( 139 | self, 140 | num_filter, 141 | map_num, 142 | BatchNorm, 143 | block_nums=[3, 4, 6, 3], 144 | block=ResidualBlockWithDilation, 145 | kernel_size=5, 146 | stride=[1, 1, 2, 2], 147 | ): 148 | super(ResnetStraight, self).__init__() 149 | self.in_channels = num_filter * map_num[0] 150 | self.stride = stride 151 | self.relu = nn.ReLU(inplace=True) 152 | self.block_nums = block_nums 153 | self.kernel_size = kernel_size 154 | 155 | self.layer1 = self.blocklayer( 156 | block, 157 | num_filter * map_num[0], 158 | self.block_nums[0], 159 | BatchNorm, 160 | kernel_size=self.kernel_size, 161 | stride=self.stride[0], 162 | ) 163 | self.layer2 = self.blocklayer( 164 | block, 165 | num_filter * map_num[1], 166 | self.block_nums[1], 167 | BatchNorm, 168 | kernel_size=self.kernel_size, 169 | stride=self.stride[1], 170 | ) 171 | self.layer3 = self.blocklayer( 172 | block, 173 | num_filter * map_num[2], 174 | self.block_nums[2], 175 | BatchNorm, 176 | kernel_size=self.kernel_size, 177 | stride=self.stride[2], 178 | ) 179 | 180 | def blocklayer( 181 | self, block, out_channels, block_nums, BatchNorm, kernel_size, stride=1 182 | ): 183 | downsample = None 184 | if (stride != 1) or (self.in_channels != out_channels): 185 | downsample = nn.Sequential( 186 | conv3x3( 187 | self.in_channels, 188 | out_channels, 189 | kernel_size=kernel_size, 190 | stride=stride, 191 | ), 192 | BatchNorm(out_channels), 193 | ) 194 | 195 | layers = [] 196 | layers.append( 197 | block( 198 | self.in_channels, 199 | out_channels, 200 | BatchNorm, 201 | kernel_size, 202 | stride, 203 | downsample, 204 | is_top=True, 205 | ) 206 | ) 207 | self.in_channels = out_channels 208 | for i in range(1, block_nums): 209 | layers.append( 210 | block( 211 | out_channels, 212 | out_channels, 213 | BatchNorm, 214 | kernel_size, 215 | is_activation=True, 216 | is_top=False, 217 | ) 218 | ) 219 | 220 | return nn.Sequential(*layers) 221 | 222 | def forward(self, x): 223 | out1 = self.layer1(x) 224 | out2 = self.layer2(out1) 225 | out3 = self.layer3(out2) 226 | return out3 227 | 228 | 229 | class UVDocNet(nn.Module): 230 | """UVDoc neural network for document dewarping""" 231 | 232 | def __init__(self, num_filter, kernel_size=5): 233 | super(UVDocNet, self).__init__() 234 | self.num_filter = num_filter 235 | self.in_channels = 3 236 | self.kernel_size = kernel_size 237 | self.stride = [1, 2, 2, 2] 238 | 239 | BatchNorm = nn.BatchNorm2d 240 | act_fn = nn.ReLU(inplace=True) 241 | map_num = [1, 2, 4, 8, 16] 242 | 243 | self.resnet_head = nn.Sequential( 244 | nn.Conv2d( 245 | self.in_channels, 246 | self.num_filter * map_num[0], 247 | bias=False, 248 | kernel_size=self.kernel_size, 249 | stride=2, 250 | padding=self.kernel_size // 2, 251 | ), 252 | BatchNorm(self.num_filter * map_num[0]), 253 | act_fn, 254 | nn.Conv2d( 255 | self.num_filter * map_num[0], 256 | self.num_filter * map_num[0], 257 | bias=False, 258 | kernel_size=self.kernel_size, 259 | stride=2, 260 | padding=self.kernel_size // 2, 261 | ), 262 | BatchNorm(self.num_filter * map_num[0]), 263 | act_fn, 264 | ) 265 | 266 | self.resnet_down = ResnetStraight( 267 | self.num_filter, 268 | map_num, 269 | BatchNorm, 270 | block_nums=[3, 4, 6, 3], 271 | block=ResidualBlockWithDilation, 272 | kernel_size=self.kernel_size, 273 | stride=self.stride, 274 | ) 275 | 276 | map_num_i = 2 277 | self.bridge_1 = nn.Sequential( 278 | dilated_conv_bn_act( 279 | self.num_filter * map_num[map_num_i], 280 | self.num_filter * map_num[map_num_i], 281 | act_fn, 282 | BatchNorm, 283 | dilation=1, 284 | ) 285 | ) 286 | 287 | self.bridge_2 = nn.Sequential( 288 | dilated_conv_bn_act( 289 | self.num_filter * map_num[map_num_i], 290 | self.num_filter * map_num[map_num_i], 291 | act_fn, 292 | BatchNorm, 293 | dilation=2, 294 | ) 295 | ) 296 | 297 | self.bridge_3 = nn.Sequential( 298 | dilated_conv_bn_act( 299 | self.num_filter * map_num[map_num_i], 300 | self.num_filter * map_num[map_num_i], 301 | act_fn, 302 | BatchNorm, 303 | dilation=5, 304 | ) 305 | ) 306 | 307 | self.bridge_4 = nn.Sequential( 308 | *[ 309 | dilated_conv_bn_act( 310 | self.num_filter * map_num[map_num_i], 311 | self.num_filter * map_num[map_num_i], 312 | act_fn, 313 | BatchNorm, 314 | dilation=d, 315 | ) 316 | for d in [8, 3, 2] 317 | ] 318 | ) 319 | 320 | self.bridge_5 = nn.Sequential( 321 | *[ 322 | dilated_conv_bn_act( 323 | self.num_filter * map_num[map_num_i], 324 | self.num_filter * map_num[map_num_i], 325 | act_fn, 326 | BatchNorm, 327 | dilation=d, 328 | ) 329 | for d in [12, 7, 4] 330 | ] 331 | ) 332 | 333 | self.bridge_6 = nn.Sequential( 334 | *[ 335 | dilated_conv_bn_act( 336 | self.num_filter * map_num[map_num_i], 337 | self.num_filter * map_num[map_num_i], 338 | act_fn, 339 | BatchNorm, 340 | dilation=d, 341 | ) 342 | for d in [18, 12, 6] 343 | ] 344 | ) 345 | 346 | self.bridge_concat = nn.Sequential( 347 | nn.Conv2d( 348 | self.num_filter * map_num[map_num_i] * 6, 349 | self.num_filter * map_num[2], 350 | bias=False, 351 | kernel_size=1, 352 | stride=1, 353 | padding=0, 354 | ), 355 | BatchNorm(self.num_filter * map_num[2]), 356 | act_fn, 357 | ) 358 | 359 | self.out_point_positions2D = nn.Sequential( 360 | nn.Conv2d( 361 | self.num_filter * map_num[2], 362 | self.num_filter * map_num[0], 363 | bias=False, 364 | kernel_size=self.kernel_size, 365 | stride=1, 366 | padding=self.kernel_size // 2, 367 | padding_mode="reflect", 368 | ), 369 | BatchNorm(self.num_filter * map_num[0]), 370 | nn.PReLU(), 371 | nn.Conv2d( 372 | self.num_filter * map_num[0], 373 | 2, 374 | kernel_size=self.kernel_size, 375 | stride=1, 376 | padding=self.kernel_size // 2, 377 | padding_mode="reflect", 378 | ), 379 | ) 380 | 381 | self.out_point_positions3D = nn.Sequential( 382 | nn.Conv2d( 383 | self.num_filter * map_num[2], 384 | self.num_filter * map_num[0], 385 | bias=False, 386 | kernel_size=self.kernel_size, 387 | stride=1, 388 | padding=self.kernel_size // 2, 389 | padding_mode="reflect", 390 | ), 391 | BatchNorm(self.num_filter * map_num[0]), 392 | nn.PReLU(), 393 | nn.Conv2d( 394 | self.num_filter * map_num[0], 395 | 3, 396 | kernel_size=self.kernel_size, 397 | stride=1, 398 | padding=self.kernel_size // 2, 399 | padding_mode="reflect", 400 | ), 401 | ) 402 | 403 | self._initialize_weights() 404 | 405 | def _initialize_weights(self): 406 | for m in self.modules(): 407 | if isinstance(m, nn.Conv2d): 408 | nn.init.xavier_normal_(m.weight, gain=0.2) 409 | if isinstance(m, nn.ConvTranspose2d): 410 | assert m.kernel_size[0] == m.kernel_size[1] 411 | nn.init.xavier_normal_(m.weight, gain=0.2) 412 | 413 | def forward(self, x): 414 | resnet_head = self.resnet_head(x) 415 | resnet_down = self.resnet_down(resnet_head) 416 | bridge_1 = self.bridge_1(resnet_down) 417 | bridge_2 = self.bridge_2(resnet_down) 418 | bridge_3 = self.bridge_3(resnet_down) 419 | bridge_4 = self.bridge_4(resnet_down) 420 | bridge_5 = self.bridge_5(resnet_down) 421 | bridge_6 = self.bridge_6(resnet_down) 422 | bridge_concat = torch.cat( 423 | [bridge_1, bridge_2, bridge_3, bridge_4, bridge_5, bridge_6], dim=1 424 | ) 425 | bridge = self.bridge_concat(bridge_concat) 426 | 427 | out_point_positions2D = self.out_point_positions2D(bridge) 428 | out_point_positions3D = self.out_point_positions3D(bridge) 429 | 430 | return out_point_positions2D, out_point_positions3D 431 | 432 | 433 | def bilinear_unwarping(warped_img, point_positions, img_size): 434 | """ 435 | Utility function that unwarps an image. 436 | Unwarp warped_img based on the 2D grid point_positions with a size img_size. 437 | 438 | Args: 439 | warped_img: torch.Tensor of shape BxCxHxW (dtype float) 440 | point_positions: torch.Tensor of shape Bx2xGhxGw (dtype float) 441 | img_size: tuple of int [w, h] 442 | """ 443 | upsampled_grid = F.interpolate( 444 | point_positions, 445 | size=(img_size[1], img_size[0]), 446 | mode="bilinear", 447 | align_corners=True, 448 | ) 449 | unwarped_img = F.grid_sample( 450 | warped_img, upsampled_grid.transpose(1, 2).transpose(2, 3), align_corners=True 451 | ) 452 | 453 | return unwarped_img 454 | 455 | 456 | class UVDocModel(DewarpingModel): 457 | """ 458 | UVDoc model for document dewarping. 459 | 460 | This model uses a deep learning approach to predict UV maps for document dewarping. 461 | """ 462 | 463 | def __init__( 464 | self, 465 | device: Optional[str] = None, 466 | model_path: Optional[Union[str, Path]] = None, 467 | img_size: Optional[Tuple[int, int]] = None, 468 | **kwargs, 469 | ): 470 | """ 471 | Initialize the UVDoc model. 472 | 473 | Args: 474 | device: Device to run the model on ('cpu' or 'cuda'). If None, will use CUDA if available. 475 | model_path: Path to a pre-trained model file 476 | img_size: Input size for the model (width, height) 477 | **kwargs: Additional parameters 478 | """ 479 | if not TORCH_AVAILABLE: 480 | raise ImportError( 481 | "PyTorch is required for the UVDoc model. " 482 | "Install it with 'pip install torch'." 483 | ) 484 | 485 | # Set device to CUDA if available and not explicitly set to CPU 486 | if device is None: 487 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 488 | else: 489 | self.device = device 490 | 491 | # Convert model_path to Path if it's a string 492 | if model_path is not None: 493 | self.model_path = Path(model_path) 494 | else: 495 | self.model_path = DEFAULT_MODEL_PATH 496 | 497 | self.img_size = img_size or DEFAULT_IMG_SIZE 498 | 499 | # Initialize the model 500 | logger.info(f"Initializing UVDoc model from {self.model_path} on {self.device}") 501 | self._load_model() 502 | 503 | def _load_model(self): 504 | """Load the UVDoc model from the specified path.""" 505 | device = torch.device(self.device) 506 | 507 | # Create the model 508 | self.model = UVDocNet(num_filter=32, kernel_size=5) 509 | 510 | # Load the weights 511 | try: 512 | checkpoint = torch.load( 513 | self.model_path, map_location=device, weights_only=True 514 | ) 515 | 516 | # Handle different checkpoint formats 517 | if isinstance(checkpoint, dict) and "model_state" in checkpoint: 518 | state_dict = checkpoint["model_state"] 519 | elif isinstance(checkpoint, dict) and "state_dict" in checkpoint: 520 | state_dict = checkpoint["state_dict"] 521 | else: 522 | # Assume the checkpoint is the state dict itself 523 | state_dict = checkpoint 524 | 525 | self.model.load_state_dict(state_dict) 526 | self.model.to(device) 527 | self.model.eval() 528 | logger.info(f"Successfully loaded UVDoc model from {self.model_path}") 529 | except Exception as e: 530 | logger.error(f"Failed to load UVDoc model: {e}") 531 | raise RuntimeError(f"Failed to load UVDoc model: {e}") 532 | 533 | def process(self, image: PIL.Image.Image) -> PIL.Image.Image: 534 | """ 535 | Process an image to dewarp it using the UVDoc model. 536 | 537 | Args: 538 | image: The input image to dewarp 539 | 540 | Returns: 541 | The dewarped image 542 | """ 543 | device = torch.device(self.device) 544 | 545 | # Convert PIL image to numpy array 546 | img_np = np.array(image) 547 | 548 | # Convert to RGB if grayscale or RGBA 549 | if len(img_np.shape) == 2: # Grayscale 550 | # Convert grayscale to RGB using PIL 551 | image = image.convert("RGB") 552 | img_np = np.array(image) 553 | elif img_np.shape[2] == 4: # RGBA 554 | # Convert RGBA to RGB using PIL 555 | image = image.convert("RGB") 556 | img_np = np.array(image) 557 | 558 | # Normalize to [0, 1] 559 | img_np = img_np.astype(np.float32) / 255.0 560 | 561 | # Resize for model input using PIL 562 | pil_resized = image.resize((self.img_size[0], self.img_size[1]), PIL.Image.Resampling.BILINEAR) 563 | inp_np = np.array(pil_resized).astype(np.float32) / 255.0 564 | inp = torch.from_numpy(inp_np.transpose(2, 0, 1)).unsqueeze(0).to(device) 565 | 566 | # Make prediction 567 | with torch.no_grad(): 568 | point_positions2D, _ = self.model(inp) 569 | 570 | # Unwarp the image 571 | size = img_np.shape[:2][::-1] # (width, height) 572 | img_tensor = torch.from_numpy(img_np.transpose(2, 0, 1)).unsqueeze(0).to(device) 573 | 574 | unwarped = bilinear_unwarping( 575 | warped_img=img_tensor, 576 | point_positions=torch.unsqueeze(point_positions2D[0], dim=0), 577 | img_size=size, 578 | ) 579 | 580 | # Convert back to PIL image 581 | unwarped_np = ( 582 | unwarped[0].detach().cpu().numpy().transpose(1, 2, 0) * 255 583 | ).astype(np.uint8) 584 | unwarped_pil = PIL.Image.fromarray(unwarped_np) 585 | 586 | return unwarped_pil 587 | -------------------------------------------------------------------------------- /py_reform/models/weights/best_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsoma/py-reform/da6e0ac5a71ca4fad3008390111194010ed4585c/py_reform/models/weights/best_model.pkl -------------------------------------------------------------------------------- /py_reform/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for the py-reform library. 3 | """ 4 | 5 | from py_reform.utils.comparison import create_comparison 6 | from py_reform.utils.pdf import image_to_pdf, pdf_to_images, save_pdf 7 | from py_reform.utils.image import auto_rotate_image, open_image 8 | 9 | __all__ = ["create_comparison"] 10 | -------------------------------------------------------------------------------- /py_reform/utils/comparison.py: -------------------------------------------------------------------------------- 1 | """ 2 | Comparison utility functions for the py-reform library. 3 | """ 4 | 5 | from pathlib import Path 6 | from typing import List, Tuple, Union, Optional 7 | 8 | import PIL.Image 9 | from PIL import ImageDraw, ImageFont 10 | import math 11 | 12 | 13 | def create_comparison( 14 | images: Union[PIL.Image.Image, List[PIL.Image.Image], str, Path, List[Union[str, Path]]], 15 | labels: Optional[List[str]] = None, 16 | orientation: str = "horizontal", 17 | grid_size: Optional[Tuple[int, int]] = None, 18 | spacing: int = 10, 19 | label_height: int = 30, 20 | background_color: Union[str, Tuple[int, int, int]] = "white", 21 | text_color: Union[str, Tuple[int, int, int]] = "black", 22 | resize_mode: str = "fit", 23 | target_size: Optional[Tuple[int, int]] = None, 24 | ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: 25 | """ 26 | Create a comparison image showing multiple images side by side or in a grid. 27 | 28 | Args: 29 | images: Images to compare. Can be: 30 | - A single image (PIL.Image.Image or path) 31 | - A list of images (List[PIL.Image.Image] or List[path]) 32 | - For backward compatibility: if exactly two arguments are provided as 33 | positional args, they are treated as 'before' and 'after' 34 | labels: Optional list of labels for each image 35 | orientation: How to arrange the images - "horizontal", "vertical", or "grid" 36 | grid_size: Optional tuple of (rows, cols) for grid layout. If not provided, 37 | will be calculated automatically based on the number of images 38 | spacing: Pixels of space between images and around the border 39 | label_height: Height in pixels for the label area (0 to disable labels) 40 | background_color: Color for the background/spacing 41 | text_color: Color for the label text 42 | resize_mode: How to handle different image sizes: 43 | - "fit": Resize all images to fit the smallest image 44 | - "stretch": Stretch all images to the same size 45 | - "none": Keep original sizes (may result in uneven grid) 46 | target_size: Optional target size for all images (width, height) 47 | 48 | Returns: 49 | A single comparison image 50 | """ 51 | # Handle backward compatibility with before/after pattern 52 | if isinstance(images, (PIL.Image.Image, str, Path)) and 'after' in locals(): 53 | # This is the old before/after pattern 54 | before = images 55 | after = locals()['after'] 56 | before_images = _ensure_image_list(before) 57 | after_images = _ensure_image_list(after) 58 | 59 | # Ensure we have the same number of before and after images 60 | if len(before_images) != len(after_images): 61 | raise ValueError( 62 | f"Number of before images ({len(before_images)}) must match " 63 | f"number of after images ({len(after_images)})" 64 | ) 65 | 66 | # Create comparison images 67 | comparisons = [] 68 | for before_img, after_img in zip(before_images, after_images): 69 | # Create the comparison image 70 | comparison = _create_single_comparison( 71 | before_img, after_img, orientation, spacing, background_color 72 | ) 73 | comparisons.append(comparison) 74 | 75 | # Return a single image or a list depending on input 76 | return comparisons[0] if len(comparisons) == 1 else comparisons 77 | 78 | # New unified approach - convert all inputs to a list of PIL Images 79 | pil_images = _ensure_image_list(images) 80 | 81 | # If only one image, just return it 82 | if len(pil_images) == 1: 83 | return pil_images[0] 84 | 85 | # If exactly two images and orientation is horizontal or vertical, use the simple comparison 86 | if len(pil_images) == 2 and orientation.lower() in ["horizontal", "vertical"]: 87 | return _create_single_comparison( 88 | pil_images[0], pil_images[1], orientation, spacing, background_color 89 | ) 90 | 91 | # For more than two images or grid orientation, create a grid 92 | if orientation.lower() == "grid" or grid_size is not None: 93 | # Use grid layout 94 | return _create_grid( 95 | pil_images, 96 | labels, 97 | grid_size, 98 | spacing, 99 | label_height, 100 | background_color, 101 | text_color, 102 | resize_mode, 103 | target_size, 104 | ) 105 | 106 | # For horizontal or vertical layouts with more than 2 images 107 | return _create_row_or_column( 108 | pil_images, 109 | labels, 110 | orientation, 111 | spacing, 112 | label_height, 113 | background_color, 114 | text_color, 115 | resize_mode, 116 | target_size, 117 | ) 118 | 119 | 120 | def _create_row_or_column( 121 | images: List[PIL.Image.Image], 122 | labels: Optional[List[str]] = None, 123 | orientation: str = "horizontal", 124 | spacing: int = 10, 125 | label_height: int = 30, 126 | background_color: Union[str, Tuple[int, int, int]] = "white", 127 | text_color: Union[str, Tuple[int, int, int]] = "black", 128 | resize_mode: str = "fit", 129 | target_size: Optional[Tuple[int, int]] = None, 130 | ) -> PIL.Image.Image: 131 | """Create a row or column of images with optional labels.""" 132 | # Handle labels 133 | if labels is not None: 134 | if len(labels) < len(images): 135 | # Pad with empty strings if needed 136 | labels = labels + [""] * (len(images) - len(labels)) 137 | elif len(labels) > len(images): 138 | # Truncate if too many labels 139 | labels = labels[: len(images)] 140 | else: 141 | # No labels provided 142 | labels = [""] * len(images) 143 | label_height = 0 144 | 145 | # Determine target size for images 146 | if target_size is not None: 147 | # Use provided target size 148 | cell_width, cell_height = target_size 149 | elif resize_mode != "none": 150 | # Calculate target size based on resize mode 151 | if resize_mode == "fit": 152 | # Find the smallest image dimensions 153 | min_width = min(img.width for img in images) 154 | min_height = min(img.height for img in images) 155 | cell_width, cell_height = min_width, min_height 156 | elif resize_mode == "stretch": 157 | # Use the average dimensions 158 | avg_width = sum(img.width for img in images) // len(images) 159 | avg_height = sum(img.height for img in images) // len(images) 160 | cell_width, cell_height = avg_width, avg_height 161 | else: 162 | raise ValueError( 163 | f"Invalid resize_mode: {resize_mode}. " 164 | f"Must be 'fit', 'stretch', or 'none'." 165 | ) 166 | 167 | # Resize all images to the target size 168 | for i, img in enumerate(images): 169 | if img.width != cell_width or img.height != cell_height: 170 | images[i] = img.resize((cell_width, cell_height), PIL.Image.Resampling.LANCZOS) 171 | else: 172 | # Use the maximum dimensions for the cells 173 | cell_width = max(img.width for img in images) 174 | cell_height = max(img.height for img in images) 175 | 176 | # Calculate the total dimensions 177 | if orientation.lower() == "horizontal": 178 | total_width = len(images) * cell_width + (len(images) + 1) * spacing 179 | total_height = cell_height + 2 * spacing + label_height 180 | else: # vertical 181 | total_width = cell_width + 2 * spacing 182 | total_height = len(images) * (cell_height + label_height) + (len(images) + 1) * spacing 183 | 184 | # Create the image 185 | result_img = PIL.Image.new("RGB", (total_width, total_height), background_color) 186 | draw = ImageDraw.Draw(result_img) 187 | 188 | # Try to get a font for the labels 189 | try: 190 | font = ImageFont.truetype("Arial", 12) 191 | except IOError: 192 | # Fallback to default font 193 | font = ImageFont.load_default() 194 | 195 | # Place images and labels 196 | for idx, (img, label) in enumerate(zip(images, labels)): 197 | if orientation.lower() == "horizontal": 198 | # Calculate position for this cell 199 | x = spacing + idx * (cell_width + spacing) 200 | y = spacing 201 | 202 | # Center the image in its cell if it's smaller than the cell 203 | img_x = x + (cell_width - img.width) // 2 204 | img_y = y + (cell_height - img.height) // 2 205 | 206 | # Paste the image 207 | result_img.paste(img, (img_x, img_y)) 208 | 209 | # Add label if provided 210 | if label and label_height > 0: 211 | # Calculate text position (centered) 212 | text_width = draw.textlength(label, font=font) 213 | text_x = x + (cell_width - text_width) // 2 214 | text_y = y + cell_height + (label_height - 12) // 2 # Approximate font height 215 | 216 | # Draw the label 217 | draw.text((text_x, text_y), label, fill=text_color, font=font) 218 | else: # vertical 219 | # Calculate position for this cell 220 | x = spacing 221 | y = spacing + idx * (cell_height + label_height + spacing) 222 | 223 | # Center the image in its cell if it's smaller than the cell 224 | img_x = x + (cell_width - img.width) // 2 225 | img_y = y + (cell_height - img.height) // 2 226 | 227 | # Paste the image 228 | result_img.paste(img, (img_x, img_y)) 229 | 230 | # Add label if provided 231 | if label and label_height > 0: 232 | # Calculate text position (centered) 233 | text_width = draw.textlength(label, font=font) 234 | text_x = x + (cell_width - text_width) // 2 235 | text_y = y + cell_height + (label_height - 12) // 2 # Approximate font height 236 | 237 | # Draw the label 238 | draw.text((text_x, text_y), label, fill=text_color, font=font) 239 | 240 | return result_img 241 | 242 | 243 | def _create_grid( 244 | images: List[PIL.Image.Image], 245 | labels: Optional[List[str]] = None, 246 | grid_size: Optional[Tuple[int, int]] = None, 247 | spacing: int = 10, 248 | label_height: int = 30, 249 | background_color: Union[str, Tuple[int, int, int]] = "white", 250 | text_color: Union[str, Tuple[int, int, int]] = "black", 251 | resize_mode: str = "fit", 252 | target_size: Optional[Tuple[int, int]] = None, 253 | ) -> PIL.Image.Image: 254 | """Create a grid of images with optional labels.""" 255 | # Determine grid size if not provided 256 | if grid_size is None: 257 | cols = math.ceil(math.sqrt(len(images))) 258 | rows = math.ceil(len(images) / cols) 259 | grid_size = (rows, cols) 260 | else: 261 | rows, cols = grid_size 262 | if rows * cols < len(images): 263 | raise ValueError( 264 | f"Grid size {grid_size} is too small for {len(images)} images" 265 | ) 266 | 267 | # Handle labels 268 | if labels is not None: 269 | if len(labels) < len(images): 270 | # Pad with empty strings if needed 271 | labels = labels + [""] * (len(images) - len(labels)) 272 | elif len(labels) > len(images): 273 | # Truncate if too many labels 274 | labels = labels[: len(images)] 275 | else: 276 | # No labels provided 277 | labels = [""] * len(images) 278 | label_height = 0 279 | 280 | # Determine target size for images 281 | if target_size is not None: 282 | # Use provided target size 283 | cell_width, cell_height = target_size 284 | elif resize_mode != "none": 285 | # Calculate target size based on resize mode 286 | if resize_mode == "fit": 287 | # Find the smallest image dimensions 288 | min_width = min(img.width for img in images) 289 | min_height = min(img.height for img in images) 290 | cell_width, cell_height = min_width, min_height 291 | elif resize_mode == "stretch": 292 | # Use the average dimensions 293 | avg_width = sum(img.width for img in images) // len(images) 294 | avg_height = sum(img.height for img in images) // len(images) 295 | cell_width, cell_height = avg_width, avg_height 296 | else: 297 | raise ValueError( 298 | f"Invalid resize_mode: {resize_mode}. " 299 | f"Must be 'fit', 'stretch', or 'none'." 300 | ) 301 | 302 | # Resize all images to the target size 303 | for i, img in enumerate(images): 304 | if img.width != cell_width or img.height != cell_height: 305 | images[i] = img.resize((cell_width, cell_height), PIL.Image.Resampling.LANCZOS) 306 | else: 307 | # Use the maximum dimensions for the grid cells 308 | cell_width = max(img.width for img in images) 309 | cell_height = max(img.height for img in images) 310 | 311 | # Calculate the total grid dimensions 312 | total_width = cols * cell_width + (cols + 1) * spacing 313 | total_height = rows * (cell_height + label_height) + (rows + 1) * spacing 314 | 315 | # Create the grid image 316 | grid_img = PIL.Image.new("RGB", (total_width, total_height), background_color) 317 | draw = ImageDraw.Draw(grid_img) 318 | 319 | # Try to get a font for the labels 320 | try: 321 | font = ImageFont.truetype("Arial", 12) 322 | except IOError: 323 | # Fallback to default font 324 | font = ImageFont.load_default() 325 | 326 | # Place images and labels in the grid 327 | for idx, (img, label) in enumerate(zip(images, labels)): 328 | if idx >= rows * cols: 329 | break # Don't try to place more images than the grid can hold 330 | 331 | row = idx // cols 332 | col = idx % cols 333 | 334 | # Calculate position for this cell 335 | x = spacing + col * (cell_width + spacing) 336 | y = spacing + row * (cell_height + label_height + spacing) 337 | 338 | # Center the image in its cell if it's smaller than the cell 339 | img_x = x + (cell_width - img.width) // 2 340 | img_y = y + (cell_height - img.height) // 2 341 | 342 | # Paste the image 343 | grid_img.paste(img, (img_x, img_y)) 344 | 345 | # Add label if provided 346 | if label and label_height > 0: 347 | # Calculate text position (centered) 348 | text_width = draw.textlength(label, font=font) 349 | text_x = x + (cell_width - text_width) // 2 350 | text_y = y + cell_height + (label_height - 12) // 2 # Approximate font height 351 | 352 | # Draw the label 353 | draw.text((text_x, text_y), label, fill=text_color, font=font) 354 | 355 | return grid_img 356 | 357 | 358 | def _ensure_image_list( 359 | images: Union[PIL.Image.Image, List[PIL.Image.Image], str, Path, List[Union[str, Path]]] 360 | ) -> List[PIL.Image.Image]: 361 | """Convert various input types to a list of PIL Images.""" 362 | if isinstance(images, (str, Path)): 363 | # Load image from file 364 | return [PIL.Image.open(images)] 365 | elif isinstance(images, PIL.Image.Image): 366 | # Single PIL Image 367 | return [images] 368 | elif isinstance(images, list): 369 | # List of images or paths 370 | result = [] 371 | for img in images: 372 | if isinstance(img, PIL.Image.Image): 373 | result.append(img) 374 | elif isinstance(img, (str, Path)): 375 | result.append(PIL.Image.open(img)) 376 | else: 377 | raise TypeError(f"Expected PIL.Image.Image or path, got {type(img)}") 378 | return result 379 | else: 380 | raise TypeError( 381 | f"Expected PIL.Image.Image, list of images, or path, got {type(images)}" 382 | ) 383 | 384 | 385 | def _create_single_comparison( 386 | before_img: PIL.Image.Image, 387 | after_img: PIL.Image.Image, 388 | orientation: str, 389 | spacing: int, 390 | background_color: Union[str, Tuple[int, int, int]], 391 | ) -> PIL.Image.Image: 392 | """Create a single comparison image from before and after images.""" 393 | # Determine dimensions based on orientation 394 | if orientation.lower() == "horizontal": 395 | width = before_img.width + after_img.width + spacing 396 | height = max(before_img.height, after_img.height) 397 | 398 | # Create new image with background color 399 | comparison = PIL.Image.new("RGB", (width, height), background_color) 400 | 401 | # Paste images 402 | comparison.paste(before_img, (0, 0)) 403 | comparison.paste(after_img, (before_img.width + spacing, 0)) 404 | 405 | elif orientation.lower() == "vertical": 406 | width = max(before_img.width, after_img.width) 407 | height = before_img.height + after_img.height + spacing 408 | 409 | # Create new image with background color 410 | comparison = PIL.Image.new("RGB", (width, height), background_color) 411 | 412 | # Paste images 413 | comparison.paste(before_img, (0, 0)) 414 | comparison.paste(after_img, (0, before_img.height + spacing)) 415 | 416 | else: 417 | raise ValueError( 418 | f"Orientation must be 'horizontal' or 'vertical', got '{orientation}'" 419 | ) 420 | 421 | return comparison 422 | -------------------------------------------------------------------------------- /py_reform/utils/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image utility functions for the py-reform library. 3 | """ 4 | 5 | import logging 6 | from typing import Union 7 | from pathlib import Path 8 | 9 | import PIL.Image 10 | from PIL import ExifTags 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | def auto_rotate_image(image: PIL.Image.Image) -> PIL.Image.Image: 15 | """ 16 | Automatically rotate an image based on EXIF orientation. 17 | 18 | Args: 19 | image: PIL Image object 20 | 21 | Returns: 22 | PIL.Image.Image: The correctly oriented image 23 | """ 24 | # Check if the image has EXIF data 25 | if hasattr(image, '_getexif') and image._getexif() is not None: 26 | exif = dict(image._getexif().items()) 27 | 28 | # Find the orientation tag 29 | orientation_tag = None 30 | for tag, tag_value in ExifTags.TAGS.items(): 31 | if tag_value == 'Orientation': 32 | orientation_tag = tag 33 | break 34 | 35 | # Apply the appropriate rotation based on orientation 36 | if orientation_tag and orientation_tag in exif: 37 | orientation = exif[orientation_tag] 38 | 39 | # Orientation values and their corresponding rotations: 40 | # 1: No rotation (normal) 41 | # 2: Mirror horizontal 42 | # 3: Rotate 180 degrees 43 | # 4: Mirror vertical 44 | # 5: Mirror horizontal and rotate 270 degrees 45 | # 6: Rotate 90 degrees (rotate right) 46 | # 7: Mirror horizontal and rotate 90 degrees 47 | # 8: Rotate 270 degrees (rotate left) 48 | 49 | if orientation == 2: 50 | image = image.transpose(PIL.Image.Transpose.FLIP_LEFT_RIGHT) 51 | elif orientation == 3: 52 | image = image.transpose(PIL.Image.Transpose.ROTATE_180) 53 | elif orientation == 4: 54 | image = image.transpose(PIL.Image.Transpose.FLIP_TOP_BOTTOM) 55 | elif orientation == 5: 56 | image = image.transpose(PIL.Image.Transpose.FLIP_LEFT_RIGHT) 57 | image = image.transpose(PIL.Image.Transpose.ROTATE_90) 58 | elif orientation == 6: 59 | image = image.transpose(PIL.Image.Transpose.ROTATE_270) 60 | elif orientation == 7: 61 | image = image.transpose(PIL.Image.Transpose.FLIP_LEFT_RIGHT) 62 | image = image.transpose(PIL.Image.Transpose.ROTATE_270) 63 | elif orientation == 8: 64 | image = image.transpose(PIL.Image.Transpose.ROTATE_90) 65 | 66 | return image 67 | 68 | def open_image(image_path: Union[str, Path]) -> PIL.Image.Image: 69 | """ 70 | Open an image file and automatically correct its orientation. 71 | 72 | Args: 73 | image_path: Path to the image file 74 | 75 | Returns: 76 | PIL.Image.Image: The opened image with correct orientation 77 | """ 78 | # Convert to Path if string 79 | if isinstance(image_path, str): 80 | image_path = Path(image_path) 81 | 82 | # Check if file exists 83 | if not image_path.exists(): 84 | raise FileNotFoundError(f"Image file not found: {image_path}") 85 | 86 | # Open the image 87 | image = PIL.Image.open(image_path) 88 | 89 | # Auto-rotate based on EXIF data 90 | return auto_rotate_image(image) -------------------------------------------------------------------------------- /py_reform/utils/pdf.py: -------------------------------------------------------------------------------- 1 | """ 2 | PDF utility functions for the py-reform library. 3 | """ 4 | 5 | import logging 6 | from pathlib import Path 7 | from typing import Any, Dict, List, Optional, Union 8 | 9 | import PIL.Image 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | try: 14 | import pypdfium2 as pdfium 15 | 16 | PDFIUM_AVAILABLE = True 17 | except ImportError: 18 | PDFIUM_AVAILABLE = False 19 | 20 | 21 | def pdf_to_images( 22 | pdf_path: Union[str, Path], 23 | pages: Optional[List[int]] = None, 24 | dpi: int = 300, 25 | ) -> List[PIL.Image.Image]: 26 | """ 27 | Extract pages from a PDF as PIL Images. 28 | 29 | Args: 30 | pdf_path: Path to the PDF file 31 | pages: List of page indices to extract (0-indexed) 32 | dpi: Resolution for the extracted images 33 | 34 | Returns: 35 | List of PIL Images 36 | """ 37 | if not PDFIUM_AVAILABLE: 38 | raise ImportError( 39 | "PyPDFium2 is required for PDF processing. " 40 | "Install it with 'pip install pypdfium2'." 41 | ) 42 | 43 | # Convert to Path if string 44 | if isinstance(pdf_path, str): 45 | pdf_path = Path(pdf_path) 46 | 47 | if not pdf_path.exists(): 48 | raise FileNotFoundError(f"PDF file not found: {pdf_path}") 49 | 50 | logger.info(f"Extracting images from PDF: {pdf_path}") 51 | 52 | # Load the PDF 53 | pdf = pdfium.PdfDocument(pdf_path) 54 | 55 | try: 56 | # Determine which pages to process 57 | if pages is None: 58 | pages = list(range(len(pdf))) 59 | 60 | # Extract images 61 | images = [] 62 | for page_idx in pages: 63 | if page_idx < 0 or page_idx >= len(pdf): 64 | logger.warning(f"Page index {page_idx} out of range, skipping") 65 | continue 66 | 67 | # Render the page to a PIL Image 68 | page = pdf[page_idx] 69 | try: 70 | bitmap = page.render( 71 | scale=dpi / 72.0, # Convert DPI to scale factor 72 | rotation=0, 73 | ) 74 | image = bitmap.to_pil() 75 | images.append(image) 76 | finally: 77 | # Explicitly close the page to prevent memory leaks 78 | page.close() 79 | 80 | return images 81 | finally: 82 | # Ensure the PDF is closed properly 83 | pdf.close() 84 | 85 | 86 | def save_pdf( 87 | images: List[PIL.Image.Image], 88 | output_path: Union[str, Path], 89 | ) -> Path: 90 | """ 91 | Save a list of PIL Images as a PDF. 92 | 93 | Args: 94 | images: List of PIL Images to save 95 | output_path: Path where the PDF will be saved 96 | 97 | Returns: 98 | Path to the saved PDF file 99 | """ 100 | if not images: 101 | raise ValueError("No images provided to save as PDF") 102 | 103 | # Convert to Path if string 104 | if isinstance(output_path, str): 105 | output_path = Path(output_path) 106 | 107 | # Create parent directories if they don't exist 108 | output_path.parent.mkdir(exist_ok=True, parents=True) 109 | 110 | # Save the first image and append the rest 111 | first_image = images[0] 112 | remaining_images = images[1:] if len(images) > 1 else [] 113 | 114 | logger.info(f"Saving {len(images)} images as PDF: {output_path}") 115 | 116 | # Convert to RGB if needed (PDF doesn't support RGBA) 117 | if first_image.mode == "RGBA": 118 | first_image = first_image.convert("RGB") 119 | 120 | # Convert remaining images to RGB if needed 121 | rgb_remaining = [] 122 | for img in remaining_images: 123 | if img.mode == "RGBA": 124 | rgb_remaining.append(img.convert("RGB")) 125 | else: 126 | rgb_remaining.append(img) 127 | 128 | # Save as PDF 129 | first_image.save( 130 | output_path, "PDF", resolution=100.0, save_all=True, append_images=rgb_remaining 131 | ) 132 | 133 | return output_path 134 | 135 | 136 | def image_to_pdf( 137 | images: List[PIL.Image.Image], 138 | output_path: Union[str, Path], 139 | compression: str = "jpeg", 140 | quality: int = 95, 141 | ) -> Path: 142 | """ 143 | Save images as a PDF with custom settings. 144 | 145 | Args: 146 | images: List of PIL Images to save 147 | output_path: Path where the PDF will be saved 148 | compression: Compression method ('jpeg', 'png', etc.) 149 | quality: Compression quality (1-100, higher is better) 150 | 151 | Returns: 152 | Path to the saved PDF file 153 | """ 154 | if not images: 155 | raise ValueError("No images provided to save as PDF") 156 | 157 | # Convert to Path if string 158 | if isinstance(output_path, str): 159 | output_path = Path(output_path) 160 | 161 | # Create parent directories if they don't exist 162 | output_path.parent.mkdir(exist_ok=True, parents=True) 163 | 164 | # Save the first image and append the rest 165 | first_image = images[0] 166 | remaining_images = images[1:] if len(images) > 1 else [] 167 | 168 | logger.info( 169 | f"Saving {len(images)} images as PDF with {compression} compression: {output_path}" 170 | ) 171 | 172 | # Convert to RGB if needed (PDF doesn't support RGBA) 173 | if first_image.mode == "RGBA": 174 | first_image = first_image.convert("RGB") 175 | 176 | # Convert remaining images to RGB if needed 177 | rgb_remaining = [] 178 | for img in remaining_images: 179 | if img.mode == "RGBA": 180 | rgb_remaining.append(img.convert("RGB")) 181 | else: 182 | rgb_remaining.append(img) 183 | 184 | # Save as PDF with custom settings 185 | first_image.save( 186 | output_path, 187 | "PDF", 188 | resolution=100.0, 189 | save_all=True, 190 | append_images=rgb_remaining, 191 | compression=compression, 192 | quality=quality, 193 | ) 194 | 195 | return output_path 196 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow>=9.0.0,<12.0.0 2 | numpy>=1.20.0,<2.0.0 3 | tqdm>=4.60.0,<5.0.0 4 | pypdfium2>=4.0.0,<5.0.0 5 | torch>=1.10.0,<3.0.0 6 | deskew>=1.0.0 7 | 8 | # Development dependencies (optional) 9 | # pytest>=7.0.0 10 | # black>=22.0.0 11 | # isort>=5.10.0 12 | # mypy>=0.950 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup script for py-reform. 3 | """ 4 | 5 | from setuptools import setup, find_namespace_packages 6 | 7 | with open("README.md", "r", encoding="utf-8") as fh: 8 | long_description = fh.read() 9 | 10 | setup( 11 | name="py-reform", 12 | version="0.1.3", 13 | author="Jonathan Soma", 14 | author_email="jonathan.soma@gmail.com", 15 | description="A Python library for dewarping/straightening/reformatting document images and PDFs", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/jsoma/py-reform", 19 | packages=find_namespace_packages(), 20 | classifiers=[ 21 | "Programming Language :: Python :: 3", 22 | "Programming Language :: Python :: 3.8", 23 | "Programming Language :: Python :: 3.9", 24 | "Programming Language :: Python :: 3.10", 25 | "License :: OSI Approved :: MIT License", 26 | "Operating System :: OS Independent", 27 | "Topic :: Scientific/Engineering :: Image Processing", 28 | "Topic :: Multimedia :: Graphics", 29 | ], 30 | python_requires=">=3.8", 31 | install_requires=[ 32 | "pillow>=9.0.0,<12.0.0", 33 | "numpy>=1.20.0,<2.0.0", 34 | "tqdm>=4.60.0,<5.0.0", 35 | "pypdfium2>=4.0.0,<5.0.0", 36 | "torch>=1.10.0,<3.0.0", 37 | "deskew>=1.0.0", 38 | ], 39 | extras_require={ 40 | "dev": [ 41 | "pytest>=7.0.0", 42 | "black>=22.0.0", 43 | "isort>=5.10.0", 44 | "mypy>=0.950", 45 | ], 46 | }, 47 | include_package_data=True, 48 | package_data={ 49 | "py_reform.models": ["weights/*.pkl"], 50 | }, 51 | ) -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic tests for py-reform. 3 | """ 4 | 5 | import sys 6 | import unittest 7 | from pathlib import Path 8 | 9 | # Add the parent directory to the path so we can import py_reform 10 | sys.path.insert(0, str(Path(__file__).parent.parent)) 11 | 12 | import PIL.Image 13 | from py_reform import straighten 14 | from py_reform.models import get_model 15 | 16 | class TestBasic(unittest.TestCase): 17 | """Basic tests for py-reform.""" 18 | 19 | def test_get_model(self): 20 | """Test that we can get a model.""" 21 | try: 22 | model = get_model("uvdoc") 23 | self.assertIsNotNone(model) 24 | except ImportError: 25 | # Skip if torch is not available 26 | self.skipTest("PyTorch not available") 27 | 28 | def test_straighten_image(self): 29 | """Test straightening an image.""" 30 | # Create a simple test image 31 | test_image = PIL.Image.new("RGB", (100, 100), color="white") 32 | 33 | # Process the image 34 | try: 35 | result = straighten(test_image) 36 | 37 | # Check that we got an image back 38 | self.assertIsInstance(result, PIL.Image.Image) 39 | 40 | # Check that the dimensions are the same 41 | self.assertEqual(result.size, test_image.size) 42 | except ImportError: 43 | # Skip if torch is not available 44 | self.skipTest("PyTorch not available") 45 | 46 | def test_invalid_model(self): 47 | """Test that we get an error for an invalid model.""" 48 | with self.assertRaises(ValueError): 49 | get_model("invalid_model") 50 | 51 | if __name__ == "__main__": 52 | unittest.main() --------------------------------------------------------------------------------