├── e-commercial ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── build.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── samplers.cpython-37.pyc │ │ ├── build_new.cpython-37.pyc │ │ ├── zipreader.cpython-37.pyc │ │ ├── saliency_loader.cpython-37.pyc │ │ ├── cached_image_folder.cpython-37.pyc │ │ └── saliency_loader_new.cpython-37.pyc │ ├── samplers.py │ ├── saliency_loader_new.py │ ├── zipreader.py │ ├── build_new.py │ ├── build.py │ ├── saliency_loader.py │ └── cached_image_folder.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── loss.cpython-37.pyc │ │ ├── build.cpython-37.pyc │ │ ├── metric.cpython-37.pyc │ │ ├── resnet.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── basemodel.cpython-37.pyc │ │ ├── imgproc.cpython-37.pyc │ │ ├── networks.cpython-37.pyc │ │ ├── swin_mlp.cpython-37.pyc │ │ ├── craft_utils.cpython-37.pyc │ │ ├── swin_transformer.cpython-37.pyc │ │ ├── saliency_detector.cpython-37.pyc │ │ └── sswin_transformer.cpython-37.pyc │ ├── build.py │ ├── imgproc.py │ ├── basemodel.py │ ├── metric.py │ ├── loss.py │ ├── craft_utils.py │ └── saliency_detector.py ├── .idea │ ├── .gitignore │ ├── vcs.xml │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── modules.xml │ └── e-commercial.iml ├── __pycache__ │ ├── config.cpython-37.pyc │ ├── logger.cpython-37.pyc │ ├── utils.cpython-37.pyc │ ├── config.cpython-310.pyc │ ├── optimizer.cpython-37.pyc │ └── lr_scheduler.cpython-37.pyc ├── configs │ ├── test.yaml │ ├── finetune.yaml │ └── sswin.yaml ├── logger.py ├── optimizer.py ├── compare.py ├── lr_scheduler.py ├── utils.py ├── config.py ├── main.py ├── train.py ├── metrics.py └── main_cnn.py ├── .gitattributes ├── LICENSE ├── README.md └── .history └── e-commercial ├── config_20240408005816.py ├── config_20240330210315.py └── config_20240408005800.py /e-commercial/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader -------------------------------------------------------------------------------- /e-commercial/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /e-commercial/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /e-commercial/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /e-commercial/__pycache__/optimizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/optimizer.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/data/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/data/__pycache__/samplers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/samplers.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/data/__pycache__/build_new.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/build_new.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/data/__pycache__/zipreader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/zipreader.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/basemodel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/basemodel.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/imgproc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/imgproc.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/swin_mlp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/swin_mlp.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/craft_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/craft_utils.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/data/__pycache__/saliency_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/saliency_loader.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/swin_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/swin_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/data/__pycache__/cached_image_folder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/cached_image_folder.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/data/__pycache__/saliency_loader_new.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/data/__pycache__/saliency_loader_new.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/saliency_detector.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/saliency_detector.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/models/__pycache__/sswin_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leafy-lee/E-commercial-dataset/HEAD/e-commercial/models/__pycache__/sswin_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /e-commercial/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /e-commercial/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /e-commercial/configs/test.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 896 3 | MODEL: 4 | TYPE: sswin 5 | NAME: densenet 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /e-commercial/configs/finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 896 3 | MODEL: 4 | TYPE: sswin 5 | NAME: densenet 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /e-commercial/configs/sswin.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 896 3 | MODEL: 4 | TYPE: sswin 5 | NAME: sswin_patch4_window7 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /e-commercial/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /e-commercial/.idea/e-commercial.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /e-commercial/data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import torch 6 | 7 | 8 | class SubsetRandomSampler(torch.utils.data.Sampler): 9 | r"""Samples elements randomly from a given list of indices, without replacement. 10 | 11 | Arguments: 12 | indices (sequence): a sequence of indices 13 | """ 14 | 15 | def __init__(self, indices): 16 | self.epoch = 0 17 | self.indices = indices 18 | 19 | def __iter__(self): 20 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 21 | 22 | def __len__(self): 23 | return len(self.indices) 24 | 25 | def set_epoch(self, epoch): 26 | self.epoch = epoch 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 YiFei Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /e-commercial/models/build.py: -------------------------------------------------------------------------------- 1 | from .sswin_transformer import SSwinTransformer 2 | 3 | 4 | def build_model(config): 5 | model_type = config.MODEL.TYPE 6 | if model_type == 'sswin': 7 | model = SSwinTransformer(img_size=config.DATA.IMG_SIZE, 8 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 9 | in_chans=config.MODEL.SWIN.IN_CHANS, 10 | num_classes=config.MODEL.NUM_CLASSES, 11 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 12 | depths=config.MODEL.SWIN.DEPTHS, 13 | num_heads=config.MODEL.SWIN.NUM_HEADS, 14 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 15 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 16 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 17 | qk_scale=config.MODEL.SWIN.QK_SCALE, 18 | drop_rate=config.MODEL.DROP_RATE, 19 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 20 | ape=config.MODEL.SWIN.APE, 21 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 22 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 23 | head=config.HEAD) 24 | else: 25 | raise NotImplementedError(f"Unkown model: {model_type}") 26 | 27 | return model -------------------------------------------------------------------------------- /e-commercial/logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import sys 7 | import logging 8 | import functools 9 | from termcolor import colored 10 | 11 | 12 | @functools.lru_cache() 13 | def create_logger(output_dir, dist_rank=0, name=''): 14 | # create logger 15 | logger = logging.getLogger(name) 16 | logger.setLevel(logging.DEBUG) 17 | logger.propagate = False 18 | 19 | # create formatter 20 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 21 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 22 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 23 | 24 | # create console handlers for master process 25 | if dist_rank == 0: 26 | console_handler = logging.StreamHandler(sys.stdout) 27 | console_handler.setLevel(logging.DEBUG) 28 | console_handler.setFormatter( 29 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 30 | logger.addHandler(console_handler) 31 | 32 | # create file handlers 33 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 34 | file_handler.setLevel(logging.DEBUG) 35 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 36 | logger.addHandler(file_handler) 37 | 38 | return logger 39 | -------------------------------------------------------------------------------- /e-commercial/optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | from torch import optim as optim 6 | 7 | 8 | def build_optimizer(config, model): 9 | """ 10 | Build optimizer, set weight decay of normalization to 0 by default. 11 | """ 12 | skip = {} 13 | skip_keywords = {} 14 | if hasattr(model, 'no_weight_decay'): 15 | skip = model.no_weight_decay() 16 | if hasattr(model, 'no_weight_decay_keywords'): 17 | skip_keywords = model.no_weight_decay_keywords() 18 | parameters = set_weight_decay(model, skip, skip_keywords) 19 | 20 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 21 | optimizer = None 22 | if opt_lower == 'sgd': 23 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 24 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 25 | elif opt_lower == 'adamw': 26 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 27 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 28 | 29 | return optimizer 30 | 31 | 32 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 33 | has_decay = [] 34 | no_decay = [] 35 | 36 | for name, param in model.named_parameters(): 37 | if not param.requires_grad: 38 | continue # frozen weights 39 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 40 | check_keywords_in_name(name, skip_keywords): 41 | no_decay.append(param) 42 | # print(f"{name} has no weight decay") 43 | else: 44 | has_decay.append(param) 45 | return [{'params': has_decay}, 46 | {'params': no_decay, 'weight_decay': 0.}] 47 | 48 | 49 | def check_keywords_in_name(name, keywords=()): 50 | isin = False 51 | for keyword in keywords: 52 | if keyword in name: 53 | isin = True 54 | return isin 55 | -------------------------------------------------------------------------------- /e-commercial/models/imgproc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | from skimage import io 9 | import cv2 10 | 11 | def loadImage(img_file): 12 | img = io.imread(img_file) # RGB order 13 | if img.shape[0] == 2: img = img[0] 14 | if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 15 | if img.shape[2] == 4: img = img[:,:,:3] 16 | img = np.array(img) 17 | 18 | return img 19 | 20 | def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 21 | # should be RGB order 22 | img = in_img.copy().astype(np.float32) 23 | 24 | img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32) 25 | img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32) 26 | return img 27 | 28 | def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 29 | # should be RGB order 30 | img = in_img.copy() 31 | img *= variance 32 | img += mean 33 | img *= 255.0 34 | img = np.clip(img, 0, 255).astype(np.uint8) 35 | return img 36 | 37 | def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1): 38 | height, width, channel = img.shape 39 | 40 | # magnify image size 41 | target_size = mag_ratio * max(height, width) 42 | 43 | # set original image size 44 | if target_size > square_size: 45 | target_size = square_size 46 | 47 | ratio = target_size / max(height, width) 48 | 49 | target_h, target_w = int(height * ratio), int(width * ratio) 50 | proc = cv2.resize(img, (target_w, target_h), interpolation = interpolation) 51 | 52 | 53 | # make canvas and paste image 54 | target_h32, target_w32 = target_h, target_w 55 | if target_h % 32 != 0: 56 | target_h32 = target_h + (32 - target_h % 32) 57 | if target_w % 32 != 0: 58 | target_w32 = target_w + (32 - target_w % 32) 59 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32) 60 | resized[0:target_h, 0:target_w, :] = proc 61 | target_h, target_w = target_h32, target_w32 62 | 63 | size_heatmap = (int(target_w/2), int(target_h/2)) 64 | 65 | return resized, ratio, size_heatmap 66 | 67 | def cvt2HeatmapImg(img): 68 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8) 69 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 70 | return img 71 | -------------------------------------------------------------------------------- /e-commercial/compare.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms, utils 3 | from PIL import Image 4 | import argparse 5 | import os 6 | import glob 7 | from metrics import auc_judd, nss 8 | 9 | import torch.nn.functional as F 10 | 11 | # Read the data 12 | toTensor = transforms.Compose([ 13 | transforms.Resize(720), 14 | transforms.ToTensor(), 15 | ]) 16 | 17 | 18 | def read(imgDir: str) -> torch.Tensor: 19 | x = toTensor(Image.open(imgDir)) 20 | return x 21 | 22 | 23 | def calAll(path1: str, path2: str, name: str): 24 | aucAll = 0 25 | nssAll = 0 26 | img_list1 = glob.glob(f"{path1}/*.jpg") 27 | tbs = len(img_list1) 28 | with torch.no_grad(): 29 | for i, path in enumerate(img_list1): 30 | # imgList1 = [] 31 | # imgList2 = [] 32 | # print(f"[{i:05}|{tbs}] | {1}", end="\r") 33 | cn = path.split("/")[-1].split(".")[0] 34 | # breakpoint() 35 | img1 = read(path) 36 | img2 = read(f"{path2}/{cn}_fixMap.jpg") 37 | # if i in allwrong: 38 | # cnt += 1 39 | # print(f"weird case {i} jump {cnt=}") 40 | # continue 41 | # breakpoint() 42 | # imgList1.append(img1) 43 | # imgList2.append(img2) 44 | # img1 = torch.stack(imgList1).cuda() 45 | # img2 = torch.stack(imgList2).cuda() 46 | 47 | singauc = auc_judd(img1, img2) 48 | singnss = nss(img1, img2) 49 | print( 50 | f"[{i:05}|{tbs}]\t\t\t\tauc path2 {(singauc):.4f} | nss {(singnss):.4f}", end="\r") 51 | with open(f"record/{name}/record", "a") as f: 52 | f.write(f"[{i:05}|{tbs}]\t\t\t\tauc path2 {(singauc):.4f} | nss {(singnss):.4f}\n") 53 | 54 | aucAll += singauc 55 | nssAll += singnss 56 | print(f"{path1} {(aucAll / tbs):.4f} | {(nssAll / tbs):.4f}") 57 | with open(f"record/all_record.txt", "a") as f: 58 | f.write(f"{path1} | auc nss | {(aucAll / tbs):.4f} | {(nssAll / tbs):.4f}\n") 59 | # print(f"[INFO]: id ORI {(idCompAll / tbs):.4f} id {(idAll / tbs):.4f} psnr ORI {(qpsnrAll / tbs):.4f} psnr {(psnrCompAll / tbs):.4f} ") 60 | 61 | 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--path1', type=str) 64 | parser.add_argument('--path2', type=str, default="/mnt/hdd1/yifei/DATA/ECdata/ALLFIXATIONMAPS") 65 | args = parser.parse_args() 66 | 67 | PATH1 = args.path1 68 | PATH2 = args.path2 69 | dirlis = args.path1.split("/") 70 | name = f"{dirlis[-3]}_{dirlis[-2]}_{dirlis[-1]}" 71 | os.makedirs(f"record/{name}", exist_ok=True) 72 | 73 | calAll(PATH1, PATH2, name) 74 | -------------------------------------------------------------------------------- /e-commercial/models/basemodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | import torch.nn.functional as F 5 | 6 | 7 | def conv_1x1_bn(inp, oup): 8 | return nn.Sequential( 9 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 10 | nn.BatchNorm2d(oup), 11 | nn.ReLU6(inplace=True) 12 | ) 13 | 14 | 15 | class InvertedResidual(nn.Module): 16 | def __init__(self, inp, oup, stride, expand_ratio, omit_stride=False, 17 | no_res_connect=False, dropout=0., bn_momentum=0.1, 18 | batchnorm=None): 19 | super().__init__() 20 | self.out_channels = oup 21 | self.stride = stride 22 | self.omit_stride = omit_stride 23 | self.use_res_connect = not no_res_connect and\ 24 | self.stride == 1 and inp == oup 25 | self.dropout = dropout 26 | actual_stride = self.stride if not self.omit_stride else 1 27 | if batchnorm is None: 28 | def batchnorm(num_features): 29 | return nn.BatchNorm2d(num_features, momentum=bn_momentum) 30 | 31 | assert actual_stride in [1, 2] 32 | 33 | hidden_dim = round(inp * expand_ratio) 34 | if expand_ratio == 1: 35 | modules = [ 36 | # dw 37 | nn.Conv2d(hidden_dim, hidden_dim, 3, actual_stride, 1, 38 | groups=hidden_dim, bias=False), 39 | batchnorm(hidden_dim), 40 | nn.ReLU6(inplace=True), 41 | # pw-linear 42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 43 | batchnorm(oup), 44 | ] 45 | if self.dropout > 0: 46 | modules.append(nn.Dropout2d(self.dropout)) 47 | self.conv = nn.Sequential(*modules) 48 | else: 49 | modules = [ 50 | # pw 51 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 52 | batchnorm(hidden_dim), 53 | nn.ReLU6(inplace=True), 54 | # dw 55 | nn.Conv2d(hidden_dim, hidden_dim, 3, actual_stride, 1, 56 | groups=hidden_dim, bias=False), 57 | batchnorm(hidden_dim), 58 | nn.ReLU6(inplace=True), 59 | # pw-linear 60 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 61 | batchnorm(oup), 62 | ] 63 | if self.dropout > 0: 64 | modules.insert(3, nn.Dropout2d(self.dropout)) 65 | self.conv = nn.Sequential(*modules) 66 | self._initialize_weights() 67 | 68 | def forward(self, x): 69 | if self.use_res_connect: 70 | return x + self.conv(x) 71 | else: 72 | return self.conv(x) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # E-commercial-dataset 2 | Dataset of electronic commercial image used for saliency. 3 | 4 | The dataset can be downloaded in https://www.dropbox.com/s/xsui782oy3kvjsm/E-commercial%20dataset.zip?dl=0. 5 | 6 | ## IMAGE 7 | 8 | Original images are saved in this path as *.jpg 9 | 10 | ## FIXATION 11 | 12 | Fixation maps are saved as *\_fixPts.jpg, while saliency maps are saved as *\_.fixMap.jpg. 13 | 14 | ## TEXT REGION 15 | 16 | The text detection results are stored in csv file, with the affinity score and region score. 17 | 18 | # SSwin-transformer Model added in Repo 19 | [![](https://img.shields.io/badge/pytorch-1.8.0-brightgreen)]() 20 | [![](https://img.shields.io/badge/CUDA-%E2%89%A510.2-lightgrey)]() 21 | [![](https://img.shields.io/badge/python-%E2%89%A53.7-orange)]() 22 | ## To-do list 23 | 1. -[x] Adding environment setting (you can use environment same as swin-transformer as temporary alternatives) 24 | 2. -[ ] Refine the code into efficient way 25 | 26 | ### Environment preparing 27 | - Clone this repo: 28 | 29 | ```bash 30 | git clone https://github.com/leafy-lee/E-commercial-dataset.git 31 | cd e-commercial 32 | ``` 33 | 34 | - Create a conda virtual environment and activate it: 35 | 36 | ```bash 37 | conda create -n ecom python=3.7 -y 38 | conda activate ecom 39 | ``` 40 | 41 | - Install `CUDA>=10.2` with `cudnn>=7` following 42 | the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) 43 | - Install `PyTorch>=1.8.0` and `torchvision>=0.9.0` with `CUDA>=10.2`: 44 | 45 | ```bash 46 | conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch 47 | ``` 48 | 49 | - Install `timm==0.4.12`: 50 | 51 | ```bash 52 | pip install timm==0.4.12 53 | ``` 54 | 55 | - Install other requirements: 56 | 57 | ```bash 58 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 pyyaml scipy 59 | ``` 60 | ### Evaluation 61 | 62 | To train the model, run: 63 | 64 | ```bash 65 | python train.py --batch-size 8 --cfg configs/sswin.yaml --data-path DATA/ECdata/ --dataset ecdata --head headname 66 | ``` 67 | 68 | 69 | 70 | ### Evaluation 71 | 72 | To evaluate a trained model, run: 73 | 74 | ```bash 75 | python main.py --eval --cfg config --resume True --finetune ckpt --data-path data_dir 76 | ``` 77 | 78 | ## Citation 79 | If you use this code, please cite 80 | ``` 81 | @InProceedings{Jiang_2022_CVPR, 82 | author = {Jiang, Lai and Li, Yifei and Li, Shengxi and Xu, Mai and Lei, Se and Guo, Yichen and Huang, Bo}, 83 | title = {Does Text Attract Attention on E-Commerce Images: A Novel Saliency Prediction Dataset and Method}, 84 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 85 | month = {June}, 86 | year = {2022}, 87 | pages = {2088-2097} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /e-commercial/models/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def calCC(gtsAnns, resAnns, is_train): 6 | if is_train: 7 | gtsAnn = gtsAnns[0, ...].detach().clone() 8 | gtsAnn = torch.cat([gtsAnn, gtsAnn, gtsAnn], 0) 9 | gtsAnn = gtsAnn.cpu().float().detach().numpy() 10 | resAnn = resAnns[0, ...].detach().clone() 11 | resAnn = torch.cat([resAnn, resAnn, resAnn], 0) 12 | resAnn = resAnn.cpu().float().detach().numpy() 13 | fixationMap = gtsAnn - np.mean(gtsAnn) 14 | if np.max(fixationMap) > 0: 15 | fixationMap = fixationMap / np.std(fixationMap) 16 | salMap = resAnn - np.mean(resAnn) 17 | if np.max(salMap) > 0: 18 | salMap = salMap / np.std(salMap) 19 | return np.corrcoef(salMap.reshape(-1), fixationMap.reshape(-1))[0][1] 20 | else: 21 | cc = 0 22 | for idx, gtsAnn in enumerate(gtsAnns): 23 | gtsAnn = gtsAnn.detach().clone() 24 | gtsAnn = torch.cat([gtsAnn, gtsAnn, gtsAnn], 0) 25 | gtsAnn = gtsAnn.cpu().float().detach().numpy() 26 | resAnn = resAnns[idx].detach().clone() 27 | resAnn = torch.cat([resAnn, resAnn, resAnn], 0) 28 | resAnn = resAnn.cpu().float().detach().numpy() 29 | fixationMap = gtsAnn - np.mean(gtsAnn) 30 | if np.max(fixationMap) > 0: 31 | fixationMap = fixationMap / np.std(fixationMap) 32 | salMap = resAnn - np.mean(resAnn) 33 | if np.max(salMap) > 0: 34 | salMap = salMap / np.std(salMap) 35 | cc += np.corrcoef(salMap.reshape(-1), fixationMap.reshape(-1))[0][1] 36 | return cc / gtsAnns.size()[0] 37 | 38 | 39 | def calKL(gtsAnns, resAnns, is_train, eps=1e-7): 40 | if is_train: 41 | gtsAnn = gtsAnns[0, ...].detach().clone() 42 | gtsAnn = torch.cat([gtsAnn, gtsAnn, gtsAnn], 0) 43 | gtsAnn = gtsAnn.cpu().float().detach().numpy() 44 | resAnn = resAnns[0, ...].detach().clone() 45 | resAnn = torch.cat([resAnn, resAnn, resAnn], 0) 46 | resAnn = resAnn.cpu().float().detach().numpy() 47 | if np.sum(gtsAnn) > 0: 48 | gtsAnn = gtsAnn / np.sum(gtsAnn) 49 | if np.sum(resAnn) > 0: 50 | resAnn = resAnn / np.sum(resAnn) 51 | return np.sum(gtsAnn * np.log(eps + gtsAnn / (resAnn + eps))) 52 | else: 53 | kl = 0 54 | for idx, gtsAnn in enumerate(gtsAnns): 55 | gtsAnn = gtsAnn.detach().clone() 56 | gtsAnn = torch.cat([gtsAnn, gtsAnn, gtsAnn], 0) 57 | gtsAnn = gtsAnn.cpu().float().detach().numpy() 58 | resAnn = resAnns[idx].detach().clone() 59 | resAnn = torch.cat([resAnn, resAnn, resAnn], 0) 60 | resAnn = resAnn.cpu().float().detach().numpy() 61 | if np.sum(gtsAnn) > 0: 62 | gtsAnn = gtsAnn / np.sum(gtsAnn) 63 | if np.sum(resAnn) > 0: 64 | resAnn = resAnn / np.sum(resAnn) 65 | kl += np.sum(gtsAnn * np.log(eps + gtsAnn / (resAnn + eps))) 66 | return kl / gtsAnns.size()[0] -------------------------------------------------------------------------------- /e-commercial/data/saliency_loader_new.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torch, csv 4 | from PIL import Image 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | import torch.nn.functional as F 8 | import random 9 | import glob 10 | import numpy as np 11 | random.seed(1) 12 | globaltest = [] 13 | 14 | def ecommercedata(data_path, img_nums, is_train): 15 | global globaltest 16 | nums = [i for i in range(1, img_nums + 1)] 17 | if not globaltest: 18 | globaltest = random.sample(nums, img_nums // 10) 19 | test = globaltest 20 | train = list(set(nums) - set(test)) 21 | if is_train: 22 | return EcommerceDataset(data_path=data_path, img_nums=(img_nums - (img_nums // 10)), 23 | lis=train) 24 | else: 25 | return EcommerceDataset(data_path=data_path, img_nums=(img_nums // 10), 26 | lis=test) 27 | 28 | 29 | class EcommerceDataset(Dataset): 30 | """Face Landmarks dataset.""" 31 | 32 | def __init__(self, data_path, img_nums, transform=None, lis=[]): 33 | """ 34 | Args: 35 | data_path (string): Path to the imgs file with saliency. 36 | img_nums (int): Total number of images to index. 37 | transform (callable, optional): Optional transform to be applied 38 | on a sample. 39 | """ 40 | self.root_dir = data_path 41 | self.transform = transform 42 | self.data_len = img_nums 43 | self.lis = lis 44 | 45 | def __len__(self): 46 | return self.data_len 47 | 48 | def __getitem__(self, idx): 49 | if torch.is_tensor(idx): 50 | idx = idx.tolist() 51 | # print("idx", idx, "length", len(self.lis)) 52 | 53 | img_name = 'ALLSTIMULI/' + str(self.lis[idx]) + '.jpg' 54 | saliency_name = 'ALLFIXATIONMAPS/' + str(self.lis[idx]) + '_fixMap.jpg' 55 | ocr_aff_name = 'OCR/affinity/' + str(self.lis[idx]) + '.csv' 56 | ocr_reg_name = 'OCR/region/' + str(self.lis[idx]) + '.csv' 57 | img_file = os.path.join(self.root_dir, img_name) 58 | saliency_file = os.path.join(self.root_dir, saliency_name) 59 | ocr_aff_file = os.path.join(self.root_dir, ocr_aff_name) 60 | ocr_reg_file = os.path.join(self.root_dir, ocr_reg_name) 61 | 62 | # images 63 | img = Image.open(img_file) 64 | # img = np.array(img, dtype=np.float32) # h, w, c 65 | torch_img = transforms.functional.to_tensor(img) 66 | torch_img = transforms.Resize(896)(torch_img) 67 | # saliency 68 | saliency = Image.open(saliency_file) 69 | # saliency = np.array(saliency, dtype=np.float32) # h, w, c 70 | torch_saliency = transforms.functional.to_tensor(saliency) 71 | torch_saliency = transforms.Resize(896)(torch_saliency) 72 | # ocr 73 | csv_aff_content = np.loadtxt(open(ocr_aff_file, "rb"), delimiter=",") 74 | csv_reg_content = np.loadtxt(open(ocr_reg_file, "rb"), delimiter=",") 75 | # to make pytorch happy in transforms 76 | csv_aff_image = np.expand_dims(csv_aff_content, axis=0) 77 | csv_reg_image = np.expand_dims(csv_reg_content, axis=0) 78 | torch_aff = torch.from_numpy(csv_aff_image).float() 79 | torch_reg = torch.from_numpy(csv_reg_image).float() 80 | gh_label = transforms.Resize(896)(torch_aff) 81 | gah_label = transforms.Resize(896)(torch_reg) 82 | # print('The sizes of gh_label and gah_label are :', gh_label.size(), gah_label.size()) 83 | if self.transform: 84 | raise NotImplementedError("Not support any transform by far!") 85 | # sample = self.transform(sample) 86 | 87 | return torch_img, torch_saliency 88 | -------------------------------------------------------------------------------- /e-commercial/data/zipreader.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer Zipreader 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import zipfile 7 | import io 8 | import numpy as np 9 | from PIL import Image 10 | from PIL import ImageFile 11 | 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | 15 | def is_zip_path(img_or_path): 16 | """judge if this is a zip path""" 17 | return '.zip@' in img_or_path 18 | 19 | 20 | class ZipReader(object): 21 | """A class to read zipped files""" 22 | zip_bank = dict() 23 | 24 | def __init__(self): 25 | super(ZipReader, self).__init__() 26 | 27 | @staticmethod 28 | def get_zipfile(path): 29 | zip_bank = ZipReader.zip_bank 30 | if path not in zip_bank: 31 | zfile = zipfile.ZipFile(path, 'r') 32 | zip_bank[path] = zfile 33 | return zip_bank[path] 34 | 35 | @staticmethod 36 | def split_zip_style_path(path): 37 | pos_at = path.index('@') 38 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 39 | 40 | zip_path = path[0: pos_at] 41 | folder_path = path[pos_at + 1:] 42 | folder_path = str.strip(folder_path, '/') 43 | return zip_path, folder_path 44 | 45 | @staticmethod 46 | def list_folder(path): 47 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 48 | 49 | zfile = ZipReader.get_zipfile(zip_path) 50 | folder_list = [] 51 | for file_foler_name in zfile.namelist(): 52 | file_foler_name = str.strip(file_foler_name, '/') 53 | if file_foler_name.startswith(folder_path) and \ 54 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 55 | file_foler_name != folder_path: 56 | if len(folder_path) == 0: 57 | folder_list.append(file_foler_name) 58 | else: 59 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 60 | 61 | return folder_list 62 | 63 | @staticmethod 64 | def list_files(path, extension=None): 65 | if extension is None: 66 | extension = ['.*'] 67 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 68 | 69 | zfile = ZipReader.get_zipfile(zip_path) 70 | file_lists = [] 71 | for file_foler_name in zfile.namelist(): 72 | file_foler_name = str.strip(file_foler_name, '/') 73 | if file_foler_name.startswith(folder_path) and \ 74 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 75 | if len(folder_path) == 0: 76 | file_lists.append(file_foler_name) 77 | else: 78 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 79 | 80 | return file_lists 81 | 82 | @staticmethod 83 | def read(path): 84 | zip_path, path_img = ZipReader.split_zip_style_path(path) 85 | zfile = ZipReader.get_zipfile(zip_path) 86 | data = zfile.read(path_img) 87 | return data 88 | 89 | @staticmethod 90 | def imread(path): 91 | zip_path, path_img = ZipReader.split_zip_style_path(path) 92 | zfile = ZipReader.get_zipfile(zip_path) 93 | data = zfile.read(path_img) 94 | try: 95 | im = Image.open(io.BytesIO(data)) 96 | except: 97 | print("ERROR IMG LOADED: ", path_img) 98 | random_img = np.random.rand(224, 224, 3) * 255 99 | im = Image.fromarray(np.uint8(random_img)) 100 | return im 101 | -------------------------------------------------------------------------------- /e-commercial/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import torch 6 | from timm.scheduler.cosine_lr import CosineLRScheduler 7 | from timm.scheduler.step_lr import StepLRScheduler 8 | from timm.scheduler.scheduler import Scheduler 9 | 10 | 11 | def build_scheduler(config, optimizer, n_iter_per_epoch): 12 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 13 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 14 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 15 | 16 | lr_scheduler = None 17 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 18 | lr_scheduler = CosineLRScheduler( 19 | optimizer, 20 | t_initial=num_steps, 21 | t_mul=1., 22 | lr_min=config.TRAIN.MIN_LR, 23 | warmup_lr_init=config.TRAIN.WARMUP_LR, 24 | warmup_t=warmup_steps, 25 | cycle_limit=1, 26 | t_in_epochs=False, 27 | ) 28 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 29 | lr_scheduler = LinearLRScheduler( 30 | optimizer, 31 | t_initial=num_steps, 32 | lr_min_rate=0.01, 33 | warmup_lr_init=config.TRAIN.WARMUP_LR, 34 | warmup_t=warmup_steps, 35 | t_in_epochs=False, 36 | ) 37 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 38 | lr_scheduler = StepLRScheduler( 39 | optimizer, 40 | decay_t=decay_steps, 41 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 42 | warmup_lr_init=config.TRAIN.WARMUP_LR, 43 | warmup_t=warmup_steps, 44 | t_in_epochs=False, 45 | ) 46 | 47 | return lr_scheduler 48 | 49 | 50 | class LinearLRScheduler(Scheduler): 51 | def __init__(self, 52 | optimizer: torch.optim.Optimizer, 53 | t_initial: int, 54 | lr_min_rate: float, 55 | warmup_t=0, 56 | warmup_lr_init=0., 57 | t_in_epochs=True, 58 | noise_range_t=None, 59 | noise_pct=0.67, 60 | noise_std=1.0, 61 | noise_seed=42, 62 | initialize=True, 63 | ) -> None: 64 | super().__init__( 65 | optimizer, param_group_field="lr", 66 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 67 | initialize=initialize) 68 | 69 | self.t_initial = t_initial 70 | self.lr_min_rate = lr_min_rate 71 | self.warmup_t = warmup_t 72 | self.warmup_lr_init = warmup_lr_init 73 | self.t_in_epochs = t_in_epochs 74 | if self.warmup_t: 75 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 76 | super().update_groups(self.warmup_lr_init) 77 | else: 78 | self.warmup_steps = [1 for _ in self.base_values] 79 | 80 | def _get_lr(self, t): 81 | if t < self.warmup_t: 82 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 83 | else: 84 | t = t - self.warmup_t 85 | total_t = self.t_initial - self.warmup_t 86 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 87 | return lrs 88 | 89 | def get_epoch_values(self, epoch: int): 90 | if self.t_in_epochs: 91 | return self._get_lr(epoch) 92 | else: 93 | return None 94 | 95 | def get_update_values(self, num_updates: int): 96 | if not self.t_in_epochs: 97 | return self._get_lr(num_updates) 98 | else: 99 | return None 100 | -------------------------------------------------------------------------------- /e-commercial/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | #################################################################### 7 | # -------------------------- Losses -------------------------- 8 | #################################################################### 9 | 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | class Maploss(nn.Module): 17 | def __init__(self, use_gpu = True): 18 | 19 | super(Maploss,self).__init__() 20 | 21 | def single_image_loss(self, pre_loss, loss_label): 22 | batch_size = pre_loss.shape[0] 23 | sum_loss = torch.mean(pre_loss.view(-1))*0 24 | pre_loss = pre_loss.view(batch_size, -1) 25 | loss_label = loss_label.view(batch_size, -1) 26 | internel = batch_size 27 | # print(loss_label.shape, pre_loss.shape) 28 | for i in range(batch_size): 29 | average_number = 0 30 | loss = torch.mean(pre_loss.view(-1)) * 0 31 | positive_pixel = len(pre_loss[i][(loss_label[i] >= 0.1)]) 32 | average_number += positive_pixel 33 | if positive_pixel != 0: 34 | posi_loss = torch.mean(pre_loss[i][(loss_label[i] >= 0.1)]) 35 | sum_loss += posi_loss 36 | if len(pre_loss[i][(loss_label[i] < 0.1)]) < 3*positive_pixel: 37 | nega_loss = torch.mean(pre_loss[i][(loss_label[i] < 0.1)]) 38 | average_number += len(pre_loss[i][(loss_label[i] < 0.1)]) 39 | else: 40 | nega_loss = torch.mean(torch.topk(pre_loss[i][(loss_label[i] < 0.1)], 3*positive_pixel)[0]) 41 | average_number += 3*positive_pixel 42 | sum_loss += nega_loss 43 | else: 44 | nega_loss = torch.mean(torch.topk(pre_loss[i], 500)[0]) 45 | average_number += 500 46 | sum_loss += nega_loss 47 | #sum_loss += loss/average_number 48 | 49 | return sum_loss 50 | 51 | 52 | 53 | def forward(self, gh_label, gah_label, p_gh, p_gah, mask): 54 | gh_label = gh_label 55 | gah_label = gah_label 56 | p_gh = p_gh 57 | p_gah = p_gah 58 | loss_fn = torch.nn.MSELoss(reduce=False, size_average=False) 59 | mask = mask.squeeze(1) 60 | 61 | assert p_gh.size() == gh_label.size() and p_gah.size() == gah_label.size() 62 | loss1 = loss_fn(p_gh, gh_label) 63 | loss2 = loss_fn(p_gah, gah_label) 64 | # print("loss1.shape, mask.shape", loss1.shape, mask.shape) 65 | # print("loss2.shape, mask.shape", loss2.shape, mask.shape) 66 | loss_g = torch.mul(loss1, mask) 67 | loss_a = torch.mul(loss2, mask) 68 | # print("loss shape", loss_g.shape, loss_a.shape, gah_label.shape, gh_label.shape) 69 | 70 | char_loss = self.single_image_loss(loss_g, gh_label) 71 | affi_loss = self.single_image_loss(loss_a, gah_label) 72 | return char_loss/loss_g.shape[0] + affi_loss/loss_a.shape[0] 73 | 74 | 75 | def NormAffine(mat, eps=1e-7, 76 | method='sum'): # tensor [batch_size, channels, image_height, image_width] normalize each fea map; 77 | matdim = len(mat.size()) 78 | if method == 'sum': 79 | tempsum = torch.sum(mat, dim=(matdim - 1, matdim - 2), keepdim=True) + eps 80 | out = mat / tempsum 81 | elif method == 'one': 82 | (tempmin, _) = torch.min(mat, dim=matdim - 1, keepdim=True) 83 | (tempmin, _) = torch.min(tempmin, dim=matdim - 2, keepdim=True) 84 | tempmat = mat - tempmin 85 | (tempmax, _) = torch.max(tempmat, dim=matdim - 1, keepdim=True) 86 | (tempmax, _) = torch.max(tempmax, dim=matdim - 2, keepdim=True) 87 | tempmax = tempmax + eps 88 | out = tempmat / tempmax 89 | else: 90 | raise NotImplementedError('Map method [%s] is not implemented' % method) 91 | return out 92 | 93 | 94 | def KL_loss(out, gt): 95 | assert out.size() == gt.size() 96 | out = NormAffine(out, eps=1e-7, method='sum') 97 | gt = NormAffine(gt, eps=1e-7, method='sum') 98 | loss = torch.sum(gt * torch.log(1e-7 + gt / (out + 1e-7))) 99 | loss = loss / out.size()[0] 100 | return loss 101 | -------------------------------------------------------------------------------- /e-commercial/data/build_new.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import numpy as np 11 | from torchvision import datasets, transforms 12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.data import Mixup 14 | from timm.data import create_transform 15 | from timm.data.transforms import _pil_interp 16 | 17 | from .cached_image_folder import CachedImageFolder 18 | from .samplers import SubsetRandomSampler 19 | from data.saliency_loader_new import ecommercedata 20 | 21 | 22 | def build_loader(config): 23 | config.defrost() 24 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 25 | config.freeze() 26 | print(f"local rank {config.LOCAL_RANK} successfully build train dataset") 27 | dataset_val, _ = build_dataset(is_train=False, config=config) 28 | print(f"local rank {config.LOCAL_RANK} successfully build val dataset") 29 | 30 | indices_train = np.arange(len(dataset_train)) 31 | sampler_train = SubsetRandomSampler(indices_train) 32 | indices_val = np.arange(len(dataset_val)) 33 | sampler_val = SubsetRandomSampler(indices_val) 34 | 35 | data_loader_train = torch.utils.data.DataLoader( 36 | dataset_train, sampler=sampler_train, 37 | batch_size=config.DATA.BATCH_SIZE, 38 | num_workers=config.DATA.NUM_WORKERS, 39 | pin_memory=config.DATA.PIN_MEMORY, 40 | drop_last=True, 41 | ) 42 | 43 | data_loader_val = torch.utils.data.DataLoader( 44 | dataset_val, sampler=sampler_val, 45 | batch_size=config.DATA.BATCH_SIZE, 46 | shuffle=False, 47 | num_workers=config.DATA.NUM_WORKERS, 48 | pin_memory=config.DATA.PIN_MEMORY, 49 | drop_last=False 50 | ) 51 | 52 | # setup mixup / cutmix 53 | mixup_fn = None 54 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 55 | if mixup_active: 56 | mixup_fn = Mixup( 57 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 58 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 59 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 60 | 61 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 62 | 63 | 64 | def build_dataset(is_train, config): 65 | transform = build_transform(is_train, config) 66 | if config.DATA.DATASET == 'ecdata': 67 | dataset = ecommercedata(config.DATA.DATA_PATH, config.DATA.DATANUM, is_train) 68 | nb_classes = 1 69 | else: 70 | raise NotImplementedError("We only support ImageNet Now.") 71 | 72 | return dataset, nb_classes 73 | 74 | 75 | def build_transform(is_train, config): 76 | resize_im = config.DATA.IMG_SIZE > 32 77 | if is_train: 78 | # this should always dispatch to transforms_imagenet_train 79 | transform = create_transform( 80 | input_size=config.DATA.IMG_SIZE, 81 | is_training=True, 82 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 83 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 84 | re_prob=config.AUG.REPROB, 85 | re_mode=config.AUG.REMODE, 86 | re_count=config.AUG.RECOUNT, 87 | interpolation=config.DATA.INTERPOLATION, 88 | ) 89 | if not resize_im: 90 | # replace RandomResizedCropAndInterpolation with 91 | # RandomCrop 92 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 93 | return transform 94 | 95 | t = [] 96 | if resize_im: 97 | if config.TEST.CROP: 98 | size = int((256 / 224) * config.DATA.IMG_SIZE) 99 | t.append( 100 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 101 | # to maintain same ratio w.r.t. 224 images 102 | ) 103 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 104 | else: 105 | t.append( 106 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 107 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 108 | ) 109 | 110 | t.append(transforms.ToTensor()) 111 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 112 | return transforms.Compose(t) 113 | -------------------------------------------------------------------------------- /e-commercial/data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import torch 7 | import numpy as np 8 | import torch.distributed as dist 9 | from torchvision import datasets, transforms 10 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 11 | from timm.data import Mixup 12 | from timm.data import create_transform 13 | from timm.data.transforms import _pil_interp 14 | 15 | from .cached_image_folder import CachedImageFolder 16 | from .samplers import SubsetRandomSampler 17 | from data.saliency_loader import ecommercedata 18 | from data.saliency_loader import finetunedata, folderimagedata 19 | 20 | 21 | def build_loader(config, logger): 22 | config.defrost() 23 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config, logger=logger) 24 | config.freeze() 25 | # print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 26 | dataset_val, _ = build_dataset(is_train=False, config=config, logger=logger) 27 | # print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 28 | ''' 29 | num_tasks = dist.get_world_size() 30 | global_rank = dist.get_rank() 31 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 32 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 33 | sampler_train = SubsetRandomSampler(indices) 34 | else: 35 | sampler_train = torch.utils.data.DistributedSampler( 36 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 37 | ) 38 | 39 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) 40 | sampler_val = SubsetRandomSampler(indices) 41 | ''' 42 | data_loader_train = torch.utils.data.DataLoader( 43 | dataset_train, 44 | batch_size=config.DATA.BATCH_SIZE, 45 | num_workers=config.DATA.NUM_WORKERS, 46 | pin_memory=config.DATA.PIN_MEMORY, 47 | drop_last=True, 48 | ) 49 | 50 | data_loader_val = torch.utils.data.DataLoader( 51 | dataset_val, 52 | batch_size=config.DATA.BATCH_SIZE, 53 | shuffle=False, 54 | num_workers=config.DATA.NUM_WORKERS, 55 | pin_memory=config.DATA.PIN_MEMORY, 56 | drop_last=False 57 | ) 58 | 59 | # setup mixup / cutmix 60 | mixup_fn = None 61 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 62 | if mixup_active: 63 | mixup_fn = Mixup( 64 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 65 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 66 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 67 | 68 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 69 | # return dataset_train, data_loader_train, mixup_fn 70 | 71 | 72 | def build_dataset(is_train, config, logger=None): 73 | transform = build_transform(is_train, config) 74 | if config.DATA.DATASET == 'ecdata': 75 | dataset = ecommercedata(config.DATA.DATA_PATH, config.DATA.DATANUM, is_train, logger=logger) 76 | nb_classes = 1 77 | elif config.DATA.DATASET == 'finetunedata': 78 | dataset = finetunedata(config.DATA.DATA_PATH) 79 | nb_classes = 1 80 | elif config.DATA.DATASET == 'folderimagedata': 81 | dataset = folderimagedata(config.DATA.DATA_PATH) 82 | nb_classes = 1 83 | else: 84 | raise NotImplementedError("We only support ImageNet Now.") 85 | 86 | return dataset, nb_classes 87 | 88 | 89 | def build_transform(is_train, config): 90 | resize_im = config.DATA.IMG_SIZE > 32 91 | if is_train: 92 | # this should always dispatch to transforms_imagenet_train 93 | transform = create_transform( 94 | input_size=config.DATA.IMG_SIZE, 95 | is_training=True, 96 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 97 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 98 | re_prob=config.AUG.REPROB, 99 | re_mode=config.AUG.REMODE, 100 | re_count=config.AUG.RECOUNT, 101 | interpolation=config.DATA.INTERPOLATION, 102 | ) 103 | if not resize_im: 104 | # replace RandomResizedCropAndInterpolation with 105 | # RandomCrop 106 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 107 | return transform 108 | 109 | t = [] 110 | if resize_im: 111 | if config.TEST.CROP: 112 | size = int((256 / 224) * config.DATA.IMG_SIZE) 113 | t.append( 114 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 115 | # to maintain same ratio w.r.t. 224 images 116 | ) 117 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 118 | else: 119 | t.append( 120 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 121 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 122 | ) 123 | 124 | t.append(transforms.ToTensor()) 125 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 126 | return transforms.Compose(t) 127 | -------------------------------------------------------------------------------- /e-commercial/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | 6 | import os 7 | import torch 8 | import torch.distributed as dist 9 | 10 | try: 11 | # noinspection PyUnresolvedReferences 12 | from apex import amp 13 | except ImportError: 14 | amp = None 15 | 16 | 17 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger): 18 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 19 | if config.MODEL.RESUME.startswith('https'): 20 | checkpoint = torch.hub.load_state_dict_from_url( 21 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 22 | else: 23 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 24 | msg = model.load_state_dict(checkpoint['model'], strict=False) 25 | logger.info(msg) 26 | max_accuracy = 0.0 27 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 28 | optimizer.load_state_dict(checkpoint['optimizer']) 29 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 30 | config.defrost() 31 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 32 | config.freeze() 33 | if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0": 34 | amp.load_state_dict(checkpoint['amp']) 35 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 36 | if 'max_accuracy' in checkpoint: 37 | max_accuracy = checkpoint['max_accuracy'] 38 | 39 | del checkpoint 40 | torch.cuda.empty_cache() 41 | return max_accuracy 42 | 43 | def load_checkpoint_finetune(config, model, optimizer, lr_scheduler, logger): 44 | logger.info(f"==============> finetune....................") 45 | print("loading from ", config.MODEL.FINETUNE) 46 | if config.MODEL.RESUME.startswith('https'): 47 | checkpoint = torch.hub.load_state_dict_from_url( 48 | config.MODEL.FINETUNE, map_location='cpu', check_hash=True) 49 | else: 50 | checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu') 51 | msg = model.load_state_dict(checkpoint['model'], strict=False) 52 | logger.info(msg) 53 | max_accuracy = 0.0 54 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 55 | optimizer.load_state_dict(checkpoint['optimizer']) 56 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 57 | config.defrost() 58 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 59 | config.freeze() 60 | logger.info(f"=> loaded successfully '{config.MODEL.FINETUNE}' (epoch {checkpoint['epoch']})") 61 | 62 | del checkpoint 63 | torch.cuda.empty_cache() 64 | 65 | def load_checkpoint_eval(config, model, optimizer, logger): 66 | logger.info(f"==============> finetune....................") 67 | print("loading from ", config.MODEL.FINETUNE) 68 | if config.MODEL.RESUME.startswith('https'): 69 | checkpoint = torch.hub.load_state_dict_from_url( 70 | config.MODEL.FINETUNE, map_location='cpu', check_hash=True) 71 | else: 72 | checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu') 73 | breakpoint() 74 | msg = model.load_state_dict(checkpoint['model'], strict=False) 75 | logger.info(msg) 76 | max_accuracy = 0.0 77 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 78 | config.defrost() 79 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 80 | config.freeze() 81 | logger.info(f"=> loaded successfully '{config.MODEL.FINETUNE}' (epoch {checkpoint['epoch']})") 82 | 83 | del checkpoint 84 | torch.cuda.empty_cache() 85 | 86 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger): 87 | save_state = {'model': model.state_dict(), 88 | 'optimizer': optimizer.state_dict(), 89 | 'lr_scheduler': lr_scheduler.state_dict(), 90 | 'max_accuracy': max_accuracy, 91 | 'epoch': epoch, 92 | 'config': config} 93 | if config.AMP_OPT_LEVEL != "O0": 94 | save_state['amp'] = amp.state_dict() 95 | 96 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 97 | logger.info(f"{save_path} saving......") 98 | torch.save(save_state, save_path) 99 | logger.info(f"{save_path} saved !!!") 100 | 101 | 102 | def get_grad_norm(parameters, norm_type=2): 103 | if isinstance(parameters, torch.Tensor): 104 | parameters = [parameters] 105 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 106 | norm_type = float(norm_type) 107 | total_norm = 0 108 | for p in parameters: 109 | param_norm = p.grad.data.norm(norm_type) 110 | total_norm += param_norm.item() ** norm_type 111 | total_norm = total_norm ** (1. / norm_type) 112 | return total_norm 113 | 114 | 115 | def auto_resume_helper(output_dir): 116 | checkpoints = os.listdir(output_dir) 117 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 118 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 119 | if len(checkpoints) > 0: 120 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 121 | print(f"The latest checkpoint founded: {latest_checkpoint}") 122 | resume_file = latest_checkpoint 123 | else: 124 | resume_file = None 125 | return resume_file 126 | 127 | 128 | def reduce_tensor(tensor): 129 | rt = tensor.clone() 130 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 131 | rt /= dist.get_world_size() 132 | return rt 133 | -------------------------------------------------------------------------------- /e-commercial/data/saliency_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torch, csv 4 | from PIL import Image 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | import torch.nn.functional as F 8 | import random 9 | import glob 10 | import numpy as np 11 | random.seed(1) 12 | globaltest = [] 13 | 14 | def ecommercedata(data_path, img_nums, is_train, logger): 15 | global globaltest 16 | nums = [i for i in range(1, img_nums + 1)] 17 | if not globaltest: 18 | globaltest = random.sample(nums, img_nums // 10) 19 | test = globaltest 20 | logger.info(f"now globaltest is {globaltest}") 21 | train = list(set(nums) - set(test)) 22 | if is_train: 23 | print(f"using {(img_nums - (img_nums // 10))} images for train") 24 | return EcommerceDataset(data_path=data_path, img_nums=(img_nums - (img_nums // 10)), 25 | lis=train) 26 | else: 27 | print(f"using {(img_nums // 10)} images for test") 28 | return EcommerceDataset(data_path=data_path, img_nums=(img_nums // 10), 29 | lis=test) 30 | 31 | 32 | class EcommerceDataset(Dataset): 33 | """Face Landmarks dataset.""" 34 | 35 | def __init__(self, data_path, img_nums, transform=None, lis=[]): 36 | """ 37 | Args: 38 | data_path (string): Path to the imgs file with saliency. 39 | img_nums (int): Total number of images to index. 40 | transform (callable, optional): Optional transform to be applied 41 | on a sample. 42 | """ 43 | self.root_dir = data_path 44 | self.transform = transform 45 | self.data_len = img_nums 46 | self.lis = lis 47 | 48 | def __len__(self): 49 | return self.data_len 50 | 51 | def __getitem__(self, idx): 52 | if torch.is_tensor(idx): 53 | idx = idx.tolist() 54 | # print("idx", idx, "length", len(self.lis)) 55 | 56 | img_name = 'ALLSTIMULI/' + str(self.lis[idx]) + '.jpg' 57 | saliency_name = 'ALLFIXATIONMAPS/' + str(self.lis[idx]) + '_fixMap.jpg' 58 | ocr_aff_name = 'OCR/affinity/' + str(self.lis[idx]) + '.csv' 59 | ocr_reg_name = 'OCR/region/' + str(self.lis[idx]) + '.csv' 60 | img_file = os.path.join(self.root_dir, img_name) 61 | saliency_file = os.path.join(self.root_dir, saliency_name) 62 | ocr_aff_file = os.path.join(self.root_dir, ocr_aff_name) 63 | ocr_reg_file = os.path.join(self.root_dir, ocr_reg_name) 64 | 65 | # images 66 | img = Image.open(img_file) 67 | # img = np.array(img, dtype=np.float32) # h, w, c 68 | torch_img = transforms.functional.to_tensor(img) 69 | torch_img = transforms.Resize(896)(torch_img) 70 | # saliency 71 | saliency = Image.open(saliency_file) 72 | # saliency = np.array(saliency, dtype=np.float32) # h, w, c 73 | torch_saliency = transforms.functional.to_tensor(saliency) 74 | torch_saliency = transforms.Resize(896)(torch_saliency) 75 | # ocr 76 | csv_aff_content = np.loadtxt(open(ocr_aff_file, "rb"), delimiter=",") 77 | csv_reg_content = np.loadtxt(open(ocr_reg_file, "rb"), delimiter=",") 78 | # to make pytorch happy in transforms 79 | csv_aff_image = np.expand_dims(csv_aff_content, axis=0) 80 | csv_reg_image = np.expand_dims(csv_reg_content, axis=0) 81 | torch_aff = torch.from_numpy(csv_aff_image).float() 82 | torch_reg = torch.from_numpy(csv_reg_image).float() 83 | gh_label = transforms.Resize(224)(torch_aff) 84 | gah_label = transforms.Resize(224)(torch_reg) 85 | # print('The sizes of gh_label and gah_label are :', gh_label.size(), gah_label.size()) 86 | if self.transform: 87 | raise NotImplementedError("Not support any transform by far!") 88 | # sample = self.transform(sample) 89 | 90 | return self.lis[idx], torch_img, torch_saliency, \ 91 | {'gh_label': gh_label, 'gah_label': gah_label, 'mask': torch.ones_like(gah_label)} 92 | 93 | 94 | class finetunedata(Dataset): 95 | """Face Landmarks dataset.""" 96 | 97 | def __init__(self, data_path, transform=None): 98 | """ 99 | Args: 100 | data_path (string): Path to the imgs file with saliency. 101 | img_nums (int): Total number of images to index. 102 | transform (callable, optional): Optional transform to be applied 103 | on a sample. 104 | """ 105 | self.root_dir = data_path 106 | self.transform = transform 107 | path = os.path.join(self.root_dir, 'stimuli/') 108 | path2 = os.path.join(self.root_dir, 'fixation/') 109 | print("using imgs in %s"%path) 110 | imgs = [f for f in glob.glob(path+"*.jpg")] 111 | gts = [f for f in glob.glob(path2+"*.jpg")] 112 | print("imgs are like %s"%imgs[0]) 113 | self.data_len = len(imgs) 114 | print("using %d imgs in training"%self.data_len) 115 | self.lis = imgs 116 | self.gts = gts 117 | 118 | def __len__(self): 119 | return self.data_len 120 | 121 | def __getitem__(self, idx): 122 | if torch.is_tensor(idx): 123 | idx = idx.tolist() 124 | 125 | img_file = str(self.lis[idx]) 126 | saliency_file = str(self.gts[idx]) 127 | 128 | # images 129 | img = Image.open(img_file) 130 | # img = np.array(img, dtype=np.float32) # h, w, c 131 | torch_img = transforms.functional.to_tensor(img) 132 | torch_img = transforms.Resize((896,896))(torch_img) 133 | # saliency 134 | saliency = Image.open(saliency_file) 135 | # saliency = np.array(saliency, dtype=np.float32) # h, w, c 136 | torch_saliency = transforms.functional.to_tensor(saliency) 137 | torch_saliency = transforms.Resize((896,896))(torch_saliency) 138 | # print("traing", torch_img.shape, torch_saliency.shape) 139 | if torch_img.size()[0] == 1: 140 | torch_img = torch_img.expand([3,896,896]) 141 | if self.transform: 142 | raise NotImplementedError("Not support any transform by far!") 143 | # sample = self.transform(sample) 144 | 145 | return str(self.lis[idx]), torch_img, torch_saliency 146 | 147 | class folderimagedata(Dataset): 148 | """Face Landmarks dataset.""" 149 | 150 | def __init__(self, data_path, transform=None): 151 | """ 152 | Args: 153 | data_path (string): Path to the imgs file with saliency. 154 | img_nums (int): Total number of images to index. 155 | transform (callable, optional): Optional transform to be applied 156 | on a sample. 157 | """ 158 | self.root_dir = data_path 159 | self.transform = transform 160 | print("using imgs in %s"%self.root_dir) 161 | supported = ["jpg", "jpeg"] 162 | imgs = [] 163 | for i in supported: 164 | types = [f for f in glob.glob(self.root_dir+"*."+i)] 165 | imgs += types 166 | print("imgs are like %s"%imgs[0]) 167 | self.data_len = len(imgs) 168 | print("using %d imgs in validate folder"%self.data_len) 169 | self.lis = imgs 170 | 171 | def __len__(self): 172 | return self.data_len 173 | 174 | def __getitem__(self, idx): 175 | if torch.is_tensor(idx): 176 | idx = idx.tolist() 177 | 178 | img_file = str(self.lis[idx]) 179 | 180 | # images 181 | img = Image.open(img_file) 182 | # img = np.array(img, dtype=np.float32) # h, w, c 183 | torch_img = transforms.functional.to_tensor(img) 184 | torch_img = transforms.Resize((896,896))(torch_img) 185 | print(torch_img.size()) 186 | if torch_img.size()[0] == 1: 187 | torch_img = torch_img.expand([3,896,896]) 188 | if self.transform: 189 | raise NotImplementedError("Not support any transform by far!") 190 | # sample = self.transform(sample) 191 | 192 | return str(self.lis[idx]), torch_img, torch_img[:1].unsqueeze(0) 193 | -------------------------------------------------------------------------------- /e-commercial/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import yaml 7 | import time 8 | from yacs.config import CfgNode as CN 9 | 10 | _C = CN() 11 | 12 | # Base config files 13 | _C.BASE = [''] 14 | 15 | # ----------------------------------------------------------------------------- 16 | # Data settings 17 | # ----------------------------------------------------------------------------- 18 | _C.DATA = CN() 19 | # Batch size for a single GPU, could be overwritten by command line argument 20 | _C.DATA.BATCH_SIZE = 128 21 | # Path to dataset, could be overwritten by command line argument 22 | _C.DATA.DATA_PATH = '' 23 | # Dataset name 24 | _C.DATA.DATASET = 'imagenet' 25 | # Input image size 26 | _C.DATA.IMG_SIZE = 896 27 | # Interpolation to resize image (random, bilinear, bicubic) 28 | _C.DATA.INTERPOLATION = 'bicubic' 29 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 30 | _C.DATA.PIN_MEMORY = True 31 | # Number of data loading threads 32 | _C.DATA.NUM_WORKERS = 8 33 | 34 | # ----------------------------------------------------------------------------- 35 | # Model settings 36 | # ----------------------------------------------------------------------------- 37 | _C.MODEL = CN() 38 | # Model type 39 | _C.MODEL.TYPE = 'sswin' 40 | # Model name 41 | _C.MODEL.NAME = 'sswin' 42 | # Checkpoint to resume, could be overwritten by command line argument 43 | _C.MODEL.RESUME = '' 44 | # Number of classes, overwritten in data preparation 45 | _C.MODEL.NUM_CLASSES = 1 46 | # Dropout rate 47 | _C.MODEL.DROP_RATE = 0.0 48 | # Drop path rate 49 | _C.MODEL.DROP_PATH_RATE = 0.1 50 | # Label Smoothing 51 | _C.MODEL.LABEL_SMOOTHING = 0.1 52 | 53 | _C.MODEL.FINETUNE = 'your_path/swin_tiny_patch4_window7_224/default/ckpt_epoch_49.pth' 54 | 55 | # Swin Transformer parameters 56 | _C.MODEL.SWIN = CN() 57 | _C.MODEL.SWIN.PATCH_SIZE = 4 58 | _C.MODEL.SWIN.IN_CHANS = 3 59 | _C.MODEL.SWIN.EMBED_DIM = 96 60 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 61 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 62 | _C.MODEL.SWIN.WINDOW_SIZE = 7 63 | _C.MODEL.SWIN.MLP_RATIO = 4. 64 | _C.MODEL.SWIN.QKV_BIAS = True 65 | _C.MODEL.SWIN.QK_SCALE = None 66 | _C.MODEL.SWIN.APE = False 67 | _C.MODEL.SWIN.PATCH_NORM = True 68 | 69 | # Swin MLP parameters 70 | _C.MODEL.SWIN_MLP = CN() 71 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4 72 | _C.MODEL.SWIN_MLP.IN_CHANS = 3 73 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96 74 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] 75 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] 76 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 77 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4. 78 | _C.MODEL.SWIN_MLP.APE = False 79 | _C.MODEL.SWIN_MLP.PATCH_NORM = True 80 | 81 | # ----------------------------------------------------------------------------- 82 | # Training settings 83 | # ----------------------------------------------------------------------------- 84 | _C.TRAIN = CN() 85 | _C.TRAIN.START_EPOCH = 0 86 | _C.TRAIN.EPOCHS = 50 87 | _C.TRAIN.WARMUP_EPOCHS = 20 88 | _C.TRAIN.WEIGHT_DECAY = 0.05 89 | _C.TRAIN.BASE_LR = 5e-4 90 | _C.TRAIN.WARMUP_LR = 5e-7 91 | _C.TRAIN.MIN_LR = 5e-6 92 | # Clip gradient norm 93 | _C.TRAIN.CLIP_GRAD = 5.0 94 | # Auto resume from latest checkpoint 95 | _C.TRAIN.AUTO_RESUME = True 96 | # Whether to use gradient checkpointing to save memory 97 | # could be overwritten by command line argument 98 | _C.TRAIN.USE_CHECKPOINT = False 99 | 100 | # LR scheduler 101 | _C.TRAIN.LR_SCHEDULER = CN() 102 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 103 | # Epoch interval to decay LR, used in StepLRScheduler 104 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 105 | # LR decay rate, used in StepLRScheduler 106 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 107 | 108 | # Optimizer 109 | _C.TRAIN.OPTIMIZER = CN() 110 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 111 | # Optimizer Epsilon 112 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 113 | # Optimizer Betas 114 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 115 | # SGD momentum 116 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 117 | 118 | # ----------------------------------------------------------------------------- 119 | # Augmentation settings 120 | # ----------------------------------------------------------------------------- 121 | _C.AUG = CN() 122 | # Color jitter factor 123 | _C.AUG.COLOR_JITTER = 0.4 124 | # Use AutoAugment policy. "v0" or "original" 125 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 126 | # Random erase prob 127 | _C.AUG.REPROB = 0.25 128 | # Random erase mode 129 | _C.AUG.REMODE = 'pixel' 130 | # Random erase count 131 | _C.AUG.RECOUNT = 1 132 | # Mixup alpha, mixup enabled if > 0 133 | _C.AUG.MIXUP = 0.8 134 | # Cutmix alpha, cutmix enabled if > 0 135 | _C.AUG.CUTMIX = 1.0 136 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 137 | _C.AUG.CUTMIX_MINMAX = None 138 | # Probability of performing mixup or cutmix when either/both is enabled 139 | _C.AUG.MIXUP_PROB = 1.0 140 | # Probability of switching to cutmix when both mixup and cutmix enabled 141 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 142 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 143 | _C.AUG.MIXUP_MODE = 'batch' 144 | 145 | # ----------------------------------------------------------------------------- 146 | # Testing settings 147 | # ----------------------------------------------------------------------------- 148 | _C.TEST = CN() 149 | # Whether to use center crop when testing 150 | _C.TEST.CROP = True 151 | 152 | # ----------------------------------------------------------------------------- 153 | # Misc 154 | # ----------------------------------------------------------------------------- 155 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 156 | # overwritten by command line argument 157 | _C.AMP_OPT_LEVEL = 'O0' 158 | # Path to output folder, overwritten by command line argument 159 | _C.OUTPUT = '' 160 | # Tag of experiment, overwritten by command line argument 161 | _C.TAG = 'default' 162 | # Frequency to save checkpoint 163 | _C.SAVE_FREQ = 1 164 | # Frequency to logging info 165 | _C.PRINT_FREQ = 10 166 | # Fixed random seed 167 | _C.SEED = 0 168 | # Perform evaluation only, overwritten by command line argument 169 | _C.EVAL_MODE = False 170 | # Test throughput only, overwritten by command line argument 171 | _C.THROUGHPUT_MODE = False 172 | # local rank for DistributedDataParallel, given by command line argument 173 | # _C.LOCAL_RANK = 0 174 | 175 | _C.HEAD = 'denseNet_15layer' 176 | 177 | 178 | def _update_config_from_file(config, cfg_file): 179 | config.defrost() 180 | with open(cfg_file, 'r') as f: 181 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 182 | 183 | for cfg in yaml_cfg.setdefault('BASE', ['']): 184 | if cfg: 185 | _update_config_from_file( 186 | config, os.path.join(os.path.dirname(cfg_file), cfg) 187 | ) 188 | print('=> merge config from {}'.format(cfg_file)) 189 | config.merge_from_file(cfg_file) 190 | config.freeze() 191 | 192 | 193 | def update_config(config, args): 194 | _update_config_from_file(config, args.cfg) 195 | 196 | config.defrost() 197 | if args.opts: 198 | config.merge_from_list(args.opts) 199 | 200 | # merge from specific arguments 201 | if args.batch_size: 202 | config.DATA.BATCH_SIZE = args.batch_size 203 | if args.data_path: 204 | config.DATA.DATA_PATH = args.data_path 205 | if args.resume: 206 | config.MODEL.RESUME = args.resume 207 | if args.use_checkpoint: 208 | config.TRAIN.USE_CHECKPOINT = True 209 | if args.output: 210 | config.OUTPUT = args.output 211 | if args.tag: 212 | config.TAG = args.tag 213 | if args.eval: 214 | config.EVAL_MODE = True 215 | if args.dataset: 216 | config.DATA.DATASET = args.dataset 217 | if args.datanum: 218 | config.DATA.DATANUM = args.datanum 219 | if args.num_epoch: 220 | config.TRAIN.EPOCHS = args.num_epoch 221 | 222 | # if args.loss: 223 | if args.finetune: 224 | # print("finetune") 225 | config.MODEL.FINETUNE = args.finetune 226 | # config.LOSS = args.loss 227 | 228 | # set local rank for distributed training 229 | # config.LOCAL_RANK = args.local_rank 230 | 231 | # output folder 232 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME + ".".join(str(time.asctime()).split()), config.TAG) 233 | 234 | config.freeze() 235 | 236 | 237 | def get_config(args): 238 | """Get a yacs CfgNode object with default values.""" 239 | # Return a clone so that the defaults will not be altered 240 | # This is for the "local variable" use pattern 241 | config = _C.clone() 242 | update_config(config, args) 243 | 244 | return config 245 | -------------------------------------------------------------------------------- /.history/e-commercial/config_20240408005816.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import yaml 7 | import time 8 | from yacs.config import CfgNode as CN 9 | 10 | _C = CN() 11 | 12 | # Base config files 13 | _C.BASE = [''] 14 | 15 | # ----------------------------------------------------------------------------- 16 | # Data settings 17 | # ----------------------------------------------------------------------------- 18 | _C.DATA = CN() 19 | # Batch size for a single GPU, could be overwritten by command line argument 20 | _C.DATA.BATCH_SIZE = 128 21 | # Path to dataset, could be overwritten by command line argument 22 | _C.DATA.DATA_PATH = '' 23 | # Dataset name 24 | _C.DATA.DATASET = 'imagenet' 25 | # Input image size 26 | _C.DATA.IMG_SIZE = 896 27 | # Interpolation to resize image (random, bilinear, bicubic) 28 | _C.DATA.INTERPOLATION = 'bicubic' 29 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 30 | _C.DATA.PIN_MEMORY = True 31 | # Number of data loading threads 32 | _C.DATA.NUM_WORKERS = 8 33 | 34 | # ----------------------------------------------------------------------------- 35 | # Model settings 36 | # ----------------------------------------------------------------------------- 37 | _C.MODEL = CN() 38 | # Model type 39 | _C.MODEL.TYPE = 'sswin' 40 | # Model name 41 | _C.MODEL.NAME = 'sswin' 42 | # Checkpoint to resume, could be overwritten by command line argument 43 | _C.MODEL.RESUME = '' 44 | # Number of classes, overwritten in data preparation 45 | _C.MODEL.NUM_CLASSES = 1 46 | # Dropout rate 47 | _C.MODEL.DROP_RATE = 0.0 48 | # Drop path rate 49 | _C.MODEL.DROP_PATH_RATE = 0.1 50 | # Label Smoothing 51 | _C.MODEL.LABEL_SMOOTHING = 0.1 52 | 53 | _C.MODEL.FINETUNE = 'your_path/swin_tiny_patch4_window7_224/default/ckpt_epoch_49.pth' 54 | 55 | # Swin Transformer parameters 56 | _C.MODEL.SWIN = CN() 57 | _C.MODEL.SWIN.PATCH_SIZE = 4 58 | _C.MODEL.SWIN.IN_CHANS = 3 59 | _C.MODEL.SWIN.EMBED_DIM = 96 60 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 61 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 62 | _C.MODEL.SWIN.WINDOW_SIZE = 7 63 | _C.MODEL.SWIN.MLP_RATIO = 4. 64 | _C.MODEL.SWIN.QKV_BIAS = True 65 | _C.MODEL.SWIN.QK_SCALE = None 66 | _C.MODEL.SWIN.APE = False 67 | _C.MODEL.SWIN.PATCH_NORM = True 68 | 69 | # Swin MLP parameters 70 | _C.MODEL.SWIN_MLP = CN() 71 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4 72 | _C.MODEL.SWIN_MLP.IN_CHANS = 3 73 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96 74 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] 75 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] 76 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 77 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4. 78 | _C.MODEL.SWIN_MLP.APE = False 79 | _C.MODEL.SWIN_MLP.PATCH_NORM = True 80 | 81 | # ----------------------------------------------------------------------------- 82 | # Training settings 83 | # ----------------------------------------------------------------------------- 84 | _C.TRAIN = CN() 85 | _C.TRAIN.START_EPOCH = 0 86 | _C.TRAIN.EPOCHS = 50 87 | _C.TRAIN.WARMUP_EPOCHS = 20 88 | _C.TRAIN.WEIGHT_DECAY = 0.05 89 | _C.TRAIN.BASE_LR = 5e-4 90 | _C.TRAIN.WARMUP_LR = 5e-7 91 | _C.TRAIN.MIN_LR = 5e-6 92 | # Clip gradient norm 93 | _C.TRAIN.CLIP_GRAD = 5.0 94 | # Auto resume from latest checkpoint 95 | _C.TRAIN.AUTO_RESUME = True 96 | # Whether to use gradient checkpointing to save memory 97 | # could be overwritten by command line argument 98 | _C.TRAIN.USE_CHECKPOINT = False 99 | 100 | # LR scheduler 101 | _C.TRAIN.LR_SCHEDULER = CN() 102 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 103 | # Epoch interval to decay LR, used in StepLRScheduler 104 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 105 | # LR decay rate, used in StepLRScheduler 106 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 107 | 108 | # Optimizer 109 | _C.TRAIN.OPTIMIZER = CN() 110 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 111 | # Optimizer Epsilon 112 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 113 | # Optimizer Betas 114 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 115 | # SGD momentum 116 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 117 | 118 | # ----------------------------------------------------------------------------- 119 | # Augmentation settings 120 | # ----------------------------------------------------------------------------- 121 | _C.AUG = CN() 122 | # Color jitter factor 123 | _C.AUG.COLOR_JITTER = 0.4 124 | # Use AutoAugment policy. "v0" or "original" 125 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 126 | # Random erase prob 127 | _C.AUG.REPROB = 0.25 128 | # Random erase mode 129 | _C.AUG.REMODE = 'pixel' 130 | # Random erase count 131 | _C.AUG.RECOUNT = 1 132 | # Mixup alpha, mixup enabled if > 0 133 | _C.AUG.MIXUP = 0.8 134 | # Cutmix alpha, cutmix enabled if > 0 135 | _C.AUG.CUTMIX = 1.0 136 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 137 | _C.AUG.CUTMIX_MINMAX = None 138 | # Probability of performing mixup or cutmix when either/both is enabled 139 | _C.AUG.MIXUP_PROB = 1.0 140 | # Probability of switching to cutmix when both mixup and cutmix enabled 141 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 142 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 143 | _C.AUG.MIXUP_MODE = 'batch' 144 | 145 | # ----------------------------------------------------------------------------- 146 | # Testing settings 147 | # ----------------------------------------------------------------------------- 148 | _C.TEST = CN() 149 | # Whether to use center crop when testing 150 | _C.TEST.CROP = True 151 | 152 | # ----------------------------------------------------------------------------- 153 | # Misc 154 | # ----------------------------------------------------------------------------- 155 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 156 | # overwritten by command line argument 157 | _C.AMP_OPT_LEVEL = 'O0' 158 | # Path to output folder, overwritten by command line argument 159 | _C.OUTPUT = '' 160 | # Tag of experiment, overwritten by command line argument 161 | _C.TAG = 'default' 162 | # Frequency to save checkpoint 163 | _C.SAVE_FREQ = 1 164 | # Frequency to logging info 165 | _C.PRINT_FREQ = 10 166 | # Fixed random seed 167 | _C.SEED = 0 168 | # Perform evaluation only, overwritten by command line argument 169 | _C.EVAL_MODE = False 170 | # Test throughput only, overwritten by command line argument 171 | _C.THROUGHPUT_MODE = False 172 | # local rank for DistributedDataParallel, given by command line argument 173 | # _C.LOCAL_RANK = 0 174 | 175 | _C.HEAD = 'denseNet_15layer' 176 | 177 | 178 | def _update_config_from_file(config, cfg_file): 179 | config.defrost() 180 | with open(cfg_file, 'r') as f: 181 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 182 | 183 | for cfg in yaml_cfg.setdefault('BASE', ['']): 184 | if cfg: 185 | _update_config_from_file( 186 | config, os.path.join(os.path.dirname(cfg_file), cfg) 187 | ) 188 | print('=> merge config from {}'.format(cfg_file)) 189 | config.merge_from_file(cfg_file) 190 | config.freeze() 191 | 192 | 193 | def update_config(config, args): 194 | _update_config_from_file(config, args.cfg) 195 | 196 | config.defrost() 197 | if args.opts: 198 | config.merge_from_list(args.opts) 199 | 200 | # merge from specific arguments 201 | if args.batch_size: 202 | config.DATA.BATCH_SIZE = args.batch_size 203 | if args.data_path: 204 | config.DATA.DATA_PATH = args.data_path 205 | if args.resume: 206 | config.MODEL.RESUME = args.resume 207 | if args.use_checkpoint: 208 | config.TRAIN.USE_CHECKPOINT = True 209 | if args.output: 210 | config.OUTPUT = args.output 211 | if args.tag: 212 | config.TAG = args.tag 213 | if args.eval: 214 | config.EVAL_MODE = True 215 | if args.dataset: 216 | config.DATA.DATASET = args.dataset 217 | if args.datanum: 218 | config.DATA.DATANUM = args.datanum 219 | if args.num_epoch: 220 | config.TRAIN.EPOCHS = args.num_epoch 221 | 222 | # if args.loss: 223 | if args.finetune: 224 | # print("finetune") 225 | config.MODEL.FINETUNE = args.finetune 226 | # config.LOSS = args.loss 227 | 228 | # set local rank for distributed training 229 | # config.LOCAL_RANK = args.local_rank 230 | 231 | # output folder 232 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME + ".".join(str(time.asctime()).split()), config.TAG) 233 | 234 | config.freeze() 235 | 236 | 237 | def get_config(args): 238 | """Get a yacs CfgNode object with default values.""" 239 | # Return a clone so that the defaults will not be altered 240 | # This is for the "local variable" use pattern 241 | config = _C.clone() 242 | update_config(config, args) 243 | 244 | return config 245 | -------------------------------------------------------------------------------- /.history/e-commercial/config_20240330210315.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import yaml 7 | import time 8 | from yacs.config import CfgNode as CN 9 | 10 | _C = CN() 11 | 12 | # Base config files 13 | _C.BASE = [''] 14 | 15 | # ----------------------------------------------------------------------------- 16 | # Data settings 17 | # ----------------------------------------------------------------------------- 18 | _C.DATA = CN() 19 | # Batch size for a single GPU, could be overwritten by command line argument 20 | _C.DATA.BATCH_SIZE = 128 21 | # Path to dataset, could be overwritten by command line argument 22 | _C.DATA.DATA_PATH = '' 23 | # Dataset name 24 | _C.DATA.DATASET = 'imagenet' 25 | # Input image size 26 | _C.DATA.IMG_SIZE = 896 27 | # Interpolation to resize image (random, bilinear, bicubic) 28 | _C.DATA.INTERPOLATION = 'bicubic' 29 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 30 | _C.DATA.PIN_MEMORY = True 31 | # Number of data loading threads 32 | _C.DATA.NUM_WORKERS = 8 33 | 34 | # ----------------------------------------------------------------------------- 35 | # Model settings 36 | # ----------------------------------------------------------------------------- 37 | _C.MODEL = CN() 38 | # Model type 39 | _C.MODEL.TYPE = 'sswin' 40 | # Model name 41 | _C.MODEL.NAME = 'sswin' 42 | # Checkpoint to resume, could be overwritten by command line argument 43 | _C.MODEL.RESUME = '' 44 | # Number of classes, overwritten in data preparation 45 | _C.MODEL.NUM_CLASSES = 1 46 | # Dropout rate 47 | _C.MODEL.DROP_RATE = 0.0 48 | # Drop path rate 49 | _C.MODEL.DROP_PATH_RATE = 0.1 50 | # Label Smoothing 51 | _C.MODEL.LABEL_SMOOTHING = 0.1 52 | 53 | _C.MODEL.FINETUNE = '/mnt/disk10T/liyifei/CVPR2022EC/output/sal+ocr/swin_tiny_patch4_window7_224/default/ckpt_epoch_49.pth' 54 | 55 | # Swin Transformer parameters 56 | _C.MODEL.SWIN = CN() 57 | _C.MODEL.SWIN.PATCH_SIZE = 4 58 | _C.MODEL.SWIN.IN_CHANS = 3 59 | _C.MODEL.SWIN.EMBED_DIM = 96 60 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 61 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 62 | _C.MODEL.SWIN.WINDOW_SIZE = 7 63 | _C.MODEL.SWIN.MLP_RATIO = 4. 64 | _C.MODEL.SWIN.QKV_BIAS = True 65 | _C.MODEL.SWIN.QK_SCALE = None 66 | _C.MODEL.SWIN.APE = False 67 | _C.MODEL.SWIN.PATCH_NORM = True 68 | 69 | # Swin MLP parameters 70 | _C.MODEL.SWIN_MLP = CN() 71 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4 72 | _C.MODEL.SWIN_MLP.IN_CHANS = 3 73 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96 74 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] 75 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] 76 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 77 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4. 78 | _C.MODEL.SWIN_MLP.APE = False 79 | _C.MODEL.SWIN_MLP.PATCH_NORM = True 80 | 81 | # ----------------------------------------------------------------------------- 82 | # Training settings 83 | # ----------------------------------------------------------------------------- 84 | _C.TRAIN = CN() 85 | _C.TRAIN.START_EPOCH = 0 86 | _C.TRAIN.EPOCHS = 50 87 | _C.TRAIN.WARMUP_EPOCHS = 20 88 | _C.TRAIN.WEIGHT_DECAY = 0.05 89 | _C.TRAIN.BASE_LR = 5e-4 90 | _C.TRAIN.WARMUP_LR = 5e-7 91 | _C.TRAIN.MIN_LR = 5e-6 92 | # Clip gradient norm 93 | _C.TRAIN.CLIP_GRAD = 5.0 94 | # Auto resume from latest checkpoint 95 | _C.TRAIN.AUTO_RESUME = True 96 | # Whether to use gradient checkpointing to save memory 97 | # could be overwritten by command line argument 98 | _C.TRAIN.USE_CHECKPOINT = False 99 | 100 | # LR scheduler 101 | _C.TRAIN.LR_SCHEDULER = CN() 102 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 103 | # Epoch interval to decay LR, used in StepLRScheduler 104 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 105 | # LR decay rate, used in StepLRScheduler 106 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 107 | 108 | # Optimizer 109 | _C.TRAIN.OPTIMIZER = CN() 110 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 111 | # Optimizer Epsilon 112 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 113 | # Optimizer Betas 114 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 115 | # SGD momentum 116 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 117 | 118 | # ----------------------------------------------------------------------------- 119 | # Augmentation settings 120 | # ----------------------------------------------------------------------------- 121 | _C.AUG = CN() 122 | # Color jitter factor 123 | _C.AUG.COLOR_JITTER = 0.4 124 | # Use AutoAugment policy. "v0" or "original" 125 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 126 | # Random erase prob 127 | _C.AUG.REPROB = 0.25 128 | # Random erase mode 129 | _C.AUG.REMODE = 'pixel' 130 | # Random erase count 131 | _C.AUG.RECOUNT = 1 132 | # Mixup alpha, mixup enabled if > 0 133 | _C.AUG.MIXUP = 0.8 134 | # Cutmix alpha, cutmix enabled if > 0 135 | _C.AUG.CUTMIX = 1.0 136 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 137 | _C.AUG.CUTMIX_MINMAX = None 138 | # Probability of performing mixup or cutmix when either/both is enabled 139 | _C.AUG.MIXUP_PROB = 1.0 140 | # Probability of switching to cutmix when both mixup and cutmix enabled 141 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 142 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 143 | _C.AUG.MIXUP_MODE = 'batch' 144 | 145 | # ----------------------------------------------------------------------------- 146 | # Testing settings 147 | # ----------------------------------------------------------------------------- 148 | _C.TEST = CN() 149 | # Whether to use center crop when testing 150 | _C.TEST.CROP = True 151 | 152 | # ----------------------------------------------------------------------------- 153 | # Misc 154 | # ----------------------------------------------------------------------------- 155 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 156 | # overwritten by command line argument 157 | _C.AMP_OPT_LEVEL = 'O0' 158 | # Path to output folder, overwritten by command line argument 159 | _C.OUTPUT = '' 160 | # Tag of experiment, overwritten by command line argument 161 | _C.TAG = 'default' 162 | # Frequency to save checkpoint 163 | _C.SAVE_FREQ = 1 164 | # Frequency to logging info 165 | _C.PRINT_FREQ = 10 166 | # Fixed random seed 167 | _C.SEED = 0 168 | # Perform evaluation only, overwritten by command line argument 169 | _C.EVAL_MODE = False 170 | # Test throughput only, overwritten by command line argument 171 | _C.THROUGHPUT_MODE = False 172 | # local rank for DistributedDataParallel, given by command line argument 173 | # _C.LOCAL_RANK = 0 174 | 175 | _C.HEAD = 'denseNet_15layer' 176 | 177 | 178 | def _update_config_from_file(config, cfg_file): 179 | config.defrost() 180 | with open(cfg_file, 'r') as f: 181 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 182 | 183 | for cfg in yaml_cfg.setdefault('BASE', ['']): 184 | if cfg: 185 | _update_config_from_file( 186 | config, os.path.join(os.path.dirname(cfg_file), cfg) 187 | ) 188 | print('=> merge config from {}'.format(cfg_file)) 189 | config.merge_from_file(cfg_file) 190 | config.freeze() 191 | 192 | 193 | def update_config(config, args): 194 | _update_config_from_file(config, args.cfg) 195 | 196 | config.defrost() 197 | if args.opts: 198 | config.merge_from_list(args.opts) 199 | 200 | # merge from specific arguments 201 | if args.batch_size: 202 | config.DATA.BATCH_SIZE = args.batch_size 203 | if args.data_path: 204 | config.DATA.DATA_PATH = args.data_path 205 | if args.resume: 206 | config.MODEL.RESUME = args.resume 207 | if args.use_checkpoint: 208 | config.TRAIN.USE_CHECKPOINT = True 209 | if args.output: 210 | config.OUTPUT = args.output 211 | if args.tag: 212 | config.TAG = args.tag 213 | if args.eval: 214 | config.EVAL_MODE = True 215 | if args.dataset: 216 | config.DATA.DATASET = args.dataset 217 | if args.datanum: 218 | config.DATA.DATANUM = args.datanum 219 | if args.num_epoch: 220 | config.TRAIN.EPOCHS = args.num_epoch 221 | 222 | # if args.loss: 223 | if args.finetune: 224 | # print("finetune") 225 | config.MODEL.FINETUNE = args.finetune 226 | # config.LOSS = args.loss 227 | 228 | # set local rank for distributed training 229 | # config.LOCAL_RANK = args.local_rank 230 | 231 | # output folder 232 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME + ".".join(str(time.asctime()).split()), config.TAG) 233 | 234 | config.freeze() 235 | 236 | 237 | def get_config(args): 238 | """Get a yacs CfgNode object with default values.""" 239 | # Return a clone so that the defaults will not be altered 240 | # This is for the "local variable" use pattern 241 | config = _C.clone() 242 | update_config(config, args) 243 | 244 | return config 245 | -------------------------------------------------------------------------------- /.history/e-commercial/config_20240408005800.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import yaml 7 | import time 8 | from yacs.config import CfgNode as CN 9 | 10 | _C = CN() 11 | 12 | # Base config files 13 | _C.BASE = [''] 14 | 15 | # ----------------------------------------------------------------------------- 16 | # Data settings 17 | # ----------------------------------------------------------------------------- 18 | _C.DATA = CN() 19 | # Batch size for a single GPU, could be overwritten by command line argument 20 | _C.DATA.BATCH_SIZE = 128 21 | # Path to dataset, could be overwritten by command line argument 22 | _C.DATA.DATA_PATH = '' 23 | # Dataset name 24 | _C.DATA.DATASET = 'imagenet' 25 | # Input image size 26 | _C.DATA.IMG_SIZE = 896 27 | # Interpolation to resize image (random, bilinear, bicubic) 28 | _C.DATA.INTERPOLATION = 'bicubic' 29 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 30 | _C.DATA.PIN_MEMORY = True 31 | # Number of data loading threads 32 | _C.DATA.NUM_WORKERS = 8 33 | 34 | # ----------------------------------------------------------------------------- 35 | # Model settings 36 | # ----------------------------------------------------------------------------- 37 | _C.MODEL = CN() 38 | # Model type 39 | _C.MODEL.TYPE = 'sswin' 40 | # Model name 41 | _C.MODEL.NAME = 'sswin' 42 | # Checkpoint to resume, could be overwritten by command line argument 43 | _C.MODEL.RESUME = '' 44 | # Number of classes, overwritten in data preparation 45 | _C.MODEL.NUM_CLASSES = 1 46 | # Dropout rate 47 | _C.MODEL.DROP_RATE = 0.0 48 | # Drop path rate 49 | _C.MODEL.DROP_PATH_RATE = 0.1 50 | # Label Smoothing 51 | _C.MODEL.LABEL_SMOOTHING = 0.1 52 | 53 | _C.MODEL.FINETUNE = '/mnt/disk10T/your_path/CVPR2022EC/output/sal+ocr/swin_tiny_patch4_window7_224/default/ckpt_epoch_49.pth' 54 | 55 | # Swin Transformer parameters 56 | _C.MODEL.SWIN = CN() 57 | _C.MODEL.SWIN.PATCH_SIZE = 4 58 | _C.MODEL.SWIN.IN_CHANS = 3 59 | _C.MODEL.SWIN.EMBED_DIM = 96 60 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 61 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 62 | _C.MODEL.SWIN.WINDOW_SIZE = 7 63 | _C.MODEL.SWIN.MLP_RATIO = 4. 64 | _C.MODEL.SWIN.QKV_BIAS = True 65 | _C.MODEL.SWIN.QK_SCALE = None 66 | _C.MODEL.SWIN.APE = False 67 | _C.MODEL.SWIN.PATCH_NORM = True 68 | 69 | # Swin MLP parameters 70 | _C.MODEL.SWIN_MLP = CN() 71 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4 72 | _C.MODEL.SWIN_MLP.IN_CHANS = 3 73 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96 74 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] 75 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] 76 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 77 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4. 78 | _C.MODEL.SWIN_MLP.APE = False 79 | _C.MODEL.SWIN_MLP.PATCH_NORM = True 80 | 81 | # ----------------------------------------------------------------------------- 82 | # Training settings 83 | # ----------------------------------------------------------------------------- 84 | _C.TRAIN = CN() 85 | _C.TRAIN.START_EPOCH = 0 86 | _C.TRAIN.EPOCHS = 50 87 | _C.TRAIN.WARMUP_EPOCHS = 20 88 | _C.TRAIN.WEIGHT_DECAY = 0.05 89 | _C.TRAIN.BASE_LR = 5e-4 90 | _C.TRAIN.WARMUP_LR = 5e-7 91 | _C.TRAIN.MIN_LR = 5e-6 92 | # Clip gradient norm 93 | _C.TRAIN.CLIP_GRAD = 5.0 94 | # Auto resume from latest checkpoint 95 | _C.TRAIN.AUTO_RESUME = True 96 | # Whether to use gradient checkpointing to save memory 97 | # could be overwritten by command line argument 98 | _C.TRAIN.USE_CHECKPOINT = False 99 | 100 | # LR scheduler 101 | _C.TRAIN.LR_SCHEDULER = CN() 102 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 103 | # Epoch interval to decay LR, used in StepLRScheduler 104 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 105 | # LR decay rate, used in StepLRScheduler 106 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 107 | 108 | # Optimizer 109 | _C.TRAIN.OPTIMIZER = CN() 110 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 111 | # Optimizer Epsilon 112 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 113 | # Optimizer Betas 114 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 115 | # SGD momentum 116 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 117 | 118 | # ----------------------------------------------------------------------------- 119 | # Augmentation settings 120 | # ----------------------------------------------------------------------------- 121 | _C.AUG = CN() 122 | # Color jitter factor 123 | _C.AUG.COLOR_JITTER = 0.4 124 | # Use AutoAugment policy. "v0" or "original" 125 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 126 | # Random erase prob 127 | _C.AUG.REPROB = 0.25 128 | # Random erase mode 129 | _C.AUG.REMODE = 'pixel' 130 | # Random erase count 131 | _C.AUG.RECOUNT = 1 132 | # Mixup alpha, mixup enabled if > 0 133 | _C.AUG.MIXUP = 0.8 134 | # Cutmix alpha, cutmix enabled if > 0 135 | _C.AUG.CUTMIX = 1.0 136 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 137 | _C.AUG.CUTMIX_MINMAX = None 138 | # Probability of performing mixup or cutmix when either/both is enabled 139 | _C.AUG.MIXUP_PROB = 1.0 140 | # Probability of switching to cutmix when both mixup and cutmix enabled 141 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 142 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 143 | _C.AUG.MIXUP_MODE = 'batch' 144 | 145 | # ----------------------------------------------------------------------------- 146 | # Testing settings 147 | # ----------------------------------------------------------------------------- 148 | _C.TEST = CN() 149 | # Whether to use center crop when testing 150 | _C.TEST.CROP = True 151 | 152 | # ----------------------------------------------------------------------------- 153 | # Misc 154 | # ----------------------------------------------------------------------------- 155 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 156 | # overwritten by command line argument 157 | _C.AMP_OPT_LEVEL = 'O0' 158 | # Path to output folder, overwritten by command line argument 159 | _C.OUTPUT = '' 160 | # Tag of experiment, overwritten by command line argument 161 | _C.TAG = 'default' 162 | # Frequency to save checkpoint 163 | _C.SAVE_FREQ = 1 164 | # Frequency to logging info 165 | _C.PRINT_FREQ = 10 166 | # Fixed random seed 167 | _C.SEED = 0 168 | # Perform evaluation only, overwritten by command line argument 169 | _C.EVAL_MODE = False 170 | # Test throughput only, overwritten by command line argument 171 | _C.THROUGHPUT_MODE = False 172 | # local rank for DistributedDataParallel, given by command line argument 173 | # _C.LOCAL_RANK = 0 174 | 175 | _C.HEAD = 'denseNet_15layer' 176 | 177 | 178 | def _update_config_from_file(config, cfg_file): 179 | config.defrost() 180 | with open(cfg_file, 'r') as f: 181 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 182 | 183 | for cfg in yaml_cfg.setdefault('BASE', ['']): 184 | if cfg: 185 | _update_config_from_file( 186 | config, os.path.join(os.path.dirname(cfg_file), cfg) 187 | ) 188 | print('=> merge config from {}'.format(cfg_file)) 189 | config.merge_from_file(cfg_file) 190 | config.freeze() 191 | 192 | 193 | def update_config(config, args): 194 | _update_config_from_file(config, args.cfg) 195 | 196 | config.defrost() 197 | if args.opts: 198 | config.merge_from_list(args.opts) 199 | 200 | # merge from specific arguments 201 | if args.batch_size: 202 | config.DATA.BATCH_SIZE = args.batch_size 203 | if args.data_path: 204 | config.DATA.DATA_PATH = args.data_path 205 | if args.resume: 206 | config.MODEL.RESUME = args.resume 207 | if args.use_checkpoint: 208 | config.TRAIN.USE_CHECKPOINT = True 209 | if args.output: 210 | config.OUTPUT = args.output 211 | if args.tag: 212 | config.TAG = args.tag 213 | if args.eval: 214 | config.EVAL_MODE = True 215 | if args.dataset: 216 | config.DATA.DATASET = args.dataset 217 | if args.datanum: 218 | config.DATA.DATANUM = args.datanum 219 | if args.num_epoch: 220 | config.TRAIN.EPOCHS = args.num_epoch 221 | 222 | # if args.loss: 223 | if args.finetune: 224 | # print("finetune") 225 | config.MODEL.FINETUNE = args.finetune 226 | # config.LOSS = args.loss 227 | 228 | # set local rank for distributed training 229 | # config.LOCAL_RANK = args.local_rank 230 | 231 | # output folder 232 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME + ".".join(str(time.asctime()).split()), config.TAG) 233 | 234 | config.freeze() 235 | 236 | 237 | def get_config(args): 238 | """Get a yacs CfgNode object with default values.""" 239 | # Return a clone so that the defaults will not be altered 240 | # This is for the "local variable" use pattern 241 | config = _C.clone() 242 | update_config(config, args) 243 | 244 | return config 245 | -------------------------------------------------------------------------------- /e-commercial/data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import io 6 | import os 7 | import time 8 | import torch.distributed as dist 9 | import torch.utils.data as data 10 | from PIL import Image 11 | 12 | from .zipreader import is_zip_path, ZipReader 13 | 14 | 15 | def has_file_allowed_extension(filename, extensions): 16 | """Checks if a file is an allowed extension. 17 | Args: 18 | filename (string): path to a file 19 | Returns: 20 | bool: True if the filename ends with a known image extension 21 | """ 22 | filename_lower = filename.lower() 23 | return any(filename_lower.endswith(ext) for ext in extensions) 24 | 25 | 26 | def find_classes(dir): 27 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 28 | classes.sort() 29 | class_to_idx = {classes[i]: i for i in range(len(classes))} 30 | return classes, class_to_idx 31 | 32 | 33 | def make_dataset(dir, class_to_idx, extensions): 34 | images = [] 35 | dir = os.path.expanduser(dir) 36 | for target in sorted(os.listdir(dir)): 37 | d = os.path.join(dir, target) 38 | if not os.path.isdir(d): 39 | continue 40 | 41 | for root, _, fnames in sorted(os.walk(d)): 42 | for fname in sorted(fnames): 43 | if has_file_allowed_extension(fname, extensions): 44 | path = os.path.join(root, fname) 45 | item = (path, class_to_idx[target]) 46 | images.append(item) 47 | 48 | return images 49 | 50 | 51 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 52 | images = [] 53 | with open(ann_file, "r") as f: 54 | contents = f.readlines() 55 | for line_str in contents: 56 | path_contents = [c for c in line_str.split('\t')] 57 | im_file_name = path_contents[0] 58 | class_index = int(path_contents[1]) 59 | 60 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 61 | item = (os.path.join(img_prefix, im_file_name), class_index) 62 | 63 | images.append(item) 64 | 65 | return images 66 | 67 | 68 | class DatasetFolder(data.Dataset): 69 | """A generic data loader where the samples are arranged in this way: :: 70 | root/class_x/xxx.ext 71 | root/class_x/xxy.ext 72 | root/class_x/xxz.ext 73 | root/class_y/123.ext 74 | root/class_y/nsdf3.ext 75 | root/class_y/asd932_.ext 76 | Args: 77 | root (string): Root directory path. 78 | loader (callable): A function to load a sample given its path. 79 | extensions (list[string]): A list of allowed extensions. 80 | transform (callable, optional): A function/transform that takes in 81 | a sample and returns a transformed version. 82 | E.g, ``transforms.RandomCrop`` for images. 83 | target_transform (callable, optional): A function/transform that takes 84 | in the target and transforms it. 85 | Attributes: 86 | samples (list): List of (sample path, class_index) tuples 87 | """ 88 | 89 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 90 | cache_mode="no"): 91 | # image folder mode 92 | if ann_file == '': 93 | _, class_to_idx = find_classes(root) 94 | samples = make_dataset(root, class_to_idx, extensions) 95 | # zip mode 96 | else: 97 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 98 | os.path.join(root, img_prefix), 99 | extensions) 100 | 101 | if len(samples) == 0: 102 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 103 | "Supported extensions are: " + ",".join(extensions))) 104 | 105 | self.root = root 106 | self.loader = loader 107 | self.extensions = extensions 108 | 109 | self.samples = samples 110 | self.labels = [y_1k for _, y_1k in samples] 111 | self.classes = list(set(self.labels)) 112 | 113 | self.transform = transform 114 | self.target_transform = target_transform 115 | 116 | self.cache_mode = cache_mode 117 | if self.cache_mode != "no": 118 | self.init_cache() 119 | 120 | def init_cache(self): 121 | assert self.cache_mode in ["part", "full"] 122 | n_sample = len(self.samples) 123 | global_rank = dist.get_rank() 124 | world_size = dist.get_world_size() 125 | 126 | samples_bytes = [None for _ in range(n_sample)] 127 | start_time = time.time() 128 | for index in range(n_sample): 129 | if index % (n_sample // 10) == 0: 130 | t = time.time() - start_time 131 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 132 | start_time = time.time() 133 | path, target = self.samples[index] 134 | if self.cache_mode == "full": 135 | samples_bytes[index] = (ZipReader.read(path), target) 136 | elif self.cache_mode == "part" and index % world_size == global_rank: 137 | samples_bytes[index] = (ZipReader.read(path), target) 138 | else: 139 | samples_bytes[index] = (path, target) 140 | self.samples = samples_bytes 141 | 142 | def __getitem__(self, index): 143 | """ 144 | Args: 145 | index (int): Index 146 | Returns: 147 | tuple: (sample, target) where target is class_index of the target class. 148 | """ 149 | path, target = self.samples[index] 150 | sample = self.loader(path) 151 | if self.transform is not None: 152 | sample = self.transform(sample) 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | 156 | return sample, target 157 | 158 | def __len__(self): 159 | return len(self.samples) 160 | 161 | def __repr__(self): 162 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 163 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 164 | fmt_str += ' Root Location: {}\n'.format(self.root) 165 | tmp = ' Transforms (if any): ' 166 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 167 | tmp = ' Target Transforms (if any): ' 168 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 169 | return fmt_str 170 | 171 | 172 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 173 | 174 | 175 | def pil_loader(path): 176 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 177 | if isinstance(path, bytes): 178 | img = Image.open(io.BytesIO(path)) 179 | elif is_zip_path(path): 180 | data = ZipReader.read(path) 181 | img = Image.open(io.BytesIO(data)) 182 | else: 183 | with open(path, 'rb') as f: 184 | img = Image.open(f) 185 | return img.convert('RGB') 186 | 187 | 188 | def accimage_loader(path): 189 | import accimage 190 | try: 191 | return accimage.Image(path) 192 | except IOError: 193 | # Potentially a decoding problem, fall back to PIL.Image 194 | return pil_loader(path) 195 | 196 | 197 | def default_img_loader(path): 198 | from torchvision import get_image_backend 199 | if get_image_backend() == 'accimage': 200 | return accimage_loader(path) 201 | else: 202 | return pil_loader(path) 203 | 204 | 205 | class CachedImageFolder(DatasetFolder): 206 | """A generic data loader where the images are arranged in this way: :: 207 | root/dog/xxx.png 208 | root/dog/xxy.png 209 | root/dog/xxz.png 210 | root/cat/123.png 211 | root/cat/nsdf3.png 212 | root/cat/asd932_.png 213 | Args: 214 | root (string): Root directory path. 215 | transform (callable, optional): A function/transform that takes in an PIL image 216 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 217 | target_transform (callable, optional): A function/transform that takes in the 218 | target and transforms it. 219 | loader (callable, optional): A function to load an image given its path. 220 | Attributes: 221 | imgs (list): List of (image path, class_index) tuples 222 | """ 223 | 224 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 225 | loader=default_img_loader, cache_mode="no"): 226 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 227 | ann_file=ann_file, img_prefix=img_prefix, 228 | transform=transform, target_transform=target_transform, 229 | cache_mode=cache_mode) 230 | self.imgs = self.samples 231 | 232 | def __getitem__(self, index): 233 | """ 234 | Args: 235 | index (int): Index 236 | Returns: 237 | tuple: (image, target) where target is class_index of the target class. 238 | """ 239 | path, target = self.samples[index] 240 | image = self.loader(path) 241 | if self.transform is not None: 242 | img = self.transform(image) 243 | else: 244 | img = image 245 | if self.target_transform is not None: 246 | target = self.target_transform(target) 247 | 248 | return img, target 249 | -------------------------------------------------------------------------------- /e-commercial/models/craft_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | import cv2 9 | import math 10 | 11 | """ auxilary functions """ 12 | # unwarp corodinates 13 | def warpCoord(Minv, pt): 14 | out = np.matmul(Minv, (pt[0], pt[1], 1)) 15 | return np.array([out[0]/out[2], out[1]/out[2]]) 16 | """ end of auxilary functions """ 17 | 18 | 19 | def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text): 20 | # prepare data 21 | linkmap = linkmap.copy() 22 | textmap = textmap.copy() 23 | img_h, img_w = textmap.shape 24 | 25 | """ labeling method """ 26 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0) 27 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0) 28 | 29 | text_score_comb = np.clip(text_score + link_score, 0, 1) 30 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4) 31 | 32 | det = [] 33 | mapper = [] 34 | for k in range(1,nLabels): 35 | # size filtering 36 | size = stats[k, cv2.CC_STAT_AREA] 37 | if size < 10: continue 38 | 39 | # thresholding 40 | if np.max(textmap[labels==k]) < text_threshold: continue 41 | 42 | # make segmentation map 43 | segmap = np.zeros(textmap.shape, dtype=np.uint8) 44 | segmap[labels==k] = 255 45 | segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area 46 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP] 47 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT] 48 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2) 49 | sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1 50 | # boundary check 51 | if sx < 0 : sx = 0 52 | if sy < 0 : sy = 0 53 | if ex >= img_w: ex = img_w 54 | if ey >= img_h: ey = img_h 55 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter)) 56 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel, iterations=1) 57 | #kernel1 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 5)) 58 | #segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel1, iterations=1) 59 | 60 | 61 | # make box 62 | np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2) 63 | rectangle = cv2.minAreaRect(np_contours) 64 | box = cv2.boxPoints(rectangle) 65 | 66 | # align diamond-shape 67 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) 68 | box_ratio = max(w, h) / (min(w, h) + 1e-5) 69 | if abs(1 - box_ratio) <= 0.1: 70 | l, r = min(np_contours[:,0]), max(np_contours[:,0]) 71 | t, b = min(np_contours[:,1]), max(np_contours[:,1]) 72 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) 73 | 74 | # make clock-wise order 75 | startidx = box.sum(axis=1).argmin() 76 | box = np.roll(box, 4-startidx, 0) 77 | box = np.array(box) 78 | 79 | det.append(box) 80 | mapper.append(k) 81 | 82 | return det, labels, mapper 83 | 84 | def getPoly_core(boxes, labels, mapper, linkmap): 85 | # configs 86 | num_cp = 5 87 | max_len_ratio = 0.7 88 | expand_ratio = 1.45 89 | max_r = 2.0 90 | step_r = 0.2 91 | 92 | polys = [] 93 | for k, box in enumerate(boxes): 94 | # size filter for small instance 95 | w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1) 96 | if w < 30 or h < 30: 97 | polys.append(None); continue 98 | 99 | # warp image 100 | tar = np.float32([[0,0],[w,0],[w,h],[0,h]]) 101 | M = cv2.getPerspectiveTransform(box, tar) 102 | word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST) 103 | try: 104 | Minv = np.linalg.inv(M) 105 | except: 106 | polys.append(None); continue 107 | 108 | # binarization for selected label 109 | cur_label = mapper[k] 110 | word_label[word_label != cur_label] = 0 111 | word_label[word_label > 0] = 1 112 | 113 | """ Polygon generation """ 114 | # find top/bottom contours 115 | cp = [] 116 | max_len = -1 117 | for i in range(w): 118 | region = np.where(word_label[:,i] != 0)[0] 119 | if len(region) < 2 : continue 120 | cp.append((i, region[0], region[-1])) 121 | length = region[-1] - region[0] + 1 122 | if length > max_len: max_len = length 123 | 124 | # pass if max_len is similar to h 125 | if h * max_len_ratio < max_len: 126 | polys.append(None); continue 127 | 128 | # get pivot points with fixed length 129 | tot_seg = num_cp * 2 + 1 130 | seg_w = w / tot_seg # segment width 131 | pp = [None] * num_cp # init pivot points 132 | cp_section = [[0, 0]] * tot_seg 133 | seg_height = [0] * num_cp 134 | seg_num = 0 135 | num_sec = 0 136 | prev_h = -1 137 | for i in range(0,len(cp)): 138 | (x, sy, ey) = cp[i] 139 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg: 140 | # average previous segment 141 | if num_sec == 0: break 142 | cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec] 143 | num_sec = 0 144 | 145 | # reset variables 146 | seg_num += 1 147 | prev_h = -1 148 | 149 | # accumulate center points 150 | cy = (sy + ey) * 0.5 151 | cur_h = ey - sy + 1 152 | cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy] 153 | num_sec += 1 154 | 155 | if seg_num % 2 == 0: continue # No polygon area 156 | 157 | if prev_h < cur_h: 158 | pp[int((seg_num - 1)/2)] = (x, cy) 159 | seg_height[int((seg_num - 1)/2)] = cur_h 160 | prev_h = cur_h 161 | 162 | # processing last segment 163 | if num_sec != 0: 164 | cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec] 165 | 166 | # pass if num of pivots is not sufficient or segment widh is smaller than character height 167 | if None in pp or seg_w < np.max(seg_height) * 0.25: 168 | polys.append(None); continue 169 | 170 | # calc median maximum of pivot points 171 | half_char_h = np.median(seg_height) * expand_ratio / 2 172 | 173 | # calc gradiant and apply to make horizontal pivots 174 | new_pp = [] 175 | for i, (x, cy) in enumerate(pp): 176 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0] 177 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1] 178 | if dx == 0: # gradient if zero 179 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h]) 180 | continue 181 | rad = - math.atan2(dy, dx) 182 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad) 183 | new_pp.append([x - s, cy - c, x + s, cy + c]) 184 | 185 | # get edge points to cover character heatmaps 186 | isSppFound, isEppFound = False, False 187 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0]) 188 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0]) 189 | for r in np.arange(0.5, max_r, step_r): 190 | dx = 2 * half_char_h * r 191 | if not isSppFound: 192 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 193 | dy = grad_s * dx 194 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy]) 195 | cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) 196 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 197 | spp = p 198 | isSppFound = True 199 | if not isEppFound: 200 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 201 | dy = grad_e * dx 202 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy]) 203 | cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) 204 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 205 | epp = p 206 | isEppFound = True 207 | if isSppFound and isEppFound: 208 | break 209 | 210 | # pass if boundary of polygon is not found 211 | if not (isSppFound and isEppFound): 212 | polys.append(None); continue 213 | 214 | # make final polygon 215 | poly = [] 216 | poly.append(warpCoord(Minv, (spp[0], spp[1]))) 217 | for p in new_pp: 218 | poly.append(warpCoord(Minv, (p[0], p[1]))) 219 | poly.append(warpCoord(Minv, (epp[0], epp[1]))) 220 | poly.append(warpCoord(Minv, (epp[2], epp[3]))) 221 | for p in reversed(new_pp): 222 | poly.append(warpCoord(Minv, (p[2], p[3]))) 223 | poly.append(warpCoord(Minv, (spp[2], spp[3]))) 224 | 225 | # add to final result 226 | polys.append(np.array(poly)) 227 | 228 | return polys 229 | 230 | def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False): 231 | boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text) 232 | 233 | if poly: 234 | polys = getPoly_core(boxes, labels, mapper, linkmap) 235 | else: 236 | polys = [None] * len(boxes) 237 | 238 | return boxes, polys 239 | 240 | def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2): 241 | if len(polys) > 0: 242 | polys = np.array(polys) 243 | for k in range(len(polys)): 244 | if polys[k] is not None: 245 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net) 246 | return polys 247 | -------------------------------------------------------------------------------- /e-commercial/models/saliency_detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | def define_salD(input_dim, actType='lrelu', normType='batch', netSalD='denseNet256_36layer'): 9 | if netSalD == 'CNN_3layer': 10 | net = SalDetector(input_dim, growthRate=32, EnBlocks=(0, 0, 0), feaChannel=64, reduction=0.5, bottleneck=True, 11 | actType=actType, normType=normType, useASPP=True, ASPPsacles=(0, 1, 4, 7)) 12 | elif netSalD == 'denseNet_15layer': 13 | net = SalDetector(input_dim, growthRate=32, EnBlocks=(2, 2, 2), feaChannel=64, reduction=0.5, bottleneck=True, 14 | actType=actType, normType=normType, useASPP=True, ASPPsacles=(0, 1, 4, 7)) 15 | elif netSalD == 'denseNet_20layer': 16 | net = SalDetector(input_dim, growthRate=32, EnBlocks=(2, 2, 2, 2), feaChannel=64, reduction=0.5, 17 | bottleneck=True, 18 | actType=actType, normType=normType, useASPP=True, ASPPsacles=(0, 1, 4, 7)) 19 | elif netSalD == 'denseNet_28layer': 20 | net = SalDetector(input_dim, growthRate=16, EnBlocks=(2, 4, 4, 2), feaChannel=64, reduction=0.5, 21 | bottleneck=True, 22 | actType=actType, normType=normType, useASPP=True, ASPPsacles=(0, 1, 4, 7)) 23 | else: 24 | raise NotImplementedError('Saliency detector model name [%s] is not recognized' % netSalD) 25 | return net 26 | 27 | 28 | class SalDetector(nn.Module): 29 | def __init__(self, inputdim=768, growthRate=32, EnBlocks=(2, 2, 2), feaChannel=64, reduction=0.5, bottleneck=True, 30 | actType='lrelu', normType='batch', useASPP=True, ASPPsacles=(0, 1, 4, 7)): 31 | super(SalDetector, self).__init__() 32 | norm_layer = get_norm_layer(normType) 33 | act_layer = get_activation_layer(actType) 34 | nChannels = inputdim 35 | layers = [] 36 | # layers = [nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False)] 37 | # nChannels = growthRate 38 | for nLayers in EnBlocks: 39 | if nLayers == 0: 40 | CNNlayer = [norm_layer(nChannels), 41 | act_layer(), 42 | nn.Conv2d(nChannels, nChannels // 2, kernel_size=3, padding=1, bias=False)] 43 | layers += CNNlayer 44 | nChannels = nChannels // 2 45 | else: 46 | layers += [DenseBlock(nChannels, growthRate, nLayers, reduction, bottleneck, norm_layer, act_layer)] 47 | nChannels = int(math.floor((nChannels + nLayers * growthRate) * reduction)) 48 | if useASPP: 49 | layers += [ASPP(nChannels, feaChannel, scales=ASPPsacles)] 50 | else: 51 | layers += [nn.Conv2d(nChannels, feaChannel, kernel_size=3, padding=1, bias=False)] 52 | self.Encoder = nn.Sequential(*layers) 53 | layers2 = [DecoderBlock(feaChannel, feaChannel // 4, False, norm_layer, act_layer), 54 | DecoderBlock(feaChannel // 4, feaChannel // 16, True, norm_layer, act_layer), 55 | DecoderBlock(feaChannel // 16, 1, False, norm_layer, act_layer), 56 | ] 57 | self.Decoder = nn.Sequential(*layers2) 58 | 59 | def forward(self, x): 60 | encoderFea = self.Encoder(x) 61 | out = self.Decoder(encoderFea) 62 | # out = torch.squeeze(out) 63 | out = NormAffine(out, method='one') 64 | return out 65 | 66 | 67 | class DecoderBlock(nn.Module): 68 | def __init__(self, nChannels, nOutChannels, deConv=False, norm_layer=nn.BatchNorm2d, act_layer=nn.LeakyReLU): 69 | super(DecoderBlock, self).__init__() 70 | interChannels = int(math.ceil(nChannels * math.sqrt(nOutChannels / nChannels))) 71 | if deConv: 72 | # output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size) 73 | layers = [norm_layer(nChannels), 74 | act_layer(), 75 | nn.ConvTranspose2d(nChannels, interChannels, kernel_size=2, stride=2, padding=0, output_padding=0, 76 | bias=False), 77 | norm_layer(interChannels), 78 | act_layer(), 79 | nn.Conv2d(interChannels, nOutChannels, kernel_size=1, bias=False) 80 | ] 81 | else: 82 | layers = [norm_layer(nChannels), 83 | act_layer(), 84 | nn.Conv2d(nChannels, interChannels, kernel_size=3, padding=1, bias=False), 85 | norm_layer(interChannels), 86 | act_layer(), 87 | nn.Conv2d(interChannels, nOutChannels, kernel_size=1, bias=False) 88 | ] 89 | self.model = nn.Sequential(*layers) 90 | 91 | def forward(self, x): 92 | out = self.model(x) 93 | return out 94 | 95 | 96 | class Transition(nn.Module): 97 | def __init__(self, nChannels, nOutChannels, sn=False, norm_layer=nn.BatchNorm2d, act_layer=nn.LeakyReLU): 98 | super(Transition, self).__init__() 99 | if sn: 100 | layers = [norm_layer(nChannels), 101 | act_layer(), 102 | nn.utils.spectral_norm(nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)) 103 | # nn.AvgPool2d(2) 104 | ] 105 | else: 106 | layers = [norm_layer(nChannels), 107 | act_layer(), 108 | nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) 109 | # nn.AvgPool2d(2) 110 | ] 111 | self.model = nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | out = self.model(x) 115 | return out 116 | 117 | 118 | class DenseBlock(nn.Module): 119 | def __init__(self, nChannels, growthRate, nLayers, reduction, bottleneck=True, norm_layer=nn.BatchNorm2d, 120 | act_layer=nn.LeakyReLU, sn=False): 121 | super(DenseBlock, self).__init__() 122 | layers = [] 123 | for i in range(int(nLayers)): 124 | if bottleneck: 125 | layers += [Bottleneck(nChannels, growthRate, sn, norm_layer, act_layer)] 126 | else: 127 | layers += [SingleLayer(nChannels, growthRate, sn, norm_layer, act_layer)] 128 | nChannels += growthRate 129 | nOutChannels = int(math.floor(nChannels * reduction)) 130 | layers += [Transition(nChannels, nOutChannels, sn, norm_layer, act_layer)] 131 | self.model = nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | out = self.model(x) 135 | return out 136 | 137 | 138 | class Bottleneck(nn.Module): 139 | def __init__(self, nChannels, growthRate, sn=False, norm_layer=nn.BatchNorm2d, act_layer=nn.LeakyReLU): 140 | super(Bottleneck, self).__init__() 141 | interChannels = 4 * growthRate 142 | if sn: 143 | layers = [norm_layer(nChannels), 144 | act_layer(), 145 | nn.utils.spectral_norm(nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)), 146 | norm_layer(interChannels), 147 | act_layer(), 148 | nn.utils.spectral_norm(nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False)) 149 | ] 150 | else: 151 | layers = [norm_layer(nChannels), 152 | act_layer(), 153 | nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False), 154 | norm_layer(interChannels), 155 | act_layer(), 156 | nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) 157 | ] 158 | self.model = nn.Sequential(*layers) 159 | 160 | def forward(self, x): 161 | out = self.model(x) 162 | out = torch.cat((x, out), 1) 163 | return out 164 | 165 | 166 | class SingleLayer(nn.Module): 167 | def __init__(self, nChannels, growthRate, sn=False, norm_layer=nn.BatchNorm2d, act_layer=nn.LeakyReLU): 168 | super(SingleLayer, self).__init__() 169 | if sn: 170 | layers = [norm_layer(nChannels), 171 | act_layer(), 172 | nn.utils.spectral_norm(nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False)) 173 | ] 174 | else: 175 | layers = [norm_layer(nChannels), 176 | act_layer(), 177 | nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) 178 | ] 179 | self.model = nn.Sequential(*layers) 180 | 181 | def forward(self, x): 182 | out = self.model(x) 183 | out = torch.cat((x, out), 1) 184 | return out 185 | 186 | 187 | class ASPP(nn.Module): 188 | def __init__(self, in_channel=512, depth=256, scales=(0, 1, 4, 7), sn=False): 189 | super(ASPP, self).__init__() 190 | self.scales = scales 191 | for dilate_rate in self.scales: 192 | if dilate_rate == -1: 193 | break 194 | if dilate_rate == 0: 195 | layers = [nn.AdaptiveAvgPool2d((1, 1))] 196 | if sn: 197 | layers += [nn.utils.spectral_norm(nn.Conv2d(in_channel, depth, 1, 1))] 198 | else: 199 | layers += [nn.Conv2d(in_channel, depth, 1, 1)] 200 | setattr(self, 'dilate_layer_{}'.format(dilate_rate), nn.Sequential(*layers)) 201 | elif dilate_rate == 1: 202 | if sn: 203 | layers = [nn.utils.spectral_norm(nn.Conv2d(in_channel, depth, 1, 1))] 204 | else: 205 | layers = [nn.Conv2d(in_channel, depth, 1, 1)] 206 | setattr(self, 'dilate_layer_{}'.format(dilate_rate), nn.Sequential(*layers)) 207 | else: 208 | if sn: 209 | layers = [nn.utils.spectral_norm( 210 | nn.Conv2d(in_channel, depth, 3, 1, dilation=dilate_rate, padding=dilate_rate))] 211 | else: 212 | layers = [nn.Conv2d(in_channel, depth, 3, 1, dilation=dilate_rate, padding=dilate_rate)] 213 | setattr(self, 'dilate_layer_{}'.format(dilate_rate), nn.Sequential(*layers)) 214 | self.conv_1x1_output = nn.Conv2d(depth * len(scales), depth, 1, 1) 215 | 216 | def forward(self, x): 217 | dilate_outs = [] 218 | for dilate_rate in self.scales: 219 | if dilate_rate == -1: 220 | return x 221 | if dilate_rate == 0: 222 | layer = getattr(self, 'dilate_layer_{}'.format(dilate_rate)) 223 | size = x.shape[2:] 224 | tempout = F.interpolate(layer(x), size=size, mode='bilinear', align_corners=True) 225 | dilate_outs.append(tempout) 226 | else: 227 | layer = getattr(self, 'dilate_layer_{}'.format(dilate_rate)) 228 | dilate_outs.append(layer(x)) 229 | out = self.conv_1x1_output(torch.cat(dilate_outs, dim=1)) 230 | return out 231 | 232 | 233 | #################################################################### 234 | # ------------------------- Basic Functions ------------------------- 235 | #################################################################### 236 | 237 | 238 | def get_norm_layer(layer_type='instance'): 239 | if layer_type == 'batch': 240 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 241 | elif layer_type == 'instance': 242 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 243 | elif layer_type == 'none': 244 | norm_layer = None 245 | else: 246 | raise NotImplementedError('normalization layer [%s] is not found' % layer_type) 247 | return norm_layer 248 | 249 | 250 | def get_activation_layer(layer_type='relu'): 251 | if layer_type == 'relu': 252 | nl_layer = functools.partial(nn.ReLU, inplace=True) 253 | elif layer_type == 'lrelu': 254 | nl_layer = functools.partial(nn.LeakyReLU, negative_slope=0.2, inplace=True) 255 | elif layer_type == 'elu': 256 | nl_layer = functools.partial(nn.ELU, inplace=True) 257 | elif layer_type == 'none': 258 | nl_layer = None 259 | else: 260 | raise NotImplementedError('activitation [%s] is not found' % layer_type) 261 | return nl_layer 262 | 263 | 264 | def NormAffine(mat, eps=1e-7, 265 | method='sum'): # tensor [batch_size, channels, image_height, image_width] normalize each fea map; 266 | matdim = len(mat.size()) 267 | if method == 'sum': 268 | tempsum = torch.sum(mat, dim=(matdim - 1, matdim - 2), keepdim=True) + eps 269 | out = mat / tempsum 270 | elif method == 'one': 271 | (tempmin, _) = torch.min(mat, dim=matdim - 1, keepdim=True) 272 | (tempmin, _) = torch.min(tempmin, dim=matdim - 2, keepdim=True) 273 | tempmat = mat - tempmin 274 | (tempmax, _) = torch.max(tempmat, dim=matdim - 1, keepdim=True) 275 | (tempmax, _) = torch.max(tempmax, dim=matdim - 2, keepdim=True) 276 | tempmax = tempmax + eps 277 | out = tempmat / tempmax 278 | else: 279 | raise NotImplementedError('Map method [%s] is not implemented' % method) 280 | return out 281 | -------------------------------------------------------------------------------- /e-commercial/main.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import time 7 | import argparse 8 | import datetime 9 | import numpy as np 10 | import torchvision 11 | import cv2 12 | from IPython import embed 13 | 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | 18 | from timm.utils import accuracy, AverageMeter 19 | from torchvision import transforms 20 | import torchvision 21 | import torch.nn.functional as F 22 | from PIL import Image 23 | 24 | from config import get_config 25 | from models import build_model 26 | from models import craft_utils, imgproc 27 | from models.loss import KL_loss, Maploss 28 | from models.metric import calCC, calKL 29 | from data import build_loader 30 | from lr_scheduler import build_scheduler 31 | from optimizer import build_optimizer 32 | from logger import create_logger 33 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor, load_checkpoint_finetune, load_checkpoint_eval 34 | os.environ['CUDA_VISIBLE_DEVICES']="0" 35 | 36 | 37 | try: 38 | # noinspection PyUnresolvedReferences 39 | from apex import amp 40 | except ImportError: 41 | amp = None 42 | 43 | 44 | def parse_option(): 45 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) 46 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 47 | parser.add_argument( 48 | "--opts", 49 | help="Modify config options by adding 'KEY VALUE' pairs. ", 50 | default=None, 51 | nargs='+', 52 | ) 53 | 54 | # easy config modification 55 | parser.add_argument('--batch-size', type=int, help="batch size") 56 | parser.add_argument('--data-path', type=str, help='path to dataset') 57 | # parser.add_argument('--eval-path', type=str, help='path to dataset') 58 | parser.add_argument('--resume', help='resume from checkpoint') 59 | parser.add_argument('--finetune', help='finetune from checkpoint') 60 | parser.add_argument('--use-checkpoint', action='store_true', 61 | help="whether to use gradient checkpointing to save memory") 62 | parser.add_argument('--output', default='./output', type=str, metavar='PATH', 63 | help='root of output folder, the full path is // (default: output)') 64 | parser.add_argument('--tag', help='tag of experiment') 65 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 66 | 67 | # about dataset 68 | parser.add_argument('--dataset', type=str, default='imagenet', help='name of dataset') 69 | parser.add_argument('--datanum', type=int, default=972, help='num of dataset') 70 | parser.add_argument('--num_epoch', type=int, default=50, help='num of epoch') 71 | parser.add_argument('--head', type=str, default='denseNet_15layer', help='head') 72 | 73 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 74 | 75 | args, unparsed = parser.parse_known_args() 76 | 77 | config = get_config(args) 78 | 79 | return args, config 80 | 81 | 82 | def main(config): 83 | # dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) 84 | dataset_train, data_loader_train, mixup_fn = build_loader(config) 85 | 86 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 87 | model = build_model(config) 88 | model.cuda() 89 | logger.info(str(model)) 90 | 91 | optimizer = build_optimizer(config, model) 92 | model = torch.nn.DataParallel(model) 93 | model_without_ddp = model.module 94 | 95 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 96 | logger.info(f"number of params: {n_parameters}") 97 | 98 | if config.EVAL_MODE: 99 | print("test model") 100 | dataset_val, data_loader_val, _ = build_loader(config) 101 | load_checkpoint_eval(config, model_without_ddp, optimizer, logger) 102 | validate_article(config, dataset_train, model) 103 | return 104 | 105 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 106 | 107 | criterion = Maploss() 108 | 109 | 110 | 111 | max_accuracy = 0.0 112 | 113 | if config.TRAIN.AUTO_RESUME: 114 | resume_file = auto_resume_helper(config.OUTPUT) 115 | if resume_file: 116 | if config.MODEL.RESUME: 117 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 118 | config.defrost() 119 | config.MODEL.RESUME = resume_file 120 | config.freeze() 121 | logger.info(f'auto resuming from {resume_file}') 122 | else: 123 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 124 | 125 | 126 | ''' 127 | if config.MODEL.FINETUNE: 128 | print("FINETUNE") 129 | load_checkpoint_finetune(config, model_without_ddp, optimizer, lr_scheduler, logger) 130 | ''' 131 | if config.MODEL.RESUME: 132 | print("resuming") 133 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 134 | validate_article(config, data_loader_val, data_loader_train, model) 135 | # logger.info(f"loss of the network on the {len(dataset_val)} test images: {loss:.1f}%") 136 | if config.EVAL_MODE: 137 | print("eval mode") 138 | return 139 | 140 | 141 | logger.info("Start training") 142 | start_time = time.time() 143 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 144 | print("traning %d epoch" % epoch) 145 | #data_loader_train.sampler.set_epoch(epoch) 146 | 147 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) 148 | print("training %d epoch done" % epoch) 149 | if (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 150 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) 151 | 152 | # validate(config, data_loader_val, model) 153 | # print("imgs saved") 154 | 155 | total_time = time.time() - start_time 156 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 157 | logger.info('Training time {}'.format(total_time_str)) 158 | 159 | 160 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): 161 | model.train() 162 | optimizer.zero_grad() 163 | 164 | num_steps = len(data_loader) 165 | batch_time = AverageMeter() 166 | ocr_loss_meter = AverageMeter() 167 | attn_loss_meter = AverageMeter() 168 | saliency_loss_meter = AverageMeter() 169 | saliency_kl_meter = AverageMeter() 170 | saliency_cc_meter = AverageMeter() 171 | loss_meter = AverageMeter() 172 | norm_meter = AverageMeter() 173 | 174 | start = time.time() 175 | end = time.time() 176 | cnt = 0 177 | test_loc = config.OUTPUT + '/test/' 178 | if not os.path.exists(test_loc): 179 | os.makedirs(test_loc) 180 | test_map = [] 181 | test_name = [] 182 | cnt = 1 183 | 184 | for idx, (name, samples, targets) in enumerate(data_loader): 185 | 186 | samples = samples.cuda(non_blocking=True) 187 | targets = targets.cuda(non_blocking=True) 188 | 189 | outputs, attn_loss = model(samples, targets) 190 | test_map = outputs 191 | test_name = name 192 | 193 | targets = transforms.Resize(56)(targets) 194 | 195 | saliency_loss = KL_loss(outputs, targets) 196 | ''' 197 | out1 = ocr_out[:, :, :, 0].cuda() 198 | out2 = ocr_out[:, :, :, 1].cuda() 199 | gah_label = ocr_target["gah_label"].resize_(out2.size()) 200 | gh_label = ocr_target["gh_label"].resize_(out1.size()) 201 | mask = ocr_target["mask"] 202 | ocr_loss = criterion(gh_label, gah_label, out2, out1, mask) 203 | 204 | # scale 205 | ocr_loss *= 3 206 | 207 | loss = saliency_loss + attn_loss + ocr_loss 208 | saliency_kl = calKL(targets, outputs, True) 209 | saliency_cc = calCC(targets, outputs, True) 210 | 211 | attn_loss_meter.update(attn_loss, targets.size(0)) 212 | ocr_loss_meter.update(ocr_loss, targets.size(0)) 213 | saliency_loss_meter.update(saliency_loss, targets.size(0)) 214 | saliency_cc_meter.update(saliency_cc, targets.size(0)) 215 | saliency_kl_meter.update(saliency_kl, targets.size(0)) 216 | ''' 217 | # print(saliency_loss , attn_loss, attn_loss.shape,"here comes the bugs") 218 | if len(attn_loss.shape) == 0: 219 | attn_loss = attn_loss 220 | else: 221 | attn_loss = sum(attn_loss)/4 222 | if config.TRAIN.START_EPOCH == epoch and cnt: 223 | print("displaying attnloss", attn_loss) 224 | cnt = 0 225 | # print("now adding", attn_loss) 226 | loss = saliency_loss*2 + attn_loss 227 | attn_loss_meter.update(attn_loss, targets.size(0)) 228 | saliency_loss_meter.update(saliency_loss, targets.size(0)) 229 | optimizer.zero_grad() 230 | 231 | loss.backward() 232 | if config.TRAIN.CLIP_GRAD: 233 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 234 | else: 235 | grad_norm = get_grad_norm(model.parameters()) 236 | optimizer.step() 237 | lr_scheduler.step_update(epoch * num_steps + idx) 238 | 239 | torch.cuda.synchronize() 240 | 241 | loss_meter.update(loss.item(), targets.size(0)) 242 | norm_meter.update(grad_norm) 243 | batch_time.update(time.time() - end) 244 | end = time.time() 245 | 246 | if idx % config.PRINT_FREQ == 0: 247 | lr = optimizer.param_groups[0]['lr'] 248 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 249 | etas = batch_time.avg * (num_steps - idx) 250 | logger.info( 251 | f'Train: [{idx}/{len(data_loader)}]\t' 252 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 253 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 254 | f'Mem {memory_used:.0f}MB\t' 255 | f'attn_loss {attn_loss_meter.val:.4f} ({attn_loss_meter.avg:.4f})\t' 256 | f'saliency_loss {saliency_loss_meter.val:.4f} ({saliency_loss_meter.avg:.4f})' 257 | ) 258 | epoch_time = time.time() - start 259 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 260 | # print(test_name,test_map.shape,"testing for epoch") 261 | test_map = np.ascontiguousarray(test_map.detach().cpu().numpy()) 262 | test_map *= 255 263 | test_map = test_map.astype(np.uint8) 264 | # test_ans = np.ascontiguousarray(test_ans.detach().cpu().numpy()) 265 | for i, maps in enumerate(test_map): 266 | name_i = test_name[i].split("/")[-1] 267 | path = test_loc + str(epoch)+"_epoch_" + name_i 268 | # ans_path = test_loc + str(epoch)+"_epoch_test_" + name_i 269 | re = cv2.imwrite(path, maps[0]) 270 | # re = cv2.imwrite(ans_path, test_ans[i][0]) 271 | # print(path, "saved") 272 | print("saving batch_size images for browsing in epoch %d"%epoch) 273 | 274 | 275 | @torch.no_grad() 276 | def validate(config, data_loader, model): 277 | # waiting 278 | criterion = Maploss() 279 | saliency_loc = config.OUTPUT + '/ans/new_saliency/' 280 | if not os.path.exists(saliency_loc): 281 | os.makedirs(saliency_loc) 282 | save_loc = config.OUTPUT + '/ans/new_ocr/' 283 | if not os.path.exists(save_loc): 284 | os.makedirs(save_loc) 285 | model.eval() 286 | print("save loc", save_loc, "saliency loc", saliency_loc) 287 | klall = 0 288 | ccall = 0 289 | cnt = 0 290 | with torch.no_grad(): 291 | for idx, (idx_name, images, target, ocr_target) in enumerate(data_loader): 292 | idx_name = idx_name.tolist() 293 | cnt += 1 294 | # print(idx_name, "3") 295 | images = images.cuda(non_blocking=True) 296 | targets = target.cuda(non_blocking=True) 297 | 298 | output, attn_loss, ocr_out = model(images, targets) 299 | 300 | target = transforms.Resize(56)(targets) 301 | loss = KL_loss(output, target) 302 | out1 = ocr_out[:, :, :, 0].cpu() 303 | out2 = ocr_out[:, :, :, 1].cpu() 304 | gah_label = ocr_target["gah_label"].resize_(out2.size()) 305 | gh_label = ocr_target["gh_label"].resize_(out1.size()) 306 | mask = ocr_target["mask"] 307 | 308 | ocr_loss = criterion(gh_label, gah_label, out2, out1, mask) 309 | target_metric = transforms.Resize(720)(targets) 310 | output_metric = transforms.Resize(720)(output) 311 | saliency_kl = calKL(target_metric, output_metric, False) 312 | saliency_cc = calCC(target_metric, output_metric, False) 313 | klall += saliency_kl 314 | ccall += saliency_cc 315 | 316 | saliency_map = np.array(F.interpolate(output, size=(720, 720), mode='bilinear', align_corners=False).cpu()) 317 | saliency_map = np.ascontiguousarray(saliency_map) 318 | saliency_map *= 255 319 | saliency_maps = saliency_map.astype(np.uint8) 320 | 321 | for i, saliency_map in enumerate(saliency_maps): 322 | path = saliency_loc + str(idx_name[i]) + '.jpg' 323 | re = cv2.imwrite(path, saliency_map[0]) 324 | 325 | ocr_outs = ocr_out.cpu() 326 | for i, ocr_out in enumerate(ocr_outs): 327 | score_text = ocr_out[:, :, 0].cpu().data.numpy() 328 | score_link = ocr_out[:, :, 1].cpu().data.numpy() 329 | 330 | boxes, polys = craft_utils.getDetBoxes(score_text, score_link, 0.7, 0.4, 0.4, False) 331 | for k in range(len(polys)): 332 | if polys[k] is None: polys[k] = boxes[k] 333 | 334 | render_img = score_text.copy() 335 | render_img = np.hstack((render_img, score_link)) 336 | ret_score_text = imgproc.cvt2HeatmapImg(render_img) 337 | 338 | image=cv2.imread("/mnt/disk4T/lyf/ECdata/ALLSTIMULI/"+str(idx_name[i]) + '.jpg') 339 | image=cv2.resize(image,(224,224)) 340 | img = np.array(image) 341 | boxes = polys 342 | res_img_file = save_loc + str(idx_name[i]) + '.jpg' 343 | res_file = save_loc + str(idx_name[i]) + '.txt' 344 | 345 | with open(res_file, 'w') as f: 346 | for j, box in enumerate(boxes): 347 | poly = np.array(box).astype(np.int32).reshape((-1)) 348 | strResult = ','.join([str(p) for p in poly]) + '\r\n' 349 | f.write(strResult) 350 | poly = poly.reshape(-1, 2) 351 | cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) 352 | ptColor = (0, 255, 255) 353 | cv2.imwrite(res_img_file, img) 354 | print("average kl, cc is %f, %f"%(klall/cnt, ccall/cnt)) 355 | 356 | return 357 | 358 | from torchvision import transforms 359 | def validate_article(config, data_loader, model): 360 | # waiting 361 | criterion = Maploss() 362 | saliency_loc = config.OUTPUT + '/output/article_saliency/' 363 | if not os.path.exists(saliency_loc): 364 | os.makedirs(saliency_loc) 365 | model.eval() 366 | print("saliency loc", saliency_loc) 367 | times = 0 368 | cnt = 0 369 | print("testing") 370 | with torch.no_grad(): 371 | for (name, images, _) in data_loader: 372 | # print(name, images.shape, _.shape) 373 | images = images.cuda(non_blocking=True) 374 | images = images.unsqueeze(0) 375 | st = time.time() 376 | 377 | output, __ = model(images, _) 378 | ed = time.time() 379 | print("time", ed-st) 380 | output = transforms.Resize((720,720))(output) 381 | 382 | # saliency_map = F.interpolate(outputs, size=(720, 720), mode='bilinear', align_corners=False).cpu().numpy() 383 | saliency_map = np.ascontiguousarray(output.detach().cpu().numpy()) 384 | saliency_map *= 255 385 | saliency_maps = saliency_map.astype(np.uint8) 386 | tmp_time = ed-st 387 | times += len(saliency_maps)*tmp_time 388 | cnt += len(saliency_maps) 389 | # print(name, len(saliency_maps)) 390 | 391 | name_i = name.split("/")[-1] 392 | path = saliency_loc + name_i 393 | print(type(saliency_map[0]),os.path.exists(saliency_loc)) 394 | # embed() 395 | re = cv2.imwrite(path, saliency_map[0][0]) 396 | print(path, "saved", re) 397 | print(times/cnt) 398 | print("images saved") 399 | 400 | 401 | @torch.no_grad() 402 | def throughput(data_loader, model, logger): 403 | model.eval() 404 | 405 | for idx, (images, _) in enumerate(data_loader): 406 | images = images.cuda(non_blocking=True) 407 | batch_size = images.shape[0] 408 | for i in range(50): 409 | model(images) 410 | torch.cuda.synchronize() 411 | logger.info(f"throughput averaged with 30 times") 412 | tic1 = time.time() 413 | for i in range(30): 414 | model(images) 415 | torch.cuda.synchronize() 416 | tic2 = time.time() 417 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 418 | return 419 | 420 | 421 | if __name__ == '__main__': 422 | _, config = parse_option() 423 | 424 | if config.AMP_OPT_LEVEL != "O0": 425 | assert amp is not None, "amp not installed!" 426 | 427 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 428 | rank = int(os.environ["RANK"]) 429 | world_size = int(os.environ['WORLD_SIZE']) 430 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 431 | else: 432 | rank = -1 433 | world_size = -1 434 | # torch.cuda.set_device(config.LOCAL_RANK) 435 | # torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 436 | # torch.distributed.barrier() 437 | 438 | # seed = config.SEED + dist.get_rank() 439 | # torch.manual_seed(seed) 440 | # np.random.seed(seed) 441 | cudnn.benchmark = True 442 | 443 | # linear scale the learning rate according to total batch size, may not be optimal 444 | linear_scaled_lr = config.TRAIN.BASE_LR # * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 445 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR# * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 446 | linear_scaled_min_lr = config.TRAIN.MIN_LR# * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 447 | 448 | config.defrost() 449 | config.TRAIN.BASE_LR = linear_scaled_lr 450 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 451 | config.TRAIN.MIN_LR = linear_scaled_min_lr 452 | config.freeze() 453 | 454 | os.makedirs(config.OUTPUT, exist_ok=True) 455 | logger = create_logger(output_dir=config.OUTPUT, name=f"{config.MODEL.NAME}") 456 | 457 | if True: 458 | path = os.path.join(config.OUTPUT, "config.json") 459 | with open(path, "w") as f: 460 | f.write(config.dump()) 461 | logger.info(f"Full config saved to {path}") 462 | 463 | # print config 464 | logger.info(config.dump()) 465 | # print(config) 466 | main(config) 467 | -------------------------------------------------------------------------------- /e-commercial/train.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import time 7 | import argparse 8 | import datetime 9 | import numpy as np 10 | import torchvision 11 | import cv2 12 | 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | 17 | from timm.utils import accuracy, AverageMeter 18 | from torchvision import transforms 19 | import torchvision 20 | import torch.nn.functional as F 21 | from PIL import Image 22 | 23 | from config import get_config 24 | from models import build_model 25 | from models import craft_utils, imgproc 26 | from models.loss import KL_loss, Maploss 27 | from models.metric import calCC, calKL 28 | from data import build_loader 29 | from lr_scheduler import build_scheduler 30 | from optimizer import build_optimizer 31 | from logger import create_logger 32 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor, load_checkpoint_finetune, load_checkpoint_eval 33 | 34 | 35 | try: 36 | # noinspection PyUnresolvedReferences 37 | from apex import amp 38 | except ImportError: 39 | amp = None 40 | 41 | 42 | def parse_option(): 43 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) 44 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 45 | parser.add_argument( 46 | "--opts", 47 | help="Modify config options by adding 'KEY VALUE' pairs. ", 48 | default=None, 49 | nargs='+', 50 | ) 51 | 52 | # easy config modification 53 | parser.add_argument('--batch-size', type=int, help="batch size") 54 | parser.add_argument('--data-path', type=str, help='path to dataset') 55 | # parser.add_argument('--eval-path', type=str, help='path to dataset') 56 | parser.add_argument('--resume', help='resume from checkpoint') 57 | parser.add_argument('--finetune', help='finetune from checkpoint') 58 | parser.add_argument('--use-checkpoint', action='store_true', 59 | help="whether to use gradient checkpointing to save memory") 60 | parser.add_argument('--output', default='./output', type=str, metavar='PATH', 61 | help='root of output folder, the full path is // (default: output)') 62 | parser.add_argument('--tag', help='tag of experiment') 63 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 64 | 65 | # about dataset 66 | parser.add_argument('--dataset', type=str, default='imagenet', help='name of dataset') 67 | parser.add_argument('--datanum', type=int, default=972, help='num of dataset') 68 | parser.add_argument('--num_epoch', type=int, default=50, help='num of epoch') 69 | parser.add_argument('--head', type=str, default='denseNet_15layer', help='head') 70 | 71 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 72 | 73 | args, unparsed = parser.parse_known_args() 74 | 75 | config = get_config(args) 76 | 77 | return args, config 78 | 79 | 80 | def main(config): 81 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, logger) 82 | # dataset_train, data_loader_train, mixup_fn = build_loader(config) 83 | 84 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 85 | model = build_model(config) 86 | model.cuda() 87 | logger.info(str(model)) 88 | 89 | optimizer = build_optimizer(config, model) 90 | model = torch.nn.DataParallel(model) 91 | model_without_ddp = model.module 92 | 93 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 94 | logger.info(f"number of params: {n_parameters}") 95 | 96 | if config.EVAL_MODE: 97 | print("test model") 98 | dataset_val, data_loader_val, _ = build_loader(config) 99 | load_checkpoint_eval(config, model_without_ddp, optimizer, logger) 100 | validate_article(config, dataset_train, model) 101 | return 102 | 103 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 104 | 105 | criterion = Maploss() 106 | 107 | 108 | 109 | max_accuracy = 0.0 110 | 111 | if config.TRAIN.AUTO_RESUME: 112 | resume_file = auto_resume_helper(config.OUTPUT) 113 | if resume_file: 114 | if config.MODEL.RESUME: 115 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 116 | config.defrost() 117 | config.MODEL.RESUME = resume_file 118 | config.freeze() 119 | logger.info(f'auto resuming from {resume_file}') 120 | else: 121 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 122 | 123 | 124 | ''' 125 | if config.MODEL.FINETUNE: 126 | print("FINETUNE") 127 | load_checkpoint_finetune(config, model_without_ddp, optimizer, lr_scheduler, logger) 128 | ''' 129 | if config.MODEL.RESUME: 130 | print("resuming") 131 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 132 | validate_article(config, data_loader_val, data_loader_train, model) 133 | # logger.info(f"loss of the network on the {len(dataset_val)} test images: {loss:.1f}%") 134 | if config.EVAL_MODE: 135 | print("eval mode") 136 | return 137 | 138 | 139 | logger.info("Start training") 140 | start_time = time.time() 141 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 142 | if (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 143 | print("fake test first") 144 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) 145 | validate(config, data_loader_val, model, epoch) 146 | print("imgs saved") 147 | print("traning %d epoch" % epoch) 148 | #data_loader_train.sampler.set_epoch(epoch) 149 | 150 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) 151 | print("training %d epoch done" % epoch) 152 | 153 | 154 | total_time = time.time() - start_time 155 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 156 | logger.info('Training time {}'.format(total_time_str)) 157 | 158 | 159 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): 160 | model.train() 161 | optimizer.zero_grad() 162 | 163 | num_steps = len(data_loader) 164 | batch_time = AverageMeter() 165 | ocr_loss_meter = AverageMeter() 166 | attn_loss_meter = AverageMeter() 167 | saliency_loss_meter = AverageMeter() 168 | saliency_kl_meter = AverageMeter() 169 | saliency_cc_meter = AverageMeter() 170 | loss_meter = AverageMeter() 171 | norm_meter = AverageMeter() 172 | 173 | start = time.time() 174 | end = time.time() 175 | cnt = 0 176 | test_loc = config.OUTPUT + '/test/' 177 | if not os.path.exists(test_loc): 178 | os.makedirs(test_loc) 179 | test_map = [] 180 | test_name = None 181 | cnt = 1 182 | 183 | for idx, (idx_name, samples, targets, ocr_target) in enumerate(data_loader): 184 | 185 | samples = samples.cuda(non_blocking=True) 186 | targets = targets.cuda(non_blocking=True) 187 | 188 | outputs, attn_loss, ocr_out = model(samples, targets) 189 | # print(f"{ocr_out.shape}") 190 | test_map = outputs 191 | if test_name is None: 192 | test_name = idx_name 193 | 194 | targets = transforms.Resize(56)(targets) 195 | 196 | saliency_loss = KL_loss(outputs, targets) 197 | out1 = ocr_out[:, :, :, 0].cuda() 198 | out2 = ocr_out[:, :, :, 1].cuda() 199 | gah_label = ocr_target["gah_label"].resize_(out2.size()).cuda() 200 | gh_label = ocr_target["gh_label"].resize_(out1.size()).cuda() 201 | mask = ocr_target["mask"].cuda() 202 | ocr_loss = criterion(gh_label, gah_label, out2, out1, mask) 203 | 204 | # scale 205 | ocr_loss *= 3 206 | 207 | loss = saliency_loss + attn_loss + ocr_loss 208 | saliency_kl = calKL(targets, outputs, True) 209 | saliency_cc = calCC(targets, outputs, True) 210 | 211 | attn_loss_meter.update(attn_loss, targets.size(0)) 212 | ocr_loss_meter.update(ocr_loss, targets.size(0)) 213 | saliency_loss_meter.update(saliency_loss, targets.size(0)) 214 | saliency_cc_meter.update(saliency_cc, targets.size(0)) 215 | saliency_kl_meter.update(saliency_kl, targets.size(0)) 216 | # print(saliency_loss , attn_loss,"here comes the bugs") 217 | # breakpoint() 218 | if len(attn_loss.shape) == 0: 219 | attn_loss = attn_loss 220 | else: 221 | attn_loss = sum(attn_loss)/4 222 | if config.TRAIN.START_EPOCH == epoch and cnt: 223 | print("displaying attnloss", attn_loss) 224 | cnt = 0 225 | # print("now adding", attn_loss) 226 | loss = saliency_loss + attn_loss * 0.2 227 | attn_loss_meter.update(attn_loss, targets.size(0)) 228 | saliency_loss_meter.update(saliency_loss, targets.size(0)) 229 | optimizer.zero_grad() 230 | 231 | loss.backward() 232 | if config.TRAIN.CLIP_GRAD: 233 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 234 | else: 235 | grad_norm = get_grad_norm(model.parameters()) 236 | optimizer.step() 237 | lr_scheduler.step_update(epoch * num_steps + idx) 238 | 239 | torch.cuda.synchronize() 240 | 241 | loss_meter.update(loss.item(), targets.size(0)) 242 | norm_meter.update(grad_norm) 243 | batch_time.update(time.time() - end) 244 | end = time.time() 245 | 246 | if idx % config.PRINT_FREQ == 0: 247 | lr = optimizer.param_groups[0]['lr'] 248 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 249 | etas = batch_time.avg * (num_steps - idx) 250 | logger.info( 251 | f'Train: [{idx}/{len(data_loader)}]\t' 252 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 253 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 254 | f'Mem {memory_used:.0f}MB\t' 255 | f'attn_loss {attn_loss_meter.val:.4f} ({attn_loss_meter.avg:.4f})\t' 256 | f'saliency_loss {saliency_loss_meter.val:.4f} ({saliency_loss_meter.avg:.4f})' 257 | ) 258 | epoch_time = time.time() - start 259 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 260 | # print(test_name,test_map.shape,"testing for epoch") 261 | test_map = np.ascontiguousarray(test_map.detach().cpu().numpy()) 262 | test_map *= 255 263 | test_map = test_map.astype(np.uint8) 264 | # test_ans = np.ascontiguousarray(test_ans.detach().cpu().numpy()) 265 | for i, maps in enumerate(test_map): 266 | name_i = test_name[i].item() 267 | path = test_loc + str(epoch) + f"_epoch/{name_i}.jpg" 268 | # ans_path = test_loc + str(epoch)+"_epoch_test_" + name_i 269 | re = cv2.imwrite(path, maps[0]) 270 | # re = cv2.imwrite(ans_path, test_ans[i][0]) 271 | # print(path, "saved") 272 | print("saving batch_size images for browsing in epoch %d"%epoch) 273 | 274 | 275 | @torch.no_grad() 276 | def validate(config, data_loader, model, epoch): 277 | # waiting 278 | criterion = Maploss() 279 | saliency_loc = config.OUTPUT + f'/ans/{epoch}/new_saliency/' 280 | if not os.path.exists(saliency_loc): 281 | os.makedirs(saliency_loc) 282 | save_loc = config.OUTPUT + f'/ans/{epoch}/new_ocr/' 283 | if not os.path.exists(save_loc): 284 | os.makedirs(save_loc) 285 | model.eval() 286 | print("save loc", save_loc, "saliency loc", saliency_loc) 287 | klall = 0 288 | ccall = 0 289 | cnt = 0 290 | with torch.no_grad(): 291 | for idx, (idx_name, images, target, ocr_target) in enumerate(data_loader): 292 | if epoch == 0 and cnt > 1: 293 | break 294 | idx_name = idx_name.tolist() 295 | cnt += 1 296 | # print(idx_name, "3") 297 | images = images.cuda(non_blocking=True) 298 | targets = target.cuda(non_blocking=True) 299 | 300 | output, attn_loss, ocr_out = model(images, targets) 301 | 302 | target = transforms.Resize(56)(targets) 303 | loss = KL_loss(output, target) 304 | out1 = ocr_out[:, :, :, 0].cpu() 305 | out2 = ocr_out[:, :, :, 1].cpu() 306 | gah_label = ocr_target["gah_label"].resize_(out2.size()) 307 | gh_label = ocr_target["gh_label"].resize_(out1.size()) 308 | mask = ocr_target["mask"] 309 | 310 | ocr_loss = criterion(gh_label, gah_label, out2, out1, mask) 311 | target_metric = transforms.Resize(720)(targets) 312 | output_metric = transforms.Resize(720)(output) 313 | saliency_kl = calKL(target_metric, output_metric, False) 314 | saliency_cc = calCC(target_metric, output_metric, False) 315 | klall += saliency_kl 316 | ccall += saliency_cc 317 | 318 | saliency_map = np.array(F.interpolate(output, size=(720, 720), mode='bilinear', align_corners=False).cpu()) 319 | saliency_map = np.ascontiguousarray(saliency_map) 320 | saliency_map *= 255 321 | saliency_maps = saliency_map.astype(np.uint8) 322 | 323 | for i, saliency_map in enumerate(saliency_maps): 324 | path = saliency_loc + str(idx_name[i]) + '.jpg' 325 | re = cv2.imwrite(path, saliency_map[0]) 326 | 327 | ocr_outs = ocr_out.cpu() 328 | for i, ocr_out in enumerate(ocr_outs): 329 | score_text = ocr_out[:, :, 0].cpu().data.numpy() 330 | score_link = ocr_out[:, :, 1].cpu().data.numpy() 331 | 332 | boxes, polys = craft_utils.getDetBoxes(score_text, score_link, 0.7, 0.4, 0.4, False) 333 | for k in range(len(polys)): 334 | if polys[k] is None: polys[k] = boxes[k] 335 | 336 | render_img = score_text.copy() 337 | render_img = np.hstack((render_img, score_link)) 338 | ret_score_text = imgproc.cvt2HeatmapImg(render_img) 339 | 340 | image=cv2.imread("/mnt/hdd1/yifei/DATA/ECdata/ALLSTIMULI/"+str(idx_name[i]) + '.jpg') 341 | image=cv2.resize(image,(224,224)) 342 | img = np.array(image) 343 | boxes = polys 344 | res_img_file = save_loc + str(idx_name[i]) + '.jpg' 345 | res_file = save_loc + str(idx_name[i]) + '.txt' 346 | 347 | with open(res_file, 'w') as f: 348 | for j, box in enumerate(boxes): 349 | poly = np.array(box).astype(np.int32).reshape((-1)) 350 | strResult = ','.join([str(p) for p in poly]) + '\r\n' 351 | f.write(strResult) 352 | poly = poly.reshape(-1, 2) 353 | cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) 354 | ptColor = (0, 255, 255) 355 | cv2.imwrite(res_img_file, img) 356 | logger.info("average kl, cc is %f, %f"%(klall/cnt, ccall/cnt)) 357 | 358 | return 359 | 360 | from torchvision import transforms 361 | def validate_article(config, data_loader, model): 362 | # waiting 363 | criterion = Maploss() 364 | saliency_loc = config.OUTPUT + '/output/article_saliency/' 365 | if not os.path.exists(saliency_loc): 366 | os.makedirs(saliency_loc) 367 | model.eval() 368 | print("saliency loc", saliency_loc) 369 | times = 0 370 | cnt = 0 371 | print("testing") 372 | with torch.no_grad(): 373 | for (name, images, _) in data_loader: 374 | # print(name, images.shape, _.shape) 375 | images = images.cuda(non_blocking=True) 376 | images = images.unsqueeze(0) 377 | st = time.time() 378 | 379 | output, __ = model(images, _) 380 | ed = time.time() 381 | print("time", ed-st) 382 | output = transforms.Resize((720,720))(output) 383 | 384 | # saliency_map = F.interpolate(outputs, size=(720, 720), mode='bilinear', align_corners=False).cpu().numpy() 385 | saliency_map = np.ascontiguousarray(output.detach().cpu().numpy()) 386 | saliency_map *= 255 387 | saliency_maps = saliency_map.astype(np.uint8) 388 | tmp_time = ed-st 389 | times += len(saliency_maps)*tmp_time 390 | cnt += len(saliency_maps) 391 | # print(name, len(saliency_maps)) 392 | 393 | name_i = name.split("/")[-1] 394 | path = saliency_loc + name_i 395 | print(type(saliency_map[0]),os.path.exists(saliency_loc)) 396 | # embed() 397 | re = cv2.imwrite(path, saliency_map[0][0]) 398 | print(path, "saved", re) 399 | print(times/cnt) 400 | print("images saved") 401 | 402 | 403 | @torch.no_grad() 404 | def throughput(data_loader, model, logger): 405 | model.eval() 406 | 407 | for idx, (images, _) in enumerate(data_loader): 408 | images = images.cuda(non_blocking=True) 409 | batch_size = images.shape[0] 410 | for i in range(50): 411 | model(images) 412 | torch.cuda.synchronize() 413 | logger.info(f"throughput averaged with 30 times") 414 | tic1 = time.time() 415 | for i in range(30): 416 | model(images) 417 | torch.cuda.synchronize() 418 | tic2 = time.time() 419 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 420 | return 421 | 422 | 423 | if __name__ == '__main__': 424 | _, config = parse_option() 425 | 426 | if config.AMP_OPT_LEVEL != "O0": 427 | assert amp is not None, "amp not installed!" 428 | 429 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 430 | rank = int(os.environ["RANK"]) 431 | world_size = int(os.environ['WORLD_SIZE']) 432 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 433 | else: 434 | rank = -1 435 | world_size = -1 436 | # torch.cuda.set_device(config.LOCAL_RANK) 437 | # torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 438 | # torch.distributed.barrier() 439 | 440 | # seed = config.SEED + dist.get_rank() 441 | # torch.manual_seed(seed) 442 | # np.random.seed(seed) 443 | cudnn.benchmark = True 444 | 445 | # linear scale the learning rate according to total batch size, may not be optimal 446 | linear_scaled_lr = config.TRAIN.BASE_LR # * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 447 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR# * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 448 | linear_scaled_min_lr = config.TRAIN.MIN_LR# * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 449 | 450 | config.defrost() 451 | config.TRAIN.BASE_LR = linear_scaled_lr 452 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 453 | config.TRAIN.MIN_LR = linear_scaled_min_lr 454 | config.freeze() 455 | 456 | os.makedirs(config.OUTPUT, exist_ok=True) 457 | logger = create_logger(output_dir=config.OUTPUT, name=f"{config.MODEL.NAME}") 458 | 459 | if True: 460 | path = os.path.join(config.OUTPUT, "config.json") 461 | with open(path, "w") as f: 462 | f.write(config.dump()) 463 | logger.info(f"Full config saved to {path}") 464 | 465 | # print config 466 | logger.info(config.dump()) 467 | # print(config) 468 | main(config) 469 | -------------------------------------------------------------------------------- /e-commercial/metrics.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from skimage.transform import resize 5 | from numpy import random 6 | from functools import partial 7 | 8 | def loss_func(pre, gt_sal, gt_fix, args): 9 | 10 | kl = kld_loss(pre, gt_sal) 11 | cc1 = cc(pre, gt_sal) 12 | nss1 = nss(pre, gt_fix) 13 | 14 | return kl, cc1, nss1 15 | 16 | 17 | def kld_loss(s_map, gt): # 1,1280,720 1,1280,720 18 | assert s_map.size() == gt.size() 19 | batch_size = s_map.size(0) 20 | w = s_map.size(1) 21 | h = s_map.size(2) 22 | # 以下相当于对feature的每一个元素做了归一化 23 | sum_s_map = torch.sum(s_map.view(batch_size, -1), 1) # batch_size个和 24 | expand_s_map = sum_s_map.view(batch_size, 1, 1).expand(batch_size, w, h) # 1,1280,720,value=sum_s_map 25 | assert expand_s_map.size() == s_map.size() 26 | 27 | sum_gt = torch.sum(gt.view(batch_size, -1), 1) 28 | expand_gt = sum_gt.view(batch_size, 1, 1).expand(batch_size, w, h) 29 | assert expand_gt.size() == gt.size() 30 | 31 | s_map = s_map / (expand_s_map * 1.0) # 1,1280,720 32 | gt = gt / (expand_gt * 1.0) # 1,1280,720 33 | 34 | s_map = s_map.view(batch_size, -1) # 1,921600 35 | gt = gt.view(batch_size, -1) # B,921600 36 | 37 | eps = 2.2204e-16 38 | result = gt * torch.log(eps + gt / (s_map + eps)) # B,921600 39 | 40 | return torch.mean(torch.sum(result, 1)) # 返回的是batch个kl的平均值 41 | 42 | 43 | def cc(s_map, gt): # 1,1280,720 1,1280,720 44 | assert s_map.size() == gt.size() 45 | batch_size = s_map.size(0) 46 | w = s_map.size(1) 47 | h = s_map.size(2) 48 | 49 | mean_s_map = torch.mean(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h) 50 | std_s_map = torch.std(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h) 51 | 52 | mean_gt = torch.mean(gt.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h) 53 | std_gt = torch.std(gt.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h) 54 | 55 | s_map = (s_map - mean_s_map) / std_s_map 56 | gt = (gt - mean_gt) / std_gt 57 | 58 | ab = torch.sum((s_map * gt).view(batch_size, -1), 1) 59 | aa = torch.sum((s_map * s_map).view(batch_size, -1), 1) 60 | bb = torch.sum((gt * gt).view(batch_size, -1), 1) 61 | 62 | return torch.mean(ab / (torch.sqrt(aa * bb))) 63 | 64 | 65 | def nss(s_map, gt): 66 | # def cal_nss_wsj(s_map, gt): 67 | # print(">> s_map ", s_map, s_map.size()) 68 | # print(">> gt ", gt, gt.size()) 69 | 70 | # if s_map.size() != gt.size(): 71 | # s_map = s_map.cpu().squeeze(0).numpy() 72 | # s_map = torch.FloatTensor(cv2.resize(s_map, (gt.size(2), gt.size(1)))).unsqueeze(0) 73 | # s_map = s_map.cuda() 74 | # gt = gt.cuda() 75 | # print(s_map.size(), gt.size()) 76 | assert s_map.size() == gt.size() 77 | batch_size = s_map.size(0) 78 | w = s_map.size(1) 79 | h = s_map.size(2) 80 | mean_s_map = torch.mean(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h) 81 | std_s_map = torch.std(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h) 82 | 83 | eps = 2.2204e-16 84 | s_map = (s_map - mean_s_map) / (std_s_map + eps) 85 | 86 | s_map = torch.sum((s_map * gt).view(batch_size, -1), 1) 87 | count = torch.sum(gt.view(batch_size, -1), 1) 88 | return torch.mean(s_map / count) 89 | 90 | 91 | def normalize_map(s_map): 92 | # normalize the salience map (as done in MIT code) 93 | batch_size = s_map.size(0) 94 | w = s_map.size(1) 95 | h = s_map.size(2) 96 | 97 | min_s_map = torch.min(s_map.view(batch_size, -1), 1)[0].view(batch_size, 1, 1).expand(batch_size, w, h) 98 | max_s_map = torch.max(s_map.view(batch_size, -1), 1)[0].view(batch_size, 1, 1).expand(batch_size, w, h) 99 | 100 | norm_s_map = (s_map - min_s_map) / (max_s_map - min_s_map * 1.0) 101 | return norm_s_map 102 | 103 | 104 | def similarity(s_map, gt): 105 | """ 106 | For single image metric 107 | Size of Image - WxH or 1xWxH 108 | gt is ground truth saliency map 109 | """ 110 | batch_size = s_map.size(0) 111 | w = s_map.size(1) 112 | h = s_map.size(2) 113 | 114 | s_map = normalize_map(s_map) 115 | gt = normalize_map(gt) 116 | 117 | sum_s_map = torch.sum(s_map.view(batch_size, -1), 1) 118 | expand_s_map = sum_s_map.view(batch_size, 1, 1).expand(batch_size, w, h) 119 | 120 | assert expand_s_map.size() == s_map.size() 121 | 122 | sum_gt = torch.sum(gt.view(batch_size, -1), 1) 123 | expand_gt = sum_gt.view(batch_size, 1, 1).expand(batch_size, w, h) 124 | 125 | s_map = s_map / (expand_s_map * 1.0) 126 | gt = gt / (expand_gt * 1.0) 127 | 128 | s_map = s_map.view(batch_size, -1) 129 | gt = gt.view(batch_size, -1) 130 | return torch.mean(torch.sum(torch.min(s_map, gt), 1)) 131 | 132 | 133 | def auc_judd(saliencyMap, fixationMap, jitter=True, toPlot=False, normalize=False): 134 | # saliencyMap is the saliency map 135 | # fixationMap is the human fixation map (binary matrix) 136 | # jitter=True will add tiny non-zero random constant to all map locations to ensure 137 | # ROC can be calculated robustly (to avoid uniform region) 138 | # if toPlot=True, displays ROC curve 139 | 140 | # If there are no fixations to predict, return NaN 141 | if saliencyMap.size() != fixationMap.size(): 142 | saliencyMap = saliencyMap.cpu().squeeze(0).numpy() 143 | saliencyMap = torch.FloatTensor(cv2.resize(saliencyMap, (fixationMap.size(2), fixationMap.size(1)))).unsqueeze( 144 | 0) 145 | # saliencyMap = saliencyMap.cuda() 146 | # fixationMap = fixationMap.cuda() 147 | if len(saliencyMap.size()) == 3: 148 | saliencyMap = saliencyMap[0, :, :] 149 | fixationMap = fixationMap[0, :, :] 150 | # saliencyMap = saliencyMap.numpy() 151 | # fixationMap = fixationMap.numpy() 152 | saliencyMap = saliencyMap 153 | fixationMap = fixationMap 154 | if normalize: 155 | saliencyMap = normalize_map(saliencyMap) 156 | 157 | if not fixationMap.any(): 158 | print('Error: no fixationMap') 159 | score = float('nan') 160 | return score 161 | 162 | # make the saliencyMap the size of the image of fixationMap 163 | 164 | if not np.shape(saliencyMap) == np.shape(fixationMap): 165 | from scipy.misc import imresize 166 | saliencyMap = imresize(saliencyMap, np.shape(fixationMap)) 167 | 168 | # jitter saliency maps that come from saliency models that have a lot of zero values. 169 | # If the saliency map is made with a Gaussian then it does not need to be jittered as 170 | # the values are varied and there is not a large patch of the same value. In fact 171 | # jittering breaks the ordering in the small values! 172 | if jitter: 173 | # jitter the saliency map slightly to distrupt ties of the same numbers 174 | saliencyMap = saliencyMap.cpu() + np.random.random(np.shape(saliencyMap.cpu())) / 10 ** 7 175 | 176 | # normalize saliency map 177 | saliencyMap = (saliencyMap - saliencyMap.min()) \ 178 | / (saliencyMap.max() - saliencyMap.min()) 179 | 180 | if np.isnan(saliencyMap).all(): 181 | print('NaN saliencyMap') 182 | score = float('nan') 183 | return score 184 | 185 | S = saliencyMap.flatten().cpu() 186 | F = fixationMap.flatten().cpu() 187 | 188 | Sth = S[F > 0] # sal map values at fixation locations 189 | Nfixations = len(Sth) 190 | Npixels = len(S) 191 | 192 | allthreshes = sorted(Sth, reverse=True) # sort sal map values, to sweep through values 193 | tp = np.zeros((Nfixations + 2)) 194 | fp = np.zeros((Nfixations + 2)) 195 | tp[0], tp[-1] = 0, 1 196 | fp[0], fp[-1] = 0, 1 197 | 198 | for i in range(Nfixations): 199 | thresh = allthreshes[i] 200 | aboveth = (S >= thresh).sum() # total number of sal map values above threshold 201 | tp[i + 1] = float(i + 1) / Nfixations # ratio sal map values at fixation locations 202 | # above threshold 203 | fp[i + 1] = float(aboveth - i) / (Npixels - Nfixations) # ratio other sal map values 204 | # above threshold 205 | 206 | score = np.trapz(tp, x=fp) 207 | allthreshes = np.insert(allthreshes, 0, 0) 208 | allthreshes = np.append(allthreshes, 1) 209 | 210 | if toPlot: 211 | import matplotlib.pyplot as plt 212 | fig = plt.figure() 213 | ax = fig.add_subplot(1, 2, 1) 214 | ax.matshow(saliencyMap, cmap='gray') 215 | ax.set_title('SaliencyMap with fixations to be predicted') 216 | [y, x] = np.nonzero(fixationMap) 217 | s = np.shape(saliencyMap) 218 | plt.axis((-.5, s[1] - .5, s[0] - .5, -.5)) 219 | plt.plot(x, y, 'ro') 220 | 221 | ax = fig.add_subplot(1, 2, 2) 222 | plt.plot(fp, tp, '.b-') 223 | ax.set_title('Area under ROC curve: ' + str(score)) 224 | plt.axis((0, 1, 0, 1)) 225 | plt.show() 226 | 227 | score = score 228 | return score 229 | 230 | 231 | def auc_judd_npy(saliencyMap, fixationMap, jitter=True, toPlot=False, normalize=False): 232 | 233 | # If there are no fixations to predict, return NaN 234 | # if saliencyMap.size() != fixationMap.size(): 235 | # saliencyMap = saliencyMap.cpu().squeeze(0).numpy() 236 | # saliencyMap = torch.FloatTensor(cv2.resize(saliencyMap, (fixationMap.size(2), fixationMap.size(1)))).unsqueeze( 237 | # 0) 238 | # saliencyMap = saliencyMap.cuda() 239 | # fixationMap = fixationMap.cuda() 240 | # if len(saliencyMap.size()) == 3: 241 | # saliencyMap = saliencyMap[0, :, :] 242 | # fixationMap = fixationMap[0, :, :] 243 | # saliencyMap = saliencyMap.numpy() 244 | # fixationMap = fixationMap.numpy() 245 | saliencyMap = saliencyMap 246 | fixationMap = fixationMap 247 | if normalize: 248 | saliencyMap = normalize_map(saliencyMap) 249 | 250 | if not fixationMap.any(): 251 | print('Error: no fixationMap') 252 | score = float('nan') 253 | return score 254 | 255 | # make the saliencyMap the size of the image of fixationMap 256 | 257 | if not np.shape(saliencyMap) == np.shape(fixationMap): 258 | from scipy.misc import imresize 259 | saliencyMap = imresize(saliencyMap, np.shape(fixationMap)) 260 | 261 | # jitter saliency maps that come from saliency models that have a lot of zero values. 262 | # If the saliency map is made with a Gaussian then it does not need to be jittered as 263 | # the values are varied and there is not a large patch of the same value. In fact 264 | # jittering breaks the ordering in the small values! 265 | if jitter: 266 | # jitter the saliency map slightly to distrupt ties of the same numbers 267 | saliencyMap = saliencyMap + np.random.random(np.shape(saliencyMap)) / 10 ** 7 268 | 269 | # normalize saliency map 270 | saliencyMap = (saliencyMap - saliencyMap.min()) \ 271 | / (saliencyMap.max() - saliencyMap.min()) 272 | 273 | if np.isnan(saliencyMap).all(): 274 | print('NaN saliencyMap') 275 | score = float('nan') 276 | return score 277 | 278 | S = saliencyMap.flatten() #.cpu() 279 | F = fixationMap.flatten() #.cpu() 280 | 281 | Sth = S[F > 0] # sal map values at fixation locations 282 | Nfixations = len(Sth) 283 | Npixels = len(S) 284 | 285 | allthreshes = sorted(Sth, reverse=True) # sort sal map values, to sweep through values 286 | tp = np.zeros((Nfixations + 2)) 287 | fp = np.zeros((Nfixations + 2)) 288 | tp[0], tp[-1] = 0, 1 289 | fp[0], fp[-1] = 0, 1 290 | 291 | for i in range(Nfixations): 292 | thresh = allthreshes[i] 293 | aboveth = (S >= thresh).sum() # total number of sal map values above threshold 294 | tp[i + 1] = float(i + 1) / Nfixations # ratio sal map values at fixation locations 295 | # above threshold 296 | fp[i + 1] = float(aboveth - i) / (Npixels - Nfixations) # ratio other sal map values 297 | # above threshold 298 | 299 | score = np.trapz(tp, x=fp) 300 | allthreshes = np.insert(allthreshes, 0, 0) 301 | allthreshes = np.append(allthreshes, 1) 302 | 303 | 304 | score = score 305 | return score 306 | 307 | 308 | def auc_shuff(s_map, gt, other_map, splits=100, stepsize=0.1): 309 | if len(s_map.size()) == 3: 310 | s_map = s_map[0, :, :] 311 | gt = gt[0, :, :] 312 | other_map = other_map[0, :, :] 313 | 314 | s_map = s_map.cpu().numpy() 315 | s_map = (s_map-np.min(s_map))/(np.max(s_map)-np.min(s_map)) 316 | gt = gt.cpu().numpy() 317 | other_map = other_map.cpu().numpy() 318 | 319 | num_fixations = np.sum(gt) 320 | 321 | x, y = np.where(other_map == 1) 322 | other_map_fixs = [] 323 | for j in zip(x, y): 324 | other_map_fixs.append(j[0] * other_map.shape[0] + j[1]) 325 | ind = len(other_map_fixs) 326 | assert ind == np.sum(other_map), 'something is wrong in auc shuffle' 327 | 328 | num_fixations_other = min(ind, num_fixations) 329 | 330 | num_pixels = s_map.shape[0] * s_map.shape[1] 331 | random_numbers = [] 332 | for i in range(0, splits): 333 | temp_list = [] 334 | t1 = np.random.permutation(ind) 335 | for k in t1: 336 | temp_list.append(other_map_fixs[k]) 337 | random_numbers.append(temp_list) 338 | 339 | aucs = [] 340 | # for each split, calculate auc 341 | for i in random_numbers: 342 | r_sal_map = [] 343 | for k in i: 344 | r_sal_map.append(s_map[k % s_map.shape[0] - 1, int(k / s_map.shape[0])]) 345 | # in these values, we need to find thresholds and calculate auc 346 | thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 347 | 348 | r_sal_map = np.array(r_sal_map) 349 | 350 | # once threshs are got 351 | thresholds = sorted(set(thresholds)) 352 | area = [] 353 | area.append((0.0, 0.0)) 354 | for thresh in thresholds: 355 | # in the salience map, keep only those pixels with values above threshold 356 | temp = np.zeros(s_map.shape) 357 | temp[s_map >= thresh] = 1.0 358 | num_overlap = np.where(np.add(temp, gt) == 2)[0].shape[0] 359 | tp = num_overlap / (num_fixations * 1.0) 360 | 361 | # fp = (np.sum(temp) - num_overlap)/((np.shape(gt)[0] * np.shape(gt)[1]) - num_fixations) 362 | # number of values in r_sal_map, above the threshold, divided by num of random locations = num of fixations 363 | fp = len(np.where(r_sal_map > thresh)[0]) / (num_fixations * 1.0) 364 | 365 | area.append((round(tp, 4), round(fp, 4))) 366 | 367 | area.append((1.0, 1.0)) 368 | area.sort(key=lambda x: x[0]) 369 | tp_list = [x[0] for x in area] 370 | fp_list = [x[1] for x in area] 371 | 372 | aucs.append(np.trapz(np.array(tp_list), np.array(fp_list))) 373 | 374 | return np.mean(aucs) 375 | 376 | 377 | # 来源:SalEMA 378 | def AUC_Borji(saliency_map, fixation_map, n_rep=100, step_size=0.1, rand_sampler=None): 379 | ''' 380 | This measures how well the saliency map of an image predicts the ground truth human fixations on the image. 381 | ROC curve created by sweeping through threshold values at fixed step size 382 | until the maximum saliency map value. 383 | True positive (tp) rate correspond to the ratio of saliency map values above threshold 384 | at fixation locations to the total number of fixation locations. 385 | False positive (fp) rate correspond to the ratio of saliency map values above threshold 386 | at random locations to the total number of random locations 387 | (as many random locations as fixations, sampled uniformly from fixation_map ALL IMAGE PIXELS), 388 | averaging over n_rep number of selections of random locations. 389 | Parameters 390 | ---------- 391 | saliency_map : real-valued matrix 392 | fixation_map : binary matrix 393 | Human fixation map. 394 | n_rep : int, optional 395 | Number of repeats for random sampling of non-fixated locations. 396 | step_size : int, optional 397 | Step size for sweeping through saliency map. 398 | rand_sampler : callable 399 | S_rand = rand_sampler(S, F, n_rep, n_fix) 400 | Sample the saliency map at random locations to estimate false positive. 401 | Return the sampled saliency values, S_rand.shape=(n_fix,n_rep) 402 | Returns 403 | ------- 404 | AUC : float, between [0,1] 405 | ''' 406 | saliency_map = np.array(saliency_map.cpu(), copy=False) 407 | fixation_map = np.array(fixation_map.cpu(), copy=False) > 0.5 408 | # If there are no fixation to predict, return NaN 409 | if not np.any(fixation_map): 410 | print('no fixation to predict') 411 | return np.nan 412 | # Make the saliency_map the size of the fixation_map 413 | if saliency_map.shape != fixation_map.shape: 414 | saliency_map = resize(saliency_map, fixation_map.shape, order=3, mode='nearest') 415 | # Normalize saliency map to have values between [0,1] 416 | saliency_map = (saliency_map - np.min(saliency_map)) / (np.max(saliency_map) - np.min(saliency_map)) 417 | 418 | S = saliency_map.ravel() 419 | F = fixation_map.ravel() 420 | S_fix = S[F] # Saliency map values at fixation locations 421 | n_fix = len(S_fix) 422 | n_pixels = len(S) 423 | # For each fixation, sample n_rep values from anywhere on the saliency map 424 | if rand_sampler is None: 425 | r = random.randint(0, n_pixels, [n_fix, n_rep]) 426 | S_rand = S[r] # Saliency map values at random locations (including fixated locations!? underestimated) 427 | else: 428 | S_rand = rand_sampler(S, F, n_rep, n_fix) 429 | # Calculate AUC per random split (set of random locations) 430 | auc = np.zeros(n_rep) * np.nan 431 | for rep in range(n_rep): 432 | thresholds = np.r_[0:np.max(np.r_[S_fix, S_rand[:,rep]]):step_size][::-1] 433 | tp = np.zeros(len(thresholds)+2) 434 | fp = np.zeros(len(thresholds)+2) 435 | tp[0] = 0; tp[-1] = 1 436 | fp[0] = 0; fp[-1] = 1 437 | for k, thresh in enumerate(thresholds): 438 | tp[k+1] = np.sum(S_fix >= thresh) / float(n_fix) 439 | fp[k+1] = np.sum(S_rand[:,rep] >= thresh) / float(n_fix) 440 | auc[rep] = np.trapz(tp, fp) 441 | return np.mean(auc) # Average across random splits 442 | 443 | 444 | def AUC_shuffled(saliency_map, fixation_map, other_map, n_rep=100, step_size=0.1): 445 | ''' 446 | This measures how well the saliency map of an image predicts the ground truth human fixations on the image. 447 | ROC curve created by sweeping through threshold values at fixed step size 448 | until the maximum saliency map value. 449 | True positive (tp) rate correspond to the ratio of saliency map values above threshold 450 | at fixation locations to the total number of fixation locations. 451 | False positive (fp) rate correspond to the ratio of saliency map values above threshold 452 | at random locations to the total number of random locations 453 | (as many random locations as fixations, sampled uniformly from fixation_map ON OTHER IMAGES), 454 | averaging over n_rep number of selections of random locations. 455 | Parameters 456 | ---------- 457 | saliency_map : real-valued matrix 458 | fixation_map : binary matrix 459 | Human fixation map. 460 | other_map : binary matrix, same shape as fixation_map 461 | A binary fixation map (like fixation_map) by taking the union of fixations from M other random images 462 | (Borji uses M=10). 463 | n_rep : int, optional 464 | Number of repeats for random sampling of non-fixated locations. 465 | step_size : int, optional 466 | Step size for sweeping through saliency map. 467 | Returns 468 | ------- 469 | AUC : float, between [0,1] 470 | ''' 471 | other_map = np.array(other_map.cpu(), copy=False) > 0.5 472 | 473 | if other_map.shape != fixation_map.shape: 474 | raise ValueError('other_map.shape != fixation_map.shape') 475 | 476 | # For each fixation, sample n_rep values (from fixated locations on other_map) on the saliency map 477 | def sample_other(other, S, F, n_rep, n_fix): 478 | fixated = np.nonzero(other)[0] 479 | indexer = list(map(lambda x: random.permutation(x)[:n_fix], np.tile(range(len(fixated)), [n_rep, 1]))) 480 | r = fixated[np.transpose(indexer)] 481 | S_rand = S[r] # Saliency map values at random locations (including fixated locations!? underestimated) 482 | return S_rand 483 | return AUC_Borji(saliency_map, fixation_map, n_rep, step_size, partial(sample_other, other_map.ravel())) 484 | 485 | 486 | 487 | -------------------------------------------------------------------------------- /e-commercial/main_cnn.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # based on Swin Transformer 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import time 7 | import argparse 8 | import datetime 9 | import numpy as np 10 | import torchvision 11 | import cv2 12 | 13 | # os.environ["CUDA_VISIBLE_DEVICES"] = '0, 2, 4, 7' 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | #import torch.distributed as dist 18 | 19 | from timm.utils import accuracy, AverageMeter 20 | from torchvision import transforms 21 | import torchvision 22 | import torch.nn.functional as F 23 | 24 | from config import get_config 25 | from models import build_model_new 26 | from models import craft_utils, imgproc 27 | from models.loss import KL_loss, Maploss 28 | from data import build_loader 29 | from lr_scheduler import build_scheduler 30 | from optimizer import build_optimizer 31 | from logger import create_logger 32 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor 33 | 34 | try: 35 | # noinspection PyUnresolvedReferences 36 | from apex import amp 37 | except ImportError: 38 | amp = None 39 | 40 | 41 | def parse_option(): 42 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) 43 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 44 | parser.add_argument( 45 | "--opts", 46 | help="Modify config options by adding 'KEY VALUE' pairs. ", 47 | default=None, 48 | nargs='+', 49 | ) 50 | 51 | # easy config modification 52 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 53 | parser.add_argument('--data-path', type=str, help='path to dataset') 54 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 55 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 56 | help='no: no cache, ' 57 | 'full: cache all data, ' 58 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 59 | parser.add_argument('--resume', help='resume from checkpoint') 60 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 61 | parser.add_argument('--use-checkpoint', action='store_true', 62 | help="whether to use gradient checkpointing to save memory") 63 | parser.add_argument('--amp-opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2'], 64 | help='mixed precision opt level, if O0, no amp is used') 65 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 66 | help='root of output folder, the full path is // (default: output)') 67 | parser.add_argument('--tag', help='tag of experiment') 68 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 69 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 70 | 71 | # distributed training 72 | parser.add_argument("--local_rank", type=int, required=False, help='local rank for DistributedDataParallel') 73 | # parser.add_argument("--loss", type=str, required=True, help='debugging for different loss') 74 | 75 | # about dataset 76 | parser.add_argument('--dataset', type=str, default='imagenet', help='name of dataset') 77 | parser.add_argument('--head', type=str, default='denseNet_15layer', help='name of dataset') 78 | parser.add_argument('--datanum', type=int, default=972, help='num of dataset') 79 | 80 | parser.add_argument('--lr', type=float, default=5e-4, help='learning rate') 81 | 82 | args, unparsed = parser.parse_known_args() 83 | 84 | config = get_config(args) 85 | 86 | return args, config 87 | 88 | 89 | def main(config): 90 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) 91 | 92 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 93 | model = build_model_new(config) 94 | model.cuda() 95 | logger.info(str(model)) 96 | 97 | optimizer = build_optimizer(config, model) 98 | if config.AMP_OPT_LEVEL != "O0": 99 | model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) 100 | model = torch.nn.DataParallel(model) 101 | model_without_ddp = model.module 102 | 103 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 104 | logger.info(f"number of params: {n_parameters}") 105 | if hasattr(model_without_ddp, 'flops'): 106 | flops = model_without_ddp.flops() 107 | logger.info(f"number of GFLOPs: {flops / 1e9}") 108 | 109 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 110 | 111 | # criterion = KL_Loss() 112 | criterion = "" 113 | 114 | max_accuracy = 0.0 115 | 116 | if config.TRAIN.AUTO_RESUME: 117 | resume_file = auto_resume_helper(config.OUTPUT) 118 | if resume_file: 119 | if config.MODEL.RESUME: 120 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 121 | config.defrost() 122 | config.MODEL.RESUME = resume_file 123 | config.freeze() 124 | logger.info(f'auto resuming from {resume_file}') 125 | else: 126 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 127 | 128 | if config.MODEL.RESUME: 129 | print("resuming") 130 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 131 | validate(config, data_loader_val, model) 132 | # logger.info(f"loss of the network on the {len(dataset_val)} test images: {loss:.1f}%") 133 | if config.EVAL_MODE: 134 | print("eval mode") 135 | return 136 | 137 | if config.THROUGHPUT_MODE: 138 | throughput(data_loader_val, model, logger) 139 | return 140 | 141 | logger.info("Start training") 142 | start_time = time.time() 143 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 144 | print("traning %d epoch" % epoch) 145 | #data_loader_train.sampler.set_epoch(epoch) 146 | 147 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) 148 | print("training %d epoch done" % epoch) 149 | if (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 150 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) 151 | 152 | validate(config, data_loader_val, model) 153 | print("imgs saved") 154 | 155 | total_time = time.time() - start_time 156 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 157 | logger.info('Training time {}'.format(total_time_str)) 158 | 159 | 160 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): 161 | model.train() 162 | optimizer.zero_grad() 163 | 164 | num_steps = len(data_loader) 165 | batch_time = AverageMeter() 166 | ocr_loss_meter = AverageMeter() 167 | attn_loss_meter = AverageMeter() 168 | saliency_loss_meter = AverageMeter() 169 | loss_meter = AverageMeter() 170 | norm_meter = AverageMeter() 171 | 172 | start = time.time() 173 | end = time.time() 174 | cnt = 0 175 | test_loc = config.OUTPUT + '/test/' 176 | # print("saving in test_loc", test_loc) 177 | # if not os.path.exists(test_loc): 178 | # os.makedirs(test_loc) 179 | for idx, (idx_name, samples, targets, ocr_target) in enumerate(data_loader): 180 | #print(samples.shape, targets.shape) 181 | # idx_name = idx_name.tolist() 182 | samples = samples.cuda(non_blocking=True) 183 | targets = targets.cuda(non_blocking=True) 184 | outputs = model(samples, targets) 185 | # print("output.shape", outputs.shape) 186 | targets = transforms.Resize(56)(targets) 187 | # if epoch == 3 and cnt == 0: 188 | # torchvision.utils.save_image(samples, test_loc+str(epoch)+"_epoch_"+str(cnt)+'_image_'+str(idx_name)+'.jpg') 189 | # torchvision.utils.save_image(targets, test_loc+str(epoch)+"_epoch_"+str(cnt)+'_map_'+str(idx_name)+'.jpg') 190 | #cnt+=1 191 | loss = KL_loss(outputs, targets) 192 | # print("looking loss", idx, loss) 193 | # print("saliency mode", saliency_loss) 194 | # ocr_target = {key: ocr_target[key].cuda(non_blocking=True) for key in ocr_target} 195 | # if mixup_fn is not None: 196 | # samples, targets = mixup_fn(samples, targets) 197 | 198 | # outputs, attn_loss, ocr_out, feature = model(samples, targets) 199 | 200 | # print("output.size", output.shape) 201 | # print("ocr_out.shape", ocr_out.shape) 202 | # outputs, attn_loss = model(samples, targets) 203 | # outputs = model(samples, targets) 204 | ''' 205 | attn_loss_meter.update(attn_loss, targets.size(0)) 206 | ocr_loss_meter.update(ocr_loss, targets.size(0)) 207 | saliency_loss_meter.update(saliency_loss, targets.size(0)) 208 | # print("detect loss from outside", saliency_loss, "\n", attn_loss, "\n", ocr_loss) 209 | 210 | if config.LOSS == 'ocr': 211 | ocr_out = model(samples, targets, config.LOSS) 212 | out1 = ocr_out[:, :, :, 0].cuda() 213 | out2 = ocr_out[:, :, :, 1].cuda() 214 | gah_label = ocr_target["gah_label"].resize_(out2.size()) 215 | gh_label = ocr_target["gh_label"].resize_(out1.size()) 216 | mask = ocr_target["mask"] 217 | ocr_loss = criterion[1](gh_label, gah_label, out2, out1, mask) 218 | print("ocr mode", ocr_loss) 219 | loss = ocr_loss 220 | elif config.LOSS == 'saliency': 221 | outputs = model(samples, targets, config.LOSS) 222 | targets = transforms.Resize(56)(targets) 223 | saliency_loss = criterion[0](outputs, targets) 224 | print("saliency mode", saliency_loss) 225 | loss = saliency_loss 226 | elif config.LOSS == 'attn': 227 | attn_loss = model(samples, targets, config.LOSS) 228 | print("attention mode", attn_loss) 229 | loss = attn_loss 230 | else: 231 | outputs, attn_loss, ocr_out, feature = model(samples, targets, config.LOSS) 232 | print("normal mode") 233 | loss = saliency_loss + attn_loss + ocr_loss 234 | ''' 235 | optimizer.zero_grad() 236 | 237 | if config.AMP_OPT_LEVEL != "O0": 238 | with amp.scale_loss(loss, optimizer) as scaled_loss: 239 | scaled_loss.backward() 240 | if config.TRAIN.CLIP_GRAD: 241 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 242 | else: 243 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 244 | else: 245 | loss.backward() 246 | if config.TRAIN.CLIP_GRAD: 247 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 248 | else: 249 | grad_norm = get_grad_norm(model.parameters()) 250 | optimizer.step() 251 | lr_scheduler.step_update(epoch * num_steps + idx) 252 | 253 | torch.cuda.synchronize() 254 | 255 | loss_meter.update(loss.item(), targets.size(0)) 256 | norm_meter.update(grad_norm) 257 | batch_time.update(time.time() - end) 258 | end = time.time() 259 | 260 | if idx % config.PRINT_FREQ == 0: 261 | lr = optimizer.param_groups[0]['lr'] 262 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 263 | etas = batch_time.avg * (num_steps - idx) 264 | logger.info( 265 | f'Train: [{idx}/{len(data_loader)}]\t' 266 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 267 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 268 | f'Mem {memory_used:.0f}MB\t' 269 | f'attn_loss {attn_loss_meter.val:.4f} ({attn_loss_meter.avg:.4f})\t' 270 | f'ocr_loss {ocr_loss_meter.val:.4f} ({ocr_loss_meter.avg:.4f})\t' 271 | f'saliency_loss {saliency_loss_meter.val:.4f} ({saliency_loss_meter.avg:.4f})' 272 | ) 273 | epoch_time = time.time() - start 274 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 275 | 276 | 277 | @torch.no_grad() 278 | def validate(config, data_loader, model): 279 | # waiting 280 | saliency_loc = config.OUTPUT + '/ans/saliency/' 281 | if not os.path.exists(saliency_loc): 282 | os.makedirs(saliency_loc) 283 | model.eval() 284 | with torch.no_grad(): 285 | for idx, (idx_name, images, target, ocr_target) in enumerate(data_loader): 286 | idx_name = idx_name.tolist() 287 | # print(idx_name, "3") 288 | images = images.cuda(non_blocking=True) 289 | target = target.cuda(non_blocking=True) 290 | # torchvision.utils.save_image(images[0], saliency_loc + str(idx_name) + '_ori.jpg') 291 | #print(images) 292 | #non = target.nonzero() 293 | #print(non, " none \n") 294 | #print(target[non[0][0]][non[0][1]][non[0][2]][non[0][3]]) 295 | # ocr_target = {key: ocr_target[key].cuda(non_blocking=True) for key in ocr_target} 296 | #output, attn_loss, ocr_out, feature = model(images, target) 297 | 298 | 299 | output = model(images, target) 300 | target = transforms.Resize(56)(target) 301 | loss = KL_loss(output, target) 302 | print("loss", idx_name, "validate", loss) 303 | saliency_map = np.array(F.interpolate(output, size=images.size()[2:], mode='bilinear', align_corners=False).cpu()) 304 | saliency_map = np.ascontiguousarray(saliency_map) 305 | saliency_map *= 255 306 | saliency_maps = saliency_map.astype(np.uint8) 307 | # embed() 308 | for i, saliency_map in enumerate(saliency_maps): 309 | print("saving", saliency_loc + str(idx_name[i]) + '.jpg') 310 | path = saliency_loc + str(idx_name[i]) + '.jpg' 311 | re = cv2.imwrite(path, saliency_map[0]) 312 | # print(re) 313 | ''' 314 | ocr_out = ocr_out.cpu() 315 | images = images.cpu() 316 | score_text = ocr_out[0, :, :, 0].cpu().data.numpy() 317 | score_link = ocr_out[0, :, :, 1].cpu().data.numpy() 318 | 319 | boxes, polys = craft_utils.getDetBoxes(score_text, score_link, 0.7, 0.4, 0.4, False) 320 | # print(boxes.size()) 321 | 322 | # boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) 323 | # polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) 324 | # print(boxes) 325 | for k in range(len(polys)): 326 | if polys[k] is None: polys[k] = boxes[k] 327 | 328 | render_img = score_text.copy() 329 | render_img = np.hstack((render_img, score_link)) 330 | ret_score_text = imgproc.cvt2HeatmapImg(render_img) 331 | # output, attn_loss = model(images, target) 332 | # output = model(images, target) 333 | image=cv2.imread("/home/your_path/experiment/EC/swin-transformer/ECdata/ALLSTIMULI/"+str(idx_name) + '.jpg') 334 | image=cv2.resize(image,(896,896)) 335 | img = np.array(images) 336 | boxes = polys 337 | # make result file list 338 | # filename, file_ext = os.path.splitext(os.path.basename(img_file)) 339 | # dirname=/temp_disk2/home/leise/ali/chineseocr1/CRAFT-Reimplementation-master/result/img/ 340 | # result directory 341 | # res_file = '/temp_disk2/leise/ali/CRAFT-Reimplementation-master/data/result1/txt/' + "res_" + filename + '.txt' 342 | # res_img_file = dirname + "res_" + filename + '.jpg' 343 | if not os.path.exists('/home/your_path/experiment/EC/swin-transformer/output/ans/ocr/box/'): 344 | os.makedirs('/home/your_path/experiment/EC/swin-transformer/output/ans/ocr/box/') 345 | os.makedirs('/home/your_path/experiment/EC/swin-transformer/output/ans/ocr/anchor/') 346 | res_img_file = r'/home/your_path/experiment/EC/swin-transformer/output/ans/ocr/box/' + str(idx_name) + '.jpg' 347 | res_file = r'/home/your_path/experiment/EC/swin-transformer/output/ans/ocr/anchor/' + str(idx_name) + '.txt' 348 | # if not os.path.isdir(dirname): 349 | # os.mkdir(dirname) 350 | 351 | with open(res_file, 'w') as f: 352 | for i, box in enumerate(boxes): 353 | poly = np.array(box).astype(np.int32).reshape((-1)) 354 | strResult = ','.join([str(p) for p in poly]) + '\r\n' 355 | print("str result", strResult) 356 | f.write(strResult) 357 | poly = poly.reshape(-1, 2) 358 | cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) 359 | ptColor = (0, 255, 255) 360 | # if verticals is not None: 361 | # if verticals[i]: 362 | # ptColor = (255, 0, 0) 363 | 364 | # if texts is not None: 365 | # font = cv2.FONT_HERSHEY_SIMPLEX 366 | # font_scale = 0.5 367 | # cv2.putText(img, "{}".format(texts[i]), (poly[0][0]+1, poly[0][1]+1), font, font_scale, (0, 0, 0), thickness=1) 368 | # cv2.putText(img, "{}".format(texts[i]), tuple(poly[0]), font, font_scale, (0, 255, 255), thickness=1) 369 | 370 | # Save result image 371 | # img=cv2.addWeighted(ret_score_text,0.5,img,0.5,0) 372 | cv2.imwrite(res_img_file, img) 373 | ''' 374 | 375 | # print('time:%s' % ((T2 - T1)*1000)) 376 | # print('cc:%s' %np.mean(self.cc)) 377 | 378 | return 379 | 380 | 381 | @torch.no_grad() 382 | def throughput(data_loader, model, logger): 383 | model.eval() 384 | 385 | for idx, (images, _) in enumerate(data_loader): 386 | images = images.cuda(non_blocking=True) 387 | batch_size = images.shape[0] 388 | for i in range(50): 389 | model(images) 390 | torch.cuda.synchronize() 391 | logger.info(f"throughput averaged with 30 times") 392 | tic1 = time.time() 393 | for i in range(30): 394 | model(images) 395 | torch.cuda.synchronize() 396 | tic2 = time.time() 397 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 398 | return 399 | 400 | 401 | if __name__ == '__main__': 402 | _, config = parse_option() 403 | 404 | if config.AMP_OPT_LEVEL != "O0": 405 | assert amp is not None, "amp not installed!" 406 | 407 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 408 | rank = int(os.environ["RANK"]) 409 | world_size = int(os.environ['WORLD_SIZE']) 410 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 411 | else: 412 | rank = -1 413 | world_size = -1 414 | # torch.cuda.set_device(config.LOCAL_RANK) 415 | # torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 416 | # torch.distributed.barrier() 417 | 418 | # seed = config.SEED + dist.get_rank() 419 | # torch.manual_seed(seed) 420 | # np.random.seed(seed) 421 | cudnn.benchmark = True 422 | 423 | # linear scale the learning rate according to total batch size, may not be optimal 424 | linear_scaled_lr = config.TRAIN.BASE_LR # * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 425 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR# * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 426 | linear_scaled_min_lr = config.TRAIN.MIN_LR# * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 427 | # gradient accumulation also need to scale the learning rate 428 | if config.TRAIN.ACCUMULATION_STEPS > 1: 429 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 430 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 431 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 432 | config.defrost() 433 | config.TRAIN.BASE_LR = linear_scaled_lr 434 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 435 | config.TRAIN.MIN_LR = linear_scaled_min_lr 436 | config.freeze() 437 | 438 | os.makedirs(config.OUTPUT, exist_ok=True) 439 | logger = create_logger(output_dir=config.OUTPUT, name=f"{config.MODEL.NAME}") 440 | 441 | if True: 442 | path = os.path.join(config.OUTPUT, "config.json") 443 | with open(path, "w") as f: 444 | f.write(config.dump()) 445 | logger.info(f"Full config saved to {path}") 446 | 447 | # print config 448 | logger.info(config.dump()) 449 | # print(config) 450 | main(config) 451 | --------------------------------------------------------------------------------