├── config ├── __init__.py ├── crnn_der.py ├── svtr_der.py ├── trba_der.py ├── trba_mrn.py ├── crnn_mrn.py └── svtr_mrn.py ├── data ├── __init__.py ├── dataset.py ├── data_manage.py └── transform.py ├── tools ├── __init__.py ├── utils.py └── crop_by_word.py ├── il_modules ├── __init__.py ├── joint.py ├── wa.py ├── lwf.py ├── ewc.py └── der.py ├── data.sh ├── train.sh ├── demo_image └── teaser.png ├── .idea ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── misc.xml ├── modules.xml ├── remote-mappings.xml ├── STR-Fewer-Labels.iml ├── STR-Fewer-Labels-main.iml └── deployment.xml ├── modules ├── sequence_modeling.py ├── dm_router.py ├── prediction.py ├── transformation.py ├── mlp.py └── model.py ├── LICENSE ├── .gitignore ├── README.md ├── tiny_train.py └── test.py /config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /il_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data.sh: -------------------------------------------------------------------------------- 1 | python tools/crop_by_word.py ../dataset/SynthMLT/ Latin txt 2 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python3 tiny_train.py --config=config/crnn_3090.py --exp_name CRNN_real 2 | -------------------------------------------------------------------------------- /demo_image/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/MRN/HEAD/demo_image/teaser.png -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/STR-Fewer-Labels.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/STR-Fewer-Labels-main.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | def __init__(self, input_size, hidden_size, output_size): 6 | super(BidirectionalLSTM, self).__init__() 7 | self.rnn = nn.LSTM( 8 | input_size, hidden_size, bidirectional=True, batch_first=True 9 | ) 10 | self.linear = nn.Linear(hidden_size * 2, output_size) 11 | 12 | def forward(self, input): 13 | """ 14 | input : visual feature [batch_size x T x input_size], T = num_steps. 15 | output : contextual feature [batch_size x T x output_size] 16 | """ 17 | self.rnn.flatten_parameters() 18 | recurrent, _ = self.rnn( 19 | input 20 | ) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 21 | output = self.linear(recurrent) # batch_size x T x output_size 22 | return output 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Baek JeongHun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 36 | -------------------------------------------------------------------------------- /config/crnn_der.py: -------------------------------------------------------------------------------- 1 | common=dict( 2 | exp_name="CRNN_DER", # Where to store logs and models 3 | il="der", # joint_mix | joint_loader | base | lwf | wa | ewc | der | mrn 4 | memory="random", # None | random 5 | memory_num=2000, 6 | batch_max_length = 25, 7 | imgH = 32, 8 | imgW = 256, 9 | manual_seed=111, 10 | start_task = 0 11 | ) 12 | 13 | 14 | """ Model Architecture """ 15 | model=dict( 16 | model_name="CRNN", 17 | Transformation = "None", #None TPS 18 | FeatureExtraction = "VGG", #VGG ResNet 19 | SequenceModeling = "BiLSTM", #None BiLSTM 20 | Prediction = "CTC", #CTC Attn 21 | num_fiducial=20, 22 | input_channel=4, 23 | output_channel=512, 24 | hidden_size=256, 25 | ) 26 | 27 | 28 | """ Optimizer """ 29 | optimizer=dict( 30 | schedule="super", #default is super for super convergence, 1 for None, [0.6, 0.8] for the same setting with ASTER 31 | optimizer="adam", 32 | lr=0.0005, 33 | sgd_momentum=0.9, 34 | sgd_weight_decay=0.000001, 35 | milestones=[2000,4000], 36 | lrate_decay=0.1, 37 | rho=0.95, 38 | eps=1e-8, 39 | lr_drop_rate=0.1 40 | ) 41 | 42 | 43 | """ Data processing """ 44 | train = dict( 45 | saved_model="", # "path to model to continue training" 46 | Aug="None", # |None|Blur|Crop|Rot|ABINet 47 | workers=4, 48 | lan_list=["Chinese","Latin","Japanese", "Korean", "Arabic", "Bangla"], 49 | valid_datas=[ 50 | "../dataset/MLT17_IL/test_2017", 51 | "../dataset/MLT19_IL/test_2019" 52 | ], 53 | select_data=[ 54 | "../dataset/MLT17_IL/train_2017", 55 | "../dataset/MLT19_IL/train_2019" 56 | ], 57 | batch_ratio="0.5-0.5", 58 | total_data_usage_ratio="1.0", 59 | NED=True, 60 | batch_size=256, 61 | num_iter=10000, 62 | val_interval=5000, 63 | log_multiple_test=None, 64 | grad_clip=5, 65 | ) 66 | 67 | 68 | -------------------------------------------------------------------------------- /config/svtr_der.py: -------------------------------------------------------------------------------- 1 | common=dict( 2 | exp_name="SVTR_DER", # Where to store logs and models 3 | il="der", # joint_mix | joint_loader | base | lwf | wa | ewc | der | mrn 4 | memory="random", # None | random 5 | memory_num=2000, 6 | batch_max_length = 25, 7 | imgH = 32, 8 | imgW = 256, 9 | manual_seed=111, 10 | start_task = 0 11 | ) 12 | 13 | 14 | """ Model Architecture """ 15 | model=dict( 16 | model_name="SVTR", 17 | Transformation = "None", #None TPS 18 | FeatureExtraction = "SVTR", #VGG ResNet 19 | SequenceModeling = "None", #None BiLSTM 20 | Prediction = "CTC", #CTC Attn 21 | num_fiducial=20, 22 | input_channel=4, 23 | output_channel=512, 24 | hidden_size=256, 25 | ) 26 | 27 | 28 | """ Optimizer """ 29 | optimizer=dict( 30 | schedule="super", #default is super for super convergence, 1 for None, [0.6, 0.8] for the same setting with ASTER 31 | optimizer="adam", 32 | lr=0.0005, 33 | sgd_momentum=0.9, 34 | sgd_weight_decay=0.000001, 35 | milestones=[2000,4000], 36 | lrate_decay=0.1, 37 | rho=0.95, 38 | eps=1e-8, 39 | lr_drop_rate=0.1 40 | ) 41 | 42 | 43 | """ Data processing """ 44 | train = dict( 45 | saved_model="", # "path to model to continue training" 46 | Aug="None", # |None|Blur|Crop|Rot|ABINet 47 | workers=4, 48 | lan_list=["Chinese","Latin","Japanese", "Korean", "Arabic", "Bangla"], 49 | valid_datas=[ 50 | "../dataset/MLT17_IL/test_2017", 51 | "../dataset/MLT19_IL/test_2019" 52 | ], 53 | select_data=[ 54 | "../dataset/MLT17_IL/train_2017", 55 | "../dataset/MLT19_IL/train_2019" 56 | ], 57 | batch_ratio="0.5-0.5", 58 | total_data_usage_ratio="1.0", 59 | NED=True, 60 | batch_size=256, 61 | num_iter=10000, 62 | val_interval=5000, 63 | log_multiple_test=None, 64 | grad_clip=5, 65 | ) 66 | 67 | 68 | -------------------------------------------------------------------------------- /config/trba_der.py: -------------------------------------------------------------------------------- 1 | common=dict( 2 | exp_name="TRBA_DER", # Where to store logs and models 3 | il="der", # joint_mix | joint_loader | base | lwf | wa | ewc | der | mrn 4 | memory="random", # None | random 5 | memory_num=2000, 6 | batch_max_length = 25, 7 | imgH = 32, 8 | imgW = 256, 9 | manual_seed=111, 10 | start_task = 0 11 | ) 12 | 13 | 14 | """ Model Architecture """ 15 | model=dict( 16 | model_name="TRBA", 17 | Transformation = "TPS", #None TPS 18 | FeatureExtraction = "ResNet", #VGG ResNet 19 | SequenceModeling = "BiLSTM", #None BiLSTM 20 | Prediction = "Attn", #CTC Attn 21 | num_fiducial=20, 22 | input_channel=4, 23 | output_channel=512, 24 | hidden_size=256, 25 | ) 26 | 27 | 28 | """ Optimizer """ 29 | optimizer=dict( 30 | schedule="super", #default is super for super convergence, 1 for None, [0.6, 0.8] for the same setting with ASTER 31 | optimizer="adam", 32 | lr=0.0005, 33 | sgd_momentum=0.9, 34 | sgd_weight_decay=0.000001, 35 | milestones=[2000,4000], 36 | lrate_decay=0.1, 37 | rho=0.95, 38 | eps=1e-8, 39 | lr_drop_rate=0.1 40 | ) 41 | 42 | 43 | """ Data processing """ 44 | train = dict( 45 | saved_model="", # "path to model to continue training" 46 | Aug="None", # |None|Blur|Crop|Rot|ABINet 47 | workers=4, 48 | lan_list=["Chinese","Latin","Japanese", "Korean", "Arabic", "Bangla"], 49 | valid_datas=[ 50 | "../dataset/MLT17_IL/test_2017", 51 | "../dataset/MLT19_IL/test_2019" 52 | ], 53 | select_data=[ 54 | "../dataset/MLT17_IL/train_2017", 55 | "../dataset/MLT19_IL/train_2019" 56 | ], 57 | batch_ratio="0.5-0.5", 58 | total_data_usage_ratio="1.0", 59 | NED=True, 60 | batch_size=256, 61 | num_iter=10000, 62 | val_interval=5000, 63 | log_multiple_test=None, 64 | grad_clip=5, 65 | ) 66 | 67 | 68 | -------------------------------------------------------------------------------- /config/trba_mrn.py: -------------------------------------------------------------------------------- 1 | common=dict( 2 | exp_name="TRBA_MRN", # Where to store logs and models 3 | il="mrn", # joint_mix | joint_loader | base | lwf | wa | ewc | der | mrn 4 | memory="random", # None | random 5 | memory_num=2000, 6 | batch_max_length = 25, 7 | imgH = 32, 8 | imgW = 256, 9 | manual_seed=111, 10 | start_task = 0 11 | ) 12 | 13 | 14 | """ Model Architecture """ 15 | model=dict( 16 | model_name="TRBA", 17 | Transformation = "TPS", #None TPS 18 | FeatureExtraction = "ResNet", #VGG ResNet 19 | SequenceModeling = "BiLSTM", #None BiLSTM 20 | Prediction = "Attn", #CTC Attn 21 | num_fiducial=20, 22 | input_channel=4, 23 | output_channel=512, 24 | hidden_size=256, 25 | ) 26 | 27 | 28 | 29 | """ Optimizer """ 30 | optimizer=dict( 31 | schedule="super", #default is super for super convergence, 1 for None, [0.6, 0.8] for the same setting with ASTER 32 | optimizer="adam", 33 | lr=0.0005, 34 | sgd_momentum=0.9, 35 | sgd_weight_decay=0.000001, 36 | milestones=[2000,4000], 37 | lrate_decay=0.1, 38 | rho=0.95, 39 | eps=1e-8, 40 | lr_drop_rate=0.1 41 | ) 42 | 43 | 44 | """ Data processing """ 45 | train = dict( 46 | saved_model="", # "path to model to continue training" 47 | Aug="None", # |None|Blur|Crop|Rot|ABINet 48 | workers=4, 49 | lan_list=["Chinese","Latin","Japanese", "Korean", "Arabic", "Bangla"], 50 | valid_datas=[ 51 | "../dataset/MLT17_IL/test_2017", 52 | "../dataset/MLT19_IL/test_2019" 53 | ], 54 | select_data=[ 55 | "../dataset/MLT17_IL/train_2017", 56 | "../dataset/MLT19_IL/train_2019" 57 | ], 58 | batch_ratio="0.5-0.5", 59 | total_data_usage_ratio="1.0", 60 | NED=True, 61 | batch_size=256, 62 | num_iter=10000, 63 | val_interval=5000, 64 | log_multiple_test=None, 65 | grad_clip=5, 66 | ) 67 | 68 | 69 | -------------------------------------------------------------------------------- /config/crnn_mrn.py: -------------------------------------------------------------------------------- 1 | common=dict( 2 | exp_name="CRNN_MRN", # Where to store logs and models 3 | il="mrn", # joint_mix | joint_loader | base | lwf | wa | ewc | der | mrn 4 | memory="random", # None | random 5 | memory_num=2000, 6 | batch_max_length = 25, 7 | imgH = 32, 8 | imgW = 256, 9 | manual_seed=111, 10 | start_task = 1 11 | ) 12 | 13 | 14 | """ Model Architecture """ 15 | model=dict( 16 | model_name="CRNN", 17 | Transformation = "None", #None TPS 18 | FeatureExtraction = "VGG", #VGG ResNet 19 | SequenceModeling = "BiLSTM", #None BiLSTM 20 | Prediction = "CTC", #CTC Attn 21 | num_fiducial=20, 22 | input_channel=4, 23 | output_channel=512, 24 | hidden_size=256, 25 | ) 26 | 27 | 28 | """ Optimizer """ 29 | optimizer=dict( 30 | schedule="super", #default is super for super convergence, 1 for None, [0.6, 0.8] for the same setting with ASTER 31 | optimizer="adam", 32 | lr=0.0005, 33 | sgd_momentum=0.9, 34 | sgd_weight_decay=0.000001, 35 | milestones=[2000,4000], 36 | lrate_decay=0.1, 37 | rho=0.95, 38 | eps=1e-8, 39 | lr_drop_rate=0.1 40 | ) 41 | 42 | 43 | """ Data processing """ 44 | train = dict( 45 | saved_model="", # "path to model to continue training" 46 | Aug="None", # |None|Blur|Crop|Rot|ABINet 47 | workers=4, 48 | lan_list=["Chinese","Latin","Japanese", "Korean", "Arabic", "Bangla",], 49 | valid_datas=[ 50 | "../dataset/MLT17_IL/test_2017", 51 | "../dataset/MLT19_IL/test_2019" 52 | ], 53 | select_data=[ 54 | "../dataset/MLT17_IL/train_2017", 55 | "../dataset/MLT19_IL/train_2019" 56 | ], 57 | # train_data="../dataset/MLT17/train_2017", # stash 58 | # valid_data="../dataset/MLT17/test_2017", # stash 59 | batch_ratio="0.5-0.5", 60 | total_data_usage_ratio="1.0", 61 | NED=True, 62 | batch_size=256, 63 | num_iter=10000, 64 | val_interval=5000, 65 | log_multiple_test=None, 66 | grad_clip=5, 67 | ) 68 | -------------------------------------------------------------------------------- /config/svtr_mrn.py: -------------------------------------------------------------------------------- 1 | common=dict( 2 | exp_name="CRNN_MRN", # Where to store logs and models 3 | il="mrn", # joint_mix | joint_loader | base | lwf | wa | ewc | der | mrn 4 | memory="random", # None | random 5 | # None | random | test (just for ems) |large | total | score | cof_max | rehearsal | 6 | memory_num=2000, 7 | batch_max_length = 25, 8 | imgH = 32, 9 | imgW = 256, 10 | manual_seed=111, 11 | start_task = 0 12 | ) 13 | 14 | 15 | """ Model Architecture """ 16 | model=dict( 17 | model_name="SVTR", 18 | Transformation = "None", #None TPS 19 | FeatureExtraction = "SVTR", #VGG ResNet 20 | SequenceModeling = "None", #None BiLSTM 21 | Prediction = "CTC", #CTC Attn 22 | num_fiducial=20, 23 | input_channel=4, 24 | output_channel=512, 25 | hidden_size=256, 26 | ) 27 | 28 | 29 | """ Optimizer """ 30 | optimizer=dict( 31 | schedule="super", #default is super for super convergence, 1 for None, [0.6, 0.8] for the same setting with ASTER 32 | optimizer="adam", 33 | lr=0.0005, 34 | sgd_momentum=0.9, 35 | sgd_weight_decay=0.000001, 36 | milestones=[2000,4000], 37 | lrate_decay=0.1, 38 | rho=0.95, 39 | eps=1e-8, 40 | lr_drop_rate=0.1 41 | ) 42 | 43 | 44 | """ Data processing """ 45 | train = dict( 46 | saved_model="", # "path to model to continue training" 47 | Aug="None", # |None|Blur|Crop|Rot|ABINet 48 | workers=4, 49 | lan_list=["Chinese","Latin","Japanese", "Korean", "Arabic", "Bangla"], 50 | valid_datas=[ 51 | "../dataset/MLT17_IL/test_2017", 52 | "../dataset/MLT19_IL/test_2019" 53 | ], 54 | select_data=[ 55 | "../dataset/MLT17_IL/train_2017", 56 | "../dataset/MLT19_IL/train_2019" 57 | ], 58 | batch_ratio="0.5-0.5", 59 | total_data_usage_ratio="1.0", 60 | NED=True, 61 | batch_size=256, 62 | num_iter=10000, 63 | val_interval=5000, 64 | log_multiple_test=None, 65 | grad_clip=5, 66 | # FT="init", 67 | # self_pre="RotNet", # whether to use `RotNet` or `MoCo` pretrained model. 68 | # semi="None", #|None|PL|MT| 69 | # MT_C=1, 70 | # MT_alpha=0.999, 71 | # model_for_PseudoLabel="", 72 | ) 73 | 74 | -------------------------------------------------------------------------------- /modules/dm_router.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from einops import rearrange 3 | 4 | class SpatialDomainGating(nn.Module): 5 | def __init__(self, d_ffn, seq_len): 6 | super().__init__() 7 | 8 | self.norm = nn.LayerNorm(d_ffn // 2) 9 | self.proj = nn.Linear(seq_len, seq_len) 10 | 11 | def forward(self, x): 12 | u, v = x.chunk(2, dim=-1) 13 | v = self.norm(v) 14 | v = v.permute(0, 2, 1) 15 | v = self.proj(v) 16 | v = v.permute(0, 2, 1) 17 | return u * v 18 | 19 | class ChannelDomainGating(nn.Module): 20 | def __init__(self, d_ffn, seq_len): 21 | super().__init__() 22 | 23 | self.norm = nn.LayerNorm(d_ffn) 24 | self.proj = nn.Linear(seq_len, seq_len) 25 | 26 | def forward(self, x): 27 | # b w (i c) 28 | # x = u 29 | v = self.norm(x) 30 | v = v.permute(0, 2, 1) 31 | v = self.proj(v) 32 | v = v.permute(0, 2, 1) 33 | return x * v 34 | 35 | class DM_Router(nn.Module): 36 | def __init__(self, channel, d_ffn, patch,domain): 37 | super().__init__() 38 | self.patch = patch 39 | self.channel = channel 40 | self.norm = nn.LayerNorm(channel) 41 | self.proj_1 = nn.Linear(channel, d_ffn) 42 | self.activation = nn.GELU() 43 | self.spatial_gating = SpatialDomainGating(d_ffn, patch * domain) 44 | self.channel_gating = ChannelDomainGating(patch, domain * channel) 45 | self.proj_2 = nn.Linear(d_ffn//2, channel) 46 | self.proj_3 = nn.Linear(channel, channel) 47 | # self.route = nn.Linear(self.patch , 1) 48 | # self.channel_route = nn.Linear(self.feature_dim, domain) 49 | 50 | def forward(self, x): 51 | # if self.training and torch.equal(self.m.sample(), torch.zeros(1)): 52 | # return x 53 | # B, H, W, C = x.shape 54 | shorcut = x.clone() 55 | x = self.norm(x) 56 | x = self.proj_1(x) 57 | x = self.activation(x) 58 | x = rearrange(x,'b d p c -> b (d p) c') 59 | x = self.spatial_gating(x) 60 | x = self.proj_2(x) 61 | x = rearrange(x, 'b (d p) c -> b d p c',p=self.patch) 62 | x = x + shorcut 63 | x = rearrange(x,'b d p c -> b (d c) p',c=self.channel) 64 | x = self.channel_gating(x) 65 | x = rearrange(x, 'b (d c) p -> b d p c', c=self.channel) 66 | x = self.proj_3(x) 67 | return x + shorcut -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/tensorboard/* 2 | **/saved_models/* 3 | **/data_CVPR2021/* 4 | **/image_release/* 5 | **/image_test/* 6 | **/result/* 7 | *.mdb 8 | *.pth 9 | *.tar 10 | *.txt 11 | *.ipynb 12 | *.zip 13 | *.eps 14 | *.pdf 15 | 16 | ### Linux ### 17 | *~ 18 | 19 | # temporary files which can be created if a process still has a handle open of a deleted file 20 | .fuse_hidden* 21 | 22 | # KDE directory preferences 23 | .directory 24 | 25 | # Linux trash folder which might appear on any partition or disk 26 | .Trash-* 27 | 28 | # .nfs files are created when an open file is removed but is still being accessed 29 | .nfs* 30 | 31 | ### OSX ### 32 | # General 33 | .DS_Store 34 | .AppleDouble 35 | .LSOverride 36 | 37 | # Icon must end with two \r 38 | Icon 39 | 40 | # Thumbnails 41 | ._* 42 | 43 | # Files that might appear in the root of a volume 44 | .DocumentRevisions-V100 45 | .fseventsd 46 | .Spotlight-V100 47 | .TemporaryItems 48 | .Trashes 49 | .VolumeIcon.icns 50 | .com.apple.timemachine.donotpresent 51 | 52 | # Directories potentially created on remote AFP share 53 | .AppleDB 54 | .AppleDesktop 55 | Network Trash Folder 56 | Temporary Items 57 | .apdisk 58 | 59 | ### Python ### 60 | # Byte-compiled / optimized / DLL files 61 | __pycache__/ 62 | *.py[cod] 63 | *$py.class 64 | 65 | # C extensions 66 | *.so 67 | 68 | # Distribution / packaging 69 | .Python 70 | build/ 71 | develop-eggs/ 72 | dist/ 73 | downloads/ 74 | eggs/ 75 | .eggs/ 76 | lib/ 77 | lib64/ 78 | parts/ 79 | sdist/ 80 | var/ 81 | wheels/ 82 | *.egg-info/ 83 | .installed.cfg 84 | *.egg 85 | MANIFEST 86 | 87 | # PyInstaller 88 | # Usually these files are written by a python script from a template 89 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 90 | *.manifest 91 | *.spec 92 | 93 | # Installer logs 94 | pip-log.txt 95 | pip-delete-this-directory.txt 96 | 97 | # Unit test / coverage reports 98 | htmlcov/ 99 | .tox/ 100 | .coverage 101 | .coverage.* 102 | .cache 103 | nosetests.xml 104 | coverage.xml 105 | *.cover 106 | .hypothesis/ 107 | .pytest_cache/ 108 | 109 | # Translations 110 | *.mo 111 | *.pot 112 | 113 | # Django stuff: 114 | *.log 115 | local_settings.py 116 | db.sqlite3 117 | 118 | # Flask stuff: 119 | instance/ 120 | .webassets-cache 121 | 122 | # Scrapy stuff: 123 | .scrapy 124 | 125 | # Sphinx documentation 126 | docs/_build/ 127 | 128 | # PyBuilder 129 | target/ 130 | 131 | # Jupyter Notebook 132 | .ipynb_checkpoints 133 | 134 | # IPython 135 | profile_default/ 136 | ipython_config.py 137 | 138 | # pyenv 139 | .python-version 140 | 141 | # celery beat schedule file 142 | celerybeat-schedule 143 | 144 | # SageMath parsed files 145 | *.sage.py 146 | 147 | # Environments 148 | .env 149 | .venv 150 | env/ 151 | venv/ 152 | ENV/ 153 | env.bak/ 154 | venv.bak/ 155 | 156 | # Spyder project settings 157 | .spyderproject 158 | .spyproject 159 | 160 | # Rope project settings 161 | .ropeproject 162 | 163 | # mkdocs documentation 164 | /site 165 | 166 | # mypy 167 | .mypy_cache/ 168 | .dmypy.json 169 | dmypy.json 170 | 171 | ### Python Patch ### 172 | .venv/ 173 | 174 | ### Python.VirtualEnv Stack ### 175 | # Virtualenv 176 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 177 | [Bb]in 178 | [Ii]nclude 179 | [Ll]ib 180 | [Ll]ib64 181 | [Ll]ocal 182 | [Ss]cripts 183 | pyvenv.cfg 184 | pip-selfcheck.json 185 | 186 | ### Windows ### 187 | # Windows thumbnail cache files 188 | Thumbs.db 189 | ehthumbs.db 190 | ehthumbs_vista.db 191 | 192 | # Dump file 193 | *.stackdump 194 | 195 | # Folder config file 196 | [Dd]esktop.ini 197 | 198 | # Recycle Bin used on file shares 199 | $RECYCLE.BIN/ 200 | 201 | # Windows Installer files 202 | *.cab 203 | *.msi 204 | *.msix 205 | *.msm 206 | *.msp 207 | 208 | # Windows shortcuts 209 | *.lnk 210 | -------------------------------------------------------------------------------- /il_modules/joint.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.utils.data 4 | from tqdm import tqdm 5 | 6 | from il_modules.base import BaseLearner 7 | from tools.utils import Averager, adjust_learning_rate 8 | 9 | class JointLearner(BaseLearner): 10 | 11 | def incremental_train(self,taski, character, train_loader, valid_loader,AlignCollate_valid,valid_datas): 12 | 13 | # pre task classes for know classes 14 | # self._known_classes = self._total_classes 15 | self.character = character 16 | self.converter = self.build_converter() 17 | valid_loader = valid_loader.create_list_dataset(valid_datas=valid_datas) 18 | 19 | if taski > 0: 20 | self.change_model() 21 | else: 22 | self.criterion = self.build_criterion() 23 | self.build_model() 24 | 25 | # filter that only require gradient descent 26 | filtered_parameters = self.count_param() 27 | 28 | # setup optimizer 29 | self.build_optimizer(filtered_parameters) 30 | 31 | # print opt config 32 | # self.print_config(self.opt) 33 | 34 | """ start training """ 35 | start_iter = 0 36 | if self.opt.saved_model != "": 37 | try: 38 | start_iter = int(self.saved_model.split("_")[-1].split(".")[0]) 39 | print(f"continue to train, start_iter: {start_iter}") 40 | except: 41 | pass 42 | 43 | return self._init_train(start_iter,taski, train_loader, valid_loader,AlignCollate_valid,valid_datas) 44 | 45 | def _init_train(self,start_iter,taski, train_loader, valid_loader,AlignCollate_valid,valid_datas): 46 | # loss averager 47 | train_loss_avg = Averager() 48 | best_scores = [] 49 | ned_scores = [] 50 | 51 | 52 | start_time = time.time() 53 | best_score = -1 54 | 55 | # training loop 56 | for iteration in tqdm( 57 | range(start_iter + 1, self.opt.num_iter + 1), 58 | total=self.opt.num_iter, 59 | position=0, 60 | leave=True, 61 | ): 62 | image_tensors, labels = train_loader.get_batch() 63 | 64 | image = image_tensors.to(self.device) 65 | labels_index, labels_length = self.converter.encode( 66 | labels, batch_max_length=self.opt.batch_max_length 67 | ) 68 | batch_size = image.size(0) 69 | 70 | # default recognition loss part 71 | if "CTC" in self.opt.Prediction: 72 | preds = self.model(image)["predict"] 73 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 74 | preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) 75 | loss = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length) 76 | else: 77 | preds = self.model(image, labels_index[:, :-1])["predict"] # align with Attention.forward 78 | target = labels_index[:, 1:] # without [SOS] Symbol 79 | loss = self.criterion( 80 | preds.view(-1, preds.shape[-1]), target.contiguous().view(-1) 81 | ) 82 | 83 | self.model.zero_grad() 84 | loss.backward() 85 | torch.nn.utils.clip_grad_norm_( 86 | self.model.parameters(), self.opt.grad_clip 87 | ) # gradient clipping with 5 (Default) 88 | self.optimizer.step() 89 | train_loss_avg.add(loss) 90 | 91 | if "super" in self.opt.schedule: 92 | self.scheduler.step() 93 | else: 94 | adjust_learning_rate(self.optimizer, iteration, self.opt) 95 | 96 | # validation part. 97 | # To see training progress, we also conduct validation when 'iteration == 1' 98 | if iteration % self.opt.val_interval == 0 or iteration == 1: 99 | # for validation log 100 | self.val(valid_loader, self.opt, best_score, start_time, iteration, 101 | train_loss_avg, taski) 102 | if iteration != 1: 103 | best_scores,ned_scores = self.test(AlignCollate_valid,valid_datas,best_scores,ned_scores,taski) 104 | self.model.train() 105 | train_loss_avg.reset() 106 | return best_scores,ned_scores -------------------------------------------------------------------------------- /il_modules/wa.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from tqdm import tqdm 4 | 5 | from il_modules.base import BaseLearner 6 | from tools.utils import Averager, adjust_learning_rate 7 | 8 | 9 | EPSILON = 1e-8 10 | 11 | 12 | init_epoch=200 13 | init_lr=0.1 14 | init_milestones=[60,120,170] 15 | init_lr_decay=0.1 16 | init_weight_decay=0.0005 17 | 18 | 19 | epochs = 170 20 | lrate = 0.1 21 | milestones = [60, 100,140] 22 | lrate_decay = 0.1 23 | batch_size = 128 24 | weight_decay=2e-4 25 | num_workers=8 26 | T=2 27 | 28 | 29 | class WA(BaseLearner): 30 | def __init__(self, opt): 31 | super().__init__(opt) 32 | self.taski = 0 33 | 34 | def after_task(self): 35 | if self.taski >0: 36 | self.model.module.weight_align(self._total_classes-self._known_classes) 37 | self.model = self.model.module 38 | self._old_network = self.model.copy().freeze() 39 | self._known_classes = self._total_classes 40 | 41 | def _update_representation(self,start_iter, taski, train_loader, valid_loader): 42 | self.taski = taski 43 | # loss averager 44 | train_loss_avg = Averager() 45 | # semi_loss_avg = Averager() 46 | 47 | start_time = time.time() 48 | best_score = -1 49 | 50 | # training loop 51 | for iteration in tqdm( 52 | range(start_iter + 1, self.opt.num_iter + 1), 53 | total=self.opt.num_iter, 54 | position=0, 55 | leave=True, 56 | ): 57 | image_tensors, labels = train_loader.get_batch() 58 | 59 | image = image_tensors.to(self.device) 60 | labels_index, labels_length = self.converter.encode( 61 | labels, batch_max_length=self.opt.batch_max_length 62 | ) 63 | batch_size = image.size(0) 64 | 65 | # default recognition loss part 66 | if "CTC" in self.opt.Prediction: 67 | start_index = 0 68 | preds = self.model(image)["predict"] 69 | old_preds = self._old_network(image)["predict"] 70 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 71 | # B,T,C(max) -> T, B, C 72 | preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) 73 | loss_clf = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length) 74 | else: 75 | start_index = 1 76 | preds = self.model(image, labels_index[:, :-1],True)["predict"] # align with Attention.forward 77 | old_preds = self._old_network(image, labels_index[:, :-1],True)["predict"] 78 | target = labels_index[:, 1:] # without [SOS] Symbol 79 | loss_clf = self.criterion( 80 | preds.view(-1, preds.shape[-1]), target.contiguous().view(-1) 81 | ) 82 | 83 | loss_kd = _KD_loss( 84 | preds.view(-1, preds.shape[-1])[:, start_index: self._known_classes], 85 | old_preds.view(-1, old_preds.shape[-1])[:, start_index: self._known_classes], 86 | T, 87 | ) 88 | 89 | loss=loss_clf+ 2*loss_kd 90 | 91 | self.model.zero_grad() 92 | loss.backward() 93 | torch.nn.utils.clip_grad_norm_( 94 | self.model.parameters(), self.opt.grad_clip 95 | ) # gradient clipping with 5 (Default) 96 | self.optimizer.step() 97 | train_loss_avg.add(loss) 98 | 99 | if "super" in self.opt.schedule: 100 | self.scheduler.step() 101 | else: 102 | adjust_learning_rate(self.optimizer, iteration, self.opt) 103 | 104 | # validation part. 105 | # To see training progress, we also conduct validation when 'iteration == 1' 106 | if iteration % self.opt.val_interval == 0 or iteration == 1: 107 | # for validation log 108 | self.val(valid_loader, self.opt, best_score, start_time, iteration, 109 | train_loss_avg, taski) 110 | train_loss_avg.reset() 111 | self.model.module.weight_align(self._total_classes - self._known_classes) 112 | 113 | def _KD_loss(pred, soft, T): 114 | pred = torch.log_softmax(pred / T, dim=1) 115 | soft = torch.softmax(soft / T, dim=1) 116 | return -1 * torch.mul(soft, pred).sum() / pred.shape[0] 117 | -------------------------------------------------------------------------------- /il_modules/lwf.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from tqdm import tqdm 4 | 5 | from il_modules.base import BaseLearner 6 | from tools.utils import Averager, adjust_learning_rate 7 | 8 | init_epoch = 200 9 | init_lr = 0.1 10 | init_milestones = [60, 120, 160] 11 | init_lr_decay = 0.1 12 | init_weight_decay = 0.0005 13 | 14 | 15 | epochs = 250 16 | lrate = 0.1 17 | milestones = [60, 120, 180, 220] 18 | lrate_decay = 0.1 19 | batch_size = 128 20 | weight_decay = 2e-4 21 | num_workers = 8 22 | T = 2 23 | lamda = 3 24 | 25 | 26 | class LwF(BaseLearner): 27 | def __init__(self, opt): 28 | super().__init__(opt) 29 | 30 | def after_task(self): 31 | self.model = self.model.module 32 | self._old_network = self.model.copy().freeze() 33 | self._known_classes = self._total_classes 34 | 35 | def _update_representation(self,start_iter, taski, train_loader, valid_loader): 36 | # loss averager 37 | train_loader.get_dataset(taski, memory=self.opt.memory) 38 | train_loss_avg = Averager() 39 | # semi_loss_avg = Averager() 40 | 41 | start_time = time.time() 42 | best_score = -1 43 | 44 | # training loop 45 | for iteration in tqdm( 46 | range(start_iter + 1, self.opt.num_iter + 1), 47 | total=self.opt.num_iter, 48 | position=0, 49 | leave=True, 50 | ): 51 | image_tensors, labels = train_loader.get_batch() 52 | 53 | image = image_tensors.to(self.device) 54 | labels_index, labels_length = self.converter.encode( 55 | labels, batch_max_length=self.opt.batch_max_length 56 | ) 57 | batch_size = image.size(0) 58 | 59 | # default recognition loss part 60 | if "CTC" in self.opt.Prediction: 61 | start_index = 0 62 | preds = self.model(image)["predict"] 63 | old_preds = self._old_network(image)["predict"] 64 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 65 | # B,T,C(max) -> T, B, C 66 | preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) 67 | loss_clf = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length) 68 | else: 69 | start_index = 1 70 | preds = self.model(image, labels_index[:, :-1],True)["predict"] # align with Attention.forward 71 | old_preds = self._old_network(image, labels_index[:, :-1],True)["predict"] 72 | target = labels_index[:, 1:] # without [SOS] Symbol 73 | loss_clf = self.criterion( 74 | preds.view(-1, preds.shape[-1]), target.contiguous().view(-1) 75 | ) 76 | 77 | # fake_targets = self._total_classes - self._known_classes 78 | # loss_clf = F.cross_entropy( 79 | # preds_log_softmax[:, self._known_classes:], fake_targets 80 | # ) 81 | loss_kd = _KD_loss( 82 | preds.view(-1,preds.shape[-1])[:, start_index: self._known_classes], 83 | old_preds.view(-1,old_preds.shape[-1])[:, start_index: self._known_classes], 84 | T, 85 | ) 86 | 87 | loss = lamda * loss_kd + loss_clf 88 | 89 | self.model.zero_grad() 90 | loss.backward() 91 | torch.nn.utils.clip_grad_norm_( 92 | self.model.parameters(), self.opt.grad_clip 93 | ) # gradient clipping with 5 (Default) 94 | self.optimizer.step() 95 | train_loss_avg.add(loss) 96 | 97 | if "super" in self.opt.schedule: 98 | self.scheduler.step() 99 | else: 100 | adjust_learning_rate(self.optimizer, iteration, self.opt) 101 | 102 | # validation part. 103 | # To see training progress, we also conduct validation when 'iteration == 1' 104 | if iteration % self.opt.val_interval == 0 or iteration == 1: 105 | # for validation log 106 | self.val(valid_loader, self.opt, best_score, start_time, iteration, 107 | train_loss_avg, taski) 108 | train_loss_avg.reset() 109 | 110 | 111 | def _KD_loss(pred, soft, T): 112 | pred = torch.log_softmax(pred / T, dim=1) 113 | soft = torch.softmax(soft / T, dim=1) 114 | return -1 * torch.mul(soft, pred).sum() / pred.shape[0] 115 | -------------------------------------------------------------------------------- /modules/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class Attention(nn.Module): 9 | def __init__(self, input_size, hidden_size, num_class, fc, num_char_embeddings=256): 10 | super(Attention, self).__init__() 11 | self.attention_cell = AttentionCell( 12 | input_size, hidden_size, num_char_embeddings 13 | ) 14 | self.hidden_size = hidden_size 15 | self.num_class = num_class 16 | self.generator = fc 17 | self.num_char_embeddings = num_char_embeddings 18 | # self.generator = nn.Linear(hidden_size, num_class) 19 | self.char_embeddings = nn.Embedding(num_class, num_char_embeddings) 20 | # self.char_embeddings = nn.Linear(1, num_char_embeddings) 21 | 22 | def _char_to_onehot(self, input_char, onehot_dim=38): 23 | input_char = input_char.unsqueeze(1) 24 | batch_size = input_char.size(0) 25 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) 26 | one_hot = one_hot.scatter_(1, input_char, 1) 27 | return one_hot 28 | 29 | def minmax(self,a): 30 | min_a = torch.min(a) 31 | max_a = torch.max(a) 32 | n2 = (a - min_a) / (max_a - min_a) 33 | return n2 34 | 35 | def cut_unknown(self,index): 36 | return torch.where(index >= self.num_class, 0, index) 37 | 38 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 39 | """ 40 | input: 41 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] 42 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS]. 43 | output: probability distribution at each step [batch_size x num_steps x num_class] 44 | """ 45 | batch_size = batch_H.size(0) 46 | num_steps = batch_max_length + 1 # +1 for [EOS] at end of sentence. 47 | 48 | output_hiddens = ( 49 | torch.FloatTensor(batch_size, num_steps, self.hidden_size) 50 | .fill_(0) 51 | .to(device) 52 | ) 53 | hidden = ( 54 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 55 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 56 | ) 57 | # text = self.minmax(text) 58 | if is_train: 59 | for i in range(num_steps): 60 | # char_embeddings = self._char_to_onehot(text[:, i], onehot_dim=self.num_class) 61 | char_embeddings = self.char_embeddings(self.cut_unknown(text[:, i])) 62 | # char_embeddings = self.char_embeddings(text[:, i].unsqueeze(-1).float()) 63 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1}) 64 | hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) 65 | output_hiddens[:, i, :] = hidden[ 66 | 0 67 | ] # LSTM hidden index (0: hidden, 1: Cell) 68 | probs = self.generator(output_hiddens) 69 | 70 | else: 71 | targets = text[0].expand(batch_size) # should be fill with [SOS] token 72 | probs = ( 73 | torch.FloatTensor(batch_size, num_steps, self.num_class) 74 | .fill_(0) 75 | .to(device) 76 | ) 77 | 78 | for i in range(num_steps): 79 | # char_embeddings = self._char_to_onehot(targets, onehot_dim=self.num_class) 80 | char_embeddings = self.char_embeddings(self.cut_unknown(targets)) 81 | # char_embeddings = self.char_embeddings(targets.unsqueeze(-1).float()) 82 | hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) 83 | probs_step = self.generator(hidden[0]) 84 | probs[:, i, :] = probs_step 85 | _, next_input = probs_step.max(1) 86 | targets = next_input 87 | 88 | return probs # batch_size x num_steps x num_class 89 | 90 | 91 | class AttentionCell(nn.Module): 92 | def __init__(self, input_size, hidden_size, num_embeddings): 93 | super(AttentionCell, self).__init__() 94 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 95 | self.h2h = nn.Linear( 96 | hidden_size, hidden_size 97 | ) # either i2i or h2h should have bias 98 | self.score = nn.Linear(hidden_size, 1, bias=False) 99 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 100 | self.hidden_size = hidden_size 101 | 102 | def forward(self, prev_hidden, batch_H, char_embeddings): 103 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 104 | batch_H_proj = self.i2h(batch_H) 105 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 106 | e = self.score( 107 | torch.tanh(batch_H_proj + prev_hidden_proj) 108 | ) # batch_size x num_encoder_step * 1 109 | 110 | alpha = F.softmax(e, dim=1) 111 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze( 112 | 1 113 | ) # batch_size x num_channel 114 | concat_context = torch.cat( 115 | [context, char_embeddings], 1 116 | ) # batch_size x (num_channel + num_embedding) 117 | cur_hidden = self.rnn(concat_context, prev_hidden) 118 | return cur_hidden, alpha 119 | -------------------------------------------------------------------------------- /il_modules/ewc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from tqdm import tqdm 4 | 5 | from il_modules.base import BaseLearner 6 | from tools.utils import Averager, adjust_learning_rate 7 | 8 | init_epoch=200 9 | init_lr=0.1 10 | init_milestones=[60,120,170] 11 | init_lr_decay=0.1 12 | init_weight_decay=0.0005 13 | 14 | epochs = 180 15 | lrate = 0.1 16 | milestones = [70, 120,150] 17 | lrate_decay = 0.1 18 | batch_size = 128 19 | weight_decay=2e-4 20 | num_workers=4 21 | T=2 22 | lamda=1000 23 | fishermax=0.0001 24 | alpha = 0.5 25 | num_iter = 5000 26 | 27 | class EWC(BaseLearner): 28 | def __init__(self, opt): 29 | super().__init__(opt) 30 | self.fisher = None 31 | 32 | def after_task(self): 33 | self.model = self.model.module 34 | # self._old_network = self.model.copy().freeze() 35 | self._known_classes = self._total_classes 36 | 37 | def _train(self, start_iter, taski, train_loader, valid_loader): 38 | if taski == 0: 39 | self._init_train(start_iter,taski, train_loader, valid_loader) 40 | else: 41 | if self.opt.memory != None: 42 | self.build_rehearsal_memory(train_loader, taski) 43 | else: 44 | train_loader.get_dataset(taski, memory=self.opt.memory) 45 | self._update_representation(start_iter,taski, train_loader, valid_loader) 46 | # self._update_representation(start_iter,taski, train_loader, valid_loader) 47 | if self.fisher is None: 48 | self.fisher=self.getFisherDiagonal(train_loader) 49 | else: 50 | # alpha=self._known_classes/self._total_classes 51 | new_finsher=self.getFisherDiagonal(train_loader) 52 | f_list = list(self.fisher.values()) 53 | for i,(n,p) in enumerate(new_finsher.items()): 54 | new_finsher[n][:len(f_list[i])] = alpha * f_list[i] + (1 - alpha) * new_finsher[n][:len(f_list[i])] 55 | # new_finsher[n][:len(self.fisher[n])]=alpha*self.fisher[n]+(1-alpha)*new_finsher[n][:len(self.fisher[n])] 56 | self.fisher=new_finsher 57 | self.mean={n: p.clone().detach() for n, p in self.model.named_parameters() if p.requires_grad} 58 | 59 | def _update_representation(self,start_iter, taski, train_loader, valid_loader): 60 | # loss averager 61 | train_loss_avg = Averager() 62 | 63 | start_time = time.time() 64 | best_score = -1 65 | 66 | # training loop 67 | for iteration in tqdm( 68 | range(start_iter + 1, self.opt.num_iter + 1), 69 | total=self.opt.num_iter, 70 | position=0, 71 | leave=True, 72 | ): 73 | image_tensors, labels = train_loader.get_batch() 74 | 75 | image = image_tensors.to(self.device) 76 | labels_index, labels_length = self.converter.encode( 77 | labels, batch_max_length=self.opt.batch_max_length 78 | ) 79 | batch_size = image.size(0) 80 | 81 | # default recognition loss part 82 | if "CTC" in self.opt.Prediction: 83 | preds = self.model(image)["predict"] 84 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 85 | # B,T,C(max) -> T, B, C 86 | preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) 87 | loss_clf = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length) 88 | else: 89 | preds = self.model(image, labels_index[:, :-1])["predict"] # align with Attention.forward 90 | target = labels_index[:, 1:] # without [SOS] Symbol 91 | loss_clf = self.criterion( 92 | preds.view(-1, preds.shape[-1]), target.contiguous().view(-1) 93 | ) 94 | 95 | loss_ewc = self.compute_ewc() 96 | loss = loss_clf + lamda * loss_ewc 97 | 98 | self.model.zero_grad() 99 | loss.backward() 100 | torch.nn.utils.clip_grad_norm_( 101 | self.model.parameters(), self.opt.grad_clip 102 | ) # gradient clipping with 5 (Default) 103 | self.optimizer.step() 104 | train_loss_avg.add(loss) 105 | 106 | if "super" in self.opt.schedule: 107 | self.scheduler.step() 108 | else: 109 | adjust_learning_rate(self.optimizer, iteration, self.opt) 110 | 111 | # validation part. 112 | # To see training progress, we also conduct validation when 'iteration == 1' 113 | if iteration % self.opt.val_interval == 0 or iteration == 1: 114 | # for validation log 115 | self.val(valid_loader, self.opt, best_score, start_time, iteration, 116 | train_loss_avg, taski) 117 | train_loss_avg.reset() 118 | 119 | 120 | def compute_ewc(self): 121 | loss = 0 122 | # if len(self._multiple_gpus) > 1: 123 | for n, p in self.model.module.named_parameters(): 124 | if n in self.fisher.keys(): 125 | loss += torch.sum((self.fisher[n]) * (p[:len(self.mean[n])] - self.mean[n]).pow(2)) / 2 126 | return loss 127 | 128 | def getFisherDiagonal(self,train_loader): 129 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.named_parameters() 130 | if p.requires_grad} 131 | self.model.train() 132 | # optimizer = optim.SGD(self.model.parameters(),lr=lrate) 133 | for iteration in tqdm( 134 | range( 1, num_iter + 1), 135 | total= num_iter, 136 | position=0, 137 | leave=True, 138 | ): 139 | image_tensors, labels = train_loader.get_batch() 140 | image = image_tensors.to(self.device) 141 | labels_index, labels_length = self.converter.encode( 142 | labels, batch_max_length=self.opt.batch_max_length 143 | ) 144 | batch_size = image.size(0) 145 | 146 | # default recognition loss part 147 | if "CTC" in self.opt.Prediction: 148 | preds = self.model(image)["predict"] 149 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 150 | # B,T,C(max) -> T, B, C 151 | preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) 152 | loss = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length) 153 | else: 154 | preds = self.model(image, labels_index[:, :-1])["predict"] # align with Attention.forward 155 | target = labels_index[:, 1:] # without [SOS] Symbol 156 | loss= self.criterion( 157 | preds.view(-1, preds.shape[-1]), target.contiguous().view(-1) 158 | ) 159 | self.optimizer.zero_grad() 160 | loss.backward() 161 | for n, p in self.model.named_parameters(): 162 | if p.grad is not None: 163 | fisher[n] += p.grad.pow(2).clone() 164 | for n,p in fisher.items(): 165 | fisher[n]=p/num_iter 166 | fisher[n]=torch.min(fisher[n],torch.tensor(fishermax)) 167 | return fisher 168 | -------------------------------------------------------------------------------- /modules/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | class TPS_SpatialTransformerNetwork(nn.Module): 10 | """Rectification Network of RARE, namely TPS based STN""" 11 | 12 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 13 | """Based on RARE TPS 14 | input: 15 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 16 | I_size : (height, width) of the input image I 17 | I_r_size : (height, width) of the rectified image I_r 18 | I_channel_num : the number of channels of the input image I 19 | output: 20 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 21 | """ 22 | super(TPS_SpatialTransformerNetwork, self).__init__() 23 | self.F = F 24 | self.I_size = I_size 25 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 26 | self.I_channel_num = I_channel_num 27 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 28 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 29 | 30 | def forward(self, batch_I): 31 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 32 | # batch_size x n (= I_r_width x I_r_height) x 2 33 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) 34 | build_P_prime_reshape = build_P_prime.reshape( 35 | [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2] 36 | ) 37 | 38 | if torch.__version__ > "1.2.0": 39 | batch_I_r = F.grid_sample( 40 | batch_I, 41 | build_P_prime_reshape, 42 | padding_mode="border", 43 | align_corners=True, 44 | ) 45 | else: 46 | batch_I_r = F.grid_sample( 47 | batch_I, build_P_prime_reshape, padding_mode="border" 48 | ) 49 | 50 | return batch_I_r 51 | 52 | 53 | class LocalizationNetwork(nn.Module): 54 | """Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height)""" 55 | 56 | def __init__(self, F, I_channel_num): 57 | super(LocalizationNetwork, self).__init__() 58 | self.F = F 59 | self.I_channel_num = I_channel_num 60 | self.conv = nn.Sequential( 61 | nn.Conv2d( 62 | in_channels=self.I_channel_num, 63 | out_channels=64, 64 | kernel_size=3, 65 | stride=1, 66 | padding=1, 67 | bias=False, 68 | ), 69 | nn.BatchNorm2d(64), 70 | nn.ReLU(True), 71 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 72 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), 73 | nn.BatchNorm2d(128), 74 | nn.ReLU(True), 75 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 76 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), 77 | nn.BatchNorm2d(256), 78 | nn.ReLU(True), 79 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 80 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), 81 | nn.BatchNorm2d(512), 82 | nn.ReLU(True), 83 | nn.AdaptiveAvgPool2d(1), # batch_size x 512 84 | ) 85 | 86 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 87 | self.localization_fc2 = nn.Linear(256, self.F * 2) 88 | 89 | # Init fc2 in LocalizationNetwork 90 | self.localization_fc2.weight.data.fill_(0) 91 | """ see RARE paper Fig. 6 (a) """ 92 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 93 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 94 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 95 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 96 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 97 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 98 | self.localization_fc2.bias.data = ( 99 | torch.from_numpy(initial_bias).float().view(-1) 100 | ) 101 | 102 | def forward(self, batch_I): 103 | """ 104 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 105 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 106 | """ 107 | batch_size = batch_I.size(0) 108 | features = self.conv(batch_I).view(batch_size, -1) 109 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view( 110 | batch_size, self.F, 2 111 | ) 112 | return batch_C_prime 113 | 114 | 115 | class GridGenerator(nn.Module): 116 | """Grid Generator of RARE, which produces P_prime by multipling T with P""" 117 | 118 | def __init__(self, F, I_r_size): 119 | """Generate P_hat and inv_delta_C for later""" 120 | super(GridGenerator, self).__init__() 121 | self.eps = 1e-6 122 | self.I_r_height, self.I_r_width = I_r_size 123 | self.F = F 124 | self.C = self._build_C(self.F) # F x 2 125 | self.P = self._build_P(self.I_r_width, self.I_r_height) 126 | 127 | num_gpu = torch.cuda.device_count() 128 | if num_gpu > 1: 129 | # for multi-gpu, you may need register buffer 130 | self.register_buffer( 131 | "inv_delta_C", 132 | torch.tensor(self._build_inv_delta_C(self.F, self.C)).float(), 133 | ) # F+3 x F+3 134 | self.register_buffer( 135 | "P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() 136 | ) # n x F+3 137 | else: 138 | # for fine-tuning with different image width, you may use below instead of self.register_buffer 139 | self.inv_delta_C = ( 140 | torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().to(device) 141 | ) # F+3 x F+3 142 | self.P_hat = ( 143 | torch.tensor(self._build_P_hat(self.F, self.C, self.P)) 144 | .float() 145 | .to(device) 146 | ) # n x F+3 147 | 148 | def _build_C(self, F): 149 | """Return coordinates of fiducial points in I_r; C""" 150 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 151 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 152 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 153 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 154 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 155 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 156 | return C # F x 2 157 | 158 | def _build_inv_delta_C(self, F, C): 159 | """Return inv_delta_C which is needed to calculate T""" 160 | hat_C = np.zeros((F, F), dtype=float) # F x F 161 | for i in range(0, F): 162 | for j in range(i, F): 163 | r = np.linalg.norm(C[i] - C[j]) 164 | hat_C[i, j] = r 165 | hat_C[j, i] = r 166 | np.fill_diagonal(hat_C, 1) 167 | hat_C = (hat_C ** 2) * np.log(hat_C) 168 | # print(C.shape, hat_C.shape) 169 | delta_C = np.concatenate( # F+3 x F+3 170 | [ 171 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 172 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 173 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1), # 1 x F+3 174 | ], 175 | axis=0, 176 | ) 177 | inv_delta_C = np.linalg.inv(delta_C) 178 | return inv_delta_C # F+3 x F+3 179 | 180 | def _build_P(self, I_r_width, I_r_height): 181 | I_r_grid_x = ( 182 | np.arange(-I_r_width, I_r_width, 2) + 1.0 183 | ) / I_r_width # self.I_r_width 184 | I_r_grid_y = ( 185 | np.arange(-I_r_height, I_r_height, 2) + 1.0 186 | ) / I_r_height # self.I_r_height 187 | P = np.stack( # self.I_r_width x self.I_r_height x 2 188 | np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2 189 | ) 190 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 191 | 192 | def _build_P_hat(self, F, C, P): 193 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 194 | P_tile = np.tile( 195 | np.expand_dims(P, axis=1), (1, F, 1) 196 | ) # n x 2 -> n x 1 x 2 -> n x F x 2 197 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 198 | P_diff = P_tile - C_tile # n x F x 2 199 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 200 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 201 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 202 | return P_hat # n x F+3 203 | 204 | def build_P_prime(self, batch_C_prime): 205 | """Generate Grid from batch_C_prime [batch_size x F x 2]""" 206 | batch_size = batch_C_prime.size(0) 207 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 208 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 209 | batch_C_prime_with_zeros = torch.cat( 210 | (batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)), dim=1 211 | ) # batch_size x F+3 x 2 212 | batch_T = torch.bmm( 213 | batch_inv_delta_C, batch_C_prime_with_zeros 214 | ) # batch_size x F+3 x 2 215 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 216 | return batch_P_prime # batch_size x n x 2 217 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # MRN: Multiplexed Routing Network
for Incremental Multilingual Text Recognition 4 | 5 | 6 | ![ICCV 2023](https://img.shields.io/badge/ICCV-2023-ff7c00) 7 | [![ArXiv preprint](http://img.shields.io/badge/ArXiv-2305-b31b1b)](https://arxiv.org/abs/2305.14758) 8 | [![Blog](http://img.shields.io/badge/Blog-Link-6790ac)](https://zhuanlan.zhihu.com/p/643948935) 9 | ![LICENSE](https://img.shields.io/badge/license-Apache--2.0-green?style=flat-square) 10 | 11 | [Method](#methods) |[IMLTR Dataset](#imltr-dataset) | [Getting Started](#getting-started) | [Citation](#citation) 12 | 13 |
14 | 15 | It started as code for the paper: 16 | 17 | **MRN: Multiplexed Routing Network for Incremental Multilingual Text Recognition** 18 | (Accepted by ICCV 2023) 19 | 20 | This project is a toolkit for the novel scenario of Incremental Multilingual Text Recognition (IMLTR), the project supports many incremental learning methods and proposes a more applicable method for IMLTR: Multiplexed Routing Network (MRN) and the corresponding dataset. The project provides an efficient framework to assist in developing new methods and analyzing existing ones under the IMLTR task, and we hope it will advance the IMLTR community. 21 | 22 |
23 | 24 | image 25 | 26 | 27 |
28 | 29 | --- 30 | ## Methods 31 | ### Incremental Learning Methods 32 | * [x] Base: Baseline method which simply updates parameters on new tasks. 33 | * [x] Joint: Bound method: data for all tasks are trained at once, an upper bound for the method
(Joint_mix means all tasks data mixed in batch, Joint_loader means the consistent proportion of data from each task in a batch) 34 | * [x] [EWC](https://arxiv.org/abs/1612.00796) `[PNAS2017]`: Overcoming catastrophic forgetting in neural networks 35 | * [x] [LwF](https://arxiv.org/abs/1911.07053) `[ECCV2016]`: Learning without Forgetting 36 | * [x] [WA](https://arxiv.org/abs/1911.07053) `[CVPR2020]`: Maintaining Discrimination and Fairness in Class Incremental Learning 37 | * [x] [DER](https://arxiv.org/abs/2103.16788) `[CVPR2021]`: DER: Dynamically Expandable Representation for Class Incremental Learning 38 | * [x] [MRN](https://arxiv.org/abs/2305.14758) `[ICCV2023]`: MRN: Multiplexed Routing Network for Incremental Multilingual Text Recognition 39 | 40 | you can change config `config/crnn_mrn.py` for different il methods or setting. 41 | ``` 42 | common=dict( 43 | il="mrn", # joint_mix | joint_loader | base | lwf | wa | ewc | der | mrn 44 | memory="random", # None | random 45 | memory_num=2000, 46 | start_task = 0 # checkpoint start 47 | ) 48 | ``` 49 | 50 | ### Text Recognition Methods 51 | * [x] [CRNN](https://ieeexplore.ieee.org/abstract/document/7801919) `[TPAMI2017]`: An End-to-End Trainable Neural Network for Image-Based Sequence Recognition and Its Application to Scene Text Recognition 52 | * [x] [TRBA](https://arxiv.org/abs/1904.01906) `[ICCV2019]`: What Is Wrong With Scene Text Recognition Model Comparisons? Dataset and Model Analysis 53 | * [x] [SVTR](https://arxiv.org/abs/2205.00159) `[IJCAI2022]`: SVTR: Scene Text Recognition with a Single Visual Model 54 | 55 | you can change config `config/crnn_mrn.py` for different text recognition modules or setting. 56 | ``` 57 | """ Model Architecture """ 58 | common=dict( 59 | batch_max_length = 25, 60 | imgH = 32, 61 | imgW = 256, 62 | ) 63 | model=dict( 64 | model_name="TRBA", 65 | Transformation = "TPS", #None TPS 66 | FeatureExtraction = "ResNet", #VGG ResNet SVTR 67 | SequenceModeling = "BiLSTM", #None BiLSTM 68 | Prediction = "Attn", #CTC Attn 69 | num_fiducial=20, 70 | input_channel=4, 71 | output_channel=512, 72 | hidden_size=256, 73 | ) 74 | ``` 75 | 76 | 77 | ## IMLTR Dataset 78 | The Dataset can be downloaded from [BaiduNetdisk](https://pan.baidu.com/s/1Qv4utVzWlLu8UPcBpItHbQ)(passwd:c07h). 79 | 80 | ``` 81 | dataset 82 | ├── MLT17_IL 83 | │   ├── test_2017 84 | │   ├── train_2017 85 | ├── MLT19_IL 86 | │   ├── test_2019 87 | │   ├── train_2019 88 | ``` 89 | 90 | Incremental MLT17: MLT17 has 68,613 training instances and 16,255 validation instances, which are from 6 scripts and 9 languages: Chinese, Japanese, Korean, Bangla, Arabic, Italian, English, French, and German. The last four use Latin script. Incremental MLT17 use the validation set for test due to the unavailability of test data. Tasks are split by scripts and modeled sequentially. Special symbols are discarded at the preprocessing step as with no linguistic meaning. 91 | 92 | Incremental MLT19: MLT19 has 89,177 text instances coming from 7 scripts. Since the inaccessibility of test set, Incremental MLT19 randomly split the training instances to 9:1 script-by-script, for model training and test. To be consistent with Incremental MLT17 dataset, we discard the Hindi script and also special symbols. Statistics of the two datasets are shown in the following. 93 | 94 | | Dataset | Categories | | | | | | | 95 | |---------|----------------|-------|-------|----------|--------|--------|--------| 96 | | | | Task1 | Task2 | Task3 | Task4 | Task5 | Task6 | 97 | | | | Chinese | Latin | Japanese | Korean | Arabic | Bangla | 98 | | MLT17[^1^] | Train Instance | 2687 | 47411 | 4609 | 5631 | 3711 | 3237 | 99 | | | Test Instance | 529 | 11073 | 1350 | 1230 | 983 | 713 | 100 | | | Train Class | 1895 | 325 | 1620 | 1124 | 73 | 112 | 101 | | MLT19[^2^] | Train Instance | 2897 | 52921 | 5324 | 6107 | 4230 | 3542 | 102 | | | Test Instance | 322 | 5882 | 590 | 679 | 470 | 393 | 103 | | | Train Class | 2086 | 220 | 1728 | 1160 | 73 | 102 | 104 | 105 | [^1^]: Nayef, N., et al. (2017). MLT 2017. 106 | [^2^]: Nayef, N., et al. (2019). MLT 2019. 107 | 108 | 109 | ## Getting Started 110 | ### Dependency 111 | - This work was tested with PyTorch 1.6.0, CUDA 10.1 and python 3.6. 112 | ``` 113 | conda create -n mrn python=3.7 -y 114 | conda activate mrn 115 | conda install pytorch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1 cudatoolkit=11.3 -c pytorch -c conda-forge 116 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 117 | ``` 118 | - requirements : 119 | ``` 120 | pip3 install lmdb pillow torchvision nltk natsort fire tensorboard tqdm opencv-python einops timm mmcv shapely scipy 121 | pip3 install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.1/index.html 122 | ``` 123 | 124 | ## Training 125 | ``` 126 | python3 tiny_train.py --config=config/crnn_mrn.py --exp_name CRNN_real 127 | ``` 128 | ### Arguments 129 | tiny_train.py (as a default, evaluate trained model on IMLTR datasets at the end of training. 130 | * `--select_data`: folder path to training lmdb datasets.
`[" ../dataset/MLT17_IL/train_2017", "../dataset/MLT19_IL/train_2019"] ` 131 | * `--valid_datas`: folder path to testing lmdb dataset.
`[" ../dataset/MLT17_IL/test_2017", "../dataset/MLT19_IL/test_2019"] ` 132 | * `--batch_ratio`: assign ratio for each selected data in the batch. default is '1 / number of datasets'. 133 | * `--Aug`: whether to use augmentation |None|Blur|Crop|Rot| 134 | 135 | ### Config Detail 136 | For detailed configuration modifications please use the config file `config/crnn_mrn.py` 137 | ``` 138 | common=dict( 139 | exp_name="TRBA_MRN", # Where to store logs and models 140 | il="mrn", # joint_mix | joint_loader | base | lwf | wa | ewc | der | mrn 141 | memory="random", # None | random 142 | memory_num=2000, 143 | batch_max_length = 25, 144 | imgH = 32, 145 | imgW = 256, 146 | manual_seed=111, 147 | start_task = 0 148 | ) 149 | 150 | """ Model Architecture """ 151 | model=dict( 152 | model_name="TRBA", 153 | Transformation = "TPS", #None TPS 154 | FeatureExtraction = "ResNet", #VGG ResNet 155 | SequenceModeling = "BiLSTM", #None BiLSTM 156 | Prediction = "Attn", #CTC Attn 157 | num_fiducial=20, 158 | input_channel=4, 159 | output_channel=512, 160 | hidden_size=256, 161 | ) 162 | 163 | 164 | 165 | """ Optimizer """ 166 | optimizer=dict( 167 | schedule="super", #default is super for super convergence, 1 for None, [0.6, 0.8] for the same setting with ASTER 168 | optimizer="adam", 169 | lr=0.0005, 170 | sgd_momentum=0.9, 171 | sgd_weight_decay=0.000001, 172 | milestones=[2000,4000], 173 | lrate_decay=0.1, 174 | rho=0.95, 175 | eps=1e-8, 176 | lr_drop_rate=0.1 177 | ) 178 | 179 | 180 | """ Data processing """ 181 | train = dict( 182 | saved_model="", # "path to model to continue training" 183 | Aug="None", # |None|Blur|Crop|Rot|ABINet 184 | workers=4, 185 | lan_list=["Chinese","Latin","Japanese", "Korean", "Arabic", "Bangla"], 186 | valid_datas=[ 187 | "../dataset/MLT17_IL/test_2017", 188 | "../dataset/MLT19_IL/test_2019" 189 | ], 190 | select_data=[ 191 | "../dataset/MLT17_IL/train_2017", 192 | "../dataset/MLT19_IL/train_2019" 193 | ], 194 | batch_ratio="0.5-0.5", 195 | total_data_usage_ratio="1.0", 196 | NED=True, 197 | batch_size=256, 198 | num_iter=10000, 199 | val_interval=5000, 200 | log_multiple_test=None, 201 | grad_clip=5, 202 | ) 203 | 204 | ``` 205 | 206 | ### Data Analysis 207 | The experimental results of each task are recorded in `data_any.txt` and can be used for analysis of the data. 208 | 209 | 210 | ## Acknowledgements 211 | This implementation has been based on these repositories: 212 | - [STR-Fewer-Labels](https://github.com/ku21fan/STR-Fewer-Labels) 213 | - [PyCIL: A Python Toolbox for Class-Incremental Learning](https://github.com/G-U-N/PyCIL) 214 | 215 | ## Citation 216 | Please consider citing this work in your publications if it helps your research. 217 | ``` 218 | @article{zheng2023mrn, 219 | title={MRN: Multiplexed Routing Network for Incremental Multilingual Text Recognition}, 220 | author={Zheng, Tianlun and Chen, Zhineng and Huang, BingChen and Zhang, Wei and Jiang, Yu-Gang}, 221 | journal={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 222 | year={2023} 223 | } 224 | ``` 225 | 226 | ## License 227 | This project is released under the Apache 2.0 license. 228 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import PIL 4 | import numpy as np 5 | import torch 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | class CTCLabelConverter(object): 11 | """Convert between text-label and text-index""" 12 | 13 | def __init__(self, character): 14 | # character (str): set of the possible characters. 15 | list_special_token = [ 16 | "[PAD]", 17 | "[UNK]", 18 | " ", 19 | ] # [UNK] for unknown character, ' ' for space. 20 | list_character = list(character) 21 | dict_character = list_special_token + list_character 22 | 23 | self.dict = {} 24 | for i, char in enumerate(dict_character): 25 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss, not same with space ' '. 26 | # print(type(char)) 27 | self.dict[char] = i + 1 28 | 29 | self.character = [ 30 | "[CTCblank]" 31 | ] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0). 32 | print(f"# characters dict has: {len(self.character)}") 33 | # print(f"\n {self.character}\n") 34 | 35 | def encode(self, word_string, batch_max_length=25): 36 | """convert word_list (string) into word_index. 37 | input: 38 | word_string: word labels of each image. [batch_size] 39 | batch_max_length: max length of word in the batch. Default: 25 40 | 41 | output: 42 | word_index: word index list for CTCLoss. [batch_size, batch_max_length] 43 | word_length: length of each word. [batch_size] 44 | """ 45 | word_length = [len(word) for word in word_string] 46 | 47 | # The index used for padding (=[PAD]) would not affect the CTC loss calculation. 48 | word_index = torch.LongTensor(len(word_string), batch_max_length).fill_( 49 | self.dict["[PAD]"] 50 | ) 51 | 52 | for i, word in enumerate(word_string): 53 | word = list(word) 54 | word_idx = [ 55 | self.dict[char] if char in self.dict else self.dict["[UNK]"] 56 | for char in word 57 | ] 58 | word_index[i][: len(word_idx)] = torch.LongTensor(word_idx) 59 | 60 | return (word_index.to(device), torch.IntTensor(word_length).to(device)) 61 | 62 | def decode(self, word_index, word_length): 63 | """convert word_index into word_string""" 64 | word_string = [] 65 | for idx, length in enumerate(word_length): 66 | word_idx = word_index[idx, :] 67 | 68 | char_list = [] 69 | for i in range(length): 70 | # removing repeated characters and blank. 71 | if word_idx[i] != 0 and not (i > 0 and word_idx[i - 1] == word_idx[i]): 72 | char_list.append(self.character[word_idx[i]]) 73 | 74 | word = "".join(char_list) 75 | word_string.append(word) 76 | return word_string 77 | 78 | 79 | class AttnLabelConverter(object): 80 | """Convert between text-label and text-index""" 81 | 82 | def __init__(self, character): 83 | # character (str): set of the possible characters. 84 | # [SOS] (start-of-sentence token) and [EOS] (end-of-sentence token) for the attention decoder. 85 | list_special_token = [ 86 | "[UNK]", 87 | "[PAD]", 88 | "[SOS]", 89 | "[EOS]", 90 | " ", 91 | ] # [UNK] for unknown character, ' ' for space. 92 | list_character = list(character) 93 | self.character = list_special_token + list_character 94 | 95 | self.dict = {} 96 | for i, char in enumerate(self.character): 97 | # print(i, char) 98 | self.dict[char] = i 99 | 100 | print(f"# of tokens and characters: {len(self.character)}") 101 | 102 | def encode(self, word_string, batch_max_length=25): 103 | """convert word_list (string) into word_index. 104 | input: 105 | word_string: word labels of each image. [batch_size] 106 | batch_max_length: max length of word in the batch. Default: 25 107 | 108 | output: 109 | word_index : the input of attention decoder. [batch_size x (max_length+2)] +1 for [SOS] token and +1 for [EOS] token. 110 | word_length : the length of output of attention decoder, which count [EOS] token also. [batch_size] 111 | """ 112 | word_length = [ 113 | len(word) + 1 for word in word_string 114 | ] # +1 for [EOS] at end of sentence. 115 | batch_max_length += 1 116 | 117 | # additional batch_max_length + 1 for [SOS] at first step. 118 | word_index = torch.LongTensor(len(word_string), batch_max_length + 1).fill_( 119 | self.dict["[PAD]"] 120 | ) 121 | word_index[:, 0] = self.dict["[SOS]"] 122 | 123 | for i, word in enumerate(word_string): 124 | word = list(word) 125 | word.append("[EOS]") 126 | word_idx = [ 127 | self.dict[char] if char in self.dict else self.dict["[UNK]"] 128 | for char in word 129 | ] 130 | word_index[i][1 : 1 + len(word_idx)] = torch.LongTensor( 131 | word_idx 132 | ) # word_index[:, 0] = [SOS] token 133 | 134 | return (word_index.to(device), torch.IntTensor(word_length).to(device)) 135 | 136 | def decode(self, word_index, word_length): 137 | """convert word_index into word_string""" 138 | word_string = [] 139 | for idx, length in enumerate(word_length): 140 | word_idx = word_index[idx, :length] 141 | word = "".join([self.character[i] for i in word_idx]) 142 | word_string.append(word) 143 | return word_string 144 | 145 | 146 | class Averager(object): 147 | """Compute average for torch.Tensor, used for loss average.""" 148 | 149 | def __init__(self): 150 | self.reset() 151 | 152 | def add(self, v): 153 | count = v.data.numel() 154 | v = v.data.sum() 155 | self.n_count += count 156 | self.sum += v 157 | 158 | def reset(self): 159 | self.n_count = 0 160 | self.sum = 0 161 | 162 | def val(self): 163 | res = 0 164 | if self.n_count != 0: 165 | res = self.sum / float(self.n_count) 166 | return res 167 | 168 | 169 | def adjust_learning_rate(optimizer, iteration, opt): 170 | """Decay the learning rate based on schedule""" 171 | lr = opt.lr 172 | # stepwise lr schedule 173 | for milestone in opt.schedule: 174 | lr *= ( 175 | opt.lr_drop_rate if iteration >= (float(milestone) * opt.num_iter) else 1.0 176 | ) 177 | for param_group in optimizer.param_groups: 178 | param_group["lr"] = lr 179 | 180 | 181 | def tensor2im(image_tensor, imtype=np.uint8): 182 | image_numpy = image_tensor.cpu().float().numpy() 183 | if image_numpy.shape[0] == 1: 184 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 185 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 186 | return image_numpy.astype(imtype) 187 | 188 | 189 | def save_image(image_numpy, image_path): 190 | image_pil = PIL.Image.fromarray(image_numpy) 191 | image_pil.save(image_path) 192 | 193 | def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2): 194 | """Crop text region with their bounding box. 195 | 196 | Args: 197 | src_img (np.array): The original image. 198 | box (list[float | int]): Points of quadrangle. 199 | long_edge_pad_ratio (float): Box pad ratio for long edge 200 | corresponding to font size. 201 | short_edge_pad_ratio (float): Box pad ratio for short edge 202 | corresponding to font size. 203 | """ 204 | assert utils.is_type_list(box, (float, int)) 205 | assert len(box) == 8 206 | assert 0. <= long_edge_pad_ratio < 1.0 207 | assert 0. <= short_edge_pad_ratio < 1.0 208 | 209 | h, w = src_img.shape[:2] 210 | points_x = np.clip(np.array(box[0::2]), 0, w) 211 | points_y = np.clip(np.array(box[1::2]), 0, h) 212 | 213 | box_width = np.max(points_x) - np.min(points_x) 214 | box_height = np.max(points_y) - np.min(points_y) 215 | font_size = min(box_height, box_width) 216 | 217 | if box_height < box_width: 218 | horizontal_pad = long_edge_pad_ratio * font_size 219 | vertical_pad = short_edge_pad_ratio * font_size 220 | else: 221 | horizontal_pad = short_edge_pad_ratio * font_size 222 | vertical_pad = long_edge_pad_ratio * font_size 223 | 224 | left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w) 225 | top = np.clip(int(np.min(points_y) - vertical_pad), 0, h) 226 | right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w) 227 | bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h) 228 | 229 | dst_img = src_img[top:bottom, left:right] 230 | 231 | return dst_img 232 | 233 | def read_txt(path): 234 | f = open(path) 235 | list = [] 236 | char_dict = {} 237 | line = f.readline() 238 | while line: 239 | list.append(line.strip("\n")) 240 | # print(line) 241 | line = f.readline() 242 | f.close() 243 | for str in list: 244 | for char in str: 245 | if char_dict.get(char, None) == None: 246 | char_dict[char] = 1 247 | else: 248 | char_dict[char] += 1 249 | return char_dict 250 | def dict_total(path='.txt', 251 | path_a='_all.txt'): 252 | root = '/share/home/ztl/CIL_MLSTR/exp/base/' 253 | language = "Japanese" 254 | path_ = language+'test.txt' 255 | char_list = [] 256 | true_char = read_txt(root + language + path) 257 | total_char = read_txt(root + language + path_a) 258 | for key, value in total_char.items(): 259 | acc = true_char.get(key, 0) / total_char[key] 260 | char_list.append([key,value,acc]) 261 | print([key,value,acc]) 262 | pred_list = sorted(char_list,key=lambda list: list[1]) 263 | start_i = 0 264 | for i,list in enumerate(pred_list): 265 | if i != 0 and list[1] != pred_list[i-1][1]: 266 | avg = acc / (i- start_i) 267 | # avg = acc / (i + 1) 268 | str_log = "avg {} char is {:.2f} total {}\n".format(pred_list[i-1][1],avg,i - start_i) 269 | print(str_log) 270 | with open(root + path_, "a") as log: 271 | log.write(str_log) 272 | start_i = i 273 | acc = 0 274 | acc += list[2] 275 | with open(root + path_, "a") as log: 276 | for line in pred_list: 277 | log.write(str(line) + "\n") 278 | # dict_total() 279 | 280 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import six 4 | import random 5 | 6 | from natsort import natsorted 7 | import PIL 8 | import lmdb 9 | import torch 10 | from torch.utils.data import Dataset, ConcatDataset 11 | import torchvision.transforms as transforms 12 | 13 | from data.transform import CVGeometry, CVDeterioration, CVColorJitter 14 | 15 | def hierarchical_dataset(root, opt, select_data="/", data_type="label", mode="train"): 16 | """select_data='/' contains all sub-directory of root directory""" 17 | dataset_list = [] 18 | dataset_log = f"dataset_root: {root}\t dataset: {select_data}" 19 | print(dataset_log) 20 | dataset_log += "\n" 21 | for dirpath, dirnames, filenames in os.walk(root + "/"): 22 | if not dirnames: 23 | select_flag = False 24 | for selected_d in select_data: 25 | if selected_d in dirpath: 26 | select_flag = True 27 | break 28 | 29 | if select_flag: 30 | # if data_type == "label": 31 | dataset = LmdbDataset(dirpath, opt, mode=mode) 32 | # else: 33 | # dataset = LmdbDataset_unlabel(dirpath, opt) 34 | sub_dataset_log = f"sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}" 35 | print(sub_dataset_log) 36 | dataset_log += f"{sub_dataset_log}\n" 37 | dataset_list.append(dataset) 38 | 39 | concatenated_dataset = ConcatDataset(dataset_list) 40 | 41 | return concatenated_dataset, dataset_log 42 | 43 | 44 | class LmdbDataset(Dataset): 45 | def __init__(self, root, opt, mode="train"): 46 | 47 | self.root = root 48 | skip = 0 49 | self.opt = opt 50 | self.mode = mode 51 | self.env = lmdb.open( 52 | root, 53 | max_readers=32, 54 | readonly=True, 55 | lock=False, 56 | readahead=False, 57 | meminit=False, 58 | ) 59 | if not self.env: 60 | print("cannot open lmdb from %s" % (root)) 61 | sys.exit(0) 62 | 63 | with self.env.begin(write=False) as txn: 64 | self.nSamples = int(txn.get("num-samples".encode())) 65 | print(self.nSamples) 66 | self.filtered_index_list = [] 67 | for index in range(self.nSamples): 68 | index += 1 # lmdb starts with 1 69 | label_key = "label-%09d".encode() % index 70 | # print(label_key) 71 | if txn.get(label_key)==None: 72 | skip+=1 73 | print("skip --- {}\n".format(skip)) 74 | continue 75 | label = txn.get(label_key).decode("utf-8") 76 | # print(label) 77 | 78 | # length filtering 79 | length_of_label = len(label) 80 | if length_of_label > opt.batch_max_length: 81 | continue 82 | 83 | self.filtered_index_list.append(index) 84 | 85 | self.nSamples = len(self.filtered_index_list) 86 | 87 | def __len__(self): 88 | return self.nSamples 89 | 90 | def __getitem__(self, index): 91 | assert index <= len(self), "index range error" 92 | index = self.filtered_index_list[index] 93 | 94 | with self.env.begin(write=False) as txn: 95 | label_key = "label-%09d".encode() % index 96 | label = txn.get(label_key).decode("utf-8") 97 | img_key = "image-%09d".encode() % index 98 | imgbuf = txn.get(img_key) 99 | buf = six.BytesIO() 100 | buf.write(imgbuf) 101 | buf.seek(0) 102 | 103 | try: 104 | img = PIL.Image.open(buf).convert("RGBA") 105 | 106 | except IOError: 107 | print(f"Corrupted image for {index}") 108 | # make dummy image and dummy label for corrupted image. 109 | img = PIL.Image.new("RGBA", (self.opt.imgW, self.opt.imgH)) 110 | label = "[dummy_label]" 111 | 112 | return (img, label) 113 | 114 | 115 | class RawDataset(Dataset): 116 | def __init__(self, root, opt): 117 | self.opt = opt 118 | self.image_path_list = [] 119 | for dirpath, dirnames, filenames in os.walk(root): 120 | for name in filenames: 121 | _, ext = os.path.splitext(name) 122 | ext = ext.lower() 123 | if ext == ".jpg" or ext == ".jpeg" or ext == ".png": 124 | self.image_path_list.append(os.path.join(dirpath, name)) 125 | 126 | self.image_path_list = natsorted(self.image_path_list) 127 | self.nSamples = len(self.image_path_list) 128 | 129 | def __len__(self): 130 | return self.nSamples 131 | 132 | def __getitem__(self, index): 133 | 134 | try: 135 | img = PIL.Image.open(self.image_path_list[index]).convert("RGBA") 136 | 137 | except IOError: 138 | print(f"Corrupted image for {index}") 139 | # make dummy image and dummy label for corrupted image. 140 | img = PIL.Image.new("RGBA", (self.opt.imgW, self.opt.imgH)) 141 | 142 | return (img, self.image_path_list[index]) 143 | 144 | class AlignCollate2(object): 145 | def __init__(self, opt, mode="train"): 146 | self.opt = opt 147 | self.mode = mode 148 | 149 | if opt.Aug == "None" or mode != "train": 150 | self.transform = ResizeNormalize((opt.imgW, opt.imgH)) 151 | elif opt.Aug == "ABINet" and mode == "train": 152 | self.transform = transforms.Compose([ 153 | CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5), 154 | CVDeterioration(var=20, degrees=6, factor=4, p=0.25), 155 | CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25), 156 | transforms.Resize( 157 | (self.opt.imgH, self.opt.imgW), interpolation=PIL.Image.BICUBIC 158 | ), 159 | transforms.ToTensor(), 160 | ]) 161 | else: 162 | self.transform = Text_augment(opt) 163 | 164 | def __call__(self, batch): 165 | b_info, index = zip(*batch) 166 | images, labels = zip(*b_info) 167 | image_tensors = [self.transform(image) for image in images] 168 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 169 | 170 | return image_tensors, labels , index 171 | 172 | class AlignCollate(object): 173 | def __init__(self, opt, mode="train"): 174 | self.opt = opt 175 | self.mode = mode 176 | 177 | if opt.Aug == "None" or mode != "train": 178 | self.transform = ResizeNormalize((opt.imgW, opt.imgH)) 179 | elif opt.Aug == "ABINet" and mode == "train": 180 | self.transform = transforms.Compose([ 181 | CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5), 182 | CVDeterioration(var=20, degrees=6, factor=4, p=0.25), 183 | CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25), 184 | transforms.Resize( 185 | (self.opt.imgH, self.opt.imgW), interpolation=PIL.Image.BICUBIC 186 | ), 187 | transforms.ToTensor(), 188 | ]) 189 | else: 190 | self.transform = Text_augment(opt) 191 | 192 | def __call__(self, batch): 193 | images, labels = zip(*batch) 194 | image_tensors = [self.transform(image) for image in images] 195 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 196 | 197 | return image_tensors, labels 198 | 199 | class GaussianBlur(object): 200 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 201 | 202 | def __init__(self, sigma=[0.1, 2.0]): 203 | self.sigma = sigma 204 | 205 | def __call__(self, image): 206 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 207 | image = image.filter(PIL.ImageFilter.GaussianBlur(radius=sigma)) 208 | return image 209 | 210 | 211 | class RandomCrop(object): 212 | """RandomCrop, 213 | RandomResizedCrop of PyTorch 1.6 and torchvision 0.7.0 work weird with scale 0.90-1.0. 214 | i.e. you can not always make 90%~100% cropped image scale 0.90-1.0, you will get central cropped image instead. 215 | so we made RandomCrop (keeping aspect ratio version) then use Resize. 216 | """ 217 | 218 | def __init__(self, scale=[1, 1]): 219 | self.scale = scale 220 | 221 | def __call__(self, image): 222 | width, height = image.size 223 | crop_ratio = random.uniform(self.scale[0], self.scale[1]) 224 | crop_width = int(width * crop_ratio) 225 | crop_height = int(height * crop_ratio) 226 | 227 | x_start = random.randint(0, width - crop_width) 228 | y_start = random.randint(0, height - crop_height) 229 | image_crop = image.crop( 230 | (x_start, y_start, x_start + crop_width, y_start + crop_height) 231 | ) 232 | return image_crop 233 | 234 | 235 | class ResizeNormalize(object): 236 | def __init__(self, size, interpolation=PIL.Image.BICUBIC): 237 | # CAUTION: it should be (width, height). different from size of transforms.Resize (height, width) 238 | self.size = size 239 | self.interpolation = interpolation 240 | self.toTensor = transforms.ToTensor() 241 | 242 | def __call__(self, image): 243 | image = image.resize(self.size, self.interpolation) 244 | image = self.toTensor(image) 245 | image.sub_(0.5).div_(0.5) 246 | return image 247 | 248 | 249 | class Text_augment(object): 250 | """Augmentation for Text recognition""" 251 | 252 | def __init__(self, opt): 253 | self.opt = opt 254 | augmentation = [] 255 | aug_list = self.opt.Aug.split("-") 256 | for aug in aug_list: 257 | if aug.startswith("Blur"): 258 | maximum = float(aug.strip("Blur")) 259 | augmentation.append( 260 | transforms.RandomApply([GaussianBlur([0.1, maximum])], p=0.5) 261 | ) 262 | 263 | if aug.startswith("Crop"): 264 | crop_scale = float(aug.strip("Crop")) / 100 265 | augmentation.append(RandomCrop(scale=(crop_scale, 1.0))) 266 | 267 | if aug.startswith("Rot"): 268 | degree = int(aug.strip("Rot")) 269 | augmentation.append( 270 | transforms.RandomRotation( 271 | degree, resample=PIL.Image.BICUBIC, expand=True 272 | ) 273 | ) 274 | 275 | augmentation.append( 276 | transforms.Resize( 277 | (self.opt.imgH, self.opt.imgW), interpolation=PIL.Image.BICUBIC 278 | ) 279 | ) 280 | augmentation.append(transforms.ToTensor()) 281 | self.Augment = transforms.Compose(augmentation) 282 | print("Use Text_augment", augmentation) 283 | 284 | def __call__(self, image): 285 | image = self.Augment(image) 286 | image.sub_(0.5).div_(0.5) 287 | 288 | return image 289 | 290 | 291 | class MoCo_augment(object): 292 | """Take two random crops of one image as the query and key.""" 293 | 294 | def __init__(self, opt): 295 | self.opt = opt 296 | 297 | # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978 298 | augmentation = [ 299 | transforms.RandomResizedCrop( 300 | (opt.imgH, opt.imgW), scale=(0.2, 1.0), interpolation=PIL.Image.BICUBIC 301 | ), 302 | transforms.RandomGrayscale(p=0.2), 303 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 304 | transforms.RandomHorizontalFlip(), 305 | transforms.ToTensor(), 306 | ] 307 | 308 | self.Augment = transforms.Compose(augmentation) 309 | print("Use MoCo_augment", augmentation) 310 | 311 | def __call__(self, x): 312 | q = self.Augment(x) 313 | k = self.Augment(x) 314 | q.sub_(0.5).div_(0.5) 315 | k.sub_(0.5).div_(0.5) 316 | 317 | return [q, k] 318 | -------------------------------------------------------------------------------- /tools/crop_by_word.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import glob 4 | import os 5 | import os.path as osp 6 | import re 7 | 8 | import mmcv 9 | import numpy as np 10 | from shapely.geometry import Polygon 11 | 12 | def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2): 13 | """Crop text region with their bounding box. 14 | 15 | Args: 16 | src_img (np.array): The original image. 17 | box (list[float | int]): Points of quadrangle. 18 | long_edge_pad_ratio (float): Box pad ratio for long edge 19 | corresponding to font size. 20 | short_edge_pad_ratio (float): Box pad ratio for short edge 21 | corresponding to font size. 22 | """ 23 | # assert utils.is_type_list(box, (float, int)) 24 | assert len(box) == 8 25 | assert 0. <= long_edge_pad_ratio < 1.0 26 | assert 0. <= short_edge_pad_ratio < 1.0 27 | 28 | h, w = src_img.shape[:2] 29 | points_x = np.clip(np.array(box[0::2]), 0, w) 30 | points_y = np.clip(np.array(box[1::2]), 0, h) 31 | 32 | box_width = np.max(points_x) - np.min(points_x) 33 | box_height = np.max(points_y) - np.min(points_y) 34 | font_size = min(box_height, box_width) 35 | 36 | if box_height < box_width: 37 | horizontal_pad = long_edge_pad_ratio * font_size 38 | vertical_pad = short_edge_pad_ratio * font_size 39 | else: 40 | horizontal_pad = short_edge_pad_ratio * font_size 41 | vertical_pad = long_edge_pad_ratio * font_size 42 | 43 | left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w) 44 | top = np.clip(int(np.min(points_y) - vertical_pad), 0, h) 45 | right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w) 46 | bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h) 47 | 48 | dst_img = src_img[top:bottom, left:right] 49 | 50 | return dst_img 51 | 52 | def test_crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2): 53 | pts1 = np.float32([[wordBB[0][0], wordBB[1][0]], 54 | [wordBB[0][3], wordBB[1][3]], 55 | [wordBB[0][1], wordBB[1][1]], 56 | [wordBB[0][2], wordBB[1][2]]]) 57 | height = math.sqrt((wordBB[0][0] - wordBB[0][3]) ** 2 + (wordBB[1][0] - wordBB[1][3]) ** 2) 58 | width = math.sqrt((wordBB[0][0] - wordBB[0][1]) ** 2 + (wordBB[1][0] - wordBB[1][1]) ** 2) 59 | 60 | # Coord validation check 61 | if (height * width) <= 0: 62 | err_log = 'empty file : {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB) 63 | err_file.write(err_log) 64 | # print(err_log) 65 | continue 66 | elif (height * width) > (img_height * img_width): 67 | err_log = 'too big box : {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB) 68 | err_file.write(err_log) 69 | # print(err_log) 70 | continue 71 | else: 72 | valid = True 73 | for i in range(2): 74 | for j in range(4): 75 | if wordBB[i][j] < 0 or wordBB[i][j] > img.shape[1 - i]: 76 | valid = False 77 | break 78 | if not valid: 79 | break 80 | if not valid: 81 | err_log = 'invalid coord : {}\t{}\t{}\t{}\t{}\n'.format( 82 | image_name, txt[word_indx], wordBB, (width, height), (img_width, img_height)) 83 | err_file.write(err_log) 84 | # print(err_log) 85 | continue 86 | 87 | pts2 = np.float32([[0, 0], 88 | [0, height], 89 | [width, 0], 90 | [width, height]]) 91 | 92 | x_min = np.int(round(min(wordBB[0][0], wordBB[0][1], wordBB[0][2], wordBB[0][3]))) 93 | x_max = np.int(round(max(wordBB[0][0], wordBB[0][1], wordBB[0][2], wordBB[0][3]))) 94 | y_min = np.int(round(min(wordBB[1][0], wordBB[1][1], wordBB[1][2], wordBB[1][3]))) 95 | y_max = np.int(round(max(wordBB[1][0], wordBB[1][1], wordBB[1][2], wordBB[1][3]))) 96 | # print(x_min, x_max, y_min, y_max) 97 | # print(img.shape) 98 | # assert 1<0 99 | if len(img.shape) == 3: 100 | img_cropped = img[y_min:y_max:1, x_min:x_max:1, :] 101 | else: 102 | img_cropped = img[y_min:y_max:1, x_min:x_max:1] 103 | 104 | def list_to_file(filename, lines): 105 | """Write a list of strings to a text file. 106 | 107 | Args: 108 | filename (str): The output filename. It will be created/overwritten. 109 | lines (list(str)): Data to be written. 110 | """ 111 | mmcv.mkdir_or_exist(os.path.dirname(filename)) 112 | with open(filename, 'w', encoding='utf-8') as fw: 113 | for line in lines: 114 | fw.write(f'{line}\n') 115 | 116 | def list_from_file(filename, encoding='utf-8'): 117 | """Load a text file and parse the content as a list of strings. The 118 | trailing "\\r" and "\\n" of each line will be removed. 119 | 120 | Note: 121 | This will be replaced by mmcv's version after it supports encoding. 122 | 123 | Args: 124 | filename (str): Filename. 125 | encoding (str): Encoding used to open the file. Default utf-8. 126 | 127 | Returns: 128 | list[str]: A list of strings. 129 | """ 130 | item_list = [] 131 | with open(filename, 'r', encoding=encoding) as f: 132 | for line in f: 133 | item_list.append(line.rstrip('\n\r')) 134 | return item_list 135 | 136 | 137 | def load_img_info(file): 138 | """Load the information of one image. 139 | 140 | Args: 141 | files(tuple): The tuple of (img_file, groundtruth_file) 142 | dataset(str): Dataset name, icdar2015 or icdar2017 143 | 144 | Returns: 145 | img_info(dict): The dict of the img and annotation information 146 | """ 147 | # assert isinstance(files, tuple) 148 | # assert isinstance(dataset, str) 149 | # assert dataset 150 | 151 | # img_file, gt_file = files 152 | # read imgs with ignoring orientations 153 | # img = mmcv.imread(img_file, 'unchanged') 154 | gt_file = file[1] 155 | img_file = file[0] 156 | img = mmcv.imread(img_file, 'unchanged') 157 | 158 | split_name = osp.basename(osp.dirname(img_file)) 159 | img_info = dict( 160 | # remove img_prefix for filename 161 | file_name=img_file, 162 | height=img.shape[0], 163 | width=img.shape[1],) 164 | # img_file 165 | # print("gt_file{}".format(gt_file)) 166 | gt_list = list_from_file(gt_file) 167 | 168 | anno_info = [] 169 | # img_info = {} 170 | for line in gt_list: 171 | # each line has one ploygen (4 vetices), and others. 172 | # e.g., 695,885,866,888,867,1146,696,1143,Latin,9 173 | line = line.strip() 174 | strs = line.split(',') 175 | category_id = 1 176 | xy = [float(x) for x in strs[0:8]] 177 | coordinates = np.array(xy).reshape(-1, 2) 178 | polygon = Polygon(coordinates) 179 | 180 | area = polygon.area 181 | # convert to COCO style XYWH format 182 | min_x, min_y, max_x, max_y = polygon.bounds 183 | # bbox = [min_x, min_y, max_x - min_x, max_y - min_y] 184 | bbox = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y] 185 | # bbox = [min_x, min_y, max_x - min_x, max_y - min_y] 186 | anno = dict(word=strs[9], bbox=bbox) 187 | anno_info.append(anno) 188 | # print(anno) 189 | img_info.update(anno_info=anno_info) 190 | # print(img_info) 191 | return img_info 192 | 193 | 194 | def collect_files(img_dir, gt_dir): 195 | """Collect all images and their corresponding groundtruth files. 196 | 197 | Args: 198 | img_dir(str): The image directory 199 | gt_dir(str): The groundtruth directory 200 | split(str): The split of dataset. Namely: training or test 201 | Returns: 202 | files(list): The list of tuples (img_file, groundtruth_file) 203 | """ 204 | assert isinstance(img_dir, str) 205 | assert img_dir 206 | assert isinstance(gt_dir, str) 207 | assert gt_dir 208 | 209 | # note that we handle png and jpg only. Pls convert others such as gif to 210 | # jpg or png offline 211 | suffixes = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG'] 212 | # suffixes = ['.png'] 213 | 214 | imgs_list = [] 215 | for suffix in suffixes: 216 | imgs_list.extend(glob.glob(osp.join(img_dir, '*' + suffix))) 217 | 218 | imgs_list = sorted(imgs_list) 219 | ann_list = sorted( 220 | [osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir)]) 221 | 222 | files = [(img_file, gt_file) 223 | for (img_file, gt_file) in zip(imgs_list, ann_list)] 224 | assert len(files), f'No images found in {img_dir}' 225 | print(f'Loaded {len(files)} images from {img_dir}') 226 | 227 | return files 228 | 229 | 230 | def collect_annotations(files, nproc=1): 231 | """Collect the annotation information. 232 | 233 | Args: 234 | files(list): The list of tuples (image_file, groundtruth_file) 235 | nproc(int): The number of process to collect annotations 236 | Returns: 237 | images(list): The list of image information dicts 238 | """ 239 | assert isinstance(files, list) 240 | assert isinstance(nproc, int) 241 | 242 | if nproc > 1: 243 | images = mmcv.track_parallel_progress( 244 | load_img_info, files, nproc=nproc) 245 | else: 246 | images = mmcv.track_progress(load_img_info, files) 247 | 248 | return images 249 | 250 | 251 | def generate_ann(root_path, image_infos, out_dir): 252 | """Generate cropped annotations and label txt file. 253 | 254 | Args: 255 | root_path(str): The relative path of the totaltext file 256 | split(str): The split of dataset. Namely: training or test 257 | image_infos(list[dict]): A list of dicts of the img and 258 | annotation information 259 | """ 260 | 261 | dst_image_root = osp.join(out_dir, 'imgs') 262 | dst_label_file = osp.join(out_dir, 'label.txt') 263 | os.makedirs(dst_image_root, exist_ok=True) 264 | 265 | lines = [] 266 | for image_info in image_infos: 267 | index = 1 268 | src_img_path = image_info['file_name'] 269 | image = mmcv.imread(src_img_path) 270 | # src_img_root = osp.splitext(image_info['file_name'])[0].split('/')[1] 271 | src_img_root = image_info['file_name'].split('/')[-1].split(".")[0] 272 | 273 | for anno in image_info['anno_info']: 274 | word = anno['word'] 275 | dst_img = crop_img(image, anno['bbox']) 276 | 277 | # Skip invalid annotations 278 | if min(dst_img.shape) == 0: 279 | continue 280 | 281 | dst_img_name = f'{src_img_root}_{index}.png' 282 | index += 1 283 | dst_img_path = osp.join(dst_image_root, dst_img_name) 284 | mmcv.imwrite(dst_img, dst_img_path) 285 | lines.append(f'{osp.basename(dst_image_root)}/{dst_img_name} ' 286 | f'{word}') 287 | # print(lines) 288 | # print("\n") 289 | list_to_file(dst_label_file, lines) 290 | 291 | 292 | def parse_args(): 293 | parser = argparse.ArgumentParser( 294 | description='Convert SynthMLT annotations to COCO format') 295 | parser.add_argument('--root_path', help='SynthMLT root path') 296 | parser.add_argument('--lan', default="Hindi", help='languang for data') 297 | parser.add_argument('--out_dir', default="test",help='output path') 298 | # parser.add_argument( 299 | # '--split-list', 300 | # nargs='+', 301 | # help='a list of splits. e.g., "--split_list training test"') 302 | 303 | parser.add_argument( 304 | '--nproc', default=10, type=int, help='number of process') 305 | args = parser.parse_args() 306 | return args 307 | 308 | def unzip(root_path, lan): 309 | # root_path = "../dataset/SynthMLT/" 310 | img_path = "{}{}".format(root_path, lan) 311 | gt_path = "{}{}_gt".format(root_path, lan) 312 | 313 | if not os.path.exists(img_path): 314 | # os.system(f"rm -r {img_path}") 315 | cmd = "unzip -d {} {}.zip".format(img_path, img_path) 316 | os.system(cmd) 317 | 318 | 319 | if not os.path.exists(gt_path): 320 | # os.system(f"rm -r {gt_path}") 321 | cmd = "unzip -d {} {}.zip".format(gt_path, gt_path) 322 | os.system(cmd) 323 | 324 | def main(): 325 | args = parse_args() 326 | unzip(args.root_path, args.lan) 327 | # root_path = args.root_path + args.lan 328 | # out_dir = args.root_path + args.out_dir if args.out_dir else args.root_path 329 | out_dir = args.out_dir 330 | root_path = args.root_path + args.lan 331 | mmcv.mkdir_or_exist(out_dir) 332 | out_dir = osp.join(out_dir, args.lan) 333 | print("save to {}\n".format(out_dir)) 334 | 335 | 336 | 337 | # root_path = "../dataset/SynthMLT/" 338 | img_dir = "{}/{}/".format(root_path, args.lan) 339 | gt_dir = "{}_gt/{}/".format(root_path, args.lan) 340 | 341 | print(f'Converting SynthMLT to TXT\n') 342 | # print("img dir is {}\n".format(img_dir)) 343 | # print("gt dir is {}\n".format(gt_dir)) 344 | with mmcv.Timer( print_tmpl='It takes {}s to convert txt annotation'): 345 | files = collect_files(img_dir, gt_dir) 346 | # print("--------------------start------------\n{}".format(files)) 347 | image_infos = collect_annotations(files, nproc=args.nproc) 348 | generate_ann(root_path, image_infos,out_dir) 349 | print(out_dir) 350 | 351 | 352 | if __name__ == '__main__': 353 | main() 354 | -------------------------------------------------------------------------------- /data/data_manage.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy.random 3 | import torch 4 | from torch.utils.data import Dataset, ConcatDataset, Subset 5 | from data.dataset import AlignCollate, LmdbDataset, AlignCollate2, hierarchical_dataset 6 | 7 | 8 | class Dataset_Manager(object): 9 | def __init__(self,opt): 10 | self.data_list = [] 11 | self.data_loader_list = [] 12 | self.dataloader_iter_list = [] 13 | self.select_data = None 14 | self.opt = opt 15 | 16 | def get_dataset(self, taski, memory="random",index_list=None): 17 | self.data_loader_list = [] 18 | self.dataloader_iter_list = [] 19 | memory_num = self.opt.memory_num 20 | 21 | dataset = self.create_dataset(data_list=self.select_data,taski=taski) 22 | 23 | if memory != None and self.opt.il=="mrn": 24 | # curr: num/(taski-1) mem: num/(taski-1) 25 | index_current = numpy.random.choice(range(len(dataset)),int(self.opt.memory_num/(taski)),replace=False) 26 | split_dataset = Subset(dataset,index_current.tolist()) 27 | memory_data,index_list = self.rehearsal_memory(taski, random=False,total_num=self.opt.memory_num,index_array=index_list) 28 | self.create_dataloader_mix(IndexConcatDataset([memory_data,split_dataset]),self.opt.batch_size) 29 | print("taski is {} current dataset chose {}\n now dataset chose {}".format(taski,int(self.opt.memory_num/taski),len(memory_data))) 30 | elif memory == "test_ch": 31 | # curr: total mem: num/(taski-1) (repeat) 32 | # index_current = numpy.random.choice(range(len(dataset)),int(self.opt.memory_num/taski),replace=False) 33 | # split_dataset = Subset(dataset,index_current.tolist()) 34 | memory_data,index_list = self.rehearsal_memory(taski, random=False,total_num=self.opt.memory_num,index_array=index_list,repeat=True) 35 | self.create_dataloader_mix(IndexConcatDataset([memory_data,dataset]),self.opt.batch_size) 36 | print("taski is {} current dataset chose {}\n now dataset chose {}".format(taski,int(self.opt.memory_num/taski),len(memory_data))) 37 | elif memory == "large": 38 | # curr: num mem: num 39 | index_current = numpy.random.choice(range(len(dataset)), memory_num, replace=False) 40 | split_dataset = Subset(dataset, index_current.tolist()) 41 | memory_data, index_list = self.rehearsal_memory(taski, random=False, total_num=memory_num*taski, index_array=index_list) 42 | self.create_dataloader_mix(IndexConcatDataset([memory_data, split_dataset]), self.opt.batch_size) 43 | print("taski is {} current dataset chose {}\n now dataset chose {}".format(taski, int(memory_num), 44 | len(memory_data))) 45 | elif memory == "total": 46 | # curr : total mem : total(repeat) 47 | total_data_list = [] 48 | total_data_list.append(dataset) 49 | for i in range(taski): 50 | dataset = self.create_dataset(data_list=self.select_data, taski=i) 51 | total_data_list.append(dataset) 52 | self.create_dataloader_mix(IndexConcatDataset(total_data_list), self.opt.batch_size) 53 | print("taski is {} current dataset chose {} lenth dataset\n now dataset chose {}".format(taski, len(total_data_list), 54 | len(dataset))) 55 | elif memory != None: 56 | memory_data,index_list = self.rehearsal_memory(taski, random=False,total_num=memory_num,index_array=index_list) 57 | self.create_dataloader(memory_data,(self.opt.batch_size)//2) 58 | self.create_dataloader(dataset,(self.opt.batch_size)//2) 59 | else: 60 | self.create_dataloader(dataset) 61 | return index_list 62 | 63 | def joint_start( 64 | self, opt, select_data, log, taski,total_task): 65 | self.opt = opt 66 | self.select_data = select_data 67 | dashed_line = "-" * 80 68 | print(dashed_line) 69 | log.write(dashed_line + "\n") 70 | 71 | dataset = self.create_dataset(data_list=self.select_data, taski=taski) 72 | if opt.il == "joint_mix": 73 | self.data_list.append(dataset) 74 | if taski == total_task-1: 75 | self.create_dataloader(ConcatDataset(self.data_list), int(self.opt.batch_size)) 76 | elif opt.il == "joint_loader": 77 | self.create_dataloader(dataset, int(self.opt.batch_size // total_task)) 78 | 79 | 80 | def init_start( 81 | self, opt, select_data, log, taski): 82 | self.opt = opt 83 | self.select_data = select_data 84 | self.data_loader_list = [] 85 | self.dataloader_iter_list = [] 86 | dashed_line = "-" * 80 87 | print(dashed_line) 88 | log.write(dashed_line + "\n") 89 | print( 90 | f"select_data: {select_data}\n" 91 | ) 92 | log.write( 93 | f"select_data: {select_data}\n" 94 | ) 95 | self.get_dataset(taski, memory=None) 96 | 97 | def rehearsal_memory(self,taski, random=False,total_num=2000,index_array=None,repeat=False): 98 | data_list = [] 99 | select_data = self.select_data 100 | num_i = int(total_num/(taski)) 101 | print("memory size is {}\n".format(num_i)) 102 | for i in range(taski): 103 | dataset = self.create_dataset(data_list=select_data,taski=i,repeat=repeat) 104 | if random: 105 | index_list = numpy.random.choice(range(len(dataset)),num_i,replace=repeat) 106 | # print(random) 107 | else: 108 | index_list = index_array[i] 109 | split_dataset = Subset(dataset,index_list.tolist()) 110 | data_list.append(split_dataset) 111 | return ConcatDataset(data_list), index_array 112 | 113 | def rehearsal_prev_model(self,taski,): 114 | select_data = self.select_data 115 | dataset = self.create_dataset(data_list=select_data,taski=taski-1,repeat=False) 116 | data_loader = torch.utils.data.DataLoader( 117 | dataset, 118 | batch_size=self.opt.batch_size, 119 | shuffle=False, 120 | num_workers=int(self.opt.workers), 121 | collate_fn=AlignCollate(self.opt), 122 | pin_memory=False, 123 | drop_last=False, 124 | ) 125 | return data_loader,len(dataset) 126 | 127 | def create_dataset(self, data_list="/", taski=0, mode="train", repeat=True): 128 | """select_data is list for all dataset""" 129 | dataset_list = [] 130 | for data_root in data_list: 131 | # print(dataset_log) 132 | # dataset_log += "\n" 133 | dataset = LmdbDataset(data_root + "/" + self.opt.lan_list[taski], self.opt, mode=mode) 134 | dataset_log = f"num samples: {len(dataset)}" 135 | print(dataset_log) 136 | 137 | # for faster training, we multiply small datasets itself. 138 | if len(dataset) < 50000 and repeat: 139 | multiple_times = int(50000 / len(dataset)) 140 | dataset_self_multiple = [dataset] * multiple_times 141 | dataset = ConcatDataset(dataset_self_multiple) 142 | dataset_list.append(dataset) 143 | # if memory !=None: 144 | # dataset_list.append(memory_dataset) 145 | 146 | return ConcatDataset(dataset_list) 147 | 148 | def create_dataloader(self,dataset,batch_size=None): 149 | data_loader = torch.utils.data.DataLoader( 150 | dataset, 151 | batch_size=self.opt.batch_size if batch_size==None else batch_size, 152 | shuffle=True, 153 | num_workers=int(self.opt.workers), 154 | collate_fn=AlignCollate(self.opt), 155 | pin_memory=False, 156 | drop_last=False, 157 | ) 158 | self.data_loader_list.append(data_loader) 159 | self.dataloader_iter_list.append(iter(data_loader)) 160 | 161 | def create_dataloader_mix(self,dataset,batch_size=None): 162 | data_loader = torch.utils.data.DataLoader( 163 | dataset, 164 | batch_size=self.opt.batch_size if batch_size==None else batch_size, 165 | shuffle=True, 166 | num_workers=int(self.opt.workers), 167 | collate_fn=AlignCollate2(self.opt), 168 | pin_memory=False, 169 | drop_last=False, 170 | ) 171 | self.data_loader_list.append(data_loader) 172 | self.dataloader_iter_list.append(iter(data_loader)) 173 | 174 | def get_batch2(self): 175 | balanced_batch_images = [] 176 | balanced_batch_labels = [] 177 | balanced_batch_index = [] 178 | 179 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 180 | try: 181 | image, label,index = data_loader_iter.next() 182 | balanced_batch_images.append(image) 183 | balanced_batch_labels += label 184 | balanced_batch_index.append(index) 185 | except StopIteration: 186 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 187 | image, label, index = self.dataloader_iter_list[i].next() 188 | balanced_batch_images.append(image) 189 | balanced_batch_labels += label 190 | balanced_batch_index.append(index) 191 | except ValueError: 192 | pass 193 | 194 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 195 | 196 | return balanced_batch_images, balanced_batch_labels, balanced_batch_index 197 | 198 | def get_batch(self): 199 | balanced_batch_images = [] 200 | balanced_batch_labels = [] 201 | 202 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 203 | try: 204 | image, label = data_loader_iter.next() 205 | balanced_batch_images.append(image) 206 | balanced_batch_labels += label 207 | except StopIteration: 208 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 209 | image, label = self.dataloader_iter_list[i].next() 210 | balanced_batch_images.append(image) 211 | balanced_batch_labels += label 212 | except ValueError: 213 | pass 214 | 215 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 216 | 217 | return balanced_batch_images, balanced_batch_labels 218 | 219 | class Val_Dataset(object): 220 | def __init__(self,val_datas,opt): 221 | self.data_loader_list = [] 222 | self.dataset_list = [] 223 | self.current_data = val_datas[-1] 224 | self.val_datas = val_datas 225 | self.opt = opt 226 | self.AlignCollate_valid = AlignCollate(self.opt, mode="test") 227 | 228 | 229 | def create_dataset(self,val_data=None): 230 | if val_data == None: 231 | val_data = self.current_data 232 | valid_dataset, valid_dataset_log = hierarchical_dataset( 233 | root=val_data, opt=self.opt, mode="test" 234 | ) 235 | # print(valid_dataset_log) 236 | print("-" * 80) 237 | valid_loader = torch.utils.data.DataLoader( 238 | valid_dataset, 239 | batch_size=self.opt.batch_size, 240 | shuffle=True, # 'True' to check training progress with validation function. 241 | num_workers=int(self.opt.workers), 242 | collate_fn=self.AlignCollate_valid, 243 | pin_memory=False, 244 | ) 245 | return valid_loader 246 | 247 | def create_list_dataset(self,valid_datas=None): 248 | if valid_datas==None: 249 | valid_datas = self.val_datas 250 | concat_data = [] 251 | for val_data in valid_datas: 252 | valid_dataset, valid_dataset_log = hierarchical_dataset( 253 | root=val_data, opt=self.opt, mode="test") 254 | if len(valid_dataset) > 700: 255 | index_current = numpy.random.choice(range(len(valid_dataset)),700,replace=False) 256 | valid_dataset = Subset(valid_dataset,index_current.tolist()) 257 | concat_data.append(valid_dataset) 258 | print(valid_dataset_log) 259 | print("-" * 80) 260 | val_data = ConcatDataset(concat_data) 261 | valid_loader = torch.utils.data.DataLoader( 262 | val_data, 263 | batch_size=self.opt.batch_size, 264 | shuffle=True, # 'True' to check training progress with validation function. 265 | num_workers=int(self.opt.workers), 266 | collate_fn=self.AlignCollate_valid, 267 | pin_memory=False, 268 | ) 269 | return valid_loader 270 | 271 | 272 | class IndexConcatDataset(ConcatDataset): 273 | def __getitem__(self, idx): 274 | if idx < 0: 275 | if -idx > len(self): 276 | raise ValueError("absolute value of index should not exceed dataset length") 277 | idx = len(self) + idx 278 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 279 | if dataset_idx == 0: 280 | sample_idx = idx 281 | else: 282 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 283 | return self.datasets[dataset_idx][sample_idx],dataset_idx 284 | 285 | class DummyDataset(Dataset): 286 | def __init__(self, images, labels): 287 | assert len(images) == len(labels), 'Data size error!' 288 | self.images = images 289 | self.labels = labels 290 | 291 | def __len__(self): 292 | return len(self.images) 293 | 294 | def __getitem__(self, idx): 295 | image = self.images[idx] 296 | label = self.labels[idx] 297 | 298 | return (image, label) -------------------------------------------------------------------------------- /modules/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | from timm.models.layers import DropPath 6 | 7 | import math 8 | from torch import Tensor 9 | from torch.nn import init 10 | from torch.nn.modules.utils import _pair 11 | from torchvision.ops.deform_conv import deform_conv2d as deform_conv2d_tv 12 | from modules.dm_router import GatingMlpBlock 13 | 14 | 15 | class Mlp(nn.Module): 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | class CycleFC(nn.Module): 35 | """ 36 | """ 37 | 38 | def __init__( 39 | self, 40 | in_channels: int, 41 | out_channels: int, 42 | kernel_size, # re-defined kernel_size, represent the spatial area of staircase FC 43 | stride: int = 1, 44 | padding: int = 0, 45 | dilation: int = 1, 46 | groups: int = 1, 47 | bias: bool = True, 48 | ): 49 | super(CycleFC, self).__init__() 50 | 51 | if in_channels % groups != 0: 52 | raise ValueError('in_channels must be divisible by groups') 53 | if out_channels % groups != 0: 54 | raise ValueError('out_channels must be divisible by groups') 55 | if stride != 1: 56 | raise ValueError('stride must be 1') 57 | if padding != 0: 58 | raise ValueError('padding must be 0') 59 | 60 | self.in_channels = in_channels 61 | self.out_channels = out_channels 62 | self.kernel_size = kernel_size 63 | self.stride = _pair(stride) 64 | self.padding = _pair(padding) 65 | self.dilation = _pair(dilation) 66 | self.groups = groups 67 | 68 | self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, 1, 1)) # kernel size == 1 69 | 70 | if bias: 71 | self.bias = nn.Parameter(torch.empty(out_channels)) 72 | else: 73 | self.register_parameter('bias', None) 74 | self.register_buffer('offset', self.gen_offset()) 75 | 76 | self.reset_parameters() 77 | 78 | def reset_parameters(self) -> None: 79 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 80 | 81 | if self.bias is not None: 82 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 83 | bound = 1 / math.sqrt(fan_in) 84 | init.uniform_(self.bias, -bound, bound) 85 | 86 | def gen_offset(self): 87 | """ 88 | offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, 89 | out_height, out_width]): offsets to be applied for each position in the 90 | convolution kernel. 91 | """ 92 | offset = torch.empty(1, self.in_channels*2, 1, 1) 93 | start_idx = (self.kernel_size[0] * self.kernel_size[1]) // 2 94 | assert self.kernel_size[0] == 1 or self.kernel_size[1] == 1, self.kernel_size 95 | for i in range(self.in_channels): 96 | if self.kernel_size[0] == 1: 97 | offset[0, 2 * i + 0, 0, 0] = 0 98 | offset[0, 2 * i + 1, 0, 0] = (i + start_idx) % self.kernel_size[1] - (self.kernel_size[1] // 2) 99 | else: 100 | offset[0, 2 * i + 0, 0, 0] = (i + start_idx) % self.kernel_size[0] - (self.kernel_size[0] // 2) 101 | offset[0, 2 * i + 1, 0, 0] = 0 102 | return offset 103 | 104 | def forward(self, input: Tensor) -> Tensor: 105 | """ 106 | Args: 107 | input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor 108 | """ 109 | B, C, H, W = input.size() 110 | return deform_conv2d_tv(input, self.offset.expand(B, -1, H, W), self.weight, self.bias, stride=self.stride, 111 | padding=self.padding, dilation=self.dilation) 112 | 113 | def extra_repr(self) -> str: 114 | s = self.__class__.__name__ + '(' 115 | s += '{in_channels}' 116 | s += ', {out_channels}' 117 | s += ', kernel_size={kernel_size}' 118 | s += ', stride={stride}' 119 | s += ', padding={padding}' if self.padding != (0, 0) else '' 120 | s += ', dilation={dilation}' if self.dilation != (1, 1) else '' 121 | s += ', groups={groups}' if self.groups != 1 else '' 122 | s += ', bias=False' if self.bias is None else '' 123 | s += ')' 124 | return s.format(**self.__dict__) 125 | 126 | 127 | class CycleMLP(nn.Module): 128 | def __init__(self, dim, segment_dim=8, qkv_bias=False, taski=1,patch=63, proj_drop=0.): 129 | super().__init__() 130 | self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias) 131 | 132 | self.sfc_h = CycleFC(dim, dim, (1, 3), 1, 0) 133 | self.sfc_w = CycleFC(dim, dim, (3, 1), 1, 0) 134 | 135 | self.reweight = Mlp(dim, dim // 4, dim * 3) 136 | 137 | self.proj = nn.Linear(dim, dim) 138 | self.proj_drop = nn.Dropout(proj_drop) 139 | 140 | def forward(self, x): 141 | B, H, W, C = x.shape 142 | # B,C,H,W 143 | h = self.sfc_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 144 | w = self.sfc_w(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 145 | c = self.mlp_c(x) 146 | 147 | a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2) 148 | a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2) 149 | 150 | x = h * a[0] + w * a[1] + c * a[2] 151 | 152 | x = self.proj(x) 153 | x = self.proj_drop(x) 154 | 155 | return x 156 | 157 | 158 | class CycleBlock(nn.Module): 159 | 160 | def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 161 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn=CycleMLP): 162 | super().__init__() 163 | self.norm1 = norm_layer(dim) 164 | self.attn = mlp_fn(dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop) 165 | 166 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 167 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 168 | 169 | self.norm2 = norm_layer(dim) 170 | mlp_hidden_dim = int(dim * mlp_ratio) 171 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 172 | self.skip_lam = skip_lam 173 | 174 | def forward(self, x): 175 | x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam 176 | x = x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam 177 | return x 178 | 179 | 180 | # class WeightedPermuteMLP(nn.Module): 181 | # def __init__(self, dim, segment_dim=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 182 | # super().__init__() 183 | # self.segment_dim = segment_dim 184 | # 185 | # self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias) 186 | # self.mlp_h = nn.Linear(dim, dim, bias=qkv_bias) 187 | # self.mlp_w = nn.Linear(dim, dim, bias=qkv_bias) 188 | # 189 | # self.reweight = Mlp(dim, dim // 4, dim * 3) 190 | # 191 | # self.proj = nn.Linear(dim, dim) 192 | # self.proj_drop = nn.Dropout(proj_drop) 193 | # 194 | # def forward(self, x): 195 | # B, H, W, C = x.shape 196 | # 197 | # S = C // self.segment_dim 198 | # h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H * S) 199 | # h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C) 200 | # 201 | # w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W * S) 202 | # w = self.mlp_w(w).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C) 203 | # 204 | # c = self.mlp_c(x) 205 | # # B, C, H, W -> B, C,[ H, W ] 206 | # a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2) 207 | # a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2) 208 | # 209 | # x = h * a[0] + w * a[1] + c * a[2] 210 | # 211 | # x = self.proj(x) 212 | # x = self.proj_drop(x) 213 | # 214 | # return x 215 | 216 | class WeightedPermuteMLPv3(nn.Module): 217 | def __init__(self, dim, segment_dim=8, qkv_bias=False, taski=1,patch=63, proj_drop=0.,mlp="None"): 218 | super().__init__() 219 | self.segment_dim = segment_dim 220 | self.taski = taski 221 | self.patch = patch 222 | self.mlp = mlp 223 | self.mlp_c = nn.Sequential( 224 | nn.Linear(dim, dim, bias=qkv_bias), 225 | ) 226 | if self.mlp != "taski": 227 | self.mlp_h = nn.Sequential( 228 | nn.Linear(taski * dim, taski * dim, bias=qkv_bias), 229 | # nn.Linear(dim, taski, bias=qkv_bias), 230 | ) 231 | else: 232 | self.mlp_h = GatingMlpBlock(dim, dim, taski) 233 | # self.up_mlp = nn.Linear(dim // 2, dim, bias=qkv_bias) 234 | 235 | if self.mlp == "patch": 236 | self.mlp_w = GatingMlpBlock(dim, dim, patch) 237 | else: 238 | self.mlp_w = nn.Sequential( 239 | nn.Linear(patch*taski, patch*taski, bias=qkv_bias), 240 | # nn.Linear(dim, patch, bias=qkv_bias), 241 | ) 242 | self.reweight = Mlp(dim, dim // 4, dim * 2) 243 | 244 | self.proj = nn.Linear(dim, dim) 245 | self.proj_drop = nn.Dropout(proj_drop) 246 | 247 | def forward(self, x): 248 | B, H, W, C = x.shape 249 | # print(x.shape) 250 | 251 | if self.mlp != "taski": 252 | # h = rearrange(x,'b i t (h k) -> b t k (i h)',h=64) 253 | h = rearrange(x, 'b i t c -> b t (i c)') 254 | h = self.mlp_h(h) 255 | h = rearrange(h,'b t (i c) -> b i t c',i=self.taski) 256 | else: 257 | h = self.mlp_h(x.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) 258 | # h = self.up_mlp(h) 259 | 260 | # B,C, H,W -> B,H,W,C 261 | if self.mlp != "patch": 262 | w = rearrange(x,'b i t c -> b c (i t)') 263 | w = self.mlp_w(w) 264 | w = rearrange(w,'b c (i t) -> b i t c',t = self.patch) 265 | else: 266 | w = self.mlp_w(x) 267 | 268 | # B, C, H, W -> B, C,[ H, W ] 269 | a = (h + w).permute(0, 3, 1, 2).flatten(2).mean(2) 270 | a = self.reweight(a).reshape(B, C, 2).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2) 271 | 272 | x = h * a[0] + w * a[1] 273 | 274 | x = self.proj(x) 275 | x = self.proj_drop(x) 276 | 277 | return x 278 | 279 | # class WeightedPermuteMLPv2(nn.Module): 280 | # def __init__(self, dim, segment_dim=8, qkv_bias=False, taski=1,patch=63, proj_drop=0.): 281 | # super().__init__() 282 | # self.segment_dim = segment_dim 283 | # 284 | # self.mlp_c = nn.Sequential( 285 | # nn.Linear(dim, dim, bias=qkv_bias), 286 | # ) 287 | # self.mlp_h = nn.Sequential( 288 | # nn.Linear(taski, dim, bias=qkv_bias), 289 | # nn.Linear(dim, taski, bias=qkv_bias), 290 | # ) 291 | # self.mlp_w = nn.Sequential( 292 | # nn.Linear(patch, dim, bias=qkv_bias), 293 | # nn.Linear(dim, patch, bias=qkv_bias), 294 | # ) 295 | # self.reweight = Mlp(dim, dim // 4, dim * 3) 296 | # 297 | # self.proj = nn.Linear(dim, dim) 298 | # self.proj_drop = nn.Dropout(proj_drop) 299 | # 300 | # def forward(self, x): 301 | # B, H, W, C = x.shape 302 | # # print(x.shape) 303 | # 304 | # h = x.permute(0,3,2,1) 305 | # h = self.mlp_h(h).permute(0, 3, 2 , 1) 306 | # # B,C, H,W -> B,H,W,C 307 | # w = x.permute(0, 3, 1, 2) 308 | # w = self.mlp_w(w).permute(0, 2, 3, 1) 309 | # 310 | # # S = C // self.segment_dim 311 | # # h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H * S) 312 | # # h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C) 313 | # 314 | # # w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W * S) 315 | # # w = self.mlp_w(w).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C) 316 | # 317 | # c = self.mlp_c(x) 318 | # # B, C, H, W -> B, C,[ H, W ] 319 | # a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2) 320 | # a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2) 321 | # 322 | # x = h * a[0] + w * a[1] + c * a[2] 323 | # 324 | # x = self.proj(x) 325 | # x = self.proj_drop(x) 326 | # 327 | # return x 328 | 329 | class PermutatorBlock(nn.Module): 330 | 331 | def __init__(self, dim, mlp_ratio=4., taski = 1, patch = 63, segment_dim=8, qkv_bias=False, 332 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn=WeightedPermuteMLPv3): 333 | super().__init__() 334 | self.norm1 = norm_layer(dim) 335 | self.attn = mlp_fn(dim, segment_dim=segment_dim, taski=taski,patch=patch,qkv_bias=qkv_bias) 336 | 337 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 338 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 339 | 340 | self.norm2 = norm_layer(dim) 341 | mlp_hidden_dim = int(dim * mlp_ratio) 342 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 343 | self.skip_lam = skip_lam 344 | 345 | def forward(self, x): 346 | x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam 347 | x = x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam 348 | return x -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision import transforms 9 | from torchvision.transforms import Compose 10 | 11 | 12 | def sample_asym(magnitude, size=None): 13 | return np.random.beta(1, 4, size) * magnitude 14 | 15 | 16 | def sample_sym(magnitude, size=None): 17 | return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude 18 | 19 | 20 | def sample_uniform(low, high, size=None): 21 | return np.random.uniform(low, high, size=size) 22 | 23 | 24 | def get_interpolation(type='random'): 25 | if type == 'random': 26 | choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA] 27 | interpolation = choice[random.randint(0, len(choice) - 1)] 28 | elif type == 'nearest': 29 | interpolation = cv2.INTER_NEAREST 30 | elif type == 'linear': 31 | interpolation = cv2.INTER_LINEAR 32 | elif type == 'cubic': 33 | interpolation = cv2.INTER_CUBIC 34 | elif type == 'area': 35 | interpolation = cv2.INTER_AREA 36 | else: 37 | raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!') 38 | return interpolation 39 | 40 | 41 | class CVRandomRotation(object): 42 | def __init__(self, degrees=15): 43 | assert isinstance(degrees, numbers.Number), "degree should be a single number." 44 | assert degrees >= 0, "degree must be positive." 45 | self.degrees = degrees 46 | 47 | @staticmethod 48 | def get_params(degrees): 49 | return sample_sym(degrees) 50 | 51 | def __call__(self, img): 52 | angle = self.get_params(self.degrees) 53 | src_h, src_w = img.shape[:2] 54 | M = cv2.getRotationMatrix2D(center=(src_w / 2, src_h / 2), angle=angle, scale=1.0) 55 | abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1]) 56 | dst_w = int(src_h * abs_sin + src_w * abs_cos) 57 | dst_h = int(src_h * abs_cos + src_w * abs_sin) 58 | M[0, 2] += (dst_w - src_w) / 2 59 | M[1, 2] += (dst_h - src_h) / 2 60 | 61 | flags = get_interpolation() 62 | return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE) 63 | 64 | 65 | class CVRandomAffine(object): 66 | def __init__(self, degrees, translate=None, scale=None, shear=None): 67 | assert isinstance(degrees, numbers.Number), "degree should be a single number." 68 | assert degrees >= 0, "degree must be positive." 69 | self.degrees = degrees 70 | 71 | if translate is not None: 72 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 73 | "translate should be a list or tuple and it must be of length 2." 74 | for t in translate: 75 | if not (0.0 <= t <= 1.0): 76 | raise ValueError("translation values should be between 0 and 1") 77 | self.translate = translate 78 | 79 | if scale is not None: 80 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 81 | "scale should be a list or tuple and it must be of length 2." 82 | for s in scale: 83 | if s <= 0: 84 | raise ValueError("scale values should be positive") 85 | self.scale = scale 86 | 87 | if shear is not None: 88 | if isinstance(shear, numbers.Number): 89 | if shear < 0: 90 | raise ValueError("If shear is a single number, it must be positive.") 91 | self.shear = [shear] 92 | else: 93 | assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \ 94 | "shear should be a list or tuple and it must be of length 2." 95 | self.shear = shear 96 | else: 97 | self.shear = shear 98 | 99 | def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear): 100 | # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717 101 | from numpy import sin, cos, tan 102 | 103 | if isinstance(shear, numbers.Number): 104 | shear = [shear, 0] 105 | 106 | if not isinstance(shear, (tuple, list)) and len(shear) == 2: 107 | raise ValueError( 108 | "Shear should be a single value or a tuple/list containing " + 109 | "two values. Got {}".format(shear)) 110 | 111 | rot = math.radians(angle) 112 | sx, sy = [math.radians(s) for s in shear] 113 | 114 | cx, cy = center 115 | tx, ty = translate 116 | 117 | # RSS without scaling 118 | a = cos(rot - sy) / cos(sy) 119 | b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot) 120 | c = sin(rot - sy) / cos(sy) 121 | d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot) 122 | 123 | # Inverted rotation matrix with scale and shear 124 | # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 125 | M = [d, -b, 0, 126 | -c, a, 0] 127 | M = [x / scale for x in M] 128 | 129 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 130 | M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty) 131 | M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty) 132 | 133 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1 134 | M[2] += cx 135 | M[5] += cy 136 | return M 137 | 138 | @staticmethod 139 | def get_params(degrees, translate, scale_ranges, shears, height): 140 | angle = sample_sym(degrees) 141 | if translate is not None: 142 | max_dx = translate[0] * height 143 | max_dy = translate[1] * height 144 | translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy))) 145 | else: 146 | translations = (0, 0) 147 | 148 | if scale_ranges is not None: 149 | scale = sample_uniform(scale_ranges[0], scale_ranges[1]) 150 | else: 151 | scale = 1.0 152 | 153 | if shears is not None: 154 | if len(shears) == 1: 155 | shear = [sample_sym(shears[0]), 0.] 156 | elif len(shears) == 2: 157 | shear = [sample_sym(shears[0]), sample_sym(shears[1])] 158 | else: 159 | shear = 0.0 160 | 161 | return angle, translations, scale, shear 162 | 163 | def __call__(self, img): 164 | src_h, src_w = img.shape[:2] 165 | angle, translate, scale, shear = self.get_params( 166 | self.degrees, self.translate, self.scale, self.shear, src_h) 167 | 168 | M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle, (0, 0), scale, shear) 169 | M = np.array(M).reshape(2, 3) 170 | 171 | startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)] 172 | project = lambda x, y, a, b, c: int(a * x + b * y + c) 173 | endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints] 174 | 175 | rect = cv2.minAreaRect(np.array(endpoints)) 176 | bbox = cv2.boxPoints(rect).astype(dtype=np.int) 177 | max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() 178 | min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() 179 | 180 | dst_w = int(max_x - min_x) 181 | dst_h = int(max_y - min_y) 182 | M[0, 2] += (dst_w - src_w) / 2 183 | M[1, 2] += (dst_h - src_h) / 2 184 | 185 | # add translate 186 | dst_w += int(abs(translate[0])) 187 | dst_h += int(abs(translate[1])) 188 | if translate[0] < 0: M[0, 2] += abs(translate[0]) 189 | if translate[1] < 0: M[1, 2] += abs(translate[1]) 190 | 191 | flags = get_interpolation() 192 | return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE) 193 | 194 | 195 | class CVRandomPerspective(object): 196 | def __init__(self, distortion=0.5): 197 | self.distortion = distortion 198 | 199 | def get_params(self, width, height, distortion): 200 | offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int) 201 | offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int) 202 | topleft = (offset_w[0], offset_h[0]) 203 | topright = (width - 1 - offset_w[1], offset_h[1]) 204 | botright = (width - 1 - offset_w[2], height - 1 - offset_h[2]) 205 | botleft = (offset_w[3], height - 1 - offset_h[3]) 206 | 207 | startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] 208 | endpoints = [topleft, topright, botright, botleft] 209 | return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32) 210 | 211 | def __call__(self, img): 212 | height, width = img.shape[:2] 213 | startpoints, endpoints = self.get_params(width, height, self.distortion) 214 | M = cv2.getPerspectiveTransform(startpoints, endpoints) 215 | 216 | # TODO: more robust way to crop image 217 | rect = cv2.minAreaRect(endpoints) 218 | bbox = cv2.boxPoints(rect).astype(dtype=np.int) 219 | max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() 220 | min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() 221 | min_x, min_y = max(min_x, 0), max(min_y, 0) 222 | 223 | flags = get_interpolation() 224 | img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE) 225 | img = img[min_y:, min_x:] 226 | return img 227 | 228 | 229 | class CVRescale(object): 230 | 231 | def __init__(self, factor=4, base_size=(128, 512)): 232 | """ Define image scales using gaussian pyramid and rescale image to target scale. 233 | 234 | Args: 235 | factor: the decayed factor from base size, factor=4 keeps target scale by default. 236 | base_size: base size the build the bottom layer of pyramid 237 | """ 238 | if isinstance(factor, numbers.Number): 239 | self.factor = round(sample_uniform(0, factor)) 240 | elif isinstance(factor, (tuple, list)) and len(factor) == 2: 241 | self.factor = round(sample_uniform(factor[0], factor[1])) 242 | else: 243 | raise Exception('factor must be number or list with length 2') 244 | # assert factor is valid 245 | self.base_h, self.base_w = base_size[:2] 246 | 247 | def __call__(self, img): 248 | if self.factor == 0: return img 249 | src_h, src_w = img.shape[:2] 250 | cur_w, cur_h = self.base_w, self.base_h 251 | scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation()) 252 | for _ in range(self.factor): 253 | scale_img = cv2.pyrDown(scale_img) 254 | scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation()) 255 | return scale_img 256 | 257 | 258 | class CVGaussianNoise(object): 259 | def __init__(self, mean=0, var=20): 260 | self.mean = mean 261 | if isinstance(var, numbers.Number): 262 | self.var = max(int(sample_asym(var)), 1) 263 | elif isinstance(var, (tuple, list)) and len(var) == 2: 264 | self.var = int(sample_uniform(var[0], var[1])) 265 | else: 266 | raise Exception('degree must be number or list with length 2') 267 | 268 | def __call__(self, img): 269 | noise = np.random.normal(self.mean, self.var ** 0.5, img.shape) 270 | img = np.clip(img + noise, 0, 255).astype(np.uint8) 271 | return img 272 | 273 | 274 | class CVMotionBlur(object): 275 | def __init__(self, degrees=12, angle=90): 276 | if isinstance(degrees, numbers.Number): 277 | self.degree = max(int(sample_asym(degrees)), 1) 278 | elif isinstance(degrees, (tuple, list)) and len(degrees) == 2: 279 | self.degree = int(sample_uniform(degrees[0], degrees[1])) 280 | else: 281 | raise Exception('degree must be number or list with length 2') 282 | self.angle = sample_uniform(-angle, angle) 283 | 284 | def __call__(self, img): 285 | M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1) 286 | motion_blur_kernel = np.zeros((self.degree, self.degree)) 287 | motion_blur_kernel[self.degree // 2, :] = 1 288 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree)) 289 | motion_blur_kernel = motion_blur_kernel / self.degree 290 | img = cv2.filter2D(img, -1, motion_blur_kernel) 291 | img = np.clip(img, 0, 255).astype(np.uint8) 292 | return img 293 | 294 | 295 | class CVGeometry(object): 296 | def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.), 297 | shear=(45, 15), distortion=0.5, p=0.5): 298 | self.p = p 299 | type_p = random.random() 300 | if type_p < 0.33: 301 | self.transforms = CVRandomRotation(degrees=degrees) 302 | elif type_p < 0.66: 303 | self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear) 304 | else: 305 | self.transforms = CVRandomPerspective(distortion=distortion) 306 | 307 | def __call__(self, img): 308 | if random.random() < self.p: 309 | img = np.array(img) 310 | return Image.fromarray(self.transforms(img)) 311 | else: 312 | return img 313 | 314 | 315 | class CVDeterioration(object): 316 | def __init__(self, var, degrees, factor, p=0.5): 317 | self.p = p 318 | transforms = [] 319 | if var is not None: 320 | transforms.append(CVGaussianNoise(var=var)) 321 | if degrees is not None: 322 | transforms.append(CVMotionBlur(degrees=degrees)) 323 | if factor is not None: 324 | transforms.append(CVRescale(factor=factor)) 325 | 326 | random.shuffle(transforms) 327 | transforms = Compose(transforms) 328 | self.transforms = transforms 329 | 330 | def __call__(self, img): 331 | if random.random() < self.p: 332 | img = np.array(img) 333 | return Image.fromarray(self.transforms(img)) 334 | else: 335 | return img 336 | 337 | 338 | class CVColorJitter(object): 339 | def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5): 340 | self.p = p 341 | self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast, 342 | saturation=saturation, hue=hue) 343 | 344 | def __call__(self, img): 345 | if random.random() < self.p: 346 | return self.transforms(img) 347 | else: 348 | return img -------------------------------------------------------------------------------- /il_modules/der.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tqdm import tqdm 3 | import torch 4 | import torch.nn.init as init 5 | from il_modules.base import BaseLearner 6 | from modules.model import DERNet 7 | from test import validation 8 | from tools.utils import Averager, adjust_learning_rate 9 | 10 | EPSILON = 1e-8 11 | 12 | init_epoch = 200 13 | init_lr = 0.1 14 | init_milestones = [60, 120, 170] 15 | init_lr_decay = 0.1 16 | init_weight_decay = 0.0005 17 | 18 | epochs = 170 19 | lrate = 0.1 20 | milestones = [80, 120, 150] 21 | lrate_decay = 0.1 22 | batch_size = 128 23 | weight_decay = 2e-4 24 | num_workers = 8 25 | T = 2 26 | 27 | 28 | class DER(BaseLearner): 29 | 30 | def __init__(self, opt): 31 | super().__init__(opt) 32 | self.model = DERNet(opt) 33 | 34 | def after_task(self): 35 | self.model = self.model.module 36 | self._known_classes = self._total_classes 37 | # logging.info('Exemplar size: {}'.format(self.exemplar_size)) 38 | 39 | def model_eval_and_train(self,taski): 40 | self.model.train() 41 | self.model.module.model[-1].train() 42 | if taski >= 1: 43 | for i in range(taski): 44 | self.model.module.model[i].eval() 45 | 46 | def change_model(self,): 47 | """ model configuration """ 48 | # model.module.reset_class(opt, device) 49 | # self.model.update_fc(self.opt.output_channel, self._total_classes) 50 | self.model.update_fc(self.opt.hidden_size, self._total_classes) 51 | self.model.build_prediction(self.opt, self._total_classes) 52 | self.model.build_aux_prediction(self.opt, self._total_classes) 53 | # reset_class(self.model.module, self.device) 54 | # data parallel for multi-GPU 55 | self.model = torch.nn.DataParallel(self.model).to(self.device) 56 | self.model.train() 57 | # return self.model 58 | 59 | def build_model(self): 60 | """ model configuration """ 61 | # self.model.update_fc(self.opt.output_channel, self._total_classes) 62 | self.model.update_fc(self.opt.hidden_size, self._total_classes) 63 | self.model.build_prediction(self.opt, self._total_classes) 64 | self.model.build_aux_prediction(self.opt, self._total_classes) 65 | 66 | # weight initialization 67 | for name, param in self.model.named_parameters(): 68 | if "localization_fc2" in name: 69 | print(f"Skip {name} as it is already initialized") 70 | continue 71 | try: 72 | if "bias" in name: 73 | init.constant_(param, 0.0) 74 | elif "weight" in name: 75 | init.kaiming_normal_(param) 76 | except Exception as e: # for batchnorm. 77 | if "weight" in name: 78 | param.data.fill_(1) 79 | continue 80 | 81 | # data parallel for multi-GPU 82 | self.model = torch.nn.DataParallel(self.model).to(self.device) 83 | self.model.train() 84 | 85 | def incremental_train(self, taski, character, train_loader, valid_loader): 86 | 87 | # pre task classes for know classes 88 | # self._known_classes = self._total_classes 89 | self.character = character 90 | self.converter = self.build_converter() 91 | valid_loader = valid_loader.create_dataset() 92 | 93 | if taski > 0: 94 | self.change_model() 95 | else: 96 | self.criterion = self.build_criterion() 97 | self.build_model() 98 | 99 | # print opt config 100 | # self.print_config(self.opt) 101 | if taski > 0: 102 | for i in range(taski): 103 | for p in self.model.module.model[i].parameters(): 104 | p.requires_grad = False 105 | 106 | # filter that only require gradient descent 107 | filtered_parameters = self.count_param() 108 | 109 | # setup optimizer 110 | self.build_optimizer(filtered_parameters) 111 | 112 | if self.opt.start_task > taski: 113 | 114 | if taski > 0: 115 | if self.opt.memory != None: 116 | self.build_rehearsal_memory(train_loader, taski) 117 | else: 118 | train_loader.get_dataset(taski, memory=self.opt.memory) 119 | 120 | # if self.opt.ch_list!=None: 121 | # name = self.opt.ch_list[taski] 122 | # else: 123 | name = self.opt.lan_list[taski] 124 | saved_best_model = f"./saved_models/{self.opt.exp_name}/{name}_{taski}_best_score.pth" 125 | # os.system(f'cp {saved_best_model} ./result/{opt.exp_name}/') 126 | self.model.load_state_dict(torch.load(f"{saved_best_model}"), strict=True) 127 | print( 128 | 'Task {} load checkpoint from {}.'.format(taski, saved_best_model) 129 | ) 130 | 131 | else: 132 | print( 133 | 'Task {} start training for model ------{}------'.format(taski,self.opt.exp_name) 134 | ) 135 | """ start training """ 136 | self._train(0, taski, train_loader, valid_loader) 137 | 138 | 139 | def _train(self, start_iter,taski, train_loader, valid_loader): 140 | if taski == 0: 141 | self._init_train(start_iter,taski, train_loader, valid_loader) 142 | else: 143 | if self.opt.memory != None: 144 | self.build_rehearsal_memory(train_loader, taski) 145 | else: 146 | train_loader.get_dataset(taski, memory=self.opt.memory) 147 | self._update_representation(start_iter,taski, train_loader, valid_loader) 148 | self.model.module.weight_align(self._total_classes - self._known_classes) 149 | 150 | def _init_train(self,start_iter,taski, train_loader, valid_loader): 151 | # loss averager 152 | train_loss_avg = Averager() 153 | train_clf_loss = Averager() 154 | train_aux_loss = Averager() 155 | start_time = time.time() 156 | best_score = -1 157 | 158 | # training loop 159 | for iteration in tqdm( 160 | range(start_iter + 1, self.opt.num_iter + 1), 161 | total=self.opt.num_iter, 162 | position=0, 163 | leave=True, 164 | ): 165 | image_tensors, labels = train_loader.get_batch() 166 | 167 | image = image_tensors.to(self.device) 168 | labels_index, labels_length = self.converter.encode( 169 | labels, batch_max_length=self.opt.batch_max_length 170 | ) 171 | batch_size = image.size(0) 172 | 173 | # default recognition loss part 174 | if "CTC" in self.opt.Prediction: 175 | preds = self.model(image)['logits'] 176 | # preds = self.model(image) 177 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 178 | preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) 179 | loss = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length) 180 | else: 181 | preds = self.model(image, labels_index[:, :-1])['logits'] # align with Attention.forward 182 | target = labels_index[:, 1:] # without [SOS] Symbol 183 | loss = self.criterion( 184 | preds.view(-1, preds.shape[-1]), target.contiguous().view(-1) 185 | ) 186 | 187 | self.model.zero_grad() 188 | loss.backward() 189 | torch.nn.utils.clip_grad_norm_( 190 | self.model.parameters(), self.opt.grad_clip 191 | ) # gradient clipping with 5 (Default) 192 | self.optimizer.step() 193 | train_loss_avg.add(loss) 194 | 195 | if "super" in self.opt.schedule: 196 | self.scheduler.step() 197 | else: 198 | adjust_learning_rate(self.optimizer, iteration, self.opt) 199 | 200 | # validation part. 201 | # To see training progress, we also conduct validation when 'iteration == 1' 202 | if iteration % self.opt.val_interval == 0 or iteration ==1: 203 | # for validation log 204 | self.val(valid_loader, self.opt, best_score, start_time, iteration, 205 | train_loss_avg, train_clf_loss, train_aux_loss, taski) 206 | train_loss_avg.reset() 207 | 208 | def _update_representation(self,start_iter, taski, train_loader, valid_loader): 209 | # loss averager 210 | train_loss_avg = Averager() 211 | train_clf_loss = Averager() 212 | train_aux_loss = Averager() 213 | 214 | self.model_eval_and_train(taski) 215 | 216 | 217 | start_time = time.time() 218 | best_score = -1 219 | 220 | # training loop 221 | for iteration in tqdm( 222 | range(start_iter + 1, self.opt.num_iter + 1), 223 | total=self.opt.num_iter, 224 | position=0, 225 | leave=True, 226 | ): 227 | image_tensors, labels = train_loader.get_batch() 228 | 229 | image = image_tensors.to(self.device) 230 | labels_index, labels_length = self.converter.encode( 231 | labels, batch_max_length=self.opt.batch_max_length 232 | ) 233 | batch_size = image.size(0) 234 | 235 | # default recognition loss part 236 | if "CTC" in self.opt.Prediction: 237 | output = self.model(image) 238 | preds = output["logits"] 239 | aux_logits = output["aux_logits"] 240 | aux_targets = labels_index.clone() 241 | # aux_targets = torch.where(aux_targets - self._known_classes + 1 > 0, 242 | # aux_targets - self._known_classes + 1, 0) 243 | 244 | aux_preds_size = torch.IntTensor([aux_logits.size(1)] * batch_size) 245 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 246 | # B,T,C(max) -> T, B, C 247 | preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) 248 | aux_preds_log_softmax = aux_logits.log_softmax(2).permute(1, 0, 2) 249 | 250 | loss_clf = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length) 251 | loss_aux = self.criterion(aux_preds_log_softmax, aux_targets, aux_preds_size, labels_length) 252 | else: 253 | output = self.model(image, labels_index[:, :-1]) # align with Attention.forward 254 | preds = output["logits"] 255 | aux_logits = output["aux_logits"] 256 | aux_targets = labels_index.clone()[:, 1:] 257 | target = labels_index[:, 1:] # without [SOS] Symbol 258 | loss_clf = self.criterion( 259 | preds.view(-1, preds.shape[-1]), target.contiguous().view(-1) 260 | ) 261 | loss_aux = self.criterion( 262 | aux_logits.view(-1, aux_logits.shape[-1]), aux_targets.contiguous().view(-1) 263 | ) 264 | # loss = loss_clf + loss_aux 265 | loss = loss_clf 266 | 267 | self.model.zero_grad() 268 | loss.backward() 269 | torch.nn.utils.clip_grad_norm_( 270 | self.model.parameters(), self.opt.grad_clip 271 | ) # gradient clipping with 5 (Default) 272 | self.optimizer.step() 273 | train_loss_avg.add(loss) 274 | train_clf_loss.add(loss_clf) 275 | train_aux_loss.add(loss_aux) 276 | 277 | if "super" in self.opt.schedule: 278 | self.scheduler.step() 279 | else: 280 | adjust_learning_rate(self.optimizer, iteration, self.opt) 281 | 282 | # validation part. 283 | # To see training progress, we also conduct validation when 'iteration == 1' 284 | if iteration % self.opt.val_interval == 0 or iteration == 1: 285 | # for validation log 286 | self.val(valid_loader, self.opt, best_score, start_time, iteration, 287 | train_loss_avg, train_clf_loss, train_aux_loss, taski) 288 | train_loss_avg.reset() 289 | train_clf_loss.reset() 290 | train_aux_loss.reset() 291 | 292 | def val(self, valid_loader, opt, best_score, start_time, iteration, 293 | train_loss_avg,train_clf_loss, train_aux_loss, taski): 294 | self.model.eval() 295 | start_time = time.time() 296 | with torch.no_grad(): 297 | ( 298 | valid_loss, 299 | current_score, 300 | ned_score, 301 | preds, 302 | confidence_score, 303 | labels, 304 | infer_time, 305 | length_of_data, 306 | ) = validation(self.model, self.criterion, valid_loader, self.converter, opt) 307 | self.model.train() 308 | 309 | # keep best score (accuracy or norm ED) model on valid dataset 310 | # Do not use this on test datasets. It would be an unfair comparison 311 | # (training should be done without referring test set). 312 | if current_score > best_score: 313 | best_score = current_score 314 | # if opt.ch_list != None: 315 | # name = opt.ch_list[taski] 316 | # else: 317 | name = opt.lan_list[taski] 318 | torch.save( 319 | self.model.state_dict(), 320 | f"./saved_models/{opt.exp_name}/{name}_{taski}_best_score.pth", 321 | ) 322 | 323 | # validation log: loss, lr, score (accuracy or norm ED), time. 324 | lr = self.optimizer.param_groups[0]["lr"] 325 | elapsed_time = time.time() - start_time 326 | valid_log = f"\n[{iteration}/{opt.num_iter}] Train_loss: {train_loss_avg.val():0.5f}, Valid_loss: {valid_loss:0.5f} \n " 327 | if train_clf_loss !=None: 328 | valid_log += f"CLF_loss: {train_clf_loss.val():0.5f} , Aux_loss: {train_aux_loss.val():0.5f}\n" 329 | valid_log += f'{"":9s}Current_score: {current_score:0.2f}, Ned_score: {ned_score:0.2f}\n' 330 | valid_log += f'{"":9s}Current_lr: {lr:0.7f}, Best_score: {best_score:0.2f}\n' 331 | valid_log += f'{"":9s}Infer_time: {infer_time:0.2f}, Elapsed_time: {elapsed_time/length_of_data:0.4f}\n' 332 | 333 | # show some predicted results 334 | dashed_line = "-" * 80 335 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' 336 | predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n" 337 | for gt, pred, confidence in zip( 338 | labels[:5], preds[:5], confidence_score[:5] 339 | ): 340 | if "Attn" in opt.Prediction: 341 | gt = gt[: gt.find("[EOS]")] 342 | pred = pred[: pred.find("[EOS]")] 343 | 344 | predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n" 345 | predicted_result_log += f"{dashed_line}" 346 | valid_log = f"{valid_log}\n{predicted_result_log}" 347 | print(valid_log) 348 | self.write_log(valid_log + "\n") 349 | -------------------------------------------------------------------------------- /tiny_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import argparse 6 | from data.data_manage import Dataset_Manager, Val_Dataset 7 | from il_modules.base import BaseLearner 8 | from il_modules.der import DER 9 | from il_modules.mrn import MRN 10 | from il_modules.ewc import EWC 11 | from il_modules.joint import JointLearner 12 | from il_modules.lwf import LwF 13 | from il_modules.wa import WA 14 | 15 | print(os.getcwd()) 16 | import torch 17 | import torch.backends.cudnn as cudnn 18 | import torch.utils.data 19 | import numpy as np 20 | from mmcv import Config 21 | 22 | from data.dataset import hierarchical_dataset, AlignCollate 23 | from test import validation 24 | 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | def write_data_log(line): 28 | ''' 29 | 30 | :param name: 31 | :param line: list of the string [a,b,c] 32 | :return: 33 | ''' 34 | with open(f"data_any.txt", "a+") as log: 35 | log.write(line) 36 | 37 | def load_dict(path,char): 38 | ch_list = [] 39 | character = [] 40 | f = open(path + "/dict.txt") 41 | line = f.readline() 42 | while line: 43 | ch_list.append(line.strip("\n")) 44 | line = f.readline() 45 | f.close() 46 | 47 | for ch in ch_list: 48 | if char.get(ch, None) == None: 49 | char[ch] = 1 50 | for key, value in char.items(): 51 | character.append(key) 52 | print("dict has {} number characters\n".format(len(character))) 53 | return character,char 54 | 55 | 56 | def build_arg(parser): 57 | parser.add_argument( 58 | "--config", 59 | default="config/crnn_mrn.py", 60 | help="path to validation dataset", 61 | ) 62 | parser.add_argument( 63 | "--valid_datas", 64 | default=[" ../dataset/MLT17_IL/test_2017", "../dataset/MLT19_IL/test_2019"], 65 | help="path to testing dataset", 66 | ) 67 | parser.add_argument( 68 | "--select_data", 69 | type=str, 70 | default=[" ../dataset/MLT17_IL/train_2017", "../dataset/MLT19_IL/train_2019"], 71 | help="select training data.", 72 | ) 73 | parser.add_argument( 74 | "--workers", type=int, default=4, help="number of data loading workers" 75 | ) 76 | parser.add_argument("--batch_size", type=int, default=128, help="input batch size") 77 | parser.add_argument( 78 | "--num_iter", type=int, default=20000, help="number of iterations to train for" 79 | ) 80 | parser.add_argument( 81 | "--val_interval", 82 | type=int, 83 | default=5000, 84 | help="Interval between each validation", 85 | ) 86 | parser.add_argument( 87 | "--log_multiple_test", action="store_true", help="log_multiple_test" 88 | ) 89 | parser.add_argument( 90 | "--grad_clip", type=float, default=5, help="gradient clipping value. default=5" 91 | ) 92 | """ Optimizer """ 93 | parser.add_argument( 94 | "--optimizer", type=str, default="adam", help="optimizer |sgd|adadelta|adam|" 95 | ) 96 | parser.add_argument( 97 | "--lr", 98 | type=float, 99 | default=0.0005, 100 | help="learning rate, default=1.0 for Adadelta, 0.0005 for Adam", 101 | ) 102 | parser.add_argument( 103 | "--sgd_momentum", default=0.9, type=float, help="momentum for SGD" 104 | ) 105 | parser.add_argument( 106 | "--sgd_weight_decay", default=0.000001, type=float, help="weight decay for SGD" 107 | ) 108 | parser.add_argument( 109 | "--rho", 110 | type=float, 111 | default=0.95, 112 | help="decay rate rho for Adadelta. default=0.95", 113 | ) 114 | parser.add_argument( 115 | "--eps", type=float, default=1e-8, help="eps for Adadelta. default=1e-8" 116 | ) 117 | parser.add_argument( 118 | "--schedule", 119 | default="super", 120 | nargs="*", 121 | help="(learning rate schedule. default is super for super convergence, 1 for None, [0.6, 0.8] for the same setting with ASTER", 122 | ) 123 | parser.add_argument( 124 | "--lr_drop_rate", 125 | type=float, 126 | default=0.1, 127 | help="lr_drop_rate. default is the same setting with ASTER", 128 | ) 129 | 130 | """ Model Architecture """ 131 | parser.add_argument("--model_name", type=str, required=False, help="CRNN|TRBA") 132 | parser.add_argument( 133 | "--num_fiducial", 134 | type=int, 135 | default=20, 136 | help="number of fiducial points of TPS-STN", 137 | ) 138 | parser.add_argument( 139 | "--input_channel", 140 | type=int, 141 | default=3, 142 | help="the number of input channel of Feature extractor", 143 | ) 144 | parser.add_argument( 145 | "--output_channel", 146 | type=int, 147 | default=512, 148 | help="the number of output channel of Feature extractor", 149 | ) 150 | parser.add_argument( 151 | "--hidden_size", type=int, default=256, help="the size of the LSTM hidden state" 152 | ) 153 | 154 | """ Data processing """ 155 | parser.add_argument( 156 | "--batch_ratio", 157 | type=str, 158 | default="1.0", 159 | help="assign ratio for each selected data in the batch", 160 | ) 161 | parser.add_argument( 162 | "--total_data_usage_ratio", 163 | type=str, 164 | default="1.0", 165 | help="total data usage ratio, this ratio is multiplied to total number of data.", 166 | ) 167 | parser.add_argument( 168 | "--batch_max_length", type=int, default=25, help="maximum-label-length" 169 | ) 170 | parser.add_argument( 171 | "--imgH", type=int, default=32, help="the height of the input image" 172 | ) 173 | parser.add_argument( 174 | "--imgW", type=int, default=100, help="the width of the input image" 175 | ) 176 | parser.add_argument( 177 | "--NED", action="store_true", help="For Normalized edit_distance" 178 | ) 179 | parser.add_argument( 180 | "--Aug", 181 | type=str, 182 | default="None", 183 | help="whether to use augmentation |None|Blur|Crop|Rot|", 184 | ) 185 | """ exp_name and etc """ 186 | parser.add_argument("--exp_name", help="Where to store logs and models") 187 | parser.add_argument( 188 | "--manual_seed", type=int, default=111, help="for random seed setting" 189 | ) 190 | parser.add_argument( 191 | "--saved_model", default="", help="path to model to continue training" 192 | ) 193 | return parser 194 | 195 | def train(opt, log): 196 | # ["Latin", "Chinese", "Arabic", "Japanese", "Korean", "Bangla","Hindi","Symbols"] 197 | write_data_log(f"----------- {opt.exp_name} ------------\n") 198 | print(f"----------- {opt.exp_name} ------------\n") 199 | 200 | valid_datasets = train_datasets = [lan for lan in opt.lan_list] 201 | 202 | best_scores = [] 203 | ned_scores = [] 204 | valid_datas = [] 205 | char = dict() 206 | """ final options """ 207 | # print(opt) 208 | opt_log = "------------ Options -------------\n" 209 | args = vars(opt) 210 | for k, v in args.items(): 211 | if str(k) == "character" and len(str(v)) > 500: 212 | opt_log += f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}\n" 213 | opt_log += "---------------------------------------\n" 214 | # print(opt_log) 215 | log.write(opt_log) 216 | if opt.il == "lwf": 217 | learner = LwF(opt) 218 | elif opt.il == "wa": 219 | learner = WA(opt) 220 | elif opt.il == "ewc": 221 | learner = EWC(opt) 222 | elif opt.il == "der": 223 | learner = DER(opt) 224 | elif opt.il == "mrn": 225 | learner = MRN(opt) 226 | elif opt.il == "joint_mix" or opt.il == "joint_loader": 227 | learner = JointLearner(opt) 228 | else: 229 | learner = BaseLearner(opt) 230 | 231 | data_manager = Dataset_Manager(opt) 232 | for taski in range(len(train_datasets)): 233 | # train_data = os.path.join(opt.train_data, train_datasets[taski]) 234 | for valid_data in opt.valid_datas: 235 | val_data = os.path.join(valid_data, valid_datasets[taski]) 236 | valid_datas.append(val_data) 237 | 238 | valid_loader = Val_Dataset(valid_datas,opt) 239 | """dataset preparation""" 240 | select_data = opt.select_data 241 | AlignCollate_valid = AlignCollate(opt, mode="test") 242 | 243 | if opt.il =="joint_loader" or opt.il == "joint_mix": 244 | valid_datas = [] 245 | char = {} 246 | for taski in range(len(train_datasets)): 247 | # char={} 248 | # train_data = os.path.join(opt.train_data, train_datasets[taski]) 249 | for val_data in opt.valid_datas: 250 | valid_data = os.path.join(val_data, valid_datasets[taski]) 251 | valid_datas.append(valid_data) 252 | data_manager.joint_start(opt, select_data, log, taski, len(train_datasets)) 253 | for data_path in opt.select_data: 254 | opt.character, char = load_dict(data_path + f"/{opt.lan_list[taski]}", char) 255 | print(len(opt.character)) 256 | best_scores,ned_scores = learner.incremental_train(0,opt.character, data_manager, valid_loader,AlignCollate_valid,valid_datas) 257 | """ Evaluation at the end of training """ 258 | best_scores, ned_scores = learner.test(AlignCollate_valid, valid_datas, best_scores, ned_scores, 0) 259 | break 260 | if taski == 0: 261 | data_manager.init_start(opt, select_data, log, taski) 262 | train_loader = data_manager 263 | 264 | #-------load char to dict --------# 265 | for data_path in opt.select_data: 266 | if data_path=="/": 267 | opt.character = load_dict(data_path+f"/{opt.lan_list[taski]}",char) 268 | else: 269 | opt.character,tmp_char = load_dict(data_path+f"/{opt.lan_list[taski]}",char) 270 | # ----- incremental model start ------- 271 | 272 | learner.incremental_train(taski, opt.character, train_loader, valid_loader) 273 | 274 | # ----- incremental model end ------- 275 | """ Evaluation at the end of training """ 276 | best_scores,ned_scores = learner.test(AlignCollate_valid,valid_datas,best_scores,ned_scores, taski) 277 | learner.after_task() 278 | 279 | write_data_log(f"----------- {opt.exp_name} ------------\n") 280 | print(f"----------- {opt.exp_name} ------------\n") 281 | if len(opt.valid_datas) == 1: 282 | print( 283 | 'ALL Average Incremental Accuracy: {:.2f} \n'.format(sum(best_scores)/len(best_scores)) 284 | ) 285 | write_data_log('ALL Average Acc: {:.2f} \n'.format(sum(best_scores)/len(best_scores))) 286 | elif len(opt.valid_datas) == 2: 287 | print( 288 | 'ALL Average 17 Acc: {:.2f} \n'.format(sum(best_scores) / len(best_scores)) 289 | ) 290 | print( 291 | 'ALL Average 19 Acc: {:.2f} \n'.format(sum(ned_scores) / len(ned_scores)) 292 | ) 293 | write_data_log('ALL 17 Acc: {:.2f} \n'.format(sum(best_scores) / len(best_scores))) 294 | write_data_log('ALL 19 Acc: {:.2f} \n'.format(sum(ned_scores) / len(ned_scores))) 295 | 296 | def val(model, criterion, valid_loader, converter, opt,optimizer,best_score,start_time,iteration,train_loss_avg,taski): 297 | with open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") as log: 298 | model.eval() 299 | with torch.no_grad(): 300 | ( 301 | valid_loss, 302 | current_score, 303 | ned_score, 304 | preds, 305 | confidence_score, 306 | labels, 307 | infer_time, 308 | length_of_data, 309 | ) = validation(model, criterion, valid_loader, converter, opt) 310 | model.train() 311 | 312 | # keep best score (accuracy or norm ED) model on valid dataset 313 | # Do not use this on test datasets. It would be an unfair comparison 314 | # (training should be done without referring test set). 315 | if current_score > best_score: 316 | best_score = current_score 317 | # if opt.ch_list!=None: 318 | # name = opt.ch_list[taski] 319 | # else: 320 | name = opt.lan_list[taski] 321 | torch.save( 322 | model.state_dict(), 323 | f"./saved_models/{opt.exp_name}/{name}_{taski}_best_score.pth", 324 | ) 325 | 326 | # validation log: loss, lr, score (accuracy or norm ED), time. 327 | lr = optimizer.param_groups[0]["lr"] 328 | elapsed_time = time.time() - start_time 329 | valid_log = f"\n[{iteration}/{opt.num_iter}] Train_loss: {train_loss_avg.val():0.5f}, Valid_loss: {valid_loss:0.5f} \n " 330 | # valid_log += f", Semi_loss: {semi_loss_avg.val():0.5f}\n" 331 | valid_log += f'{"":9s}Current_score: {current_score:0.2f}, Ned_score: {ned_score:0.2f}\n' 332 | valid_log += f'{"":9s}Current_lr: {lr:0.7f}, Best_score: {best_score:0.2f}\n' 333 | valid_log += f'{"":9s}Infer_time: {infer_time:0.2f}, Elapsed_time: {elapsed_time:0.2f}\n' 334 | 335 | # show some predicted results 336 | dashed_line = "-" * 80 337 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' 338 | predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n" 339 | for gt, pred, confidence in zip( 340 | labels[:5], preds[:5], confidence_score[:5] 341 | ): 342 | if "Attn" in opt.Prediction: 343 | gt = gt[: gt.find("[EOS]")] 344 | pred = pred[: pred.find("[EOS]")] 345 | 346 | predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n" 347 | predicted_result_log += f"{dashed_line}" 348 | valid_log = f"{valid_log}\n{predicted_result_log}" 349 | print(valid_log) 350 | log.write(valid_log + "\n") 351 | write_data_log(f"Task {opt.lan_list[taski]} [{iteration}/{opt.num_iter}] : Score:{current_score:0.2f} LR:{lr:0.7f}\n") 352 | 353 | 354 | 355 | def test(AlignCollate_valid,valid_datas,model,criterion,converter,opt,best_scores,taski,log): 356 | print("---Start evaluation on benchmark testset----") 357 | """ keep evaluation model and result logs """ 358 | os.makedirs(f"./result/{opt.exp_name}", exist_ok=True) 359 | os.makedirs(f"./evaluation_log", exist_ok=True) 360 | # if opt.ch_list != None: 361 | # name = opt.ch_list[taski] 362 | # else: 363 | name = opt.lan_list[taski] 364 | saved_best_model = f"./saved_models/{opt.exp_name}/{name}_{taski}_best_score.pth" 365 | # os.system(f'cp {saved_best_model} ./result/{opt.exp_name}/') 366 | model.load_state_dict(torch.load(f"{saved_best_model}")) 367 | 368 | task_accs = [] 369 | for val_data in valid_datas: 370 | valid_dataset, valid_dataset_log = hierarchical_dataset( 371 | root=val_data, opt=opt, mode="test") 372 | valid_loader = torch.utils.data.DataLoader( 373 | valid_dataset, 374 | batch_size=opt.batch_size, 375 | shuffle=True, # 'True' to check training progress with validation function. 376 | num_workers=int(opt.workers), 377 | collate_fn=AlignCollate_valid, 378 | pin_memory=False, 379 | ) 380 | 381 | model.eval() 382 | with torch.no_grad(): 383 | ( 384 | valid_loss, 385 | current_score, 386 | ned_score, 387 | preds, 388 | confidence_score, 389 | labels, 390 | infer_time, 391 | length_of_data, 392 | ) = validation(model, criterion, valid_loader, converter, opt) 393 | 394 | task_accs.append(current_score) 395 | 396 | best_scores.append(sum(task_accs) / len(task_accs)) 397 | 398 | acc_log= f'Task {taski} Test Average Incremental Accuracy: {best_scores[taski]} \n Task {taski} Incremental Accuracy: {task_accs}' 399 | # acc_log = f'Task {taski} Test Average Incremental Accuracy: {best_scores[taski]} \n ' 400 | # acc_log += f'Task {taski} Incremental Accuracy: {task_accs:.2f}' 401 | write_data_log(f'Task {taski} Avg Acc: {best_scores[taski]:0.2f} \n {task_accs}\n') 402 | print(acc_log) 403 | log.write(acc_log) 404 | return best_scores,log 405 | 406 | 407 | if __name__ == "__main__": 408 | 409 | parser = argparse.ArgumentParser() 410 | parser = build_arg(parser) 411 | 412 | arg = parser.parse_args() 413 | cfg = Config.fromfile(arg.config) 414 | 415 | opt={} 416 | opt.update(cfg.common) 417 | # opt.update(cfg.test) 418 | opt.update(cfg.model) 419 | opt.update(cfg.train) 420 | opt.update(cfg.optimizer) 421 | 422 | opt = argparse.Namespace(**opt) 423 | 424 | """ Seed and GPU setting """ 425 | random.seed(opt.manual_seed) 426 | np.random.seed(opt.manual_seed) 427 | torch.manual_seed(opt.manual_seed) 428 | torch.cuda.manual_seed_all(opt.manual_seed) # if you are using multi-GPU. 429 | torch.cuda.manual_seed(opt.manual_seed) 430 | 431 | cudnn.benchmark = True # It fasten training. 432 | cudnn.deterministic = True 433 | 434 | opt.gpu_name = "_".join(torch.cuda.get_device_name().split()) 435 | if sys.platform == "linux": 436 | opt.CUDA_VISIBLE_DEVICES = os.environ["CUDA_VISIBLE_DEVICES"] 437 | else: 438 | opt.CUDA_VISIBLE_DEVICES = 0 # for convenience 439 | opt.num_gpu = torch.cuda.device_count() 440 | 441 | if sys.platform == "win32": 442 | opt.workers = 0 443 | 444 | """ directory and log setting """ 445 | if not opt.exp_name: 446 | opt.exp_name = f"Seed{opt.manual_seed}-{opt.model_name}" 447 | 448 | os.makedirs(f"./saved_models/{opt.exp_name}", exist_ok=True) 449 | log = open(f"./saved_models/{opt.exp_name}/log_train.txt", "a") 450 | command_line_input = " ".join(sys.argv) 451 | print( 452 | f"Command line input: CUDA_VISIBLE_DEVICES={opt.CUDA_VISIBLE_DEVICES} python {command_line_input}" 453 | ) 454 | log.write( 455 | f"Command line input: CUDA_VISIBLE_DEVICES={opt.CUDA_VISIBLE_DEVICES} python {command_line_input}\n" 456 | ) 457 | os.makedirs(f"./tensorboard", exist_ok=True) 458 | # opt.writer = SummaryWriter(log_dir=f"./tensorboard/{opt.exp_name}") 459 | 460 | train(opt, log) 461 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import re 6 | from datetime import date 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.utils.data 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from mmcv import Config 14 | from nltk.metrics.distance import edit_distance 15 | from tqdm import tqdm 16 | 17 | from tools.utils import CTCLabelConverter, AttnLabelConverter, Averager 18 | from data.dataset import hierarchical_dataset, AlignCollate 19 | from modules.model import Model 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): 25 | 26 | if opt.eval_type == "benchmark": 27 | """evaluation with 6 benchmark evaluation datasets""" 28 | eval_data_list = [ 29 | "IIIT5k_3000", 30 | "SVT", 31 | "IC13_1015", 32 | "IC15_2077", 33 | "SVTP", 34 | "CUTE80", 35 | ] 36 | opt.eval_data = "data_CVPR2021/evaluation/benchmark/" 37 | 38 | elif opt.eval_type == "addition": 39 | """evaluation with 7 additionally collected evaluation datasets""" 40 | eval_data_list = [ 41 | "5.COCO", 42 | "6.RCTW17", 43 | "7.Uber", 44 | "8.ArT", 45 | "9.LSVT", 46 | "10.MLT19", 47 | "11.ReCTS", 48 | ] 49 | opt.eval_data = "data_CVPR2021/evaluation/addition/" 50 | elif opt.eval_type == "IL_STR": 51 | """evaluation with IL_STR datasets""" 52 | eval_data_list = ["Latin", "Chinese", "Arabic", "Japanese", "Korean", "Bangla", "Hindi", "Symbols"] 53 | 54 | opt.eval_data = "../dataset/MLT2019/test_2019/" 55 | 56 | if calculate_infer_time: 57 | eval_batch_size = ( 58 | 1 # batch_size should be 1 to calculate the GPU inference time per image. 59 | ) 60 | else: 61 | eval_batch_size = opt.batch_size 62 | 63 | accuracy_list = [] 64 | total_forward_time = 0 65 | total_eval_data_number = 0 66 | total_correct_number = 0 67 | log = open(f"./result/{opt.exp_name}/log_all_evaluation.txt", "a") 68 | dashed_line = "-" * 80 69 | print(dashed_line) 70 | log.write(dashed_line + "\n") 71 | for eval_data in eval_data_list: 72 | eval_data_path= opt.eval_data+eval_data 73 | # eval_data_path = os.path.join(opt.eval_data, eval_data) 74 | AlignCollate_eval = AlignCollate(opt, mode="test") 75 | eval_data, eval_data_log = hierarchical_dataset( 76 | root=eval_data_path, opt=opt, mode="test" 77 | ) 78 | eval_loader = torch.utils.data.DataLoader( 79 | eval_data, 80 | batch_size=eval_batch_size, 81 | shuffle=False, 82 | num_workers=int(opt.workers), 83 | collate_fn=AlignCollate_eval, 84 | pin_memory=True, 85 | ) 86 | 87 | _, accuracy_by_best_model, ned_score, _, _, _, infer_time, length_of_data = validation( 88 | model, criterion, eval_loader, converter, opt, tqdm_position=0 89 | ) 90 | accuracy_list.append(f"{accuracy_by_best_model:0.2f}") 91 | total_forward_time += infer_time 92 | total_eval_data_number += len(eval_data) 93 | total_correct_number += accuracy_by_best_model * length_of_data 94 | log.write(eval_data_log) 95 | print(f"Acc {accuracy_by_best_model:0.2f}") 96 | log.write(f"Acc {accuracy_by_best_model:0.2f}\n") 97 | print(f"Ned {ned_score:0.2f}") 98 | log.write(f"Ned {ned_score:0.2f}\n") 99 | print(dashed_line) 100 | log.write(dashed_line + "\n") 101 | 102 | averaged_forward_time = total_forward_time / total_eval_data_number * 1000 103 | total_accuracy = total_correct_number / total_eval_data_number 104 | params_num = sum([np.prod(p.size()) for p in model.parameters()]) 105 | 106 | eval_log = "accuracy: " 107 | for name, accuracy in zip(eval_data_list, accuracy_list): 108 | eval_log += f"{name}: {accuracy}\t" 109 | eval_log += f"total_accuracy: {total_accuracy:0.2f}\t" 110 | eval_log += f"averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.2f}" 111 | print(eval_log) 112 | log.write(eval_log + "\n") 113 | 114 | # for convenience 115 | print("\t".join(accuracy_list)) 116 | print(f"Total_accuracy:{total_accuracy:0.2f}") 117 | log.write("\t".join(accuracy_list) + "\n") 118 | log.write(f"Total_accuracy:{total_accuracy:0.2f}" + "\n") 119 | log.close() 120 | 121 | # for convenience 122 | today = date.today() 123 | if opt.log_multiple_test: 124 | log_all_model = open(f"./evaluation_log/log_multiple_test_{today}.txt", "a") 125 | log_all_model.write("\t".join(accuracy_list) + "\n") 126 | else: 127 | log_all_model = open( 128 | f"./evaluation_log/log_all_model_evaluation_{today}.txt", "a" 129 | ) 130 | log_all_model.write( 131 | f"./result/{opt.exp_name}\tTotal_accuracy:{total_accuracy:0.2f}\n" 132 | ) 133 | log_all_model.write("\t".join(accuracy_list) + "\n") 134 | log_all_model.close() 135 | 136 | return total_accuracy, eval_data_list, accuracy_list 137 | 138 | 139 | def validation(model, criterion, eval_loader, converter, opt, val_choose="val",tqdm_position=1): 140 | """validation or evaluation""" 141 | n_correct = 0 142 | norm_ED = 0 143 | length_of_data = 0 144 | infer_time = 0 145 | valid_loss_avg = Averager() 146 | 147 | for i, (image_tensors, labels) in tqdm( 148 | enumerate(eval_loader), 149 | total=len(eval_loader), 150 | position=tqdm_position, 151 | leave=False, 152 | ): 153 | batch_size = image_tensors.size(0) 154 | length_of_data = length_of_data + batch_size 155 | image = image_tensors.to(device) 156 | # For max length prediction 157 | labels_index, labels_length = converter.encode( 158 | labels, batch_max_length=opt.batch_max_length 159 | ) 160 | 161 | if "CTC" in opt.Prediction: 162 | start_time = time.time() 163 | if val_choose == "FF": 164 | preds = model(image, cross = False, is_train = False) 165 | elif val_choose == "TF": 166 | preds = model(image,cross = True, is_train = False) 167 | else: 168 | preds = model(image, is_train = False) 169 | if len(preds) == 3 or len(preds) == 4: 170 | preds = preds['logits'] 171 | elif len(preds) == 2: 172 | preds = preds['predict'] 173 | forward_time = time.time() - start_time 174 | 175 | # Calculate evaluation loss for CTC deocder. 176 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 177 | # permute 'preds' to use CTCloss format 178 | cost = criterion( 179 | preds.log_softmax(2).permute(1, 0, 2), 180 | labels_index, 181 | preds_size, 182 | labels_length, 183 | ) 184 | 185 | else: 186 | text_for_pred = ( 187 | torch.LongTensor(batch_size).fill_(converter.dict["[SOS]"]).to(device) 188 | ) 189 | 190 | start_time = time.time() 191 | # preds = model(image, text_for_pred, is_train=False) 192 | if val_choose == "FF": 193 | preds = model(image, cross = False,text = text_for_pred, is_train = False) 194 | elif val_choose == "TF": 195 | preds = model(image,cross = True, text = text_for_pred, is_train = False) 196 | else: 197 | preds = model(image, text = text_for_pred, is_train=False) 198 | if len(preds) == 3: 199 | preds = preds['logits'] 200 | elif len(preds) == 2: 201 | preds = preds['predict'] 202 | forward_time = time.time() - start_time 203 | 204 | target = labels_index[:, 1:] # without [SOS] Symbol 205 | cost = criterion( 206 | preds.contiguous().view(-1, preds.shape[-1]), 207 | target.contiguous().view(-1), 208 | ) 209 | 210 | # select max probabilty (greedy decoding) then decode index to character 211 | _, preds_index = preds.max(2) 212 | preds_size = torch.IntTensor([preds.size(1)] * preds_index.size(0)).to(device) 213 | preds_str = converter.decode(preds_index, preds_size) 214 | 215 | infer_time += forward_time 216 | valid_loss_avg.add(cost) 217 | 218 | # calculate accuracy & confidence score 219 | preds_prob = F.softmax(preds, dim=2) 220 | preds_max_prob, _ = preds_prob.max(dim=2) 221 | confidence_score_list = [] 222 | for gt, prd, prd_max_prob in zip(labels, preds_str, preds_max_prob): 223 | if "Attn" in opt.Prediction: 224 | prd_EOS = prd.find("[EOS]") 225 | prd = prd[:prd_EOS] # prune after "end of sentence" token ([EOS]) 226 | prd_max_prob = prd_max_prob[:prd_EOS] 227 | 228 | """ 229 | In our experiment, if the model predicts at least one [UNK] token, we count the word prediction as incorrect. 230 | To not take account of [UNK] token, use the below line. 231 | prd = prd.replace('[UNK]', '') 232 | """ 233 | 234 | # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. = same with ASTER 235 | # gt = gt.lower() 236 | # prd = prd.lower() 237 | # alphanumeric_case_insensitve = "0123456789abcdefghijklmnopqrstuvwxyz" 238 | # out_of_alphanumeric_case_insensitve = f"[^{alphanumeric_case_insensitve}]" 239 | # gt = re.sub(out_of_alphanumeric_case_insensitve, "", gt) 240 | # prd = re.sub(out_of_alphanumeric_case_insensitve, "", prd) 241 | 242 | 243 | if opt.NED: 244 | # ICDAR2019 Normalized Edit Distance 245 | if len(gt) == 0 or len(prd) == 0: 246 | norm_ED += 0 247 | elif len(gt) > len(prd): 248 | norm_ED += 1 - edit_distance(prd, gt) / len(gt) 249 | else: 250 | norm_ED += 1 - edit_distance(prd, gt) / len(prd) 251 | 252 | if prd == gt: 253 | n_correct += 1 254 | 255 | # calculate confidence score (= multiply of prd_max_prob) 256 | try: 257 | confidence_score = prd_max_prob.cumprod(dim=0)[-1] 258 | except: 259 | confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([EOS]) 260 | confidence_score_list.append(confidence_score) 261 | 262 | ned_score=None 263 | 264 | if opt.NED: 265 | # ICDAR2019 Normalized Edit Distance. In web page, they report % of norm_ED (= norm_ED * 100). 266 | ned_score = norm_ED / float(length_of_data) * 100 267 | 268 | score = n_correct / float(length_of_data) * 100 # accuracy 269 | 270 | return ( 271 | valid_loss_avg.val(), 272 | score, 273 | ned_score, 274 | preds_str, 275 | confidence_score_list, 276 | labels, 277 | infer_time, 278 | length_of_data, 279 | ) 280 | 281 | 282 | def test(opt): 283 | """model configuration""" 284 | opt.character = [] 285 | f = open(opt.train_data+"/dict.txt") 286 | line = f.readline() 287 | while line: 288 | opt.character.append(line.strip("\n")) 289 | # print(line) 290 | line = f.readline() 291 | f.close() 292 | if "CTC" in opt.Prediction: 293 | converter = CTCLabelConverter(opt.character) 294 | else: 295 | converter = AttnLabelConverter(opt.character) 296 | opt.sos_token_index = converter.dict["[SOS]"] 297 | opt.eos_token_index = converter.dict["[EOS]"] 298 | opt.num_class = len(converter.character) 299 | 300 | model = Model(opt) 301 | print( 302 | "model input parameters", 303 | opt.imgH, 304 | opt.imgW, 305 | opt.num_fiducial, 306 | opt.input_channel, 307 | opt.output_channel, 308 | opt.hidden_size, 309 | opt.num_class, 310 | opt.batch_max_length, 311 | opt.Transformation, 312 | opt.FeatureExtraction, 313 | opt.SequenceModeling, 314 | opt.Prediction, 315 | ) 316 | model = torch.nn.DataParallel(model).to(device) 317 | 318 | # load model 319 | print("loading pretrained model from %s" % opt.saved_model) 320 | try: 321 | model.load_state_dict(torch.load(opt.saved_model, map_location=device)) 322 | except: 323 | print( 324 | "*** pretrained model not match strictly *** and thus load_state_dict with strict=False mode" 325 | ) 326 | # pretrained_state_dict = torch.load(opt.saved_model) 327 | # for name in pretrained_state_dict: 328 | # print(name) 329 | model.load_state_dict( 330 | torch.load(opt.saved_model, map_location=device), strict=False 331 | ) 332 | 333 | opt.exp_name = "_".join(opt.saved_model.split("/")[1:]) 334 | # print(model) 335 | 336 | """ keep evaluation model and result logs """ 337 | os.makedirs(f"./result/{opt.exp_name}", exist_ok=True) 338 | # os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/') 339 | 340 | """ setup loss """ 341 | if "CTC" in opt.Prediction: 342 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 343 | else: 344 | # ignore [PAD] token 345 | criterion = torch.nn.CrossEntropyLoss(ignore_index=converter.dict["[PAD]"]).to( 346 | device 347 | ) 348 | 349 | """ evaluation """ 350 | model.eval() 351 | with torch.no_grad(): 352 | if ( 353 | opt.eval_type 354 | ): # evaluate 6 benchmark evaluation datasets or 7 additionally collected evaluation datasets 355 | benchmark_all_eval(model, criterion, converter, opt) 356 | else: 357 | log = open(f"./result/{opt.exp_name}/log_evaluation.txt", "a") 358 | AlignCollate_eval = AlignCollate(opt, mode="test") 359 | eval_data, eval_data_log = hierarchical_dataset( 360 | root=opt.eval_data, opt=opt, mode="test" 361 | ) 362 | eval_loader = torch.utils.data.DataLoader( 363 | eval_data, 364 | batch_size=opt.batch_size, 365 | shuffle=False, 366 | num_workers=int(opt.workers), 367 | collate_fn=AlignCollate_eval, 368 | pin_memory=True, 369 | ) 370 | _, score_by_best_model, ned_score,_, _, _, _, _ = validation( 371 | model, criterion, eval_loader, converter, opt 372 | ) 373 | log.write(eval_data_log) 374 | print(f"best acc score {score_by_best_model:0.2f}") 375 | print(f"best ned score {ned_score:0.2f}") 376 | log.write(f"best acc score{score_by_best_model:0.2f}\n") 377 | log.write(f"best ned score{ned_score:0.2f}\n") 378 | log.close() 379 | 380 | 381 | if __name__ == "__main__": 382 | parser = argparse.ArgumentParser() 383 | parser.add_argument( 384 | "--config", 385 | default="config/crnn.py", 386 | help="path to validation dataset", 387 | ) 388 | parser.add_argument("--eval_data", help="path to evaluation dataset") 389 | parser.add_argument( 390 | "--eval_type", 391 | type=str, 392 | help="evaluate 6 benchmark evaluation datasets or 7 additionally collected evaluation datasets |benchmark|addition|", 393 | ) 394 | parser.add_argument( 395 | "--workers", type=int, help="number of data loading workers", default=4 396 | ) 397 | parser.add_argument("--batch_size", type=int, default=512, help="input batch size") 398 | parser.add_argument( 399 | "--saved_model", help="path to saved_model to evaluation" 400 | ) 401 | parser.add_argument( 402 | "--log_multiple_test", action="store_true", help="log_multiple_test" 403 | ) 404 | """ Data processing """ 405 | parser.add_argument( 406 | "--batch_max_length", type=int, default=25, help="maximum-label-length" 407 | ) 408 | parser.add_argument( 409 | "--imgH", type=int, default=32, help="the height of the input image" 410 | ) 411 | parser.add_argument( 412 | "--imgW", type=int, default=100, help="the width of the input image" 413 | ) 414 | parser.add_argument( 415 | "--character", 416 | type=str, 417 | default="0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~", 418 | help="character label", 419 | ) 420 | parser.add_argument( 421 | "--NED", action="store_true", help="For Normalized edit_distance" 422 | ) 423 | parser.add_argument( 424 | "--Aug", 425 | type=str, 426 | default="None", 427 | help="whether to use augmentation |None|Blur|Crop|Rot|", 428 | ) 429 | # parser.add_argument( 430 | # "--semi", 431 | # type=str, 432 | # default="None", 433 | # help="whether to use semi-supervised learning |None|PL|MT|", 434 | # ) 435 | """ Model Architecture """ 436 | parser.add_argument("--model_name", type=str, help="CRNN|TRBA") 437 | parser.add_argument( 438 | "--num_fiducial", 439 | type=int, 440 | default=20, 441 | help="number of fiducial points of TPS-STN", 442 | ) 443 | parser.add_argument( 444 | "--input_channel", 445 | type=int, 446 | default=3, 447 | help="the number of input channel of Feature extractor", 448 | ) 449 | parser.add_argument( 450 | "--output_channel", 451 | type=int, 452 | default=512, 453 | help="the number of output channel of Feature extractor", 454 | ) 455 | parser.add_argument( 456 | "--hidden_size", type=int, default=256, help="the size of the LSTM hidden state" 457 | ) 458 | 459 | arg = parser.parse_args() 460 | cfg = Config.fromfile(arg.config) 461 | # optcfg.model 462 | # opt.update(arg) 463 | # cfg.merge_from_dict(cfg.model) 464 | # opt.merge_from_dict(cfg.train) 465 | # opt.merge_from_dict(cfg.optimizer) 466 | 467 | opt = {} 468 | opt.update(cfg.common) 469 | opt.update(cfg.model) 470 | opt.update(cfg.train) 471 | opt.update(cfg.optimizer) 472 | opt.update(cfg.test) 473 | opt = argparse.Namespace(**opt) 474 | # opt.saved_model=cfg.test.saved_model 475 | # print(cfg.test.saved_model) 476 | if opt.model_name == "CRNN": 477 | opt.Transformation = "None" 478 | opt.FeatureExtraction = "VGG" 479 | opt.SequenceModeling = "BiLSTM" 480 | opt.Prediction = "CTC" 481 | 482 | elif opt.model_name == "TRBA": 483 | opt.Transformation = "TPS" 484 | opt.FeatureExtraction = "ResNet" 485 | opt.SequenceModeling = "BiLSTM" 486 | opt.Prediction = "Attn" 487 | 488 | elif opt.model_name == "RBA": # RBA 489 | opt.Transformation = "None" 490 | opt.FeatureExtraction = "ResNet" 491 | opt.SequenceModeling = "BiLSTM" 492 | opt.Prediction = "Attn" 493 | 494 | cudnn.benchmark = True 495 | cudnn.deterministic = True 496 | opt.num_gpu = torch.cuda.device_count() 497 | if opt.num_gpu > 1: 498 | print( 499 | "We recommend to use 1 GPU, check your GPU number, you would miss CUDA_VISIBLE_DEVICES=0 or typo" 500 | ) 501 | print("To use multi-gpu setting, remove or comment out these lines") 502 | sys.exit() 503 | 504 | if sys.platform == "win32": 505 | opt.workers = 0 506 | 507 | os.makedirs(f"./evaluation_log", exist_ok=True) 508 | 509 | test(opt) 510 | -------------------------------------------------------------------------------- /modules/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from einops import rearrange 4 | import torch.nn as nn 5 | 6 | from modules.dm_router import DM_Router 7 | from modules.transformation import TPS_SpatialTransformerNetwork 8 | from modules.feature_extraction import ( 9 | VGG_FeatureExtractor, 10 | RCNN_FeatureExtractor, 11 | ResNet_FeatureExtractor, SVTR_FeatureExtractor, 12 | ) 13 | from modules.sequence_modeling import BidirectionalLSTM 14 | from modules.prediction import Attention 15 | 16 | 17 | class Model_Extractor(nn.Module): 18 | def __init__(self, opt): 19 | super(Model_Extractor, self).__init__() 20 | self.opt = opt 21 | self.stages = { 22 | "Trans": opt.Transformation, 23 | "Feat": opt.FeatureExtraction, 24 | "Seq": opt.SequenceModeling, 25 | "Pred": opt.Prediction, 26 | } 27 | 28 | """ Transformation """ 29 | if opt.Transformation == "TPS": 30 | self.Transformation = TPS_SpatialTransformerNetwork( 31 | F=opt.num_fiducial, 32 | I_size=(opt.imgH, opt.imgW), 33 | I_r_size=(opt.imgH, opt.imgW), 34 | I_channel_num=opt.input_channel, 35 | ) 36 | else: 37 | print("No Transformation module specified") 38 | 39 | """ FeatureExtraction """ 40 | if opt.FeatureExtraction == "VGG": 41 | self.FeatureExtraction = VGG_FeatureExtractor( 42 | opt.input_channel, opt.output_channel 43 | ) 44 | elif opt.FeatureExtraction == "RCNN": 45 | self.FeatureExtraction = RCNN_FeatureExtractor( 46 | opt.input_channel, opt.output_channel 47 | ) 48 | elif opt.FeatureExtraction == "ResNet": 49 | self.FeatureExtraction = ResNet_FeatureExtractor( 50 | opt.input_channel, opt.output_channel 51 | ) 52 | elif opt.FeatureExtraction == "SVTR": 53 | self.FeatureExtraction = SVTR_FeatureExtractor( 54 | opt.input_channel, opt.output_channel 55 | ) 56 | else: 57 | raise Exception("No FeatureExtraction module specified") 58 | self.FeatureExtraction_output = opt.output_channel 59 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( 60 | (None, 1) 61 | ) # Transform final (imgH/16-1) -> 1 62 | 63 | """Sequence modeling""" 64 | if opt.SequenceModeling == "BiLSTM": 65 | self.SequenceModeling = nn.Sequential( 66 | BidirectionalLSTM( 67 | self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size 68 | ), 69 | BidirectionalLSTM( 70 | opt.hidden_size, opt.hidden_size, opt.hidden_size 71 | ), 72 | ) 73 | self.SequenceModeling_output = opt.hidden_size 74 | else: 75 | self.SequenceModeling = nn.Sequential( 76 | nn.Linear( 77 | self.FeatureExtraction_output, opt.hidden_size) 78 | ) 79 | print("No SequenceModeling module specified") 80 | self.SequenceModeling_output = opt.hidden_size 81 | 82 | def forward(self, image): 83 | """Transformation stage""" 84 | if not self.stages["Trans"] == "None": 85 | image = self.Transformation(image) 86 | 87 | """ Feature extraction stage """ 88 | visual_feature = self.FeatureExtraction(image) 89 | visual_feature = visual_feature.permute( 90 | 0, 3, 1, 2 91 | ) # [b, c, h, w] -> [b, w, c, h] 92 | visual_feature = self.AdaptiveAvgPool( 93 | visual_feature 94 | ) # [b, w, c, h] -> [b, w, c, 1] 95 | visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c] 96 | 97 | """ Sequence modeling stage """ 98 | contextual_feature = self.SequenceModeling( 99 | visual_feature 100 | ) # [b, num_steps, opt.hidden_size] 101 | return contextual_feature # [b, num_steps, opt.num_class] 102 | 103 | 104 | 105 | class Model(nn.Module): 106 | def __init__(self, opt): 107 | super(Model, self).__init__() 108 | self.opt = opt 109 | self.model = Model_Extractor(opt) 110 | self.SequenceModeling_output = self.model.SequenceModeling_output 111 | self.stages = { 112 | "Pred": opt.Prediction, 113 | } 114 | self.fc = None 115 | self.Prediction=None 116 | 117 | def reset_class(self, opt, device): 118 | 119 | """Prediction""" 120 | if opt.Prediction == "CTC": 121 | self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) 122 | elif opt.Prediction == "Attn": 123 | self.Prediction = Attention( 124 | self.SequenceModeling_output, opt.hidden_size, opt.num_class 125 | ) 126 | else: 127 | raise Exception("Prediction is neither CTC or Attn") 128 | 129 | self.Prediction.to(device) 130 | 131 | 132 | 133 | def forward(self, image, text=None, is_train=True): 134 | """Transformation stage""" 135 | contextual_feature = self.model(image) 136 | """ Prediction stage """ 137 | if self.stages["Pred"] == "CTC": 138 | prediction = self.Prediction(contextual_feature.contiguous()) 139 | else: 140 | prediction = self.Prediction( 141 | contextual_feature.contiguous(), 142 | text, 143 | is_train, 144 | batch_max_length=self.opt.batch_max_length, 145 | ) 146 | 147 | # return prediction # [b, num_steps, opt.num_class] 148 | return {"predict":prediction,"feature":contextual_feature} 149 | 150 | def update_fc(self, hidden_size, nb_classes,device=None): 151 | fc = nn.Linear(hidden_size, nb_classes) 152 | if self.fc is not None: 153 | nb_output = self.fc.out_features 154 | weight = copy.deepcopy(self.fc.weight.data) 155 | bias = copy.deepcopy(self.fc.bias.data) 156 | fc.weight.data[:nb_output] = weight 157 | fc.bias.data[:nb_output] = bias 158 | 159 | # del self.fc 160 | self.fc = fc 161 | 162 | def new_fc(self, hidden_size, nb_classes): 163 | # print("new_fc") 164 | self.fc = nn.Linear(hidden_size, nb_classes) 165 | 166 | def weight_align(self, increment): 167 | weights=self.fc.weight.data 168 | newnorm=(torch.norm(weights[-increment:,:],p=2,dim=1)) 169 | oldnorm=(torch.norm(weights[:-increment,:],p=2,dim=1)) 170 | meannew=torch.mean(newnorm) 171 | meanold=torch.mean(oldnorm) 172 | gamma=meanold/meannew 173 | print('alignweights,gamma=',gamma) 174 | self.fc.weight.data[-increment:,:]*=gamma 175 | 176 | def build_prediction(self,opt,num_class): 177 | """Prediction""" 178 | # print("build_prediction") 179 | if opt.Prediction == "CTC": 180 | # self.fc = nn.Linear(self.SequenceModeling_output, num_class) 181 | self.Prediction = self.fc 182 | # self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) 183 | elif opt.Prediction == "Attn": 184 | # self.fc = nn.Linear(opt.hidden_size, num_class) 185 | self.Prediction = Attention( 186 | self.SequenceModeling_output, opt.hidden_size, num_class,self.fc 187 | ) 188 | else: 189 | raise Exception("Prediction is neither CTC or Attn") 190 | 191 | def copy(self): 192 | return copy.deepcopy(self) 193 | 194 | def freeze(self): 195 | for param in self.parameters(): 196 | param.requires_grad = False 197 | self.eval() 198 | 199 | return self 200 | 201 | 202 | 203 | class DERNet(Model): 204 | def __init__(self, opt): 205 | super(DERNet,self).__init__(opt) 206 | self.model = nn.ModuleList() 207 | self.out_dim = None 208 | self.fc = None 209 | self.aux_fc=None 210 | self.task_sizes = [] 211 | 212 | @property 213 | def feature_dim(self): 214 | if self.out_dim is None: 215 | return 0 216 | return self.out_dim*len(self.model) 217 | 218 | def extract_vector(self, x): 219 | features = [convnet(x) for convnet in self.model] 220 | features = torch.cat(features, 1) 221 | return features 222 | 223 | def forward(self, image, text=None, is_train=True): 224 | """Transformation stage""" 225 | features = [convnet(image) for convnet in self.model] 226 | contextual_feature = torch.cat(features, -1) 227 | 228 | """ Prediction stage """ 229 | if self.stages["Pred"] == "CTC": 230 | prediction = self.Prediction(contextual_feature.contiguous()) 231 | else: 232 | prediction = self.Prediction( 233 | contextual_feature.contiguous(), 234 | text, 235 | is_train, 236 | batch_max_length=self.opt.batch_max_length, 237 | ) 238 | 239 | """ Prediction stage """ 240 | if self.stages["Pred"] == "CTC": 241 | aux_logits = self.aux_Prediction(contextual_feature[:,:,-self.out_dim:].contiguous()) 242 | else: 243 | aux_logits = self.aux_Prediction( 244 | contextual_feature[:,:,-self.out_dim:].contiguous(), 245 | text, 246 | is_train, 247 | batch_max_length=self.opt.batch_max_length, 248 | ) 249 | # out=self.fc(features) #{logics: self.fc(features)} 250 | out = dict({"logits":prediction}) 251 | # aux_logits=self.aux_fc(contextual_feature[:,-self.out_dim:]) 252 | 253 | out.update({"aux_logits":aux_logits,"features":contextual_feature.contiguous()}) 254 | return out # [b, num_steps, opt.num_class] 255 | 256 | def update_fc(self, hidden_size, nb_classes,device=None): 257 | if len(self.model)==0: 258 | self.model.append(Model_Extractor(self.opt)) 259 | else: 260 | self.model.append(Model_Extractor(self.opt)) 261 | self.model[-1].load_state_dict(self.model[-2].state_dict()) 262 | 263 | if self.out_dim is None: 264 | self.out_dim=self.model[-1].SequenceModeling_output 265 | fc = nn.Linear(self.feature_dim if self.opt.Prediction=="CTC" else self.out_dim, nb_classes) 266 | if self.fc is not None: 267 | nb_output = self.fc.out_features 268 | weight = copy.deepcopy(self.fc.weight.data) 269 | bias = copy.deepcopy(self.fc.bias.data) 270 | fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight 271 | fc.bias.data[:nb_output] = bias 272 | 273 | del self.fc 274 | self.fc = fc 275 | # new_task_size = nb_classes - sum(self.task_sizes) 276 | # self.task_sizes.append(new_task_size) 277 | 278 | self.aux_fc= nn.Linear(self.out_dim,nb_classes) 279 | 280 | def build_prediction(self,opt,num_class): 281 | """Prediction""" 282 | # print("build_prediction") 283 | if opt.Prediction == "CTC": 284 | # self.fc = nn.Linear(self.SequenceModeling_output, num_class) 285 | self.Prediction = self.fc 286 | # self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) 287 | elif opt.Prediction == "Attn": 288 | # self.fc = nn.Linear(opt.hidden_size, num_class) 289 | self.Prediction = Attention( 290 | self.feature_dim, opt.hidden_size, num_class,self.fc 291 | ) 292 | else: 293 | raise Exception("Prediction is neither CTC or Attn") 294 | 295 | def build_aux_prediction(self,opt,num_class): 296 | """Prediction""" 297 | if opt.Prediction == "CTC": 298 | # self.aux_fc = nn.Linear(self.SequenceModeling_output, num_class) 299 | self.aux_Prediction = self.aux_fc 300 | # self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) 301 | elif opt.Prediction == "Attn": 302 | # self.aux_fc = nn.Linear(opt.hidden_size, num_class) 303 | self.aux_Prediction = Attention( 304 | self.SequenceModeling_output, opt.hidden_size, num_class,self.aux_fc 305 | ) 306 | else: 307 | raise Exception("Prediction is neither CTC or Attn") 308 | 309 | def freeze_conv(self): 310 | for param in self.model.parameters(): 311 | param.requires_grad = False 312 | self.model.eval() 313 | 314 | class MRNNet(nn.Module): 315 | def __init__(self, opt): 316 | super(MRNNet, self).__init__() 317 | self.model = nn.ModuleList() 318 | self.out_dim=None 319 | self.fc = None 320 | self.opt = opt 321 | self.task_sizes = [] 322 | if self.opt.FeatureExtraction == "VGG": 323 | self.patch = 63 324 | elif self.opt.FeatureExtraction == "SVTR": 325 | self.patch = 64 326 | elif self.opt.FeatureExtraction == "ResNet": 327 | self.patch = 65 328 | self.router = "dm-router" #dm-router 329 | self.layer_num = 1 330 | self.beta = 1 331 | 332 | @property 333 | def feature_dim(self): 334 | if self.out_dim is None: 335 | return 0 336 | return self.out_dim*len(self.model) 337 | 338 | def extract_vector(self, x): 339 | features = [convnet(x) for convnet in self.model] 340 | features = torch.cat(features, 1) 341 | return features 342 | 343 | def forward(self, image, cross = True,text=None, is_train=True): 344 | """Transformation stage""" 345 | # features = [convnet(image) for convnet in self.model] 346 | if cross==False: 347 | features = self.model[-1](image,text,is_train)["predict"] 348 | index = None 349 | # elif is_train == False: 350 | # features, index = self.cross_test(image) 351 | elif is_train == False: 352 | features, index = self.cross_forward_expert(image, text, is_train) 353 | else: 354 | # features,index = self.cross_forwardv2(image) 355 | features, index = self.cross_forward(image,text,is_train) 356 | # out=self.fc(features) #{logics: self.fc(features)} 357 | out = dict({"logits":features,"index":index,"aux_logits":None}) 358 | 359 | return out # [b, num_steps, opt.num_class] 360 | 361 | def pad_zeros_features(self,feature,total): 362 | B,T,know = feature.size() 363 | zero = torch.ones([B,T,total-know],dtype=torch.float).to(feature.device) 364 | return torch.cat([feature,zero],dim=-1) 365 | 366 | def cross_forward_expert(self, image, text=None, is_train=True): 367 | """Transformation stage""" 368 | features = [convnet(image,text,is_train) for convnet in self.model] 369 | route_info = torch.stack([feature["feature"] for feature in features], 1) 370 | route_info = self.dm_router(route_info) 371 | route_info = rearrange(route_info, 'b h w c -> b w (h c)') 372 | route_info = self.channel_route(route_info) 373 | # route_info = torch.cat([torch.max(feature,-1)[0] for feature in features],-1) 374 | index = self.route(route_info.permute(0, 2, 1).contiguous()) 375 | # index = self.softargmax1d(torch.squeeze(index, -1),self.beta) 376 | index = torch.squeeze(index, -1) 377 | index = torch.max(index, -1)[1] 378 | 379 | # index [B,I] 380 | # route_info [B,T,I] 381 | 382 | # feature_array = torch.stack(features, 1) 383 | features = [feature["predict"] for feature in features] 384 | B, T, C = features[-1].size() 385 | list_len = len(features) 386 | normal_feat = [] 387 | for i in range(list_len - 1): 388 | feat = self.pad_zeros_features(features[i], total=C) 389 | normal_feat.append(feat) 390 | normal_feat.append(features[-1]) 391 | normal_feat = torch.stack(normal_feat, 0) 392 | # normal_feat [I,B,T,C] -> [T,C,B,I] -> [B,T,C,I] 393 | output = torch.stack([normal_feat[index_one][i,:,:]for i,index_one in enumerate(index)],0) 394 | 395 | return output.contiguous(),index 396 | 397 | def cross_forward(self, image, text=None, is_train=True): 398 | """Transformation stage""" 399 | features = [convnet(image,text,is_train) for convnet in self.model] 400 | route_info = torch.stack([feature["feature"] for feature in features], 1) 401 | route_info = self.dm_router(route_info) 402 | route_info = rearrange(route_info, 'b h w c -> b w (h c)') 403 | route_info = self.channel_route(route_info) 404 | # route_info = torch.cat([torch.max(feature,-1)[0] for feature in features],-1) 405 | index = self.route(route_info.permute(0, 2, 1).contiguous()) 406 | index = self.softargmax1d(torch.squeeze(index, -1),self.beta) 407 | # index [B,I] 408 | # route_info [B,T,I] 409 | 410 | features = [feature["predict"] for feature in features] 411 | B, T, C = features[-1].size() 412 | list_len = len(features) 413 | normal_feat = [] 414 | for i in range(list_len - 1): 415 | feat = self.pad_zeros_features(features[i], total=C) 416 | normal_feat.append(feat) 417 | normal_feat.append(features[-1]) 418 | normal_feat = torch.stack(normal_feat, 0) 419 | # normal_feat [I,B,T,C] -> [T,C,B,I] -> [B,T,C,I] 420 | output = (normal_feat.permute(2, 3, 1, 0) * index).permute(2, 0, 1, 3).contiguous() 421 | # output = (normal_feat.permute(3,1,2,0) * route_info).permute(1,2,0,3).contiguous() 422 | 423 | return torch.sum(output, -1), index 424 | 425 | def build_fc(self, hidden_size, nb_classes): 426 | self.update_fc(hidden_size, nb_classes) 427 | 428 | def update_fc(self, hidden_size, nb_classes): 429 | self.model.append(Model(self.opt)) 430 | self.model[-1].new_fc(hidden_size,nb_classes) 431 | # self.model[-1].load_state_dict(self.model[-2].state_dict()) 432 | 433 | if self.out_dim is None: 434 | self.out_dim=self.model[-1].SequenceModeling_output 435 | 436 | 437 | self.route = nn.Linear(self.patch , 1) 438 | self.channel_route = nn.Linear(self.feature_dim, len(self.model)) 439 | # if self.router == "gmlp": 440 | # block = GatingMlpBlock(self.out_dim, self.out_dim * 2, self.patch) 441 | # elif self.router == "vip": 442 | # block = PermutatorBlock(self.out_dim, 2, taski = len(self.model), patch = self.patch) 443 | # el 444 | if self.router == "dm-router": 445 | block = DM_Router(self.out_dim, self.out_dim * 2, self.patch,len(self.model)) 446 | else: 447 | block = nn.Linear(self.out_dim, self.out_dim ) 448 | layers=[] 449 | for _ in range(self.layer_num): 450 | layers.append(block) 451 | print("mlp {} has {} layers".format(block, len(layers))) 452 | self.dm_router = nn.Sequential(*layers) 453 | # [b, num_steps * len] -> [b, len] 454 | # if self.fc is not None: 455 | # nb_output = self.fc.out_features 456 | # weight = copy.deepcopy(self.fc.weight.data) 457 | # bias = copy.deepcopy(self.fc.bias.data) 458 | # fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight 459 | # fc.bias.data[:nb_output] = bias 460 | # 461 | # del self.fc 462 | # self.fc = fc 463 | # fc = nn.Linear(self.feature_dim, nb_classes) 464 | def load_fc(self,input,output): 465 | fc = nn.Linear(input,output) 466 | if self.channel_route is not None: 467 | nb_output = self.channel_route.out_features 468 | weight = copy.deepcopy(self.channel_route.weight.data) 469 | bias = copy.deepcopy(self.channel_route.bias.data) 470 | fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight 471 | fc.bias.data[:nb_output] = bias 472 | 473 | del self.fc 474 | self.fc = fc 475 | 476 | def build_prediction(self,opt,num_class): 477 | """Prediction""" 478 | if opt.Prediction == "CTC" or opt.Prediction == "Attn": 479 | # self.fc = nn.Linear(self.SequenceModeling_output, num_class) 480 | # self.Prediction = self.fc 481 | self.model[-1].build_prediction(opt,num_class) 482 | else: 483 | raise Exception("Prediction is neither CTC or Attn") 484 | 485 | def copy(self): 486 | return copy.deepcopy(self) 487 | 488 | def freeze(self): 489 | for param in self.parameters(): 490 | param.requires_grad = False 491 | self.eval() 492 | 493 | return self 494 | 495 | def softargmax1d(self,input, beta=5): 496 | return nn.functional.softmax(beta * input, dim=-1) 497 | 498 | 499 | --------------------------------------------------------------------------------