├── Attribute_test_acc.py ├── README.md ├── convert_imageset_multi_labels.cpp ├── data_process └── alignment.cpp ├── models └── val_crop.txt └── train └── LFWA ├── alexnet_256_lfw.prototxt ├── solver_lfwa.prototxt ├── solver_lfwa.prototxt~ ├── train_rnn_net_lfwa.sh └── train_rnn_net_lfwa.sh~ /Attribute_test_acc.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import math 7 | import cv2 8 | import numpy as np 9 | import multiprocessing 10 | from sklearn.metrics import confusion_matrix 11 | import matplotlib.pyplot as plt 12 | 13 | # this file should be run from {caffe_root} 14 | sys.path.append('./python') 15 | import caffe 16 | attribute_acceracy = np.zeros(40) 17 | 18 | 19 | eval_model_def = '/home/sanyuan/CaffeProject/FA_v1-attention/models/attentionRoi/deploy.prototxt' 20 | eval_model_weights = '/home/sanyuan/CaffeProject/FA_v1-attention/models/attentionRoi/alexnet_iter_70000.caffemodel' 21 | 22 | #caffe.set_mode_cpu() 23 | caffe.set_device(0) 24 | caffe.set_mode_gpu() 25 | net = caffe.Net(eval_model_def, eval_model_weights, caffe.TEST) 26 | input_shape = net.blobs['data'].data.shape 27 | sample_shape = net.blobs['data'].data.shape 28 | 29 | # Set transformer 30 | transformer = caffe.io.Transformer({'data': input_shape}) 31 | transformer.set_transpose('data', (2,0,1)) #from 256 256 3 to 3* 256 *256 32 | transformer.set_mean('data', np.array([90.1146, 103.035, 127.689])) #or load the mean image 33 | 34 | 35 | with open("/home/sanyuan/CaffeProject/FA_v1-attention/data/Attribute/val_crop.txt") as Test_list: 36 | lines = Test_list.readlines() 37 | for line in lines: 38 | img_name = line.split()[0] 39 | image_path = "/media/sanyuan/Sanyuan/Dataset/crop_by_4/"+img_name 40 | eval_image = cv2.imread(image_path, cv2.IMREAD_COLOR) 41 | #plt.figure() 42 | #plt.imshow(eval_image) 43 | #plt.show() 44 | rs_image = cv2.resize(eval_image, (input_shape[2],input_shape[3])) 45 | 46 | data = transformer.preprocess('data', rs_image) 47 | data=data*0.0078125 48 | xdata = np.zeros((1,3,256,256)) 49 | xdata[0,...] = data 50 | out = net.forward(data=xdata) 51 | prob_data = net.blobs['prob_data'] 52 | attribute_data =prob_data.data[0][0] 53 | for j in range(0,40): 54 | if(((attribute_data[j][0]>0.5)and(line.split()[j+1]=='0'))or((attribute_data[j][0]<0.5)and(line.split()[j+1]=='1'))): 55 | attribute_acceracy[j]+=1 56 | for j in range(0,40): 57 | print (attribute_acceracy[j]/20258.0) 58 | 59 | print "average acc:" 60 | print (sum(attribute_acceracy)/(40*20258.0)) 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Caffe_face_attribute_classification 2 | multi-task learning method for face attributes learning 3 | -------------------------------------------------------------------------------- /convert_imageset_multi_labels.cpp: -------------------------------------------------------------------------------- 1 | // This program converts a set of images to a lmdb/leveldb by storing them 2 | // as Datum proto buffers. 3 | // Usage: 4 | // convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME 5 | // 6 | // where ROOTFOLDER is the root folder that holds all the images, and LISTFILE 7 | // should be a list of files as well as their labels, in the format as 8 | // subfolder1/file1.JPEG 7 9 | // .... 10 | 11 | #include 12 | #include // NOLINT(readability/streams) 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "boost/scoped_ptr.hpp" 19 | #include "gflags/gflags.h" 20 | #include "glog/logging.h" 21 | 22 | #include "caffe/proto/caffe.pb.h" 23 | #include "caffe/util/db.hpp" 24 | #include "caffe/util/format.hpp" 25 | #include "caffe/util/io.hpp" 26 | #include "caffe/util/rng.hpp" 27 | 28 | using namespace caffe; // NOLINT(build/namespaces) 29 | using std::pair; 30 | using boost::scoped_ptr; 31 | 32 | DEFINE_bool(gray, false, 33 | "When this option is on, treat images as grayscale ones"); 34 | DEFINE_bool(shuffle, false, 35 | "Randomly shuffle the order of images and their labels"); 36 | DEFINE_string(backend, "lmdb", 37 | "The backend {lmdb, leveldb} for storing the result"); 38 | DEFINE_int32(resize_width, 0, "Width images are resized to"); 39 | DEFINE_int32(resize_height, 0, "Height images are resized to"); 40 | DEFINE_bool(check_size, false, 41 | "When this option is on, check that all the datum have the same size"); 42 | DEFINE_bool(encoded, false, 43 | "When this option is on, the encoded image will be save in datum"); 44 | DEFINE_string(encode_type, "", 45 | "Optional: What type should we encode the image as ('png','jpg',...)."); 46 | DEFINE_int32(label_num, 1, 47 | "Optional: How many numbers should we encode the label."); 48 | 49 | int main(int argc, char** argv) { 50 | #ifdef USE_OPENCV 51 | ::google::InitGoogleLogging(argv[0]); 52 | // Print output to stderr (while still logging) 53 | FLAGS_alsologtostderr = 1; 54 | 55 | #ifndef GFLAGS_GFLAGS_H_ 56 | namespace gflags = google; 57 | #endif 58 | 59 | gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n" 60 | "format used as input for Caffe.\n" 61 | "Usage:\n" 62 | " convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n" 63 | "The ImageNet dataset for the training demo is at\n" 64 | " http://www.image-net.org/download-images\n"); 65 | gflags::ParseCommandLineFlags(&argc, &argv, true); 66 | 67 | if (argc < 5) { 68 | gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset_multi_labels"); 69 | return 1; 70 | } 71 | 72 | const bool is_color = !FLAGS_gray; 73 | const bool check_size = FLAGS_check_size; 74 | const bool encoded = FLAGS_encoded; 75 | const string encode_type = FLAGS_encode_type; 76 | const int label_num = FLAGS_label_num; 77 | 78 | std::ifstream infile(argv[2]); 79 | if (!infile.good()) { 80 | std::cout<<"Can not open: "< > > lines; 85 | vector labels(label_num); 86 | 87 | std::string line; 88 | 89 | /*size_t pos; 90 | 91 | while (std::getline(infile, line)) { 92 | pos = line.find_last_of(' '); 93 | label = atoi(line.substr(pos + 1).c_str()); 94 | lines.push_back(std::make_pair(line.substr(0, pos), label)); 95 | } 96 | */ 97 | std::string filename; 98 | while (std::getline(infile, line)) { 99 | //std::cout<> filename; 102 | //std::cout<> labels[i]; 105 | //std::cout<(0, FLAGS_resize_height); 124 | int resize_width = std::max(0, FLAGS_resize_width); 125 | 126 | // Create new DB 127 | scoped_ptr db_image(db::GetDB(FLAGS_backend)); 128 | scoped_ptr db_labels(db::GetDB(FLAGS_backend)); 129 | db_image->Open(argv[3], db::NEW); 130 | db_labels->Open(argv[4], db::NEW); 131 | scoped_ptr txn_image(db_image->NewTransaction()); 132 | scoped_ptr txn_labels(db_labels->NewTransaction()); 133 | 134 | // Storing to db 135 | std::string root_folder(argv[1]); 136 | Datum datum_image; 137 | Datum datum_labels; 138 | int count = 0; 139 | int data_size_image = 0; 140 | int data_size_labels = 0; 141 | bool data_size_initialized = false; 142 | 143 | for (int line_id = 0; line_id < lines.size(); ++line_id) { 144 | bool status; 145 | std::string enc = encode_type; 146 | if (encoded && !enc.size()) { 147 | // Guess the encoding type from the file name 148 | string fn = lines[line_id].first; 149 | size_t p = fn.rfind('.'); 150 | if ( p == fn.npos ) 151 | LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'"; 152 | enc = fn.substr(p); 153 | std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower); 154 | } 155 | status = ReadImageToDatum(root_folder + lines[line_id].first, 156 | lines[line_id].second[0], resize_height, resize_width, is_color, 157 | enc, &datum_image); 158 | if (status == false) continue; 159 | 160 | datum_labels.set_height(1); 161 | datum_labels.set_width(1); 162 | datum_labels.set_channels(label_num); 163 | 164 | for (int index_label = 0; index_label < lines[line_id].second.size(); index_label++) 165 | { 166 | float tmp_float_value = lines[line_id].second[index_label]; 167 | //std::cout<Put(key_str_image, out_image); 200 | txn_labels->Put(key_str_labels, out_labels); 201 | 202 | if (++count % 1000 == 0) { 203 | // Commit db 204 | txn_image->Commit(); 205 | txn_labels->Commit(); 206 | 207 | txn_image.reset(db_image->NewTransaction()); 208 | txn_labels.reset(db_labels->NewTransaction()); 209 | LOG(INFO) << "Processed " << count << " files."; 210 | } 211 | } 212 | // write the last batch 213 | if (count % 1000 != 0) { 214 | txn_image->Commit(); 215 | txn_labels->Commit(); 216 | LOG(INFO) << "Processed " << count << " files."; 217 | } 218 | #else 219 | LOG(FATAL) << "This tool requires OpenCV; compile with USE_OPENCV."; 220 | #endif // USE_OPENCV 221 | return 0; 222 | } 223 | -------------------------------------------------------------------------------- /data_process/alignment.cpp: -------------------------------------------------------------------------------- 1 | //celeba aligment 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | using namespace cv; 8 | using namespace std; 9 | 10 | cv::Mat findNonReflectiveTransform(std::vector source_points, std::vector target_points, Mat& Tinv = Mat()) { 11 | assert(source_points.size() == target_points.size()); 12 | assert(source_points.size() >= 2); 13 | Mat U = Mat::zeros(target_points.size() * 2, 1, CV_64F); 14 | Mat X = Mat::zeros(source_points.size() * 2, 4, CV_64F); 15 | for (int i = 0; i < target_points.size(); i++) { 16 | U.at(i * 2, 0) = source_points[i].x; 17 | U.at(i * 2 + 1, 0) = source_points[i].y; 18 | X.at(i * 2, 0) = target_points[i].x; 19 | X.at(i * 2, 1) = target_points[i].y; 20 | X.at(i * 2, 2) = 1; 21 | X.at(i * 2, 3) = 0; 22 | X.at(i * 2 + 1, 0) = target_points[i].y; 23 | X.at(i * 2 + 1, 1) = -target_points[i].x; 24 | X.at(i * 2 + 1, 2) = 0; 25 | X.at(i * 2 + 1, 3) = 1; 26 | } 27 | Mat r = X.inv(DECOMP_SVD)*U; 28 | Tinv = (Mat_(3, 3) << r.at(0), -r.at(1), 0, 29 | r.at(1), r.at(0), 0, 30 | r.at(2), r.at(3), 1); 31 | Mat T = Tinv.inv(DECOMP_SVD); 32 | Tinv = Tinv(Rect(0, 0, 2, 3)).t(); 33 | return T(Rect(0, 0, 2, 3)).t(); 34 | } 35 | cv::Mat findSimilarityTransform(std::vector source_points, std::vector target_points, Mat& Tinv = Mat()) { 36 | Mat Tinv1, Tinv2; 37 | Mat trans1 = findNonReflectiveTransform(source_points, target_points, Tinv1); 38 | std::vector source_point_reflect; 39 | for (auto sp : source_points) { 40 | source_point_reflect.push_back(Point2d(-sp.x, sp.y)); 41 | } 42 | swap(source_point_reflect[0], source_point_reflect[1]); 43 | swap(source_point_reflect[3], source_point_reflect[4]); 44 | Mat trans2 = findNonReflectiveTransform(source_point_reflect, target_points, Tinv2); 45 | trans2.colRange(0, 1) *= -1; 46 | Tinv2.rowRange(0, 1) *= -1; 47 | std::vector trans_points1, trans_points2; 48 | transform(source_points, trans_points1, trans1); 49 | transform(source_points, trans_points2, trans2); 50 | swap(trans_points2[0], trans_points2[1]); 51 | swap(trans_points2[3], trans_points2[4]); 52 | double norm1 = norm(Mat(trans_points1), Mat(target_points), NORM_L2); 53 | double norm2 = norm(Mat(trans_points2), Mat(target_points), NORM_L2); 54 | Tinv = norm1 < norm2 ? Tinv1 : Tinv2; 55 | return norm1 < norm2 ? trans1 : trans2; 56 | } 57 | 58 | 59 | 60 | int main() { 61 | ifstream label_attribute("D:\\DataSet\\CelebA\\list_landmarks_celeba.txt"); 62 | string img_dir = "D:\\DataSet\\CelebA\\Img\\img_celeba.7z\\img_celeba.7z\\img_celeba\\"; 63 | string dstimg_dir = "E:\\Dataset\\crop_by_me\\"; 64 | string point_attribute; 65 | vector target_points = { {98.4, 102.4 },{ 147.7, 102.1 },{ 123.2, 103.4 },{ 102.97,159.3 },{ 143.8,159.1} }; 66 | 67 | Mat trans_inv; 68 | std::vector points; 69 | Mat trans; 70 | Mat cropImage; 71 | Mat image; 72 | 73 | while (getline(label_attribute, point_attribute)) 74 | { 75 | string buf; 76 | stringstream ss(point_attribute); 77 | vector tokens; 78 | while (ss >> buf) { 79 | tokens.push_back(buf); 80 | } 81 | 82 | for (int i = 1; i < 6; i++) { 83 | points.push_back(cv::Point2d(std::stoi(tokens[2 * i - 1]), std::stoi(tokens[2 * i]))); 84 | } 85 | trans = findSimilarityTransform(points,target_points, trans_inv); 86 | image = imread(img_dir + tokens[0]); 87 | warpAffine(image, cropImage, trans, Size(256, 256)); 88 | //imshow("Q", cropImage); 89 | //waitKey(1); 90 | imwrite(dstimg_dir+tokens[0], cropImage); 91 | points.clear(); 92 | } 93 | 94 | return 0; 95 | } 96 | -------------------------------------------------------------------------------- /train/LFWA/alexnet_256_lfw.prototxt: -------------------------------------------------------------------------------- 1 | name: "AlexNet" 2 | layer { 3 | name: "data" 4 | type: "Data" 5 | top: "data" 6 | include { 7 | phase: TRAIN 8 | } 9 | transform_param { 10 | mirror: true 11 | mean_value: 82.6054 12 | mean_value: 92.6743 13 | mean_value: 106.799 14 | crop_size: 256 15 | scale: 0.0078125 16 | } 17 | data_param { 18 | source: "examples/lfwa/attribute_train_img_lmdb" 19 | batch_size: 32 20 | backend: LMDB 21 | } 22 | } 23 | layer { 24 | name: "data" 25 | type: "Data" 26 | top: "labels" 27 | include { 28 | phase: TRAIN 29 | } 30 | data_param { 31 | source: "examples/lfwa/attribute_train_labels_lmdb" 32 | batch_size: 32 33 | backend: LMDB 34 | } 35 | } 36 | layer { 37 | name: "data" 38 | type: "Data" 39 | top: "data" 40 | include { 41 | phase: TEST 42 | } 43 | transform_param { 44 | mirror: true 45 | mean_value: 87.753 46 | mean_value: 92.8039 47 | mean_value: 106.842 48 | scale: 0.0078125 49 | crop_size: 256 50 | } 51 | data_param { 52 | source: "examples/lfwa/attribute_val_img_lmdb" 53 | batch_size: 16 54 | backend: LMDB 55 | } 56 | } 57 | layer { 58 | name: "data" 59 | type: "Data" 60 | top: "labels" 61 | include { 62 | phase: TEST 63 | } 64 | data_param { 65 | source: "examples/lfwa/attribute_val_labels_lmdb" 66 | batch_size: 16 67 | backend: LMDB 68 | } 69 | } 70 | layer { 71 | name: "conv1" 72 | type: "Convolution" 73 | bottom: "data" 74 | top: "conv1" 75 | param { 76 | lr_mult: 1 77 | decay_mult: 1 78 | } 79 | param { 80 | lr_mult: 2 81 | decay_mult: 0 82 | } 83 | convolution_param { 84 | num_output: 96 85 | kernel_size: 5 86 | stride: 2 87 | weight_filler { 88 | type: "gaussian" 89 | std: 0.01 90 | } 91 | bias_filler { 92 | type: "constant" 93 | value: 0 94 | } 95 | } 96 | } 97 | layer { 98 | name: "bn1_1" 99 | type: "BN" 100 | bottom: "conv1" 101 | top: "bn1_1" 102 | param { 103 | lr_mult: 1 104 | decay_mult: 0 105 | } 106 | param { 107 | lr_mult: 1 108 | decay_mult: 0 109 | } 110 | bn_param { 111 | slope_filler { 112 | type: "constant" 113 | value: 1 114 | } 115 | bias_filler { 116 | type: "constant" 117 | value: 0 118 | } 119 | } 120 | } 121 | layer { 122 | name: "relu1_1" 123 | type: "PReLU" 124 | bottom: "bn1_1" 125 | top: "bn1_1" 126 | } 127 | layer { 128 | name: "pool1" 129 | type: "Pooling" 130 | bottom: "bn1_1" 131 | top: "pool1" 132 | pooling_param { 133 | pool: MAX 134 | kernel_size: 3 135 | stride: 2 136 | } 137 | } 138 | layer { 139 | name: "conv2" 140 | type: "Convolution" 141 | bottom: "pool1" 142 | top: "conv2" 143 | param { 144 | lr_mult: 1 145 | decay_mult: 1 146 | } 147 | param { 148 | lr_mult: 2 149 | decay_mult: 0 150 | } 151 | convolution_param { 152 | num_output: 256 153 | pad: 2 154 | kernel_size: 5 155 | group: 2 156 | weight_filler { 157 | type: "gaussian" 158 | std: 0.01 159 | } 160 | bias_filler { 161 | type: "constant" 162 | value: 0.1 163 | } 164 | } 165 | } 166 | layer { 167 | name: "bn1_2" 168 | type: "BN" 169 | bottom: "conv2" 170 | top: "bn1_2" 171 | param { 172 | lr_mult: 1 173 | decay_mult: 0 174 | } 175 | param { 176 | lr_mult: 1 177 | decay_mult: 0 178 | } 179 | bn_param { 180 | slope_filler { 181 | type: "constant" 182 | value: 1 183 | } 184 | bias_filler { 185 | type: "constant" 186 | value: 0 187 | } 188 | } 189 | } 190 | layer { 191 | name: "relu1_2" 192 | type: "PReLU" 193 | bottom: "bn1_2" 194 | top: "bn1_2" 195 | } 196 | layer { 197 | name: "pool2" 198 | type: "Pooling" 199 | bottom: "bn1_2" 200 | top: "pool2" 201 | pooling_param { 202 | pool: MAX 203 | kernel_size: 3 204 | stride: 2 205 | } 206 | } 207 | layer { 208 | name: "conv3" 209 | type: "Convolution" 210 | bottom: "pool2" 211 | top: "conv3" 212 | param { 213 | lr_mult: 1 214 | decay_mult: 1 215 | } 216 | param { 217 | lr_mult: 2 218 | decay_mult: 0 219 | } 220 | convolution_param { 221 | num_output: 384 222 | pad: 1 223 | kernel_size: 3 224 | weight_filler { 225 | type: "gaussian" 226 | std: 0.01 227 | } 228 | bias_filler { 229 | type: "constant" 230 | value: 0 231 | } 232 | } 233 | } 234 | layer { 235 | name: "bn1_3" 236 | type: "BN" 237 | bottom: "conv3" 238 | top: "bn1_3" 239 | param { 240 | lr_mult: 1 241 | decay_mult: 0 242 | } 243 | param { 244 | lr_mult: 1 245 | decay_mult: 0 246 | } 247 | bn_param { 248 | slope_filler { 249 | type: "constant" 250 | value: 1 251 | } 252 | bias_filler { 253 | type: "constant" 254 | value: 0 255 | } 256 | } 257 | } 258 | layer { 259 | name: "relu1_3" 260 | type: "PReLU" 261 | bottom: "bn1_3" 262 | top: "bn1_3" 263 | } 264 | layer { 265 | name: "conv4" 266 | type: "Convolution" 267 | bottom: "bn1_3" 268 | top: "conv4" 269 | param { 270 | lr_mult: 1 271 | decay_mult: 1 272 | } 273 | param { 274 | lr_mult: 2 275 | decay_mult: 0 276 | } 277 | convolution_param { 278 | num_output: 384 279 | pad: 1 280 | kernel_size: 3 281 | group: 2 282 | weight_filler { 283 | type: "gaussian" 284 | std: 0.01 285 | } 286 | bias_filler { 287 | type: "constant" 288 | value: 0.1 289 | } 290 | } 291 | } 292 | layer { 293 | name: "bn1_4" 294 | type: "BN" 295 | bottom: "conv4" 296 | top: "bn1_4" 297 | param { 298 | lr_mult: 1 299 | decay_mult: 0 300 | } 301 | param { 302 | lr_mult: 1 303 | decay_mult: 0 304 | } 305 | bn_param { 306 | slope_filler { 307 | type: "constant" 308 | value: 1 309 | } 310 | bias_filler { 311 | type: "constant" 312 | value: 0 313 | } 314 | } 315 | } 316 | layer { 317 | name: "relu1_4" 318 | type: "PReLU" 319 | bottom: "bn1_4" 320 | top: "bn1_4" 321 | } 322 | layer { 323 | name: "conv5" 324 | type: "Convolution" 325 | bottom: "bn1_4" 326 | top: "conv5" 327 | param { 328 | lr_mult: 1 329 | decay_mult: 1 330 | } 331 | param { 332 | lr_mult: 2 333 | decay_mult: 0 334 | } 335 | convolution_param { 336 | num_output: 256 337 | pad: 1 338 | kernel_size: 3 339 | group: 2 340 | weight_filler { 341 | type: "gaussian" 342 | std: 0.01 343 | } 344 | bias_filler { 345 | type: "constant" 346 | value: 0.1 347 | } 348 | } 349 | } 350 | layer { 351 | name: "bn2_1" 352 | type: "BN" 353 | bottom: "conv5" 354 | top: "bn2_1" 355 | param { 356 | lr_mult: 1 357 | decay_mult: 0 358 | } 359 | param { 360 | lr_mult: 1 361 | decay_mult: 0 362 | } 363 | bn_param { 364 | slope_filler { 365 | type: "constant" 366 | value: 1 367 | } 368 | bias_filler { 369 | type: "constant" 370 | value: 0 371 | } 372 | } 373 | } 374 | layer { 375 | name: "relu2_1" 376 | type: "PReLU" 377 | bottom: "bn2_1" 378 | top: "bn2_1" 379 | } 380 | layer { 381 | name: "pool5" 382 | type: "Pooling" 383 | bottom: "bn2_1" 384 | top: "pool5" 385 | pooling_param { 386 | pool: MAX 387 | kernel_size: 3 388 | stride: 2 389 | } 390 | } 391 | layer { 392 | name: "fc6" 393 | type: "InnerProduct" 394 | bottom: "pool5" 395 | top: "fc6" 396 | param { 397 | lr_mult: 1 398 | decay_mult: 1 399 | } 400 | param { 401 | lr_mult: 2 402 | decay_mult: 0 403 | } 404 | inner_product_param { 405 | num_output: 4096 406 | weight_filler { 407 | type: "gaussian" 408 | std: 0.005 409 | } 410 | bias_filler { 411 | type: "constant" 412 | value: 0.1 413 | } 414 | } 415 | } 416 | layer { 417 | name: "bn2_2" 418 | type: "BN" 419 | bottom: "fc6" 420 | top: "bn2_2" 421 | param { 422 | lr_mult: 1 423 | decay_mult: 0 424 | } 425 | param { 426 | lr_mult: 1 427 | decay_mult: 0 428 | } 429 | bn_param { 430 | slope_filler { 431 | type: "constant" 432 | value: 1 433 | } 434 | bias_filler { 435 | type: "constant" 436 | value: 0 437 | } 438 | } 439 | } 440 | layer { 441 | name: "relu2_2" 442 | type: "PReLU" 443 | bottom: "bn2_2" 444 | top: "bn2_2" 445 | } 446 | layer { 447 | name: "fc7" 448 | type: "InnerProduct" 449 | bottom: "bn2_2" 450 | top: "fc7" 451 | param { 452 | lr_mult: 1 453 | decay_mult: 1 454 | } 455 | param { 456 | lr_mult: 2 457 | decay_mult: 0 458 | } 459 | inner_product_param { 460 | num_output: 4096 461 | weight_filler { 462 | type: "gaussian" 463 | std: 0.005 464 | } 465 | bias_filler { 466 | type: "constant" 467 | value: 0.1 468 | } 469 | } 470 | } 471 | layer { 472 | name: "bn2_3" 473 | type: "BN" 474 | bottom: "fc7" 475 | top: "bn2_3" 476 | param { 477 | lr_mult: 1 478 | decay_mult: 0 479 | } 480 | param { 481 | lr_mult: 1 482 | decay_mult: 0 483 | } 484 | bn_param { 485 | slope_filler { 486 | type: "constant" 487 | value: 1 488 | } 489 | bias_filler { 490 | type: "constant" 491 | value: 0 492 | } 493 | } 494 | } 495 | layer { 496 | name: "relu2_3" 497 | type: "PReLU" 498 | bottom: "bn2_3" 499 | top: "bn2_3" 500 | } 501 | layer { 502 | name: "attribute_ip2" 503 | type: "InnerProduct" 504 | bottom: "bn2_3" 505 | top: "attribute_ip2" 506 | param { 507 | lr_mult: 1 508 | decay_mult: 1e-5 509 | name: "attribute_ip2_w" 510 | } 511 | param { 512 | lr_mult: 1 513 | decay_mult: 0 514 | name: "attribute_ip2_b" 515 | } 516 | inner_product_param { 517 | num_output: 80 518 | weight_filler { 519 | type: "xavier" 520 | } 521 | bias_filler { 522 | type: "constant" 523 | } 524 | } 525 | } 526 | layer { 527 | name: "reshape_attribute" 528 | type: "Reshape" 529 | bottom: "attribute_ip2" 530 | top: "attribute_ip2_reshape" 531 | reshape_param { 532 | shape { 533 | dim: 0 # copy the dimension from below 534 | dim: 2 535 | dim: 40 536 | dim: -1 # infer it from the other dimensions 537 | } 538 | } 539 | } 540 | layer { 541 | name: "attribute_loss" 542 | type: "SoftmaxWithLoss" 543 | loss_weight: 1 544 | bottom: "attribute_ip2_reshape" 545 | bottom: "labels" 546 | top: "attribute_loss" 547 | softmax_param { 548 | axis: 1 549 | } 550 | # include: { phase: TRAIN } 551 | } 552 | layer { 553 | name: "attribute_accuracy" 554 | type: "Accuracy" 555 | bottom: "attribute_ip2_reshape" 556 | bottom: "labels" 557 | top: "attribute_accuracy" 558 | include: { phase: TEST } 559 | accuracy_param { 560 | axis: 1 561 | } 562 | } 563 | -------------------------------------------------------------------------------- /train/LFWA/solver_lfwa.prototxt: -------------------------------------------------------------------------------- 1 | net: "train/LFWA/alexnet_256_lfw.prototxt" 2 | test_iter: 1000 3 | test_interval: 1000 4 | base_lr: 0.001 5 | lr_policy: "multistep" 6 | gamma: 0.1 7 | stepvalue: 10000 8 | #for celeba the stepvalue:30000, the max_iter:40000 9 | display: 100 10 | max_iter: 20000 11 | momentum: 0.9 12 | weight_decay: 0.005 13 | snapshot: 10000 14 | snapshot_prefix: "train/LFWA/lfwa" 15 | solver_mode: GPU 16 | -------------------------------------------------------------------------------- /train/LFWA/solver_lfwa.prototxt~: -------------------------------------------------------------------------------- 1 | net: "train/LFWA/alexnet_256_lfw.prototxt" 2 | test_iter: 1000 3 | test_interval: 1000 4 | base_lr: 0.001 5 | lr_policy: "multistep" 6 | gamma: 0.1 7 | stepvalue: 10000 8 | display: 100 9 | max_iter: 20000 10 | momentum: 0.9 11 | weight_decay: 0.005 12 | snapshot: 10000 13 | snapshot_prefix: "./models/Attribute/lfwa/lfwa" 14 | solver_mode: GPU 15 | -------------------------------------------------------------------------------- /train/LFWA/train_rnn_net_lfwa.sh: -------------------------------------------------------------------------------- 1 | ./build/tools/caffe train \ 2 | --solver=train/LFWA/solver_lfwa.prototxt \ 3 | 2>&1 | tee train/LFWA/attribute.log 4 | # --weights=/home/sanyuan/CaffeProject/FA_v1-attention/models/Attribute/lfwa/step1/lfwa_iter_10000.caffemodel \ 5 | -------------------------------------------------------------------------------- /train/LFWA/train_rnn_net_lfwa.sh~: -------------------------------------------------------------------------------- 1 | ./build/tools/caffe train \ 2 | --solver=train/LFWA/solver_lfwa.prototxt \ 3 | 2>&1 | tee logs/attribute.log 4 | # --weights=/home/sanyuan/CaffeProject/FA_v1-attention/models/Attribute/lfwa/step1/lfwa_iter_10000.caffemodel \ 5 | --------------------------------------------------------------------------------