├── README.md ├── config.py └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # MIAV 2 | Multi-scale interactive network with artery/vein discriminator for retinal vessel classification 3 | 4 | # Update Results 5 | For DRIVE-AV, Se, Sp, and Acc are 98.99%, 93.15%, and 97.79%, respectively 6 | 7 | # Open source datasets 8 | 9 | The 100 fundus images dataset has been released. 10 | 11 | [Open_BUA_AV](https://pan.baidu.com/s/1V0eChuEa6_ec0lVtP7Nu1w ) 12 | 13 | Password:00di 14 | 15 | The labeled data has been normalized. Users can handle it by themselves according to the usage situation. 16 | 17 | ## Author 18 | JingFeiHu GuangWu HuaWang 19 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __all__ = ["proj_root", "arg_config"] 4 | 5 | from collections import OrderedDict 6 | 7 | proj_root = os.path.dirname(__file__) 8 | datasets_root = "dataset" 9 | 10 | DRIVE_AV_train = os.path.join(datasets_root, "Eye_AV", "DRIVE_AV/training") 11 | DRIVE_AV_path = os.path.join(datasets_root, "Eye_AV", "DRIVE_AV/test") 12 | 13 | arg_config = { 14 | "model": "PCNet_ISE", # 实际使用的模型,需要在`network/__init__.py`中导入 15 | "Discriminator": "Discriminator", 16 | "info": "chan32-drive", # 关于本次实验的额外信息说明,这个会附加到本次试验的exp_name的结尾,如果为空,则不会附加内容。 17 | "use_amp": False, # 是否使用amp加速训练 18 | "resume_mode": "test", # the mode for resume parameters: ['train', 'test', ''] 19 | "save_pre": True, # 是否保留最终的预测结果 20 | "epoch_num": 8000, # 训练周期, 0: directly test model 21 | "lr": 0.0002, # 微调时缩小100倍 22 | "channel": 32, 23 | "xlsx_name": "", # the name of the record file 24 | # 数据集设置 25 | "rgb_data": { 26 | "tr_data_path": DRIVE_AV_train, 27 | "te_data_list": OrderedDict( 28 | { 29 | "DRIVE_AV": DRIVE_AV_path, 30 | }, 31 | ), 32 | }, 33 | # 训练过程中的监控信息 34 | "tb_update": 10, # >0 则使用tensorboard 35 | "print_freq": 10, # >0, 保存迭代过程中的信息 36 | "save_vis_freq": 100, # >0,保存可视化结果 37 | "save_middle_res": 100, # 保存和测试中间结果 38 | # img_prefix, gt_prefix,用在使用索引文件的时候的对应的扩展名 39 | "prefix": (".jpg", ".png"), 40 | "size_list": None, # 不使用多尺度训练 41 | "reduction": "mean", # 损失处理的方式,可选“mean”和“sum” 42 | # 优化器与学习率衰减 43 | "optim": "adam", # 自定义部分的学习率 44 | "weight_decay": 5e-4, # 微调时设置为0.0001 45 | "momentum": 0.9, 46 | "patch": True, 47 | "nesterov": False, 48 | "sche_usebatch": False, 49 | "lr_type": "poly", 50 | "warmup_epoch": 1, 51 | # depond on the special lr_type, only lr_type has 'warmup', when set it to 1, it means no warmup. 52 | "lr_decay": 0.9, # poly 53 | "use_bigt": True, # 训练时是否对真值二值化(阈值为0.5) 54 | "batch_size": 8, # 要是继续训练, 最好使用相同的batchsize 55 | "num_workers": 0, # 不要太大, 不然运行多个程序同时训练的时候, 会造成数据读入速度受影响 56 | "input_size": 256, 57 | 58 | # 损失函数控制 59 | "base_loss": True, 60 | "use_aux_loss": True, # 是否使用辅助损失 61 | "use_en_loss": False, 62 | "use_dice_loss": False, 63 | "dif_loss": True, 64 | "topo_loss": True, 65 | # GAN loss - [lsgan, hinge] you only choose one to True 66 | "lsgan_loss": True, 67 | "hinge_loss": False, 68 | 69 | # Noisy Labels控制 70 | "AFM": True, 71 | 72 | # Fca控制 73 | "Fca": False, 74 | 75 | # Dropout控制 76 | "Dropout": True, 77 | 78 | # 超参数设置 79 | "aux_weight": 0.8, # 利用辅助的损失,及CEL损失,可在loss下的CEL中查看 80 | "dice_weight": 0.01, # 鲁棒的Dice 81 | "en_loss": 0.1, # 连通性loss 82 | "dif_weight": 0.1, # bce loss应用在差分中 83 | "base_weight": 1, # 最基础的损失,我们采用的是BCE损失 84 | "topo_weight": 0.5, 85 | "gan_weight": 1, # 控制GAN损失的参数 86 | 87 | # 控制动静脉判别器的权重分配,和为1 88 | "fake_A_weight": 0.2, 89 | "fake_V_weight": 0.2, 90 | "fake_AV_weight": 0.6, 91 | } 92 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from datetime import datetime 3 | import torch 4 | 5 | from config import arg_config, proj_root 6 | from utils.misc import construct_exp_name, construct_path, construct_print, pre_mkdir, set_seed 7 | from utils.solver import Solver 8 | 9 | construct_print(f"{datetime.now()}: Initializing...") 10 | construct_print(f"Project Root: {proj_root}") 11 | init_start = datetime.now() 12 | 13 | exp_name = construct_exp_name(arg_config) 14 | path_config = construct_path( 15 | proj_root=proj_root, exp_name=exp_name, xlsx_name=arg_config["xlsx_name"], 16 | ) 17 | pre_mkdir(path_config) 18 | set_seed(seed=0, use_cudnn_benchmark=arg_config["size_list"] is not None) 19 | 20 | solver = Solver(exp_name, arg_config, path_config) 21 | construct_print(f"Total initialization time:{datetime.now() - init_start}") 22 | 23 | shutil.copy(f"{proj_root}/config.py", path_config["cfg_log"]) 24 | shutil.copy(f"{proj_root}/utils/solver.py", path_config["trainer_log"]) 25 | 26 | construct_print(f"{datetime.now()}: Start...") 27 | if arg_config["resume_mode"] == "test": 28 | solver.test() 29 | else: 30 | solver.train() 31 | construct_print(f"{datetime.now()}: End...") 32 | --------------------------------------------------------------------------------