├── README-old.md ├── README.md ├── images ├── length.png └── style.png └── src ├── bash_trainval_mtmc_diy_yolov2.sh ├── bash_trainval_mtmc_resnet18_ft.sh ├── bash_val_mtmc_resnet18_ft.sh ├── diy_folder.py ├── diy_yolov2.py ├── folder_raw_torchvision_dataset.py ├── main_mtmc_resnet.py ├── main_stmc_resnet.py ├── mtmc_coordinator.py ├── mtmcconfig.py └── mtmcmodel.py /README-old.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 1. Pytorch-Multi-Task-Multi-class-Classification 4 | 5 | **MTMC-Pytorch:** 6 | MTMC-Pytorch = Multi-Task Multi-Class Classification Project using Pytorch. 7 | 8 | **目的:** 9 | 旨在搭建一个分类问题在Pytorch框架下的通解,批量解决单任务多分类问题、多任务多分类问题。 10 | 11 | **使用:** 12 | 需要做的准备工作是将样本整理成如下格式按文件夹存放; 13 | MTMC自动解析任务获取类别标签、自适应样本均衡、模型训练、模型评估等过程。 14 | 15 | ``` 16 | MLDataloader load MTMC dataset as following directory tree. 17 | Make sur train-val directory tree keeps consistency. 18 | 19 | data_root_path 20 | ├── task_A 21 | │ ├── train 22 | │ │ ├── class_1 23 | │ │ ├── class_2 24 | │ │ ├── class_3 25 | │ │ └── class_4 26 | │ └── val 27 | │ ├── class_1 28 | │ ├── class_2 29 | │ ├── class_3 30 | │ └── class_4 31 | └── task_B 32 | ├── train 33 | │ ├── class_1 34 | │ ├── class_2 35 | │ └── class_3 36 | └── val 37 | ├── class_1 38 | ├── class_2 39 | └── class_3 40 | 41 | ``` 42 | 43 | **备注:** 44 | 1. 通用的,而不是对任意问题都是最优的; 45 | 2. 目的是集成分类问题诸多训练Tricks; 46 | 3. 项目不再更新,止步于Gluon CV;(https://github.com/dmlc/gluon-cv ,尽管他不是很完整。Ref: Bag of Tricks for Image Classification with Convolutional Neural Networks https://arxiv.org/abs/1812.01187v2 ) 47 | 48 | # 2. Pytorch Version Info 49 | 50 | ``` 51 | $ conda list | grep torch 52 | pytorch 0.4.1 py36_cuda0.0_cudnn0.0_1 pytorch 53 | torchvision 0.2.1 py36_1 pytorch 54 | ``` 55 | 56 | # 3. Train-Val Logs 57 | 58 | 59 | ## 3.1 类目分类问题(Trainval_log_2018-07-09) 60 | 61 | **Argmax运算前,网络输出值与准确率关系如下:** 62 | 63 | ![](https://ws4.sinaimg.cn/large/006tKfTcgy1ft6v36bcjtj30k40dfgoc.jpg) 64 | 65 | **Softmax运算后,网络输出值与准确率关系如下:** 66 | 67 | ![](https://ws3.sinaimg.cn/large/006tKfTcgy1ft6xjm5m4yj30k40df0u5.jpg) 68 | 69 | **类目判别在检测之后,另在Softmax后分布显示准确度收敛较好,所以类目部分直接输出Argmax对应的标签做为判定结果。** 70 | 71 | ## 3.2 印花分类问题 72 | 73 | 多类别值置信度值预测 74 | 75 | **Argmax运算前,网络输出值与准确率关系如下:** 76 | 77 | ![](https://ws2.sinaimg.cn/large/006tKfTcgy1ft6x7qvvv8j30k40df76z.jpg) 78 | 79 | **Softmax运算后,网络输出值与准确率关系如下:** 80 | 81 | ![](https://ws2.sinaimg.cn/large/006tKfTcgy1ft6x7zo20bj30k40df76y.jpg) 82 | 83 | **鉴于印花问题存在以下问题:标签定义域关系相重叠,存在未定义印花种类。 84 | 搭建服务时,在服务内部进行置信度值映射:** 85 | 86 | ``` 87 | class_idx_dict = { 88 | "0" : "五角星", 89 | "1" : "人物", 90 | "2" : "几何", 91 | "3" : "动物鸟虫", 92 | "4" : "千鸟", 93 | "5" : "卡通", 94 | "6" : "复古", 95 | "7" : "大花", 96 | "8" : "字母数字汉字", 97 | "9" : "手绘", 98 | "10" : "斑马纹", 99 | "11" : "条纹", 100 | "12" : "格子", 101 | "13" : "植物风景", 102 | "14" : "波点", 103 | "15" : "渐变色", 104 | "16" : "爱心", 105 | "17" : "牛仔", 106 | "18" : "碎花", 107 | "19" : "纯色", 108 | "20" : "色块拼色", 109 | "21" : "豹纹蛇纹", 110 | "22" : "迷彩", 111 | "23" : "食物水果"} 112 | 113 | pvals = [0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, 0.6, 0.62, 0.64, 0.66, 0.68, 0.7, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98] 114 | 115 | ptmat_dict = { 116 | "五角星" : [1.30, 1.80, 2.30, 2.70, 3.00, 3.30, 3.50, 3.80, 4.00, 4.20, 4.40, 4.50, 4.70, 4.90, 5.00, 5.20, 5.30, 5.50, 5.60, 5.80, 5.90, 6.00, 6.20, 6.30, 6.50, 6.60, 6.70, 6.90, 7.00, 7.20, 7.30, 7.50, 7.60, 7.80, 7.90, 8.10, 8.30, 8.50, 8.70, 8.90, 9.10, 9.40, 9.70, 10.10, 11.20], 117 | ... 118 | ... 119 | "食物水果" : [2.40, 2.90, 3.30, 3.60, 3.90, 4.20, 4.50, 4.70, 4.90, 5.10, 5.30, 5.50, 5.70, 5.80, 6.00, 6.10, 6.30, 6.50, 6.60, 6.70, 6.90, 7.00, 7.20, 7.30, 7.50, 7.60, 7.70, 7.90, 8.00, 8.10, 8.30, 8.40, 8.60, 8.80, 8.90, 9.10, 9.30, 9.50, 9.70, 10.00, 10.30, 10.60, 11.00, 11.60, 12.70] 120 | } 121 | ``` 122 | 123 | **不使用Softmax之后的值,而使用网络的输出值,可以得到两个置信度都比较高的分类判别信息。** 124 | 125 | **例如下图人物卡通图案,会在“人物”&“卡通”两个类别下都有较高的置信度值:** 126 | 127 | ![](https://ws1.sinaimg.cn/large/006tNc79gy1ft4we0p3puj30l30g27c9.jpg) 128 | 129 | 130 | **另,输出结果如果要用作半结构化信息时,采用归一化过的置信度值是欠妥的。** 131 | 132 | ## 3.3 日志文件保存为*.txt文件,使用Excel打开展示结果如下: 133 | 134 | ![](https://ws4.sinaimg.cn/large/006tNc79ly1fz1bikc8edj30ql0p8wid.jpg) 135 | 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 1. Pytorch-Multi-Task-Multi-class-Classification 4 | 5 | **MTMC-Pytorch:** 6 | MTMC-Pytorch = Multi-Task Multi-Class Classification Project using Pytorch. 7 | 8 | **目的:** 9 | 旨在搭建一个分类问题在Pytorch框架下的通解,批量解决单任务多分类问题、多任务多分类问题。 10 | 11 | **备注:** 12 | 1. 通用的,而不是对任意问题都是最优的; 13 | 2. 目的是集成分类问题诸多训练Tricks; 14 | 3. 项目不再更新,止步于Gluon CV;(https://github.com/dmlc/gluon-cv ,尽管他不是很完整。Ref: Bag of Tricks for Image Classification with Convolutional Neural Networks https://arxiv.org/abs/1812.01187v2 ) 15 | 16 | # 2. Pytorch Version Info 17 | 18 | ``` 19 | $ conda list | grep torch 20 | pytorch 0.4.1 py36_cuda0.0_cudnn0.0_1 pytorch 21 | torchvision 0.2.1 py36_1 pytorch 22 | ``` 23 | 24 | **经验证,pytorch=1.3.0版本也是支持的。--2020.02.29** 25 | 26 | # 3. Getting Started 27 | 28 | ## 3.1 Data Preparation 29 | 30 | 将样本整理成如下格式按文件夹存放: 31 | 32 | ``` 33 | MLDataloader load MTMC dataset as following directory tree. 34 | Make sur train-val directory tree keeps consistency. 35 | 36 | data_root_path 37 | ├── task_A 38 | │ ├── train 39 | │ │ ├── class_1 40 | │ │ ├── class_2 41 | │ │ ├── class_3 42 | │ │ └── class_4 43 | │ └── val 44 | │ ├── class_1 45 | │ ├── class_2 46 | │ ├── class_3 47 | │ └── class_4 48 | └── task_B 49 | ├── train 50 | │ ├── class_1 51 | │ ├── class_2 52 | │ └── class_3 53 | └── val 54 | ├── class_1 55 | ├── class_2 56 | └── class_3 57 | 58 | ``` 59 | 60 | ## 3.1 Train-Val Logs 61 | 62 | MTMC自动解析任务获取类别标签、自适应样本均衡、模型训练、模型评估等过程。 63 | 64 | 你需要做的步骤如下: 65 | 66 | Step 1. 修改```/src/bash_trainval_mtmc_resnet18_ft.sh```文件确认数据地址,模型参数训练参数等。 67 | 68 | ``` 69 | DATA=../data/pants 70 | MAX_BASE_NUMBER=5000 71 | 72 | ARC=resnet18 73 | CLASS_NUM=24 # deprecated in mtmc 74 | 75 | # 336X224--S11X7--MP7X7--512*(11-7+1)=512*5=2560 76 | # 960:640 = 3:2 = 224*1.5:224 = 336:224 = 384:256 = 1.5:1 77 | DATALOADER_RESIZE_H=384 78 | DATALOADER_RESIZE_W=256 79 | INPUTLAYER_H=336 80 | INPUTLAYER_W=224 81 | FC_FEATURES=2560 82 | 83 | EPOCHS=120 84 | FC_EPOCHS=50 85 | 86 | BATCHSIZE=256 87 | WORKERS=8 88 | 89 | LEARNING_RATE=0.01 90 | WEIGHT_DECAY=0.0001 91 | 92 | TRAIN_LOG_FILENAME=$ARC"_train_`date +%Y%m%d_%H%M%S`".log 93 | VAL_LOG_FILENAME=$ARC"_val_`date +%Y%m%d_%H%M%S`".log 94 | 95 | python main_mtmc_resnet.py --data $DATA \ 96 | --dataloader_resize_h $DATALOADER_RESIZE_H \ 97 | --dataloader_resize_w $DATALOADER_RESIZE_W \ 98 | --inputlayer_h $INPUTLAYER_H \ 99 | --inputlayer_w $INPUTLAYER_W \ 100 | --fc_features $FC_FEATURES \ 101 | --max_base_number $MAX_BASE_NUMBER \ 102 | --arc $ARC \ 103 | --workers $WORKERS \ 104 | --pretrained \ 105 | --epochs $EPOCHS \ 106 | --fc_epochs $FC_EPOCHS \ 107 | --batch_size $BATCHSIZE \ 108 | --learning-rate $LEARNING_RATE \ 109 | --weight-decay $WEIGHT_DECAY \ 110 | 2>&1 | tee $TRAIN_LOG_FILENAME 111 | 112 | echo "Train... Done." 113 | 114 | python main_mtmc_resnet.py --data $DATA \ 115 | --dataloader_resize_h $DATALOADER_RESIZE_H \ 116 | --dataloader_resize_w $DATALOADER_RESIZE_W \ 117 | --inputlayer_h $INPUTLAYER_H \ 118 | --inputlayer_w $INPUTLAYER_W \ 119 | --fc_features $FC_FEATURES \ 120 | --arc $ARC \ 121 | --workers $WORKERS \ 122 | --evaluate \ 123 | --resume model_best_checkpoint_$ARC.pth.tar \ 124 | --batch_size $BATCHSIZE \ 125 | 2>&1 | tee $VAL_LOG_FILENAME 126 | 127 | echo "Val... Done." 128 | ``` 129 | 130 | Step 2. 执行```/src/bash_trainval_mtmc_resnet18_ft.sh```文件 131 | 132 | ``` 133 | $ bash bash_trainval_mtmc_resnet18_ft.sh 134 | ``` 135 | 136 | Step 3. 在```/src``` & ```/src/vals```中查看训练日志和结果,日志文件保存为*.txt文件,使用Excel打开展示结果如下: 137 | 138 | 示例任务:裤子属性分析,分为两个任务,裤型分类和裤长分类; 139 | 140 | **裤型:** 141 | 142 | ![](images/style.png) 143 | 144 | **裤长:** 145 | 146 | ![](images/length.png) -------------------------------------------------------------------------------- /images/length.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuluyeah/Pytorch-Multi-Task-Multi-class-Classification/d1fbeea6038e21ae786bbb4c33b8b8252166eb9b/images/length.png -------------------------------------------------------------------------------- /images/style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuluyeah/Pytorch-Multi-Task-Multi-class-Classification/d1fbeea6038e21ae786bbb4c33b8b8252166eb9b/images/style.png -------------------------------------------------------------------------------- /src/bash_trainval_mtmc_diy_yolov2.sh: -------------------------------------------------------------------------------- 1 | 2 | DATA=../data 3 | MAX_BASE_NUMBER=5000 4 | 5 | ARC=Yolov2_768x512 6 | CLASS_NUM=24 7 | 8 | EPOCHS=120 9 | FC_EPOCHS=50 10 | 11 | BATCHSIZE=128 12 | WORKERS=8 13 | 14 | LEARNING_RATE=0.01 15 | WEIGHT_DECAY=0.0001 16 | 17 | TRAIN_LOG_FILENAME=$ARC"_train_`date +%Y%m%d_%H%M%S`".log 18 | VAL_LOG_FILENAME=$ARC"_val_`date +%Y%m%d_%H%M%S`".log 19 | 20 | python main_mtmc_resnet.py --data $DATA \ 21 | --max_base_number $MAX_BASE_NUMBER \ 22 | --arc $ARC \ 23 | --workers $WORKERS \ 24 | --pretrained \ 25 | --epochs $EPOCHS \ 26 | --fc_epochs $FC_EPOCHS \ 27 | --batch_size $BATCHSIZE \ 28 | --learning-rate $LEARNING_RATE \ 29 | --weight-decay $WEIGHT_DECAY \ 30 | 2>&1 | tee $TRAIN_LOG_FILENAME 31 | 32 | echo "Train... Done." 33 | 34 | python main_mtmc_resnet.py --data $DATA \ 35 | --arc $ARC \ 36 | --workers $WORKERS \ 37 | --evaluate \ 38 | --resume model_best_checkpoint_$ARC.pth.tar \ 39 | --batch_size $BATCHSIZE \ 40 | 2>&1 | tee $VAL_LOG_FILENAME 41 | 42 | echo "Val... Done." 43 | 44 | -------------------------------------------------------------------------------- /src/bash_trainval_mtmc_resnet18_ft.sh: -------------------------------------------------------------------------------- 1 | 2 | DATA=../data 3 | MAX_BASE_NUMBER=5000 4 | 5 | ARC=resnet18 6 | CLASS_NUM=24 # deprecated in mtmc 7 | :' 8 | # Restnet:224X224--S7X7--MP7X7--512 9 | DATALOADER_RESIZE_H=256 10 | DATALOADER_RESIZE_W=256 11 | INPUTLAYER_H=224 12 | INPUTLAYER_W=224 13 | FC_FEATURES=512 14 | ' 15 | :' 16 | # Inception:320, 299 17 | DATALOADER_RESIZE_H=320 18 | DATALOADER_RESIZE_W=320 19 | INPUTLAYER_H=299 20 | INPUTLAYER_W=299 21 | FC_FEATURES=* 22 | ' 23 | # 336X224--S11X7--MP7X7--512*(11-7+1)=512*5=2560 24 | # 960:640 = 3:2 = 224*1.5:224 = 336:224 = 384:256 = 1.5:1 25 | DATALOADER_RESIZE_H=384 26 | DATALOADER_RESIZE_W=256 27 | INPUTLAYER_H=336 28 | INPUTLAYER_W=224 29 | FC_FEATURES=2560 30 | 31 | EPOCHS=120 32 | FC_EPOCHS=50 33 | 34 | BATCHSIZE=256 35 | WORKERS=8 36 | 37 | LEARNING_RATE=0.01 38 | WEIGHT_DECAY=0.0001 39 | 40 | TRAIN_LOG_FILENAME=$ARC"_train_`date +%Y%m%d_%H%M%S`".log 41 | VAL_LOG_FILENAME=$ARC"_val_`date +%Y%m%d_%H%M%S`".log 42 | 43 | python main_mtmc_resnet.py --data $DATA \ 44 | --dataloader_resize_h $DATALOADER_RESIZE_H \ 45 | --dataloader_resize_w $DATALOADER_RESIZE_W \ 46 | --inputlayer_h $INPUTLAYER_H \ 47 | --inputlayer_w $INPUTLAYER_W \ 48 | --fc_features $FC_FEATURES \ 49 | --max_base_number $MAX_BASE_NUMBER \ 50 | --arc $ARC \ 51 | --workers $WORKERS \ 52 | --pretrained \ 53 | --epochs $EPOCHS \ 54 | --fc_epochs $FC_EPOCHS \ 55 | --batch_size $BATCHSIZE \ 56 | --learning-rate $LEARNING_RATE \ 57 | --weight-decay $WEIGHT_DECAY \ 58 | 2>&1 | tee $TRAIN_LOG_FILENAME 59 | 60 | echo "Train... Done." 61 | 62 | python main_mtmc_resnet.py --data $DATA \ 63 | --dataloader_resize_h $DATALOADER_RESIZE_H \ 64 | --dataloader_resize_w $DATALOADER_RESIZE_W \ 65 | --inputlayer_h $INPUTLAYER_H \ 66 | --inputlayer_w $INPUTLAYER_W \ 67 | --fc_features $FC_FEATURES \ 68 | --arc $ARC \ 69 | --workers $WORKERS \ 70 | --evaluate \ 71 | --resume model_best_checkpoint_$ARC.pth.tar \ 72 | --batch_size $BATCHSIZE \ 73 | 2>&1 | tee $VAL_LOG_FILENAME 74 | 75 | echo "Val... Done." 76 | 77 | -------------------------------------------------------------------------------- /src/bash_val_mtmc_resnet18_ft.sh: -------------------------------------------------------------------------------- 1 | 2 | DATA=../data 3 | MAX_BASE_NUMBER=5000 4 | 5 | ARC=resnet18 6 | CLASS_NUM=24 # deprecated in mtmc 7 | :' 8 | # Restnet:224X224--S7X7--MP7X7--512 9 | DATALOADER_RESIZE_H=256 10 | DATALOADER_RESIZE_W=256 11 | INPUTLAYER_H=224 12 | INPUTLAYER_W=224 13 | FC_FEATURES=512 14 | ' 15 | :' 16 | # Inception:320, 299 17 | DATALOADER_RESIZE_H=320 18 | DATALOADER_RESIZE_W=320 19 | INPUTLAYER_H=299 20 | INPUTLAYER_W=299 21 | FC_FEATURES=* 22 | ' 23 | # 336X224--S11X7--MP7X7--512*(11-7+1)=512*5=2560 24 | # 960:640 = 3:2 = 224*1.5:224 = 336:224 = 384:256 = 1.5:1 25 | DATALOADER_RESIZE_H=384 26 | DATALOADER_RESIZE_W=256 27 | INPUTLAYER_H=336 28 | INPUTLAYER_W=224 29 | FC_FEATURES=2560 30 | 31 | EPOCHS=120 32 | FC_EPOCHS=50 33 | 34 | BATCHSIZE=256 35 | WORKERS=8 36 | 37 | LEARNING_RATE=0.01 38 | WEIGHT_DECAY=0.0001 39 | 40 | TRAIN_LOG_FILENAME=$ARC"_train_`date +%Y%m%d_%H%M%S`".log 41 | VAL_LOG_FILENAME=$ARC"_val_`date +%Y%m%d_%H%M%S`".log 42 | 43 | :' 44 | python main_mtmc_resnet.py --data $DATA \ 45 | --dataloader_resize_h $DATALOADER_RESIZE_H \ 46 | --dataloader_resize_w $DATALOADER_RESIZE_W \ 47 | --inputlayer_h $INPUTLAYER_H \ 48 | --inputlayer_w $INPUTLAYER_W \ 49 | --fc_features $FC_FEATURES \ 50 | --max_base_number $MAX_BASE_NUMBER \ 51 | --arc $ARC \ 52 | --workers $WORKERS \ 53 | --pretrained \ 54 | --epochs $EPOCHS \ 55 | --fc_epochs $FC_EPOCHS \ 56 | --batch_size $BATCHSIZE \ 57 | --learning-rate $LEARNING_RATE \ 58 | --weight-decay $WEIGHT_DECAY \ 59 | 2>&1 | tee $TRAIN_LOG_FILENAME 60 | ' 61 | echo "Train... Done." 62 | 63 | python main_mtmc_resnet.py --data $DATA \ 64 | --dataloader_resize_h $DATALOADER_RESIZE_H \ 65 | --dataloader_resize_w $DATALOADER_RESIZE_W \ 66 | --inputlayer_h $INPUTLAYER_H \ 67 | --inputlayer_w $INPUTLAYER_W \ 68 | --fc_features $FC_FEATURES \ 69 | --arc $ARC \ 70 | --workers $WORKERS \ 71 | --evaluate \ 72 | --resume model_best_checkpoint_$ARC.pth.tar \ 73 | --batch_size $BATCHSIZE \ 74 | 2>&1 | tee $VAL_LOG_FILENAME 75 | 76 | echo "Val... Done." 77 | 78 | -------------------------------------------------------------------------------- /src/diy_folder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Nov 5 14:52:43 2018 5 | 6 | Python 3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 11:07:29) 7 | 8 | @author: pilgrim.bin@163.com 9 | """ 10 | 11 | import os 12 | import os.path 13 | import copy 14 | import math 15 | import random 16 | 17 | import torch.utils.data as data 18 | 19 | from PIL import Image 20 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 21 | 22 | '''-------------------------------------------------------------------''' 23 | 24 | def pil_loader(path): 25 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 26 | with open(path, 'rb') as f: 27 | img = Image.open(f) 28 | return img.convert('RGB') 29 | 30 | 31 | def accimage_loader(path): 32 | import accimage 33 | try: 34 | return accimage.Image(path) 35 | except IOError: 36 | # Potentially a decoding problem, fall back to PIL.Image 37 | return pil_loader(path) 38 | 39 | 40 | def default_loader(path): 41 | from torchvision import get_image_backend 42 | if get_image_backend() == 'accimage': 43 | return accimage_loader(path) 44 | else: 45 | return pil_loader(path) 46 | 47 | '''-------------------------------------------------------------------''' 48 | 49 | def has_file_allowed_extension(filename, extensions): 50 | """Checks if a file is an allowed extension. 51 | 52 | Args: 53 | filename (string): path to a file 54 | 55 | Returns: 56 | bool: True if the filename ends with a known image extension 57 | """ 58 | filename_lower = filename.lower() 59 | return any(filename_lower.endswith(ext) for ext in extensions) 60 | 61 | 62 | def find_classes(dir): 63 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 64 | classes.sort() 65 | class_to_idx = {classes[i]: i for i in range(len(classes))} 66 | return classes, class_to_idx 67 | 68 | 69 | def make_dataset(dir, class_to_idx, extensions): 70 | images = [] 71 | dir = os.path.expanduser(dir) 72 | for target in sorted(os.listdir(dir)): 73 | d = os.path.join(dir, target) 74 | if not os.path.isdir(d): 75 | continue 76 | 77 | for root, _, fnames in sorted(os.walk(d)): 78 | for fname in sorted(fnames): 79 | if has_file_allowed_extension(fname, extensions): 80 | path = os.path.join(root, fname) 81 | item = (path, class_to_idx[target]) 82 | images.append(item) 83 | 84 | return images 85 | 86 | # return all type filepath of this path 87 | def get_filelist(path): 88 | filelist = [] 89 | for root,dirs,filenames in os.walk(path): 90 | for fn in filenames: 91 | filelist.append(os.path.join(root,fn)) 92 | return filelist 93 | 94 | # return img filepath of this path 95 | def get_img_filelist(path): 96 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 97 | filelist = [] 98 | for root,dirs,filenames in os.walk(path): 99 | for fn in filenames: 100 | if has_file_allowed_extension(fn, IMG_EXTENSIONS): 101 | filelist.append(os.path.join(root,fn)) 102 | return filelist 103 | 104 | 105 | '''-------------------------------------------------------------------''' 106 | 107 | class DatasetFolderParsing(): 108 | """ 109 | Args: 110 | root (string): Root directory path. 111 | extensions (list[string]): A list of allowed extensions. 112 | 113 | Attributes: 114 | classes (list): List of the class names. 115 | class_to_idx (dict): Dict with items (class_name, class_index). 116 | samples (list): List of (sample path, class_index) tuples 117 | """ 118 | 119 | def __init__(self, root): 120 | classes, class_to_idx = find_classes(root) 121 | print('---------------------------------') 122 | print("DatasetFolderParsing::class_to_idx = {}".format(class_to_idx)) 123 | extensions = IMG_EXTENSIONS 124 | class_to_samples = {key : get_img_filelist(os.path.join(root, key)) for key in class_to_idx} 125 | 126 | for key in class_to_samples.keys(): 127 | if not len(class_to_samples[key]) > 0: 128 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\{}".format(key) + "\n" 129 | "Supported extensions are: " + ",".join(extensions))) 130 | 131 | self.root = root 132 | self.extensions = extensions 133 | 134 | self.classes = classes 135 | self.class_to_idx = class_to_idx 136 | self.class_to_samples = class_to_samples 137 | 138 | 139 | 140 | class ImageFolder_SpecifiedNumber(data.Dataset): 141 | """A generic data loader where the images are arranged in this way: :: 142 | 143 | root/dog/xxx.png 144 | root/dog/xxy.png 145 | root/dog/xxz.png 146 | 147 | root/cat/123.png 148 | root/cat/nsdf3.png 149 | root/cat/asd932_.png 150 | 151 | Args: 152 | root (string): Root directory path. 153 | dbparser : DatasetFolderParsing instance. 154 | number_dict : specified number of each dict. 155 | transform (callable, optional): A function/transform that takes in an PIL image 156 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 157 | target_transform (callable, optional): A function/transform that takes in the 158 | target and transforms it. 159 | loader (callable, optional): A function to load an image given its path. 160 | 161 | Attributes: 162 | classes (list): List of the class names. 163 | class_to_idx (dict): Dict with items (class_name, class_index). 164 | imgs (list): List of (image path, class_index) tuples 165 | """ 166 | def __init__(self, dbparser, number_dict=None, 167 | transform=None, target_transform=None, 168 | loader=default_loader): 169 | ''' 170 | if not isinstance(dbparser, DatasetFolderParsing): 171 | raise(RuntimeError("dbparser must be DatasetFolderParsing instance!")) 172 | ''' 173 | 174 | if number_dict is None: 175 | raise(RuntimeError("number_dict cannot set as None!")) 176 | 177 | # check if keys right 178 | for key in number_dict.keys(): 179 | if not key in dbparser.classes: 180 | raise(RuntimeError("Unknown class = {}.".format(key))) 181 | 182 | super(ImageFolder_SpecifiedNumber, self).__init__() 183 | 184 | self.root = dbparser.root 185 | self.loader = loader 186 | self.extensions = dbparser.extensions 187 | 188 | self.classes = dbparser.classes 189 | self.class_to_idx = dbparser.class_to_idx 190 | 191 | samples = [] 192 | print("ImageFolder_SpecifiedNumber:") 193 | for key in number_dict.keys(): 194 | number = number_dict[key] 195 | this_class_samples = dbparser.class_to_samples[key] 196 | random.shuffle(this_class_samples) 197 | # 可改,改为队列式,同时保持随机性。改为每次向队列请求n个样本,队列根据自身长度返回相应的样本, 198 | # 尽可能保证外发样本的随机性是不够的,还要保证每一个样本参与训练的随机性是平均分布,而不是高斯分布。 199 | n_repeat = int(math.ceil(1. * number / len(this_class_samples))) 200 | sub_samples = [] 201 | for i in range(n_repeat): 202 | sub_samples += this_class_samples 203 | sub_samples = sub_samples[:number] 204 | sub_samples = [(s, self.class_to_idx[key]) for s in sub_samples] 205 | samples += sub_samples 206 | print("class = {}, n_repeat = {}".format(key, n_repeat)) 207 | 208 | random.shuffle(samples) 209 | random.shuffle(samples) 210 | self.samples = samples 211 | 212 | self.transform = transform 213 | self.target_transform = target_transform 214 | 215 | def __getitem__(self, index): 216 | """ 217 | Args: 218 | index (int): Index 219 | 220 | Returns: 221 | tuple: (sample, target) where target is class_index of the target class. 222 | """ 223 | path, target = self.samples[index] 224 | #sample = path # for test 225 | sample = self.loader(path) 226 | if self.transform is not None: 227 | sample = self.transform(sample) 228 | if self.target_transform is not None: 229 | target = self.target_transform(target) 230 | 231 | return sample, target 232 | 233 | def __len__(self): 234 | return len(self.samples) 235 | 236 | '''-------------------------------------------------------------------''' 237 | 238 | class BaseNumberSpecifier(): 239 | """ 240 | Args: 241 | class_to_number_dict (dict): Dict with items (class_name, number). 242 | base_number (int): base numner of each class. 243 | 244 | Attributes: 245 | class_to_number_dict (dict): Dict with items (class_name, number). 246 | """ 247 | 248 | def __init__(self, class_to_number_dict, base_number): 249 | 250 | if not len(class_to_number_dict) > 0: 251 | raise(RuntimeError("Error: len(class_to_number_dict) <= 0.")) 252 | if not base_number > 0: 253 | raise(RuntimeError("Error: base_number > 0.")) 254 | 255 | self.class_number = len(class_to_number_dict) 256 | self.base_number = base_number 257 | self.class_to_number_dict = copy.deepcopy(class_to_number_dict) 258 | self.class_to_prec_dict = None 259 | for key in self.class_to_number_dict.keys(): 260 | self.class_to_number_dict[key] = base_number 261 | 262 | def update(self, class_to_prec_dict): 263 | if not len(class_to_prec_dict) == len(self.class_to_number_dict): 264 | raise(RuntimeError("Error: len(class_to_prec_dict) != len(self.class_to_number_dict).")) 265 | self.class_to_prec_dict = copy.deepcopy(class_to_prec_dict) 266 | 267 | for key in class_to_prec_dict.keys(): 268 | ''' 269 | if class_to_prec_dict[key] < 0 or class_to_prec_dict[key] > 1: # =1 ? 270 | raise(RuntimeError("Error: class_to_prec_dict[key] < 0 or class_to_prec_dict[key] > 1")) 271 | ''' 272 | # dummy protect 273 | if class_to_prec_dict[key] <= 0: 274 | class_to_prec_dict[key] = 0.001 275 | if class_to_prec_dict[key] >= 1: 276 | class_to_prec_dict[key] = 0.9999 277 | 278 | # 惩戒太过了 279 | ''' 280 | #weight_dict = {key : - math.log(class_to_prec_dict[key]) for key in class_to_prec_dict.keys()} 281 | #weight_dict = {key : 1 - (class_to_prec_dict[key]) for key in class_to_prec_dict.keys()} 282 | 283 | sum_weight = sum([weight_dict[key] for key in weight_dict.keys()]) 284 | number_ratio = float(self.class_number * self.base_number) / sum_weight 285 | for key in weight_dict.keys(): 286 | self.class_to_number_dict[key] = int(number_ratio * weight_dict[key]) 287 | ''' 288 | 289 | # w = 1 - p 290 | # w = (1 - p)**2 291 | # K = M + (N - M) * w 292 | weight_dict = {key : 1 - (class_to_prec_dict[key]) for key in class_to_prec_dict.keys()} 293 | # weight_dict = {key : (1 - (class_to_prec_dict[key]))**2 for key in class_to_prec_dict.keys()} 294 | M = self.base_number 295 | N = M * 2 # raw = 3 296 | for key in weight_dict.keys(): 297 | self.class_to_number_dict[key] = int(M + weight_dict[key] * (N - M)) 298 | 299 | if __name__ == '__main__': 300 | 301 | path = 'data' 302 | base_number = 5 303 | 304 | dbparser = DatasetFolderParsing(path) 305 | 306 | print('--------------base test----------------') 307 | number_dict = {key : 0 for key in dbparser.class_to_idx.keys()} 308 | base_number_specifier = BaseNumberSpecifier(number_dict, base_number) 309 | number_dict=base_number_specifier.class_to_number_dict 310 | dataloader = ImageFolder_SpecifiedNumber(dbparser, 311 | number_dict=number_dict) 312 | print("number_dict = {}".format(number_dict)) 313 | for d, target in dataloader: 314 | print("t-d = {}-{}".format(target, d)) 315 | 316 | 317 | print('--------------fake p test----------------') 318 | fake_p = {key : random.random() for key in dbparser.class_to_idx.keys()} 319 | print("fake_p = {}".format(fake_p)) 320 | base_number_specifier.update(fake_p) 321 | number_dict=base_number_specifier.class_to_number_dict 322 | dataloader = ImageFolder_SpecifiedNumber(dbparser, 323 | number_dict=number_dict) 324 | print("number_dict = {}".format(number_dict)) 325 | 326 | for d, target in dataloader: 327 | print("t-d = {}-{}".format(target, d)) 328 | 329 | -------------------------------------------------------------------------------- /src/diy_yolov2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Nov 5 14:52:43 2018 5 | 6 | Python 3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 11:07:29) 7 | 8 | @author: pilgrim.bin@163.com 9 | """ 10 | 11 | import torch.nn as nn 12 | 13 | cfg = { 14 | 'Yolov2_960x640':[[32, 3, 1], # <----960x640 15 | 'M', 16 | [64, 3, 1], 17 | 'M', 18 | [128, 3, 1], 19 | [64, 1, 0], 20 | [128, 3, 1], 21 | 'M', 22 | [256, 3, 1], 23 | [128, 1, 0], 24 | [256, 3, 1], 25 | 'M', 26 | [512, 3, 1], 27 | [256, 1, 0], 28 | [512, 3, 1], 29 | [256, 1, 0], 30 | [512, 3, 1], 31 | 'M', # 5*M = <----30x20 32 | [1024, 3, 0], 33 | [512, 1, 0], 34 | [1024, 3, 0], 35 | 'M', 36 | [1024, 3, 0], 37 | [512, 1, 0], 38 | [1024, 3, 0]], 39 | 40 | 'Yolov2_768x512_raw':[[32, 3, 1], # <----768x512 41 | 'M', 42 | [64, 3, 1], 43 | 'M', 44 | [128, 3, 1], 45 | [64, 1, 0], 46 | [128, 3, 1], 47 | 'M', 48 | [256, 3, 1], 49 | [128, 1, 0], 50 | [256, 3, 1], 51 | 'M', 52 | [512, 3, 1], 53 | [256, 1, 0], 54 | [512, 3, 1], 55 | 'M', # 5*M <----24x16 56 | [1024, 3, 1], 57 | [512, 1, 0], 58 | [1024, 3, 1], 59 | 'M', # 6*M <----12x8 60 | [1024, 3, 1], 61 | [512, 1, 0], 62 | [1024, 3, 1], 63 | 'M', # 7*M <----6x4 64 | [1024, 3, 1], 65 | [512, 1, 0], 66 | [512, 3, 1], 67 | 'M'], # 8*M <----3x2 68 | 69 | 'Yolov2_768x512_v2':[[32, 3, 1], # <----768x512 70 | 'M', 71 | [64, 3, 1], 72 | 'M', 73 | [128, 3, 1], 74 | [64, 1, 0], 75 | [128, 3, 1], 76 | 'M', 77 | [128, 3, 1], 78 | [64, 1, 0], 79 | [128, 3, 1], 80 | 'M', 81 | [256, 3, 1], 82 | [128, 1, 0], 83 | [256, 3, 1], 84 | 'M', # 5*M <----24x16 85 | [256, 3, 1], 86 | [128, 1, 0], 87 | [256, 3, 1], 88 | 'M', # 6*M <----12x8 89 | [512, 3, 1], 90 | [256, 1, 0], 91 | [512, 3, 1], 92 | 'M', # 7*M <----6x4 93 | [512, 3, 1], 94 | [256, 1, 0], 95 | [512, 3, 1], 96 | 'M'], # 8*M <----3x2 97 | 98 | # 'Yolov2_384x256_v3' --> batchsize=128, @Tesla P100-PCIE... cuMemory = 15743MiB / 16276MiB 99 | 'Yolov2_384x256_v3':[[32, 3, 1], # <----384x256 100 | 'M', 101 | [64, 3, 1], 102 | 'M', 103 | [128, 3, 1], 104 | [64, 1, 0], 105 | [128, 3, 1], 106 | 'M', 107 | [128, 3, 1], 108 | [64, 1, 0], 109 | [128, 3, 1], 110 | 'M', 111 | [256, 3, 1], 112 | [128, 1, 0], 113 | [256, 3, 1], 114 | 'M', # 5*M <----12x8 115 | [256, 3, 1], 116 | [128, 1, 0], 117 | [256, 3, 1], 118 | 'M', # 6*M <----6x4 119 | [512, 3, 1], 120 | [256, 1, 0], 121 | [512, 3, 1], 122 | 'M'] # 7*M <----3x2 123 | } 124 | 125 | def make_yolov2_backbone_layers(cfg, batch_norm=True): 126 | layers = [] 127 | in_channels = 3 128 | for v in cfg: 129 | if v == 'M': 130 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 131 | else: 132 | out_channels, kernel_size, padding = v 133 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) 134 | if batch_norm: 135 | layers += [conv2d, nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)] 136 | else: 137 | layers += [conv2d, nn.ReLU(inplace=True)] 138 | in_channels = out_channels 139 | return nn.Sequential(*layers) 140 | 141 | 142 | class Yolov2Backbone(nn.Module): 143 | def __init__(self, features, num_classes=1000, init_weights=True): 144 | super(Yolov2Backbone, self).__init__() 145 | self.features = features 146 | self.classifier = nn.Linear(512 * 3 * 2, num_classes) 147 | if init_weights: 148 | self._initialize_weights() 149 | 150 | def forward(self, x): 151 | x = self.features(x) 152 | x = x.view(x.size(0), -1) 153 | x = self.classifier(x) 154 | return x 155 | 156 | def _initialize_weights(self): 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 160 | if m.bias is not None: 161 | nn.init.constant_(m.bias, 0) 162 | elif isinstance(m, nn.BatchNorm2d): 163 | nn.init.constant_(m.weight, 1) 164 | nn.init.constant_(m.bias, 0) 165 | elif isinstance(m, nn.Linear): 166 | nn.init.normal_(m.weight, 0, 0.01) 167 | nn.init.constant_(m.bias, 0) 168 | 169 | def diy_yolov2(pretrained=False, state_dict=None, **kwargs): 170 | if pretrained: 171 | kwargs['init_weights'] = False 172 | model = Yolov2Backbone(make_yolov2_backbone_layers(cfg['Yolov2_384x256_v3']), **kwargs) 173 | if pretrained: 174 | model.load_state_dict(state_dict) 175 | return model 176 | 177 | 178 | 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /src/folder_raw_torchvision_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | 5 | import os 6 | import os.path 7 | 8 | 9 | def has_file_allowed_extension(filename, extensions): 10 | """Checks if a file is an allowed extension. 11 | 12 | Args: 13 | filename (string): path to a file 14 | 15 | Returns: 16 | bool: True if the filename ends with a known image extension 17 | """ 18 | filename_lower = filename.lower() 19 | return any(filename_lower.endswith(ext) for ext in extensions) 20 | 21 | 22 | def find_classes(dir): 23 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 24 | classes.sort() 25 | class_to_idx = {classes[i]: i for i in range(len(classes))} 26 | return classes, class_to_idx 27 | 28 | 29 | def make_dataset(dir, class_to_idx, extensions): 30 | images = [] 31 | dir = os.path.expanduser(dir) 32 | for target in sorted(os.listdir(dir)): 33 | d = os.path.join(dir, target) 34 | if not os.path.isdir(d): 35 | continue 36 | 37 | for root, _, fnames in sorted(os.walk(d)): 38 | for fname in sorted(fnames): 39 | if has_file_allowed_extension(fname, extensions): 40 | path = os.path.join(root, fname) 41 | item = (path, class_to_idx[target]) 42 | images.append(item) 43 | 44 | return images 45 | 46 | 47 | class DatasetFolder(data.Dataset): 48 | """A generic data loader where the samples are arranged in this way: :: 49 | 50 | root/class_x/xxx.ext 51 | root/class_x/xxy.ext 52 | root/class_x/xxz.ext 53 | 54 | root/class_y/123.ext 55 | root/class_y/nsdf3.ext 56 | root/class_y/asd932_.ext 57 | 58 | Args: 59 | root (string): Root directory path. 60 | loader (callable): A function to load a sample given its path. 61 | extensions (list[string]): A list of allowed extensions. 62 | transform (callable, optional): A function/transform that takes in 63 | a sample and returns a transformed version. 64 | E.g, ``transforms.RandomCrop`` for images. 65 | target_transform (callable, optional): A function/transform that takes 66 | in the target and transforms it. 67 | 68 | Attributes: 69 | classes (list): List of the class names. 70 | class_to_idx (dict): Dict with items (class_name, class_index). 71 | samples (list): List of (sample path, class_index) tuples 72 | """ 73 | 74 | def __init__(self, root, loader, extensions, transform=None, target_transform=None): 75 | classes, class_to_idx = find_classes(root) 76 | samples = make_dataset(root, class_to_idx, extensions) 77 | if len(samples) == 0: 78 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" 79 | "Supported extensions are: " + ",".join(extensions))) 80 | 81 | self.root = root 82 | self.loader = loader 83 | self.extensions = extensions 84 | 85 | self.classes = classes 86 | self.class_to_idx = class_to_idx 87 | self.samples = samples 88 | 89 | self.transform = transform 90 | self.target_transform = target_transform 91 | 92 | def __getitem__(self, index): 93 | """ 94 | Args: 95 | index (int): Index 96 | 97 | Returns: 98 | tuple: (sample, target) where target is class_index of the target class. 99 | """ 100 | path, target = self.samples[index] 101 | sample = self.loader(path) 102 | if self.transform is not None: 103 | sample = self.transform(sample) 104 | if self.target_transform is not None: 105 | target = self.target_transform(target) 106 | 107 | return sample, target 108 | 109 | def __len__(self): 110 | return len(self.samples) 111 | 112 | def __repr__(self): 113 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 114 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 115 | fmt_str += ' Root Location: {}\n'.format(self.root) 116 | tmp = ' Transforms (if any): ' 117 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 118 | tmp = ' Target Transforms (if any): ' 119 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 120 | return fmt_str 121 | 122 | 123 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 124 | 125 | 126 | def pil_loader(path): 127 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 128 | with open(path, 'rb') as f: 129 | img = Image.open(f) 130 | return img.convert('RGB') 131 | 132 | 133 | def accimage_loader(path): 134 | import accimage 135 | try: 136 | return accimage.Image(path) 137 | except IOError: 138 | # Potentially a decoding problem, fall back to PIL.Image 139 | return pil_loader(path) 140 | 141 | 142 | def default_loader(path): 143 | from torchvision import get_image_backend 144 | if get_image_backend() == 'accimage': 145 | return accimage_loader(path) 146 | else: 147 | return pil_loader(path) 148 | 149 | 150 | class ImageFolder(DatasetFolder): 151 | """A generic data loader where the images are arranged in this way: :: 152 | 153 | root/dog/xxx.png 154 | root/dog/xxy.png 155 | root/dog/xxz.png 156 | 157 | root/cat/123.png 158 | root/cat/nsdf3.png 159 | root/cat/asd932_.png 160 | 161 | Args: 162 | root (string): Root directory path. 163 | transform (callable, optional): A function/transform that takes in an PIL image 164 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 165 | target_transform (callable, optional): A function/transform that takes in the 166 | target and transforms it. 167 | loader (callable, optional): A function to load an image given its path. 168 | 169 | Attributes: 170 | classes (list): List of the class names. 171 | class_to_idx (dict): Dict with items (class_name, class_index). 172 | imgs (list): List of (image path, class_index) tuples 173 | """ 174 | def __init__(self, root, transform=None, target_transform=None, 175 | loader=default_loader): 176 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 177 | transform=transform, 178 | target_transform=target_transform) 179 | self.imgs = self.samples 180 | -------------------------------------------------------------------------------- /src/main_mtmc_resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jun 12 17:07:39 2018 5 | 6 | Python 3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 11:07:29) 7 | 8 | reverse based on pytorch::examples/imagenet 9 | $ conda list | grep torch 10 | pytorch 0.4.1 py36_cuda0.0_cudnn0.0_1 pytorch 11 | torchvision 0.2.1 py36_1 pytorch 12 | 13 | @author: pilgrim.bin@163.com 14 | """ 15 | import argparse 16 | import os 17 | import shutil 18 | import time 19 | import copy 20 | import numpy as np 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.parallel 25 | import torch.backends.cudnn as cudnn 26 | import torch.distributed as dist 27 | import torch.optim 28 | import torch.utils.data 29 | import torch.utils.data.distributed 30 | #import torchvision.transforms as transforms 31 | #import torchvision.datasets as datasets 32 | import torchvision.models as models 33 | import mtmcmodel as mtmcmodel 34 | import mtmcconfig as mtmcconfig 35 | import diy_yolov2 36 | 37 | #import folder_diy 38 | from mtmc_coordinator import SLMCCoordinator 39 | 40 | tv_model_names = sorted(name for name in models.__dict__ 41 | if name.islower() and not name.startswith("__") 42 | and callable(models.__dict__[name])) 43 | 44 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 45 | 46 | # data 47 | parser.add_argument('--data', metavar='DIR', 48 | default='/Users/baiqi/data/pants', 49 | help='path to dataset') 50 | parser.add_argument('--max_base_number', default=5000, type=int, 51 | #help='base number of each class sample.') 52 | help='max_base_number is the base_number of the label with the fewest classes.') 53 | 54 | # net 55 | ''' 56 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 57 | choices=tv_model_names, 58 | help='model architecture: ' + 59 | ' | '.join(tv_model_names) + 60 | ' (default: resnet18)') 61 | ''' 62 | parser.add_argument('--arch', metavar='ARCH', default='yolov2', 63 | # default='resnet18', 64 | help='model architecture: tv.models or diy_model.') 65 | 66 | ''' 67 | parser.add_argument('--class_number', default=1000, type=int, metavar='N', 68 | help='number of class (default: 1000)') 69 | ''' 70 | # add for diy models 71 | parser.add_argument('--dataloader_resize_h', default=256, type=int) 72 | parser.add_argument('--dataloader_resize_w', default=256, type=int) 73 | parser.add_argument('--inputlayer_h', default=256, type=int) 74 | parser.add_argument('--inputlayer_w', default=256, type=int) 75 | parser.add_argument('--fc_features', default=512, type=int, help='net:input layer size and model framework defines the input-fc-features') 76 | 77 | 78 | # training params 79 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 80 | help='number of data loading workers (default: 4)') 81 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 82 | help='number of total epochs to run') 83 | parser.add_argument('--fc_epochs', default=50, type=int, metavar='N', 84 | help='number of epochs to update optimizer.') 85 | 86 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 87 | help='manual epoch number (useful on restarts)') 88 | parser.add_argument('--batch_size', default=256, type=int, 89 | metavar='N', help='mini-batch size (default: 256)') 90 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 91 | metavar='LR', help='initial learning rate') 92 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 93 | help='momentum') 94 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 95 | metavar='W', help='weight decay (default: 1e-4)') 96 | parser.add_argument('--print-freq', '-p', default=10, type=int, 97 | metavar='N', help='print frequency (default: 10)') 98 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 99 | help='path to latest checkpoint (default: none)') 100 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 101 | help='evaluate model on validation set') 102 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 103 | help='use pre-trained model') 104 | parser.add_argument('--world-size', default=1, type=int, 105 | help='number of distributed processes') 106 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 107 | help='url used to set up distributed training') 108 | parser.add_argument('--dist-backend', default='gloo', type=str, 109 | help='distributed backend') 110 | 111 | best_prec1 = 0 112 | 113 | # vals log 114 | vals_log_path = "vals_log" 115 | if not os.path.exists(vals_log_path): 116 | os.mkdir(vals_log_path) 117 | 118 | def cout_info_BaseNumberSpecifier(mtmcdataloader): 119 | for label in mtmcdataloader.labels: 120 | print('-------cout_info_BaseNumberSpecifier::{}--------'.format(label)) 121 | print("TRAIN-Label = {} top1_dict = {}".format(label, 122 | mtmcdataloader.slmcdataloader_dict[label].base_number_specifier.class_to_number_dict)) 123 | print("TRAIN-Label = {} number_dict = {}".format(label, 124 | mtmcdataloader.slmcdataloader_dict[label].base_number_specifier.class_to_prec_dict)) 125 | 126 | 127 | def main(): 128 | global args, best_prec1 129 | args = parser.parse_args() 130 | 131 | print('args = {}'.format(args)) 132 | 133 | args.distributed = args.world_size > 1 134 | print("args.distributed = {}".format(args.distributed)) 135 | if args.distributed: 136 | raise(RuntimeError("Distributed Mode is Unsupported Currently.")) 137 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 138 | world_size=args.world_size) 139 | train_sampler = None 140 | 141 | '''------------------------------------- 142 | # create dataloader 143 | ----------------------------------------''' 144 | # resize params 145 | dataresize = [args.dataloader_resize_h, args.dataloader_resize_w, args.inputlayer_h, args.inputlayer_w] 146 | # MTMCDataloader 147 | mtmcdataloader = mtmcconfig.MTMCDataloader( 148 | args.data, # data root path 149 | dataresize=dataresize, 150 | batch_size=args.batch_size, 151 | workers=args.workers, 152 | max_base_number=args.max_base_number) 153 | print('INFO: = mtmcdataloader.mtmc_tree = {}'.format(mtmcdataloader.mtmc_tree)) 154 | print('INFO: = mtmcdataloader.label_to_idx = {}'.format(mtmcdataloader.label_to_idx)) 155 | 156 | label_list = mtmcdataloader.label_to_idx.keys() 157 | #label_list.sort() # py2 158 | label_list = sorted(label_list) 159 | class_numbers = [] 160 | for label in label_list: 161 | class_numbers.append(len(mtmcdataloader.mtmc_tree[label])) 162 | 163 | '''------------------------------------- 164 | # create model 165 | ----------------------------------------''' 166 | if args.arch in tv_model_names: 167 | # using torchvision modles, resnet or inception 168 | if args.pretrained: 169 | print("=> using pre-trained model '{}'".format(args.arch)) 170 | model = models.__dict__[args.arch](pretrained=True) 171 | else: 172 | print("=> creating model '{}'".format(args.arch)) 173 | model = models.__dict__[args.arch]() 174 | # raw mc 175 | # fc_features = model.fc.in_features # it is 512 if using resnet18_224x224 176 | # model.fc = nn.Linear(fc_features, class_number) 177 | # new mtmc 178 | fc_features = args.fc_features 179 | model.fc = mtmcmodel.BuildMultiLabelModel(fc_features, class_numbers) 180 | elif 'yolo' in args.arch.lower(): # yolo_vx_h_w = yolov123_768x512 181 | g_inputlayer_heigth, g_inputlayer_width = args.arch.lower().split('_')[-1].split('x') 182 | g_inputlayer_heigth = int(g_inputlayer_heigth) 183 | g_inputlayer_width = int(g_inputlayer_width) 184 | model = diy_yolov2.diy_yolov2(pretrained=False, num_classes=1000, init_weights=True) 185 | fc_features = args.fc_features 186 | model.classifier = mtmcmodel.BuildMultiLabelModel(fc_features, class_numbers) 187 | else: 188 | raise(RuntimeError("Unknown model arch = {}.".format(args.arch))) 189 | 190 | if not args.distributed: 191 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 192 | model.features = torch.nn.DataParallel(model.features) 193 | model.cuda() 194 | else: 195 | model = torch.nn.DataParallel(model).cuda() 196 | else: 197 | model.cuda() 198 | model = torch.nn.parallel.DistributedDataParallel(model) 199 | 200 | 201 | '''------------------------------------- 202 | # set criterion & optimizer 203 | ----------------------------------------''' 204 | # define loss function (criterion) and optimizer 205 | criterion = nn.CrossEntropyLoss().cuda() 206 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 207 | momentum=args.momentum, 208 | weight_decay=args.weight_decay) 209 | 210 | # optionally resume from a checkpoint 211 | if args.resume: 212 | if os.path.isfile(args.resume): 213 | print("=> loading checkpoint '{}'".format(args.resume)) 214 | checkpoint = torch.load(args.resume) 215 | args.start_epoch = checkpoint['epoch'] 216 | best_prec1 = checkpoint['best_prec1'] 217 | model.load_state_dict(checkpoint['state_dict']) 218 | if args.start_epoch >= args.fc_epochs: 219 | optimizer = torch.optim.SGD(model.module.fc.parameters(), args.lr, 220 | momentum=args.momentum, 221 | weight_decay=args.weight_decay) 222 | optimizer.load_state_dict(checkpoint['optimizer']) 223 | print("=> loaded checkpoint '{}' (epoch {})" 224 | .format(args.resume, checkpoint['epoch'])) 225 | else: 226 | print("=> no checkpoint found at '{}'".format(args.resume)) 227 | 228 | cudnn.benchmark = True 229 | 230 | # Data loading code 231 | ''' 232 | traindir = os.path.join(args.data, 'train') 233 | valdir = os.path.join(args.data, 'val') 234 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 235 | train_dataset = datasets.ImageFolder( 236 | traindir, 237 | transforms.Compose([ 238 | transforms.RandomResizedCrop(224), 239 | transforms.RandomHorizontalFlip(), 240 | transforms.ToTensor(), 241 | normalize])) 242 | if args.distributed: 243 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 244 | else: 245 | train_sampler = None 246 | ''' 247 | 248 | if args.evaluate: 249 | prec1_avg, label_prec1_top1_dict_dict, label_top1_dict = mtmc_validate(mtmcdataloader, model, criterion, phase="VAL") 250 | print('label_prec1_top1_dict_dict = {}'.format(label_prec1_top1_dict_dict)) 251 | return 252 | 253 | for epoch in range(args.start_epoch, args.epochs): 254 | if args.distributed: 255 | train_sampler.set_epoch(epoch) 256 | adjust_learning_rate(optimizer, epoch) 257 | 258 | '''------------------------------------- 259 | # update criterion & optimizer 260 | # 可以分三阶段:FC, Base + FC, FC. 261 | ----------------------------------------''' 262 | # define loss function (criterion) and optimizer 263 | if epoch == args.fc_epochs: 264 | # criterion = nn.CrossEntropyLoss().cuda() 265 | optimizer = torch.optim.SGD(model.module.fc.parameters(), args.lr, 266 | momentum=args.momentum, 267 | weight_decay=args.weight_decay) 268 | 269 | # mtmc_coordinator 270 | max_lens_list = [] 271 | for label in label_list: 272 | dataloader = mtmcdataloader.slmcdataloader_dict[label] 273 | max_lens_list.append(dataloader.train_diy_loader.__len__()) 274 | print("SLMCCoordinator::max_lens_list = {}".format(max_lens_list)) 275 | coordinator = SLMCCoordinator(max_lens_list) 276 | 277 | # train for one epoch 278 | mtmctrain(mtmcdataloader, model, criterion, optimizer, epoch, coordinator) 279 | 280 | # evaluate on trainset 281 | ''' # no need to waste time testing trainset 282 | prec1_avg, label_prec1_top1_dict_dict, label_top1_dict = mtmc_validate(mtmcdataloader, model, criterion, phase="TRAIN") 283 | print('Train - label_prec1_top1_dict_dict = {}'.format(label_prec1_top1_dict_dict)) 284 | ''' 285 | 286 | # update sample number of each class 287 | if epoch > 0 and epoch % 20 == 0: 288 | prec1_avg, label_prec1_top1_dict_dict, label_top1_dict = mtmc_validate(mtmcdataloader, model, criterion, phase="TRAIN") 289 | print('Train - label_prec1_top1_dict_dict = {}'.format(label_prec1_top1_dict_dict)) 290 | mtmcdataloader.update_train_diy_loader(label_top1_dict) 291 | 292 | # cout_info_BaseNumberSpecifier(mtmcdataloader) 293 | cout_info_BaseNumberSpecifier(mtmcdataloader) 294 | 295 | # evaluate on validation set 296 | prec1_avg, label_prec1_top1_dict_dict, label_top1_dict = mtmc_validate(mtmcdataloader, model, criterion, phase="VAL") 297 | print('VAL - label_prec1_top1_dict_dict = {}'.format(label_prec1_top1_dict_dict)) 298 | 299 | # remember best prec@1 and save checkpoint 300 | is_best = prec1_avg > best_prec1 301 | best_prec1 = max(prec1_avg, best_prec1) 302 | save_checkpoint({ 303 | 'epoch': epoch + 1, 304 | 'arch': args.arch, 305 | 'state_dict': model.state_dict(), 306 | 'best_prec1': best_prec1, 307 | 'optimizer' : optimizer.state_dict(), 308 | }, is_best, 309 | filename='checkpoint_{}.pth.tar'.format(args.arch)) 310 | if epoch > 0 and epoch % 10 == 0: 311 | save_checkpoint({ 312 | 'epoch': epoch + 1, 313 | 'arch': args.arch, 314 | 'state_dict': model.state_dict(), 315 | 'best_prec1': best_prec1, 316 | 'optimizer' : optimizer.state_dict(), 317 | }, False, 318 | filename='checkpoint_{}_{}.pth.tar'.format(args.arch, epoch + 1)) 319 | 320 | 321 | def get_dict_key(dict, value): 322 | for k in dict.keys(): 323 | if dict[k] == value: 324 | return k 325 | return None 326 | 327 | 328 | # it goes wrong if using model = torch.nn.DataParallel(model).cuda() 329 | def backbone_zero_grad(model): 330 | model.avgpool.zero_grad() 331 | model.bn1.zero_grad() 332 | model.conv1.zero_grad() 333 | model.layer1.zero_grad() 334 | model.layer2.zero_grad() 335 | model.layer3.zero_grad() 336 | model.layer4.zero_grad() 337 | model.relu.zero_grad() 338 | model.maxpool.zero_grad() 339 | 340 | def mtmctrain(mtmcdataloader, model, criterion, optimizer, epoch, coordinator): 341 | # vals log 342 | val_measure_dict = copy.deepcopy(mtmcdataloader.label_to_idx) 343 | batch_time = AverageMeter() 344 | data_time = AverageMeter() 345 | for key in val_measure_dict.keys(): 346 | losses = AverageMeter() # 0 347 | top1 = AverageMeter() # 1 348 | topk = AverageMeter() # 2 349 | val_measure_dict[key] = [losses, top1, topk] 350 | 351 | 352 | # switch to train mode 353 | model.train() 354 | 355 | # enumerate_list of traindata loader 356 | enumerate_list = [] 357 | for label in mtmcdataloader.labels: 358 | enumerate_list.append(enumerate(mtmcdataloader.slmcdataloader_dict[label].train_diy_loader)) 359 | 360 | end = time.time() 361 | for i, flag in enumerate(coordinator.iter_flag_list): 362 | #i, (input, target) = enumerate_list[flag].next() # py2 363 | i, (input, target) = next(enumerate_list[flag]) 364 | label = get_dict_key(mtmcdataloader.label_to_idx, flag) 365 | 366 | # measure data loading time 367 | data_time.update(time.time() - end) 368 | 369 | target = target.cuda(non_blocking=True) 370 | 371 | # compute output 372 | output = model(input)[flag] 373 | loss = criterion(output, target) 374 | 375 | # measure accuracy and record loss 376 | [prec1, preck], class_to = accuracy(output, target, topk=(1, 2)) 377 | val_measure_dict[label][0].update(loss.item(), input.size(0)) 378 | val_measure_dict[label][1].update(prec1[0], input.size(0)) 379 | val_measure_dict[label][2].update(preck[0], input.size(0)) 380 | 381 | # compute gradient and do SGD step 382 | optimizer.zero_grad() 383 | loss.backward() 384 | optimizer.step() 385 | 386 | # measure elapsed time 387 | batch_time.update(time.time() - end) 388 | end = time.time() 389 | 390 | if i % args.print_freq == 0: 391 | print('Epoch: {} - [{}][{}/{}]\t' 392 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 393 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 394 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 395 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 396 | 'Prec@2 {topk.val:.3f} ({topk.avg:.3f})'.format( 397 | label, epoch, i, len(mtmcdataloader.slmcdataloader_dict[label].train_diy_loader), 398 | batch_time=batch_time, data_time=data_time, 399 | loss=val_measure_dict[label][0], 400 | top1=val_measure_dict[label][1], 401 | topk=val_measure_dict[label][2])) 402 | 403 | 404 | def cout_confmatrix_tofile(ConfusionMatrix, classes, filename): 405 | cm = ConfusionMatrix 406 | print('------ConfusionMatrix[target[idx], class_to[idx]]:') 407 | print(cm) 408 | print('------Recall = ConfusionMatrix ratio per class:') 409 | cmf = cm.astype(np.float) 410 | ss = cmf.sum(1,keepdims=True) 411 | for idx in range(len(ss)): 412 | cmf[idx,] /= ss[idx] 413 | print(cmf) 414 | 415 | # Precision 416 | print('------Precision matrix @ ConfusionMatrix :') 417 | cmf_prec = cm.astype(np.float) 418 | ss = cmf_prec.sum(0,keepdims=True) 419 | precs_top1 = [] 420 | for idx in range(len(ss[0])): 421 | cmf_prec[:,idx] /= ss[0][idx] 422 | precs_top1.append(cmf_prec[idx,idx]) 423 | print('precs_top1 = {}'.format(precs_top1)) 424 | 425 | with open(filename, 'w') as fp: # using 'wb' in py2 426 | # cm 427 | fp.write('ConfusionMatrix:\n') 428 | fp.write('gt\\pred:\t') 429 | [fp.write('{}\t'.format(c)) for c in classes] 430 | fp.write('\n') 431 | for h in range(len(classes)): 432 | fp.write('{}\t'.format(classes[h])) 433 | for w in range(len(classes)): 434 | fp.write('{}\t'.format(cm[h,w])) 435 | fp.write('\n') 436 | fp.write('--------------------------------------') 437 | fp.write('\n') 438 | 439 | # cmf 440 | fp.write('ConfusionMatrix classify ratio:\n') 441 | fp.write('gt\\pred:\t') 442 | [fp.write('{}\t'.format(c)) for c in classes] 443 | fp.write('\n') 444 | for h in range(len(classes)): 445 | fp.write('{}\t'.format(classes[h])) 446 | for w in range(len(classes)): 447 | fp.write('{0:.3f}\t'.format(cmf[h,w])) 448 | fp.write('\n') 449 | fp.write('--------------------------------------') 450 | fp.write('\n') 451 | 452 | # Recall 453 | fp.write('Recall:\n') 454 | for h in range(len(classes)): 455 | fp.write('{}\t{}\n'.format(classes[h], cmf[h,h])) 456 | fp.write('--------------------------------------') 457 | fp.write('\n') 458 | 459 | # Precision@1 460 | fp.write('Precision@1:\n') 461 | for h in range(len(classes)): 462 | fp.write('{}\t{}\n'.format(classes[h], cmf_prec[h,h])) 463 | fp.write('--------------------------------------') 464 | fp.write('\n') 465 | return precs_top1 466 | 467 | 468 | 469 | def slmc_validate(val_loader, model, model_idx, criterion, phase="VAL"): 470 | 471 | batch_time = AverageMeter() 472 | losses = AverageMeter() 473 | top1 = AverageMeter() 474 | topk = AverageMeter() 475 | 476 | # switch to evaluate mode 477 | model.eval() 478 | 479 | classes = val_loader.dataset.classes 480 | class_number = len(classes) 481 | ConfusionMatrix = np.zeros((class_number, class_number), dtype=int) 482 | 483 | with torch.no_grad(): 484 | end = time.time() 485 | for i, (input, target) in enumerate(val_loader): 486 | target = target.cuda(non_blocking=True) 487 | 488 | # compute output 489 | output = model(input)[model_idx] 490 | loss = criterion(output, target) 491 | 492 | # measure accuracy and record loss 493 | [prec1, preck], class_to = accuracy(output, target, topk=(1, 2)) 494 | losses.update(loss.item(), input.size(0)) 495 | top1.update(prec1[0], input.size(0)) 496 | topk.update(preck[0], input.size(0)) 497 | 498 | # measure elapsed time 499 | batch_time.update(time.time() - end) 500 | end = time.time() 501 | 502 | if i % args.print_freq == 0: 503 | print('Test-{0}: [{1}/{2}]\t' 504 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 505 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 506 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 507 | 'Prec@2 {topk.val:.3f} ({topk.avg:.3f})'.format( 508 | phase, i, len(val_loader), 509 | batch_time=batch_time, 510 | loss=losses, 511 | top1=top1, topk=topk)) 512 | 513 | # save result to ConfusionMatrix 514 | label = target.cpu().numpy() 515 | for idx in range(len(class_to)): 516 | ConfusionMatrix[label[idx], class_to[idx]] += 1 517 | 518 | print(' * {} Prec@1 {top1.avg:.3f} Prec@2 {topk.avg:.3f}' 519 | .format(phase, top1=top1, topk=topk)) 520 | 521 | filename = 'val_{}_ConfusionMatrix_'.format(phase) \ 522 | + time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(time.time())) \ 523 | + '.txt' 524 | filename = os.path.join(vals_log_path, filename) 525 | precs_top1 = cout_confmatrix_tofile(ConfusionMatrix, classes, filename) 526 | 527 | top1_dict = {classes[i] : precs_top1[i] for i in range(len(classes))} 528 | return (top1.avg, top1_dict) 529 | 530 | 531 | def mtmc_validate(mtmcdataloader, model, criterion, phase="VAL"): 532 | label_prec1_top1_dict_dict = copy.deepcopy(mtmcdataloader.label_to_idx) 533 | label_top1_dict = copy.deepcopy(mtmcdataloader.label_to_idx) 534 | prec1_avg = 0 535 | for label in mtmcdataloader.label_to_idx.keys(): 536 | if phase == "VAL": 537 | val_loader = mtmcdataloader.slmcdataloader_dict[label].val_val_loader 538 | else: # 'TRAIN' 539 | val_loader = mtmcdataloader.slmcdataloader_dict[label].val_train_loader 540 | model_idx = mtmcdataloader.label_to_idx[label] 541 | prec1, top1_dict = slmc_validate(val_loader, model, model_idx, criterion, phase='_'.join([phase, label])) 542 | prec1_avg += prec1 543 | label_prec1_top1_dict_dict[label] = tuple([prec1, top1_dict]) 544 | label_top1_dict[label] = top1_dict 545 | return (prec1_avg / float(len(label_prec1_top1_dict_dict.keys())), label_prec1_top1_dict_dict, label_top1_dict) 546 | 547 | 548 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 549 | torch.save(state, filename) 550 | if is_best: 551 | shutil.copyfile(filename, 'model_best_' + filename) 552 | 553 | 554 | class AverageMeter(object): 555 | """Computes and stores the average and current value""" 556 | def __init__(self): 557 | self.reset() 558 | 559 | def reset(self): 560 | self.val = 0 561 | self.avg = 0 562 | self.sum = 0 563 | self.count = 0 564 | 565 | def update(self, val, n=1): 566 | self.val = val 567 | self.sum += val * n 568 | self.count += n 569 | self.avg = self.sum / self.count 570 | 571 | 572 | def adjust_learning_rate(optimizer, epoch): 573 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 574 | lr = args.lr * (0.1 ** (epoch // 40)) 575 | for param_group in optimizer.param_groups: 576 | param_group['lr'] = lr 577 | 578 | 579 | def accuracy(output, target, topk=(1,)): 580 | """Computes the precision@k for the specified values of k""" 581 | with torch.no_grad(): 582 | maxk = max(topk) 583 | batch_size = target.size(0) 584 | 585 | _, pred = output.topk(maxk, 1, True, True) 586 | pred = pred.t() 587 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 588 | 589 | class_to = pred[0].cpu().numpy() 590 | 591 | res = [] 592 | for k in topk: 593 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 594 | res.append(correct_k.mul_(100.0 / batch_size)) 595 | return res, class_to 596 | 597 | 598 | if __name__ == '__main__': 599 | main() 600 | 601 | -------------------------------------------------------------------------------- /src/main_stmc_resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jun 12 17:07:39 2018 5 | 6 | Python 3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 11:07:29) 7 | 8 | reverse based on pytorch::examples/imagenet 9 | $ conda list | grep torch 10 | pytorch 0.4.1 py36_cuda0.0_cudnn0.0_1 pytorch 11 | torchvision 0.2.1 py36_1 pytorch 12 | 13 | @author: pilgrim.bin@163.com 14 | """ 15 | import argparse 16 | import os 17 | import shutil 18 | import time 19 | 20 | import numpy as np 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.parallel 25 | import torch.backends.cudnn as cudnn 26 | import torch.distributed as dist 27 | import torch.optim 28 | import torch.utils.data 29 | import torch.utils.data.distributed 30 | import torchvision.transforms as transforms 31 | import torchvision.datasets as datasets 32 | import torchvision.models as models 33 | 34 | import folder_diy 35 | 36 | model_names = sorted(name for name in models.__dict__ 37 | if name.islower() and not name.startswith("__") 38 | and callable(models.__dict__[name])) 39 | 40 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 41 | 42 | # data 43 | parser.add_argument('data', metavar='DIR', 44 | help='path to dataset') 45 | parser.add_argument('--base_number', default=10000, type=int, 46 | help='base number of each class sample ') 47 | 48 | # net 49 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 50 | choices=model_names, 51 | help='model architecture: ' + 52 | ' | '.join(model_names) + 53 | ' (default: resnet18)') 54 | parser.add_argument('--class_number', default=1000, type=int, metavar='N', 55 | help='number of class (default: 1000)') 56 | 57 | # training params 58 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 59 | help='number of data loading workers (default: 4)') 60 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 61 | help='number of total epochs to run') 62 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 63 | help='manual epoch number (useful on restarts)') 64 | parser.add_argument('-b', '--batch-size', default=256, type=int, 65 | metavar='N', help='mini-batch size (default: 256)') 66 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 67 | metavar='LR', help='initial learning rate') 68 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 69 | help='momentum') 70 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 71 | metavar='W', help='weight decay (default: 1e-4)') 72 | parser.add_argument('--print-freq', '-p', default=10, type=int, 73 | metavar='N', help='print frequency (default: 10)') 74 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 75 | help='path to latest checkpoint (default: none)') 76 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 77 | help='evaluate model on validation set') 78 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 79 | help='use pre-trained model') 80 | parser.add_argument('--world-size', default=1, type=int, 81 | help='number of distributed processes') 82 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 83 | help='url used to set up distributed training') 84 | parser.add_argument('--dist-backend', default='gloo', type=str, 85 | help='distributed backend') 86 | 87 | best_prec1 = 0 88 | 89 | # vals log 90 | vals_log_path = "vals_log" 91 | if not os.path.exists(vals_log_path): 92 | os.mkdir(vals_log_path) 93 | 94 | def main(): 95 | global args, best_prec1 96 | args = parser.parse_args() 97 | 98 | args.distributed = args.world_size > 1 99 | 100 | print("args.distributed = {}".format(args.distributed)) 101 | 102 | if args.distributed: 103 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 104 | world_size=args.world_size) 105 | 106 | # create model 107 | if args.pretrained: 108 | print("=> using pre-trained model '{}'".format(args.arch)) 109 | model = models.__dict__[args.arch](pretrained=True) 110 | else: 111 | print("=> creating model '{}'".format(args.arch)) 112 | model = models.__dict__[args.arch]() 113 | 114 | class_number = args.class_number 115 | fc_features = model.fc.in_features 116 | model.fc = nn.Linear(fc_features, class_number) 117 | 118 | if not args.distributed: 119 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 120 | model.features = torch.nn.DataParallel(model.features) 121 | model.cuda() 122 | else: 123 | model = torch.nn.DataParallel(model).cuda() 124 | else: 125 | model.cuda() 126 | model = torch.nn.parallel.DistributedDataParallel(model) 127 | 128 | # define loss function (criterion) and optimizer 129 | criterion = nn.CrossEntropyLoss().cuda() 130 | 131 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 132 | momentum=args.momentum, 133 | weight_decay=args.weight_decay) 134 | 135 | # optionally resume from a checkpoint 136 | if args.resume: 137 | if os.path.isfile(args.resume): 138 | print("=> loading checkpoint '{}'".format(args.resume)) 139 | checkpoint = torch.load(args.resume) 140 | args.start_epoch = checkpoint['epoch'] 141 | best_prec1 = checkpoint['best_prec1'] 142 | model.load_state_dict(checkpoint['state_dict']) 143 | optimizer.load_state_dict(checkpoint['optimizer']) 144 | print("=> loaded checkpoint '{}' (epoch {})" 145 | .format(args.resume, checkpoint['epoch'])) 146 | else: 147 | print("=> no checkpoint found at '{}'".format(args.resume)) 148 | 149 | cudnn.benchmark = True 150 | 151 | # Data loading code 152 | traindir = os.path.join(args.data, 'train') 153 | valdir = os.path.join(args.data, 'val') 154 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 155 | std=[0.229, 0.224, 0.225]) 156 | 157 | train_dataset = datasets.ImageFolder( 158 | traindir, 159 | transforms.Compose([ 160 | transforms.RandomResizedCrop(224), 161 | transforms.RandomHorizontalFlip(), 162 | transforms.ToTensor(), 163 | normalize, 164 | ])) 165 | 166 | if args.distributed: 167 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 168 | else: 169 | train_sampler = None 170 | ''' 171 | train_loader = torch.utils.data.DataLoader( 172 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 173 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 174 | ''' 175 | val_loader = torch.utils.data.DataLoader( 176 | datasets.ImageFolder(valdir, transforms.Compose([ 177 | transforms.Resize(224), # raw = 256, CenterCrop is not a good trick. 178 | transforms.CenterCrop(224), 179 | transforms.ToTensor(), 180 | normalize, 181 | ])), 182 | batch_size=args.batch_size, shuffle=False, 183 | num_workers=args.workers, pin_memory=True) 184 | 185 | val_trainset_loader = torch.utils.data.DataLoader( 186 | datasets.ImageFolder(traindir, transforms.Compose([ 187 | transforms.Resize(224), # raw = 256, CenterCrop is not a good trick. 188 | transforms.CenterCrop(224), 189 | transforms.ToTensor(), 190 | normalize, 191 | ])), 192 | batch_size=args.batch_size, shuffle=False, 193 | num_workers=args.workers, pin_memory=True) 194 | 195 | if args.evaluate: 196 | prec1, top1_dict = validate(val_loader, model, criterion, phase="VAL") 197 | return 198 | 199 | # diy dataloader 200 | base_number = args.base_number 201 | dbparser = folder_diy.DatasetFolderParsing(traindir) 202 | number_dict = {key : 0 for key in dbparser.class_to_idx.keys()} 203 | base_number_specifier = folder_diy.BaseNumberSpecifier(number_dict, base_number) 204 | train_diy_transforms = transforms.Compose([ 205 | transforms.Resize(224), # raw = 256, CenterCrop is not a good trick. 206 | # transforms.RandomRotation(10), 207 | transforms.CenterCrop(224), 208 | # transforms.RandomHorizontalFlip(),# bad for charactor 209 | # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), 210 | transforms.ToTensor(), 211 | normalize]) 212 | 213 | for epoch in range(args.start_epoch, args.epochs): 214 | if args.distributed: 215 | train_sampler.set_epoch(epoch) 216 | adjust_learning_rate(optimizer, epoch) 217 | 218 | # diy dataloader 219 | print("number_dict = {}".format(number_dict)) 220 | train_diy_dataset = folder_diy.ImageFolder_SpecifiedNumber( 221 | dbparser, number_dict=number_dict, 222 | transform=train_diy_transforms) 223 | 224 | train_diy_loader = torch.utils.data.DataLoader( 225 | train_diy_dataset, 226 | batch_size=args.batch_size, shuffle=(train_sampler is None), 227 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 228 | 229 | # train for one epoch 230 | #train(train_loader, model, criterion, optimizer, epoch) 231 | train(train_diy_loader, model, criterion, optimizer, epoch) 232 | 233 | # evaluate on trainset 234 | prec1, top1_dict = validate(val_trainset_loader, model, criterion, phase="TRAIN") 235 | # update sample number of each class 236 | if epoch % 20 == 0: 237 | base_number_specifier.update(top1_dict) 238 | number_dict=base_number_specifier.class_to_number_dict 239 | print("TRAIN top1_dict = {}".format(top1_dict)) 240 | print("TRAIN number_dict = {}".format(number_dict)) 241 | print("TRAIN top1_dict.values() = {}".format(top1_dict.values())) 242 | print("TRAIN number_dict.values() = {}".format(number_dict.values())) 243 | 244 | # evaluate on validation set 245 | prec1, top1_dict = validate(val_loader, model, criterion, phase="VAL") 246 | print("VAL top1_dict = {}".format(top1_dict)) 247 | print("VAL top1_dict.values() = {}".format(top1_dict.values())) 248 | 249 | # remember best prec@1 and save checkpoint 250 | is_best = prec1 > best_prec1 251 | best_prec1 = max(prec1, best_prec1) 252 | save_checkpoint({ 253 | 'epoch': epoch + 1, 254 | 'arch': args.arch, 255 | 'state_dict': model.state_dict(), 256 | 'best_prec1': best_prec1, 257 | 'optimizer' : optimizer.state_dict(), 258 | }, is_best, 259 | filename='checkpoint_{}.pth.tar'.format(args.arch)) 260 | 261 | 262 | def train(train_loader, model, criterion, optimizer, epoch): 263 | batch_time = AverageMeter() 264 | data_time = AverageMeter() 265 | losses = AverageMeter() 266 | top1 = AverageMeter() 267 | top5 = AverageMeter() 268 | 269 | # switch to train mode 270 | model.train() 271 | 272 | end = time.time() 273 | for i, (input, target) in enumerate(train_loader): 274 | # measure data loading time 275 | data_time.update(time.time() - end) 276 | 277 | target = target.cuda(non_blocking=True) 278 | 279 | # compute output 280 | output = model(input) 281 | loss = criterion(output, target) 282 | 283 | # measure accuracy and record loss 284 | [prec1, prec5], class_to = accuracy(output, target, topk=(1, 5)) 285 | losses.update(loss.item(), input.size(0)) 286 | top1.update(prec1[0], input.size(0)) 287 | top5.update(prec5[0], input.size(0)) 288 | 289 | # compute gradient and do SGD step 290 | optimizer.zero_grad() 291 | loss.backward() 292 | optimizer.step() 293 | 294 | # measure elapsed time 295 | batch_time.update(time.time() - end) 296 | end = time.time() 297 | 298 | if i % args.print_freq == 0: 299 | print('Epoch: [{0}][{1}/{2}]\t' 300 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 301 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 302 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 303 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 304 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 305 | epoch, i, len(train_loader), batch_time=batch_time, 306 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 307 | 308 | 309 | def cout_confmatrix_tofile(ConfusionMatrix, classes, filename): 310 | cm = ConfusionMatrix 311 | print('------ConfusionMatrix[target[idx], class_to[idx]]:') 312 | print(cm) 313 | print('------ConfusionMatrix ratio per class:') 314 | cmf = cm.astype(np.float) 315 | ss = cmf.sum(1,keepdims=True) 316 | for idx in range(len(ss)): 317 | cmf[idx,] /= ss[idx] 318 | print(cmf) 319 | 320 | with open(filename, 'wb') as fp: 321 | # cm 322 | fp.write('ConfusionMatrix:\n') 323 | fp.write('classify_as:\t') 324 | [fp.write('{}\t'.format(c)) for c in classes] 325 | fp.write('\n') 326 | for h in range(len(classes)): 327 | fp.write('{}\t'.format(classes[h])) 328 | for w in range(len(classes)): 329 | fp.write('{}\t'.format(cm[h,w])) 330 | fp.write('\n') 331 | fp.write('--------------------------------------') 332 | fp.write('\n') 333 | 334 | # cmf 335 | fp.write('ConfusionMatrix classify ratio:\n') 336 | fp.write('classify_as:\t') 337 | [fp.write('{}\t'.format(c)) for c in classes] 338 | fp.write('\n') 339 | for h in range(len(classes)): 340 | fp.write('{}\t'.format(classes[h])) 341 | for w in range(len(classes)): 342 | fp.write('{0:.3f}\t'.format(cmf[h,w])) 343 | fp.write('\n') 344 | fp.write('--------------------------------------') 345 | fp.write('\n') 346 | 347 | # Precision@1 348 | fp.write('Precision@1:\n') 349 | for h in range(len(classes)): 350 | fp.write('{}\t{}\n'.format(classes[h], cmf[h,h])) 351 | fp.write('--------------------------------------') 352 | fp.write('\n') 353 | return cmf 354 | 355 | 356 | 357 | def validate(val_loader, model, criterion, phase="VAL"): 358 | batch_time = AverageMeter() 359 | losses = AverageMeter() 360 | top1 = AverageMeter() 361 | top5 = AverageMeter() 362 | 363 | # switch to evaluate mode 364 | model.eval() 365 | 366 | classes = val_loader.dataset.classes 367 | class_number = len(classes) 368 | ConfusionMatrix = np.zeros((class_number, class_number), dtype=int) 369 | 370 | with torch.no_grad(): 371 | end = time.time() 372 | for i, (input, target) in enumerate(val_loader): 373 | target = target.cuda(non_blocking=True) 374 | 375 | # compute output 376 | output = model(input) 377 | loss = criterion(output, target) 378 | 379 | # measure accuracy and record loss 380 | [prec1, prec5], class_to = accuracy(output, target, topk=(1, 5)) 381 | losses.update(loss.item(), input.size(0)) 382 | top1.update(prec1[0], input.size(0)) 383 | top5.update(prec5[0], input.size(0)) 384 | 385 | # measure elapsed time 386 | batch_time.update(time.time() - end) 387 | end = time.time() 388 | 389 | if i % args.print_freq == 0: 390 | print('Test-{0}: [{1}/{2}]\t' 391 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 392 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 393 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 394 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 395 | phase, i, len(val_loader), 396 | batch_time=batch_time, 397 | loss=losses, 398 | top1=top1, top5=top5)) 399 | 400 | # save result to ConfusionMatrix 401 | label = target.cpu().numpy() 402 | for idx in range(len(class_to)): 403 | ConfusionMatrix[label[idx], class_to[idx]] += 1 404 | 405 | print(' * {} Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 406 | .format(phase, top1=top1, top5=top5)) 407 | 408 | filename = 'val_{}_ConfusionMatrix_'.format(phase) \ 409 | + time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(time.time())) \ 410 | + '.txt' 411 | filename = os.path.join(vals_log_path, filename) 412 | cmf = cout_confmatrix_tofile(ConfusionMatrix, classes, filename) 413 | 414 | top1_dict = {classes[i] : cmf[i,i] for i in range(len(classes))} 415 | return (top1.avg, top1_dict) 416 | 417 | 418 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 419 | torch.save(state, filename) 420 | if is_best: 421 | shutil.copyfile(filename, 'model_best_' + filename) 422 | 423 | 424 | class AverageMeter(object): 425 | """Computes and stores the average and current value""" 426 | def __init__(self): 427 | self.reset() 428 | 429 | def reset(self): 430 | self.val = 0 431 | self.avg = 0 432 | self.sum = 0 433 | self.count = 0 434 | 435 | def update(self, val, n=1): 436 | self.val = val 437 | self.sum += val * n 438 | self.count += n 439 | self.avg = self.sum / self.count 440 | 441 | 442 | def adjust_learning_rate(optimizer, epoch): 443 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 444 | lr = args.lr * (0.1 ** (epoch // 40)) 445 | for param_group in optimizer.param_groups: 446 | param_group['lr'] = lr 447 | 448 | 449 | def accuracy(output, target, topk=(1,)): 450 | """Computes the precision@k for the specified values of k""" 451 | with torch.no_grad(): 452 | maxk = max(topk) 453 | batch_size = target.size(0) 454 | 455 | _, pred = output.topk(maxk, 1, True, True) 456 | pred = pred.t() 457 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 458 | 459 | class_to = pred[0].cpu().numpy() 460 | 461 | res = [] 462 | for k in topk: 463 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 464 | res.append(correct_k.mul_(100.0 / batch_size)) 465 | return res, class_to 466 | 467 | 468 | if __name__ == '__main__': 469 | main() 470 | 471 | -------------------------------------------------------------------------------- /src/mtmc_coordinator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Nov 5 14:52:43 2018 5 | 6 | Python 3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 11:07:29) 7 | 8 | @author: pilgrim.bin@163.com 9 | """ 10 | 11 | import random 12 | 13 | class SLMCCoordinator(): 14 | def __init__(self, max_lens_list): 15 | # input params 16 | self.max_lens_list = max_lens_list 17 | self.iter_flag_list = [] 18 | for idx in range(len(max_lens_list)): 19 | self.iter_flag_list += [idx] * max_lens_list[idx] 20 | random.shuffle(self.iter_flag_list) 21 | 22 | if __name__ == '__main__': 23 | max_lens_list = [12, 20, 15] 24 | coordinator = SLMCCoordinator(max_lens_list) 25 | 26 | for i, iter_flag in enumerate(coordinator.iter_flag_list): 27 | print([i, iter_flag]) 28 | -------------------------------------------------------------------------------- /src/mtmcconfig.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Nov 5 14:52:43 2018 5 | 6 | Python 3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 11:07:29) 7 | 8 | @author: pilgrim.bin@163.com 9 | """ 10 | 11 | import sys 12 | import os 13 | import argparse 14 | import copy 15 | import random 16 | 17 | import torch 18 | #import torch.nn as nn 19 | #import torch.nn.parallel 20 | #import torch.backends.cudnn as cudnn 21 | #import torch.distributed as dist 22 | import torch.optim 23 | import torch.utils.data 24 | import torch.utils.data.distributed 25 | import torchvision.transforms as transforms 26 | import torchvision.datasets as datasets 27 | #import torchvision.models as models 28 | 29 | # baiqi diy 30 | import diy_folder as diy_folder 31 | 32 | ''' 33 | MLDataloader load MTMC dataset as following directory tree. 34 | Make sure train-val directory tree in consistance. 35 | 36 | data_root_path 37 | ├── task_A 38 | │ ├── train 39 | │ │ ├── class_1 40 | │ │ ├── class_2 41 | │ │ ├── class_3 42 | │ │ └── class_4 43 | │ └── val 44 | │ ├── class_1 45 | │ ├── class_2 46 | │ ├── class_3 47 | │ └── class_4 48 | └── task_B 49 | ├── train 50 | │ ├── class_1 51 | │ ├── class_2 52 | │ └── class_3 53 | └── val 54 | ├── class_1 55 | ├── class_2 56 | └── class_3 57 | ''' 58 | 59 | def raise_error_if_not_exists(path): 60 | if not os.path.exists(path): 61 | raise(RuntimeError("Dataset path = {} cannot be found.".format(path))) 62 | 63 | # Single Label Multi-Class Dataloader 64 | class SLMCDataloader(): 65 | def __init__(self, path, dataresize=[256,256,224,224], batch_size=32, workers=4, base_number=100): 66 | # input params 67 | self.batch_size = batch_size 68 | self.workers = workers 69 | self.base_number = base_number 70 | 71 | # directory_check 72 | self.rootpath = path 73 | self.train_path = os.path.join(path, 'train') 74 | self.val_path = os.path.join(path, 'val') 75 | raise_error_if_not_exists(self.rootpath) 76 | raise_error_if_not_exists(self.train_path) 77 | raise_error_if_not_exists(self.val_path) 78 | 79 | # resnet=224, inception=299, diy_yolo 80 | resize_h, resize_w, inputlayer_h, inputlayer_w = dataresize 81 | 82 | # init transforms 83 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 84 | self.train_diy_transforms = transforms.Compose([ 85 | transforms.RandomAffine(degrees=[-15, 15], translate=[0.1, 0.1], scale=[0.9, 1.1]), 86 | transforms.Resize(size=(resize_h, resize_w)), 87 | # transforms.RandomRotation(10), # 5, 10 88 | transforms.CenterCrop(size=(inputlayer_h, inputlayer_w)), 89 | transforms.RandomHorizontalFlip(),# bad for charactor 90 | transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), 91 | transforms.ToTensor(), 92 | normalize]) 93 | val_transforms = transforms.Compose([ 94 | # transforms.Resize(size=(resize_h, resize_w)), 95 | # transforms.CenterCrop(size=(inputlayer_h, inputlayer_w)), 96 | transforms.Resize(size=(inputlayer_h, inputlayer_w)), 97 | transforms.ToTensor(), 98 | normalize]) 99 | 100 | # train_diy_dataloader 101 | self.dbparser = diy_folder.DatasetFolderParsing(self.train_path) 102 | number_dict = {key : 0 for key in self.dbparser.class_to_idx.keys()} 103 | self.base_number_specifier = diy_folder.BaseNumberSpecifier(number_dict, base_number) 104 | self.train_diy_dataset = diy_folder.ImageFolder_SpecifiedNumber( 105 | self.dbparser, 106 | number_dict=self.base_number_specifier.class_to_number_dict, 107 | transform=self.train_diy_transforms) 108 | self.train_sampler = None # do not using distributed mode 109 | self.train_diy_loader = torch.utils.data.DataLoader( 110 | self.train_diy_dataset, 111 | batch_size=batch_size, 112 | shuffle=(self.train_sampler is None), 113 | num_workers=workers, pin_memory=True, sampler=self.train_sampler) 114 | 115 | # val_*_loaders 116 | self.val_train_loader = torch.utils.data.DataLoader( 117 | datasets.ImageFolder(self.train_path, val_transforms), 118 | batch_size=batch_size, 119 | shuffle=False, 120 | num_workers=workers, pin_memory=True) 121 | self.val_val_loader = torch.utils.data.DataLoader( 122 | datasets.ImageFolder(self.val_path, val_transforms), 123 | batch_size=batch_size, 124 | shuffle=False, 125 | num_workers=workers, pin_memory=True) 126 | 127 | def update_train_diy_loader(self, top1_dict): 128 | self.base_number_specifier.update(top1_dict) 129 | number_dict = self.base_number_specifier.class_to_number_dict 130 | 131 | # diy dataloader 132 | print("number_dict = {}".format(number_dict)) 133 | self.train_diy_dataset = diy_folder.ImageFolder_SpecifiedNumber( 134 | self.dbparser, 135 | number_dict=self.base_number_specifier.class_to_number_dict, 136 | transform=self.train_diy_transforms) 137 | self.train_sampler = None # do not using distributed mode 138 | self.train_diy_loader = torch.utils.data.DataLoader( 139 | self.train_diy_dataset, 140 | batch_size=self.batch_size, 141 | shuffle=(self.train_sampler is None), 142 | num_workers=self.workers, pin_memory=True, sampler=self.train_sampler) 143 | 144 | 145 | # max_base_number is the base_number of the label with the fewest classes, keep this val 146 | # adaptive aims to keep trainloader batches balance. 147 | class MTMCDataloader(): 148 | def __init__(self, path, dataresize=[256,256,224,224], batch_size=32, workers=4, max_base_number=100): 149 | # directory_check 150 | self.rootpath = path 151 | raise_error_if_not_exists(self.rootpath) 152 | labels, label_to_idx = diy_folder.find_classes(self.rootpath) 153 | if len(labels) == 0: 154 | raise(RuntimeError("Dataset path = {} has no folder as task-label.".format(self.rootpath))) 155 | self.labels = labels 156 | self.label_to_idx = copy.deepcopy(label_to_idx) 157 | 158 | # get_mtmc_tree 159 | self.get_mtmc_tree() 160 | 161 | # N * slmcdataloader_dict 162 | self.slmcdataloader_dict = copy.deepcopy(label_to_idx) 163 | for label in self.slmcdataloader_dict.keys(): 164 | print('------MTMCDataloader::create_SLMCDataloader({})'.format(label)) 165 | self.slmcdataloader_dict[label] = SLMCDataloader( 166 | os.path.join(self.rootpath, label), 167 | dataresize = dataresize, 168 | batch_size=batch_size, 169 | workers=workers, 170 | base_number=self.get_suitable_base_number(label, max_base_number)) 171 | 172 | def get_mtmc_tree(self): 173 | self.mtmc_tree = copy.deepcopy(self.label_to_idx) 174 | self.min_class_number = sys.maxsize # no sys.maxint in py3 175 | for label in self.mtmc_tree.keys(): 176 | classes, class_to_idx = diy_folder.find_classes(os.path.join(self.rootpath, label, 'train')) 177 | val_classes, val_class_to_idx = diy_folder.find_classes(os.path.join(self.rootpath, label, 'val')) 178 | if not class_to_idx == val_class_to_idx: 179 | print('train_class_to_idx = {}'.format(class_to_idx)) 180 | print('val_class_to_idx = {}'.format(val_class_to_idx)) 181 | raise(RuntimeError("train_class_to_idx != val_class_to_idx.")) 182 | self.mtmc_tree[label] = copy.deepcopy(class_to_idx) 183 | if self.min_class_number > len(classes): 184 | self.min_class_number = len(classes) 185 | 186 | def get_suitable_base_number(self, label, max_base_number): 187 | class_number = len(self.mtmc_tree[label].keys()) 188 | return (max_base_number * self.min_class_number) // class_number # py3 189 | 190 | 191 | def update_train_diy_loader(self, label_top1_dict): 192 | for label in label_top1_dict.keys(): 193 | print('------MTMCDataloader::update_train_diy_loader({})'.format(label)) 194 | self.slmcdataloader_dict[label].update_train_diy_loader(label_top1_dict[label]) 195 | 196 | def make_fake_label_top1_dict(label_top1_dict): 197 | for label in label_top1_dict.keys(): 198 | for c in label_top1_dict[label].keys(): 199 | label_top1_dict[label][c] = random.random() 200 | 201 | 202 | if __name__ == '__main__': 203 | 204 | parser = argparse.ArgumentParser( 205 | description='python main.py --path=data' 206 | ) 207 | parser.add_argument( 208 | "--path", 209 | default='/Users/baiqi/data/pants', 210 | type=str, 211 | ) 212 | args = parser.parse_args() 213 | print('args.path = {}'.format(args.path)) 214 | path = args.path 215 | 216 | '''--------------------------------------''' 217 | 218 | '''SLMCDataloader--------------------------------------''' 219 | # SLMCDataloader 220 | slmcdataloader = SLMCDataloader(os.path.join(path, 'length'), batch_size=32, workers=4, base_number=128) 221 | ''' 222 | for i, (input, target) in enumerate(slmcdataloader.train_diy_loader): 223 | print(target) 224 | print(i) 225 | ''' 226 | print('sldataloader.train_diy_loader.__len__() = {}'.format(slmcdataloader.train_diy_loader.__len__())) 227 | print('sldataloader.val_train_loader.__len__() = {}'.format(slmcdataloader.val_train_loader.__len__())) 228 | print('sldataloader.val_val_loader.__len__() = {}'.format(slmcdataloader.val_val_loader.__len__())) 229 | 230 | 231 | '''MTMCDataloader--------------------------------------''' 232 | # MTMCDataloader 233 | mtmcdataloader = MTMCDataloader(path, batch_size=32, workers=4, max_base_number=100) 234 | for label in mtmcdataloader.label_to_idx.keys(): 235 | print('---main::label={}'.format(label)) 236 | slmcdataloader = mtmcdataloader.slmcdataloader_dict[label] 237 | print('sldataloader.train_diy_loader.__len__() = {}'.format(slmcdataloader.train_diy_loader.__len__())) 238 | print('sldataloader.val_train_loader.__len__() = {}'.format(slmcdataloader.val_train_loader.__len__())) 239 | print('sldataloader.val_val_loader.__len__() = {}'.format(slmcdataloader.val_val_loader.__len__())) 240 | 241 | # mtmcdataloader.update_train_diy_loader(label_top1_dict) 242 | for epoch in range(10): 243 | print('----------epoch = {}------------'.format(epoch)) 244 | label_top1_dict = copy.deepcopy(mtmcdataloader.mtmc_tree) 245 | make_fake_label_top1_dict(label_top1_dict) 246 | mtmcdataloader.update_train_diy_loader(label_top1_dict) 247 | for label in mtmcdataloader.label_to_idx.keys(): 248 | print('-----------main::label={}'.format(label)) 249 | slmcdataloader = mtmcdataloader.slmcdataloader_dict[label] 250 | print('sldataloader.train_diy_loader.__len__() = {}'.format(slmcdataloader.train_diy_loader.__len__())) 251 | print('sldataloader.val_train_loader.__len__() = {}'.format(slmcdataloader.val_train_loader.__len__())) 252 | print('sldataloader.val_val_loader.__len__() = {}'.format(slmcdataloader.val_val_loader.__len__())) 253 | class_to_number_dict = slmcdataloader.base_number_specifier.class_to_number_dict 254 | print('class_to_number_dict = {}'.format(class_to_number_dict)) 255 | 256 | 257 | -------------------------------------------------------------------------------- /src/mtmcmodel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Nov 5 14:52:43 2018 5 | 6 | Python 3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 11:07:29) 7 | 8 | @author: pilgrim.bin@163.com 9 | """ 10 | import torch.nn as nn 11 | from torch.nn import init 12 | from torch.nn import functional as F 13 | 14 | class MultiLabelModel(nn.Module): 15 | def __init__(self, basemodel_output, num_classes, basemodel=None): 16 | super(MultiLabelModel, self).__init__() 17 | self.basemodel = basemodel 18 | self.num_classes = num_classes 19 | 20 | # config 21 | self.cfg_normalize = False # unchecked other method, diff with embedding. 22 | self.cfg_has_embedding = True 23 | self.cfg_num_features = basemodel_output # is there a better number? 24 | self.cfg_dropout_ratio = 0. # 0. is better than 0.8 at attributes:pants problem 25 | 26 | # diy head 27 | for index, num_class in enumerate(num_classes): 28 | if self.cfg_has_embedding: 29 | setattr(self, "EmbeddingFeature_FCLayer_" + str(index), nn.Linear(basemodel_output, self.cfg_num_features)) 30 | setattr(self, "EmbeddingFeature_FCLayer_BN_" + str(index), nn.BatchNorm1d(self.cfg_num_features)) 31 | feat = getattr(self, "EmbeddingFeature_FCLayer_" + str(index)) 32 | feat_bn = getattr(self, "EmbeddingFeature_FCLayer_BN_" + str(index)) 33 | init.kaiming_normal_(feat.weight, mode='fan_out') 34 | init.constant_(feat.bias, 0) 35 | init.constant_(feat_bn.weight, 1) 36 | init.constant_(feat_bn.bias, 0) 37 | if self.cfg_dropout_ratio > 0: 38 | setattr(self, "Dropout_" + str(index), nn.Dropout(self.cfg_dropout_ratio)) 39 | setattr(self, "FullyConnectedLayer_" + str(index), nn.Linear(self.cfg_num_features, num_class)) 40 | classifier = getattr(self, "FullyConnectedLayer_" + str(index)) 41 | init.normal_(classifier.weight, std=0.001) 42 | init.constant_(classifier.bias, 0) 43 | 44 | def forward(self, x): 45 | if self.basemodel is not None: 46 | x = self.basemodel.forward(x) 47 | outs = list() 48 | for index, num_class in enumerate(self.num_classes): 49 | if self.cfg_has_embedding: 50 | feat = getattr(self, "EmbeddingFeature_FCLayer_" + str(index)) 51 | feat_bn = getattr(self, "EmbeddingFeature_FCLayer_BN_" + str(index)) 52 | x = feat(x) 53 | x = feat_bn(x) 54 | if self.cfg_normalize: 55 | x = F.normalize(x) # getattr bug 56 | elif self.cfg_has_embedding: 57 | x = F.relu(x) 58 | if self.cfg_dropout_ratio > 0: 59 | dropout = getattr(self, "Dropout_" + str(index)) 60 | x = dropout(x) 61 | classifier = getattr(self, "FullyConnectedLayer_" + str(index)) 62 | out = classifier(x) 63 | outs.append(out) 64 | return outs 65 | 66 | 67 | def LoadPretrainedModel(model, pretrained_state_dict): 68 | model_dict = model.state_dict() 69 | union_dict = {k : v for k,v in pretrained_state_dict.iteritems() if k in model_dict} 70 | model_dict.update(union_dict) 71 | return model_dict 72 | 73 | def BuildMultiLabelModel(basemodel_output, num_classes, basemodel=None): 74 | return MultiLabelModel(basemodel_output, num_classes, basemodel=basemodel) 75 | 76 | '''----------------------------------------------------------------------------------------------------''' 77 | 78 | # original version of https://github.com/pangwong/pytorch-multi-label-classifier.git 79 | ''' 80 | import torch.nn as nn 81 | 82 | class MultiLabelModel(nn.Module): 83 | def __init__(self, basemodel, basemodel_output, num_classes): 84 | super(MultiLabelModel, self).__init__() 85 | self.basemodel = basemodel 86 | self.num_classes = num_classes 87 | for index, num_class in enumerate(num_classes): 88 | setattr(self, "FullyConnectedLayer_" + str(index), nn.Linear(basemodel_output, num_class)) 89 | 90 | def forward(self, x): 91 | x = self.basemodel.forward(x) 92 | outs = list() 93 | dir(self) 94 | for index, num_class in enumerate(self.num_classes): 95 | fun = eval("self.FullyConnectedLayer_" + str(index)) 96 | out = fun(x) 97 | outs.append(out) 98 | return outs 99 | 100 | def LoadPretrainedModel(model, pretrained_state_dict): 101 | model_dict = model.state_dict() 102 | union_dict = {k : v for k,v in pretrained_state_dict.iteritems() if k in model_dict} 103 | model_dict.update(union_dict) 104 | return model_dict 105 | 106 | def BuildMultiLabelModel(basemodel, basemodel_output, num_classes): 107 | return MultiLabelModel(basemodel, basemodel_output, num_classes) 108 | 109 | ''' 110 | 111 | 112 | 113 | --------------------------------------------------------------------------------