├── PL_kNN.m ├── demo_data.mat ├── graph_construction.m ├── test_data_aug_gen.m ├── demo.m ├── LabelPropagationSettings.m ├── README.md └── label_propagation.m /PL_kNN.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwangwitsel/PLDA/HEAD/PL_kNN.m -------------------------------------------------------------------------------- /demo_data.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwangwitsel/PLDA/HEAD/demo_data.mat -------------------------------------------------------------------------------- /graph_construction.m: -------------------------------------------------------------------------------- 1 | function S = graph_construction(train_data, k) 2 | %[~,p_data_num] = size(train_p_target); 3 | [p_data_num, ~] = size(train_data); 4 | S = zeros(p_data_num,p_data_num); 5 | %train_p_data = normr(train_p_data); 6 | kdtree = KDTreeSearcher(train_data); 7 | [neighbor,distence] = knnsearch(kdtree,train_data,'k',k+1); 8 | neighbor = neighbor(:,2:k+1); 9 | distence = distence(:,2:k+1); 10 | sigma = mean(mean(distence(:, k))); 11 | for i = 1:p_data_num 12 | for j = 1:k 13 | S(i,neighbor(i,j)) = exp(-distence(i,j)*distence(i,j)/(sigma*sigma)); 14 | end 15 | end 16 | S = S + S'; % to ensure symmetric 17 | % mask = train_p_target' * train_p_target; 18 | % mask = (mask ~= 0); 19 | % mask = double(mask); 20 | % S = mask .* S; 21 | end 22 | -------------------------------------------------------------------------------- /test_data_aug_gen.m: -------------------------------------------------------------------------------- 1 | function test_data_aug = test_data_aug_gen(train_data, label_confidence, prototype, test_data, k) 2 | [num_train, label_num] = size(label_confidence); 3 | num_test = size(test_data, 1); 4 | [neighbor,distence] = knnsearch(train_data, test_data, 'k', k); 5 | neighbor = neighbor(:,1:k); 6 | distence = distence(:,1:k); 7 | sigma = mean(mean(distence(:, k))); 8 | Spt = zeros(num_test, num_train); 9 | for i = 1:num_test 10 | for j = 1:k 11 | Spt(i,neighbor(i,j)) = exp(-distence(i,j)*distence(i,j)/(sigma*sigma)); 12 | end 13 | end 14 | tmp = sum(Spt, 2); 15 | Spt = Spt ./ repmat(tmp, 1, num_train); 16 | test_confidence = Spt * label_confidence; 17 | [~, test_pred_label] = max(test_confidence, [], 2); 18 | test_pred_target = zeros(num_test, label_num); 19 | test_pred_target(sub2ind(size(test_pred_target), 1:num_test, test_pred_label')) = 1; 20 | 21 | test_aug_feature = test_pred_target * prototype; 22 | test_data_aug = [test_data, test_aug_feature]; 23 | end 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /demo.m: -------------------------------------------------------------------------------- 1 | load('demo_data.mat'); 2 | %hyperparameter for PL-kNN 3 | k = 10; 4 | %hyperparameter for PLDA 5 | lambda = 0.01; 6 | % If we want to run PLDA for feature augmentation, set aug_flag = 1; otherwise set aug_flag = 0. 7 | aug_flag = 0; 8 | train_data = zscore(train_data); 9 | test_data = zscore(test_data); 10 | if aug_flag == 1 11 | S = graph_construction(train_data, k); 12 | [label_confidence, prototype] = label_propagation(train_data,train_p_target, S, lambda); 13 | aug_feature = label_confidence * prototype; 14 | train_data_aug = [train_data, aug_feature]; 15 | test_data_aug = test_data_aug_gen(train_data, label_confidence, prototype, test_data, k); 16 | [accuracy,~] = PL_kNN(train_data_aug,train_p_target,test_data_aug,test_target,k); 17 | fprintf('classification accuracy: %.3f\n', accuracy); 18 | elseif aug_flag == 0 19 | [accuracy,~] = PL_kNN(train_data,train_p_target,test_data,test_target,k); 20 | fprintf('classification accuracy: %.3f\n', accuracy); 21 | end 22 | -------------------------------------------------------------------------------- /LabelPropagationSettings.m: -------------------------------------------------------------------------------- 1 | function [H, Aeq, beq, lb, ub, opts] = LabelPropagationSettings(S, train_p_target) 2 | % The label propagation problem is formulated as a standard quadratic programming problem 3 | 4 | [m, l] = size(train_p_target); 5 | total = m*l; 6 | 7 | lb = sparse(total, 1); 8 | ub = reshape(train_p_target, total, 1); 9 | 10 | % H = sparse(total, total); % for large datasets 11 | % D = diag(sum(S,2)); 12 | % for kk = 0:(l-1) 13 | % for ii = 1:m 14 | % for jj = 1:m 15 | % if(ii == jj) 16 | % H(ii+kk*m,jj+kk*m) = 2*(mu+1); 17 | % else 18 | % H(ii+kk*m,jj+kk*m) = -2*mu*S(ii, jj)/(sqrt(D(ii,ii)*D(jj,jj))); % -2u*D^(-1/2)*S*D^(-1/2) 19 | % end 20 | % end 21 | % end 22 | % end 23 | 24 | % H = kron(speye(l,l), 2*(mu+1)*eye(m,m)-2*mu*D^(-1/2)*S*D^(-1/2)); 25 | 26 | d = sum(S,2); % m by 1 vector 27 | %H = kron(speye(l,l), 4*eye(m,m)-4*(S./sqrt(d*d'))); % for small datasets 28 | H = kron(speye(l,l), 4*eye(m,m)-4*(S./(sqrt(d*d') + 1e-10))); % for small datasets 29 | Aeq = sparse(m, total); 30 | for i = 1:m 31 | Aeq(i, i:m:total) = 1; 32 | end 33 | beq = ones(m, 1); 34 | opts = optimoptions('quadprog',... 35 | 'Algorithm','interior-point-convex','Display','off'); 36 | end -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PLDA 2 | This repository is the official implementation of the PLDA approach of the paper "Partial Label Learning with Discrimination Augmentation" and technical details of this approach can be found in the paper. 3 | 4 | ## Requirements 5 | - MATLAB, version 2014a and higher. 6 | 7 | To start, create a directory of your choice and copy the code there. 8 | 9 | Set the path in your MATLAB to add the directory you just created. 10 | Then, run this command to enter the MATLAB environment: 11 | ``` 12 | matlab 13 | ``` 14 | ## Demo 15 | This repository provides a demo which shows the training and testing phase of the PLDA approach coupled with one of the state-of-the-art partial label learning (PLL) approach, PL-kNN. The coupled version for other PLL approaches can be implemented easily by replacing PL-kNN with the chosen PLL approach. 16 | 17 | To run demo.m, run this command in MATLAB command: 18 | 19 | ``` 20 | demo 21 | ``` 22 | 23 | ## Citation 24 | ``` 25 | @inproceedings{KDD22Wang, 26 | author = {Wang, Wei and Zhang, Min-Ling}, 27 | title = {Partial label learning with discrimination augmentation}, 28 | booktitle = {Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, 29 | year = {2022}, 30 | pages = {1920--1928} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /label_propagation.m: -------------------------------------------------------------------------------- 1 | function [F, prototype] = label_propagation(train_p_data, train_p_target, S, lambda) 2 | % 3 | %This function estimate Fp via label propagation(line 1 in Table 1) 4 | % 5 | lp_max_iter = 100; 6 | %delta = 1; 7 | [label_num,p_data_num] = size(train_p_target); 8 | [H, Aeq, beq, lb, ub, opts] = LabelPropagationSettings(S, train_p_target'); 9 | F = train_p_target'; 10 | ins_cap = sum(F,2); 11 | F = F./repmat(ins_cap,1,label_num); 12 | label_cap = sum(F, 1); 13 | proto_conf = F ./ repmat(label_cap, p_data_num, 1); 14 | prototype = proto_conf' * train_p_data; 15 | F_old = F; 16 | for i = 1:lp_max_iter 17 | if mod(i,10)==0 18 | fprintf('label propagation iteration: %d\n',i); 19 | end 20 | cluster_error_mat = pdist2(train_p_data, prototype); 21 | cluster_error_mat2 = cluster_error_mat.^2; 22 | cluster_error_vec = reshape(cluster_error_mat2, p_data_num*label_num, 1); 23 | f_vec = quadprog(H, lambda * cluster_error_vec, [], [], Aeq, beq, lb, ub, [], opts); 24 | F = reshape(f_vec, p_data_num, label_num); 25 | label_cap = sum(F, 1); 26 | proto_conf = F ./ repmat(label_cap, p_data_num, 1); 27 | prototype = proto_conf' * train_p_data; 28 | if abs(norm(F,'fro')-norm(F_old,'fro')) < 1e-2 29 | fprintf('label propagation iteration end at: %d\n',i); 30 | break; 31 | end 32 | F_old = F; 33 | end --------------------------------------------------------------------------------