├── image ├── dataset.png └── Pytorch-Dataloader.png ├── inference ├── images │ ├── 0 │ │ ├── vietocr_img_000040.jpg │ │ └── vietocr_img_000240.jpg │ └── 1 │ │ └── paper_img_002844.png └── inference.ipynb ├── ruff.toml ├── requirements.txt ├── data ├── __init__.py ├── augmentations.py ├── collate.py ├── vocab.txt ├── dali.py ├── vocab.py └── dataset.py ├── script ├── train.sh └── test.sh ├── models ├── seq_modules.py ├── __init__.py ├── backbones.py └── pred_modules.py ├── cli.py ├── inference.py ├── .gitignore ├── README.md ├── test.py ├── utils.py ├── LICENSE ├── main.py └── dataloader.py /image/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducto489/lib_ocr/HEAD/image/dataset.png -------------------------------------------------------------------------------- /image/Pytorch-Dataloader.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducto489/lib_ocr/HEAD/image/Pytorch-Dataloader.png -------------------------------------------------------------------------------- /inference/images/1/paper_img_002844.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducto489/lib_ocr/HEAD/inference/images/1/paper_img_002844.png -------------------------------------------------------------------------------- /inference/images/0/vietocr_img_000040.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducto489/lib_ocr/HEAD/inference/images/0/vietocr_img_000040.jpg -------------------------------------------------------------------------------- /inference/images/0/vietocr_img_000240.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducto489/lib_ocr/HEAD/inference/images/0/vietocr_img_000240.jpg -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 120 2 | target-version = "py310" 3 | 4 | [lint] 5 | # Only ignore variables with names starting with "_". 6 | dummy-variable-rgx = "^_.*$" 7 | 8 | [lint.isort] 9 | force-single-line = false 10 | lines-after-imports = 2 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu121 2 | --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda120 3 | --extra-index-url https://pypi.nvidia.com nvidia-dali-cuda120 4 | 5 | torch==2.5.1+cu121 6 | pytorch-lightning==2.4.0 7 | opencv-python 8 | wandb 9 | timm 10 | jsonargparse[signatures]>=4.27.7 11 | loguru 12 | torchvision 13 | pandas -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import OCRDataset 2 | from .collate import OCRCollator 3 | from .augmentations import data_transforms, data_transforms_2 4 | from .dataset import process_tgt 5 | from .dali import ExternalInputCallable, LightningWrapper 6 | from .vocab import Vocab 7 | 8 | __all__ = [ 9 | "OCRDataset", 10 | "OCRCollator", 11 | "data_transforms", 12 | "data_transforms_2", 13 | "process_tgt", 14 | "ExternalInputCallable", 15 | "LightningWrapper", 16 | "Vocab", 17 | ] 18 | -------------------------------------------------------------------------------- /script/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python cli.py fit \ 4 | --data.train_data_path "/home/qsvm/dataset/train" \ 5 | --data.val_data_path "/home/qsvm/dataset/val" \ 6 | --data.batch_size 64 \ 7 | --data.num_workers 8 \ 8 | --data.dali True \ 9 | --data.frac 1 \ 10 | --model.backbone_name "resnet18" \ 11 | --model.seq_name "bilstm" \ 12 | --model.pred_name "attn" \ 13 | --model.learning_rate 1e-4 \ 14 | --model.batch_max_length 200 \ 15 | --model.save_dir "checkpoints/run200" \ 16 | --trainer.max_epochs 10 \ 17 | --trainer.val_check_interval 0.5 \ 18 | --trainer.logger WandbLogger \ 19 | --trainer.logger.name "real-dali-200-fix-aug"\ 20 | --trainer.logger.project "OCR"\ 21 | --trainer.log_every_n_steps 16 -------------------------------------------------------------------------------- /models/seq_modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BiLSTM(nn.Module): 5 | def __init__(self, input_size, hidden_size, output_size): 6 | super().__init__() 7 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 8 | self.linear = nn.Linear(hidden_size * 2, output_size) 9 | 10 | def forward(self, input): 11 | """ 12 | input : visual feature [batch_size x T x input_size] 13 | output : contextual feature [batch_size x T x output_size] 14 | """ 15 | self.rnn.flatten_parameters() 16 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 17 | output = self.linear(recurrent) # batch_size x T x output_size 18 | return output 19 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import resnet18, resnet50, VGG 2 | from .seq_modules import BiLSTM 3 | from .pred_modules import CTC, Attention 4 | 5 | 6 | _backbone_factory = {"resnet18": resnet18, "resnet50": resnet50, "vgg": VGG} 7 | 8 | 9 | seq_factory = {"bilstm": BiLSTM} 10 | 11 | 12 | pred_factory = {"ctc": CTC, "attn": Attention} 13 | 14 | 15 | def get_module(backbone_name, seq_name, pred_name): 16 | print(f"backbone_name: {backbone_name}") 17 | print(f"backbone_fac: {_backbone_factory}") 18 | if backbone_name not in _backbone_factory: 19 | raise ValueError(f"Backbone {backbone_name} not found") 20 | backbone = _backbone_factory[backbone_name] 21 | 22 | if seq_name not in seq_factory: 23 | seq_module = None 24 | else: 25 | seq_module = seq_factory[seq_name] 26 | 27 | if pred_name not in pred_factory: 28 | raise ValueError(f"Prediction {pred_name} not found") 29 | pred_module = pred_factory[pred_name] 30 | 31 | return backbone, seq_module, pred_module 32 | -------------------------------------------------------------------------------- /data/augmentations.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import v2 2 | 3 | 4 | class Scaling: 5 | def __call__(self, image): 6 | w, h = image.size 7 | H = 100 8 | scale_ratio = H / h 9 | return v2.functional.resize(image, (100, int(w * scale_ratio))) 10 | 11 | 12 | data_transforms = { 13 | "train": v2.Compose( 14 | [ 15 | Scaling(), 16 | v2.ToTensor(), 17 | v2.Normalize((0.5,), (0.5,)), 18 | ] 19 | ), 20 | "val": v2.Compose( 21 | [ 22 | Scaling(), 23 | v2.ToTensor(), 24 | v2.Normalize((0.5,), (0.5,)), 25 | ] 26 | ), 27 | } 28 | 29 | data_transforms_2 = { 30 | "train": v2.Compose( 31 | [ 32 | Scaling(), 33 | v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.3), 34 | v2.RandomAffine(degrees=2, scale=(0.9, 1)), 35 | v2.RandomPerspective(distortion_scale=0.1, p=0.5), 36 | v2.ToTensor(), 37 | v2.GaussianNoise(), 38 | v2.Normalize((0.5,), (0.5,)), 39 | ] 40 | ), 41 | "val": v2.Compose( 42 | [ 43 | Scaling(), 44 | v2.ToTensor(), 45 | v2.Normalize((0.5,), (0.5,)), 46 | ] 47 | ), 48 | } 49 | -------------------------------------------------------------------------------- /data/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | 4 | 5 | class OCRCollator: 6 | def __init__(self): 7 | self.to_tensor = transforms.ToTensor() # Convert PIL image to tensor 8 | 9 | def __call__(self, batch): 10 | # Unpack batch into separate lists 11 | images, encoded_labels, lengths = zip(*batch) 12 | 13 | # Convert images to tensors and get their shapes 14 | # images = [self.to_tensor(img) for img in images] 15 | channels, height = images[0].shape[0], images[0].shape[1] 16 | widths = [img.shape[2] for img in images] 17 | 18 | # Create a tensor to hold the padded images 19 | max_width = max(widths) 20 | padded_imgs = torch.zeros((len(images), channels, height, max_width), dtype=torch.float32) 21 | 22 | # Copy each image into the tensor with padding 23 | for i, img in enumerate(images): 24 | orig_width = widths[i] 25 | padded_imgs[i, :, :, :orig_width] = img 26 | # add the same background for new padding 27 | # if not, the added background will be black 28 | if orig_width < max_width: 29 | padded_imgs[i, :, :, orig_width:] = img[:, :, -1].unsqueeze(2) 30 | 31 | padded_labels = torch.zeros((len(encoded_labels), encoded_labels[0].size(0)), dtype=torch.int64) 32 | for i, label in enumerate(encoded_labels): 33 | padded_labels[i] = label 34 | 35 | return { 36 | "data": padded_imgs, # Batched image tensor 37 | "label": padded_labels, # Padded encoded labels 38 | "length": torch.tensor(lengths).squeeze(), # Lengths of original labels 39 | } 40 | -------------------------------------------------------------------------------- /data/vocab.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | " 4 | # 5 | $ 6 | % 7 | & 8 | ' 9 | ( 10 | ) 11 | + 12 | , 13 | - 14 | . 15 | / 16 | 0 17 | 1 18 | 2 19 | 3 20 | 4 21 | 5 22 | 6 23 | 7 24 | 8 25 | 9 26 | : 27 | ; 28 | ? 29 | A 30 | B 31 | C 32 | D 33 | E 34 | F 35 | G 36 | H 37 | I 38 | J 39 | K 40 | L 41 | M 42 | N 43 | O 44 | P 45 | Q 46 | R 47 | S 48 | T 49 | U 50 | V 51 | W 52 | X 53 | Y 54 | Z 55 | [ 56 | \ 57 | ] 58 | ^ 59 | _ 60 | a 61 | b 62 | c 63 | d 64 | e 65 | f 66 | g 67 | h 68 | i 69 | j 70 | k 71 | l 72 | m 73 | n 74 | o 75 | p 76 | q 77 | r 78 | s 79 | t 80 | u 81 | v 82 | w 83 | x 84 | y 85 | z 86 | { 87 | | 88 | } 89 | ° 90 | ² 91 | À 92 | Á 93 | Â 94 | Ã 95 | È 96 | É 97 | Ê 98 | Ì 99 | Í 100 | Ò 101 | Ó 102 | Ô 103 | Õ 104 | Ù 105 | Ú 106 | Ý 107 | à 108 | á 109 | â 110 | ã 111 | è 112 | é 113 | ê 114 | ì 115 | í 116 | ò 117 | ó 118 | ô 119 | õ 120 | ù 121 | ú 122 | ý 123 | Ă 124 | ă 125 | Đ 126 | đ 127 | Ĩ 128 | ĩ 129 | Ũ 130 | ũ 131 | Ơ 132 | ơ 133 | Ư 134 | ư 135 | Ạ 136 | ạ 137 | Ả 138 | ả 139 | Ấ 140 | ấ 141 | Ầ 142 | ầ 143 | Ẩ 144 | ẩ 145 | Ẫ 146 | ẫ 147 | Ậ 148 | ậ 149 | Ắ 150 | ắ 151 | Ằ 152 | ằ 153 | Ẳ 154 | ẳ 155 | Ẵ 156 | ẵ 157 | Ặ 158 | ặ 159 | Ẹ 160 | ẹ 161 | Ẻ 162 | ẻ 163 | Ẽ 164 | ẽ 165 | Ế 166 | ế 167 | Ề 168 | ề 169 | Ể 170 | ể 171 | Ễ 172 | ễ 173 | Ệ 174 | ệ 175 | Ỉ 176 | ỉ 177 | Ị 178 | ị 179 | Ọ 180 | ọ 181 | Ỏ 182 | ỏ 183 | Ố 184 | ố 185 | Ồ 186 | ồ 187 | Ổ 188 | ổ 189 | Ỗ 190 | ỗ 191 | Ộ 192 | ộ 193 | Ớ 194 | ớ 195 | Ờ 196 | ờ 197 | Ở 198 | ở 199 | Ỡ 200 | ỡ 201 | Ợ 202 | ợ 203 | Ụ 204 | ụ 205 | Ủ 206 | ủ 207 | Ứ 208 | ứ 209 | Ừ 210 | ừ 211 | Ử 212 | ử 213 | Ữ 214 | ữ 215 | Ự 216 | ự 217 | Ỳ 218 | ỳ 219 | Ỵ 220 | ỵ 221 | Ỷ 222 | ỷ 223 | Ỹ 224 | ỹ 225 | € 226 | ™ 227 | -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.cli import LightningCLI 2 | from pytorch_lightning.callbacks import LearningRateMonitor 3 | from main import OCRModel 4 | from dataloader import OCRDataModule, DALI_OCRDataModule 5 | import os 6 | 7 | 8 | class OCRTrainingCLI(LightningCLI): 9 | def add_arguments_to_parser(self, parser): 10 | parser.add_argument("--save_dir", type=str, default="checkpoints", help="Directory to save checkpoints") 11 | parser.link_arguments("model.batch_max_length", "data.batch_max_length") 12 | parser.link_arguments("model.pred_name", "data.pred_name") 13 | parser.link_arguments("data.batch_size", "model.batch_size") 14 | parser.link_arguments("data.train_data_path", "model.train_data_path") 15 | parser.link_arguments("data.dali", "model.dali") 16 | 17 | def before_fit(self): 18 | save_dir = self.config.get("save_dir", "checkpoints") 19 | os.makedirs(save_dir, exist_ok=True) 20 | 21 | # Add LearningRateMonitor callback 22 | lr_monitor = LearningRateMonitor(logging_interval="step") 23 | self.trainer.callbacks.append(lr_monitor) 24 | 25 | 26 | def cli_main(): 27 | import sys 28 | 29 | # Check if --data.dali True is in the command line arguments 30 | use_dali = False 31 | for i, arg in enumerate(sys.argv): 32 | if arg == "--data.dali" and i + 1 < len(sys.argv) and sys.argv[i + 1].lower() == "true": 33 | use_dali = True 34 | break 35 | 36 | # Use DALI_OCRDataModule if --data.dali True, otherwise use OCRDataModule 37 | data_module = DALI_OCRDataModule if use_dali else OCRDataModule 38 | 39 | OCRTrainingCLI(OCRModel, data_module, save_config_kwargs={"overwrite": True}, seed_everything_default=42) 40 | 41 | 42 | if __name__ == "__main__": 43 | cli_main() 44 | -------------------------------------------------------------------------------- /script/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # python test.py \ 4 | # --checkpoint "/home/qsvm/temp/lib_ocr/checkpoints/run200/model_val_epoch_7_loss_0.7515_cer_0.0525_wer_0.1175.ckpt" \ 5 | # --data_root "/home/qsvm/dataset/eval/paper" \ 6 | # --batch_size 2 \ 7 | # --output_dir "results_lowercase" 8 | 9 | python test.py \ 10 | --checkpoint "/home/qsvm/temp/lib_ocr/checkpoints/run200/model_train_epoch_6.ckpt" \ 11 | --data_root "/home/qsvm/dataset/eval/paper" \ 12 | --batch_size 2 \ 13 | --output_dir "results_epoch_6" 14 | 15 | python test.py \ 16 | --checkpoint "/home/qsvm/temp/lib_ocr/checkpoints/run200/model_val_epoch_6_loss_0.7923_cer_0.0547_wer_0.1233.ckpt" \ 17 | --data_root "/home/qsvm/dataset/eval/paper" \ 18 | --batch_size 2 \ 19 | --output_dir "results_epoch_mid_6" 20 | 21 | python test.py \ 22 | --checkpoint "/home/qsvm/temp/lib_ocr/checkpoints/run200/model_train_epoch_4.ckpt" \ 23 | --data_root "/home/qsvm/dataset/eval/paper" \ 24 | --batch_size 2 \ 25 | --output_dir "results_epoch_4" 26 | 27 | python test.py \ 28 | --checkpoint "/home/qsvm/temp/lib_ocr/checkpoints/run200/model_val_epoch_3_loss_0.9336_cer_0.0679_wer_0.1515.ckpt" \ 29 | --data_root "/home/qsvm/dataset/eval/paper" \ 30 | --batch_size 2 \ 31 | --output_dir "results_epoch_mid_3" 32 | 33 | python test.py \ 34 | --checkpoint "/home/qsvm/temp/lib_ocr/checkpoints/run200/model_train_epoch_5.ckpt" \ 35 | --data_root "/home/qsvm/dataset/eval/paper" \ 36 | --batch_size 2 \ 37 | --output_dir "results_epoch_5" 38 | 39 | python test.py \ 40 | --checkpoint "/home/qsvm/temp/lib_ocr/checkpoints/run200/model_val_epoch_5_loss_0.8127_cer_0.0589_wer_0.1291.ckpt" \ 41 | --data_root "/home/qsvm/dataset/eval/paper" \ 42 | --batch_size 2 \ 43 | --output_dir "results_epoch_mid_5" -------------------------------------------------------------------------------- /data/dali.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import os 3 | import random 4 | 5 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class LightningWrapper(DALIGenericIterator): 11 | def __init__(self, pipelines, dataset_size, *args, **kwargs): 12 | super().__init__(pipelines=pipelines, *args, **kwargs) 13 | self.pipelines = pipelines 14 | self.dataset_size = dataset_size 15 | 16 | def __len__(self): 17 | return self.dataset_size 18 | 19 | def __next__(self): 20 | batch = super().__next__()[0] 21 | 22 | batch["data"] = batch["data"].permute(0, 3, 1, 2) 23 | batch["data"] = batch["data"].detach().clone() 24 | batch["label"] = batch["label"].detach().clone() 25 | return batch 26 | 27 | def __code__(self): 28 | return super().__code() 29 | 30 | 31 | class ExternalInputCallable(object): 32 | def __init__(self, steps_per_epoch, data_path, converter, images_names, labels, batch_size=32): 33 | self.data_path = data_path 34 | self.steps_per_epoch = steps_per_epoch 35 | self.converter = converter 36 | self.batch_size = batch_size 37 | 38 | self.images_names = images_names 39 | self.labels = labels 40 | 41 | self.data = list(zip(images_names, labels)) 42 | random.shuffle(self.data) 43 | 44 | def __call__(self, sample_info): 45 | idx = sample_info.idx_in_epoch 46 | if idx >= len(self.data): 47 | # Indicate end of the epoch 48 | raise StopIteration() 49 | image_name, label = self.data[idx % len(self.data)] 50 | image_path = os.path.join(self.data_path, image_name) 51 | 52 | with open(image_path, "rb") as f: 53 | file_bytes = f.read() 54 | 55 | image = np.frombuffer(file_bytes, dtype=np.uint8) 56 | encoded_label, length = self.converter.encode([label]) 57 | return image, torch.squeeze(encoded_label), length 58 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 2 | from nvidia.dali.pipeline import pipeline_def 3 | import nvidia.dali.types as types 4 | import nvidia.dali.fn as fn 5 | from pytorch_lightning import LightningDataModule 6 | import numpy as np 7 | 8 | 9 | class PredictLightningWrapper(DALIGenericIterator): 10 | def __init__(self, pipelines, *args, **kwargs): 11 | super().__init__(pipelines=pipelines, *args, **kwargs) 12 | self.pipelines = pipelines 13 | 14 | def __next__(self): 15 | batch = super().__next__()[0] 16 | 17 | batch["data"] = batch["data"].permute(0, 3, 1, 2) 18 | batch["data"] = batch["data"].detach().clone() 19 | return batch 20 | 21 | def __code__(self): 22 | return super().__code() 23 | 24 | 25 | class Inference(LightningDataModule): 26 | def __init__(self, image_path, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | self.image_path = image_path 29 | 30 | self.MEAN = np.asarray([0.485, 0.456, 0.406])[None, None, :] 31 | self.STD = np.asarray([0.229, 0.224, 0.225])[None, None, :] 32 | self.SCALE = 1 / 255.0 33 | 34 | def predict_dataloader(self): 35 | predict_pipeline = self.get_dali_predict_pipeline() 36 | predict_pipeline.build() 37 | self.predict_dataloader = PredictLightningWrapper( 38 | pipelines=predict_pipeline, 39 | output_map=["data"], 40 | ) 41 | return self.predict_dataloader 42 | 43 | @pipeline_def( 44 | num_threads=8, 45 | batch_size=1, 46 | device_id=0, 47 | py_start_method="spawn", 48 | ) 49 | def get_dali_predict_pipeline(self): 50 | image, _ = fn.readers.file(file_root=self.image_path, shard_id=0, num_shards=1) 51 | image = fn.decoders.image(image, device="mixed", output_type=types.RGB) 52 | image = fn.resize(image, device="gpu", resize_y=100, dtype=types.FLOAT) 53 | image = fn.normalize( 54 | image, device="gpu", dtype=types.FLOAT, mean=self.MEAN / self.SCALE, stddev=self.STD, scale=self.SCALE 55 | ) 56 | image = fn.pad(image, fill_value=0) 57 | return image 58 | -------------------------------------------------------------------------------- /data/vocab.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | 5 | 6 | class Vocab: 7 | def __init__(self): 8 | """ 9 | Initialize vocabulary from labels file 10 | Args: 11 | label_path: Path to labels file (CSV or JSON) 12 | """ 13 | 14 | def get_vocab_json(self, label_path): 15 | """Get vocabulary from JSON file""" 16 | with open(label_path, "r") as f: 17 | data = json.load(f) 18 | data = data["labels"] 19 | vocab = set() 20 | for label in data.values(): 21 | vocab.update(list(label)) 22 | return list(vocab) 23 | 24 | def get_vocab_csv(self, label_path): 25 | """Get vocabulary from CSV file with Unicode support""" 26 | # Try different encodings for Vietnamese text 27 | for encoding in ["utf-8", "utf-8-sig", "utf-16"]: 28 | try: 29 | df = pd.read_csv(label_path, encoding=encoding) 30 | break 31 | except UnicodeDecodeError: 32 | continue 33 | else: 34 | raise ValueError(f"Could not read CSV file with any of the supported encodings") 35 | 36 | if "label" not in df.columns: 37 | raise ValueError(f"CSV file must contain 'label' column. Found columns: {df.columns.tolist()}") 38 | 39 | vocab = set() 40 | # Convert labels to string to handle any numeric values 41 | for label in df["label"].astype(str): 42 | vocab.update(list(label)) 43 | 44 | # Add Vietnamese characters 45 | vi_char_path = os.path.join(os.path.dirname(__file__), "vi_char.txt") 46 | if os.path.exists(vi_char_path): 47 | try: 48 | with open(vi_char_path, "r", encoding="utf-8") as f: 49 | vi_chars = set(f.read().splitlines()) 50 | vocab.update(vi_chars) 51 | except Exception as e: 52 | raise ValueError(f"Error reading Vietnamese characters file: {str(e)}") 53 | 54 | return sorted(list(vocab)) # Sort for consistent ordering 55 | 56 | def get_vocab(self): 57 | char_path = os.path.join(os.path.dirname(__file__), "vocab.txt") 58 | with open(char_path, "r", encoding="utf-8") as f: 59 | vocab = set(f.read().splitlines()) 60 | return sorted(list(vocab)) 61 | -------------------------------------------------------------------------------- /models/backbones.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import timm 3 | 4 | 5 | class resnet50(nn.Module): 6 | def __init__(self, input_channels, output_channels=512): 7 | super().__init__() 8 | backbone = timm.create_model("resnet50", pretrained=True) 9 | core = list(backbone.children())[:-2] 10 | self.backbone = nn.Sequential(*core) 11 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) 12 | 13 | def forward(self, x): 14 | x = self.backbone(x) 15 | x = self.AdaptiveAvgPool(x.permute(0, 3, 1, 2)) 16 | x = x.squeeze(3) 17 | return x 18 | 19 | 20 | class resnet18(nn.Module): 21 | def __init__(self, input_channels, output_channels=512): 22 | super().__init__() 23 | backbone = timm.create_model("resnet18", pretrained=True) 24 | core = list(backbone.children())[:-2] 25 | self.backbone = nn.Sequential(*core) 26 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) 27 | 28 | def forward(self, x): 29 | x = self.backbone(x) 30 | x = self.AdaptiveAvgPool(x.permute(0, 3, 1, 2)) 31 | x = x.squeeze(3) 32 | return x 33 | 34 | 35 | class VGG(nn.Module): 36 | def __init__(self, input_channels, output_channels=512): 37 | super().__init__() 38 | self.output_channels = [ 39 | int(output_channels / 8), 40 | int(output_channels / 4), 41 | int(output_channels / 2), 42 | output_channels, 43 | ] # 64, 128, 256, 512 44 | self.ConvNet = nn.Sequential( # Input: 3 x 100 x 420 45 | nn.Conv2d(input_channels, self.output_channels[0], 3, 1, 0), 46 | nn.ReLU(True), # 64 x 98 x 418 47 | nn.MaxPool2d(2, 2), # 64 x 49 x 209 48 | nn.Conv2d(self.output_channels[0], self.output_channels[1], 3, 1, 0), 49 | nn.ReLU(True), # 128 x 48 x 208 50 | nn.MaxPool2d(2, 2), # 128 x 24 x 104 51 | nn.Conv2d(self.output_channels[1], self.output_channels[2], 3, 1, 0), 52 | nn.ReLU(True), # 256 x 22 x 102 53 | nn.Conv2d(self.output_channels[2], self.output_channels[2], 3, 1, 0), 54 | nn.ReLU(True), # 256 x 20 x 100 55 | nn.MaxPool2d(2, 2), # 256 x 10 x 50 56 | nn.Conv2d(self.output_channels[2], self.output_channels[3], 3, 1, 0), 57 | nn.ReLU(True), # 512 x 8 x 48 58 | nn.BatchNorm2d(self.output_channels[3]), 59 | nn.ReLU(True), 60 | nn.Conv2d(self.output_channels[3], self.output_channels[3], 3, 1, 0), 61 | nn.ReLU(True), # 512 x 6 x 46 62 | nn.BatchNorm2d(self.output_channels[3]), 63 | nn.ReLU(True), 64 | nn.MaxPool2d(2, 2), # 512 x 3 x 23 65 | nn.Conv2d(self.output_channels[3], self.output_channels[3], 2, 1, 0), 66 | nn.ReLU(True), # 512 x 1 x 21 67 | ) 68 | 69 | def forward(self, x): 70 | conv = self.ConvNet(x) # batch_size x 512 x 5 x 104 71 | # Reshape to (batch_size, sequence_length, channels) for CTC 72 | conv = conv.permute(0, 3, 1, 2) # [b, w, c, h] 73 | conv = conv.squeeze(-1) # [b, w, c] 74 | return conv # [batch_size, sequence_length, channels] 75 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | from torch.utils.data import Dataset 5 | import pandas as pd 6 | from loguru import logger 7 | from utils import CTCLabelConverter, AttnLabelConverter 8 | from data.vocab import Vocab 9 | 10 | 11 | class OCRDataset(Dataset): 12 | def __init__(self, data_path, batch_max_length, frac, pred_name="attn", transform=None): 13 | self.data_path = data_path 14 | self.transform = transform 15 | self.batch_max_length = batch_max_length 16 | 17 | images, labels = process_tgt(data_path, batch_max_length, frac=frac) 18 | self.data = list(zip(images, labels)) # list(zip(df['image_name'], df['label'])) 19 | logger.debug("Get Vocab") 20 | path = os.path.join(self.data_path, "tgt.csv") 21 | vocab = Vocab().get_vocab() 22 | logger.debug(f"{pred_name=}") 23 | if pred_name == "ctc": 24 | self.converter = CTCLabelConverter(vocab, device="cpu") 25 | else: 26 | self.converter = AttnLabelConverter(vocab, batch_max_length=self.batch_max_length, device="cpu") 27 | 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def __getitem__(self, idx): 32 | image_name, label = self.data[idx] 33 | image_path = os.path.join(self.data_path, "images", image_name) 34 | 35 | image = Image.open(image_path).convert("RGB") 36 | 37 | if self.transform: 38 | image = self.transform(image) 39 | 40 | encoded_label, length = self.converter.encode([label]) 41 | 42 | return image, torch.squeeze(encoded_label), length 43 | 44 | 45 | def process_tgt(data_path, batch_max_length, frac): 46 | # Check for tgt.csv file 47 | tgt_path = os.path.join(data_path, "tgt.csv") 48 | if not os.path.exists(tgt_path): 49 | raise FileNotFoundError(f"Label file not found at {tgt_path}") 50 | 51 | # Try different encodings for Vietnamese text support 52 | encodings = ["utf-8", "utf-8-sig", "utf-16"] 53 | for encoding in encodings: 54 | try: 55 | df = pd.read_csv(tgt_path, encoding=encoding) 56 | break 57 | except UnicodeDecodeError: 58 | if encoding == encodings[-1]: 59 | raise UnicodeDecodeError(f"Failed to read CSV with encodings: {encodings}") 60 | continue 61 | 62 | # Validate required columns 63 | required_cols = ["image_name", "label"] 64 | missing_cols = [col for col in required_cols if col not in df.columns] 65 | if missing_cols: 66 | raise ValueError(f"Missing required columns in CSV: {missing_cols}") 67 | 68 | # Convert values to strings and filter by length 69 | df["image_name"] = df["image_name"].astype(str) 70 | df["label"] = df["label"].astype(str) 71 | 72 | # Filter out samples exceeding batch_max_length 73 | total_samples = len(df) 74 | df = df[df["label"].str.len() <= batch_max_length].sample(frac=frac, random_state=42) 75 | filtered_samples = total_samples - len(df) 76 | if filtered_samples > 0: 77 | print( 78 | f"Filtered out {filtered_samples} samples ({filtered_samples / total_samples * 100:.2f}%) exceeding max length {batch_max_length}" 79 | ) 80 | 81 | return list(df["image_name"]), list(df["label"]) 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results*/ 2 | notebook/ 3 | experiment/ 4 | # images 5 | label_analysis/ 6 | analyze_labels.py 7 | label_distribution.png 8 | 9 | training_images/images 10 | training_images/labels.json 11 | debug/ 12 | lightning_logs/ 13 | wandb/ 14 | checkpoints/ 15 | 16 | config.yaml 17 | note.md 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | share/python-wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .nox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | cover/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | db.sqlite3-journal 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | .pybuilder/ 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | # For a library or package, you might want to ignore these files since the code is 104 | # intended to run in multiple environments; otherwise, check them in: 105 | # .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # UV 115 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 116 | # This is especially recommended for binary packages to ensure reproducibility, and is more 117 | # commonly ignored for libraries. 118 | #uv.lock 119 | 120 | # poetry 121 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 122 | # This is especially recommended for binary packages to ensure reproducibility, and is more 123 | # commonly ignored for libraries. 124 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 125 | #poetry.lock 126 | 127 | # pdm 128 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 129 | #pdm.lock 130 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 131 | # in version control. 132 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 133 | .pdm.toml 134 | .pdm-python 135 | .pdm-build/ 136 | 137 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 138 | __pypackages__/ 139 | 140 | # Celery stuff 141 | celerybeat-schedule 142 | celerybeat.pid 143 | 144 | # SageMath parsed files 145 | *.sage.py 146 | 147 | # Environments 148 | .env 149 | .venv 150 | env/ 151 | venv/ 152 | ENV/ 153 | env.bak/ 154 | venv.bak/ 155 | 156 | # Spyder project settings 157 | .spyderproject 158 | .spyproject 159 | 160 | # Rope project settings 161 | .ropeproject 162 | 163 | # mkdocs documentation 164 | /site 165 | 166 | # mypy 167 | .mypy_cache/ 168 | .dmypy.json 169 | dmypy.json 170 | 171 | # Pyre type checker 172 | .pyre/ 173 | 174 | # pytype static type analyzer 175 | .pytype/ 176 | 177 | # Cython debug symbols 178 | cython_debug/ 179 | 180 | # PyCharm 181 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 182 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 183 | # and can be added to the global gitignore or merged into this file. For a more nuclear 184 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 185 | #.idea/ 186 | 187 | # PyPI configuration file 188 | .pypirc 189 | 190 | # Checkpoint 191 | *.ckpt 192 | images/dog/dog_2.jpg 193 | results.txt 194 | -------------------------------------------------------------------------------- /models/pred_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class CTC(nn.Module): 9 | def __init__(self, input_dim, num_classes): 10 | super().__init__() 11 | self.fc = nn.Linear(input_dim, num_classes + 1) # +1 for blank token 12 | 13 | def forward(self, x, text=None, is_train=None, batch_max_length=None): 14 | return self.fc(x) 15 | 16 | 17 | class Attention(nn.Module): 18 | def __init__(self, input_dim, num_classes, hidden_dim=256): 19 | super().__init__() 20 | self.input_dim = input_dim 21 | self.hidden_dim = hidden_dim 22 | self.num_classes = num_classes 23 | self.attention_cell = AttentionCell(input_dim, hidden_dim, num_classes) 24 | self.generator = nn.Linear(hidden_dim, num_classes) 25 | # Initialize weights 26 | nn.init.xavier_uniform_(self.generator.weight) 27 | if self.generator.bias is not None: 28 | nn.init.constant_(self.generator.bias, 0) 29 | 30 | def _char_to_onehot(self, input_char, onehot_dim=None): 31 | if onehot_dim is None: 32 | onehot_dim = self.num_classes 33 | input_char = input_char.unsqueeze(1) 34 | batch_size = input_char.size(0) 35 | onehot = torch.zeros(batch_size, onehot_dim, dtype=torch.float32, device=device) 36 | onehot = onehot.scatter_(1, input_char, 1) 37 | return onehot 38 | 39 | def forward(self, batch_H, batch_max_length, text=None, is_train=True): 40 | batch_size = batch_H.size(0) 41 | num_steps = batch_max_length + 1 42 | 43 | output_hiddens = torch.zeros(batch_size, num_steps, self.hidden_dim, dtype=torch.float32, device=device) 44 | hidden = torch.zeros(batch_size, self.hidden_dim, dtype=torch.float32, device=device) 45 | 46 | if is_train: 47 | for i in range(num_steps): 48 | char_onehots = self._char_to_onehot(text[:, i]) 49 | hidden = self.attention_cell(hidden, batch_H, char_onehots) 50 | output_hiddens[:, i, :] = hidden 51 | logits = self.generator(output_hiddens.view(-1, self.hidden_dim)) 52 | probs = logits.view(batch_size, num_steps, -1) # Let the loss function handle the softmax 53 | 54 | else: 55 | # Inference mode or when text is not provided 56 | probs = torch.zeros(batch_size, num_steps, self.num_classes, dtype=torch.float32, device=device) 57 | target = torch.zeros(batch_size, dtype=torch.long, device=device) 58 | for i in range(num_steps): 59 | char_onehots = self._char_to_onehot(target) 60 | hidden = self.attention_cell(hidden, batch_H, char_onehots) 61 | probs_step = self.generator(hidden) 62 | probs[:, i, :] = probs_step 63 | _, target = probs_step.max(dim=1) 64 | 65 | return probs # (batch_size, num_steps, num_classes) 66 | 67 | 68 | class AttentionCell(nn.Module): 69 | def __init__(self, input_dim, hidden_dim, output_dim): 70 | super().__init__() 71 | self.hidden_dim = hidden_dim 72 | self.i2h = nn.Linear(input_dim, hidden_dim, bias=False) 73 | self.h2h = nn.Linear(hidden_dim, hidden_dim) 74 | self.score = nn.Linear(hidden_dim, 1, bias=False) 75 | self.rnn = nn.GRUCell(input_dim + output_dim, hidden_dim) 76 | 77 | # Initialize weights 78 | nn.init.xavier_uniform_(self.i2h.weight) 79 | nn.init.xavier_uniform_(self.h2h.weight) 80 | if self.h2h.bias is not None: 81 | nn.init.constant_(self.h2h.bias, 0) 82 | nn.init.xavier_uniform_(self.score.weight) 83 | 84 | # Initialize GRU weights 85 | for name, param in self.rnn.named_parameters(): 86 | if "weight" in name: 87 | nn.init.orthogonal_(param) 88 | elif "bias" in name: 89 | nn.init.constant_(param, 0) 90 | 91 | def forward(self, prev_hidden, batch_H, char_onehots): 92 | # [batch_size, num_steps, input_dim] -> [batch_size, num_steps, hidden_dim] 93 | batch_H_proj = self.i2h(batch_H) 94 | prev_hidden_proj = self.h2h(prev_hidden).unsqueeze(1) 95 | 96 | # Scaled dot-product attention 97 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) 98 | 99 | # Apply attention with temperature scaling 100 | alpha = F.softmax(e, dim=1) # Equation 5, batch_size x num_steps x 1 101 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # Equation 3, batch_size x input_dim 102 | 103 | # Concatenate context with character embedding 104 | concat_context = torch.cat([context, char_onehots], dim=1) # batch_size x (input_dim + output_dim) 105 | 106 | # Update hidden state with GRU 107 | hidden = self.rnn(concat_context, prev_hidden) 108 | 109 | return hidden 110 | 111 | 112 | # Postprocessing: Greedy, Beam Search 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VietConizer: Vietnamese OCR with NVIDIA DALI 2 | 3 | This repository provides a flexible and high-performance Optical Character Recognition (OCR) system built with PyTorch Lightning and accelerated using NVIDIA DALI for data loading. 4 | 5 | Check out my detailed documentation to learn about data handling, training approaches, and the performance benefits of NVIDIA DALI: **[Accelerating OCR Training with NVIDIA DALI: A Practical Guide and Case Study](https://ducto489.github.io/projects/ocr-dali/)** 6 | 7 | 8 | ![clickbait image](/image/Pytorch-Dataloader.png) 9 | 10 | _**Left (PyTorch DataLoader):** The GPU frequently idles or is underutilized, indicating data bottlenecks. **Right (NVIDIA DALI):** The GPU maintains consistently high utilization. DALI keeps the L4 GPU working hard, reducing wasted cycles and speeding up training._ 11 | 12 | ## Overview 13 | 14 | This project implements an end-to-end training OCR pipeline, featuring: 15 | 16 | * **Multiple Architectures:** Easily swap CNN backbones (ResNet, VGG) and sequence modeling layers (BiLSTM). 17 | * **Prediction Methods:** Supports both CTC (Connectionist Temporal Classification) and Attention-based decoding. 18 | * **High-Performance Data Loading:** Integrates NVIDIA DALI for significantly faster data preprocessing and loading, especially beneficial on NVIDIA GPUs. 19 | * **Structured Training:** Leverages PyTorch Lightning for organized, reproducible training workflows, including logging (e.g., Wandb) and checkpointing. 20 | 21 | ## Table of Contents 22 | 23 | - [Acknowledge](#acknowledgements) 24 | - [Setup](#setup) 25 | - [Training](#training) 26 | - [Inference](#inference) 27 | - [Evaluation](#evaluation) 28 | 29 | ## Acknowledgements 30 | 31 | Special thanks to **Trong Tuan** ([@santapo](https://github.com/santapo)) and **Phuong** ([@mp1704](https://github.com/mp1704)) for their significant help to this project. 32 | 33 | ## Setup 34 | 35 | Follow these steps to set up your environment and download the necessary data. 36 | 37 | ### Get the Code 38 | 39 | ```bash 40 | git clone https://github.com/AnyGlow/lib_ocr.git 41 | cd lib_ocr 42 | ``` 43 | 44 | ### Install Dependencies 45 | 46 | ```bash 47 | conda create -n ocr python=3.11 48 | conda activate ocr 49 | 50 | pip install -r requirements.txt 51 | ``` 52 | 53 | ### Download Data 54 | 55 | The primary dataset can be downloaded from [Hugging Face](https://huggingface.co/datasets/ducto489/ocr_datasets): 56 | 57 | ```bash 58 | pip install huggingface_hub 59 | 60 | huggingface-cli download ducto489/ocr_datasets ocr_dataset.zip --repo-type dataset --local-dir . 61 | unzip ocr_dataset.zip -d /path/to/your/data/directory 62 | ``` 63 | 64 | Replace `/path/to/your/data/directory` with the actual path where you want to store the data. 65 | 66 | ### Data Structure 67 | 68 | Organize the downloaded data into the following structure: 69 | 70 | ``` 71 | /path/to/your/data/directory/ 72 | ├── train/ 73 | │ ├── images/ # Directory containing training images 74 | │ │ ├── image1.jpg 75 | │ │ └── ... 76 | │ └── tgt.csv # CSV with image filenames and labels 77 | └── val/ 78 | ├── images/ # Directory containing validation images 79 | │ ├── image1.jpg 80 | │ └── ... 81 | └── tgt.csv # CSV with image filenames and labels 82 | ``` 83 | 84 | The `tgt.csv` should contain image names and their corresponding text labels. 85 | 86 | ## Training 87 | 88 | Use the `cli.py` script with the `fit` command to start training. Key parameters can be adjusted via command-line arguments. 89 | 90 | **Example Training Command (modify paths and parameters as needed):** 91 | 92 | ```bash 93 | # Found in script/train.sh 94 | python cli.py fit \ 95 | --data.train_data_path "/path/to/your/data/directory/train" \ 96 | --data.val_data_path "/path/to/your/data/directory/val" \ 97 | --data.batch_size 64 \ 98 | --data.num_workers 8 \ 99 | --data.dali True \ 100 | --data.frac 1.0 \ 101 | --model.backbone_name "resnet18" \ 102 | --model.seq_name "bilstm" \ 103 | --model.pred_name "attn" \ 104 | --model.learning_rate 1e-4 \ 105 | --model.batch_max_length 200 \ 106 | --model.save_dir "checkpoints/my_experiment" \ 107 | ``` 108 | 109 | **Key Training Parameters:** 110 | 111 | * `--data.train_data_path`, `--data.val_data_path`: Paths to your training and validation data. 112 | * `--data.batch_size`, `--data.num_workers`: Configure data loading. 113 | * `--data.dali`: Set to `True` to use NVIDIA DALI. 114 | * `--data.frac`: Use a fraction of the data (e.g., `0.1` for 10%). 115 | * `--model.backbone_name`: CNN feature extractor (`resnet18`, `vgg`). 116 | * `--model.seq_name`: Sequence model (`bilstm`, `none`). 117 | * `--model.pred_name`: Prediction head (`ctc`, `attn`). 118 | * `--model.learning_rate`: Optimizer learning rate. 119 | * `--model.batch_max_length`: Max sequence length for padding/processing. 120 | * `--model.save_dir`: Where to save model checkpoints. 121 | * `--trainer.*`: PyTorch Lightning trainer configurations (epochs, validation frequency, logger, etc.). 122 | 123 | Refer to `python cli.py fit --help` for all available options. 124 | 125 | ## Inference 126 | 127 | Checkout the Jupyter notebook `inference/inference.ipynb` to run predictions on new images using a trained checkpoint from [Hugging Face](https://huggingface.co/ducto489/ocr_model). 128 | 129 | | Backbone | Time | 130 | |-----------------------------------------|---------------| 131 | | Our model: Resnet - Bilstm - Attention | 73ms @ A6000 | 132 | | VietOCR: VGG19-bn - Transformer | 565ms @ A6000 | 133 | | VietOCR: VGG19-bn - Seq2Seq | 30ms @ A6000 | 134 | 135 | ## Evaluation 136 | 137 | Use the `test.py` script to evaluate a trained model's performance on standard OCR benchmark datasets. 138 | 139 | **Example Evaluation Command:** 140 | 141 | ```bash 142 | # Found in script/test.sh 143 | python test.py \ 144 | --checkpoint "/path/to/your/checkpoints/my_experiment/model.ckpt" \ 145 | --data_root "/path/to/evaluation/datasets/root" \ 146 | --output_dir "evaluation_results" \ 147 | --batch_size 32 148 | ``` 149 | 150 | **Evaluation Parameters:** 151 | 152 | * `--checkpoint`: Path to the trained model checkpoint (`.ckpt`) file. 153 | * `--data_root`: The root directory containing benchmark datasets (e.g., IIIT5k, SVT, IC13, etc.). The script expects specific subdirectories for each benchmark. 154 | * `--output_dir`: Directory to save evaluation results. 155 | * `--batch_size`: Batch size for evaluation. 156 | 157 | The script will automatically detect and run evaluation on supported datasets found within the `--data_root`, such as: 158 | CUTE80, IC03_860, IC03_867, IC13_1015, IC13_857, IC15_1811, IC15_2077, IIIT5k_3000, SVT, SVTP. 159 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # OCR Inference Script 3 | 4 | import os 5 | import argparse 6 | import torch 7 | import numpy as np 8 | from pathlib import Path 9 | import torch.nn.functional as F 10 | 11 | from main import OCRModel 12 | from dataloader import DALI_OCRDataModule 13 | 14 | eval_data_path = [ 15 | "CUTE80_png", 16 | "IC03_860_png", 17 | "IC03_867_png", 18 | "IC13_1015_png", 19 | "IC13_857_png", 20 | "IC15_1811_png", 21 | "IC15_2077_png", 22 | "IIIT5k_3000_png", 23 | "SVT_png", 24 | "SVTP_png", 25 | ] 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description="OCR Inference Script") 30 | parser.add_argument("--checkpoint", type=str, required=True, help="Path to the model checkpoint") 31 | parser.add_argument("--data_root", type=str, required=True, help="Root directory containing evaluation datasets") 32 | parser.add_argument("--output_dir", type=str, default="results", help="Directory to save results") 33 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") 34 | return parser.parse_args() 35 | 36 | 37 | def evaluate_dataset(model, dataset_path, args, device): 38 | # Initialize DALI_OCRDataModule for this dataset 39 | data_module = DALI_OCRDataModule( 40 | batch_max_length=200, 41 | frac=1.0, 42 | dali=True, 43 | train_data_path=dataset_path, 44 | val_data_path=dataset_path, 45 | batch_size=args.batch_size, 46 | num_workers=4, 47 | pred_name=model.pred_name, 48 | ) 49 | 50 | # Get validation dataloader for inference 51 | val_dataloader = data_module.val_dataloader() 52 | 53 | # Process images and gather results 54 | results = [] 55 | n_correct = 0 56 | total_samples = 0 57 | confidence_score_list = [] 58 | 59 | for batch in val_dataloader: 60 | images = batch["data"] 61 | labels = batch["label"] 62 | 63 | # Move to device 64 | images = images.to(device) 65 | 66 | with torch.no_grad(): 67 | # Forward pass based on prediction type 68 | if model.pred_name == "ctc": 69 | logits = model(images, text=labels) 70 | log_probs = logits.log_softmax(2).permute(1, 0, 2) 71 | preds = log_probs.argmax(2).permute(1, 0).detach().cpu() 72 | text = data_module.converter.decode(preds, None) 73 | else: # Attention model 74 | preds = model(images, text=labels[:, :-1]).to(device) 75 | pred_size = torch.LongTensor([preds.size(1)] * preds.size(0)) 76 | _, pred_index = preds.max(2) 77 | text = data_module.converter.decode(pred_index, pred_size) 78 | 79 | # Calculate confidence scores 80 | preds_prob = F.softmax(preds, dim=2) 81 | preds_max_prob, _ = preds_prob.max(dim=2) 82 | 83 | # Get ground truth 84 | ground_truth = data_module.converter.decode(labels, None) 85 | 86 | # Process predictions 87 | for i, (pred, gt, pred_max_prob) in enumerate(zip(text, ground_truth, preds_max_prob)): 88 | total_samples += 1 89 | 90 | if model.pred_name == "attn": 91 | # Handle end of sentence token for attention model 92 | eos_idx = "[EOS]" 93 | if eos_idx in gt: 94 | gt = gt[: gt.find(eos_idx)] 95 | if eos_idx in pred: 96 | pred_EOS = pred.find(eos_idx) 97 | pred = pred[:pred_EOS] 98 | pred_max_prob = pred_max_prob[:pred_EOS] 99 | 100 | if pred.lower() == gt.lower(): 101 | n_correct += 1 102 | 103 | # Calculate confidence score 104 | try: 105 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 106 | except: 107 | confidence_score = 0 # for empty pred case 108 | 109 | confidence_score_list.append(confidence_score) 110 | results.append(f"{gt}\t{pred}\t{confidence_score.item():.4f}") 111 | 112 | # Calculate overall accuracy 113 | accuracy = n_correct / float(total_samples) * 100 if total_samples > 0 else 0 114 | avg_confidence = sum(confidence_score_list) / len(confidence_score_list) if confidence_score_list else 0 115 | 116 | return { 117 | "results": results, 118 | "accuracy": accuracy, 119 | "avg_confidence": avg_confidence, 120 | "total_samples": total_samples, 121 | "correct_samples": n_correct, 122 | } 123 | 124 | 125 | def main(): 126 | args = parse_args() 127 | 128 | # Create output directory if it doesn't exist 129 | os.makedirs(args.output_dir, exist_ok=True) 130 | 131 | # Set device 132 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 133 | print(f"Using device: {device}") 134 | 135 | # Load model from checkpoint 136 | model = OCRModel.load_from_checkpoint( 137 | args.checkpoint, strict=False, batch_max_length=200, dali=True, map_location=device, pred_name="attn" 138 | ) 139 | model.eval() 140 | model.to(device) 141 | 142 | # Prepare to collect overall statistics 143 | all_results = {} 144 | total_correct = 0 145 | total_samples = 0 146 | 147 | # Evaluate each dataset 148 | for dataset in eval_data_path: 149 | dataset_path = os.path.join(args.data_root, dataset) 150 | if not os.path.exists(dataset_path): 151 | print(f"Warning: Dataset path {dataset_path} does not exist. Skipping.") 152 | continue 153 | 154 | print(f"\nEvaluating dataset: {dataset}") 155 | 156 | # Run evaluation 157 | eval_results = evaluate_dataset(model, dataset_path, args, device) 158 | 159 | # Save results 160 | output_file = os.path.join(args.output_dir, f"{dataset}_results.txt") 161 | with open(output_file, "w", encoding="utf-8") as f: 162 | f.write("Ground Truth\tPrediction\tConfidence\n") 163 | f.write("\n".join(eval_results["results"])) 164 | 165 | # Print statistics 166 | print(f"Dataset: {dataset}") 167 | print(f" Accuracy: {eval_results['accuracy']:.2f}%") 168 | print(f" Average confidence: {eval_results['avg_confidence']:.4f}") 169 | print(f" Samples: {eval_results['total_samples']}") 170 | print(f" Results saved to {output_file}") 171 | 172 | # Store results for summary 173 | all_results[dataset] = eval_results 174 | total_correct += eval_results["correct_samples"] 175 | total_samples += eval_results["total_samples"] 176 | 177 | # Calculate and print overall statistics 178 | if total_samples > 0: 179 | overall_accuracy = (total_correct / total_samples) * 100 180 | print("\n==== Overall Results ====") 181 | print(f"Total accuracy: {overall_accuracy:.2f}%") 182 | print(f"Total samples: {total_samples}") 183 | 184 | # Save summary to file 185 | summary_file = os.path.join(args.output_dir, "summary.txt") 186 | with open(summary_file, "w", encoding="utf-8") as f: 187 | f.write("Dataset\tAccuracy\tSamples\n") 188 | for dataset, results in all_results.items(): 189 | f.write(f"{dataset}\t{results['accuracy']:.2f}%\t{results['total_samples']}\n") 190 | f.write(f"\nOverall\t{overall_accuracy:.2f}%\t{total_samples}\n") 191 | 192 | print(f"Summary saved to {summary_file}") 193 | 194 | 195 | if __name__ == "__main__": 196 | main() 197 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | 6 | 7 | class CTCLabelConverter: 8 | def __init__(self, vocab, device="cuda"): 9 | # Add blank token at position 0 10 | self.SPACE = " " 11 | self.BLANK = "[blank]" 12 | self.UNKNOWN = "[UNK]" 13 | 14 | self.vocab = [self.BLANK] + list(vocab) + [self.UNKNOWN] 15 | self.dict = {char: idx for idx, char in enumerate(self.vocab)} 16 | self.device = device 17 | 18 | def encode(self, texts): 19 | """ 20 | Convert text strings to encoded tensor and lengths 21 | """ 22 | lengths = [len(text) for text in texts] 23 | 24 | # Convert characters to indices, using UNKNOWN token for unknown chars 25 | encoded = [] 26 | for text in texts: 27 | text_encoded = [] 28 | for char in text: 29 | if char in self.dict: 30 | text_encoded.append(self.dict[char]) 31 | else: 32 | text_encoded.append(self.dict[self.UNKNOWN]) 33 | encoded.append(text_encoded) 34 | 35 | # Create padded tensor 36 | max_length = max(lengths) 37 | batch_size = len(texts) 38 | batch_tensor = torch.zeros(batch_size, max_length).long() 39 | 40 | for i, text_encoded in enumerate(encoded): 41 | batch_tensor[i, : len(text_encoded)] = torch.tensor(text_encoded) 42 | 43 | batch_tensor = batch_tensor.to(self.device) 44 | lengths = torch.tensor(lengths).to(self.device) 45 | 46 | return batch_tensor, lengths 47 | 48 | def decode(self, text_indices, length): 49 | """ 50 | Convert encoded indices back to text strings 51 | """ 52 | texts = [] 53 | for indices, l in zip(text_indices, length): 54 | text = "".join([self.vocab[idx] for idx in indices[:l]]) 55 | texts.append(text) 56 | return texts 57 | 58 | def decode_v1(self, text_index, length): 59 | """convert text-index into text-label.""" 60 | texts = [] 61 | for index, l in enumerate(length): 62 | t = text_index[index, :] 63 | 64 | char_list = [] 65 | for i in range(l): 66 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 67 | char_list.append(self.vocab[t[i]]) 68 | text = "".join(char_list) 69 | 70 | texts.append(text) 71 | return texts 72 | 73 | 74 | class CTCLabelConverter_clovaai(object): 75 | """Convert between text-label and text-index""" 76 | 77 | def __init__(self, character, device): 78 | self.device = device 79 | # character (str): set of the possible characters. 80 | dict_character = list(character) 81 | 82 | self.dict = {} 83 | for i, char in enumerate(dict_character): 84 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 85 | self.dict[char] = i + 1 86 | 87 | self.character = ["[CTCblank]"] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 88 | 89 | def encode(self, text, batch_max_length=25): 90 | """convert text-label into text-index. 91 | input: 92 | text: text labels of each image. [batch_size] 93 | batch_max_length: max length of text label in the batch. 25 by default 94 | 95 | output: 96 | text: text index for CTCLoss. [batch_size, batch_max_length] 97 | length: length of each text. [batch_size] 98 | """ 99 | length = [len(s) for s in text] 100 | 101 | # The index used for padding (=0) would not affect the CTC loss calculation. 102 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) 103 | for i, t in enumerate(text): 104 | text = list(t) 105 | text = [self.dict[char] for char in text] 106 | batch_text[i][: len(text)] = torch.LongTensor(text) 107 | return (batch_text.to(self.device), torch.LongTensor(length).to(self.device)) 108 | 109 | def decode(self, text_index, length=None): 110 | """convert text-index into text-label.""" 111 | texts = [] 112 | for index, l in enumerate(length): 113 | t = text_index[index, :] 114 | 115 | char_list = [] 116 | for i in range(l): 117 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 118 | char_list.append(self.character[t[i]]) 119 | text = "".join(char_list) 120 | 121 | texts.append(text) 122 | return texts 123 | 124 | 125 | class AttnLabelConverter: 126 | def __init__(self, character, batch_max_length, device="cuda"): 127 | list_token = ["[GO]", "[EOS]"] 128 | self.character = list_token + list(character) 129 | self.device = device 130 | self.batch_max_length = batch_max_length 131 | self.dict = {char: idx for idx, char in enumerate(self.character)} 132 | 133 | def encode(self, text): 134 | """convert text-label into text-index.""" 135 | length = [len(s) + 1 for s in text] 136 | batch_max_length = self.batch_max_length 137 | batch_max_length += 1 138 | batch_text = torch.zeros(len(text), batch_max_length + 1).long() 139 | for i, t in enumerate(text): 140 | text = list(t) 141 | text.append("[EOS]") 142 | text_encoded = [] 143 | for char in text: 144 | try: 145 | text_encoded.append(self.dict[char]) 146 | except KeyError: 147 | # Handle unknown character by adding a placeholder index 148 | text_encoded.append(self.dict[" "]) 149 | batch_text[i][1 : 1 + len(text_encoded)] = torch.LongTensor(text_encoded) 150 | return (batch_text.to(self.device), torch.LongTensor(length).to(self.device)) 151 | 152 | def decode(self, text_index, length): 153 | """convert text-index into text-label.""" 154 | texts = [] 155 | for index in range(text_index.size(0)): 156 | text = "" 157 | for i in range(0, text_index.size(1)): 158 | char_index = text_index[index, i].item() # Use .item() to get Python int 159 | if char_index == self.dict["[EOS]"]: 160 | break # Stop decoding at [EOS] 161 | if char_index == self.dict["[GO]"]: # Skip [GO] token in output text 162 | continue 163 | text += self.character[char_index] 164 | texts.append(text) 165 | return texts 166 | 167 | 168 | def test_ctc_label_converter(text=["hel` lo", "world"]): 169 | vocab = "QWERTYUIOPASDFGHJKLZXCVBNMqwertyuiopasdfghjklzxcvbnm" 170 | converter = CTCLabelConverter(vocab, device="cpu") 171 | text_index, length = converter.encode(text) 172 | print("Original:", text) 173 | print("Encoded:", text_index) 174 | print("Decoded:", converter.decode_v1(text_index, length)) 175 | 176 | 177 | def test_ctc_label_converter_clovaai(): 178 | import string 179 | 180 | vocab = string.printable[:-6] 181 | converter = CTCLabelConverter_clovaai(vocab, device="cpu") 182 | text = ["hello ", "world"] 183 | text_index, length = converter.encode(text) 184 | print("Original:", text) 185 | print("Encoded:", text_index) 186 | print("Decoded:", converter.decode(text_index, length)) 187 | 188 | 189 | def test_attn_label_converter(text=["hel` lo", "world"]): 190 | vocab = Vocab() 191 | converter = AttnLabelConverter(vocab.get_vocab(), 10, device="cpu") 192 | text_index, _ = converter.encode(text) 193 | print("Original:", text) 194 | print("Encoded:", text_index) 195 | print("Decoded:", converter.decode(text_index, None)) 196 | import torch.nn as nn 197 | 198 | loss = nn.CrossEntropyLoss(ignore_index=0) 199 | # print(text[0][1:]) 200 | print(converter.encode([text[0][1:]])[0]) 201 | print(loss(converter.encode([text[0][1:]])[0], text_index)) 202 | 203 | 204 | class SentenceErrorRate(Metric): 205 | """Custom metric to compute Sentence Error Rate (SER). 206 | SER is the percentage of sentences that contain any errors.""" 207 | 208 | def __init__(self): 209 | super().__init__() 210 | self.add_state("incorrect", default=torch.tensor(0), dist_reduce_fx="sum") 211 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 212 | 213 | def update(self, preds: list, target: list): 214 | """Update state with predictions and targets. 215 | 216 | Args: 217 | preds: List of predicted sentences 218 | target: List of target sentences 219 | """ 220 | assert len(preds) == len(target), "Number of predictions and targets must match" 221 | 222 | # Count sentences with any errors 223 | incorrect = sum(1 for p, t in zip(preds, target) if p != t) 224 | self.incorrect += torch.tensor(incorrect) 225 | self.total += torch.tensor(len(preds)) 226 | 227 | def compute(self): 228 | """Compute the sentence error rate.""" 229 | return self.incorrect.float() / self.total 230 | 231 | 232 | if __name__ == "__main__": 233 | test_attn_label_converter(text=["hello "]) 234 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningModule 2 | import torch 3 | from torch import nn 4 | from loguru import logger 5 | from torchmetrics.text import CharErrorRate, WordErrorRate 6 | import time 7 | 8 | from models import get_module 9 | from utils import SentenceErrorRate 10 | 11 | from utils import CTCLabelConverter_clovaai, AttnLabelConverter 12 | from data.vocab import Vocab 13 | 14 | import os 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | class OCRModel(LightningModule): 20 | def __init__( 21 | self, 22 | batch_max_length, 23 | dali, 24 | backbone_name: str = "resnet18", 25 | seq_name: str = "bilstm", 26 | pred_name: str = "ctc", 27 | batch_size: int = 64, 28 | learning_rate: float = 1e-3, 29 | weight_decay: float = 1e-5, 30 | save_dir: str = "checkpoints", 31 | train_data_path: str = "./training_images/", 32 | ): 33 | super().__init__() 34 | self.backbone_name = backbone_name if backbone_name is not None else "resnet18" 35 | self.seq_name = seq_name 36 | self.pred_name = pred_name 37 | path = os.path.join(train_data_path, "tgt.csv") 38 | self.vocab = Vocab().get_vocab() 39 | if self.pred_name == "ctc": 40 | self.converter = CTCLabelConverter_clovaai(self.vocab, device="cuda") 41 | else: 42 | self.converter = AttnLabelConverter(self.vocab, batch_max_length=batch_max_length) 43 | self._build_model() 44 | 45 | # TODO: add optimizer, scheduler, loss, metrics 46 | if self.pred_name == "ctc": 47 | self.loss = nn.CTCLoss(blank=0, zero_infinity=True) 48 | else: 49 | self.loss = nn.CrossEntropyLoss(ignore_index=0) 50 | 51 | # Initialize metrics 52 | self.cer = CharErrorRate() 53 | self.wer = WordErrorRate() 54 | self.ser = SentenceErrorRate() 55 | self.val_predictions = [] 56 | self.val_targets = [] 57 | 58 | self.learning_rate = learning_rate 59 | self.batch_size = batch_size 60 | self.weight_decay = weight_decay 61 | self.batch_max_length = batch_max_length 62 | self.save_dir = save_dir 63 | self.dali = dali 64 | self.val_epoch_start_time = 0 # Because the sanity check fails 65 | logger.info(f"{self.dali=}") 66 | 67 | def _build_model(self): 68 | logger.info(f"{self.backbone_name}") 69 | backbone_cls, seq_module_cls, pred_module_cls = get_module(self.backbone_name, self.seq_name, self.pred_name) 70 | 71 | # Initialize backbone 72 | self.backbone = backbone_cls(input_channels=3, output_channels=512) 73 | 74 | # Initialize sequence module if provided 75 | self.seq_module = None 76 | if seq_module_cls is not None: 77 | self.seq_module = seq_module_cls( 78 | input_size=512, # ResNet backbone output channels 79 | hidden_size=256, # Hidden size for LSTM 80 | output_size=256, # Output size from sequence module 81 | ) 82 | 83 | # Initialize prediction module 84 | input_dim = ( 85 | 256 if self.seq_module else 512 # channels from VGG output 86 | ) # When no seq_module, use channel size from VGG 87 | self.pred_module = pred_module_cls( 88 | input_dim=input_dim, 89 | num_classes=len(self.converter.character), # len(self.vocab), # Number of classes including blank token 90 | ) 91 | 92 | def forward(self, x, text): 93 | x = self.backbone(x) 94 | if self.seq_module: 95 | x = self.seq_module(x) 96 | x = self.pred_module(x, text=text, is_train=self.training, batch_max_length=self.batch_max_length) 97 | return x 98 | 99 | def on_train_epoch_start(self): 100 | # Start timing for training epoch 101 | self.train_epoch_start_time = time.time() 102 | logger.debug("Starting training epoch") 103 | 104 | def training_step(self, batch, batch_idx): 105 | images = batch["data"] 106 | text_encoded = batch["label"] 107 | text_lengths = batch["length"] 108 | 109 | if self.pred_name == "ctc": 110 | # Forward pass 111 | logits = self(images, text=text_encoded) 112 | # Prepare CTC input 113 | log_probs = logits.log_softmax(2).permute(1, 0, 2) 114 | # Calculate input lengths 115 | preds_size = torch.LongTensor([logits.size(1)] * images.size(0)) 116 | # Calculate loss 117 | loss = self.loss(log_probs, text_encoded, preds_size, text_lengths) 118 | else: 119 | preds = self(images, text=text_encoded[:, :-1]).to(device) 120 | target = text_encoded[:, 1:].to(device) 121 | loss = self.loss(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 122 | 123 | # Log metrics 124 | self.log("train_loss", loss, prog_bar=True) 125 | return loss 126 | 127 | def validation_step(self, batch, batch_idx): 128 | images = batch["data"] 129 | text_encoded = batch["label"] 130 | text_lengths = batch["length"] 131 | # text_encoded, text_lengths = self.converter.encode(labels, batch_max_length=self.batch_max_length) 132 | 133 | labels = self.converter.decode(text_encoded, text_lengths) 134 | 135 | if self.pred_name == "ctc": 136 | # Forward pass 137 | logits = self(images, text=text_encoded) 138 | # Prepare CTC input 139 | log_probs = logits.log_softmax(2).permute(1, 0, 2) 140 | # Calculate input lengths 141 | input_lengths = torch.full( 142 | size=(logits.size(0),), 143 | fill_value=logits.size(1), 144 | dtype=torch.long, 145 | device=self.device, 146 | ) 147 | # Calculate loss 148 | loss = self.loss(log_probs, text_encoded, input_lengths, text_lengths) 149 | 150 | # Get predictions for metrics 151 | preds = log_probs.argmax(2).permute(1, 0).detach().cpu() 152 | pred_texts = self.converter.decode(preds) 153 | 154 | else: 155 | # Attention model validation 156 | preds = self(images, text=text_encoded[:, :-1]).to(device) 157 | target = text_encoded[:, 1:].to(device) # Shift target by 1 since we predict next char 158 | loss = self.loss(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 159 | 160 | # Get predictions for metrics 161 | pred_size = torch.LongTensor([preds.size(1)] * preds.size(0)) 162 | _, pred_index = preds.max(2) 163 | pred_texts = self.converter.decode(pred_index, pred_size) 164 | 165 | # Store predictions and targets for epoch end metrics 166 | self.val_predictions.extend(pred_texts) 167 | self.val_targets.extend(labels) 168 | 169 | self.log("val_loss", loss, prog_bar=True) 170 | return loss 171 | 172 | def predict_step(self, batch, batch_idx): 173 | images = batch["data"] 174 | 175 | if self.pred_name == "ctc": 176 | # Forward pass 177 | logits = self(images, text=None) 178 | # Prepare CTC input 179 | log_probs = logits.log_softmax(2).permute(1, 0, 2) 180 | 181 | # Get predictions for metrics 182 | preds = log_probs.argmax(2).permute(1, 0).detach().cpu() 183 | pred_texts = self.converter.decode(preds) 184 | 185 | else: 186 | # Attention model validation 187 | preds = self(images, text=None).to(device) 188 | # Get predictions for metrics 189 | _, pred_index = preds.max(2) 190 | pred_texts = self.converter.decode(pred_index, None) 191 | 192 | return pred_texts 193 | 194 | def on_train_start(self): 195 | # Log hyperparameters to the logger 196 | hyperparams = { 197 | "backbone_name": self.backbone_name, 198 | "seq_name": self.seq_name, 199 | "pred_name": self.pred_name, 200 | "batch_size": self.batch_size, 201 | "learning_rate": self.learning_rate, 202 | "weight_decay": self.weight_decay, 203 | "batch_max_length": self.batch_max_length, 204 | "max_epochs": self.trainer.max_epochs, 205 | "dali": self.dali, 206 | } 207 | 208 | self.logger.log_hyperparams(hyperparams) 209 | logger.info(f"Logged hyperparameters: {hyperparams}") 210 | 211 | def on_train_epoch_end(self): 212 | # Calculate and log training epoch time 213 | train_epoch_time = time.time() - self.train_epoch_start_time 214 | logger.info(f"Training epoch {self.current_epoch} completed in {train_epoch_time:.2f} seconds") 215 | self.log("train_epoch_time", train_epoch_time) 216 | 217 | # Save model after each training epoch 218 | epoch = self.current_epoch 219 | save_path = f"{self.save_dir}/model_train_epoch_{epoch}.ckpt" 220 | self.trainer.save_checkpoint(save_path) 221 | logger.info(f"Saved model checkpoint after training epoch {epoch} to {save_path}") 222 | # Reset 223 | if self.dali: 224 | self.trainer.datamodule.train_dataloader.reset() 225 | 226 | def on_validation_epoch_start(self): 227 | # Reset stored predictions and targets 228 | self.val_predictions = [] 229 | self.val_targets = [] 230 | 231 | # Start timing for validation epoch 232 | self.val_epoch_start_time = time.time() 233 | logger.debug("Starting validation epoch") 234 | 235 | def on_validation_epoch_end(self): 236 | # Calculate and log validation epoch time 237 | val_epoch_time = time.time() - self.val_epoch_start_time 238 | logger.info(f"Validation epoch {self.current_epoch} completed in {val_epoch_time:.2f} seconds") 239 | self.log("val_epoch_time", val_epoch_time) 240 | 241 | # Calculate and log CER 242 | cer = self.cer(self.val_predictions, self.val_targets) 243 | self.log("val_cer", cer, prog_bar=True) 244 | logger.info(f"Validation CER: {cer:.4f}") 245 | 246 | # Calculate and log WER 247 | wer = self.wer(self.val_predictions, self.val_targets) 248 | self.log("val_wer", wer, prog_bar=True) 249 | logger.info(f"Validation WER: {wer:.4f}") 250 | 251 | # Calculate and log SER 252 | ser = self.ser(self.val_predictions, self.val_targets) 253 | self.log("val_ser", ser, prog_bar=True) 254 | logger.info(f"Validation SER: {ser:.4f}") 255 | 256 | # Save model after validation with metrics in filename 257 | epoch = self.current_epoch 258 | val_loss = self.trainer.callback_metrics.get("val_loss", 0) 259 | save_path = f"{self.save_dir}/model_val_epoch_{epoch}_loss_{val_loss:.4f}_cer_{cer:.4f}_wer_{wer:.4f}.ckpt" 260 | self.trainer.save_checkpoint(save_path) 261 | logger.info(f"Saved model checkpoint after validation epoch {epoch} to {save_path}") 262 | 263 | # Clear predictions and targets 264 | self.val_predictions = [] 265 | self.val_targets = [] 266 | # Reset 267 | if self.dali: 268 | self.trainer.datamodule.val_dataloader.reset() 269 | 270 | def evaluate(self, batch, batch_idx): 271 | # TODO: implement evaluation 272 | pass 273 | 274 | def configure_optimizers(self): 275 | # Group parameters by module for different learning rates 276 | param_groups = [ 277 | {"params": self.backbone.parameters(), "lr": self.learning_rate * 0.1}, 278 | {"params": self.pred_module.parameters(), "lr": self.learning_rate}, 279 | ] 280 | 281 | # Only add seq_module parameters if it exists 282 | if self.seq_module: 283 | param_groups.insert(1, {"params": self.seq_module.parameters(), "lr": self.learning_rate}) 284 | 285 | optimizer = torch.optim.AdamW(param_groups, weight_decay=self.weight_decay) 286 | 287 | # Use CosineAnnealingLR for learning rate scheduling 288 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 289 | optimizer, 290 | T_max=self.trainer.max_epochs, 291 | eta_min=self.learning_rate / 100, # Minimum learning rate at the end of schedule 292 | ) 293 | 294 | return { 295 | "optimizer": optimizer, 296 | "lr_scheduler": { 297 | "scheduler": scheduler, 298 | "interval": "epoch", 299 | "frequency": 1, 300 | }, 301 | } 302 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningDataModule 2 | from torch.utils.data import DataLoader 3 | from nvidia.dali.plugin.pytorch import LastBatchPolicy 4 | 5 | from data import ( 6 | OCRDataset, 7 | OCRCollator, 8 | data_transforms, 9 | data_transforms_2, 10 | process_tgt, 11 | Vocab, 12 | ExternalInputCallable, 13 | LightningWrapper, 14 | ) 15 | from utils import AttnLabelConverter, CTCLabelConverter 16 | from loguru import logger 17 | import os 18 | from nvidia.dali.pipeline import pipeline_def 19 | import nvidia.dali.types as types 20 | import nvidia.dali.fn as fn 21 | import numpy as np 22 | 23 | 24 | class OCRDataModule(LightningDataModule): 25 | def __init__( 26 | self, 27 | batch_max_length, 28 | frac, 29 | dali: bool = False, 30 | train_data_path: str = "./training_images/", 31 | val_data_path: str = "./validation_images/", 32 | batch_size: int = 32, 33 | num_workers: int = 4, 34 | pred_name: str = "attn", 35 | ): 36 | logger.debug(f"{train_data_path=}") 37 | logger.debug(f"{val_data_path=}") 38 | logger.debug(f"{batch_size=}") 39 | logger.debug(f"{num_workers=}") 40 | logger.debug(f"{batch_max_length=}") 41 | 42 | super().__init__() 43 | self.train_data_path = train_data_path 44 | self.val_data_path = val_data_path 45 | self.batch_size = batch_size 46 | self.num_workers = num_workers 47 | self.batch_max_length = batch_max_length 48 | self.dali = dali 49 | self.pred_name = pred_name 50 | self.frac = frac 51 | self.collator = OCRCollator() 52 | 53 | # Save hyperparameters for logging 54 | self.save_hyperparameters() 55 | 56 | def train_dataloader(self): 57 | self.train_data = OCRDataset( 58 | self.train_data_path, 59 | transform=data_transforms_2["train"], 60 | batch_max_length=self.batch_max_length, 61 | pred_name=self.pred_name, 62 | frac=self.frac, 63 | ) 64 | return DataLoader( 65 | self.train_data, 66 | batch_size=self.batch_size, 67 | num_workers=self.num_workers, 68 | pin_memory=True, 69 | collate_fn=self.collator, 70 | persistent_workers=True, 71 | shuffle=True, 72 | ) 73 | 74 | def val_dataloader(self): 75 | self.val_data = OCRDataset( 76 | self.val_data_path, 77 | transform=data_transforms_2["val"], 78 | batch_max_length=self.batch_max_length, 79 | pred_name=self.pred_name, 80 | frac=self.frac, 81 | ) 82 | return DataLoader( 83 | self.val_data, 84 | batch_size=self.batch_size, 85 | num_workers=self.num_workers, 86 | pin_memory=True, 87 | collate_fn=self.collator, 88 | persistent_workers=True, 89 | ) 90 | 91 | 92 | class DALI_OCRDataModule(LightningDataModule): 93 | def __init__( 94 | self, 95 | batch_max_length, 96 | frac, 97 | dali: bool = True, 98 | train_data_path: str = "./training_images/", 99 | val_data_path: str = "./validation_images/", 100 | batch_size: int = 32, 101 | num_workers: int = 4, 102 | pred_name: str = "attn", 103 | ): 104 | logger.debug(f"{train_data_path=}") 105 | logger.debug(f"{val_data_path=}") 106 | logger.debug(f"{batch_size=}") 107 | logger.debug(f"{num_workers=}") 108 | logger.debug(f"{batch_max_length=}") 109 | 110 | super().__init__() 111 | self.train_data_path = train_data_path 112 | self.val_data_path = val_data_path 113 | self.batch_size = batch_size 114 | self.num_workers = num_workers 115 | self.batch_max_length = batch_max_length 116 | self.dali = dali 117 | self.pred_name = pred_name 118 | 119 | # Save hyperparameters for logging 120 | self.save_hyperparameters() 121 | 122 | logger.debug("Get Vocab") 123 | path = os.path.join(self.train_data_path, "tgt.csv") 124 | vocab = Vocab().get_vocab() 125 | logger.debug(f"{pred_name=}") 126 | if pred_name == "ctc": 127 | self.converter = CTCLabelConverter(vocab, device="cpu") 128 | else: 129 | self.converter = AttnLabelConverter(vocab, batch_max_length=self.batch_max_length, device="cpu") 130 | logger.debug("Processing tgt.csv file") 131 | self.train_images_names, self.train_labels = process_tgt( 132 | self.train_data_path, batch_max_length=self.batch_max_length, frac=frac 133 | ) 134 | self.val_images_names, self.val_labels = process_tgt( 135 | self.val_data_path, batch_max_length=self.batch_max_length, frac=frac 136 | ) 137 | logger.debug("Done!") 138 | # logger.debug(f"{len(self.train_images_names)=}") 139 | # logger.debug(f"{self.batch_size=}") 140 | self.steps_per_epoch = len(self.train_images_names) // self.batch_size 141 | logger.debug(f"{self.steps_per_epoch=}") 142 | self.train_data_path = os.path.join(self.train_data_path, "images") 143 | self.val_data_path = os.path.join(self.val_data_path, "images") 144 | 145 | self.MEAN = np.asarray([0.485, 0.456, 0.406])[None, None, :] 146 | self.STD = np.asarray([0.229, 0.224, 0.225])[None, None, :] 147 | self.SCALE = 1 / 255.0 148 | 149 | def train_dataloader(self): 150 | logger.debug("Building train DALI pipelines...") 151 | train_pipeline = self.get_dali_train_pipeline_aug(batch_size=self.batch_size, num_threads=self.num_workers) 152 | train_pipeline.build() 153 | logger.debug("Train DALI pipelines built.") 154 | 155 | self.train_dataloader = LightningWrapper( 156 | pipelines=train_pipeline, 157 | output_map=["data", "label", "length"], 158 | dataset_size=self.steps_per_epoch, 159 | auto_reset=False, 160 | last_batch_policy=LastBatchPolicy.FILL, 161 | # dynamic_shape=True 162 | ) 163 | # self.train_dataloader.pipelines.run() 164 | return self.train_dataloader 165 | 166 | def val_dataloader(self): 167 | logger.debug("Building val DALI pipelines...") 168 | val_pipeline = self.get_dali_val_pipeline(batch_size=self.batch_size, num_threads=self.num_workers) 169 | val_pipeline.build() 170 | logger.debug("Val DALI pipelines built.") 171 | # self.val_dataloader = DALIClassificationIterator( 172 | # pipelines=val_pipeline, 173 | # auto_reset=True, 174 | # ) 175 | self.val_dataloader = LightningWrapper( 176 | pipelines=val_pipeline, 177 | output_map=["data", "label", "length"], 178 | dataset_size=len(self.val_images_names) // self.batch_size, 179 | auto_reset=False, 180 | last_batch_policy=LastBatchPolicy.FILL, 181 | # dynamic_shape=True 182 | ) 183 | return self.val_dataloader 184 | 185 | @pipeline_def( 186 | num_threads=8, 187 | batch_size=32, 188 | device_id=0, 189 | py_start_method="spawn", 190 | exec_dynamic=True, 191 | ) 192 | def get_dali_train_pipeline(self): 193 | # images, _ = fn.readers.file(file_root=self.val_data_path, files=self.val_data_path, random_shuffle=False, name="Reader") 194 | images, indices, length = fn.external_source( 195 | source=ExternalInputCallable( 196 | steps_per_epoch=self.steps_per_epoch, 197 | data_path=self.train_data_path, 198 | converter=self.converter, 199 | images_names=self.train_images_names, 200 | labels=self.train_labels, 201 | batch_size=self.batch_size, 202 | ), 203 | num_outputs=3, 204 | batch=False, 205 | parallel=True, 206 | dtype=[types.UINT8, types.INT64, types.INT64], 207 | prefetch_queue_depth=8, 208 | ) 209 | images = fn.decoders.image(images, device="mixed", output_type=types.RGB) 210 | images = fn.resize(images, device="gpu", resize_y=100, dtype=types.FLOAT) 211 | images = fn.normalize( 212 | images, device="gpu", dtype=types.FLOAT, mean=self.MEAN / self.SCALE, stddev=self.STD, scale=self.SCALE 213 | ) 214 | # images = images.gpu() 215 | # images = fn.cast(images, dtype=types.FLOAT) 216 | images = fn.pad(images, fill_value=0) 217 | indices = fn.pad(indices, fill_value=0) 218 | length = fn.pad(length, fill_value=0) 219 | return images, indices.gpu(), length.gpu() 220 | 221 | @pipeline_def( 222 | num_threads=8, 223 | batch_size=32, 224 | device_id=0, 225 | py_start_method="spawn", 226 | exec_dynamic=True, 227 | ) 228 | def get_dali_val_pipeline(self): 229 | # images, _ = fn.readers.file(file_root=self.val_data_path, files=self.val_data_path, random_shuffle=False, name="Reader") 230 | images, indices, length = fn.external_source( 231 | source=ExternalInputCallable( 232 | steps_per_epoch=len(self.val_images_names) // self.batch_size, 233 | data_path=self.val_data_path, 234 | converter=self.converter, 235 | images_names=self.val_images_names, 236 | labels=self.val_labels, 237 | batch_size=self.batch_size, 238 | ), 239 | num_outputs=3, 240 | batch=False, 241 | parallel=True, 242 | dtype=[types.UINT8, types.INT64, types.INT64], 243 | prefetch_queue_depth=8, 244 | ) 245 | images = fn.decoders.image(images, device="mixed", output_type=types.RGB) 246 | images = fn.resize(images, device="gpu", resize_y=100, dtype=types.FLOAT) 247 | images = fn.normalize( 248 | images, device="gpu", dtype=types.FLOAT, mean=self.MEAN / self.SCALE, stddev=self.STD, scale=self.SCALE 249 | ) 250 | # images = images.gpu() 251 | # images = fn.cast(images, dtype=types.FLOAT) 252 | images = fn.pad(images, fill_value=0) 253 | indices = fn.pad(indices, fill_value=0) 254 | length = fn.pad(length, fill_value=0) 255 | return images, indices.gpu(), length.gpu() 256 | 257 | @pipeline_def( 258 | num_threads=8, 259 | batch_size=32, 260 | device_id=0, 261 | py_start_method="spawn", 262 | exec_dynamic=True, 263 | ) 264 | def get_dali_train_pipeline_aug(self): 265 | images, indices, length = fn.external_source( 266 | source=ExternalInputCallable( 267 | steps_per_epoch=self.steps_per_epoch, 268 | data_path=self.train_data_path, 269 | converter=self.converter, 270 | images_names=self.train_images_names, 271 | labels=self.train_labels, 272 | batch_size=self.batch_size, 273 | ), 274 | num_outputs=3, 275 | batch=False, 276 | parallel=True, 277 | dtype=[types.UINT8, types.INT64, types.INT64], 278 | prefetch_queue_depth=8, 279 | ) 280 | images = fn.decoders.image(images, device="mixed", output_type=types.RGB) 281 | images = fn.rotate( 282 | images, 283 | device="gpu", 284 | angle=fn.random.uniform(range=[-1, 1]), 285 | dtype=types.FLOAT, 286 | ) 287 | images = fn.resize(images, device="gpu", resize_y=100) 288 | images = fn.color_twist( 289 | images, 290 | brightness=fn.random.uniform(range=[0.8, 1.2]), 291 | contrast=fn.random.uniform(range=[0.8, 1.2]), 292 | saturation=fn.random.uniform(range=[0.8, 1.2]), 293 | hue=fn.random.uniform(range=[0, 0.3]), 294 | ) 295 | images = fn.warp_affine( 296 | images, 297 | matrix=fn.transforms.scale(scale=fn.random.uniform(range=[0.9, 1], shape=[2])), 298 | fill_value=0, 299 | inverse_map=False, 300 | ) 301 | images = fn.noise.gaussian(images, mean=0.0, stddev=fn.random.uniform(range=[-10, 10])) 302 | images = fn.normalize( 303 | images, device="gpu", dtype=types.FLOAT, mean=self.MEAN / self.SCALE, stddev=self.STD, scale=self.SCALE 304 | ) 305 | images = fn.pad(images, fill_value=0) 306 | indices = fn.pad(indices, fill_value=0) 307 | length = fn.pad(length, fill_value=0) 308 | return images, indices.gpu(), length.gpu() 309 | -------------------------------------------------------------------------------- /inference/inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "\n", 11 | "sys.path.append(\"..\")" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stderr", 21 | "output_type": "stream", 22 | "text": [ 23 | "/home/qsvm/miniconda3/envs/new_ocr/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 24 | " from .autonotebook import tqdm as notebook_tqdm\n", 25 | "/home/qsvm/miniconda3/envs/new_ocr/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py:42: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.Output is equivalent up to float precision.\n", 26 | " warnings.warn(\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "from main import OCRModel\n", 32 | "from inference import Inference" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "/home/qsvm/miniconda3/envs/new_ocr/lib/python3.11/site-packages/huggingface_hub/commands/download.py:139: FutureWarning: Ignoring --local-dir-use-symlinks. Downloading to a local directory does not use symlinks anymore.\n", 45 | " warnings.warn(\n", 46 | "checkpoints/pretrained_viet/viet.ckpt\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "! huggingface-cli download ducto489/ocr_model viet.ckpt --repo-type model --local-dir checkpoints/pretrained_viet --local-dir-use-symlinks False" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "checkpoint = \"checkpoints/pretrained_viet/viet.ckpt\"" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 18, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stderr", 70 | "output_type": "stream", 71 | "text": [ 72 | "\u001b[32m2025-05-25 05:31:01.770\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmain\u001b[0m:\u001b[36m_build_model\u001b[0m:\u001b[36m68\u001b[0m - \u001b[1mresnet18\u001b[0m\n" 73 | ] 74 | }, 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "backbone_name: resnet18\n", 80 | "backbone_fac: {'resnet18': , 'resnet50': , 'vgg': }\n" 81 | ] 82 | }, 83 | { 84 | "name": "stderr", 85 | "output_type": "stream", 86 | "text": [ 87 | "\u001b[32m2025-05-25 05:31:02.607\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmain\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mself.dali=True\u001b[0m\n" 88 | ] 89 | }, 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "Predictions: ['Chennai (Madras)']\n", 95 | "Batch inference time: 0.082s\n", 96 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 97 | "Batch inference time: 0.093s\n", 98 | "Predictions: ['CLEANED']\n", 99 | "Batch inference time: 0.073s\n", 100 | "Predictions: ['Chennai (Madras)']\n", 101 | "Batch inference time: 0.074s\n", 102 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 103 | "Batch inference time: 0.074s\n", 104 | "Predictions: ['CLEANED']\n", 105 | "Batch inference time: 0.072s\n", 106 | "Predictions: ['Chennai (Madras)']\n", 107 | "Batch inference time: 0.074s\n", 108 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 109 | "Batch inference time: 0.075s\n", 110 | "Predictions: ['CLEANED']\n", 111 | "Batch inference time: 0.075s\n", 112 | "Predictions: ['Chennai (Madras)']\n", 113 | "Batch inference time: 0.076s\n", 114 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 115 | "Batch inference time: 0.074s\n", 116 | "Predictions: ['CLEANED']\n", 117 | "Batch inference time: 0.091s\n", 118 | "Predictions: ['Chennai (Madras)']\n", 119 | "Batch inference time: 0.075s\n", 120 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 121 | "Batch inference time: 0.075s\n", 122 | "Predictions: ['CLEANED']\n", 123 | "Batch inference time: 0.072s\n", 124 | "Predictions: ['Chennai (Madras)']\n", 125 | "Batch inference time: 0.076s\n", 126 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 127 | "Batch inference time: 0.079s\n", 128 | "Predictions: ['CLEANED']\n", 129 | "Batch inference time: 0.073s\n", 130 | "Predictions: ['Chennai (Madras)']\n", 131 | "Batch inference time: 0.076s\n", 132 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 133 | "Batch inference time: 0.077s\n", 134 | "Predictions: ['CLEANED']\n", 135 | "Batch inference time: 0.073s\n", 136 | "Predictions: ['Chennai (Madras)']\n", 137 | "Batch inference time: 0.074s\n", 138 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 139 | "Batch inference time: 0.074s\n", 140 | "Predictions: ['CLEANED']\n", 141 | "Batch inference time: 0.074s\n", 142 | "Predictions: ['Chennai (Madras)']\n", 143 | "Batch inference time: 0.078s\n", 144 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 145 | "Batch inference time: 0.075s\n", 146 | "Predictions: ['CLEANED']\n", 147 | "Batch inference time: 0.072s\n", 148 | "Predictions: ['Chennai (Madras)']\n", 149 | "Batch inference time: 0.075s\n", 150 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 151 | "Batch inference time: 0.075s\n", 152 | "Predictions: ['CLEANED']\n", 153 | "Batch inference time: 0.071s\n", 154 | "Predictions: ['Chennai (Madras)']\n", 155 | "Batch inference time: 0.074s\n", 156 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 157 | "Batch inference time: 0.078s\n", 158 | "Predictions: ['CLEANED']\n", 159 | "Batch inference time: 0.074s\n", 160 | "Predictions: ['Chennai (Madras)']\n", 161 | "Batch inference time: 0.075s\n", 162 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 163 | "Batch inference time: 0.075s\n", 164 | "Predictions: ['CLEANED']\n", 165 | "Batch inference time: 0.071s\n", 166 | "Predictions: ['Chennai (Madras)']\n", 167 | "Batch inference time: 0.074s\n", 168 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 169 | "Batch inference time: 0.074s\n", 170 | "Predictions: ['CLEANED']\n", 171 | "Batch inference time: 0.072s\n", 172 | "Predictions: ['Chennai (Madras)']\n", 173 | "Batch inference time: 0.078s\n", 174 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 175 | "Batch inference time: 0.076s\n", 176 | "Predictions: ['CLEANED']\n", 177 | "Batch inference time: 0.077s\n", 178 | "Predictions: ['Chennai (Madras)']\n", 179 | "Batch inference time: 0.075s\n", 180 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 181 | "Batch inference time: 0.076s\n", 182 | "Predictions: ['CLEANED']\n", 183 | "Batch inference time: 0.072s\n", 184 | "Predictions: ['Chennai (Madras)']\n", 185 | "Batch inference time: 0.074s\n", 186 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 187 | "Batch inference time: 0.076s\n", 188 | "Predictions: ['CLEANED']\n", 189 | "Batch inference time: 0.075s\n", 190 | "Predictions: ['Chennai (Madras)']\n", 191 | "Batch inference time: 0.076s\n", 192 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 193 | "Batch inference time: 0.074s\n", 194 | "Predictions: ['CLEANED']\n", 195 | "Batch inference time: 0.071s\n", 196 | "Predictions: ['Chennai (Madras)']\n", 197 | "Batch inference time: 0.076s\n", 198 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 199 | "Batch inference time: 0.074s\n", 200 | "Predictions: ['CLEANED']\n", 201 | "Batch inference time: 0.071s\n", 202 | "Predictions: ['Chennai (Madras)']\n", 203 | "Batch inference time: 0.076s\n", 204 | "Predictions: ['because the sources of the pollution would be outside of urban areas']\n", 205 | "Batch inference time: 0.078s\n", 206 | "Predictions: ['CLEANED']\n", 207 | "Batch inference time: 0.073s\n" 208 | ] 209 | }, 210 | { 211 | "ename": "KeyboardInterrupt", 212 | "evalue": "", 213 | "output_type": "error", 214 | "traceback": [ 215 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 216 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 217 | "Cell \u001b[0;32mIn[18], line 13\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m predict_dataloader:\n\u001b[1;32m 12\u001b[0m batch_start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m---> 13\u001b[0m preds \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m batch_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime() \u001b[38;5;241m-\u001b[39m batch_start\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPredictions: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpreds\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", 218 | "File \u001b[0;32m~/temp2/lib_ocr/inference/../main.py:187\u001b[0m, in \u001b[0;36mOCRModel.predict_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 183\u001b[0m pred_texts \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconverter\u001b[38;5;241m.\u001b[39mdecode(preds)\n\u001b[1;32m 185\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# Attention model validation\u001b[39;00m\n\u001b[0;32m--> 187\u001b[0m preds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mimages\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# Get predictions for metrics\u001b[39;00m\n\u001b[1;32m 189\u001b[0m _, pred_index \u001b[38;5;241m=\u001b[39m preds\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;241m2\u001b[39m)\n", 219 | "File \u001b[0;32m~/miniconda3/envs/new_ocr/lib/python3.11/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", 220 | "File \u001b[0;32m~/miniconda3/envs/new_ocr/lib/python3.11/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", 221 | "File \u001b[0;32m~/temp2/lib_ocr/inference/../main.py:96\u001b[0m, in \u001b[0;36mOCRModel.forward\u001b[0;34m(self, x, text)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mseq_module:\n\u001b[1;32m 95\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mseq_module(x)\n\u001b[0;32m---> 96\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpred_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_train\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_max_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_max_length\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", 222 | "File \u001b[0;32m~/miniconda3/envs/new_ocr/lib/python3.11/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", 223 | "File \u001b[0;32m~/miniconda3/envs/new_ocr/lib/python3.11/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", 224 | "File \u001b[0;32m~/temp2/lib_ocr/inference/../models/pred_modules.py:60\u001b[0m, in \u001b[0;36mAttention.forward\u001b[0;34m(self, batch_H, batch_max_length, text, is_train)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_steps):\n\u001b[1;32m 59\u001b[0m char_onehots \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_char_to_onehot(target)\n\u001b[0;32m---> 60\u001b[0m hidden \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention_cell\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_H\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchar_onehots\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 61\u001b[0m probs_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgenerator(hidden)\n\u001b[1;32m 62\u001b[0m probs[:, i, :] \u001b[38;5;241m=\u001b[39m probs_step\n", 225 | "File \u001b[0;32m~/miniconda3/envs/new_ocr/lib/python3.11/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", 226 | "File \u001b[0;32m~/miniconda3/envs/new_ocr/lib/python3.11/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", 227 | "File \u001b[0;32m~/temp2/lib_ocr/inference/../models/pred_modules.py:97\u001b[0m, in \u001b[0;36mAttentionCell.forward\u001b[0;34m(self, prev_hidden, batch_H, char_onehots)\u001b[0m\n\u001b[1;32m 94\u001b[0m prev_hidden_proj \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mh2h(prev_hidden)\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 96\u001b[0m \u001b[38;5;66;03m# Scaled dot-product attention\u001b[39;00m\n\u001b[0;32m---> 97\u001b[0m e \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscore\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtanh\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_H_proj\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mprev_hidden_proj\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 99\u001b[0m \u001b[38;5;66;03m# Apply attention with temperature scaling\u001b[39;00m\n\u001b[1;32m 100\u001b[0m alpha \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39msoftmax(e, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Equation 5, batch_size x num_steps x 1\u001b[39;00m\n", 228 | "File \u001b[0;32m~/miniconda3/envs/new_ocr/lib/python3.11/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", 229 | "File \u001b[0;32m~/miniconda3/envs/new_ocr/lib/python3.11/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", 230 | "File \u001b[0;32m~/miniconda3/envs/new_ocr/lib/python3.11/site-packages/torch/nn/modules/linear.py:125\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 125\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", 231 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 232 | ] 233 | } 234 | ], 235 | "source": [ 236 | "inference = Inference(image_path=\"images\") \n", 237 | "predict_dataloader = inference.predict_dataloader()\n", 238 | "model = OCRModel.load_from_checkpoint(\n", 239 | " checkpoint, strict=True, batch_max_length=200, dali=True, map_location=\"cuda\", pred_name=\"attn\"\n", 240 | ")\n", 241 | "model.eval()\n", 242 | "\n", 243 | "import time\n", 244 | "\n", 245 | "start_time = time.time()\n", 246 | "for batch in predict_dataloader:\n", 247 | " batch_start = time.time()\n", 248 | " preds = model.predict_step(batch, 0)\n", 249 | " batch_time = time.time() - batch_start\n", 250 | " print(f\"Predictions: {preds}\")\n", 251 | " print(f\"Batch inference time: {batch_time:.3f}s\")\n", 252 | "\n", 253 | "total_time = time.time() - start_time\n", 254 | "print(f\"\\nTotal inference time: {total_time:.3f}s\")" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "Test Inference time with VietOCR" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 9, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "name": "stdout", 271 | "output_type": "stream", 272 | "text": [ 273 | "Model weight /tmp/vgg_transformer.pth exsits. Ignore download!\n", 274 | "VietOCR inference time: 0.565s\n" 275 | ] 276 | } 277 | ], 278 | "source": [ 279 | "from vietocr.tool.predictor import Predictor\n", 280 | "from vietocr.tool.config import Cfg\n", 281 | "from PIL import Image\n", 282 | "\n", 283 | "# config = Cfg.load_config_from_file('config.yml') # sử dụng config của các bạn được export lúc train nếu đã thay đổi tham số \n", 284 | "config = Cfg.load_config_from_name('vgg_transformer') # sử dụng config mặc định của mình \n", 285 | "# config['weights'] = '/home/qsvm/.cache/torch/hub/checkpoints/vgg19_bn-c79401a0.pth' # đường dẫn đến trọng số đã huấn luyện hoặc comment để sử dụng pretrained model của mình\n", 286 | "config['device'] = 'cuda:0' # device chạy 'cuda:0', 'cuda:1', 'cpu'\n", 287 | "\n", 288 | "detector = Predictor(config)\n", 289 | "\n", 290 | "img = './images/1/paper_img_002844.png'\n", 291 | "img = Image.open(img)\n", 292 | "# dự đoán \n", 293 | "start_time = time.time()\n", 294 | "s = detector.predict(img, return_prob=False) # muốn trả về xác suất của câu dự đoán thì đổi return_prob=True\n", 295 | "inference_time = time.time() - start_time\n", 296 | "print(f\"VietOCR inference time: {inference_time:.3f}s\")" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 20, 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "name": "stderr", 306 | "output_type": "stream", 307 | "text": [ 308 | "10935it [00:29, 376.59it/s]\n", 309 | "/home/qsvm/miniconda3/envs/new_ocr/lib/python3.11/site-packages/vietocr/tool/predictor.py:20: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", 310 | " model.load_state_dict(torch.load(weights, map_location=torch.device(device)))\n" 311 | ] 312 | }, 313 | { 314 | "name": "stdout", 315 | "output_type": "stream", 316 | "text": [ 317 | "VietOCR inference time: 0.030s\n" 318 | ] 319 | } 320 | ], 321 | "source": [ 322 | "from vietocr.tool.predictor import Predictor\n", 323 | "from vietocr.tool.config import Cfg\n", 324 | "from PIL import Image\n", 325 | "\n", 326 | "# config = Cfg.load_config_from_file('config.yml') # sử dụng config của các bạn được export lúc train nếu đã thay đổi tham số \n", 327 | "config = Cfg.load_config_from_name('vgg_seq2seq') # sử dụng config mặc định của mình \n", 328 | "# config['weights'] = '/home/qsvm/.cache/torch/hub/checkpoints/vgg19_bn-c79401a0.pth' # đường dẫn đến trọng số đã huấn luyện hoặc comment để sử dụng pretrained model của mình\n", 329 | "config['device'] = 'cuda:0' # device chạy 'cuda:0', 'cuda:1', 'cpu'\n", 330 | "\n", 331 | "detector = Predictor(config)\n", 332 | "\n", 333 | "img = './images/1/paper_img_002844.png'\n", 334 | "img = Image.open(img)\n", 335 | "# dự đoán \n", 336 | "start_time = time.time()\n", 337 | "s = detector.predict(img, return_prob=False) # muốn trả về xác suất của câu dự đoán thì đổi return_prob=True\n", 338 | "inference_time = time.time() - start_time\n", 339 | "print(f\"VietOCR inference time: {inference_time:.3f}s\")" 340 | ] 341 | } 342 | ], 343 | "metadata": { 344 | "kernelspec": { 345 | "display_name": "new_ocr", 346 | "language": "python", 347 | "name": "python3" 348 | }, 349 | "language_info": { 350 | "codemirror_mode": { 351 | "name": "ipython", 352 | "version": 3 353 | }, 354 | "file_extension": ".py", 355 | "mimetype": "text/x-python", 356 | "name": "python", 357 | "nbconvert_exporter": "python", 358 | "pygments_lexer": "ipython3", 359 | "version": "3.11.11" 360 | } 361 | }, 362 | "nbformat": 4, 363 | "nbformat_minor": 2 364 | } 365 | --------------------------------------------------------------------------------