├── .gitignore ├── LICENSE ├── README.md ├── base ├── __init__.py ├── common_util.py ├── driver.py ├── meter.py └── torch_utils │ ├── __init__.py │ ├── dl_util.py │ └── scheduler_util.py ├── config └── base.yaml ├── doc └── encoder_arch.jpeg ├── examples └── test_forward.py ├── experiment ├── __init__.py ├── base_experiment.py └── docparser_experiment.py ├── logs └── .gitignore ├── models ├── __init__.py ├── config.json ├── configuration_docparser.py ├── convnext.py └── modeling_docparser.py ├── mydatasets ├── __init__.py └── docparser_dataset.py ├── requirements.txt └── train └── train_experiment.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### JetBrains template 2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 4 | 5 | # User-specific stuff 6 | .idea/**/workspace.xml 7 | .idea/**/tasks.xml 8 | .idea/**/usage.statistics.xml 9 | .idea/**/dictionaries 10 | .idea/**/shelf 11 | 12 | # Generated files 13 | .idea/**/contentModel.xml 14 | 15 | # Sensitive or high-churn files 16 | .idea/**/dataSources/ 17 | .idea/**/dataSources.ids 18 | .idea/**/dataSources.local.xml 19 | .idea/**/sqlDataSources.xml 20 | .idea/**/dynamic.xml 21 | .idea/**/uiDesigner.xml 22 | .idea/**/dbnavigator.xml 23 | 24 | # Gradle 25 | .idea/**/gradle.xml 26 | .idea/**/libraries 27 | 28 | # Gradle and Maven with auto-import 29 | # When using Gradle or Maven with auto-import, you should exclude module files, 30 | # since they will be recreated, and may cause churn. Uncomment if using 31 | # auto-import. 32 | # .idea/artifacts 33 | # .idea/compiler.xml 34 | # .idea/jarRepositories.xml 35 | # .idea/modules.xml 36 | # .idea/*.iml 37 | # .idea/modules 38 | # *.iml 39 | # *.ipr 40 | 41 | # CMake 42 | cmake-build-*/ 43 | 44 | # Mongo Explorer plugin 45 | .idea/**/mongoSettings.xml 46 | 47 | # File-based project format 48 | *.iws 49 | 50 | # IntelliJ 51 | out/ 52 | 53 | # mpeltonen/sbt-idea plugin 54 | .idea_modules/ 55 | 56 | # JIRA plugin 57 | atlassian-ide-plugin.xml 58 | 59 | # Cursive Clojure plugin 60 | .idea/replstate.xml 61 | 62 | # Crashlytics plugin (for Android Studio and IntelliJ) 63 | com_crashlytics_export_strings.xml 64 | crashlytics.properties 65 | crashlytics-build.properties 66 | fabric.properties 67 | 68 | # Editor-based Rest Client 69 | .idea/httpRequests 70 | 71 | # Android studio 3.1+ serialized cache file 72 | .idea/caches/build_file_checksums.ser 73 | 74 | ### macOS template 75 | # General 76 | .DS_Store 77 | .AppleDouble 78 | .LSOverride 79 | 80 | # Icon must end with two \r 81 | Icon 82 | 83 | # Thumbnails 84 | ._* 85 | 86 | # Files that might appear in the root of a volume 87 | .DocumentRevisions-V100 88 | .fseventsd 89 | .Spotlight-V100 90 | .TemporaryItems 91 | .Trashes 92 | .VolumeIcon.icns 93 | .com.apple.timemachine.donotpresent 94 | 95 | # Directories potentially created on remote AFP share 96 | .AppleDB 97 | .AppleDesktop 98 | Network Trash Folder 99 | Temporary Items 100 | .apdisk 101 | 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Nuo Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DocParser: End-to-end OCR-free Information Extraction from Visually Rich Documents 2 | 3 | This is an unofficial Pytorch implementation of DocParser. 4 | 5 |
6 | {{ encoder architecture }} 7 |
The architecture of DocParser's Encoder
8 |
9 | 10 | ## News 11 | - **Sep 1st**, release the ConNext weight [here](https://drive.google.com/drive/folders/1ZsvXgULWWm3sR6ZKxGlHmGO5-ESvXl1J?usp=drive_link). Please note that this weight is trained with a CTC head on a OCR task and can only be used to initialize the ConvNext part in the docparser during pretraining. It is NOT intended for fine-tuning in any downstream tasks. 12 | - **July 15th**, update training scripts for Masked Document Reading Task and model architecture. 13 | 14 | ## How to use 15 | ### 1. Set Up Environment 16 | ```shell 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ### 2. Prepare Dataset 21 | The dataset should be processed into the following format 22 | ```json 23 | { 24 | "filepath": "path/to/image/folder", // path to image folder 25 | "filename": "file_name", // file name 26 | "extract_info": { 27 | "ocr_info": [ 28 | { 29 | "chunk": "text1" 30 | }, 31 | { 32 | "chunk": "text2" 33 | }, 34 | { 35 | "chunk": "text3" 36 | } 37 | ] 38 | } // a list of ocr info of filepath/filename 39 | } 40 | ``` 41 | ### 3. Start Training 42 | You can start the training from ```train/train_experiment.py``` or 43 | 44 | ```shell 45 | python train/train_experiment.py --config_file config/base.yaml 46 | ``` 47 | The training script also support ddp with huggingface/accelerate by 48 | ```shell 49 | accelerate train/train_experiment.py --config_file config/base.yaml --use_accelerate True 50 | ``` 51 | ### 4. Notes 52 | The training script currently solely implements the **Masked Document Reading Step** described in the paper. The decoder weights, tokenizer and processor are borrowed from [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base). 53 | 54 | Unfortunately, there is no DocParser pre-training weights publicly available. Simply borrowing weights from Donut-based fails to benefit DocParser on any downstream tasks. But I am working on training a pretraining DocParser based on the two-stage tasks mentioned in the paper recently. Once I successfully complete both the pretraining tasks, and achieve a well-performing model successfully, I intend to make it publicly available on the Huggingface hub. -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | # create: 2021/6/9 -------------------------------------------------------------------------------- /base/common_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: 2021/6/10 3 | 4 | import os 5 | import glob 6 | import codecs 7 | import json 8 | from natsort import natsorted 9 | from base.driver import logger, PROJECT_ROOT_PATH 10 | 11 | 12 | def get_file_list(folder_path: str, p_postfix: list = None, sub_dir: bool = True) -> list: 13 | assert os.path.exists(folder_path) and os.path.isdir(folder_path) 14 | if p_postfix is None: 15 | p_postfix = ['.jpg'] 16 | if isinstance(p_postfix, str): 17 | p_postfix = [p_postfix] 18 | logger.info("begin to get files from:{}".format(folder_path)) 19 | file_list = [ 20 | x for x in glob.glob(folder_path + '/**/*.*', recursive=True) 21 | if os.path.splitext(x)[-1].lower() in p_postfix or '*'.lower() in p_postfix 22 | ] 23 | logger.info("success to get files from:{}".format(folder_path)) 24 | 25 | return natsorted(file_list) 26 | 27 | 28 | def search_file_from_dir(file_dir, file_path_pref_list, base_pref="", exts=["pdf", "PDF"], save_pref=True): 29 | assert isinstance(file_path_pref_list, list) 30 | base_dir_name = os.path.basename(file_dir) 31 | logger.info("searching exts:{} from:{}".format(exts, file_dir)) 32 | if len(base_pref.strip()) == 0: 33 | current_base_pref = "" 34 | else: 35 | current_base_pref = "{}_{}".format(base_pref, base_dir_name) 36 | for file_name in os.listdir(file_dir): 37 | file_path = os.path.join(file_dir, file_name) 38 | if os.path.isdir(file_path): 39 | search_file_from_dir(file_path, file_path_pref_list, current_base_pref, exts, save_pref) 40 | else: 41 | file_res = file_name.rsplit(".", 1) 42 | if len(file_res) == 2: 43 | file_pref, file_ext = file_res 44 | if file_ext.strip().lower() in exts: 45 | if save_pref: 46 | if len(current_base_pref.strip()) > 0: 47 | current_file_pref = "{}_{}".format(current_base_pref, file_pref) 48 | else: 49 | current_file_pref = "{}".format(file_pref) 50 | file_path_pref_list.append((file_path, current_file_pref)) 51 | else: 52 | file_path_pref_list.append(file_path) 53 | 54 | 55 | def get_absolute_file_path(file_path): 56 | if file_path.startswith("/"): 57 | return file_path 58 | else: 59 | return os.path.join(PROJECT_ROOT_PATH, file_path) 60 | 61 | 62 | def get_file_path_list(path, ext=None): 63 | if not path.startswith('/'): 64 | path = os.path.join(PROJECT_ROOT_PATH, path) 65 | # print(path) 66 | assert os.path.exists(path), 'path not exist {}'.format(path) 67 | assert ext is not None, 'ext is None' 68 | if os.path.isfile(path): 69 | return [path] 70 | file_path_list = [] 71 | for root, _, files in os.walk(path): 72 | for file in files: 73 | try: 74 | if file.rsplit('.')[-1].lower() in ext: 75 | file_path_list.append(os.path.join(root, file)) 76 | except Exception as e: 77 | pass 78 | return file_path_list 79 | 80 | 81 | # load json data 82 | def load_json(data): 83 | if isinstance(data, dict): 84 | return [data] 85 | elif isinstance(data, list): 86 | file_path_list = data 87 | elif data.endswith('.json'): 88 | file_path_list = [data] 89 | else: 90 | file_path_list = get_file_path_list(data, '.json') 91 | json_data_list = list() 92 | for file_path in file_path_list: 93 | with codecs.open(file_path, "r", "utf-8") as fr: 94 | json_data = json.loads(fr.read()) 95 | json_data_list.append(json_data) 96 | return json_data_list 97 | 98 | 99 | def save_params(save_dir, save_json, yml_name='config.yaml'): 100 | import yaml 101 | with open(os.path.join(save_dir, yml_name), 'w', encoding='utf-8') as f: 102 | yaml.dump(save_json, f, default_flow_style=False, encoding='utf-8', allow_unicode=True) 103 | 104 | 105 | def read_config(config_file): 106 | import anyconfig 107 | if os.path.exists(config_file): 108 | with open(config_file, "rb") as fr: 109 | config = anyconfig.load(fr) 110 | if 'base' in config: 111 | base_config_path = config['base'] 112 | if not base_config_path.startswith('/'): 113 | base_config_path = os.path.join(PROJECT_ROOT_PATH, base_config_path) 114 | elif os.path.basename(config_file) == 'base.yaml': 115 | return config 116 | else: 117 | base_config_path = os.path.join(os.path.dirname(config_file), "base.yaml") 118 | base_config = read_config(base_config_path) 119 | merged_config = base_config.copy() 120 | merge_config(config, merged_config) 121 | return merged_config 122 | else: 123 | return {} 124 | 125 | 126 | def merge_config(config, base_config): 127 | for key, _ in config.items(): 128 | if isinstance(config[key], dict) and key not in base_config: 129 | base_config[key] = config[key] 130 | elif isinstance(config[key], dict): 131 | merge_config(config[key], base_config[key]) 132 | else: 133 | if key in base_config: 134 | base_config[key] = config[key] 135 | else: 136 | base_config.update({key: config[key]}) 137 | 138 | 139 | def init_experiment_config(config_file, experiment_name): 140 | if not config_file.startswith("/"): 141 | config_file = get_absolute_file_path(config_file) 142 | input_config = read_config(config_file) 143 | experiment_base_config = read_config(os.path.join(PROJECT_ROOT_PATH, 'config', experiment_name.lower(), 144 | 'base.yaml')) 145 | merged_config = experiment_base_config.copy() 146 | merge_config(input_config, merged_config) 147 | 148 | base_config = read_config(os.path.join(PROJECT_ROOT_PATH, 'config', 149 | 'base.yaml')) 150 | final_merged_config = base_config.copy() 151 | merge_config(merged_config, final_merged_config) 152 | return final_merged_config 153 | -------------------------------------------------------------------------------- /base/driver.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: 2021/6/9 3 | 4 | import os 5 | import logging 6 | 7 | PROJECT_ROOT_PATH = os.path.abspath(os.path.join(__file__, '../../')) 8 | 9 | DATA_ROOT = os.path.join(PROJECT_ROOT_PATH, "data") 10 | CACHE_ROOT = os.path.join(DATA_ROOT, "cache") 11 | 12 | logger = logging.getLogger() 13 | stream_handler = logging.StreamHandler() 14 | 15 | log_formatter = logging.Formatter(fmt='%(asctime)s\t%(levelname)s\t%(name)s %(filename)s:%(lineno)s - %(message)s', 16 | datefmt='%Y-%m-%d %H:%M:%S') 17 | stream_handler.setFormatter(log_formatter) 18 | 19 | logger.addHandler(stream_handler) 20 | 21 | logger.setLevel(logging.INFO) 22 | -------------------------------------------------------------------------------- /base/meter.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | # create: 2021/7/16 4 | 5 | 6 | class AverageMeter: 7 | """Computes and stores the average and current value""" 8 | 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | -------------------------------------------------------------------------------- /base/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # datetime:2022/4/28 11:50 上午 4 | # software: PyCharm 5 | -------------------------------------------------------------------------------- /base/torch_utils/dl_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: 2021/6/17 3 | 4 | import json 5 | import math 6 | import torch 7 | import random 8 | import numpy as np 9 | from torch import nn 10 | import torch.optim as optim 11 | from base.driver import logger 12 | from collections import OrderedDict 13 | from torch.optim import lr_scheduler 14 | from timm.scheduler.cosine_lr import CosineLRScheduler 15 | from timm.scheduler.step_lr import StepLRScheduler 16 | from transformers.optimization import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, \ 17 | get_linear_schedule_with_warmup 18 | from base.torch_utils.scheduler_util import LinearLRScheduler, get_cosine_schedule_by_epochs, \ 19 | get_stairs_schedule_with_warmup 20 | 21 | 22 | def seed_all(random_seed): 23 | if random_seed is not None: 24 | random.seed(random_seed) 25 | np.random.seed(random_seed) 26 | torch.manual_seed(random_seed) 27 | torch.cuda.manual_seed_all(random_seed) 28 | torch.backends.cudnn.deterministic = True 29 | 30 | 31 | def print_network(net, verbose=False, name=""): 32 | num_params = 0 33 | for param in net.parameters(): 34 | num_params += param.numel() 35 | if verbose: 36 | logger.info(net) 37 | if hasattr(net, 'flops'): 38 | flops = net.flops() 39 | logger.info(f"number of GFLOPs: {flops / 1e9}") 40 | logger.info('network:{} Total number of parameters: {}'.format(name, num_params)) 41 | 42 | 43 | def check_keywords_in_name(name, keywords=()): 44 | isin = False 45 | for keyword in keywords: 46 | if keyword in name: 47 | isin = True 48 | return isin 49 | 50 | 51 | def get_grad_norm(parameters, norm_type=2): 52 | if isinstance(parameters, torch.Tensor): 53 | parameters = [parameters] 54 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 55 | norm_type = float(norm_type) 56 | total_norm = 0 57 | for p in parameters: 58 | param_norm = p.grad.data.norm(norm_type) 59 | total_norm += param_norm.item()**norm_type 60 | total_norm = total_norm**(1. / norm_type) 61 | return total_norm 62 | 63 | 64 | def set_params_optimizer(model, keyword=None, keywords=None, weight_decay=0.0, lr=None): 65 | if keywords is None: 66 | keywords = [] 67 | param_dict = OrderedDict() 68 | no_decay_param_names = [] 69 | for name, param in model.named_parameters(): 70 | if not param.requires_grad: 71 | continue # frozen weights 72 | if keyword in name or check_keywords_in_name(name, keywords): 73 | param_dict[name] = {"weight_decay": weight_decay} 74 | if lr is not None: 75 | lr = float(lr) 76 | param_dict[name].update({"lr": lr}) 77 | else: 78 | no_decay_param_names.append(name) 79 | return param_dict, no_decay_param_names 80 | 81 | 82 | def get_optimizer(model, 83 | optimizer_type="adam", 84 | lr=0.001, 85 | beta1=0.9, 86 | beta2=0.999, 87 | no_decay_keys=None, 88 | weight_decay=0.0, 89 | layer_decay=None, 90 | eps=1e-8, 91 | momentum=0, 92 | params=None, 93 | **kwargs): 94 | assigner = None 95 | if layer_decay is not None: 96 | if layer_decay < 1.0: 97 | num_layers = kwargs.get('num_layers') 98 | assigner = LayerDecayValueAssigner(list(layer_decay**(num_layers + 1 - i) for i in range(num_layers + 2))) 99 | 100 | lr = float(lr) 101 | beta1, beta2 = float(beta1), float(beta2) 102 | weight_decay = float(weight_decay) 103 | momentum = float(momentum) 104 | eps = float(eps) 105 | freeze_params = kwargs.get('freeze_params', []) 106 | img_lr = float(kwargs.get('img_lr', lr)) 107 | for name, param in model.named_parameters(): 108 | freeze_flag = False 109 | for freeze_param in freeze_params: 110 | if freeze_param in name: 111 | freeze_flag = True 112 | break 113 | if freeze_flag: 114 | print("name={} param.requires_grad = False".format(name)) 115 | param.requires_grad = False 116 | 117 | if params is None: 118 | if weight_decay: 119 | skip = {} 120 | if no_decay_keys is not None: 121 | skip = no_decay_keys 122 | elif hasattr(model, 'no_weight_decay'): 123 | skip = model.no_weight_decay() 124 | param_configs = get_parameter_groups(model, img_lr, weight_decay, skip, assigner) 125 | weight_decay = 0. 126 | else: 127 | param_configs = model.parameters() 128 | else: 129 | param_configs = params 130 | if optimizer_type == "sgd": 131 | optimizer = optim.SGD(param_configs, momentum=momentum, nesterov=True, lr=lr, weight_decay=weight_decay) 132 | elif optimizer_type == "adam": 133 | optimizer = optim.Adam(param_configs, lr=lr, betas=(beta1, beta2), eps=eps, weight_decay=weight_decay) 134 | elif optimizer_type == "adadelta": 135 | optimizer = optim.Adadelta(param_configs, lr=lr, eps=eps, weight_decay=weight_decay) 136 | elif optimizer_type == "rmsprob": 137 | optimizer = optim.RMSprop(param_configs, lr=lr, eps=eps, weight_decay=weight_decay, momentum=momentum) 138 | elif optimizer_type == "adamw": 139 | optimizer = optim.AdamW(param_configs, lr=lr, betas=(beta1, beta2), eps=eps, weight_decay=weight_decay) 140 | else: 141 | return NotImplementedError('learning rate policy [%s] is not implemented', optimizer_type) 142 | return optimizer 143 | 144 | 145 | def get_scheduler(optimizer, 146 | scheduler_type="linear", 147 | num_warmup_steps=0, 148 | num_training_steps=10000, 149 | last_epoch=-1, 150 | step_size=10, 151 | gamma=0.1, 152 | epochs=20, 153 | **kwargs): 154 | gamma = float(gamma) 155 | if scheduler_type == "cosine": 156 | scheduler = get_cosine_schedule_with_warmup(optimizer, 157 | num_warmup_steps=num_warmup_steps, 158 | num_training_steps=num_training_steps, 159 | last_epoch=last_epoch) 160 | elif scheduler_type == 'cosine_epoch': 161 | scheduler = get_cosine_schedule_by_epochs(optimizer, num_epochs=epochs, last_epoch=last_epoch) 162 | elif scheduler_type == "linear": 163 | scheduler = get_linear_schedule_with_warmup(optimizer, 164 | num_warmup_steps=num_warmup_steps, 165 | num_training_steps=num_training_steps, 166 | last_epoch=last_epoch) 167 | elif scheduler_type == "stairs": 168 | logger.info("current use stair scheduler") 169 | scheduler = get_stairs_schedule_with_warmup(optimizer, 170 | num_warmup_steps=num_warmup_steps, 171 | num_training_steps=num_training_steps, 172 | last_epoch=last_epoch, 173 | **kwargs) 174 | elif scheduler_type == "step": 175 | step_size = int(step_size) 176 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 177 | elif scheduler_type == "exponential": 178 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=gamma) 179 | """ 180 | def exp_decay(epoch): 181 | initial_lrate = 0.1 182 | k = 0.1 183 | lrate = initial_lrate * exp(-k*t) 184 | return lrate 185 | """ 186 | 187 | else: 188 | scheduler = get_constant_schedule_with_warmup(optimizer, 189 | num_warmup_steps=num_warmup_steps, 190 | last_epoch=last_epoch) 191 | return scheduler 192 | 193 | 194 | def get_scheduler2(optimizer, 195 | scheduler_type="cosine", 196 | num_warmup_steps=0, 197 | num_training_steps=10000, 198 | decay_steps=1000, 199 | decay_rate=0.1, 200 | lr_min=5e-6, 201 | warmup_lr=5e-7): 202 | lr_min = float(lr_min) 203 | warmup_lr = float(warmup_lr) 204 | decay_rate = float(decay_rate) 205 | if scheduler_type == "cosine": 206 | scheduler = CosineLRScheduler(optimizer, 207 | t_initial=num_training_steps, 208 | t_mul=1, 209 | lr_min=lr_min, 210 | warmup_lr_init=warmup_lr, 211 | cycle_limit=1, 212 | t_in_epochs=False) 213 | elif scheduler_type == "linear": 214 | scheduler = LinearLRScheduler(optimizer, 215 | t_initial=num_training_steps, 216 | lr_min_rate=0.01, 217 | warmup_lr_init=warmup_lr, 218 | warmup_t=num_warmup_steps, 219 | t_in_epochs=False) 220 | else: 221 | scheduler = StepLRScheduler(optimizer, 222 | decay_t=decay_steps, 223 | decay_rate=decay_rate, 224 | warmup_lr_init=warmup_lr, 225 | warmup_t=num_warmup_steps, 226 | t_in_epochs=False) 227 | return scheduler 228 | 229 | 230 | def one_cycle(y1=0.0, y2=1.0, steps=100): 231 | # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf 232 | return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 233 | 234 | 235 | def get_scheduler_yolo(optimizer, cos_lr=True, lrf=0.1, epochs=20, last_epoch=-1, **kwargs): 236 | if cos_lr: 237 | lf = one_cycle(1, lrf, epochs) # cosine 1->lrf 238 | else: 239 | lf = lambda x: (1 - x / epochs) * (1.0 - lrf) + lrf # linear 240 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf, 241 | last_epoch=last_epoch) # plot_lr_scheduler(optimizer, scheduler, epochs) 242 | return scheduler, lf 243 | 244 | 245 | def get_tensorboard_texts(label_texts): 246 | new_labels = [] 247 | for label_text in label_texts: 248 | new_labels.append(label_text.replace("/", "//").replace("<", "/<").replace(">", "/>")) 249 | return " \n".join(new_labels) 250 | 251 | 252 | def get_parameter_groups(model, img_lr, weight_decay, skip_list=(), assigner=None): 253 | parameter_group_names = {} 254 | parameter_group_vars = {} 255 | 256 | for name, param in model.named_parameters(): 257 | if not param.requires_grad: 258 | continue # frozen weights 259 | # TODO 是否通用? 260 | if 'image_encoder' in name: 261 | if len(param.shape) == 1 or name.endswith(".bias") or name.split('.')[-1] in skip_list: 262 | group_name = "img_encoder_no_decay" 263 | this_weight_decay = 0. 264 | else: 265 | group_name = "img_encoder_decay" 266 | this_weight_decay = weight_decay 267 | if assigner is not None: 268 | layer_id = assigner.get_layer_id(name) 269 | group_name = "layer_%d_%s" % (layer_id, group_name) 270 | else: 271 | layer_id = None 272 | 273 | if group_name not in parameter_group_names: 274 | if assigner is not None: 275 | scale = assigner.get_scale(layer_id) 276 | else: 277 | scale = 1. 278 | 279 | parameter_group_names[group_name] = { 280 | "weight_decay": this_weight_decay, 281 | "params": [], 282 | "lr_scale": scale, 283 | "lr": img_lr 284 | } 285 | parameter_group_vars[group_name] = { 286 | "weight_decay": this_weight_decay, 287 | "params": [], 288 | "lr_scale": scale, 289 | "lr": img_lr 290 | } 291 | else: 292 | if len(param.shape) == 1 or name.endswith(".bias") or name.split('.')[-1] in skip_list: 293 | group_name = "no_decay" 294 | this_weight_decay = 0. 295 | else: 296 | group_name = "decay" 297 | this_weight_decay = weight_decay 298 | if assigner is not None: 299 | layer_id = assigner.get_layer_id(name) 300 | group_name = "layer_%d_%s" % (layer_id, group_name) 301 | else: 302 | layer_id = None 303 | 304 | if group_name not in parameter_group_names: 305 | if assigner is not None: 306 | scale = assigner.get_scale(layer_id) 307 | else: 308 | scale = 1. 309 | 310 | parameter_group_names[group_name] = {"weight_decay": this_weight_decay, "params": [], "lr_scale": scale} 311 | parameter_group_vars[group_name] = {"weight_decay": this_weight_decay, "params": [], "lr_scale": scale} 312 | 313 | parameter_group_vars[group_name]["params"].append(param) 314 | parameter_group_names[group_name]["params"].append(name) 315 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 316 | return list(parameter_group_vars.values()) 317 | 318 | 319 | class LayerDecayValueAssigner(object): 320 | 321 | def __init__(self, values): 322 | self.values = values 323 | 324 | def get_scale(self, layer_id): 325 | return self.values[layer_id] 326 | 327 | def get_layer_id(self, var_name): 328 | return get_num_layer(var_name, len(self.values)) 329 | 330 | 331 | def get_num_layer(var_name, num_max_layer): 332 | var_name = var_name.split('.', 1)[-1] 333 | if var_name.startswith("embeddings"): 334 | return 0 335 | elif var_name.startswith("encoder.layer"): 336 | layer_id = int(var_name.split('.')[2]) 337 | return layer_id + 1 338 | else: 339 | return num_max_layer - 1 340 | -------------------------------------------------------------------------------- /base/torch_utils/scheduler_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: 2021/10/9 3 | import math 4 | import torch 5 | from torch.optim import Optimizer 6 | from timm.scheduler.scheduler import Scheduler 7 | 8 | from torch.optim.lr_scheduler import LambdaLR 9 | 10 | 11 | def get_stairs_schedule_with_warmup(optimizer, 12 | num_warmup_steps, 13 | num_training_steps, 14 | stair_num=2, 15 | min_scale=0.01, 16 | last_epoch=-1, 17 | **kwargs): 18 | """ 19 | Create a stair schedule with a learning rate that from the initial lr set in the optimizer to 0, after 20 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 21 | and then with duplicate stairs, more train step will be allocated to a smaller learning rates. 22 | decrease stage like this with learning rate:4e-4, stair_num:3, remain_steps: 1400 23 | then learning rate is: 24 | from 0-100, 4e-4 25 | from 100-200, decrease from 4e-4 to 2e-4 26 | from 200-400 2e-4 27 | from 400-600 decrease from 2e-4 to 1e-4 28 | from 600-1000 1e-4 29 | from 1000-1400 decrease from 1e-4 to 0 30 | as following: 31 | ___ 32 | / \ ____ 33 | / \ ___ 34 | / \ 35 | Args: 36 | optimizer (:class:`~torch.optim.Optimizer`): 37 | The optimizer for which to schedule the learning rate. 38 | num_warmup_steps (:obj:`int`): 39 | The number of steps for the warmup phase. 40 | num_training_steps (:obj:`int`): 41 | The total number of training steps. 42 | stair_num: int 43 | min_scale: min learning_rate ratio 44 | last_epoch (:obj:`int`, `optional`, defaults to -1): 45 | The index of the last epoch when resuming training. 46 | 47 | Return: 48 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 49 | """ 50 | 51 | stair_num = int(stair_num) 52 | min_scale = float(min_scale) 53 | remain_step = max(1, num_training_steps - num_warmup_steps) 54 | unit_step = int(remain_step / (2**(stair_num + 1) - 2)) 55 | remain_linear_step = remain_step - unit_step * int(2**stair_num - 1) 56 | stair_steps = [ 57 | unit_step * int((3 * 2**(i / 2) - 2)) if i % 2 == 0 else unit_step * int(4 * 2**((i - 1) / 2) - 2) 58 | for i in range(2 * stair_num) 59 | ] 60 | stair_scales = [1.0 - (2**i - 1) / (2**stair_num - 1) for i in range(stair_num)] 61 | 62 | def lr_lambda(current_step: int): 63 | if current_step < num_warmup_steps: 64 | return float(current_step) / float(max(1, num_warmup_steps)) 65 | 66 | current_remain_step = current_step - num_warmup_steps 67 | for i in range(stair_num * 2): 68 | if i == 0: 69 | prev_stair_step = 0 70 | else: 71 | prev_stair_step = stair_steps[i - 1] 72 | stair_step = stair_steps[i] 73 | if prev_stair_step <= current_remain_step <= stair_step: 74 | if i % 2 == 0: 75 | return max(min_scale, stair_scales[i // 2]) 76 | else: 77 | prev_linear_step = unit_step * int(2**((i - 1) / 2) - 1) 78 | current_linear_step = current_remain_step - prev_stair_step + prev_linear_step 79 | linear_lr = float(remain_linear_step - current_linear_step) / float(remain_linear_step) 80 | return max(min_scale, linear_lr) 81 | return min_scale 82 | 83 | return LambdaLR(optimizer, lr_lambda, last_epoch) 84 | 85 | 86 | class LinearLRScheduler(Scheduler): 87 | 88 | def __init__( 89 | self, 90 | optimizer: torch.optim.Optimizer, 91 | t_initial: int, 92 | lr_min_rate: float, 93 | warmup_t=0, 94 | warmup_lr_init=0., 95 | t_in_epochs=True, 96 | noise_range_t=None, 97 | noise_pct=0.67, 98 | noise_std=1.0, 99 | noise_seed=42, 100 | initialize=True, 101 | ) -> None: 102 | super().__init__(optimizer, 103 | param_group_field="lr", 104 | noise_range_t=noise_range_t, 105 | noise_pct=noise_pct, 106 | noise_std=noise_std, 107 | noise_seed=noise_seed, 108 | initialize=initialize) 109 | 110 | self.t_initial = t_initial 111 | self.lr_min_rate = lr_min_rate 112 | self.warmup_t = warmup_t 113 | self.warmup_lr_init = warmup_lr_init 114 | self.t_in_epochs = t_in_epochs 115 | if self.warmup_t: 116 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 117 | super().update_groups(self.warmup_lr_init) 118 | else: 119 | self.warmup_steps = [1 for _ in self.base_values] 120 | 121 | def _get_lr(self, t): 122 | if t < self.warmup_t: 123 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 124 | else: 125 | t = t - self.warmup_t 126 | total_t = self.t_initial - self.warmup_t 127 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 128 | return lrs 129 | 130 | def get_epoch_values(self, epoch: int): 131 | if self.t_in_epochs: 132 | return self._get_lr(epoch) 133 | else: 134 | return None 135 | 136 | def get_update_values(self, num_updates: int): 137 | if not self.t_in_epochs: 138 | return self._get_lr(num_updates) 139 | else: 140 | return None 141 | 142 | 143 | # update lr by epoch 144 | def get_cosine_schedule_by_epochs(optimizer: Optimizer, num_epochs: int, last_epoch: int = -1, **kwargs): 145 | 146 | def lr_lambda(epoch): 147 | lf = ((1 + math.cos(epoch * math.pi / num_epochs)) / 2) * 0.8 + 0.2 # cosine 148 | return lf 149 | 150 | return LambdaLR(optimizer, lr_lambda, last_epoch) 151 | -------------------------------------------------------------------------------- /config/base.yaml: -------------------------------------------------------------------------------- 1 | name: docparser-base # experiment name 2 | model: 3 | type: DocParser 4 | pretrained_model_name_or_path: /models 5 | # encoder 6 | image_size: [2560, 1920] # the input image size of docparser 7 | # decoder 8 | max_length: 1024 # the max input length of docparser 9 | decoder_layers: 2 # the decoder layer 10 | 11 | model_path: ~ # path to a certain checkpoint 12 | load_strict: true # whether to strictly load the checkpoint 13 | 14 | # training precision 15 | mixed_precision: "fp16" # "["no", "fp16", "bf16] # use torch native amp 16 | 17 | tokenizer_args: 18 | pretrained_model_name_or_path: naver-clova-ix/donut-base # we borrow tokenizer & image processor from donut 19 | extra_args: {} 20 | 21 | predictor: 22 | img_paths: 23 | - 24 | save_dir: /data/data/cache 25 | 26 | trainer: 27 | start_global_step: -1 # start training from a certain global step; -1 means no starting global step is set 28 | resume_flag: false # whether to resume the training from a certain checkpoint 29 | random_seed: ~ 30 | grad_clip: 1.0 31 | epochs: 5 32 | 33 | # tensorboard configuration 34 | save_dir: /logs/docparser 35 | tensorboard_dir: /logs/docparser/tensorboard 36 | 37 | # display configuration 38 | save_epoch_freq: 1 39 | save_step_freq: 800 40 | print_freq: 20 41 | 42 | # gradient configuration 43 | grad_accumulate: 1 # gradient accumulation 44 | 45 | # optimizer configuration 46 | optimizer: 47 | optimizer_type: "adamw" 48 | lr: 1.0e-04 49 | # layer_decay: 0.75 50 | weight_decay: 0.05 51 | beta1: 0.9 52 | beta2: 0.98 53 | eps: 1.0e-6 54 | 55 | # scheduler configuration 56 | scheduler: 57 | scheduler_type: "cosine" 58 | warmup_steps: 2000 59 | warmup_epochs: 0 60 | 61 | datasets: 62 | train: 63 | dataset: 64 | type: DocParser 65 | task_start_token: 66 | data_root: 67 | - # put your dataset path here 68 | num_workers: 0 69 | batch_size: 1 # global batch = bz * num_gpu * grad 70 | shuffle: true 71 | collate_fn: 72 | type: DataCollatorForDocParserDataset -------------------------------------------------------------------------------- /doc/encoder_arch.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NormXU/DocParser-Pytorch/6e11ea5fc211b1ccc37f51b2b0f64baea68fcf83/doc/encoder_arch.jpeg -------------------------------------------------------------------------------- /examples/test_forward.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: @time: 7/6/23 16:38 3 | import torch 4 | from transformers import VisionEncoderDecoderConfig, AutoConfig, VisionEncoderDecoderModel, AutoModel 5 | from models import DocParserModel, DocParserConfig 6 | 7 | if __name__ == '__main__': 8 | AutoConfig.register("docparser-swin", DocParserConfig) 9 | AutoModel.register(DocParserConfig, DocParserModel) 10 | 11 | config = VisionEncoderDecoderConfig.from_pretrained("../models/") 12 | model = VisionEncoderDecoderModel(config=config) 13 | 14 | # test forward with dummy input 15 | input_tensor = torch.ones(1, 3, 2560, 1920) 16 | output = model(input_tensor) 17 | -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: 2021/6/8 3 | from .base_experiment import BaseExperiment 4 | from .docparser_experiment import DocParserExperiment 5 | 6 | def get_experiment_name(name): 7 | name_split = name.split("_") 8 | trainer_name = "".join([tmp_name[0].upper() + tmp_name[1:] for tmp_name in name_split]) 9 | return "{}Experiment".format(trainer_name) -------------------------------------------------------------------------------- /experiment/base_experiment.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: 2021/7/2 3 | import copy 4 | import itertools 5 | import json 6 | import logging 7 | import os 8 | from contextlib import nullcontext 9 | 10 | import munch 11 | import torch 12 | from accelerate import Accelerator 13 | from torch import autocast 14 | from torch.utils.data import DataLoader 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | import mydatasets 18 | from base.common_util import get_absolute_file_path, merge_config, get_file_path_list 19 | from base.common_util import save_params 20 | from base.driver import log_formatter 21 | from base.driver import logger 22 | from base.torch_utils.dl_util import get_optimizer, get_scheduler, get_scheduler2, seed_all, get_grad_norm 23 | from mydatasets import get_dataset 24 | 25 | 26 | class BaseExperiment(object): 27 | 28 | def __init__(self, config): 29 | config = self._init_config(config) 30 | self.experiment_name = config["name"] 31 | self.args = munch.munchify(config) 32 | self.init_device(config) 33 | self.init_random_seed(config) 34 | self.init_model(config) 35 | self.init_dataset(config) 36 | self.init_trainer_args(config) 37 | self.init_predictor_args(config) 38 | self.prepare_accelerator() 39 | 40 | """ 41 | Main Block 42 | """ 43 | 44 | def predict(self, **kwargs): 45 | request_property = kwargs.get('request_property') 46 | pass 47 | 48 | def evaluate(self, **kwargs): 49 | pass 50 | 51 | def train(self, **kwargs): 52 | pass 53 | 54 | def _step_forward(self, batch, is_train=True, eval_model=None, **kwargs): 55 | pass 56 | 57 | def _step_backward(self, loss, **kwargs): 58 | if self.use_torch_amp: 59 | self.mixed_scaler.scale(loss).backward() 60 | else: 61 | if self.accelerator is not None: 62 | self.accelerator.backward(loss) 63 | else: 64 | loss = loss / self.args.trainer.grad_accumulate 65 | loss.backward() 66 | 67 | def _get_current_lr(self, ni, global_step=0, **kwargs): 68 | if self.args.trainer.scheduler_type == "scheduler2": 69 | current_lr = self.scheduler.get_update_values(global_step)[-1] 70 | else: 71 | current_lr = self.scheduler.get_last_lr()[-1] 72 | return current_lr 73 | 74 | def _step_optimizer(self, **kwargs): 75 | params_to_clip = (itertools.chain(self.model.parameters())) 76 | for param_group in self.optimizer.param_groups: 77 | if "lr_scale" in param_group: 78 | param_group["lr"] = param_group["lr"] * param_group["lr_scale"] 79 | grad_norm = None 80 | if self.args.trainer.grad_clip is not None: 81 | if self.use_torch_amp: 82 | # Unscales the gradients of optimizer's assigned params in-place 83 | # called only after all gradients for that optimizer’s assigned parameters have been accumulated 84 | self.mixed_scaler.unscale_(self.optimizer) 85 | grad_norm = torch.nn.utils.clip_grad_norm_(params_to_clip, self.args.trainer.grad_clip) 86 | self.mixed_scaler.step(self.optimizer) 87 | # Updates the scale for next iteration. 88 | self.mixed_scaler.update() 89 | if self.accelerator: 90 | if self.accelerator.sync_gradients: 91 | grad_norm = self.accelerator.clip_grad_norm_(params_to_clip, self.args.trainer.grad_clip) 92 | self.optimizer.step() 93 | if grad_norm is None: 94 | grad_norm = get_grad_norm(params_to_clip) 95 | self.optimizer.step() 96 | 97 | self.optimizer.zero_grad() 98 | return grad_norm 99 | 100 | def _step_scheduler(self, global_step, **kwargs): 101 | if self.args.trainer.scheduler_type == "scheduler2": 102 | self.scheduler.step_update(global_step) 103 | else: 104 | self.scheduler.step() 105 | 106 | """ 107 | Initialization Functions 108 | """ 109 | 110 | # config的联动关系可以写在这个函数中 111 | def _init_config(self, config): 112 | if 'trainer' in config and config.get('phase', 'train') == 'train': 113 | trainer_args = config["trainer"] 114 | trainer_args['save_dir'] = get_absolute_file_path(trainer_args.get("save_dir")) 115 | os.makedirs(trainer_args['save_dir'], exist_ok=True) 116 | save_params(trainer_args['save_dir'], config) 117 | train_log_path = os.path.join(trainer_args['save_dir'], "{}.log".format(config['name'])) 118 | file_handler = logging.FileHandler(train_log_path) 119 | file_handler.setLevel(logging.INFO) 120 | file_handler.setFormatter(log_formatter) 121 | logger.addHandler(file_handler) 122 | return config 123 | 124 | def init_device(self, config): 125 | # ADD RUN_ON_GPU_IDs=-1是cpu,多张默认走accelerator 126 | self.args.device = munch.munchify(config.get('device', {})) 127 | self.accelerator = None 128 | self.weight_dtype = torch.float32 129 | self.gradient_accumulate_scope = nullcontext 130 | self.precision_scope = nullcontext() 131 | self.use_torch_amp = False 132 | 133 | # accelerator configuration 134 | if config['use_accelerate']: 135 | # If you define multiple visible GPU, I suppose you to use accelerator to do ddp training 136 | self.accelerator = Accelerator( 137 | gradient_accumulation_steps=int(self.args.trainer.grad_accumulate), 138 | mixed_precision=self.args.model.mixed_precision) 139 | self.args.device.device_id = self.accelerator.device 140 | self.args.device.device_ids = [] 141 | if self.accelerator.mixed_precision == "fp16": 142 | self.weight_dtype = torch.float16 143 | elif self.accelerator.mixed_precision == "bf16": 144 | self.weight_dtype = torch.bfloat16 145 | self.gradient_accumulate_scope = self.accelerator.accumulate 146 | self.args.device.is_master = self.accelerator.is_main_process 147 | self.args.device.is_distributed = self.accelerator.num_processes > 1 148 | elif os.environ.get("RUN_ON_GPU_IDs", 0) == str(-1): 149 | # load model with CPU 150 | self.args.device.device_id = torch.device("cpu") 151 | self.args.device.device_ids = [-1] 152 | self.args.device.is_master = True 153 | self.args.device.is_distributed = False 154 | else: 155 | # USE one GPU specified by user w/o using accelerate 156 | device_id = os.environ.get("RUN_ON_GPU_IDs", 0) 157 | self.args.device.device_id = torch.device("cuda:{}".format(device_id)) 158 | self.args.device.device_ids = [int(device_id)] 159 | torch.cuda.set_device(int(device_id)) 160 | self.args.device.is_master = True 161 | self.args.device.is_distributed = False 162 | if self.args.model.mixed_precision in ["fp16", "bf16"]: 163 | # ADD mixed_precision_flag改为use_torch_amp 164 | self.use_torch_amp = True 165 | self.weight_dtype = torch.float16 if self.args.model.mixed_precision == "fp16" else torch.bfloat16 166 | self.precision_scope = autocast(device_type="cuda", dtype=self.weight_dtype) 167 | logger.info("device:{}, is_master:{}, device_ids:{}, is_distributed:{}".format( 168 | self.args.device.device_id, self.args.device.is_master, self.args.device.device_ids, 169 | self.args.device.is_distributed)) 170 | 171 | def init_model(self, config): 172 | pass 173 | 174 | def init_dataset(self, config): 175 | if 'datasets' in config and config.get('phase', 'train') != 'predict': 176 | dataset_args = config.get("datasets") 177 | train_data_loader_args = dataset_args.get("train") 178 | if config.get('phase', 'train') == 'train': 179 | self.train_dataset = get_dataset(train_data_loader_args['dataset']) 180 | self.train_data_loader = self._get_data_loader_from_dataset(self.train_dataset, 181 | train_data_loader_args, 182 | phase='train') 183 | logger.info("success init train data loader len:{} ".format(len(self.train_data_loader))) 184 | eval_data_loader_args = dataset_args.get("eval") 185 | merged_eval_data_loader_args = train_data_loader_args.copy() 186 | merge_config(eval_data_loader_args, merged_eval_data_loader_args) 187 | self.eval_dataset = get_dataset(merged_eval_data_loader_args['dataset']) 188 | self.eval_data_loader = self._get_data_loader_from_dataset(self.eval_dataset, 189 | merged_eval_data_loader_args, 190 | phase='eval') 191 | logger.info("success init eval data loader len:{}".format(len(self.eval_data_loader))) 192 | 193 | def init_random_seed(self, config): 194 | if 'random_seed' in config['trainer']: 195 | seed_all(config['trainer']['random_seed']) 196 | else: 197 | logger.warning("random seed is missing") 198 | 199 | def init_predictor_args(self, config): 200 | if 'predictor' in config and config.get('phase', 'train') == 'predict': 201 | predictor_args = config["predictor"] 202 | self.args.predictor.img_paths = predictor_args.get("img_paths", None) 203 | if self.args.predictor.img_paths is None: 204 | self.args.predictor.img_paths = [] 205 | img_dirs = predictor_args.get("img_dirs", None) 206 | if img_dirs: 207 | for img_dir in img_dirs: 208 | if img_dir: 209 | self.args.predictor.img_paths.extend(get_file_path_list(img_dir, ['jpg', 'png', 'jpeg'])) 210 | if predictor_args['save_dir'] is None and 'model_path' in config['model'] and config['model'][ 211 | 'model_path'] is not None: 212 | predictor_args['save_dir'] = os.path.join(os.path.dirname(config['model']['model_path']), 213 | 'test_results') 214 | self.args.predictor.save_dir = get_absolute_file_path(predictor_args["save_dir"]) 215 | os.makedirs(self.args.predictor.save_dir, exist_ok=True) 216 | 217 | def init_trainer_args(self, config): 218 | if 'trainer' in config and config.get('phase', 'train') == 'train': 219 | trainer_args = config["trainer"] 220 | self._init_optimizer(trainer_args) 221 | self._init_scheduler(trainer_args) 222 | logger.info("current trainer epochs:{}, train_dataset_len:{}, data_loader_len:{}".format( 223 | self.args.trainer.epochs, len(self.train_dataset), len(self.train_data_loader))) 224 | self.mixed_scaler = torch.cuda.amp.GradScaler(enabled=True) if self.use_torch_amp else None 225 | self.args.trainer.best_eval_result = -1 226 | self.args.trainer.best_model_path = '' 227 | self.args.trainer.start_epoch = 0 228 | self.args.trainer.start_global_step = 0 229 | if self.args.trainer.resume_flag and 'model_path' in self.args.model and self.args.model.model_path is not None: 230 | resume_path = self.args.model.model_path.replace('.pth', '_resume.pth') 231 | if os.path.exists(resume_path): 232 | resume_checkpoint = torch.load(resume_path) 233 | self.optimizer.load_state_dict(resume_checkpoint['optimizer_state_dict']) 234 | self.scheduler.load_state_dict(resume_checkpoint['scheduler_state_dict']) 235 | self.args.trainer.start_epoch = resume_checkpoint['epoch'] 236 | self.args.trainer.start_global_step = resume_checkpoint['global_step'] 237 | else: 238 | logger.warning("resume path {} doesn't exist: failed to resume!!".format(resume_path)) 239 | 240 | if 'trainer' in config and config.get('phase', 'train') != 'predict': 241 | trainer_args = config["trainer"] 242 | self._init_criterion(trainer_args) 243 | # init tensorboard and log 244 | if "tensorboard_dir" in trainer_args and self.args.device.is_master: 245 | tensorboard_log_dir = get_absolute_file_path(trainer_args.get("tensorboard_dir")) 246 | os.makedirs(tensorboard_log_dir, exist_ok=True) 247 | self.writer = SummaryWriter(log_dir=tensorboard_log_dir, comment=self.experiment_name) 248 | else: 249 | self.writer = None 250 | 251 | def _init_optimizer(self, trainer_args, **kwargs): 252 | optimizer_args = trainer_args.get("optimizer") 253 | # ADD scale lr 254 | if optimizer_args["scale_lr"]: 255 | num_process = 1 if self.accelerator is None else self.accelerator.num_processes 256 | optimizer_args['lr'] = optimizer_args['lr'] * self.args.trainer.grad_accumulate * \ 257 | self.train_data_loader.batch_size * num_process 258 | self.optimizer = get_optimizer(self.model, **optimizer_args) 259 | 260 | def _init_scheduler(self, trainer_args, **kwargs): 261 | scheduler_args = trainer_args.get("scheduler") 262 | self.args.trainer.scheduler_by_epoch = scheduler_args.get("scheduler_by_epoch", False) 263 | total_epoch_train_steps = len(self.train_data_loader) 264 | if scheduler_args["warmup_epochs"] > 0: 265 | warmup_steps = scheduler_args.get("warmup_epochs") * total_epoch_train_steps 266 | elif scheduler_args['warmup_steps'] > 0: 267 | warmup_steps = scheduler_args.get("warmup_steps") 268 | else: 269 | warmup_steps = 0 270 | self.args.trainer.scheduler.warmup_steps = warmup_steps 271 | num_training_steps = total_epoch_train_steps * self.args.trainer.epochs 272 | if self.accelerator is None: 273 | # accelerator will automatically take care of the grad accumulate in calculating total num_training steps, 274 | # or you need to calculate by yourself 275 | num_training_steps = num_training_steps // self.args.trainer.grad_accumulate 276 | if "scheduler_method" in scheduler_args and scheduler_args["scheduler_method"] == "get_scheduler2": 277 | self.scheduler = get_scheduler2(self.optimizer, 278 | num_training_steps=num_training_steps, 279 | num_warmup_steps=warmup_steps, 280 | **scheduler_args) 281 | self.args.trainer.scheduler_type = "scheduler2" 282 | else: 283 | self.scheduler = get_scheduler(self.optimizer, 284 | num_training_steps=num_training_steps, 285 | num_warmup_steps=warmup_steps, 286 | epochs=self.args.trainer.epochs, 287 | **scheduler_args) 288 | self.args.trainer.scheduler_type = "scheduler" 289 | 290 | logger.info( 291 | "success init optimizer and scheduler, optimizer:{}, scheduler:{}, scheduler_args:{}, warmup_steps:{}," 292 | "num_training_steps:{}, gradient_accumulator:{}".format(self.optimizer, self.scheduler, scheduler_args, 293 | warmup_steps, num_training_steps, 294 | self.args.trainer.grad_accumulate)) 295 | 296 | def _init_criterion(self, trainer_args): 297 | pass 298 | 299 | def _init_metric(self, **kwargs): 300 | pass 301 | 302 | """ 303 | Tool Functions 304 | """ 305 | 306 | def load_model(self, checkpoint_path, strict=True, **kwargs): 307 | if os.path.exists(checkpoint_path) and os.path.isfile(checkpoint_path): 308 | state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) 309 | if 'model_state_dict' in state_dict: 310 | model_state_dict = state_dict['model_state_dict'] 311 | else: 312 | model_state_dict = state_dict 313 | self.model.load_state_dict(model_state_dict, strict=strict) 314 | logger.info("success load model:{}".format(checkpoint_path)) 315 | 316 | def save_model(self, checkpoint_path, **save_kwargs): 317 | if self.accelerator is not None: 318 | unwrapped_model = self.accelerator.unwrap_model(self.model) 319 | if self.args.trainer.resume_flag: 320 | save_kwargs.update({ 321 | 'model_state_dict': unwrapped_model.state_dict(), 322 | 'optimizer_state_dict': self.optimizer.state_dict(), 323 | 'scheduler_state_dict': self.scheduler.state_dict(), 324 | }) 325 | self.accelerator.save(save_kwargs, checkpoint_path.replace('.pth', '.ckpt')) 326 | else: 327 | self.accelerator.save(unwrapped_model.state_dict(), checkpoint_path) 328 | else: 329 | if self.args.model.quantization_type == 'quantization_aware_training': 330 | self.model.eval() 331 | model_int8 = torch.quantization.convert(self.model) 332 | torch.save(model_int8.state_dict(), checkpoint_path) 333 | else: 334 | if self.args.trainer.resume_flag: 335 | save_kwargs.update({ 336 | 'model_state_dict': self.model.state_dict(), 337 | 'optimizer_state_dict': self.optimizer.state_dict(), 338 | 'scheduler_state_dict': self.scheduler.state_dict(), 339 | }) 340 | torch.save(save_kwargs, checkpoint_path.replace('.pth', '.ckpt')) 341 | else: 342 | torch.save(self.model.state_dict(), checkpoint_path) 343 | logger.info("model successfully saved to {}".format(checkpoint_path)) 344 | 345 | def _get_data_loader_from_dataset(self, dataset, data_loader_args, phase="train"): 346 | num_workers = data_loader_args.get("num_workers", 0) 347 | batch_size = data_loader_args.get("batch_size", 1) 348 | if phase == "train" and data_loader_args.get('shuffle', True): 349 | shuffle = data_loader_args.get("shuffle", True) 350 | else: 351 | shuffle = data_loader_args.get("shuffle", False) 352 | pin_memory = data_loader_args.get("shuffle", False) 353 | 354 | collate_fn_args = data_loader_args.get("collate_fn") 355 | if collate_fn_args.get("type") is None: 356 | collate_fn = None 357 | else: 358 | collate_fn_type = collate_fn_args.get("type") 359 | collate_fn = getattr(mydatasets, collate_fn_type)(batch_size=batch_size, **collate_fn_args) 360 | data_loader = DataLoader(dataset, 361 | shuffle=shuffle, 362 | num_workers=num_workers, 363 | pin_memory=pin_memory, 364 | collate_fn=collate_fn, 365 | batch_size=batch_size) 366 | logger.info("use data loader with batch_size:{},num_workers:{}".format(batch_size, num_workers)) 367 | 368 | return data_loader 369 | 370 | def prepare_accelerator(self): 371 | if self.accelerator is not None: 372 | self.model, self.optimizer, self.train_data_loader, self.scheduler = self.accelerator.prepare( 373 | self.model, self.optimizer, self.train_data_loader, self.scheduler) 374 | 375 | def _train_post_process(self): 376 | args = copy.deepcopy(self.args) 377 | args.model.model_path = args.trainer.best_model_path 378 | if 'base' in args: 379 | args.pop('base') 380 | args.device.pop('device_id') 381 | args.pop('trainer') 382 | args.phase = 'predict' 383 | save_params(self.args.trainer.save_dir, json.loads(json.dumps(args)), 'model_args.yaml') 384 | return os.path.join(self.args.trainer.save_dir, 'model_args.yaml') 385 | 386 | def _print_step_log(self, epoch, global_step, global_eval_step, loss_meter, norm_meter, batch_time, ni, **kwargs): 387 | current_lr = self._get_current_lr(ni, global_step) 388 | if self.args.device.is_master and self.args.trainer.print_freq > 0 and global_step % self.args.trainer.print_freq == 0: 389 | message = "experiment:{}; train, (epoch: {}, steps: {}, lr:{:e}, step_mean_loss:{}," \ 390 | " average_loss:{}), time, (train_step_time: {:.5f}s, train_average_time: {:.5f}s);" \ 391 | "(grad_norm_mean: {:.5f}, grad_norm_step: {:.5f})". \ 392 | format(self.experiment_name, epoch, global_step, current_lr, 393 | loss_meter.val, loss_meter.avg, batch_time.val, batch_time.avg, norm_meter.avg, 394 | norm_meter.val) 395 | logger.info(message) 396 | if self.writer is not None: 397 | self.writer.add_scalar("{}_train/lr".format(self.experiment_name), current_lr, global_step) 398 | self.writer.add_scalar("{}_train/step_loss".format(self.experiment_name), loss_meter.val, global_step) 399 | self.writer.add_scalar("{}_train/average_loss".format(self.experiment_name), loss_meter.avg, 400 | global_step) 401 | if global_step > 0 and self.args.trainer.save_step_freq > 0 and global_step % self.args.trainer.save_step_freq == 0: 402 | message = "experiment:{}; eval, (epoch: {}, steps: {});".format(self.experiment_name, epoch, global_step) 403 | logger.info(message) 404 | result = self.evaluate(global_eval_step=global_eval_step) 405 | global_eval_step, acc = result['global_eval_step'], result['acc'] 406 | # ADD is_master判断移到这里 407 | if (not self.args.trainer.save_best or (self.args.trainer.save_best 408 | and acc > self.args.trainer.best_eval_result)) and self.args.device.is_master: 409 | checkpoint_name = "{}_epoch{}_step{}_lr{:e}_average_loss{:.5f}_acc{:.5f}.pth".format( 410 | self.experiment_name, epoch, global_step, current_lr, loss_meter.avg, acc) 411 | checkpoint_path = os.path.join(self.args.trainer.save_dir, checkpoint_name) 412 | # ADD记得传epoch和global_step,resume才能存 413 | self.save_model(checkpoint_path, epoch=epoch, global_step=global_step, loss=loss_meter.val) 414 | if acc > self.args.trainer.best_eval_result: 415 | self.args.trainer.best_eval_result = acc 416 | self.args.trainer.best_model_path = checkpoint_path 417 | return global_eval_step 418 | 419 | def _print_epoch_log(self, epoch, global_step, global_eval_step, loss_meter, ni, **kwargs): 420 | current_lr = self._get_current_lr(ni, global_step) 421 | if self.args.trainer.save_epoch_freq > 0 and epoch % self.args.trainer.save_epoch_freq == 0: 422 | message = "experiment:{}; eval, (epoch: {}, steps: {});".format(self.experiment_name, epoch, global_step) 423 | logger.info(message) 424 | result = self.evaluate(global_eval_step=global_eval_step) 425 | global_eval_step, acc = result['global_eval_step'], result['acc'] 426 | if (not self.args.trainer.save_best or (self.args.trainer.save_best 427 | and acc > self.args.trainer.best_eval_result)) and self.args.device.is_master: 428 | checkpoint_name = "{}_epoch{}_step{}_lr{:e}_average_loss{:.5f}_acc{:.5f}.pth".format( 429 | self.experiment_name, epoch, global_step, current_lr, loss_meter.avg, acc) 430 | checkpoint_path = os.path.join(self.args.trainer.save_dir, checkpoint_name) 431 | self.save_model(checkpoint_path, epoch=epoch, global_step=global_step, loss=loss_meter.val) 432 | if acc > self.args.trainer.best_eval_result: 433 | self.args.trainer.best_eval_result = acc 434 | self.args.trainer.best_model_path = checkpoint_path 435 | return global_eval_step 436 | 437 | def _print_eval_log(self, global_step, loss_meter, eval_metric, **kwargs): 438 | evaluate_report = eval_metric.get_report() 439 | acc = evaluate_report["acc"] 440 | message = "experiment:{}; eval,global_step:{}, (step_mean_loss:{},average_loss:{:.5f},evaluate_report:{})".format( 441 | self.experiment_name, global_step, loss_meter.val, loss_meter.avg, evaluate_report) 442 | logger.info(message) 443 | if self.writer is not None: 444 | self.writer.add_scalar("{}_eval/step_loss".format(self.experiment_name), loss_meter.val, global_step) 445 | self.writer.add_scalar("{}_eval/average_loss".format(self.experiment_name), loss_meter.avg, global_step) 446 | self.writer.add_scalar("{}_eval/acc".format(self.experiment_name), acc, global_step) 447 | return acc 448 | -------------------------------------------------------------------------------- /experiment/docparser_experiment.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: 2023/6/2 3 | import os 4 | import re 5 | import time 6 | 7 | import munch 8 | import torch 9 | from PIL import Image 10 | from transformers import AutoTokenizer, DonutProcessor, VisionEncoderDecoderModel, \ 11 | VisionEncoderDecoderConfig, DonutImageProcessor, AutoConfig, AutoModel 12 | 13 | from base.common_util import get_absolute_file_path 14 | from base.driver import logger 15 | from base.meter import AverageMeter 16 | from base.torch_utils.dl_util import get_optimizer 17 | from models.configuration_docparser import DocParserConfig 18 | from models.modeling_docparser import DocParserModel 19 | from mydatasets import get_dataset 20 | from .base_experiment import BaseExperiment 21 | 22 | 23 | class DocParserExperiment(BaseExperiment): 24 | 25 | def __init__(self, config): 26 | config = self._init_config(config) 27 | self.experiment_name = config["name"] 28 | self.args = munch.munchify(config) 29 | self.init_device(config) 30 | self.init_random_seed(config) 31 | self.init_model(config) 32 | self.init_dataset(config) 33 | self.init_trainer_args(config) 34 | self.init_predictor_args(config) 35 | self.prepare_accelerator() 36 | 37 | """ 38 | Main Block 39 | """ 40 | 41 | def predict(self, **kwargs): 42 | for img_path in self.args.predictor.img_paths: 43 | image = Image.open(img_path) 44 | if not image.mode == "RGB": 45 | image = image.convert('RGB') 46 | 47 | pixel_values = self.processor(image, return_tensors="pt").pixel_values 48 | # prepare decoder inputs 49 | task_prompt = self.args.datasets.train.dataset.task_start_token 50 | decoder_input_ids = self.processor.tokenizer(task_prompt, add_special_tokens=False, 51 | return_tensors="pt").input_ids 52 | start = time.time() 53 | with torch.no_grad(): 54 | outputs = self.model.generate( 55 | pixel_values.to(self.args.device.device_id), 56 | decoder_input_ids=decoder_input_ids.to(self.args.device.device_id), 57 | max_length=self.model.decoder.config.max_length, 58 | early_stopping=True, 59 | pad_token_id=self.processor.tokenizer.pad_token_id, 60 | eos_token_id=self.processor.tokenizer.eos_token_id, 61 | use_cache=True, 62 | num_beams=1, 63 | bad_words_ids=[[self.processor.tokenizer.unk_token_id]], 64 | return_dict_in_generate=True, 65 | ) 66 | sequence = self.processor.batch_decode(outputs.sequences)[0] 67 | batch_time = time.time() - start 68 | logger.info("batch inference time:{} s".format(batch_time)) 69 | sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace( 70 | self.processor.tokenizer.pad_token, "") 71 | sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token 72 | print(self.processor.token2json(sequence)) 73 | 74 | def train(self, **kwargs): 75 | batch_time = AverageMeter() 76 | loss_meter = AverageMeter() 77 | norm_meter = AverageMeter() 78 | global_step = self.args.trainer.start_epoch * len(self.train_data_loader) 79 | global_eval_step = 0 80 | ni = 0 81 | for epoch in range(self.args.trainer.start_epoch, self.args.trainer.epochs): 82 | self.optimizer.zero_grad() 83 | for i, batch in enumerate(self.train_data_loader): 84 | if global_step < self.args.trainer.start_global_step: 85 | global_step += 1 86 | continue 87 | start = time.time() 88 | self.model.train() 89 | ni = i + len(self.train_data_loader) * epoch # number integrated batches (since train start) 90 | with self.gradient_accumulate_scope(self.model): 91 | result = self._step_forward(batch) 92 | self._step_backward(result.loss) 93 | if self.accelerator is not None or ((i + 1) % self.args.trainer.grad_accumulate 94 | == 0) or ((i + 1) == len(self.train_data_loader)): 95 | grad_norm = self._step_optimizer() 96 | norm_meter.update(grad_norm) 97 | if not self.args.trainer.scheduler_by_epoch: 98 | self._step_scheduler(global_step) 99 | loss_meter.update(result['loss'].item(), self.args.datasets.train.batch_size) 100 | batch_time.update(time.time() - start) 101 | global_step += 1 102 | global_eval_step = self._print_step_log(epoch, global_step, global_eval_step, loss_meter, norm_meter, 103 | batch_time, ni) 104 | if self.args.trainer.scheduler_by_epoch: 105 | self._step_scheduler(global_step) 106 | global_eval_step = self._print_epoch_log(epoch, global_step, global_eval_step, loss_meter, ni) 107 | model_config_path = self._train_post_process() 108 | if self.args.device.is_master: 109 | self.writer.close() 110 | return { 111 | 'acc': self.args.trainer.best_eval_result, 112 | 'best_model_path': self.args.trainer.best_model_path, 113 | 'model_config_path': model_config_path, 114 | } 115 | 116 | def _step_forward(self, batch, is_train=True, eval_model=None, **kwargs): 117 | input_args_list = ['pixel_values', 'labels', 'decoder_input_ids'] 118 | batch = {k: v.to(self.args.device.device_id) for k, v in batch.items() if k in input_args_list} 119 | # Runs the forward pass with auto-casting. 120 | with self.precision_scope: 121 | output = self.model(**batch) 122 | return output 123 | 124 | """ 125 | Initialization Functions 126 | """ 127 | 128 | def init_model(self, config): 129 | model_args = config["model"] 130 | tokenizer_args = model_args["tokenizer_args"] 131 | # we can borrow donut tokenizer & processor for docparser 132 | tokenizer = AutoTokenizer.from_pretrained( 133 | pretrained_model_name_or_path=tokenizer_args['pretrained_model_name_or_path'] 134 | ) 135 | image_processor = DonutImageProcessor( 136 | size={"height": model_args['image_size'][0], "width": model_args['image_size'][1]}) 137 | self.processor = DonutProcessor(image_processor=image_processor, 138 | tokenizer=tokenizer) 139 | 140 | # model initialization 141 | AutoConfig.register("docparser-swin", DocParserConfig) 142 | AutoModel.register(DocParserConfig, DocParserModel) 143 | config = VisionEncoderDecoderConfig.from_pretrained(model_args["pretrained_model_name_or_path"]) 144 | config.encoder.image_size = model_args['image_size'] 145 | # during pre-training, a larger image size was used; for fine-tuning, 146 | # we update max_length of the decoder (for generation) 147 | config.decoder.max_length = model_args['max_length'] 148 | config.decoder.decoder_layers = model_args['decoder_layers'] 149 | model = VisionEncoderDecoderModel(config=config) 150 | logger.info("init weight from pretrained model:{}".format(model_args["pretrained_model_name_or_path"])) 151 | model.decoder.resize_token_embeddings(len(self.processor.tokenizer)) 152 | self.model = model 153 | self.model.to(self.args.device.device_id) 154 | if "model_path" in model_args and model_args['model_path'] is not None: 155 | model_path = get_absolute_file_path(model_args['model_path']) 156 | self.load_model(model_path, strict=model_args.get('load_strict', True)) 157 | total = sum([param.nelement() for param in self.model.parameters()]) 158 | logger.info("Number of parameter: %.2fM" % (total / 1e6)) 159 | 160 | def _init_optimizer(self, trainer_args, **kwargs): 161 | optimizer_args = trainer_args.get("optimizer") 162 | if optimizer_args.get("scale_lr"): 163 | num_process = 1 if self.accelerator is None else self.accelerator.num_processes 164 | optimizer_args['lr'] = float(optimizer_args['lr']) * self.grad_accumulate * \ 165 | self.train_data_loader.batch_size * num_process 166 | optimizer_args['img_lr'] = float(optimizer_args['img_lr']) * self.grad_accumulate * \ 167 | self.train_data_loader.batch_size * num_process 168 | self.optimizer = get_optimizer(self.model, **optimizer_args) 169 | 170 | def init_dataset(self, config): 171 | if 'datasets' in config and config.get('phase', 'train') != 'predict': 172 | dataset_args = config.get("datasets") 173 | train_data_loader_args = dataset_args.get("train") 174 | if config.get('phase', 'train') == 'train': 175 | train_data_loader_args['dataset'].update({ 176 | "donut_model": self.model, 177 | "processor": self.processor, 178 | "max_length": config['model']['max_length'], 179 | "phase": 'train', 180 | }) 181 | if "cache_dir" not in train_data_loader_args['dataset']: 182 | train_data_loader_args['dataset'].update({ 183 | "cache_dir": config['trainer']['save_dir']}) 184 | self.train_dataset = get_dataset(train_data_loader_args['dataset']) 185 | self.train_data_loader = self._get_data_loader_from_dataset(self.train_dataset, 186 | train_data_loader_args, 187 | phase='train') 188 | logger.info("success init train data loader len:{} ".format(len(self.train_data_loader))) 189 | 190 | # set task start token & pad token for bart decoder; 191 | # Do NOT change it since you can only set the start_token after dataset initialization where special tokens 192 | # are added into vocab 193 | self.model.config.decoder_start_token_id = self.processor.tokenizer.convert_tokens_to_ids( 194 | train_data_loader_args['dataset']['task_start_token']) 195 | self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id 196 | 197 | """ 198 | Tool Functions 199 | """ 200 | 201 | def _print_step_log(self, epoch, global_step, global_eval_step, loss_meter, norm_meter, batch_time, ni, **kwargs): 202 | current_lr = self._get_current_lr(ni, global_step) 203 | if self.args.device.is_master and self.args.trainer.print_freq > 0 and global_step % self.args.trainer.print_freq == 0: 204 | message = "experiment:{}; train, (epoch: {}, steps: {}, lr:{:e}, step_mean_loss:{}," \ 205 | " average_loss:{}), time, (train_step_time: {:.5f}s, train_average_time: {:.5f}s);" \ 206 | "(grad_norm_mean: {:.5f}, grad_norm_step: {:.5f})". \ 207 | format(self.experiment_name, epoch, global_step, current_lr, 208 | loss_meter.val, loss_meter.avg, batch_time.val, batch_time.avg, norm_meter.avg, 209 | norm_meter.val) 210 | logger.info(message) 211 | if self.writer is not None: 212 | self.writer.add_scalar("{}_train/lr".format(self.experiment_name), current_lr, global_step) 213 | self.writer.add_scalar("{}_train/step_loss".format(self.experiment_name), loss_meter.val, global_step) 214 | self.writer.add_scalar("{}_train/average_loss".format(self.experiment_name), loss_meter.avg, 215 | global_step) 216 | if global_step > 0 and self.args.trainer.save_step_freq > 0 and self.args.device.is_master and global_step % self.args.trainer.save_step_freq == 0: 217 | message = "experiment:{}; eval, (epoch: {}, steps: {});".format(self.experiment_name, epoch, global_step) 218 | logger.info(message) 219 | # result = self.evaluate(global_eval_step=global_eval_step) 220 | checkpoint_name = "{}_epoch{}_step{}_lr{:e}_average_loss{:.5f}.pth".format( 221 | self.experiment_name, epoch, global_step, current_lr, loss_meter.avg) 222 | checkpoint_path = os.path.join(self.args.trainer.save_dir, checkpoint_name) 223 | tokenizer_path = os.path.join(self.args.trainer.save_dir, "tokenizer") 224 | os.makedirs(tokenizer_path, exist_ok=True) 225 | self.processor.tokenizer.save_pretrained(tokenizer_path) 226 | self.save_model(checkpoint_path, epoch=epoch, global_step=global_step, loss=loss_meter.val) 227 | return global_eval_step 228 | 229 | def _print_epoch_log(self, epoch, global_step, global_eval_step, loss_meter, ni, **kwargs): 230 | current_lr = self._get_current_lr(ni, global_step) 231 | if self.args.trainer.save_epoch_freq > 0 and self.args.device.is_master and epoch % self.args.trainer.save_epoch_freq == 0: 232 | message = "experiment:{}; eval, (epoch: {}, steps: {});".format(self.experiment_name, epoch, global_step) 233 | logger.info(message) 234 | checkpoint_name = "{}_epoch{}_step{}_lr{:e}_average_loss{:.5f}.pth".format( 235 | self.experiment_name, epoch, global_step, current_lr, loss_meter.avg) 236 | checkpoint_path = os.path.join(self.args.trainer.save_dir, checkpoint_name) 237 | tokenizer_path = os.path.join(self.args.trainer.save_dir, "tokenizer") 238 | os.makedirs(tokenizer_path, exist_ok=True) 239 | self.save_model(checkpoint_path, epoch=epoch, global_step=global_step, loss=loss_meter.val) 240 | self.processor.tokenizer.save_pretrained(tokenizer_path) 241 | return global_eval_step 242 | -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | ### JetBrains template 2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 4 | 5 | # User-specific stuff 6 | .idea/**/workspace.xml 7 | .idea/**/tasks.xml 8 | .idea/**/usage.statistics.xml 9 | .idea/**/dictionaries 10 | .idea/**/shelf 11 | 12 | # Generated files 13 | .idea/**/contentModel.xml 14 | 15 | # Sensitive or high-churn files 16 | .idea/**/dataSources/ 17 | .idea/**/dataSources.ids 18 | .idea/**/dataSources.local.xml 19 | .idea/**/sqlDataSources.xml 20 | .idea/**/dynamic.xml 21 | .idea/**/uiDesigner.xml 22 | .idea/**/dbnavigator.xml 23 | 24 | # Gradle 25 | .idea/**/gradle.xml 26 | .idea/**/libraries 27 | 28 | # Gradle and Maven with auto-import 29 | # When using Gradle or Maven with auto-import, you should exclude module files, 30 | # since they will be recreated, and may cause churn. Uncomment if using 31 | # auto-import. 32 | # .idea/artifacts 33 | # .idea/compiler.xml 34 | # .idea/jarRepositories.xml 35 | # .idea/modules.xml 36 | # .idea/*.iml 37 | # .idea/modules 38 | # *.iml 39 | # *.ipr 40 | 41 | # CMake 42 | cmake-build-*/ 43 | 44 | # Mongo Explorer plugin 45 | .idea/**/mongoSettings.xml 46 | 47 | # File-based project format 48 | *.iws 49 | 50 | # IntelliJ 51 | out/ 52 | 53 | # mpeltonen/sbt-idea plugin 54 | .idea_modules/ 55 | 56 | # JIRA plugin 57 | atlassian-ide-plugin.xml 58 | 59 | # Cursive Clojure plugin 60 | .idea/replstate.xml 61 | 62 | # Crashlytics plugin (for Android Studio and IntelliJ) 63 | com_crashlytics_export_strings.xml 64 | crashlytics.properties 65 | crashlytics-build.properties 66 | fabric.properties 67 | 68 | # Editor-based Rest Client 69 | .idea/httpRequests 70 | 71 | # Android studio 3.1+ serialized cache file 72 | .idea/caches/build_file_checksums.ser 73 | 74 | ### macOS template 75 | # General 76 | .DS_Store 77 | .AppleDouble 78 | .LSOverride 79 | 80 | # Icon must end with two \r 81 | Icon 82 | 83 | # Thumbnails 84 | ._* 85 | 86 | # Files that might appear in the root of a volume 87 | .DocumentRevisions-V100 88 | .fseventsd 89 | .Spotlight-V100 90 | .TemporaryItems 91 | .Trashes 92 | .VolumeIcon.icns 93 | .com.apple.timemachine.donotpresent 94 | 95 | # Directories potentially created on remote AFP share 96 | .AppleDB 97 | .AppleDesktop 98 | Network Trash Folder 99 | Temporary Items 100 | .apdisk 101 | 102 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: @time: 1/31/23 17:08 3 | from .configuration_docparser import DocParserConfig 4 | from .modeling_docparser import DocParserModel -------------------------------------------------------------------------------- /models/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "VisionEncoderDecoderModel" 4 | ], 5 | "decoder": { 6 | "_name_or_path": "", 7 | "activation_dropout": 0.0, 8 | "activation_function": "gelu", 9 | "add_cross_attention": true, 10 | "add_final_layer_norm": true, 11 | "architectures": null, 12 | "attention_dropout": 0.0, 13 | "bad_words_ids": null, 14 | "bos_token_id": 0, 15 | "chunk_size_feed_forward": 0, 16 | "classifier_dropout": 0.0, 17 | "cross_attention_hidden_size": null, 18 | "d_model": 1024, 19 | "decoder_attention_heads": 16, 20 | "decoder_ffn_dim": 4096, 21 | "decoder_layerdrop": 0.0, 22 | "decoder_layers": 4, 23 | "decoder_start_token_id": null, 24 | "diversity_penalty": 0.0, 25 | "do_sample": false, 26 | "dropout": 0.1, 27 | "early_stopping": false, 28 | "encoder_attention_heads": 16, 29 | "encoder_ffn_dim": 4096, 30 | "encoder_layerdrop": 0.0, 31 | "encoder_layers": 12, 32 | "encoder_no_repeat_ngram_size": 0, 33 | "eos_token_id": 2, 34 | "exponential_decay_length_penalty": null, 35 | "finetuning_task": null, 36 | "forced_bos_token_id": null, 37 | "forced_eos_token_id": 2, 38 | "id2label": { 39 | "0": "LABEL_0", 40 | "1": "LABEL_1" 41 | }, 42 | "init_std": 0.02, 43 | "is_decoder": true, 44 | "is_encoder_decoder": false, 45 | "label2id": { 46 | "LABEL_0": 0, 47 | "LABEL_1": 1 48 | }, 49 | "length_penalty": 1.0, 50 | "max_length": 20, 51 | "max_position_embeddings": 1536, 52 | "min_length": 0, 53 | "model_type": "mbart", 54 | "no_repeat_ngram_size": 0, 55 | "num_beam_groups": 1, 56 | "num_beams": 1, 57 | "num_hidden_layers": 12, 58 | "num_return_sequences": 1, 59 | "output_attentions": false, 60 | "output_hidden_states": false, 61 | "output_scores": false, 62 | "pad_token_id": 1, 63 | "prefix": null, 64 | "problem_type": null, 65 | "pruned_heads": {}, 66 | "remove_invalid_values": false, 67 | "repetition_penalty": 1.0, 68 | "return_dict": true, 69 | "return_dict_in_generate": false, 70 | "scale_embedding": true, 71 | "sep_token_id": null, 72 | "task_specific_params": null, 73 | "temperature": 1.0, 74 | "tf_legacy_loss": false, 75 | "tie_encoder_decoder": false, 76 | "tie_word_embeddings": true, 77 | "tokenizer_class": null, 78 | "top_k": 50, 79 | "top_p": 1.0, 80 | "torch_dtype": null, 81 | "torchscript": false, 82 | "transformers_version": "4.22.0.dev0", 83 | "typical_p": 1.0, 84 | "use_bfloat16": false, 85 | "use_cache": true, 86 | "vocab_size": 57525 87 | }, 88 | "encoder": { 89 | "_name_or_path": "", 90 | "add_cross_attention": false, 91 | "architectures": null, 92 | "attention_probs_dropout_prob": 0.0, 93 | "bad_words_ids": null, 94 | "bos_token_id": null, 95 | "chunk_size_feed_forward": 0, 96 | "cross_attention_hidden_size": null, 97 | "conv_depth_num_layers": 3, 98 | "decoder_start_token_id": null, 99 | "depths": [ 100 | 3, 101 | 6, 102 | 6, 103 | 2, 104 | 2, 105 | 2 106 | ], 107 | "diversity_penalty": 0.0, 108 | "do_sample": false, 109 | "drop_path_rate": 0.1, 110 | "early_stopping": false, 111 | "embed_dim": [ 112 | 64, 113 | 128, 114 | 256, 115 | 512, 116 | 768, 117 | 1024 118 | ], 119 | "encoder_no_repeat_ngram_size": 0, 120 | "eos_token_id": null, 121 | "exponential_decay_length_penalty": null, 122 | "finetuning_task": null, 123 | "forced_bos_token_id": null, 124 | "forced_eos_token_id": null, 125 | "hidden_act": "gelu", 126 | "hidden_dropout_prob": 0.0, 127 | "hidden_size": 1024, 128 | "id2label": { 129 | "0": "LABEL_0", 130 | "1": "LABEL_1" 131 | }, 132 | "image_size": [ 133 | 2560, 134 | 1920 135 | ], 136 | "initializer_range": 0.02, 137 | "is_decoder": false, 138 | "is_encoder_decoder": false, 139 | "label2id": { 140 | "LABEL_0": 0, 141 | "LABEL_1": 1 142 | }, 143 | "layer_norm_eps": 1e-05, 144 | "length_penalty": 1.0, 145 | "max_length": 20, 146 | "min_length": 0, 147 | "mlp_ratio": 4.0, 148 | "auto_map": { 149 | "AutoConfig": "configuration_docparser.DocParserConfig" 150 | }, 151 | "model_type": "docparser-swin", 152 | "no_repeat_ngram_size": 0, 153 | "num_beam_groups": 1, 154 | "num_beams": 1, 155 | "num_channels": 3, 156 | "num_heads": [ 157 | 4, 158 | 8, 159 | 16 160 | ], 161 | "pe_kernel_size": 3, 162 | "pe_stride_size": 2, 163 | "pe_hidden_size": 64, 164 | "pe_add_hidden_act": true, 165 | "num_layers": 3, 166 | "num_return_sequences": 1, 167 | "output_attentions": false, 168 | "output_hidden_states": false, 169 | "output_scores": false, 170 | "pad_token_id": null, 171 | "prefix": null, 172 | "problem_type": null, 173 | "pruned_heads": {}, 174 | "qkv_bias": true, 175 | "remove_invalid_values": false, 176 | "repetition_penalty": 1.0, 177 | "return_dict": true, 178 | "return_dict_in_generate": false, 179 | "sep_token_id": null, 180 | "stride_size": [ 181 | [ 182 | 2, 183 | 1 184 | ], 185 | [ 186 | 2, 187 | 1 188 | ], 189 | [ 190 | 2, 191 | 2 192 | ] 193 | ], 194 | "task_specific_params": null, 195 | "temperature": 1.0, 196 | "tf_legacy_loss": false, 197 | "tie_encoder_decoder": false, 198 | "tie_word_embeddings": true, 199 | "tokenizer_class": null, 200 | "top_k": 50, 201 | "top_p": 1.0, 202 | "torch_dtype": null, 203 | "torchscript": false, 204 | "transformers_version": "4.22.0.dev0", 205 | "typical_p": 1.0, 206 | "use_absolute_embeddings": false, 207 | "use_bfloat16": false, 208 | "window_size": [ 209 | [ 210 | 5, 211 | 40 212 | ], 213 | [ 214 | 5, 215 | 20 216 | ], 217 | [ 218 | 10, 219 | 10 220 | ] 221 | ] 222 | }, 223 | "is_encoder_decoder": true, 224 | "model_type": "vision-encoder-decoder", 225 | "tie_word_embeddings": false, 226 | "torch_dtype": "float32", 227 | "transformers_version": null 228 | } -------------------------------------------------------------------------------- /models/configuration_docparser.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: @time: 7/6/23 10:08 3 | # coding=utf-8 4 | # Copyright 2022 The HuggingFace Inc. team. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ DocParser Swin Transformer model configuration""" 18 | 19 | from transformers.configuration_utils import PretrainedConfig 20 | from transformers.utils import logging 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | class DocParserConfig(PretrainedConfig): 26 | model_type = "docparser-swin" 27 | 28 | attribute_map = { 29 | "num_attention_heads": "num_heads", 30 | "num_hidden_layers": "num_layers", 31 | } 32 | 33 | def __init__( 34 | self, 35 | image_size=224, 36 | patch_size=4, 37 | num_channels=3, 38 | embed_dim=96, 39 | depths=[2, 2, 6, 2], 40 | num_heads=[3, 6, 12, 24], 41 | window_size=7, 42 | mlp_ratio=4.0, 43 | qkv_bias=True, 44 | hidden_dropout_prob=0.0, 45 | attention_probs_dropout_prob=0.0, 46 | drop_path_rate=0.1, 47 | hidden_act="gelu", 48 | use_absolute_embeddings=False, 49 | initializer_range=0.02, 50 | layer_norm_eps=1e-5, 51 | **kwargs, 52 | ): 53 | super().__init__(**kwargs) 54 | 55 | self.image_size = image_size 56 | self.patch_size = patch_size 57 | self.num_channels = num_channels 58 | self.embed_dim = embed_dim 59 | self.depths = depths 60 | self.num_layers = len(depths) 61 | self.num_heads = num_heads 62 | self.window_size = window_size 63 | self.mlp_ratio = mlp_ratio 64 | self.qkv_bias = qkv_bias 65 | self.hidden_dropout_prob = hidden_dropout_prob 66 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 67 | self.drop_path_rate = drop_path_rate 68 | self.hidden_act = hidden_act 69 | self.use_absolute_embeddings = use_absolute_embeddings 70 | self.layer_norm_eps = layer_norm_eps 71 | self.initializer_range = initializer_range 72 | -------------------------------------------------------------------------------- /models/convnext.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: @time: 7/6/23 10:28 3 | 4 | from functools import partial 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | from torch.nn import functional as F 9 | from torchvision.ops.stochastic_depth import StochasticDepth 10 | from typing import Any, Callable, List, Optional, Sequence, Tuple, Union 11 | 12 | 13 | class LayerNorm2d(nn.LayerNorm): 14 | def forward(self, x: Tensor) -> Tensor: 15 | x = x.permute(0, 2, 3, 1) 16 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 17 | x = x.permute(0, 3, 1, 2).contiguous() 18 | return x 19 | 20 | 21 | class Permute(torch.nn.Module): 22 | """This module returns a view of the tensor input with its dimensions permuted. 23 | 24 | Args: 25 | dims (List[int]): The desired ordering of dimensions 26 | """ 27 | 28 | def __init__(self, dims: List[int]): 29 | super().__init__() 30 | self.dims = dims 31 | 32 | def forward(self, x: Tensor) -> Tensor: 33 | return torch.permute(x, self.dims).contiguous() 34 | 35 | 36 | class CNBlock(nn.Module): 37 | def __init__( 38 | self, 39 | dim, 40 | layer_scale: float, 41 | stochastic_depth_prob: float, 42 | norm_layer: Optional[Callable[..., nn.Module]] = None, 43 | ) -> None: 44 | super().__init__() 45 | if norm_layer is None: 46 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 47 | 48 | self.block = nn.Sequential( 49 | nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True), 50 | Permute([0, 2, 3, 1]), 51 | norm_layer(dim), 52 | nn.Linear(in_features=dim, out_features=4 * dim, bias=True), 53 | nn.GELU(), 54 | nn.Linear(in_features=4 * dim, out_features=dim, bias=True), 55 | Permute([0, 3, 1, 2]), 56 | ) 57 | self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) 58 | self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") 59 | 60 | def forward(self, input: Tensor) -> Tensor: 61 | result = self.layer_scale * self.block(input) 62 | result = self.stochastic_depth(result) 63 | result += input 64 | return result 65 | 66 | 67 | class CNBlockConfig: 68 | # Stores information listed at Section 3 of the ConvNeXt paper 69 | def __init__( 70 | self, 71 | input_channels: int, 72 | out_channels: Optional[int], 73 | num_layers: int, 74 | stride: Union[Tuple[int, int], int], 75 | ) -> None: 76 | self.input_channels = input_channels 77 | self.out_channels = out_channels 78 | self.num_layers = num_layers 79 | self.stride = stride 80 | 81 | def __repr__(self) -> str: 82 | s = self.__class__.__name__ + "(" 83 | s += "input_channels={input_channels}" 84 | s += ", out_channels={out_channels}" 85 | s += ", num_layers={num_layers}" 86 | s += ")" 87 | return s.format(**self.__dict__) 88 | 89 | 90 | class ConvNeXt(nn.Module): 91 | def __init__( 92 | self, 93 | block_setting: List[CNBlockConfig], 94 | stochastic_depth_prob: float = 0.0, 95 | layer_scale: float = 1e-6, 96 | block: Optional[Callable[..., nn.Module]] = None, 97 | norm_layer: Optional[Callable[..., nn.Module]] = None, 98 | **kwargs: Any, 99 | ) -> None: 100 | super().__init__() 101 | 102 | if not block_setting: 103 | raise ValueError("The block_setting should not be empty") 104 | elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): 105 | raise TypeError("The block_setting should be List[CNBlockConfig]") 106 | 107 | if block is None: 108 | block = CNBlock 109 | 110 | if norm_layer is None: 111 | norm_layer = partial(LayerNorm2d, eps=1e-6) 112 | 113 | layers: List[nn.Module] = [] 114 | 115 | total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) 116 | stage_block_id = 0 117 | for cnf in block_setting: 118 | # Bottlenecks 119 | stage: List[nn.Module] = [] 120 | for _ in range(cnf.num_layers): 121 | # adjust stochastic depth probability based on the depth of the stage block 122 | sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) 123 | stage.append(block(cnf.input_channels, layer_scale, sd_prob)) 124 | stage_block_id += 1 125 | layers.append(nn.Sequential(*stage)) 126 | if cnf.out_channels is not None: 127 | # Downsampling 128 | layers.append( 129 | nn.Sequential( 130 | norm_layer(cnf.input_channels), 131 | nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=cnf.stride, stride=cnf.stride), 132 | ) 133 | ) 134 | 135 | self.features = nn.Sequential(*layers) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, (nn.Conv2d, nn.Linear)): 139 | nn.init.trunc_normal_(m.weight, std=0.02) 140 | if m.bias is not None: 141 | nn.init.zeros_(m.bias) 142 | 143 | def _forward_impl(self, x: Tensor) -> Tensor: 144 | x = self.features(x) 145 | return x 146 | 147 | def forward(self, x: Tensor) -> Tensor: 148 | return self._forward_impl(x) 149 | 150 | 151 | if __name__ == '__main__': 152 | channel_list = [64, 128, 256] 153 | num_layer_list = [3, 6, 6] 154 | stride = [(1, 2), (1, 2), (2, 2)] 155 | model = ConvNeXt(block_setting=[ 156 | CNBlockConfig(input_channels=channel_list[i_layer], 157 | out_channels=channel_list[i_layer] * 2, 158 | num_layers=num_layer_list[i_layer], 159 | stride=stride[i_layer] 160 | ) 161 | for i_layer in range(len(num_layer_list)) 162 | ]) 163 | -------------------------------------------------------------------------------- /models/modeling_docparser.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: @time: 7/5/23 16:14 3 | # coding=utf-8 4 | 5 | 6 | import collections.abc 7 | 8 | import math 9 | from dataclasses import dataclass 10 | from typing import Optional, Tuple, Union 11 | 12 | import torch 13 | import torch.utils.checkpoint 14 | from torch import nn 15 | 16 | from transformers.activations import ACT2FN 17 | from transformers.modeling_utils import PreTrainedModel 18 | from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer 19 | from transformers.utils import ( 20 | ModelOutput, 21 | add_code_sample_docstrings, 22 | add_start_docstrings, 23 | add_start_docstrings_to_model_forward, 24 | logging, 25 | ) 26 | from .configuration_docparser import DocParserConfig 27 | from .convnext import ConvNeXt, CNBlockConfig 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | # General docstring 32 | _CONFIG_FOR_DOC = "DocParserConfig" 33 | 34 | # Base docstring 35 | _CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base" 36 | _EXPECTED_OUTPUT_SHAPE = [1, 49, 768] 37 | 38 | DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [ 39 | "naver-clova-ix/donut-base", 40 | # See all Donut Swin models at https://huggingface.co/models?filter=donut 41 | ] 42 | 43 | 44 | @dataclass 45 | # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DocParser 46 | class DocParserEncoderOutput(ModelOutput): 47 | """ 48 | DocParser encoder's outputs, with potential hidden states and attentions. 49 | 50 | Args: 51 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 52 | Sequence of hidden-states at the output of the last layer of the model. 53 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 54 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of 55 | shape `(batch_size, sequence_length, hidden_size)`. 56 | 57 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 58 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 59 | Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, 60 | sequence_length)`. 61 | 62 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 63 | heads. 64 | reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 65 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of 66 | shape `(batch_size, hidden_size, height, width)`. 67 | 68 | Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to 69 | include the spatial dimensions. 70 | """ 71 | 72 | last_hidden_state: torch.FloatTensor = None 73 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 74 | attentions: Optional[Tuple[torch.FloatTensor]] = None 75 | reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 76 | 77 | 78 | @dataclass 79 | # Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DocParser 80 | class DocParserModelOutput(ModelOutput): 81 | """ 82 | DocParser model's outputs that also contains a pooling of the last hidden states. 83 | 84 | Args: 85 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 86 | Sequence of hidden-states at the output of the last layer of the model. 87 | pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): 88 | Average pooling of the last layer hidden-state. 89 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 90 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of 91 | shape `(batch_size, sequence_length, hidden_size)`. 92 | 93 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 94 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 95 | Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, 96 | sequence_length)`. 97 | 98 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 99 | heads. 100 | reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 101 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of 102 | shape `(batch_size, hidden_size, height, width)`. 103 | 104 | Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to 105 | include the spatial dimensions. 106 | """ 107 | 108 | last_hidden_state: torch.FloatTensor = None 109 | pooler_output: Optional[torch.FloatTensor] = None 110 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 111 | attentions: Optional[Tuple[torch.FloatTensor]] = None 112 | reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 113 | 114 | 115 | # Copied from transformers.models.swin.modeling_swin.window_partition 116 | def window_partition(input_feature, window_size): 117 | """ 118 | Partitions the given input into windows. 119 | """ 120 | batch_size, height, width, num_channels = input_feature.shape 121 | input_feature = input_feature.view( 122 | batch_size, height // window_size[0], window_size[0], width // window_size[1], window_size[1], num_channels 123 | ) 124 | windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], 125 | num_channels) 126 | return windows 127 | 128 | 129 | # Copied from transformers.models.swin.modeling_swin.window_reverse 130 | def window_reverse(windows, window_size, height, width): 131 | """ 132 | Merges windows to produce higher resolution features. 133 | """ 134 | num_channels = windows.shape[-1] 135 | windows = windows.view(-1, height // window_size[0], width // window_size[1], window_size[0], window_size[1], 136 | num_channels) 137 | windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) 138 | return windows 139 | 140 | 141 | class ConvBNLayer(nn.Module): 142 | def __init__(self, 143 | in_channels, 144 | out_channels, 145 | kernel_size, 146 | padding, 147 | stride_size): 148 | super().__init__() 149 | kernel_size = kernel_size if isinstance(kernel_size, collections.abc.Iterable) else (kernel_size, kernel_size) 150 | stride_size = stride_size if isinstance(stride_size, collections.abc.Iterable) else (stride_size, stride_size) 151 | padding = padding if isinstance(stride_size, collections.abc.Iterable) else (padding, padding) 152 | self.conv = nn.Conv2d( 153 | in_channels=in_channels, 154 | out_channels=out_channels, 155 | kernel_size=kernel_size, 156 | padding=padding, 157 | stride=stride_size) 158 | self.norm = nn.BatchNorm2d(out_channels) 159 | self.act = nn.GELU() 160 | 161 | def forward(self, inputs): 162 | out = self.conv(inputs) 163 | out = self.norm(out) 164 | out = self.act(out) 165 | return out 166 | 167 | 168 | class DocParserPatchEmbeddings(nn.Module): 169 | """ 170 | Construct the patch and position embeddings. Optionally, also the mask token. 171 | """ 172 | 173 | def __init__(self, config): 174 | super().__init__() 175 | image_size = config.image_size 176 | kernel_size, stride_size = config.pe_kernel_size, config.pe_stride_size 177 | num_channels, hidden_size = config.num_channels, config.pe_hidden_size 178 | self.grid_size = (image_size[0] // 32, image_size[1] // 8) # num patches for swin-part 179 | self.patch_embedding = nn.Sequential( 180 | ConvBNLayer( 181 | in_channels=num_channels, 182 | out_channels=hidden_size // 2, 183 | kernel_size=kernel_size, 184 | stride_size=stride_size, 185 | padding=1), 186 | ConvBNLayer( 187 | in_channels=hidden_size // 2, 188 | out_channels=hidden_size, 189 | kernel_size=kernel_size, 190 | stride_size=stride_size, 191 | padding=1), 192 | ) 193 | 194 | def forward( 195 | self, 196 | pixel_values: Optional[torch.FloatTensor]): 197 | embeddings = self.patch_embedding(pixel_values) 198 | return embeddings 199 | 200 | 201 | # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging 202 | class DocParserPatchMerging(nn.Module): 203 | """ 204 | Patch Merging Layer. 205 | 206 | Args: 207 | input_resolution (`Tuple[int]`): 208 | Resolution of input feature. 209 | dim (`int`): 210 | Number of input channels. 211 | norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): 212 | Normalization layer class. 213 | """ 214 | 215 | def __init__(self, input_resolution: Tuple[int], dim: int, dim_out: int, 216 | norm_layer: nn.Module = nn.LayerNorm) -> None: 217 | super().__init__() 218 | self.input_resolution = input_resolution 219 | self.dim = dim 220 | self.reduction = nn.Linear(4 * dim, dim_out, bias=False) 221 | self.norm = norm_layer(4 * dim) 222 | 223 | def maybe_pad(self, input_feature, height, width): 224 | should_pad = (height % 2 == 1) or (width % 2 == 1) 225 | if should_pad: 226 | pad_values = (0, 0, 0, width % 2, 0, height % 2) 227 | input_feature = nn.functional.pad(input_feature, pad_values) 228 | 229 | return input_feature 230 | 231 | def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: 232 | height, width = input_dimensions 233 | # `dim` is height * width 234 | batch_size, dim, num_channels = input_feature.shape 235 | 236 | input_feature = input_feature.view(batch_size, height, width, num_channels) 237 | # pad input to be disible by width and height, if needed 238 | input_feature = self.maybe_pad(input_feature, height, width) 239 | # [batch_size, height, width/2, num_channels] 240 | input_feature_0 = input_feature[:, :, 0::2, :] 241 | # [batch_size, height, width/2, num_channels] 242 | input_feature_1 = input_feature[:, :, 0::2, :] 243 | # [batch_size, height, width/2, num_channels] 244 | input_feature_2 = input_feature[:, :, 1::2, :] 245 | # [batch_size, height, width/2, num_channels] 246 | input_feature_3 = input_feature[:, :, 1::2, :] 247 | # batch_size height width/2 4*num_channels 248 | input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) 249 | input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C 250 | 251 | input_feature = self.norm(input_feature) 252 | input_feature = self.reduction(input_feature) 253 | 254 | return input_feature 255 | 256 | 257 | # Copied from transformers.models.swin.modeling_swin.drop_path 258 | def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): 259 | """ 260 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 261 | 262 | Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, 263 | however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 264 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the 265 | layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the 266 | argument. 267 | """ 268 | if drop_prob == 0.0 or not training: 269 | return input 270 | keep_prob = 1 - drop_prob 271 | shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 272 | random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) 273 | random_tensor.floor_() # binarize 274 | output = input.div(keep_prob) * random_tensor 275 | return output 276 | 277 | 278 | # Copied from transformers.models.swin.modeling_swin.SwinDropPath 279 | class DocParserDropPath(nn.Module): 280 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 281 | 282 | def __init__(self, drop_prob: Optional[float] = None) -> None: 283 | super().__init__() 284 | self.drop_prob = drop_prob 285 | 286 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 287 | return drop_path(hidden_states, self.drop_prob, self.training) 288 | 289 | def extra_repr(self) -> str: 290 | return "p={}".format(self.drop_prob) 291 | 292 | 293 | # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DocParser 294 | class DocParserSelfAttention(nn.Module): 295 | def __init__(self, config, dim, num_heads, window_size): 296 | super().__init__() 297 | if dim % num_heads != 0: 298 | raise ValueError( 299 | f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" 300 | ) 301 | 302 | self.num_attention_heads = num_heads 303 | self.attention_head_size = int(dim / num_heads) 304 | self.all_head_size = self.num_attention_heads * self.attention_head_size 305 | self.window_size = ( 306 | window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) 307 | ) 308 | 309 | self.relative_position_bias_table = nn.Parameter( 310 | torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) 311 | ) 312 | 313 | # get pair-wise relative position index for each token inside the window 314 | coords_h = torch.arange(self.window_size[0]) 315 | coords_w = torch.arange(self.window_size[1]) 316 | coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) 317 | coords_flatten = torch.flatten(coords, 1) 318 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 319 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 320 | relative_coords[:, :, 0] += self.window_size[0] - 1 321 | relative_coords[:, :, 1] += self.window_size[1] - 1 322 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 323 | relative_position_index = relative_coords.sum(-1) 324 | self.register_buffer("relative_position_index", relative_position_index) 325 | 326 | self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) 327 | self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) 328 | self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) 329 | 330 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 331 | 332 | def transpose_for_scores(self, x): 333 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 334 | x = x.view(new_x_shape) 335 | return x.permute(0, 2, 1, 3) 336 | 337 | def forward( 338 | self, 339 | hidden_states: torch.Tensor, 340 | attention_mask: Optional[torch.FloatTensor] = None, 341 | head_mask: Optional[torch.FloatTensor] = None, 342 | output_attentions: Optional[bool] = False, 343 | ) -> Tuple[torch.Tensor]: 344 | batch_size, dim, num_channels = hidden_states.shape 345 | mixed_query_layer = self.query(hidden_states) 346 | 347 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 348 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 349 | query_layer = self.transpose_for_scores(mixed_query_layer) 350 | 351 | # Take the dot product between "query" and "key" to get the raw attention scores. 352 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 353 | 354 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 355 | 356 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] 357 | relative_position_bias = relative_position_bias.view( 358 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 359 | ) 360 | 361 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 362 | attention_scores = attention_scores + relative_position_bias.unsqueeze(0) 363 | 364 | if attention_mask is not None: 365 | # Apply the attention mask is (precomputed for all layers in DocParserModel forward() function) 366 | mask_shape = attention_mask.shape[0] 367 | attention_scores = attention_scores.view( 368 | batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim 369 | ) 370 | attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) 371 | attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) 372 | 373 | # Normalize the attention scores to probabilities. 374 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 375 | 376 | # This is actually dropping out entire tokens to attend to, which might 377 | # seem a bit unusual, but is taken from the original Transformer paper. 378 | attention_probs = self.dropout(attention_probs) 379 | 380 | # Mask heads if we want to 381 | if head_mask is not None: 382 | attention_probs = attention_probs * head_mask 383 | 384 | context_layer = torch.matmul(attention_probs, value_layer) 385 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 386 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 387 | context_layer = context_layer.view(new_context_layer_shape) 388 | 389 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 390 | 391 | return outputs 392 | 393 | 394 | # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput 395 | class DocParserSelfOutput(nn.Module): 396 | def __init__(self, config, dim): 397 | super().__init__() 398 | self.dense = nn.Linear(dim, dim) 399 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 400 | 401 | def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: 402 | hidden_states = self.dense(hidden_states) 403 | hidden_states = self.dropout(hidden_states) 404 | 405 | return hidden_states 406 | 407 | 408 | # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DocParser 409 | class DocParserAttention(nn.Module): 410 | def __init__(self, config, dim, num_heads, window_size): 411 | super().__init__() 412 | self.self = DocParserSelfAttention(config, dim, num_heads, window_size) 413 | self.output = DocParserSelfOutput(config, dim) 414 | self.pruned_heads = set() 415 | 416 | def prune_heads(self, heads): 417 | if len(heads) == 0: 418 | return 419 | heads, index = find_pruneable_heads_and_indices( 420 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 421 | ) 422 | 423 | # Prune linear layers 424 | self.self.query = prune_linear_layer(self.self.query, index) 425 | self.self.key = prune_linear_layer(self.self.key, index) 426 | self.self.value = prune_linear_layer(self.self.value, index) 427 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 428 | 429 | # Update hyper params and store pruned heads 430 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 431 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 432 | self.pruned_heads = self.pruned_heads.union(heads) 433 | 434 | def forward( 435 | self, 436 | hidden_states: torch.Tensor, 437 | attention_mask: Optional[torch.FloatTensor] = None, 438 | head_mask: Optional[torch.FloatTensor] = None, 439 | output_attentions: Optional[bool] = False, 440 | ) -> Tuple[torch.Tensor]: 441 | self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) 442 | attention_output = self.output(self_outputs[0], hidden_states) 443 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 444 | return outputs 445 | 446 | 447 | # Copied from transformers.models.swin.modeling_swin.SwinIntermediate 448 | class DocParserIntermediate(nn.Module): 449 | def __init__(self, config, dim): 450 | super().__init__() 451 | self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) 452 | if isinstance(config.hidden_act, str): 453 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 454 | else: 455 | self.intermediate_act_fn = config.hidden_act 456 | 457 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 458 | hidden_states = self.dense(hidden_states) 459 | hidden_states = self.intermediate_act_fn(hidden_states) 460 | return hidden_states 461 | 462 | 463 | # Copied from transformers.models.swin.modeling_swin.SwinOutput 464 | class DocParserOutput(nn.Module): 465 | def __init__(self, config, dim): 466 | super().__init__() 467 | self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) 468 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 469 | 470 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 471 | hidden_states = self.dense(hidden_states) 472 | hidden_states = self.dropout(hidden_states) 473 | return hidden_states 474 | 475 | 476 | # Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DocParser 477 | class DocParserLayer(nn.Module): 478 | def __init__(self, config, dim, input_resolution, window_size, num_heads, shift_size=0): 479 | super().__init__() 480 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 481 | self.shift_size = shift_size 482 | self.window_size = window_size 483 | self.input_resolution = input_resolution 484 | self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) 485 | self.attention = DocParserAttention(config, dim, num_heads, window_size=self.window_size) 486 | self.drop_path = DocParserDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() 487 | self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) 488 | self.intermediate = DocParserIntermediate(config, dim) 489 | self.output = DocParserOutput(config, dim) 490 | 491 | def set_shift_and_window_size(self, input_resolution): 492 | if min(input_resolution) <= min(self.window_size): 493 | # if window size is larger than input resolution, we don't partition windows 494 | self.shift_size = 0 495 | self.window_size = min(input_resolution) 496 | 497 | def get_attn_mask(self, height, width, dtype): 498 | if isinstance(self.shift_size, list): 499 | # calculate attention mask for SW-MSA 500 | img_mask = torch.zeros((1, height, width, 1), dtype=dtype) 501 | height_slices = ( 502 | slice(0, -self.window_size[0]), 503 | slice(-self.window_size[0], -self.shift_size[0]), 504 | slice(-self.shift_size[0], None), 505 | ) 506 | width_slices = ( 507 | slice(0, -self.window_size[1]), 508 | slice(-self.window_size[1], -self.shift_size[1]), 509 | slice(-self.shift_size[1], None), 510 | ) 511 | count = 0 512 | for height_slice in height_slices: 513 | for width_slice in width_slices: 514 | img_mask[:, height_slice, width_slice, :] = count 515 | count += 1 516 | 517 | mask_windows = window_partition(img_mask, self.window_size) 518 | mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1]) 519 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 520 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 521 | else: 522 | attn_mask = None 523 | return attn_mask 524 | 525 | def maybe_pad(self, hidden_states, height, width): 526 | pad_right = (self.window_size[1] - width % self.window_size[1]) % self.window_size[1] 527 | pad_bottom = (self.window_size[0] - height % self.window_size[0]) % self.window_size[0] 528 | pad_values = (0, 0, 0, pad_right, 0, pad_bottom) 529 | hidden_states = nn.functional.pad(hidden_states, pad_values) 530 | return hidden_states, pad_values 531 | 532 | def forward( 533 | self, 534 | hidden_states: torch.Tensor, 535 | input_dimensions: Tuple[int, int], 536 | head_mask: Optional[torch.FloatTensor] = None, 537 | output_attentions: Optional[bool] = False, 538 | always_partition: Optional[bool] = False, 539 | ) -> Tuple[torch.Tensor, torch.Tensor]: 540 | if not always_partition: 541 | self.set_shift_and_window_size(input_dimensions) 542 | else: 543 | pass 544 | height, width = input_dimensions 545 | batch_size, _, channels = hidden_states.size() 546 | shortcut = hidden_states 547 | 548 | hidden_states = self.layernorm_before(hidden_states) 549 | 550 | hidden_states = hidden_states.view(batch_size, height, width, channels) 551 | 552 | # pad hidden_states to multiples of window size 553 | hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) 554 | 555 | _, height_pad, width_pad, _ = hidden_states.shape 556 | # cyclic shift 557 | if isinstance(self.shift_size, list): 558 | shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size[0], -self.shift_size[1]), 559 | dims=(1, 2)) 560 | else: 561 | shifted_hidden_states = hidden_states 562 | 563 | # partition windows 564 | hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) 565 | hidden_states_windows = hidden_states_windows.view(-1, self.window_size[0] * self.window_size[1], channels) 566 | attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) 567 | if attn_mask is not None: 568 | attn_mask = attn_mask.to(hidden_states_windows.device) 569 | 570 | attention_outputs = self.attention( 571 | hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions 572 | ) 573 | 574 | attention_output = attention_outputs[0] 575 | 576 | attention_windows = attention_output.view(-1, self.window_size[0], self.window_size[1], channels) 577 | shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) 578 | 579 | # reverse cyclic shift 580 | if isinstance(self.shift_size, list) > 0: 581 | attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size[0], self.shift_size[1]), 582 | dims=(1, 2)) 583 | else: 584 | attention_windows = shifted_windows 585 | 586 | was_padded = pad_values[3] > 0 or pad_values[5] > 0 587 | if was_padded: 588 | attention_windows = attention_windows[:, :height, :width, :].contiguous() 589 | 590 | attention_windows = attention_windows.view(batch_size, height * width, channels) 591 | 592 | hidden_states = shortcut + self.drop_path(attention_windows) 593 | 594 | layer_output = self.layernorm_after(hidden_states) 595 | layer_output = self.intermediate(layer_output) 596 | layer_output = hidden_states + self.output(layer_output) 597 | 598 | layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) 599 | return layer_outputs 600 | 601 | 602 | # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DocParser 603 | class DocParserStage(nn.Module): 604 | def __init__(self, config, dim, dim_out, input_resolution, depth, window_size, num_heads, drop_path, downsample): 605 | super().__init__() 606 | self.config = config 607 | self.dim = dim 608 | self.blocks = nn.ModuleList( 609 | [ 610 | DocParserLayer( 611 | config=config, 612 | dim=dim, 613 | input_resolution=input_resolution, 614 | num_heads=num_heads, 615 | window_size=window_size, 616 | shift_size=0 if (i % 2 == 0) else [window_size[0] // 2, window_size[1] // 2], 617 | ) 618 | for i in range(depth) 619 | ] 620 | ) 621 | 622 | # patch merging layer 623 | if downsample is not None: 624 | self.downsample = downsample(input_resolution, dim=dim, dim_out=dim_out, norm_layer=nn.LayerNorm) 625 | else: 626 | self.downsample = None 627 | 628 | self.pointing = False 629 | 630 | def forward( 631 | self, 632 | hidden_states: torch.Tensor, 633 | input_dimensions: Tuple[int, int], 634 | head_mask: Optional[torch.FloatTensor] = None, 635 | output_attentions: Optional[bool] = False, 636 | always_partition: Optional[bool] = False, 637 | ) -> Tuple[torch.Tensor]: 638 | height, width = input_dimensions 639 | for i, layer_module in enumerate(self.blocks): 640 | layer_head_mask = head_mask[i] if head_mask is not None else None 641 | 642 | layer_outputs = layer_module( 643 | hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition 644 | ) 645 | 646 | hidden_states = layer_outputs[0] 647 | 648 | hidden_states_before_downsampling = hidden_states 649 | if self.downsample is not None: 650 | height_downsampled, width_downsampled = height, (width + 1) // 2 651 | output_dimensions = (height, width, height_downsampled, width_downsampled) 652 | hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) 653 | else: 654 | output_dimensions = (height, width, height, width) 655 | 656 | stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) 657 | 658 | if output_attentions: 659 | stage_outputs += layer_outputs[1:] 660 | return stage_outputs 661 | 662 | 663 | class DocParserConvNeXtEncoder(nn.Module): 664 | def __init__(self, config): 665 | super().__init__() 666 | self.config = config 667 | conv_depth_num_layers = config.conv_depth_num_layers 668 | conv_embed_dim = config.embed_dim[:conv_depth_num_layers] 669 | conv_depth = config.depths[:conv_depth_num_layers] 670 | stride_size = config.stride_size 671 | # ConNeXt Stage 672 | self.layers = ConvNeXt(block_setting=[ 673 | CNBlockConfig(input_channels=conv_embed_dim[i_layer], 674 | out_channels=conv_embed_dim[i_layer] * 2, 675 | num_layers=conv_depth[i_layer], 676 | stride=stride_size[i_layer] 677 | ) 678 | for i_layer in range(conv_depth_num_layers)], 679 | stochastic_depth_prob=0.1) 680 | 681 | def forward( 682 | self, 683 | hidden_states: torch.Tensor): 684 | return self.layers(hidden_states) 685 | 686 | 687 | # Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DocParser 688 | class DocParserEncoder(nn.Module): 689 | def __init__(self, config, grid_size): 690 | super().__init__() 691 | self.num_layers = len(config.depths) 692 | self.config = config 693 | swin_depth_num_layers = self.num_layers - config.conv_depth_num_layers 694 | swin_embed_dim = config.embed_dim[swin_depth_num_layers:] 695 | swin_depth = config.depths[swin_depth_num_layers:] 696 | 697 | # Swin-ViT Stage 698 | dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] 699 | self.layers = nn.ModuleList([ 700 | DocParserStage( 701 | config=config, 702 | window_size=config.window_size[i_layer], 703 | dim=int(swin_embed_dim[i_layer]), 704 | dim_out=int(swin_embed_dim[i_layer + 1]) if i_layer < swin_depth_num_layers - 1 else int( 705 | swin_embed_dim[i_layer]), 706 | input_resolution=(grid_size[0], grid_size[1] // (2 ** i_layer)), 707 | depth=swin_depth[i_layer], 708 | num_heads=config.num_heads[i_layer], 709 | drop_path=dpr[sum(swin_depth[:i_layer]): sum(swin_depth[: i_layer + 1])], 710 | downsample=DocParserPatchMerging if (i_layer < swin_depth_num_layers - 1) else None, 711 | ) 712 | for i_layer in range(swin_depth_num_layers) 713 | ]) 714 | self.gradient_checkpointing = False 715 | 716 | def forward( 717 | self, 718 | hidden_states: torch.Tensor, 719 | input_dimensions: Tuple[int, int], 720 | head_mask: Optional[torch.FloatTensor] = None, 721 | output_attentions: Optional[bool] = False, 722 | output_hidden_states: Optional[bool] = False, 723 | output_hidden_states_before_downsampling: Optional[bool] = False, 724 | always_partition: Optional[bool] = False, 725 | return_dict: Optional[bool] = True, 726 | ) -> Union[Tuple, DocParserEncoderOutput]: 727 | all_hidden_states = () if output_hidden_states else None 728 | all_reshaped_hidden_states = () if output_hidden_states else None 729 | all_self_attentions = () if output_attentions else None 730 | 731 | if output_hidden_states: 732 | batch_size, _, hidden_size = hidden_states.shape 733 | # rearrange b (h w) c -> b c h w 734 | reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) 735 | reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) 736 | all_hidden_states += (hidden_states,) 737 | all_reshaped_hidden_states += (reshaped_hidden_state,) 738 | 739 | for i, layer_module in enumerate(self.layers): 740 | layer_head_mask = head_mask[i] if head_mask is not None else None 741 | 742 | if self.gradient_checkpointing and self.training: 743 | 744 | def create_custom_forward(module): 745 | def custom_forward(*inputs): 746 | return module(*inputs, output_attentions) 747 | 748 | return custom_forward 749 | 750 | layer_outputs = torch.utils.checkpoint.checkpoint( 751 | create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask 752 | ) 753 | else: 754 | layer_outputs = layer_module( 755 | hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition 756 | ) 757 | 758 | hidden_states = layer_outputs[0] 759 | hidden_states_before_downsampling = layer_outputs[1] 760 | output_dimensions = layer_outputs[2] 761 | 762 | input_dimensions = (output_dimensions[-2], output_dimensions[-1]) 763 | 764 | if output_hidden_states and output_hidden_states_before_downsampling: 765 | batch_size, _, hidden_size = hidden_states_before_downsampling.shape 766 | # rearrange b (h w) c -> b c h w 767 | # here we use the original (not downsampled) height and width 768 | reshaped_hidden_state = hidden_states_before_downsampling.view( 769 | batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size 770 | ) 771 | reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) 772 | all_hidden_states += (hidden_states_before_downsampling,) 773 | all_reshaped_hidden_states += (reshaped_hidden_state,) 774 | elif output_hidden_states and not output_hidden_states_before_downsampling: 775 | batch_size, _, hidden_size = hidden_states.shape 776 | # rearrange b (h w) c -> b c h w 777 | reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) 778 | reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) 779 | all_hidden_states += (hidden_states,) 780 | all_reshaped_hidden_states += (reshaped_hidden_state,) 781 | 782 | if output_attentions: 783 | all_self_attentions += layer_outputs[3:] 784 | 785 | if not return_dict: 786 | return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) 787 | 788 | return DocParserEncoderOutput( 789 | last_hidden_state=hidden_states, 790 | hidden_states=all_hidden_states, 791 | attentions=all_self_attentions, 792 | reshaped_hidden_states=all_reshaped_hidden_states, 793 | ) 794 | 795 | 796 | # Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DocParser 797 | class DocParserPreTrainedModel(PreTrainedModel): 798 | """ 799 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 800 | models. 801 | """ 802 | 803 | config_class = DocParserConfig 804 | base_model_prefix = "swin" 805 | main_input_name = "pixel_values" 806 | supports_gradient_checkpointing = True 807 | 808 | def _init_weights(self, module): 809 | """Initialize the weights""" 810 | if isinstance(module, (nn.Linear, nn.Conv2d)): 811 | # Slightly different from the TF version which uses truncated_normal for initialization 812 | # cf https://github.com/pytorch/pytorch/pull/5617 813 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 814 | if module.bias is not None: 815 | module.bias.data.zero_() 816 | elif isinstance(module, nn.LayerNorm): 817 | module.bias.data.zero_() 818 | module.weight.data.fill_(1.0) 819 | 820 | def _set_gradient_checkpointing(self, module, value=False): 821 | if isinstance(module, DocParserEncoder): 822 | module.gradient_checkpointing = value 823 | 824 | 825 | SWIN_START_DOCSTRING = r""" 826 | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use 827 | it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and 828 | behavior. 829 | 830 | Parameters: 831 | config ([`DocParserConfig`]): Model configuration class with all the parameters of the model. 832 | Initializing with a config file does not load the weights associated with the model, only the 833 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 834 | """ 835 | 836 | SWIN_INPUTS_DOCSTRING = r""" 837 | Args: 838 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 839 | Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See 840 | [`DonutImageProcessor.__call__`] for details. 841 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 842 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 843 | 844 | - 1 indicates the head is **not masked**, 845 | - 0 indicates the head is **masked**. 846 | 847 | output_attentions (`bool`, *optional*): 848 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 849 | tensors for more detail. 850 | output_hidden_states (`bool`, *optional*): 851 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 852 | more detail. 853 | return_dict (`bool`, *optional*): 854 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 855 | """ 856 | 857 | 858 | @add_start_docstrings( 859 | "The bare Donut Swin Model transformer outputting raw hidden-states without any specific head on top.", 860 | SWIN_START_DOCSTRING, 861 | ) 862 | class DocParserModel(DocParserPreTrainedModel): 863 | def __init__(self, config, add_pooling_layer=True, use_mask_token=False): 864 | super().__init__(config) 865 | self.config = config 866 | self.num_layers = len(config.depths) 867 | 868 | self.embeddings = DocParserPatchEmbeddings(config) 869 | self.convnext_encoder = DocParserConvNeXtEncoder(config) 870 | self.encoder = DocParserEncoder(config, self.embeddings.grid_size) 871 | 872 | self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None 873 | 874 | # Initialize weights and apply final processing 875 | self.post_init() 876 | 877 | def get_input_embeddings(self): 878 | return self.embeddings.patch_embeddings 879 | 880 | def _prune_heads(self, heads_to_prune): 881 | """ 882 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 883 | class PreTrainedModel 884 | """ 885 | for layer, heads in heads_to_prune.items(): 886 | self.encoder.layer[layer].attention.prune_heads(heads) 887 | 888 | @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) 889 | @add_code_sample_docstrings( 890 | checkpoint=_CHECKPOINT_FOR_DOC, 891 | output_type=DocParserModelOutput, 892 | config_class=_CONFIG_FOR_DOC, 893 | modality="vision", 894 | expected_output=_EXPECTED_OUTPUT_SHAPE, 895 | ) 896 | def forward( 897 | self, 898 | pixel_values: Optional[torch.FloatTensor] = None, 899 | head_mask: Optional[torch.FloatTensor] = None, 900 | output_attentions: Optional[bool] = None, 901 | output_hidden_states: Optional[bool] = None, 902 | return_dict: Optional[bool] = None, 903 | ) -> Union[Tuple, DocParserModelOutput]: 904 | r""" 905 | bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): 906 | Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). 907 | """ 908 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 909 | output_hidden_states = ( 910 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 911 | ) 912 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 913 | 914 | if pixel_values is None: 915 | raise ValueError("You have to specify pixel_values") 916 | 917 | # Prepare head mask if needed 918 | # 1.0 in head_mask indicate we keep the head 919 | # attention_probs has shape bsz x n_heads x N x N 920 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 921 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 922 | head_mask = self.get_head_mask(head_mask, len(self.config.depths)) 923 | 924 | embedding_output = self.embeddings(pixel_values) 925 | 926 | # ConvNext Stage 927 | encoder_outputs = self.convnext_encoder(embedding_output) 928 | 929 | # ConvNext to Swin-ViT Stage 930 | _, _, height, width = encoder_outputs.shape 931 | input_dimensions = (height, width) 932 | encoder_outputs = encoder_outputs.flatten(2).transpose(1, 2) 933 | 934 | # Swin-ViT Stage 935 | encoder_outputs = self.encoder( 936 | encoder_outputs, 937 | input_dimensions, 938 | head_mask=head_mask, 939 | output_attentions=output_attentions, 940 | output_hidden_states=output_hidden_states, 941 | return_dict=return_dict, 942 | ) 943 | 944 | sequence_output = encoder_outputs[0] 945 | 946 | pooled_output = None 947 | if self.pooler is not None: 948 | pooled_output = self.pooler(sequence_output.transpose(1, 2)) 949 | pooled_output = torch.flatten(pooled_output, 1) 950 | 951 | if not return_dict: 952 | output = (sequence_output, pooled_output) + encoder_outputs[1:] 953 | 954 | return output 955 | 956 | return DocParserModelOutput( 957 | last_hidden_state=sequence_output, 958 | pooler_output=pooled_output, 959 | hidden_states=encoder_outputs.hidden_states, 960 | attentions=encoder_outputs.attentions, 961 | reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, 962 | ) 963 | -------------------------------------------------------------------------------- /mydatasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: 2021/6/8 3 | from .docparser_dataset import DocParser, DataCollatorForDocParserDataset 4 | 5 | 6 | def get_dataset(dataset_args): 7 | dataset_type = dataset_args.get("type") 8 | dataset = eval(dataset_type)(**dataset_args) 9 | return dataset 10 | 11 | -------------------------------------------------------------------------------- /mydatasets/docparser_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: @time: 6/6/23 10:30 3 | """ 4 | Dataloader for Pretraining Task of DocParser 5 | 6 | Masked Document Reading Step After the knowledge transfer step, we 7 | pre-train our model on the task of document reading. In this pre-training phase, 8 | the model learns to predict the next textual token while conditioning on the 9 | previous textual tokens and the input image. To encourage joint reasoning, we 10 | mask several 32 × 32 blocks representing approximately fifteen percent of the 11 | input image. In fact, in order to predict the text situated within the masked 12 | regions, the model is obliged to understand its textual context. 13 | 14 | """ 15 | import os 16 | import os.path 17 | import random 18 | from dataclasses import dataclass 19 | from typing import Any, Dict, List, Tuple, Sequence 20 | 21 | import torch 22 | from PIL import Image, ImageFile 23 | from torch.utils.data import Dataset 24 | from tqdm import tqdm 25 | from transformers.modeling_utils import PreTrainedModel 26 | 27 | from base.common_util import load_json 28 | 29 | ImageFile.LOAD_TRUNCATED_IMAGES = True 30 | 31 | 32 | # Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right 33 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 34 | """ 35 | Shift input ids one token to the right. 36 | """ 37 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 38 | shifted_input_ids[1:] = input_ids[:-1].clone() 39 | if decoder_start_token_id is None: 40 | raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") 41 | shifted_input_ids[0] = decoder_start_token_id 42 | 43 | if pad_token_id is None: 44 | raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") 45 | # replace possible -100 values in labels by `pad_token_id` 46 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 47 | 48 | return shifted_input_ids 49 | 50 | 51 | class DocParser(Dataset): 52 | """ 53 | DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets) 54 | Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt), 55 | and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string) 56 | 57 | Args: 58 | data_root: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl 59 | ignore_id: ignore_index for torch.nn.CrossEntropyLoss 60 | task_start_token: the special token to be fed to the decoder to conduct the target task 61 | """ 62 | 63 | def __init__( 64 | self, 65 | data_root: list, 66 | donut_model: PreTrainedModel, 67 | processor, 68 | max_length: int, 69 | phase: str = "train", 70 | ignore_id: int = -100, 71 | task_start_token: str = "", 72 | prompt_end_token: str = None, 73 | sort_json_key: bool = True, 74 | **kwargs 75 | ): 76 | super().__init__() 77 | 78 | self.donut_model = donut_model 79 | self.processor = processor 80 | self.max_length = max_length 81 | self.phase = phase 82 | self.ignore_id = ignore_id 83 | self.task_start_token = task_start_token 84 | self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token 85 | self.sort_json_key = sort_json_key 86 | gt_info_list = [] 87 | self.img_path_list = [] 88 | print("processing json to token sequence...") 89 | for data_dir in data_root: 90 | for gt_info in load_json(data_dir): 91 | gt_info_list.extend(gt_info) 92 | 93 | self.dataset_length = len(gt_info_list) 94 | self.gt_token_sequences = [] 95 | self.special_token_list = [] 96 | 97 | for gt_info in tqdm(gt_info_list): 98 | gt_token_sequence = self.json2token( 99 | gt_info['extract_info'], 100 | update_special_tokens_for_json_key=self.phase == "train", 101 | sort_json_key=self.sort_json_key, 102 | ) + self.processor.tokenizer.eos_token 103 | self.gt_token_sequences.append(gt_token_sequence) 104 | self.img_path_list.append(os.path.join(gt_info['filepath'], gt_info['filename'])) 105 | 106 | # add special token 107 | list_of_tokens = [self.task_start_token, self.prompt_end_token] 108 | 109 | self.add_tokens(list_of_tokens) 110 | self.donut_model.decoder.resize_token_embeddings(len(self.processor.tokenizer)) 111 | 112 | # patch config 113 | self.height, self.width = self.processor.image_processor.size['height'], self.processor.image_processor.size[ 114 | 'width'] 115 | self.num_patches = self.height // 32 * self.width // 32 116 | self.mask_tensor = torch.zeros(3, 32, 32) 117 | 118 | def add_tokens(self, list_of_tokens: List[str]): 119 | """ 120 | Add special tokens to tokenizer and resize the token embeddings of the decoder 121 | """ 122 | newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens) 123 | if newly_added_num > 0: 124 | self.special_token_list.extend(list_of_tokens) 125 | 126 | def json2token(self, obj: Any, 127 | update_special_tokens_for_json_key: bool = True, 128 | sort_json_key: bool = True): 129 | """ 130 | Convert an ordered JSON object into a token sequence 131 | """ 132 | if type(obj) == dict: 133 | if len(obj) == 1 and "text_sequence" in obj: 134 | return obj["text_sequence"] 135 | else: 136 | output = "" 137 | if sort_json_key: 138 | keys = sorted(obj.keys(), reverse=True) 139 | else: 140 | keys = obj.keys() 141 | for k in keys: 142 | if update_special_tokens_for_json_key: 143 | list_of_tokens = [fr"", fr""] 144 | # add extract token 145 | self.add_tokens(list_of_tokens) 146 | output += ( 147 | fr"" 148 | + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key) 149 | + fr"" 150 | ) 151 | return output 152 | elif type(obj) == list: 153 | return r"".join( 154 | [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj] 155 | ) 156 | else: 157 | obj = str(obj) 158 | if f"<{obj}/>" in self.special_token_list: 159 | obj = f"<{obj}/>" # for categorical special tokens 160 | return obj 161 | 162 | def __len__(self) -> int: 163 | return self.dataset_length 164 | 165 | def __getitem__(self, idx: int): 166 | try: 167 | # pixel_tensor 168 | sample = Image.open(self.img_path_list[idx]).convert("RGB") 169 | input_tensor = self.processor(sample, random_padding=self.phase == "train", 170 | do_normalize=False, 171 | return_tensors="pt").pixel_values[0] 172 | 173 | # To encourage joint reasoning, we mask several 32 × 32 blocks 174 | # representing approximately fifteen percent of the input image. 175 | input_tensor = self.mask_document_patch(input_tensor) 176 | 177 | # input_ids 178 | processed_parse = self.gt_token_sequences[idx] 179 | input_ids = self.processor.tokenizer( 180 | processed_parse, 181 | add_special_tokens=False, 182 | max_length=self.max_length, 183 | padding="max_length", 184 | truncation=True, 185 | return_tensors="pt", 186 | )["input_ids"].squeeze(0) 187 | 188 | labels = input_ids.clone() 189 | labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id 190 | except: 191 | random_index = random.randrange(self.__len__()) 192 | return self.__getitem__(random_index) 193 | # model doesn't need to predict pad token 194 | return input_tensor, labels, processed_parse 195 | 196 | def mask_document_patch(self, pixel_values): 197 | patch_width = self.width // 32 198 | sample_idx_list = random.sample(list(range(self.num_patches)), int(self.num_patches * 0.15)) 199 | for sample_id in sample_idx_list: 200 | row_id = sample_id // patch_width 201 | col_id = sample_id % patch_width 202 | pixel_values[:, row_id * 32: (row_id + 1) * 32, col_id * 32: (col_id + 1) * 32] = self.mask_tensor 203 | return self.processor(pixel_values, return_tensors="pt").pixel_values[0] 204 | 205 | 206 | @dataclass 207 | class DataCollatorForDocParserDataset(object): 208 | """Collate examples for supervised fine-tuning.""" 209 | 210 | def __init__(self, **kwargs): 211 | pass 212 | 213 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 214 | batch = dict() 215 | # pixel_values 216 | images = [instance[0] for instance in instances] 217 | batch['pixel_values'] = torch.stack(images) 218 | # labels 219 | labels = [instance[1] for instance in instances] 220 | batch['labels'] = torch.stack(labels) 221 | # processed_parse 222 | batch['processed_parse'] = [instance[2] for instance in instances] 223 | return batch 224 | 225 | 226 | 227 | 228 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anyconfig 2 | accelerate 3 | munch 4 | torch==2.0.0 5 | torchvision==0.15.1 6 | transformers==4.28.1 -------------------------------------------------------------------------------- /train/train_experiment.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # create: 2021/6/10 3 | import os 4 | import sys 5 | import argparse 6 | import setproctitle 7 | 8 | PROJECT_ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | sys.path.append(PROJECT_ROOT_PATH) 10 | os.environ['RUN_ON_GPU_IDs'] = "0" 11 | 12 | import experiment 13 | 14 | from base.common_util import get_absolute_file_path, init_experiment_config 15 | from experiment import get_experiment_name 16 | 17 | 18 | def init_args(): 19 | parser = argparse.ArgumentParser(description='trainer args') 20 | parser.add_argument( 21 | '--config_file', 22 | default='config/base.yaml', 23 | type=str, 24 | ) 25 | parser.add_argument( 26 | '--experiment_name', 27 | default='DocParser', 28 | type=str, 29 | ) 30 | parser.add_argument( 31 | '--phase', 32 | default='train', 33 | type=str, 34 | ) 35 | parser.add_argument( 36 | '--use_accelerate', 37 | default=False, 38 | type=bool, 39 | ) 40 | args = parser.parse_args() 41 | os.environ['WORKSPACE'] = args.experiment_name 42 | return args 43 | 44 | 45 | def main(args): 46 | config = init_experiment_config(args.config_file, args.experiment_name) 47 | config.update({'phase': args.phase, 48 | 'use_accelerate': args.use_accelerate}) 49 | experiment_instance = getattr(experiment, get_experiment_name(args.experiment_name))(config) 50 | if args.phase == 'train': 51 | experiment_instance.train() 52 | elif args.phase == 'predict': 53 | experiment_instance.predict() 54 | else: 55 | print("Unimplemented phase: {}".format(args.phase)) 56 | 57 | 58 | if __name__ == '__main__': 59 | args = init_args() 60 | setproctitle.setproctitle("{} task for {}".format(args.experiment_name, args.config_file.split('/')[-1])) 61 | main(args) 62 | --------------------------------------------------------------------------------