├── trained_model └── README.txt ├── requirements.txt ├── WACV2025 ├── wacv25-1278-poster.pdf └── wacv25-1278-slides.pdf ├── utils ├── load_config.py ├── criterion.py ├── averager.py ├── utils_HDGE.py └── converter.py ├── config ├── DD.yaml ├── HDGE.yaml └── STR.yaml ├── modules ├── sequence_modeling.py ├── discriminators.py ├── generators.py ├── prediction.py ├── feature_extraction.py └── transformation.py ├── LICENSE ├── Dockerfile ├── source ├── ops.py ├── rand_aug.py ├── model.py ├── stratify.py ├── HDGE.py └── dataset.py ├── .gitignore ├── stage1_HDGE.py ├── README.md ├── stage1_DD.py ├── supervised_learning.py ├── test.py └── stage2_StrDA.py /trained_model/README.txt: -------------------------------------------------------------------------------- 1 | Store trained model here -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | six 2 | lmdb 3 | tqdm 4 | nltk 5 | pyyaml 6 | pillow 7 | opencv-python 8 | -------------------------------------------------------------------------------- /WACV2025/wacv25-1278-poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KhaLee2307/StrDA/HEAD/WACV2025/wacv25-1278-poster.pdf -------------------------------------------------------------------------------- /WACV2025/wacv25-1278-slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KhaLee2307/StrDA/HEAD/WACV2025/wacv25-1278-slides.pdf -------------------------------------------------------------------------------- /utils/load_config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | def load_config(config_path): 5 | with open(config_path, "r") as file: 6 | config = yaml.safe_load(file) 7 | return config 8 | -------------------------------------------------------------------------------- /utils/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FocalLoss(nn.Module): 7 | def __init__(self, alpha=1, gamma=2): 8 | super(FocalLoss, self).__init__() 9 | self.alpha = alpha 10 | self.gamma = gamma 11 | 12 | def forward(self, inputs, targets): 13 | bce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float()) 14 | loss = self.alpha * (1 - torch.exp(-bce_loss)) ** self.gamma * bce_loss 15 | return loss 16 | -------------------------------------------------------------------------------- /utils/averager.py: -------------------------------------------------------------------------------- 1 | class Averager(object): 2 | """ Compute average for torch.Tensor, used for loss average. """ 3 | 4 | def __init__(self): 5 | self.reset() 6 | 7 | def add(self, v): 8 | count = v.data.numel() 9 | v = v.data.sum() 10 | self.n_count += count 11 | self.sum += v 12 | 13 | def reset(self): 14 | self.n_count = 0 15 | self.sum = 0 16 | 17 | def val(self): 18 | res = 0 19 | if self.n_count != 0: 20 | res = self.sum / float(self.n_count) 21 | return res 22 | -------------------------------------------------------------------------------- /config/DD.yaml: -------------------------------------------------------------------------------- 1 | # Data Processing 2 | imgH: 32 # the height of the input image 3 | imgW: 100 # the width of the input image 4 | 5 | # Model Architecture 6 | num_fiducial: 20 # number of fiducial points of TPS-STN" 7 | input_channel: 3 # the number of input channel of Feature extractor 8 | output_channel: 512 # the number of output channel of Feature extractor 9 | hidden_size: 256 # the size of the LSTM hidden state 10 | 11 | # Optimizer 12 | lr: 0.001 # learning rate, 0.001 for Adam 13 | weight_decay: 0.01 # weight decay, 0.01 for Adam 14 | 15 | # Training 16 | grad_clip: 5 # gradient clipping value 17 | workers: 8 # number of data loading workers 18 | 19 | # Experiment 20 | manual_seed: 111 # for random seed setting 21 | method: "DD" # select Domain Stratifying method, DD|HDGE 22 | -------------------------------------------------------------------------------- /config/HDGE.yaml: -------------------------------------------------------------------------------- 1 | # HDGE 2 | decay_epoch: 100 # epoch from which to start lr decay 3 | load_height: 48 # image height 4 | load_width: 160 # image width 5 | crop_height: 32 # image height to be cropped 6 | crop_width: 100 # image width to be cropped 7 | lamda: 10 # lamda for gradient penalty 8 | idt_coef: 0.5 # coefficient for identity loss 9 | ngf: 64 # of gen filters in first conv layer 10 | ndf: 64 # of discrim filters in first conv layer 11 | norm: "instance" # instance normalization or batch normalization 12 | 13 | # Optimizer 14 | lr: 0.001 # learning rate, 0.001 for Adam 15 | 16 | # Training 17 | workers: 8 # number of data loading workers 18 | 19 | # Experiment 20 | manual_seed: 111 # for random seed setting 21 | method: "HDGE" # select Domain Stratifying method, DD|HDGE 22 | -------------------------------------------------------------------------------- /config/STR.yaml: -------------------------------------------------------------------------------- 1 | # Data Processing 2 | batch_max_length: 25 # maximum-label-length 3 | imgH: 32 # the height of the input image 4 | imgW: 100 # the width of the input image 5 | character: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" # character label 6 | 7 | # Model Architecture 8 | num_fiducial: 20 # number of fiducial points of TPS-STN" 9 | input_channel: 3 # the number of input channel of Feature extractor 10 | output_channel: 512 # the number of output channel of Feature extractor 11 | hidden_size: 256 # the size of the LSTM hidden state 12 | 13 | # Optimizer 14 | lr: 0.001 # learning rate, 0.001 for Adam 15 | weight_decay: 0.01 # weight decay, 0.01 for Adam 16 | 17 | # Training 18 | grad_clip: 5 # gradient clipping value 19 | workers: 8 # number of data loading workers 20 | 21 | # Experiment 22 | manual_seed: 111 # for random seed setting 23 | -------------------------------------------------------------------------------- /modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | def __init__(self, input_size, hidden_size, output_size): 6 | super(BidirectionalLSTM, self).__init__() 7 | self.rnn = nn.LSTM( 8 | input_size, hidden_size, bidirectional=True, batch_first=True 9 | ) 10 | self.linear = nn.Linear(hidden_size * 2, output_size) 11 | 12 | def forward(self, input): 13 | """ 14 | input : visual feature [batch_size x T x input_size], T = num_steps. 15 | output : contextual feature [batch_size x T x output_size] 16 | """ 17 | self.rnn.flatten_parameters() 18 | recurrent, _ = self.rnn( 19 | input 20 | ) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 21 | output = self.linear(recurrent) # batch_size x T x output_size 22 | return output 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lê Nhật Kha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # select Image 2 | FROM nvidia/cuda:11.3.1-runtime-ubuntu20.04 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | 5 | # set bash as current shell 6 | RUN chsh -s /bin/bash 7 | SHELL ["/bin/bash", "-c"] 8 | 9 | RUN apt-get update 10 | RUN apt-get install -y libicu-dev git wget bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion g++ gcc && \ 11 | apt-get clean && rm -rf /var/lib/apt/lists/* 12 | 13 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 14 | && mkdir /root/.conda \ 15 | && bash Miniconda3-latest-Linux-x86_64.sh -b \ 16 | && rm -f Miniconda3-latest-Linux-x86_64.sh 17 | ENV PATH=/root/miniconda3/bin:$PATH 18 | 19 | # init conda and update 20 | RUN echo "Running $(conda --version)" && \ 21 | conda init bash && . /root/.bashrc && \ 22 | conda update conda 23 | 24 | # set up conda environment 25 | RUN conda create -n strda python=3.8 26 | RUN echo "source activate strda" > ~/.bashrc 27 | ENV PATH /root/miniconda3/envs/strda/bin:$PATH 28 | 29 | # install dependencies 30 | RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 31 | RUN pip install opencv-python==4.4.0.46 Pillow==7.2.0 opencv-python-headless==4.5.1.48 lmdb tqdm nltk six pyyaml 32 | 33 | RUN apt-get update 34 | RUN apt-get install -y ffmpeg libsm6 libxext6 35 | 36 | # get repository 37 | WORKDIR /home 38 | -------------------------------------------------------------------------------- /modules/discriminators.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from torch import nn 4 | 5 | from source.ops import conv_norm_lrelu, get_norm_layer, init_network 6 | 7 | 8 | class NLayerDiscriminator(nn.Module): 9 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_bias=False): 10 | super(NLayerDiscriminator, self).__init__() 11 | dis_model = [nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1), 12 | nn.LeakyReLU(0.2, True)] 13 | nf_mult = 1 14 | nf_mult_prev = 1 15 | for n in range(1, n_layers): 16 | nf_mult_prev = nf_mult 17 | nf_mult = min(2**n, 8) 18 | dis_model += [conv_norm_lrelu(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, 19 | norm_layer= norm_layer, padding=1, bias=use_bias)] 20 | nf_mult_prev = nf_mult 21 | nf_mult = min(2**n_layers, 8) 22 | dis_model += [conv_norm_lrelu(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, 23 | norm_layer= norm_layer, padding=1, bias=use_bias)] 24 | dis_model += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)] 25 | 26 | self.dis_model = nn.Sequential(*dis_model) 27 | 28 | def forward(self, input): 29 | return self.dis_model(input) 30 | 31 | 32 | def define_Dis(input_nc, ndf, n_layers_D=3, norm="batch", gpu_ids=[0]): 33 | dis_net = None 34 | norm_layer = get_norm_layer(norm_type=norm) 35 | if type(norm_layer) == functools.partial: 36 | use_bias = norm_layer.func == nn.InstanceNorm2d 37 | else: 38 | use_bias = norm_layer == nn.InstanceNorm2d 39 | 40 | dis_net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_bias=use_bias) 41 | 42 | return init_network(dis_net, gpu_ids) 43 | -------------------------------------------------------------------------------- /modules/generators.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from torch import nn 4 | 5 | from source.ops import ResidualBlock, conv_norm_relu, dconv_norm_relu, get_norm_layer, init_network 6 | 7 | 8 | class ResnetGenerator(nn.Module): 9 | def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=True, num_blocks=6): 10 | super(ResnetGenerator, self).__init__() 11 | if type(norm_layer) == functools.partial: 12 | use_bias = norm_layer.func == nn.InstanceNorm2d 13 | else: 14 | use_bias = norm_layer == nn.InstanceNorm2d 15 | 16 | res_model = [nn.ReflectionPad2d(3), 17 | conv_norm_relu(input_nc, ngf * 1, 7, norm_layer=norm_layer, bias=use_bias), 18 | conv_norm_relu(ngf * 1, ngf * 2, 3, 2, 1, norm_layer=norm_layer, bias=use_bias), 19 | conv_norm_relu(ngf * 2, ngf * 4, 3, 2, 1, norm_layer=norm_layer, bias=use_bias)] 20 | 21 | for i in range(num_blocks): 22 | res_model += [ResidualBlock(ngf * 4, norm_layer, use_dropout, use_bias)] 23 | 24 | res_model += [dconv_norm_relu(ngf * 4, ngf * 2, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias), 25 | dconv_norm_relu(ngf * 2, ngf * 1, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias), 26 | nn.ReflectionPad2d(3), 27 | nn.Conv2d(ngf, output_nc, 7), 28 | nn.Tanh()] 29 | self.res_model = nn.Sequential(*res_model) 30 | 31 | def forward(self, x): 32 | return self.res_model(x) 33 | 34 | 35 | def define_Gen(input_nc, output_nc, ngf, norm="batch", use_dropout=False, gpu_ids=[0]): 36 | gen_net = None 37 | norm_layer = get_norm_layer(norm_type=norm) 38 | 39 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=9) 40 | 41 | return init_network(gen_net, gpu_ids) 42 | -------------------------------------------------------------------------------- /utils/utils_HDGE.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | # to make directories 10 | def mkdir(paths): 11 | for path in paths: 12 | os.makedirs(path, exist_ok=True) 13 | 14 | 15 | # to make cuda tensor 16 | def cuda(xs): 17 | if torch.cuda.is_available(): 18 | if not isinstance(xs, (list, tuple)): 19 | return xs.cuda() 20 | else: 21 | return [x.cuda() for x in xs] 22 | 23 | 24 | # to save the checkpoint 25 | def save_checkpoint(state, save_path): 26 | torch.save(state, save_path) 27 | 28 | 29 | # to load the checkpoint 30 | def load_checkpoint(ckpt_path, map_location=None): 31 | ckpt = torch.load(ckpt_path, map_location=map_location) 32 | print(" [*] Loading checkpoint from %s succeed!" % ckpt_path) 33 | return ckpt 34 | 35 | 36 | # to store 50 generated image in a pool and sample from it when it is full 37 | # shrivastava et al’s strategy 38 | class Sample_from_Pool(object): 39 | def __init__(self, max_elements=50): 40 | self.max_elements = max_elements 41 | self.cur_elements = 0 42 | self.items = [] 43 | 44 | def __call__(self, in_items): 45 | return_items = [] 46 | for in_item in in_items: 47 | if self.cur_elements < self.max_elements: 48 | self.items.append(in_item) 49 | self.cur_elements = self.cur_elements + 1 50 | return_items.append(in_item) 51 | else: 52 | if np.random.ranf() > 0.5: 53 | idx = np.random.randint(0, self.max_elements) 54 | tmp = copy.copy(self.items[idx]) 55 | self.items[idx] = in_item 56 | return_items.append(tmp) 57 | else: 58 | return_items.append(in_item) 59 | return return_items 60 | 61 | 62 | class LambdaLR(): 63 | def __init__(self, epochs, offset, decay_epoch): 64 | self.epochs = epochs 65 | self.offset = offset 66 | self.decay_epoch = decay_epoch 67 | 68 | def step(self, epoch): 69 | return 1.0 - max(0, epoch + self.offset - self.decay_epoch)/(self.epochs - self.decay_epoch) 70 | 71 | 72 | def print_networks(nets, names): 73 | print("------------Number of Parameters---------------") 74 | i=0 75 | for net in nets: 76 | num_params = 0 77 | for param in net.parameters(): 78 | num_params += param.numel() 79 | print("[Network %s] Total number of parameters : %.3f M" % (names[i], num_params / 1e6)) 80 | i=i+1 81 | print("-----------------------------------------------") 82 | -------------------------------------------------------------------------------- /source/ops.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | 7 | 8 | def get_norm_layer(norm_type="instance"): 9 | if norm_type == "batch": 10 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 11 | elif norm_type == "instance": 12 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 13 | else: 14 | raise NotImplementedError("normalization layer [%s] is not found" % norm_type) 15 | return norm_layer 16 | 17 | 18 | def init_weights(net, gain=0.02): 19 | def init_func(m): 20 | classname = m.__class__.__name__ 21 | if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): 22 | init.normal_(m.weight.data, 0.0, gain) 23 | if hasattr(m, "bias") and m.bias is not None: 24 | init.constant_(m.bias.data, 0.0) 25 | elif classname.find("BatchNorm2d") != -1: 26 | init.normal_(m.weight.data, 1.0, gain) 27 | init.constant_(m.bias.data, 0.0) 28 | 29 | print("Network initialized with weights sampled from N(0,0.02).") 30 | net.apply(init_func) 31 | 32 | 33 | def init_network(net, gpu_ids=[]): 34 | if len(gpu_ids) > 0: 35 | assert(torch.cuda.is_available()) 36 | net.cuda(gpu_ids[0]) 37 | net = torch.nn.DataParallel(net, gpu_ids) 38 | init_weights(net) 39 | return net 40 | 41 | 42 | def conv_norm_lrelu(in_dim, out_dim, kernel_size, stride = 1, padding=0, 43 | norm_layer = nn.BatchNorm2d, bias = False): 44 | return nn.Sequential( 45 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias), 46 | norm_layer(out_dim), nn.LeakyReLU(0.2,True)) 47 | 48 | 49 | def conv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0, 50 | norm_layer = nn.BatchNorm2d, bias = False): 51 | return nn.Sequential( 52 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias), 53 | norm_layer(out_dim), nn.ReLU(True)) 54 | 55 | 56 | def dconv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0, output_padding=0, 57 | norm_layer = nn.BatchNorm2d, bias = False): 58 | return nn.Sequential( 59 | nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, 60 | padding, output_padding, bias = bias), 61 | norm_layer(out_dim), nn.ReLU(True)) 62 | 63 | 64 | class ResidualBlock(nn.Module): 65 | def __init__(self, dim, norm_layer, use_dropout, use_bias): 66 | super(ResidualBlock, self).__init__() 67 | res_block = [nn.ReflectionPad2d(1), 68 | conv_norm_relu(dim, dim, kernel_size=3, 69 | norm_layer= norm_layer, bias=use_bias)] 70 | if use_dropout: 71 | res_block += [nn.Dropout(0.5)] 72 | res_block += [nn.ReflectionPad2d(1), 73 | nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias), 74 | norm_layer(dim)] 75 | 76 | self.res_block = nn.Sequential(*res_block) 77 | 78 | def forward(self, x): 79 | return x + self.res_block(x) 80 | 81 | 82 | def set_grad(nets, requires_grad=False): 83 | for net in nets: 84 | for param in net.parameters(): 85 | param.requires_grad = requires_grad 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # ignore folder 165 | data/ 166 | trained_model/ 167 | stratify/ 168 | log/ 169 | -------------------------------------------------------------------------------- /modules/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class Attention(nn.Module): 9 | def __init__(self, input_size, hidden_size, num_class, num_char_embeddings=256): 10 | super(Attention, self).__init__() 11 | self.attention_cell = AttentionCell( 12 | input_size, hidden_size, num_char_embeddings 13 | ) 14 | self.hidden_size = hidden_size 15 | self.num_class = num_class 16 | self.generator = nn.Linear(hidden_size, num_class) 17 | self.char_embeddings = nn.Embedding(num_class, num_char_embeddings) 18 | 19 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 20 | """ 21 | input: 22 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] 23 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS]. 24 | output: probability distribution at each step [batch_size x num_steps x num_class] 25 | """ 26 | batch_size = batch_H.size(0) 27 | num_steps = batch_max_length + 1 # +1 for [EOS] at end of sentence. 28 | 29 | output_hiddens = ( 30 | torch.FloatTensor(batch_size, num_steps, self.hidden_size) 31 | .fill_(0) 32 | .to(device) 33 | ) 34 | hidden = ( 35 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 36 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 37 | ) 38 | 39 | if is_train: 40 | for i in range(num_steps): 41 | char_embeddings = self.char_embeddings(text[:, i]) 42 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1}) 43 | hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) 44 | output_hiddens[:, i, :] = hidden[ 45 | 0 46 | ] # LSTM hidden index (0: hidden, 1: Cell) 47 | probs = self.generator(output_hiddens) 48 | 49 | else: 50 | targets = text[0].expand(batch_size) # should be fill with [SOS] token 51 | probs = ( 52 | torch.FloatTensor(batch_size, num_steps, self.num_class) 53 | .fill_(0) 54 | .to(device) 55 | ) 56 | 57 | for i in range(num_steps): 58 | char_embeddings = self.char_embeddings(targets) 59 | hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) 60 | probs_step = self.generator(hidden[0]) 61 | probs[:, i, :] = probs_step 62 | _, next_input = probs_step.max(1) 63 | targets = next_input 64 | 65 | return probs # batch_size x num_steps x num_class 66 | 67 | 68 | class AttentionCell(nn.Module): 69 | def __init__(self, input_size, hidden_size, num_embeddings): 70 | super(AttentionCell, self).__init__() 71 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 72 | self.h2h = nn.Linear( 73 | hidden_size, hidden_size 74 | ) # either i2i or h2h should have bias 75 | self.score = nn.Linear(hidden_size, 1, bias=False) 76 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 77 | self.hidden_size = hidden_size 78 | 79 | def forward(self, prev_hidden, batch_H, char_embeddings): 80 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 81 | batch_H_proj = self.i2h(batch_H) 82 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 83 | e = self.score( 84 | torch.tanh(batch_H_proj + prev_hidden_proj) 85 | ) # batch_size x num_encoder_step * 1 86 | 87 | alpha = F.softmax(e, dim=1) 88 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze( 89 | 1 90 | ) # batch_size x num_channel 91 | concat_context = torch.cat( 92 | [context, char_embeddings], 1 93 | ) # batch_size x (num_channel + num_embedding) 94 | cur_hidden = self.rnn(concat_context, prev_hidden) 95 | return cur_hidden, alpha 96 | -------------------------------------------------------------------------------- /source/rand_aug.py: -------------------------------------------------------------------------------- 1 | import random 2 | import logging 3 | 4 | import numpy as np 5 | import PIL 6 | import PIL.ImageOps 7 | import PIL.ImageDraw 8 | import PIL.ImageEnhance 9 | from PIL import Image 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | PARAMETER_MAX = 10 14 | 15 | 16 | def AutoContrast(img, **kwarg): 17 | return PIL.ImageOps.autocontrast(img) 18 | 19 | 20 | def Brightness(img, v, max_v, bias=0): 21 | v = _float_parameter(v, max_v) + bias 22 | return PIL.ImageEnhance.Brightness(img).enhance(v) 23 | 24 | 25 | def Color(img, v, max_v, bias=0): 26 | v = _float_parameter(v, max_v) + bias 27 | return PIL.ImageEnhance.Color(img).enhance(v) 28 | 29 | 30 | def Contrast(img, v, max_v, bias=0): 31 | v = _float_parameter(v, max_v) + bias 32 | return PIL.ImageEnhance.Contrast(img).enhance(v) 33 | 34 | 35 | def Cutout(img, v, max_v, bias=0): 36 | if v == 0: 37 | return img 38 | v = _float_parameter(v, max_v) + bias 39 | v = int(v * min(img.size)) 40 | return CutoutAbs(img, v) 41 | 42 | 43 | def CutoutAbs(img, v, **kwarg): 44 | w, h = img.size 45 | x0 = np.random.uniform(0, w) 46 | y0 = np.random.uniform(0, h) 47 | x0 = int(max(0, x0 - v / 2.)) 48 | y0 = int(max(0, y0 - v / 2.)) 49 | x1 = int(min(w, x0 + v)) 50 | y1 = int(min(h, y0 + v)) 51 | xy = (x0, y0, x1, y1) 52 | # gray 53 | color = (127, 127, 127) 54 | img = img.copy() 55 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 56 | return img 57 | 58 | 59 | def Equalize(img, **kwarg): 60 | return PIL.ImageOps.equalize(img) 61 | 62 | 63 | def Posterize(img, v, max_v, bias=0): 64 | v = _int_parameter(v, max_v) + bias 65 | return PIL.ImageOps.posterize(img, v) 66 | 67 | 68 | def Rotate(img, v, max_v, bias=0): 69 | v = _int_parameter(v, max_v) + bias 70 | if random.random() < 0.5: 71 | v = -v 72 | return img.rotate(v) 73 | 74 | 75 | def ShearX(img, v, max_v, bias=0): 76 | v = _float_parameter(v, max_v) + bias 77 | if random.random() < 0.5: 78 | v = -v 79 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 80 | 81 | 82 | def ShearY(img, v, max_v, bias=0): 83 | v = _float_parameter(v, max_v) + bias 84 | if random.random() < 0.5: 85 | v = -v 86 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 87 | 88 | 89 | def Solarize(img, v, max_v, bias=0): 90 | v = _int_parameter(v, max_v) + bias 91 | return PIL.ImageOps.solarize(img, 256 - v) 92 | 93 | 94 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 95 | v = _int_parameter(v, max_v) + bias 96 | if random.random() < 0.5: 97 | v = -v 98 | img_np = np.array(img).astype(int) 99 | img_np = img_np + v 100 | img_np = np.clip(img_np, 0, 255) 101 | img_np = img_np.astype(np.uint8) 102 | img = Image.fromarray(img_np) 103 | return PIL.ImageOps.solarize(img, threshold) 104 | 105 | 106 | def TranslateX(img, v, max_v, bias=0): 107 | v = _float_parameter(v, max_v) + bias 108 | if random.random() < 0.5: 109 | v = -v 110 | v = int(v * img.size[0]) 111 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 112 | 113 | 114 | def TranslateY(img, v, max_v, bias=0): 115 | v = _float_parameter(v, max_v) + bias 116 | if random.random() < 0.5: 117 | v = -v 118 | v = int(v * img.size[1]) 119 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 120 | 121 | 122 | def _float_parameter(v, max_v): 123 | return float(v) * max_v / PARAMETER_MAX 124 | 125 | 126 | def _int_parameter(v, max_v): 127 | return int(v * max_v / PARAMETER_MAX) 128 | 129 | 130 | def spatial_augment_pool(): 131 | augs = [(ShearX, 0.2, 0.1), (ShearY, 0.2, 0.1), (TranslateX, 0.2, 0.1), 132 | (TranslateY, 0.2, 0.1), (Rotate, 10, 10)] 133 | return augs 134 | 135 | 136 | def channel_augment_pool(): 137 | augs = [ 138 | (AutoContrast, None, None), 139 | (Brightness, 1.8, 0.1), 140 | (Color, 1.8, 0.1), 141 | (Contrast, 1.8, 0.1), 142 | (Equalize, None, None), 143 | (Posterize, 4, 4), 144 | (Solarize, 256, 0), 145 | (SolarizeAdd, 110, 0), 146 | ] 147 | return augs 148 | 149 | 150 | class Augmentor(object): 151 | 152 | def __init__(self, n, m, augment_type): 153 | assert n >= 1 154 | assert 1 <= m <= 10 155 | self.n = n 156 | self.m = m 157 | assert augment_type in ["channel", "spatial" 158 | ], "not augment type name %s" % augment_type 159 | if augment_type == "spatial": 160 | self.augment_pool = spatial_augment_pool() 161 | elif augment_type == "channel": 162 | self.augment_pool = channel_augment_pool() 163 | 164 | def __call__(self, img, prob=0.5): 165 | ops = random.choices(self.augment_pool, k=self.n) 166 | for op, max_v, bias in ops: 167 | v = np.random.randint(1, self.m) 168 | if random.random() < prob: 169 | img = op(img, v=v, max_v=max_v, bias=bias) 170 | return img 171 | -------------------------------------------------------------------------------- /utils/converter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | 5 | 6 | class CTCLabelConverter(object): 7 | """ Convert between text-label and text-index """ 8 | 9 | def __init__(self, character): 10 | # character (str): set of the possible characters. 11 | list_special_token = [ 12 | "[PAD]", 13 | "[UNK]", 14 | " ", 15 | ] # [UNK] for unknown character, " " for space. 16 | list_character = list(character) 17 | dict_character = list_special_token + list_character 18 | 19 | self.dict = {} 20 | for i, char in enumerate(dict_character): 21 | # NOTE: 0 is reserved for "CTCblank" token required by CTCLoss, not same with space " ". 22 | # print(i, char) 23 | self.dict[char] = i + 1 24 | 25 | self.character = [ 26 | "[CTCblank]" 27 | ] + dict_character # dummy "[CTCblank]" token for CTCLoss (index 0). 28 | print(f"# of tokens and characters: {len(self.character)}") 29 | 30 | def encode(self, word_string, batch_max_length=25): 31 | """ Convert word_list (string) into word_index. 32 | input: 33 | word_string: word labels of each image. [batch_size] 34 | batch_max_length: max length of word in the batch. Default: 25 35 | 36 | output: 37 | word_index: word index list for CTCLoss. [batch_size, batch_max_length] 38 | word_length: length of each word. [batch_size] 39 | """ 40 | word_length = [len(word) for word in word_string] 41 | 42 | # the index used for padding (=[PAD]) would not affect the CTC loss calculation. 43 | word_index = torch.LongTensor(len(word_string), batch_max_length).fill_( 44 | self.dict["[PAD]"] 45 | ) 46 | 47 | for i, word in enumerate(word_string): 48 | word = list(word) 49 | word_idx = [ 50 | self.dict[char] if char in self.dict else self.dict["[UNK]"] 51 | for char in word 52 | ] 53 | word_index[i][: len(word_idx)] = torch.LongTensor(word_idx) 54 | 55 | return (word_index.to(device), torch.IntTensor(word_length).to(device)) 56 | 57 | def decode(self, word_index, word_length): 58 | """ Convert word_index into word_string """ 59 | word_string = [] 60 | for idx, length in enumerate(word_length): 61 | word_idx = word_index[idx, :] 62 | 63 | char_list = [] 64 | for i in range(length): 65 | # removing repeated characters and blank. 66 | if word_idx[i] != 0 and not (i > 0 and word_idx[i - 1] == word_idx[i]): 67 | char_list.append(self.character[word_idx[i]]) 68 | 69 | word = "".join(char_list) 70 | word_string.append(word) 71 | return word_string 72 | 73 | 74 | class AttnLabelConverter(object): 75 | """ Convert between text-label and text-index """ 76 | 77 | def __init__(self, character): 78 | # character (str): set of the possible characters. 79 | # [SOS] (start-of-sentence token) and [EOS] (end-of-sentence token) for the attention decoder. 80 | list_special_token = [ 81 | "[PAD]", 82 | "[UNK]", 83 | "[SOS]", 84 | "[EOS]", 85 | " ", 86 | ] # [UNK] for unknown character, " " for space. 87 | list_character = list(character) 88 | self.character = list_special_token + list_character 89 | 90 | self.dict = {} 91 | for i, char in enumerate(self.character): 92 | # print(i, char) 93 | self.dict[char] = i 94 | 95 | print(f"# of tokens and characters: {len(self.character)}") 96 | 97 | def encode(self, word_string, batch_max_length=25): 98 | """ Convert word_list (string) into word_index. 99 | input: 100 | word_string: word labels of each image. [batch_size] 101 | batch_max_length: max length of word in the batch. Default: 25 102 | 103 | output: 104 | word_index : the input of attention decoder. [batch_size x (max_length+2)] +1 for [SOS] token and +1 for [EOS] token. 105 | word_length : the length of output of attention decoder, which count [EOS] token also. [batch_size] 106 | """ 107 | word_length = [ 108 | len(word) + 1 for word in word_string 109 | ] # +1 for [EOS] at end of sentence. 110 | batch_max_length += 1 111 | 112 | # additional batch_max_length + 1 for [SOS] at first step. 113 | word_index = torch.LongTensor(len(word_string), batch_max_length + 1).fill_( 114 | self.dict["[PAD]"] 115 | ) 116 | word_index[:, 0] = self.dict["[SOS]"] 117 | 118 | for i, word in enumerate(word_string): 119 | word = list(word) 120 | word.append("[EOS]") 121 | word_idx = [ 122 | self.dict[char] if char in self.dict else self.dict["[UNK]"] 123 | for char in word 124 | ] 125 | word_index[i][1 : 1 + len(word_idx)] = torch.LongTensor( 126 | word_idx 127 | ) # word_index[:, 0] = [SOS] token 128 | 129 | return (word_index.to(device), torch.IntTensor(word_length).to(device)) 130 | 131 | def decode(self, word_index, word_length): 132 | """ Convert word_index into word_string """ 133 | word_string = [] 134 | for idx, length in enumerate(word_length): 135 | word_idx = word_index[idx, :length] 136 | word = "".join([self.character[i] for i in word_idx]) 137 | word_string.append(word) 138 | return word_string 139 | -------------------------------------------------------------------------------- /stage1_HDGE.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import argparse 5 | 6 | import numpy as np 7 | from PIL import ImageFile 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | 12 | import utils.utils_HDGE as utils 13 | from utils.load_config import load_config 14 | 15 | from modules.discriminators import define_Dis 16 | 17 | import source.HDGE as md 18 | from source.stratify import DomainStratifying 19 | from source.dataset import hierarchical_dataset 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | ImageFile.LOAD_TRUNCATED_IMAGES = True 24 | torch.multiprocessing.set_sharing_strategy("file_system") 25 | 26 | 27 | def main(args): 28 | dashed_line = "-" * 80 29 | 30 | # to make directories for saving results and trained models 31 | args.saved_path = f"stratify/{args.method}/{args.beta}_beta" 32 | os.makedirs(f"{args.saved_path}/{args.num_subsets}_subsets/", exist_ok=True) 33 | 34 | str_ids = args.gpu_ids.split(",") 35 | args.gpu_ids = [] 36 | for str_id in str_ids: 37 | id = int(str_id) 38 | if id >= 0: 39 | args.gpu_ids.append(id) 40 | # print(not args.no_dropout) 41 | 42 | # training part 43 | if args.train: 44 | print(dashed_line) 45 | model = md.HDGE(args) 46 | model.train(args) 47 | 48 | # inference part 49 | print(dashed_line) 50 | print("Start Inference") 51 | 52 | # load target domain data (raw) 53 | print("Load target domain data for inference...") 54 | target_data_raw, target_data_log = hierarchical_dataset(args.target_data, args, mode = "raw") 55 | print(target_data_log, end="") 56 | 57 | try: 58 | select_data = list(np.load(args.select_data)) 59 | except: 60 | print("\n [*][WARNING] NO available select_data!") 61 | print(" [*][WARNING] You are using all target domain data!\n") 62 | select_data = list(range(len(target_data_raw))) 63 | 64 | print(dashed_line) 65 | 66 | dis_source = define_Dis(input_nc=3, ndf=args.ndf, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) 67 | dis_target = define_Dis(input_nc=3, ndf=args.ndf, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) 68 | 69 | utils.print_networks([dis_source,dis_target], ["Da","Db"]) 70 | 71 | try: 72 | ckpt = utils.load_checkpoint("%s/HDGE_gen_dis.ckpt" % (args.checkpoint_dir)) 73 | dis_source.load_state_dict(ckpt["Da"]) 74 | dis_target.load_state_dict(ckpt["Db"]) 75 | 76 | print(dashed_line) 77 | # Domain Stratifying (Harmonic Domain Gap Estimator - HDGE) 78 | HDGE = DomainStratifying(args, select_data) 79 | HDGE.stratify_HDGE(target_data_raw, dis_source, dis_target, args.beta) 80 | 81 | print("\nAll information is saved in " + f"{args.saved_path}/") 82 | print("The trained weights are saved at " + f"{args.checkpoint_dir}/HDGE_gen_dis.ckpt") 83 | 84 | except: 85 | print("\n [*][WARNING] STOP Domain Stratifying!") 86 | print(" [*][WARNING] NO checkpoint!") 87 | print(" [*][WARNING] Please train the model first!") 88 | print(" [*][WARNING] Please check the checkpoint directory!\n") 89 | raise ValueError("NO checkpoint!") 90 | 91 | print(dashed_line) 92 | return 93 | 94 | 95 | if __name__ == "__main__": 96 | """ Argument """ 97 | parser = argparse.ArgumentParser() 98 | config = load_config("config/HDGE.yaml") 99 | parser.set_defaults(**config) 100 | 101 | parser.add_argument( 102 | "--source_data", default="data/train/synth/", help="path to source domain data", 103 | ) 104 | parser.add_argument( 105 | "--target_data", default="data/train/real/", help="path to target domain data", 106 | ) 107 | parser.add_argument( 108 | "--select_data", 109 | required=True, 110 | help="path to select data", 111 | ) 112 | parser.add_argument( 113 | "--checkpoint_dir", type=str, default="stratify/HDGE", help="models are saved here", 114 | ) 115 | parser.add_argument( 116 | "--batch_size", type=int, default=16, help="input batch size", 117 | ) 118 | parser.add_argument( 119 | "--batch_size_val", type=int, default=128, help="input batch size val", 120 | ) 121 | parser.add_argument( 122 | "--epochs", type=int, default=20, help="number of epochs to train for", 123 | ) 124 | parser.add_argument( 125 | "--no_dropout", action="store_true", help="no dropout for the generator", 126 | ) 127 | parser.add_argument( 128 | "--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2, 0,2", 129 | ) 130 | """ Adaptation """ 131 | parser.add_argument( 132 | "--num_subsets", 133 | type=int, 134 | required=True, 135 | help="hyper-parameter n, number of subsets partitioned from target domain data", 136 | ) 137 | parser.add_argument( 138 | "--beta", 139 | type=float, 140 | required=True, 141 | help="hyper-parameter beta in HDGE formula, 0 2 |

