14 |
15 | It started as code for the paper:
16 |
17 | **MRN: Multiplexed Routing Network for Incremental Multilingual Text Recognition**
18 | (Accepted by ICCV 2023)
19 |
20 | This project is a toolkit for the novel scenario of Incremental Multilingual Text Recognition (IMLTR), the project supports many incremental learning methods and proposes a more applicable method for IMLTR: Multiplexed Routing Network (MRN) and the corresponding dataset. The project provides an efficient framework to assist in developing new methods and analyzing existing ones under the IMLTR task, and we hope it will advance the IMLTR community.
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | ---
30 | ## Methods
31 | ### Incremental Learning Methods
32 | * [x] Base: Baseline method which simply updates parameters on new tasks.
33 | * [x] Joint: Bound method: data for all tasks are trained at once, an upper bound for the method (Joint_mix means all tasks data mixed in batch, Joint_loader means the consistent proportion of data from each task in a batch)
34 | * [x] [EWC](https://arxiv.org/abs/1612.00796) `[PNAS2017]`: Overcoming catastrophic forgetting in neural networks
35 | * [x] [LwF](https://arxiv.org/abs/1911.07053) `[ECCV2016]`: Learning without Forgetting
36 | * [x] [WA](https://arxiv.org/abs/1911.07053) `[CVPR2020]`: Maintaining Discrimination and Fairness in Class Incremental Learning
37 | * [x] [DER](https://arxiv.org/abs/2103.16788) `[CVPR2021]`: DER: Dynamically Expandable Representation for Class Incremental Learning
38 | * [x] [MRN](https://arxiv.org/abs/2305.14758) `[ICCV2023]`: MRN: Multiplexed Routing Network for Incremental Multilingual Text Recognition
39 |
40 | you can change config `config/crnn_mrn.py` for different il methods or setting.
41 | ```
42 | common=dict(
43 | il="mrn", # joint_mix | joint_loader | base | lwf | wa | ewc | der | mrn
44 | memory="random", # None | random
45 | memory_num=2000,
46 | start_task = 0 # checkpoint start
47 | )
48 | ```
49 |
50 | ### Text Recognition Methods
51 | * [x] [CRNN](https://ieeexplore.ieee.org/abstract/document/7801919) `[TPAMI2017]`: An End-to-End Trainable Neural Network for Image-Based Sequence Recognition and Its Application to Scene Text Recognition
52 | * [x] [TRBA](https://arxiv.org/abs/1904.01906) `[ICCV2019]`: What Is Wrong With Scene Text Recognition Model Comparisons? Dataset and Model Analysis
53 | * [x] [SVTR](https://arxiv.org/abs/2205.00159) `[IJCAI2022]`: SVTR: Scene Text Recognition with a Single Visual Model
54 |
55 | you can change config `config/crnn_mrn.py` for different text recognition modules or setting.
56 | ```
57 | """ Model Architecture """
58 | common=dict(
59 | batch_max_length = 25,
60 | imgH = 32,
61 | imgW = 256,
62 | )
63 | model=dict(
64 | model_name="TRBA",
65 | Transformation = "TPS", #None TPS
66 | FeatureExtraction = "ResNet", #VGG ResNet SVTR
67 | SequenceModeling = "BiLSTM", #None BiLSTM
68 | Prediction = "Attn", #CTC Attn
69 | num_fiducial=20,
70 | input_channel=4,
71 | output_channel=512,
72 | hidden_size=256,
73 | )
74 | ```
75 |
76 |
77 | ## IMLTR Dataset
78 | The Dataset can be downloaded from [BaiduNetdisk](https://pan.baidu.com/s/1Qv4utVzWlLu8UPcBpItHbQ)(passwd:c07h).
79 |
80 | ```
81 | dataset
82 | ├── MLT17_IL
83 | │ ├── test_2017
84 | │ ├── train_2017
85 | ├── MLT19_IL
86 | │ ├── test_2019
87 | │ ├── train_2019
88 | ```
89 |
90 | Incremental MLT17: MLT17 has 68,613 training instances and 16,255 validation instances, which are from 6 scripts and 9 languages: Chinese, Japanese, Korean, Bangla, Arabic, Italian, English, French, and German. The last four use Latin script. Incremental MLT17 use the validation set for test due to the unavailability of test data. Tasks are split by scripts and modeled sequentially. Special symbols are discarded at the preprocessing step as with no linguistic meaning.
91 |
92 | Incremental MLT19: MLT19 has 89,177 text instances coming from 7 scripts. Since the inaccessibility of test set, Incremental MLT19 randomly split the training instances to 9:1 script-by-script, for model training and test. To be consistent with Incremental MLT17 dataset, we discard the Hindi script and also special symbols. Statistics of the two datasets are shown in the following.
93 |
94 | | Dataset | Categories | | | | | | |
95 | |---------|----------------|-------|-------|----------|--------|--------|--------|
96 | | | | Task1 | Task2 | Task3 | Task4 | Task5 | Task6 |
97 | | | | Chinese | Latin | Japanese | Korean | Arabic | Bangla |
98 | | MLT17[^1^] | Train Instance | 2687 | 47411 | 4609 | 5631 | 3711 | 3237 |
99 | | | Test Instance | 529 | 11073 | 1350 | 1230 | 983 | 713 |
100 | | | Train Class | 1895 | 325 | 1620 | 1124 | 73 | 112 |
101 | | MLT19[^2^] | Train Instance | 2897 | 52921 | 5324 | 6107 | 4230 | 3542 |
102 | | | Test Instance | 322 | 5882 | 590 | 679 | 470 | 393 |
103 | | | Train Class | 2086 | 220 | 1728 | 1160 | 73 | 102 |
104 |
105 | [^1^]: Nayef, N., et al. (2017). MLT 2017.
106 | [^2^]: Nayef, N., et al. (2019). MLT 2019.
107 |
108 |
109 | ## Getting Started
110 | ### Dependency
111 | - This work was tested with PyTorch 1.6.0, CUDA 10.1 and python 3.6.
112 | ```
113 | conda create -n mrn python=3.7 -y
114 | conda activate mrn
115 | conda install pytorch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1 cudatoolkit=11.3 -c pytorch -c conda-forge
116 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
117 | ```
118 | - requirements :
119 | ```
120 | pip3 install lmdb pillow torchvision nltk natsort fire tensorboard tqdm opencv-python einops timm mmcv shapely scipy
121 | pip3 install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.1/index.html
122 | ```
123 |
124 | ## Training
125 | ```
126 | python3 tiny_train.py --config=config/crnn_mrn.py --exp_name CRNN_real
127 | ```
128 | ### Arguments
129 | tiny_train.py (as a default, evaluate trained model on IMLTR datasets at the end of training.
130 | * `--select_data`: folder path to training lmdb datasets. `[" ../dataset/MLT17_IL/train_2017", "../dataset/MLT19_IL/train_2019"] `
131 | * `--valid_datas`: folder path to testing lmdb dataset. `[" ../dataset/MLT17_IL/test_2017", "../dataset/MLT19_IL/test_2019"] `
132 | * `--batch_ratio`: assign ratio for each selected data in the batch. default is '1 / number of datasets'.
133 | * `--Aug`: whether to use augmentation |None|Blur|Crop|Rot|
134 |
135 | ### Config Detail
136 | For detailed configuration modifications please use the config file `config/crnn_mrn.py`
137 | ```
138 | common=dict(
139 | exp_name="TRBA_MRN", # Where to store logs and models
140 | il="mrn", # joint_mix | joint_loader | base | lwf | wa | ewc | der | mrn
141 | memory="random", # None | random
142 | memory_num=2000,
143 | batch_max_length = 25,
144 | imgH = 32,
145 | imgW = 256,
146 | manual_seed=111,
147 | start_task = 0
148 | )
149 |
150 | """ Model Architecture """
151 | model=dict(
152 | model_name="TRBA",
153 | Transformation = "TPS", #None TPS
154 | FeatureExtraction = "ResNet", #VGG ResNet
155 | SequenceModeling = "BiLSTM", #None BiLSTM
156 | Prediction = "Attn", #CTC Attn
157 | num_fiducial=20,
158 | input_channel=4,
159 | output_channel=512,
160 | hidden_size=256,
161 | )
162 |
163 |
164 |
165 | """ Optimizer """
166 | optimizer=dict(
167 | schedule="super", #default is super for super convergence, 1 for None, [0.6, 0.8] for the same setting with ASTER
168 | optimizer="adam",
169 | lr=0.0005,
170 | sgd_momentum=0.9,
171 | sgd_weight_decay=0.000001,
172 | milestones=[2000,4000],
173 | lrate_decay=0.1,
174 | rho=0.95,
175 | eps=1e-8,
176 | lr_drop_rate=0.1
177 | )
178 |
179 |
180 | """ Data processing """
181 | train = dict(
182 | saved_model="", # "path to model to continue training"
183 | Aug="None", # |None|Blur|Crop|Rot|ABINet
184 | workers=4,
185 | lan_list=["Chinese","Latin","Japanese", "Korean", "Arabic", "Bangla"],
186 | valid_datas=[
187 | "../dataset/MLT17_IL/test_2017",
188 | "../dataset/MLT19_IL/test_2019"
189 | ],
190 | select_data=[
191 | "../dataset/MLT17_IL/train_2017",
192 | "../dataset/MLT19_IL/train_2019"
193 | ],
194 | batch_ratio="0.5-0.5",
195 | total_data_usage_ratio="1.0",
196 | NED=True,
197 | batch_size=256,
198 | num_iter=10000,
199 | val_interval=5000,
200 | log_multiple_test=None,
201 | grad_clip=5,
202 | )
203 |
204 | ```
205 |
206 | ### Data Analysis
207 | The experimental results of each task are recorded in `data_any.txt` and can be used for analysis of the data.
208 |
209 |
210 | ## Acknowledgements
211 | This implementation has been based on these repositories:
212 | - [STR-Fewer-Labels](https://github.com/ku21fan/STR-Fewer-Labels)
213 | - [PyCIL: A Python Toolbox for Class-Incremental Learning](https://github.com/G-U-N/PyCIL)
214 |
215 | ## Citation
216 | Please consider citing this work in your publications if it helps your research.
217 | ```
218 | @article{zheng2023mrn,
219 | title={MRN: Multiplexed Routing Network for Incremental Multilingual Text Recognition},
220 | author={Zheng, Tianlun and Chen, Zhineng and Huang, BingChen and Zhang, Wei and Jiang, Yu-Gang},
221 | journal={Proceedings of the IEEE/CVF International Conference on Computer Vision},
222 | year={2023}
223 | }
224 | ```
225 |
226 | ## License
227 | This project is released under the Apache 2.0 license.
228 |
--------------------------------------------------------------------------------
/tools/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import PIL
4 | import numpy as np
5 | import torch
6 |
7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8 |
9 |
10 | class CTCLabelConverter(object):
11 | """Convert between text-label and text-index"""
12 |
13 | def __init__(self, character):
14 | # character (str): set of the possible characters.
15 | list_special_token = [
16 | "[PAD]",
17 | "[UNK]",
18 | " ",
19 | ] # [UNK] for unknown character, ' ' for space.
20 | list_character = list(character)
21 | dict_character = list_special_token + list_character
22 |
23 | self.dict = {}
24 | for i, char in enumerate(dict_character):
25 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss, not same with space ' '.
26 | # print(type(char))
27 | self.dict[char] = i + 1
28 |
29 | self.character = [
30 | "[CTCblank]"
31 | ] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0).
32 | print(f"# characters dict has: {len(self.character)}")
33 | # print(f"\n {self.character}\n")
34 |
35 | def encode(self, word_string, batch_max_length=25):
36 | """convert word_list (string) into word_index.
37 | input:
38 | word_string: word labels of each image. [batch_size]
39 | batch_max_length: max length of word in the batch. Default: 25
40 |
41 | output:
42 | word_index: word index list for CTCLoss. [batch_size, batch_max_length]
43 | word_length: length of each word. [batch_size]
44 | """
45 | word_length = [len(word) for word in word_string]
46 |
47 | # The index used for padding (=[PAD]) would not affect the CTC loss calculation.
48 | word_index = torch.LongTensor(len(word_string), batch_max_length).fill_(
49 | self.dict["[PAD]"]
50 | )
51 |
52 | for i, word in enumerate(word_string):
53 | word = list(word)
54 | word_idx = [
55 | self.dict[char] if char in self.dict else self.dict["[UNK]"]
56 | for char in word
57 | ]
58 | word_index[i][: len(word_idx)] = torch.LongTensor(word_idx)
59 |
60 | return (word_index.to(device), torch.IntTensor(word_length).to(device))
61 |
62 | def decode(self, word_index, word_length):
63 | """convert word_index into word_string"""
64 | word_string = []
65 | for idx, length in enumerate(word_length):
66 | word_idx = word_index[idx, :]
67 |
68 | char_list = []
69 | for i in range(length):
70 | # removing repeated characters and blank.
71 | if word_idx[i] != 0 and not (i > 0 and word_idx[i - 1] == word_idx[i]):
72 | char_list.append(self.character[word_idx[i]])
73 |
74 | word = "".join(char_list)
75 | word_string.append(word)
76 | return word_string
77 |
78 |
79 | class AttnLabelConverter(object):
80 | """Convert between text-label and text-index"""
81 |
82 | def __init__(self, character):
83 | # character (str): set of the possible characters.
84 | # [SOS] (start-of-sentence token) and [EOS] (end-of-sentence token) for the attention decoder.
85 | list_special_token = [
86 | "[UNK]",
87 | "[PAD]",
88 | "[SOS]",
89 | "[EOS]",
90 | " ",
91 | ] # [UNK] for unknown character, ' ' for space.
92 | list_character = list(character)
93 | self.character = list_special_token + list_character
94 |
95 | self.dict = {}
96 | for i, char in enumerate(self.character):
97 | # print(i, char)
98 | self.dict[char] = i
99 |
100 | print(f"# of tokens and characters: {len(self.character)}")
101 |
102 | def encode(self, word_string, batch_max_length=25):
103 | """convert word_list (string) into word_index.
104 | input:
105 | word_string: word labels of each image. [batch_size]
106 | batch_max_length: max length of word in the batch. Default: 25
107 |
108 | output:
109 | word_index : the input of attention decoder. [batch_size x (max_length+2)] +1 for [SOS] token and +1 for [EOS] token.
110 | word_length : the length of output of attention decoder, which count [EOS] token also. [batch_size]
111 | """
112 | word_length = [
113 | len(word) + 1 for word in word_string
114 | ] # +1 for [EOS] at end of sentence.
115 | batch_max_length += 1
116 |
117 | # additional batch_max_length + 1 for [SOS] at first step.
118 | word_index = torch.LongTensor(len(word_string), batch_max_length + 1).fill_(
119 | self.dict["[PAD]"]
120 | )
121 | word_index[:, 0] = self.dict["[SOS]"]
122 |
123 | for i, word in enumerate(word_string):
124 | word = list(word)
125 | word.append("[EOS]")
126 | word_idx = [
127 | self.dict[char] if char in self.dict else self.dict["[UNK]"]
128 | for char in word
129 | ]
130 | word_index[i][1 : 1 + len(word_idx)] = torch.LongTensor(
131 | word_idx
132 | ) # word_index[:, 0] = [SOS] token
133 |
134 | return (word_index.to(device), torch.IntTensor(word_length).to(device))
135 |
136 | def decode(self, word_index, word_length):
137 | """convert word_index into word_string"""
138 | word_string = []
139 | for idx, length in enumerate(word_length):
140 | word_idx = word_index[idx, :length]
141 | word = "".join([self.character[i] for i in word_idx])
142 | word_string.append(word)
143 | return word_string
144 |
145 |
146 | class Averager(object):
147 | """Compute average for torch.Tensor, used for loss average."""
148 |
149 | def __init__(self):
150 | self.reset()
151 |
152 | def add(self, v):
153 | count = v.data.numel()
154 | v = v.data.sum()
155 | self.n_count += count
156 | self.sum += v
157 |
158 | def reset(self):
159 | self.n_count = 0
160 | self.sum = 0
161 |
162 | def val(self):
163 | res = 0
164 | if self.n_count != 0:
165 | res = self.sum / float(self.n_count)
166 | return res
167 |
168 |
169 | def adjust_learning_rate(optimizer, iteration, opt):
170 | """Decay the learning rate based on schedule"""
171 | lr = opt.lr
172 | # stepwise lr schedule
173 | for milestone in opt.schedule:
174 | lr *= (
175 | opt.lr_drop_rate if iteration >= (float(milestone) * opt.num_iter) else 1.0
176 | )
177 | for param_group in optimizer.param_groups:
178 | param_group["lr"] = lr
179 |
180 |
181 | def tensor2im(image_tensor, imtype=np.uint8):
182 | image_numpy = image_tensor.cpu().float().numpy()
183 | if image_numpy.shape[0] == 1:
184 | image_numpy = np.tile(image_numpy, (3, 1, 1))
185 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
186 | return image_numpy.astype(imtype)
187 |
188 |
189 | def save_image(image_numpy, image_path):
190 | image_pil = PIL.Image.fromarray(image_numpy)
191 | image_pil.save(image_path)
192 |
193 | def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
194 | """Crop text region with their bounding box.
195 |
196 | Args:
197 | src_img (np.array): The original image.
198 | box (list[float | int]): Points of quadrangle.
199 | long_edge_pad_ratio (float): Box pad ratio for long edge
200 | corresponding to font size.
201 | short_edge_pad_ratio (float): Box pad ratio for short edge
202 | corresponding to font size.
203 | """
204 | assert utils.is_type_list(box, (float, int))
205 | assert len(box) == 8
206 | assert 0. <= long_edge_pad_ratio < 1.0
207 | assert 0. <= short_edge_pad_ratio < 1.0
208 |
209 | h, w = src_img.shape[:2]
210 | points_x = np.clip(np.array(box[0::2]), 0, w)
211 | points_y = np.clip(np.array(box[1::2]), 0, h)
212 |
213 | box_width = np.max(points_x) - np.min(points_x)
214 | box_height = np.max(points_y) - np.min(points_y)
215 | font_size = min(box_height, box_width)
216 |
217 | if box_height < box_width:
218 | horizontal_pad = long_edge_pad_ratio * font_size
219 | vertical_pad = short_edge_pad_ratio * font_size
220 | else:
221 | horizontal_pad = short_edge_pad_ratio * font_size
222 | vertical_pad = long_edge_pad_ratio * font_size
223 |
224 | left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w)
225 | top = np.clip(int(np.min(points_y) - vertical_pad), 0, h)
226 | right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w)
227 | bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h)
228 |
229 | dst_img = src_img[top:bottom, left:right]
230 |
231 | return dst_img
232 |
233 | def read_txt(path):
234 | f = open(path)
235 | list = []
236 | char_dict = {}
237 | line = f.readline()
238 | while line:
239 | list.append(line.strip("\n"))
240 | # print(line)
241 | line = f.readline()
242 | f.close()
243 | for str in list:
244 | for char in str:
245 | if char_dict.get(char, None) == None:
246 | char_dict[char] = 1
247 | else:
248 | char_dict[char] += 1
249 | return char_dict
250 | def dict_total(path='.txt',
251 | path_a='_all.txt'):
252 | root = '/share/home/ztl/CIL_MLSTR/exp/base/'
253 | language = "Japanese"
254 | path_ = language+'test.txt'
255 | char_list = []
256 | true_char = read_txt(root + language + path)
257 | total_char = read_txt(root + language + path_a)
258 | for key, value in total_char.items():
259 | acc = true_char.get(key, 0) / total_char[key]
260 | char_list.append([key,value,acc])
261 | print([key,value,acc])
262 | pred_list = sorted(char_list,key=lambda list: list[1])
263 | start_i = 0
264 | for i,list in enumerate(pred_list):
265 | if i != 0 and list[1] != pred_list[i-1][1]:
266 | avg = acc / (i- start_i)
267 | # avg = acc / (i + 1)
268 | str_log = "avg {} char is {:.2f} total {}\n".format(pred_list[i-1][1],avg,i - start_i)
269 | print(str_log)
270 | with open(root + path_, "a") as log:
271 | log.write(str_log)
272 | start_i = i
273 | acc = 0
274 | acc += list[2]
275 | with open(root + path_, "a") as log:
276 | for line in pred_list:
277 | log.write(str(line) + "\n")
278 | # dict_total()
279 |
280 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import six
4 | import random
5 |
6 | from natsort import natsorted
7 | import PIL
8 | import lmdb
9 | import torch
10 | from torch.utils.data import Dataset, ConcatDataset
11 | import torchvision.transforms as transforms
12 |
13 | from data.transform import CVGeometry, CVDeterioration, CVColorJitter
14 |
15 | def hierarchical_dataset(root, opt, select_data="/", data_type="label", mode="train"):
16 | """select_data='/' contains all sub-directory of root directory"""
17 | dataset_list = []
18 | dataset_log = f"dataset_root: {root}\t dataset: {select_data}"
19 | print(dataset_log)
20 | dataset_log += "\n"
21 | for dirpath, dirnames, filenames in os.walk(root + "/"):
22 | if not dirnames:
23 | select_flag = False
24 | for selected_d in select_data:
25 | if selected_d in dirpath:
26 | select_flag = True
27 | break
28 |
29 | if select_flag:
30 | # if data_type == "label":
31 | dataset = LmdbDataset(dirpath, opt, mode=mode)
32 | # else:
33 | # dataset = LmdbDataset_unlabel(dirpath, opt)
34 | sub_dataset_log = f"sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}"
35 | print(sub_dataset_log)
36 | dataset_log += f"{sub_dataset_log}\n"
37 | dataset_list.append(dataset)
38 |
39 | concatenated_dataset = ConcatDataset(dataset_list)
40 |
41 | return concatenated_dataset, dataset_log
42 |
43 |
44 | class LmdbDataset(Dataset):
45 | def __init__(self, root, opt, mode="train"):
46 |
47 | self.root = root
48 | skip = 0
49 | self.opt = opt
50 | self.mode = mode
51 | self.env = lmdb.open(
52 | root,
53 | max_readers=32,
54 | readonly=True,
55 | lock=False,
56 | readahead=False,
57 | meminit=False,
58 | )
59 | if not self.env:
60 | print("cannot open lmdb from %s" % (root))
61 | sys.exit(0)
62 |
63 | with self.env.begin(write=False) as txn:
64 | self.nSamples = int(txn.get("num-samples".encode()))
65 | print(self.nSamples)
66 | self.filtered_index_list = []
67 | for index in range(self.nSamples):
68 | index += 1 # lmdb starts with 1
69 | label_key = "label-%09d".encode() % index
70 | # print(label_key)
71 | if txn.get(label_key)==None:
72 | skip+=1
73 | print("skip --- {}\n".format(skip))
74 | continue
75 | label = txn.get(label_key).decode("utf-8")
76 | # print(label)
77 |
78 | # length filtering
79 | length_of_label = len(label)
80 | if length_of_label > opt.batch_max_length:
81 | continue
82 |
83 | self.filtered_index_list.append(index)
84 |
85 | self.nSamples = len(self.filtered_index_list)
86 |
87 | def __len__(self):
88 | return self.nSamples
89 |
90 | def __getitem__(self, index):
91 | assert index <= len(self), "index range error"
92 | index = self.filtered_index_list[index]
93 |
94 | with self.env.begin(write=False) as txn:
95 | label_key = "label-%09d".encode() % index
96 | label = txn.get(label_key).decode("utf-8")
97 | img_key = "image-%09d".encode() % index
98 | imgbuf = txn.get(img_key)
99 | buf = six.BytesIO()
100 | buf.write(imgbuf)
101 | buf.seek(0)
102 |
103 | try:
104 | img = PIL.Image.open(buf).convert("RGBA")
105 |
106 | except IOError:
107 | print(f"Corrupted image for {index}")
108 | # make dummy image and dummy label for corrupted image.
109 | img = PIL.Image.new("RGBA", (self.opt.imgW, self.opt.imgH))
110 | label = "[dummy_label]"
111 |
112 | return (img, label)
113 |
114 |
115 | class RawDataset(Dataset):
116 | def __init__(self, root, opt):
117 | self.opt = opt
118 | self.image_path_list = []
119 | for dirpath, dirnames, filenames in os.walk(root):
120 | for name in filenames:
121 | _, ext = os.path.splitext(name)
122 | ext = ext.lower()
123 | if ext == ".jpg" or ext == ".jpeg" or ext == ".png":
124 | self.image_path_list.append(os.path.join(dirpath, name))
125 |
126 | self.image_path_list = natsorted(self.image_path_list)
127 | self.nSamples = len(self.image_path_list)
128 |
129 | def __len__(self):
130 | return self.nSamples
131 |
132 | def __getitem__(self, index):
133 |
134 | try:
135 | img = PIL.Image.open(self.image_path_list[index]).convert("RGBA")
136 |
137 | except IOError:
138 | print(f"Corrupted image for {index}")
139 | # make dummy image and dummy label for corrupted image.
140 | img = PIL.Image.new("RGBA", (self.opt.imgW, self.opt.imgH))
141 |
142 | return (img, self.image_path_list[index])
143 |
144 | class AlignCollate2(object):
145 | def __init__(self, opt, mode="train"):
146 | self.opt = opt
147 | self.mode = mode
148 |
149 | if opt.Aug == "None" or mode != "train":
150 | self.transform = ResizeNormalize((opt.imgW, opt.imgH))
151 | elif opt.Aug == "ABINet" and mode == "train":
152 | self.transform = transforms.Compose([
153 | CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
154 | CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
155 | CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25),
156 | transforms.Resize(
157 | (self.opt.imgH, self.opt.imgW), interpolation=PIL.Image.BICUBIC
158 | ),
159 | transforms.ToTensor(),
160 | ])
161 | else:
162 | self.transform = Text_augment(opt)
163 |
164 | def __call__(self, batch):
165 | b_info, index = zip(*batch)
166 | images, labels = zip(*b_info)
167 | image_tensors = [self.transform(image) for image in images]
168 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
169 |
170 | return image_tensors, labels , index
171 |
172 | class AlignCollate(object):
173 | def __init__(self, opt, mode="train"):
174 | self.opt = opt
175 | self.mode = mode
176 |
177 | if opt.Aug == "None" or mode != "train":
178 | self.transform = ResizeNormalize((opt.imgW, opt.imgH))
179 | elif opt.Aug == "ABINet" and mode == "train":
180 | self.transform = transforms.Compose([
181 | CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
182 | CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
183 | CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25),
184 | transforms.Resize(
185 | (self.opt.imgH, self.opt.imgW), interpolation=PIL.Image.BICUBIC
186 | ),
187 | transforms.ToTensor(),
188 | ])
189 | else:
190 | self.transform = Text_augment(opt)
191 |
192 | def __call__(self, batch):
193 | images, labels = zip(*batch)
194 | image_tensors = [self.transform(image) for image in images]
195 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
196 |
197 | return image_tensors, labels
198 |
199 | class GaussianBlur(object):
200 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
201 |
202 | def __init__(self, sigma=[0.1, 2.0]):
203 | self.sigma = sigma
204 |
205 | def __call__(self, image):
206 | sigma = random.uniform(self.sigma[0], self.sigma[1])
207 | image = image.filter(PIL.ImageFilter.GaussianBlur(radius=sigma))
208 | return image
209 |
210 |
211 | class RandomCrop(object):
212 | """RandomCrop,
213 | RandomResizedCrop of PyTorch 1.6 and torchvision 0.7.0 work weird with scale 0.90-1.0.
214 | i.e. you can not always make 90%~100% cropped image scale 0.90-1.0, you will get central cropped image instead.
215 | so we made RandomCrop (keeping aspect ratio version) then use Resize.
216 | """
217 |
218 | def __init__(self, scale=[1, 1]):
219 | self.scale = scale
220 |
221 | def __call__(self, image):
222 | width, height = image.size
223 | crop_ratio = random.uniform(self.scale[0], self.scale[1])
224 | crop_width = int(width * crop_ratio)
225 | crop_height = int(height * crop_ratio)
226 |
227 | x_start = random.randint(0, width - crop_width)
228 | y_start = random.randint(0, height - crop_height)
229 | image_crop = image.crop(
230 | (x_start, y_start, x_start + crop_width, y_start + crop_height)
231 | )
232 | return image_crop
233 |
234 |
235 | class ResizeNormalize(object):
236 | def __init__(self, size, interpolation=PIL.Image.BICUBIC):
237 | # CAUTION: it should be (width, height). different from size of transforms.Resize (height, width)
238 | self.size = size
239 | self.interpolation = interpolation
240 | self.toTensor = transforms.ToTensor()
241 |
242 | def __call__(self, image):
243 | image = image.resize(self.size, self.interpolation)
244 | image = self.toTensor(image)
245 | image.sub_(0.5).div_(0.5)
246 | return image
247 |
248 |
249 | class Text_augment(object):
250 | """Augmentation for Text recognition"""
251 |
252 | def __init__(self, opt):
253 | self.opt = opt
254 | augmentation = []
255 | aug_list = self.opt.Aug.split("-")
256 | for aug in aug_list:
257 | if aug.startswith("Blur"):
258 | maximum = float(aug.strip("Blur"))
259 | augmentation.append(
260 | transforms.RandomApply([GaussianBlur([0.1, maximum])], p=0.5)
261 | )
262 |
263 | if aug.startswith("Crop"):
264 | crop_scale = float(aug.strip("Crop")) / 100
265 | augmentation.append(RandomCrop(scale=(crop_scale, 1.0)))
266 |
267 | if aug.startswith("Rot"):
268 | degree = int(aug.strip("Rot"))
269 | augmentation.append(
270 | transforms.RandomRotation(
271 | degree, resample=PIL.Image.BICUBIC, expand=True
272 | )
273 | )
274 |
275 | augmentation.append(
276 | transforms.Resize(
277 | (self.opt.imgH, self.opt.imgW), interpolation=PIL.Image.BICUBIC
278 | )
279 | )
280 | augmentation.append(transforms.ToTensor())
281 | self.Augment = transforms.Compose(augmentation)
282 | print("Use Text_augment", augmentation)
283 |
284 | def __call__(self, image):
285 | image = self.Augment(image)
286 | image.sub_(0.5).div_(0.5)
287 |
288 | return image
289 |
290 |
291 | class MoCo_augment(object):
292 | """Take two random crops of one image as the query and key."""
293 |
294 | def __init__(self, opt):
295 | self.opt = opt
296 |
297 | # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978
298 | augmentation = [
299 | transforms.RandomResizedCrop(
300 | (opt.imgH, opt.imgW), scale=(0.2, 1.0), interpolation=PIL.Image.BICUBIC
301 | ),
302 | transforms.RandomGrayscale(p=0.2),
303 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
304 | transforms.RandomHorizontalFlip(),
305 | transforms.ToTensor(),
306 | ]
307 |
308 | self.Augment = transforms.Compose(augmentation)
309 | print("Use MoCo_augment", augmentation)
310 |
311 | def __call__(self, x):
312 | q = self.Augment(x)
313 | k = self.Augment(x)
314 | q.sub_(0.5).div_(0.5)
315 | k.sub_(0.5).div_(0.5)
316 |
317 | return [q, k]
318 |
--------------------------------------------------------------------------------
/tools/crop_by_word.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import glob
4 | import os
5 | import os.path as osp
6 | import re
7 |
8 | import mmcv
9 | import numpy as np
10 | from shapely.geometry import Polygon
11 |
12 | def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
13 | """Crop text region with their bounding box.
14 |
15 | Args:
16 | src_img (np.array): The original image.
17 | box (list[float | int]): Points of quadrangle.
18 | long_edge_pad_ratio (float): Box pad ratio for long edge
19 | corresponding to font size.
20 | short_edge_pad_ratio (float): Box pad ratio for short edge
21 | corresponding to font size.
22 | """
23 | # assert utils.is_type_list(box, (float, int))
24 | assert len(box) == 8
25 | assert 0. <= long_edge_pad_ratio < 1.0
26 | assert 0. <= short_edge_pad_ratio < 1.0
27 |
28 | h, w = src_img.shape[:2]
29 | points_x = np.clip(np.array(box[0::2]), 0, w)
30 | points_y = np.clip(np.array(box[1::2]), 0, h)
31 |
32 | box_width = np.max(points_x) - np.min(points_x)
33 | box_height = np.max(points_y) - np.min(points_y)
34 | font_size = min(box_height, box_width)
35 |
36 | if box_height < box_width:
37 | horizontal_pad = long_edge_pad_ratio * font_size
38 | vertical_pad = short_edge_pad_ratio * font_size
39 | else:
40 | horizontal_pad = short_edge_pad_ratio * font_size
41 | vertical_pad = long_edge_pad_ratio * font_size
42 |
43 | left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w)
44 | top = np.clip(int(np.min(points_y) - vertical_pad), 0, h)
45 | right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w)
46 | bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h)
47 |
48 | dst_img = src_img[top:bottom, left:right]
49 |
50 | return dst_img
51 |
52 | def test_crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
53 | pts1 = np.float32([[wordBB[0][0], wordBB[1][0]],
54 | [wordBB[0][3], wordBB[1][3]],
55 | [wordBB[0][1], wordBB[1][1]],
56 | [wordBB[0][2], wordBB[1][2]]])
57 | height = math.sqrt((wordBB[0][0] - wordBB[0][3]) ** 2 + (wordBB[1][0] - wordBB[1][3]) ** 2)
58 | width = math.sqrt((wordBB[0][0] - wordBB[0][1]) ** 2 + (wordBB[1][0] - wordBB[1][1]) ** 2)
59 |
60 | # Coord validation check
61 | if (height * width) <= 0:
62 | err_log = 'empty file : {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB)
63 | err_file.write(err_log)
64 | # print(err_log)
65 | continue
66 | elif (height * width) > (img_height * img_width):
67 | err_log = 'too big box : {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB)
68 | err_file.write(err_log)
69 | # print(err_log)
70 | continue
71 | else:
72 | valid = True
73 | for i in range(2):
74 | for j in range(4):
75 | if wordBB[i][j] < 0 or wordBB[i][j] > img.shape[1 - i]:
76 | valid = False
77 | break
78 | if not valid:
79 | break
80 | if not valid:
81 | err_log = 'invalid coord : {}\t{}\t{}\t{}\t{}\n'.format(
82 | image_name, txt[word_indx], wordBB, (width, height), (img_width, img_height))
83 | err_file.write(err_log)
84 | # print(err_log)
85 | continue
86 |
87 | pts2 = np.float32([[0, 0],
88 | [0, height],
89 | [width, 0],
90 | [width, height]])
91 |
92 | x_min = np.int(round(min(wordBB[0][0], wordBB[0][1], wordBB[0][2], wordBB[0][3])))
93 | x_max = np.int(round(max(wordBB[0][0], wordBB[0][1], wordBB[0][2], wordBB[0][3])))
94 | y_min = np.int(round(min(wordBB[1][0], wordBB[1][1], wordBB[1][2], wordBB[1][3])))
95 | y_max = np.int(round(max(wordBB[1][0], wordBB[1][1], wordBB[1][2], wordBB[1][3])))
96 | # print(x_min, x_max, y_min, y_max)
97 | # print(img.shape)
98 | # assert 1<0
99 | if len(img.shape) == 3:
100 | img_cropped = img[y_min:y_max:1, x_min:x_max:1, :]
101 | else:
102 | img_cropped = img[y_min:y_max:1, x_min:x_max:1]
103 |
104 | def list_to_file(filename, lines):
105 | """Write a list of strings to a text file.
106 |
107 | Args:
108 | filename (str): The output filename. It will be created/overwritten.
109 | lines (list(str)): Data to be written.
110 | """
111 | mmcv.mkdir_or_exist(os.path.dirname(filename))
112 | with open(filename, 'w', encoding='utf-8') as fw:
113 | for line in lines:
114 | fw.write(f'{line}\n')
115 |
116 | def list_from_file(filename, encoding='utf-8'):
117 | """Load a text file and parse the content as a list of strings. The
118 | trailing "\\r" and "\\n" of each line will be removed.
119 |
120 | Note:
121 | This will be replaced by mmcv's version after it supports encoding.
122 |
123 | Args:
124 | filename (str): Filename.
125 | encoding (str): Encoding used to open the file. Default utf-8.
126 |
127 | Returns:
128 | list[str]: A list of strings.
129 | """
130 | item_list = []
131 | with open(filename, 'r', encoding=encoding) as f:
132 | for line in f:
133 | item_list.append(line.rstrip('\n\r'))
134 | return item_list
135 |
136 |
137 | def load_img_info(file):
138 | """Load the information of one image.
139 |
140 | Args:
141 | files(tuple): The tuple of (img_file, groundtruth_file)
142 | dataset(str): Dataset name, icdar2015 or icdar2017
143 |
144 | Returns:
145 | img_info(dict): The dict of the img and annotation information
146 | """
147 | # assert isinstance(files, tuple)
148 | # assert isinstance(dataset, str)
149 | # assert dataset
150 |
151 | # img_file, gt_file = files
152 | # read imgs with ignoring orientations
153 | # img = mmcv.imread(img_file, 'unchanged')
154 | gt_file = file[1]
155 | img_file = file[0]
156 | img = mmcv.imread(img_file, 'unchanged')
157 |
158 | split_name = osp.basename(osp.dirname(img_file))
159 | img_info = dict(
160 | # remove img_prefix for filename
161 | file_name=img_file,
162 | height=img.shape[0],
163 | width=img.shape[1],)
164 | # img_file
165 | # print("gt_file{}".format(gt_file))
166 | gt_list = list_from_file(gt_file)
167 |
168 | anno_info = []
169 | # img_info = {}
170 | for line in gt_list:
171 | # each line has one ploygen (4 vetices), and others.
172 | # e.g., 695,885,866,888,867,1146,696,1143,Latin,9
173 | line = line.strip()
174 | strs = line.split(',')
175 | category_id = 1
176 | xy = [float(x) for x in strs[0:8]]
177 | coordinates = np.array(xy).reshape(-1, 2)
178 | polygon = Polygon(coordinates)
179 |
180 | area = polygon.area
181 | # convert to COCO style XYWH format
182 | min_x, min_y, max_x, max_y = polygon.bounds
183 | # bbox = [min_x, min_y, max_x - min_x, max_y - min_y]
184 | bbox = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
185 | # bbox = [min_x, min_y, max_x - min_x, max_y - min_y]
186 | anno = dict(word=strs[9], bbox=bbox)
187 | anno_info.append(anno)
188 | # print(anno)
189 | img_info.update(anno_info=anno_info)
190 | # print(img_info)
191 | return img_info
192 |
193 |
194 | def collect_files(img_dir, gt_dir):
195 | """Collect all images and their corresponding groundtruth files.
196 |
197 | Args:
198 | img_dir(str): The image directory
199 | gt_dir(str): The groundtruth directory
200 | split(str): The split of dataset. Namely: training or test
201 | Returns:
202 | files(list): The list of tuples (img_file, groundtruth_file)
203 | """
204 | assert isinstance(img_dir, str)
205 | assert img_dir
206 | assert isinstance(gt_dir, str)
207 | assert gt_dir
208 |
209 | # note that we handle png and jpg only. Pls convert others such as gif to
210 | # jpg or png offline
211 | suffixes = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG']
212 | # suffixes = ['.png']
213 |
214 | imgs_list = []
215 | for suffix in suffixes:
216 | imgs_list.extend(glob.glob(osp.join(img_dir, '*' + suffix)))
217 |
218 | imgs_list = sorted(imgs_list)
219 | ann_list = sorted(
220 | [osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir)])
221 |
222 | files = [(img_file, gt_file)
223 | for (img_file, gt_file) in zip(imgs_list, ann_list)]
224 | assert len(files), f'No images found in {img_dir}'
225 | print(f'Loaded {len(files)} images from {img_dir}')
226 |
227 | return files
228 |
229 |
230 | def collect_annotations(files, nproc=1):
231 | """Collect the annotation information.
232 |
233 | Args:
234 | files(list): The list of tuples (image_file, groundtruth_file)
235 | nproc(int): The number of process to collect annotations
236 | Returns:
237 | images(list): The list of image information dicts
238 | """
239 | assert isinstance(files, list)
240 | assert isinstance(nproc, int)
241 |
242 | if nproc > 1:
243 | images = mmcv.track_parallel_progress(
244 | load_img_info, files, nproc=nproc)
245 | else:
246 | images = mmcv.track_progress(load_img_info, files)
247 |
248 | return images
249 |
250 |
251 | def generate_ann(root_path, image_infos, out_dir):
252 | """Generate cropped annotations and label txt file.
253 |
254 | Args:
255 | root_path(str): The relative path of the totaltext file
256 | split(str): The split of dataset. Namely: training or test
257 | image_infos(list[dict]): A list of dicts of the img and
258 | annotation information
259 | """
260 |
261 | dst_image_root = osp.join(out_dir, 'imgs')
262 | dst_label_file = osp.join(out_dir, 'label.txt')
263 | os.makedirs(dst_image_root, exist_ok=True)
264 |
265 | lines = []
266 | for image_info in image_infos:
267 | index = 1
268 | src_img_path = image_info['file_name']
269 | image = mmcv.imread(src_img_path)
270 | # src_img_root = osp.splitext(image_info['file_name'])[0].split('/')[1]
271 | src_img_root = image_info['file_name'].split('/')[-1].split(".")[0]
272 |
273 | for anno in image_info['anno_info']:
274 | word = anno['word']
275 | dst_img = crop_img(image, anno['bbox'])
276 |
277 | # Skip invalid annotations
278 | if min(dst_img.shape) == 0:
279 | continue
280 |
281 | dst_img_name = f'{src_img_root}_{index}.png'
282 | index += 1
283 | dst_img_path = osp.join(dst_image_root, dst_img_name)
284 | mmcv.imwrite(dst_img, dst_img_path)
285 | lines.append(f'{osp.basename(dst_image_root)}/{dst_img_name} '
286 | f'{word}')
287 | # print(lines)
288 | # print("\n")
289 | list_to_file(dst_label_file, lines)
290 |
291 |
292 | def parse_args():
293 | parser = argparse.ArgumentParser(
294 | description='Convert SynthMLT annotations to COCO format')
295 | parser.add_argument('--root_path', help='SynthMLT root path')
296 | parser.add_argument('--lan', default="Hindi", help='languang for data')
297 | parser.add_argument('--out_dir', default="test",help='output path')
298 | # parser.add_argument(
299 | # '--split-list',
300 | # nargs='+',
301 | # help='a list of splits. e.g., "--split_list training test"')
302 |
303 | parser.add_argument(
304 | '--nproc', default=10, type=int, help='number of process')
305 | args = parser.parse_args()
306 | return args
307 |
308 | def unzip(root_path, lan):
309 | # root_path = "../dataset/SynthMLT/"
310 | img_path = "{}{}".format(root_path, lan)
311 | gt_path = "{}{}_gt".format(root_path, lan)
312 |
313 | if not os.path.exists(img_path):
314 | # os.system(f"rm -r {img_path}")
315 | cmd = "unzip -d {} {}.zip".format(img_path, img_path)
316 | os.system(cmd)
317 |
318 |
319 | if not os.path.exists(gt_path):
320 | # os.system(f"rm -r {gt_path}")
321 | cmd = "unzip -d {} {}.zip".format(gt_path, gt_path)
322 | os.system(cmd)
323 |
324 | def main():
325 | args = parse_args()
326 | unzip(args.root_path, args.lan)
327 | # root_path = args.root_path + args.lan
328 | # out_dir = args.root_path + args.out_dir if args.out_dir else args.root_path
329 | out_dir = args.out_dir
330 | root_path = args.root_path + args.lan
331 | mmcv.mkdir_or_exist(out_dir)
332 | out_dir = osp.join(out_dir, args.lan)
333 | print("save to {}\n".format(out_dir))
334 |
335 |
336 |
337 | # root_path = "../dataset/SynthMLT/"
338 | img_dir = "{}/{}/".format(root_path, args.lan)
339 | gt_dir = "{}_gt/{}/".format(root_path, args.lan)
340 |
341 | print(f'Converting SynthMLT to TXT\n')
342 | # print("img dir is {}\n".format(img_dir))
343 | # print("gt dir is {}\n".format(gt_dir))
344 | with mmcv.Timer( print_tmpl='It takes {}s to convert txt annotation'):
345 | files = collect_files(img_dir, gt_dir)
346 | # print("--------------------start------------\n{}".format(files))
347 | image_infos = collect_annotations(files, nproc=args.nproc)
348 | generate_ann(root_path, image_infos,out_dir)
349 | print(out_dir)
350 |
351 |
352 | if __name__ == '__main__':
353 | main()
354 |
--------------------------------------------------------------------------------
/data/data_manage.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import numpy.random
3 | import torch
4 | from torch.utils.data import Dataset, ConcatDataset, Subset
5 | from data.dataset import AlignCollate, LmdbDataset, AlignCollate2, hierarchical_dataset
6 |
7 |
8 | class Dataset_Manager(object):
9 | def __init__(self,opt):
10 | self.data_list = []
11 | self.data_loader_list = []
12 | self.dataloader_iter_list = []
13 | self.select_data = None
14 | self.opt = opt
15 |
16 | def get_dataset(self, taski, memory="random",index_list=None):
17 | self.data_loader_list = []
18 | self.dataloader_iter_list = []
19 | memory_num = self.opt.memory_num
20 |
21 | dataset = self.create_dataset(data_list=self.select_data,taski=taski)
22 |
23 | if memory != None and self.opt.il=="mrn":
24 | # curr: num/(taski-1) mem: num/(taski-1)
25 | index_current = numpy.random.choice(range(len(dataset)),int(self.opt.memory_num/(taski)),replace=False)
26 | split_dataset = Subset(dataset,index_current.tolist())
27 | memory_data,index_list = self.rehearsal_memory(taski, random=False,total_num=self.opt.memory_num,index_array=index_list)
28 | self.create_dataloader_mix(IndexConcatDataset([memory_data,split_dataset]),self.opt.batch_size)
29 | print("taski is {} current dataset chose {}\n now dataset chose {}".format(taski,int(self.opt.memory_num/taski),len(memory_data)))
30 | elif memory == "test_ch":
31 | # curr: total mem: num/(taski-1) (repeat)
32 | # index_current = numpy.random.choice(range(len(dataset)),int(self.opt.memory_num/taski),replace=False)
33 | # split_dataset = Subset(dataset,index_current.tolist())
34 | memory_data,index_list = self.rehearsal_memory(taski, random=False,total_num=self.opt.memory_num,index_array=index_list,repeat=True)
35 | self.create_dataloader_mix(IndexConcatDataset([memory_data,dataset]),self.opt.batch_size)
36 | print("taski is {} current dataset chose {}\n now dataset chose {}".format(taski,int(self.opt.memory_num/taski),len(memory_data)))
37 | elif memory == "large":
38 | # curr: num mem: num
39 | index_current = numpy.random.choice(range(len(dataset)), memory_num, replace=False)
40 | split_dataset = Subset(dataset, index_current.tolist())
41 | memory_data, index_list = self.rehearsal_memory(taski, random=False, total_num=memory_num*taski, index_array=index_list)
42 | self.create_dataloader_mix(IndexConcatDataset([memory_data, split_dataset]), self.opt.batch_size)
43 | print("taski is {} current dataset chose {}\n now dataset chose {}".format(taski, int(memory_num),
44 | len(memory_data)))
45 | elif memory == "total":
46 | # curr : total mem : total(repeat)
47 | total_data_list = []
48 | total_data_list.append(dataset)
49 | for i in range(taski):
50 | dataset = self.create_dataset(data_list=self.select_data, taski=i)
51 | total_data_list.append(dataset)
52 | self.create_dataloader_mix(IndexConcatDataset(total_data_list), self.opt.batch_size)
53 | print("taski is {} current dataset chose {} lenth dataset\n now dataset chose {}".format(taski, len(total_data_list),
54 | len(dataset)))
55 | elif memory != None:
56 | memory_data,index_list = self.rehearsal_memory(taski, random=False,total_num=memory_num,index_array=index_list)
57 | self.create_dataloader(memory_data,(self.opt.batch_size)//2)
58 | self.create_dataloader(dataset,(self.opt.batch_size)//2)
59 | else:
60 | self.create_dataloader(dataset)
61 | return index_list
62 |
63 | def joint_start(
64 | self, opt, select_data, log, taski,total_task):
65 | self.opt = opt
66 | self.select_data = select_data
67 | dashed_line = "-" * 80
68 | print(dashed_line)
69 | log.write(dashed_line + "\n")
70 |
71 | dataset = self.create_dataset(data_list=self.select_data, taski=taski)
72 | if opt.il == "joint_mix":
73 | self.data_list.append(dataset)
74 | if taski == total_task-1:
75 | self.create_dataloader(ConcatDataset(self.data_list), int(self.opt.batch_size))
76 | elif opt.il == "joint_loader":
77 | self.create_dataloader(dataset, int(self.opt.batch_size // total_task))
78 |
79 |
80 | def init_start(
81 | self, opt, select_data, log, taski):
82 | self.opt = opt
83 | self.select_data = select_data
84 | self.data_loader_list = []
85 | self.dataloader_iter_list = []
86 | dashed_line = "-" * 80
87 | print(dashed_line)
88 | log.write(dashed_line + "\n")
89 | print(
90 | f"select_data: {select_data}\n"
91 | )
92 | log.write(
93 | f"select_data: {select_data}\n"
94 | )
95 | self.get_dataset(taski, memory=None)
96 |
97 | def rehearsal_memory(self,taski, random=False,total_num=2000,index_array=None,repeat=False):
98 | data_list = []
99 | select_data = self.select_data
100 | num_i = int(total_num/(taski))
101 | print("memory size is {}\n".format(num_i))
102 | for i in range(taski):
103 | dataset = self.create_dataset(data_list=select_data,taski=i,repeat=repeat)
104 | if random:
105 | index_list = numpy.random.choice(range(len(dataset)),num_i,replace=repeat)
106 | # print(random)
107 | else:
108 | index_list = index_array[i]
109 | split_dataset = Subset(dataset,index_list.tolist())
110 | data_list.append(split_dataset)
111 | return ConcatDataset(data_list), index_array
112 |
113 | def rehearsal_prev_model(self,taski,):
114 | select_data = self.select_data
115 | dataset = self.create_dataset(data_list=select_data,taski=taski-1,repeat=False)
116 | data_loader = torch.utils.data.DataLoader(
117 | dataset,
118 | batch_size=self.opt.batch_size,
119 | shuffle=False,
120 | num_workers=int(self.opt.workers),
121 | collate_fn=AlignCollate(self.opt),
122 | pin_memory=False,
123 | drop_last=False,
124 | )
125 | return data_loader,len(dataset)
126 |
127 | def create_dataset(self, data_list="/", taski=0, mode="train", repeat=True):
128 | """select_data is list for all dataset"""
129 | dataset_list = []
130 | for data_root in data_list:
131 | # print(dataset_log)
132 | # dataset_log += "\n"
133 | dataset = LmdbDataset(data_root + "/" + self.opt.lan_list[taski], self.opt, mode=mode)
134 | dataset_log = f"num samples: {len(dataset)}"
135 | print(dataset_log)
136 |
137 | # for faster training, we multiply small datasets itself.
138 | if len(dataset) < 50000 and repeat:
139 | multiple_times = int(50000 / len(dataset))
140 | dataset_self_multiple = [dataset] * multiple_times
141 | dataset = ConcatDataset(dataset_self_multiple)
142 | dataset_list.append(dataset)
143 | # if memory !=None:
144 | # dataset_list.append(memory_dataset)
145 |
146 | return ConcatDataset(dataset_list)
147 |
148 | def create_dataloader(self,dataset,batch_size=None):
149 | data_loader = torch.utils.data.DataLoader(
150 | dataset,
151 | batch_size=self.opt.batch_size if batch_size==None else batch_size,
152 | shuffle=True,
153 | num_workers=int(self.opt.workers),
154 | collate_fn=AlignCollate(self.opt),
155 | pin_memory=False,
156 | drop_last=False,
157 | )
158 | self.data_loader_list.append(data_loader)
159 | self.dataloader_iter_list.append(iter(data_loader))
160 |
161 | def create_dataloader_mix(self,dataset,batch_size=None):
162 | data_loader = torch.utils.data.DataLoader(
163 | dataset,
164 | batch_size=self.opt.batch_size if batch_size==None else batch_size,
165 | shuffle=True,
166 | num_workers=int(self.opt.workers),
167 | collate_fn=AlignCollate2(self.opt),
168 | pin_memory=False,
169 | drop_last=False,
170 | )
171 | self.data_loader_list.append(data_loader)
172 | self.dataloader_iter_list.append(iter(data_loader))
173 |
174 | def get_batch2(self):
175 | balanced_batch_images = []
176 | balanced_batch_labels = []
177 | balanced_batch_index = []
178 |
179 | for i, data_loader_iter in enumerate(self.dataloader_iter_list):
180 | try:
181 | image, label,index = data_loader_iter.next()
182 | balanced_batch_images.append(image)
183 | balanced_batch_labels += label
184 | balanced_batch_index.append(index)
185 | except StopIteration:
186 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
187 | image, label, index = self.dataloader_iter_list[i].next()
188 | balanced_batch_images.append(image)
189 | balanced_batch_labels += label
190 | balanced_batch_index.append(index)
191 | except ValueError:
192 | pass
193 |
194 | balanced_batch_images = torch.cat(balanced_batch_images, 0)
195 |
196 | return balanced_batch_images, balanced_batch_labels, balanced_batch_index
197 |
198 | def get_batch(self):
199 | balanced_batch_images = []
200 | balanced_batch_labels = []
201 |
202 | for i, data_loader_iter in enumerate(self.dataloader_iter_list):
203 | try:
204 | image, label = data_loader_iter.next()
205 | balanced_batch_images.append(image)
206 | balanced_batch_labels += label
207 | except StopIteration:
208 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
209 | image, label = self.dataloader_iter_list[i].next()
210 | balanced_batch_images.append(image)
211 | balanced_batch_labels += label
212 | except ValueError:
213 | pass
214 |
215 | balanced_batch_images = torch.cat(balanced_batch_images, 0)
216 |
217 | return balanced_batch_images, balanced_batch_labels
218 |
219 | class Val_Dataset(object):
220 | def __init__(self,val_datas,opt):
221 | self.data_loader_list = []
222 | self.dataset_list = []
223 | self.current_data = val_datas[-1]
224 | self.val_datas = val_datas
225 | self.opt = opt
226 | self.AlignCollate_valid = AlignCollate(self.opt, mode="test")
227 |
228 |
229 | def create_dataset(self,val_data=None):
230 | if val_data == None:
231 | val_data = self.current_data
232 | valid_dataset, valid_dataset_log = hierarchical_dataset(
233 | root=val_data, opt=self.opt, mode="test"
234 | )
235 | # print(valid_dataset_log)
236 | print("-" * 80)
237 | valid_loader = torch.utils.data.DataLoader(
238 | valid_dataset,
239 | batch_size=self.opt.batch_size,
240 | shuffle=True, # 'True' to check training progress with validation function.
241 | num_workers=int(self.opt.workers),
242 | collate_fn=self.AlignCollate_valid,
243 | pin_memory=False,
244 | )
245 | return valid_loader
246 |
247 | def create_list_dataset(self,valid_datas=None):
248 | if valid_datas==None:
249 | valid_datas = self.val_datas
250 | concat_data = []
251 | for val_data in valid_datas:
252 | valid_dataset, valid_dataset_log = hierarchical_dataset(
253 | root=val_data, opt=self.opt, mode="test")
254 | if len(valid_dataset) > 700:
255 | index_current = numpy.random.choice(range(len(valid_dataset)),700,replace=False)
256 | valid_dataset = Subset(valid_dataset,index_current.tolist())
257 | concat_data.append(valid_dataset)
258 | print(valid_dataset_log)
259 | print("-" * 80)
260 | val_data = ConcatDataset(concat_data)
261 | valid_loader = torch.utils.data.DataLoader(
262 | val_data,
263 | batch_size=self.opt.batch_size,
264 | shuffle=True, # 'True' to check training progress with validation function.
265 | num_workers=int(self.opt.workers),
266 | collate_fn=self.AlignCollate_valid,
267 | pin_memory=False,
268 | )
269 | return valid_loader
270 |
271 |
272 | class IndexConcatDataset(ConcatDataset):
273 | def __getitem__(self, idx):
274 | if idx < 0:
275 | if -idx > len(self):
276 | raise ValueError("absolute value of index should not exceed dataset length")
277 | idx = len(self) + idx
278 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
279 | if dataset_idx == 0:
280 | sample_idx = idx
281 | else:
282 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
283 | return self.datasets[dataset_idx][sample_idx],dataset_idx
284 |
285 | class DummyDataset(Dataset):
286 | def __init__(self, images, labels):
287 | assert len(images) == len(labels), 'Data size error!'
288 | self.images = images
289 | self.labels = labels
290 |
291 | def __len__(self):
292 | return len(self.images)
293 |
294 | def __getitem__(self, idx):
295 | image = self.images[idx]
296 | label = self.labels[idx]
297 |
298 | return (image, label)
--------------------------------------------------------------------------------
/modules/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 |
5 | from timm.models.layers import DropPath
6 |
7 | import math
8 | from torch import Tensor
9 | from torch.nn import init
10 | from torch.nn.modules.utils import _pair
11 | from torchvision.ops.deform_conv import deform_conv2d as deform_conv2d_tv
12 | from modules.dm_router import GatingMlpBlock
13 |
14 |
15 | class Mlp(nn.Module):
16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
17 | super().__init__()
18 | out_features = out_features or in_features
19 | hidden_features = hidden_features or in_features
20 | self.fc1 = nn.Linear(in_features, hidden_features)
21 | self.act = act_layer()
22 | self.fc2 = nn.Linear(hidden_features, out_features)
23 | self.drop = nn.Dropout(drop)
24 |
25 | def forward(self, x):
26 | x = self.fc1(x)
27 | x = self.act(x)
28 | x = self.drop(x)
29 | x = self.fc2(x)
30 | x = self.drop(x)
31 | return x
32 |
33 |
34 | class CycleFC(nn.Module):
35 | """
36 | """
37 |
38 | def __init__(
39 | self,
40 | in_channels: int,
41 | out_channels: int,
42 | kernel_size, # re-defined kernel_size, represent the spatial area of staircase FC
43 | stride: int = 1,
44 | padding: int = 0,
45 | dilation: int = 1,
46 | groups: int = 1,
47 | bias: bool = True,
48 | ):
49 | super(CycleFC, self).__init__()
50 |
51 | if in_channels % groups != 0:
52 | raise ValueError('in_channels must be divisible by groups')
53 | if out_channels % groups != 0:
54 | raise ValueError('out_channels must be divisible by groups')
55 | if stride != 1:
56 | raise ValueError('stride must be 1')
57 | if padding != 0:
58 | raise ValueError('padding must be 0')
59 |
60 | self.in_channels = in_channels
61 | self.out_channels = out_channels
62 | self.kernel_size = kernel_size
63 | self.stride = _pair(stride)
64 | self.padding = _pair(padding)
65 | self.dilation = _pair(dilation)
66 | self.groups = groups
67 |
68 | self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, 1, 1)) # kernel size == 1
69 |
70 | if bias:
71 | self.bias = nn.Parameter(torch.empty(out_channels))
72 | else:
73 | self.register_parameter('bias', None)
74 | self.register_buffer('offset', self.gen_offset())
75 |
76 | self.reset_parameters()
77 |
78 | def reset_parameters(self) -> None:
79 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
80 |
81 | if self.bias is not None:
82 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
83 | bound = 1 / math.sqrt(fan_in)
84 | init.uniform_(self.bias, -bound, bound)
85 |
86 | def gen_offset(self):
87 | """
88 | offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
89 | out_height, out_width]): offsets to be applied for each position in the
90 | convolution kernel.
91 | """
92 | offset = torch.empty(1, self.in_channels*2, 1, 1)
93 | start_idx = (self.kernel_size[0] * self.kernel_size[1]) // 2
94 | assert self.kernel_size[0] == 1 or self.kernel_size[1] == 1, self.kernel_size
95 | for i in range(self.in_channels):
96 | if self.kernel_size[0] == 1:
97 | offset[0, 2 * i + 0, 0, 0] = 0
98 | offset[0, 2 * i + 1, 0, 0] = (i + start_idx) % self.kernel_size[1] - (self.kernel_size[1] // 2)
99 | else:
100 | offset[0, 2 * i + 0, 0, 0] = (i + start_idx) % self.kernel_size[0] - (self.kernel_size[0] // 2)
101 | offset[0, 2 * i + 1, 0, 0] = 0
102 | return offset
103 |
104 | def forward(self, input: Tensor) -> Tensor:
105 | """
106 | Args:
107 | input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
108 | """
109 | B, C, H, W = input.size()
110 | return deform_conv2d_tv(input, self.offset.expand(B, -1, H, W), self.weight, self.bias, stride=self.stride,
111 | padding=self.padding, dilation=self.dilation)
112 |
113 | def extra_repr(self) -> str:
114 | s = self.__class__.__name__ + '('
115 | s += '{in_channels}'
116 | s += ', {out_channels}'
117 | s += ', kernel_size={kernel_size}'
118 | s += ', stride={stride}'
119 | s += ', padding={padding}' if self.padding != (0, 0) else ''
120 | s += ', dilation={dilation}' if self.dilation != (1, 1) else ''
121 | s += ', groups={groups}' if self.groups != 1 else ''
122 | s += ', bias=False' if self.bias is None else ''
123 | s += ')'
124 | return s.format(**self.__dict__)
125 |
126 |
127 | class CycleMLP(nn.Module):
128 | def __init__(self, dim, segment_dim=8, qkv_bias=False, taski=1,patch=63, proj_drop=0.):
129 | super().__init__()
130 | self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)
131 |
132 | self.sfc_h = CycleFC(dim, dim, (1, 3), 1, 0)
133 | self.sfc_w = CycleFC(dim, dim, (3, 1), 1, 0)
134 |
135 | self.reweight = Mlp(dim, dim // 4, dim * 3)
136 |
137 | self.proj = nn.Linear(dim, dim)
138 | self.proj_drop = nn.Dropout(proj_drop)
139 |
140 | def forward(self, x):
141 | B, H, W, C = x.shape
142 | # B,C,H,W
143 | h = self.sfc_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
144 | w = self.sfc_w(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
145 | c = self.mlp_c(x)
146 |
147 | a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)
148 | a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)
149 |
150 | x = h * a[0] + w * a[1] + c * a[2]
151 |
152 | x = self.proj(x)
153 | x = self.proj_drop(x)
154 |
155 | return x
156 |
157 |
158 | class CycleBlock(nn.Module):
159 |
160 | def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
161 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn=CycleMLP):
162 | super().__init__()
163 | self.norm1 = norm_layer(dim)
164 | self.attn = mlp_fn(dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop)
165 |
166 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
167 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
168 |
169 | self.norm2 = norm_layer(dim)
170 | mlp_hidden_dim = int(dim * mlp_ratio)
171 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
172 | self.skip_lam = skip_lam
173 |
174 | def forward(self, x):
175 | x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam
176 | x = x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam
177 | return x
178 |
179 |
180 | # class WeightedPermuteMLP(nn.Module):
181 | # def __init__(self, dim, segment_dim=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
182 | # super().__init__()
183 | # self.segment_dim = segment_dim
184 | #
185 | # self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)
186 | # self.mlp_h = nn.Linear(dim, dim, bias=qkv_bias)
187 | # self.mlp_w = nn.Linear(dim, dim, bias=qkv_bias)
188 | #
189 | # self.reweight = Mlp(dim, dim // 4, dim * 3)
190 | #
191 | # self.proj = nn.Linear(dim, dim)
192 | # self.proj_drop = nn.Dropout(proj_drop)
193 | #
194 | # def forward(self, x):
195 | # B, H, W, C = x.shape
196 | #
197 | # S = C // self.segment_dim
198 | # h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H * S)
199 | # h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)
200 | #
201 | # w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W * S)
202 | # w = self.mlp_w(w).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)
203 | #
204 | # c = self.mlp_c(x)
205 | # # B, C, H, W -> B, C,[ H, W ]
206 | # a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)
207 | # a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)
208 | #
209 | # x = h * a[0] + w * a[1] + c * a[2]
210 | #
211 | # x = self.proj(x)
212 | # x = self.proj_drop(x)
213 | #
214 | # return x
215 |
216 | class WeightedPermuteMLPv3(nn.Module):
217 | def __init__(self, dim, segment_dim=8, qkv_bias=False, taski=1,patch=63, proj_drop=0.,mlp="None"):
218 | super().__init__()
219 | self.segment_dim = segment_dim
220 | self.taski = taski
221 | self.patch = patch
222 | self.mlp = mlp
223 | self.mlp_c = nn.Sequential(
224 | nn.Linear(dim, dim, bias=qkv_bias),
225 | )
226 | if self.mlp != "taski":
227 | self.mlp_h = nn.Sequential(
228 | nn.Linear(taski * dim, taski * dim, bias=qkv_bias),
229 | # nn.Linear(dim, taski, bias=qkv_bias),
230 | )
231 | else:
232 | self.mlp_h = GatingMlpBlock(dim, dim, taski)
233 | # self.up_mlp = nn.Linear(dim // 2, dim, bias=qkv_bias)
234 |
235 | if self.mlp == "patch":
236 | self.mlp_w = GatingMlpBlock(dim, dim, patch)
237 | else:
238 | self.mlp_w = nn.Sequential(
239 | nn.Linear(patch*taski, patch*taski, bias=qkv_bias),
240 | # nn.Linear(dim, patch, bias=qkv_bias),
241 | )
242 | self.reweight = Mlp(dim, dim // 4, dim * 2)
243 |
244 | self.proj = nn.Linear(dim, dim)
245 | self.proj_drop = nn.Dropout(proj_drop)
246 |
247 | def forward(self, x):
248 | B, H, W, C = x.shape
249 | # print(x.shape)
250 |
251 | if self.mlp != "taski":
252 | # h = rearrange(x,'b i t (h k) -> b t k (i h)',h=64)
253 | h = rearrange(x, 'b i t c -> b t (i c)')
254 | h = self.mlp_h(h)
255 | h = rearrange(h,'b t (i c) -> b i t c',i=self.taski)
256 | else:
257 | h = self.mlp_h(x.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)
258 | # h = self.up_mlp(h)
259 |
260 | # B,C, H,W -> B,H,W,C
261 | if self.mlp != "patch":
262 | w = rearrange(x,'b i t c -> b c (i t)')
263 | w = self.mlp_w(w)
264 | w = rearrange(w,'b c (i t) -> b i t c',t = self.patch)
265 | else:
266 | w = self.mlp_w(x)
267 |
268 | # B, C, H, W -> B, C,[ H, W ]
269 | a = (h + w).permute(0, 3, 1, 2).flatten(2).mean(2)
270 | a = self.reweight(a).reshape(B, C, 2).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)
271 |
272 | x = h * a[0] + w * a[1]
273 |
274 | x = self.proj(x)
275 | x = self.proj_drop(x)
276 |
277 | return x
278 |
279 | # class WeightedPermuteMLPv2(nn.Module):
280 | # def __init__(self, dim, segment_dim=8, qkv_bias=False, taski=1,patch=63, proj_drop=0.):
281 | # super().__init__()
282 | # self.segment_dim = segment_dim
283 | #
284 | # self.mlp_c = nn.Sequential(
285 | # nn.Linear(dim, dim, bias=qkv_bias),
286 | # )
287 | # self.mlp_h = nn.Sequential(
288 | # nn.Linear(taski, dim, bias=qkv_bias),
289 | # nn.Linear(dim, taski, bias=qkv_bias),
290 | # )
291 | # self.mlp_w = nn.Sequential(
292 | # nn.Linear(patch, dim, bias=qkv_bias),
293 | # nn.Linear(dim, patch, bias=qkv_bias),
294 | # )
295 | # self.reweight = Mlp(dim, dim // 4, dim * 3)
296 | #
297 | # self.proj = nn.Linear(dim, dim)
298 | # self.proj_drop = nn.Dropout(proj_drop)
299 | #
300 | # def forward(self, x):
301 | # B, H, W, C = x.shape
302 | # # print(x.shape)
303 | #
304 | # h = x.permute(0,3,2,1)
305 | # h = self.mlp_h(h).permute(0, 3, 2 , 1)
306 | # # B,C, H,W -> B,H,W,C
307 | # w = x.permute(0, 3, 1, 2)
308 | # w = self.mlp_w(w).permute(0, 2, 3, 1)
309 | #
310 | # # S = C // self.segment_dim
311 | # # h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H * S)
312 | # # h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)
313 | #
314 | # # w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W * S)
315 | # # w = self.mlp_w(w).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)
316 | #
317 | # c = self.mlp_c(x)
318 | # # B, C, H, W -> B, C,[ H, W ]
319 | # a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)
320 | # a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)
321 | #
322 | # x = h * a[0] + w * a[1] + c * a[2]
323 | #
324 | # x = self.proj(x)
325 | # x = self.proj_drop(x)
326 | #
327 | # return x
328 |
329 | class PermutatorBlock(nn.Module):
330 |
331 | def __init__(self, dim, mlp_ratio=4., taski = 1, patch = 63, segment_dim=8, qkv_bias=False,
332 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn=WeightedPermuteMLPv3):
333 | super().__init__()
334 | self.norm1 = norm_layer(dim)
335 | self.attn = mlp_fn(dim, segment_dim=segment_dim, taski=taski,patch=patch,qkv_bias=qkv_bias)
336 |
337 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
338 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
339 |
340 | self.norm2 = norm_layer(dim)
341 | mlp_hidden_dim = int(dim * mlp_ratio)
342 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
343 | self.skip_lam = skip_lam
344 |
345 | def forward(self, x):
346 | x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam
347 | x = x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam
348 | return x
--------------------------------------------------------------------------------
/data/transform.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 | import random
4 |
5 | import cv2
6 | import numpy as np
7 | from PIL import Image
8 | from torchvision import transforms
9 | from torchvision.transforms import Compose
10 |
11 |
12 | def sample_asym(magnitude, size=None):
13 | return np.random.beta(1, 4, size) * magnitude
14 |
15 |
16 | def sample_sym(magnitude, size=None):
17 | return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
18 |
19 |
20 | def sample_uniform(low, high, size=None):
21 | return np.random.uniform(low, high, size=size)
22 |
23 |
24 | def get_interpolation(type='random'):
25 | if type == 'random':
26 | choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA]
27 | interpolation = choice[random.randint(0, len(choice) - 1)]
28 | elif type == 'nearest':
29 | interpolation = cv2.INTER_NEAREST
30 | elif type == 'linear':
31 | interpolation = cv2.INTER_LINEAR
32 | elif type == 'cubic':
33 | interpolation = cv2.INTER_CUBIC
34 | elif type == 'area':
35 | interpolation = cv2.INTER_AREA
36 | else:
37 | raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!')
38 | return interpolation
39 |
40 |
41 | class CVRandomRotation(object):
42 | def __init__(self, degrees=15):
43 | assert isinstance(degrees, numbers.Number), "degree should be a single number."
44 | assert degrees >= 0, "degree must be positive."
45 | self.degrees = degrees
46 |
47 | @staticmethod
48 | def get_params(degrees):
49 | return sample_sym(degrees)
50 |
51 | def __call__(self, img):
52 | angle = self.get_params(self.degrees)
53 | src_h, src_w = img.shape[:2]
54 | M = cv2.getRotationMatrix2D(center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
55 | abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
56 | dst_w = int(src_h * abs_sin + src_w * abs_cos)
57 | dst_h = int(src_h * abs_cos + src_w * abs_sin)
58 | M[0, 2] += (dst_w - src_w) / 2
59 | M[1, 2] += (dst_h - src_h) / 2
60 |
61 | flags = get_interpolation()
62 | return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
63 |
64 |
65 | class CVRandomAffine(object):
66 | def __init__(self, degrees, translate=None, scale=None, shear=None):
67 | assert isinstance(degrees, numbers.Number), "degree should be a single number."
68 | assert degrees >= 0, "degree must be positive."
69 | self.degrees = degrees
70 |
71 | if translate is not None:
72 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
73 | "translate should be a list or tuple and it must be of length 2."
74 | for t in translate:
75 | if not (0.0 <= t <= 1.0):
76 | raise ValueError("translation values should be between 0 and 1")
77 | self.translate = translate
78 |
79 | if scale is not None:
80 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
81 | "scale should be a list or tuple and it must be of length 2."
82 | for s in scale:
83 | if s <= 0:
84 | raise ValueError("scale values should be positive")
85 | self.scale = scale
86 |
87 | if shear is not None:
88 | if isinstance(shear, numbers.Number):
89 | if shear < 0:
90 | raise ValueError("If shear is a single number, it must be positive.")
91 | self.shear = [shear]
92 | else:
93 | assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
94 | "shear should be a list or tuple and it must be of length 2."
95 | self.shear = shear
96 | else:
97 | self.shear = shear
98 |
99 | def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear):
100 | # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
101 | from numpy import sin, cos, tan
102 |
103 | if isinstance(shear, numbers.Number):
104 | shear = [shear, 0]
105 |
106 | if not isinstance(shear, (tuple, list)) and len(shear) == 2:
107 | raise ValueError(
108 | "Shear should be a single value or a tuple/list containing " +
109 | "two values. Got {}".format(shear))
110 |
111 | rot = math.radians(angle)
112 | sx, sy = [math.radians(s) for s in shear]
113 |
114 | cx, cy = center
115 | tx, ty = translate
116 |
117 | # RSS without scaling
118 | a = cos(rot - sy) / cos(sy)
119 | b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
120 | c = sin(rot - sy) / cos(sy)
121 | d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
122 |
123 | # Inverted rotation matrix with scale and shear
124 | # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
125 | M = [d, -b, 0,
126 | -c, a, 0]
127 | M = [x / scale for x in M]
128 |
129 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
130 | M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
131 | M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
132 |
133 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1
134 | M[2] += cx
135 | M[5] += cy
136 | return M
137 |
138 | @staticmethod
139 | def get_params(degrees, translate, scale_ranges, shears, height):
140 | angle = sample_sym(degrees)
141 | if translate is not None:
142 | max_dx = translate[0] * height
143 | max_dy = translate[1] * height
144 | translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy)))
145 | else:
146 | translations = (0, 0)
147 |
148 | if scale_ranges is not None:
149 | scale = sample_uniform(scale_ranges[0], scale_ranges[1])
150 | else:
151 | scale = 1.0
152 |
153 | if shears is not None:
154 | if len(shears) == 1:
155 | shear = [sample_sym(shears[0]), 0.]
156 | elif len(shears) == 2:
157 | shear = [sample_sym(shears[0]), sample_sym(shears[1])]
158 | else:
159 | shear = 0.0
160 |
161 | return angle, translations, scale, shear
162 |
163 | def __call__(self, img):
164 | src_h, src_w = img.shape[:2]
165 | angle, translate, scale, shear = self.get_params(
166 | self.degrees, self.translate, self.scale, self.shear, src_h)
167 |
168 | M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle, (0, 0), scale, shear)
169 | M = np.array(M).reshape(2, 3)
170 |
171 | startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)]
172 | project = lambda x, y, a, b, c: int(a * x + b * y + c)
173 | endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints]
174 |
175 | rect = cv2.minAreaRect(np.array(endpoints))
176 | bbox = cv2.boxPoints(rect).astype(dtype=np.int)
177 | max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
178 | min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
179 |
180 | dst_w = int(max_x - min_x)
181 | dst_h = int(max_y - min_y)
182 | M[0, 2] += (dst_w - src_w) / 2
183 | M[1, 2] += (dst_h - src_h) / 2
184 |
185 | # add translate
186 | dst_w += int(abs(translate[0]))
187 | dst_h += int(abs(translate[1]))
188 | if translate[0] < 0: M[0, 2] += abs(translate[0])
189 | if translate[1] < 0: M[1, 2] += abs(translate[1])
190 |
191 | flags = get_interpolation()
192 | return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
193 |
194 |
195 | class CVRandomPerspective(object):
196 | def __init__(self, distortion=0.5):
197 | self.distortion = distortion
198 |
199 | def get_params(self, width, height, distortion):
200 | offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int)
201 | offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int)
202 | topleft = (offset_w[0], offset_h[0])
203 | topright = (width - 1 - offset_w[1], offset_h[1])
204 | botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
205 | botleft = (offset_w[3], height - 1 - offset_h[3])
206 |
207 | startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
208 | endpoints = [topleft, topright, botright, botleft]
209 | return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32)
210 |
211 | def __call__(self, img):
212 | height, width = img.shape[:2]
213 | startpoints, endpoints = self.get_params(width, height, self.distortion)
214 | M = cv2.getPerspectiveTransform(startpoints, endpoints)
215 |
216 | # TODO: more robust way to crop image
217 | rect = cv2.minAreaRect(endpoints)
218 | bbox = cv2.boxPoints(rect).astype(dtype=np.int)
219 | max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
220 | min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
221 | min_x, min_y = max(min_x, 0), max(min_y, 0)
222 |
223 | flags = get_interpolation()
224 | img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE)
225 | img = img[min_y:, min_x:]
226 | return img
227 |
228 |
229 | class CVRescale(object):
230 |
231 | def __init__(self, factor=4, base_size=(128, 512)):
232 | """ Define image scales using gaussian pyramid and rescale image to target scale.
233 |
234 | Args:
235 | factor: the decayed factor from base size, factor=4 keeps target scale by default.
236 | base_size: base size the build the bottom layer of pyramid
237 | """
238 | if isinstance(factor, numbers.Number):
239 | self.factor = round(sample_uniform(0, factor))
240 | elif isinstance(factor, (tuple, list)) and len(factor) == 2:
241 | self.factor = round(sample_uniform(factor[0], factor[1]))
242 | else:
243 | raise Exception('factor must be number or list with length 2')
244 | # assert factor is valid
245 | self.base_h, self.base_w = base_size[:2]
246 |
247 | def __call__(self, img):
248 | if self.factor == 0: return img
249 | src_h, src_w = img.shape[:2]
250 | cur_w, cur_h = self.base_w, self.base_h
251 | scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation())
252 | for _ in range(self.factor):
253 | scale_img = cv2.pyrDown(scale_img)
254 | scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation())
255 | return scale_img
256 |
257 |
258 | class CVGaussianNoise(object):
259 | def __init__(self, mean=0, var=20):
260 | self.mean = mean
261 | if isinstance(var, numbers.Number):
262 | self.var = max(int(sample_asym(var)), 1)
263 | elif isinstance(var, (tuple, list)) and len(var) == 2:
264 | self.var = int(sample_uniform(var[0], var[1]))
265 | else:
266 | raise Exception('degree must be number or list with length 2')
267 |
268 | def __call__(self, img):
269 | noise = np.random.normal(self.mean, self.var ** 0.5, img.shape)
270 | img = np.clip(img + noise, 0, 255).astype(np.uint8)
271 | return img
272 |
273 |
274 | class CVMotionBlur(object):
275 | def __init__(self, degrees=12, angle=90):
276 | if isinstance(degrees, numbers.Number):
277 | self.degree = max(int(sample_asym(degrees)), 1)
278 | elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
279 | self.degree = int(sample_uniform(degrees[0], degrees[1]))
280 | else:
281 | raise Exception('degree must be number or list with length 2')
282 | self.angle = sample_uniform(-angle, angle)
283 |
284 | def __call__(self, img):
285 | M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1)
286 | motion_blur_kernel = np.zeros((self.degree, self.degree))
287 | motion_blur_kernel[self.degree // 2, :] = 1
288 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree))
289 | motion_blur_kernel = motion_blur_kernel / self.degree
290 | img = cv2.filter2D(img, -1, motion_blur_kernel)
291 | img = np.clip(img, 0, 255).astype(np.uint8)
292 | return img
293 |
294 |
295 | class CVGeometry(object):
296 | def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.),
297 | shear=(45, 15), distortion=0.5, p=0.5):
298 | self.p = p
299 | type_p = random.random()
300 | if type_p < 0.33:
301 | self.transforms = CVRandomRotation(degrees=degrees)
302 | elif type_p < 0.66:
303 | self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
304 | else:
305 | self.transforms = CVRandomPerspective(distortion=distortion)
306 |
307 | def __call__(self, img):
308 | if random.random() < self.p:
309 | img = np.array(img)
310 | return Image.fromarray(self.transforms(img))
311 | else:
312 | return img
313 |
314 |
315 | class CVDeterioration(object):
316 | def __init__(self, var, degrees, factor, p=0.5):
317 | self.p = p
318 | transforms = []
319 | if var is not None:
320 | transforms.append(CVGaussianNoise(var=var))
321 | if degrees is not None:
322 | transforms.append(CVMotionBlur(degrees=degrees))
323 | if factor is not None:
324 | transforms.append(CVRescale(factor=factor))
325 |
326 | random.shuffle(transforms)
327 | transforms = Compose(transforms)
328 | self.transforms = transforms
329 |
330 | def __call__(self, img):
331 | if random.random() < self.p:
332 | img = np.array(img)
333 | return Image.fromarray(self.transforms(img))
334 | else:
335 | return img
336 |
337 |
338 | class CVColorJitter(object):
339 | def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5):
340 | self.p = p
341 | self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast,
342 | saturation=saturation, hue=hue)
343 |
344 | def __call__(self, img):
345 | if random.random() < self.p:
346 | return self.transforms(img)
347 | else:
348 | return img
--------------------------------------------------------------------------------
/il_modules/der.py:
--------------------------------------------------------------------------------
1 | import time
2 | from tqdm import tqdm
3 | import torch
4 | import torch.nn.init as init
5 | from il_modules.base import BaseLearner
6 | from modules.model import DERNet
7 | from test import validation
8 | from tools.utils import Averager, adjust_learning_rate
9 |
10 | EPSILON = 1e-8
11 |
12 | init_epoch = 200
13 | init_lr = 0.1
14 | init_milestones = [60, 120, 170]
15 | init_lr_decay = 0.1
16 | init_weight_decay = 0.0005
17 |
18 | epochs = 170
19 | lrate = 0.1
20 | milestones = [80, 120, 150]
21 | lrate_decay = 0.1
22 | batch_size = 128
23 | weight_decay = 2e-4
24 | num_workers = 8
25 | T = 2
26 |
27 |
28 | class DER(BaseLearner):
29 |
30 | def __init__(self, opt):
31 | super().__init__(opt)
32 | self.model = DERNet(opt)
33 |
34 | def after_task(self):
35 | self.model = self.model.module
36 | self._known_classes = self._total_classes
37 | # logging.info('Exemplar size: {}'.format(self.exemplar_size))
38 |
39 | def model_eval_and_train(self,taski):
40 | self.model.train()
41 | self.model.module.model[-1].train()
42 | if taski >= 1:
43 | for i in range(taski):
44 | self.model.module.model[i].eval()
45 |
46 | def change_model(self,):
47 | """ model configuration """
48 | # model.module.reset_class(opt, device)
49 | # self.model.update_fc(self.opt.output_channel, self._total_classes)
50 | self.model.update_fc(self.opt.hidden_size, self._total_classes)
51 | self.model.build_prediction(self.opt, self._total_classes)
52 | self.model.build_aux_prediction(self.opt, self._total_classes)
53 | # reset_class(self.model.module, self.device)
54 | # data parallel for multi-GPU
55 | self.model = torch.nn.DataParallel(self.model).to(self.device)
56 | self.model.train()
57 | # return self.model
58 |
59 | def build_model(self):
60 | """ model configuration """
61 | # self.model.update_fc(self.opt.output_channel, self._total_classes)
62 | self.model.update_fc(self.opt.hidden_size, self._total_classes)
63 | self.model.build_prediction(self.opt, self._total_classes)
64 | self.model.build_aux_prediction(self.opt, self._total_classes)
65 |
66 | # weight initialization
67 | for name, param in self.model.named_parameters():
68 | if "localization_fc2" in name:
69 | print(f"Skip {name} as it is already initialized")
70 | continue
71 | try:
72 | if "bias" in name:
73 | init.constant_(param, 0.0)
74 | elif "weight" in name:
75 | init.kaiming_normal_(param)
76 | except Exception as e: # for batchnorm.
77 | if "weight" in name:
78 | param.data.fill_(1)
79 | continue
80 |
81 | # data parallel for multi-GPU
82 | self.model = torch.nn.DataParallel(self.model).to(self.device)
83 | self.model.train()
84 |
85 | def incremental_train(self, taski, character, train_loader, valid_loader):
86 |
87 | # pre task classes for know classes
88 | # self._known_classes = self._total_classes
89 | self.character = character
90 | self.converter = self.build_converter()
91 | valid_loader = valid_loader.create_dataset()
92 |
93 | if taski > 0:
94 | self.change_model()
95 | else:
96 | self.criterion = self.build_criterion()
97 | self.build_model()
98 |
99 | # print opt config
100 | # self.print_config(self.opt)
101 | if taski > 0:
102 | for i in range(taski):
103 | for p in self.model.module.model[i].parameters():
104 | p.requires_grad = False
105 |
106 | # filter that only require gradient descent
107 | filtered_parameters = self.count_param()
108 |
109 | # setup optimizer
110 | self.build_optimizer(filtered_parameters)
111 |
112 | if self.opt.start_task > taski:
113 |
114 | if taski > 0:
115 | if self.opt.memory != None:
116 | self.build_rehearsal_memory(train_loader, taski)
117 | else:
118 | train_loader.get_dataset(taski, memory=self.opt.memory)
119 |
120 | # if self.opt.ch_list!=None:
121 | # name = self.opt.ch_list[taski]
122 | # else:
123 | name = self.opt.lan_list[taski]
124 | saved_best_model = f"./saved_models/{self.opt.exp_name}/{name}_{taski}_best_score.pth"
125 | # os.system(f'cp {saved_best_model} ./result/{opt.exp_name}/')
126 | self.model.load_state_dict(torch.load(f"{saved_best_model}"), strict=True)
127 | print(
128 | 'Task {} load checkpoint from {}.'.format(taski, saved_best_model)
129 | )
130 |
131 | else:
132 | print(
133 | 'Task {} start training for model ------{}------'.format(taski,self.opt.exp_name)
134 | )
135 | """ start training """
136 | self._train(0, taski, train_loader, valid_loader)
137 |
138 |
139 | def _train(self, start_iter,taski, train_loader, valid_loader):
140 | if taski == 0:
141 | self._init_train(start_iter,taski, train_loader, valid_loader)
142 | else:
143 | if self.opt.memory != None:
144 | self.build_rehearsal_memory(train_loader, taski)
145 | else:
146 | train_loader.get_dataset(taski, memory=self.opt.memory)
147 | self._update_representation(start_iter,taski, train_loader, valid_loader)
148 | self.model.module.weight_align(self._total_classes - self._known_classes)
149 |
150 | def _init_train(self,start_iter,taski, train_loader, valid_loader):
151 | # loss averager
152 | train_loss_avg = Averager()
153 | train_clf_loss = Averager()
154 | train_aux_loss = Averager()
155 | start_time = time.time()
156 | best_score = -1
157 |
158 | # training loop
159 | for iteration in tqdm(
160 | range(start_iter + 1, self.opt.num_iter + 1),
161 | total=self.opt.num_iter,
162 | position=0,
163 | leave=True,
164 | ):
165 | image_tensors, labels = train_loader.get_batch()
166 |
167 | image = image_tensors.to(self.device)
168 | labels_index, labels_length = self.converter.encode(
169 | labels, batch_max_length=self.opt.batch_max_length
170 | )
171 | batch_size = image.size(0)
172 |
173 | # default recognition loss part
174 | if "CTC" in self.opt.Prediction:
175 | preds = self.model(image)['logits']
176 | # preds = self.model(image)
177 | preds_size = torch.IntTensor([preds.size(1)] * batch_size)
178 | preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2)
179 | loss = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length)
180 | else:
181 | preds = self.model(image, labels_index[:, :-1])['logits'] # align with Attention.forward
182 | target = labels_index[:, 1:] # without [SOS] Symbol
183 | loss = self.criterion(
184 | preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)
185 | )
186 |
187 | self.model.zero_grad()
188 | loss.backward()
189 | torch.nn.utils.clip_grad_norm_(
190 | self.model.parameters(), self.opt.grad_clip
191 | ) # gradient clipping with 5 (Default)
192 | self.optimizer.step()
193 | train_loss_avg.add(loss)
194 |
195 | if "super" in self.opt.schedule:
196 | self.scheduler.step()
197 | else:
198 | adjust_learning_rate(self.optimizer, iteration, self.opt)
199 |
200 | # validation part.
201 | # To see training progress, we also conduct validation when 'iteration == 1'
202 | if iteration % self.opt.val_interval == 0 or iteration ==1:
203 | # for validation log
204 | self.val(valid_loader, self.opt, best_score, start_time, iteration,
205 | train_loss_avg, train_clf_loss, train_aux_loss, taski)
206 | train_loss_avg.reset()
207 |
208 | def _update_representation(self,start_iter, taski, train_loader, valid_loader):
209 | # loss averager
210 | train_loss_avg = Averager()
211 | train_clf_loss = Averager()
212 | train_aux_loss = Averager()
213 |
214 | self.model_eval_and_train(taski)
215 |
216 |
217 | start_time = time.time()
218 | best_score = -1
219 |
220 | # training loop
221 | for iteration in tqdm(
222 | range(start_iter + 1, self.opt.num_iter + 1),
223 | total=self.opt.num_iter,
224 | position=0,
225 | leave=True,
226 | ):
227 | image_tensors, labels = train_loader.get_batch()
228 |
229 | image = image_tensors.to(self.device)
230 | labels_index, labels_length = self.converter.encode(
231 | labels, batch_max_length=self.opt.batch_max_length
232 | )
233 | batch_size = image.size(0)
234 |
235 | # default recognition loss part
236 | if "CTC" in self.opt.Prediction:
237 | output = self.model(image)
238 | preds = output["logits"]
239 | aux_logits = output["aux_logits"]
240 | aux_targets = labels_index.clone()
241 | # aux_targets = torch.where(aux_targets - self._known_classes + 1 > 0,
242 | # aux_targets - self._known_classes + 1, 0)
243 |
244 | aux_preds_size = torch.IntTensor([aux_logits.size(1)] * batch_size)
245 | preds_size = torch.IntTensor([preds.size(1)] * batch_size)
246 | # B,T,C(max) -> T, B, C
247 | preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2)
248 | aux_preds_log_softmax = aux_logits.log_softmax(2).permute(1, 0, 2)
249 |
250 | loss_clf = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length)
251 | loss_aux = self.criterion(aux_preds_log_softmax, aux_targets, aux_preds_size, labels_length)
252 | else:
253 | output = self.model(image, labels_index[:, :-1]) # align with Attention.forward
254 | preds = output["logits"]
255 | aux_logits = output["aux_logits"]
256 | aux_targets = labels_index.clone()[:, 1:]
257 | target = labels_index[:, 1:] # without [SOS] Symbol
258 | loss_clf = self.criterion(
259 | preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)
260 | )
261 | loss_aux = self.criterion(
262 | aux_logits.view(-1, aux_logits.shape[-1]), aux_targets.contiguous().view(-1)
263 | )
264 | # loss = loss_clf + loss_aux
265 | loss = loss_clf
266 |
267 | self.model.zero_grad()
268 | loss.backward()
269 | torch.nn.utils.clip_grad_norm_(
270 | self.model.parameters(), self.opt.grad_clip
271 | ) # gradient clipping with 5 (Default)
272 | self.optimizer.step()
273 | train_loss_avg.add(loss)
274 | train_clf_loss.add(loss_clf)
275 | train_aux_loss.add(loss_aux)
276 |
277 | if "super" in self.opt.schedule:
278 | self.scheduler.step()
279 | else:
280 | adjust_learning_rate(self.optimizer, iteration, self.opt)
281 |
282 | # validation part.
283 | # To see training progress, we also conduct validation when 'iteration == 1'
284 | if iteration % self.opt.val_interval == 0 or iteration == 1:
285 | # for validation log
286 | self.val(valid_loader, self.opt, best_score, start_time, iteration,
287 | train_loss_avg, train_clf_loss, train_aux_loss, taski)
288 | train_loss_avg.reset()
289 | train_clf_loss.reset()
290 | train_aux_loss.reset()
291 |
292 | def val(self, valid_loader, opt, best_score, start_time, iteration,
293 | train_loss_avg,train_clf_loss, train_aux_loss, taski):
294 | self.model.eval()
295 | start_time = time.time()
296 | with torch.no_grad():
297 | (
298 | valid_loss,
299 | current_score,
300 | ned_score,
301 | preds,
302 | confidence_score,
303 | labels,
304 | infer_time,
305 | length_of_data,
306 | ) = validation(self.model, self.criterion, valid_loader, self.converter, opt)
307 | self.model.train()
308 |
309 | # keep best score (accuracy or norm ED) model on valid dataset
310 | # Do not use this on test datasets. It would be an unfair comparison
311 | # (training should be done without referring test set).
312 | if current_score > best_score:
313 | best_score = current_score
314 | # if opt.ch_list != None:
315 | # name = opt.ch_list[taski]
316 | # else:
317 | name = opt.lan_list[taski]
318 | torch.save(
319 | self.model.state_dict(),
320 | f"./saved_models/{opt.exp_name}/{name}_{taski}_best_score.pth",
321 | )
322 |
323 | # validation log: loss, lr, score (accuracy or norm ED), time.
324 | lr = self.optimizer.param_groups[0]["lr"]
325 | elapsed_time = time.time() - start_time
326 | valid_log = f"\n[{iteration}/{opt.num_iter}] Train_loss: {train_loss_avg.val():0.5f}, Valid_loss: {valid_loss:0.5f} \n "
327 | if train_clf_loss !=None:
328 | valid_log += f"CLF_loss: {train_clf_loss.val():0.5f} , Aux_loss: {train_aux_loss.val():0.5f}\n"
329 | valid_log += f'{"":9s}Current_score: {current_score:0.2f}, Ned_score: {ned_score:0.2f}\n'
330 | valid_log += f'{"":9s}Current_lr: {lr:0.7f}, Best_score: {best_score:0.2f}\n'
331 | valid_log += f'{"":9s}Infer_time: {infer_time:0.2f}, Elapsed_time: {elapsed_time/length_of_data:0.4f}\n'
332 |
333 | # show some predicted results
334 | dashed_line = "-" * 80
335 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
336 | predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n"
337 | for gt, pred, confidence in zip(
338 | labels[:5], preds[:5], confidence_score[:5]
339 | ):
340 | if "Attn" in opt.Prediction:
341 | gt = gt[: gt.find("[EOS]")]
342 | pred = pred[: pred.find("[EOS]")]
343 |
344 | predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n"
345 | predicted_result_log += f"{dashed_line}"
346 | valid_log = f"{valid_log}\n{predicted_result_log}"
347 | print(valid_log)
348 | self.write_log(valid_log + "\n")
349 |
--------------------------------------------------------------------------------
/tiny_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import random
5 | import argparse
6 | from data.data_manage import Dataset_Manager, Val_Dataset
7 | from il_modules.base import BaseLearner
8 | from il_modules.der import DER
9 | from il_modules.mrn import MRN
10 | from il_modules.ewc import EWC
11 | from il_modules.joint import JointLearner
12 | from il_modules.lwf import LwF
13 | from il_modules.wa import WA
14 |
15 | print(os.getcwd())
16 | import torch
17 | import torch.backends.cudnn as cudnn
18 | import torch.utils.data
19 | import numpy as np
20 | from mmcv import Config
21 |
22 | from data.dataset import hierarchical_dataset, AlignCollate
23 | from test import validation
24 |
25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26 |
27 | def write_data_log(line):
28 | '''
29 |
30 | :param name:
31 | :param line: list of the string [a,b,c]
32 | :return:
33 | '''
34 | with open(f"data_any.txt", "a+") as log:
35 | log.write(line)
36 |
37 | def load_dict(path,char):
38 | ch_list = []
39 | character = []
40 | f = open(path + "/dict.txt")
41 | line = f.readline()
42 | while line:
43 | ch_list.append(line.strip("\n"))
44 | line = f.readline()
45 | f.close()
46 |
47 | for ch in ch_list:
48 | if char.get(ch, None) == None:
49 | char[ch] = 1
50 | for key, value in char.items():
51 | character.append(key)
52 | print("dict has {} number characters\n".format(len(character)))
53 | return character,char
54 |
55 |
56 | def build_arg(parser):
57 | parser.add_argument(
58 | "--config",
59 | default="config/crnn_mrn.py",
60 | help="path to validation dataset",
61 | )
62 | parser.add_argument(
63 | "--valid_datas",
64 | default=[" ../dataset/MLT17_IL/test_2017", "../dataset/MLT19_IL/test_2019"],
65 | help="path to testing dataset",
66 | )
67 | parser.add_argument(
68 | "--select_data",
69 | type=str,
70 | default=[" ../dataset/MLT17_IL/train_2017", "../dataset/MLT19_IL/train_2019"],
71 | help="select training data.",
72 | )
73 | parser.add_argument(
74 | "--workers", type=int, default=4, help="number of data loading workers"
75 | )
76 | parser.add_argument("--batch_size", type=int, default=128, help="input batch size")
77 | parser.add_argument(
78 | "--num_iter", type=int, default=20000, help="number of iterations to train for"
79 | )
80 | parser.add_argument(
81 | "--val_interval",
82 | type=int,
83 | default=5000,
84 | help="Interval between each validation",
85 | )
86 | parser.add_argument(
87 | "--log_multiple_test", action="store_true", help="log_multiple_test"
88 | )
89 | parser.add_argument(
90 | "--grad_clip", type=float, default=5, help="gradient clipping value. default=5"
91 | )
92 | """ Optimizer """
93 | parser.add_argument(
94 | "--optimizer", type=str, default="adam", help="optimizer |sgd|adadelta|adam|"
95 | )
96 | parser.add_argument(
97 | "--lr",
98 | type=float,
99 | default=0.0005,
100 | help="learning rate, default=1.0 for Adadelta, 0.0005 for Adam",
101 | )
102 | parser.add_argument(
103 | "--sgd_momentum", default=0.9, type=float, help="momentum for SGD"
104 | )
105 | parser.add_argument(
106 | "--sgd_weight_decay", default=0.000001, type=float, help="weight decay for SGD"
107 | )
108 | parser.add_argument(
109 | "--rho",
110 | type=float,
111 | default=0.95,
112 | help="decay rate rho for Adadelta. default=0.95",
113 | )
114 | parser.add_argument(
115 | "--eps", type=float, default=1e-8, help="eps for Adadelta. default=1e-8"
116 | )
117 | parser.add_argument(
118 | "--schedule",
119 | default="super",
120 | nargs="*",
121 | help="(learning rate schedule. default is super for super convergence, 1 for None, [0.6, 0.8] for the same setting with ASTER",
122 | )
123 | parser.add_argument(
124 | "--lr_drop_rate",
125 | type=float,
126 | default=0.1,
127 | help="lr_drop_rate. default is the same setting with ASTER",
128 | )
129 |
130 | """ Model Architecture """
131 | parser.add_argument("--model_name", type=str, required=False, help="CRNN|TRBA")
132 | parser.add_argument(
133 | "--num_fiducial",
134 | type=int,
135 | default=20,
136 | help="number of fiducial points of TPS-STN",
137 | )
138 | parser.add_argument(
139 | "--input_channel",
140 | type=int,
141 | default=3,
142 | help="the number of input channel of Feature extractor",
143 | )
144 | parser.add_argument(
145 | "--output_channel",
146 | type=int,
147 | default=512,
148 | help="the number of output channel of Feature extractor",
149 | )
150 | parser.add_argument(
151 | "--hidden_size", type=int, default=256, help="the size of the LSTM hidden state"
152 | )
153 |
154 | """ Data processing """
155 | parser.add_argument(
156 | "--batch_ratio",
157 | type=str,
158 | default="1.0",
159 | help="assign ratio for each selected data in the batch",
160 | )
161 | parser.add_argument(
162 | "--total_data_usage_ratio",
163 | type=str,
164 | default="1.0",
165 | help="total data usage ratio, this ratio is multiplied to total number of data.",
166 | )
167 | parser.add_argument(
168 | "--batch_max_length", type=int, default=25, help="maximum-label-length"
169 | )
170 | parser.add_argument(
171 | "--imgH", type=int, default=32, help="the height of the input image"
172 | )
173 | parser.add_argument(
174 | "--imgW", type=int, default=100, help="the width of the input image"
175 | )
176 | parser.add_argument(
177 | "--NED", action="store_true", help="For Normalized edit_distance"
178 | )
179 | parser.add_argument(
180 | "--Aug",
181 | type=str,
182 | default="None",
183 | help="whether to use augmentation |None|Blur|Crop|Rot|",
184 | )
185 | """ exp_name and etc """
186 | parser.add_argument("--exp_name", help="Where to store logs and models")
187 | parser.add_argument(
188 | "--manual_seed", type=int, default=111, help="for random seed setting"
189 | )
190 | parser.add_argument(
191 | "--saved_model", default="", help="path to model to continue training"
192 | )
193 | return parser
194 |
195 | def train(opt, log):
196 | # ["Latin", "Chinese", "Arabic", "Japanese", "Korean", "Bangla","Hindi","Symbols"]
197 | write_data_log(f"----------- {opt.exp_name} ------------\n")
198 | print(f"----------- {opt.exp_name} ------------\n")
199 |
200 | valid_datasets = train_datasets = [lan for lan in opt.lan_list]
201 |
202 | best_scores = []
203 | ned_scores = []
204 | valid_datas = []
205 | char = dict()
206 | """ final options """
207 | # print(opt)
208 | opt_log = "------------ Options -------------\n"
209 | args = vars(opt)
210 | for k, v in args.items():
211 | if str(k) == "character" and len(str(v)) > 500:
212 | opt_log += f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}\n"
213 | opt_log += "---------------------------------------\n"
214 | # print(opt_log)
215 | log.write(opt_log)
216 | if opt.il == "lwf":
217 | learner = LwF(opt)
218 | elif opt.il == "wa":
219 | learner = WA(opt)
220 | elif opt.il == "ewc":
221 | learner = EWC(opt)
222 | elif opt.il == "der":
223 | learner = DER(opt)
224 | elif opt.il == "mrn":
225 | learner = MRN(opt)
226 | elif opt.il == "joint_mix" or opt.il == "joint_loader":
227 | learner = JointLearner(opt)
228 | else:
229 | learner = BaseLearner(opt)
230 |
231 | data_manager = Dataset_Manager(opt)
232 | for taski in range(len(train_datasets)):
233 | # train_data = os.path.join(opt.train_data, train_datasets[taski])
234 | for valid_data in opt.valid_datas:
235 | val_data = os.path.join(valid_data, valid_datasets[taski])
236 | valid_datas.append(val_data)
237 |
238 | valid_loader = Val_Dataset(valid_datas,opt)
239 | """dataset preparation"""
240 | select_data = opt.select_data
241 | AlignCollate_valid = AlignCollate(opt, mode="test")
242 |
243 | if opt.il =="joint_loader" or opt.il == "joint_mix":
244 | valid_datas = []
245 | char = {}
246 | for taski in range(len(train_datasets)):
247 | # char={}
248 | # train_data = os.path.join(opt.train_data, train_datasets[taski])
249 | for val_data in opt.valid_datas:
250 | valid_data = os.path.join(val_data, valid_datasets[taski])
251 | valid_datas.append(valid_data)
252 | data_manager.joint_start(opt, select_data, log, taski, len(train_datasets))
253 | for data_path in opt.select_data:
254 | opt.character, char = load_dict(data_path + f"/{opt.lan_list[taski]}", char)
255 | print(len(opt.character))
256 | best_scores,ned_scores = learner.incremental_train(0,opt.character, data_manager, valid_loader,AlignCollate_valid,valid_datas)
257 | """ Evaluation at the end of training """
258 | best_scores, ned_scores = learner.test(AlignCollate_valid, valid_datas, best_scores, ned_scores, 0)
259 | break
260 | if taski == 0:
261 | data_manager.init_start(opt, select_data, log, taski)
262 | train_loader = data_manager
263 |
264 | #-------load char to dict --------#
265 | for data_path in opt.select_data:
266 | if data_path=="/":
267 | opt.character = load_dict(data_path+f"/{opt.lan_list[taski]}",char)
268 | else:
269 | opt.character,tmp_char = load_dict(data_path+f"/{opt.lan_list[taski]}",char)
270 | # ----- incremental model start -------
271 |
272 | learner.incremental_train(taski, opt.character, train_loader, valid_loader)
273 |
274 | # ----- incremental model end -------
275 | """ Evaluation at the end of training """
276 | best_scores,ned_scores = learner.test(AlignCollate_valid,valid_datas,best_scores,ned_scores, taski)
277 | learner.after_task()
278 |
279 | write_data_log(f"----------- {opt.exp_name} ------------\n")
280 | print(f"----------- {opt.exp_name} ------------\n")
281 | if len(opt.valid_datas) == 1:
282 | print(
283 | 'ALL Average Incremental Accuracy: {:.2f} \n'.format(sum(best_scores)/len(best_scores))
284 | )
285 | write_data_log('ALL Average Acc: {:.2f} \n'.format(sum(best_scores)/len(best_scores)))
286 | elif len(opt.valid_datas) == 2:
287 | print(
288 | 'ALL Average 17 Acc: {:.2f} \n'.format(sum(best_scores) / len(best_scores))
289 | )
290 | print(
291 | 'ALL Average 19 Acc: {:.2f} \n'.format(sum(ned_scores) / len(ned_scores))
292 | )
293 | write_data_log('ALL 17 Acc: {:.2f} \n'.format(sum(best_scores) / len(best_scores)))
294 | write_data_log('ALL 19 Acc: {:.2f} \n'.format(sum(ned_scores) / len(ned_scores)))
295 |
296 | def val(model, criterion, valid_loader, converter, opt,optimizer,best_score,start_time,iteration,train_loss_avg,taski):
297 | with open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") as log:
298 | model.eval()
299 | with torch.no_grad():
300 | (
301 | valid_loss,
302 | current_score,
303 | ned_score,
304 | preds,
305 | confidence_score,
306 | labels,
307 | infer_time,
308 | length_of_data,
309 | ) = validation(model, criterion, valid_loader, converter, opt)
310 | model.train()
311 |
312 | # keep best score (accuracy or norm ED) model on valid dataset
313 | # Do not use this on test datasets. It would be an unfair comparison
314 | # (training should be done without referring test set).
315 | if current_score > best_score:
316 | best_score = current_score
317 | # if opt.ch_list!=None:
318 | # name = opt.ch_list[taski]
319 | # else:
320 | name = opt.lan_list[taski]
321 | torch.save(
322 | model.state_dict(),
323 | f"./saved_models/{opt.exp_name}/{name}_{taski}_best_score.pth",
324 | )
325 |
326 | # validation log: loss, lr, score (accuracy or norm ED), time.
327 | lr = optimizer.param_groups[0]["lr"]
328 | elapsed_time = time.time() - start_time
329 | valid_log = f"\n[{iteration}/{opt.num_iter}] Train_loss: {train_loss_avg.val():0.5f}, Valid_loss: {valid_loss:0.5f} \n "
330 | # valid_log += f", Semi_loss: {semi_loss_avg.val():0.5f}\n"
331 | valid_log += f'{"":9s}Current_score: {current_score:0.2f}, Ned_score: {ned_score:0.2f}\n'
332 | valid_log += f'{"":9s}Current_lr: {lr:0.7f}, Best_score: {best_score:0.2f}\n'
333 | valid_log += f'{"":9s}Infer_time: {infer_time:0.2f}, Elapsed_time: {elapsed_time:0.2f}\n'
334 |
335 | # show some predicted results
336 | dashed_line = "-" * 80
337 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
338 | predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n"
339 | for gt, pred, confidence in zip(
340 | labels[:5], preds[:5], confidence_score[:5]
341 | ):
342 | if "Attn" in opt.Prediction:
343 | gt = gt[: gt.find("[EOS]")]
344 | pred = pred[: pred.find("[EOS]")]
345 |
346 | predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n"
347 | predicted_result_log += f"{dashed_line}"
348 | valid_log = f"{valid_log}\n{predicted_result_log}"
349 | print(valid_log)
350 | log.write(valid_log + "\n")
351 | write_data_log(f"Task {opt.lan_list[taski]} [{iteration}/{opt.num_iter}] : Score:{current_score:0.2f} LR:{lr:0.7f}\n")
352 |
353 |
354 |
355 | def test(AlignCollate_valid,valid_datas,model,criterion,converter,opt,best_scores,taski,log):
356 | print("---Start evaluation on benchmark testset----")
357 | """ keep evaluation model and result logs """
358 | os.makedirs(f"./result/{opt.exp_name}", exist_ok=True)
359 | os.makedirs(f"./evaluation_log", exist_ok=True)
360 | # if opt.ch_list != None:
361 | # name = opt.ch_list[taski]
362 | # else:
363 | name = opt.lan_list[taski]
364 | saved_best_model = f"./saved_models/{opt.exp_name}/{name}_{taski}_best_score.pth"
365 | # os.system(f'cp {saved_best_model} ./result/{opt.exp_name}/')
366 | model.load_state_dict(torch.load(f"{saved_best_model}"))
367 |
368 | task_accs = []
369 | for val_data in valid_datas:
370 | valid_dataset, valid_dataset_log = hierarchical_dataset(
371 | root=val_data, opt=opt, mode="test")
372 | valid_loader = torch.utils.data.DataLoader(
373 | valid_dataset,
374 | batch_size=opt.batch_size,
375 | shuffle=True, # 'True' to check training progress with validation function.
376 | num_workers=int(opt.workers),
377 | collate_fn=AlignCollate_valid,
378 | pin_memory=False,
379 | )
380 |
381 | model.eval()
382 | with torch.no_grad():
383 | (
384 | valid_loss,
385 | current_score,
386 | ned_score,
387 | preds,
388 | confidence_score,
389 | labels,
390 | infer_time,
391 | length_of_data,
392 | ) = validation(model, criterion, valid_loader, converter, opt)
393 |
394 | task_accs.append(current_score)
395 |
396 | best_scores.append(sum(task_accs) / len(task_accs))
397 |
398 | acc_log= f'Task {taski} Test Average Incremental Accuracy: {best_scores[taski]} \n Task {taski} Incremental Accuracy: {task_accs}'
399 | # acc_log = f'Task {taski} Test Average Incremental Accuracy: {best_scores[taski]} \n '
400 | # acc_log += f'Task {taski} Incremental Accuracy: {task_accs:.2f}'
401 | write_data_log(f'Task {taski} Avg Acc: {best_scores[taski]:0.2f} \n {task_accs}\n')
402 | print(acc_log)
403 | log.write(acc_log)
404 | return best_scores,log
405 |
406 |
407 | if __name__ == "__main__":
408 |
409 | parser = argparse.ArgumentParser()
410 | parser = build_arg(parser)
411 |
412 | arg = parser.parse_args()
413 | cfg = Config.fromfile(arg.config)
414 |
415 | opt={}
416 | opt.update(cfg.common)
417 | # opt.update(cfg.test)
418 | opt.update(cfg.model)
419 | opt.update(cfg.train)
420 | opt.update(cfg.optimizer)
421 |
422 | opt = argparse.Namespace(**opt)
423 |
424 | """ Seed and GPU setting """
425 | random.seed(opt.manual_seed)
426 | np.random.seed(opt.manual_seed)
427 | torch.manual_seed(opt.manual_seed)
428 | torch.cuda.manual_seed_all(opt.manual_seed) # if you are using multi-GPU.
429 | torch.cuda.manual_seed(opt.manual_seed)
430 |
431 | cudnn.benchmark = True # It fasten training.
432 | cudnn.deterministic = True
433 |
434 | opt.gpu_name = "_".join(torch.cuda.get_device_name().split())
435 | if sys.platform == "linux":
436 | opt.CUDA_VISIBLE_DEVICES = os.environ["CUDA_VISIBLE_DEVICES"]
437 | else:
438 | opt.CUDA_VISIBLE_DEVICES = 0 # for convenience
439 | opt.num_gpu = torch.cuda.device_count()
440 |
441 | if sys.platform == "win32":
442 | opt.workers = 0
443 |
444 | """ directory and log setting """
445 | if not opt.exp_name:
446 | opt.exp_name = f"Seed{opt.manual_seed}-{opt.model_name}"
447 |
448 | os.makedirs(f"./saved_models/{opt.exp_name}", exist_ok=True)
449 | log = open(f"./saved_models/{opt.exp_name}/log_train.txt", "a")
450 | command_line_input = " ".join(sys.argv)
451 | print(
452 | f"Command line input: CUDA_VISIBLE_DEVICES={opt.CUDA_VISIBLE_DEVICES} python {command_line_input}"
453 | )
454 | log.write(
455 | f"Command line input: CUDA_VISIBLE_DEVICES={opt.CUDA_VISIBLE_DEVICES} python {command_line_input}\n"
456 | )
457 | os.makedirs(f"./tensorboard", exist_ok=True)
458 | # opt.writer = SummaryWriter(log_dir=f"./tensorboard/{opt.exp_name}")
459 |
460 | train(opt, log)
461 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import argparse
5 | import re
6 | from datetime import date
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | import torch.utils.data
11 | import torch.nn.functional as F
12 | import numpy as np
13 | from mmcv import Config
14 | from nltk.metrics.distance import edit_distance
15 | from tqdm import tqdm
16 |
17 | from tools.utils import CTCLabelConverter, AttnLabelConverter, Averager
18 | from data.dataset import hierarchical_dataset, AlignCollate
19 | from modules.model import Model
20 |
21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 |
23 |
24 | def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False):
25 |
26 | if opt.eval_type == "benchmark":
27 | """evaluation with 6 benchmark evaluation datasets"""
28 | eval_data_list = [
29 | "IIIT5k_3000",
30 | "SVT",
31 | "IC13_1015",
32 | "IC15_2077",
33 | "SVTP",
34 | "CUTE80",
35 | ]
36 | opt.eval_data = "data_CVPR2021/evaluation/benchmark/"
37 |
38 | elif opt.eval_type == "addition":
39 | """evaluation with 7 additionally collected evaluation datasets"""
40 | eval_data_list = [
41 | "5.COCO",
42 | "6.RCTW17",
43 | "7.Uber",
44 | "8.ArT",
45 | "9.LSVT",
46 | "10.MLT19",
47 | "11.ReCTS",
48 | ]
49 | opt.eval_data = "data_CVPR2021/evaluation/addition/"
50 | elif opt.eval_type == "IL_STR":
51 | """evaluation with IL_STR datasets"""
52 | eval_data_list = ["Latin", "Chinese", "Arabic", "Japanese", "Korean", "Bangla", "Hindi", "Symbols"]
53 |
54 | opt.eval_data = "../dataset/MLT2019/test_2019/"
55 |
56 | if calculate_infer_time:
57 | eval_batch_size = (
58 | 1 # batch_size should be 1 to calculate the GPU inference time per image.
59 | )
60 | else:
61 | eval_batch_size = opt.batch_size
62 |
63 | accuracy_list = []
64 | total_forward_time = 0
65 | total_eval_data_number = 0
66 | total_correct_number = 0
67 | log = open(f"./result/{opt.exp_name}/log_all_evaluation.txt", "a")
68 | dashed_line = "-" * 80
69 | print(dashed_line)
70 | log.write(dashed_line + "\n")
71 | for eval_data in eval_data_list:
72 | eval_data_path= opt.eval_data+eval_data
73 | # eval_data_path = os.path.join(opt.eval_data, eval_data)
74 | AlignCollate_eval = AlignCollate(opt, mode="test")
75 | eval_data, eval_data_log = hierarchical_dataset(
76 | root=eval_data_path, opt=opt, mode="test"
77 | )
78 | eval_loader = torch.utils.data.DataLoader(
79 | eval_data,
80 | batch_size=eval_batch_size,
81 | shuffle=False,
82 | num_workers=int(opt.workers),
83 | collate_fn=AlignCollate_eval,
84 | pin_memory=True,
85 | )
86 |
87 | _, accuracy_by_best_model, ned_score, _, _, _, infer_time, length_of_data = validation(
88 | model, criterion, eval_loader, converter, opt, tqdm_position=0
89 | )
90 | accuracy_list.append(f"{accuracy_by_best_model:0.2f}")
91 | total_forward_time += infer_time
92 | total_eval_data_number += len(eval_data)
93 | total_correct_number += accuracy_by_best_model * length_of_data
94 | log.write(eval_data_log)
95 | print(f"Acc {accuracy_by_best_model:0.2f}")
96 | log.write(f"Acc {accuracy_by_best_model:0.2f}\n")
97 | print(f"Ned {ned_score:0.2f}")
98 | log.write(f"Ned {ned_score:0.2f}\n")
99 | print(dashed_line)
100 | log.write(dashed_line + "\n")
101 |
102 | averaged_forward_time = total_forward_time / total_eval_data_number * 1000
103 | total_accuracy = total_correct_number / total_eval_data_number
104 | params_num = sum([np.prod(p.size()) for p in model.parameters()])
105 |
106 | eval_log = "accuracy: "
107 | for name, accuracy in zip(eval_data_list, accuracy_list):
108 | eval_log += f"{name}: {accuracy}\t"
109 | eval_log += f"total_accuracy: {total_accuracy:0.2f}\t"
110 | eval_log += f"averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.2f}"
111 | print(eval_log)
112 | log.write(eval_log + "\n")
113 |
114 | # for convenience
115 | print("\t".join(accuracy_list))
116 | print(f"Total_accuracy:{total_accuracy:0.2f}")
117 | log.write("\t".join(accuracy_list) + "\n")
118 | log.write(f"Total_accuracy:{total_accuracy:0.2f}" + "\n")
119 | log.close()
120 |
121 | # for convenience
122 | today = date.today()
123 | if opt.log_multiple_test:
124 | log_all_model = open(f"./evaluation_log/log_multiple_test_{today}.txt", "a")
125 | log_all_model.write("\t".join(accuracy_list) + "\n")
126 | else:
127 | log_all_model = open(
128 | f"./evaluation_log/log_all_model_evaluation_{today}.txt", "a"
129 | )
130 | log_all_model.write(
131 | f"./result/{opt.exp_name}\tTotal_accuracy:{total_accuracy:0.2f}\n"
132 | )
133 | log_all_model.write("\t".join(accuracy_list) + "\n")
134 | log_all_model.close()
135 |
136 | return total_accuracy, eval_data_list, accuracy_list
137 |
138 |
139 | def validation(model, criterion, eval_loader, converter, opt, val_choose="val",tqdm_position=1):
140 | """validation or evaluation"""
141 | n_correct = 0
142 | norm_ED = 0
143 | length_of_data = 0
144 | infer_time = 0
145 | valid_loss_avg = Averager()
146 |
147 | for i, (image_tensors, labels) in tqdm(
148 | enumerate(eval_loader),
149 | total=len(eval_loader),
150 | position=tqdm_position,
151 | leave=False,
152 | ):
153 | batch_size = image_tensors.size(0)
154 | length_of_data = length_of_data + batch_size
155 | image = image_tensors.to(device)
156 | # For max length prediction
157 | labels_index, labels_length = converter.encode(
158 | labels, batch_max_length=opt.batch_max_length
159 | )
160 |
161 | if "CTC" in opt.Prediction:
162 | start_time = time.time()
163 | if val_choose == "FF":
164 | preds = model(image, cross = False, is_train = False)
165 | elif val_choose == "TF":
166 | preds = model(image,cross = True, is_train = False)
167 | else:
168 | preds = model(image, is_train = False)
169 | if len(preds) == 3 or len(preds) == 4:
170 | preds = preds['logits']
171 | elif len(preds) == 2:
172 | preds = preds['predict']
173 | forward_time = time.time() - start_time
174 |
175 | # Calculate evaluation loss for CTC deocder.
176 | preds_size = torch.IntTensor([preds.size(1)] * batch_size)
177 | # permute 'preds' to use CTCloss format
178 | cost = criterion(
179 | preds.log_softmax(2).permute(1, 0, 2),
180 | labels_index,
181 | preds_size,
182 | labels_length,
183 | )
184 |
185 | else:
186 | text_for_pred = (
187 | torch.LongTensor(batch_size).fill_(converter.dict["[SOS]"]).to(device)
188 | )
189 |
190 | start_time = time.time()
191 | # preds = model(image, text_for_pred, is_train=False)
192 | if val_choose == "FF":
193 | preds = model(image, cross = False,text = text_for_pred, is_train = False)
194 | elif val_choose == "TF":
195 | preds = model(image,cross = True, text = text_for_pred, is_train = False)
196 | else:
197 | preds = model(image, text = text_for_pred, is_train=False)
198 | if len(preds) == 3:
199 | preds = preds['logits']
200 | elif len(preds) == 2:
201 | preds = preds['predict']
202 | forward_time = time.time() - start_time
203 |
204 | target = labels_index[:, 1:] # without [SOS] Symbol
205 | cost = criterion(
206 | preds.contiguous().view(-1, preds.shape[-1]),
207 | target.contiguous().view(-1),
208 | )
209 |
210 | # select max probabilty (greedy decoding) then decode index to character
211 | _, preds_index = preds.max(2)
212 | preds_size = torch.IntTensor([preds.size(1)] * preds_index.size(0)).to(device)
213 | preds_str = converter.decode(preds_index, preds_size)
214 |
215 | infer_time += forward_time
216 | valid_loss_avg.add(cost)
217 |
218 | # calculate accuracy & confidence score
219 | preds_prob = F.softmax(preds, dim=2)
220 | preds_max_prob, _ = preds_prob.max(dim=2)
221 | confidence_score_list = []
222 | for gt, prd, prd_max_prob in zip(labels, preds_str, preds_max_prob):
223 | if "Attn" in opt.Prediction:
224 | prd_EOS = prd.find("[EOS]")
225 | prd = prd[:prd_EOS] # prune after "end of sentence" token ([EOS])
226 | prd_max_prob = prd_max_prob[:prd_EOS]
227 |
228 | """
229 | In our experiment, if the model predicts at least one [UNK] token, we count the word prediction as incorrect.
230 | To not take account of [UNK] token, use the below line.
231 | prd = prd.replace('[UNK]', '')
232 | """
233 |
234 | # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. = same with ASTER
235 | # gt = gt.lower()
236 | # prd = prd.lower()
237 | # alphanumeric_case_insensitve = "0123456789abcdefghijklmnopqrstuvwxyz"
238 | # out_of_alphanumeric_case_insensitve = f"[^{alphanumeric_case_insensitve}]"
239 | # gt = re.sub(out_of_alphanumeric_case_insensitve, "", gt)
240 | # prd = re.sub(out_of_alphanumeric_case_insensitve, "", prd)
241 |
242 |
243 | if opt.NED:
244 | # ICDAR2019 Normalized Edit Distance
245 | if len(gt) == 0 or len(prd) == 0:
246 | norm_ED += 0
247 | elif len(gt) > len(prd):
248 | norm_ED += 1 - edit_distance(prd, gt) / len(gt)
249 | else:
250 | norm_ED += 1 - edit_distance(prd, gt) / len(prd)
251 |
252 | if prd == gt:
253 | n_correct += 1
254 |
255 | # calculate confidence score (= multiply of prd_max_prob)
256 | try:
257 | confidence_score = prd_max_prob.cumprod(dim=0)[-1]
258 | except:
259 | confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([EOS])
260 | confidence_score_list.append(confidence_score)
261 |
262 | ned_score=None
263 |
264 | if opt.NED:
265 | # ICDAR2019 Normalized Edit Distance. In web page, they report % of norm_ED (= norm_ED * 100).
266 | ned_score = norm_ED / float(length_of_data) * 100
267 |
268 | score = n_correct / float(length_of_data) * 100 # accuracy
269 |
270 | return (
271 | valid_loss_avg.val(),
272 | score,
273 | ned_score,
274 | preds_str,
275 | confidence_score_list,
276 | labels,
277 | infer_time,
278 | length_of_data,
279 | )
280 |
281 |
282 | def test(opt):
283 | """model configuration"""
284 | opt.character = []
285 | f = open(opt.train_data+"/dict.txt")
286 | line = f.readline()
287 | while line:
288 | opt.character.append(line.strip("\n"))
289 | # print(line)
290 | line = f.readline()
291 | f.close()
292 | if "CTC" in opt.Prediction:
293 | converter = CTCLabelConverter(opt.character)
294 | else:
295 | converter = AttnLabelConverter(opt.character)
296 | opt.sos_token_index = converter.dict["[SOS]"]
297 | opt.eos_token_index = converter.dict["[EOS]"]
298 | opt.num_class = len(converter.character)
299 |
300 | model = Model(opt)
301 | print(
302 | "model input parameters",
303 | opt.imgH,
304 | opt.imgW,
305 | opt.num_fiducial,
306 | opt.input_channel,
307 | opt.output_channel,
308 | opt.hidden_size,
309 | opt.num_class,
310 | opt.batch_max_length,
311 | opt.Transformation,
312 | opt.FeatureExtraction,
313 | opt.SequenceModeling,
314 | opt.Prediction,
315 | )
316 | model = torch.nn.DataParallel(model).to(device)
317 |
318 | # load model
319 | print("loading pretrained model from %s" % opt.saved_model)
320 | try:
321 | model.load_state_dict(torch.load(opt.saved_model, map_location=device))
322 | except:
323 | print(
324 | "*** pretrained model not match strictly *** and thus load_state_dict with strict=False mode"
325 | )
326 | # pretrained_state_dict = torch.load(opt.saved_model)
327 | # for name in pretrained_state_dict:
328 | # print(name)
329 | model.load_state_dict(
330 | torch.load(opt.saved_model, map_location=device), strict=False
331 | )
332 |
333 | opt.exp_name = "_".join(opt.saved_model.split("/")[1:])
334 | # print(model)
335 |
336 | """ keep evaluation model and result logs """
337 | os.makedirs(f"./result/{opt.exp_name}", exist_ok=True)
338 | # os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/')
339 |
340 | """ setup loss """
341 | if "CTC" in opt.Prediction:
342 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
343 | else:
344 | # ignore [PAD] token
345 | criterion = torch.nn.CrossEntropyLoss(ignore_index=converter.dict["[PAD]"]).to(
346 | device
347 | )
348 |
349 | """ evaluation """
350 | model.eval()
351 | with torch.no_grad():
352 | if (
353 | opt.eval_type
354 | ): # evaluate 6 benchmark evaluation datasets or 7 additionally collected evaluation datasets
355 | benchmark_all_eval(model, criterion, converter, opt)
356 | else:
357 | log = open(f"./result/{opt.exp_name}/log_evaluation.txt", "a")
358 | AlignCollate_eval = AlignCollate(opt, mode="test")
359 | eval_data, eval_data_log = hierarchical_dataset(
360 | root=opt.eval_data, opt=opt, mode="test"
361 | )
362 | eval_loader = torch.utils.data.DataLoader(
363 | eval_data,
364 | batch_size=opt.batch_size,
365 | shuffle=False,
366 | num_workers=int(opt.workers),
367 | collate_fn=AlignCollate_eval,
368 | pin_memory=True,
369 | )
370 | _, score_by_best_model, ned_score,_, _, _, _, _ = validation(
371 | model, criterion, eval_loader, converter, opt
372 | )
373 | log.write(eval_data_log)
374 | print(f"best acc score {score_by_best_model:0.2f}")
375 | print(f"best ned score {ned_score:0.2f}")
376 | log.write(f"best acc score{score_by_best_model:0.2f}\n")
377 | log.write(f"best ned score{ned_score:0.2f}\n")
378 | log.close()
379 |
380 |
381 | if __name__ == "__main__":
382 | parser = argparse.ArgumentParser()
383 | parser.add_argument(
384 | "--config",
385 | default="config/crnn.py",
386 | help="path to validation dataset",
387 | )
388 | parser.add_argument("--eval_data", help="path to evaluation dataset")
389 | parser.add_argument(
390 | "--eval_type",
391 | type=str,
392 | help="evaluate 6 benchmark evaluation datasets or 7 additionally collected evaluation datasets |benchmark|addition|",
393 | )
394 | parser.add_argument(
395 | "--workers", type=int, help="number of data loading workers", default=4
396 | )
397 | parser.add_argument("--batch_size", type=int, default=512, help="input batch size")
398 | parser.add_argument(
399 | "--saved_model", help="path to saved_model to evaluation"
400 | )
401 | parser.add_argument(
402 | "--log_multiple_test", action="store_true", help="log_multiple_test"
403 | )
404 | """ Data processing """
405 | parser.add_argument(
406 | "--batch_max_length", type=int, default=25, help="maximum-label-length"
407 | )
408 | parser.add_argument(
409 | "--imgH", type=int, default=32, help="the height of the input image"
410 | )
411 | parser.add_argument(
412 | "--imgW", type=int, default=100, help="the width of the input image"
413 | )
414 | parser.add_argument(
415 | "--character",
416 | type=str,
417 | default="0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~",
418 | help="character label",
419 | )
420 | parser.add_argument(
421 | "--NED", action="store_true", help="For Normalized edit_distance"
422 | )
423 | parser.add_argument(
424 | "--Aug",
425 | type=str,
426 | default="None",
427 | help="whether to use augmentation |None|Blur|Crop|Rot|",
428 | )
429 | # parser.add_argument(
430 | # "--semi",
431 | # type=str,
432 | # default="None",
433 | # help="whether to use semi-supervised learning |None|PL|MT|",
434 | # )
435 | """ Model Architecture """
436 | parser.add_argument("--model_name", type=str, help="CRNN|TRBA")
437 | parser.add_argument(
438 | "--num_fiducial",
439 | type=int,
440 | default=20,
441 | help="number of fiducial points of TPS-STN",
442 | )
443 | parser.add_argument(
444 | "--input_channel",
445 | type=int,
446 | default=3,
447 | help="the number of input channel of Feature extractor",
448 | )
449 | parser.add_argument(
450 | "--output_channel",
451 | type=int,
452 | default=512,
453 | help="the number of output channel of Feature extractor",
454 | )
455 | parser.add_argument(
456 | "--hidden_size", type=int, default=256, help="the size of the LSTM hidden state"
457 | )
458 |
459 | arg = parser.parse_args()
460 | cfg = Config.fromfile(arg.config)
461 | # optcfg.model
462 | # opt.update(arg)
463 | # cfg.merge_from_dict(cfg.model)
464 | # opt.merge_from_dict(cfg.train)
465 | # opt.merge_from_dict(cfg.optimizer)
466 |
467 | opt = {}
468 | opt.update(cfg.common)
469 | opt.update(cfg.model)
470 | opt.update(cfg.train)
471 | opt.update(cfg.optimizer)
472 | opt.update(cfg.test)
473 | opt = argparse.Namespace(**opt)
474 | # opt.saved_model=cfg.test.saved_model
475 | # print(cfg.test.saved_model)
476 | if opt.model_name == "CRNN":
477 | opt.Transformation = "None"
478 | opt.FeatureExtraction = "VGG"
479 | opt.SequenceModeling = "BiLSTM"
480 | opt.Prediction = "CTC"
481 |
482 | elif opt.model_name == "TRBA":
483 | opt.Transformation = "TPS"
484 | opt.FeatureExtraction = "ResNet"
485 | opt.SequenceModeling = "BiLSTM"
486 | opt.Prediction = "Attn"
487 |
488 | elif opt.model_name == "RBA": # RBA
489 | opt.Transformation = "None"
490 | opt.FeatureExtraction = "ResNet"
491 | opt.SequenceModeling = "BiLSTM"
492 | opt.Prediction = "Attn"
493 |
494 | cudnn.benchmark = True
495 | cudnn.deterministic = True
496 | opt.num_gpu = torch.cuda.device_count()
497 | if opt.num_gpu > 1:
498 | print(
499 | "We recommend to use 1 GPU, check your GPU number, you would miss CUDA_VISIBLE_DEVICES=0 or typo"
500 | )
501 | print("To use multi-gpu setting, remove or comment out these lines")
502 | sys.exit()
503 |
504 | if sys.platform == "win32":
505 | opt.workers = 0
506 |
507 | os.makedirs(f"./evaluation_log", exist_ok=True)
508 |
509 | test(opt)
510 |
--------------------------------------------------------------------------------
/modules/model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from einops import rearrange
4 | import torch.nn as nn
5 |
6 | from modules.dm_router import DM_Router
7 | from modules.transformation import TPS_SpatialTransformerNetwork
8 | from modules.feature_extraction import (
9 | VGG_FeatureExtractor,
10 | RCNN_FeatureExtractor,
11 | ResNet_FeatureExtractor, SVTR_FeatureExtractor,
12 | )
13 | from modules.sequence_modeling import BidirectionalLSTM
14 | from modules.prediction import Attention
15 |
16 |
17 | class Model_Extractor(nn.Module):
18 | def __init__(self, opt):
19 | super(Model_Extractor, self).__init__()
20 | self.opt = opt
21 | self.stages = {
22 | "Trans": opt.Transformation,
23 | "Feat": opt.FeatureExtraction,
24 | "Seq": opt.SequenceModeling,
25 | "Pred": opt.Prediction,
26 | }
27 |
28 | """ Transformation """
29 | if opt.Transformation == "TPS":
30 | self.Transformation = TPS_SpatialTransformerNetwork(
31 | F=opt.num_fiducial,
32 | I_size=(opt.imgH, opt.imgW),
33 | I_r_size=(opt.imgH, opt.imgW),
34 | I_channel_num=opt.input_channel,
35 | )
36 | else:
37 | print("No Transformation module specified")
38 |
39 | """ FeatureExtraction """
40 | if opt.FeatureExtraction == "VGG":
41 | self.FeatureExtraction = VGG_FeatureExtractor(
42 | opt.input_channel, opt.output_channel
43 | )
44 | elif opt.FeatureExtraction == "RCNN":
45 | self.FeatureExtraction = RCNN_FeatureExtractor(
46 | opt.input_channel, opt.output_channel
47 | )
48 | elif opt.FeatureExtraction == "ResNet":
49 | self.FeatureExtraction = ResNet_FeatureExtractor(
50 | opt.input_channel, opt.output_channel
51 | )
52 | elif opt.FeatureExtraction == "SVTR":
53 | self.FeatureExtraction = SVTR_FeatureExtractor(
54 | opt.input_channel, opt.output_channel
55 | )
56 | else:
57 | raise Exception("No FeatureExtraction module specified")
58 | self.FeatureExtraction_output = opt.output_channel
59 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
60 | (None, 1)
61 | ) # Transform final (imgH/16-1) -> 1
62 |
63 | """Sequence modeling"""
64 | if opt.SequenceModeling == "BiLSTM":
65 | self.SequenceModeling = nn.Sequential(
66 | BidirectionalLSTM(
67 | self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size
68 | ),
69 | BidirectionalLSTM(
70 | opt.hidden_size, opt.hidden_size, opt.hidden_size
71 | ),
72 | )
73 | self.SequenceModeling_output = opt.hidden_size
74 | else:
75 | self.SequenceModeling = nn.Sequential(
76 | nn.Linear(
77 | self.FeatureExtraction_output, opt.hidden_size)
78 | )
79 | print("No SequenceModeling module specified")
80 | self.SequenceModeling_output = opt.hidden_size
81 |
82 | def forward(self, image):
83 | """Transformation stage"""
84 | if not self.stages["Trans"] == "None":
85 | image = self.Transformation(image)
86 |
87 | """ Feature extraction stage """
88 | visual_feature = self.FeatureExtraction(image)
89 | visual_feature = visual_feature.permute(
90 | 0, 3, 1, 2
91 | ) # [b, c, h, w] -> [b, w, c, h]
92 | visual_feature = self.AdaptiveAvgPool(
93 | visual_feature
94 | ) # [b, w, c, h] -> [b, w, c, 1]
95 | visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c]
96 |
97 | """ Sequence modeling stage """
98 | contextual_feature = self.SequenceModeling(
99 | visual_feature
100 | ) # [b, num_steps, opt.hidden_size]
101 | return contextual_feature # [b, num_steps, opt.num_class]
102 |
103 |
104 |
105 | class Model(nn.Module):
106 | def __init__(self, opt):
107 | super(Model, self).__init__()
108 | self.opt = opt
109 | self.model = Model_Extractor(opt)
110 | self.SequenceModeling_output = self.model.SequenceModeling_output
111 | self.stages = {
112 | "Pred": opt.Prediction,
113 | }
114 | self.fc = None
115 | self.Prediction=None
116 |
117 | def reset_class(self, opt, device):
118 |
119 | """Prediction"""
120 | if opt.Prediction == "CTC":
121 | self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
122 | elif opt.Prediction == "Attn":
123 | self.Prediction = Attention(
124 | self.SequenceModeling_output, opt.hidden_size, opt.num_class
125 | )
126 | else:
127 | raise Exception("Prediction is neither CTC or Attn")
128 |
129 | self.Prediction.to(device)
130 |
131 |
132 |
133 | def forward(self, image, text=None, is_train=True):
134 | """Transformation stage"""
135 | contextual_feature = self.model(image)
136 | """ Prediction stage """
137 | if self.stages["Pred"] == "CTC":
138 | prediction = self.Prediction(contextual_feature.contiguous())
139 | else:
140 | prediction = self.Prediction(
141 | contextual_feature.contiguous(),
142 | text,
143 | is_train,
144 | batch_max_length=self.opt.batch_max_length,
145 | )
146 |
147 | # return prediction # [b, num_steps, opt.num_class]
148 | return {"predict":prediction,"feature":contextual_feature}
149 |
150 | def update_fc(self, hidden_size, nb_classes,device=None):
151 | fc = nn.Linear(hidden_size, nb_classes)
152 | if self.fc is not None:
153 | nb_output = self.fc.out_features
154 | weight = copy.deepcopy(self.fc.weight.data)
155 | bias = copy.deepcopy(self.fc.bias.data)
156 | fc.weight.data[:nb_output] = weight
157 | fc.bias.data[:nb_output] = bias
158 |
159 | # del self.fc
160 | self.fc = fc
161 |
162 | def new_fc(self, hidden_size, nb_classes):
163 | # print("new_fc")
164 | self.fc = nn.Linear(hidden_size, nb_classes)
165 |
166 | def weight_align(self, increment):
167 | weights=self.fc.weight.data
168 | newnorm=(torch.norm(weights[-increment:,:],p=2,dim=1))
169 | oldnorm=(torch.norm(weights[:-increment,:],p=2,dim=1))
170 | meannew=torch.mean(newnorm)
171 | meanold=torch.mean(oldnorm)
172 | gamma=meanold/meannew
173 | print('alignweights,gamma=',gamma)
174 | self.fc.weight.data[-increment:,:]*=gamma
175 |
176 | def build_prediction(self,opt,num_class):
177 | """Prediction"""
178 | # print("build_prediction")
179 | if opt.Prediction == "CTC":
180 | # self.fc = nn.Linear(self.SequenceModeling_output, num_class)
181 | self.Prediction = self.fc
182 | # self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
183 | elif opt.Prediction == "Attn":
184 | # self.fc = nn.Linear(opt.hidden_size, num_class)
185 | self.Prediction = Attention(
186 | self.SequenceModeling_output, opt.hidden_size, num_class,self.fc
187 | )
188 | else:
189 | raise Exception("Prediction is neither CTC or Attn")
190 |
191 | def copy(self):
192 | return copy.deepcopy(self)
193 |
194 | def freeze(self):
195 | for param in self.parameters():
196 | param.requires_grad = False
197 | self.eval()
198 |
199 | return self
200 |
201 |
202 |
203 | class DERNet(Model):
204 | def __init__(self, opt):
205 | super(DERNet,self).__init__(opt)
206 | self.model = nn.ModuleList()
207 | self.out_dim = None
208 | self.fc = None
209 | self.aux_fc=None
210 | self.task_sizes = []
211 |
212 | @property
213 | def feature_dim(self):
214 | if self.out_dim is None:
215 | return 0
216 | return self.out_dim*len(self.model)
217 |
218 | def extract_vector(self, x):
219 | features = [convnet(x) for convnet in self.model]
220 | features = torch.cat(features, 1)
221 | return features
222 |
223 | def forward(self, image, text=None, is_train=True):
224 | """Transformation stage"""
225 | features = [convnet(image) for convnet in self.model]
226 | contextual_feature = torch.cat(features, -1)
227 |
228 | """ Prediction stage """
229 | if self.stages["Pred"] == "CTC":
230 | prediction = self.Prediction(contextual_feature.contiguous())
231 | else:
232 | prediction = self.Prediction(
233 | contextual_feature.contiguous(),
234 | text,
235 | is_train,
236 | batch_max_length=self.opt.batch_max_length,
237 | )
238 |
239 | """ Prediction stage """
240 | if self.stages["Pred"] == "CTC":
241 | aux_logits = self.aux_Prediction(contextual_feature[:,:,-self.out_dim:].contiguous())
242 | else:
243 | aux_logits = self.aux_Prediction(
244 | contextual_feature[:,:,-self.out_dim:].contiguous(),
245 | text,
246 | is_train,
247 | batch_max_length=self.opt.batch_max_length,
248 | )
249 | # out=self.fc(features) #{logics: self.fc(features)}
250 | out = dict({"logits":prediction})
251 | # aux_logits=self.aux_fc(contextual_feature[:,-self.out_dim:])
252 |
253 | out.update({"aux_logits":aux_logits,"features":contextual_feature.contiguous()})
254 | return out # [b, num_steps, opt.num_class]
255 |
256 | def update_fc(self, hidden_size, nb_classes,device=None):
257 | if len(self.model)==0:
258 | self.model.append(Model_Extractor(self.opt))
259 | else:
260 | self.model.append(Model_Extractor(self.opt))
261 | self.model[-1].load_state_dict(self.model[-2].state_dict())
262 |
263 | if self.out_dim is None:
264 | self.out_dim=self.model[-1].SequenceModeling_output
265 | fc = nn.Linear(self.feature_dim if self.opt.Prediction=="CTC" else self.out_dim, nb_classes)
266 | if self.fc is not None:
267 | nb_output = self.fc.out_features
268 | weight = copy.deepcopy(self.fc.weight.data)
269 | bias = copy.deepcopy(self.fc.bias.data)
270 | fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight
271 | fc.bias.data[:nb_output] = bias
272 |
273 | del self.fc
274 | self.fc = fc
275 | # new_task_size = nb_classes - sum(self.task_sizes)
276 | # self.task_sizes.append(new_task_size)
277 |
278 | self.aux_fc= nn.Linear(self.out_dim,nb_classes)
279 |
280 | def build_prediction(self,opt,num_class):
281 | """Prediction"""
282 | # print("build_prediction")
283 | if opt.Prediction == "CTC":
284 | # self.fc = nn.Linear(self.SequenceModeling_output, num_class)
285 | self.Prediction = self.fc
286 | # self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
287 | elif opt.Prediction == "Attn":
288 | # self.fc = nn.Linear(opt.hidden_size, num_class)
289 | self.Prediction = Attention(
290 | self.feature_dim, opt.hidden_size, num_class,self.fc
291 | )
292 | else:
293 | raise Exception("Prediction is neither CTC or Attn")
294 |
295 | def build_aux_prediction(self,opt,num_class):
296 | """Prediction"""
297 | if opt.Prediction == "CTC":
298 | # self.aux_fc = nn.Linear(self.SequenceModeling_output, num_class)
299 | self.aux_Prediction = self.aux_fc
300 | # self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
301 | elif opt.Prediction == "Attn":
302 | # self.aux_fc = nn.Linear(opt.hidden_size, num_class)
303 | self.aux_Prediction = Attention(
304 | self.SequenceModeling_output, opt.hidden_size, num_class,self.aux_fc
305 | )
306 | else:
307 | raise Exception("Prediction is neither CTC or Attn")
308 |
309 | def freeze_conv(self):
310 | for param in self.model.parameters():
311 | param.requires_grad = False
312 | self.model.eval()
313 |
314 | class MRNNet(nn.Module):
315 | def __init__(self, opt):
316 | super(MRNNet, self).__init__()
317 | self.model = nn.ModuleList()
318 | self.out_dim=None
319 | self.fc = None
320 | self.opt = opt
321 | self.task_sizes = []
322 | if self.opt.FeatureExtraction == "VGG":
323 | self.patch = 63
324 | elif self.opt.FeatureExtraction == "SVTR":
325 | self.patch = 64
326 | elif self.opt.FeatureExtraction == "ResNet":
327 | self.patch = 65
328 | self.router = "dm-router" #dm-router
329 | self.layer_num = 1
330 | self.beta = 1
331 |
332 | @property
333 | def feature_dim(self):
334 | if self.out_dim is None:
335 | return 0
336 | return self.out_dim*len(self.model)
337 |
338 | def extract_vector(self, x):
339 | features = [convnet(x) for convnet in self.model]
340 | features = torch.cat(features, 1)
341 | return features
342 |
343 | def forward(self, image, cross = True,text=None, is_train=True):
344 | """Transformation stage"""
345 | # features = [convnet(image) for convnet in self.model]
346 | if cross==False:
347 | features = self.model[-1](image,text,is_train)["predict"]
348 | index = None
349 | # elif is_train == False:
350 | # features, index = self.cross_test(image)
351 | elif is_train == False:
352 | features, index = self.cross_forward_expert(image, text, is_train)
353 | else:
354 | # features,index = self.cross_forwardv2(image)
355 | features, index = self.cross_forward(image,text,is_train)
356 | # out=self.fc(features) #{logics: self.fc(features)}
357 | out = dict({"logits":features,"index":index,"aux_logits":None})
358 |
359 | return out # [b, num_steps, opt.num_class]
360 |
361 | def pad_zeros_features(self,feature,total):
362 | B,T,know = feature.size()
363 | zero = torch.ones([B,T,total-know],dtype=torch.float).to(feature.device)
364 | return torch.cat([feature,zero],dim=-1)
365 |
366 | def cross_forward_expert(self, image, text=None, is_train=True):
367 | """Transformation stage"""
368 | features = [convnet(image,text,is_train) for convnet in self.model]
369 | route_info = torch.stack([feature["feature"] for feature in features], 1)
370 | route_info = self.dm_router(route_info)
371 | route_info = rearrange(route_info, 'b h w c -> b w (h c)')
372 | route_info = self.channel_route(route_info)
373 | # route_info = torch.cat([torch.max(feature,-1)[0] for feature in features],-1)
374 | index = self.route(route_info.permute(0, 2, 1).contiguous())
375 | # index = self.softargmax1d(torch.squeeze(index, -1),self.beta)
376 | index = torch.squeeze(index, -1)
377 | index = torch.max(index, -1)[1]
378 |
379 | # index [B,I]
380 | # route_info [B,T,I]
381 |
382 | # feature_array = torch.stack(features, 1)
383 | features = [feature["predict"] for feature in features]
384 | B, T, C = features[-1].size()
385 | list_len = len(features)
386 | normal_feat = []
387 | for i in range(list_len - 1):
388 | feat = self.pad_zeros_features(features[i], total=C)
389 | normal_feat.append(feat)
390 | normal_feat.append(features[-1])
391 | normal_feat = torch.stack(normal_feat, 0)
392 | # normal_feat [I,B,T,C] -> [T,C,B,I] -> [B,T,C,I]
393 | output = torch.stack([normal_feat[index_one][i,:,:]for i,index_one in enumerate(index)],0)
394 |
395 | return output.contiguous(),index
396 |
397 | def cross_forward(self, image, text=None, is_train=True):
398 | """Transformation stage"""
399 | features = [convnet(image,text,is_train) for convnet in self.model]
400 | route_info = torch.stack([feature["feature"] for feature in features], 1)
401 | route_info = self.dm_router(route_info)
402 | route_info = rearrange(route_info, 'b h w c -> b w (h c)')
403 | route_info = self.channel_route(route_info)
404 | # route_info = torch.cat([torch.max(feature,-1)[0] for feature in features],-1)
405 | index = self.route(route_info.permute(0, 2, 1).contiguous())
406 | index = self.softargmax1d(torch.squeeze(index, -1),self.beta)
407 | # index [B,I]
408 | # route_info [B,T,I]
409 |
410 | features = [feature["predict"] for feature in features]
411 | B, T, C = features[-1].size()
412 | list_len = len(features)
413 | normal_feat = []
414 | for i in range(list_len - 1):
415 | feat = self.pad_zeros_features(features[i], total=C)
416 | normal_feat.append(feat)
417 | normal_feat.append(features[-1])
418 | normal_feat = torch.stack(normal_feat, 0)
419 | # normal_feat [I,B,T,C] -> [T,C,B,I] -> [B,T,C,I]
420 | output = (normal_feat.permute(2, 3, 1, 0) * index).permute(2, 0, 1, 3).contiguous()
421 | # output = (normal_feat.permute(3,1,2,0) * route_info).permute(1,2,0,3).contiguous()
422 |
423 | return torch.sum(output, -1), index
424 |
425 | def build_fc(self, hidden_size, nb_classes):
426 | self.update_fc(hidden_size, nb_classes)
427 |
428 | def update_fc(self, hidden_size, nb_classes):
429 | self.model.append(Model(self.opt))
430 | self.model[-1].new_fc(hidden_size,nb_classes)
431 | # self.model[-1].load_state_dict(self.model[-2].state_dict())
432 |
433 | if self.out_dim is None:
434 | self.out_dim=self.model[-1].SequenceModeling_output
435 |
436 |
437 | self.route = nn.Linear(self.patch , 1)
438 | self.channel_route = nn.Linear(self.feature_dim, len(self.model))
439 | # if self.router == "gmlp":
440 | # block = GatingMlpBlock(self.out_dim, self.out_dim * 2, self.patch)
441 | # elif self.router == "vip":
442 | # block = PermutatorBlock(self.out_dim, 2, taski = len(self.model), patch = self.patch)
443 | # el
444 | if self.router == "dm-router":
445 | block = DM_Router(self.out_dim, self.out_dim * 2, self.patch,len(self.model))
446 | else:
447 | block = nn.Linear(self.out_dim, self.out_dim )
448 | layers=[]
449 | for _ in range(self.layer_num):
450 | layers.append(block)
451 | print("mlp {} has {} layers".format(block, len(layers)))
452 | self.dm_router = nn.Sequential(*layers)
453 | # [b, num_steps * len] -> [b, len]
454 | # if self.fc is not None:
455 | # nb_output = self.fc.out_features
456 | # weight = copy.deepcopy(self.fc.weight.data)
457 | # bias = copy.deepcopy(self.fc.bias.data)
458 | # fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight
459 | # fc.bias.data[:nb_output] = bias
460 | #
461 | # del self.fc
462 | # self.fc = fc
463 | # fc = nn.Linear(self.feature_dim, nb_classes)
464 | def load_fc(self,input,output):
465 | fc = nn.Linear(input,output)
466 | if self.channel_route is not None:
467 | nb_output = self.channel_route.out_features
468 | weight = copy.deepcopy(self.channel_route.weight.data)
469 | bias = copy.deepcopy(self.channel_route.bias.data)
470 | fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight
471 | fc.bias.data[:nb_output] = bias
472 |
473 | del self.fc
474 | self.fc = fc
475 |
476 | def build_prediction(self,opt,num_class):
477 | """Prediction"""
478 | if opt.Prediction == "CTC" or opt.Prediction == "Attn":
479 | # self.fc = nn.Linear(self.SequenceModeling_output, num_class)
480 | # self.Prediction = self.fc
481 | self.model[-1].build_prediction(opt,num_class)
482 | else:
483 | raise Exception("Prediction is neither CTC or Attn")
484 |
485 | def copy(self):
486 | return copy.deepcopy(self)
487 |
488 | def freeze(self):
489 | for param in self.parameters():
490 | param.requires_grad = False
491 | self.eval()
492 |
493 | return self
494 |
495 | def softargmax1d(self,input, beta=5):
496 | return nn.functional.softmax(beta * input, dim=-1)
497 |
498 |
499 |
--------------------------------------------------------------------------------