├── .gitignore
├── LICENSE
├── README.md
├── base
├── __init__.py
├── common_util.py
├── driver.py
├── meter.py
└── torch_utils
│ ├── __init__.py
│ ├── dl_util.py
│ └── scheduler_util.py
├── config
└── base.yaml
├── doc
└── encoder_arch.jpeg
├── examples
└── test_forward.py
├── experiment
├── __init__.py
├── base_experiment.py
└── docparser_experiment.py
├── logs
└── .gitignore
├── models
├── __init__.py
├── config.json
├── configuration_docparser.py
├── convnext.py
└── modeling_docparser.py
├── mydatasets
├── __init__.py
└── docparser_dataset.py
├── requirements.txt
└── train
└── train_experiment.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ### JetBrains template
2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
4 |
5 | # User-specific stuff
6 | .idea/**/workspace.xml
7 | .idea/**/tasks.xml
8 | .idea/**/usage.statistics.xml
9 | .idea/**/dictionaries
10 | .idea/**/shelf
11 |
12 | # Generated files
13 | .idea/**/contentModel.xml
14 |
15 | # Sensitive or high-churn files
16 | .idea/**/dataSources/
17 | .idea/**/dataSources.ids
18 | .idea/**/dataSources.local.xml
19 | .idea/**/sqlDataSources.xml
20 | .idea/**/dynamic.xml
21 | .idea/**/uiDesigner.xml
22 | .idea/**/dbnavigator.xml
23 |
24 | # Gradle
25 | .idea/**/gradle.xml
26 | .idea/**/libraries
27 |
28 | # Gradle and Maven with auto-import
29 | # When using Gradle or Maven with auto-import, you should exclude module files,
30 | # since they will be recreated, and may cause churn. Uncomment if using
31 | # auto-import.
32 | # .idea/artifacts
33 | # .idea/compiler.xml
34 | # .idea/jarRepositories.xml
35 | # .idea/modules.xml
36 | # .idea/*.iml
37 | # .idea/modules
38 | # *.iml
39 | # *.ipr
40 |
41 | # CMake
42 | cmake-build-*/
43 |
44 | # Mongo Explorer plugin
45 | .idea/**/mongoSettings.xml
46 |
47 | # File-based project format
48 | *.iws
49 |
50 | # IntelliJ
51 | out/
52 |
53 | # mpeltonen/sbt-idea plugin
54 | .idea_modules/
55 |
56 | # JIRA plugin
57 | atlassian-ide-plugin.xml
58 |
59 | # Cursive Clojure plugin
60 | .idea/replstate.xml
61 |
62 | # Crashlytics plugin (for Android Studio and IntelliJ)
63 | com_crashlytics_export_strings.xml
64 | crashlytics.properties
65 | crashlytics-build.properties
66 | fabric.properties
67 |
68 | # Editor-based Rest Client
69 | .idea/httpRequests
70 |
71 | # Android studio 3.1+ serialized cache file
72 | .idea/caches/build_file_checksums.ser
73 |
74 | ### macOS template
75 | # General
76 | .DS_Store
77 | .AppleDouble
78 | .LSOverride
79 |
80 | # Icon must end with two \r
81 | Icon
82 |
83 | # Thumbnails
84 | ._*
85 |
86 | # Files that might appear in the root of a volume
87 | .DocumentRevisions-V100
88 | .fseventsd
89 | .Spotlight-V100
90 | .TemporaryItems
91 | .Trashes
92 | .VolumeIcon.icns
93 | .com.apple.timemachine.donotpresent
94 |
95 | # Directories potentially created on remote AFP share
96 | .AppleDB
97 | .AppleDesktop
98 | Network Trash Folder
99 | Temporary Items
100 | .apdisk
101 |
102 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Nuo Xu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DocParser: End-to-end OCR-free Information Extraction from Visually Rich Documents
2 |
3 | This is an unofficial Pytorch implementation of DocParser.
4 |
5 |
6 |
7 | The architecture of DocParser's Encoder
8 |
9 |
10 | ## News
11 | - **Sep 1st**, release the ConNext weight [here](https://drive.google.com/drive/folders/1ZsvXgULWWm3sR6ZKxGlHmGO5-ESvXl1J?usp=drive_link). Please note that this weight is trained with a CTC head on a OCR task and can only be used to initialize the ConvNext part in the docparser during pretraining. It is NOT intended for fine-tuning in any downstream tasks.
12 | - **July 15th**, update training scripts for Masked Document Reading Task and model architecture.
13 |
14 | ## How to use
15 | ### 1. Set Up Environment
16 | ```shell
17 | pip install -r requirements.txt
18 | ```
19 |
20 | ### 2. Prepare Dataset
21 | The dataset should be processed into the following format
22 | ```json
23 | {
24 | "filepath": "path/to/image/folder", // path to image folder
25 | "filename": "file_name", // file name
26 | "extract_info": {
27 | "ocr_info": [
28 | {
29 | "chunk": "text1"
30 | },
31 | {
32 | "chunk": "text2"
33 | },
34 | {
35 | "chunk": "text3"
36 | }
37 | ]
38 | } // a list of ocr info of filepath/filename
39 | }
40 | ```
41 | ### 3. Start Training
42 | You can start the training from ```train/train_experiment.py``` or
43 |
44 | ```shell
45 | python train/train_experiment.py --config_file config/base.yaml
46 | ```
47 | The training script also support ddp with huggingface/accelerate by
48 | ```shell
49 | accelerate train/train_experiment.py --config_file config/base.yaml --use_accelerate True
50 | ```
51 | ### 4. Notes
52 | The training script currently solely implements the **Masked Document Reading Step** described in the paper. The decoder weights, tokenizer and processor are borrowed from [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base).
53 |
54 | Unfortunately, there is no DocParser pre-training weights publicly available. Simply borrowing weights from Donut-based fails to benefit DocParser on any downstream tasks. But I am working on training a pretraining DocParser based on the two-stage tasks mentioned in the paper recently. Once I successfully complete both the pretraining tasks, and achieve a well-performing model successfully, I intend to make it publicly available on the Huggingface hub.
--------------------------------------------------------------------------------
/base/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | # create: 2021/6/9
--------------------------------------------------------------------------------
/base/common_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: 2021/6/10
3 |
4 | import os
5 | import glob
6 | import codecs
7 | import json
8 | from natsort import natsorted
9 | from base.driver import logger, PROJECT_ROOT_PATH
10 |
11 |
12 | def get_file_list(folder_path: str, p_postfix: list = None, sub_dir: bool = True) -> list:
13 | assert os.path.exists(folder_path) and os.path.isdir(folder_path)
14 | if p_postfix is None:
15 | p_postfix = ['.jpg']
16 | if isinstance(p_postfix, str):
17 | p_postfix = [p_postfix]
18 | logger.info("begin to get files from:{}".format(folder_path))
19 | file_list = [
20 | x for x in glob.glob(folder_path + '/**/*.*', recursive=True)
21 | if os.path.splitext(x)[-1].lower() in p_postfix or '*'.lower() in p_postfix
22 | ]
23 | logger.info("success to get files from:{}".format(folder_path))
24 |
25 | return natsorted(file_list)
26 |
27 |
28 | def search_file_from_dir(file_dir, file_path_pref_list, base_pref="", exts=["pdf", "PDF"], save_pref=True):
29 | assert isinstance(file_path_pref_list, list)
30 | base_dir_name = os.path.basename(file_dir)
31 | logger.info("searching exts:{} from:{}".format(exts, file_dir))
32 | if len(base_pref.strip()) == 0:
33 | current_base_pref = ""
34 | else:
35 | current_base_pref = "{}_{}".format(base_pref, base_dir_name)
36 | for file_name in os.listdir(file_dir):
37 | file_path = os.path.join(file_dir, file_name)
38 | if os.path.isdir(file_path):
39 | search_file_from_dir(file_path, file_path_pref_list, current_base_pref, exts, save_pref)
40 | else:
41 | file_res = file_name.rsplit(".", 1)
42 | if len(file_res) == 2:
43 | file_pref, file_ext = file_res
44 | if file_ext.strip().lower() in exts:
45 | if save_pref:
46 | if len(current_base_pref.strip()) > 0:
47 | current_file_pref = "{}_{}".format(current_base_pref, file_pref)
48 | else:
49 | current_file_pref = "{}".format(file_pref)
50 | file_path_pref_list.append((file_path, current_file_pref))
51 | else:
52 | file_path_pref_list.append(file_path)
53 |
54 |
55 | def get_absolute_file_path(file_path):
56 | if file_path.startswith("/"):
57 | return file_path
58 | else:
59 | return os.path.join(PROJECT_ROOT_PATH, file_path)
60 |
61 |
62 | def get_file_path_list(path, ext=None):
63 | if not path.startswith('/'):
64 | path = os.path.join(PROJECT_ROOT_PATH, path)
65 | # print(path)
66 | assert os.path.exists(path), 'path not exist {}'.format(path)
67 | assert ext is not None, 'ext is None'
68 | if os.path.isfile(path):
69 | return [path]
70 | file_path_list = []
71 | for root, _, files in os.walk(path):
72 | for file in files:
73 | try:
74 | if file.rsplit('.')[-1].lower() in ext:
75 | file_path_list.append(os.path.join(root, file))
76 | except Exception as e:
77 | pass
78 | return file_path_list
79 |
80 |
81 | # load json data
82 | def load_json(data):
83 | if isinstance(data, dict):
84 | return [data]
85 | elif isinstance(data, list):
86 | file_path_list = data
87 | elif data.endswith('.json'):
88 | file_path_list = [data]
89 | else:
90 | file_path_list = get_file_path_list(data, '.json')
91 | json_data_list = list()
92 | for file_path in file_path_list:
93 | with codecs.open(file_path, "r", "utf-8") as fr:
94 | json_data = json.loads(fr.read())
95 | json_data_list.append(json_data)
96 | return json_data_list
97 |
98 |
99 | def save_params(save_dir, save_json, yml_name='config.yaml'):
100 | import yaml
101 | with open(os.path.join(save_dir, yml_name), 'w', encoding='utf-8') as f:
102 | yaml.dump(save_json, f, default_flow_style=False, encoding='utf-8', allow_unicode=True)
103 |
104 |
105 | def read_config(config_file):
106 | import anyconfig
107 | if os.path.exists(config_file):
108 | with open(config_file, "rb") as fr:
109 | config = anyconfig.load(fr)
110 | if 'base' in config:
111 | base_config_path = config['base']
112 | if not base_config_path.startswith('/'):
113 | base_config_path = os.path.join(PROJECT_ROOT_PATH, base_config_path)
114 | elif os.path.basename(config_file) == 'base.yaml':
115 | return config
116 | else:
117 | base_config_path = os.path.join(os.path.dirname(config_file), "base.yaml")
118 | base_config = read_config(base_config_path)
119 | merged_config = base_config.copy()
120 | merge_config(config, merged_config)
121 | return merged_config
122 | else:
123 | return {}
124 |
125 |
126 | def merge_config(config, base_config):
127 | for key, _ in config.items():
128 | if isinstance(config[key], dict) and key not in base_config:
129 | base_config[key] = config[key]
130 | elif isinstance(config[key], dict):
131 | merge_config(config[key], base_config[key])
132 | else:
133 | if key in base_config:
134 | base_config[key] = config[key]
135 | else:
136 | base_config.update({key: config[key]})
137 |
138 |
139 | def init_experiment_config(config_file, experiment_name):
140 | if not config_file.startswith("/"):
141 | config_file = get_absolute_file_path(config_file)
142 | input_config = read_config(config_file)
143 | experiment_base_config = read_config(os.path.join(PROJECT_ROOT_PATH, 'config', experiment_name.lower(),
144 | 'base.yaml'))
145 | merged_config = experiment_base_config.copy()
146 | merge_config(input_config, merged_config)
147 |
148 | base_config = read_config(os.path.join(PROJECT_ROOT_PATH, 'config',
149 | 'base.yaml'))
150 | final_merged_config = base_config.copy()
151 | merge_config(merged_config, final_merged_config)
152 | return final_merged_config
153 |
--------------------------------------------------------------------------------
/base/driver.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: 2021/6/9
3 |
4 | import os
5 | import logging
6 |
7 | PROJECT_ROOT_PATH = os.path.abspath(os.path.join(__file__, '../../'))
8 |
9 | DATA_ROOT = os.path.join(PROJECT_ROOT_PATH, "data")
10 | CACHE_ROOT = os.path.join(DATA_ROOT, "cache")
11 |
12 | logger = logging.getLogger()
13 | stream_handler = logging.StreamHandler()
14 |
15 | log_formatter = logging.Formatter(fmt='%(asctime)s\t%(levelname)s\t%(name)s %(filename)s:%(lineno)s - %(message)s',
16 | datefmt='%Y-%m-%d %H:%M:%S')
17 | stream_handler.setFormatter(log_formatter)
18 |
19 | logger.addHandler(stream_handler)
20 |
21 | logger.setLevel(logging.INFO)
22 |
--------------------------------------------------------------------------------
/base/meter.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | # create: 2021/7/16
4 |
5 |
6 | class AverageMeter:
7 | """Computes and stores the average and current value"""
8 |
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
--------------------------------------------------------------------------------
/base/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # datetime:2022/4/28 11:50 上午
4 | # software: PyCharm
5 |
--------------------------------------------------------------------------------
/base/torch_utils/dl_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: 2021/6/17
3 |
4 | import json
5 | import math
6 | import torch
7 | import random
8 | import numpy as np
9 | from torch import nn
10 | import torch.optim as optim
11 | from base.driver import logger
12 | from collections import OrderedDict
13 | from torch.optim import lr_scheduler
14 | from timm.scheduler.cosine_lr import CosineLRScheduler
15 | from timm.scheduler.step_lr import StepLRScheduler
16 | from transformers.optimization import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, \
17 | get_linear_schedule_with_warmup
18 | from base.torch_utils.scheduler_util import LinearLRScheduler, get_cosine_schedule_by_epochs, \
19 | get_stairs_schedule_with_warmup
20 |
21 |
22 | def seed_all(random_seed):
23 | if random_seed is not None:
24 | random.seed(random_seed)
25 | np.random.seed(random_seed)
26 | torch.manual_seed(random_seed)
27 | torch.cuda.manual_seed_all(random_seed)
28 | torch.backends.cudnn.deterministic = True
29 |
30 |
31 | def print_network(net, verbose=False, name=""):
32 | num_params = 0
33 | for param in net.parameters():
34 | num_params += param.numel()
35 | if verbose:
36 | logger.info(net)
37 | if hasattr(net, 'flops'):
38 | flops = net.flops()
39 | logger.info(f"number of GFLOPs: {flops / 1e9}")
40 | logger.info('network:{} Total number of parameters: {}'.format(name, num_params))
41 |
42 |
43 | def check_keywords_in_name(name, keywords=()):
44 | isin = False
45 | for keyword in keywords:
46 | if keyword in name:
47 | isin = True
48 | return isin
49 |
50 |
51 | def get_grad_norm(parameters, norm_type=2):
52 | if isinstance(parameters, torch.Tensor):
53 | parameters = [parameters]
54 | parameters = list(filter(lambda p: p.grad is not None, parameters))
55 | norm_type = float(norm_type)
56 | total_norm = 0
57 | for p in parameters:
58 | param_norm = p.grad.data.norm(norm_type)
59 | total_norm += param_norm.item()**norm_type
60 | total_norm = total_norm**(1. / norm_type)
61 | return total_norm
62 |
63 |
64 | def set_params_optimizer(model, keyword=None, keywords=None, weight_decay=0.0, lr=None):
65 | if keywords is None:
66 | keywords = []
67 | param_dict = OrderedDict()
68 | no_decay_param_names = []
69 | for name, param in model.named_parameters():
70 | if not param.requires_grad:
71 | continue # frozen weights
72 | if keyword in name or check_keywords_in_name(name, keywords):
73 | param_dict[name] = {"weight_decay": weight_decay}
74 | if lr is not None:
75 | lr = float(lr)
76 | param_dict[name].update({"lr": lr})
77 | else:
78 | no_decay_param_names.append(name)
79 | return param_dict, no_decay_param_names
80 |
81 |
82 | def get_optimizer(model,
83 | optimizer_type="adam",
84 | lr=0.001,
85 | beta1=0.9,
86 | beta2=0.999,
87 | no_decay_keys=None,
88 | weight_decay=0.0,
89 | layer_decay=None,
90 | eps=1e-8,
91 | momentum=0,
92 | params=None,
93 | **kwargs):
94 | assigner = None
95 | if layer_decay is not None:
96 | if layer_decay < 1.0:
97 | num_layers = kwargs.get('num_layers')
98 | assigner = LayerDecayValueAssigner(list(layer_decay**(num_layers + 1 - i) for i in range(num_layers + 2)))
99 |
100 | lr = float(lr)
101 | beta1, beta2 = float(beta1), float(beta2)
102 | weight_decay = float(weight_decay)
103 | momentum = float(momentum)
104 | eps = float(eps)
105 | freeze_params = kwargs.get('freeze_params', [])
106 | img_lr = float(kwargs.get('img_lr', lr))
107 | for name, param in model.named_parameters():
108 | freeze_flag = False
109 | for freeze_param in freeze_params:
110 | if freeze_param in name:
111 | freeze_flag = True
112 | break
113 | if freeze_flag:
114 | print("name={} param.requires_grad = False".format(name))
115 | param.requires_grad = False
116 |
117 | if params is None:
118 | if weight_decay:
119 | skip = {}
120 | if no_decay_keys is not None:
121 | skip = no_decay_keys
122 | elif hasattr(model, 'no_weight_decay'):
123 | skip = model.no_weight_decay()
124 | param_configs = get_parameter_groups(model, img_lr, weight_decay, skip, assigner)
125 | weight_decay = 0.
126 | else:
127 | param_configs = model.parameters()
128 | else:
129 | param_configs = params
130 | if optimizer_type == "sgd":
131 | optimizer = optim.SGD(param_configs, momentum=momentum, nesterov=True, lr=lr, weight_decay=weight_decay)
132 | elif optimizer_type == "adam":
133 | optimizer = optim.Adam(param_configs, lr=lr, betas=(beta1, beta2), eps=eps, weight_decay=weight_decay)
134 | elif optimizer_type == "adadelta":
135 | optimizer = optim.Adadelta(param_configs, lr=lr, eps=eps, weight_decay=weight_decay)
136 | elif optimizer_type == "rmsprob":
137 | optimizer = optim.RMSprop(param_configs, lr=lr, eps=eps, weight_decay=weight_decay, momentum=momentum)
138 | elif optimizer_type == "adamw":
139 | optimizer = optim.AdamW(param_configs, lr=lr, betas=(beta1, beta2), eps=eps, weight_decay=weight_decay)
140 | else:
141 | return NotImplementedError('learning rate policy [%s] is not implemented', optimizer_type)
142 | return optimizer
143 |
144 |
145 | def get_scheduler(optimizer,
146 | scheduler_type="linear",
147 | num_warmup_steps=0,
148 | num_training_steps=10000,
149 | last_epoch=-1,
150 | step_size=10,
151 | gamma=0.1,
152 | epochs=20,
153 | **kwargs):
154 | gamma = float(gamma)
155 | if scheduler_type == "cosine":
156 | scheduler = get_cosine_schedule_with_warmup(optimizer,
157 | num_warmup_steps=num_warmup_steps,
158 | num_training_steps=num_training_steps,
159 | last_epoch=last_epoch)
160 | elif scheduler_type == 'cosine_epoch':
161 | scheduler = get_cosine_schedule_by_epochs(optimizer, num_epochs=epochs, last_epoch=last_epoch)
162 | elif scheduler_type == "linear":
163 | scheduler = get_linear_schedule_with_warmup(optimizer,
164 | num_warmup_steps=num_warmup_steps,
165 | num_training_steps=num_training_steps,
166 | last_epoch=last_epoch)
167 | elif scheduler_type == "stairs":
168 | logger.info("current use stair scheduler")
169 | scheduler = get_stairs_schedule_with_warmup(optimizer,
170 | num_warmup_steps=num_warmup_steps,
171 | num_training_steps=num_training_steps,
172 | last_epoch=last_epoch,
173 | **kwargs)
174 | elif scheduler_type == "step":
175 | step_size = int(step_size)
176 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
177 | elif scheduler_type == "exponential":
178 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=gamma)
179 | """
180 | def exp_decay(epoch):
181 | initial_lrate = 0.1
182 | k = 0.1
183 | lrate = initial_lrate * exp(-k*t)
184 | return lrate
185 | """
186 |
187 | else:
188 | scheduler = get_constant_schedule_with_warmup(optimizer,
189 | num_warmup_steps=num_warmup_steps,
190 | last_epoch=last_epoch)
191 | return scheduler
192 |
193 |
194 | def get_scheduler2(optimizer,
195 | scheduler_type="cosine",
196 | num_warmup_steps=0,
197 | num_training_steps=10000,
198 | decay_steps=1000,
199 | decay_rate=0.1,
200 | lr_min=5e-6,
201 | warmup_lr=5e-7):
202 | lr_min = float(lr_min)
203 | warmup_lr = float(warmup_lr)
204 | decay_rate = float(decay_rate)
205 | if scheduler_type == "cosine":
206 | scheduler = CosineLRScheduler(optimizer,
207 | t_initial=num_training_steps,
208 | t_mul=1,
209 | lr_min=lr_min,
210 | warmup_lr_init=warmup_lr,
211 | cycle_limit=1,
212 | t_in_epochs=False)
213 | elif scheduler_type == "linear":
214 | scheduler = LinearLRScheduler(optimizer,
215 | t_initial=num_training_steps,
216 | lr_min_rate=0.01,
217 | warmup_lr_init=warmup_lr,
218 | warmup_t=num_warmup_steps,
219 | t_in_epochs=False)
220 | else:
221 | scheduler = StepLRScheduler(optimizer,
222 | decay_t=decay_steps,
223 | decay_rate=decay_rate,
224 | warmup_lr_init=warmup_lr,
225 | warmup_t=num_warmup_steps,
226 | t_in_epochs=False)
227 | return scheduler
228 |
229 |
230 | def one_cycle(y1=0.0, y2=1.0, steps=100):
231 | # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
232 | return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
233 |
234 |
235 | def get_scheduler_yolo(optimizer, cos_lr=True, lrf=0.1, epochs=20, last_epoch=-1, **kwargs):
236 | if cos_lr:
237 | lf = one_cycle(1, lrf, epochs) # cosine 1->lrf
238 | else:
239 | lf = lambda x: (1 - x / epochs) * (1.0 - lrf) + lrf # linear
240 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf,
241 | last_epoch=last_epoch) # plot_lr_scheduler(optimizer, scheduler, epochs)
242 | return scheduler, lf
243 |
244 |
245 | def get_tensorboard_texts(label_texts):
246 | new_labels = []
247 | for label_text in label_texts:
248 | new_labels.append(label_text.replace("/", "//").replace("<", "/<").replace(">", "/>"))
249 | return " \n".join(new_labels)
250 |
251 |
252 | def get_parameter_groups(model, img_lr, weight_decay, skip_list=(), assigner=None):
253 | parameter_group_names = {}
254 | parameter_group_vars = {}
255 |
256 | for name, param in model.named_parameters():
257 | if not param.requires_grad:
258 | continue # frozen weights
259 | # TODO 是否通用?
260 | if 'image_encoder' in name:
261 | if len(param.shape) == 1 or name.endswith(".bias") or name.split('.')[-1] in skip_list:
262 | group_name = "img_encoder_no_decay"
263 | this_weight_decay = 0.
264 | else:
265 | group_name = "img_encoder_decay"
266 | this_weight_decay = weight_decay
267 | if assigner is not None:
268 | layer_id = assigner.get_layer_id(name)
269 | group_name = "layer_%d_%s" % (layer_id, group_name)
270 | else:
271 | layer_id = None
272 |
273 | if group_name not in parameter_group_names:
274 | if assigner is not None:
275 | scale = assigner.get_scale(layer_id)
276 | else:
277 | scale = 1.
278 |
279 | parameter_group_names[group_name] = {
280 | "weight_decay": this_weight_decay,
281 | "params": [],
282 | "lr_scale": scale,
283 | "lr": img_lr
284 | }
285 | parameter_group_vars[group_name] = {
286 | "weight_decay": this_weight_decay,
287 | "params": [],
288 | "lr_scale": scale,
289 | "lr": img_lr
290 | }
291 | else:
292 | if len(param.shape) == 1 or name.endswith(".bias") or name.split('.')[-1] in skip_list:
293 | group_name = "no_decay"
294 | this_weight_decay = 0.
295 | else:
296 | group_name = "decay"
297 | this_weight_decay = weight_decay
298 | if assigner is not None:
299 | layer_id = assigner.get_layer_id(name)
300 | group_name = "layer_%d_%s" % (layer_id, group_name)
301 | else:
302 | layer_id = None
303 |
304 | if group_name not in parameter_group_names:
305 | if assigner is not None:
306 | scale = assigner.get_scale(layer_id)
307 | else:
308 | scale = 1.
309 |
310 | parameter_group_names[group_name] = {"weight_decay": this_weight_decay, "params": [], "lr_scale": scale}
311 | parameter_group_vars[group_name] = {"weight_decay": this_weight_decay, "params": [], "lr_scale": scale}
312 |
313 | parameter_group_vars[group_name]["params"].append(param)
314 | parameter_group_names[group_name]["params"].append(name)
315 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
316 | return list(parameter_group_vars.values())
317 |
318 |
319 | class LayerDecayValueAssigner(object):
320 |
321 | def __init__(self, values):
322 | self.values = values
323 |
324 | def get_scale(self, layer_id):
325 | return self.values[layer_id]
326 |
327 | def get_layer_id(self, var_name):
328 | return get_num_layer(var_name, len(self.values))
329 |
330 |
331 | def get_num_layer(var_name, num_max_layer):
332 | var_name = var_name.split('.', 1)[-1]
333 | if var_name.startswith("embeddings"):
334 | return 0
335 | elif var_name.startswith("encoder.layer"):
336 | layer_id = int(var_name.split('.')[2])
337 | return layer_id + 1
338 | else:
339 | return num_max_layer - 1
340 |
--------------------------------------------------------------------------------
/base/torch_utils/scheduler_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: 2021/10/9
3 | import math
4 | import torch
5 | from torch.optim import Optimizer
6 | from timm.scheduler.scheduler import Scheduler
7 |
8 | from torch.optim.lr_scheduler import LambdaLR
9 |
10 |
11 | def get_stairs_schedule_with_warmup(optimizer,
12 | num_warmup_steps,
13 | num_training_steps,
14 | stair_num=2,
15 | min_scale=0.01,
16 | last_epoch=-1,
17 | **kwargs):
18 | """
19 | Create a stair schedule with a learning rate that from the initial lr set in the optimizer to 0, after
20 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
21 | and then with duplicate stairs, more train step will be allocated to a smaller learning rates.
22 | decrease stage like this with learning rate:4e-4, stair_num:3, remain_steps: 1400
23 | then learning rate is:
24 | from 0-100, 4e-4
25 | from 100-200, decrease from 4e-4 to 2e-4
26 | from 200-400 2e-4
27 | from 400-600 decrease from 2e-4 to 1e-4
28 | from 600-1000 1e-4
29 | from 1000-1400 decrease from 1e-4 to 0
30 | as following:
31 | ___
32 | / \ ____
33 | / \ ___
34 | / \
35 | Args:
36 | optimizer (:class:`~torch.optim.Optimizer`):
37 | The optimizer for which to schedule the learning rate.
38 | num_warmup_steps (:obj:`int`):
39 | The number of steps for the warmup phase.
40 | num_training_steps (:obj:`int`):
41 | The total number of training steps.
42 | stair_num: int
43 | min_scale: min learning_rate ratio
44 | last_epoch (:obj:`int`, `optional`, defaults to -1):
45 | The index of the last epoch when resuming training.
46 |
47 | Return:
48 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
49 | """
50 |
51 | stair_num = int(stair_num)
52 | min_scale = float(min_scale)
53 | remain_step = max(1, num_training_steps - num_warmup_steps)
54 | unit_step = int(remain_step / (2**(stair_num + 1) - 2))
55 | remain_linear_step = remain_step - unit_step * int(2**stair_num - 1)
56 | stair_steps = [
57 | unit_step * int((3 * 2**(i / 2) - 2)) if i % 2 == 0 else unit_step * int(4 * 2**((i - 1) / 2) - 2)
58 | for i in range(2 * stair_num)
59 | ]
60 | stair_scales = [1.0 - (2**i - 1) / (2**stair_num - 1) for i in range(stair_num)]
61 |
62 | def lr_lambda(current_step: int):
63 | if current_step < num_warmup_steps:
64 | return float(current_step) / float(max(1, num_warmup_steps))
65 |
66 | current_remain_step = current_step - num_warmup_steps
67 | for i in range(stair_num * 2):
68 | if i == 0:
69 | prev_stair_step = 0
70 | else:
71 | prev_stair_step = stair_steps[i - 1]
72 | stair_step = stair_steps[i]
73 | if prev_stair_step <= current_remain_step <= stair_step:
74 | if i % 2 == 0:
75 | return max(min_scale, stair_scales[i // 2])
76 | else:
77 | prev_linear_step = unit_step * int(2**((i - 1) / 2) - 1)
78 | current_linear_step = current_remain_step - prev_stair_step + prev_linear_step
79 | linear_lr = float(remain_linear_step - current_linear_step) / float(remain_linear_step)
80 | return max(min_scale, linear_lr)
81 | return min_scale
82 |
83 | return LambdaLR(optimizer, lr_lambda, last_epoch)
84 |
85 |
86 | class LinearLRScheduler(Scheduler):
87 |
88 | def __init__(
89 | self,
90 | optimizer: torch.optim.Optimizer,
91 | t_initial: int,
92 | lr_min_rate: float,
93 | warmup_t=0,
94 | warmup_lr_init=0.,
95 | t_in_epochs=True,
96 | noise_range_t=None,
97 | noise_pct=0.67,
98 | noise_std=1.0,
99 | noise_seed=42,
100 | initialize=True,
101 | ) -> None:
102 | super().__init__(optimizer,
103 | param_group_field="lr",
104 | noise_range_t=noise_range_t,
105 | noise_pct=noise_pct,
106 | noise_std=noise_std,
107 | noise_seed=noise_seed,
108 | initialize=initialize)
109 |
110 | self.t_initial = t_initial
111 | self.lr_min_rate = lr_min_rate
112 | self.warmup_t = warmup_t
113 | self.warmup_lr_init = warmup_lr_init
114 | self.t_in_epochs = t_in_epochs
115 | if self.warmup_t:
116 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
117 | super().update_groups(self.warmup_lr_init)
118 | else:
119 | self.warmup_steps = [1 for _ in self.base_values]
120 |
121 | def _get_lr(self, t):
122 | if t < self.warmup_t:
123 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
124 | else:
125 | t = t - self.warmup_t
126 | total_t = self.t_initial - self.warmup_t
127 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
128 | return lrs
129 |
130 | def get_epoch_values(self, epoch: int):
131 | if self.t_in_epochs:
132 | return self._get_lr(epoch)
133 | else:
134 | return None
135 |
136 | def get_update_values(self, num_updates: int):
137 | if not self.t_in_epochs:
138 | return self._get_lr(num_updates)
139 | else:
140 | return None
141 |
142 |
143 | # update lr by epoch
144 | def get_cosine_schedule_by_epochs(optimizer: Optimizer, num_epochs: int, last_epoch: int = -1, **kwargs):
145 |
146 | def lr_lambda(epoch):
147 | lf = ((1 + math.cos(epoch * math.pi / num_epochs)) / 2) * 0.8 + 0.2 # cosine
148 | return lf
149 |
150 | return LambdaLR(optimizer, lr_lambda, last_epoch)
151 |
--------------------------------------------------------------------------------
/config/base.yaml:
--------------------------------------------------------------------------------
1 | name: docparser-base # experiment name
2 | model:
3 | type: DocParser
4 | pretrained_model_name_or_path: /models
5 | # encoder
6 | image_size: [2560, 1920] # the input image size of docparser
7 | # decoder
8 | max_length: 1024 # the max input length of docparser
9 | decoder_layers: 2 # the decoder layer
10 |
11 | model_path: ~ # path to a certain checkpoint
12 | load_strict: true # whether to strictly load the checkpoint
13 |
14 | # training precision
15 | mixed_precision: "fp16" # "["no", "fp16", "bf16] # use torch native amp
16 |
17 | tokenizer_args:
18 | pretrained_model_name_or_path: naver-clova-ix/donut-base # we borrow tokenizer & image processor from donut
19 | extra_args: {}
20 |
21 | predictor:
22 | img_paths:
23 | -
24 | save_dir: /data/data/cache
25 |
26 | trainer:
27 | start_global_step: -1 # start training from a certain global step; -1 means no starting global step is set
28 | resume_flag: false # whether to resume the training from a certain checkpoint
29 | random_seed: ~
30 | grad_clip: 1.0
31 | epochs: 5
32 |
33 | # tensorboard configuration
34 | save_dir: /logs/docparser
35 | tensorboard_dir: /logs/docparser/tensorboard
36 |
37 | # display configuration
38 | save_epoch_freq: 1
39 | save_step_freq: 800
40 | print_freq: 20
41 |
42 | # gradient configuration
43 | grad_accumulate: 1 # gradient accumulation
44 |
45 | # optimizer configuration
46 | optimizer:
47 | optimizer_type: "adamw"
48 | lr: 1.0e-04
49 | # layer_decay: 0.75
50 | weight_decay: 0.05
51 | beta1: 0.9
52 | beta2: 0.98
53 | eps: 1.0e-6
54 |
55 | # scheduler configuration
56 | scheduler:
57 | scheduler_type: "cosine"
58 | warmup_steps: 2000
59 | warmup_epochs: 0
60 |
61 | datasets:
62 | train:
63 | dataset:
64 | type: DocParser
65 | task_start_token:
66 | data_root:
67 | - # put your dataset path here
68 | num_workers: 0
69 | batch_size: 1 # global batch = bz * num_gpu * grad
70 | shuffle: true
71 | collate_fn:
72 | type: DataCollatorForDocParserDataset
--------------------------------------------------------------------------------
/doc/encoder_arch.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NormXU/DocParser-Pytorch/6e11ea5fc211b1ccc37f51b2b0f64baea68fcf83/doc/encoder_arch.jpeg
--------------------------------------------------------------------------------
/examples/test_forward.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: @time: 7/6/23 16:38
3 | import torch
4 | from transformers import VisionEncoderDecoderConfig, AutoConfig, VisionEncoderDecoderModel, AutoModel
5 | from models import DocParserModel, DocParserConfig
6 |
7 | if __name__ == '__main__':
8 | AutoConfig.register("docparser-swin", DocParserConfig)
9 | AutoModel.register(DocParserConfig, DocParserModel)
10 |
11 | config = VisionEncoderDecoderConfig.from_pretrained("../models/")
12 | model = VisionEncoderDecoderModel(config=config)
13 |
14 | # test forward with dummy input
15 | input_tensor = torch.ones(1, 3, 2560, 1920)
16 | output = model(input_tensor)
17 |
--------------------------------------------------------------------------------
/experiment/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: 2021/6/8
3 | from .base_experiment import BaseExperiment
4 | from .docparser_experiment import DocParserExperiment
5 |
6 | def get_experiment_name(name):
7 | name_split = name.split("_")
8 | trainer_name = "".join([tmp_name[0].upper() + tmp_name[1:] for tmp_name in name_split])
9 | return "{}Experiment".format(trainer_name)
--------------------------------------------------------------------------------
/experiment/base_experiment.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: 2021/7/2
3 | import copy
4 | import itertools
5 | import json
6 | import logging
7 | import os
8 | from contextlib import nullcontext
9 |
10 | import munch
11 | import torch
12 | from accelerate import Accelerator
13 | from torch import autocast
14 | from torch.utils.data import DataLoader
15 | from torch.utils.tensorboard import SummaryWriter
16 |
17 | import mydatasets
18 | from base.common_util import get_absolute_file_path, merge_config, get_file_path_list
19 | from base.common_util import save_params
20 | from base.driver import log_formatter
21 | from base.driver import logger
22 | from base.torch_utils.dl_util import get_optimizer, get_scheduler, get_scheduler2, seed_all, get_grad_norm
23 | from mydatasets import get_dataset
24 |
25 |
26 | class BaseExperiment(object):
27 |
28 | def __init__(self, config):
29 | config = self._init_config(config)
30 | self.experiment_name = config["name"]
31 | self.args = munch.munchify(config)
32 | self.init_device(config)
33 | self.init_random_seed(config)
34 | self.init_model(config)
35 | self.init_dataset(config)
36 | self.init_trainer_args(config)
37 | self.init_predictor_args(config)
38 | self.prepare_accelerator()
39 |
40 | """
41 | Main Block
42 | """
43 |
44 | def predict(self, **kwargs):
45 | request_property = kwargs.get('request_property')
46 | pass
47 |
48 | def evaluate(self, **kwargs):
49 | pass
50 |
51 | def train(self, **kwargs):
52 | pass
53 |
54 | def _step_forward(self, batch, is_train=True, eval_model=None, **kwargs):
55 | pass
56 |
57 | def _step_backward(self, loss, **kwargs):
58 | if self.use_torch_amp:
59 | self.mixed_scaler.scale(loss).backward()
60 | else:
61 | if self.accelerator is not None:
62 | self.accelerator.backward(loss)
63 | else:
64 | loss = loss / self.args.trainer.grad_accumulate
65 | loss.backward()
66 |
67 | def _get_current_lr(self, ni, global_step=0, **kwargs):
68 | if self.args.trainer.scheduler_type == "scheduler2":
69 | current_lr = self.scheduler.get_update_values(global_step)[-1]
70 | else:
71 | current_lr = self.scheduler.get_last_lr()[-1]
72 | return current_lr
73 |
74 | def _step_optimizer(self, **kwargs):
75 | params_to_clip = (itertools.chain(self.model.parameters()))
76 | for param_group in self.optimizer.param_groups:
77 | if "lr_scale" in param_group:
78 | param_group["lr"] = param_group["lr"] * param_group["lr_scale"]
79 | grad_norm = None
80 | if self.args.trainer.grad_clip is not None:
81 | if self.use_torch_amp:
82 | # Unscales the gradients of optimizer's assigned params in-place
83 | # called only after all gradients for that optimizer’s assigned parameters have been accumulated
84 | self.mixed_scaler.unscale_(self.optimizer)
85 | grad_norm = torch.nn.utils.clip_grad_norm_(params_to_clip, self.args.trainer.grad_clip)
86 | self.mixed_scaler.step(self.optimizer)
87 | # Updates the scale for next iteration.
88 | self.mixed_scaler.update()
89 | if self.accelerator:
90 | if self.accelerator.sync_gradients:
91 | grad_norm = self.accelerator.clip_grad_norm_(params_to_clip, self.args.trainer.grad_clip)
92 | self.optimizer.step()
93 | if grad_norm is None:
94 | grad_norm = get_grad_norm(params_to_clip)
95 | self.optimizer.step()
96 |
97 | self.optimizer.zero_grad()
98 | return grad_norm
99 |
100 | def _step_scheduler(self, global_step, **kwargs):
101 | if self.args.trainer.scheduler_type == "scheduler2":
102 | self.scheduler.step_update(global_step)
103 | else:
104 | self.scheduler.step()
105 |
106 | """
107 | Initialization Functions
108 | """
109 |
110 | # config的联动关系可以写在这个函数中
111 | def _init_config(self, config):
112 | if 'trainer' in config and config.get('phase', 'train') == 'train':
113 | trainer_args = config["trainer"]
114 | trainer_args['save_dir'] = get_absolute_file_path(trainer_args.get("save_dir"))
115 | os.makedirs(trainer_args['save_dir'], exist_ok=True)
116 | save_params(trainer_args['save_dir'], config)
117 | train_log_path = os.path.join(trainer_args['save_dir'], "{}.log".format(config['name']))
118 | file_handler = logging.FileHandler(train_log_path)
119 | file_handler.setLevel(logging.INFO)
120 | file_handler.setFormatter(log_formatter)
121 | logger.addHandler(file_handler)
122 | return config
123 |
124 | def init_device(self, config):
125 | # ADD RUN_ON_GPU_IDs=-1是cpu,多张默认走accelerator
126 | self.args.device = munch.munchify(config.get('device', {}))
127 | self.accelerator = None
128 | self.weight_dtype = torch.float32
129 | self.gradient_accumulate_scope = nullcontext
130 | self.precision_scope = nullcontext()
131 | self.use_torch_amp = False
132 |
133 | # accelerator configuration
134 | if config['use_accelerate']:
135 | # If you define multiple visible GPU, I suppose you to use accelerator to do ddp training
136 | self.accelerator = Accelerator(
137 | gradient_accumulation_steps=int(self.args.trainer.grad_accumulate),
138 | mixed_precision=self.args.model.mixed_precision)
139 | self.args.device.device_id = self.accelerator.device
140 | self.args.device.device_ids = []
141 | if self.accelerator.mixed_precision == "fp16":
142 | self.weight_dtype = torch.float16
143 | elif self.accelerator.mixed_precision == "bf16":
144 | self.weight_dtype = torch.bfloat16
145 | self.gradient_accumulate_scope = self.accelerator.accumulate
146 | self.args.device.is_master = self.accelerator.is_main_process
147 | self.args.device.is_distributed = self.accelerator.num_processes > 1
148 | elif os.environ.get("RUN_ON_GPU_IDs", 0) == str(-1):
149 | # load model with CPU
150 | self.args.device.device_id = torch.device("cpu")
151 | self.args.device.device_ids = [-1]
152 | self.args.device.is_master = True
153 | self.args.device.is_distributed = False
154 | else:
155 | # USE one GPU specified by user w/o using accelerate
156 | device_id = os.environ.get("RUN_ON_GPU_IDs", 0)
157 | self.args.device.device_id = torch.device("cuda:{}".format(device_id))
158 | self.args.device.device_ids = [int(device_id)]
159 | torch.cuda.set_device(int(device_id))
160 | self.args.device.is_master = True
161 | self.args.device.is_distributed = False
162 | if self.args.model.mixed_precision in ["fp16", "bf16"]:
163 | # ADD mixed_precision_flag改为use_torch_amp
164 | self.use_torch_amp = True
165 | self.weight_dtype = torch.float16 if self.args.model.mixed_precision == "fp16" else torch.bfloat16
166 | self.precision_scope = autocast(device_type="cuda", dtype=self.weight_dtype)
167 | logger.info("device:{}, is_master:{}, device_ids:{}, is_distributed:{}".format(
168 | self.args.device.device_id, self.args.device.is_master, self.args.device.device_ids,
169 | self.args.device.is_distributed))
170 |
171 | def init_model(self, config):
172 | pass
173 |
174 | def init_dataset(self, config):
175 | if 'datasets' in config and config.get('phase', 'train') != 'predict':
176 | dataset_args = config.get("datasets")
177 | train_data_loader_args = dataset_args.get("train")
178 | if config.get('phase', 'train') == 'train':
179 | self.train_dataset = get_dataset(train_data_loader_args['dataset'])
180 | self.train_data_loader = self._get_data_loader_from_dataset(self.train_dataset,
181 | train_data_loader_args,
182 | phase='train')
183 | logger.info("success init train data loader len:{} ".format(len(self.train_data_loader)))
184 | eval_data_loader_args = dataset_args.get("eval")
185 | merged_eval_data_loader_args = train_data_loader_args.copy()
186 | merge_config(eval_data_loader_args, merged_eval_data_loader_args)
187 | self.eval_dataset = get_dataset(merged_eval_data_loader_args['dataset'])
188 | self.eval_data_loader = self._get_data_loader_from_dataset(self.eval_dataset,
189 | merged_eval_data_loader_args,
190 | phase='eval')
191 | logger.info("success init eval data loader len:{}".format(len(self.eval_data_loader)))
192 |
193 | def init_random_seed(self, config):
194 | if 'random_seed' in config['trainer']:
195 | seed_all(config['trainer']['random_seed'])
196 | else:
197 | logger.warning("random seed is missing")
198 |
199 | def init_predictor_args(self, config):
200 | if 'predictor' in config and config.get('phase', 'train') == 'predict':
201 | predictor_args = config["predictor"]
202 | self.args.predictor.img_paths = predictor_args.get("img_paths", None)
203 | if self.args.predictor.img_paths is None:
204 | self.args.predictor.img_paths = []
205 | img_dirs = predictor_args.get("img_dirs", None)
206 | if img_dirs:
207 | for img_dir in img_dirs:
208 | if img_dir:
209 | self.args.predictor.img_paths.extend(get_file_path_list(img_dir, ['jpg', 'png', 'jpeg']))
210 | if predictor_args['save_dir'] is None and 'model_path' in config['model'] and config['model'][
211 | 'model_path'] is not None:
212 | predictor_args['save_dir'] = os.path.join(os.path.dirname(config['model']['model_path']),
213 | 'test_results')
214 | self.args.predictor.save_dir = get_absolute_file_path(predictor_args["save_dir"])
215 | os.makedirs(self.args.predictor.save_dir, exist_ok=True)
216 |
217 | def init_trainer_args(self, config):
218 | if 'trainer' in config and config.get('phase', 'train') == 'train':
219 | trainer_args = config["trainer"]
220 | self._init_optimizer(trainer_args)
221 | self._init_scheduler(trainer_args)
222 | logger.info("current trainer epochs:{}, train_dataset_len:{}, data_loader_len:{}".format(
223 | self.args.trainer.epochs, len(self.train_dataset), len(self.train_data_loader)))
224 | self.mixed_scaler = torch.cuda.amp.GradScaler(enabled=True) if self.use_torch_amp else None
225 | self.args.trainer.best_eval_result = -1
226 | self.args.trainer.best_model_path = ''
227 | self.args.trainer.start_epoch = 0
228 | self.args.trainer.start_global_step = 0
229 | if self.args.trainer.resume_flag and 'model_path' in self.args.model and self.args.model.model_path is not None:
230 | resume_path = self.args.model.model_path.replace('.pth', '_resume.pth')
231 | if os.path.exists(resume_path):
232 | resume_checkpoint = torch.load(resume_path)
233 | self.optimizer.load_state_dict(resume_checkpoint['optimizer_state_dict'])
234 | self.scheduler.load_state_dict(resume_checkpoint['scheduler_state_dict'])
235 | self.args.trainer.start_epoch = resume_checkpoint['epoch']
236 | self.args.trainer.start_global_step = resume_checkpoint['global_step']
237 | else:
238 | logger.warning("resume path {} doesn't exist: failed to resume!!".format(resume_path))
239 |
240 | if 'trainer' in config and config.get('phase', 'train') != 'predict':
241 | trainer_args = config["trainer"]
242 | self._init_criterion(trainer_args)
243 | # init tensorboard and log
244 | if "tensorboard_dir" in trainer_args and self.args.device.is_master:
245 | tensorboard_log_dir = get_absolute_file_path(trainer_args.get("tensorboard_dir"))
246 | os.makedirs(tensorboard_log_dir, exist_ok=True)
247 | self.writer = SummaryWriter(log_dir=tensorboard_log_dir, comment=self.experiment_name)
248 | else:
249 | self.writer = None
250 |
251 | def _init_optimizer(self, trainer_args, **kwargs):
252 | optimizer_args = trainer_args.get("optimizer")
253 | # ADD scale lr
254 | if optimizer_args["scale_lr"]:
255 | num_process = 1 if self.accelerator is None else self.accelerator.num_processes
256 | optimizer_args['lr'] = optimizer_args['lr'] * self.args.trainer.grad_accumulate * \
257 | self.train_data_loader.batch_size * num_process
258 | self.optimizer = get_optimizer(self.model, **optimizer_args)
259 |
260 | def _init_scheduler(self, trainer_args, **kwargs):
261 | scheduler_args = trainer_args.get("scheduler")
262 | self.args.trainer.scheduler_by_epoch = scheduler_args.get("scheduler_by_epoch", False)
263 | total_epoch_train_steps = len(self.train_data_loader)
264 | if scheduler_args["warmup_epochs"] > 0:
265 | warmup_steps = scheduler_args.get("warmup_epochs") * total_epoch_train_steps
266 | elif scheduler_args['warmup_steps'] > 0:
267 | warmup_steps = scheduler_args.get("warmup_steps")
268 | else:
269 | warmup_steps = 0
270 | self.args.trainer.scheduler.warmup_steps = warmup_steps
271 | num_training_steps = total_epoch_train_steps * self.args.trainer.epochs
272 | if self.accelerator is None:
273 | # accelerator will automatically take care of the grad accumulate in calculating total num_training steps,
274 | # or you need to calculate by yourself
275 | num_training_steps = num_training_steps // self.args.trainer.grad_accumulate
276 | if "scheduler_method" in scheduler_args and scheduler_args["scheduler_method"] == "get_scheduler2":
277 | self.scheduler = get_scheduler2(self.optimizer,
278 | num_training_steps=num_training_steps,
279 | num_warmup_steps=warmup_steps,
280 | **scheduler_args)
281 | self.args.trainer.scheduler_type = "scheduler2"
282 | else:
283 | self.scheduler = get_scheduler(self.optimizer,
284 | num_training_steps=num_training_steps,
285 | num_warmup_steps=warmup_steps,
286 | epochs=self.args.trainer.epochs,
287 | **scheduler_args)
288 | self.args.trainer.scheduler_type = "scheduler"
289 |
290 | logger.info(
291 | "success init optimizer and scheduler, optimizer:{}, scheduler:{}, scheduler_args:{}, warmup_steps:{},"
292 | "num_training_steps:{}, gradient_accumulator:{}".format(self.optimizer, self.scheduler, scheduler_args,
293 | warmup_steps, num_training_steps,
294 | self.args.trainer.grad_accumulate))
295 |
296 | def _init_criterion(self, trainer_args):
297 | pass
298 |
299 | def _init_metric(self, **kwargs):
300 | pass
301 |
302 | """
303 | Tool Functions
304 | """
305 |
306 | def load_model(self, checkpoint_path, strict=True, **kwargs):
307 | if os.path.exists(checkpoint_path) and os.path.isfile(checkpoint_path):
308 | state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
309 | if 'model_state_dict' in state_dict:
310 | model_state_dict = state_dict['model_state_dict']
311 | else:
312 | model_state_dict = state_dict
313 | self.model.load_state_dict(model_state_dict, strict=strict)
314 | logger.info("success load model:{}".format(checkpoint_path))
315 |
316 | def save_model(self, checkpoint_path, **save_kwargs):
317 | if self.accelerator is not None:
318 | unwrapped_model = self.accelerator.unwrap_model(self.model)
319 | if self.args.trainer.resume_flag:
320 | save_kwargs.update({
321 | 'model_state_dict': unwrapped_model.state_dict(),
322 | 'optimizer_state_dict': self.optimizer.state_dict(),
323 | 'scheduler_state_dict': self.scheduler.state_dict(),
324 | })
325 | self.accelerator.save(save_kwargs, checkpoint_path.replace('.pth', '.ckpt'))
326 | else:
327 | self.accelerator.save(unwrapped_model.state_dict(), checkpoint_path)
328 | else:
329 | if self.args.model.quantization_type == 'quantization_aware_training':
330 | self.model.eval()
331 | model_int8 = torch.quantization.convert(self.model)
332 | torch.save(model_int8.state_dict(), checkpoint_path)
333 | else:
334 | if self.args.trainer.resume_flag:
335 | save_kwargs.update({
336 | 'model_state_dict': self.model.state_dict(),
337 | 'optimizer_state_dict': self.optimizer.state_dict(),
338 | 'scheduler_state_dict': self.scheduler.state_dict(),
339 | })
340 | torch.save(save_kwargs, checkpoint_path.replace('.pth', '.ckpt'))
341 | else:
342 | torch.save(self.model.state_dict(), checkpoint_path)
343 | logger.info("model successfully saved to {}".format(checkpoint_path))
344 |
345 | def _get_data_loader_from_dataset(self, dataset, data_loader_args, phase="train"):
346 | num_workers = data_loader_args.get("num_workers", 0)
347 | batch_size = data_loader_args.get("batch_size", 1)
348 | if phase == "train" and data_loader_args.get('shuffle', True):
349 | shuffle = data_loader_args.get("shuffle", True)
350 | else:
351 | shuffle = data_loader_args.get("shuffle", False)
352 | pin_memory = data_loader_args.get("shuffle", False)
353 |
354 | collate_fn_args = data_loader_args.get("collate_fn")
355 | if collate_fn_args.get("type") is None:
356 | collate_fn = None
357 | else:
358 | collate_fn_type = collate_fn_args.get("type")
359 | collate_fn = getattr(mydatasets, collate_fn_type)(batch_size=batch_size, **collate_fn_args)
360 | data_loader = DataLoader(dataset,
361 | shuffle=shuffle,
362 | num_workers=num_workers,
363 | pin_memory=pin_memory,
364 | collate_fn=collate_fn,
365 | batch_size=batch_size)
366 | logger.info("use data loader with batch_size:{},num_workers:{}".format(batch_size, num_workers))
367 |
368 | return data_loader
369 |
370 | def prepare_accelerator(self):
371 | if self.accelerator is not None:
372 | self.model, self.optimizer, self.train_data_loader, self.scheduler = self.accelerator.prepare(
373 | self.model, self.optimizer, self.train_data_loader, self.scheduler)
374 |
375 | def _train_post_process(self):
376 | args = copy.deepcopy(self.args)
377 | args.model.model_path = args.trainer.best_model_path
378 | if 'base' in args:
379 | args.pop('base')
380 | args.device.pop('device_id')
381 | args.pop('trainer')
382 | args.phase = 'predict'
383 | save_params(self.args.trainer.save_dir, json.loads(json.dumps(args)), 'model_args.yaml')
384 | return os.path.join(self.args.trainer.save_dir, 'model_args.yaml')
385 |
386 | def _print_step_log(self, epoch, global_step, global_eval_step, loss_meter, norm_meter, batch_time, ni, **kwargs):
387 | current_lr = self._get_current_lr(ni, global_step)
388 | if self.args.device.is_master and self.args.trainer.print_freq > 0 and global_step % self.args.trainer.print_freq == 0:
389 | message = "experiment:{}; train, (epoch: {}, steps: {}, lr:{:e}, step_mean_loss:{}," \
390 | " average_loss:{}), time, (train_step_time: {:.5f}s, train_average_time: {:.5f}s);" \
391 | "(grad_norm_mean: {:.5f}, grad_norm_step: {:.5f})". \
392 | format(self.experiment_name, epoch, global_step, current_lr,
393 | loss_meter.val, loss_meter.avg, batch_time.val, batch_time.avg, norm_meter.avg,
394 | norm_meter.val)
395 | logger.info(message)
396 | if self.writer is not None:
397 | self.writer.add_scalar("{}_train/lr".format(self.experiment_name), current_lr, global_step)
398 | self.writer.add_scalar("{}_train/step_loss".format(self.experiment_name), loss_meter.val, global_step)
399 | self.writer.add_scalar("{}_train/average_loss".format(self.experiment_name), loss_meter.avg,
400 | global_step)
401 | if global_step > 0 and self.args.trainer.save_step_freq > 0 and global_step % self.args.trainer.save_step_freq == 0:
402 | message = "experiment:{}; eval, (epoch: {}, steps: {});".format(self.experiment_name, epoch, global_step)
403 | logger.info(message)
404 | result = self.evaluate(global_eval_step=global_eval_step)
405 | global_eval_step, acc = result['global_eval_step'], result['acc']
406 | # ADD is_master判断移到这里
407 | if (not self.args.trainer.save_best or (self.args.trainer.save_best
408 | and acc > self.args.trainer.best_eval_result)) and self.args.device.is_master:
409 | checkpoint_name = "{}_epoch{}_step{}_lr{:e}_average_loss{:.5f}_acc{:.5f}.pth".format(
410 | self.experiment_name, epoch, global_step, current_lr, loss_meter.avg, acc)
411 | checkpoint_path = os.path.join(self.args.trainer.save_dir, checkpoint_name)
412 | # ADD记得传epoch和global_step,resume才能存
413 | self.save_model(checkpoint_path, epoch=epoch, global_step=global_step, loss=loss_meter.val)
414 | if acc > self.args.trainer.best_eval_result:
415 | self.args.trainer.best_eval_result = acc
416 | self.args.trainer.best_model_path = checkpoint_path
417 | return global_eval_step
418 |
419 | def _print_epoch_log(self, epoch, global_step, global_eval_step, loss_meter, ni, **kwargs):
420 | current_lr = self._get_current_lr(ni, global_step)
421 | if self.args.trainer.save_epoch_freq > 0 and epoch % self.args.trainer.save_epoch_freq == 0:
422 | message = "experiment:{}; eval, (epoch: {}, steps: {});".format(self.experiment_name, epoch, global_step)
423 | logger.info(message)
424 | result = self.evaluate(global_eval_step=global_eval_step)
425 | global_eval_step, acc = result['global_eval_step'], result['acc']
426 | if (not self.args.trainer.save_best or (self.args.trainer.save_best
427 | and acc > self.args.trainer.best_eval_result)) and self.args.device.is_master:
428 | checkpoint_name = "{}_epoch{}_step{}_lr{:e}_average_loss{:.5f}_acc{:.5f}.pth".format(
429 | self.experiment_name, epoch, global_step, current_lr, loss_meter.avg, acc)
430 | checkpoint_path = os.path.join(self.args.trainer.save_dir, checkpoint_name)
431 | self.save_model(checkpoint_path, epoch=epoch, global_step=global_step, loss=loss_meter.val)
432 | if acc > self.args.trainer.best_eval_result:
433 | self.args.trainer.best_eval_result = acc
434 | self.args.trainer.best_model_path = checkpoint_path
435 | return global_eval_step
436 |
437 | def _print_eval_log(self, global_step, loss_meter, eval_metric, **kwargs):
438 | evaluate_report = eval_metric.get_report()
439 | acc = evaluate_report["acc"]
440 | message = "experiment:{}; eval,global_step:{}, (step_mean_loss:{},average_loss:{:.5f},evaluate_report:{})".format(
441 | self.experiment_name, global_step, loss_meter.val, loss_meter.avg, evaluate_report)
442 | logger.info(message)
443 | if self.writer is not None:
444 | self.writer.add_scalar("{}_eval/step_loss".format(self.experiment_name), loss_meter.val, global_step)
445 | self.writer.add_scalar("{}_eval/average_loss".format(self.experiment_name), loss_meter.avg, global_step)
446 | self.writer.add_scalar("{}_eval/acc".format(self.experiment_name), acc, global_step)
447 | return acc
448 |
--------------------------------------------------------------------------------
/experiment/docparser_experiment.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: 2023/6/2
3 | import os
4 | import re
5 | import time
6 |
7 | import munch
8 | import torch
9 | from PIL import Image
10 | from transformers import AutoTokenizer, DonutProcessor, VisionEncoderDecoderModel, \
11 | VisionEncoderDecoderConfig, DonutImageProcessor, AutoConfig, AutoModel
12 |
13 | from base.common_util import get_absolute_file_path
14 | from base.driver import logger
15 | from base.meter import AverageMeter
16 | from base.torch_utils.dl_util import get_optimizer
17 | from models.configuration_docparser import DocParserConfig
18 | from models.modeling_docparser import DocParserModel
19 | from mydatasets import get_dataset
20 | from .base_experiment import BaseExperiment
21 |
22 |
23 | class DocParserExperiment(BaseExperiment):
24 |
25 | def __init__(self, config):
26 | config = self._init_config(config)
27 | self.experiment_name = config["name"]
28 | self.args = munch.munchify(config)
29 | self.init_device(config)
30 | self.init_random_seed(config)
31 | self.init_model(config)
32 | self.init_dataset(config)
33 | self.init_trainer_args(config)
34 | self.init_predictor_args(config)
35 | self.prepare_accelerator()
36 |
37 | """
38 | Main Block
39 | """
40 |
41 | def predict(self, **kwargs):
42 | for img_path in self.args.predictor.img_paths:
43 | image = Image.open(img_path)
44 | if not image.mode == "RGB":
45 | image = image.convert('RGB')
46 |
47 | pixel_values = self.processor(image, return_tensors="pt").pixel_values
48 | # prepare decoder inputs
49 | task_prompt = self.args.datasets.train.dataset.task_start_token
50 | decoder_input_ids = self.processor.tokenizer(task_prompt, add_special_tokens=False,
51 | return_tensors="pt").input_ids
52 | start = time.time()
53 | with torch.no_grad():
54 | outputs = self.model.generate(
55 | pixel_values.to(self.args.device.device_id),
56 | decoder_input_ids=decoder_input_ids.to(self.args.device.device_id),
57 | max_length=self.model.decoder.config.max_length,
58 | early_stopping=True,
59 | pad_token_id=self.processor.tokenizer.pad_token_id,
60 | eos_token_id=self.processor.tokenizer.eos_token_id,
61 | use_cache=True,
62 | num_beams=1,
63 | bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
64 | return_dict_in_generate=True,
65 | )
66 | sequence = self.processor.batch_decode(outputs.sequences)[0]
67 | batch_time = time.time() - start
68 | logger.info("batch inference time:{} s".format(batch_time))
69 | sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(
70 | self.processor.tokenizer.pad_token, "")
71 | sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
72 | print(self.processor.token2json(sequence))
73 |
74 | def train(self, **kwargs):
75 | batch_time = AverageMeter()
76 | loss_meter = AverageMeter()
77 | norm_meter = AverageMeter()
78 | global_step = self.args.trainer.start_epoch * len(self.train_data_loader)
79 | global_eval_step = 0
80 | ni = 0
81 | for epoch in range(self.args.trainer.start_epoch, self.args.trainer.epochs):
82 | self.optimizer.zero_grad()
83 | for i, batch in enumerate(self.train_data_loader):
84 | if global_step < self.args.trainer.start_global_step:
85 | global_step += 1
86 | continue
87 | start = time.time()
88 | self.model.train()
89 | ni = i + len(self.train_data_loader) * epoch # number integrated batches (since train start)
90 | with self.gradient_accumulate_scope(self.model):
91 | result = self._step_forward(batch)
92 | self._step_backward(result.loss)
93 | if self.accelerator is not None or ((i + 1) % self.args.trainer.grad_accumulate
94 | == 0) or ((i + 1) == len(self.train_data_loader)):
95 | grad_norm = self._step_optimizer()
96 | norm_meter.update(grad_norm)
97 | if not self.args.trainer.scheduler_by_epoch:
98 | self._step_scheduler(global_step)
99 | loss_meter.update(result['loss'].item(), self.args.datasets.train.batch_size)
100 | batch_time.update(time.time() - start)
101 | global_step += 1
102 | global_eval_step = self._print_step_log(epoch, global_step, global_eval_step, loss_meter, norm_meter,
103 | batch_time, ni)
104 | if self.args.trainer.scheduler_by_epoch:
105 | self._step_scheduler(global_step)
106 | global_eval_step = self._print_epoch_log(epoch, global_step, global_eval_step, loss_meter, ni)
107 | model_config_path = self._train_post_process()
108 | if self.args.device.is_master:
109 | self.writer.close()
110 | return {
111 | 'acc': self.args.trainer.best_eval_result,
112 | 'best_model_path': self.args.trainer.best_model_path,
113 | 'model_config_path': model_config_path,
114 | }
115 |
116 | def _step_forward(self, batch, is_train=True, eval_model=None, **kwargs):
117 | input_args_list = ['pixel_values', 'labels', 'decoder_input_ids']
118 | batch = {k: v.to(self.args.device.device_id) for k, v in batch.items() if k in input_args_list}
119 | # Runs the forward pass with auto-casting.
120 | with self.precision_scope:
121 | output = self.model(**batch)
122 | return output
123 |
124 | """
125 | Initialization Functions
126 | """
127 |
128 | def init_model(self, config):
129 | model_args = config["model"]
130 | tokenizer_args = model_args["tokenizer_args"]
131 | # we can borrow donut tokenizer & processor for docparser
132 | tokenizer = AutoTokenizer.from_pretrained(
133 | pretrained_model_name_or_path=tokenizer_args['pretrained_model_name_or_path']
134 | )
135 | image_processor = DonutImageProcessor(
136 | size={"height": model_args['image_size'][0], "width": model_args['image_size'][1]})
137 | self.processor = DonutProcessor(image_processor=image_processor,
138 | tokenizer=tokenizer)
139 |
140 | # model initialization
141 | AutoConfig.register("docparser-swin", DocParserConfig)
142 | AutoModel.register(DocParserConfig, DocParserModel)
143 | config = VisionEncoderDecoderConfig.from_pretrained(model_args["pretrained_model_name_or_path"])
144 | config.encoder.image_size = model_args['image_size']
145 | # during pre-training, a larger image size was used; for fine-tuning,
146 | # we update max_length of the decoder (for generation)
147 | config.decoder.max_length = model_args['max_length']
148 | config.decoder.decoder_layers = model_args['decoder_layers']
149 | model = VisionEncoderDecoderModel(config=config)
150 | logger.info("init weight from pretrained model:{}".format(model_args["pretrained_model_name_or_path"]))
151 | model.decoder.resize_token_embeddings(len(self.processor.tokenizer))
152 | self.model = model
153 | self.model.to(self.args.device.device_id)
154 | if "model_path" in model_args and model_args['model_path'] is not None:
155 | model_path = get_absolute_file_path(model_args['model_path'])
156 | self.load_model(model_path, strict=model_args.get('load_strict', True))
157 | total = sum([param.nelement() for param in self.model.parameters()])
158 | logger.info("Number of parameter: %.2fM" % (total / 1e6))
159 |
160 | def _init_optimizer(self, trainer_args, **kwargs):
161 | optimizer_args = trainer_args.get("optimizer")
162 | if optimizer_args.get("scale_lr"):
163 | num_process = 1 if self.accelerator is None else self.accelerator.num_processes
164 | optimizer_args['lr'] = float(optimizer_args['lr']) * self.grad_accumulate * \
165 | self.train_data_loader.batch_size * num_process
166 | optimizer_args['img_lr'] = float(optimizer_args['img_lr']) * self.grad_accumulate * \
167 | self.train_data_loader.batch_size * num_process
168 | self.optimizer = get_optimizer(self.model, **optimizer_args)
169 |
170 | def init_dataset(self, config):
171 | if 'datasets' in config and config.get('phase', 'train') != 'predict':
172 | dataset_args = config.get("datasets")
173 | train_data_loader_args = dataset_args.get("train")
174 | if config.get('phase', 'train') == 'train':
175 | train_data_loader_args['dataset'].update({
176 | "donut_model": self.model,
177 | "processor": self.processor,
178 | "max_length": config['model']['max_length'],
179 | "phase": 'train',
180 | })
181 | if "cache_dir" not in train_data_loader_args['dataset']:
182 | train_data_loader_args['dataset'].update({
183 | "cache_dir": config['trainer']['save_dir']})
184 | self.train_dataset = get_dataset(train_data_loader_args['dataset'])
185 | self.train_data_loader = self._get_data_loader_from_dataset(self.train_dataset,
186 | train_data_loader_args,
187 | phase='train')
188 | logger.info("success init train data loader len:{} ".format(len(self.train_data_loader)))
189 |
190 | # set task start token & pad token for bart decoder;
191 | # Do NOT change it since you can only set the start_token after dataset initialization where special tokens
192 | # are added into vocab
193 | self.model.config.decoder_start_token_id = self.processor.tokenizer.convert_tokens_to_ids(
194 | train_data_loader_args['dataset']['task_start_token'])
195 | self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id
196 |
197 | """
198 | Tool Functions
199 | """
200 |
201 | def _print_step_log(self, epoch, global_step, global_eval_step, loss_meter, norm_meter, batch_time, ni, **kwargs):
202 | current_lr = self._get_current_lr(ni, global_step)
203 | if self.args.device.is_master and self.args.trainer.print_freq > 0 and global_step % self.args.trainer.print_freq == 0:
204 | message = "experiment:{}; train, (epoch: {}, steps: {}, lr:{:e}, step_mean_loss:{}," \
205 | " average_loss:{}), time, (train_step_time: {:.5f}s, train_average_time: {:.5f}s);" \
206 | "(grad_norm_mean: {:.5f}, grad_norm_step: {:.5f})". \
207 | format(self.experiment_name, epoch, global_step, current_lr,
208 | loss_meter.val, loss_meter.avg, batch_time.val, batch_time.avg, norm_meter.avg,
209 | norm_meter.val)
210 | logger.info(message)
211 | if self.writer is not None:
212 | self.writer.add_scalar("{}_train/lr".format(self.experiment_name), current_lr, global_step)
213 | self.writer.add_scalar("{}_train/step_loss".format(self.experiment_name), loss_meter.val, global_step)
214 | self.writer.add_scalar("{}_train/average_loss".format(self.experiment_name), loss_meter.avg,
215 | global_step)
216 | if global_step > 0 and self.args.trainer.save_step_freq > 0 and self.args.device.is_master and global_step % self.args.trainer.save_step_freq == 0:
217 | message = "experiment:{}; eval, (epoch: {}, steps: {});".format(self.experiment_name, epoch, global_step)
218 | logger.info(message)
219 | # result = self.evaluate(global_eval_step=global_eval_step)
220 | checkpoint_name = "{}_epoch{}_step{}_lr{:e}_average_loss{:.5f}.pth".format(
221 | self.experiment_name, epoch, global_step, current_lr, loss_meter.avg)
222 | checkpoint_path = os.path.join(self.args.trainer.save_dir, checkpoint_name)
223 | tokenizer_path = os.path.join(self.args.trainer.save_dir, "tokenizer")
224 | os.makedirs(tokenizer_path, exist_ok=True)
225 | self.processor.tokenizer.save_pretrained(tokenizer_path)
226 | self.save_model(checkpoint_path, epoch=epoch, global_step=global_step, loss=loss_meter.val)
227 | return global_eval_step
228 |
229 | def _print_epoch_log(self, epoch, global_step, global_eval_step, loss_meter, ni, **kwargs):
230 | current_lr = self._get_current_lr(ni, global_step)
231 | if self.args.trainer.save_epoch_freq > 0 and self.args.device.is_master and epoch % self.args.trainer.save_epoch_freq == 0:
232 | message = "experiment:{}; eval, (epoch: {}, steps: {});".format(self.experiment_name, epoch, global_step)
233 | logger.info(message)
234 | checkpoint_name = "{}_epoch{}_step{}_lr{:e}_average_loss{:.5f}.pth".format(
235 | self.experiment_name, epoch, global_step, current_lr, loss_meter.avg)
236 | checkpoint_path = os.path.join(self.args.trainer.save_dir, checkpoint_name)
237 | tokenizer_path = os.path.join(self.args.trainer.save_dir, "tokenizer")
238 | os.makedirs(tokenizer_path, exist_ok=True)
239 | self.save_model(checkpoint_path, epoch=epoch, global_step=global_step, loss=loss_meter.val)
240 | self.processor.tokenizer.save_pretrained(tokenizer_path)
241 | return global_eval_step
242 |
--------------------------------------------------------------------------------
/logs/.gitignore:
--------------------------------------------------------------------------------
1 | ### JetBrains template
2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
4 |
5 | # User-specific stuff
6 | .idea/**/workspace.xml
7 | .idea/**/tasks.xml
8 | .idea/**/usage.statistics.xml
9 | .idea/**/dictionaries
10 | .idea/**/shelf
11 |
12 | # Generated files
13 | .idea/**/contentModel.xml
14 |
15 | # Sensitive or high-churn files
16 | .idea/**/dataSources/
17 | .idea/**/dataSources.ids
18 | .idea/**/dataSources.local.xml
19 | .idea/**/sqlDataSources.xml
20 | .idea/**/dynamic.xml
21 | .idea/**/uiDesigner.xml
22 | .idea/**/dbnavigator.xml
23 |
24 | # Gradle
25 | .idea/**/gradle.xml
26 | .idea/**/libraries
27 |
28 | # Gradle and Maven with auto-import
29 | # When using Gradle or Maven with auto-import, you should exclude module files,
30 | # since they will be recreated, and may cause churn. Uncomment if using
31 | # auto-import.
32 | # .idea/artifacts
33 | # .idea/compiler.xml
34 | # .idea/jarRepositories.xml
35 | # .idea/modules.xml
36 | # .idea/*.iml
37 | # .idea/modules
38 | # *.iml
39 | # *.ipr
40 |
41 | # CMake
42 | cmake-build-*/
43 |
44 | # Mongo Explorer plugin
45 | .idea/**/mongoSettings.xml
46 |
47 | # File-based project format
48 | *.iws
49 |
50 | # IntelliJ
51 | out/
52 |
53 | # mpeltonen/sbt-idea plugin
54 | .idea_modules/
55 |
56 | # JIRA plugin
57 | atlassian-ide-plugin.xml
58 |
59 | # Cursive Clojure plugin
60 | .idea/replstate.xml
61 |
62 | # Crashlytics plugin (for Android Studio and IntelliJ)
63 | com_crashlytics_export_strings.xml
64 | crashlytics.properties
65 | crashlytics-build.properties
66 | fabric.properties
67 |
68 | # Editor-based Rest Client
69 | .idea/httpRequests
70 |
71 | # Android studio 3.1+ serialized cache file
72 | .idea/caches/build_file_checksums.ser
73 |
74 | ### macOS template
75 | # General
76 | .DS_Store
77 | .AppleDouble
78 | .LSOverride
79 |
80 | # Icon must end with two \r
81 | Icon
82 |
83 | # Thumbnails
84 | ._*
85 |
86 | # Files that might appear in the root of a volume
87 | .DocumentRevisions-V100
88 | .fseventsd
89 | .Spotlight-V100
90 | .TemporaryItems
91 | .Trashes
92 | .VolumeIcon.icns
93 | .com.apple.timemachine.donotpresent
94 |
95 | # Directories potentially created on remote AFP share
96 | .AppleDB
97 | .AppleDesktop
98 | Network Trash Folder
99 | Temporary Items
100 | .apdisk
101 |
102 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: @time: 1/31/23 17:08
3 | from .configuration_docparser import DocParserConfig
4 | from .modeling_docparser import DocParserModel
--------------------------------------------------------------------------------
/models/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "VisionEncoderDecoderModel"
4 | ],
5 | "decoder": {
6 | "_name_or_path": "",
7 | "activation_dropout": 0.0,
8 | "activation_function": "gelu",
9 | "add_cross_attention": true,
10 | "add_final_layer_norm": true,
11 | "architectures": null,
12 | "attention_dropout": 0.0,
13 | "bad_words_ids": null,
14 | "bos_token_id": 0,
15 | "chunk_size_feed_forward": 0,
16 | "classifier_dropout": 0.0,
17 | "cross_attention_hidden_size": null,
18 | "d_model": 1024,
19 | "decoder_attention_heads": 16,
20 | "decoder_ffn_dim": 4096,
21 | "decoder_layerdrop": 0.0,
22 | "decoder_layers": 4,
23 | "decoder_start_token_id": null,
24 | "diversity_penalty": 0.0,
25 | "do_sample": false,
26 | "dropout": 0.1,
27 | "early_stopping": false,
28 | "encoder_attention_heads": 16,
29 | "encoder_ffn_dim": 4096,
30 | "encoder_layerdrop": 0.0,
31 | "encoder_layers": 12,
32 | "encoder_no_repeat_ngram_size": 0,
33 | "eos_token_id": 2,
34 | "exponential_decay_length_penalty": null,
35 | "finetuning_task": null,
36 | "forced_bos_token_id": null,
37 | "forced_eos_token_id": 2,
38 | "id2label": {
39 | "0": "LABEL_0",
40 | "1": "LABEL_1"
41 | },
42 | "init_std": 0.02,
43 | "is_decoder": true,
44 | "is_encoder_decoder": false,
45 | "label2id": {
46 | "LABEL_0": 0,
47 | "LABEL_1": 1
48 | },
49 | "length_penalty": 1.0,
50 | "max_length": 20,
51 | "max_position_embeddings": 1536,
52 | "min_length": 0,
53 | "model_type": "mbart",
54 | "no_repeat_ngram_size": 0,
55 | "num_beam_groups": 1,
56 | "num_beams": 1,
57 | "num_hidden_layers": 12,
58 | "num_return_sequences": 1,
59 | "output_attentions": false,
60 | "output_hidden_states": false,
61 | "output_scores": false,
62 | "pad_token_id": 1,
63 | "prefix": null,
64 | "problem_type": null,
65 | "pruned_heads": {},
66 | "remove_invalid_values": false,
67 | "repetition_penalty": 1.0,
68 | "return_dict": true,
69 | "return_dict_in_generate": false,
70 | "scale_embedding": true,
71 | "sep_token_id": null,
72 | "task_specific_params": null,
73 | "temperature": 1.0,
74 | "tf_legacy_loss": false,
75 | "tie_encoder_decoder": false,
76 | "tie_word_embeddings": true,
77 | "tokenizer_class": null,
78 | "top_k": 50,
79 | "top_p": 1.0,
80 | "torch_dtype": null,
81 | "torchscript": false,
82 | "transformers_version": "4.22.0.dev0",
83 | "typical_p": 1.0,
84 | "use_bfloat16": false,
85 | "use_cache": true,
86 | "vocab_size": 57525
87 | },
88 | "encoder": {
89 | "_name_or_path": "",
90 | "add_cross_attention": false,
91 | "architectures": null,
92 | "attention_probs_dropout_prob": 0.0,
93 | "bad_words_ids": null,
94 | "bos_token_id": null,
95 | "chunk_size_feed_forward": 0,
96 | "cross_attention_hidden_size": null,
97 | "conv_depth_num_layers": 3,
98 | "decoder_start_token_id": null,
99 | "depths": [
100 | 3,
101 | 6,
102 | 6,
103 | 2,
104 | 2,
105 | 2
106 | ],
107 | "diversity_penalty": 0.0,
108 | "do_sample": false,
109 | "drop_path_rate": 0.1,
110 | "early_stopping": false,
111 | "embed_dim": [
112 | 64,
113 | 128,
114 | 256,
115 | 512,
116 | 768,
117 | 1024
118 | ],
119 | "encoder_no_repeat_ngram_size": 0,
120 | "eos_token_id": null,
121 | "exponential_decay_length_penalty": null,
122 | "finetuning_task": null,
123 | "forced_bos_token_id": null,
124 | "forced_eos_token_id": null,
125 | "hidden_act": "gelu",
126 | "hidden_dropout_prob": 0.0,
127 | "hidden_size": 1024,
128 | "id2label": {
129 | "0": "LABEL_0",
130 | "1": "LABEL_1"
131 | },
132 | "image_size": [
133 | 2560,
134 | 1920
135 | ],
136 | "initializer_range": 0.02,
137 | "is_decoder": false,
138 | "is_encoder_decoder": false,
139 | "label2id": {
140 | "LABEL_0": 0,
141 | "LABEL_1": 1
142 | },
143 | "layer_norm_eps": 1e-05,
144 | "length_penalty": 1.0,
145 | "max_length": 20,
146 | "min_length": 0,
147 | "mlp_ratio": 4.0,
148 | "auto_map": {
149 | "AutoConfig": "configuration_docparser.DocParserConfig"
150 | },
151 | "model_type": "docparser-swin",
152 | "no_repeat_ngram_size": 0,
153 | "num_beam_groups": 1,
154 | "num_beams": 1,
155 | "num_channels": 3,
156 | "num_heads": [
157 | 4,
158 | 8,
159 | 16
160 | ],
161 | "pe_kernel_size": 3,
162 | "pe_stride_size": 2,
163 | "pe_hidden_size": 64,
164 | "pe_add_hidden_act": true,
165 | "num_layers": 3,
166 | "num_return_sequences": 1,
167 | "output_attentions": false,
168 | "output_hidden_states": false,
169 | "output_scores": false,
170 | "pad_token_id": null,
171 | "prefix": null,
172 | "problem_type": null,
173 | "pruned_heads": {},
174 | "qkv_bias": true,
175 | "remove_invalid_values": false,
176 | "repetition_penalty": 1.0,
177 | "return_dict": true,
178 | "return_dict_in_generate": false,
179 | "sep_token_id": null,
180 | "stride_size": [
181 | [
182 | 2,
183 | 1
184 | ],
185 | [
186 | 2,
187 | 1
188 | ],
189 | [
190 | 2,
191 | 2
192 | ]
193 | ],
194 | "task_specific_params": null,
195 | "temperature": 1.0,
196 | "tf_legacy_loss": false,
197 | "tie_encoder_decoder": false,
198 | "tie_word_embeddings": true,
199 | "tokenizer_class": null,
200 | "top_k": 50,
201 | "top_p": 1.0,
202 | "torch_dtype": null,
203 | "torchscript": false,
204 | "transformers_version": "4.22.0.dev0",
205 | "typical_p": 1.0,
206 | "use_absolute_embeddings": false,
207 | "use_bfloat16": false,
208 | "window_size": [
209 | [
210 | 5,
211 | 40
212 | ],
213 | [
214 | 5,
215 | 20
216 | ],
217 | [
218 | 10,
219 | 10
220 | ]
221 | ]
222 | },
223 | "is_encoder_decoder": true,
224 | "model_type": "vision-encoder-decoder",
225 | "tie_word_embeddings": false,
226 | "torch_dtype": "float32",
227 | "transformers_version": null
228 | }
--------------------------------------------------------------------------------
/models/configuration_docparser.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: @time: 7/6/23 10:08
3 | # coding=utf-8
4 | # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | """ DocParser Swin Transformer model configuration"""
18 |
19 | from transformers.configuration_utils import PretrainedConfig
20 | from transformers.utils import logging
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 |
25 | class DocParserConfig(PretrainedConfig):
26 | model_type = "docparser-swin"
27 |
28 | attribute_map = {
29 | "num_attention_heads": "num_heads",
30 | "num_hidden_layers": "num_layers",
31 | }
32 |
33 | def __init__(
34 | self,
35 | image_size=224,
36 | patch_size=4,
37 | num_channels=3,
38 | embed_dim=96,
39 | depths=[2, 2, 6, 2],
40 | num_heads=[3, 6, 12, 24],
41 | window_size=7,
42 | mlp_ratio=4.0,
43 | qkv_bias=True,
44 | hidden_dropout_prob=0.0,
45 | attention_probs_dropout_prob=0.0,
46 | drop_path_rate=0.1,
47 | hidden_act="gelu",
48 | use_absolute_embeddings=False,
49 | initializer_range=0.02,
50 | layer_norm_eps=1e-5,
51 | **kwargs,
52 | ):
53 | super().__init__(**kwargs)
54 |
55 | self.image_size = image_size
56 | self.patch_size = patch_size
57 | self.num_channels = num_channels
58 | self.embed_dim = embed_dim
59 | self.depths = depths
60 | self.num_layers = len(depths)
61 | self.num_heads = num_heads
62 | self.window_size = window_size
63 | self.mlp_ratio = mlp_ratio
64 | self.qkv_bias = qkv_bias
65 | self.hidden_dropout_prob = hidden_dropout_prob
66 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
67 | self.drop_path_rate = drop_path_rate
68 | self.hidden_act = hidden_act
69 | self.use_absolute_embeddings = use_absolute_embeddings
70 | self.layer_norm_eps = layer_norm_eps
71 | self.initializer_range = initializer_range
72 |
--------------------------------------------------------------------------------
/models/convnext.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: @time: 7/6/23 10:28
3 |
4 | from functools import partial
5 |
6 | import torch
7 | from torch import nn, Tensor
8 | from torch.nn import functional as F
9 | from torchvision.ops.stochastic_depth import StochasticDepth
10 | from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
11 |
12 |
13 | class LayerNorm2d(nn.LayerNorm):
14 | def forward(self, x: Tensor) -> Tensor:
15 | x = x.permute(0, 2, 3, 1)
16 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
17 | x = x.permute(0, 3, 1, 2).contiguous()
18 | return x
19 |
20 |
21 | class Permute(torch.nn.Module):
22 | """This module returns a view of the tensor input with its dimensions permuted.
23 |
24 | Args:
25 | dims (List[int]): The desired ordering of dimensions
26 | """
27 |
28 | def __init__(self, dims: List[int]):
29 | super().__init__()
30 | self.dims = dims
31 |
32 | def forward(self, x: Tensor) -> Tensor:
33 | return torch.permute(x, self.dims).contiguous()
34 |
35 |
36 | class CNBlock(nn.Module):
37 | def __init__(
38 | self,
39 | dim,
40 | layer_scale: float,
41 | stochastic_depth_prob: float,
42 | norm_layer: Optional[Callable[..., nn.Module]] = None,
43 | ) -> None:
44 | super().__init__()
45 | if norm_layer is None:
46 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
47 |
48 | self.block = nn.Sequential(
49 | nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
50 | Permute([0, 2, 3, 1]),
51 | norm_layer(dim),
52 | nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
53 | nn.GELU(),
54 | nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
55 | Permute([0, 3, 1, 2]),
56 | )
57 | self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
58 | self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
59 |
60 | def forward(self, input: Tensor) -> Tensor:
61 | result = self.layer_scale * self.block(input)
62 | result = self.stochastic_depth(result)
63 | result += input
64 | return result
65 |
66 |
67 | class CNBlockConfig:
68 | # Stores information listed at Section 3 of the ConvNeXt paper
69 | def __init__(
70 | self,
71 | input_channels: int,
72 | out_channels: Optional[int],
73 | num_layers: int,
74 | stride: Union[Tuple[int, int], int],
75 | ) -> None:
76 | self.input_channels = input_channels
77 | self.out_channels = out_channels
78 | self.num_layers = num_layers
79 | self.stride = stride
80 |
81 | def __repr__(self) -> str:
82 | s = self.__class__.__name__ + "("
83 | s += "input_channels={input_channels}"
84 | s += ", out_channels={out_channels}"
85 | s += ", num_layers={num_layers}"
86 | s += ")"
87 | return s.format(**self.__dict__)
88 |
89 |
90 | class ConvNeXt(nn.Module):
91 | def __init__(
92 | self,
93 | block_setting: List[CNBlockConfig],
94 | stochastic_depth_prob: float = 0.0,
95 | layer_scale: float = 1e-6,
96 | block: Optional[Callable[..., nn.Module]] = None,
97 | norm_layer: Optional[Callable[..., nn.Module]] = None,
98 | **kwargs: Any,
99 | ) -> None:
100 | super().__init__()
101 |
102 | if not block_setting:
103 | raise ValueError("The block_setting should not be empty")
104 | elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
105 | raise TypeError("The block_setting should be List[CNBlockConfig]")
106 |
107 | if block is None:
108 | block = CNBlock
109 |
110 | if norm_layer is None:
111 | norm_layer = partial(LayerNorm2d, eps=1e-6)
112 |
113 | layers: List[nn.Module] = []
114 |
115 | total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
116 | stage_block_id = 0
117 | for cnf in block_setting:
118 | # Bottlenecks
119 | stage: List[nn.Module] = []
120 | for _ in range(cnf.num_layers):
121 | # adjust stochastic depth probability based on the depth of the stage block
122 | sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
123 | stage.append(block(cnf.input_channels, layer_scale, sd_prob))
124 | stage_block_id += 1
125 | layers.append(nn.Sequential(*stage))
126 | if cnf.out_channels is not None:
127 | # Downsampling
128 | layers.append(
129 | nn.Sequential(
130 | norm_layer(cnf.input_channels),
131 | nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=cnf.stride, stride=cnf.stride),
132 | )
133 | )
134 |
135 | self.features = nn.Sequential(*layers)
136 |
137 | for m in self.modules():
138 | if isinstance(m, (nn.Conv2d, nn.Linear)):
139 | nn.init.trunc_normal_(m.weight, std=0.02)
140 | if m.bias is not None:
141 | nn.init.zeros_(m.bias)
142 |
143 | def _forward_impl(self, x: Tensor) -> Tensor:
144 | x = self.features(x)
145 | return x
146 |
147 | def forward(self, x: Tensor) -> Tensor:
148 | return self._forward_impl(x)
149 |
150 |
151 | if __name__ == '__main__':
152 | channel_list = [64, 128, 256]
153 | num_layer_list = [3, 6, 6]
154 | stride = [(1, 2), (1, 2), (2, 2)]
155 | model = ConvNeXt(block_setting=[
156 | CNBlockConfig(input_channels=channel_list[i_layer],
157 | out_channels=channel_list[i_layer] * 2,
158 | num_layers=num_layer_list[i_layer],
159 | stride=stride[i_layer]
160 | )
161 | for i_layer in range(len(num_layer_list))
162 | ])
163 |
--------------------------------------------------------------------------------
/models/modeling_docparser.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: @time: 7/5/23 16:14
3 | # coding=utf-8
4 |
5 |
6 | import collections.abc
7 |
8 | import math
9 | from dataclasses import dataclass
10 | from typing import Optional, Tuple, Union
11 |
12 | import torch
13 | import torch.utils.checkpoint
14 | from torch import nn
15 |
16 | from transformers.activations import ACT2FN
17 | from transformers.modeling_utils import PreTrainedModel
18 | from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
19 | from transformers.utils import (
20 | ModelOutput,
21 | add_code_sample_docstrings,
22 | add_start_docstrings,
23 | add_start_docstrings_to_model_forward,
24 | logging,
25 | )
26 | from .configuration_docparser import DocParserConfig
27 | from .convnext import ConvNeXt, CNBlockConfig
28 |
29 | logger = logging.get_logger(__name__)
30 |
31 | # General docstring
32 | _CONFIG_FOR_DOC = "DocParserConfig"
33 |
34 | # Base docstring
35 | _CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
36 | _EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
37 |
38 | DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
39 | "naver-clova-ix/donut-base",
40 | # See all Donut Swin models at https://huggingface.co/models?filter=donut
41 | ]
42 |
43 |
44 | @dataclass
45 | # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DocParser
46 | class DocParserEncoderOutput(ModelOutput):
47 | """
48 | DocParser encoder's outputs, with potential hidden states and attentions.
49 |
50 | Args:
51 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
52 | Sequence of hidden-states at the output of the last layer of the model.
53 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
54 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
55 | shape `(batch_size, sequence_length, hidden_size)`.
56 |
57 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
58 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
59 | Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
60 | sequence_length)`.
61 |
62 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
63 | heads.
64 | reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
65 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
66 | shape `(batch_size, hidden_size, height, width)`.
67 |
68 | Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
69 | include the spatial dimensions.
70 | """
71 |
72 | last_hidden_state: torch.FloatTensor = None
73 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
74 | attentions: Optional[Tuple[torch.FloatTensor]] = None
75 | reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
76 |
77 |
78 | @dataclass
79 | # Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DocParser
80 | class DocParserModelOutput(ModelOutput):
81 | """
82 | DocParser model's outputs that also contains a pooling of the last hidden states.
83 |
84 | Args:
85 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
86 | Sequence of hidden-states at the output of the last layer of the model.
87 | pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
88 | Average pooling of the last layer hidden-state.
89 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
90 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
91 | shape `(batch_size, sequence_length, hidden_size)`.
92 |
93 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
94 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
95 | Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
96 | sequence_length)`.
97 |
98 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
99 | heads.
100 | reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
101 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
102 | shape `(batch_size, hidden_size, height, width)`.
103 |
104 | Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
105 | include the spatial dimensions.
106 | """
107 |
108 | last_hidden_state: torch.FloatTensor = None
109 | pooler_output: Optional[torch.FloatTensor] = None
110 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
111 | attentions: Optional[Tuple[torch.FloatTensor]] = None
112 | reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
113 |
114 |
115 | # Copied from transformers.models.swin.modeling_swin.window_partition
116 | def window_partition(input_feature, window_size):
117 | """
118 | Partitions the given input into windows.
119 | """
120 | batch_size, height, width, num_channels = input_feature.shape
121 | input_feature = input_feature.view(
122 | batch_size, height // window_size[0], window_size[0], width // window_size[1], window_size[1], num_channels
123 | )
124 | windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1],
125 | num_channels)
126 | return windows
127 |
128 |
129 | # Copied from transformers.models.swin.modeling_swin.window_reverse
130 | def window_reverse(windows, window_size, height, width):
131 | """
132 | Merges windows to produce higher resolution features.
133 | """
134 | num_channels = windows.shape[-1]
135 | windows = windows.view(-1, height // window_size[0], width // window_size[1], window_size[0], window_size[1],
136 | num_channels)
137 | windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
138 | return windows
139 |
140 |
141 | class ConvBNLayer(nn.Module):
142 | def __init__(self,
143 | in_channels,
144 | out_channels,
145 | kernel_size,
146 | padding,
147 | stride_size):
148 | super().__init__()
149 | kernel_size = kernel_size if isinstance(kernel_size, collections.abc.Iterable) else (kernel_size, kernel_size)
150 | stride_size = stride_size if isinstance(stride_size, collections.abc.Iterable) else (stride_size, stride_size)
151 | padding = padding if isinstance(stride_size, collections.abc.Iterable) else (padding, padding)
152 | self.conv = nn.Conv2d(
153 | in_channels=in_channels,
154 | out_channels=out_channels,
155 | kernel_size=kernel_size,
156 | padding=padding,
157 | stride=stride_size)
158 | self.norm = nn.BatchNorm2d(out_channels)
159 | self.act = nn.GELU()
160 |
161 | def forward(self, inputs):
162 | out = self.conv(inputs)
163 | out = self.norm(out)
164 | out = self.act(out)
165 | return out
166 |
167 |
168 | class DocParserPatchEmbeddings(nn.Module):
169 | """
170 | Construct the patch and position embeddings. Optionally, also the mask token.
171 | """
172 |
173 | def __init__(self, config):
174 | super().__init__()
175 | image_size = config.image_size
176 | kernel_size, stride_size = config.pe_kernel_size, config.pe_stride_size
177 | num_channels, hidden_size = config.num_channels, config.pe_hidden_size
178 | self.grid_size = (image_size[0] // 32, image_size[1] // 8) # num patches for swin-part
179 | self.patch_embedding = nn.Sequential(
180 | ConvBNLayer(
181 | in_channels=num_channels,
182 | out_channels=hidden_size // 2,
183 | kernel_size=kernel_size,
184 | stride_size=stride_size,
185 | padding=1),
186 | ConvBNLayer(
187 | in_channels=hidden_size // 2,
188 | out_channels=hidden_size,
189 | kernel_size=kernel_size,
190 | stride_size=stride_size,
191 | padding=1),
192 | )
193 |
194 | def forward(
195 | self,
196 | pixel_values: Optional[torch.FloatTensor]):
197 | embeddings = self.patch_embedding(pixel_values)
198 | return embeddings
199 |
200 |
201 | # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
202 | class DocParserPatchMerging(nn.Module):
203 | """
204 | Patch Merging Layer.
205 |
206 | Args:
207 | input_resolution (`Tuple[int]`):
208 | Resolution of input feature.
209 | dim (`int`):
210 | Number of input channels.
211 | norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
212 | Normalization layer class.
213 | """
214 |
215 | def __init__(self, input_resolution: Tuple[int], dim: int, dim_out: int,
216 | norm_layer: nn.Module = nn.LayerNorm) -> None:
217 | super().__init__()
218 | self.input_resolution = input_resolution
219 | self.dim = dim
220 | self.reduction = nn.Linear(4 * dim, dim_out, bias=False)
221 | self.norm = norm_layer(4 * dim)
222 |
223 | def maybe_pad(self, input_feature, height, width):
224 | should_pad = (height % 2 == 1) or (width % 2 == 1)
225 | if should_pad:
226 | pad_values = (0, 0, 0, width % 2, 0, height % 2)
227 | input_feature = nn.functional.pad(input_feature, pad_values)
228 |
229 | return input_feature
230 |
231 | def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
232 | height, width = input_dimensions
233 | # `dim` is height * width
234 | batch_size, dim, num_channels = input_feature.shape
235 |
236 | input_feature = input_feature.view(batch_size, height, width, num_channels)
237 | # pad input to be disible by width and height, if needed
238 | input_feature = self.maybe_pad(input_feature, height, width)
239 | # [batch_size, height, width/2, num_channels]
240 | input_feature_0 = input_feature[:, :, 0::2, :]
241 | # [batch_size, height, width/2, num_channels]
242 | input_feature_1 = input_feature[:, :, 0::2, :]
243 | # [batch_size, height, width/2, num_channels]
244 | input_feature_2 = input_feature[:, :, 1::2, :]
245 | # [batch_size, height, width/2, num_channels]
246 | input_feature_3 = input_feature[:, :, 1::2, :]
247 | # batch_size height width/2 4*num_channels
248 | input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
249 | input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
250 |
251 | input_feature = self.norm(input_feature)
252 | input_feature = self.reduction(input_feature)
253 |
254 | return input_feature
255 |
256 |
257 | # Copied from transformers.models.swin.modeling_swin.drop_path
258 | def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
259 | """
260 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
261 |
262 | Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
263 | however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
264 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
265 | layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
266 | argument.
267 | """
268 | if drop_prob == 0.0 or not training:
269 | return input
270 | keep_prob = 1 - drop_prob
271 | shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
272 | random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
273 | random_tensor.floor_() # binarize
274 | output = input.div(keep_prob) * random_tensor
275 | return output
276 |
277 |
278 | # Copied from transformers.models.swin.modeling_swin.SwinDropPath
279 | class DocParserDropPath(nn.Module):
280 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
281 |
282 | def __init__(self, drop_prob: Optional[float] = None) -> None:
283 | super().__init__()
284 | self.drop_prob = drop_prob
285 |
286 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
287 | return drop_path(hidden_states, self.drop_prob, self.training)
288 |
289 | def extra_repr(self) -> str:
290 | return "p={}".format(self.drop_prob)
291 |
292 |
293 | # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DocParser
294 | class DocParserSelfAttention(nn.Module):
295 | def __init__(self, config, dim, num_heads, window_size):
296 | super().__init__()
297 | if dim % num_heads != 0:
298 | raise ValueError(
299 | f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
300 | )
301 |
302 | self.num_attention_heads = num_heads
303 | self.attention_head_size = int(dim / num_heads)
304 | self.all_head_size = self.num_attention_heads * self.attention_head_size
305 | self.window_size = (
306 | window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
307 | )
308 |
309 | self.relative_position_bias_table = nn.Parameter(
310 | torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
311 | )
312 |
313 | # get pair-wise relative position index for each token inside the window
314 | coords_h = torch.arange(self.window_size[0])
315 | coords_w = torch.arange(self.window_size[1])
316 | coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
317 | coords_flatten = torch.flatten(coords, 1)
318 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
319 | relative_coords = relative_coords.permute(1, 2, 0).contiguous()
320 | relative_coords[:, :, 0] += self.window_size[0] - 1
321 | relative_coords[:, :, 1] += self.window_size[1] - 1
322 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
323 | relative_position_index = relative_coords.sum(-1)
324 | self.register_buffer("relative_position_index", relative_position_index)
325 |
326 | self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
327 | self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
328 | self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
329 |
330 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
331 |
332 | def transpose_for_scores(self, x):
333 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
334 | x = x.view(new_x_shape)
335 | return x.permute(0, 2, 1, 3)
336 |
337 | def forward(
338 | self,
339 | hidden_states: torch.Tensor,
340 | attention_mask: Optional[torch.FloatTensor] = None,
341 | head_mask: Optional[torch.FloatTensor] = None,
342 | output_attentions: Optional[bool] = False,
343 | ) -> Tuple[torch.Tensor]:
344 | batch_size, dim, num_channels = hidden_states.shape
345 | mixed_query_layer = self.query(hidden_states)
346 |
347 | key_layer = self.transpose_for_scores(self.key(hidden_states))
348 | value_layer = self.transpose_for_scores(self.value(hidden_states))
349 | query_layer = self.transpose_for_scores(mixed_query_layer)
350 |
351 | # Take the dot product between "query" and "key" to get the raw attention scores.
352 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
353 |
354 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
355 |
356 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
357 | relative_position_bias = relative_position_bias.view(
358 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
359 | )
360 |
361 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
362 | attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
363 |
364 | if attention_mask is not None:
365 | # Apply the attention mask is (precomputed for all layers in DocParserModel forward() function)
366 | mask_shape = attention_mask.shape[0]
367 | attention_scores = attention_scores.view(
368 | batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
369 | )
370 | attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
371 | attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
372 |
373 | # Normalize the attention scores to probabilities.
374 | attention_probs = nn.functional.softmax(attention_scores, dim=-1)
375 |
376 | # This is actually dropping out entire tokens to attend to, which might
377 | # seem a bit unusual, but is taken from the original Transformer paper.
378 | attention_probs = self.dropout(attention_probs)
379 |
380 | # Mask heads if we want to
381 | if head_mask is not None:
382 | attention_probs = attention_probs * head_mask
383 |
384 | context_layer = torch.matmul(attention_probs, value_layer)
385 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
386 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
387 | context_layer = context_layer.view(new_context_layer_shape)
388 |
389 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
390 |
391 | return outputs
392 |
393 |
394 | # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
395 | class DocParserSelfOutput(nn.Module):
396 | def __init__(self, config, dim):
397 | super().__init__()
398 | self.dense = nn.Linear(dim, dim)
399 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
400 |
401 | def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
402 | hidden_states = self.dense(hidden_states)
403 | hidden_states = self.dropout(hidden_states)
404 |
405 | return hidden_states
406 |
407 |
408 | # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DocParser
409 | class DocParserAttention(nn.Module):
410 | def __init__(self, config, dim, num_heads, window_size):
411 | super().__init__()
412 | self.self = DocParserSelfAttention(config, dim, num_heads, window_size)
413 | self.output = DocParserSelfOutput(config, dim)
414 | self.pruned_heads = set()
415 |
416 | def prune_heads(self, heads):
417 | if len(heads) == 0:
418 | return
419 | heads, index = find_pruneable_heads_and_indices(
420 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
421 | )
422 |
423 | # Prune linear layers
424 | self.self.query = prune_linear_layer(self.self.query, index)
425 | self.self.key = prune_linear_layer(self.self.key, index)
426 | self.self.value = prune_linear_layer(self.self.value, index)
427 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
428 |
429 | # Update hyper params and store pruned heads
430 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
431 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
432 | self.pruned_heads = self.pruned_heads.union(heads)
433 |
434 | def forward(
435 | self,
436 | hidden_states: torch.Tensor,
437 | attention_mask: Optional[torch.FloatTensor] = None,
438 | head_mask: Optional[torch.FloatTensor] = None,
439 | output_attentions: Optional[bool] = False,
440 | ) -> Tuple[torch.Tensor]:
441 | self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
442 | attention_output = self.output(self_outputs[0], hidden_states)
443 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
444 | return outputs
445 |
446 |
447 | # Copied from transformers.models.swin.modeling_swin.SwinIntermediate
448 | class DocParserIntermediate(nn.Module):
449 | def __init__(self, config, dim):
450 | super().__init__()
451 | self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
452 | if isinstance(config.hidden_act, str):
453 | self.intermediate_act_fn = ACT2FN[config.hidden_act]
454 | else:
455 | self.intermediate_act_fn = config.hidden_act
456 |
457 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
458 | hidden_states = self.dense(hidden_states)
459 | hidden_states = self.intermediate_act_fn(hidden_states)
460 | return hidden_states
461 |
462 |
463 | # Copied from transformers.models.swin.modeling_swin.SwinOutput
464 | class DocParserOutput(nn.Module):
465 | def __init__(self, config, dim):
466 | super().__init__()
467 | self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
468 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
469 |
470 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
471 | hidden_states = self.dense(hidden_states)
472 | hidden_states = self.dropout(hidden_states)
473 | return hidden_states
474 |
475 |
476 | # Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DocParser
477 | class DocParserLayer(nn.Module):
478 | def __init__(self, config, dim, input_resolution, window_size, num_heads, shift_size=0):
479 | super().__init__()
480 | self.chunk_size_feed_forward = config.chunk_size_feed_forward
481 | self.shift_size = shift_size
482 | self.window_size = window_size
483 | self.input_resolution = input_resolution
484 | self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
485 | self.attention = DocParserAttention(config, dim, num_heads, window_size=self.window_size)
486 | self.drop_path = DocParserDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
487 | self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
488 | self.intermediate = DocParserIntermediate(config, dim)
489 | self.output = DocParserOutput(config, dim)
490 |
491 | def set_shift_and_window_size(self, input_resolution):
492 | if min(input_resolution) <= min(self.window_size):
493 | # if window size is larger than input resolution, we don't partition windows
494 | self.shift_size = 0
495 | self.window_size = min(input_resolution)
496 |
497 | def get_attn_mask(self, height, width, dtype):
498 | if isinstance(self.shift_size, list):
499 | # calculate attention mask for SW-MSA
500 | img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
501 | height_slices = (
502 | slice(0, -self.window_size[0]),
503 | slice(-self.window_size[0], -self.shift_size[0]),
504 | slice(-self.shift_size[0], None),
505 | )
506 | width_slices = (
507 | slice(0, -self.window_size[1]),
508 | slice(-self.window_size[1], -self.shift_size[1]),
509 | slice(-self.shift_size[1], None),
510 | )
511 | count = 0
512 | for height_slice in height_slices:
513 | for width_slice in width_slices:
514 | img_mask[:, height_slice, width_slice, :] = count
515 | count += 1
516 |
517 | mask_windows = window_partition(img_mask, self.window_size)
518 | mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])
519 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
520 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
521 | else:
522 | attn_mask = None
523 | return attn_mask
524 |
525 | def maybe_pad(self, hidden_states, height, width):
526 | pad_right = (self.window_size[1] - width % self.window_size[1]) % self.window_size[1]
527 | pad_bottom = (self.window_size[0] - height % self.window_size[0]) % self.window_size[0]
528 | pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
529 | hidden_states = nn.functional.pad(hidden_states, pad_values)
530 | return hidden_states, pad_values
531 |
532 | def forward(
533 | self,
534 | hidden_states: torch.Tensor,
535 | input_dimensions: Tuple[int, int],
536 | head_mask: Optional[torch.FloatTensor] = None,
537 | output_attentions: Optional[bool] = False,
538 | always_partition: Optional[bool] = False,
539 | ) -> Tuple[torch.Tensor, torch.Tensor]:
540 | if not always_partition:
541 | self.set_shift_and_window_size(input_dimensions)
542 | else:
543 | pass
544 | height, width = input_dimensions
545 | batch_size, _, channels = hidden_states.size()
546 | shortcut = hidden_states
547 |
548 | hidden_states = self.layernorm_before(hidden_states)
549 |
550 | hidden_states = hidden_states.view(batch_size, height, width, channels)
551 |
552 | # pad hidden_states to multiples of window size
553 | hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
554 |
555 | _, height_pad, width_pad, _ = hidden_states.shape
556 | # cyclic shift
557 | if isinstance(self.shift_size, list):
558 | shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size[0], -self.shift_size[1]),
559 | dims=(1, 2))
560 | else:
561 | shifted_hidden_states = hidden_states
562 |
563 | # partition windows
564 | hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
565 | hidden_states_windows = hidden_states_windows.view(-1, self.window_size[0] * self.window_size[1], channels)
566 | attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
567 | if attn_mask is not None:
568 | attn_mask = attn_mask.to(hidden_states_windows.device)
569 |
570 | attention_outputs = self.attention(
571 | hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
572 | )
573 |
574 | attention_output = attention_outputs[0]
575 |
576 | attention_windows = attention_output.view(-1, self.window_size[0], self.window_size[1], channels)
577 | shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
578 |
579 | # reverse cyclic shift
580 | if isinstance(self.shift_size, list) > 0:
581 | attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size[0], self.shift_size[1]),
582 | dims=(1, 2))
583 | else:
584 | attention_windows = shifted_windows
585 |
586 | was_padded = pad_values[3] > 0 or pad_values[5] > 0
587 | if was_padded:
588 | attention_windows = attention_windows[:, :height, :width, :].contiguous()
589 |
590 | attention_windows = attention_windows.view(batch_size, height * width, channels)
591 |
592 | hidden_states = shortcut + self.drop_path(attention_windows)
593 |
594 | layer_output = self.layernorm_after(hidden_states)
595 | layer_output = self.intermediate(layer_output)
596 | layer_output = hidden_states + self.output(layer_output)
597 |
598 | layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
599 | return layer_outputs
600 |
601 |
602 | # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DocParser
603 | class DocParserStage(nn.Module):
604 | def __init__(self, config, dim, dim_out, input_resolution, depth, window_size, num_heads, drop_path, downsample):
605 | super().__init__()
606 | self.config = config
607 | self.dim = dim
608 | self.blocks = nn.ModuleList(
609 | [
610 | DocParserLayer(
611 | config=config,
612 | dim=dim,
613 | input_resolution=input_resolution,
614 | num_heads=num_heads,
615 | window_size=window_size,
616 | shift_size=0 if (i % 2 == 0) else [window_size[0] // 2, window_size[1] // 2],
617 | )
618 | for i in range(depth)
619 | ]
620 | )
621 |
622 | # patch merging layer
623 | if downsample is not None:
624 | self.downsample = downsample(input_resolution, dim=dim, dim_out=dim_out, norm_layer=nn.LayerNorm)
625 | else:
626 | self.downsample = None
627 |
628 | self.pointing = False
629 |
630 | def forward(
631 | self,
632 | hidden_states: torch.Tensor,
633 | input_dimensions: Tuple[int, int],
634 | head_mask: Optional[torch.FloatTensor] = None,
635 | output_attentions: Optional[bool] = False,
636 | always_partition: Optional[bool] = False,
637 | ) -> Tuple[torch.Tensor]:
638 | height, width = input_dimensions
639 | for i, layer_module in enumerate(self.blocks):
640 | layer_head_mask = head_mask[i] if head_mask is not None else None
641 |
642 | layer_outputs = layer_module(
643 | hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
644 | )
645 |
646 | hidden_states = layer_outputs[0]
647 |
648 | hidden_states_before_downsampling = hidden_states
649 | if self.downsample is not None:
650 | height_downsampled, width_downsampled = height, (width + 1) // 2
651 | output_dimensions = (height, width, height_downsampled, width_downsampled)
652 | hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
653 | else:
654 | output_dimensions = (height, width, height, width)
655 |
656 | stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
657 |
658 | if output_attentions:
659 | stage_outputs += layer_outputs[1:]
660 | return stage_outputs
661 |
662 |
663 | class DocParserConvNeXtEncoder(nn.Module):
664 | def __init__(self, config):
665 | super().__init__()
666 | self.config = config
667 | conv_depth_num_layers = config.conv_depth_num_layers
668 | conv_embed_dim = config.embed_dim[:conv_depth_num_layers]
669 | conv_depth = config.depths[:conv_depth_num_layers]
670 | stride_size = config.stride_size
671 | # ConNeXt Stage
672 | self.layers = ConvNeXt(block_setting=[
673 | CNBlockConfig(input_channels=conv_embed_dim[i_layer],
674 | out_channels=conv_embed_dim[i_layer] * 2,
675 | num_layers=conv_depth[i_layer],
676 | stride=stride_size[i_layer]
677 | )
678 | for i_layer in range(conv_depth_num_layers)],
679 | stochastic_depth_prob=0.1)
680 |
681 | def forward(
682 | self,
683 | hidden_states: torch.Tensor):
684 | return self.layers(hidden_states)
685 |
686 |
687 | # Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DocParser
688 | class DocParserEncoder(nn.Module):
689 | def __init__(self, config, grid_size):
690 | super().__init__()
691 | self.num_layers = len(config.depths)
692 | self.config = config
693 | swin_depth_num_layers = self.num_layers - config.conv_depth_num_layers
694 | swin_embed_dim = config.embed_dim[swin_depth_num_layers:]
695 | swin_depth = config.depths[swin_depth_num_layers:]
696 |
697 | # Swin-ViT Stage
698 | dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
699 | self.layers = nn.ModuleList([
700 | DocParserStage(
701 | config=config,
702 | window_size=config.window_size[i_layer],
703 | dim=int(swin_embed_dim[i_layer]),
704 | dim_out=int(swin_embed_dim[i_layer + 1]) if i_layer < swin_depth_num_layers - 1 else int(
705 | swin_embed_dim[i_layer]),
706 | input_resolution=(grid_size[0], grid_size[1] // (2 ** i_layer)),
707 | depth=swin_depth[i_layer],
708 | num_heads=config.num_heads[i_layer],
709 | drop_path=dpr[sum(swin_depth[:i_layer]): sum(swin_depth[: i_layer + 1])],
710 | downsample=DocParserPatchMerging if (i_layer < swin_depth_num_layers - 1) else None,
711 | )
712 | for i_layer in range(swin_depth_num_layers)
713 | ])
714 | self.gradient_checkpointing = False
715 |
716 | def forward(
717 | self,
718 | hidden_states: torch.Tensor,
719 | input_dimensions: Tuple[int, int],
720 | head_mask: Optional[torch.FloatTensor] = None,
721 | output_attentions: Optional[bool] = False,
722 | output_hidden_states: Optional[bool] = False,
723 | output_hidden_states_before_downsampling: Optional[bool] = False,
724 | always_partition: Optional[bool] = False,
725 | return_dict: Optional[bool] = True,
726 | ) -> Union[Tuple, DocParserEncoderOutput]:
727 | all_hidden_states = () if output_hidden_states else None
728 | all_reshaped_hidden_states = () if output_hidden_states else None
729 | all_self_attentions = () if output_attentions else None
730 |
731 | if output_hidden_states:
732 | batch_size, _, hidden_size = hidden_states.shape
733 | # rearrange b (h w) c -> b c h w
734 | reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
735 | reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
736 | all_hidden_states += (hidden_states,)
737 | all_reshaped_hidden_states += (reshaped_hidden_state,)
738 |
739 | for i, layer_module in enumerate(self.layers):
740 | layer_head_mask = head_mask[i] if head_mask is not None else None
741 |
742 | if self.gradient_checkpointing and self.training:
743 |
744 | def create_custom_forward(module):
745 | def custom_forward(*inputs):
746 | return module(*inputs, output_attentions)
747 |
748 | return custom_forward
749 |
750 | layer_outputs = torch.utils.checkpoint.checkpoint(
751 | create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
752 | )
753 | else:
754 | layer_outputs = layer_module(
755 | hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
756 | )
757 |
758 | hidden_states = layer_outputs[0]
759 | hidden_states_before_downsampling = layer_outputs[1]
760 | output_dimensions = layer_outputs[2]
761 |
762 | input_dimensions = (output_dimensions[-2], output_dimensions[-1])
763 |
764 | if output_hidden_states and output_hidden_states_before_downsampling:
765 | batch_size, _, hidden_size = hidden_states_before_downsampling.shape
766 | # rearrange b (h w) c -> b c h w
767 | # here we use the original (not downsampled) height and width
768 | reshaped_hidden_state = hidden_states_before_downsampling.view(
769 | batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
770 | )
771 | reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
772 | all_hidden_states += (hidden_states_before_downsampling,)
773 | all_reshaped_hidden_states += (reshaped_hidden_state,)
774 | elif output_hidden_states and not output_hidden_states_before_downsampling:
775 | batch_size, _, hidden_size = hidden_states.shape
776 | # rearrange b (h w) c -> b c h w
777 | reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
778 | reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
779 | all_hidden_states += (hidden_states,)
780 | all_reshaped_hidden_states += (reshaped_hidden_state,)
781 |
782 | if output_attentions:
783 | all_self_attentions += layer_outputs[3:]
784 |
785 | if not return_dict:
786 | return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
787 |
788 | return DocParserEncoderOutput(
789 | last_hidden_state=hidden_states,
790 | hidden_states=all_hidden_states,
791 | attentions=all_self_attentions,
792 | reshaped_hidden_states=all_reshaped_hidden_states,
793 | )
794 |
795 |
796 | # Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DocParser
797 | class DocParserPreTrainedModel(PreTrainedModel):
798 | """
799 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
800 | models.
801 | """
802 |
803 | config_class = DocParserConfig
804 | base_model_prefix = "swin"
805 | main_input_name = "pixel_values"
806 | supports_gradient_checkpointing = True
807 |
808 | def _init_weights(self, module):
809 | """Initialize the weights"""
810 | if isinstance(module, (nn.Linear, nn.Conv2d)):
811 | # Slightly different from the TF version which uses truncated_normal for initialization
812 | # cf https://github.com/pytorch/pytorch/pull/5617
813 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
814 | if module.bias is not None:
815 | module.bias.data.zero_()
816 | elif isinstance(module, nn.LayerNorm):
817 | module.bias.data.zero_()
818 | module.weight.data.fill_(1.0)
819 |
820 | def _set_gradient_checkpointing(self, module, value=False):
821 | if isinstance(module, DocParserEncoder):
822 | module.gradient_checkpointing = value
823 |
824 |
825 | SWIN_START_DOCSTRING = r"""
826 | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
827 | it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
828 | behavior.
829 |
830 | Parameters:
831 | config ([`DocParserConfig`]): Model configuration class with all the parameters of the model.
832 | Initializing with a config file does not load the weights associated with the model, only the
833 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
834 | """
835 |
836 | SWIN_INPUTS_DOCSTRING = r"""
837 | Args:
838 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
839 | Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
840 | [`DonutImageProcessor.__call__`] for details.
841 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
842 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
843 |
844 | - 1 indicates the head is **not masked**,
845 | - 0 indicates the head is **masked**.
846 |
847 | output_attentions (`bool`, *optional*):
848 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
849 | tensors for more detail.
850 | output_hidden_states (`bool`, *optional*):
851 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
852 | more detail.
853 | return_dict (`bool`, *optional*):
854 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
855 | """
856 |
857 |
858 | @add_start_docstrings(
859 | "The bare Donut Swin Model transformer outputting raw hidden-states without any specific head on top.",
860 | SWIN_START_DOCSTRING,
861 | )
862 | class DocParserModel(DocParserPreTrainedModel):
863 | def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
864 | super().__init__(config)
865 | self.config = config
866 | self.num_layers = len(config.depths)
867 |
868 | self.embeddings = DocParserPatchEmbeddings(config)
869 | self.convnext_encoder = DocParserConvNeXtEncoder(config)
870 | self.encoder = DocParserEncoder(config, self.embeddings.grid_size)
871 |
872 | self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
873 |
874 | # Initialize weights and apply final processing
875 | self.post_init()
876 |
877 | def get_input_embeddings(self):
878 | return self.embeddings.patch_embeddings
879 |
880 | def _prune_heads(self, heads_to_prune):
881 | """
882 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
883 | class PreTrainedModel
884 | """
885 | for layer, heads in heads_to_prune.items():
886 | self.encoder.layer[layer].attention.prune_heads(heads)
887 |
888 | @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
889 | @add_code_sample_docstrings(
890 | checkpoint=_CHECKPOINT_FOR_DOC,
891 | output_type=DocParserModelOutput,
892 | config_class=_CONFIG_FOR_DOC,
893 | modality="vision",
894 | expected_output=_EXPECTED_OUTPUT_SHAPE,
895 | )
896 | def forward(
897 | self,
898 | pixel_values: Optional[torch.FloatTensor] = None,
899 | head_mask: Optional[torch.FloatTensor] = None,
900 | output_attentions: Optional[bool] = None,
901 | output_hidden_states: Optional[bool] = None,
902 | return_dict: Optional[bool] = None,
903 | ) -> Union[Tuple, DocParserModelOutput]:
904 | r"""
905 | bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
906 | Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
907 | """
908 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
909 | output_hidden_states = (
910 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
911 | )
912 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
913 |
914 | if pixel_values is None:
915 | raise ValueError("You have to specify pixel_values")
916 |
917 | # Prepare head mask if needed
918 | # 1.0 in head_mask indicate we keep the head
919 | # attention_probs has shape bsz x n_heads x N x N
920 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
921 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
922 | head_mask = self.get_head_mask(head_mask, len(self.config.depths))
923 |
924 | embedding_output = self.embeddings(pixel_values)
925 |
926 | # ConvNext Stage
927 | encoder_outputs = self.convnext_encoder(embedding_output)
928 |
929 | # ConvNext to Swin-ViT Stage
930 | _, _, height, width = encoder_outputs.shape
931 | input_dimensions = (height, width)
932 | encoder_outputs = encoder_outputs.flatten(2).transpose(1, 2)
933 |
934 | # Swin-ViT Stage
935 | encoder_outputs = self.encoder(
936 | encoder_outputs,
937 | input_dimensions,
938 | head_mask=head_mask,
939 | output_attentions=output_attentions,
940 | output_hidden_states=output_hidden_states,
941 | return_dict=return_dict,
942 | )
943 |
944 | sequence_output = encoder_outputs[0]
945 |
946 | pooled_output = None
947 | if self.pooler is not None:
948 | pooled_output = self.pooler(sequence_output.transpose(1, 2))
949 | pooled_output = torch.flatten(pooled_output, 1)
950 |
951 | if not return_dict:
952 | output = (sequence_output, pooled_output) + encoder_outputs[1:]
953 |
954 | return output
955 |
956 | return DocParserModelOutput(
957 | last_hidden_state=sequence_output,
958 | pooler_output=pooled_output,
959 | hidden_states=encoder_outputs.hidden_states,
960 | attentions=encoder_outputs.attentions,
961 | reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
962 | )
963 |
--------------------------------------------------------------------------------
/mydatasets/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: 2021/6/8
3 | from .docparser_dataset import DocParser, DataCollatorForDocParserDataset
4 |
5 |
6 | def get_dataset(dataset_args):
7 | dataset_type = dataset_args.get("type")
8 | dataset = eval(dataset_type)(**dataset_args)
9 | return dataset
10 |
11 |
--------------------------------------------------------------------------------
/mydatasets/docparser_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: @time: 6/6/23 10:30
3 | """
4 | Dataloader for Pretraining Task of DocParser
5 |
6 | Masked Document Reading Step After the knowledge transfer step, we
7 | pre-train our model on the task of document reading. In this pre-training phase,
8 | the model learns to predict the next textual token while conditioning on the
9 | previous textual tokens and the input image. To encourage joint reasoning, we
10 | mask several 32 × 32 blocks representing approximately fifteen percent of the
11 | input image. In fact, in order to predict the text situated within the masked
12 | regions, the model is obliged to understand its textual context.
13 |
14 | """
15 | import os
16 | import os.path
17 | import random
18 | from dataclasses import dataclass
19 | from typing import Any, Dict, List, Tuple, Sequence
20 |
21 | import torch
22 | from PIL import Image, ImageFile
23 | from torch.utils.data import Dataset
24 | from tqdm import tqdm
25 | from transformers.modeling_utils import PreTrainedModel
26 |
27 | from base.common_util import load_json
28 |
29 | ImageFile.LOAD_TRUNCATED_IMAGES = True
30 |
31 |
32 | # Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
33 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
34 | """
35 | Shift input ids one token to the right.
36 | """
37 | shifted_input_ids = input_ids.new_zeros(input_ids.shape)
38 | shifted_input_ids[1:] = input_ids[:-1].clone()
39 | if decoder_start_token_id is None:
40 | raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
41 | shifted_input_ids[0] = decoder_start_token_id
42 |
43 | if pad_token_id is None:
44 | raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
45 | # replace possible -100 values in labels by `pad_token_id`
46 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
47 |
48 | return shifted_input_ids
49 |
50 |
51 | class DocParser(Dataset):
52 | """
53 | DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
54 | Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
55 | and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string)
56 |
57 | Args:
58 | data_root: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
59 | ignore_id: ignore_index for torch.nn.CrossEntropyLoss
60 | task_start_token: the special token to be fed to the decoder to conduct the target task
61 | """
62 |
63 | def __init__(
64 | self,
65 | data_root: list,
66 | donut_model: PreTrainedModel,
67 | processor,
68 | max_length: int,
69 | phase: str = "train",
70 | ignore_id: int = -100,
71 | task_start_token: str = "",
72 | prompt_end_token: str = None,
73 | sort_json_key: bool = True,
74 | **kwargs
75 | ):
76 | super().__init__()
77 |
78 | self.donut_model = donut_model
79 | self.processor = processor
80 | self.max_length = max_length
81 | self.phase = phase
82 | self.ignore_id = ignore_id
83 | self.task_start_token = task_start_token
84 | self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
85 | self.sort_json_key = sort_json_key
86 | gt_info_list = []
87 | self.img_path_list = []
88 | print("processing json to token sequence...")
89 | for data_dir in data_root:
90 | for gt_info in load_json(data_dir):
91 | gt_info_list.extend(gt_info)
92 |
93 | self.dataset_length = len(gt_info_list)
94 | self.gt_token_sequences = []
95 | self.special_token_list = []
96 |
97 | for gt_info in tqdm(gt_info_list):
98 | gt_token_sequence = self.json2token(
99 | gt_info['extract_info'],
100 | update_special_tokens_for_json_key=self.phase == "train",
101 | sort_json_key=self.sort_json_key,
102 | ) + self.processor.tokenizer.eos_token
103 | self.gt_token_sequences.append(gt_token_sequence)
104 | self.img_path_list.append(os.path.join(gt_info['filepath'], gt_info['filename']))
105 |
106 | # add special token
107 | list_of_tokens = [self.task_start_token, self.prompt_end_token]
108 |
109 | self.add_tokens(list_of_tokens)
110 | self.donut_model.decoder.resize_token_embeddings(len(self.processor.tokenizer))
111 |
112 | # patch config
113 | self.height, self.width = self.processor.image_processor.size['height'], self.processor.image_processor.size[
114 | 'width']
115 | self.num_patches = self.height // 32 * self.width // 32
116 | self.mask_tensor = torch.zeros(3, 32, 32)
117 |
118 | def add_tokens(self, list_of_tokens: List[str]):
119 | """
120 | Add special tokens to tokenizer and resize the token embeddings of the decoder
121 | """
122 | newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
123 | if newly_added_num > 0:
124 | self.special_token_list.extend(list_of_tokens)
125 |
126 | def json2token(self, obj: Any,
127 | update_special_tokens_for_json_key: bool = True,
128 | sort_json_key: bool = True):
129 | """
130 | Convert an ordered JSON object into a token sequence
131 | """
132 | if type(obj) == dict:
133 | if len(obj) == 1 and "text_sequence" in obj:
134 | return obj["text_sequence"]
135 | else:
136 | output = ""
137 | if sort_json_key:
138 | keys = sorted(obj.keys(), reverse=True)
139 | else:
140 | keys = obj.keys()
141 | for k in keys:
142 | if update_special_tokens_for_json_key:
143 | list_of_tokens = [fr"", fr""]
144 | # add extract token
145 | self.add_tokens(list_of_tokens)
146 | output += (
147 | fr""
148 | + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
149 | + fr""
150 | )
151 | return output
152 | elif type(obj) == list:
153 | return r"".join(
154 | [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
155 | )
156 | else:
157 | obj = str(obj)
158 | if f"<{obj}/>" in self.special_token_list:
159 | obj = f"<{obj}/>" # for categorical special tokens
160 | return obj
161 |
162 | def __len__(self) -> int:
163 | return self.dataset_length
164 |
165 | def __getitem__(self, idx: int):
166 | try:
167 | # pixel_tensor
168 | sample = Image.open(self.img_path_list[idx]).convert("RGB")
169 | input_tensor = self.processor(sample, random_padding=self.phase == "train",
170 | do_normalize=False,
171 | return_tensors="pt").pixel_values[0]
172 |
173 | # To encourage joint reasoning, we mask several 32 × 32 blocks
174 | # representing approximately fifteen percent of the input image.
175 | input_tensor = self.mask_document_patch(input_tensor)
176 |
177 | # input_ids
178 | processed_parse = self.gt_token_sequences[idx]
179 | input_ids = self.processor.tokenizer(
180 | processed_parse,
181 | add_special_tokens=False,
182 | max_length=self.max_length,
183 | padding="max_length",
184 | truncation=True,
185 | return_tensors="pt",
186 | )["input_ids"].squeeze(0)
187 |
188 | labels = input_ids.clone()
189 | labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id
190 | except:
191 | random_index = random.randrange(self.__len__())
192 | return self.__getitem__(random_index)
193 | # model doesn't need to predict pad token
194 | return input_tensor, labels, processed_parse
195 |
196 | def mask_document_patch(self, pixel_values):
197 | patch_width = self.width // 32
198 | sample_idx_list = random.sample(list(range(self.num_patches)), int(self.num_patches * 0.15))
199 | for sample_id in sample_idx_list:
200 | row_id = sample_id // patch_width
201 | col_id = sample_id % patch_width
202 | pixel_values[:, row_id * 32: (row_id + 1) * 32, col_id * 32: (col_id + 1) * 32] = self.mask_tensor
203 | return self.processor(pixel_values, return_tensors="pt").pixel_values[0]
204 |
205 |
206 | @dataclass
207 | class DataCollatorForDocParserDataset(object):
208 | """Collate examples for supervised fine-tuning."""
209 |
210 | def __init__(self, **kwargs):
211 | pass
212 |
213 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
214 | batch = dict()
215 | # pixel_values
216 | images = [instance[0] for instance in instances]
217 | batch['pixel_values'] = torch.stack(images)
218 | # labels
219 | labels = [instance[1] for instance in instances]
220 | batch['labels'] = torch.stack(labels)
221 | # processed_parse
222 | batch['processed_parse'] = [instance[2] for instance in instances]
223 | return batch
224 |
225 |
226 |
227 |
228 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | anyconfig
2 | accelerate
3 | munch
4 | torch==2.0.0
5 | torchvision==0.15.1
6 | transformers==4.28.1
--------------------------------------------------------------------------------
/train/train_experiment.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # create: 2021/6/10
3 | import os
4 | import sys
5 | import argparse
6 | import setproctitle
7 |
8 | PROJECT_ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9 | sys.path.append(PROJECT_ROOT_PATH)
10 | os.environ['RUN_ON_GPU_IDs'] = "0"
11 |
12 | import experiment
13 |
14 | from base.common_util import get_absolute_file_path, init_experiment_config
15 | from experiment import get_experiment_name
16 |
17 |
18 | def init_args():
19 | parser = argparse.ArgumentParser(description='trainer args')
20 | parser.add_argument(
21 | '--config_file',
22 | default='config/base.yaml',
23 | type=str,
24 | )
25 | parser.add_argument(
26 | '--experiment_name',
27 | default='DocParser',
28 | type=str,
29 | )
30 | parser.add_argument(
31 | '--phase',
32 | default='train',
33 | type=str,
34 | )
35 | parser.add_argument(
36 | '--use_accelerate',
37 | default=False,
38 | type=bool,
39 | )
40 | args = parser.parse_args()
41 | os.environ['WORKSPACE'] = args.experiment_name
42 | return args
43 |
44 |
45 | def main(args):
46 | config = init_experiment_config(args.config_file, args.experiment_name)
47 | config.update({'phase': args.phase,
48 | 'use_accelerate': args.use_accelerate})
49 | experiment_instance = getattr(experiment, get_experiment_name(args.experiment_name))(config)
50 | if args.phase == 'train':
51 | experiment_instance.train()
52 | elif args.phase == 'predict':
53 | experiment_instance.predict()
54 | else:
55 | print("Unimplemented phase: {}".format(args.phase))
56 |
57 |
58 | if __name__ == '__main__':
59 | args = init_args()
60 | setproctitle.setproctitle("{} task for {}".format(args.experiment_name, args.config_file.split('/')[-1]))
61 | main(args)
62 |
--------------------------------------------------------------------------------