Stratified Domain Adaptation: A Progressive Self-Training Approach for Scene Text Recognition

3 | [📰 Paper] 4 | [🖼️ Poster] 5 | [📚 Slides] 6 |
7 |

WACV 2025 Early Acceptance (Round 1)

8 | WACV 2025 Logo 9 | 10 | 11 | ## Introduction 12 | This is the official PyTorch implementation of the [StrDA paper](https://openaccess.thecvf.com/content/WACV2025/html/Le_Stratified_Domain_Adaptation_A_Progressive_Self-Training_Approach_for_Scene_Text_WACV_2025_paper.html), which was accepted at the main conference of the ***IEEE/CVF Winter Conference on Applications of Computer Vision (WACV) 2025***. 13 | 14 | In this paper, we propose the Stratified Domain Adaptation (StrDA) approach, a progressive self-training framework for scene text recognition. By leveraging the gradual escalation of the domain gap with the Harmonic Domain Gap Estimator ($\mathrm{HDGE}$), we propose partitioning the target domain into a sequence of ordered subsets to progressively reduce the domain gap between each and the source domain. Progressive self-training is then applied sequentially to these subsets. Extensive experiments on STR benchmarks demonstrate that our approach enables the baseline STR models to progressively adapt to the target domain. This approach significantly improves the performance of the baseline model without using any human-annotated data and shows its superior effectiveness compared to existing UDA methods for the scene text recognition task. 15 | 16 | * **Keywords:** scene text recognition (STR), unsupervised domain adaptation (UDA), self-training (ST), optical character recognition (OCR) 17 | 18 | ## News 🚀🚀🚀 19 | - `2025/03/06`: 📜 We have uploaded the instructions for running the code. 20 | - `2025/03/03`: 💻 We have released the implementation of StrDA for TRBA and CRNN. 21 | - `2025/02/28`: 🗣️ We attended the conference, you can view the poster and slides [here](WACV2025). 22 | - `2025/08/30`: 🔥 Our paper has been accepted to [WACV'25](https://wacv2025.thecvf.com/) (Algorithms Track). 23 | 24 | ## Getting Started 25 | 26 | ### Installation 27 | 1. python>=3.8.16 28 | 2. Install PyTorch-cuda>=11.3 following [official instruction](https://pytorch.org/): 29 | 30 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 31 | 32 | 3. Install the necessary dependencies by running (`!pip install -r requirements.txt`): 33 | 34 | pip install opencv-python==4.4.0.46 Pillow==7.2.0 opencv-python-headless==4.5.1.48 lmdb tqdm nltk six pyyaml 35 | 36 | * You can also create the environment using `docker build -t StrDA .` 37 | 38 | ### Datasets 39 | Thanks to [ku21fan/STR-Fewer-Labels](https://github.com/ku21fan/STR-Fewer-Labels/blob/main/data.md), [baudm/parseq](https://github.com/baudm/parseq/blob/main/Datasets.md), and [Mountchicken/Union14M](https://github.com/Mountchicken/Union14M) for compiling and organizing the data. I highly recommend that you follow their guidelines to download the datasets and review the license of each dataset. 40 | 41 | ## Running the code 42 | *Please pay attention to the warnings when running the code (e.g., select_data for target domain data, checkpoint of HDGE, and trained weights of DD).* 43 | 44 | ### Training 45 | 46 | - First, you need a **source-trained STR model**. If you don’t have one, you can use `supervised_learning.py` to train an STR model with **source domain data (synthetic)**. 47 | - Next, you need to **filter the data**, removing samples that are too long (width > 25 times height) and save them to `select_data.npy` (to be updated later). Since the model only processes a maximum of 25 characters per word, these long samples could be **harmful** during pseudo-labeling. 48 | - Then, you will **run Stage 1** using one of the two methods. The files containing data information for each subset will be saved in `stratify/{args.method}/` as `.npy` files. **Please check them carefully!** 49 | - Finally, **run Stage 2** to perform adaptation on the **target domain data** to **boost model performance**. Then, test the results using a wide range of benchmarks. 50 | 51 | *Note: The target domain data must remain **unchanged** throughout the experiment.* 52 | 53 | #### Supervised Learning 54 | 55 | CUDA_VISIBLE_DEVICES=0 python supervised_learning.py --model TRBA --aug 56 | 57 | #### Stage 1 (Domain Stratifying) 58 | 59 | There are 2 main methods with many settings: 60 | 1. Harmonic Domain Gap Estimator ($\mathrm{HDGE}$) 61 | 62 | CUDA_VISIBLE_DEVICES=0 python stage1_HDGE.py --select_data select_data.npy --num_subsets 5 --beta 0.7 --train 63 | 64 | 2. Domain Discriminator ($\mathrm{DD}$) 65 | 66 | CUDA_VISIBLE_DEVICES=0 python stage1_DD.py --select_data select_data.npy --num_subsets 5 --discriminator CRNN --train --aug 67 | 68 | *Note: For both methods, you only need to activate `--train` to train the model the first time. After that, you can stratify the data without retraining.* 69 | 70 | #### Stage 2 (Progressive Self-Training) 71 | 72 | CUDA_VISIBLE_DEVICES=0 python stage2_StrDA.py --saved_model trained_model/TRBA.pth --model TRBA --num_subsets 5 --method HDGE --beta 0.7 --aug 73 | 74 | *Note: If the method is HDGE, you must enter `--beta`. If the method is DD, you must select a `--discriminator`. Example:* 75 | 76 | CUDA_VISIBLE_DEVICES=0 python stage2_StrDA.py --saved_model trained_model/CRNN.pth --model CRNN --num_subsets 5 --method DD --discriminator CRNN --aug 77 | 78 | ### Testing 79 | 80 | CUDA_VISIBLE_DEVICES=0 python test.py --saved_model trained_model/TRBA.pth --model TRBA 81 | 82 | **Broader insight:** You can try this method with different STR models, on various source-target domain pairs (e.g., synthetic-handwritten/art text) and even more complex domain gap problems like medical image segmentation. Additionally, you can replace self-training with more advanced UDA techniques. 83 | 84 | ## Reference 85 | If you find our work useful for your research, please cite it and give us a star⭐! 86 | ``` 87 | @inproceedings{le2025stratified, 88 | title={Stratified Domain Adaptation: A Progressive Self-Training Approach for Scene Text Recognition}, 89 | author={Le, Kha Nhat and Nguyen, Hoang-Tuan and Tran, Hung Tien and Ngo, Thanh Duc}, 90 | booktitle={2025 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 91 | pages={8990--9000}, 92 | year={2025}, 93 | organization={IEEE} 94 | } 95 | ``` 96 | 97 | ## Acknowledgements 98 | This code is based on [STR-Fewer-Labels](https://github.com/ku21fan/STR-Fewer-Labels) by [Jeonghun Baek](https://github.com/ku21fan) and [cycleGAN-PyTorch](https://github.com/arnab39/cycleGAN-PyTorch) by [Arnab Mondal 99 | ](https://github.com/arnab39). Thanks for your contributions! 100 | -------------------------------------------------------------------------------- /source/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from modules.transformation import TPS_SpatialTransformerNetwork 4 | from modules.feature_extraction import ResNet_FeatureExtractor, VGG_FeatureExtractor 5 | from modules.sequence_modeling import BidirectionalLSTM 6 | from modules.prediction import Attention 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args): 11 | super(Model, self).__init__() 12 | self.args = args 13 | 14 | self.stages = { 15 | "Trans": args.Transformation, 16 | "Feat": args.FeatureExtraction, 17 | "Seq": args.SequenceModeling, 18 | "Pred": args.Prediction, 19 | } 20 | 21 | """ Transformation """ 22 | if args.Transformation == "TPS": 23 | self.Transformation = TPS_SpatialTransformerNetwork( 24 | F=args.num_fiducial, 25 | I_size=(args.imgH, args.imgW), 26 | I_r_size=(args.imgH, args.imgW), 27 | I_channel_num=args.input_channel, 28 | ) 29 | else: 30 | print("No Transformation module specified") 31 | 32 | """ FeatureExtraction """ 33 | if args.FeatureExtraction == "VGG": 34 | self.FeatureExtraction = VGG_FeatureExtractor( 35 | args.input_channel, args.output_channel 36 | ) 37 | elif args.FeatureExtraction == "ResNet": 38 | self.FeatureExtraction = ResNet_FeatureExtractor( 39 | args.input_channel, args.output_channel 40 | ) 41 | else: 42 | raise Exception("No FeatureExtraction module specified") 43 | 44 | self.FeatureExtraction_output = args.output_channel 45 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( 46 | (None, 1) 47 | ) # Transform final (imgH/16-1) -> 1 48 | 49 | """ Sequence modeling """ 50 | if args.SequenceModeling == "BiLSTM": 51 | self.SequenceModeling = nn.Sequential( 52 | BidirectionalLSTM( 53 | self.FeatureExtraction_output, args.hidden_size, args.hidden_size 54 | ), 55 | BidirectionalLSTM( 56 | args.hidden_size, args.hidden_size, args.hidden_size 57 | ), 58 | ) 59 | self.SequenceModeling_output = args.hidden_size 60 | else: 61 | print("No SequenceModeling module specified") 62 | self.SequenceModeling_output = self.FeatureExtraction_output 63 | 64 | """ Prediction """ 65 | if args.Prediction == "CTC": 66 | self.Prediction = nn.Linear(self.SequenceModeling_output, args.num_class) 67 | elif args.Prediction == "Attn": 68 | self.Prediction = Attention( 69 | self.SequenceModeling_output, args.hidden_size, args.num_class 70 | ) 71 | else: 72 | raise Exception("Prediction is neither CTC or Attn") 73 | 74 | def forward(self, image, text=None, is_train=True): 75 | """ Transformation stage """ 76 | if not self.stages["Trans"] == "None": 77 | image = self.Transformation(image) 78 | 79 | """ Feature extraction stage """ 80 | visual_feature = self.FeatureExtraction(image) 81 | visual_feature = visual_feature.permute( 82 | 0, 3, 1, 2 83 | ) # [b, c, h, w] -> [b, w, c, h] 84 | visual_feature = self.AdaptiveAvgPool( 85 | visual_feature 86 | ) # [b, w, c, h] -> [b, w, c, 1] 87 | visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c] 88 | 89 | """ Sequence modeling stage """ 90 | if self.stages["Seq"] == "BiLSTM": 91 | contextual_feature = self.SequenceModeling( 92 | visual_feature 93 | ) # [b, num_steps, args.hidden_size] 94 | else: 95 | contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM 96 | 97 | """ Prediction stage """ 98 | if self.stages["Pred"] == "CTC": 99 | prediction = self.Prediction(contextual_feature.contiguous()) 100 | else: 101 | prediction = self.Prediction( 102 | contextual_feature.contiguous(), 103 | text, 104 | is_train, 105 | batch_max_length=self.args.batch_max_length, 106 | ) 107 | 108 | return prediction # [b, num_steps, args.num_class] 109 | 110 | 111 | class BaselineClassifier(nn.Module): 112 | """ Baseline model for discriminaton method """ 113 | 114 | def __init__(self, args): 115 | super(BaselineClassifier, self).__init__() 116 | self.args = args 117 | 118 | self.stages = { 119 | "Trans": args.Transformation, 120 | "Feat": args.FeatureExtraction, 121 | "Seq": args.SequenceModeling, 122 | "Pred": args.Prediction, 123 | } 124 | 125 | """ Transformation """ 126 | if args.Transformation == "TPS": 127 | self.Transformation = TPS_SpatialTransformerNetwork( 128 | F=args.num_fiducial, 129 | I_size=(args.imgH, args.imgW), 130 | I_r_size=(args.imgH, args.imgW), 131 | I_channel_num=args.input_channel, 132 | ) 133 | else: 134 | print("No Transformation module specified") 135 | 136 | """ FeatureExtraction """ 137 | if args.FeatureExtraction == "VGG": 138 | self.FeatureExtraction = VGG_FeatureExtractor( 139 | args.input_channel, args.output_channel 140 | ) 141 | elif args.FeatureExtraction == "ResNet": 142 | self.FeatureExtraction = ResNet_FeatureExtractor( 143 | args.input_channel, args.output_channel 144 | ) 145 | else: 146 | raise Exception("No FeatureExtraction module specified") 147 | 148 | self.FeatureExtraction_output = args.output_channel 149 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( 150 | (None, 1) 151 | ) # Transform final (imgH/16-1) -> 1 152 | 153 | """ Binary classifier """ 154 | self.AdaptiveAvgPool_2 = nn.AdaptiveAvgPool2d((None, 1)) 155 | self.Classifier_input = self.FeatureExtraction_output 156 | self.predict = nn.Linear(self.Classifier_input, 1) 157 | 158 | def forward(self, image, extract_feature = False): 159 | """ Transformation stage """ 160 | if not self.stages["Trans"] == "None": 161 | image = self.Transformation(image) 162 | 163 | """ Feature extraction stage """ 164 | visual_feature = self.FeatureExtraction(image) 165 | visual_feature = visual_feature.permute( 166 | 0, 3, 1, 2 167 | ) # [b, c, h, w] -> [b, w, c, h] 168 | visual_feature = self.AdaptiveAvgPool( 169 | visual_feature 170 | ) # [b, w, c, h] -> [b, w, c, 1] 171 | visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c] 172 | 173 | visual_feature = visual_feature.permute(0, 2, 1) # [b, w, c] -> [b, c, w] 174 | visual_feature = self.AdaptiveAvgPool_2( 175 | visual_feature 176 | ) # [b, c, w] -> [b, c, 1] 177 | visual_feature = visual_feature.squeeze(2) # [b, c, 1] -> [b, c] 178 | 179 | """ Binary classifier """ 180 | output = self.predict(visual_feature) # [b, c] -> [b, class] 181 | 182 | if extract_feature == True: 183 | return output, visual_feature 184 | 185 | return output 186 | -------------------------------------------------------------------------------- /source/stratify.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import Subset 8 | 9 | from .dataset import Pseudolabel_Dataset, AlignCollateHDGE, get_dataloader 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | class DomainStratifying(object): 15 | def __init__( 16 | self, args, select_data 17 | ): 18 | """ 19 | Stage 1: Domain Stratifying for Stratified Domain Adaptation using 2 main methods: 20 | - Domain Discriminator (DD) 21 | - Harmonic Domain Gap Estimator (HDGE) 22 | Each method gives one sample a distance d_i. Then, we sort them in ascending order. 23 | 24 | Parameters 25 | ---------- 26 | args: argparse.ArgumentParser().parse_args() 27 | argument 28 | select_data: list() 29 | the array of selected data 30 | """ 31 | 32 | self.args = args 33 | self.method = args.method 34 | self.saved_path = args.saved_path # the path to save the result of stratifying method 35 | self.num_subsets = args.num_subsets # the number of subsets 36 | self.remain_data = select_data # the number of remain data after selection steps 37 | self.k_number = len(select_data) // self.num_subsets # the number of data point per subset 38 | 39 | def save_subset(self, result): 40 | 41 | # sort result in ascending order 42 | distance = sorted(result, key=lambda x: x[1]) 43 | 44 | print("\n5-lowest distance:") 45 | print(distance[:5]) 46 | 47 | print("\n5-highest distance:") 48 | print(distance[-5:]) 49 | 50 | result_index = [u[0] for u in distance] 51 | result_distance = [u[1] for u in distance] 52 | 53 | if self.method == "DD": 54 | np.save(f"{self.saved_path}/{self.method}_{self.args.discriminator}_distance.npy", result_distance) 55 | np.save(f"{self.saved_path}/{self.method}_{self.args.discriminator}_index.npy", result_index) 56 | else: 57 | np.save(f"{self.saved_path}/{self.method}_{self.args.beta}_distance.npy", result_distance) 58 | np.save(f"{self.saved_path}/{self.method}_{self.args.beta}_index.npy", result_index) 59 | 60 | for iter in range(self.num_subsets // 2): 61 | # select k_number lowest distance 62 | add_source = [u for u in result_index[:self.k_number]] 63 | # select k_number highest distance 64 | add_target = [u for u in result_index[-self.k_number:]] 65 | # adjust result 66 | result_index = np.setdiff1d(result_index, add_source + add_target) 67 | 68 | # save work 69 | source = np.array(add_source, dtype=np.int32) 70 | target = np.array(add_target, dtype=np.int32) 71 | 72 | np.save(f"{self.saved_path}/{self.num_subsets}_subsets/subset_{iter + 1}.npy", source) 73 | np.save(f"{self.saved_path}/{self.num_subsets}_subsets/subset_{self.num_subsets - iter}.npy", target) 74 | 75 | if (self.num_subsets % 2 != 0): 76 | result_index = np.array(result_index, dtype=np.int32) 77 | np.save(f"{self.saved_path}/{self.num_subsets}_subsets/subset_{self.num_subsets // 2 + 1}.npy", result_index) 78 | 79 | def stratify_DD(self, adapt_data_raw, model): 80 | """ 81 | Select data point for each subset and save them 82 | 83 | Parameters 84 | ---------- 85 | adapt_data_raw: torch.utils.data.Dataset 86 | adapt data 87 | model: Model 88 | discriminator module for stratifying 89 | 90 | Return 91 | ---------- 92 | """ 93 | 94 | print("Start Domain Stratifying (Domain Discriminator - DD)...\n") 95 | 96 | unlabel_data_remain = Subset(adapt_data_raw, self.remain_data) 97 | 98 | # assign pseudo labels by the order of sample in dataset 99 | unlabel_data_remain = Pseudolabel_Dataset(unlabel_data_remain, self.remain_data) 100 | adapt_data_loader = get_dataloader(self.args, unlabel_data_remain, self.args.batch_size_val, shuffle=False) 101 | 102 | del unlabel_data_remain 103 | 104 | result = [] 105 | model.eval() 106 | with torch.no_grad(): 107 | for batch in tqdm(adapt_data_loader): 108 | image_tensors, index_unlabel = batch 109 | image = image_tensors.to(device) 110 | 111 | preds = model(image) 112 | preds_prob = F.sigmoid(preds).detach().cpu().squeeze().numpy().tolist() 113 | 114 | result.extend(list(zip(index_unlabel, preds_prob))) 115 | 116 | # sort result in ascending order 117 | result = sorted(result, key=lambda x: x[1]) 118 | 119 | self.save_subset(result) 120 | 121 | def stratify_HDGE(self, adapt_data_raw, dis_source, dis_target, beta): 122 | """ 123 | Select data point for each subset and save them 124 | 125 | Parameters 126 | ---------- 127 | adapt_data_raw: torch.utils.data.Dataset 128 | adapt data 129 | dis_source: Model 130 | discriminator of source module 131 | dis_target: Model 132 | discriminator of target module 133 | beta: float 134 | hyperparameter for HDGE method (default: 1) 135 | 136 | Return 137 | ---------- 138 | """ 139 | 140 | print("Start Domain Stratifying (Harmonic Domain Gap Estimator - HDGE)...\n") 141 | 142 | unlabel_data_remain = Subset(adapt_data_raw, self.remain_data) 143 | 144 | myAlignCollate = AlignCollateHDGE(self.args, infer=True) 145 | adapt_data_loader = torch.utils.data.DataLoader( 146 | unlabel_data_remain, 147 | batch_size=self.args.batch_size_val, 148 | shuffle=False, 149 | num_workers=self.args.workers, 150 | collate_fn=myAlignCollate, 151 | pin_memory=False, 152 | drop_last=False, 153 | ) 154 | 155 | del unlabel_data_remain 156 | 157 | dis_source = dis_source.to(device) 158 | dis_target = dis_target.to(device) 159 | 160 | source_loss = [] 161 | target_loss = [] 162 | 163 | dis_source.eval() 164 | dis_target.eval() 165 | with torch.no_grad(): 166 | for batch in tqdm(adapt_data_loader): 167 | image_tensors = batch 168 | image = image_tensors.to(device) 169 | 170 | source_dis = dis_source(image) 171 | target_dis = dis_target(image) 172 | 173 | real_label = torch.ones(source_dis.size()).to(device) 174 | 175 | # calculate MSE for each sample 176 | source_batch_loss = torch.mean((source_dis - real_label)**2, dim=(1,2,3)).cpu().squeeze().numpy().tolist() 177 | target_batch_loss = torch.mean((target_dis - real_label)**2, dim=(1,2,3)).cpu().squeeze().numpy().tolist() 178 | 179 | source_loss.extend(source_batch_loss) 180 | target_loss.extend(target_batch_loss) 181 | 182 | np.save(f"{self.saved_path}/source_loss.npy", source_loss) 183 | np.save(f"{self.saved_path}/target_loss.npy", target_loss) 184 | 185 | # calculate di 186 | def formula(source_loss, target_loss, beta=1): 187 | return(1 + (beta)**2)*source_loss*target_loss / ((beta**2)*source_loss + target_loss) 188 | 189 | distance = [formula(s_loss, t_loss, beta) for s_loss, t_loss in zip(source_loss, target_loss)] 190 | 191 | result = [[index, distance] for index, distance in zip(self.remain_data, distance)] 192 | 193 | self.save_subset(result) 194 | -------------------------------------------------------------------------------- /modules/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class VGG_FeatureExtractor(nn.Module): 5 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ 6 | 7 | def __init__(self, input_channel, output_channel=512): 8 | super(VGG_FeatureExtractor, self).__init__() 9 | self.output_channel = [ 10 | int(output_channel / 8), 11 | int(output_channel / 4), 12 | int(output_channel / 2), 13 | output_channel, 14 | ] # [64, 128, 256, 512] 15 | self.ConvNet = nn.Sequential( 16 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), 17 | nn.ReLU(True), 18 | nn.MaxPool2d(2, 2), # 64x16x50 19 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), 20 | nn.ReLU(True), 21 | nn.MaxPool2d(2, 2), # 128x8x25 22 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), 23 | nn.ReLU(True), # 256x8x25 24 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), 25 | nn.ReLU(True), 26 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 27 | nn.Conv2d( 28 | self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False 29 | ), 30 | nn.BatchNorm2d(self.output_channel[3]), 31 | nn.ReLU(True), # 512x4x25 32 | nn.Conv2d( 33 | self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False 34 | ), 35 | nn.BatchNorm2d(self.output_channel[3]), 36 | nn.ReLU(True), 37 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 38 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), 39 | nn.ReLU(True), 40 | ) # 512x1x24 41 | 42 | def forward(self, input): 43 | return self.ConvNet(input) 44 | 45 | 46 | class ResNet_FeatureExtractor(nn.Module): 47 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 48 | 49 | def __init__(self, input_channel, output_channel=512): 50 | super(ResNet_FeatureExtractor, self).__init__() 51 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 52 | 53 | def forward(self, input): 54 | return self.ConvNet(input) 55 | 56 | 57 | class BasicBlock(nn.Module): 58 | expansion = 1 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(BasicBlock, self).__init__() 62 | self.conv1 = self._conv3x3(inplanes, planes) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = self._conv3x3(planes, planes) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def _conv3x3(self, in_planes, out_planes, stride=1): 71 | "3x3 convolution with padding" 72 | return nn.Conv2d( 73 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 74 | ) 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | def __init__(self, input_channel, output_channel, block, layers): 96 | super(ResNet, self).__init__() 97 | 98 | self.output_channel_block = [ 99 | int(output_channel / 4), 100 | int(output_channel / 2), 101 | output_channel, 102 | output_channel, 103 | ] 104 | 105 | self.inplanes = int(output_channel / 8) 106 | self.conv0_1 = nn.Conv2d( 107 | input_channel, 108 | int(output_channel / 16), 109 | kernel_size=3, 110 | stride=1, 111 | padding=1, 112 | bias=False, 113 | ) 114 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 115 | self.conv0_2 = nn.Conv2d( 116 | int(output_channel / 16), 117 | self.inplanes, 118 | kernel_size=3, 119 | stride=1, 120 | padding=1, 121 | bias=False, 122 | ) 123 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 124 | self.relu = nn.ReLU(inplace=True) 125 | 126 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 127 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 128 | self.conv1 = nn.Conv2d( 129 | self.output_channel_block[0], 130 | self.output_channel_block[0], 131 | kernel_size=3, 132 | stride=1, 133 | padding=1, 134 | bias=False, 135 | ) 136 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 137 | 138 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 139 | self.layer2 = self._make_layer( 140 | block, self.output_channel_block[1], layers[1], stride=1 141 | ) 142 | self.conv2 = nn.Conv2d( 143 | self.output_channel_block[1], 144 | self.output_channel_block[1], 145 | kernel_size=3, 146 | stride=1, 147 | padding=1, 148 | bias=False, 149 | ) 150 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 151 | 152 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 153 | self.layer3 = self._make_layer( 154 | block, self.output_channel_block[2], layers[2], stride=1 155 | ) 156 | self.conv3 = nn.Conv2d( 157 | self.output_channel_block[2], 158 | self.output_channel_block[2], 159 | kernel_size=3, 160 | stride=1, 161 | padding=1, 162 | bias=False, 163 | ) 164 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 165 | 166 | self.layer4 = self._make_layer( 167 | block, self.output_channel_block[3], layers[3], stride=1 168 | ) 169 | self.conv4_1 = nn.Conv2d( 170 | self.output_channel_block[3], 171 | self.output_channel_block[3], 172 | kernel_size=2, 173 | stride=(2, 1), 174 | padding=(0, 1), 175 | bias=False, 176 | ) 177 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 178 | self.conv4_2 = nn.Conv2d( 179 | self.output_channel_block[3], 180 | self.output_channel_block[3], 181 | kernel_size=2, 182 | stride=1, 183 | padding=0, 184 | bias=False, 185 | ) 186 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 187 | 188 | def _make_layer(self, block, planes, blocks, stride=1): 189 | downsample = None 190 | if stride != 1 or self.inplanes != planes * block.expansion: 191 | downsample = nn.Sequential( 192 | nn.Conv2d( 193 | self.inplanes, 194 | planes * block.expansion, 195 | kernel_size=1, 196 | stride=stride, 197 | bias=False, 198 | ), 199 | nn.BatchNorm2d(planes * block.expansion), 200 | ) 201 | 202 | layers = [] 203 | layers.append(block(self.inplanes, planes, stride, downsample)) 204 | self.inplanes = planes * block.expansion 205 | for i in range(1, blocks): 206 | layers.append(block(self.inplanes, planes)) 207 | 208 | return nn.Sequential(*layers) 209 | 210 | def forward(self, x): 211 | x = self.conv0_1(x) 212 | x = self.bn0_1(x) 213 | x = self.relu(x) 214 | x = self.conv0_2(x) 215 | x = self.bn0_2(x) 216 | x = self.relu(x) 217 | 218 | x = self.maxpool1(x) 219 | x = self.layer1(x) 220 | x = self.conv1(x) 221 | x = self.bn1(x) 222 | x = self.relu(x) 223 | 224 | x = self.maxpool2(x) 225 | x = self.layer2(x) 226 | x = self.conv2(x) 227 | x = self.bn2(x) 228 | x = self.relu(x) 229 | 230 | x = self.maxpool3(x) 231 | x = self.layer3(x) 232 | x = self.conv3(x) 233 | x = self.bn3(x) 234 | x = self.relu(x) 235 | 236 | x = self.layer4(x) 237 | x = self.conv4_1(x) 238 | x = self.bn4_1(x) 239 | x = self.relu(x) 240 | x = self.conv4_2(x) 241 | x = self.bn4_2(x) 242 | x = self.relu(x) 243 | 244 | return x 245 | -------------------------------------------------------------------------------- /source/HDGE.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn 9 | from torch.autograd import Variable 10 | from torch.utils.data import Subset 11 | 12 | from .ops import set_grad 13 | from .dataset import AlignCollateHDGE, hierarchical_dataset 14 | 15 | import utils.utils_HDGE as utils 16 | 17 | from modules.generators import define_Gen 18 | from modules.discriminators import define_Dis 19 | 20 | 21 | class HDGE(object): 22 | def __init__(self,args): 23 | 24 | # define the network 25 | self.Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, norm=args.norm, 26 | use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) 27 | self.Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, norm=args.norm, 28 | use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) 29 | self.Da = define_Dis(input_nc=3, ndf=args.ndf, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) 30 | self.Db = define_Dis(input_nc=3, ndf=args.ndf, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) 31 | 32 | utils.print_networks([self.Gab,self.Gba,self.Da,self.Db], ["Gab","Gba","Da","Db"]) 33 | 34 | # define Loss criterias 35 | self.MSE = nn.MSELoss() 36 | self.L1 = nn.L1Loss() 37 | 38 | # optimizers 39 | self.g_optimizer = torch.optim.Adam(itertools.chain(self.Gab.parameters(),self.Gba.parameters()), lr=args.lr, betas=(0.5, 0.999)) 40 | self.d_optimizer = torch.optim.Adam(itertools.chain(self.Da.parameters(),self.Db.parameters()), lr=args.lr, betas=(0.5, 0.999)) 41 | 42 | self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) 43 | self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) 44 | 45 | # to make directories for saving checkpoints 46 | os.makedirs(args.checkpoint_dir, exist_ok=True) 47 | 48 | # try loading checkpoint 49 | try: 50 | ckpt = utils.load_checkpoint("%s/HDGE_gen_dis.ckpt" % (args.checkpoint_dir)) 51 | self.start_epoch = ckpt["epoch"] 52 | self.Da.load_state_dict(ckpt["Da"]) 53 | self.Db.load_state_dict(ckpt["Db"]) 54 | self.Gab.load_state_dict(ckpt["Gab"]) 55 | self.Gba.load_state_dict(ckpt["Gba"]) 56 | self.d_optimizer.load_state_dict(ckpt["d_optimizer"]) 57 | self.g_optimizer.load_state_dict(ckpt["g_optimizer"]) 58 | except: 59 | print(" [*] No checkpoint!") 60 | self.start_epoch = 0 61 | 62 | def train(self,args): 63 | dashed_line = "-" * 80 64 | 65 | # load source domain data (raw) 66 | print(dashed_line) 67 | print("Load source domain data...") 68 | source_data, source_data_log = hierarchical_dataset(args.source_data, args, mode = "raw") 69 | print(source_data_log, end="") 70 | 71 | # load target domain data (raw) 72 | print(dashed_line) 73 | print("Load target domain data...") 74 | target_data, target_data_log = hierarchical_dataset(args.target_data, args, mode = "raw") 75 | print(target_data_log, end="") 76 | 77 | try: 78 | select_data = list(np.load(args.select_data)) 79 | except: 80 | print("\n [*][WARNING] NO available select_data!") 81 | print(" [*][WARNING] You are using all target domain data!\n") 82 | select_data = list(range(len(target_data))) 83 | 84 | target_data_adjust = Subset(target_data, select_data) 85 | 86 | myAlignCollate = AlignCollateHDGE(args) 87 | 88 | a_loader = torch.utils.data.DataLoader( 89 | source_data, 90 | batch_size=args.batch_size, 91 | shuffle=True, 92 | num_workers=args.workers, 93 | collate_fn=myAlignCollate, 94 | pin_memory=False, 95 | drop_last=True, 96 | ) 97 | b_loader = torch.utils.data.DataLoader( 98 | target_data_adjust, 99 | batch_size=args.batch_size, 100 | shuffle=True, 101 | num_workers=args.workers, 102 | collate_fn=myAlignCollate, 103 | pin_memory=False, 104 | drop_last=True, 105 | ) 106 | 107 | a_fake_sample = utils.Sample_from_Pool() 108 | b_fake_sample = utils.Sample_from_Pool() 109 | 110 | a_loader_iter = iter(a_loader) 111 | 112 | print(dashed_line) 113 | print("Start Training Harmonic Domain Gap Estimator (HDGE)...\n") 114 | 115 | for epoch in range(self.start_epoch, args.epochs): 116 | 117 | for b_real in tqdm(b_loader): 118 | 119 | try: 120 | a_real = next(a_loader_iter) 121 | except StopIteration: 122 | del a_loader_iter 123 | a_loader_iter = iter(a_loader) 124 | a_real = next(a_loader_iter) 125 | 126 | # generator Computations 127 | set_grad([self.Da, self.Db], False) 128 | self.g_optimizer.zero_grad() 129 | 130 | a_real = Variable(a_real) 131 | b_real = Variable(b_real) 132 | a_real, b_real = utils.cuda([a_real, b_real]) 133 | 134 | # forward pass through generators 135 | a_fake = self.Gab(b_real) 136 | b_fake = self.Gba(a_real) 137 | 138 | a_recon = self.Gab(b_fake) 139 | b_recon = self.Gba(a_fake) 140 | 141 | a_idt = self.Gab(a_real) 142 | b_idt = self.Gba(b_real) 143 | 144 | # identity losses 145 | a_idt_loss = self.L1(a_idt, a_real) * args.lamda * args.idt_coef 146 | b_idt_loss = self.L1(b_idt, b_real) * args.lamda * args.idt_coef 147 | 148 | # adversarial losses 149 | a_fake_dis = self.Da(a_fake) 150 | b_fake_dis = self.Db(b_fake) 151 | 152 | real_label = utils.cuda(Variable(torch.ones(a_fake_dis.size()))) 153 | 154 | a_gen_loss = self.MSE(a_fake_dis, real_label) 155 | b_gen_loss = self.MSE(b_fake_dis, real_label) 156 | 157 | # cycle consistency losses 158 | a_cycle_loss = self.L1(a_recon, a_real) * args.lamda 159 | b_cycle_loss = self.L1(b_recon, b_real) * args.lamda 160 | 161 | # total generators losses 162 | gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss 163 | 164 | # update generators 165 | gen_loss.backward() 166 | self.g_optimizer.step() 167 | 168 | # discriminator Computations 169 | set_grad([self.Da, self.Db], True) 170 | self.d_optimizer.zero_grad() 171 | 172 | # sample from history of generated images 173 | a_fake = Variable(torch.Tensor(a_fake_sample([a_fake.cpu().data.numpy()])[0])) 174 | b_fake = Variable(torch.Tensor(b_fake_sample([b_fake.cpu().data.numpy()])[0])) 175 | a_fake, b_fake = utils.cuda([a_fake, b_fake]) 176 | 177 | # forward pass through discriminators 178 | a_real_dis = self.Da(a_real) 179 | a_fake_dis = self.Da(a_fake) 180 | b_real_dis = self.Db(b_real) 181 | b_fake_dis = self.Db(b_fake) 182 | real_label = utils.cuda(Variable(torch.ones(a_real_dis.size()))) 183 | fake_label = utils.cuda(Variable(torch.zeros(a_fake_dis.size()))) 184 | 185 | # discriminator losses 186 | a_dis_real_loss = self.MSE(a_real_dis, real_label) 187 | a_dis_fake_loss = self.MSE(a_fake_dis, fake_label) 188 | b_dis_real_loss = self.MSE(b_real_dis, real_label) 189 | b_dis_fake_loss = self.MSE(b_fake_dis, fake_label) 190 | 191 | # total discriminators losses 192 | a_dis_loss = (a_dis_real_loss + a_dis_fake_loss)*0.5 193 | b_dis_loss = (b_dis_real_loss + b_dis_fake_loss)*0.5 194 | 195 | # update discriminators 196 | a_dis_loss.backward() 197 | b_dis_loss.backward() 198 | self.d_optimizer.step() 199 | 200 | print(f"\nEpoch ({epoch+1}/{args.epochs}) | Gen Loss: %0.4f | Dis Loss: %0.4f\n" % (gen_loss,a_dis_loss+b_dis_loss)) 201 | 202 | # override the latest checkpoint 203 | utils.save_checkpoint({"epoch": epoch + 1, 204 | "Da": self.Da.state_dict(), 205 | "Db": self.Db.state_dict(), 206 | "Gab": self.Gab.state_dict(), 207 | "Gba": self.Gba.state_dict(), 208 | "d_optimizer": self.d_optimizer.state_dict(), 209 | "g_optimizer": self.g_optimizer.state_dict()}, 210 | "%s/HDGE_gen_dis.ckpt" % (args.checkpoint_dir)) 211 | 212 | # update learning rates 213 | self.g_lr_scheduler.step() 214 | self.d_lr_scheduler.step() 215 | -------------------------------------------------------------------------------- /modules/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | class TPS_SpatialTransformerNetwork(nn.Module): 11 | """ Rectification Network of RARE, namely TPS based STN """ 12 | 13 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 14 | """ Based on RARE TPS 15 | input: 16 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 17 | I_size : (height, width) of the input image I 18 | I_r_size : (height, width) of the rectified image I_r 19 | I_channel_num : the number of channels of the input image I 20 | output: 21 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 22 | """ 23 | super(TPS_SpatialTransformerNetwork, self).__init__() 24 | self.F = F 25 | self.I_size = I_size 26 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 27 | self.I_channel_num = I_channel_num 28 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 29 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 30 | 31 | def forward(self, batch_I): 32 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 33 | # batch_size x n (= I_r_width x I_r_height) x 2 34 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) 35 | build_P_prime_reshape = build_P_prime.reshape( 36 | [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2] 37 | ) 38 | 39 | if torch.__version__ > "1.2.0": 40 | batch_I_r = F.grid_sample( 41 | batch_I, 42 | build_P_prime_reshape, 43 | padding_mode="border", 44 | align_corners=True, 45 | ) 46 | else: 47 | batch_I_r = F.grid_sample( 48 | batch_I, build_P_prime_reshape, padding_mode="border" 49 | ) 50 | 51 | return batch_I_r 52 | 53 | 54 | class LocalizationNetwork(nn.Module): 55 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 56 | 57 | def __init__(self, F, I_channel_num): 58 | super(LocalizationNetwork, self).__init__() 59 | self.F = F 60 | self.I_channel_num = I_channel_num 61 | self.conv = nn.Sequential( 62 | nn.Conv2d( 63 | in_channels=self.I_channel_num, 64 | out_channels=64, 65 | kernel_size=3, 66 | stride=1, 67 | padding=1, 68 | bias=False, 69 | ), 70 | nn.BatchNorm2d(64), 71 | nn.ReLU(True), 72 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 73 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), 74 | nn.BatchNorm2d(128), 75 | nn.ReLU(True), 76 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 77 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), 78 | nn.BatchNorm2d(256), 79 | nn.ReLU(True), 80 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 81 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), 82 | nn.BatchNorm2d(512), 83 | nn.ReLU(True), 84 | nn.AdaptiveAvgPool2d(1), # batch_size x 512 85 | ) 86 | 87 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 88 | self.localization_fc2 = nn.Linear(256, self.F * 2) 89 | 90 | # init fc2 in LocalizationNetwork 91 | self.localization_fc2.weight.data.fill_(0) 92 | """ see RARE paper Fig. 6 (a) """ 93 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 94 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 95 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 96 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 97 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 98 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 99 | self.localization_fc2.bias.data = ( 100 | torch.from_numpy(initial_bias).float().view(-1) 101 | ) 102 | 103 | def forward(self, batch_I): 104 | """ 105 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 106 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 107 | """ 108 | batch_size = batch_I.size(0) 109 | features = self.conv(batch_I).view(batch_size, -1) 110 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view( 111 | batch_size, self.F, 2 112 | ) 113 | return batch_C_prime 114 | 115 | 116 | class GridGenerator(nn.Module): 117 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 118 | 119 | def __init__(self, F, I_r_size): 120 | """ Generate P_hat and inv_delta_C for later """ 121 | super(GridGenerator, self).__init__() 122 | self.eps = 1e-6 123 | self.I_r_height, self.I_r_width = I_r_size 124 | self.F = F 125 | self.C = self._build_C(self.F) # F x 2 126 | self.P = self._build_P(self.I_r_width, self.I_r_height) 127 | 128 | num_gpu = torch.cuda.device_count() 129 | if num_gpu > 1: 130 | # for multi-gpu, you may need register buffer 131 | self.register_buffer( 132 | "inv_delta_C", 133 | torch.tensor(self._build_inv_delta_C(self.F, self.C)).float(), 134 | ) # F+3 x F+3 135 | self.register_buffer( 136 | "P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() 137 | ) # n x F+3 138 | else: 139 | # for fine-tuning with different image width, you may use below instead of self.register_buffer 140 | self.inv_delta_C = ( 141 | torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().to(device) 142 | ) # F+3 x F+3 143 | self.P_hat = ( 144 | torch.tensor(self._build_P_hat(self.F, self.C, self.P)) 145 | .float() 146 | .to(device) 147 | ) # n x F+3 148 | 149 | def _build_C(self, F): 150 | """ Return coordinates of fiducial points in I_r; C """ 151 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 152 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 153 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 154 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 155 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 156 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 157 | return C # F x 2 158 | 159 | def _build_inv_delta_C(self, F, C): 160 | """ Return inv_delta_C which is needed to calculate T """ 161 | hat_C = np.zeros((F, F), dtype=float) # F x F 162 | for i in range(0, F): 163 | for j in range(i, F): 164 | r = np.linalg.norm(C[i] - C[j]) 165 | hat_C[i, j] = r 166 | hat_C[j, i] = r 167 | np.fill_diagonal(hat_C, 1) 168 | hat_C = (hat_C ** 2) * np.log(hat_C) 169 | # print(C.shape, hat_C.shape) 170 | delta_C = np.concatenate( # F+3 x F+3 171 | [ 172 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 173 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 174 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1), # 1 x F+3 175 | ], 176 | axis=0, 177 | ) 178 | inv_delta_C = np.linalg.inv(delta_C) 179 | return inv_delta_C # F+3 x F+3 180 | 181 | def _build_P(self, I_r_width, I_r_height): 182 | I_r_grid_x = ( 183 | np.arange(-I_r_width, I_r_width, 2) + 1.0 184 | ) / I_r_width # self.I_r_width 185 | I_r_grid_y = ( 186 | np.arange(-I_r_height, I_r_height, 2) + 1.0 187 | ) / I_r_height # self.I_r_height 188 | P = np.stack( # self.I_r_width x self.I_r_height x 2 189 | np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2 190 | ) 191 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 192 | 193 | def _build_P_hat(self, F, C, P): 194 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 195 | P_tile = np.tile( 196 | np.expand_dims(P, axis=1), (1, F, 1) 197 | ) # n x 2 -> n x 1 x 2 -> n x F x 2 198 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 199 | P_diff = P_tile - C_tile # n x F x 2 200 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 201 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 202 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 203 | return P_hat # n x F+3 204 | 205 | def build_P_prime(self, batch_C_prime): 206 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 207 | batch_size = batch_C_prime.size(0) 208 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 209 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 210 | batch_C_prime_with_zeros = torch.cat( 211 | (batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)), dim=1 212 | ) # batch_size x F+3 x 2 213 | batch_T = torch.bmm( 214 | batch_inv_delta_C, batch_C_prime_with_zeros 215 | ) # batch_size x F+3 x 2 216 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 217 | return batch_P_prime # batch_size x n x 2 218 | -------------------------------------------------------------------------------- /stage1_DD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | import numpy as np 8 | from PIL import ImageFile 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | from torch.utils.data import Subset 13 | 14 | from utils.averager import Averager 15 | from utils.criterion import FocalLoss 16 | from utils.load_config import load_config 17 | 18 | from source.model import BaselineClassifier 19 | from source.stratify import DomainStratifying 20 | from source.dataset import Pseudolabel_Dataset, hierarchical_dataset, get_dataloader 21 | 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | 24 | ImageFile.LOAD_TRUNCATED_IMAGES = True 25 | torch.multiprocessing.set_sharing_strategy("file_system") 26 | 27 | 28 | def main(args): 29 | dashed_line = "-" * 80 30 | 31 | # to make directories for saving models and files if not exist 32 | args.saved_path = f"stratify/{args.method}/{args.discriminator}" 33 | os.makedirs(f"{args.saved_path}/{args.num_subsets}_subsets/", exist_ok=True) 34 | 35 | print(dashed_line) 36 | # load source domain data (raw) 37 | print("Load source domain data...") 38 | source_data_raw, source_data_log = hierarchical_dataset(args.source_data, args, mode="raw") 39 | source_data = Pseudolabel_Dataset(source_data_raw, np.full(len(source_data_raw), 0)) 40 | print(source_data_log, end="") 41 | 42 | print(dashed_line) 43 | # load target domain data (raw) 44 | print("Load target domain data...") 45 | target_data_raw, target_data_log = hierarchical_dataset(args.target_data, args, mode="raw") 46 | print(target_data_log, end="") 47 | 48 | try: 49 | select_data = list(np.load(args.select_data)) 50 | except: 51 | print("\n [*][WARNING] NO available select_data!") 52 | print(" [*][WARNING] You are using all target domain data!\n") 53 | select_data = list(range(len(target_data_raw))) 54 | 55 | print(dashed_line) 56 | 57 | # setup model 58 | print("Init model") 59 | model = BaselineClassifier(args) 60 | 61 | # load pretrained model (baseline) 62 | pretrained_state_dict = torch.load(args.saved_model) 63 | print(f"Load pretrained model from {args.saved_model}") 64 | 65 | try: 66 | model.load_state_dict(pretrained_state_dict) 67 | except: 68 | print("\n [*][WARNING] The pre-trained weights do not match the model! Carefully check!\n") 69 | state_dict = model.state_dict() 70 | for key in list(state_dict.keys()): 71 | if (("module." + key) in pretrained_state_dict.keys()): 72 | state_dict[key] = pretrained_state_dict["module." + key].data 73 | # else: 74 | # print(key) 75 | model.load_state_dict(state_dict) 76 | 77 | model = model.to(device) 78 | # print(model.state_dict()) 79 | 80 | # training part 81 | if (args.train == True): 82 | filtered_parameters = [] 83 | params_num = [] 84 | for p in filter(lambda p: p.requires_grad, model.parameters()): 85 | filtered_parameters.append(p) 86 | params_num.append(np.prod(p.size())) 87 | print(f"Trainable params num: {sum(params_num)}") 88 | 89 | print(dashed_line) 90 | 91 | # setup loss (not contain sigmoid function) 92 | criterion = FocalLoss().to(device) 93 | 94 | # load target data adjust (use select data) 95 | target_data_adjust_raw = Subset(target_data_raw, select_data) 96 | target_data_adjust = Pseudolabel_Dataset(target_data_adjust_raw, np.full(len(target_data_adjust_raw), 1)) 97 | 98 | # get dataloader 99 | source_loader = get_dataloader(args, source_data, args.batch_size, shuffle=True, aug=args.aug) 100 | target_loader = get_dataloader(args, target_data_adjust, args.batch_size, shuffle=True, aug=args.aug) 101 | 102 | # set up iter dataloader 103 | source_loader_iter = iter(source_loader) 104 | 105 | # set up optimizer 106 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 107 | 108 | # set up scheduler 109 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 110 | optimizer, 111 | max_lr=args.lr, 112 | cycle_momentum=False, 113 | div_factor=20, 114 | final_div_factor=1000, 115 | total_steps=args.epochs * (len(select_data) // args.batch_size + 1), 116 | ) 117 | 118 | # train 119 | train_loss_avg = Averager() 120 | 121 | print(dashed_line) 122 | print("Start Training Domain Discriminator (DD)...\n") 123 | 124 | for epoch in range(args.epochs): 125 | 126 | model.train() 127 | for batch in tqdm(target_loader): 128 | 129 | images_target_tensor, labels_target = batch 130 | 131 | try: 132 | images_source_tensor, labels_source = next(source_loader_iter) 133 | except StopIteration: 134 | del source_loader_iter 135 | source_loader_iter = iter(source_loader) 136 | images_source_tensor, labels_source = next(source_loader_iter) 137 | 138 | images_tensor = torch.cat((images_source_tensor, images_target_tensor), 0) 139 | labels = labels_source + labels_target 140 | images = images_tensor.to(device) 141 | preds = model(images) 142 | loss = criterion(preds, torch.Tensor(labels).view(-1,1).to(device)) 143 | 144 | # optimize 145 | model.zero_grad(set_to_none=True) 146 | loss.backward() 147 | torch.nn.utils.clip_grad_norm_( 148 | model.parameters(), args.grad_clip 149 | ) # gradient clipping with 5 (Default) 150 | optimizer.step() 151 | train_loss_avg.add(loss) 152 | scheduler.step() 153 | 154 | lr = optimizer.param_groups[0]["lr"] 155 | valid_log = f"\nEpoch {epoch + 1}/{args.epochs}:\n" 156 | valid_log += f"Train_loss: {train_loss_avg.val():0.5f}, Current_lr: {lr:0.7f}\n" 157 | print(valid_log) 158 | train_loss_avg.reset() 159 | 160 | torch.save( 161 | model.state_dict(), 162 | f"{args.saved_path}/DD_{args.discriminator}_discriminator.pth", 163 | ) 164 | 165 | print(dashed_line) 166 | 167 | try: 168 | model.load_state_dict( 169 | torch.load(f"{args.saved_path}/DD_{args.discriminator}_discriminator.pth") 170 | ) 171 | print(f"Load model from {args.saved_path}/DD_{args.discriminator}_discriminator.pth") 172 | except: 173 | print("\n [*][WARNING] NO checkpoint!") 174 | print(" [*][WARNING] You are using the baseline model!") 175 | print(" [*][WARNING] You haven't trained the discriminator yet!\n") 176 | 177 | model.eval() 178 | 179 | # Domain Stratifying (Domain Discriminator - DD) 180 | DD = DomainStratifying(args, select_data) 181 | DD.stratify_DD(target_data_raw, model) 182 | 183 | print("\nAll information is saved in " + f"{args.saved_path}/") 184 | print("The trained weights are saved at " + f"{args.saved_path}/DD_{args.discriminator}_discriminator.pth") 185 | print(dashed_line) 186 | 187 | return 188 | 189 | 190 | if __name__ == "__main__": 191 | """ Argument """ 192 | parser = argparse.ArgumentParser() 193 | config = load_config("config/DD.yaml") 194 | parser.set_defaults(**config) 195 | 196 | parser.add_argument( 197 | "--source_data", default="data/train/synth/", help="path to source domain data", 198 | ) 199 | parser.add_argument( 200 | "--target_data", default="data/train/real/", help="path to target domain data", 201 | ) 202 | parser.add_argument( 203 | "--select_data", 204 | required=True, 205 | help="path to select data file exp: select_data.npy", 206 | ) 207 | parser.add_argument( 208 | "--saved_model", 209 | required=True, 210 | help="path to pretrained model (backbone model)", 211 | ) 212 | parser.add_argument( 213 | "--batch_size", type=int, default=128, help="input batch size", 214 | ) 215 | parser.add_argument( 216 | "--batch_size_val", type=int, default=512, help="input batch size val", 217 | ) 218 | parser.add_argument( 219 | "--epochs", type=int, default=20, help="number of epochs to train for", 220 | ) 221 | """ Adaptation """ 222 | parser.add_argument( 223 | "--num_subsets", 224 | type=int, 225 | required=True, 226 | help="hyper-parameter n, number of subsets partitioned from target domain data", 227 | ) 228 | parser.add_argument( 229 | "--discriminator", 230 | type=str, 231 | required=True, 232 | help="choose discriminator, CRNN|TRBA", 233 | ) 234 | parser.add_argument( 235 | "--train", action="store_true", default=False, help="training or not", 236 | ) 237 | parser.add_argument( 238 | "--aug", action="store_true", default=False, help="augmentation or not", 239 | ) 240 | 241 | args = parser.parse_args() 242 | 243 | if args.discriminator == "CRNN": # CRNN = NVBC 244 | args.Transformation = "None" 245 | args.FeatureExtraction = "VGG" 246 | args.SequenceModeling = "None" 247 | args.Prediction = "CTC" 248 | 249 | elif args.discriminator == "TRBA": # TRBA 250 | args.Transformation = "TPS" 251 | args.FeatureExtraction = "ResNet" 252 | args.SequenceModeling = "None" 253 | args.Prediction = "None" 254 | 255 | """ Seed and GPU setting """ 256 | random.seed(args.manual_seed) 257 | np.random.seed(args.manual_seed) 258 | torch.manual_seed(args.manual_seed) 259 | torch.cuda.manual_seed(args.manual_seed) 260 | 261 | cudnn.benchmark = True # it fasten training 262 | cudnn.deterministic = True 263 | 264 | if sys.platform == "win32": 265 | args.workers = 0 266 | 267 | args.gpu_name = "_".join(torch.cuda.get_device_name().split()) 268 | if sys.platform == "linux": 269 | args.CUDA_VISIBLE_DEVICES = os.environ["CUDA_VISIBLE_DEVICES"] 270 | else: 271 | args.CUDA_VISIBLE_DEVICES = 0 # for convenience 272 | 273 | command_line_input = " ".join(sys.argv) 274 | print( 275 | f"Command line input: CUDA_VISIBLE_DEVICES={args.CUDA_VISIBLE_DEVICES} python {command_line_input}" 276 | ) 277 | 278 | main(args) 279 | -------------------------------------------------------------------------------- /source/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import six 4 | import lmdb 5 | 6 | import PIL 7 | from PIL import Image 8 | 9 | import torch 10 | import torchvision.transforms 11 | from torch.utils.data import Dataset, ConcatDataset, DataLoader 12 | 13 | from .rand_aug import Augmentor 14 | 15 | _MEAN_IMAGENET = torch.tensor([0.485, 0.456, 0.406]) 16 | _STD_IMAGENET = torch.tensor([0.229, 0.224, 0.225]) 17 | 18 | 19 | def get_dataloader(args, dataset, batch_size, shuffle=False, aug=False): 20 | """ 21 | Get dataloader for each dataset 22 | 23 | Parameters 24 | ---------- 25 | args: argparse.ArgumentParser().parse_args() 26 | dataset: torch.utils.data.Dataset 27 | batch_size: int 28 | shuffle: boolean 29 | 30 | Returns 31 | ---------- 32 | data_loader: torch.utils.data.DataLoader 33 | """ 34 | 35 | myAlignCollate = AlignCollate(args, aug) 36 | 37 | data_loader = DataLoader( 38 | dataset, 39 | batch_size=batch_size, 40 | shuffle=shuffle, 41 | num_workers=args.workers, 42 | collate_fn=myAlignCollate, 43 | pin_memory=False, 44 | drop_last=False, 45 | ) 46 | return data_loader 47 | 48 | 49 | def hierarchical_dataset(root, args, mode="label", drop_data=[]): 50 | """ select_data="/" contains all sub-directory of root directory """ 51 | dataset_list = [] 52 | dataset_log = f"dataset_root: {root}\t dataset:\n" 53 | # print(dataset_log) 54 | 55 | listdir = list() 56 | for dirpath, dirnames, filenames in os.walk(root + "/"): 57 | if not dirnames: 58 | # print(dirpath) 59 | flag = True 60 | for u in drop_data: 61 | if u in dirpath: 62 | flag = False 63 | break 64 | if flag == True: 65 | listdir.append(dirpath) 66 | 67 | listdir.sort() 68 | 69 | for dirpath in listdir: 70 | if mode == "raw": 71 | # load data without label 72 | dataset = LmdbDataset_raw(dirpath, args) 73 | else: 74 | # load data with label 75 | dataset = LmdbDataset(dirpath, args) 76 | sub_dataset_log = f"sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}" 77 | # print(sub_dataset_log) 78 | dataset_log += f"{sub_dataset_log}\n" 79 | dataset_list.append(dataset) 80 | 81 | # concatenate many dataset 82 | concatenated_dataset = ConcatDataset(dataset_list) 83 | 84 | return concatenated_dataset, dataset_log 85 | 86 | 87 | class Pseudolabel_Dataset(Dataset): 88 | """ 89 | Assign pseudo labels to data 90 | 91 | Parameters 92 | ---------- 93 | unlabel_dataset: torch.utils.data.Dataset 94 | psudolabel_list: list(object) of pseudo labels 95 | """ 96 | 97 | def __init__(self, unlabel_dataset, psudolabel_list): 98 | self.unlabel_dataset = unlabel_dataset 99 | self.psudolabel_list = psudolabel_list 100 | self.nSamples= len(self.psudolabel_list) 101 | 102 | def __len__(self): 103 | return self.nSamples 104 | 105 | def __getitem__(self, index): 106 | label = self.psudolabel_list[index] 107 | img = self.unlabel_dataset[index] 108 | return (img, label) 109 | 110 | 111 | class AlignCollate(object): 112 | """ Transform data to the same format """ 113 | def __init__(self, args, aug=False): 114 | self.args = args 115 | 116 | if aug==True: 117 | self.transform = Rand_augment() 118 | else: 119 | self.transform = torchvision.transforms.Compose([]) 120 | 121 | # resize image 122 | self.resize = ResizeNormalize(args) 123 | print("Labeled dataloader using Text_augment", self.transform) 124 | 125 | def __call__(self, batch): 126 | images, labels = zip(*batch) 127 | 128 | image_tensors = [self.resize(self.transform(image)) for image in images] 129 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 130 | 131 | return image_tensors, labels 132 | 133 | 134 | class AlignCollateHDGE(object): 135 | """ Transform data to the same format """ 136 | def __init__(self, args, infer=False): 137 | self.args = args 138 | 139 | # for transforming the input image 140 | if infer == False: 141 | transform = torchvision.transforms.Compose( 142 | [torchvision.transforms.RandomHorizontalFlip(), 143 | torchvision.transforms.Resize((args.load_height,args.load_width)), 144 | torchvision.transforms.RandomCrop((args.crop_height,args.crop_width)), 145 | torchvision.transforms.ToTensor(), 146 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 147 | else: 148 | transform = torchvision.transforms.Compose( 149 | [torchvision.transforms.Resize((args.crop_height,args.crop_width)), 150 | torchvision.transforms.ToTensor(), 151 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 152 | 153 | self.transform = transform 154 | 155 | def __call__(self, batch): 156 | images = batch 157 | 158 | image_tensors = [self.transform(image) for image in images] 159 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 160 | 161 | return image_tensors 162 | 163 | 164 | class LmdbDataset(Dataset): 165 | """ Load data from Lmdb file with label """ 166 | def __init__(self, root, args): 167 | 168 | self.root = root 169 | self.args = args 170 | self.env = lmdb.open( 171 | root, 172 | max_readers=32, 173 | readonly=True, 174 | lock=False, 175 | readahead=False, 176 | meminit=False, 177 | ) 178 | if not self.env: 179 | print("cannot open lmdb from %s" % (root)) 180 | sys.exit(0) 181 | 182 | with self.env.begin(write=False) as txn: 183 | self.nSamples = int(txn.get("num-samples".encode())) 184 | self.filtered_index_list = [] 185 | for index in range(self.nSamples): 186 | index += 1 # lmdb starts with 1 187 | label_key = "label-%09d".encode() % index 188 | label = txn.get(label_key).decode("utf-8") 189 | 190 | # length filtering 191 | length_of_label = len(label) 192 | if length_of_label > args.batch_max_length: 193 | continue 194 | 195 | self.filtered_index_list.append(index) 196 | 197 | self.nSamples = len(self.filtered_index_list) 198 | 199 | def __len__(self): 200 | return self.nSamples 201 | 202 | def __getitem__(self, index): 203 | assert index <= len(self), "index range error" 204 | index = self.filtered_index_list[index] 205 | 206 | with self.env.begin(write=False) as txn: 207 | label_key = "label-%09d".encode() % index 208 | label = txn.get(label_key).decode("utf-8") 209 | img_key = "image-%09d".encode() % index 210 | imgbuf = txn.get(img_key) 211 | buf = six.BytesIO() 212 | buf.write(imgbuf) 213 | buf.seek(0) 214 | 215 | try: 216 | img = PIL.Image.open(buf).convert("RGB") 217 | 218 | except IOError: 219 | print(f"Corrupted image for {index}") 220 | # make dummy image and dummy label for corrupted image. 221 | img = PIL.Image.new("RGB", (self.args.imgW, self.args.imgH)) 222 | label = "[dummy_label]" 223 | 224 | return (img, label) 225 | 226 | 227 | class LmdbDataset_raw(Dataset): 228 | """ Load data from Lmdb file without label """ 229 | def __init__(self, root, args): 230 | 231 | self.root = root 232 | self.args = args 233 | self.env = lmdb.open( 234 | root, 235 | max_readers=32, 236 | readonly=True, 237 | lock=False, 238 | readahead=False, 239 | meminit=False, 240 | ) 241 | if not self.env: 242 | print("cannot open lmdb from %s" % (root)) 243 | sys.exit(0) 244 | 245 | with self.env.begin(write=False) as txn: 246 | self.nSamples = int(txn.get("num-samples".encode())) 247 | self.index_list = [index + 1 for index in range(self.nSamples)] 248 | 249 | def __len__(self): 250 | return self.nSamples 251 | 252 | def __getitem__(self, index): 253 | assert index <= len(self), "index range error" 254 | index = self.index_list[index] 255 | 256 | with self.env.begin(write=False) as txn: 257 | img_key = "image-%09d".encode() % index 258 | imgbuf = txn.get(img_key) 259 | buf = six.BytesIO() 260 | buf.write(imgbuf) 261 | buf.seek(0) 262 | 263 | try: 264 | img = PIL.Image.open(buf).convert("RGB") 265 | 266 | except IOError: 267 | print(f"Corrupted image for {img_key}") 268 | # make dummy image for corrupted image. 269 | img = PIL.Image.new("RGB", (self.args.imgW, self.args.imgH)) 270 | 271 | return img 272 | 273 | 274 | class ResizeNormalize(object): 275 | 276 | def __init__(self, args): 277 | self.args = args 278 | _transforms = [] 279 | 280 | _transforms.append(torchvision.transforms.Resize((self.args.imgH, self.args.imgW), 281 | interpolation=torchvision.transforms.InterpolationMode.BICUBIC)) 282 | _transforms.append(torchvision.transforms.ToTensor()) 283 | _transforms.append(torchvision.transforms.Normalize(mean=_MEAN_IMAGENET, std=_STD_IMAGENET)) 284 | self._transforms = torchvision.transforms.Compose(_transforms) 285 | 286 | def __call__(self, image): 287 | image = self._transforms(image) 288 | 289 | return image 290 | 291 | 292 | class Weak_augment(object): 293 | 294 | def __init__(self): 295 | augmentation = [] 296 | augmentation.append( 297 | torchvision.transforms.ColorJitter(brightness=0.2, 298 | contrast=0.1, 299 | saturation=0.1, 300 | hue=0.05)) 301 | self.Augment = torchvision.transforms.Compose(augmentation) 302 | 303 | def __call__(self, image): 304 | image = self.Augment(image) 305 | 306 | return image 307 | 308 | 309 | class Rand_augment(object): 310 | 311 | def __init__(self): 312 | self.first_augmentor = Augmentor(2, 5, "spatial") 313 | self.augmentor = Augmentor(2, 10, "channel") 314 | 315 | def __call__(self, image): 316 | image = self.first_augmentor(image) 317 | image = self.augmentor(image) 318 | 319 | return image 320 | 321 | 322 | if __name__ == "__main__": 323 | image = Image.open("image.jpg") 324 | 325 | # weak augment 326 | weak_augment = Weak_augment() 327 | image_weak = weak_augment(image) 328 | image.save("weak_augment.jpg") 329 | 330 | # rand augment 331 | rand_augment = Rand_augment() 332 | image_weak = rand_augment(image) 333 | image.save("rand_augment.jpg") 334 | -------------------------------------------------------------------------------- /supervised_learning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | import numpy as np 8 | from PIL import ImageFile 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | 13 | from utils.averager import Averager 14 | from utils.converter import AttnLabelConverter, CTCLabelConverter 15 | from utils.load_config import load_config 16 | 17 | from source.model import Model 18 | from source.dataset import hierarchical_dataset, get_dataloader 19 | 20 | from test import validation 21 | 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | 24 | ImageFile.LOAD_TRUNCATED_IMAGES = True 25 | torch.multiprocessing.set_sharing_strategy("file_system") 26 | 27 | 28 | def main(args): 29 | dashed_line = "-" * 80 30 | main_log = "" 31 | 32 | # to make directories for saving model and log files if not exist 33 | os.makedirs("trained_model/", exist_ok=True) 34 | os.makedirs("log/", exist_ok=True) 35 | 36 | # load source domain data for supervised learning 37 | print(dashed_line) 38 | main_log = dashed_line + "\n" 39 | print("Load training data (source domain)...") 40 | main_log += "Load training data (source domain)...\n" 41 | 42 | train_data, train_data_log = hierarchical_dataset(args.train_data, args) 43 | 44 | train_loader = get_dataloader(args, train_data, args.batch_size, shuffle=True, aug=args.aug) 45 | 46 | print(train_data_log, end="") 47 | main_log += train_data_log 48 | 49 | # load validation data 50 | print(dashed_line) 51 | main_log += dashed_line + "\n" 52 | print("Load validation data...") 53 | main_log += "Load validation data...\n" 54 | 55 | valid_data, valid_data_log = hierarchical_dataset(args.valid_data, args) 56 | valid_loader = get_dataloader(args, valid_data, args.batch_size_val, shuffle=False) # "True" to check training progress with validation function. 57 | 58 | print(valid_data_log, end="") 59 | main_log += valid_data_log 60 | 61 | print(dashed_line) 62 | main_log += dashed_line + "\n" 63 | print("Init model") 64 | main_log += "Init model\n" 65 | 66 | """ Model configuration """ 67 | if args.Prediction == "CTC": 68 | converter = CTCLabelConverter(args.character) 69 | else: 70 | converter = AttnLabelConverter(args.character) 71 | args.sos_token_index = converter.dict["[SOS]"] 72 | args.eos_token_index = converter.dict["[EOS]"] 73 | args.num_class = len(converter.character) 74 | 75 | # setup model 76 | model = Model(args) 77 | 78 | # data parallel for multi-GPU 79 | model = torch.nn.DataParallel(model).to(device) 80 | model.train() 81 | 82 | # load pretrained model 83 | if args.saved_model != "": 84 | pretrained = torch.load(args.saved_model) 85 | model.load_state_dict(pretrained) 86 | torch.save( 87 | pretrained, 88 | f"trained_model/{args.model}_supervised.pth", 89 | ) 90 | print(f"Load pretrained model from {args.saved_model}") 91 | main_log += "Load pretrained model\n" 92 | 93 | # setup loss 94 | if args.Prediction == "CTC": 95 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 96 | else: 97 | # ignore [PAD] token 98 | criterion = torch.nn.CrossEntropyLoss(ignore_index=converter.dict["[PAD]"]).to(device) 99 | 100 | # filter that only require gradient descent 101 | filtered_parameters = [] 102 | params_num = [] 103 | for p in filter(lambda p: p.requires_grad, model.parameters()): 104 | filtered_parameters.append(p) 105 | params_num.append(np.prod(p.size())) 106 | print(f"Trainable params num: {sum(params_num)}") 107 | main_log += f"Trainable params num: {sum(params_num)}\n" 108 | 109 | """ Final options """ 110 | print("------------ Options -------------") 111 | main_log += "------------ Options -------------\n" 112 | opt = vars(args) 113 | for k, v in opt.items(): 114 | if str(k) == "character" and len(str(v)) > 500: 115 | print(f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}") 116 | main_log += f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}\n" 117 | else: 118 | print(f"{str(k)}: {str(v)}") 119 | main_log += f"{str(k)}: {str(v)}\n" 120 | print(dashed_line) 121 | main_log += dashed_line + "\n" 122 | print("Start Supervised Learning (Scene Text Recognition - STR)...\n") 123 | main_log += "Start Supervised Learning (Scene Text Recognition - STR)...\n" 124 | 125 | # set up optimizer 126 | optimizer = torch.optim.AdamW(filtered_parameters, lr=args.lr, weight_decay=args.weight_decay) 127 | 128 | # set up scheduler 129 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 130 | optimizer, 131 | max_lr=args.lr, 132 | cycle_momentum=False, 133 | div_factor=20, 134 | final_div_factor=1000, 135 | total_steps=(args.epochs * len(train_loader)), 136 | ) 137 | 138 | train_loss_avg = Averager() 139 | best_score = float("-inf") 140 | score_descent = 0 141 | 142 | # training loop 143 | for epoch in range(args.epochs): 144 | 145 | # training part 146 | model.train() 147 | for (images, labels) in tqdm(train_loader): 148 | batch_size = len(labels) 149 | 150 | images_tensor = images.to(device) 151 | labels_index, labels_length = converter.encode( 152 | labels, batch_max_length=args.batch_max_length 153 | ) 154 | 155 | if args.Prediction == "CTC": 156 | preds = model(images_tensor) 157 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 158 | preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2) 159 | loss = criterion(preds_log_softmax, labels_index, preds_size, labels_length) 160 | else: 161 | preds = model(images_tensor, labels_index[:, :-1]) # align with Attention.forward 162 | target = labels_index[:, 1:] # without [SOS] Symbol 163 | loss = criterion( 164 | preds.view(-1, preds.shape[-1]), target.contiguous().view(-1) 165 | ) 166 | 167 | model.zero_grad(set_to_none=True) 168 | loss.backward() 169 | torch.nn.utils.clip_grad_norm_( 170 | model.parameters(), args.grad_clip 171 | ) # gradient clipping with 5 (Default) 172 | optimizer.step() 173 | 174 | train_loss_avg.add(loss) 175 | 176 | scheduler.step() 177 | 178 | # valiation part 179 | model.eval() 180 | with torch.no_grad(): 181 | ( 182 | valid_loss, 183 | current_score, 184 | preds, 185 | confidence_score, 186 | labels, 187 | infer_time, 188 | length_of_data, 189 | ) = validation(model, criterion, valid_loader, converter, args) 190 | model.train() 191 | 192 | if (current_score >= best_score): 193 | score_descent = 0 194 | 195 | best_score = current_score 196 | torch.save( 197 | model.state_dict(), 198 | f"trained_model/{args.model}_supervised.pth", 199 | ) 200 | else: 201 | score_descent += 1 202 | 203 | # log 204 | lr = optimizer.param_groups[0]["lr"] 205 | valid_log = f"\nEpoch {epoch + 1}/{args.epochs}:\n" 206 | valid_log += f"Train_loss: {train_loss_avg.val():0.3f}, Valid_loss: {valid_loss:0.3f}, " 207 | valid_log += f"Current_lr: {lr:0.7f},\n" 208 | valid_log += f"Current_score: {current_score:0.2f}, Best_score: {best_score:0.2f}, " 209 | valid_log += f"Score_descent: {score_descent}\n" 210 | print(valid_log) 211 | 212 | main_log += valid_log 213 | main_log += "\n" + dashed_line + "\n" 214 | 215 | train_loss_avg.reset() 216 | 217 | # free cache 218 | torch.cuda.empty_cache() 219 | 220 | # save log 221 | print("Training is done!") 222 | main_log += "Training is done!" 223 | main_log += f"Model is saved at trained_model/{args.model}_supervised.pth" 224 | print(main_log, file= open(f"log/{args.model}_supervised.txt", "w")) 225 | 226 | print(f"Model is saved at trained_model/{args.model}_supervised.pth") 227 | print(f"All information is saved at log/{args.model}_supervised.txt") 228 | print(dashed_line) 229 | 230 | return 231 | 232 | 233 | if __name__ == "__main__": 234 | """ Argument """ 235 | parser = argparse.ArgumentParser() 236 | config = load_config("config/STR.yaml") 237 | parser.set_defaults(**config) 238 | 239 | parser.add_argument( 240 | "--train_data", default="data/train/synth/", help="path to training dataset", 241 | ) 242 | parser.add_argument( 243 | "--valid_data", default="data/val/", help="path to validation dataset", 244 | ) 245 | parser.add_argument( 246 | "--saved_model", default="", help="path to pretrained model (to continue training)", 247 | ) 248 | parser.add_argument( 249 | "--batch_size", type=int, default=128, help="input batch size", 250 | ) 251 | parser.add_argument( 252 | "--batch_size_val", type=int, default=512, help="input batch size val", 253 | ) 254 | parser.add_argument( 255 | "--epochs", type=int, default=20, help="number of epochs to train for", 256 | ) 257 | parser.add_argument( 258 | "--val_interval", type=int, default=1000, help="interval between each validation", 259 | ) 260 | parser.add_argument( 261 | "--NED", action="store_true", help="for Normalized edit_distance", 262 | ) 263 | """ Model Architecture """ 264 | parser.add_argument( 265 | "--model", 266 | type=str, 267 | required=True, 268 | help="CRNN|TRBA", 269 | ) 270 | """ Training """ 271 | parser.add_argument( 272 | "--aug", action="store_true", default=False, help="augmentation or not", 273 | ) 274 | 275 | args = parser.parse_args() 276 | 277 | if args.model == "CRNN": # CRNN = NVBC 278 | args.Transformation = "None" 279 | args.FeatureExtraction = "VGG" 280 | args.SequenceModeling = "BiLSTM" 281 | args.Prediction = "CTC" 282 | 283 | elif args.model == "TRBA": # TRBA 284 | args.Transformation = "TPS" 285 | args.FeatureExtraction = "ResNet" 286 | args.SequenceModeling = "BiLSTM" 287 | args.Prediction = "Attn" 288 | 289 | """ Seed and GPU setting """ 290 | random.seed(args.manual_seed) 291 | np.random.seed(args.manual_seed) 292 | torch.manual_seed(args.manual_seed) 293 | torch.cuda.manual_seed(args.manual_seed) 294 | 295 | cudnn.benchmark = True # it fasten training 296 | cudnn.deterministic = True 297 | 298 | if sys.platform == "win32": 299 | args.workers = 0 300 | 301 | args.gpu_name = "_".join(torch.cuda.get_device_name().split()) 302 | if sys.platform == "linux": 303 | args.CUDA_VISIBLE_DEVICES = os.environ["CUDA_VISIBLE_DEVICES"] 304 | else: 305 | args.CUDA_VISIBLE_DEVICES = 0 # for convenience 306 | 307 | command_line_input = " ".join(sys.argv) 308 | print( 309 | f"Command line input: CUDA_VISIBLE_DEVICES={args.CUDA_VISIBLE_DEVICES} python {command_line_input}" 310 | ) 311 | 312 | main(args) 313 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import time 5 | import random 6 | import argparse 7 | from tqdm import tqdm 8 | 9 | import numpy as np 10 | from PIL import ImageFile 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.backends.cudnn as cudnn 15 | from nltk.metrics.distance import edit_distance 16 | 17 | from utils.averager import Averager 18 | from utils.converter import AttnLabelConverter, CTCLabelConverter 19 | from utils.load_config import load_config 20 | 21 | from source.model import Model 22 | from source.dataset import AlignCollate, hierarchical_dataset 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | ImageFile.LOAD_TRUNCATED_IMAGES = True 27 | torch.multiprocessing.set_sharing_strategy("file_system") 28 | 29 | 30 | def benchmark_all_eval(model, criterion, converter, args): 31 | """ Evaluation with 6 benchmark evaluation datasets """ 32 | eval_data_list = [ 33 | "IIIT5k", 34 | "SVT", 35 | "IC13_1015", 36 | "IC15_2077", 37 | "SVTP", 38 | "CUTE80", 39 | ] 40 | if (args.addition == True): 41 | eval_data_list = [ 42 | "COCOv1.4", 43 | "Uber", 44 | "ArT", 45 | "ReCTS", 46 | ] 47 | if (args.exception == True): 48 | eval_data_list = [ 49 | "IC13_857", 50 | "IC15_1811", 51 | ] 52 | if (args.union == True): 53 | eval_data_list = [ 54 | "artistic", 55 | "contextless", 56 | "curve", 57 | "general", 58 | ] 59 | 60 | accuracy_list = [] 61 | total_forward_time = 0 62 | total_eval_data_number = 0 63 | total_correct_number = 0 64 | 65 | dashed_line = "-" * 80 66 | print(dashed_line) 67 | 68 | for eval_data in eval_data_list: 69 | eval_data_path = os.path.join(args.eval_data, eval_data) 70 | AlignCollate_eval = AlignCollate(args) 71 | eval_data, eval_data_log = hierarchical_dataset( 72 | root=eval_data_path, args=args 73 | ) 74 | print(eval_data_log) 75 | eval_loader = torch.utils.data.DataLoader( 76 | eval_data, 77 | batch_size=args.batch_size_val, 78 | shuffle=False, 79 | num_workers=int(args.workers), 80 | collate_fn=AlignCollate_eval, 81 | pin_memory=True, 82 | ) 83 | 84 | _, accuracy_by_best_model, _, _, _, infer_time, length_of_data = validation( 85 | model, criterion, eval_loader, converter, args, tqdm_position=0 86 | ) 87 | accuracy_list.append(f"{accuracy_by_best_model:0.2f}") 88 | total_forward_time += infer_time 89 | total_eval_data_number += len(eval_data) 90 | total_correct_number += accuracy_by_best_model * length_of_data 91 | print(f"Acc {accuracy_by_best_model:0.2f}") 92 | print(dashed_line) 93 | 94 | averaged_forward_time = total_forward_time / total_eval_data_number * 1000 95 | total_accuracy = total_correct_number / total_eval_data_number 96 | params_num = sum([np.prod(p.size()) for p in model.parameters()]) 97 | 98 | eval_log = "Accuracy:\n" 99 | for name, accuracy in zip(eval_data_list, accuracy_list): 100 | eval_log += f"{name}: {accuracy} | " 101 | eval_log += f"\nTotal_accuracy: {total_accuracy:0.2f}\t" 102 | eval_log += f"Averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.2f}" 103 | print(eval_log) 104 | 105 | # for convenience 106 | print() 107 | print("\t".join(accuracy_list)) 108 | print(f"Total_accuracy: {total_accuracy:0.2f}") 109 | 110 | return total_accuracy, eval_data_list, accuracy_list 111 | 112 | 113 | def validation(model, criterion, eval_loader, converter, args, tqdm_position=1): 114 | """ Validation or evaluation """ 115 | n_correct = 0 116 | norm_ED = 0 117 | length_of_data = 0 118 | infer_time = 0 119 | valid_loss_avg = Averager() 120 | 121 | for i, (image_tensors, labels) in tqdm( 122 | enumerate(eval_loader), 123 | total=len(eval_loader), 124 | position=tqdm_position, 125 | leave=False, 126 | ): 127 | batch_size = image_tensors.size(0) 128 | length_of_data = length_of_data + batch_size 129 | image = image_tensors.to(device) 130 | # for max length prediction 131 | labels_index, labels_length = converter.encode( 132 | labels, batch_max_length=args.batch_max_length 133 | ) 134 | 135 | if "CTC" in args.Prediction: 136 | start_time = time.time() 137 | preds = model(image) 138 | forward_time = time.time() - start_time 139 | 140 | # calculate evaluation loss for CTC deocder. 141 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 142 | # permute "preds" to use CTCloss format 143 | cost = criterion( 144 | preds.log_softmax(2).permute(1, 0, 2), 145 | labels_index, 146 | preds_size, 147 | labels_length, 148 | ) 149 | else: 150 | text_for_pred = ( 151 | torch.LongTensor(batch_size).fill_(converter.dict["[SOS]"]).to(device) 152 | ) 153 | 154 | start_time = time.time() 155 | preds = model(image, text_for_pred, is_train=False) 156 | forward_time = time.time() - start_time 157 | 158 | target = labels_index[:, 1:] # without [SOS] Symbol 159 | cost = criterion( 160 | preds.contiguous().view(-1, preds.shape[-1]), 161 | target.contiguous().view(-1), 162 | ) 163 | 164 | # select max probabilty (greedy decoding) then decode index to character 165 | _, preds_index = preds.max(2) 166 | preds_size = torch.IntTensor([preds.size(1)] * preds_index.size(0)).to(device) 167 | preds_str = converter.decode(preds_index, preds_size) 168 | 169 | infer_time += forward_time 170 | valid_loss_avg.add(cost) 171 | 172 | # calculate accuracy & confidence score 173 | preds_prob = F.softmax(preds, dim=2) 174 | preds_max_prob, _ = preds_prob.max(dim=2) 175 | confidence_score_list = [] 176 | for gt, prd, prd_max_prob in zip(labels, preds_str, preds_max_prob): 177 | if "Attn" in args.Prediction: 178 | prd_EOS = prd.find("[EOS]") 179 | prd = prd[:prd_EOS] # prune after "end of sentence" token ([EOS]) 180 | prd_max_prob = prd_max_prob[:prd_EOS] 181 | 182 | """ 183 | In our experiment, if the model predicts at least one [UNK] token, we count the word prediction as incorrect. 184 | To not take account of [UNK] token, use the below line. 185 | prd = prd.replace("[UNK]", "") 186 | """ 187 | 188 | # to evaluate "case sensitive model" with alphanumeric and case insensitve setting. = same with ASTER 189 | gt = gt.lower() 190 | prd = prd.lower() 191 | alphanumeric_case_insensitve = "0123456789abcdefghijklmnopqrstuvwxyz" 192 | out_of_alphanumeric_case_insensitve = f"[^{alphanumeric_case_insensitve}]" 193 | gt = re.sub(out_of_alphanumeric_case_insensitve, "", gt) 194 | prd = re.sub(out_of_alphanumeric_case_insensitve, "", prd) 195 | 196 | if args.NED: 197 | # ICDAR2019 Normalized Edit Distance 198 | if len(gt) == 0 or len(prd) == 0: 199 | norm_ED += 0 200 | elif len(gt) > len(prd): 201 | norm_ED += 1 - edit_distance(prd, gt) / len(gt) 202 | else: 203 | norm_ED += 1 - edit_distance(prd, gt) / len(prd) 204 | 205 | else: 206 | if prd == gt: 207 | n_correct += 1 208 | 209 | # calculate confidence score (= multiply of prd_max_prob) 210 | try: 211 | confidence_score = prd_max_prob.cumprod(dim=0)[-1] 212 | except: 213 | confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([EOS]) 214 | confidence_score_list.append(confidence_score) 215 | 216 | if args.NED: 217 | # ICDAR2019 Normalized Edit Distance. In web page, they report % of norm_ED (= norm_ED * 100). 218 | score = norm_ED / float(length_of_data) * 100 219 | else: 220 | score = n_correct / float(length_of_data) * 100 # accuracy 221 | 222 | return ( 223 | valid_loss_avg.val(), 224 | score, 225 | preds_str, 226 | confidence_score_list, 227 | labels, 228 | infer_time, 229 | length_of_data, 230 | ) 231 | 232 | 233 | def test(args): 234 | """ Model configuration """ 235 | if "CTC" in args.Prediction: 236 | converter = CTCLabelConverter(args.character) 237 | else: 238 | converter = AttnLabelConverter(args.character) 239 | args.sos_token_index = converter.dict["[SOS]"] 240 | args.eos_token_index = converter.dict["[EOS]"] 241 | args.num_class = len(converter.character) 242 | 243 | model = Model(args) 244 | print( 245 | "model input parameters", 246 | args.imgH, 247 | args.imgW, 248 | args.num_fiducial, 249 | args.input_channel, 250 | args.output_channel, 251 | args.hidden_size, 252 | args.num_class, 253 | args.batch_max_length, 254 | args.Transformation, 255 | args.FeatureExtraction, 256 | args.SequenceModeling, 257 | args.Prediction, 258 | ) 259 | model = torch.nn.DataParallel(model).to(device) 260 | 261 | # load model 262 | print("loading pretrained model from %s" % args.saved_model) 263 | try: 264 | model.load_state_dict( 265 | torch.load(args.saved_model, map_location=device) 266 | ) 267 | except: 268 | print("\n [*][WARNING] The pre-trained weights do not match the model! Carefully check!\n") 269 | # pretrained_state_dict = torch.load(args.saved_model) 270 | # for name in pretrained_state_dict: 271 | # print(name) 272 | model.load_state_dict( 273 | torch.load(args.saved_model, map_location=device), strict=False 274 | ) 275 | 276 | """ Setup loss """ 277 | if "CTC" in args.Prediction: 278 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 279 | else: 280 | # ignore [PAD] token 281 | criterion = torch.nn.CrossEntropyLoss(ignore_index=converter.dict["[PAD]"]).to( 282 | device 283 | ) 284 | 285 | """ Evaluation """ 286 | model.eval() 287 | with torch.no_grad(): 288 | # evaluate 6 benchmark evaluation datasets 289 | benchmark_all_eval(model, criterion, converter, args) 290 | 291 | 292 | if __name__ == "__main__": 293 | """ Argument """ 294 | parser = argparse.ArgumentParser() 295 | config = load_config("config/STR.yaml") 296 | parser.set_defaults(**config) 297 | 298 | parser.add_argument( 299 | "--eval_data", default="data/test/benchmark/", help="path to evaluation dataset", 300 | ) 301 | parser.add_argument( 302 | "--addition", action="store_true", default=False, help="test on addition data", 303 | ) 304 | parser.add_argument( 305 | "--exception", action="store_true", default=False, help="test on exception data", 306 | ) 307 | parser.add_argument( 308 | "--union", action="store_true", default=False, help="test on Union14M data", 309 | ) 310 | parser.add_argument( 311 | "--saved_model", 312 | required=True, 313 | help="path to saved_model to evaluation", 314 | ) 315 | parser.add_argument( 316 | "--batch_size_val", type=int, default=512, help="input batch size", 317 | ) 318 | parser.add_argument( 319 | "--NED", action="store_true", help="for Normalized edit_distance", 320 | ) 321 | """ Model Architecture """ 322 | parser.add_argument( 323 | "--model", 324 | type=str, 325 | required=True, 326 | help="CRNN|TRBA", 327 | ) 328 | 329 | args = parser.parse_args() 330 | 331 | if args.model == "CRNN": # CRNN = NVBC 332 | args.Transformation = "None" 333 | args.FeatureExtraction = "VGG" 334 | args.SequenceModeling = "BiLSTM" 335 | args.Prediction = "CTC" 336 | 337 | elif args.model == "TRBA": # TRBA 338 | args.Transformation = "TPS" 339 | args.FeatureExtraction = "ResNet" 340 | args.SequenceModeling = "BiLSTM" 341 | args.Prediction = "Attn" 342 | 343 | """ Seed and GPU setting """ 344 | random.seed(args.manual_seed) 345 | np.random.seed(args.manual_seed) 346 | torch.manual_seed(args.manual_seed) 347 | torch.cuda.manual_seed(args.manual_seed) 348 | 349 | cudnn.benchmark = True # it fasten training 350 | cudnn.deterministic = True 351 | 352 | if sys.platform == "win32": 353 | args.workers = 0 354 | 355 | args.gpu_name = "_".join(torch.cuda.get_device_name().split()) 356 | if sys.platform == "linux": 357 | args.CUDA_VISIBLE_DEVICES = os.environ["CUDA_VISIBLE_DEVICES"] 358 | else: 359 | args.CUDA_VISIBLE_DEVICES = 0 # for convenience 360 | 361 | command_line_input = " ".join(sys.argv) 362 | print( 363 | f"Command line input: CUDA_VISIBLE_DEVICES={args.CUDA_VISIBLE_DEVICES} python {command_line_input}" 364 | ) 365 | 366 | test(args) 367 | -------------------------------------------------------------------------------- /stage2_StrDA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import argparse 6 | from tqdm import tqdm 7 | 8 | import numpy as np 9 | from PIL import ImageFile 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.backends.cudnn as cudnn 14 | from torch.utils.data import Subset 15 | 16 | from utils.averager import Averager 17 | from utils.converter import AttnLabelConverter, CTCLabelConverter 18 | from utils.load_config import load_config 19 | 20 | from source.model import Model 21 | from source.dataset import Pseudolabel_Dataset, hierarchical_dataset, get_dataloader 22 | 23 | from test import validation 24 | 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | ImageFile.LOAD_TRUNCATED_IMAGES = True 28 | torch.multiprocessing.set_sharing_strategy("file_system") 29 | 30 | 31 | def pseudo_labeling(args, model, converter, target_data, adapting_list, round): 32 | """ Make prediction and return them """ 33 | 34 | # get adapt_data 35 | data = Subset(target_data, adapting_list) 36 | data = Pseudolabel_Dataset(data, adapting_list) 37 | dataloader = get_dataloader(args, data, args.batch_size_val, shuffle=False) 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | list_adapt_data = list() 42 | list_pseudo_data = list() 43 | list_pseudo_label = list() 44 | 45 | mean_conf = 0 46 | 47 | for (image_tensors, image_indexs) in tqdm(dataloader): 48 | batch_size = len(image_indexs) 49 | image = image_tensors.to(device) 50 | 51 | if args.Prediction == "CTC": 52 | preds = model(image) 53 | else: 54 | text_for_pred = ( 55 | torch.LongTensor(batch_size) 56 | .fill_(args.sos_token_index) 57 | .to(device) 58 | ) 59 | preds = model(image, text_for_pred, is_train=False) 60 | 61 | # select max probabilty (greedy decoding) then decode index to character 62 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 63 | _, preds_index = preds.max(2) 64 | preds_str = converter.decode(preds_index, preds_size) 65 | 66 | preds_prob = F.softmax(preds, dim=2) 67 | preds_max_prob, _ = preds_prob.max(dim=2) 68 | 69 | for pred, pred_max_prob, index in zip( 70 | preds_str, preds_max_prob, image_indexs 71 | ): 72 | if args.Prediction == "Attn": 73 | pred_EOS = pred.find("[EOS]") 74 | pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) 75 | pred_max_prob = pred_max_prob[:pred_EOS] 76 | 77 | if ( 78 | "[PAD]" in pred 79 | or "[UNK]" in pred 80 | or "[SOS]" in pred 81 | ): 82 | list_pseudo_label.append(pred) 83 | continue 84 | 85 | # calculate confidence score (= multiply of pred_max_prob) 86 | if len(pred_max_prob.cumprod(dim=0)) > 0: 87 | confidence_score = pred_max_prob.cumprod(dim=0)[-1].item() 88 | else: 89 | list_pseudo_label.append(pred) 90 | continue 91 | 92 | list_adapt_data.append(index) 93 | list_pseudo_data.append(pred) 94 | 95 | mean_conf += confidence_score 96 | 97 | mean_conf /= (len(list_adapt_data)) 98 | # adjust mean_conf (round_down) 99 | mean_conf = int(mean_conf * 10) / 10 100 | 101 | # save pseudo-labels 102 | with open(f"stratify/{args.method}/pseudolabel_{round}.txt", "w") as file: 103 | for string in list_pseudo_label: 104 | file.write(string + "\n") 105 | 106 | # free cache 107 | torch.cuda.empty_cache() 108 | 109 | return list_adapt_data, list_pseudo_data, mean_conf 110 | 111 | 112 | def self_training(args, filtered_parameters, model, criterion, converter, relative_path, \ 113 | source_loader, valid_loader, adapting_loader, mean_conf, round=0): 114 | 115 | num_iter = (args.total_iter // args.val_interval) // args.num_subsets * args.val_interval 116 | 117 | if round == 1: 118 | num_iter += (args.total_iter // args.val_interval) % args.num_subsets * args.val_interval 119 | 120 | # set up iter dataloader 121 | source_loader_iter = iter(source_loader) 122 | adapting_loader_iter = iter(adapting_loader) 123 | 124 | # set up optimizer 125 | optimizer = torch.optim.AdamW(filtered_parameters, lr=args.lr, weight_decay=args.weight_decay) 126 | 127 | # set up scheduler 128 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 129 | optimizer, 130 | max_lr=args.lr, 131 | cycle_momentum=False, 132 | div_factor=20, 133 | final_div_factor=1000, 134 | total_steps=num_iter, 135 | ) 136 | 137 | train_loss_avg = Averager() 138 | source_loss_avg = Averager() 139 | adapting_loss_avg = Averager() 140 | best_score = float("-inf") 141 | score_descent = 0 142 | 143 | log = "-" * 80 +"\n" 144 | log += "Start Self-Training (Scene Text Recognition - STR)...\n" 145 | 146 | model.train() 147 | # training loop 148 | for iteration in tqdm( 149 | range(0, num_iter + 1), 150 | total=num_iter, 151 | position=0, 152 | leave=True, 153 | ): 154 | if (iteration % args.val_interval == 0 or iteration == num_iter): 155 | # valiation part 156 | model.eval() 157 | with torch.no_grad(): 158 | ( 159 | valid_loss, 160 | current_score, 161 | preds, 162 | confidence_score, 163 | labels, 164 | infer_time, 165 | length_of_data, 166 | ) = validation(model, criterion, valid_loader, converter, args) 167 | 168 | if (current_score >= best_score): 169 | score_descent = 0 170 | 171 | best_score = current_score 172 | torch.save( 173 | model.state_dict(), 174 | f"trained_model/{relative_path}/{args.model}_round{round}.pth", 175 | ) 176 | else: 177 | score_descent += 1 178 | 179 | # log 180 | lr = optimizer.param_groups[0]["lr"] 181 | valid_log = f"\nValidation at {iteration}/{num_iter}:\n" 182 | valid_log += f"Train_loss: {train_loss_avg.val():0.4f}, Valid_loss: {valid_loss:0.4f}, " 183 | valid_log += f"Source_loss: {source_loss_avg.val():0.4f}, Adapting_loss: {adapting_loss_avg.val():0.4f},\n" 184 | valid_log += f"Current_lr: {lr:0.7f}, " 185 | valid_log += f"Current_score: {current_score:0.2f}, Best_score: {best_score:0.2f}, " 186 | valid_log += f"Score_descent: {score_descent}\n" 187 | print(valid_log) 188 | 189 | log += valid_log 190 | 191 | log += "\n" + "-" * 80 +"\n" 192 | 193 | train_loss_avg.reset() 194 | source_loss_avg.reset() 195 | adapting_loss_avg.reset() 196 | 197 | if iteration == num_iter: 198 | log += f"Stop training at iteration {iteration}!\n" 199 | print(f"Stop training at iteration {iteration}!\n") 200 | break 201 | 202 | # training part 203 | model.train() 204 | """ Loss of labeled data (source domain) """ 205 | try: 206 | images_source_tensor, labels_source = next(source_loader_iter) 207 | except StopIteration: 208 | del source_loader_iter 209 | source_loader_iter = iter(source_loader) 210 | images_source_tensor, labels_source = next(source_loader_iter) 211 | 212 | images_source = images_source_tensor.to(device) 213 | labels_source_index, labels_source_length = converter.encode( 214 | labels_source, batch_max_length=args.batch_max_length 215 | ) 216 | 217 | batch_source_size = len(labels_source) 218 | if args.Prediction == "CTC": 219 | preds_source = model(images_source) 220 | preds_source_size = torch.IntTensor([preds_source.size(1)] * batch_source_size) 221 | preds_source_log_softmax = preds_source.log_softmax(2).permute(1, 0, 2) 222 | loss_source = criterion(preds_source_log_softmax, labels_source_index, preds_source_size, labels_source_length) 223 | else: 224 | preds_source = model(images_source, labels_source_index[:, :-1]) # align with Attention.forward 225 | target_source = labels_source_index[:, 1:] # without [SOS] Symbol 226 | loss_source = criterion( 227 | preds_source.view(-1, preds_source.shape[-1]), target_source.contiguous().view(-1) 228 | ) 229 | 230 | """ Loss of pseudo-labeled data (target domain) """ 231 | try: 232 | images_adapting_tensor, labels_adapting = next(adapting_loader_iter) 233 | except StopIteration: 234 | del adapting_loader_iter 235 | adapting_loader_iter = iter(adapting_loader) 236 | images_adapting_tensor, labels_adapting = next(adapting_loader_iter) 237 | 238 | images_adapting = images_adapting_tensor.to(device) 239 | labels_adapting_index, labels_adapting_length = converter.encode( 240 | labels_adapting, batch_max_length=args.batch_max_length 241 | ) 242 | 243 | batch_adapting_size = len(labels_adapting) 244 | if args.Prediction == "CTC": 245 | preds_adapting = model(images_adapting) 246 | preds_adapting_size = torch.IntTensor([preds_adapting.size(1)] * batch_adapting_size) 247 | preds_adapting_log_softmax = preds_adapting.log_softmax(2).permute(1, 0, 2) 248 | loss_adapting = criterion(preds_adapting_log_softmax, labels_adapting_index, preds_adapting_size, labels_adapting_length) 249 | else: 250 | preds_adapting = model(images_adapting, labels_adapting_index[:, :-1]) # align with Attention.forward 251 | target_adapting = labels_adapting_index[:, 1:] # without [SOS] Symbol 252 | loss_adapting = criterion( 253 | preds_adapting.view(-1, preds_adapting.shape[-1]), target_adapting.contiguous().view(-1) 254 | ) 255 | 256 | loss = (1 - mean_conf) * loss_source + loss_adapting * mean_conf 257 | 258 | model.zero_grad(set_to_none=True) 259 | loss.backward() 260 | torch.nn.utils.clip_grad_norm_( 261 | model.parameters(), args.grad_clip 262 | ) # gradient clipping with 5 (Default) 263 | optimizer.step() 264 | 265 | train_loss_avg.add(loss) 266 | source_loss_avg.add(loss_source) 267 | adapting_loss_avg.add(loss_adapting) 268 | 269 | scheduler.step() 270 | 271 | model.eval() 272 | 273 | # save model 274 | # torch.save( 275 | # model.state_dict(), 276 | # f"trained_model/{relative_path}/{args.model}_round{round}.pth", 277 | # ) 278 | 279 | # save log 280 | log += f"Model is saved at trained_model/{relative_path}/{args.model}_round{round}.pth" 281 | print(log, file= open(f"log/{relative_path}/log_self_training_round{round}.txt", "w")) 282 | 283 | # free cache 284 | torch.cuda.empty_cache() 285 | 286 | 287 | def main(args): 288 | dashed_line = "-" * 80 289 | main_log = "" 290 | 291 | if args.method == "HDGE": 292 | if args.beta == -1: 293 | raise ValueError("Please set beta value for HDGE method.") 294 | relative_path = f"{args.method}/{args.beta}_beta/{args.num_subsets}_subsets" 295 | else: 296 | if args.discriminator == "": 297 | raise ValueError("Please set discriminator for DD method.") 298 | relative_path = f"{args.method}/{args.discriminator}/{args.num_subsets}_subsets" 299 | 300 | # to make directories for saving models and logs if not exist 301 | os.makedirs(f"log/{relative_path}/", exist_ok=True) 302 | os.makedirs(f"trained_model/{relative_path}/", exist_ok=True) 303 | 304 | # load source domain data 305 | print(dashed_line) 306 | main_log = dashed_line + "\n" 307 | print("Load source domain data...") 308 | main_log += "Load source domain data...\n" 309 | 310 | source_data, source_data_log = hierarchical_dataset(args.source_data, args) 311 | source_loader = get_dataloader(args, source_data, args.batch_size, shuffle=True, aug=args.aug) 312 | 313 | print(source_data_log, end="") 314 | main_log += source_data_log 315 | 316 | # load target domain data (raw) 317 | print(dashed_line) 318 | main_log += dashed_line + "\n" 319 | print("Load target domain data...") 320 | main_log += "Load target domain data...\n" 321 | 322 | target_data, target_data_log= hierarchical_dataset(args.target_data, args, mode="raw") 323 | 324 | print(target_data_log, end="") 325 | main_log += target_data_log 326 | 327 | # load validation data 328 | print(dashed_line) 329 | main_log += dashed_line + "\n" 330 | print("Load validation data...") 331 | main_log += "Load validation data...\n" 332 | 333 | valid_data, valid_data_log = hierarchical_dataset(args.valid_data, args) 334 | valid_loader = get_dataloader(args, valid_data, args.batch_size_val, shuffle=False) # "True" to check training progress with validation function. 335 | 336 | print(valid_data_log, end="") 337 | main_log += valid_data_log 338 | 339 | """ Model configuration """ 340 | if args.Prediction == "CTC": 341 | converter = CTCLabelConverter(args.character) 342 | else: 343 | converter = AttnLabelConverter(args.character) 344 | args.sos_token_index = converter.dict["[SOS]"] 345 | args.eos_token_index = converter.dict["[EOS]"] 346 | args.num_class = len(converter.character) 347 | 348 | # setup model 349 | print(dashed_line) 350 | main_log += dashed_line + "\n" 351 | print("Init model") 352 | main_log += "Init model\n" 353 | model = Model(args) 354 | 355 | # data parallel for multi-GPU 356 | model = torch.nn.DataParallel(model).to(device) 357 | model.train() 358 | 359 | # load pretrained model 360 | try: 361 | pretrained = torch.load(args.saved_model) 362 | model.load_state_dict(pretrained) 363 | except: 364 | raise ValueError("The pre-trained weights do not match the model! Carefully check!") 365 | 366 | torch.save( 367 | pretrained, 368 | f"trained_model/{relative_path}/{args.model}_round0.pth" 369 | ) 370 | print(f"Load pretrained model from {args.saved_model}") 371 | main_log += f"Load pretrained model from {args.saved_model}\n" 372 | 373 | # setup loss 374 | if args.Prediction == "CTC": 375 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 376 | else: 377 | # ignore [PAD] token 378 | criterion = torch.nn.CrossEntropyLoss(ignore_index=converter.dict["[PAD]"]).to(device) 379 | 380 | # filter that only require gradient descent 381 | filtered_parameters = [] 382 | params_num = [] 383 | for p in filter(lambda p: p.requires_grad, model.parameters()): 384 | filtered_parameters.append(p) 385 | params_num.append(np.prod(p.size())) 386 | print(f"Trainable params num: {sum(params_num)}") 387 | main_log += f"Trainable params num: {sum(params_num)}\n" 388 | 389 | """ Final options """ 390 | print("------------ Options -------------") 391 | main_log += "------------ Options -------------\n" 392 | opt = vars(args) 393 | for k, v in opt.items(): 394 | if str(k) == "character" and len(str(v)) > 500: 395 | print(f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}") 396 | main_log += f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}\n" 397 | else: 398 | print(f"{str(k)}: {str(v)}") 399 | main_log += f"{str(k)}: {str(v)}\n" 400 | print(dashed_line) 401 | main_log += dashed_line + "\n" 402 | print("Start Adapting (Scene Text Recognition - STR)...\n") 403 | main_log += "Start Adapting (Scene Text Recognition - STR)...\n" 404 | 405 | for round in range(args.num_subsets): 406 | 407 | print(f"Round {round+1}/{args.num_subsets}: \n") 408 | main_log += f"\nRound {round+1}/{args.num_subsets}: \n" 409 | 410 | # load best model of previous round 411 | print(f"- Load best model of round {round}.") 412 | main_log += f"- Load best model of round {round}. \n" 413 | model.load_state_dict( 414 | torch.load(f"trained_model/{relative_path}/{args.model}_round{round}.pth") 415 | ) 416 | 417 | # select subset 418 | try: 419 | adapting_list = list(np.load(f"stratify/{relative_path}/subset_{round + 1}.npy")) 420 | except: 421 | raise ValueError(f"stratify/{relative_path}/subset_{round + 1}.npy not found.") 422 | 423 | # assign pseudo labels 424 | print("- Pseudo labeling...\n") 425 | main_log += "- Pseudo labeling...\n" 426 | list_adapt_data, list_pseudo_data, mean_conf = pseudo_labeling( 427 | args, model, converter, target_data, adapting_list, round + 1 428 | ) 429 | 430 | print(f"- Number of adapting data: {len(list_adapt_data)}") 431 | main_log += f"- Number of adapting data: {len(list_adapt_data)} \n" 432 | print(f"- Mean of confidence score: {mean_conf}") 433 | main_log += f"- Mean of confidence scores: {mean_conf} \n" 434 | 435 | # restrict adapting data 436 | adapting_data = Subset(target_data, list_adapt_data) 437 | adapting_data = Pseudolabel_Dataset(adapting_data, list_pseudo_data) 438 | 439 | # get dataloader 440 | adapting_loader = get_dataloader(args, adapting_data, args.batch_size, shuffle=True, aug=args.aug) 441 | 442 | # self-training 443 | print(dashed_line) 444 | print("- Start Self-Training (Scene Text Recognition - STR)...") 445 | main_log += "\n- Start Self-Training (Scene Text Recognition - STR)..." 446 | 447 | self_training_start = time.time() 448 | if (round >= args.checkpoint): 449 | self_training(args, filtered_parameters, model, criterion, converter, relative_path, \ 450 | source_loader, valid_loader, adapting_loader, mean_conf, round + 1) 451 | self_training_end = time.time() 452 | 453 | print(f"Processing time: {self_training_end - self_training_start}s") 454 | print(f"Model is saved at trained_model/{relative_path}/{args.model}_round{round}.pth") 455 | print(f"Saved log for adapting round to: 'log/{relative_path}/log_self_training_round{round + 1}.txt'") 456 | 457 | main_log += f"\nProcessing time: {self_training_end - self_training_start}s" 458 | main_log += f"\nModel is saved at trained_model/{relative_path}/{args.model}_round{round}.pth" 459 | main_log += f"\nSaved log for adapting round to: 'log/{relative_path}/log_self_training_round{round + 1}.txt'" 460 | main_log += "\n" + dashed_line + "\n" 461 | 462 | print(dashed_line * 3) 463 | 464 | # free cache 465 | torch.cuda.empty_cache() 466 | 467 | # save log 468 | print(main_log, file= open(f"log/{args.method}/log_StrDA.txt", "w")) 469 | 470 | return 471 | 472 | if __name__ == "__main__": 473 | """ Argument """ 474 | parser = argparse.ArgumentParser() 475 | config = load_config("config/STR.yaml") 476 | parser.set_defaults(**config) 477 | 478 | parser.add_argument( 479 | "--source_data", default="data/train/synth/", help="path to source dataset", 480 | ) 481 | parser.add_argument( 482 | "--target_data", default="data/train/real/", help="path to adaptation dataset", 483 | ) 484 | parser.add_argument( 485 | "--valid_data", default="data/val/", help="path to validation dataset", 486 | ) 487 | parser.add_argument( 488 | "--saved_model", 489 | required=True, 490 | help="path to source-trained model for adaptation", 491 | ) 492 | parser.add_argument( 493 | "--batch_size", type=int, default=128, help="input batch size", 494 | ) 495 | parser.add_argument( 496 | "--batch_size_val", type=int, default=512, help="input batch size val", 497 | ) 498 | parser.add_argument( 499 | "--total_iter", type=int, default=50000, help="number of iterations to train for", 500 | ) 501 | parser.add_argument( 502 | "--val_interval", type=int, default=500, help="interval between each validation", 503 | ) 504 | parser.add_argument( 505 | "--NED", action="store_true", help="for Normalized edit_distance", 506 | ) 507 | """ Model Architecture """ 508 | parser.add_argument( 509 | "--model", 510 | type=str, 511 | required=True, 512 | help="CRNN|TRBA", 513 | ) 514 | """ Adaptation """ 515 | parser.add_argument( 516 | "--num_subsets", 517 | type=int, 518 | required=True, 519 | help="hyper-parameter n, number of subsets partitioned from target domain data", 520 | ) 521 | parser.add_argument( 522 | "--method", 523 | required=True, 524 | help="select Domain Stratifying method, DD|HDGE", 525 | ) 526 | parser.add_argument("--discriminator", default="", help="for DD method, choose discriminator, CRNN|TRBA") 527 | parser.add_argument("--beta", type=float, default=-1, help="for HDGE method, hyper-parameter beta, 0