├── softmax_multilabel_output.cu ├── softmax_multilabel_output.cc ├── my_metrics.py ├── README.md └── softmax_multilabel_output-inl.h /softmax_multilabel_output.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2015 by Contributors 3 | * \file softmax_multilabel_output.cu 4 | * \brief 5 | * modified from softmax_output.cu by Bo Xin 6 | */ 7 | 8 | #include "./softmax_multilabel_output-inl.h" 9 | 10 | namespace mxnet { 11 | namespace op { 12 | template<> 13 | Operator *CreateOp(SoftmaxMultilabelOutputParam param) { 14 | return new SoftmaxMultilabelOutputOp(param); 15 | } 16 | 17 | 18 | } // namespace op 19 | } // namespace mxnet 20 | 21 | -------------------------------------------------------------------------------- /softmax_multilabel_output.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2015 by Contributors 3 | * \file softmax_multilabel_output.cc 4 | * \brief 5 | * modified from softmax_output.cc by Bo Xin 6 | */ 7 | #include "./softmax_multilabel_output-inl.h" 8 | 9 | 10 | namespace mxnet { 11 | namespace op { 12 | template<> 13 | Operator *CreateOp(SoftmaxMultilabelOutputParam param) { 14 | return new SoftmaxMultilabelOutputOp(param); 15 | } 16 | 17 | Operator *SoftmaxMultilabelOutputProp::CreateOperator(Context ctx) const { 18 | DO_BIND_DISPATCH(CreateOp, param_); 19 | } 20 | 21 | DMLC_REGISTER_PARAMETER(SoftmaxMultilabelOutputParam); 22 | 23 | MXNET_REGISTER_OP_PROPERTY(SoftmaxMultilabelOutput, SoftmaxMultilabelOutputProp) 24 | .describe("Perform a softmax_multilabel transformation on input, backprop with logloss.") 25 | .add_argument("data", "Symbol", "Input data to softmax_multilabel.") 26 | .add_arguments(SoftmaxMultilabelOutputParam::__FIELDS__()); 27 | 28 | } // namespace op 29 | } // namespace mxnet 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /my_metrics.py: -------------------------------------------------------------------------------- 1 | # a simple evaluation metric by Bo Xin 2 | 3 | # k: num of labels 4 | # label: ground truth label matrix (numpy ndarray), nxk where n is the num of sampels per batch. 5 | # The value of a label belongs to [0,K-1] where K is the largest label index. 6 | # pred_prob: predicted probability matrix (numpy ndarray), nxK where n is the num of sampels per batch. 7 | 8 | # Accuracy1: check if the ground truth k labels are exactly same with the predicted top k labels. 9 | # Accuracy2: check how much percent of the ground truth k labels were correctly found in the predicted top k labels. 10 | 11 | def Accuracy1(label, pred_prob): 12 | pred = np.argpartition(pred_prob, -k, axis=1)[:,-k:] 13 | t_score = np.zeros(label.shape) 14 | for i in range(k): 15 | for j in range(k): 16 | t_score[:,i] = t_score[:,i]+(label[:,i]==pred[:,j]) 17 | return np.sum((np.sum(t_score, axis=1)==k))*1.0/(pred.shape[0]) 18 | 19 | def Accuracy2(label, pred_prob): 20 | pred = np.argpartition(pred_prob, -k, axis=1)[:,-k:] 21 | t_score = np.zeros(label.shape) 22 | for i in range(k): 23 | for j in range(k): 24 | t_score[:,i] = t_score[:,i]+(label[:,i]==pred[:,j]) 25 | return np.sum(np.sum(t_score))*1.0/(pred.shape[0]*k) 26 | 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # multilabel-layer-mxnet 2 | 3 | 4 | # what is it. 5 | This is a mxnet operator layer for multilabel classification. It implements multilabel softmax. 6 | 7 | We assume the num of labels for each sample is known and fixed, denoted as k. Then the ground truth label matrix (size nxk where n is the num of sampels per batch) have values of [0,K-1] where K is the largest label index. 8 | 9 | (In the case k is not fixed for each sample, one can bulid the ground truth label matrix of size nxkm, where km is the largest possible num of labels for any sample, and assign the useless entries to be any values outside [0, K-1]) 10 | 11 | # how to use it. 12 | Put the 3 files (softmax_multilabel_output-inh.h, softmax_multilabel_output.cc, softmax_multilabel_output.cu) into mxnet/src/operator and recompile mxnet. 13 | 14 | Then one can use it as any other operators. The following is a simple example. 15 | 16 | ... 17 | num_label = 3 18 | X = mx.sym.Variable("data") 19 | fc = mx.sym.FullyConnected(data=net, num_hidden=100, name="fc") 20 | msoftmax = mx.sym.SoftmaxMultilabelOutput(data=fc, name='msoftmax', num_label=num_label) 21 | executor = msoftmax.simple_bind(ctx=ctx, data=data_shape, grad_req='write') 22 | ... 23 | 24 | The above example used low level API of MXnet which is more flexible, please find more details here: https://github.com/dmlc/mxnet/blob/master/example/notebooks/simple_bind.ipynb 25 | 26 | # some thoughts. 27 | In theory, the ground truth matrix often appears to be of size nxK with 0s and 1s. We have assumed that k is usually much smaller than K and the current choice save certain space. 28 | 29 | 30 | -------------------------------------------------------------------------------- /softmax_multilabel_output-inl.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2015 by Contributors 3 | * \file softmax_multilabel_output-inl.h 4 | * \brief 5 | * modified from softmax_output-inl.h by Bo Xin 6 | */ 7 | #ifndef MXNET_OPERATOR_SOFTMAX_MULTILABEL_OUTPUT_INL_H_ 8 | #define MXNET_OPERATOR_SOFTMAX_MULTILABEL_OUTPUT_INL_H_ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include "./operator_common.h" 19 | 20 | namespace mxnet { 21 | namespace op { 22 | 23 | namespace softmaxmultilabelout_enum { 24 | enum SoftmaxMultilabelOutputOpInputs {kData, kLabel}; 25 | enum SoftmaxMultilabelOutputOpOutputs {kOut}; 26 | } // namespace softmaxmultilabelout_enum 27 | 28 | struct SoftmaxMultilabelOutputParam : public dmlc::Parameter { 29 | float grad_scale; 30 | int num_label; 31 | DMLC_DECLARE_PARAMETER(SoftmaxMultilabelOutputParam) { 32 | DMLC_DECLARE_FIELD(grad_scale).set_default(1.0f) 33 | .describe("Scale the gradient by a float factor"); 34 | DMLC_DECLARE_FIELD(num_label).set_default(1) 35 | .describe("Set the number of labels"); 36 | }; 37 | }; 38 | 39 | template 40 | class SoftmaxMultilabelOutputOp : public Operator { 41 | public: 42 | explicit SoftmaxMultilabelOutputOp(SoftmaxMultilabelOutputParam param) : param_(param) {} 43 | 44 | virtual void Forward(const OpContext &ctx, 45 | const std::vector &in_data, 46 | const std::vector &req, 47 | const std::vector &out_data, 48 | const std::vector &aux_args) { 49 | using namespace mshadow; 50 | using namespace mshadow::expr; 51 | CHECK_EQ(in_data.size(), 2) << "SoftmaxMultilabelOutput Input: [data, label]"; 52 | CHECK_EQ(out_data.size(), 1) << "SoftmaxMultilabelOutput Output: [output]"; 53 | Stream *s = ctx.get_stream(); 54 | Tensor data = in_data[softmaxmultilabelout_enum::kData].FlatTo2D(s); 55 | Tensor out = out_data[softmaxmultilabelout_enum::kOut].FlatTo2D(s); 56 | Softmax(out, data); 57 | } 58 | 59 | virtual void Backward(const OpContext &ctx, 60 | const std::vector &out_grad, 61 | const std::vector &in_data, 62 | const std::vector &out_data, 63 | const std::vector &req, 64 | const std::vector &in_grad, 65 | const std::vector &aux_args) { 66 | using namespace mshadow; 67 | using namespace mshadow::expr; 68 | CHECK_EQ(in_data.size(), 2); 69 | CHECK_EQ(out_grad.size(), 1); 70 | CHECK_GE(in_grad.size(), 1); 71 | CHECK_GE(req.size(), 1); 72 | Stream *s = ctx.get_stream(); 73 | 74 | Tensor label = in_data[softmaxmultilabelout_enum::kLabel].FlatTo2D(s); 75 | Tensor out = out_data[softmaxmultilabelout_enum::kOut].FlatTo2D(s); 76 | Tensor grad = in_grad[softmaxmultilabelout_enum::kData].FlatTo2D(s); 77 | SoftmaxMultilabelGrad(grad, out, label); 78 | if (param_.grad_scale < 1.0) { 79 | grad *= param_.grad_scale; 80 | } 81 | } 82 | 83 | private: 84 | SoftmaxMultilabelOutputParam param_; 85 | }; // class SoftmaxMultilabelOutputOp 86 | 87 | // Decalre Factory function, used for dispatch specialization 88 | template 89 | Operator* CreateOp(SoftmaxMultilabelOutputParam param); 90 | 91 | #if DMLC_USE_CXX11 92 | class SoftmaxMultilabelOutputProp : public OperatorProperty { 93 | public: 94 | std::vector ListArguments() const override { 95 | return {"data", "label"}; 96 | } 97 | 98 | void Init(const std::vector >& kwargs) override { 99 | param_.Init(kwargs); 100 | } 101 | 102 | std::map GetParams() const override { 103 | return param_.__DICT__(); 104 | } 105 | 106 | bool InferShape(std::vector *in_shape, 107 | std::vector *out_shape, 108 | std::vector *aux_shape) const override { 109 | using namespace mshadow; 110 | CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]"; 111 | const TShape &dshape = in_shape->at(0); 112 | if (dshape.ndim() == 0) return false; 113 | SHAPE_ASSIGN_CHECK(*in_shape, softmaxmultilabelout_enum::kLabel, Shape2(dshape[0], param_.num_label)); 114 | out_shape->clear(); 115 | out_shape->push_back(dshape); 116 | return true; 117 | } 118 | 119 | OperatorProperty* Copy() const override { 120 | auto ptr = new SoftmaxMultilabelOutputProp(); 121 | ptr->param_ = param_; 122 | return ptr; 123 | } 124 | 125 | std::string TypeString() const override { 126 | return "SoftmaxMultilabelOutput"; 127 | } 128 | 129 | std::vector DeclareBackwardDependency( 130 | const std::vector &out_grad, 131 | const std::vector &in_data, 132 | const std::vector &out_data) const override { 133 | return {in_data[softmaxmultilabelout_enum::kLabel], out_data[softmaxmultilabelout_enum::kOut]}; 134 | } 135 | 136 | std::vector > BackwardInplaceOption( 137 | const std::vector &out_grad, 138 | const std::vector &in_data, 139 | const std::vector &out_data, 140 | const std::vector &in_grad) const override { 141 | return {{out_data[softmaxmultilabelout_enum::kOut], in_grad[softmaxmultilabelout_enum::kData]}}; 142 | } 143 | 144 | std::vector > ForwardInplaceOption( 145 | const std::vector &in_data, 146 | const std::vector &out_data) const override { 147 | return {{in_data[softmaxmultilabelout_enum::kData], out_data[softmaxmultilabelout_enum::kOut]}}; 148 | } 149 | 150 | Operator* CreateOperator(Context ctx) const override; 151 | 152 | protected: 153 | SoftmaxMultilabelOutputParam param_; 154 | }; // class SoftmaxMultilabelOutputProp 155 | 156 | 157 | #endif // DMLC_USE_CXX11 158 | 159 | } // namespace op 160 | } // namespace mxnet 161 | 162 | 163 | 164 | 165 | // detailed implementaion of both cpu and gpu 166 | namespace mshadow { 167 | 168 | template 169 | inline void SoftmaxMultilabelGrad(Tensor dst, 170 | const Tensor &src, 171 | const Tensor &label) { 172 | 173 | for (index_t y = 0; y < dst.size(0); ++y) { 174 | for (index_t x = 0; x < dst.size(1); ++x) { 175 | dst[y][x] = 0.0; 176 | for (index_t i = 0; i < label.size(1); ++i) { 177 | const index_t k = static_cast(label[y][i]); 178 | 179 | if ( k >= 0 && k < dst.size(1)) { 180 | 181 | if (x == k) { 182 | dst[y][x] += src[y][x] - 1.0f; 183 | } 184 | else { 185 | dst[y][x] += src[y][x]; 186 | } 187 | 188 | } 189 | 190 | } 191 | } 192 | } 193 | } 194 | 195 | 196 | namespace cuda { 197 | template 198 | __global__ void SoftmaxMultilabelGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax, index_t lmax) { 199 | const unsigned x_size = 1 << x_bits; 200 | const int y = blockIdx.x; 201 | 202 | // calculate normalizer, with writeback 203 | for (unsigned x = 0; x < xmax; x += x_size) { 204 | const unsigned xindex = x + threadIdx.x; 205 | if (xindex < xmax) { 206 | 207 | dst.REval(y, xindex) = 0.0f; 208 | 209 | for (unsigned i = 0; i < lmax; ++i) { 210 | 211 | //__syncthreads(); 212 | 213 | int k = static_cast(label.Eval(y, i)); 214 | 215 | if (k >= 0 && k < xmax) { 216 | 217 | if (xindex == k) { 218 | dst.REval(y, xindex) += src.Eval(y, xindex) - 1.0f; 219 | } 220 | else { 221 | dst.REval(y, xindex) += src.Eval(y, xindex); 222 | } 223 | 224 | } 225 | 226 | } 227 | } 228 | } 229 | } 230 | template 231 | inline void SoftmaxMultilabelGrad(Tensor &dst, 232 | const Tensor &src, 233 | const Tensor &label) { 234 | dim3 dimBlock(kBaseThreadNum); 235 | dim3 dimGrid(dst.size(0)); 236 | CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxMultilabelGrad: shape mismatch"; 237 | CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxMultilabelGrad: label shape mismatch"; 238 | CheckLaunchParam(dimGrid, dimBlock, "SoftmaxMultilabelGrad"); 239 | cudaStream_t stream = Stream::GetStream(dst.stream_); 240 | SoftmaxMultilabelGradKernel 241 | << > > 242 | (expr::MakePlan(dst), 243 | expr::MakePlan(src), 244 | expr::MakePlan(label), 245 | dst.size(1), label.size(1)); 246 | } 247 | } // namespace cuda 248 | 249 | template 250 | inline void SoftmaxMultilabelGrad(Tensor dst, 251 | const Tensor &src, 252 | const Tensor &label) { 253 | cuda::SoftmaxMultilabelGrad(dst, src, label); 254 | } 255 | 256 | } // namespace mshadow 257 | 258 | 259 | #endif // MXNET_OPERATOR_SOFTMAX_MULTILABEL_OUTPUT_INL_H_ 260 | --------------------------------------------------------------------------------