├── e-commercial
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── build.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── samplers.cpython-37.pyc
│ │ ├── build_new.cpython-37.pyc
│ │ ├── zipreader.cpython-37.pyc
│ │ ├── saliency_loader.cpython-37.pyc
│ │ ├── cached_image_folder.cpython-37.pyc
│ │ └── saliency_loader_new.cpython-37.pyc
│ ├── samplers.py
│ ├── saliency_loader_new.py
│ ├── zipreader.py
│ ├── build_new.py
│ ├── build.py
│ ├── saliency_loader.py
│ └── cached_image_folder.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── loss.cpython-37.pyc
│ │ ├── build.cpython-37.pyc
│ │ ├── metric.cpython-37.pyc
│ │ ├── resnet.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── basemodel.cpython-37.pyc
│ │ ├── imgproc.cpython-37.pyc
│ │ ├── networks.cpython-37.pyc
│ │ ├── swin_mlp.cpython-37.pyc
│ │ ├── craft_utils.cpython-37.pyc
│ │ ├── swin_transformer.cpython-37.pyc
│ │ ├── saliency_detector.cpython-37.pyc
│ │ └── sswin_transformer.cpython-37.pyc
│ ├── build.py
│ ├── imgproc.py
│ ├── basemodel.py
│ ├── metric.py
│ ├── loss.py
│ ├── craft_utils.py
│ └── saliency_detector.py
├── .idea
│ ├── .gitignore
│ ├── vcs.xml
│ ├── inspectionProfiles
│ │ └── profiles_settings.xml
│ ├── modules.xml
│ └── e-commercial.iml
├── __pycache__
│ ├── config.cpython-37.pyc
│ ├── logger.cpython-37.pyc
│ ├── utils.cpython-37.pyc
│ ├── config.cpython-310.pyc
│ ├── optimizer.cpython-37.pyc
│ └── lr_scheduler.cpython-37.pyc
├── configs
│ ├── test.yaml
│ ├── finetune.yaml
│ └── sswin.yaml
├── logger.py
├── optimizer.py
├── compare.py
├── lr_scheduler.py
├── utils.py
├── config.py
├── main.py
├── train.py
├── metrics.py
└── main_cnn.py
├── .gitattributes
├── LICENSE
├── README.md
└── .history
└── e-commercial
├── config_20240408005816.py
├── config_20240330210315.py
└── config_20240408005800.py
/e-commercial/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_loader
--------------------------------------------------------------------------------
/e-commercial/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_model
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/e-commercial/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/e-commercial/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/__pycache__/logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/logger.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/e-commercial/__pycache__/optimizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/optimizer.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/data/__pycache__/build.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/build.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/__pycache__/lr_scheduler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/lr_scheduler.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/data/__pycache__/samplers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/samplers.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/build.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/build.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/metric.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/metric.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/data/__pycache__/build_new.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/build_new.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/data/__pycache__/zipreader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/zipreader.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/basemodel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/basemodel.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/imgproc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/imgproc.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/networks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/networks.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/swin_mlp.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/swin_mlp.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/craft_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/craft_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/data/__pycache__/saliency_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/saliency_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/swin_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/swin_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/data/__pycache__/cached_image_folder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/cached_image_folder.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/data/__pycache__/saliency_loader_new.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/saliency_loader_new.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/saliency_detector.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/saliency_detector.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/models/__pycache__/sswin_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/sswin_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/e-commercial/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/e-commercial/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/e-commercial/configs/test.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 896
3 | MODEL:
4 | TYPE: sswin
5 | NAME: densenet
6 | DROP_PATH_RATE: 0.2
7 | SWIN:
8 | EMBED_DIM: 96
9 | DEPTHS: [ 2, 2, 6, 2 ]
10 | NUM_HEADS: [ 3, 6, 12, 24 ]
11 | WINDOW_SIZE: 7
--------------------------------------------------------------------------------
/e-commercial/configs/finetune.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 896
3 | MODEL:
4 | TYPE: sswin
5 | NAME: densenet
6 | DROP_PATH_RATE: 0.2
7 | SWIN:
8 | EMBED_DIM: 96
9 | DEPTHS: [ 2, 2, 6, 2 ]
10 | NUM_HEADS: [ 3, 6, 12, 24 ]
11 | WINDOW_SIZE: 7
--------------------------------------------------------------------------------
/e-commercial/configs/sswin.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMG_SIZE: 896
3 | MODEL:
4 | TYPE: sswin
5 | NAME: sswin_patch4_window7
6 | DROP_PATH_RATE: 0.2
7 | SWIN:
8 | EMBED_DIM: 96
9 | DEPTHS: [ 2, 2, 6, 2 ]
10 | NUM_HEADS: [ 3, 6, 12, 24 ]
11 | WINDOW_SIZE: 7
--------------------------------------------------------------------------------
/e-commercial/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/e-commercial/.idea/e-commercial.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/e-commercial/data/samplers.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | import torch
6 |
7 |
8 | class SubsetRandomSampler(torch.utils.data.Sampler):
9 | r"""Samples elements randomly from a given list of indices, without replacement.
10 |
11 | Arguments:
12 | indices (sequence): a sequence of indices
13 | """
14 |
15 | def __init__(self, indices):
16 | self.epoch = 0
17 | self.indices = indices
18 |
19 | def __iter__(self):
20 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
21 |
22 | def __len__(self):
23 | return len(self.indices)
24 |
25 | def set_epoch(self, epoch):
26 | self.epoch = epoch
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 YiFei Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/e-commercial/models/build.py:
--------------------------------------------------------------------------------
1 | from .sswin_transformer import SSwinTransformer
2 |
3 |
4 | def build_model(config):
5 | model_type = config.MODEL.TYPE
6 | if model_type == 'sswin':
7 | model = SSwinTransformer(img_size=config.DATA.IMG_SIZE,
8 | patch_size=config.MODEL.SWIN.PATCH_SIZE,
9 | in_chans=config.MODEL.SWIN.IN_CHANS,
10 | num_classes=config.MODEL.NUM_CLASSES,
11 | embed_dim=config.MODEL.SWIN.EMBED_DIM,
12 | depths=config.MODEL.SWIN.DEPTHS,
13 | num_heads=config.MODEL.SWIN.NUM_HEADS,
14 | window_size=config.MODEL.SWIN.WINDOW_SIZE,
15 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
16 | qkv_bias=config.MODEL.SWIN.QKV_BIAS,
17 | qk_scale=config.MODEL.SWIN.QK_SCALE,
18 | drop_rate=config.MODEL.DROP_RATE,
19 | drop_path_rate=config.MODEL.DROP_PATH_RATE,
20 | ape=config.MODEL.SWIN.APE,
21 | patch_norm=config.MODEL.SWIN.PATCH_NORM,
22 | use_checkpoint=config.TRAIN.USE_CHECKPOINT,
23 | head=config.HEAD)
24 | else:
25 | raise NotImplementedError(f"Unkown model: {model_type}")
26 |
27 | return model
--------------------------------------------------------------------------------
/e-commercial/logger.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | import os
6 | import sys
7 | import logging
8 | import functools
9 | from termcolor import colored
10 |
11 |
12 | @functools.lru_cache()
13 | def create_logger(output_dir, dist_rank=0, name=''):
14 | # create logger
15 | logger = logging.getLogger(name)
16 | logger.setLevel(logging.DEBUG)
17 | logger.propagate = False
18 |
19 | # create formatter
20 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
21 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
22 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
23 |
24 | # create console handlers for master process
25 | if dist_rank == 0:
26 | console_handler = logging.StreamHandler(sys.stdout)
27 | console_handler.setLevel(logging.DEBUG)
28 | console_handler.setFormatter(
29 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
30 | logger.addHandler(console_handler)
31 |
32 | # create file handlers
33 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
34 | file_handler.setLevel(logging.DEBUG)
35 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
36 | logger.addHandler(file_handler)
37 |
38 | return logger
39 |
--------------------------------------------------------------------------------
/e-commercial/optimizer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | from torch import optim as optim
6 |
7 |
8 | def build_optimizer(config, model):
9 | """
10 | Build optimizer, set weight decay of normalization to 0 by default.
11 | """
12 | skip = {}
13 | skip_keywords = {}
14 | if hasattr(model, 'no_weight_decay'):
15 | skip = model.no_weight_decay()
16 | if hasattr(model, 'no_weight_decay_keywords'):
17 | skip_keywords = model.no_weight_decay_keywords()
18 | parameters = set_weight_decay(model, skip, skip_keywords)
19 |
20 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
21 | optimizer = None
22 | if opt_lower == 'sgd':
23 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
24 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
25 | elif opt_lower == 'adamw':
26 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
27 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
28 |
29 | return optimizer
30 |
31 |
32 | def set_weight_decay(model, skip_list=(), skip_keywords=()):
33 | has_decay = []
34 | no_decay = []
35 |
36 | for name, param in model.named_parameters():
37 | if not param.requires_grad:
38 | continue # frozen weights
39 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
40 | check_keywords_in_name(name, skip_keywords):
41 | no_decay.append(param)
42 | # print(f"{name} has no weight decay")
43 | else:
44 | has_decay.append(param)
45 | return [{'params': has_decay},
46 | {'params': no_decay, 'weight_decay': 0.}]
47 |
48 |
49 | def check_keywords_in_name(name, keywords=()):
50 | isin = False
51 | for keyword in keywords:
52 | if keyword in name:
53 | isin = True
54 | return isin
55 |
--------------------------------------------------------------------------------
/e-commercial/models/imgproc.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019-present NAVER Corp.
3 | MIT License
4 | """
5 |
6 | # -*- coding: utf-8 -*-
7 | import numpy as np
8 | from skimage import io
9 | import cv2
10 |
11 | def loadImage(img_file):
12 | img = io.imread(img_file) # RGB order
13 | if img.shape[0] == 2: img = img[0]
14 | if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
15 | if img.shape[2] == 4: img = img[:,:,:3]
16 | img = np.array(img)
17 |
18 | return img
19 |
20 | def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
21 | # should be RGB order
22 | img = in_img.copy().astype(np.float32)
23 |
24 | img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32)
25 | img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32)
26 | return img
27 |
28 | def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
29 | # should be RGB order
30 | img = in_img.copy()
31 | img *= variance
32 | img += mean
33 | img *= 255.0
34 | img = np.clip(img, 0, 255).astype(np.uint8)
35 | return img
36 |
37 | def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
38 | height, width, channel = img.shape
39 |
40 | # magnify image size
41 | target_size = mag_ratio * max(height, width)
42 |
43 | # set original image size
44 | if target_size > square_size:
45 | target_size = square_size
46 |
47 | ratio = target_size / max(height, width)
48 |
49 | target_h, target_w = int(height * ratio), int(width * ratio)
50 | proc = cv2.resize(img, (target_w, target_h), interpolation = interpolation)
51 |
52 |
53 | # make canvas and paste image
54 | target_h32, target_w32 = target_h, target_w
55 | if target_h % 32 != 0:
56 | target_h32 = target_h + (32 - target_h % 32)
57 | if target_w % 32 != 0:
58 | target_w32 = target_w + (32 - target_w % 32)
59 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
60 | resized[0:target_h, 0:target_w, :] = proc
61 | target_h, target_w = target_h32, target_w32
62 |
63 | size_heatmap = (int(target_w/2), int(target_h/2))
64 |
65 | return resized, ratio, size_heatmap
66 |
67 | def cvt2HeatmapImg(img):
68 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
69 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
70 | return img
71 |
--------------------------------------------------------------------------------
/e-commercial/compare.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms, utils
3 | from PIL import Image
4 | import argparse
5 | import os
6 | import glob
7 | from metrics import auc_judd, nss
8 |
9 | import torch.nn.functional as F
10 |
11 | # Read the data
12 | toTensor = transforms.Compose([
13 | transforms.Resize(720),
14 | transforms.ToTensor(),
15 | ])
16 |
17 |
18 | def read(imgDir: str) -> torch.Tensor:
19 | x = toTensor(Image.open(imgDir))
20 | return x
21 |
22 |
23 | def calAll(path1: str, path2: str, name: str):
24 | aucAll = 0
25 | nssAll = 0
26 | img_list1 = glob.glob(f"{path1}/*.jpg")
27 | tbs = len(img_list1)
28 | with torch.no_grad():
29 | for i, path in enumerate(img_list1):
30 | # imgList1 = []
31 | # imgList2 = []
32 | # print(f"[{i:05}|{tbs}] | {1}", end="\r")
33 | cn = path.split("/")[-1].split(".")[0]
34 | # breakpoint()
35 | img1 = read(path)
36 | img2 = read(f"{path2}/{cn}_fixMap.jpg")
37 | # if i in allwrong:
38 | # cnt += 1
39 | # print(f"weird case {i} jump {cnt=}")
40 | # continue
41 | # breakpoint()
42 | # imgList1.append(img1)
43 | # imgList2.append(img2)
44 | # img1 = torch.stack(imgList1).cuda()
45 | # img2 = torch.stack(imgList2).cuda()
46 |
47 | singauc = auc_judd(img1, img2)
48 | singnss = nss(img1, img2)
49 | print(
50 | f"[{i:05}|{tbs}]\t\t\t\tauc path2 {(singauc):.4f} | nss {(singnss):.4f}", end="\r")
51 | with open(f"record/{name}/record", "a") as f:
52 | f.write(f"[{i:05}|{tbs}]\t\t\t\tauc path2 {(singauc):.4f} | nss {(singnss):.4f}\n")
53 |
54 | aucAll += singauc
55 | nssAll += singnss
56 | print(f"{path1} {(aucAll / tbs):.4f} | {(nssAll / tbs):.4f}")
57 | with open(f"record/all_record.txt", "a") as f:
58 | f.write(f"{path1} | auc nss | {(aucAll / tbs):.4f} | {(nssAll / tbs):.4f}\n")
59 | # print(f"[INFO]: id ORI {(idCompAll / tbs):.4f} id {(idAll / tbs):.4f} psnr ORI {(qpsnrAll / tbs):.4f} psnr {(psnrCompAll / tbs):.4f} ")
60 |
61 |
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--path1', type=str)
64 | parser.add_argument('--path2', type=str, default="/mnt/hdd1/yifei/DATA/ECdata/ALLFIXATIONMAPS")
65 | args = parser.parse_args()
66 |
67 | PATH1 = args.path1
68 | PATH2 = args.path2
69 | dirlis = args.path1.split("/")
70 | name = f"{dirlis[-3]}_{dirlis[-2]}_{dirlis[-1]}"
71 | os.makedirs(f"record/{name}", exist_ok=True)
72 |
73 | calAll(PATH1, PATH2, name)
74 |
--------------------------------------------------------------------------------
/e-commercial/models/basemodel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 | import torch.nn.functional as F
5 |
6 |
7 | def conv_1x1_bn(inp, oup):
8 | return nn.Sequential(
9 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
10 | nn.BatchNorm2d(oup),
11 | nn.ReLU6(inplace=True)
12 | )
13 |
14 |
15 | class InvertedResidual(nn.Module):
16 | def __init__(self, inp, oup, stride, expand_ratio, omit_stride=False,
17 | no_res_connect=False, dropout=0., bn_momentum=0.1,
18 | batchnorm=None):
19 | super().__init__()
20 | self.out_channels = oup
21 | self.stride = stride
22 | self.omit_stride = omit_stride
23 | self.use_res_connect = not no_res_connect and\
24 | self.stride == 1 and inp == oup
25 | self.dropout = dropout
26 | actual_stride = self.stride if not self.omit_stride else 1
27 | if batchnorm is None:
28 | def batchnorm(num_features):
29 | return nn.BatchNorm2d(num_features, momentum=bn_momentum)
30 |
31 | assert actual_stride in [1, 2]
32 |
33 | hidden_dim = round(inp * expand_ratio)
34 | if expand_ratio == 1:
35 | modules = [
36 | # dw
37 | nn.Conv2d(hidden_dim, hidden_dim, 3, actual_stride, 1,
38 | groups=hidden_dim, bias=False),
39 | batchnorm(hidden_dim),
40 | nn.ReLU6(inplace=True),
41 | # pw-linear
42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
43 | batchnorm(oup),
44 | ]
45 | if self.dropout > 0:
46 | modules.append(nn.Dropout2d(self.dropout))
47 | self.conv = nn.Sequential(*modules)
48 | else:
49 | modules = [
50 | # pw
51 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
52 | batchnorm(hidden_dim),
53 | nn.ReLU6(inplace=True),
54 | # dw
55 | nn.Conv2d(hidden_dim, hidden_dim, 3, actual_stride, 1,
56 | groups=hidden_dim, bias=False),
57 | batchnorm(hidden_dim),
58 | nn.ReLU6(inplace=True),
59 | # pw-linear
60 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
61 | batchnorm(oup),
62 | ]
63 | if self.dropout > 0:
64 | modules.insert(3, nn.Dropout2d(self.dropout))
65 | self.conv = nn.Sequential(*modules)
66 | self._initialize_weights()
67 |
68 | def forward(self, x):
69 | if self.use_res_connect:
70 | return x + self.conv(x)
71 | else:
72 | return self.conv(x)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # E-commercial-dataset
2 | Dataset of electronic commercial image used for saliency.
3 |
4 | The dataset can be downloaded in https://www.dropbox.com/s/xsui782oy3kvjsm/E-commercial%20dataset.zip?dl=0.
5 |
6 | ## IMAGE
7 |
8 | Original images are saved in this path as *.jpg
9 |
10 | ## FIXATION
11 |
12 | Fixation maps are saved as *\_fixPts.jpg, while saliency maps are saved as *\_.fixMap.jpg.
13 |
14 | ## TEXT REGION
15 |
16 | The text detection results are stored in csv file, with the affinity score and region score.
17 |
18 | # SSwin-transformer Model added in Repo
19 | []()
20 | []()
21 | []()
22 | ## To-do list
23 | 1. -[x] Adding environment setting (you can use environment same as swin-transformer as temporary alternatives)
24 | 2. -[ ] Refine the code into efficient way
25 |
26 | ### Environment preparing
27 | - Clone this repo:
28 |
29 | ```bash
30 | git clone https://github.com/leafy-lee/E-commercial-dataset.git
31 | cd e-commercial
32 | ```
33 |
34 | - Create a conda virtual environment and activate it:
35 |
36 | ```bash
37 | conda create -n ecom python=3.7 -y
38 | conda activate ecom
39 | ```
40 |
41 | - Install `CUDA>=10.2` with `cudnn>=7` following
42 | the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
43 | - Install `PyTorch>=1.8.0` and `torchvision>=0.9.0` with `CUDA>=10.2`:
44 |
45 | ```bash
46 | conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch
47 | ```
48 |
49 | - Install `timm==0.4.12`:
50 |
51 | ```bash
52 | pip install timm==0.4.12
53 | ```
54 |
55 | - Install other requirements:
56 |
57 | ```bash
58 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 pyyaml scipy
59 | ```
60 | ### Evaluation
61 |
62 | To train the model, run:
63 |
64 | ```bash
65 | python train.py --batch-size 8 --cfg configs/sswin.yaml --data-path DATA/ECdata/ --dataset ecdata --head headname
66 | ```
67 |
68 |
69 |
70 | ### Evaluation
71 |
72 | To evaluate a trained model, run:
73 |
74 | ```bash
75 | python main.py --eval --cfg config --resume True --finetune ckpt --data-path data_dir
76 | ```
77 |
78 | ## Citation
79 | If you use this code, please cite
80 | ```
81 | @InProceedings{Jiang_2022_CVPR,
82 | author = {Jiang, Lai and Li, Yifei and Li, Shengxi and Xu, Mai and Lei, Se and Guo, Yichen and Huang, Bo},
83 | title = {Does Text Attract Attention on E-Commerce Images: A Novel Saliency Prediction Dataset and Method},
84 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
85 | month = {June},
86 | year = {2022},
87 | pages = {2088-2097}
88 | }
89 | ```
90 |
--------------------------------------------------------------------------------
/e-commercial/models/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def calCC(gtsAnns, resAnns, is_train):
6 | if is_train:
7 | gtsAnn = gtsAnns[0, ...].detach().clone()
8 | gtsAnn = torch.cat([gtsAnn, gtsAnn, gtsAnn], 0)
9 | gtsAnn = gtsAnn.cpu().float().detach().numpy()
10 | resAnn = resAnns[0, ...].detach().clone()
11 | resAnn = torch.cat([resAnn, resAnn, resAnn], 0)
12 | resAnn = resAnn.cpu().float().detach().numpy()
13 | fixationMap = gtsAnn - np.mean(gtsAnn)
14 | if np.max(fixationMap) > 0:
15 | fixationMap = fixationMap / np.std(fixationMap)
16 | salMap = resAnn - np.mean(resAnn)
17 | if np.max(salMap) > 0:
18 | salMap = salMap / np.std(salMap)
19 | return np.corrcoef(salMap.reshape(-1), fixationMap.reshape(-1))[0][1]
20 | else:
21 | cc = 0
22 | for idx, gtsAnn in enumerate(gtsAnns):
23 | gtsAnn = gtsAnn.detach().clone()
24 | gtsAnn = torch.cat([gtsAnn, gtsAnn, gtsAnn], 0)
25 | gtsAnn = gtsAnn.cpu().float().detach().numpy()
26 | resAnn = resAnns[idx].detach().clone()
27 | resAnn = torch.cat([resAnn, resAnn, resAnn], 0)
28 | resAnn = resAnn.cpu().float().detach().numpy()
29 | fixationMap = gtsAnn - np.mean(gtsAnn)
30 | if np.max(fixationMap) > 0:
31 | fixationMap = fixationMap / np.std(fixationMap)
32 | salMap = resAnn - np.mean(resAnn)
33 | if np.max(salMap) > 0:
34 | salMap = salMap / np.std(salMap)
35 | cc += np.corrcoef(salMap.reshape(-1), fixationMap.reshape(-1))[0][1]
36 | return cc / gtsAnns.size()[0]
37 |
38 |
39 | def calKL(gtsAnns, resAnns, is_train, eps=1e-7):
40 | if is_train:
41 | gtsAnn = gtsAnns[0, ...].detach().clone()
42 | gtsAnn = torch.cat([gtsAnn, gtsAnn, gtsAnn], 0)
43 | gtsAnn = gtsAnn.cpu().float().detach().numpy()
44 | resAnn = resAnns[0, ...].detach().clone()
45 | resAnn = torch.cat([resAnn, resAnn, resAnn], 0)
46 | resAnn = resAnn.cpu().float().detach().numpy()
47 | if np.sum(gtsAnn) > 0:
48 | gtsAnn = gtsAnn / np.sum(gtsAnn)
49 | if np.sum(resAnn) > 0:
50 | resAnn = resAnn / np.sum(resAnn)
51 | return np.sum(gtsAnn * np.log(eps + gtsAnn / (resAnn + eps)))
52 | else:
53 | kl = 0
54 | for idx, gtsAnn in enumerate(gtsAnns):
55 | gtsAnn = gtsAnn.detach().clone()
56 | gtsAnn = torch.cat([gtsAnn, gtsAnn, gtsAnn], 0)
57 | gtsAnn = gtsAnn.cpu().float().detach().numpy()
58 | resAnn = resAnns[idx].detach().clone()
59 | resAnn = torch.cat([resAnn, resAnn, resAnn], 0)
60 | resAnn = resAnn.cpu().float().detach().numpy()
61 | if np.sum(gtsAnn) > 0:
62 | gtsAnn = gtsAnn / np.sum(gtsAnn)
63 | if np.sum(resAnn) > 0:
64 | resAnn = resAnn / np.sum(resAnn)
65 | kl += np.sum(gtsAnn * np.log(eps + gtsAnn / (resAnn + eps)))
66 | return kl / gtsAnns.size()[0]
--------------------------------------------------------------------------------
/e-commercial/data/saliency_loader_new.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import torch, csv
4 | from PIL import Image
5 | from torchvision import transforms
6 | from torch.utils.data import Dataset
7 | import torch.nn.functional as F
8 | import random
9 | import glob
10 | import numpy as np
11 | random.seed(1)
12 | globaltest = []
13 |
14 | def ecommercedata(data_path, img_nums, is_train):
15 | global globaltest
16 | nums = [i for i in range(1, img_nums + 1)]
17 | if not globaltest:
18 | globaltest = random.sample(nums, img_nums // 10)
19 | test = globaltest
20 | train = list(set(nums) - set(test))
21 | if is_train:
22 | return EcommerceDataset(data_path=data_path, img_nums=(img_nums - (img_nums // 10)),
23 | lis=train)
24 | else:
25 | return EcommerceDataset(data_path=data_path, img_nums=(img_nums // 10),
26 | lis=test)
27 |
28 |
29 | class EcommerceDataset(Dataset):
30 | """Face Landmarks dataset."""
31 |
32 | def __init__(self, data_path, img_nums, transform=None, lis=[]):
33 | """
34 | Args:
35 | data_path (string): Path to the imgs file with saliency.
36 | img_nums (int): Total number of images to index.
37 | transform (callable, optional): Optional transform to be applied
38 | on a sample.
39 | """
40 | self.root_dir = data_path
41 | self.transform = transform
42 | self.data_len = img_nums
43 | self.lis = lis
44 |
45 | def __len__(self):
46 | return self.data_len
47 |
48 | def __getitem__(self, idx):
49 | if torch.is_tensor(idx):
50 | idx = idx.tolist()
51 | # print("idx", idx, "length", len(self.lis))
52 |
53 | img_name = 'ALLSTIMULI/' + str(self.lis[idx]) + '.jpg'
54 | saliency_name = 'ALLFIXATIONMAPS/' + str(self.lis[idx]) + '_fixMap.jpg'
55 | ocr_aff_name = 'OCR/affinity/' + str(self.lis[idx]) + '.csv'
56 | ocr_reg_name = 'OCR/region/' + str(self.lis[idx]) + '.csv'
57 | img_file = os.path.join(self.root_dir, img_name)
58 | saliency_file = os.path.join(self.root_dir, saliency_name)
59 | ocr_aff_file = os.path.join(self.root_dir, ocr_aff_name)
60 | ocr_reg_file = os.path.join(self.root_dir, ocr_reg_name)
61 |
62 | # images
63 | img = Image.open(img_file)
64 | # img = np.array(img, dtype=np.float32) # h, w, c
65 | torch_img = transforms.functional.to_tensor(img)
66 | torch_img = transforms.Resize(896)(torch_img)
67 | # saliency
68 | saliency = Image.open(saliency_file)
69 | # saliency = np.array(saliency, dtype=np.float32) # h, w, c
70 | torch_saliency = transforms.functional.to_tensor(saliency)
71 | torch_saliency = transforms.Resize(896)(torch_saliency)
72 | # ocr
73 | csv_aff_content = np.loadtxt(open(ocr_aff_file, "rb"), delimiter=",")
74 | csv_reg_content = np.loadtxt(open(ocr_reg_file, "rb"), delimiter=",")
75 | # to make pytorch happy in transforms
76 | csv_aff_image = np.expand_dims(csv_aff_content, axis=0)
77 | csv_reg_image = np.expand_dims(csv_reg_content, axis=0)
78 | torch_aff = torch.from_numpy(csv_aff_image).float()
79 | torch_reg = torch.from_numpy(csv_reg_image).float()
80 | gh_label = transforms.Resize(896)(torch_aff)
81 | gah_label = transforms.Resize(896)(torch_reg)
82 | # print('The sizes of gh_label and gah_label are :', gh_label.size(), gah_label.size())
83 | if self.transform:
84 | raise NotImplementedError("Not support any transform by far!")
85 | # sample = self.transform(sample)
86 |
87 | return torch_img, torch_saliency
88 |
--------------------------------------------------------------------------------
/e-commercial/data/zipreader.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer Zipreader
3 | # --------------------------------------------------------
4 |
5 | import os
6 | import zipfile
7 | import io
8 | import numpy as np
9 | from PIL import Image
10 | from PIL import ImageFile
11 |
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 |
14 |
15 | def is_zip_path(img_or_path):
16 | """judge if this is a zip path"""
17 | return '.zip@' in img_or_path
18 |
19 |
20 | class ZipReader(object):
21 | """A class to read zipped files"""
22 | zip_bank = dict()
23 |
24 | def __init__(self):
25 | super(ZipReader, self).__init__()
26 |
27 | @staticmethod
28 | def get_zipfile(path):
29 | zip_bank = ZipReader.zip_bank
30 | if path not in zip_bank:
31 | zfile = zipfile.ZipFile(path, 'r')
32 | zip_bank[path] = zfile
33 | return zip_bank[path]
34 |
35 | @staticmethod
36 | def split_zip_style_path(path):
37 | pos_at = path.index('@')
38 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
39 |
40 | zip_path = path[0: pos_at]
41 | folder_path = path[pos_at + 1:]
42 | folder_path = str.strip(folder_path, '/')
43 | return zip_path, folder_path
44 |
45 | @staticmethod
46 | def list_folder(path):
47 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
48 |
49 | zfile = ZipReader.get_zipfile(zip_path)
50 | folder_list = []
51 | for file_foler_name in zfile.namelist():
52 | file_foler_name = str.strip(file_foler_name, '/')
53 | if file_foler_name.startswith(folder_path) and \
54 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \
55 | file_foler_name != folder_path:
56 | if len(folder_path) == 0:
57 | folder_list.append(file_foler_name)
58 | else:
59 | folder_list.append(file_foler_name[len(folder_path) + 1:])
60 |
61 | return folder_list
62 |
63 | @staticmethod
64 | def list_files(path, extension=None):
65 | if extension is None:
66 | extension = ['.*']
67 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
68 |
69 | zfile = ZipReader.get_zipfile(zip_path)
70 | file_lists = []
71 | for file_foler_name in zfile.namelist():
72 | file_foler_name = str.strip(file_foler_name, '/')
73 | if file_foler_name.startswith(folder_path) and \
74 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
75 | if len(folder_path) == 0:
76 | file_lists.append(file_foler_name)
77 | else:
78 | file_lists.append(file_foler_name[len(folder_path) + 1:])
79 |
80 | return file_lists
81 |
82 | @staticmethod
83 | def read(path):
84 | zip_path, path_img = ZipReader.split_zip_style_path(path)
85 | zfile = ZipReader.get_zipfile(zip_path)
86 | data = zfile.read(path_img)
87 | return data
88 |
89 | @staticmethod
90 | def imread(path):
91 | zip_path, path_img = ZipReader.split_zip_style_path(path)
92 | zfile = ZipReader.get_zipfile(zip_path)
93 | data = zfile.read(path_img)
94 | try:
95 | im = Image.open(io.BytesIO(data))
96 | except:
97 | print("ERROR IMG LOADED: ", path_img)
98 | random_img = np.random.rand(224, 224, 3) * 255
99 | im = Image.fromarray(np.uint8(random_img))
100 | return im
101 |
--------------------------------------------------------------------------------
/e-commercial/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | import torch
6 | from timm.scheduler.cosine_lr import CosineLRScheduler
7 | from timm.scheduler.step_lr import StepLRScheduler
8 | from timm.scheduler.scheduler import Scheduler
9 |
10 |
11 | def build_scheduler(config, optimizer, n_iter_per_epoch):
12 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
13 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
14 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
15 |
16 | lr_scheduler = None
17 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
18 | lr_scheduler = CosineLRScheduler(
19 | optimizer,
20 | t_initial=num_steps,
21 | t_mul=1.,
22 | lr_min=config.TRAIN.MIN_LR,
23 | warmup_lr_init=config.TRAIN.WARMUP_LR,
24 | warmup_t=warmup_steps,
25 | cycle_limit=1,
26 | t_in_epochs=False,
27 | )
28 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
29 | lr_scheduler = LinearLRScheduler(
30 | optimizer,
31 | t_initial=num_steps,
32 | lr_min_rate=0.01,
33 | warmup_lr_init=config.TRAIN.WARMUP_LR,
34 | warmup_t=warmup_steps,
35 | t_in_epochs=False,
36 | )
37 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
38 | lr_scheduler = StepLRScheduler(
39 | optimizer,
40 | decay_t=decay_steps,
41 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
42 | warmup_lr_init=config.TRAIN.WARMUP_LR,
43 | warmup_t=warmup_steps,
44 | t_in_epochs=False,
45 | )
46 |
47 | return lr_scheduler
48 |
49 |
50 | class LinearLRScheduler(Scheduler):
51 | def __init__(self,
52 | optimizer: torch.optim.Optimizer,
53 | t_initial: int,
54 | lr_min_rate: float,
55 | warmup_t=0,
56 | warmup_lr_init=0.,
57 | t_in_epochs=True,
58 | noise_range_t=None,
59 | noise_pct=0.67,
60 | noise_std=1.0,
61 | noise_seed=42,
62 | initialize=True,
63 | ) -> None:
64 | super().__init__(
65 | optimizer, param_group_field="lr",
66 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
67 | initialize=initialize)
68 |
69 | self.t_initial = t_initial
70 | self.lr_min_rate = lr_min_rate
71 | self.warmup_t = warmup_t
72 | self.warmup_lr_init = warmup_lr_init
73 | self.t_in_epochs = t_in_epochs
74 | if self.warmup_t:
75 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
76 | super().update_groups(self.warmup_lr_init)
77 | else:
78 | self.warmup_steps = [1 for _ in self.base_values]
79 |
80 | def _get_lr(self, t):
81 | if t < self.warmup_t:
82 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
83 | else:
84 | t = t - self.warmup_t
85 | total_t = self.t_initial - self.warmup_t
86 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
87 | return lrs
88 |
89 | def get_epoch_values(self, epoch: int):
90 | if self.t_in_epochs:
91 | return self._get_lr(epoch)
92 | else:
93 | return None
94 |
95 | def get_update_values(self, num_updates: int):
96 | if not self.t_in_epochs:
97 | return self._get_lr(num_updates)
98 | else:
99 | return None
100 |
--------------------------------------------------------------------------------
/e-commercial/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | ####################################################################
7 | # -------------------------- Losses --------------------------
8 | ####################################################################
9 |
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 |
15 |
16 | class Maploss(nn.Module):
17 | def __init__(self, use_gpu = True):
18 |
19 | super(Maploss,self).__init__()
20 |
21 | def single_image_loss(self, pre_loss, loss_label):
22 | batch_size = pre_loss.shape[0]
23 | sum_loss = torch.mean(pre_loss.view(-1))*0
24 | pre_loss = pre_loss.view(batch_size, -1)
25 | loss_label = loss_label.view(batch_size, -1)
26 | internel = batch_size
27 | # print(loss_label.shape, pre_loss.shape)
28 | for i in range(batch_size):
29 | average_number = 0
30 | loss = torch.mean(pre_loss.view(-1)) * 0
31 | positive_pixel = len(pre_loss[i][(loss_label[i] >= 0.1)])
32 | average_number += positive_pixel
33 | if positive_pixel != 0:
34 | posi_loss = torch.mean(pre_loss[i][(loss_label[i] >= 0.1)])
35 | sum_loss += posi_loss
36 | if len(pre_loss[i][(loss_label[i] < 0.1)]) < 3*positive_pixel:
37 | nega_loss = torch.mean(pre_loss[i][(loss_label[i] < 0.1)])
38 | average_number += len(pre_loss[i][(loss_label[i] < 0.1)])
39 | else:
40 | nega_loss = torch.mean(torch.topk(pre_loss[i][(loss_label[i] < 0.1)], 3*positive_pixel)[0])
41 | average_number += 3*positive_pixel
42 | sum_loss += nega_loss
43 | else:
44 | nega_loss = torch.mean(torch.topk(pre_loss[i], 500)[0])
45 | average_number += 500
46 | sum_loss += nega_loss
47 | #sum_loss += loss/average_number
48 |
49 | return sum_loss
50 |
51 |
52 |
53 | def forward(self, gh_label, gah_label, p_gh, p_gah, mask):
54 | gh_label = gh_label
55 | gah_label = gah_label
56 | p_gh = p_gh
57 | p_gah = p_gah
58 | loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
59 | mask = mask.squeeze(1)
60 |
61 | assert p_gh.size() == gh_label.size() and p_gah.size() == gah_label.size()
62 | loss1 = loss_fn(p_gh, gh_label)
63 | loss2 = loss_fn(p_gah, gah_label)
64 | # print("loss1.shape, mask.shape", loss1.shape, mask.shape)
65 | # print("loss2.shape, mask.shape", loss2.shape, mask.shape)
66 | loss_g = torch.mul(loss1, mask)
67 | loss_a = torch.mul(loss2, mask)
68 | # print("loss shape", loss_g.shape, loss_a.shape, gah_label.shape, gh_label.shape)
69 |
70 | char_loss = self.single_image_loss(loss_g, gh_label)
71 | affi_loss = self.single_image_loss(loss_a, gah_label)
72 | return char_loss/loss_g.shape[0] + affi_loss/loss_a.shape[0]
73 |
74 |
75 | def NormAffine(mat, eps=1e-7,
76 | method='sum'): # tensor [batch_size, channels, image_height, image_width] normalize each fea map;
77 | matdim = len(mat.size())
78 | if method == 'sum':
79 | tempsum = torch.sum(mat, dim=(matdim - 1, matdim - 2), keepdim=True) + eps
80 | out = mat / tempsum
81 | elif method == 'one':
82 | (tempmin, _) = torch.min(mat, dim=matdim - 1, keepdim=True)
83 | (tempmin, _) = torch.min(tempmin, dim=matdim - 2, keepdim=True)
84 | tempmat = mat - tempmin
85 | (tempmax, _) = torch.max(tempmat, dim=matdim - 1, keepdim=True)
86 | (tempmax, _) = torch.max(tempmax, dim=matdim - 2, keepdim=True)
87 | tempmax = tempmax + eps
88 | out = tempmat / tempmax
89 | else:
90 | raise NotImplementedError('Map method [%s] is not implemented' % method)
91 | return out
92 |
93 |
94 | def KL_loss(out, gt):
95 | assert out.size() == gt.size()
96 | out = NormAffine(out, eps=1e-7, method='sum')
97 | gt = NormAffine(gt, eps=1e-7, method='sum')
98 | loss = torch.sum(gt * torch.log(1e-7 + gt / (out + 1e-7)))
99 | loss = loss / out.size()[0]
100 | return loss
101 |
--------------------------------------------------------------------------------
/e-commercial/data/build_new.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import torch
10 | import numpy as np
11 | from torchvision import datasets, transforms
12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13 | from timm.data import Mixup
14 | from timm.data import create_transform
15 | from timm.data.transforms import _pil_interp
16 |
17 | from .cached_image_folder import CachedImageFolder
18 | from .samplers import SubsetRandomSampler
19 | from data.saliency_loader_new import ecommercedata
20 |
21 |
22 | def build_loader(config):
23 | config.defrost()
24 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
25 | config.freeze()
26 | print(f"local rank {config.LOCAL_RANK} successfully build train dataset")
27 | dataset_val, _ = build_dataset(is_train=False, config=config)
28 | print(f"local rank {config.LOCAL_RANK} successfully build val dataset")
29 |
30 | indices_train = np.arange(len(dataset_train))
31 | sampler_train = SubsetRandomSampler(indices_train)
32 | indices_val = np.arange(len(dataset_val))
33 | sampler_val = SubsetRandomSampler(indices_val)
34 |
35 | data_loader_train = torch.utils.data.DataLoader(
36 | dataset_train, sampler=sampler_train,
37 | batch_size=config.DATA.BATCH_SIZE,
38 | num_workers=config.DATA.NUM_WORKERS,
39 | pin_memory=config.DATA.PIN_MEMORY,
40 | drop_last=True,
41 | )
42 |
43 | data_loader_val = torch.utils.data.DataLoader(
44 | dataset_val, sampler=sampler_val,
45 | batch_size=config.DATA.BATCH_SIZE,
46 | shuffle=False,
47 | num_workers=config.DATA.NUM_WORKERS,
48 | pin_memory=config.DATA.PIN_MEMORY,
49 | drop_last=False
50 | )
51 |
52 | # setup mixup / cutmix
53 | mixup_fn = None
54 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
55 | if mixup_active:
56 | mixup_fn = Mixup(
57 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
58 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
59 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
60 |
61 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
62 |
63 |
64 | def build_dataset(is_train, config):
65 | transform = build_transform(is_train, config)
66 | if config.DATA.DATASET == 'ecdata':
67 | dataset = ecommercedata(config.DATA.DATA_PATH, config.DATA.DATANUM, is_train)
68 | nb_classes = 1
69 | else:
70 | raise NotImplementedError("We only support ImageNet Now.")
71 |
72 | return dataset, nb_classes
73 |
74 |
75 | def build_transform(is_train, config):
76 | resize_im = config.DATA.IMG_SIZE > 32
77 | if is_train:
78 | # this should always dispatch to transforms_imagenet_train
79 | transform = create_transform(
80 | input_size=config.DATA.IMG_SIZE,
81 | is_training=True,
82 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
83 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
84 | re_prob=config.AUG.REPROB,
85 | re_mode=config.AUG.REMODE,
86 | re_count=config.AUG.RECOUNT,
87 | interpolation=config.DATA.INTERPOLATION,
88 | )
89 | if not resize_im:
90 | # replace RandomResizedCropAndInterpolation with
91 | # RandomCrop
92 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
93 | return transform
94 |
95 | t = []
96 | if resize_im:
97 | if config.TEST.CROP:
98 | size = int((256 / 224) * config.DATA.IMG_SIZE)
99 | t.append(
100 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
101 | # to maintain same ratio w.r.t. 224 images
102 | )
103 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
104 | else:
105 | t.append(
106 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
107 | interpolation=_pil_interp(config.DATA.INTERPOLATION))
108 | )
109 |
110 | t.append(transforms.ToTensor())
111 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
112 | return transforms.Compose(t)
113 |
--------------------------------------------------------------------------------
/e-commercial/data/build.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | import os
6 | import torch
7 | import numpy as np
8 | import torch.distributed as dist
9 | from torchvision import datasets, transforms
10 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11 | from timm.data import Mixup
12 | from timm.data import create_transform
13 | from timm.data.transforms import _pil_interp
14 |
15 | from .cached_image_folder import CachedImageFolder
16 | from .samplers import SubsetRandomSampler
17 | from data.saliency_loader import ecommercedata
18 | from data.saliency_loader import finetunedata, folderimagedata
19 |
20 |
21 | def build_loader(config, logger):
22 | config.defrost()
23 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config, logger=logger)
24 | config.freeze()
25 | # print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
26 | dataset_val, _ = build_dataset(is_train=False, config=config, logger=logger)
27 | # print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
28 | '''
29 | num_tasks = dist.get_world_size()
30 | global_rank = dist.get_rank()
31 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
32 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
33 | sampler_train = SubsetRandomSampler(indices)
34 | else:
35 | sampler_train = torch.utils.data.DistributedSampler(
36 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
37 | )
38 |
39 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
40 | sampler_val = SubsetRandomSampler(indices)
41 | '''
42 | data_loader_train = torch.utils.data.DataLoader(
43 | dataset_train,
44 | batch_size=config.DATA.BATCH_SIZE,
45 | num_workers=config.DATA.NUM_WORKERS,
46 | pin_memory=config.DATA.PIN_MEMORY,
47 | drop_last=True,
48 | )
49 |
50 | data_loader_val = torch.utils.data.DataLoader(
51 | dataset_val,
52 | batch_size=config.DATA.BATCH_SIZE,
53 | shuffle=False,
54 | num_workers=config.DATA.NUM_WORKERS,
55 | pin_memory=config.DATA.PIN_MEMORY,
56 | drop_last=False
57 | )
58 |
59 | # setup mixup / cutmix
60 | mixup_fn = None
61 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
62 | if mixup_active:
63 | mixup_fn = Mixup(
64 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
65 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
66 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
67 |
68 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
69 | # return dataset_train, data_loader_train, mixup_fn
70 |
71 |
72 | def build_dataset(is_train, config, logger=None):
73 | transform = build_transform(is_train, config)
74 | if config.DATA.DATASET == 'ecdata':
75 | dataset = ecommercedata(config.DATA.DATA_PATH, config.DATA.DATANUM, is_train, logger=logger)
76 | nb_classes = 1
77 | elif config.DATA.DATASET == 'finetunedata':
78 | dataset = finetunedata(config.DATA.DATA_PATH)
79 | nb_classes = 1
80 | elif config.DATA.DATASET == 'folderimagedata':
81 | dataset = folderimagedata(config.DATA.DATA_PATH)
82 | nb_classes = 1
83 | else:
84 | raise NotImplementedError("We only support ImageNet Now.")
85 |
86 | return dataset, nb_classes
87 |
88 |
89 | def build_transform(is_train, config):
90 | resize_im = config.DATA.IMG_SIZE > 32
91 | if is_train:
92 | # this should always dispatch to transforms_imagenet_train
93 | transform = create_transform(
94 | input_size=config.DATA.IMG_SIZE,
95 | is_training=True,
96 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
97 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
98 | re_prob=config.AUG.REPROB,
99 | re_mode=config.AUG.REMODE,
100 | re_count=config.AUG.RECOUNT,
101 | interpolation=config.DATA.INTERPOLATION,
102 | )
103 | if not resize_im:
104 | # replace RandomResizedCropAndInterpolation with
105 | # RandomCrop
106 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
107 | return transform
108 |
109 | t = []
110 | if resize_im:
111 | if config.TEST.CROP:
112 | size = int((256 / 224) * config.DATA.IMG_SIZE)
113 | t.append(
114 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
115 | # to maintain same ratio w.r.t. 224 images
116 | )
117 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
118 | else:
119 | t.append(
120 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
121 | interpolation=_pil_interp(config.DATA.INTERPOLATION))
122 | )
123 |
124 | t.append(transforms.ToTensor())
125 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
126 | return transforms.Compose(t)
127 |
--------------------------------------------------------------------------------
/e-commercial/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 |
6 | import os
7 | import torch
8 | import torch.distributed as dist
9 |
10 | try:
11 | # noinspection PyUnresolvedReferences
12 | from apex import amp
13 | except ImportError:
14 | amp = None
15 |
16 |
17 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
18 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
19 | if config.MODEL.RESUME.startswith('https'):
20 | checkpoint = torch.hub.load_state_dict_from_url(
21 | config.MODEL.RESUME, map_location='cpu', check_hash=True)
22 | else:
23 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
24 | msg = model.load_state_dict(checkpoint['model'], strict=False)
25 | logger.info(msg)
26 | max_accuracy = 0.0
27 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
28 | optimizer.load_state_dict(checkpoint['optimizer'])
29 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
30 | config.defrost()
31 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
32 | config.freeze()
33 | if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
34 | amp.load_state_dict(checkpoint['amp'])
35 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
36 | if 'max_accuracy' in checkpoint:
37 | max_accuracy = checkpoint['max_accuracy']
38 |
39 | del checkpoint
40 | torch.cuda.empty_cache()
41 | return max_accuracy
42 |
43 | def load_checkpoint_finetune(config, model, optimizer, lr_scheduler, logger):
44 | logger.info(f"==============> finetune....................")
45 | print("loading from ", config.MODEL.FINETUNE)
46 | if config.MODEL.RESUME.startswith('https'):
47 | checkpoint = torch.hub.load_state_dict_from_url(
48 | config.MODEL.FINETUNE, map_location='cpu', check_hash=True)
49 | else:
50 | checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu')
51 | msg = model.load_state_dict(checkpoint['model'], strict=False)
52 | logger.info(msg)
53 | max_accuracy = 0.0
54 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
55 | optimizer.load_state_dict(checkpoint['optimizer'])
56 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
57 | config.defrost()
58 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
59 | config.freeze()
60 | logger.info(f"=> loaded successfully '{config.MODEL.FINETUNE}' (epoch {checkpoint['epoch']})")
61 |
62 | del checkpoint
63 | torch.cuda.empty_cache()
64 |
65 | def load_checkpoint_eval(config, model, optimizer, logger):
66 | logger.info(f"==============> finetune....................")
67 | print("loading from ", config.MODEL.FINETUNE)
68 | if config.MODEL.RESUME.startswith('https'):
69 | checkpoint = torch.hub.load_state_dict_from_url(
70 | config.MODEL.FINETUNE, map_location='cpu', check_hash=True)
71 | else:
72 | checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu')
73 | breakpoint()
74 | msg = model.load_state_dict(checkpoint['model'], strict=False)
75 | logger.info(msg)
76 | max_accuracy = 0.0
77 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
78 | config.defrost()
79 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
80 | config.freeze()
81 | logger.info(f"=> loaded successfully '{config.MODEL.FINETUNE}' (epoch {checkpoint['epoch']})")
82 |
83 | del checkpoint
84 | torch.cuda.empty_cache()
85 |
86 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):
87 | save_state = {'model': model.state_dict(),
88 | 'optimizer': optimizer.state_dict(),
89 | 'lr_scheduler': lr_scheduler.state_dict(),
90 | 'max_accuracy': max_accuracy,
91 | 'epoch': epoch,
92 | 'config': config}
93 | if config.AMP_OPT_LEVEL != "O0":
94 | save_state['amp'] = amp.state_dict()
95 |
96 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
97 | logger.info(f"{save_path} saving......")
98 | torch.save(save_state, save_path)
99 | logger.info(f"{save_path} saved !!!")
100 |
101 |
102 | def get_grad_norm(parameters, norm_type=2):
103 | if isinstance(parameters, torch.Tensor):
104 | parameters = [parameters]
105 | parameters = list(filter(lambda p: p.grad is not None, parameters))
106 | norm_type = float(norm_type)
107 | total_norm = 0
108 | for p in parameters:
109 | param_norm = p.grad.data.norm(norm_type)
110 | total_norm += param_norm.item() ** norm_type
111 | total_norm = total_norm ** (1. / norm_type)
112 | return total_norm
113 |
114 |
115 | def auto_resume_helper(output_dir):
116 | checkpoints = os.listdir(output_dir)
117 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
118 | print(f"All checkpoints founded in {output_dir}: {checkpoints}")
119 | if len(checkpoints) > 0:
120 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
121 | print(f"The latest checkpoint founded: {latest_checkpoint}")
122 | resume_file = latest_checkpoint
123 | else:
124 | resume_file = None
125 | return resume_file
126 |
127 |
128 | def reduce_tensor(tensor):
129 | rt = tensor.clone()
130 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
131 | rt /= dist.get_world_size()
132 | return rt
133 |
--------------------------------------------------------------------------------
/e-commercial/data/saliency_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import torch, csv
4 | from PIL import Image
5 | from torchvision import transforms
6 | from torch.utils.data import Dataset
7 | import torch.nn.functional as F
8 | import random
9 | import glob
10 | import numpy as np
11 | random.seed(1)
12 | globaltest = []
13 |
14 | def ecommercedata(data_path, img_nums, is_train, logger):
15 | global globaltest
16 | nums = [i for i in range(1, img_nums + 1)]
17 | if not globaltest:
18 | globaltest = random.sample(nums, img_nums // 10)
19 | test = globaltest
20 | logger.info(f"now globaltest is {globaltest}")
21 | train = list(set(nums) - set(test))
22 | if is_train:
23 | print(f"using {(img_nums - (img_nums // 10))} images for train")
24 | return EcommerceDataset(data_path=data_path, img_nums=(img_nums - (img_nums // 10)),
25 | lis=train)
26 | else:
27 | print(f"using {(img_nums // 10)} images for test")
28 | return EcommerceDataset(data_path=data_path, img_nums=(img_nums // 10),
29 | lis=test)
30 |
31 |
32 | class EcommerceDataset(Dataset):
33 | """Face Landmarks dataset."""
34 |
35 | def __init__(self, data_path, img_nums, transform=None, lis=[]):
36 | """
37 | Args:
38 | data_path (string): Path to the imgs file with saliency.
39 | img_nums (int): Total number of images to index.
40 | transform (callable, optional): Optional transform to be applied
41 | on a sample.
42 | """
43 | self.root_dir = data_path
44 | self.transform = transform
45 | self.data_len = img_nums
46 | self.lis = lis
47 |
48 | def __len__(self):
49 | return self.data_len
50 |
51 | def __getitem__(self, idx):
52 | if torch.is_tensor(idx):
53 | idx = idx.tolist()
54 | # print("idx", idx, "length", len(self.lis))
55 |
56 | img_name = 'ALLSTIMULI/' + str(self.lis[idx]) + '.jpg'
57 | saliency_name = 'ALLFIXATIONMAPS/' + str(self.lis[idx]) + '_fixMap.jpg'
58 | ocr_aff_name = 'OCR/affinity/' + str(self.lis[idx]) + '.csv'
59 | ocr_reg_name = 'OCR/region/' + str(self.lis[idx]) + '.csv'
60 | img_file = os.path.join(self.root_dir, img_name)
61 | saliency_file = os.path.join(self.root_dir, saliency_name)
62 | ocr_aff_file = os.path.join(self.root_dir, ocr_aff_name)
63 | ocr_reg_file = os.path.join(self.root_dir, ocr_reg_name)
64 |
65 | # images
66 | img = Image.open(img_file)
67 | # img = np.array(img, dtype=np.float32) # h, w, c
68 | torch_img = transforms.functional.to_tensor(img)
69 | torch_img = transforms.Resize(896)(torch_img)
70 | # saliency
71 | saliency = Image.open(saliency_file)
72 | # saliency = np.array(saliency, dtype=np.float32) # h, w, c
73 | torch_saliency = transforms.functional.to_tensor(saliency)
74 | torch_saliency = transforms.Resize(896)(torch_saliency)
75 | # ocr
76 | csv_aff_content = np.loadtxt(open(ocr_aff_file, "rb"), delimiter=",")
77 | csv_reg_content = np.loadtxt(open(ocr_reg_file, "rb"), delimiter=",")
78 | # to make pytorch happy in transforms
79 | csv_aff_image = np.expand_dims(csv_aff_content, axis=0)
80 | csv_reg_image = np.expand_dims(csv_reg_content, axis=0)
81 | torch_aff = torch.from_numpy(csv_aff_image).float()
82 | torch_reg = torch.from_numpy(csv_reg_image).float()
83 | gh_label = transforms.Resize(224)(torch_aff)
84 | gah_label = transforms.Resize(224)(torch_reg)
85 | # print('The sizes of gh_label and gah_label are :', gh_label.size(), gah_label.size())
86 | if self.transform:
87 | raise NotImplementedError("Not support any transform by far!")
88 | # sample = self.transform(sample)
89 |
90 | return self.lis[idx], torch_img, torch_saliency, \
91 | {'gh_label': gh_label, 'gah_label': gah_label, 'mask': torch.ones_like(gah_label)}
92 |
93 |
94 | class finetunedata(Dataset):
95 | """Face Landmarks dataset."""
96 |
97 | def __init__(self, data_path, transform=None):
98 | """
99 | Args:
100 | data_path (string): Path to the imgs file with saliency.
101 | img_nums (int): Total number of images to index.
102 | transform (callable, optional): Optional transform to be applied
103 | on a sample.
104 | """
105 | self.root_dir = data_path
106 | self.transform = transform
107 | path = os.path.join(self.root_dir, 'stimuli/')
108 | path2 = os.path.join(self.root_dir, 'fixation/')
109 | print("using imgs in %s"%path)
110 | imgs = [f for f in glob.glob(path+"*.jpg")]
111 | gts = [f for f in glob.glob(path2+"*.jpg")]
112 | print("imgs are like %s"%imgs[0])
113 | self.data_len = len(imgs)
114 | print("using %d imgs in training"%self.data_len)
115 | self.lis = imgs
116 | self.gts = gts
117 |
118 | def __len__(self):
119 | return self.data_len
120 |
121 | def __getitem__(self, idx):
122 | if torch.is_tensor(idx):
123 | idx = idx.tolist()
124 |
125 | img_file = str(self.lis[idx])
126 | saliency_file = str(self.gts[idx])
127 |
128 | # images
129 | img = Image.open(img_file)
130 | # img = np.array(img, dtype=np.float32) # h, w, c
131 | torch_img = transforms.functional.to_tensor(img)
132 | torch_img = transforms.Resize((896,896))(torch_img)
133 | # saliency
134 | saliency = Image.open(saliency_file)
135 | # saliency = np.array(saliency, dtype=np.float32) # h, w, c
136 | torch_saliency = transforms.functional.to_tensor(saliency)
137 | torch_saliency = transforms.Resize((896,896))(torch_saliency)
138 | # print("traing", torch_img.shape, torch_saliency.shape)
139 | if torch_img.size()[0] == 1:
140 | torch_img = torch_img.expand([3,896,896])
141 | if self.transform:
142 | raise NotImplementedError("Not support any transform by far!")
143 | # sample = self.transform(sample)
144 |
145 | return str(self.lis[idx]), torch_img, torch_saliency
146 |
147 | class folderimagedata(Dataset):
148 | """Face Landmarks dataset."""
149 |
150 | def __init__(self, data_path, transform=None):
151 | """
152 | Args:
153 | data_path (string): Path to the imgs file with saliency.
154 | img_nums (int): Total number of images to index.
155 | transform (callable, optional): Optional transform to be applied
156 | on a sample.
157 | """
158 | self.root_dir = data_path
159 | self.transform = transform
160 | print("using imgs in %s"%self.root_dir)
161 | supported = ["jpg", "jpeg"]
162 | imgs = []
163 | for i in supported:
164 | types = [f for f in glob.glob(self.root_dir+"*."+i)]
165 | imgs += types
166 | print("imgs are like %s"%imgs[0])
167 | self.data_len = len(imgs)
168 | print("using %d imgs in validate folder"%self.data_len)
169 | self.lis = imgs
170 |
171 | def __len__(self):
172 | return self.data_len
173 |
174 | def __getitem__(self, idx):
175 | if torch.is_tensor(idx):
176 | idx = idx.tolist()
177 |
178 | img_file = str(self.lis[idx])
179 |
180 | # images
181 | img = Image.open(img_file)
182 | # img = np.array(img, dtype=np.float32) # h, w, c
183 | torch_img = transforms.functional.to_tensor(img)
184 | torch_img = transforms.Resize((896,896))(torch_img)
185 | print(torch_img.size())
186 | if torch_img.size()[0] == 1:
187 | torch_img = torch_img.expand([3,896,896])
188 | if self.transform:
189 | raise NotImplementedError("Not support any transform by far!")
190 | # sample = self.transform(sample)
191 |
192 | return str(self.lis[idx]), torch_img, torch_img[:1].unsqueeze(0)
193 |
--------------------------------------------------------------------------------
/e-commercial/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | import os
6 | import yaml
7 | import time
8 | from yacs.config import CfgNode as CN
9 |
10 | _C = CN()
11 |
12 | # Base config files
13 | _C.BASE = ['']
14 |
15 | # -----------------------------------------------------------------------------
16 | # Data settings
17 | # -----------------------------------------------------------------------------
18 | _C.DATA = CN()
19 | # Batch size for a single GPU, could be overwritten by command line argument
20 | _C.DATA.BATCH_SIZE = 128
21 | # Path to dataset, could be overwritten by command line argument
22 | _C.DATA.DATA_PATH = ''
23 | # Dataset name
24 | _C.DATA.DATASET = 'imagenet'
25 | # Input image size
26 | _C.DATA.IMG_SIZE = 896
27 | # Interpolation to resize image (random, bilinear, bicubic)
28 | _C.DATA.INTERPOLATION = 'bicubic'
29 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
30 | _C.DATA.PIN_MEMORY = True
31 | # Number of data loading threads
32 | _C.DATA.NUM_WORKERS = 8
33 |
34 | # -----------------------------------------------------------------------------
35 | # Model settings
36 | # -----------------------------------------------------------------------------
37 | _C.MODEL = CN()
38 | # Model type
39 | _C.MODEL.TYPE = 'sswin'
40 | # Model name
41 | _C.MODEL.NAME = 'sswin'
42 | # Checkpoint to resume, could be overwritten by command line argument
43 | _C.MODEL.RESUME = ''
44 | # Number of classes, overwritten in data preparation
45 | _C.MODEL.NUM_CLASSES = 1
46 | # Dropout rate
47 | _C.MODEL.DROP_RATE = 0.0
48 | # Drop path rate
49 | _C.MODEL.DROP_PATH_RATE = 0.1
50 | # Label Smoothing
51 | _C.MODEL.LABEL_SMOOTHING = 0.1
52 |
53 | _C.MODEL.FINETUNE = 'your_path/swin_tiny_patch4_window7_224/default/ckpt_epoch_49.pth'
54 |
55 | # Swin Transformer parameters
56 | _C.MODEL.SWIN = CN()
57 | _C.MODEL.SWIN.PATCH_SIZE = 4
58 | _C.MODEL.SWIN.IN_CHANS = 3
59 | _C.MODEL.SWIN.EMBED_DIM = 96
60 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
61 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
62 | _C.MODEL.SWIN.WINDOW_SIZE = 7
63 | _C.MODEL.SWIN.MLP_RATIO = 4.
64 | _C.MODEL.SWIN.QKV_BIAS = True
65 | _C.MODEL.SWIN.QK_SCALE = None
66 | _C.MODEL.SWIN.APE = False
67 | _C.MODEL.SWIN.PATCH_NORM = True
68 |
69 | # Swin MLP parameters
70 | _C.MODEL.SWIN_MLP = CN()
71 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4
72 | _C.MODEL.SWIN_MLP.IN_CHANS = 3
73 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96
74 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2]
75 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24]
76 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7
77 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4.
78 | _C.MODEL.SWIN_MLP.APE = False
79 | _C.MODEL.SWIN_MLP.PATCH_NORM = True
80 |
81 | # -----------------------------------------------------------------------------
82 | # Training settings
83 | # -----------------------------------------------------------------------------
84 | _C.TRAIN = CN()
85 | _C.TRAIN.START_EPOCH = 0
86 | _C.TRAIN.EPOCHS = 50
87 | _C.TRAIN.WARMUP_EPOCHS = 20
88 | _C.TRAIN.WEIGHT_DECAY = 0.05
89 | _C.TRAIN.BASE_LR = 5e-4
90 | _C.TRAIN.WARMUP_LR = 5e-7
91 | _C.TRAIN.MIN_LR = 5e-6
92 | # Clip gradient norm
93 | _C.TRAIN.CLIP_GRAD = 5.0
94 | # Auto resume from latest checkpoint
95 | _C.TRAIN.AUTO_RESUME = True
96 | # Whether to use gradient checkpointing to save memory
97 | # could be overwritten by command line argument
98 | _C.TRAIN.USE_CHECKPOINT = False
99 |
100 | # LR scheduler
101 | _C.TRAIN.LR_SCHEDULER = CN()
102 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
103 | # Epoch interval to decay LR, used in StepLRScheduler
104 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
105 | # LR decay rate, used in StepLRScheduler
106 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
107 |
108 | # Optimizer
109 | _C.TRAIN.OPTIMIZER = CN()
110 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
111 | # Optimizer Epsilon
112 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
113 | # Optimizer Betas
114 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
115 | # SGD momentum
116 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
117 |
118 | # -----------------------------------------------------------------------------
119 | # Augmentation settings
120 | # -----------------------------------------------------------------------------
121 | _C.AUG = CN()
122 | # Color jitter factor
123 | _C.AUG.COLOR_JITTER = 0.4
124 | # Use AutoAugment policy. "v0" or "original"
125 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
126 | # Random erase prob
127 | _C.AUG.REPROB = 0.25
128 | # Random erase mode
129 | _C.AUG.REMODE = 'pixel'
130 | # Random erase count
131 | _C.AUG.RECOUNT = 1
132 | # Mixup alpha, mixup enabled if > 0
133 | _C.AUG.MIXUP = 0.8
134 | # Cutmix alpha, cutmix enabled if > 0
135 | _C.AUG.CUTMIX = 1.0
136 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
137 | _C.AUG.CUTMIX_MINMAX = None
138 | # Probability of performing mixup or cutmix when either/both is enabled
139 | _C.AUG.MIXUP_PROB = 1.0
140 | # Probability of switching to cutmix when both mixup and cutmix enabled
141 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
142 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
143 | _C.AUG.MIXUP_MODE = 'batch'
144 |
145 | # -----------------------------------------------------------------------------
146 | # Testing settings
147 | # -----------------------------------------------------------------------------
148 | _C.TEST = CN()
149 | # Whether to use center crop when testing
150 | _C.TEST.CROP = True
151 |
152 | # -----------------------------------------------------------------------------
153 | # Misc
154 | # -----------------------------------------------------------------------------
155 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
156 | # overwritten by command line argument
157 | _C.AMP_OPT_LEVEL = 'O0'
158 | # Path to output folder, overwritten by command line argument
159 | _C.OUTPUT = ''
160 | # Tag of experiment, overwritten by command line argument
161 | _C.TAG = 'default'
162 | # Frequency to save checkpoint
163 | _C.SAVE_FREQ = 1
164 | # Frequency to logging info
165 | _C.PRINT_FREQ = 10
166 | # Fixed random seed
167 | _C.SEED = 0
168 | # Perform evaluation only, overwritten by command line argument
169 | _C.EVAL_MODE = False
170 | # Test throughput only, overwritten by command line argument
171 | _C.THROUGHPUT_MODE = False
172 | # local rank for DistributedDataParallel, given by command line argument
173 | # _C.LOCAL_RANK = 0
174 |
175 | _C.HEAD = 'denseNet_15layer'
176 |
177 |
178 | def _update_config_from_file(config, cfg_file):
179 | config.defrost()
180 | with open(cfg_file, 'r') as f:
181 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
182 |
183 | for cfg in yaml_cfg.setdefault('BASE', ['']):
184 | if cfg:
185 | _update_config_from_file(
186 | config, os.path.join(os.path.dirname(cfg_file), cfg)
187 | )
188 | print('=> merge config from {}'.format(cfg_file))
189 | config.merge_from_file(cfg_file)
190 | config.freeze()
191 |
192 |
193 | def update_config(config, args):
194 | _update_config_from_file(config, args.cfg)
195 |
196 | config.defrost()
197 | if args.opts:
198 | config.merge_from_list(args.opts)
199 |
200 | # merge from specific arguments
201 | if args.batch_size:
202 | config.DATA.BATCH_SIZE = args.batch_size
203 | if args.data_path:
204 | config.DATA.DATA_PATH = args.data_path
205 | if args.resume:
206 | config.MODEL.RESUME = args.resume
207 | if args.use_checkpoint:
208 | config.TRAIN.USE_CHECKPOINT = True
209 | if args.output:
210 | config.OUTPUT = args.output
211 | if args.tag:
212 | config.TAG = args.tag
213 | if args.eval:
214 | config.EVAL_MODE = True
215 | if args.dataset:
216 | config.DATA.DATASET = args.dataset
217 | if args.datanum:
218 | config.DATA.DATANUM = args.datanum
219 | if args.num_epoch:
220 | config.TRAIN.EPOCHS = args.num_epoch
221 |
222 | # if args.loss:
223 | if args.finetune:
224 | # print("finetune")
225 | config.MODEL.FINETUNE = args.finetune
226 | # config.LOSS = args.loss
227 |
228 | # set local rank for distributed training
229 | # config.LOCAL_RANK = args.local_rank
230 |
231 | # output folder
232 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME + ".".join(str(time.asctime()).split()), config.TAG)
233 |
234 | config.freeze()
235 |
236 |
237 | def get_config(args):
238 | """Get a yacs CfgNode object with default values."""
239 | # Return a clone so that the defaults will not be altered
240 | # This is for the "local variable" use pattern
241 | config = _C.clone()
242 | update_config(config, args)
243 |
244 | return config
245 |
--------------------------------------------------------------------------------
/.history/e-commercial/config_20240408005816.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | import os
6 | import yaml
7 | import time
8 | from yacs.config import CfgNode as CN
9 |
10 | _C = CN()
11 |
12 | # Base config files
13 | _C.BASE = ['']
14 |
15 | # -----------------------------------------------------------------------------
16 | # Data settings
17 | # -----------------------------------------------------------------------------
18 | _C.DATA = CN()
19 | # Batch size for a single GPU, could be overwritten by command line argument
20 | _C.DATA.BATCH_SIZE = 128
21 | # Path to dataset, could be overwritten by command line argument
22 | _C.DATA.DATA_PATH = ''
23 | # Dataset name
24 | _C.DATA.DATASET = 'imagenet'
25 | # Input image size
26 | _C.DATA.IMG_SIZE = 896
27 | # Interpolation to resize image (random, bilinear, bicubic)
28 | _C.DATA.INTERPOLATION = 'bicubic'
29 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
30 | _C.DATA.PIN_MEMORY = True
31 | # Number of data loading threads
32 | _C.DATA.NUM_WORKERS = 8
33 |
34 | # -----------------------------------------------------------------------------
35 | # Model settings
36 | # -----------------------------------------------------------------------------
37 | _C.MODEL = CN()
38 | # Model type
39 | _C.MODEL.TYPE = 'sswin'
40 | # Model name
41 | _C.MODEL.NAME = 'sswin'
42 | # Checkpoint to resume, could be overwritten by command line argument
43 | _C.MODEL.RESUME = ''
44 | # Number of classes, overwritten in data preparation
45 | _C.MODEL.NUM_CLASSES = 1
46 | # Dropout rate
47 | _C.MODEL.DROP_RATE = 0.0
48 | # Drop path rate
49 | _C.MODEL.DROP_PATH_RATE = 0.1
50 | # Label Smoothing
51 | _C.MODEL.LABEL_SMOOTHING = 0.1
52 |
53 | _C.MODEL.FINETUNE = 'your_path/swin_tiny_patch4_window7_224/default/ckpt_epoch_49.pth'
54 |
55 | # Swin Transformer parameters
56 | _C.MODEL.SWIN = CN()
57 | _C.MODEL.SWIN.PATCH_SIZE = 4
58 | _C.MODEL.SWIN.IN_CHANS = 3
59 | _C.MODEL.SWIN.EMBED_DIM = 96
60 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
61 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
62 | _C.MODEL.SWIN.WINDOW_SIZE = 7
63 | _C.MODEL.SWIN.MLP_RATIO = 4.
64 | _C.MODEL.SWIN.QKV_BIAS = True
65 | _C.MODEL.SWIN.QK_SCALE = None
66 | _C.MODEL.SWIN.APE = False
67 | _C.MODEL.SWIN.PATCH_NORM = True
68 |
69 | # Swin MLP parameters
70 | _C.MODEL.SWIN_MLP = CN()
71 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4
72 | _C.MODEL.SWIN_MLP.IN_CHANS = 3
73 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96
74 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2]
75 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24]
76 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7
77 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4.
78 | _C.MODEL.SWIN_MLP.APE = False
79 | _C.MODEL.SWIN_MLP.PATCH_NORM = True
80 |
81 | # -----------------------------------------------------------------------------
82 | # Training settings
83 | # -----------------------------------------------------------------------------
84 | _C.TRAIN = CN()
85 | _C.TRAIN.START_EPOCH = 0
86 | _C.TRAIN.EPOCHS = 50
87 | _C.TRAIN.WARMUP_EPOCHS = 20
88 | _C.TRAIN.WEIGHT_DECAY = 0.05
89 | _C.TRAIN.BASE_LR = 5e-4
90 | _C.TRAIN.WARMUP_LR = 5e-7
91 | _C.TRAIN.MIN_LR = 5e-6
92 | # Clip gradient norm
93 | _C.TRAIN.CLIP_GRAD = 5.0
94 | # Auto resume from latest checkpoint
95 | _C.TRAIN.AUTO_RESUME = True
96 | # Whether to use gradient checkpointing to save memory
97 | # could be overwritten by command line argument
98 | _C.TRAIN.USE_CHECKPOINT = False
99 |
100 | # LR scheduler
101 | _C.TRAIN.LR_SCHEDULER = CN()
102 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
103 | # Epoch interval to decay LR, used in StepLRScheduler
104 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
105 | # LR decay rate, used in StepLRScheduler
106 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
107 |
108 | # Optimizer
109 | _C.TRAIN.OPTIMIZER = CN()
110 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
111 | # Optimizer Epsilon
112 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
113 | # Optimizer Betas
114 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
115 | # SGD momentum
116 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
117 |
118 | # -----------------------------------------------------------------------------
119 | # Augmentation settings
120 | # -----------------------------------------------------------------------------
121 | _C.AUG = CN()
122 | # Color jitter factor
123 | _C.AUG.COLOR_JITTER = 0.4
124 | # Use AutoAugment policy. "v0" or "original"
125 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
126 | # Random erase prob
127 | _C.AUG.REPROB = 0.25
128 | # Random erase mode
129 | _C.AUG.REMODE = 'pixel'
130 | # Random erase count
131 | _C.AUG.RECOUNT = 1
132 | # Mixup alpha, mixup enabled if > 0
133 | _C.AUG.MIXUP = 0.8
134 | # Cutmix alpha, cutmix enabled if > 0
135 | _C.AUG.CUTMIX = 1.0
136 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
137 | _C.AUG.CUTMIX_MINMAX = None
138 | # Probability of performing mixup or cutmix when either/both is enabled
139 | _C.AUG.MIXUP_PROB = 1.0
140 | # Probability of switching to cutmix when both mixup and cutmix enabled
141 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
142 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
143 | _C.AUG.MIXUP_MODE = 'batch'
144 |
145 | # -----------------------------------------------------------------------------
146 | # Testing settings
147 | # -----------------------------------------------------------------------------
148 | _C.TEST = CN()
149 | # Whether to use center crop when testing
150 | _C.TEST.CROP = True
151 |
152 | # -----------------------------------------------------------------------------
153 | # Misc
154 | # -----------------------------------------------------------------------------
155 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
156 | # overwritten by command line argument
157 | _C.AMP_OPT_LEVEL = 'O0'
158 | # Path to output folder, overwritten by command line argument
159 | _C.OUTPUT = ''
160 | # Tag of experiment, overwritten by command line argument
161 | _C.TAG = 'default'
162 | # Frequency to save checkpoint
163 | _C.SAVE_FREQ = 1
164 | # Frequency to logging info
165 | _C.PRINT_FREQ = 10
166 | # Fixed random seed
167 | _C.SEED = 0
168 | # Perform evaluation only, overwritten by command line argument
169 | _C.EVAL_MODE = False
170 | # Test throughput only, overwritten by command line argument
171 | _C.THROUGHPUT_MODE = False
172 | # local rank for DistributedDataParallel, given by command line argument
173 | # _C.LOCAL_RANK = 0
174 |
175 | _C.HEAD = 'denseNet_15layer'
176 |
177 |
178 | def _update_config_from_file(config, cfg_file):
179 | config.defrost()
180 | with open(cfg_file, 'r') as f:
181 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
182 |
183 | for cfg in yaml_cfg.setdefault('BASE', ['']):
184 | if cfg:
185 | _update_config_from_file(
186 | config, os.path.join(os.path.dirname(cfg_file), cfg)
187 | )
188 | print('=> merge config from {}'.format(cfg_file))
189 | config.merge_from_file(cfg_file)
190 | config.freeze()
191 |
192 |
193 | def update_config(config, args):
194 | _update_config_from_file(config, args.cfg)
195 |
196 | config.defrost()
197 | if args.opts:
198 | config.merge_from_list(args.opts)
199 |
200 | # merge from specific arguments
201 | if args.batch_size:
202 | config.DATA.BATCH_SIZE = args.batch_size
203 | if args.data_path:
204 | config.DATA.DATA_PATH = args.data_path
205 | if args.resume:
206 | config.MODEL.RESUME = args.resume
207 | if args.use_checkpoint:
208 | config.TRAIN.USE_CHECKPOINT = True
209 | if args.output:
210 | config.OUTPUT = args.output
211 | if args.tag:
212 | config.TAG = args.tag
213 | if args.eval:
214 | config.EVAL_MODE = True
215 | if args.dataset:
216 | config.DATA.DATASET = args.dataset
217 | if args.datanum:
218 | config.DATA.DATANUM = args.datanum
219 | if args.num_epoch:
220 | config.TRAIN.EPOCHS = args.num_epoch
221 |
222 | # if args.loss:
223 | if args.finetune:
224 | # print("finetune")
225 | config.MODEL.FINETUNE = args.finetune
226 | # config.LOSS = args.loss
227 |
228 | # set local rank for distributed training
229 | # config.LOCAL_RANK = args.local_rank
230 |
231 | # output folder
232 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME + ".".join(str(time.asctime()).split()), config.TAG)
233 |
234 | config.freeze()
235 |
236 |
237 | def get_config(args):
238 | """Get a yacs CfgNode object with default values."""
239 | # Return a clone so that the defaults will not be altered
240 | # This is for the "local variable" use pattern
241 | config = _C.clone()
242 | update_config(config, args)
243 |
244 | return config
245 |
--------------------------------------------------------------------------------
/.history/e-commercial/config_20240330210315.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | import os
6 | import yaml
7 | import time
8 | from yacs.config import CfgNode as CN
9 |
10 | _C = CN()
11 |
12 | # Base config files
13 | _C.BASE = ['']
14 |
15 | # -----------------------------------------------------------------------------
16 | # Data settings
17 | # -----------------------------------------------------------------------------
18 | _C.DATA = CN()
19 | # Batch size for a single GPU, could be overwritten by command line argument
20 | _C.DATA.BATCH_SIZE = 128
21 | # Path to dataset, could be overwritten by command line argument
22 | _C.DATA.DATA_PATH = ''
23 | # Dataset name
24 | _C.DATA.DATASET = 'imagenet'
25 | # Input image size
26 | _C.DATA.IMG_SIZE = 896
27 | # Interpolation to resize image (random, bilinear, bicubic)
28 | _C.DATA.INTERPOLATION = 'bicubic'
29 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
30 | _C.DATA.PIN_MEMORY = True
31 | # Number of data loading threads
32 | _C.DATA.NUM_WORKERS = 8
33 |
34 | # -----------------------------------------------------------------------------
35 | # Model settings
36 | # -----------------------------------------------------------------------------
37 | _C.MODEL = CN()
38 | # Model type
39 | _C.MODEL.TYPE = 'sswin'
40 | # Model name
41 | _C.MODEL.NAME = 'sswin'
42 | # Checkpoint to resume, could be overwritten by command line argument
43 | _C.MODEL.RESUME = ''
44 | # Number of classes, overwritten in data preparation
45 | _C.MODEL.NUM_CLASSES = 1
46 | # Dropout rate
47 | _C.MODEL.DROP_RATE = 0.0
48 | # Drop path rate
49 | _C.MODEL.DROP_PATH_RATE = 0.1
50 | # Label Smoothing
51 | _C.MODEL.LABEL_SMOOTHING = 0.1
52 |
53 | _C.MODEL.FINETUNE = '/mnt/disk10T/liyifei/CVPR2022EC/output/sal+ocr/swin_tiny_patch4_window7_224/default/ckpt_epoch_49.pth'
54 |
55 | # Swin Transformer parameters
56 | _C.MODEL.SWIN = CN()
57 | _C.MODEL.SWIN.PATCH_SIZE = 4
58 | _C.MODEL.SWIN.IN_CHANS = 3
59 | _C.MODEL.SWIN.EMBED_DIM = 96
60 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
61 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
62 | _C.MODEL.SWIN.WINDOW_SIZE = 7
63 | _C.MODEL.SWIN.MLP_RATIO = 4.
64 | _C.MODEL.SWIN.QKV_BIAS = True
65 | _C.MODEL.SWIN.QK_SCALE = None
66 | _C.MODEL.SWIN.APE = False
67 | _C.MODEL.SWIN.PATCH_NORM = True
68 |
69 | # Swin MLP parameters
70 | _C.MODEL.SWIN_MLP = CN()
71 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4
72 | _C.MODEL.SWIN_MLP.IN_CHANS = 3
73 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96
74 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2]
75 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24]
76 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7
77 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4.
78 | _C.MODEL.SWIN_MLP.APE = False
79 | _C.MODEL.SWIN_MLP.PATCH_NORM = True
80 |
81 | # -----------------------------------------------------------------------------
82 | # Training settings
83 | # -----------------------------------------------------------------------------
84 | _C.TRAIN = CN()
85 | _C.TRAIN.START_EPOCH = 0
86 | _C.TRAIN.EPOCHS = 50
87 | _C.TRAIN.WARMUP_EPOCHS = 20
88 | _C.TRAIN.WEIGHT_DECAY = 0.05
89 | _C.TRAIN.BASE_LR = 5e-4
90 | _C.TRAIN.WARMUP_LR = 5e-7
91 | _C.TRAIN.MIN_LR = 5e-6
92 | # Clip gradient norm
93 | _C.TRAIN.CLIP_GRAD = 5.0
94 | # Auto resume from latest checkpoint
95 | _C.TRAIN.AUTO_RESUME = True
96 | # Whether to use gradient checkpointing to save memory
97 | # could be overwritten by command line argument
98 | _C.TRAIN.USE_CHECKPOINT = False
99 |
100 | # LR scheduler
101 | _C.TRAIN.LR_SCHEDULER = CN()
102 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
103 | # Epoch interval to decay LR, used in StepLRScheduler
104 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
105 | # LR decay rate, used in StepLRScheduler
106 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
107 |
108 | # Optimizer
109 | _C.TRAIN.OPTIMIZER = CN()
110 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
111 | # Optimizer Epsilon
112 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
113 | # Optimizer Betas
114 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
115 | # SGD momentum
116 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
117 |
118 | # -----------------------------------------------------------------------------
119 | # Augmentation settings
120 | # -----------------------------------------------------------------------------
121 | _C.AUG = CN()
122 | # Color jitter factor
123 | _C.AUG.COLOR_JITTER = 0.4
124 | # Use AutoAugment policy. "v0" or "original"
125 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
126 | # Random erase prob
127 | _C.AUG.REPROB = 0.25
128 | # Random erase mode
129 | _C.AUG.REMODE = 'pixel'
130 | # Random erase count
131 | _C.AUG.RECOUNT = 1
132 | # Mixup alpha, mixup enabled if > 0
133 | _C.AUG.MIXUP = 0.8
134 | # Cutmix alpha, cutmix enabled if > 0
135 | _C.AUG.CUTMIX = 1.0
136 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
137 | _C.AUG.CUTMIX_MINMAX = None
138 | # Probability of performing mixup or cutmix when either/both is enabled
139 | _C.AUG.MIXUP_PROB = 1.0
140 | # Probability of switching to cutmix when both mixup and cutmix enabled
141 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
142 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
143 | _C.AUG.MIXUP_MODE = 'batch'
144 |
145 | # -----------------------------------------------------------------------------
146 | # Testing settings
147 | # -----------------------------------------------------------------------------
148 | _C.TEST = CN()
149 | # Whether to use center crop when testing
150 | _C.TEST.CROP = True
151 |
152 | # -----------------------------------------------------------------------------
153 | # Misc
154 | # -----------------------------------------------------------------------------
155 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
156 | # overwritten by command line argument
157 | _C.AMP_OPT_LEVEL = 'O0'
158 | # Path to output folder, overwritten by command line argument
159 | _C.OUTPUT = ''
160 | # Tag of experiment, overwritten by command line argument
161 | _C.TAG = 'default'
162 | # Frequency to save checkpoint
163 | _C.SAVE_FREQ = 1
164 | # Frequency to logging info
165 | _C.PRINT_FREQ = 10
166 | # Fixed random seed
167 | _C.SEED = 0
168 | # Perform evaluation only, overwritten by command line argument
169 | _C.EVAL_MODE = False
170 | # Test throughput only, overwritten by command line argument
171 | _C.THROUGHPUT_MODE = False
172 | # local rank for DistributedDataParallel, given by command line argument
173 | # _C.LOCAL_RANK = 0
174 |
175 | _C.HEAD = 'denseNet_15layer'
176 |
177 |
178 | def _update_config_from_file(config, cfg_file):
179 | config.defrost()
180 | with open(cfg_file, 'r') as f:
181 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
182 |
183 | for cfg in yaml_cfg.setdefault('BASE', ['']):
184 | if cfg:
185 | _update_config_from_file(
186 | config, os.path.join(os.path.dirname(cfg_file), cfg)
187 | )
188 | print('=> merge config from {}'.format(cfg_file))
189 | config.merge_from_file(cfg_file)
190 | config.freeze()
191 |
192 |
193 | def update_config(config, args):
194 | _update_config_from_file(config, args.cfg)
195 |
196 | config.defrost()
197 | if args.opts:
198 | config.merge_from_list(args.opts)
199 |
200 | # merge from specific arguments
201 | if args.batch_size:
202 | config.DATA.BATCH_SIZE = args.batch_size
203 | if args.data_path:
204 | config.DATA.DATA_PATH = args.data_path
205 | if args.resume:
206 | config.MODEL.RESUME = args.resume
207 | if args.use_checkpoint:
208 | config.TRAIN.USE_CHECKPOINT = True
209 | if args.output:
210 | config.OUTPUT = args.output
211 | if args.tag:
212 | config.TAG = args.tag
213 | if args.eval:
214 | config.EVAL_MODE = True
215 | if args.dataset:
216 | config.DATA.DATASET = args.dataset
217 | if args.datanum:
218 | config.DATA.DATANUM = args.datanum
219 | if args.num_epoch:
220 | config.TRAIN.EPOCHS = args.num_epoch
221 |
222 | # if args.loss:
223 | if args.finetune:
224 | # print("finetune")
225 | config.MODEL.FINETUNE = args.finetune
226 | # config.LOSS = args.loss
227 |
228 | # set local rank for distributed training
229 | # config.LOCAL_RANK = args.local_rank
230 |
231 | # output folder
232 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME + ".".join(str(time.asctime()).split()), config.TAG)
233 |
234 | config.freeze()
235 |
236 |
237 | def get_config(args):
238 | """Get a yacs CfgNode object with default values."""
239 | # Return a clone so that the defaults will not be altered
240 | # This is for the "local variable" use pattern
241 | config = _C.clone()
242 | update_config(config, args)
243 |
244 | return config
245 |
--------------------------------------------------------------------------------
/.history/e-commercial/config_20240408005800.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | import os
6 | import yaml
7 | import time
8 | from yacs.config import CfgNode as CN
9 |
10 | _C = CN()
11 |
12 | # Base config files
13 | _C.BASE = ['']
14 |
15 | # -----------------------------------------------------------------------------
16 | # Data settings
17 | # -----------------------------------------------------------------------------
18 | _C.DATA = CN()
19 | # Batch size for a single GPU, could be overwritten by command line argument
20 | _C.DATA.BATCH_SIZE = 128
21 | # Path to dataset, could be overwritten by command line argument
22 | _C.DATA.DATA_PATH = ''
23 | # Dataset name
24 | _C.DATA.DATASET = 'imagenet'
25 | # Input image size
26 | _C.DATA.IMG_SIZE = 896
27 | # Interpolation to resize image (random, bilinear, bicubic)
28 | _C.DATA.INTERPOLATION = 'bicubic'
29 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
30 | _C.DATA.PIN_MEMORY = True
31 | # Number of data loading threads
32 | _C.DATA.NUM_WORKERS = 8
33 |
34 | # -----------------------------------------------------------------------------
35 | # Model settings
36 | # -----------------------------------------------------------------------------
37 | _C.MODEL = CN()
38 | # Model type
39 | _C.MODEL.TYPE = 'sswin'
40 | # Model name
41 | _C.MODEL.NAME = 'sswin'
42 | # Checkpoint to resume, could be overwritten by command line argument
43 | _C.MODEL.RESUME = ''
44 | # Number of classes, overwritten in data preparation
45 | _C.MODEL.NUM_CLASSES = 1
46 | # Dropout rate
47 | _C.MODEL.DROP_RATE = 0.0
48 | # Drop path rate
49 | _C.MODEL.DROP_PATH_RATE = 0.1
50 | # Label Smoothing
51 | _C.MODEL.LABEL_SMOOTHING = 0.1
52 |
53 | _C.MODEL.FINETUNE = '/mnt/disk10T/your_path/CVPR2022EC/output/sal+ocr/swin_tiny_patch4_window7_224/default/ckpt_epoch_49.pth'
54 |
55 | # Swin Transformer parameters
56 | _C.MODEL.SWIN = CN()
57 | _C.MODEL.SWIN.PATCH_SIZE = 4
58 | _C.MODEL.SWIN.IN_CHANS = 3
59 | _C.MODEL.SWIN.EMBED_DIM = 96
60 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
61 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
62 | _C.MODEL.SWIN.WINDOW_SIZE = 7
63 | _C.MODEL.SWIN.MLP_RATIO = 4.
64 | _C.MODEL.SWIN.QKV_BIAS = True
65 | _C.MODEL.SWIN.QK_SCALE = None
66 | _C.MODEL.SWIN.APE = False
67 | _C.MODEL.SWIN.PATCH_NORM = True
68 |
69 | # Swin MLP parameters
70 | _C.MODEL.SWIN_MLP = CN()
71 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4
72 | _C.MODEL.SWIN_MLP.IN_CHANS = 3
73 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96
74 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2]
75 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24]
76 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7
77 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4.
78 | _C.MODEL.SWIN_MLP.APE = False
79 | _C.MODEL.SWIN_MLP.PATCH_NORM = True
80 |
81 | # -----------------------------------------------------------------------------
82 | # Training settings
83 | # -----------------------------------------------------------------------------
84 | _C.TRAIN = CN()
85 | _C.TRAIN.START_EPOCH = 0
86 | _C.TRAIN.EPOCHS = 50
87 | _C.TRAIN.WARMUP_EPOCHS = 20
88 | _C.TRAIN.WEIGHT_DECAY = 0.05
89 | _C.TRAIN.BASE_LR = 5e-4
90 | _C.TRAIN.WARMUP_LR = 5e-7
91 | _C.TRAIN.MIN_LR = 5e-6
92 | # Clip gradient norm
93 | _C.TRAIN.CLIP_GRAD = 5.0
94 | # Auto resume from latest checkpoint
95 | _C.TRAIN.AUTO_RESUME = True
96 | # Whether to use gradient checkpointing to save memory
97 | # could be overwritten by command line argument
98 | _C.TRAIN.USE_CHECKPOINT = False
99 |
100 | # LR scheduler
101 | _C.TRAIN.LR_SCHEDULER = CN()
102 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
103 | # Epoch interval to decay LR, used in StepLRScheduler
104 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
105 | # LR decay rate, used in StepLRScheduler
106 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
107 |
108 | # Optimizer
109 | _C.TRAIN.OPTIMIZER = CN()
110 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
111 | # Optimizer Epsilon
112 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
113 | # Optimizer Betas
114 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
115 | # SGD momentum
116 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
117 |
118 | # -----------------------------------------------------------------------------
119 | # Augmentation settings
120 | # -----------------------------------------------------------------------------
121 | _C.AUG = CN()
122 | # Color jitter factor
123 | _C.AUG.COLOR_JITTER = 0.4
124 | # Use AutoAugment policy. "v0" or "original"
125 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
126 | # Random erase prob
127 | _C.AUG.REPROB = 0.25
128 | # Random erase mode
129 | _C.AUG.REMODE = 'pixel'
130 | # Random erase count
131 | _C.AUG.RECOUNT = 1
132 | # Mixup alpha, mixup enabled if > 0
133 | _C.AUG.MIXUP = 0.8
134 | # Cutmix alpha, cutmix enabled if > 0
135 | _C.AUG.CUTMIX = 1.0
136 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
137 | _C.AUG.CUTMIX_MINMAX = None
138 | # Probability of performing mixup or cutmix when either/both is enabled
139 | _C.AUG.MIXUP_PROB = 1.0
140 | # Probability of switching to cutmix when both mixup and cutmix enabled
141 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
142 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
143 | _C.AUG.MIXUP_MODE = 'batch'
144 |
145 | # -----------------------------------------------------------------------------
146 | # Testing settings
147 | # -----------------------------------------------------------------------------
148 | _C.TEST = CN()
149 | # Whether to use center crop when testing
150 | _C.TEST.CROP = True
151 |
152 | # -----------------------------------------------------------------------------
153 | # Misc
154 | # -----------------------------------------------------------------------------
155 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
156 | # overwritten by command line argument
157 | _C.AMP_OPT_LEVEL = 'O0'
158 | # Path to output folder, overwritten by command line argument
159 | _C.OUTPUT = ''
160 | # Tag of experiment, overwritten by command line argument
161 | _C.TAG = 'default'
162 | # Frequency to save checkpoint
163 | _C.SAVE_FREQ = 1
164 | # Frequency to logging info
165 | _C.PRINT_FREQ = 10
166 | # Fixed random seed
167 | _C.SEED = 0
168 | # Perform evaluation only, overwritten by command line argument
169 | _C.EVAL_MODE = False
170 | # Test throughput only, overwritten by command line argument
171 | _C.THROUGHPUT_MODE = False
172 | # local rank for DistributedDataParallel, given by command line argument
173 | # _C.LOCAL_RANK = 0
174 |
175 | _C.HEAD = 'denseNet_15layer'
176 |
177 |
178 | def _update_config_from_file(config, cfg_file):
179 | config.defrost()
180 | with open(cfg_file, 'r') as f:
181 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
182 |
183 | for cfg in yaml_cfg.setdefault('BASE', ['']):
184 | if cfg:
185 | _update_config_from_file(
186 | config, os.path.join(os.path.dirname(cfg_file), cfg)
187 | )
188 | print('=> merge config from {}'.format(cfg_file))
189 | config.merge_from_file(cfg_file)
190 | config.freeze()
191 |
192 |
193 | def update_config(config, args):
194 | _update_config_from_file(config, args.cfg)
195 |
196 | config.defrost()
197 | if args.opts:
198 | config.merge_from_list(args.opts)
199 |
200 | # merge from specific arguments
201 | if args.batch_size:
202 | config.DATA.BATCH_SIZE = args.batch_size
203 | if args.data_path:
204 | config.DATA.DATA_PATH = args.data_path
205 | if args.resume:
206 | config.MODEL.RESUME = args.resume
207 | if args.use_checkpoint:
208 | config.TRAIN.USE_CHECKPOINT = True
209 | if args.output:
210 | config.OUTPUT = args.output
211 | if args.tag:
212 | config.TAG = args.tag
213 | if args.eval:
214 | config.EVAL_MODE = True
215 | if args.dataset:
216 | config.DATA.DATASET = args.dataset
217 | if args.datanum:
218 | config.DATA.DATANUM = args.datanum
219 | if args.num_epoch:
220 | config.TRAIN.EPOCHS = args.num_epoch
221 |
222 | # if args.loss:
223 | if args.finetune:
224 | # print("finetune")
225 | config.MODEL.FINETUNE = args.finetune
226 | # config.LOSS = args.loss
227 |
228 | # set local rank for distributed training
229 | # config.LOCAL_RANK = args.local_rank
230 |
231 | # output folder
232 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME + ".".join(str(time.asctime()).split()), config.TAG)
233 |
234 | config.freeze()
235 |
236 |
237 | def get_config(args):
238 | """Get a yacs CfgNode object with default values."""
239 | # Return a clone so that the defaults will not be altered
240 | # This is for the "local variable" use pattern
241 | config = _C.clone()
242 | update_config(config, args)
243 |
244 | return config
245 |
--------------------------------------------------------------------------------
/e-commercial/data/cached_image_folder.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | import io
6 | import os
7 | import time
8 | import torch.distributed as dist
9 | import torch.utils.data as data
10 | from PIL import Image
11 |
12 | from .zipreader import is_zip_path, ZipReader
13 |
14 |
15 | def has_file_allowed_extension(filename, extensions):
16 | """Checks if a file is an allowed extension.
17 | Args:
18 | filename (string): path to a file
19 | Returns:
20 | bool: True if the filename ends with a known image extension
21 | """
22 | filename_lower = filename.lower()
23 | return any(filename_lower.endswith(ext) for ext in extensions)
24 |
25 |
26 | def find_classes(dir):
27 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
28 | classes.sort()
29 | class_to_idx = {classes[i]: i for i in range(len(classes))}
30 | return classes, class_to_idx
31 |
32 |
33 | def make_dataset(dir, class_to_idx, extensions):
34 | images = []
35 | dir = os.path.expanduser(dir)
36 | for target in sorted(os.listdir(dir)):
37 | d = os.path.join(dir, target)
38 | if not os.path.isdir(d):
39 | continue
40 |
41 | for root, _, fnames in sorted(os.walk(d)):
42 | for fname in sorted(fnames):
43 | if has_file_allowed_extension(fname, extensions):
44 | path = os.path.join(root, fname)
45 | item = (path, class_to_idx[target])
46 | images.append(item)
47 |
48 | return images
49 |
50 |
51 | def make_dataset_with_ann(ann_file, img_prefix, extensions):
52 | images = []
53 | with open(ann_file, "r") as f:
54 | contents = f.readlines()
55 | for line_str in contents:
56 | path_contents = [c for c in line_str.split('\t')]
57 | im_file_name = path_contents[0]
58 | class_index = int(path_contents[1])
59 |
60 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
61 | item = (os.path.join(img_prefix, im_file_name), class_index)
62 |
63 | images.append(item)
64 |
65 | return images
66 |
67 |
68 | class DatasetFolder(data.Dataset):
69 | """A generic data loader where the samples are arranged in this way: ::
70 | root/class_x/xxx.ext
71 | root/class_x/xxy.ext
72 | root/class_x/xxz.ext
73 | root/class_y/123.ext
74 | root/class_y/nsdf3.ext
75 | root/class_y/asd932_.ext
76 | Args:
77 | root (string): Root directory path.
78 | loader (callable): A function to load a sample given its path.
79 | extensions (list[string]): A list of allowed extensions.
80 | transform (callable, optional): A function/transform that takes in
81 | a sample and returns a transformed version.
82 | E.g, ``transforms.RandomCrop`` for images.
83 | target_transform (callable, optional): A function/transform that takes
84 | in the target and transforms it.
85 | Attributes:
86 | samples (list): List of (sample path, class_index) tuples
87 | """
88 |
89 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
90 | cache_mode="no"):
91 | # image folder mode
92 | if ann_file == '':
93 | _, class_to_idx = find_classes(root)
94 | samples = make_dataset(root, class_to_idx, extensions)
95 | # zip mode
96 | else:
97 | samples = make_dataset_with_ann(os.path.join(root, ann_file),
98 | os.path.join(root, img_prefix),
99 | extensions)
100 |
101 | if len(samples) == 0:
102 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
103 | "Supported extensions are: " + ",".join(extensions)))
104 |
105 | self.root = root
106 | self.loader = loader
107 | self.extensions = extensions
108 |
109 | self.samples = samples
110 | self.labels = [y_1k for _, y_1k in samples]
111 | self.classes = list(set(self.labels))
112 |
113 | self.transform = transform
114 | self.target_transform = target_transform
115 |
116 | self.cache_mode = cache_mode
117 | if self.cache_mode != "no":
118 | self.init_cache()
119 |
120 | def init_cache(self):
121 | assert self.cache_mode in ["part", "full"]
122 | n_sample = len(self.samples)
123 | global_rank = dist.get_rank()
124 | world_size = dist.get_world_size()
125 |
126 | samples_bytes = [None for _ in range(n_sample)]
127 | start_time = time.time()
128 | for index in range(n_sample):
129 | if index % (n_sample // 10) == 0:
130 | t = time.time() - start_time
131 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
132 | start_time = time.time()
133 | path, target = self.samples[index]
134 | if self.cache_mode == "full":
135 | samples_bytes[index] = (ZipReader.read(path), target)
136 | elif self.cache_mode == "part" and index % world_size == global_rank:
137 | samples_bytes[index] = (ZipReader.read(path), target)
138 | else:
139 | samples_bytes[index] = (path, target)
140 | self.samples = samples_bytes
141 |
142 | def __getitem__(self, index):
143 | """
144 | Args:
145 | index (int): Index
146 | Returns:
147 | tuple: (sample, target) where target is class_index of the target class.
148 | """
149 | path, target = self.samples[index]
150 | sample = self.loader(path)
151 | if self.transform is not None:
152 | sample = self.transform(sample)
153 | if self.target_transform is not None:
154 | target = self.target_transform(target)
155 |
156 | return sample, target
157 |
158 | def __len__(self):
159 | return len(self.samples)
160 |
161 | def __repr__(self):
162 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
163 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
164 | fmt_str += ' Root Location: {}\n'.format(self.root)
165 | tmp = ' Transforms (if any): '
166 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
167 | tmp = ' Target Transforms (if any): '
168 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
169 | return fmt_str
170 |
171 |
172 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
173 |
174 |
175 | def pil_loader(path):
176 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
177 | if isinstance(path, bytes):
178 | img = Image.open(io.BytesIO(path))
179 | elif is_zip_path(path):
180 | data = ZipReader.read(path)
181 | img = Image.open(io.BytesIO(data))
182 | else:
183 | with open(path, 'rb') as f:
184 | img = Image.open(f)
185 | return img.convert('RGB')
186 |
187 |
188 | def accimage_loader(path):
189 | import accimage
190 | try:
191 | return accimage.Image(path)
192 | except IOError:
193 | # Potentially a decoding problem, fall back to PIL.Image
194 | return pil_loader(path)
195 |
196 |
197 | def default_img_loader(path):
198 | from torchvision import get_image_backend
199 | if get_image_backend() == 'accimage':
200 | return accimage_loader(path)
201 | else:
202 | return pil_loader(path)
203 |
204 |
205 | class CachedImageFolder(DatasetFolder):
206 | """A generic data loader where the images are arranged in this way: ::
207 | root/dog/xxx.png
208 | root/dog/xxy.png
209 | root/dog/xxz.png
210 | root/cat/123.png
211 | root/cat/nsdf3.png
212 | root/cat/asd932_.png
213 | Args:
214 | root (string): Root directory path.
215 | transform (callable, optional): A function/transform that takes in an PIL image
216 | and returns a transformed version. E.g, ``transforms.RandomCrop``
217 | target_transform (callable, optional): A function/transform that takes in the
218 | target and transforms it.
219 | loader (callable, optional): A function to load an image given its path.
220 | Attributes:
221 | imgs (list): List of (image path, class_index) tuples
222 | """
223 |
224 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
225 | loader=default_img_loader, cache_mode="no"):
226 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
227 | ann_file=ann_file, img_prefix=img_prefix,
228 | transform=transform, target_transform=target_transform,
229 | cache_mode=cache_mode)
230 | self.imgs = self.samples
231 |
232 | def __getitem__(self, index):
233 | """
234 | Args:
235 | index (int): Index
236 | Returns:
237 | tuple: (image, target) where target is class_index of the target class.
238 | """
239 | path, target = self.samples[index]
240 | image = self.loader(path)
241 | if self.transform is not None:
242 | img = self.transform(image)
243 | else:
244 | img = image
245 | if self.target_transform is not None:
246 | target = self.target_transform(target)
247 |
248 | return img, target
249 |
--------------------------------------------------------------------------------
/e-commercial/models/craft_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019-present NAVER Corp.
3 | MIT License
4 | """
5 |
6 | # -*- coding: utf-8 -*-
7 | import numpy as np
8 | import cv2
9 | import math
10 |
11 | """ auxilary functions """
12 | # unwarp corodinates
13 | def warpCoord(Minv, pt):
14 | out = np.matmul(Minv, (pt[0], pt[1], 1))
15 | return np.array([out[0]/out[2], out[1]/out[2]])
16 | """ end of auxilary functions """
17 |
18 |
19 | def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
20 | # prepare data
21 | linkmap = linkmap.copy()
22 | textmap = textmap.copy()
23 | img_h, img_w = textmap.shape
24 |
25 | """ labeling method """
26 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
27 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
28 |
29 | text_score_comb = np.clip(text_score + link_score, 0, 1)
30 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
31 |
32 | det = []
33 | mapper = []
34 | for k in range(1,nLabels):
35 | # size filtering
36 | size = stats[k, cv2.CC_STAT_AREA]
37 | if size < 10: continue
38 |
39 | # thresholding
40 | if np.max(textmap[labels==k]) < text_threshold: continue
41 |
42 | # make segmentation map
43 | segmap = np.zeros(textmap.shape, dtype=np.uint8)
44 | segmap[labels==k] = 255
45 | segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area
46 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
47 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
48 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
49 | sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
50 | # boundary check
51 | if sx < 0 : sx = 0
52 | if sy < 0 : sy = 0
53 | if ex >= img_w: ex = img_w
54 | if ey >= img_h: ey = img_h
55 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
56 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel, iterations=1)
57 | #kernel1 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 5))
58 | #segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel1, iterations=1)
59 |
60 |
61 | # make box
62 | np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
63 | rectangle = cv2.minAreaRect(np_contours)
64 | box = cv2.boxPoints(rectangle)
65 |
66 | # align diamond-shape
67 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
68 | box_ratio = max(w, h) / (min(w, h) + 1e-5)
69 | if abs(1 - box_ratio) <= 0.1:
70 | l, r = min(np_contours[:,0]), max(np_contours[:,0])
71 | t, b = min(np_contours[:,1]), max(np_contours[:,1])
72 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
73 |
74 | # make clock-wise order
75 | startidx = box.sum(axis=1).argmin()
76 | box = np.roll(box, 4-startidx, 0)
77 | box = np.array(box)
78 |
79 | det.append(box)
80 | mapper.append(k)
81 |
82 | return det, labels, mapper
83 |
84 | def getPoly_core(boxes, labels, mapper, linkmap):
85 | # configs
86 | num_cp = 5
87 | max_len_ratio = 0.7
88 | expand_ratio = 1.45
89 | max_r = 2.0
90 | step_r = 0.2
91 |
92 | polys = []
93 | for k, box in enumerate(boxes):
94 | # size filter for small instance
95 | w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
96 | if w < 30 or h < 30:
97 | polys.append(None); continue
98 |
99 | # warp image
100 | tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
101 | M = cv2.getPerspectiveTransform(box, tar)
102 | word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
103 | try:
104 | Minv = np.linalg.inv(M)
105 | except:
106 | polys.append(None); continue
107 |
108 | # binarization for selected label
109 | cur_label = mapper[k]
110 | word_label[word_label != cur_label] = 0
111 | word_label[word_label > 0] = 1
112 |
113 | """ Polygon generation """
114 | # find top/bottom contours
115 | cp = []
116 | max_len = -1
117 | for i in range(w):
118 | region = np.where(word_label[:,i] != 0)[0]
119 | if len(region) < 2 : continue
120 | cp.append((i, region[0], region[-1]))
121 | length = region[-1] - region[0] + 1
122 | if length > max_len: max_len = length
123 |
124 | # pass if max_len is similar to h
125 | if h * max_len_ratio < max_len:
126 | polys.append(None); continue
127 |
128 | # get pivot points with fixed length
129 | tot_seg = num_cp * 2 + 1
130 | seg_w = w / tot_seg # segment width
131 | pp = [None] * num_cp # init pivot points
132 | cp_section = [[0, 0]] * tot_seg
133 | seg_height = [0] * num_cp
134 | seg_num = 0
135 | num_sec = 0
136 | prev_h = -1
137 | for i in range(0,len(cp)):
138 | (x, sy, ey) = cp[i]
139 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
140 | # average previous segment
141 | if num_sec == 0: break
142 | cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
143 | num_sec = 0
144 |
145 | # reset variables
146 | seg_num += 1
147 | prev_h = -1
148 |
149 | # accumulate center points
150 | cy = (sy + ey) * 0.5
151 | cur_h = ey - sy + 1
152 | cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
153 | num_sec += 1
154 |
155 | if seg_num % 2 == 0: continue # No polygon area
156 |
157 | if prev_h < cur_h:
158 | pp[int((seg_num - 1)/2)] = (x, cy)
159 | seg_height[int((seg_num - 1)/2)] = cur_h
160 | prev_h = cur_h
161 |
162 | # processing last segment
163 | if num_sec != 0:
164 | cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
165 |
166 | # pass if num of pivots is not sufficient or segment widh is smaller than character height
167 | if None in pp or seg_w < np.max(seg_height) * 0.25:
168 | polys.append(None); continue
169 |
170 | # calc median maximum of pivot points
171 | half_char_h = np.median(seg_height) * expand_ratio / 2
172 |
173 | # calc gradiant and apply to make horizontal pivots
174 | new_pp = []
175 | for i, (x, cy) in enumerate(pp):
176 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
177 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
178 | if dx == 0: # gradient if zero
179 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
180 | continue
181 | rad = - math.atan2(dy, dx)
182 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
183 | new_pp.append([x - s, cy - c, x + s, cy + c])
184 |
185 | # get edge points to cover character heatmaps
186 | isSppFound, isEppFound = False, False
187 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
188 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
189 | for r in np.arange(0.5, max_r, step_r):
190 | dx = 2 * half_char_h * r
191 | if not isSppFound:
192 | line_img = np.zeros(word_label.shape, dtype=np.uint8)
193 | dy = grad_s * dx
194 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
195 | cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
196 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
197 | spp = p
198 | isSppFound = True
199 | if not isEppFound:
200 | line_img = np.zeros(word_label.shape, dtype=np.uint8)
201 | dy = grad_e * dx
202 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
203 | cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
204 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
205 | epp = p
206 | isEppFound = True
207 | if isSppFound and isEppFound:
208 | break
209 |
210 | # pass if boundary of polygon is not found
211 | if not (isSppFound and isEppFound):
212 | polys.append(None); continue
213 |
214 | # make final polygon
215 | poly = []
216 | poly.append(warpCoord(Minv, (spp[0], spp[1])))
217 | for p in new_pp:
218 | poly.append(warpCoord(Minv, (p[0], p[1])))
219 | poly.append(warpCoord(Minv, (epp[0], epp[1])))
220 | poly.append(warpCoord(Minv, (epp[2], epp[3])))
221 | for p in reversed(new_pp):
222 | poly.append(warpCoord(Minv, (p[2], p[3])))
223 | poly.append(warpCoord(Minv, (spp[2], spp[3])))
224 |
225 | # add to final result
226 | polys.append(np.array(poly))
227 |
228 | return polys
229 |
230 | def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
231 | boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
232 |
233 | if poly:
234 | polys = getPoly_core(boxes, labels, mapper, linkmap)
235 | else:
236 | polys = [None] * len(boxes)
237 |
238 | return boxes, polys
239 |
240 | def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
241 | if len(polys) > 0:
242 | polys = np.array(polys)
243 | for k in range(len(polys)):
244 | if polys[k] is not None:
245 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
246 | return polys
247 |
--------------------------------------------------------------------------------
/e-commercial/models/saliency_detector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import functools
4 | import torch.nn.functional as F
5 | import math
6 |
7 |
8 | def define_salD(input_dim, actType='lrelu', normType='batch', netSalD='denseNet256_36layer'):
9 | if netSalD == 'CNN_3layer':
10 | net = SalDetector(input_dim, growthRate=32, EnBlocks=(0, 0, 0), feaChannel=64, reduction=0.5, bottleneck=True,
11 | actType=actType, normType=normType, useASPP=True, ASPPsacles=(0, 1, 4, 7))
12 | elif netSalD == 'denseNet_15layer':
13 | net = SalDetector(input_dim, growthRate=32, EnBlocks=(2, 2, 2), feaChannel=64, reduction=0.5, bottleneck=True,
14 | actType=actType, normType=normType, useASPP=True, ASPPsacles=(0, 1, 4, 7))
15 | elif netSalD == 'denseNet_20layer':
16 | net = SalDetector(input_dim, growthRate=32, EnBlocks=(2, 2, 2, 2), feaChannel=64, reduction=0.5,
17 | bottleneck=True,
18 | actType=actType, normType=normType, useASPP=True, ASPPsacles=(0, 1, 4, 7))
19 | elif netSalD == 'denseNet_28layer':
20 | net = SalDetector(input_dim, growthRate=16, EnBlocks=(2, 4, 4, 2), feaChannel=64, reduction=0.5,
21 | bottleneck=True,
22 | actType=actType, normType=normType, useASPP=True, ASPPsacles=(0, 1, 4, 7))
23 | else:
24 | raise NotImplementedError('Saliency detector model name [%s] is not recognized' % netSalD)
25 | return net
26 |
27 |
28 | class SalDetector(nn.Module):
29 | def __init__(self, inputdim=768, growthRate=32, EnBlocks=(2, 2, 2), feaChannel=64, reduction=0.5, bottleneck=True,
30 | actType='lrelu', normType='batch', useASPP=True, ASPPsacles=(0, 1, 4, 7)):
31 | super(SalDetector, self).__init__()
32 | norm_layer = get_norm_layer(normType)
33 | act_layer = get_activation_layer(actType)
34 | nChannels = inputdim
35 | layers = []
36 | # layers = [nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False)]
37 | # nChannels = growthRate
38 | for nLayers in EnBlocks:
39 | if nLayers == 0:
40 | CNNlayer = [norm_layer(nChannels),
41 | act_layer(),
42 | nn.Conv2d(nChannels, nChannels // 2, kernel_size=3, padding=1, bias=False)]
43 | layers += CNNlayer
44 | nChannels = nChannels // 2
45 | else:
46 | layers += [DenseBlock(nChannels, growthRate, nLayers, reduction, bottleneck, norm_layer, act_layer)]
47 | nChannels = int(math.floor((nChannels + nLayers * growthRate) * reduction))
48 | if useASPP:
49 | layers += [ASPP(nChannels, feaChannel, scales=ASPPsacles)]
50 | else:
51 | layers += [nn.Conv2d(nChannels, feaChannel, kernel_size=3, padding=1, bias=False)]
52 | self.Encoder = nn.Sequential(*layers)
53 | layers2 = [DecoderBlock(feaChannel, feaChannel // 4, False, norm_layer, act_layer),
54 | DecoderBlock(feaChannel // 4, feaChannel // 16, True, norm_layer, act_layer),
55 | DecoderBlock(feaChannel // 16, 1, False, norm_layer, act_layer),
56 | ]
57 | self.Decoder = nn.Sequential(*layers2)
58 |
59 | def forward(self, x):
60 | encoderFea = self.Encoder(x)
61 | out = self.Decoder(encoderFea)
62 | # out = torch.squeeze(out)
63 | out = NormAffine(out, method='one')
64 | return out
65 |
66 |
67 | class DecoderBlock(nn.Module):
68 | def __init__(self, nChannels, nOutChannels, deConv=False, norm_layer=nn.BatchNorm2d, act_layer=nn.LeakyReLU):
69 | super(DecoderBlock, self).__init__()
70 | interChannels = int(math.ceil(nChannels * math.sqrt(nOutChannels / nChannels)))
71 | if deConv:
72 | # output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
73 | layers = [norm_layer(nChannels),
74 | act_layer(),
75 | nn.ConvTranspose2d(nChannels, interChannels, kernel_size=2, stride=2, padding=0, output_padding=0,
76 | bias=False),
77 | norm_layer(interChannels),
78 | act_layer(),
79 | nn.Conv2d(interChannels, nOutChannels, kernel_size=1, bias=False)
80 | ]
81 | else:
82 | layers = [norm_layer(nChannels),
83 | act_layer(),
84 | nn.Conv2d(nChannels, interChannels, kernel_size=3, padding=1, bias=False),
85 | norm_layer(interChannels),
86 | act_layer(),
87 | nn.Conv2d(interChannels, nOutChannels, kernel_size=1, bias=False)
88 | ]
89 | self.model = nn.Sequential(*layers)
90 |
91 | def forward(self, x):
92 | out = self.model(x)
93 | return out
94 |
95 |
96 | class Transition(nn.Module):
97 | def __init__(self, nChannels, nOutChannels, sn=False, norm_layer=nn.BatchNorm2d, act_layer=nn.LeakyReLU):
98 | super(Transition, self).__init__()
99 | if sn:
100 | layers = [norm_layer(nChannels),
101 | act_layer(),
102 | nn.utils.spectral_norm(nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False))
103 | # nn.AvgPool2d(2)
104 | ]
105 | else:
106 | layers = [norm_layer(nChannels),
107 | act_layer(),
108 | nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
109 | # nn.AvgPool2d(2)
110 | ]
111 | self.model = nn.Sequential(*layers)
112 |
113 | def forward(self, x):
114 | out = self.model(x)
115 | return out
116 |
117 |
118 | class DenseBlock(nn.Module):
119 | def __init__(self, nChannels, growthRate, nLayers, reduction, bottleneck=True, norm_layer=nn.BatchNorm2d,
120 | act_layer=nn.LeakyReLU, sn=False):
121 | super(DenseBlock, self).__init__()
122 | layers = []
123 | for i in range(int(nLayers)):
124 | if bottleneck:
125 | layers += [Bottleneck(nChannels, growthRate, sn, norm_layer, act_layer)]
126 | else:
127 | layers += [SingleLayer(nChannels, growthRate, sn, norm_layer, act_layer)]
128 | nChannels += growthRate
129 | nOutChannels = int(math.floor(nChannels * reduction))
130 | layers += [Transition(nChannels, nOutChannels, sn, norm_layer, act_layer)]
131 | self.model = nn.Sequential(*layers)
132 |
133 | def forward(self, x):
134 | out = self.model(x)
135 | return out
136 |
137 |
138 | class Bottleneck(nn.Module):
139 | def __init__(self, nChannels, growthRate, sn=False, norm_layer=nn.BatchNorm2d, act_layer=nn.LeakyReLU):
140 | super(Bottleneck, self).__init__()
141 | interChannels = 4 * growthRate
142 | if sn:
143 | layers = [norm_layer(nChannels),
144 | act_layer(),
145 | nn.utils.spectral_norm(nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)),
146 | norm_layer(interChannels),
147 | act_layer(),
148 | nn.utils.spectral_norm(nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False))
149 | ]
150 | else:
151 | layers = [norm_layer(nChannels),
152 | act_layer(),
153 | nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False),
154 | norm_layer(interChannels),
155 | act_layer(),
156 | nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False)
157 | ]
158 | self.model = nn.Sequential(*layers)
159 |
160 | def forward(self, x):
161 | out = self.model(x)
162 | out = torch.cat((x, out), 1)
163 | return out
164 |
165 |
166 | class SingleLayer(nn.Module):
167 | def __init__(self, nChannels, growthRate, sn=False, norm_layer=nn.BatchNorm2d, act_layer=nn.LeakyReLU):
168 | super(SingleLayer, self).__init__()
169 | if sn:
170 | layers = [norm_layer(nChannels),
171 | act_layer(),
172 | nn.utils.spectral_norm(nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False))
173 | ]
174 | else:
175 | layers = [norm_layer(nChannels),
176 | act_layer(),
177 | nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False)
178 | ]
179 | self.model = nn.Sequential(*layers)
180 |
181 | def forward(self, x):
182 | out = self.model(x)
183 | out = torch.cat((x, out), 1)
184 | return out
185 |
186 |
187 | class ASPP(nn.Module):
188 | def __init__(self, in_channel=512, depth=256, scales=(0, 1, 4, 7), sn=False):
189 | super(ASPP, self).__init__()
190 | self.scales = scales
191 | for dilate_rate in self.scales:
192 | if dilate_rate == -1:
193 | break
194 | if dilate_rate == 0:
195 | layers = [nn.AdaptiveAvgPool2d((1, 1))]
196 | if sn:
197 | layers += [nn.utils.spectral_norm(nn.Conv2d(in_channel, depth, 1, 1))]
198 | else:
199 | layers += [nn.Conv2d(in_channel, depth, 1, 1)]
200 | setattr(self, 'dilate_layer_{}'.format(dilate_rate), nn.Sequential(*layers))
201 | elif dilate_rate == 1:
202 | if sn:
203 | layers = [nn.utils.spectral_norm(nn.Conv2d(in_channel, depth, 1, 1))]
204 | else:
205 | layers = [nn.Conv2d(in_channel, depth, 1, 1)]
206 | setattr(self, 'dilate_layer_{}'.format(dilate_rate), nn.Sequential(*layers))
207 | else:
208 | if sn:
209 | layers = [nn.utils.spectral_norm(
210 | nn.Conv2d(in_channel, depth, 3, 1, dilation=dilate_rate, padding=dilate_rate))]
211 | else:
212 | layers = [nn.Conv2d(in_channel, depth, 3, 1, dilation=dilate_rate, padding=dilate_rate)]
213 | setattr(self, 'dilate_layer_{}'.format(dilate_rate), nn.Sequential(*layers))
214 | self.conv_1x1_output = nn.Conv2d(depth * len(scales), depth, 1, 1)
215 |
216 | def forward(self, x):
217 | dilate_outs = []
218 | for dilate_rate in self.scales:
219 | if dilate_rate == -1:
220 | return x
221 | if dilate_rate == 0:
222 | layer = getattr(self, 'dilate_layer_{}'.format(dilate_rate))
223 | size = x.shape[2:]
224 | tempout = F.interpolate(layer(x), size=size, mode='bilinear', align_corners=True)
225 | dilate_outs.append(tempout)
226 | else:
227 | layer = getattr(self, 'dilate_layer_{}'.format(dilate_rate))
228 | dilate_outs.append(layer(x))
229 | out = self.conv_1x1_output(torch.cat(dilate_outs, dim=1))
230 | return out
231 |
232 |
233 | ####################################################################
234 | # ------------------------- Basic Functions -------------------------
235 | ####################################################################
236 |
237 |
238 | def get_norm_layer(layer_type='instance'):
239 | if layer_type == 'batch':
240 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
241 | elif layer_type == 'instance':
242 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
243 | elif layer_type == 'none':
244 | norm_layer = None
245 | else:
246 | raise NotImplementedError('normalization layer [%s] is not found' % layer_type)
247 | return norm_layer
248 |
249 |
250 | def get_activation_layer(layer_type='relu'):
251 | if layer_type == 'relu':
252 | nl_layer = functools.partial(nn.ReLU, inplace=True)
253 | elif layer_type == 'lrelu':
254 | nl_layer = functools.partial(nn.LeakyReLU, negative_slope=0.2, inplace=True)
255 | elif layer_type == 'elu':
256 | nl_layer = functools.partial(nn.ELU, inplace=True)
257 | elif layer_type == 'none':
258 | nl_layer = None
259 | else:
260 | raise NotImplementedError('activitation [%s] is not found' % layer_type)
261 | return nl_layer
262 |
263 |
264 | def NormAffine(mat, eps=1e-7,
265 | method='sum'): # tensor [batch_size, channels, image_height, image_width] normalize each fea map;
266 | matdim = len(mat.size())
267 | if method == 'sum':
268 | tempsum = torch.sum(mat, dim=(matdim - 1, matdim - 2), keepdim=True) + eps
269 | out = mat / tempsum
270 | elif method == 'one':
271 | (tempmin, _) = torch.min(mat, dim=matdim - 1, keepdim=True)
272 | (tempmin, _) = torch.min(tempmin, dim=matdim - 2, keepdim=True)
273 | tempmat = mat - tempmin
274 | (tempmax, _) = torch.max(tempmat, dim=matdim - 1, keepdim=True)
275 | (tempmax, _) = torch.max(tempmax, dim=matdim - 2, keepdim=True)
276 | tempmax = tempmax + eps
277 | out = tempmat / tempmax
278 | else:
279 | raise NotImplementedError('Map method [%s] is not implemented' % method)
280 | return out
281 |
--------------------------------------------------------------------------------
/e-commercial/main.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # based on Swin Transformer
3 | # --------------------------------------------------------
4 |
5 | import os
6 | import time
7 | import argparse
8 | import datetime
9 | import numpy as np
10 | import torchvision
11 | import cv2
12 | from IPython import embed
13 |
14 |
15 | import torch
16 | import torch.backends.cudnn as cudnn
17 |
18 | from timm.utils import accuracy, AverageMeter
19 | from torchvision import transforms
20 | import torchvision
21 | import torch.nn.functional as F
22 | from PIL import Image
23 |
24 | from config import get_config
25 | from models import build_model
26 | from models import craft_utils, imgproc
27 | from models.loss import KL_loss, Maploss
28 | from models.metric import calCC, calKL
29 | from data import build_loader
30 | from lr_scheduler import build_scheduler
31 | from optimizer import build_optimizer
32 | from logger import create_logger
33 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor, load_checkpoint_finetune, load_checkpoint_eval
34 | os.environ['CUDA_VISIBLE_DEVICES']="0"
35 |
36 |
37 | try:
38 | # noinspection PyUnresolvedReferences
39 | from apex import amp
40 | except ImportError:
41 | amp = None
42 |
43 |
44 | def parse_option():
45 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
46 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
47 | parser.add_argument(
48 | "--opts",
49 | help="Modify config options by adding 'KEY VALUE' pairs. ",
50 | default=None,
51 | nargs='+',
52 | )
53 |
54 | # easy config modification
55 | parser.add_argument('--batch-size', type=int, help="batch size")
56 | parser.add_argument('--data-path', type=str, help='path to dataset')
57 | # parser.add_argument('--eval-path', type=str, help='path to dataset')
58 | parser.add_argument('--resume', help='resume from checkpoint')
59 | parser.add_argument('--finetune', help='finetune from checkpoint')
60 | parser.add_argument('--use-checkpoint', action='store_true',
61 | help="whether to use gradient checkpointing to save memory")
62 | parser.add_argument('--output', default='./output', type=str, metavar='PATH',
63 | help='root of output folder, the full path is