├── .gitignore
├── README.md
├── fedavg_test.py
├── gradient_disaggregation.py
├── images
└── grad_disaggregated.png
└── setup.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Gradient Disaggregation: Breaking Privacy in Federated Learning by Reconstructing the User Participant Matrix
2 |
3 | ## Introduction
4 |
5 | We break the secure aggregation protocol of federated learning by showing that individual model updates may be recovered from sums given access to user summary metrics (specifically, participation frequency across training rounds). Our method, gradient disaggregation, observes multiple rounds of summed updates, then leverages summary metrics to recover individual updates. Read our paper here: https://arxiv.org/abs/2106.06089.
6 |
7 |
8 |
9 |
10 |
11 | ## Requirements
12 |
13 | ```python
14 | pip install scipy numpy
15 | ```
16 |
17 | ```python
18 | python -m pip install -i https://pypi.gurobi.com gurobipy
19 | ```
20 |
21 | ## Quickstart - Reconstructing Participant Matrix P
22 |
23 | Using the gradient_disaggregation code is fast and easy. Here we demonstrate disaggregating dummy gradients.
24 |
25 | Source the setup script.
26 | ```python
27 | source setup.sh
28 | ```
29 |
30 | Import codebase, generate participant matrix P, synthetic gradients G, aggregated gradients P*G, and participation counts. Then call disaggregate to recover P.
31 | ```python
32 | import gradient_disaggregation
33 | import numpy as np
34 |
35 | if __name__ == "__main__":
36 |
37 | num_users = 50
38 | num_rounds = 200
39 | gradient_size = 200
40 |
41 | # Trying to recover this!
42 | G = np.random.uniform(-1, 1, size=(num_users, gradient_size))
43 |
44 | # Trying to recover this! Tells which users participated in which rounds.
45 | P = np.random.choice([0, 1], size=(num_rounds, num_users), p=[.5,.5])
46 |
47 | # Given. Observed aggregated updates
48 | G_agg = P.dot(G)
49 |
50 | # Given. User summary metrics: for all users, we know how many times they participated across each 10 rounds.
51 | constraints = gradient_disaggregation.compute_P_constraints(P, 10)
52 |
53 | # Disaggregate
54 | P_star = gradient_disaggregation.reconstruct_participant_matrix(G_agg, constraints, verbose=True, multiprocess=True)
55 |
56 | diff = np.sum(np.abs(P_star-P))
57 | if diff == 0:
58 | print("Exactly recovered P!")
59 | else:
60 | print("Failed to recover P!")
61 | ```
62 |
63 | ## Reconstructing Aggregated FedAvg Updates
64 |
65 | Gradient disaggregation can disaggregate noisy aggregated model updates (e.g: FedAvg) -- see fedavg_test.py.
66 |
67 | ## Cite
68 |
69 | ```
70 | @inproceedings{lam2021gradient,
71 | title={Gradient Disaggregation: Breaking Privacy in Federated Learning by Reconstructing the User Participant Matrix},
72 | author={Lam, Maximilian and Wei, Gu-Yeon and Brooks, David and Reddi, Vijay Janapa and Mitzenmacher, Michael},
73 | booktitle={Proceedings of the 38th International Conference on Machine Learning},
74 | year={2021}
75 | }
76 | ```
77 |
78 |
--------------------------------------------------------------------------------
/fedavg_test.py:
--------------------------------------------------------------------------------
1 | import gradient_disaggregation
2 | from torch.autograd import grad
3 | from torchvision import models, datasets, transforms
4 | import copy
5 | import numpy as np
6 | import numpy as np
7 | import sys
8 | import torch
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.nn.functional as F
14 | import torch.optim as optim
15 | import torchvision
16 | import torchvision
17 | import torchvision.transforms as transforms
18 |
19 | torch.manual_seed(1234)
20 |
21 | class Net(nn.Module):
22 | def __init__(self, hidden_size=84):
23 | super(Net, self).__init__()
24 | self.conv1 = nn.Conv2d(3, 6, 5)
25 | self.pool = nn.MaxPool2d(2, 2)
26 | self.conv2 = nn.Conv2d(6, 16, 5)
27 | self.fc1 = nn.Linear(16 * 5 * 5, 120)
28 | self.fc2 = nn.Linear(120, hidden_size)
29 | self.fc3 = nn.Linear(hidden_size, 10)
30 |
31 | def forward(self, x):
32 | x = self.pool(F.relu(self.conv1(x)))
33 | x = self.pool(F.relu(self.conv2(x)))
34 | x = x.view(-1, 16 * 5 * 5)
35 | x = F.relu(self.fc1(x))
36 | x = F.relu(self.fc2(x))
37 | x = self.fc3(x)
38 | return x
39 |
40 | nclasses = 10
41 | net = Net()
42 | criterion = nn.CrossEntropyLoss()
43 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
44 | transform = transforms.Compose(
45 | [transforms.ToTensor(),
46 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
47 |
48 | def get_user_datasets(n_users, n_per_user):
49 | d = []
50 | dst = datasets.CIFAR10("~/.torch", download=True, transform=transform)
51 | loader = iter(torch.utils.data.DataLoader(dst, batch_size=1,
52 | shuffle=True))
53 | k = 0
54 | for i in range(n_users):
55 | user_data = []
56 | for j in range(n_per_user):
57 | element, label = next(loader)
58 | element.share_memory_()
59 | label.share_memory_()
60 | element, label = copy.deepcopy(element), copy.deepcopy(label)
61 | user_data.append((element, label))
62 | d.append(user_data)
63 | return d
64 |
65 | def get_params(n):
66 | all_vs = []
67 | for name, param in n.named_parameters():
68 | all_vs.append(param.detach().numpy().flatten())
69 | return np.concatenate(all_vs)
70 |
71 | def count_params(model):
72 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
73 |
74 | def get_batched_grad_fedavg(user_dataset, local_batchsize, net, epochs=10, momentum=0, lr=1e-2):
75 | copied_net = copy.deepcopy(net)
76 | cur_optimizer = optim.SGD(copied_net.parameters(), lr=lr, momentum=momentum)
77 |
78 | user_dataset_indices = np.random.choice(list(range(len(user_dataset))), size=len(user_dataset), replace=False)
79 | user_dataset = [user_dataset[i] for i in user_dataset_indices]
80 |
81 | for i in range(epochs):
82 | np.random.shuffle(user_dataset)
83 | for j in range(0, len(user_dataset), local_batchsize):
84 |
85 | end = min(j+local_batchsize, len(user_dataset))
86 |
87 | inputs = torch.cat([x[0] for x in user_dataset[j:end]], axis=0)
88 | labels = torch.cat([x[1] for x in user_dataset[j:end]], axis=0)
89 |
90 | cur_optimizer.zero_grad()
91 | outputs = copied_net(inputs)
92 | loss = criterion(outputs, labels)
93 | loss.backward()
94 |
95 | cur_optimizer.step()
96 |
97 | p_net = get_params(net)
98 | p_copied = get_params(copied_net)
99 | diff = p_copied - p_net
100 |
101 | return diff
102 |
103 | def get_batched_grad(user, batchsize, epochs, net, fedavg=True, momentum=0, lr=1e-2):
104 | return get_batched_grad_fedavg(user, batchsize, net, epochs=epochs, momentum=momentum, lr=lr)
105 |
106 | def aggregate_grads(P, user_datasets, batchsize, epochs, momentum=0, lr=1e-2):
107 |
108 | copied_net = copy.deepcopy(net)
109 |
110 | all_grads = []
111 | gradient = get_params(copied_net).shape[-1]
112 | for row in range(P.shape[0]):
113 | print("Aggregating... %d of %d" % (row, P.shape[0]))
114 | sys.stdout.flush()
115 | grads = np.zeros((gradient,))
116 | for col in range(P.shape[1]):
117 | if P[row,col] == 1:
118 | batched_grad = get_batched_grad(user_datasets[col], batchsize, epochs, copied_net, momentum=momentum, lr=lr)
119 | grads += batched_grad
120 | all_grads.append(grads)
121 | return np.stack(all_grads)
122 |
123 | if __name__ == "__main__":
124 | n_users = 20
125 | n_rounds = 60
126 | dataset_size_per_user = 64
127 | batchsize = 16
128 | epochs = 4
129 | granularity = 10
130 |
131 | user_datasets = get_user_datasets(n_users, dataset_size_per_user)
132 | P = np.random.choice([0, 1], size=(n_rounds, n_users), p=[.8, .2])
133 | G_agg = aggregate_grads(P, user_datasets, batchsize, epochs)
134 | constraints = gradient_disaggregation.compute_P_constraints(P, granularity)
135 |
136 | P_star = gradient_disaggregation.reconstruct_participant_matrix(G_agg, constraints, noisy=True, verbose=True)
137 |
138 | diff = np.sum(np.abs(P_star-P))
139 | if diff == 0:
140 | print("Exactly recovered P!")
141 | else:
142 | print("Failed to recover P!")
143 |
--------------------------------------------------------------------------------
/gradient_disaggregation.py:
--------------------------------------------------------------------------------
1 | # Please export MKL_NUM_THREADS=1 if multiprocess disaggregating
2 |
3 | from numpy.linalg import svd
4 | import copy
5 | import numpy as np
6 | import random
7 | import scipy
8 | import scipy.linalg
9 | import sys
10 | import time
11 | import gurobipy as gp
12 | from gurobipy import GRB
13 | import multiprocessing as mp
14 | import os
15 |
16 | np.set_printoptions(threshold=sys.maxsize)
17 |
18 | def relative_err(A, B):
19 | return np.mean(np.abs(A-B)/(np.abs(B)+1e-9))
20 |
21 | def generate_A(r, n):
22 | return np.random.normal(0,1,size=(r,n))
23 |
24 | def generate_T(m, r, sparsity=.2):
25 | return np.float32(np.rint(np.random.uniform(0, 1, (m, r)) <= sparsity))
26 |
27 | def nullspace(A, atol=1e-13, rtol=0):
28 | A = np.atleast_2d(A)
29 | u, s, vh = svd(A)
30 | tol = max(atol, rtol * s[0])
31 | nnz = (s >= tol).sum()
32 | ns = vh[nnz:].conj().T
33 | return ns
34 |
35 | def compute_P_constraints(T, log_intervals):
36 | # From T, extract all log data that we need
37 | # - Number of participants per round
38 | # - Per user participation ums across rounds
39 | r = T.shape[1]
40 | m = T.shape[0]
41 | round_points = np.sum(T, axis=1)
42 | participant_points = []
43 | for user_id in range(r):
44 | user_participation_vector = T[:,user_id]
45 | interval = 0
46 | individual_participant_points = []
47 | while interval < m:
48 | n_participated = np.sum(user_participation_vector[interval:interval+log_intervals])
49 | individual_participant_points.append(((interval,interval+log_intervals), n_participated))
50 | interval += log_intervals
51 | participant_points.append(individual_participant_points)
52 | return participant_points
53 |
54 | def disaggregate_grads(grads, participant_points, interval=None, gt=None, gt_P=None, verbose=False, noisy=True, override_dim_to_use=None):
55 |
56 | def intersects(x, y):
57 | return not (x[1] < y[0] or y[1] < x[0]) and (x[1] <= y[1] and x[0] >= y[0])
58 |
59 | # Params
60 | nrounds, ndim = grads.shape
61 | nusers = len(participant_points)
62 |
63 | # Try to determine interval from participant constraints
64 | max_interval = max([x[0][1]-x[0][0] for y in participant_points for x in y])
65 | interval = max(nusers*5,max_interval*10) if interval is None else interval
66 | interval = min(interval, grads.shape[0])
67 |
68 | # Calculate dimension based on matrix rank
69 | if override_dim_to_use is not None:
70 | ndim_to_use = override_dim_to_use
71 | else:
72 | ndim_to_use = min(ndim, max(nrounds, nusers))
73 | for ndim_to_use in range(max(nrounds,nusers), ndim, (ndim-max(nrounds,nusers))//10):
74 | if np.linalg.matrix_rank(grads[0:interval,0:ndim_to_use]) >= nusers:
75 | break
76 | ndim_to_use = ndim
77 |
78 | # Disaggregate gradients averaging over rounds
79 | disaggregated = 0
80 | c = 0
81 | all_P = []
82 | for total, row in list(enumerate(range(0, nrounds, interval))):
83 | if verbose:
84 | print("Set %d of %d" % (total, nrounds//interval))
85 | lower,upper = row,row+interval
86 | upper = min(upper, nrounds)
87 | cutoff_start = 0
88 | if upper-lower < interval:
89 | cutoff_start = np.abs((upper-interval)-lower)
90 | lower = upper-interval
91 |
92 | aggregated_chunk = grads[lower:upper,0:ndim_to_use]
93 | filtered_participant_points = [[y for y in x if intersects(y[0], (lower,upper))] for x in participant_points]
94 | filtered_participant_points = [[((y[0][0]-lower, y[0][1]-lower), y[1]) for y in x] for x in filtered_participant_points]
95 | Ps = reconstruct_participant_matrix(aggregated_chunk, filtered_participant_points, verbose=verbose, noisy=noisy)
96 | if len(Ps) == 1:
97 | P = Ps[0]
98 | all_P.append(P[cutoff_start:,:])
99 | disaggregated += np.linalg.lstsq(P, grads[lower:upper,:], rcond=None)[0]
100 | c += 1
101 | if gt_P is not None:
102 | ham_dist = np.sum(np.abs(gt_P[lower:upper,:]-P))
103 | if verbose:
104 | print("Ham distance of P Matrix: %f" % ham_dist)
105 | if gt is not None:
106 | rel_err = relative_err(disaggregated/c, gt)
107 | if verbose:
108 | print("Relative Error vs Ground Truth: %f" % rel_err)
109 |
110 | if c == 0:
111 | return float("inf")
112 | return disaggregated / c, np.concatenate(all_P, axis=0)
113 |
114 | def reconstruct_participant_matrix(D, participant_points,
115 | top_k=None, verbose=False, noisy=False, column_timelimit=None,
116 | exit_after_tle=False, multiprocess=True):
117 | # D - T*Grads where T is the user participant matrix
118 | # participant_points - array of (round #, count), where count is
119 | # the cumulative number of participations by
120 | # the specific user at round #
121 | # round_point - number of participants per round
122 |
123 | # Candidates for participant vectors
124 |
125 | if not multiprocess:
126 | participant_vectors = []
127 | n_possibilities_per_user = []
128 | statuses = []
129 | cached_space = None
130 | nrounds = D.shape[0]
131 | for i, individual_participant_point in enumerate(participant_points):
132 | if top_k is not None and len(participant_vectors) >= top_k:
133 | break
134 | if verbose:
135 | print("Reconstructing: %d of %d" % (i,len(participant_points)))
136 | p_vector, model, cached_space = compute_participant_candidate_vectors(D, individual_participant_point, nrounds, len(participant_points), noisy=noisy, timelimit=column_timelimit, cached_space=cached_space)
137 | if model.status == GRB.TIME_LIMIT and exit_after_tle:
138 | if verbose:
139 | print("Time limit exceeded")
140 | return []
141 | participant_vectors.append(p_vector)
142 | n_possibilities_per_user.append(len(p_vector))
143 | if verbose:
144 | print("Obtained candidate vector for user %d" % (i))
145 | sys.stdout.flush()
146 | if len(p_vector) <= 0:
147 | return []
148 |
149 | if verbose:
150 | print("# of candidate vectors per user: %s" % str(n_possibilities_per_user))
151 |
152 | # Naive method -- just stack and return
153 | naive = np.stack([x[0] for x in participant_vectors]).T
154 | return [np.rint(naive)]
155 |
156 | # Multiprocessing code
157 | nrounds = D.shape[0]
158 | nusers = len(participant_points)
159 |
160 | cached_space = compute_space(D, nusers, noisy=noisy)
161 | arguments = []
162 | for i, participant_point in enumerate(participant_points):
163 | arguments.append((i, None, participant_point, nrounds, nusers, noisy, column_timelimit, False,cached_space, verbose, exit_after_tle))
164 | if top_k is not None:
165 | arguments = arguments[:top_k]
166 |
167 | cores = mp.cpu_count()
168 | pool = mp.Pool(processes=cores)
169 | try:
170 | indices_and_participant_vectors = pool.imap_unordered(compute_participant_candidate_vectors_multicore_wrapper, arguments)
171 | indices_and_participant_vectors = sorted(indices_and_participant_vectors, key=lambda x:x[0])
172 | participant_vectors = [x[1] for x in indices_and_participant_vectors]
173 | except Exception:
174 | if verbose:
175 | print("Timelimit exceeded on one of the workers. Exiting")
176 | pool.close()
177 | pool.terminate()
178 | participant_vectors = [np.zeros((1,nrounds)) for i in range(nusers)]
179 | else:
180 | pool.close()
181 | pool.join()
182 |
183 | participant_vectors = np.concatenate([x for x in participant_vectors], axis=0).T
184 | return [np.rint(participant_vectors)]
185 |
186 | def compute_space(D, nusers, noisy=True):
187 |
188 | # We use SVD to handle noisy gradients
189 | if noisy:
190 | u,s,vh = np.linalg.svd(D, full_matrices=False)
191 | cutoff = sorted(list(s.flatten()), reverse=True)[nusers]
192 | s[s= r)
271 |
272 | participant_points = compute_P_constraints(T, log_intervals)
273 |
274 | # Solve
275 | T_reconstructed_candidates = reconstruct_participant_matrix(D, participant_points, verbose=True)
276 | T_reconstructed = T_reconstructed_candidates[0]
277 | err = np.linalg.norm(T_reconstructed-T)
278 | print("Error between reconstructed vs truth: %f" %
279 | err)
280 | assert(err <= 1e-5)
281 |
282 | def test_recover_grad_sanity():
283 |
284 | # m - number of rounds
285 | # r - number of users
286 | # n - dimension of gradient
287 | # sparsity - percent participate per round
288 | # log_intervals - every log_intervals rounds will log how many times a user participated
289 | m = 5000
290 | r = 80
291 | n = 400
292 | sparsity=.1
293 | log_intervals = 2*int(1/sparsity)
294 |
295 | # Create data
296 | user_grads = np.random.normal(0, 1, size=(r,n))
297 |
298 | # Construct rounds
299 | P = []
300 | aggregated = []
301 | total = 0
302 | for round in range(m):
303 | row = generate_T(1, r, sparsity=sparsity)
304 | with_noise = user_grads + np.random.normal(0, .2, size=user_grads.shape)
305 | total += with_noise
306 | row_grads = row.dot(with_noise)
307 | P.append(row)
308 | aggregated.append(row_grads)
309 | P = np.concatenate(P, axis=0)
310 | aggregated = np.concatenate(aggregated, axis=0)
311 |
312 | print("Number of rounds: %d" % m)
313 | print("Number of users: %d" % r)
314 | print("Number of grad dims: %d" % n)
315 | print("Rank of participant matrix: %d" % np.linalg.matrix_rank(P))
316 | assert(np.linalg.matrix_rank(P) >= r)
317 |
318 | participant_points = compute_P_constraints(P, log_intervals)
319 |
320 | disaggregated, P = disaggregate_grads(aggregated, participant_points, verbose=True, gt=user_grads, gt_P=P)
321 | rel_err = relative_err(disaggregated, user_grads)
322 | print("Relative Error (percentage): %f" % rel_err)
323 |
324 |
325 | def test_recover_grad():
326 |
327 | # m - number of rounds
328 | # r - number of users
329 | # n - dimension of gradient
330 | # sparsity - percent participate per round
331 | # log_intervals - every log_intervals rounds will log how many times a user participated
332 | m = 400
333 | r = 40
334 | n = 400
335 | sparsity=.5
336 | log_intervals = int(1/sparsity)*5
337 | noise = .2
338 | mat_intervals = 100
339 |
340 | # Create data
341 | A = np.random.normal(1, .1, size=(r,n))
342 | Ds = []
343 | Ts = []
344 | for i in range(0, m, mat_intervals):
345 | random_noise = np.random.normal(0, noise, size=(r,n))
346 | T = generate_T(mat_intervals,r,sparsity=sparsity)
347 | Ts.append(T)
348 | Ds.append(T.dot(A+random_noise))
349 | D = np.concatenate(Ds, axis=0)
350 | T = np.concatenate(Ts, axis=0)
351 |
352 |
353 | print("Number of rounds: %d" % m)
354 | print("Number of users: %d" % r)
355 | print("Number of grad dims: %d" % n)
356 | print("Rank of participant matrix: %d" % np.linalg.matrix_rank(T))
357 | assert(np.linalg.matrix_rank(T) >= r)
358 |
359 | participant_points = compute_P_constraints(T, log_intervals)
360 |
361 | disaggregated, P = disaggregate_grads(D, participant_points, verbose=True, gt=A, interval=mat_intervals)
362 | rel_err = relative_err(disaggregated, A)
363 | print("Relative Error (percentage): %f" % rel_err)
364 |
365 | if __name__=="__main__":
366 |
367 | np.random.seed(0)
368 | #test_recover_p()
369 | #test_recover_grad()
370 | test_recover_grad_sanity()
371 |
--------------------------------------------------------------------------------
/images/grad_disaggregated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gdisag/gradient_disaggregation/416b0e3175d97d0f269baaf7bc54a93d241e992d/images/grad_disaggregated.png
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | toplevel=`git rev-parse --show-toplevel`
2 | export PYTHONPATH=$PYTHONPATH:$toplevel/
3 | export MKL_NUM_THREADS=1
4 |
5 |
--------------------------------------------------------------------------------