├── .gitattributes ├── .gitignore ├── Benchmark.ipynb ├── Proof Weighted Sampling.pdf ├── README.md ├── csrc ├── cpu │ ├── reservoir_sampling.cpp │ └── reservoir_sampling.h ├── cuda │ ├── reservoir_sampling.cu │ └── reservoir_sampling.cuh └── sampling.cpp └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | build 3 | *.so 4 | -------------------------------------------------------------------------------- /Benchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np\n", 12 | "from torch_sampling import choice" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "# Checking non-contiguous tensors" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "True\n", 32 | "False\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "x = torch.arange(10)\n", 38 | "y = x[::3]\n", 39 | "for t in [x, y]:\n", 40 | " print(t.is_contiguous())" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "Sampling with replacement:\n", 53 | "tensor([3, 3, 9])\n", 54 | "tensor([3, 9, 3])\n", 55 | "tensor([3, 9, 3])\n", 56 | "tensor([0, 3, 6])\n", 57 | "tensor([9, 3, 6])\n", 58 | "Sampling without replacement:\n", 59 | "tensor([0, 3, 6])\n", 60 | "tensor([0, 3, 6])\n", 61 | "tensor([0, 3, 6])\n", 62 | "tensor([0, 3, 9])\n", 63 | "tensor([0, 3, 9])\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "k=3\n", 69 | "print(\"Sampling with replacement:\")\n", 70 | "for _ in range(5):\n", 71 | " print(choice(y, k, True))\n", 72 | "print(\"Sampling without replacement:\")\n", 73 | "for _ in range(5):\n", 74 | " print(choice(y, k, False))" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "True\n", 87 | "False\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "x = torch.arange(10).cuda()\n", 93 | "y = x[::3]\n", 94 | "for t in [x, y]:\n", 95 | " print(t.is_contiguous())" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "Sampling with replacement:\n", 108 | "tensor([0, 0, 9], device='cuda:0')\n", 109 | "tensor([3, 3, 9], device='cuda:0')\n", 110 | "tensor([3, 6, 6], device='cuda:0')\n", 111 | "tensor([0, 3, 9], device='cuda:0')\n", 112 | "tensor([3, 0, 9], device='cuda:0')\n", 113 | "Sampling without replacement:\n", 114 | "tensor([0, 3, 6], device='cuda:0')\n", 115 | "tensor([9, 3, 6], device='cuda:0')\n", 116 | "tensor([0, 3, 9], device='cuda:0')\n", 117 | "tensor([0, 9, 6], device='cuda:0')\n", 118 | "tensor([0, 3, 9], device='cuda:0')\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "k=3\n", 124 | "print(\"Sampling with replacement:\")\n", 125 | "for _ in range(5):\n", 126 | " print(choice(y, k, True))\n", 127 | "print(\"Sampling without replacement:\")\n", 128 | "for _ in range(5):\n", 129 | " print(choice(y, k, False))" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "# Checking determinism" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 6, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "tensor([5, 1, 6])\n", 149 | "tensor([5, 1, 6])\n", 150 | "tensor([5, 1, 6])\n", 151 | "tensor([5, 1, 6])\n", 152 | "tensor([5, 1, 6])\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "x = torch.arange(10)\n", 158 | "k=3\n", 159 | "for _ in range(5):\n", 160 | " torch.manual_seed(1234)\n", 161 | " print(choice(x, k, True))" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 7, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "tensor([5, 1, 9], device='cuda:0')\n", 174 | "tensor([5, 1, 9], device='cuda:0')\n", 175 | "tensor([5, 1, 9], device='cuda:0')\n", 176 | "tensor([5, 1, 9], device='cuda:0')\n", 177 | "tensor([5, 1, 9], device='cuda:0')\n" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "x = torch.arange(10).cuda()\n", 183 | "for _ in range(5):\n", 184 | " torch.manual_seed(1234)\n", 185 | " print(choice(x, k, True))" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "# Benchmarks against NumPy - Uniform" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 8, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "x = torch.arange(10**4)\n", 202 | "x_np = x.numpy()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 9, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "95.4 µs ± 489 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 215 | "126 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 216 | "48.4 µs ± 675 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 217 | "138 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 218 | ] 219 | } 220 | ], 221 | "source": [ 222 | "k=9000\n", 223 | "%timeit choice(x, k, True)\n", 224 | "%timeit np.random.choice(x_np, k, True)\n", 225 | "%timeit choice(x, k, False)\n", 226 | "%timeit np.random.choice(x_np, k, False)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 10, 232 | "metadata": {}, 233 | "outputs": [ 234 | { 235 | "name": "stdout", 236 | "output_type": "stream", 237 | "text": [ 238 | "5.71 µs ± 15.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n", 239 | "17.3 µs ± 73.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n", 240 | "14.4 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n", 241 | "130 µs ± 2.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "k=100\n", 247 | "%timeit choice(x, k, True)\n", 248 | "%timeit np.random.choice(x_np, k, True)\n", 249 | "%timeit choice(x, k, False)\n", 250 | "%timeit np.random.choice(x_np, k, False)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 11, 256 | "metadata": {}, 257 | "outputs": [ 258 | { 259 | "name": "stdout", 260 | "output_type": "stream", 261 | "text": [ 262 | "53.2 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 263 | "71.9 µs ± 152 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 264 | "72.5 µs ± 93.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 265 | "71.7 µs ± 258 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "k=4500\n", 271 | "%timeit choice(x, k, True)\n", 272 | "%timeit np.random.choice(x_np, k, True)\n", 273 | "%timeit choice(x, k, False)\n", 274 | "%timeit np.random.choice(x_np, k, True)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "metadata": {}, 280 | "source": [ 281 | "# Benchmarks against NumPy - Weighted" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 12, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "weights = torch.rand(10**4)\n", 291 | "weights /= weights.sum()\n", 292 | "weights_np = weights.numpy()" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 13, 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "name": "stdout", 302 | "output_type": "stream", 303 | "text": [ 304 | "773 µs ± 1.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", 305 | "872 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", 306 | "373 µs ± 999 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", 307 | "2.92 ms ± 6.24 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "k=9000\n", 313 | "%timeit choice(x, k, True, weights)\n", 314 | "%timeit np.random.choice(x_np, k, True, weights_np)\n", 315 | "%timeit choice(x, k, False, weights)\n", 316 | "%timeit np.random.choice(x_np, k, False, weights_np)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 14, 322 | "metadata": {}, 323 | "outputs": [ 324 | { 325 | "name": "stdout", 326 | "output_type": "stream", 327 | "text": [ 328 | "407 µs ± 4.78 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", 329 | "495 µs ± 8.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", 330 | "295 µs ± 600 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", 331 | "1.27 ms ± 4.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" 332 | ] 333 | } 334 | ], 335 | "source": [ 336 | "k=4500\n", 337 | "%timeit choice(x, k, True, weights)\n", 338 | "%timeit np.random.choice(x_np, k, True, weights_np)\n", 339 | "%timeit choice(x, k, False, weights)\n", 340 | "%timeit np.random.choice(x_np, k, False, weights_np)" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 15, 346 | "metadata": {}, 347 | "outputs": [ 348 | { 349 | "name": "stdout", 350 | "output_type": "stream", 351 | "text": [ 352 | "43.7 µs ± 536 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 353 | "110 µs ± 903 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 354 | "235 µs ± 3.67 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", 355 | "172 µs ± 2.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 356 | ] 357 | } 358 | ], 359 | "source": [ 360 | "k=100\n", 361 | "%timeit choice(x, k, True, weights)\n", 362 | "%timeit np.random.choice(x_np, k, True, weights_np)\n", 363 | "%timeit choice(x, k, False, weights)\n", 364 | "%timeit np.random.choice(x_np, k, False, weights_np)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "# Check performance for multi-d tensors" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 16, 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "x = torch.arange(10**4).view(-1, 2)\n", 381 | "n = x.size(0)\n", 382 | "k = 3\n", 383 | "idx = torch.arange(n)" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 17, 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "name": "stdout", 393 | "output_type": "stream", 394 | "text": [ 395 | "70.8 µs ± 1.88 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 396 | "12.1 µs ± 103 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n", 397 | "19 µs ± 172 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" 398 | ] 399 | } 400 | ], 401 | "source": [ 402 | "%timeit x[torch.randperm(n)[:k]]\n", 403 | "%timeit x[choice(idx, k, True)]\n", 404 | "%timeit x[choice(idx, k, False)]" 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": {}, 410 | "source": [ 411 | "# Checking the performance" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 18, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "x = torch.arange(10**4)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": {}, 426 | "source": [ 427 | "# Case 1: k big" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 19, 433 | "metadata": {}, 434 | "outputs": [ 435 | { 436 | "name": "stdout", 437 | "output_type": "stream", 438 | "text": [ 439 | "95.4 µs ± 75.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 440 | "45.8 µs ± 56.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 441 | "167 µs ± 3.14 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 442 | ] 443 | } 444 | ], 445 | "source": [ 446 | "k = 9000\n", 447 | "%timeit choice(x, k, True)\n", 448 | "%timeit choice(x, k, False)\n", 449 | "%timeit x[torch.randperm(x.numel())[:k]]" 450 | ] 451 | }, 452 | { 453 | "cell_type": "markdown", 454 | "metadata": {}, 455 | "source": [ 456 | "# Case 2: k small" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": 20, 462 | "metadata": {}, 463 | "outputs": [ 464 | { 465 | "name": "stdout", 466 | "output_type": "stream", 467 | "text": [ 468 | "5.79 µs ± 5.46 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n", 469 | "14.3 µs ± 16.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n", 470 | "130 µs ± 209 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 471 | ] 472 | } 473 | ], 474 | "source": [ 475 | "k = 100\n", 476 | "%timeit choice(x, k, True)\n", 477 | "%timeit choice(x, k, False)\n", 478 | "%timeit x[torch.randperm(x.numel())[:k]]" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "# Case 3: k medium" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 21, 491 | "metadata": {}, 492 | "outputs": [ 493 | { 494 | "name": "stdout", 495 | "output_type": "stream", 496 | "text": [ 497 | "52.5 µs ± 32.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 498 | "72.9 µs ± 681 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", 499 | "151 µs ± 389 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 500 | ] 501 | } 502 | ], 503 | "source": [ 504 | "k = 4500\n", 505 | "%timeit choice(x, k, True)\n", 506 | "%timeit choice(x, k, False)\n", 507 | "%timeit x[torch.randperm(x.numel())[:k]]" 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": {}, 513 | "source": [ 514 | "# CUDA vs CPU" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 22, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "x_cpu = torch.arange(10**7)\n", 524 | "x_cuda = x_cpu.cuda()\n", 525 | "k = 10**4" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 23, 531 | "metadata": {}, 532 | "outputs": [ 533 | { 534 | "name": "stdout", 535 | "output_type": "stream", 536 | "text": [ 537 | "208 µs ± 1.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", 538 | "27.1 µs ± 269 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 539 | ] 540 | } 541 | ], 542 | "source": [ 543 | "%timeit choice(x_cpu, k, True)\n", 544 | "%timeit choice(x_cuda, k, True);torch.cuda.synchronize()" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": 24, 550 | "metadata": {}, 551 | "outputs": [ 552 | { 553 | "name": "stdout", 554 | "output_type": "stream", 555 | "text": [ 556 | "6.87 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", 557 | "3.33 ms ± 2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 558 | ] 559 | } 560 | ], 561 | "source": [ 562 | "%timeit choice(x_cpu, k, False)\n", 563 | "%timeit choice(x_cuda, k, False);torch.cuda.synchronize()" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": 25, 569 | "metadata": {}, 570 | "outputs": [], 571 | "source": [ 572 | "weights_cpu = torch.ones(10**7).double()\n", 573 | "weights_cuda = weights_cpu.cuda()" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 26, 579 | "metadata": {}, 580 | "outputs": [ 581 | { 582 | "name": "stdout", 583 | "output_type": "stream", 584 | "text": [ 585 | "335 ms ± 4.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", 586 | "84 ms ± 618 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 587 | ] 588 | } 589 | ], 590 | "source": [ 591 | "%timeit choice(x_cpu, k, False, weights_cpu)\n", 592 | "%timeit choice(x_cuda, k, False, weights_cuda);torch.cuda.synchronize()" 593 | ] 594 | }, 595 | { 596 | "cell_type": "markdown", 597 | "metadata": {}, 598 | "source": [ 599 | "# Checking distributions - Uniform" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": 27, 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [ 608 | "x = torch.arange(10)" 609 | ] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": {}, 614 | "source": [ 615 | "### CPP Extension" 616 | ] 617 | }, 618 | { 619 | "cell_type": "code", 620 | "execution_count": 28, 621 | "metadata": {}, 622 | "outputs": [ 623 | { 624 | "data": { 625 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAASvklEQVR4nO3df6zd9X3f8ecrGBIgGebHrUVtr6aKlQ5FCiFXzBld1OGmAlLFaEoY0RY85Mr9g3VJmdTS/hNV2h9BqkqLNCFZcTqzJaSUBGFlKIMZump/QHv5EX4m44ZAbA/wDQHShGUJ6Xt/nI/LsWP7nut77jn2J8+HdHQ+38/38z2f97m6fvl7P+fHN1WFJKkvb5t2AZKk8TPcJalDhrskdchwl6QOGe6S1KFV0y4A4LzzzqsNGzZMuwxJOqk8/PDD362qmSPtOyHCfcOGDczNzU27DEk6qSR54Wj7XJaRpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOnRCfUNXJY8ON/20q8z7/2Y9MZV7pZOWZuyR1yHCXpA4Z7pLUIcNdkjrkC6qSThi+YD8+hvsy+Iso6URluEs6xLROWjRerrlLUodGOnNP8rvAbwEFPAFcB5wPfAk4F3gY+GRV/TjJ24HbgA8ArwD/qqqeH3/pkjQe0/xrZaWWWRc9c0+yFvj3wGxVvRc4BbgGuAm4uareDbwKbGuHbANebf03t3GSpAkadVlmFXB6klXAGcCLwGXAnW3/LuCq1t7Stmn7NyfJeMqVJI1i0WWZqtqf5I+B7wD/F7iXwTLMa1X1Zhu2D1jb2muBve3YN5O8zmDp5rtjrl0/R3xnkrQ0oyzLnM3gbPwC4BeBM4HLlztxku1J5pLMLSwsLPfhJElDRlmW+XXg21W1UFU/Ab4CXAqsbss0AOuA/a29H1gP0PafxeCF1UNU1Y6qmq2q2ZmZmWU+DUnSsFHC/TvApiRntLXzzcDTwAPAx9qYrcDdrb27bdP2319VNb6SJUmLWTTcq+ohBi+MPsLgbZBvA3YAvw/ckGSewZr6znbITuDc1n8DcOMK1C1JOoaR3udeVZ8BPnNY93PAJUcY+yPg48svTZJ0vPyEqiR1yO+WkY6hx08u6ueDZ+6S1CHDXZI65LLMScivZJW0GM/cJalDJ/2Zu2ex6pW/21oOz9wlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDo1wg+z1JHhu6fT/Jp5Ock+S+JM+2+7Pb+CS5Jcl8kseTXLzyT0OSNGyUy+x9s6ouqqqLgA8AbwB3Mbh83p6q2gjs4a3L6V0BbGy37cCtK1G4JOnolrossxn4VlW9AGwBdrX+XcBVrb0FuK0GHgRWJzl/LNVKkkay1HC/Bri9tddU1Yut/RKwprXXAnuHjtnX+g6RZHuSuSRzCwsLSyxDknQsI4d7ktOAjwJ/efi+qiqgljJxVe2oqtmqmp2ZmVnKoZKkRSzlzP0K4JGqerltv3xwuaXdH2j9+4H1Q8eta32SpAlZSrh/greWZAB2A1tbeytw91D/te1dM5uA14eWbyRJEzDSxTqSnAl8GPjtoe7PAnck2Qa8AFzd+u8BrgTmGbyz5rqxVStJGslI4V5VPwTOPazvFQbvnjl8bAHXj6U6SdJx8ROqktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdWikcE+yOsmdSb6R5JkkH0xyTpL7kjzb7s9uY5PkliTzSR5PcvHKPgVJ0uFGPXP/M+BrVfUrwPuAZ4AbgT1VtRHY07ZhcK3Vje22Hbh1rBVLkha1aLgnOQv4ELAToKp+XFWvAVuAXW3YLuCq1t4C3FYDDwKrD15IW5I0GaOcuV8ALAB/nuTRJJ9r11RdM3Th65eANa29Ftg7dPy+1neIJNuTzCWZW1hYOP5nIEn6GaOE+yrgYuDWqno/8EPeWoIB/uG6qbWUiatqR1XNVtXszMzMUg6VJC1ilHDfB+yrqofa9p0Mwv7lg8st7f5A278fWD90/LrWJ0makEXDvapeAvYmeU/r2gw8DewGtra+rcDdrb0buLa9a2YT8PrQ8o0kaQJWjTjud4AvJDkNeA64jsF/DHck2Qa8AFzdxt4DXAnMA2+0sZKkCRop3KvqMWD2CLs2H2FsAdcvsy5J0jL4CVVJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6NFK4J3k+yRNJHksy1/rOSXJfkmfb/dmtP0luSTKf5PEkF6/kE5Ak/aylnLn/i6q6qKoOXrTjRmBPVW0E9vDWRbOvADa223bg1nEVK0kazXKWZbYAu1p7F3DVUP9tNfAgsPrghbQlSZMxargXcG+Sh5Nsb31rhi58/RKwprXXAnuHjt3X+g6RZHuSuSRzCwsLx1G6JOloRr1A9q9W1f4kvwDcl+QbwzurqpLUUiauqh3ADoDZ2dklHStJOraRztyran+7PwDcBVwCvHxwuaXdH2jD9wPrhw5f1/okSROyaLgnOTPJuw62gd8AngR2A1vbsK3A3a29G7i2vWtmE/D60PKNJGkCRlmWWQPcleTg+C9W1deS/C1wR5JtwAvA1W38PcCVwDzwBnDd2KuWJB3TouFeVc8B7ztC/yvA5iP0F3D9WKqTJB0XP6EqSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SerQyOGe5JQkjyb5atu+IMlDSeaT/EWS01r/29v2fNu/YWVKlyQdzVLO3D8FPDO0fRNwc1W9G3gV2Nb6twGvtv6b2zhJ0gSNFO5J1gEfAT7XtgNcBtzZhuwCrmrtLW2btn9zGy9JmpBRz9z/FPg94O/b9rnAa1X1ZtveB6xt7bXAXoC2//U2/hBJtieZSzK3sLBwnOVLko5k0XBP8pvAgap6eJwTV9WOqpqtqtmZmZlxPrQk/dxbNcKYS4GPJrkSeAfwj4A/A1YnWdXOztcB+9v4/cB6YF+SVcBZwCtjr1ySdFSLnrlX1R9U1bqq2gBcA9xfVf8aeAD4WBu2Fbi7tXe3bdr++6uqxlq1JOmYlvM+998Hbkgyz2BNfWfr3wmc2/pvAG5cXomSpKUaZVnmH1TVXwF/1drPAZccYcyPgI+PoTZJ0nHyE6qS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1aJRrqL4jyd8k+XqSp5L8Ueu/IMlDSeaT/EWS01r/29v2fNu/YWWfgiTpcKOcuf8/4LKqeh9wEXB5kk3ATcDNVfVu4FVgWxu/DXi19d/cxkmSJmiUa6hWVf2gbZ7abgVcBtzZ+ncBV7X2lrZN2785ScZWsSRpUSOtuSc5JcljwAHgPuBbwGtV9WYbsg9Y29prgb0Abf/rDK6xevhjbk8yl2RuYWFhec9CknSIkcK9qn5aVRcB6xhcN/VXljtxVe2oqtmqmp2ZmVnuw0mShizp3TJV9RrwAPBBYHWSgxfYXgfsb+39wHqAtv8s4JWxVCtJGsko75aZSbK6tU8HPgw8wyDkP9aGbQXubu3dbZu2//6qqnEWLUk6tlWLD+F8YFeSUxj8Z3BHVX01ydPAl5L8R+BRYGcbvxP4L0nmge8B16xA3ZKkY1g03KvqceD9R+h/jsH6++H9PwI+PpbqJEnHxU+oSlKHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1KFRrsS0PskDSZ5O8lSST7X+c5Lcl+TZdn9260+SW5LMJ3k8ycUr/SQkSYca5cz9TeA/VNWFwCbg+iQXAjcCe6pqI7CnbQNcAWxst+3ArWOvWpJ0TIuGe1W9WFWPtPbfMbh+6lpgC7CrDdsFXNXaW4DbauBBBhfSPn/slUuSjmpJa+5JNjC45N5DwJqqerHteglY09prgb1Dh+1rfYc/1vYkc0nmFhYWlli2JOlYRg73JO8Evgx8uqq+P7yvqgqopUxcVTuqaraqZmdmZpZyqCRpESOFe5JTGQT7F6rqK6375YPLLe3+QOvfD6wfOnxd65MkTcgo75YJsBN4pqr+ZGjXbmBra28F7h7qv7a9a2YT8PrQ8o0kaQJWjTDmUuCTwBNJHmt9fwh8FrgjyTbgBeDqtu8e4EpgHngDuG6sFUuSFrVouFfV/wJylN2bjzC+gOuXWZckaRn8hKokdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdGuVKTJ9PciDJk0N95yS5L8mz7f7s1p8ktySZT/J4kotXsnhJ0pGNcub+n4HLD+u7EdhTVRuBPW0b4ApgY7ttB24dT5mSpKVYNNyr6q+B7x3WvQXY1dq7gKuG+m+rgQeB1Qcvoi1JmpzjXXNfM3TR65eANa29Ftg7NG5f65MkTdCyX1Bt10ytpR6XZHuSuSRzCwsLyy1DkjTkeMP95YPLLe3+QOvfD6wfGreu9f2MqtpRVbNVNTszM3OcZUiSjuR4w303sLW1twJ3D/Vf2941swl4fWj5RpI0IasWG5DkduDXgPOS7AM+A3wWuCPJNuAF4Oo2/B7gSmAeeAO4bgVqliQtYtFwr6pPHGXX5iOMLeD65RYlSVoeP6EqSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SerQioR7ksuTfDPJfJIbV2IOSdLRjT3ck5wC/CfgCuBC4BNJLhz3PJKko1uJM/dLgPmqeq6qfgx8CdiyAvNIko5i0WuoHoe1wN6h7X3APz18UJLtwPa2+YMk3zzO+c4Dvnucx46TdRzKOk6sGsA6DndC1JGbllXHLx1tx0qE+0iqagewY7mPk2SuqmbHUJJ1WEe3NVjHz18dK7Essx9YP7S9rvVJkiZkJcL9b4GNSS5IchpwDbB7BeaRJB3F2JdlqurNJP8O+O/AKcDnq+qpcc8zZNlLO2NiHYeyjrecCDWAdRyu6zpSVSvxuJKkKfITqpLUIcNdkjp0Uof7ifA1B0k+n+RAkienMX+rYX2SB5I8neSpJJ+aUh3vSPI3Sb7e6vijadQxVM8pSR5N8tUp1vB8kieSPJZkbop1rE5yZ5JvJHkmyQenUMN72s/h4O37ST49hTp+t/1+Ppnk9iTvmHQNrY5PtRqeWpGfQ1WdlDcGL9Z+C/hl4DTg68CFU6jjQ8DFwJNT/FmcD1zc2u8C/veUfhYB3tnapwIPAZum+HO5Afgi8NUp1vA8cN605h+qYxfwW619GrB6yvWcArwE/NKE510LfBs4vW3fAfzbKTz/9wJPAmcweGPL/wDePc45TuYz9xPiaw6q6q+B70163sNqeLGqHmntvwOeYfBLPOk6qqp+0DZPbbepvGKfZB3wEeBz05j/RJLkLAYnITsBqurHVfXadKtiM/CtqnphCnOvAk5PsopBuP6fKdTwT4CHquqNqnoT+J/AvxznBCdzuB/paw4mHmgnmiQbgPczOGuexvynJHkMOADcV1VTqQP4U+D3gL+f0vwHFXBvkofbV25MwwXAAvDnbZnqc0nOnFItB10D3D7pSatqP/DHwHeAF4HXq+reSdfB4Kz9nyc5N8kZwJUc+uHPZTuZw12HSfJO4MvAp6vq+9Oooap+WlUXMfhk8iVJ3jvpGpL8JnCgqh6e9NxH8KtVdTGDb0m9PsmHplDDKgZLh7dW1fuBHwJT+yru9uHGjwJ/OYW5z2bwF/4FwC8CZyb5N5Ouo6qeAW4C7gW+BjwG/HScc5zM4e7XHAxJciqDYP9CVX1l2vW0P/sfAC6fwvSXAh9N8jyD5brLkvzXKdRx8EyRqjoA3MVgOXHS9gH7hv6KupNB2E/LFcAjVfXyFOb+deDbVbVQVT8BvgL8synUQVXtrKoPVNWHgFcZvFY2NidzuPs1B02SMFhPfaaq/mSKdcwkWd3apwMfBr4x6Tqq6g+qal1VbWDwe3F/VU387CzJmUnedbAN/AaDP8cnqqpeAvYmeU/r2gw8Pek6hnyCKSzJNN8BNiU5o/272czgNaqJS/IL7f4fM1hv/+I4H39q3wq5XDX5rzk4oiS3A78GnJdkH/CZqto54TIuBT4JPNHWuwH+sKrumXAd5wO72gVb3gbcUVVTexviCWANcNcgQ1gFfLGqvjalWn4H+EI7EXoOuG4aRbT/5D4M/PY05q+qh5LcCTwCvAk8yvS+huDLSc4FfgJcP+4Xuf36AUnq0Mm8LCNJOgrDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXo/wO1zRjHfxZ2NAAAAABJRU5ErkJggg==\n", 626 | "text/plain": [ 627 | "
" 628 | ] 629 | }, 630 | "metadata": { 631 | "needs_background": "light" 632 | }, 633 | "output_type": "display_data" 634 | } 635 | ], 636 | "source": [ 637 | "k = 8\n", 638 | "samples = []\n", 639 | "for _ in range(1000):\n", 640 | " samples.extend(choice(x, k, True).numpy())\n", 641 | "plt.hist(samples)\n", 642 | "plt.xticks(range(x.numel()))\n", 643 | "plt.show()" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 29, 649 | "metadata": {}, 650 | "outputs": [ 651 | { 652 | "data": { 653 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAASuklEQVR4nO3db4zd1X3n8fenGBIgWcyfqeW1nZpVrHRRpRBnxDpLN+ripgJSxahKKNltcJEr9wHbTZqVurRPspX6IEhVadGuWFlxuqabQB0ShJWiFGpoq30A7fAn/HOyTAjE9gKeECBNaDZx+t0H93hzcWzmjufOXPvk/ZKu7vmd3/nd871jz2d+c+Z3701VIUnqy09NugBJ0vgZ7pLUIcNdkjpkuEtShwx3SerQikkXAHDBBRfU+vXrJ12GJJ1SHnrooW9W1dSx9p0U4b5+/XpmZmYmXYYknVKSPHe8fS7LSFKHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtSh0Z6hWqS3wZ+AyjgceA6YDVwO3A+8BDwkar6fpI3AbcC7wZeAn61qp4df+mSlsL6G/5iYnM/+8n3T2zu3swb7knWAP8RuKiq/jHJbuAa4Ergpqq6Pcl/B7YBt7T7l6vq7UmuAW4EfnXJnsFPIL/5fjJM8t9Zp75R31tmBXBmkh8AZwHPA5cB/67t3wX8FwbhvqW1Ae4A/muS1BJ9np9BJ/VjUt/PPX4vzxvuVXUwyR8C3wD+EbiHwTLMK1V1uA07AKxp7TXA/nbs4SSvMli6+ebw4ybZDmwHeNvb3rb4Z6Ku+U2vpdTjSeIoyzLnMjgbvxB4BfgccPliJ66qHcAOgOnp6VPyU7r9tVnSyWqUZZlfBL5eVXMASb4AXAqsTLKinb2vBQ628QeBdcCBJCuAcxj8YVU65fgDXKeqUS6F/AawKclZSQJsBp4C7gc+2MZsBe5q7T1tm7b/vqVab5ckHdsoa+4PJrkDeBg4DDzCYDnlL4Dbk/xB69vZDtkJ/FmSWeBbDK6sUSc8k5VODSNdLVNVnwA+cVT3M8Alxxj7PeBDiy9NknSifIWqJHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalD84Z7knckeXTo9u0kH0tyXpJ7kzzd7s9t45Pk5iSzSR5LsnHpn4Ykadi84V5VX62qi6vqYuDdwGvAncANwN6q2gDsbdsAVwAb2m07cMtSFC5JOr6FLstsBr5WVc8BW4BdrX8XcFVrbwFurYEHgJVJVo+lWknSSBYa7tcAt7X2qqp6vrVfAFa19hpg/9AxB1rf6yTZnmQmyczc3NwCy5AkvZGRwz3JGcAHgM8dva+qCqiFTFxVO6pquqqmp6amFnKoJGkeCzlzvwJ4uKpebNsvHlluafeHWv9BYN3QcWtbnyRpmSwk3D/Mj5ZkAPYAW1t7K3DXUP+17aqZTcCrQ8s3kqRlsGKUQUnOBt4H/OZQ9yeB3Um2Ac8BV7f+u4ErgVkGV9ZcN7ZqJUkjGSncq+q7wPlH9b3E4OqZo8cWcP1YqpMknRBfoSpJHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdGinck6xMckeSryTZl+Q9Sc5Lcm+Sp9v9uW1sktycZDbJY0k2Lu1TkCQdbdQz9z8BvlRVPwu8E9gH3ADsraoNwN62DXAFsKHdtgO3jLViSdK85g33JOcA7wV2AlTV96vqFWALsKsN2wVc1dpbgFtr4AFgZZLVY69cknRco5y5XwjMAX+a5JEkn0pyNrCqqp5vY14AVrX2GmD/0PEHWt/rJNmeZCbJzNzc3Ik/A0nSjxkl3FcAG4FbqupdwHf50RIMAFVVQC1k4qraUVXTVTU9NTW1kEMlSfMYJdwPAAeq6sG2fQeDsH/xyHJLuz/U9h8E1g0dv7b1SZKWybzhXlUvAPuTvKN1bQaeAvYAW1vfVuCu1t4DXNuumtkEvDq0fCNJWgYrRhz3W8BnkpwBPANcx+AHw+4k24DngKvb2LuBK4FZ4LU2VpK0jEYK96p6FJg+xq7NxxhbwPWLrEuStAi+QlWSOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1KGRwj3Js0keT/JokpnWd16Se5M83e7Pbf1JcnOS2SSPJdm4lE9AkvTjFnLm/m+r6uKqOvKJTDcAe6tqA7C3bQNcAWxot+3ALeMqVpI0msUsy2wBdrX2LuCqof5ba+ABYGWS1YuYR5K0QKOGewH3JHkoyfbWt6qqnm/tF4BVrb0G2D907IHW9zpJtieZSTIzNzd3AqVLko5npA/IBn6+qg4m+Wng3iRfGd5ZVZWkFjJxVe0AdgBMT08v6FhJ0hsb6cy9qg62+0PAncAlwItHllva/aE2/CCwbujwta1PkrRM5g33JGcneeuRNvBLwBPAHmBrG7YVuKu19wDXtqtmNgGvDi3fSJKWwSjLMquAO5McGf/ZqvpSkr8HdifZBjwHXN3G3w1cCcwCrwHXjb1qSdIbmjfcq+oZ4J3H6H8J2HyM/gKuH0t1kqQT4itUJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUMjh3uS05I8kuSLbfvCJA8mmU3y50nOaP1vatuzbf/6pSldknQ8Czlz/yiwb2j7RuCmqno78DKwrfVvA15u/Te1cZKkZTRSuCdZC7wf+FTbDnAZcEcbsgu4qrW3tG3a/s1tvCRpmYx65v7HwO8A/9S2zwdeqarDbfsAsKa11wD7Adr+V9v410myPclMkpm5ubkTLF+SdCzzhnuSXwYOVdVD45y4qnZU1XRVTU9NTY3zoSXpJ96KEcZcCnwgyZXAm4F/BvwJsDLJinZ2vhY42MYfBNYBB5KsAM4BXhp75ZKk45r3zL2qfreq1lbVeuAa4L6q+vfA/cAH27CtwF2tvadt0/bfV1U11qolSW9oMde5/2fg40lmGayp72z9O4HzW//HgRsWV6IkaaFGWZb5/6rqr4G/bu1ngEuOMeZ7wIfGUJsk6QT5ClVJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUodG+YDsNyf5uyRfTvJkkt9v/RcmeTDJbJI/T3JG639T255t+9cv7VOQJB1tlDP3/wtcVlXvBC4GLk+yCbgRuKmq3g68DGxr47cBL7f+m9o4SdIyGuUDsquqvtM2T2+3Ai4D7mj9u4CrWntL26bt35wkY6tYkjSvkdbck5yW5FHgEHAv8DXglao63IYcANa09hpgP0Db/yqDD9A++jG3J5lJMjM3N7e4ZyFJep2Rwr2qflhVFwNrGXwo9s8uduKq2lFV01U1PTU1tdiHkyQNWdDVMlX1CnA/8B5gZZIVbdda4GBrHwTWAbT95wAvjaVaSdJIRrlaZirJytY+E3gfsI9ByH+wDdsK3NXae9o2bf99VVXjLFqS9MZWzD+E1cCuJKcx+GGwu6q+mOQp4PYkfwA8Auxs43cCf5ZkFvgWcM0S1C1JegPzhntVPQa86xj9zzBYfz+6/3vAh8ZSnSTphPgKVUnqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SerQKJ+hui7J/UmeSvJkko+2/vOS3Jvk6XZ/butPkpuTzCZ5LMnGpX4SkqTXG+XM/TDwn6rqImATcH2Si4AbgL1VtQHY27YBrgA2tNt24JaxVy1JekPzhntVPV9VD7f2PwD7gDXAFmBXG7YLuKq1twC31sADwMokq8deuSTpuBa05p5kPYMPy34QWFVVz7ddLwCrWnsNsH/osAOt7+jH2p5kJsnM3NzcAsuWJL2RkcM9yVuAzwMfq6pvD++rqgJqIRNX1Y6qmq6q6ampqYUcKkmax0jhnuR0BsH+mar6Qut+8chyS7s/1PoPAuuGDl/b+iRJy2SUq2UC7AT2VdUfDe3aA2xt7a3AXUP917arZjYBrw4t30iSlsGKEcZcCnwEeDzJo63v94BPAruTbAOeA65u++4GrgRmgdeA68ZasSRpXvOGe1X9LyDH2b35GOMLuH6RdUmSFsFXqEpShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOjTKx+x9OsmhJE8M9Z2X5N4kT7f7c1t/ktycZDbJY0k2LmXxkqRjG+XM/X8Alx/VdwOwt6o2AHvbNsAVwIZ22w7cMp4yJUkLMW+4V9XfAt86qnsLsKu1dwFXDfXfWgMPACuTrB5XsZKk0Zzomvuqqnq+tV8AVrX2GmD/0LgDre/HJNmeZCbJzNzc3AmWIUk6lkX/QbV9IHadwHE7qmq6qqanpqYWW4YkaciJhvuLR5Zb2v2h1n8QWDc0bm3rkyQtoxMN9z3A1tbeCtw11H9tu2pmE/Dq0PKNJGmZrJhvQJLbgF8ALkhyAPgE8Elgd5JtwHPA1W343cCVwCzwGnDdEtQsSZrHvOFeVR8+zq7NxxhbwPWLLUqStDi+QlWSOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6tCThnuTyJF9NMpvkhqWYQ5J0fGMP9ySnAf8NuAK4CPhwkovGPY8k6fiW4sz9EmC2qp6pqu8DtwNblmAeSdJxzPsB2SdgDbB/aPsA8K+OHpRkO7C9bX4nyVdPcL4LgG+e4LHjZB2vZx0nVw1gHUc7KerIjYuq42eOt2Mpwn0kVbUD2LHYx0kyU1XTYyjJOqyj2xqs4yevjqVYljkIrBvaXtv6JEnLZCnC/e+BDUkuTHIGcA2wZwnmkSQdx9iXZarqcJL/APwlcBrw6ap6ctzzDFn00s6YWMfrWcePnAw1gHUcres6UlVL8biSpAnyFaqS1CHDXZI6dEqH+8nwNgdJPp3kUJInJjF/q2FdkvuTPJXkySQfnVAdb07yd0m+3Or4/UnUMVTPaUkeSfLFCdbwbJLHkzyaZGaCdaxMckeSryTZl+Q9E6jhHe3rcOT27SQfm0Adv93+fz6R5LYkb17uGlodH201PLkkX4eqOiVvDP5Y+zXgXwBnAF8GLppAHe8FNgJPTPBrsRrY2NpvBf73hL4WAd7S2qcDDwKbJvh1+TjwWeCLE6zhWeCCSc0/VMcu4Dda+wxg5YTrOQ14AfiZZZ53DfB14My2vRv49Qk8/58DngDOYnBhy18Bbx/nHKfymftJ8TYHVfW3wLeWe96jani+qh5u7X8A9jH4T7zcdVRVfadtnt5uE/mLfZK1wPuBT01i/pNJknMYnITsBKiq71fVK5Otis3A16rquQnMvQI4M8kKBuH6fyZQw78EHqyq16rqMPA3wK+Mc4JTOdyP9TYHyx5oJ5sk64F3MThrnsT8pyV5FDgE3FtVE6kD+GPgd4B/mtD8RxRwT5KH2ltuTMKFwBzwp22Z6lNJzp5QLUdcA9y23JNW1UHgD4FvAM8Dr1bVPctdB4Oz9n+T5PwkZwFX8voXfy7aqRzuOkqStwCfBz5WVd+eRA1V9cOqupjBK5MvSfJzy11Dkl8GDlXVQ8s99zH8fFVtZPAuqdcnee8EaljBYOnwlqp6F/BdYGJvxd1e3PgB4HMTmPtcBr/hXwj8c+DsJL+23HVU1T7gRuAe4EvAo8APxznHqRzuvs3BkCSnMwj2z1TVFyZdT/u1/37g8glMfynwgSTPMliuuyzJ/5xAHUfOFKmqQ8CdDJYTl9sB4MDQb1F3MAj7SbkCeLiqXpzA3L8IfL2q5qrqB8AXgH89gTqoqp1V9e6qei/wMoO/lY3NqRzuvs1BkyQM1lP3VdUfTbCOqSQrW/tM4H3AV5a7jqr63apaW1XrGfy/uK+qlv3sLMnZSd56pA38EoNfx5dVVb0A7E/yjta1GXhquesY8mEmsCTTfAPYlOSs9n2zmcHfqJZdkp9u929jsN7+2XE+/sTeFXKxavnf5uCYktwG/AJwQZIDwCeqaucyl3Ep8BHg8bbeDfB7VXX3MtexGtjVPrDlp4DdVTWxyxBPAquAOwcZwgrgs1X1pQnV8lvAZ9qJ0DPAdZMoov2Qex/wm5OYv6oeTHIH8DBwGHiEyb0NweeTnA/8ALh+3H/k9u0HJKlDp/KyjCTpOAx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1KH/B9rDHYGV1Sh2AAAAAElFTkSuQmCC\n", 654 | "text/plain": [ 655 | "
" 656 | ] 657 | }, 658 | "metadata": { 659 | "needs_background": "light" 660 | }, 661 | "output_type": "display_data" 662 | } 663 | ], 664 | "source": [ 665 | "k = 8\n", 666 | "samples = []\n", 667 | "for _ in range(1000):\n", 668 | " samples.extend(choice(x, k, False).numpy())\n", 669 | "plt.hist(samples)\n", 670 | "plt.xticks(range(x.numel()))\n", 671 | "plt.show()" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": 30, 677 | "metadata": {}, 678 | "outputs": [ 679 | { 680 | "data": { 681 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAANeklEQVR4nO3df6zd9V3H8edrLTh+RWDtmtrWFbVZrEsEvGHoJkHrJjCzon8QSJxIMN0fxYCaGLZ/0D+WsESnWaIkdeC6yA8rPwKZBMFKJP4xtlvGoNDhOiijtdA7mcAkcSt7+8f5Np6WW27vPfee770fn4/k5HzP93zP+b57c3ny7ffe822qCklSW97V9wCSpPln3CWpQcZdkhpk3CWpQcZdkhq0vO8BAFasWFHr16/vewxJWlJ27dr13apaOd1ziyLu69evZ3Jysu8xJGlJSfLi8Z7ztIwkNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNWhRfEJV0tutv/Efe9nvvps/1st+Nb+MuzSDviIrjcLTMpLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUID+hqlnxI/HS0mDcJf2/1+clJhbqwMXTMpLUoBnjnmRdkkeTPJvkmSTXd+vPTvJIkm9192d165Pk80n2JnkqyfkL/YeQJB3tRI7cDwN/VFUbgQuBrUk2AjcCO6tqA7CzewxwKbChu20Bbpn3qSVJ72jGuFfVwap6olt+A9gDrAE2A9u7zbYDl3fLm4Ev1cBXgDOTrJ73ySVJxzWrc+5J1gPnAY8Dq6rqYPfUy8CqbnkN8NLQy/Z36459ry1JJpNMTk1NzXJsSdI7OeG4JzkduAe4oapeH36uqgqo2ey4qrZV1URVTaxcuXI2L5UkzeCE4p7kJAZhv72q7u1Wv3LkdEt3f6hbfwBYN/Tytd06SdKYzPh77kkC3ArsqarPDT31AHA1cHN3f//Q+uuS3AV8EHht6PSN5oH/7JukmZzIh5g+BHwCeDrJk926TzOI+o4k1wIvAld0zz0IXAbsBd4ErpnXiSVJM5ox7lX1b0CO8/SmabYvYOuIc52wFj9ZJkmj8hOqktQgry0zAs99q0V+X7fBI3dJapBxl6QGGXdJapBxl6QGGXdJapC/LaMlwd/gkGbHI3dJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGzRj3JLclOZRk99C6P0lyIMmT3e2yoec+lWRvkueS/PpCDS5JOr4TOXL/InDJNOv/oqrO7W4PAiTZCFwJ/Fz3mr9Osmy+hpUknZgZ415VjwGvnuD7bQbuqqr/qaoXgL3ABSPMJ0mag1HOuV+X5KnutM1Z3bo1wEtD2+zv1r1Nki1JJpNMTk1NjTCGJOlYc437LcBPA+cCB4E/n+0bVNW2qpqoqomVK1fOcQxJ0nTmFPeqeqWq3qqqHwF/w/+dejkArBvadG23TpI0RnOKe5LVQw9/EzjymzQPAFcm+bEk5wAbgK+ONqIkabaWz7RBkjuBi4EVSfYDNwEXJzkXKGAf8EmAqnomyQ7gWeAwsLWq3lqY0SVJxzNj3KvqqmlW3/oO238G+MwoQ0mSRuMnVCWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQTPGPcltSQ4l2T207uwkjyT5Vnd/Vrc+ST6fZG+Sp5Kcv5DDS5KmdyJH7l8ELjlm3Y3AzqraAOzsHgNcCmzobluAW+ZnTEnSbMwY96p6DHj1mNWbge3d8nbg8qH1X6qBrwBnJlk9X8NKkk7MXM+5r6qqg93yy8CqbnkN8NLQdvu7dW+TZEuSySSTU1NTcxxDkjSdkX+gWlUF1Bxet62qJqpqYuXKlaOOIUkaMte4v3LkdEt3f6hbfwBYN7Td2m6dJGmM5hr3B4Cru+WrgfuH1v9O91szFwKvDZ2+kSSNyfKZNkhyJ3AxsCLJfuAm4GZgR5JrgReBK7rNHwQuA/YCbwLXLMDMkqQZzBj3qrrqOE9tmmbbAraOOpQkaTR+QlWSGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBy0d5cZJ9wBvAW8DhqppIcjbw98B6YB9wRVV9b7QxJUmzMR9H7r9SVedW1UT3+EZgZ1VtAHZ2jyVJY7QQp2U2A9u75e3A5QuwD0nSOxg17gU8nGRXki3dulVVdbBbfhlYNd0Lk2xJMplkcmpqasQxJEnDRjrnDny4qg4keS/wSJJvDj9ZVZWkpnthVW0DtgFMTExMu40kaW5GOnKvqgPd/SHgPuAC4JUkqwG6+0OjDilJmp05xz3JaUnOOLIMfBTYDTwAXN1tdjVw/6hDSpJmZ5TTMquA+5IceZ87quqhJF8DdiS5FngRuGL0MSVJszHnuFfV88DPT7P+P4FNowwlSRqNn1CVpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYtWNyTXJLkuSR7k9y4UPuRJL3dgsQ9yTLgr4BLgY3AVUk2LsS+JElvt1BH7hcAe6vq+ar6AXAXsHmB9iVJOsbyBXrfNcBLQ4/3Ax8c3iDJFmBL9/D7SZ6b475WAN+d42vnk3MczTmOthjmWAwzgHMcJZ8daY73He+JhYr7jKpqG7Bt1PdJMllVE/MwknM4R9NzLIYZnGN8cyzUaZkDwLqhx2u7dZKkMViouH8N2JDknCQnA1cCDyzQviRJx1iQ0zJVdTjJdcA/AcuA26rqmYXYF/NwameeOMfRnONoi2GOxTADOMexFmSOVNVCvK8kqUd+QlWSGmTcJalBSzrui+ESB0luS3Ioye4+9j80x7okjyZ5NskzSa7vYYZ3J/lqkm90M/zpuGc4Zp5lSb6e5Ms9zrAvydNJnkwy2eMcZya5O8k3k+xJ8os9zPD+7utw5PZ6kht6mOMPuu/P3UnuTPLucc/QzXF9N8MzC/J1qKoleWPwg9pvAz8FnAx8A9jYwxwXAecDu3v+eqwGzu+WzwD+fdxfDyDA6d3yScDjwIU9fk3+ELgD+HKPM+wDVvT5vdHNsR34vW75ZODMnudZBrwMvG/M+10DvACc0j3eAfxuD3/+DwC7gVMZ/GLLPwM/M5/7WMpH7oviEgdV9Rjw6rj3O80cB6vqiW75DWAPg2/kcc5QVfX97uFJ3a2Xn9gnWQt8DPhCH/tfTJL8OIODkFsBquoHVfVf/U7FJuDbVfViD/teDpySZDmDuP5HDzP8LPB4Vb1ZVYeBfwV+az53sJTjPt0lDsYas8UqyXrgPAZHzuPe97IkTwKHgEeqauwzdP4S+GPgRz3t/4gCHk6yq7vkRh/OAaaAv+1OU30hyWk9zXLElcCd495pVR0A/gz4DnAQeK2qHh73HAyO2n85yXuSnApcxtEf/BzZUo67ppHkdOAe4Iaqen3c+6+qt6rqXAafSr4gyQfGPUOS3wAOVdWuce97Gh+uqvMZXCF1a5KLephhOYNTh7dU1XnAfwO9XYa7+2Djx4F/6GHfZzH4G/45wE8ApyX57XHPUVV7gM8CDwMPAU8Cb83nPpZy3L3EwTGSnMQg7LdX1b19ztL9tf9R4JIedv8h4ONJ9jE4XferSf6uhzmOHClSVYeA+xicThy3/cD+ob9F3c0g9n25FHiiql7pYd+/BrxQVVNV9UPgXuCXepiDqrq1qn6hqi4Cvsfg52TzZinH3UscDEkSBudU91TV53qaYWWSM7vlU4CPAN8c9xxV9amqWltV6xl8X/xLVY396CzJaUnOOLIMfJTBX8fHqqpeBl5K8v5u1Sbg2XHPMeQqejgl0/kOcGGSU7v/ZjYx+PnU2CV5b3f/kwzOt98xn+/f21UhR1XjvcTBcSW5E7gYWJFkP3BTVd067jkYHK1+Ani6O+cN8OmqenCMM6wGtnf/WMu7gB1V1duvIS4Cq4D7Bg1hOXBHVT3U0yy/D9zeHQg9D1zTxxDd/+Q+Anyyj/1X1eNJ7gaeAA4DX6e/yxDck+Q9wA+BrfP9Q24vPyBJDVrKp2UkScdh3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhr0vztwKSjkUOCgAAAAAElFTkSuQmCC\n", 682 | "text/plain": [ 683 | "
" 684 | ] 685 | }, 686 | "metadata": { 687 | "needs_background": "light" 688 | }, 689 | "output_type": "display_data" 690 | } 691 | ], 692 | "source": [ 693 | "k = 2\n", 694 | "samples = []\n", 695 | "for _ in range(1000):\n", 696 | " samples.extend(choice(x, k, True).numpy())\n", 697 | "plt.hist(samples)\n", 698 | "plt.xticks(range(x.numel()))\n", 699 | "plt.show()" 700 | ] 701 | }, 702 | { 703 | "cell_type": "code", 704 | "execution_count": 31, 705 | "metadata": {}, 706 | "outputs": [ 707 | { 708 | "data": { 709 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAANgUlEQVR4nO3df6zd9V3H8edrLTh+RWDtmtqWFbVZrEsEvEF0k6B1E5hZ0T8IJG6VYLo/OgNqYtj+Qf9YwhKdZomS1IHrIrBVfgQyCYKVSPxjbC1jUCi4DspoLbSTCUwSt7K3f5xv9bTccn+ce8/33s+ej+Tmfs/3/Pi+e7l98r2fc89pqgpJUlve0fcAkqS5Z9wlqUHGXZIaZNwlqUHGXZIatLTvAQCWLVtWa9eu7XsMSVpUdu3a9d2qWj7ZdQsi7mvXrmXnzp19jyFJi0qSF050ncsyktQg4y5JDTLuktQg4y5JDTLuktQg4y5JDTLuktQg4y5JDTLuktSgBfEKVc3M2hv+sbdj77vpw70dW9L0eeYuSQ0y7pLUIOMuSQ1yzV3SgtHX80ktPpfkmbskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNchXqEo6Rp/vOqq5M+WZe5I1SR5O8nSSp5Jc1+0/O8lDSb7VfT6r258kn0uyN8kTSS6Y7z+EJOlY01mWOQL8cVWtBy4CtiRZD9wA7KiqdcCO7jLAZcC67mMzcPOcTy1JeltTxr2qDlbVY93268AeYBWwEdjW3WwbcEW3vRH4Yg18FTgzyco5n1ySdEIzWnNPshY4H3gUWFFVB7urXgJWdNurgBeH7ra/23cQSVqAWvzXzaYd9ySnA3cB11fVa0n+77qqqiQ1kwMn2cxg2YZzzjlnJnddMHzi6ceDb0OrxWhavwqZ5CQGYb+tqu7udr98dLml+3yo238AWDN099XdvmNU1daqmqiqieXLl892fknSJKbz2zIBbgH2VNVnh666D9jUbW8C7h3a/7Hut2YuAl4dWr6RJI3BdJZl3g98FHgyyePdvk8BNwHbk1wLvABc2V13P3A5sBd4A7hmTic+jksjPx787yzNzJRxr6p/A3KCqzdMcvsCtow4lyRpBL79gCQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoP8l5g0I75SVFocPHOXpAYZd0lqkMsy0gLlEphG4Zm7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg6aMe5JbkxxKsnto358mOZDk8e7j8qHrPplkb5Jnk/zmfA0uSTqx6Zy5fwG4dJL9f1lV53Uf9wMkWQ9cBfx8d5+/SbJkroaVJE3PlHGvqkeAV6b5eBuBL1XV/1TV88Be4MIR5pMkzcIoa+6fSPJEt2xzVrdvFfDi0G32d/skSWM027jfDPwMcB5wEPiLmT5Aks1JdibZefjw4VmOIUmazKziXlUvV9WbVfUj4G/5/6WXA8CaoZuu7vZN9hhbq2qiqiaWL18+mzEkSScwq7gnWTl08beBo79Jcx9wVZKfSHIusA742mgjSpJmaulUN0hyB3AJsCzJfuBG4JIk5wEF7AM+DlBVTyXZDjwNHAG2VNWb8zO6JOlEpox7VV09ye5b3ub2nwY+PcpQkqTR+ApVSWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWqQcZekBhl3SWrQlHFPcmuSQ0l2D+07O8lDSb7VfT6r258kn0uyN8kTSS6Yz+ElSZObzpn7F4BLj9t3A7CjqtYBO7rLAJcB67qPzcDNczOmJGkmpox7VT0CvHLc7o3Atm57G3DF0P4v1sBXgTOTrJyrYSVJ0zPbNfcVVXWw234JWNFtrwJeHLrd/m7fWyTZnGRnkp2HDx+e5RiSpMmM/IRqVRVQs7jf1qqaqKqJ5cuXjzqGJGnIbOP+8tHllu7zoW7/AWDN0O1Wd/skSWM027jfB2zqtjcB9w7t/1j3WzMXAa8OLd9IksZk6VQ3SHIHcAmwLMl+4EbgJmB7kmuBF4Aru5vfD1wO7AXeAK6Zh5klSVOYMu5VdfUJrtowyW0L2DLqUJKk0fgKVUlqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYtHeXOSfYBrwNvAkeqaiLJ2cCXgbXAPuDKqvreaGNKkmZiLs7cf62qzquqie7yDcCOqloH7OguS5LGaD6WZTYC27rtbcAV83AMSdLbGDXuBTyYZFeSzd2+FVV1sNt+CVgx2R2TbE6yM8nOw4cPjziGJGnYSGvuwAeq6kCSdwMPJXlm+MqqqiQ12R2raiuwFWBiYmLS20iSZmekM/eqOtB9PgTcA1wIvJxkJUD3+dCoQ0qSZmbWcU9yWpIzjm4DHwJ2A/cBm7qbbQLuHXVISdLMjLIsswK4J8nRx7m9qh5I8nVge5JrgReAK0cfU5I0E7OOe1U9B/zCJPv/E9gwylCSpNH4ClVJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGzVvck1ya5Nkke5PcMF/HkSS91bzEPckS4K+By4D1wNVJ1s/HsSRJbzVfZ+4XAnur6rmq+gHwJWDjPB1LknScpfP0uKuAF4cu7wd+afgGSTYDm7uL30/y7CyPtQz47izvO5ec41jOcayFMMdCmAGc4xj5zEhzvOdEV8xX3KdUVVuBraM+TpKdVTUxByM5h3M0PcdCmME5xjfHfC3LHADWDF1e3e2TJI3BfMX968C6JOcmORm4Crhvno4lSTrOvCzLVNWRJJ8A/glYAtxaVU/Nx7GYg6WdOeIcx3KOYy2EORbCDOAcx5uXOVJV8/G4kqQe+QpVSWqQcZekBi3quC+EtzhIcmuSQ0l293H8oTnWJHk4ydNJnkpyXQ8zvDPJ15J8s5vhz8Y9w3HzLEnyjSRf6XGGfUmeTPJ4kp09znFmkjuTPJNkT5Jf7mGG93Zfh6MfryW5voc5/rD7/tyd5I4k7xz3DN0c13UzPDUvX4eqWpQfDJ6o/Tbw08DJwDeB9T3McTFwAbC756/HSuCCbvsM4N/H/fUAApzebZ8EPApc1OPX5I+A24Gv9DjDPmBZn98b3RzbgN/vtk8Gzux5niXAS8B7xnzcVcDzwCnd5e3A7/Xw538fsBs4lcEvtvwz8LNzeYzFfOa+IN7ioKoeAV4Z93EnmeNgVT3Wbb8O7GHwjTzOGaqqvt9dPKn76OUZ+ySrgQ8Dn+/j+AtJkp9kcBJyC0BV/aCq/qvfqdgAfLuqXujh2EuBU5IsZRDX/+hhhp8DHq2qN6rqCPCvwO/M5QEWc9wne4uDscZsoUqyFjifwZnzuI+9JMnjwCHgoaoa+wydvwL+BPhRT8c/qoAHk+zq3nKjD+cCh4G/65apPp/ktJ5mOeoq4I5xH7SqDgB/DnwHOAi8WlUPjnsOBmftv5rkXUlOBS7n2Bd+jmwxx12TSHI6cBdwfVW9Nu7jV9WbVXUeg1clX5jkfeOeIclvAYeqate4jz2JD1TVBQzeIXVLkot7mGEpg6XDm6vqfOC/gd7ehrt7YeNHgH/o4dhnMfgJ/1zgp4DTkvzuuOeoqj3AZ4AHgQeAx4E35/IYiznuvsXBcZKcxCDst1XV3X3O0v3Y/zBwaQ+Hfz/wkST7GCzX/XqSv+9hjqNnilTVIeAeBsuJ47Yf2D/0U9SdDGLfl8uAx6rq5R6O/RvA81V1uKp+CNwN/EoPc1BVt1TVL1bVxcD3GDxPNmcWc9x9i4MhScJgTXVPVX22pxmWJzmz2z4F+CDwzLjnqKpPVtXqqlrL4PviX6pq7GdnSU5LcsbRbeBDDH4cH6uqegl4Mcl7u10bgKfHPceQq+lhSabzHeCiJKd2f2c2MHh+auySvLv7fA6D9fbb5/Lxe3tXyFHVeN/i4ISS3AFcAixLsh+4sapuGfccDM5WPwo82a15A3yqqu4f4wwrgW3dP9byDmB7VfX2a4gLwArgnkFDWArcXlUP9DTLHwC3dSdCzwHX9DFE9z+5DwIf7+P4VfVokjuBx4AjwDfo720I7kryLuCHwJa5fpLbtx+QpAYt5mUZSdIJGHdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QG/S+l+SaK0jr01gAAAABJRU5ErkJggg==\n", 710 | "text/plain": [ 711 | "
" 712 | ] 713 | }, 714 | "metadata": { 715 | "needs_background": "light" 716 | }, 717 | "output_type": "display_data" 718 | } 719 | ], 720 | "source": [ 721 | "k = 2\n", 722 | "samples = []\n", 723 | "for _ in range(1000):\n", 724 | " samples.extend(choice(x, k, False).numpy())\n", 725 | "plt.hist(samples)\n", 726 | "plt.xticks(range(x.numel()))\n", 727 | "plt.show()" 728 | ] 729 | }, 730 | { 731 | "cell_type": "markdown", 732 | "metadata": {}, 733 | "source": [ 734 | "### CUDA Extension" 735 | ] 736 | }, 737 | { 738 | "cell_type": "code", 739 | "execution_count": 32, 740 | "metadata": {}, 741 | "outputs": [], 742 | "source": [ 743 | "x = x.cuda()" 744 | ] 745 | }, 746 | { 747 | "cell_type": "code", 748 | "execution_count": 33, 749 | "metadata": {}, 750 | "outputs": [ 751 | { 752 | "data": { 753 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAASuUlEQVR4nO3df6zd9X3f8ecrGBIgGebHrUVtp2aKlQ5FCiFXzBld1OGmAlLFqEoY2RY85Mr9g3VJmdTS/pNV2h9BqkqLNiFZcTqzJaSUBGFlKA0zdNX+gPbyI/xyMm4IxPYA3xJwmqAsIX3vj/NxOTg291zfc8+xP3k+pKPz+X6+n+/5vM+V/brf+zk/vqkqJEl9ecu0C5AkjZ/hLkkdMtwlqUOGuyR1yHCXpA6tmnYBAOedd15t2LBh2mVI0knloYce+tuqmjnavhMi3Dds2MDc3Ny0y5Ckk0qS5461z2UZSeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nq0AnxCVVpMRtu/B9TmffZz3x4KvNKy+WZuyR1yHCXpA6d9Msy0/pzHab3J/vP4nOeFn/WOll55i5JHRrpzD3JbwO/ARTwOHAdcD7wReBc4CHgE1X1oyRvBW4D3g+8BPzLqnp2/KVrGqZ5JitpdIueuSdZC/x7YLaq3gOcAlwD3ATcXFXvAl4GtrVDtgEvt/6b2zhJ0gSNuua+Cjg9yY+BM4DngcuAf9X27wL+I3ArsKW1Ae4E/nOSVFWNqWZJnfItr+Oz6Jl7VR0A/hD4DoNQP8RgGeaVqnqtDdsPrG3ttcC+duxrbfy5Rz5uku1J5pLMLSwsLPd5SJKGjLIsczaDs/ELgJ8HzgQuX+7EVbWjqmaranZm5qiXAJQkHadR3i3zK8C3q2qhqn4MfBm4FFid5PCyzjrgQGsfANYDtP1nMXhhVZI0IaOE+3eATUnOSBJgM/AUcD/w0TZmK3B3a+9u27T997neLkmTNcqa+4MMXhh9mMHbIN8C7AB+F7ghyTyDNfWd7ZCdwLmt/wbgxhWoW5L0JkZ6t0xVfRr49BHdzwCXHGXsD4GPLb80SdLx8hOqktShk/67ZabJT2tqJfnvS8thuEv6mdfjF8S5LCNJHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktShUS6Q/e4kjw7dvpfkU0nOSXJvkqfb/dltfJLckmQ+yWNJLl75pyFJGjbKZfa+WVUXVdVFwPuBV4G7GFw+b09VbQT28Prl9K4ANrbbduDWlShcknRsS12W2Qx8q6qeA7YAu1r/LuCq1t4C3FYDDwCrk5w/lmolSSNZarhfA9ze2muq6vnWfgFY09prgX1Dx+xvfW+QZHuSuSRzCwsLSyxDkvRmRg73JKcBHwH+/Mh9VVVALWXiqtpRVbNVNTszM7OUQyVJi1jKmfsVwMNV9WLbfvHwcku7P9j6DwDrh45b1/okSROylHD/OK8vyQDsBra29lbg7qH+a9u7ZjYBh4aWbyRJEzDSBbKTnAl8CPjNoe7PAHck2QY8B1zd+u8BrgTmGbyz5rqxVStJGslI4V5VPwDOPaLvJQbvnjlybAHXj6U6SdJx8ROqktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOjRTuSVYnuTPJN5LsTfKBJOckuTfJ0+3+7DY2SW5JMp/ksSQXr+xTkCQdadQz9z8BvlpVvwi8F9gL3AjsqaqNwJ62DYMLaW9st+3ArWOtWJK0qEXDPclZwAeBnQBV9aOqegXYAuxqw3YBV7X2FuC2GngAWJ3k/LFXLkk6plHO3C8AFoA/TfJIks+2C2avqarn25gXgDWtvRbYN3T8/tb3Bkm2J5lLMrewsHD8z0CS9FNGCfdVwMXArVX1PuAHvL4EA/zDRbFrKRNX1Y6qmq2q2ZmZmaUcKklaxCjhvh/YX1UPtu07GYT9i4eXW9r9wbb/ALB+6Ph1rU+SNCGLhntVvQDsS/Lu1rUZeArYDWxtfVuBu1t7N3Bte9fMJuDQ0PKNJGkCVo047reAzyc5DXgGuI7BL4Y7kmwDngOubmPvAa4E5oFX21hJ0gSNFO5V9Sgwe5Rdm48ytoDrl1mXJGkZ/ISqJHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHRop3JM8m+TxJI8mmWt95yS5N8nT7f7s1p8ktySZT/JYkotX8glIkn7aUs7c/0VVXVRVhy/acSOwp6o2Ant4/aLZVwAb2207cOu4ipUkjWY5yzJbgF2tvQu4aqj/thp4AFh9+ELakqTJGDXcC/hakoeSbG99a4YufP0CsKa11wL7ho7d3/okSRMy6gWyf6mqDiT5OeDeJN8Y3llVlaSWMnH7JbEd4J3vfOdSDpUkLWKkM/eqOtDuDwJ3AZcALx5ebmn3B9vwA8D6ocPXtb4jH3NHVc1W1ezMzMzxPwNJ0k9ZNNyTnJnkHYfbwK8CTwC7ga1t2Fbg7tbeDVzb3jWzCTg0tHwjSZqAUZZl1gB3JTk8/gtV9dUkfwPckWQb8BxwdRt/D3AlMA+8Clw39qolSW9q0XCvqmeA9x6l/yVg81H6C7h+LNVJko6Ln1CVpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHVo5HBPckqSR5J8pW1fkOTBJPNJ/izJaa3/rW17vu3fsDKlS5KOZSln7p8E9g5t3wTcXFXvAl4GtrX+bcDLrf/mNk6SNEEjhXuSdcCHgc+27QCXAXe2IbuAq1p7S9um7d/cxkuSJmTUM/c/Bn4H+Pu2fS7wSlW91rb3A2tbey2wD6DtP9TGv0GS7UnmkswtLCwcZ/mSpKNZNNyT/BpwsKoeGufEVbWjqmaranZmZmacDy1JP/NWjTDmUuAjSa4E3gb8I+BPgNVJVrWz83XAgTb+ALAe2J9kFXAW8NLYK5ckHdOiZ+5V9XtVta6qNgDXAPdV1b8G7gc+2oZtBe5u7d1tm7b/vqqqsVYtSXpTy3mf++8CNySZZ7CmvrP17wTObf03ADcur0RJ0lKNsizzD6rqL4G/bO1ngEuOMuaHwMfGUJsk6Tj5CVVJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUodGuUD225L8dZKvJ3kyyR+0/guSPJhkPsmfJTmt9b+1bc+3/RtW9ilIko40ypn7/wMuq6r3AhcBlyfZBNwE3FxV7wJeBra18duAl1v/zW2cJGmCRrlAdlXV99vmqe1WwGXAna1/F3BVa29p27T9m5NkbBVLkhY10pp7klOSPAocBO4FvgW8UlWvtSH7gbWtvRbYB9D2H2JwAe0jH3N7krkkcwsLC8t7FpKkNxgp3KvqJ1V1EbCOwUWxf3G5E1fVjqqararZmZmZ5T6cJGnIkt4tU1WvAPcDHwBWJ1nVdq0DDrT2AWA9QNt/FvDSWKqVJI1klHfLzCRZ3dqnAx8C9jII+Y+2YVuBu1t7d9um7b+vqmqcRUuS3tyqxYdwPrArySkMfhncUVVfSfIU8MUk/wl4BNjZxu8E/luSeeC7wDUrULck6U0sGu5V9RjwvqP0P8Ng/f3I/h8CHxtLdZKk4+InVCWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHRrlMnvrk9yf5KkkTyb5ZOs/J8m9SZ5u92e3/iS5Jcl8kseSXLzST0KS9EajnLm/BvyHqroQ2ARcn+RC4EZgT1VtBPa0bYArgI3tth24dexVS5Le1KLhXlXPV9XDrf13DC6OvRbYAuxqw3YBV7X2FuC2GngAWJ3k/LFXLkk6piWtuSfZwOB6qg8Ca6rq+bbrBWBNa68F9g0dtr/1HflY25PMJZlbWFhYYtmSpDczcrgneTvwJeBTVfW94X1VVUAtZeKq2lFVs1U1OzMzs5RDJUmLGCnck5zKINg/X1Vfbt0vHl5uafcHW/8BYP3Q4etanyRpQkZ5t0yAncDeqvqjoV27ga2tvRW4e6j/2vaumU3AoaHlG0nSBKwaYcylwCeAx5M82vp+H/gMcEeSbcBzwNVt3z3AlcA88Cpw3VgrliQtatFwr6r/DeQYuzcfZXwB1y+zLknSMvgJVUnqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtSh0a5zN7nkhxM8sRQ3zlJ7k3ydLs/u/UnyS1J5pM8luTilSxeknR0o5y5/1fg8iP6bgT2VNVGYE/bBrgC2Nhu24Fbx1OmJGkpFg33qvor4LtHdG8BdrX2LuCqof7bauABYHWS88dVrCRpNMe75r6mqp5v7ReANa29Ftg3NG5/6/spSbYnmUsyt7CwcJxlSJKOZtkvqLYLYtdxHLejqmaranZmZma5ZUiShhxvuL94eLml3R9s/QeA9UPj1rU+SdIEHW+47wa2tvZW4O6h/mvbu2Y2AYeGlm8kSROyarEBSW4Hfhk4L8l+4NPAZ4A7kmwDngOubsPvAa4E5oFXgetWoGZJ0iIWDfeq+vgxdm0+ytgCrl9uUZKk5fETqpLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDq1IuCe5PMk3k8wnuXEl5pAkHdvYwz3JKcB/Aa4ALgQ+nuTCcc8jSTq2lThzvwSYr6pnqupHwBeBLSswjyTpGBa9hupxWAvsG9reD/zTIwcl2Q5sb5vfT/LN45zvPOBvj/PYcbKON7KOE6sGsI4jnRB15KZl1fELx9qxEuE+kqraAexY7uMkmauq2TGUZB3W0W0N1vGzV8dKLMscANYPba9rfZKkCVmJcP8bYGOSC5KcBlwD7F6BeSRJxzD2ZZmqei3JvwP+AjgF+FxVPTnueYYse2lnTKzjjazjdSdCDWAdR+q6jlTVSjyuJGmK/ISqJHXIcJekDp3U4X4ifM1Bks8lOZjkiWnM32pYn+T+JE8leTLJJ6dUx9uS/HWSr7c6/mAadQzVc0qSR5J8ZYo1PJvk8SSPJpmbYh2rk9yZ5BtJ9ib5wBRqeHf7ORy+fS/Jp6ZQx2+3f59PJLk9ydsmXUOr45OthidX5OdQVSfljcGLtd8C/jFwGvB14MIp1PFB4GLgiSn+LM4HLm7tdwD/Z0o/iwBvb+1TgQeBTVP8udwAfAH4yhRreBY4b1rzD9WxC/iN1j4NWD3lek4BXgB+YcLzrgW+DZzetu8A/u0Unv97gCeAMxi8seV/Au8a5xwn85n7CfE1B1X1V8B3Jz3vETU8X1UPt/bfAXsZ/COedB1VVd9vm6e221ResU+yDvgw8NlpzH8iSXIWg5OQnQBV9aOqemW6VbEZ+FZVPTeFuVcBpydZxSBc/+8UavgnwINV9WpVvQb8L+DXxznByRzuR/uag4kH2okmyQbgfQzOmqcx/ylJHgUOAvdW1VTqAP4Y+B3g76c0/2EFfC3JQ+0rN6bhAmAB+NO2TPXZJGdOqZbDrgFun/SkVXUA+EPgO8DzwKGq+tqk62Bw1v7Pk5yb5AzgSt744c9lO5nDXUdI8nbgS8Cnqup706ihqn5SVRcx+GTyJUneM+kakvwacLCqHpr03EfxS1V1MYNvSb0+yQenUMMqBkuHt1bV+4AfAFP7Ku724caPAH8+hbnPZvAX/gXAzwNnJvk3k66jqvYCNwFfA74KPAr8ZJxznMzh7tccDElyKoNg/3xVfXna9bQ/++8HLp/C9JcCH0nyLIPlusuS/Pcp1HH4TJGqOgjcxWA5cdL2A/uH/oq6k0HYT8sVwMNV9eIU5v4V4NtVtVBVPwa+DPyzKdRBVe2sqvdX1QeBlxm8VjY2J3O4+zUHTZIwWE/dW1V/NMU6ZpKsbu3TgQ8B35h0HVX1e1W1rqo2MPh3cV9VTfzsLMmZSd5xuA38KoM/xyeqql4A9iV5d+vaDDw16TqGfJwpLMk03wE2JTmj/b/ZzOA1qolL8nPt/p0M1tu/MM7Hn9q3Qi5XTf5rDo4qye3ALwPnJdkPfLqqdk64jEuBTwCPt/VugN+vqnsmXMf5wK52wZa3AHdU1dTehngCWAPcNcgQVgFfqKqvTqmW3wI+306EngGum0YR7Zfch4DfnMb8VfVgkjuBh4HXgEeY3tcQfCnJucCPgevH/SK3Xz8gSR06mZdlJEnHYLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDv1/FP0dDoB0gIYAAAAASUVORK5CYII=\n", 754 | "text/plain": [ 755 | "
" 756 | ] 757 | }, 758 | "metadata": { 759 | "needs_background": "light" 760 | }, 761 | "output_type": "display_data" 762 | } 763 | ], 764 | "source": [ 765 | "k = 8\n", 766 | "samples = []\n", 767 | "for _ in range(1000):\n", 768 | " samples.extend(choice(x, k,True).cpu().numpy())\n", 769 | "plt.hist(samples)\n", 770 | "plt.xticks(range(x.numel()))\n", 771 | "plt.show()" 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": 34, 777 | "metadata": {}, 778 | "outputs": [ 779 | { 780 | "data": { 781 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAASv0lEQVR4nO3df4yd1X3n8fenDCRAspgfU4vaTs0qVrqoUogzYp2lG3VxUwGpYrRKKFE3eJEr9w+2mzQrdWn/iSr1jyBVpUVasbLidE03gVInCCtFWVhDt1qp0A4/wi8ny4RAbC/gCQHShGZTp9/94x43F8dm7njuzLVP3i/p6p7nPOe553tH44+fOfe596aqkCT15acmXYAkafwMd0nqkOEuSR0y3CWpQ4a7JHVoatIFAFxwwQW1fv36SZchSaeUhx9++FtVNX2sfSdFuK9fv57Z2dlJlyFJp5Qkzx9vn8syktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUoZPiHarSyWr9jX8xsbmf+/QHJza3Tn2euUtSh0Y6c0/yW8CvAwU8AVwPXAjcAZwPPAx8rKp+kOQtwG3Ae4GXgV+tqufGX7omYVJnsp7FSouzYLgnWQP8R+Diqvr7JHcC1wJXATdX1R1J/iuwDbi13b9SVe9Mci1wE/Cry/YMJsigWzmTXB6ZFH+/tBSjLstMAWcmmQLOAl4ALgd2t/27gKtbe0vbpu3fnCTjKVeSNIoFz9yr6mCSPwC+Cfw9cC+DZZhXq+pwG3YAWNPaa4D97djDSV5jsHTzreHHTbId2A7wjne844SfwE/iGZ0kLWTBM/ck5zI4G78I+BngbOCKpU5cVTuqaqaqZqanj/lZ85KkEzTKC6q/BHyjquYBknwRuAxYlWSqnb2vBQ628QeBdcCBtoxzDoMXViXppNTjJa+jhPs3gU1JzmKwLLMZmAUeAD7M4IqZrcDdbfyetv3Xbf/9VVVjrvsnmktRkhYyypr7Q0l2A48Ah4FHgR3AXwB3JPn91rezHbIT+NMkc8C3GVxZI+kU4clDH0a6zr2qPgV86qjuZ4FLjzH2+8BHll6aJOlE+Q5VSeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHRvmC7HcleWzo9p0kn0hyXpL7kjzT7s9t45PkliRzSR5PsnH5n4YkadiC4V5VX6uqS6rqEuC9wOvAXcCNwN6q2gDsbdsAVwIb2m07cOtyFC5JOr7FLstsBr5eVc8DW4BdrX8XcHVrbwFuq4EHgVVJLhxLtZKkkSw23K8Fbm/t1VX1Qmu/CKxu7TXA/qFjDrS+N0iyPclsktn5+flFliFJejMjh3uSM4APAX9+9L6qKqAWM3FV7aiqmaqamZ6eXsyhkqQFLObM/Urgkap6qW2/dGS5pd0fav0HgXVDx61tfZKkFbKYcP8oP1qSAdgDbG3trcDdQ/3XtatmNgGvDS3fSJJWwNQog5KcDXwA+I2h7k8DdybZBjwPXNP67wGuAuYYXFlz/diqlSSNZKRwr6rvAecf1fcyg6tnjh5bwA1jqU6SdEJ8h6okdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUMjhXuSVUl2J/lqkn1J3pfkvCT3JXmm3Z/bxibJLUnmkjyeZOPyPgVJ0tFGPXP/Y+DLVfVzwLuBfcCNwN6q2gDsbdsw+CLtDe22Hbh1rBVLkha0YLgnOQd4P7AToKp+UFWvAluAXW3YLuDq1t4C3FYDDwKrklw49solScc1ypn7RcA88CdJHk3ymfaF2aur6oU25kVgdWuvAfYPHX+g9b1Bku1JZpPMzs/Pn/gzkCT9mFHCfQrYCNxaVe8BvsePlmCAf/pS7FrMxFW1o6pmqmpmenp6MYdKkhYwSrgfAA5U1UNtezeDsH/pyHJLuz/U9h8E1g0dv7b1SZJWyILhXlUvAvuTvKt1bQaeBvYAW1vfVuDu1t4DXNeumtkEvDa0fCNJWgFTI477TeBzSc4AngWuZ/Afw51JtgHPA9e0sfcAVwFzwOttrCRpBY0U7lX1GDBzjF2bjzG2gBuWWJckaQl8h6okdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1aKRwT/JckieSPJZktvWdl+S+JM+0+3Nbf5LckmQuyeNJNi7nE5Ak/bjFnLn/m6q6pKqOfN3ejcDeqtoA7G3bAFcCG9ptO3DruIqVJI1mKcsyW4Bdrb0LuHqo/7YaeBBYleTCJcwjSVqkUcO9gHuTPJxke+tbXVUvtPaLwOrWXgPsHzr2QOt7gyTbk8wmmZ2fnz+B0iVJxzM14rhfqKqDSX4auC/JV4d3VlUlqcVMXFU7gB0AMzMzizpWkvTmRjpzr6qD7f4QcBdwKfDSkeWWdn+oDT8IrBs6fG3rkyStkAXDPcnZSd5+pA38MvAksAfY2oZtBe5u7T3Ade2qmU3Aa0PLN5KkFTDKssxq4K4kR8Z/vqq+nORvgTuTbAOeB65p4+8BrgLmgNeB68detSTpTS0Y7lX1LPDuY/S/DGw+Rn8BN4ylOknSCfEdqpLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDo0c7klOS/Joki+17YuSPJRkLsmfJTmj9b+lbc+1/euXp3RJ0vEs5sz948C+oe2bgJur6p3AK8C21r8NeKX139zGSZJW0EjhnmQt8EHgM207wOXA7jZkF3B1a29p27T9m9t4SdIKGfXM/Y+A3wb+sW2fD7xaVYfb9gFgTWuvAfYDtP2vtfFvkGR7ktkks/Pz8ydYviTpWBYM9yS/AhyqqofHOXFV7aiqmaqamZ6eHudDS9JPvKkRxlwGfCjJVcBbgX8G/DGwKslUOztfCxxs4w8C64ADSaaAc4CXx165JOm4Fjxzr6rfqaq1VbUeuBa4v6p+DXgA+HAbthW4u7X3tG3a/vurqsZatSTpTS3lOvf/DHwyyRyDNfWdrX8ncH7r/yRw49JKlCQt1ijLMv+kqv4S+MvWfha49Bhjvg98ZAy1SZJOkO9QlaQOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1aJQvyH5rkr9J8pUkTyX5vdZ/UZKHkswl+bMkZ7T+t7TtubZ//fI+BUnS0UY5c/9/wOVV9W7gEuCKJJuAm4Cbq+qdwCvAtjZ+G/BK67+5jZMkraBRviC7quq7bfP0divgcmB3698FXN3aW9o2bf/mJBlbxZKkBY205p7ktCSPAYeA+4CvA69W1eE25ACwprXXAPsB2v7XGHyB9tGPuT3JbJLZ+fn5pT0LSdIbjBTuVfXDqroEWMvgS7F/bqkTV9WOqpqpqpnp6emlPpwkaciirpapqleBB4D3AauSTLVda4GDrX0QWAfQ9p8DvDyWaiVJIxnlapnpJKta+0zgA8A+BiH/4TZsK3B3a+9p27T991dVjbNoSdKbm1p4CBcCu5KcxuA/gzur6ktJngbuSPL7wKPAzjZ+J/CnSeaAbwPXLkPdkqQ3sWC4V9XjwHuO0f8sg/X3o/u/D3xkLNVJkk6I71CVpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHVolK/ZW5fkgSRPJ3kqycdb/3lJ7kvyTLs/t/UnyS1J5pI8nmTjcj8JSdIbjXLmfhj4T1V1MbAJuCHJxcCNwN6q2gDsbdsAVwIb2m07cOvYq5YkvakFw72qXqiqR1r77xh8OfYaYAuwqw3bBVzd2luA22rgQWBVkgvHXrkk6bgWteaeZD2D71N9CFhdVS+0XS8Cq1t7DbB/6LADre/ox9qeZDbJ7Pz8/CLLliS9mZHDPcnbgC8An6iq7wzvq6oCajETV9WOqpqpqpnp6enFHCpJWsBI4Z7kdAbB/rmq+mLrfunIcku7P9T6DwLrhg5f2/okSStklKtlAuwE9lXVHw7t2gNsbe2twN1D/de1q2Y2Aa8NLd9IklbA1AhjLgM+BjyR5LHW97vAp4E7k2wDngeuafvuAa4C5oDXgevHWrEkaUELhntV/W8gx9m9+RjjC7hhiXVJkpbAd6hKUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjo0ytfsfTbJoSRPDvWdl+S+JM+0+3Nbf5LckmQuyeNJNi5n8ZKkYxvlzP2/AVcc1XcjsLeqNgB72zbAlcCGdtsO3DqeMiVJi7FguFfVXwHfPqp7C7CrtXcBVw/131YDDwKrklw4rmIlSaM50TX31VX1Qmu/CKxu7TXA/qFxB1qfJGkFLfkF1faF2LXY45JsTzKbZHZ+fn6pZUiShpxouL90ZLml3R9q/QeBdUPj1ra+H1NVO6pqpqpmpqenT7AMSdKxnGi47wG2tvZW4O6h/uvaVTObgNeGlm8kSStkaqEBSW4HfhG4IMkB4FPAp4E7k2wDngeuacPvAa4C5oDXgeuXoWZJ0gIWDPeq+uhxdm0+xtgCblhqUZKkpfEdqpLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktShZQn3JFck+VqSuSQ3LscckqTjG3u4JzkN+C/AlcDFwEeTXDzueSRJx7ccZ+6XAnNV9WxV/QC4A9iyDPNIko5jwS/IPgFrgP1D2weAf3n0oCTbge1t87tJvnaC810AfOsEjx0n63gj6zi5agDrONpJUUduWlIdP3u8HcsR7iOpqh3AjqU+TpLZqpoZQ0nWYR3d1mAdP3l1LMeyzEFg3dD22tYnSVohyxHufwtsSHJRkjOAa4E9yzCPJOk4xr4sU1WHk/wH4H8ApwGfraqnxj3PkCUv7YyJdbyRdfzIyVADWMfRuq4jVbUcjytJmiDfoSpJHTLcJalDp3S4nwwfc5Dks0kOJXlyEvO3GtYleSDJ00meSvLxCdXx1iR/k+QrrY7fm0QdQ/WcluTRJF+aYA3PJXkiyWNJZidYx6oku5N8Ncm+JO+bQA3vaj+HI7fvJPnEBOr4rfb7+WSS25O8daVraHV8vNXw1LL8HKrqlLwxeLH268A/B84AvgJcPIE63g9sBJ6c4M/iQmBja78d+D8T+lkEeFtrnw48BGya4M/lk8DngS9NsIbngAsmNf9QHbuAX2/tM4BVE67nNOBF4GdXeN41wDeAM9v2ncC/n8Dz/3ngSeAsBhe2/E/gneOc41Q+cz8pPuagqv4K+PZKz3tUDS9U1SOt/XfAPga/xCtdR1XVd9vm6e02kVfsk6wFPgh8ZhLzn0ySnMPgJGQnQFX9oKpenWxVbAa+XlXPT2DuKeDMJFMMwvX/TqCGfwE8VFWvV9Vh4H8B/3acE5zK4X6sjzlY8UA72SRZD7yHwVnzJOY/LcljwCHgvqqaSB3AHwG/DfzjhOY/ooB7kzzcPnJjEi4C5oE/actUn0ly9oRqOeJa4PaVnrSqDgJ/AHwTeAF4raruXek6GJy1/+sk5yc5C7iKN775c8lO5XDXUZK8DfgC8Imq+s4kaqiqH1bVJQzemXxpkp9f6RqS/ApwqKoeXum5j+EXqmojg09JvSHJ+ydQwxSDpcNbq+o9wPeAiX0Ud3tz44eAP5/A3Ocy+Av/IuBngLOT/LuVrqOq9gE3AfcCXwYeA344zjlO5XD3Yw6GJDmdQbB/rqq+OOl62p/9DwBXTGD6y4APJXmOwXLd5Un++wTqOHKmSFUdAu5isJy40g4AB4b+itrNIOwn5Urgkap6aQJz/xLwjaqar6p/AL4I/KsJ1EFV7ayq91bV+4FXGLxWNjancrj7MQdNkjBYT91XVX84wTqmk6xq7TOBDwBfXek6qup3qmptVa1n8Htxf1Wt+NlZkrOTvP1IG/hlBn+Or6iqehHYn+RdrWsz8PRK1zHko0xgSab5JrApyVnt381mBq9RrbgkP93u38Fgvf3z43z8iX0q5FLVyn/MwTEluR34ReCCJAeAT1XVzhUu4zLgY8ATbb0b4Her6p4VruNCYFf7wpafAu6sqoldhngSWA3cNcgQpoDPV9WXJ1TLbwKfaydCzwLXT6KI9p/cB4DfmMT8VfVQkt3AI8Bh4FEm9zEEX0hyPvAPwA3jfpHbjx+QpA6dyssykqTjMNwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtSh/4/X/MaN0GgWaMAAAAASUVORK5CYII=\n", 782 | "text/plain": [ 783 | "
" 784 | ] 785 | }, 786 | "metadata": { 787 | "needs_background": "light" 788 | }, 789 | "output_type": "display_data" 790 | } 791 | ], 792 | "source": [ 793 | "k = 8\n", 794 | "samples = []\n", 795 | "for _ in range(1000):\n", 796 | " samples.extend(choice(x, k, False).cpu().numpy())\n", 797 | "plt.hist(samples)\n", 798 | "plt.xticks(range(x.numel()))\n", 799 | "plt.show()" 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "execution_count": 35, 805 | "metadata": {}, 806 | "outputs": [ 807 | { 808 | "data": { 809 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAARMElEQVR4nO3df4xlZX3H8fdHFvwBWtAdCQV0wa6k1LQLnVBahVBRC2hA/YNCWgRLu5hgo9XEIE2KbWKirWht2mJWoUAKK8iPSCy1UGokJgWdhRWWXwq4yG6X3VEsoBh14ds/7pl6GWbZmbl37t198n4lN/ec55xzn2+G4bPPPPf8SFUhSWrLi8ZdgCRp+Ax3SWqQ4S5JDTLcJalBhrskNWjZuAsAWL58ea1YsWLcZUjSbmXdunU/qKqJubbtEuG+YsUKpqamxl2GJO1Wkjyyo21Oy0hSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoN2iStUJe1aVpz3b2Ppd+Mn3j6WflvkyF2SGuTIfTc0rlEVOLKSdheO3CWpQYa7JDXIcJekBjnnrgXxLApp92C4D2CcX2xK0gvZabgnORi4HNgfKGBNVX02ySuBq4AVwEbg1Kr6UZIAnwVOAp4GzqqqO5amfEkaXItnoM1nzn078OGqOhw4Gjg3yeHAecAtVbUSuKVbBzgRWNm9VgMXDb1qSdIL2mm4V9WWmZF3VT0F3AccCJwCXNbtdhnwzm75FODy6rkN2DfJAUOvXJK0Qws6WybJCuAI4HZg/6ra0m16jN60DfSC/9G+wzZ1bbM/a3WSqSRT09PTCyxbkvRC5v2FapJ9gGuBD1bVk72p9Z6qqiS1kI6rag2wBmBycnJBx0qj5BlC2h3Na+SeZE96wX5FVV3XNW+dmW7p3rd17ZuBg/sOP6hrkySNyE7DvTv75WLgvqr6dN+mG4Azu+UzgS/3tb8nPUcDT/RN30iSRmA+0zJvBM4A7k6yvms7H/gEcHWSs4FHgFO7bTfSOw3yQXqnQr53qBXPwfPNJem5dhruVfUNIDvYfPwc+xdw7oB1Sc/hP+DSwnhvGUlqkLcfkHZR/rWiQThyl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDvHGYpF2GN0sbHkfuktSg+Txm75Ik25Js6Gu7Ksn67rVx5glNSVYk+Wnfts8tZfGSpLnNZ1rmUuAfgctnGqrqD2eWk1wIPNG3/0NVtWpYBUqSFm4+j9m7NcmKubZ1D88+FXjzcMuSJA1i0Dn3Y4CtVfXdvrZDktyZ5OtJjtnRgUlWJ5lKMjU9PT1gGZKkfoOG++nA2r71LcBrquoI4EPAlUleMdeBVbWmqiaranJiYmLAMiRJ/RYd7kmWAe8Grpppq6qfVdUPu+V1wEPA6wctUpK0MIOM3N8C3F9Vm2Yakkwk2aNbPhRYCTw8WImSpIWaz6mQa4H/Bg5LsinJ2d2m03julAzAscBd3amR1wDvq6rHh1mwJGnn5nO2zOk7aD9rjrZrgWsHL0uSNAivUJWkBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalB83lYxyVJtiXZ0Nf2sSSbk6zvXif1bftokgeTPJDkD5aqcEnSjs1n5H4pcMIc7Z+pqlXd60aAJIfTe0LTb3TH/PPMY/ckSaOz03CvqluB+T4q7xTgi92Dsr8HPAgcNUB9kqRFGGTO/f1J7uqmbfbr2g4EHu3bZ1PX9jxJVieZSjI1PT09QBmSpNkWG+4XAa8DVgFbgAsX+gFVtaaqJqtqcmJiYpFlSJLmsqhwr6qtVfVMVT0LfJ5fTr1sBg7u2/Wgrk2SNEKLCvckB/StvguYOZPmBuC0JC9OcgiwEvjmYCVKkhZq2c52SLIWOA5YnmQTcAFwXJJVQAEbgXMAquqeJFcD9wLbgXOr6pmlKV2StCM7DfeqOn2O5otfYP+PAx8fpChJ0mC8QlWSGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1KCdhnuSS5JsS7Khr+3vktyf5K4k1yfZt2tfkeSnSdZ3r88tZfGSpLnNZ+R+KXDCrLabgTdU1W8C3wE+2rftoapa1b3eN5wyJUkLsdNwr6pbgcdntd1UVdu71duAg5agNknSIg1jzv1PgH/vWz8kyZ1Jvp7kmB0dlGR1kqkkU9PT00MoQ5I0Y6BwT/KXwHbgiq5pC/CaqjoC+BBwZZJXzHVsVa2pqsmqmpyYmBikDEnSLIsO9yRnAe8A/qiqCqCqflZVP+yW1wEPAa8fQp2SpAVYVLgnOQH4CHByVT3d1z6RZI9u+VBgJfDwMAqVJM3fsp3tkGQtcBywPMkm4AJ6Z8e8GLg5CcBt3ZkxxwJ/k+QXwLPA+6rq8Tk/WJK0ZHYa7lV1+hzNF+9g32uBawctSpI0GK9QlaQGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUHzCvcklyTZlmRDX9srk9yc5Lvd+35de5L8Q5IHk9yV5MilKl6SNLf5jtwvBU6Y1XYecEtVrQRu6dYBTqT3eL2VwGrgosHLlCQtxLzCvapuBWY/Lu8U4LJu+TLgnX3tl1fPbcC+SQ4YRrGSpPkZZM59/6ra0i0/BuzfLR8IPNq336au7TmSrE4ylWRqenp6gDIkSbMN5QvVqiqgFnjMmqqarKrJiYmJYZQhSeoMEu5bZ6ZbuvdtXftm4OC+/Q7q2iRJIzJIuN8AnNktnwl8ua/9Pd1ZM0cDT/RN30iSRmDZfHZKshY4DlieZBNwAfAJ4OokZwOPAKd2u98InAQ8CDwNvHfINUuSdmJe4V5Vp+9g0/Fz7FvAuYMUJUkajFeoSlKDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1KB53c99LkkOA67qazoU+CtgX+DPgJmnXp9fVTcuukJJ0oItOtyr6gFgFUCSPeg9J/V6ek9e+kxVfWooFUqSFmxY0zLHAw9V1SND+jxJ0gCGFe6nAWv71t+f5K4klyTZb0h9SJLmaeBwT7IXcDLwpa7pIuB19KZstgAX7uC41UmmkkxNT0/PtYskaZGGMXI/EbijqrYCVNXWqnqmqp4FPg8cNddBVbWmqiaranJiYmIIZUiSZgwj3E+nb0omyQF9294FbBhCH5KkBVj02TIASfYG3gqc09f8t0lWAQVsnLVNkjQCA4V7Vf0EeNWstjMGqkiSNDCvUJWkBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNWigh3UAJNkIPAU8A2yvqskkrwSuAlbQexrTqVX1o0H7kiTNz7BG7r9fVauqarJbPw+4papWArd065KkEVmqaZlTgMu65cuAdy5RP5KkOQwj3Au4Kcm6JKu7tv2raku3/Biw/+yDkqxOMpVkanp6eghlSJJmDDznDrypqjYneTVwc5L7+zdWVSWp2QdV1RpgDcDk5OTztkuSFm/gkXtVbe7etwHXA0cBW5McANC9bxu0H0nS/A0U7kn2TvLymWXgbcAG4AbgzG63M4EvD9KPJGlhBp2W2R+4PsnMZ11ZVV9N8i3g6iRnA48Apw7YjyRpAQYK96p6GPitOdp/CBw/yGdLkhbPK1QlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lq0KLDPcnBSb6W5N4k9yT5QNf+sSSbk6zvXicNr1xJ0nwM8rCO7cCHq+qO7lF765Lc3G37TFV9avDyJEmLsehwr6otwJZu+akk9wEHDqswSdLiDWXOPckK4Ajg9q7p/UnuSnJJkv12cMzqJFNJpqanp4dRhiSpM3C4J9kHuBb4YFU9CVwEvA5YRW9kf+Fcx1XVmqqarKrJiYmJQcuQJPUZKNyT7Ekv2K+oqusAqmprVT1TVc8CnweOGrxMSdJCDHK2TICLgfuq6tN97Qf07fYuYMPiy5MkLcYgZ8u8ETgDuDvJ+q7tfOD0JKuAAjYC5wxUoSRpwQY5W+YbQObYdOPiy5EkDYNXqEpSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGrRk4Z7khCQPJHkwyXlL1Y8k6fmWJNyT7AH8E3AicDi9R+8dvhR9SZKeb6lG7kcBD1bVw1X1c+CLwClL1JckaZZBHpD9Qg4EHu1b3wT8Tv8OSVYDq7vVHyd5YID+lgM/GOD4YdgVagDrmM06dq0awDqeI58cqI7X7mjDUoX7TlXVGmDNMD4ryVRVTQ7js3bnGqzDOnb1GqxjdHUs1bTMZuDgvvWDujZJ0ggsVbh/C1iZ5JAkewGnATcsUV+SpFmWZFqmqrYneT/wH8AewCVVdc9S9NUZyvTOgHaFGsA6ZrOOX9oVagDrmG1J6khVLcXnSpLGyCtUJalBhrskNWi3Dvdd4RYHSS5Jsi3JhnH031fHwUm+luTeJPck+cCY6nhJkm8m+XZXx1+Po46ulj2S3JnkK2OsYWOSu5OsTzI1xjr2TXJNkvuT3Jfkd8dQw2Hdz2Hm9WSSD466jq6Wv+h+PzckWZvkJWOo4QNd//csyc+hqnbLF70vah8CDgX2Ar4NHD6GOo4FjgQ2jPnncQBwZLf8cuA7Y/p5BNinW94TuB04ekw/kw8BVwJfGeN/l43A8nH+bnR1XAb8abe8F7DvmOvZA3gMeO0Y+j4Q+B7w0m79auCsEdfwBmAD8DJ6J7b8J/Brw+xjdx657xK3OKiqW4HHR93vHHVsqao7uuWngPvo/RKPuo6qqh93q3t2r5F/a5/kIODtwBdG3feuJsmv0BuEXAxQVT+vqv8db1UcDzxUVY+Mqf9lwEuTLKMXsP8z4v5/Hbi9qp6uqu3A14F3D7OD3Tnc57rFwcjDbFeUZAVwBL1R8zj63yPJemAbcHNVjaOOvwc+Ajw7hr77FXBTknXdLTfG4RBgGviXbprqC0n2HlMtM04D1o6j46raDHwK+D6wBXiiqm4acRkbgGOSvCrJy4CTeO6FnwPbncNdc0iyD3At8MGqenIcNVTVM1W1it6VyUclecMo+0/yDmBbVa0bZb878KaqOpLeHVLPTXLsGGpYRm/q8KKqOgL4CTC223B3FzaeDHxpTP3vR++v/EOAXwX2TvLHo6yhqu4DPgncBHwVWA88M8w+dudw9xYHsyTZk16wX1FV1427nu5P/68BJ4y46zcCJyfZSG+67s1J/nXENQD/P0qkqrYB19ObThy1TcCmvr+grqEX9uNyInBHVW0dU/9vAb5XVdNV9QvgOuD3Rl1EVV1cVb9dVccCP6L3PdnQ7M7h7i0O+iQJvTnV+6rq02OsYyLJvt3yS4G3AvePsoaq+mhVHVRVK+j9XvxXVY10ZAaQZO8kL59ZBt5G78/xkaqqx4BHkxzWNR0P3DvqOvqczpimZDrfB45O8rLu/5vj6X1HNVJJXt29v4befPuVw/z8sd0VclA1+lsczCnJWuA4YHmSTcAFVXXxqOugN1o9A7i7m+8GOL+qbhxxHQcAl3UPbHkRcHVVje1UxDHbH7i+lx8sA66sqq+OqZY/B67oBkIPA+8dRxHdP3JvBc4ZR/8AVXV7kmuAO4DtwJ2M51YE1yZ5FfAL4Nxhf8nt7QckqUG787SMJGkHDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUoP8DvMWgTtLp1z8AAAAASUVORK5CYII=\n", 810 | "text/plain": [ 811 | "
" 812 | ] 813 | }, 814 | "metadata": { 815 | "needs_background": "light" 816 | }, 817 | "output_type": "display_data" 818 | } 819 | ], 820 | "source": [ 821 | "k = 2\n", 822 | "samples = []\n", 823 | "for _ in range(1000):\n", 824 | " samples.extend(choice(x, k, True).cpu().numpy())\n", 825 | "plt.hist(samples)\n", 826 | "plt.xticks(range(x.numel()))\n", 827 | "plt.show()" 828 | ] 829 | }, 830 | { 831 | "cell_type": "code", 832 | "execution_count": 36, 833 | "metadata": {}, 834 | "outputs": [ 835 | { 836 | "data": { 837 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAANg0lEQVR4nO3df6zd9V3H8edrLTh+RWDtmtqWFbVZrEsEvEF0k6B1E5hZ0T8IJG6VYLo/OgNqYtj+Qf9YwhKdZomS1IHrIrBVfgQyCYKVSPxjbC1jUOhwHZTRWmgnE5gkbrC3f5xv42m55fbec8/53vvx+Uhu7vd8zzn3++7N5cm3n3PPt6kqJElteUffA0iS5p9xl6QGGXdJapBxl6QGGXdJatDSvgcAWLZsWa1du7bvMSRpUdm1a9f3qmr5dPctiLivXbuWnTt39j2GJC0qSZ4/3n0uy0hSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSgxbEO1Slmay94R97O/a+mz7c27GlufLMXZIaZNwlqUEuy2hW+lwekXTiPHOXpAYZd0lqkHGXpAa55i4tUH29vuGvfrbBM3dJapBxl6QGGXdJapBxl6QG+YLqIuQbiSTNxDN3SWrQjHFPsibJw0meTvJUkuu6/WcneSjJt7vPZ3X7k+RzSfYmeSLJBeP+Q0iSjnYiZ+5vAH9cVeuBi4AtSdYDNwA7qmodsKO7DXAZsK772AzcPO9TS5Le1oxxr6qDVfVYt/0asAdYBWwEtnUP2wZc0W1vBL5YA18Fzkyyct4nlyQd16xeUE2yFjgfeBRYUVUHu7teBFZ026uAF4aetr/bd3BoH0k2Mziz55xzzpnl2JJa5Lty588Jv6Ca5HTgLuD6qnp1+L6qKqBmc+Cq2lpVU1U1tXz58tk8VZI0gxM6c09yEoOw31ZVd3e7X0qysqoOdssuh7r9B4A1Q09f3e0bC//5NUl6qxP5bZkAtwB7quqzQ3fdB2zqtjcB9w7t/1j3WzMXAa8MLd9IkibgRM7c3w98FHgyyePdvk8BNwHbk1wLPA9c2d13P3A5sBd4HbhmXieWJM1oxrhX1b8BOc7dG6Z5fAFbRpxLkjQC36EqSQ0y7pLUIOMuSQ3yqpAj8OqMapE/120w7tIMjJ0WI5dlJKlBxl2SGmTcJalBrrlL+n+vxWtUeeYuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ2aMe5Jbk1yKMnuoX1/muRAkse7j8uH7vtkkr1Jnknym+MaXJJ0fCdy5v4F4NJp9v9lVZ3XfdwPkGQ9cBXw891z/ibJkvkaVpJ0YmaMe1U9Arx8gl9vI/ClqvqfqnoO2AtcOMJ8kqQ5GGXN/RNJnuiWbc7q9q0CXhh6zP5u31sk2ZxkZ5Kdhw8fHmEMSdKx5hr3m4GfAc4DDgJ/MdsvUFVbq2qqqqaWL18+xzEkSdOZU9yr6qWqerOqfgz8Lf+39HIAWDP00NXdPknSBM0p7klWDt38beDIb9LcB1yV5CeSnAusA7422oiSpNlaOtMDktwBXAIsS7IfuBG4JMl5QAH7gI8DVNVTSbYDTwNvAFuq6s3xjC5JOp4Z415VV0+z+5a3efyngU+PMpQkaTS+Q1WSGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGjRj3JPcmuRQkt1D+85O8lCSb3efz+r2J8nnkuxN8kSSC8Y5vCRpeidy5v4F4NJj9t0A7KiqdcCO7jbAZcC67mMzcPP8jClJmo0Z415VjwAvH7N7I7Ct294GXDG0/4s18FXgzCQr52tYSdKJmeua+4qqOthtvwis6LZXAS8MPW5/t+8tkmxOsjPJzsOHD89xDEnSdEZ+QbWqCqg5PG9rVU1V1dTy5ctHHUOSNGSucX/pyHJL9/lQt/8AsGbocau7fZKkCZpr3O8DNnXbm4B7h/Z/rPutmYuAV4aWbyRJE7J0pgckuQO4BFiWZD9wI3ATsD3JtcDzwJXdw+8HLgf2Aq8D14xhZknSDGaMe1VdfZy7Nkzz2AK2jDqUJGk0vkNVkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQUtHeXKSfcBrwJvAG1U1leRs4MvAWmAfcGVVfX+0MSVJszEfZ+6/VlXnVdVUd/sGYEdVrQN2dLclSRM0jmWZjcC2bnsbcMUYjiFJehujxr2AB5PsSrK527eiqg522y8CK6Z7YpLNSXYm2Xn48OERx5AkDRtpzR34QFUdSPJu4KEk3xq+s6oqSU33xKraCmwFmJqamvYxkqS5GenMvaoOdJ8PAfcAFwIvJVkJ0H0+NOqQkqTZmXPck5yW5Iwj28CHgN3AfcCm7mGbgHtHHVKSNDujLMusAO5JcuTr3F5VDyT5OrA9ybXA88CVo48pSZqNOce9qp4FfmGa/f8JbBhlKEnSaHyHqiQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1aGxxT3JpkmeS7E1yw7iOI0l6q7HEPckS4K+By4D1wNVJ1o/jWJKktxrXmfuFwN6qeraqfgh8Cdg4pmNJko6xdExfdxXwwtDt/cAvDT8gyWZgc3fzB0memeOxlgHfm+Nz55NzHM05jrYQ5lgIM4BzHCWfGWmO9xzvjnHFfUZVtRXYOurXSbKzqqbmYSTncI6m51gIMzjH5OYY17LMAWDN0O3V3T5J0gSMK+5fB9YlOTfJycBVwH1jOpYk6RhjWZapqjeSfAL4J2AJcGtVPTWOYzEPSzvzxDmO5hxHWwhzLIQZwDmONZY5UlXj+LqSpB75DlVJapBxl6QGLeq4L4RLHCS5NcmhJLv7OP7QHGuSPJzk6SRPJbmuhxnemeRrSb7ZzfBnk57hmHmWJPlGkq/0OMO+JE8meTzJzh7nODPJnUm+lWRPkl/uYYb3dt+HIx+vJrm+hzn+sPv53J3kjiTvnPQM3RzXdTM8NZbvQ1Utyg8GL9R+B/hp4GTgm8D6Hua4GLgA2N3z92MlcEG3fQbw75P+fgABTu+2TwIeBS7q8XvyR8DtwFd6nGEfsKzPn41ujm3A73fbJwNn9jzPEuBF4D0TPu4q4DnglO72duD3evjzvw/YDZzK4Bdb/hn42fk8xmI+c18QlzioqkeAlyd93GnmOFhVj3XbrwF7GPwgT3KGqqofdDdP6j56ecU+yWrgw8Dn+zj+QpLkJxmchNwCUFU/rKr/6ncqNgDfqarnezj2UuCUJEsZxPU/epjh54BHq+r1qnoD+Ffgd+bzAIs57tNd4mCiMVuokqwFzmdw5jzpYy9J8jhwCHioqiY+Q+evgD8BftzT8Y8o4MEku7pLbvThXOAw8HfdMtXnk5zW0yxHXAXcMemDVtUB4M+B7wIHgVeq6sFJz8HgrP1Xk7wryanA5Rz9xs+RLea4axpJTgfuAq6vqlcnffyqerOqzmPwruQLk7xv0jMk+S3gUFXtmvSxp/GBqrqAwRVStyS5uIcZljJYOry5qs4H/hvo7TLc3RsbPwL8Qw/HPovB3/DPBX4KOC3J7056jqraA3wGeBB4AHgceHM+j7GY4+4lDo6R5CQGYb+tqu7uc5bur/0PA5f2cPj3Ax9Jso/Bct2vJ/n7HuY4cqZIVR0C7mGwnDhp+4H9Q3+LupNB7PtyGfBYVb3Uw7F/A3iuqg5X1Y+Au4Ff6WEOquqWqvrFqroY+D6D18nmzWKOu5c4GJIkDNZU91TVZ3uaYXmSM7vtU4APAt+a9BxV9cmqWl1Vaxn8XPxLVU387CzJaUnOOLINfIjBX8cnqqpeBF5I8t5u1wbg6UnPMeRqeliS6XwXuCjJqd1/MxsYvD41cUne3X0+h8F6++3z+fV7uyrkqGqylzg4riR3AJcAy5LsB26sqlsmPQeDs9WPAk92a94An6qq+yc4w0pgW/ePtbwD2F5Vvf0a4gKwArhn0BCWArdX1QM9zfIHwG3didCzwDV9DNH9T+6DwMf7OH5VPZrkTuAx4A3gG/R3GYK7krwL+BGwZb5f5PbyA5LUoMW8LCNJOg7jLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1KD/BS1bLB9yY/37AAAAAElFTkSuQmCC\n", 838 | "text/plain": [ 839 | "
" 840 | ] 841 | }, 842 | "metadata": { 843 | "needs_background": "light" 844 | }, 845 | "output_type": "display_data" 846 | } 847 | ], 848 | "source": [ 849 | "k = 2\n", 850 | "samples = []\n", 851 | "for _ in range(1000):\n", 852 | " samples.extend(choice(x, k, False).cpu().numpy())\n", 853 | "plt.hist(samples)\n", 854 | "plt.xticks(range(x.numel()))\n", 855 | "plt.show()" 856 | ] 857 | }, 858 | { 859 | "cell_type": "markdown", 860 | "metadata": {}, 861 | "source": [ 862 | "# Checking distributions - Weighted" 863 | ] 864 | }, 865 | { 866 | "cell_type": "code", 867 | "execution_count": 37, 868 | "metadata": {}, 869 | "outputs": [], 870 | "source": [ 871 | "x = torch.arange(10)\n", 872 | "weights = torch.Tensor([1, 1, 5, 5, 5, 1, 7, 2, 2, 2]).float()" 873 | ] 874 | }, 875 | { 876 | "cell_type": "markdown", 877 | "metadata": {}, 878 | "source": [ 879 | "### CPP Extension" 880 | ] 881 | }, 882 | { 883 | "cell_type": "code", 884 | "execution_count": 38, 885 | "metadata": {}, 886 | "outputs": [ 887 | { 888 | "data": { 889 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAOw0lEQVR4nO3df6zddX3H8efLFqaiGdVeG9aCl22NWWcyZA2y6QwbG/JjEbc/DCTTjrjUP+qi25Kl+g+bxgSTzS0mjoRJZ80Ew1RiI43QMTOzP0QKIhTQUbFIu0Lr6lBHMsW998f5dDmWe3vbe0/P99LP85GcnO95f7/n+3mfy+3rfO/nfM+XVBWSpD68aOgGJEnTY+hLUkcMfUnqiKEvSR0x9CWpIyuHbuB4Vq9eXbOzs0O3IUkvKPfdd993q2pmrnXLOvRnZ2fZvXv30G1I0gtKkifmW+f0jiR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdWRZfyNX0vPNbr1jsLH33XDVYGNrMjzSl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdWTD0k5yb5EtJHknycJL3tPorkuxK8li7X9XqSfLRJHuTPJjkwrF9bWrbP5Zk06l7WZKkuZzIkf5zwJ9V1QbgYmBLkg3AVuDuqloP3N0eA1wBrG+3zcCNMHqTAK4HXg9cBFx/9I1CkjQdC4Z+VR2sqvvb8g+AR4G1wNXA9rbZduCtbflq4JM18hXg7CTnAG8GdlXVkar6HrALuHyir0aSdFwnNaefZBZ4HXAPsKaqDrZVTwFr2vJa4Mmxp+1vtfnqx46xOcnuJLsPHz58Mu1JkhZwwqGf5GXAZ4H3VtX3x9dVVQE1iYaq6qaq2lhVG2dmZiaxS0lSc0Khn+QMRoH/qar6XCs/3aZtaPeHWv0AcO7Y09e12nx1SdKUnMjZOwFuBh6tqo+MrdoBHD0DZxPw+bH6O9pZPBcDz7RpoDuBy5Ksah/gXtZqkqQpWXkC27wBeDvwUJIHWu39wA3AbUneCTwBvK2t2wlcCewFngWuA6iqI0k+CNzbtvtAVR2ZyKuQJJ2QBUO/qv4NyDyrL51j+wK2zLOvbcC2k2lQkjQ5fiNXkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SerIyqEbkJZidusdg42974arBhtbWiyP9CWpI4a+JHXE6R1NxJDTLJJOnKF/mjF8JR2P0zuS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpIwuGfpJtSQ4l2TNW+4skB5I80G5Xjq17X5K9Sb6Z5M1j9ctbbW+SrZN/KZKkhZzIkf4ngMvnqP9NVV3QbjsBkmwArgF+uT3n75KsSLIC+BhwBbABuLZtK0maogW/nFVVX04ye4L7uxr4dFX9D/DtJHuBi9q6vVX1OECST7dtHznpjiVJi7aUOf13J3mwTf+sarW1wJNj2+xvtfnqz5Nkc5LdSXYfPnx4Ce1Jko612NC/EfgF4ALgIPDXk2qoqm6qqo1VtXFmZmZSu5Ukschr71TV00eXk/w98IX28ABw7tim61qN49QlSVOyqCP9JOeMPfw94OiZPTuAa5L8TJLzgfXAV4F7gfVJzk9yJqMPe3csvm1J0mIseKSf5FbgEmB1kv3A9cAlSS4ACtgHvAugqh5OchujD2ifA7ZU1U/aft4N3AmsALZV1cMTfzWSpOM6kbN3rp2jfPNxtv8Q8KE56juBnSfVnSRpovxGriR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqyIKhn2RbkkNJ9ozVXpFkV5LH2v2qVk+SjybZm+TBJBeOPWdT2/6xJJtOzcuRJB3PiRzpfwK4/JjaVuDuqloP3N0eA1wBrG+3zcCNMHqTAK4HXg9cBFx/9I1CkjQ9C4Z+VX0ZOHJM+Wpge1veDrx1rP7JGvkKcHaSc4A3A7uq6khVfQ/YxfPfSCRJp9hi5/TXVNXBtvwUsKYtrwWeHNtuf6vNV3+eJJuT7E6y+/Dhw4tsT5I0lyV/kFtVBdQEejm6v5uqamNVbZyZmZnUbiVJLD70n27TNrT7Q61+ADh3bLt1rTZfXZI0RYsN/R3A0TNwNgGfH6u/o53FczHwTJsGuhO4LMmq9gHuZa0mSZqilQttkORW4BJgdZL9jM7CuQG4Lck7gSeAt7XNdwJXAnuBZ4HrAKrqSJIPAve27T5QVcd+OCxJOsUWDP2qunaeVZfOsW0BW+bZzzZg20l1J0maKL+RK0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSMrh25Akpaz2a13DDLuvhuuOiX7NfQlLXtDBe/pyNCXdMIM3xc+5/QlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0JakjSwr9JPuSPJTkgSS7W+0VSXYleazdr2r1JPlokr1JHkxy4SRegCTpxE3iSP83q+qCqtrYHm8F7q6q9cDd7THAFcD6dtsM3DiBsSVJJ+FUTO9cDWxvy9uBt47VP1kjXwHOTnLOKRhfkjSPpYZ+AXcluS/J5lZbU1UH2/JTwJq2vBZ4cuy5+1tNkjQlS73K5hur6kCSVwG7knxjfGVVVZI6mR22N4/NAOedd94S25MkjVvSkX5VHWj3h4DbgYuAp49O27T7Q23zA8C5Y09f12rH7vOmqtpYVRtnZmaW0p4k6RiLDv0kZyV5+dFl4DJgD7AD2NQ22wR8vi3vAN7RzuK5GHhmbBpIkjQFS5neWQPcnuTofm6pqi8muRe4Lck7gSeAt7XtdwJXAnuBZ4HrljC2JGkRFh36VfU48Ctz1P8TuHSOegFbFjueJGnp/EauJHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1ZKlX2VzWZrfeMci4+264apBxJWkhHulLUkdO6yN96VQa6i9JaSkM/VPAMJC0XDm9I0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSNTD/0klyf5ZpK9SbZOe3xJ6tlUQz/JCuBjwBXABuDaJBum2YMk9WzaR/oXAXur6vGq+hHwaeDqKfcgSd1aOeXx1gJPjj3eD7x+fIMkm4HN7eEPk3xzCeOtBr67hOdPwnLoAezjWPbx05ZDH8uhB1gmfeTDS+rj1fOtmHboL6iqbgJumsS+kuyuqo2T2NcLuQf7sI8XQh/LoYce+pj29M4B4Nyxx+taTZI0BdMO/XuB9UnOT3ImcA2wY8o9SFK3pjq9U1XPJXk3cCewAthWVQ+fwiEnMk20RMuhB7CPY9nHT1sOfSyHHuA07yNVdSr2K0lahvxGriR1xNCXpI6clqG/HC71kGRbkkNJ9gwx/lgf5yb5UpJHkjyc5D0D9fHiJF9N8vXWx18O0UfrZUWSryX5woA97EvyUJIHkuwesI+zk3wmyTeSPJrk1wbo4TXt53D09v0k7512H62XP2m/n3uS3JrkxQP08J42/sOn5OdQVafVjdEHxN8Cfh44E/g6sGGAPt4EXAjsGfjncQ5wYVt+OfDvA/08ArysLZ8B3ANcPNDP5E+BW4AvDPjfZR+wesjfjdbHduCP2vKZwNkD97MCeAp49QBjrwW+DbykPb4N+MMp9/BaYA/wUkYn2vwz8IuTHON0PNJfFpd6qKovA0emPe4cfRysqvvb8g+ARxn9ck+7j6qqH7aHZ7Tb1M8iSLIOuAr4+LTHXm6S/Cyjg5ObAarqR1X1X8N2xaXAt6rqiYHGXwm8JMlKRsH7H1Me/5eAe6rq2ap6DvhX4PcnOcDpGPpzXeph6iG3HCWZBV7H6Ch7iPFXJHkAOATsqqoh+vhb4M+B/x1g7HEF3JXkvnbpkSGcDxwG/qFNd308yVkD9XLUNcCtQwxcVQeAvwK+AxwEnqmqu6bcxh7gN5K8MslLgSv56S+0LtnpGPqaQ5KXAZ8F3ltV3x+ih6r6SVVdwOib2Bclee00x0/yu8ChqrpvmuPO441VdSGjK85uSfKmAXpYyWgK8saqeh3w38BglztvX9h8C/BPA42/itGswPnAzwFnJfmDafZQVY8CHwbuAr4IPAD8ZJJjnI6h76UejpHkDEaB/6mq+tzQ/bQphC8Bl0956DcAb0myj9G0328l+ccp9wD8/1ElVXUIuJ3RtOS07Qf2j/3F9RlGbwJDuQK4v6qeHmj83wa+XVWHq+rHwOeAX592E1V1c1X9alW9Cfgeo8/hJuZ0DH0v9TAmSRjN2T5aVR8ZsI+ZJGe35ZcAvwN8Y5o9VNX7qmpdVc0y+r34l6qa6pEcQJKzkrz86DJwGaM/66eqqp4Cnkzymla6FHhk2n2MuZaBpnaa7wAXJ3lp+3dzKaPPwKYqyava/XmM5vNvmeT+l91VNpeqpn+phzkluRW4BFidZD9wfVXdPO0+GB3dvh14qM2nA7y/qnZOuY9zgO3tf6TzIuC2qhrslMmBrQFuH+UKK4FbquqLA/Xyx8Cn2gHS48B1QzTR3vx+B3jXEOMDVNU9ST4D3A88B3yNYS7J8NkkrwR+DGyZ9IfrXoZBkjpyOk7vSJLmYehLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjvwf7KCmJNM4zuYAAAAASUVORK5CYII=\n", 890 | "text/plain": [ 891 | "
" 892 | ] 893 | }, 894 | "metadata": { 895 | "needs_background": "light" 896 | }, 897 | "output_type": "display_data" 898 | } 899 | ], 900 | "source": [ 901 | "k = 1\n", 902 | "samples = []\n", 903 | "for _ in range(10000):\n", 904 | " samples.extend(choice(x, k, True, weights).cpu().numpy())\n", 905 | "plt.hist(samples)\n", 906 | "plt.xticks(range(x.numel()))\n", 907 | "plt.show()" 908 | ] 909 | }, 910 | { 911 | "cell_type": "code", 912 | "execution_count": 39, 913 | "metadata": {}, 914 | "outputs": [ 915 | { 916 | "data": { 917 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAOyElEQVR4nO3df6zddX3H8efLFqaiGdXWhrWdl22NWWcyZA2y6QwbE/mxiNsfBpJpR1zqH3XRbclS/YdNY4LJ5hYTR9JJZ80Ew1RiI43QMTOzP0QuiFBAR8Ui7Qq9rg51JFPce3+cT5dD6e1te0/P99LP85GcnO95f7/n+3mfy+3rfO/nfM+XVBWSpD68aOgGJEnTY+hLUkcMfUnqiKEvSR0x9CWpI8uHbuB4Vq5cWTMzM0O3IUkvKPfee+/3qmrVsdYt6dCfmZlhdnZ26DYk6QUlyePzrXN6R5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOrKkv5Er6flmtt4+2Nj7brhqsLE1GR7pS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR1ZMPSTrEvy5SQPJ3koyXtb/RVJdid5tN2vaPUk+ViSvUkeSHLh2L42te0fTbLp9L0sSdKxnMiR/rPAn1XVBuBiYEuSDcBW4K6qWg/c1R4DXAGsb7fNwI0wepMArgdeD1wEXH/kjUKSNB0Lhn5VHayq+9ryD4FHgDXA1cCOttkO4G1t+WrgUzXyVeDcJOcBbwF2V9Xhqvo+sBu4fKKvRpJ0XCc1p59kBngdcDewuqoOtlVPAqvb8hrgibGn7W+1+epHj7E5yWyS2bm5uZNpT5K0gBMO/SQvAz4HvK+qfjC+rqoKqEk0VFXbqmpjVW1ctWrVJHYpSWpOKPSTnMUo8D9dVZ9v5afatA3t/lCrHwDWjT19bavNV5ckTcmJnL0T4Cbgkar66NiqncCRM3A2AV8Yq7+zncVzMfB0mwa6A7gsyYr2Ae5lrSZJmpLlJ7DNG4B3AA8mub/VPgDcANya5F3A48Db27pdwJXAXuAZ4DqAqjqc5EPAPW27D1bV4Ym8CknSCVkw9Kvq34DMs/rSY2xfwJZ59rUd2H4yDUqSJsdv5EpSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdWT50A1osma23j7IuPtuuGqQcYd6vTDca5YWwyN9SeqIoS9JHTH0JakjzulrIoacW5d04jzSl6SOGPqS1BFDX5I6smDoJ9me5FCSPWO1v0hyIMn97Xbl2Lr3J9mb5FtJ3jJWv7zV9ibZOvmXIklayIkc6X8SuPwY9b+pqgvabRdAkg3ANcCvtOf8XZJlSZYBHweuADYA17ZtJUlTtODZO1X1lSQzJ7i/q4HPVNX/AN9Jshe4qK3bW1WPAST5TNv24ZPuWJJ0yhYzp/+eJA+06Z8VrbYGeGJsm/2tNl/9eZJsTjKbZHZubm4R7UmSjnaqoX8j8IvABcBB4K8n1VBVbauqjVW1cdWqVZParSSJU/xyVlU9dWQ5yd8DX2wPDwDrxjZd22ocpy5JmpJTOtJPct7Yw98DjpzZsxO4JsnPJDkfWA98DbgHWJ/k/CRnM/qwd+epty1JOhULHuknuQW4BFiZZD9wPXBJkguAAvYB7waoqoeS3MroA9pngS1V9dO2n/cAdwDLgO1V9dDEX40k6bhO5Oyda49Rvuk4238Y+PAx6ruAXSfVnSRpovxGriR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcWDP0k25McSrJnrPaKJLuTPNruV7R6knwsyd4kDyS5cOw5m9r2jybZdHpejiTpeE7kSP+TwOVH1bYCd1XVeuCu9hjgCmB9u20GboTRmwRwPfB64CLg+iNvFJKk6Vkw9KvqK8Dho8pXAzva8g7gbWP1T9XIV4Fzk5wHvAXYXVWHq+r7wG6e/0YiSTrNTnVOf3VVHWzLTwKr2/Ia4Imx7fa32nz150myOclsktm5ublTbE+SdCyL/iC3qgqoCfRyZH/bqmpjVW1ctWrVpHYrSeLUQ/+pNm1Duz/U6geAdWPbrW21+eqSpCk61dDfCRw5A2cT8IWx+jvbWTwXA0+3aaA7gMuSrGgf4F7WapKkKVq+0AZJbgEuAVYm2c/oLJwbgFuTvAt4HHh723wXcCWwF3gGuA6gqg4n+RBwT9vug1V19IfDkqTTbMHQr6pr51l16TG2LWDLPPvZDmw/qe4kSRPlN3IlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSPLh25AkhYys/X2wcbed8NVg419Ohj6kk7YkOGryTD0Jek4hnqjO11/YTinL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHVlU6CfZl+TBJPcnmW21VyTZneTRdr+i1ZPkY0n2JnkgyYWTeAGSpBM3iSP936qqC6pqY3u8FbirqtYDd7XHAFcA69ttM3DjBMaWJJ2E0zG9czWwoy3vAN42Vv9UjXwVODfJeadhfEnSPBYb+gXcmeTeJJtbbXVVHWzLTwKr2/Ia4Imx5+5vtedIsjnJbJLZubm5RbYnSRq32KtsvrGqDiR5FbA7yTfHV1ZVJamT2WFVbQO2AWzcuPGknitJOr5FHelX1YF2fwi4DbgIeOrItE27P9Q2PwCsG3v62laTJE3JKYd+knOSvPzIMnAZsAfYCWxqm20CvtCWdwLvbGfxXAw8PTYNJEmagsVM76wGbktyZD83V9WXktwD3JrkXcDjwNvb9ruAK4G9wDPAdYsYW5J0Ck459KvqMeBXj1H/T+DSY9QL2HKq40mSFs9v5EpSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcWe5VNqVszW28fugXppJ3RoT/UP8p9N1w1yLiStJAzOvSH4hGgpKXKOX1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6sjUQz/J5Um+lWRvkq3THl+SejbV0E+yDPg4cAWwAbg2yYZp9iBJPZv2kf5FwN6qeqyqfgx8Brh6yj1IUreWT3m8NcATY4/3A68f3yDJZmBze/ijJN9axHgrge8t4vmTsBR6APs4mn0811LoYyn0AEukj3xkUX28er4V0w79BVXVNmDbJPaVZLaqNk5iXy/kHuzDPl4IfSyFHnroY9rTOweAdWOP17aaJGkKph369wDrk5yf5GzgGmDnlHuQpG5NdXqnqp5N8h7gDmAZsL2qHjqNQ05kmmiRlkIPYB9Hs4/nWgp9LIUe4AzvI1V1OvYrSVqC/EauJHXE0JekjpyRob8ULvWQZHuSQ0n2DDH+WB/rknw5ycNJHkry3oH6eHGSryX5RuvjL4foo/WyLMnXk3xxwB72JXkwyf1JZgfs49wkn03yzSSPJPn1AXp4Tfs5HLn9IMn7pt1H6+VP2u/nniS3JHnxAD28t43/0Gn5OVTVGXVj9AHxt4FfAM4GvgFsGKCPNwEXAnsG/nmcB1zYll8O/PtAP48AL2vLZwF3AxcP9DP5U+Bm4IsD/nfZB6wc8nej9bED+KO2fDZw7sD9LAOeBF49wNhrgO8AL2mPbwX+cMo9vBbYA7yU0Yk2/wz80iTHOBOP9JfEpR6q6ivA4WmPe4w+DlbVfW35h8AjjH65p91HVdWP2sOz2m3qZxEkWQtcBXxi2mMvNUl+ltHByU0AVfXjqvqvYbviUuDbVfX4QOMvB16SZDmj4P2PKY//y8DdVfVMVT0L/Cvw+5Mc4EwM/WNd6mHqIbcUJZkBXsfoKHuI8ZcluR84BOyuqiH6+Fvgz4H/HWDscQXcmeTedumRIZwPzAH/0Ka7PpHknIF6OeIa4JYhBq6qA8BfAd8FDgJPV9WdU25jD/CbSV6Z5KXAlTz3C62LdiaGvo4hycuAzwHvq6ofDNFDVf20qi5g9E3si5K8dprjJ/ld4FBV3TvNcefxxqq6kNEVZ7ckedMAPSxnNAV5Y1W9DvhvYLDLnbcvbL4V+KeBxl/BaFbgfODngHOS/ME0e6iqR4CPAHcCXwLuB346yTHOxND3Ug9HSXIWo8D/dFV9fuh+2hTCl4HLpzz0G4C3JtnHaNrvt5P845R7AP7/qJKqOgTcxmhactr2A/vH/uL6LKM3gaFcAdxXVU8NNP7vAN+pqrmq+gnweeA3pt1EVd1UVb9WVW8Cvs/oc7iJORND30s9jEkSRnO2j1TVRwfsY1WSc9vyS4A3A9+cZg9V9f6qWltVM4x+L/6lqqZ6JAeQ5JwkLz+yDFzG6M/6qaqqJ4EnkrymlS4FHp52H2OuZaCpnea7wMVJXtr+3VzK6DOwqUryqnb/84zm82+e5P6X3FU2F6umf6mHY0pyC3AJsDLJfuD6qrpp2n0wOrp9B/Bgm08H+EBV7ZpyH+cBO9r/SOdFwK1VNdgpkwNbDdw2yhWWAzdX1ZcG6uWPgU+3A6THgOuGaKK9+b0ZePcQ4wNU1d1JPgvcBzwLfJ1hLsnwuSSvBH4CbJn0h+tehkGSOnImTu9IkuZh6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SO/B+QSai7sZQ+2wAAAABJRU5ErkJggg==\n", 918 | "text/plain": [ 919 | "
" 920 | ] 921 | }, 922 | "metadata": { 923 | "needs_background": "light" 924 | }, 925 | "output_type": "display_data" 926 | } 927 | ], 928 | "source": [ 929 | "k = 1\n", 930 | "samples = []\n", 931 | "for _ in range(10000):\n", 932 | " samples.extend(choice(x, k, False, weights).cpu().numpy())\n", 933 | "plt.hist(samples)\n", 934 | "plt.xticks(range(x.numel()))\n", 935 | "plt.show()" 936 | ] 937 | }, 938 | { 939 | "cell_type": "markdown", 940 | "metadata": {}, 941 | "source": [ 942 | "### CUDA Extension" 943 | ] 944 | }, 945 | { 946 | "cell_type": "code", 947 | "execution_count": 40, 948 | "metadata": {}, 949 | "outputs": [], 950 | "source": [ 951 | "x = x.cuda()\n", 952 | "weights = weights.cuda()" 953 | ] 954 | }, 955 | { 956 | "cell_type": "code", 957 | "execution_count": 41, 958 | "metadata": {}, 959 | "outputs": [ 960 | { 961 | "data": { 962 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAOxklEQVR4nO3df6zdd13H8eeLdhMYxBVamtlW7tSGWEkcsxlTkEwnYz8IQ/8gWyLUZVr+KApqYgr/TCEkI1FUElxSWaVENjKBhYY1bHUSiX8wdjfG1m3gyuhYa7deLA5wibD59o/zueZsu7e9vff0nNt+no/k5Hy/n+/nfD/vc3v7Ot9+vt/zbaoKSVIfXjTpAiRJ42PoS1JHDH1J6oihL0kdMfQlqSMrJ13AsaxevbqmpqYmXYYknVLuueee71XVmrm2LevQn5qaYnp6etJlSNIpJclj821zekeSOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjqyrL+RK+mFprbfNrGxD1x/xcTG1mh4pC9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTlu6CfZkOTLSR5K8mCS97b2VyTZm+SR9ryqtSfJx5LsT3J/kvOH9rWl9X8kyZaT97YkSXNZyJH+M8CfVtUm4EJgW5JNwHbgzqraCNzZ1gEuAza2x1bgBhh8SADXAa8HLgCum/2gkCSNx3FDv6oOV9W9bfmHwMPAOuBKYFfrtgt4e1u+EvhUDXwVODvJOcBbgL1VdbSqvg/sBS4d6buRJB3TCc3pJ5kCXgfcBaytqsNt0xPA2ra8Dnh86GUHW9t87ZKkMVlw6Cd5GfA54H1V9YPhbVVVQI2ioCRbk0wnmZ6ZmRnFLiVJzYJCP8kZDAL/01X1+db8ZJu2oT0fae2HgA1DL1/f2uZrf46q2lFVm6tq85o1a07kvUiSjmMhV+8EuBF4uKo+OrRpNzB7Bc4W4AtD7e9qV/FcCDzVpoFuBy5JsqqdwL2ktUmSxmTlAvq8AXgn8ECS+1rbB4DrgVuSXAs8BryjbdsDXA7sB54GrgGoqqNJPgTc3fp9sKqOjuRdSJIW5LihX1X/BmSezRfP0b+AbfPsayew80QKlCSNjt/IlaSOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqyMpJF6DTw9T22yYy7oHrr5jIuNKpyiN9SeqIR/qnmUkdcUs6NXikL0kdMfQlqSOGviR1xNCXpI4c90Rukp3AW4EjVfXa1vbnwB8AM63bB6pqT9v2fuBa4Fngj6rq9tZ+KfC3wArgE1V1/Wjfino0yRPXXi6qU9FCjvQ/CVw6R/tfV9V57TEb+JuAq4Bfaq/5uyQrkqwAPg5cBmwCrm59JUljdNwj/ar6SpKpBe7vSuAzVfU/wHeS7AcuaNv2V9WjAEk+0/o+dMIVS5IWbSlz+u9Jcn+SnUlWtbZ1wONDfQ62tvnaXyDJ1iTTSaZnZmbm6iJJWqTFhv4NwM8D5wGHgb8aVUFVtaOqNlfV5jVr1oxqt5IkFvmN3Kp6cnY5yd8DX2yrh4ANQ13XtzaO0S5JGpNFHeknOWdo9beBfW15N3BVkp9Kci6wEfgacDewMcm5Sc5kcLJ39+LLliQtxkIu2bwZuAhYneQgcB1wUZLzgAIOAO8GqKoHk9zC4ATtM8C2qnq27ec9wO0MLtncWVUPjvzdSJKOaSFX71w9R/ONx+j/YeDDc7TvAfacUHWSpJHyG7mS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHjhv6SXYmOZJk31DbK5LsTfJIe17V2pPkY0n2J7k/yflDr9nS+j+SZMvJeTuSpGNZyJH+J4FLn9e2HbizqjYCd7Z1gMuAje2xFbgBBh8SwHXA64ELgOtmPygkSeNz3NCvqq8AR5/XfCWwqy3vAt4+1P6pGvgqcHaSc4C3AHur6mhVfR/Yyws/SCRJJ9li5/TXVtXhtvwEsLYtrwMeH+p3sLXN1/4CSbYmmU4yPTMzs8jyJElzWfKJ3KoqoEZQy+z+dlTV5qravGbNmlHtVpLE4kP/yTZtQ3s+0toPARuG+q1vbfO1S5LGaLGhvxuYvQJnC/CFofZ3tat4LgSeatNAtwOXJFnVTuBe0tokSWO08ngdktwMXASsTnKQwVU41wO3JLkWeAx4R+u+B7gc2A88DVwDUFVHk3wIuLv1+2BVPf/ksCTpJDtu6FfV1fNsuniOvgVsm2c/O4GdJ1SdJGmk/EauJHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHVk5aQLkKTlbGr7bRMZ98D1V5yU/Rr6khZsUgGo0XF6R5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktSRJYV+kgNJHkhyX5Lp1vaKJHuTPNKeV7X2JPlYkv1J7k9y/ijegCRp4UZxpP8bVXVeVW1u69uBO6tqI3BnWwe4DNjYHluBG0YwtiTpBJyM6Z0rgV1teRfw9qH2T9XAV4Gzk5xzEsaXJM1jqaFfwB1J7kmytbWtrarDbfkJYG1bXgc8PvTag63tOZJsTTKdZHpmZmaJ5UmShi31hmtvrKpDSV4F7E3yzeGNVVVJ6kR2WFU7gB0AmzdvPqHXSpKObUlH+lV1qD0fAW4FLgCenJ22ac9HWvdDwIahl69vbZKkMVl06Cc5K8nLZ5eBS4B9wG5gS+u2BfhCW94NvKtdxXMh8NTQNJAkaQyWMr2zFrg1yex+bqqqLyW5G7glybXAY8A7Wv89wOXAfuBp4JoljC1JWoRFh35VPQr88hzt/wlcPEd7AdsWO54kaen8Rq4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR1Z6r13NIep7bdNugRJmpNH+pLUEUNfkjri9I60SE7j6VR0Woe+fykl6bmc3pGkjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHVk7KGf5NIk30qyP8n2cY8vST0ba+gnWQF8HLgM2ARcnWTTOGuQpJ6N+0j/AmB/VT1aVT8GPgNcOeYaJKlbK8c83jrg8aH1g8Drhzsk2Qpsbas/SvKtJYy3GvjeEl4/CsuhBrCO57OO51oOdSyHGmCZ1JGPLKmOV8+3Ydyhf1xVtQPYMYp9JZmuqs2j2NepXIN1WMepUMdyqKGHOsY9vXMI2DC0vr61SZLGYNyhfzewMcm5Sc4ErgJ2j7kGSerWWKd3quqZJO8BbgdWADur6sGTOORIpomWaDnUANbxfNbxXMuhjuVQA5zmdaSqTsZ+JUnLkN/IlaSOGPqS1JHTMvSXw60ekuxMciTJvkmMP1THhiRfTvJQkgeTvHdCdbw4ydeSfKPV8ReTqKPVsiLJ15N8cYI1HEjyQJL7kkxPsI6zk3w2yTeTPJzkVydQw2vaz2H28YMk7xt3Ha2WP26/n/uS3JzkxROo4b1t/AdPys+hqk6rB4MTxN8Gfg44E/gGsGkCdbwJOB/YN+GfxznA+W355cC/T+jnEeBlbfkM4C7gwgn9TP4EuAn44gT/XA4Aqyf5u9Hq2AX8fls+Ezh7wvWsAJ4AXj2BsdcB3wFe0tZvAX5vzDW8FtgHvJTBhTb/DPzCKMc4HY/0l8WtHqrqK8DRcY87Rx2Hq+retvxD4GEGv9zjrqOq6kdt9Yz2GPtVBEnWA1cAnxj32MtNkp9mcHByI0BV/biq/muyVXEx8O2qemxC468EXpJkJYPg/Y8xj/+LwF1V9XRVPQP8K/A7oxzgdAz9uW71MPaQW46STAGvY3CUPYnxVyS5DzgC7K2qSdTxN8CfAf87gbGHFXBHknvarUcm4VxgBviHNt31iSRnTaiWWVcBN09i4Ko6BPwl8F3gMPBUVd0x5jL2Ab+e5JVJXgpcznO/0Lpkp2Poaw5JXgZ8DnhfVf1gEjVU1bNVdR6Db2JfkOS14xw/yVuBI1V1zzjHnccbq+p8Bnec3ZbkTROoYSWDKcgbqup1wH8DE7vdefvC5tuAf5rQ+KsYzAqcC/wMcFaS3x1nDVX1MPAR4A7gS8B9wLOjHON0DH1v9fA8Sc5gEPifrqrPT7qeNoXwZeDSMQ/9BuBtSQ4wmPb7zST/OOYagP8/qqSqjgC3MpiWHLeDwMGhf3F9lsGHwKRcBtxbVU9OaPzfAr5TVTNV9RPg88CvjbuIqrqxqn6lqt4EfJ/BebiROR1D31s9DEkSBnO2D1fVRydYx5okZ7fllwBvBr45zhqq6v1Vtb6qphj8XvxLVY31SA4gyVlJXj67DFzC4J/1Y1VVTwCPJ3lNa7oYeGjcdQy5mglN7TTfBS5M8tL29+ZiBufAxirJq9rzzzKYz79plPtfdnfZXKoa/60e5pTkZuAiYHWSg8B1VXXjuOtgcHT7TuCBNp8O8IGq2jPmOs4BdrX/SOdFwC1VNbFLJidsLXDrIFdYCdxUVV+aUC1/CHy6HSA9ClwziSLah9+bgXdPYnyAqroryWeBe4FngK8zmVsyfC7JK4GfANtGfXLd2zBIUkdOx+kdSdI8DH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUkf8DCxCvQ9j6NbAAAAAASUVORK5CYII=\n", 963 | "text/plain": [ 964 | "
" 965 | ] 966 | }, 967 | "metadata": { 968 | "needs_background": "light" 969 | }, 970 | "output_type": "display_data" 971 | } 972 | ], 973 | "source": [ 974 | "k = 1\n", 975 | "samples = []\n", 976 | "for _ in range(10000):\n", 977 | " samples.extend(choice(x, k, True, weights).cpu().numpy())\n", 978 | "plt.hist(samples)\n", 979 | "plt.xticks(range(x.numel()))\n", 980 | "plt.show()" 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": 42, 986 | "metadata": {}, 987 | "outputs": [ 988 | { 989 | "data": { 990 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAOz0lEQVR4nO3df6zddX3H8efLFqaiGdVeG9YWL9sas85kyBpk0xk2JvLDiNsfBpJpR9jqH3XTbclS/YdNY4LJ5jYTR8Kks0bBMJXQQCN0zMzsD5GCCAV0VCzSrtC6OtSRTGHv/XE+dzmUe/vj3tNzbu/n+UhOzvf7+X7P9/M+l9vX+dzP93u+pKqQJPXhJZMuQJI0Poa+JHXE0Jekjhj6ktQRQ1+SOrJ80gUczcqVK2t6enrSZUjSKeW+++77flVNzbZtUYf+9PQ0u3btmnQZknRKSfLEXNuc3pGkjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4s6m/kSnqx6S13TKzvvdddPrG+NRqO9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI8cM/SRrk3wlySNJHk7y/tb+qiQ7kzzWnle09iT5RJI9SR5Mct7QsTa2/R9LsvHkvS1J0myOZ6T/HPBnVbUeuADYnGQ9sAW4u6rWAXe3dYBLgXXtsQm4HgYfEsC1wBuB84FrZz4oJEnjcczQr6oDVXV/W/4R8CiwGrgC2NZ22wa8sy1fAXymBr4GnJnkLOBtwM6qOlxVPwB2ApeM9N1Iko7qhOb0k0wDbwDuAVZV1YG26SlgVVteDTw59LJ9rW2udknSmBx36Cd5BfBF4ANV9cPhbVVVQI2ioCSbkuxKsuvQoUOjOKQkqTmu0E9yGoPA/1xVfak1P92mbWjPB1v7fmDt0MvXtLa52l+gqm6oqg1VtWFqaupE3osk6RiO5+qdADcCj1bVx4c2bQdmrsDZCNw21P6edhXPBcAzbRroTuDiJCvaCdyLW5skaUyWH8c+bwLeDTyU5IHW9iHgOuCWJNcATwDvatt2AJcBe4BngasBqupwko8A97b9PlxVh0fyLiRJx+WYoV9V/wZkjs0XzbJ/AZvnONZWYOuJFChJGh2/kStJHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1JHlky5AozW95Y6J9Lv3ussn0q+kE+NIX5I64khfI+FfGNKpwZG+JHXE0Jekjhj6ktQRQ1+SOnLME7lJtgJvBw5W1etb218Afwgcart9qKp2tG0fBK4Bngf+uKrubO2XAH8HLAM+VVXXjfatqEeTOoEMnkTWqel4RvqfBi6Zpf1vqurc9pgJ/PXAlcAvt9f8fZJlSZYBnwQuBdYDV7V9JUljdMyRflV9Ncn0cR7vCuDzVfU/wHeT7AHOb9v2VNXjAEk+3/Z95IQrliTN20Lm9N+X5MEkW5OsaG2rgSeH9tnX2uZqf5Ekm5LsSrLr0KFDs+0iSZqn+Yb+9cAvAOcCB4C/HlVBVXVDVW2oqg1TU1OjOqwkiXl+I7eqnp5ZTvIPwO1tdT+wdmjXNa2No7RLksZkXiP9JGcNrf4OsLstbweuTPIzSc4B1gFfB+4F1iU5J8npDE72bp9/2ZKk+TieSzZvBi4EVibZB1wLXJjkXKCAvcB7Aarq4SS3MDhB+xywuaqeb8d5H3Ang0s2t1bVwyN/N5Kkozqeq3eumqX5xqPs/1Hgo7O07wB2nFB1kqSR8hu5ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSR44Z+km2JjmYZPdQ26uS7EzyWHte0dqT5BNJ9iR5MMl5Q6/Z2PZ/LMnGk/N2JElHczwj/U8DlxzRtgW4u6rWAXe3dYBLgXXtsQm4HgYfEsC1wBuB84FrZz4oJEnjc8zQr6qvAoePaL4C2NaWtwHvHGr/TA18DTgzyVnA24CdVXW4qn4A7OTFHySSpJNsvnP6q6rqQFt+CljVllcDTw7tt6+1zdX+Ikk2JdmVZNehQ4fmWZ4kaTYLPpFbVQXUCGqZOd4NVbWhqjZMTU2N6rCSJOYf+k+3aRva88HWvh9YO7TfmtY2V7skaYzmG/rbgZkrcDYCtw21v6ddxXMB8EybBroTuDjJinYC9+LWJkkao+XH2iHJzcCFwMok+xhchXMdcEuSa4AngHe13XcAlwF7gGeBqwGq6nCSjwD3tv0+XFVHnhyWJJ1kxwz9qrpqjk0XzbJvAZvnOM5WYOsJVSdJGim/kStJHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSR5ZPugBJp47pLXdMpN+9110+kX6XIkf6ktQRR/qSFr1J/YUBS++vDEf6ktQRR/qSdBRL7TyGI31J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjqyoNBPsjfJQ0keSLKrtb0qyc4kj7XnFa09ST6RZE+SB5OcN4o3IEk6fqMY6f9mVZ1bVRva+hbg7qpaB9zd1gEuBda1xybg+hH0LUk6ASdjeucKYFtb3ga8c6j9MzXwNeDMJGedhP4lSXNYaOgXcFeS+5Jsam2rqupAW34KWNWWVwNPDr12X2uTJI3JQu+98+aq2p/kNcDOJN8a3lhVlaRO5IDtw2MTwNlnn73A8iRJwxY00q+q/e35IHArcD7w9My0TXs+2HbfD6wdevma1nbkMW+oqg1VtWFqamoh5UmSjjDv0E9yRpJXziwDFwO7ge3AxrbbRuC2trwdeE+7iucC4JmhaSBJ0hgsZHpnFXBrkpnj3FRVX05yL3BLkmuAJ4B3tf13AJcBe4BngasX0LckaR7mHfpV9TjwK7O0/ydw0SztBWyeb3+SpIXzG7mS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktSRhd5lU7OY3nLHpEuQpFkt6dA3fCXphZZ06Esnk4MKnYqc05ekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHVk7KGf5JIk306yJ8mWcfcvST0ba+gnWQZ8ErgUWA9clWT9OGuQpJ6Ne6R/PrCnqh6vqp8AnweuGHMNktSt5WPubzXw5ND6PuCNwzsk2QRsaqs/TvLtBfS3Evj+Al4/CouhBrCOI1nHCy2GOhZDDbBI6sjHFlTHa+faMO7QP6aqugG4YRTHSrKrqjaM4lincg3WYR2nQh2LoYYe6hj39M5+YO3Q+prWJkkag3GH/r3AuiTnJDkduBLYPuYaJKlbY53eqarnkrwPuBNYBmytqodPYpcjmSZaoMVQA1jHkazjhRZDHYuhBljidaSqTsZxJUmLkN/IlaSOGPqS1JElGfqL4VYPSbYmOZhk9yT6H6pjbZKvJHkkycNJ3j+hOl6a5OtJvtnq+MtJ1NFqWZbkG0lun2ANe5M8lOSBJLsmWMeZSb6Q5FtJHk3yaxOo4XXt5zDz+GGSD4y7jlbLn7Tfz91Jbk7y0gnU8P7W/8Mn5edQVUvqweAE8XeAnwdOB74JrJ9AHW8BzgN2T/jncRZwXlt+JfDvE/p5BHhFWz4NuAe4YEI/kz8FbgJun+B/l73Aykn+brQ6tgF/0JZPB86ccD3LgKeA106g79XAd4GXtfVbgN8fcw2vB3YDL2dwoc0/A784yj6W4kh/Udzqoaq+Chwed7+z1HGgqu5vyz8CHmXwyz3uOqqqftxWT2uPsV9FkGQNcDnwqXH3vdgk+VkGg5MbAarqJ1X1X5OtiouA71TVExPqfznwsiTLGQTvf4y5/18C7qmqZ6vqOeBfgd8dZQdLMfRnu9XD2ENuMUoyDbyBwSh7Ev0vS/IAcBDYWVWTqONvgT8H/ncCfQ8r4K4k97Vbj0zCOcAh4B/bdNenkpwxoVpmXAncPImOq2o/8FfA94ADwDNVddeYy9gN/EaSVyd5OXAZL/xC64ItxdDXLJK8Avgi8IGq+uEkaqiq56vqXAbfxD4/yevH2X+StwMHq+q+cfY7hzdX1XkM7ji7OclbJlDDcgZTkNdX1RuA/wYmdrvz9oXNdwD/NKH+VzCYFTgH+DngjCS/N84aqupR4GPAXcCXgQeA50fZx1IMfW/1cIQkpzEI/M9V1ZcmXU+bQvgKcMmYu34T8I4kexlM+/1Wks+OuQbg/0eVVNVB4FYG05Ljtg/YN/QX1xcYfAhMyqXA/VX19IT6/23gu1V1qKp+CnwJ+PVxF1FVN1bVr1bVW4AfMDgPNzJLMfS91cOQJGEwZ/toVX18gnVMJTmzLb8MeCvwrXHWUFUfrKo1VTXN4PfiX6pqrCM5gCRnJHnlzDJwMYM/68eqqp4CnkzyutZ0EfDIuOsYchUTmtppvgdckOTl7d/NRQzOgY1Vkte057MZzOffNMrjL7q7bC5Ujf9WD7NKcjNwIbAyyT7g2qq6cdx1MBjdvht4qM2nA3yoqnaMuY6zgG3tf6TzEuCWqprYJZMTtgq4dZArLAduqqovT6iWPwI+1wZIjwNXT6KI9uH3VuC9k+gfoKruSfIF4H7gOeAbTOaWDF9M8mrgp8DmUZ9c9zYMktSRpTi9I0mag6EvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOvJ/rw62eHP99NoAAAAASUVORK5CYII=\n", 991 | "text/plain": [ 992 | "
" 993 | ] 994 | }, 995 | "metadata": { 996 | "needs_background": "light" 997 | }, 998 | "output_type": "display_data" 999 | } 1000 | ], 1001 | "source": [ 1002 | "k = 1\n", 1003 | "samples = []\n", 1004 | "for _ in range(10000):\n", 1005 | " samples.extend(choice(x, k, False, weights).cpu().numpy())\n", 1006 | "plt.hist(samples)\n", 1007 | "plt.xticks(range(x.numel()))\n", 1008 | "plt.show()" 1009 | ] 1010 | } 1011 | ], 1012 | "metadata": { 1013 | "kernelspec": { 1014 | "display_name": "Python [conda env:detectron2]", 1015 | "language": "python", 1016 | "name": "conda-env-detectron2-py" 1017 | }, 1018 | "language_info": { 1019 | "codemirror_mode": { 1020 | "name": "ipython", 1021 | "version": 3 1022 | }, 1023 | "file_extension": ".py", 1024 | "mimetype": "text/x-python", 1025 | "name": "python", 1026 | "nbconvert_exporter": "python", 1027 | "pygments_lexer": "ipython3", 1028 | "version": "3.8.1" 1029 | } 1030 | }, 1031 | "nbformat": 4, 1032 | "nbformat_minor": 2 1033 | } 1034 | -------------------------------------------------------------------------------- /Proof Weighted Sampling.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeviViana/torch_sampling/c0c09566f830a29f44dd878649802d3275847b0b/Proof Weighted Sampling.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reservoir sampling implementation for Pytorch 2 | 3 | Efficient implementation of [reservoir sampling](https://en.wikipedia.org/wiki/Reservoir_sampling) for PyTorch. 4 | This implementation complexity is `O(min(k, n - k))`. 5 | The main purpose of this repo is to offer a more efficient option 6 | for sampling without replacement than the common workaround 7 | adopted (which is basically permutation followed by indexing). 8 | 9 | ## Installing 10 | ```bash 11 | git clone https://github.com/LeviViana/torch_sampling 12 | cd torch_sampling 13 | python setup.py build_ext --inplace 14 | ``` 15 | ## Benchmark 16 | 17 | Run the `Benchmark.ipynb` for details. 18 | -------------------------------------------------------------------------------- /csrc/cpu/reservoir_sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "reservoir_sampling.h" 2 | 3 | 4 | void generate_keys( 5 | float *keys, 6 | float *weights, 7 | int n, 8 | #ifdef TORCH_1_6 9 | at::CPUGeneratorImpl* generator 10 | #else 11 | at::CPUGenerator* generator 12 | #endif 13 | ){ 14 | std::lock_guard lock(generator->mutex_); 15 | at::uniform_real_distribution standard_uniform(0.0, 1.0); 16 | 17 | for(int i = 0; i < n; i++){ 18 | float u = standard_uniform(generator); 19 | keys[i] = weights[i] > 0 ? (float) std::pow(u, 1 / weights[i]):-1; 20 | } 21 | 22 | } 23 | 24 | void reservoir_generator_cpu( 25 | int64_t *indices, 26 | int64_t n, 27 | int64_t k, 28 | #ifdef TORCH_1_6 29 | at::CPUGeneratorImpl* generator 30 | #else 31 | at::CPUGenerator* generator 32 | #endif 33 | ){ 34 | std::lock_guard lock(generator->mutex_); 35 | 36 | for(int i = k; i < n; i++){ 37 | int64_t z = generator->random() % (i + 1); 38 | if (z < k) { 39 | std::swap(indices[z], indices[i]); 40 | } 41 | } 42 | 43 | } 44 | 45 | at::Tensor reservoir_sampling_cpu( 46 | at::Tensor& x, 47 | at::Tensor& weights, 48 | int64_t k 49 | ){ 50 | 51 | TORCH_CHECK( 52 | x.dim() > 0, 53 | "The input Tensor must have at least one dimension" 54 | ); 55 | 56 | int n = x.size(0); 57 | 58 | TORCH_CHECK( 59 | n >= k, 60 | "Cannot take a larger sample than population when 'replace=False'" 61 | ); 62 | 63 | auto options = x.options().dtype(at::kLong); 64 | #ifdef TORCH_1_6 65 | at::CPUGeneratorImpl* generator = at::get_generator_or_default( 66 | at::detail::getDefaultCPUGenerator(), 67 | at::detail::getDefaultCPUGenerator() 68 | ); 69 | #else 70 | at::CPUGenerator* generator = at::get_generator_or_default( 71 | nullptr, 72 | at::detail::getDefaultCPUGenerator() 73 | ); 74 | #endif 75 | 76 | if (weights.numel() == 0){ // Uniform Sampling 77 | at::Tensor indices_n = at::arange({n}, options); 78 | 79 | // This is a trick to speed up the reservoir sampling. 80 | // It makes the worst case be k = n / 2. 81 | int split, begin, end; 82 | if(2 * k < n){ 83 | split = n - k; 84 | begin = n - k; 85 | end = n; 86 | } else { 87 | split = k; 88 | begin = 0; 89 | end = k; 90 | } 91 | 92 | reservoir_generator_cpu( 93 | indices_n.data_ptr(), 94 | n, 95 | split, 96 | generator 97 | ); 98 | 99 | return x.index_select( 100 | 0, 101 | indices_n.index_select( 102 | 0, 103 | at::arange(begin, end, options) 104 | ) 105 | ); 106 | 107 | } else { // Weighted Sampling 108 | 109 | // If the weights are contiguous floating points, then 110 | // the next step won't generate a copy. 111 | at::Tensor weights_contiguous = weights.contiguous().to(at::kFloat); 112 | 113 | TORCH_CHECK( 114 | weights_contiguous.device() == x.device(), 115 | "The weights must share the same device as the inputs." 116 | ); 117 | 118 | TORCH_CHECK( 119 | n == weights_contiguous.numel(), 120 | "The weights must have the same number of elements as the input's first dimension." 121 | ); 122 | 123 | TORCH_CHECK( 124 | weights_contiguous.dim() == 1, 125 | "The weights must 1-dimensional." 126 | ); 127 | 128 | TORCH_CHECK( 129 | weights_contiguous.nonzero().numel() >= k, 130 | "Cannot have less non-zero weights than the number of samples." 131 | ); 132 | 133 | TORCH_CHECK( 134 | weights_contiguous.min().item().toLong() >= 0, 135 | "All the weights must be non-negative." 136 | ); 137 | 138 | at::Tensor keys = at::empty({n}, weights_contiguous.options()); 139 | 140 | generate_keys( 141 | keys.data_ptr(), 142 | weights_contiguous.data_ptr(), 143 | n, 144 | generator); 145 | 146 | return x.index_select(0, std::get<1>(keys.topk(k))); 147 | } 148 | } 149 | 150 | at::Tensor sampling_with_replacement_cpu( 151 | at::Tensor& x, 152 | at::Tensor& weights, 153 | int64_t k 154 | ){ 155 | 156 | TORCH_CHECK( 157 | x.dim() > 0, 158 | "The input Tensor must have at least one dimension" 159 | ); 160 | 161 | int n = x.size(0); 162 | at::Tensor samples; 163 | 164 | if (weights.numel() == 0){ // Uniform Sampling 165 | samples = at::randint(0, n, {k}, x.options().dtype(at::kLong)); 166 | } else { // Weighted Sampling 167 | 168 | TORCH_CHECK( 169 | weights.min().item().toLong() >= 0, 170 | "All the weights must be non-negative." 171 | ); 172 | 173 | TORCH_CHECK( 174 | n == weights.numel(), 175 | "The weights must have the same number of elements as the input's first dimension." 176 | ); 177 | 178 | TORCH_CHECK( 179 | weights.dim() == 1, 180 | "The weights must 1-dimensional." 181 | ); 182 | #ifdef TORCH_1_6 183 | at::CPUGeneratorImpl* generator = at::get_generator_or_default( 184 | at::detail::getDefaultCPUGenerator(), 185 | at::detail::getDefaultCPUGenerator() 186 | ); 187 | #else 188 | at::CPUGenerator* generator = at::get_generator_or_default( 189 | nullptr, 190 | at::detail::getDefaultCPUGenerator() 191 | ); 192 | #endif 193 | 194 | samples = at::empty({k}, x.options().dtype(at::kLong)); 195 | int64_t *samples_ptr = samples.data_ptr(); 196 | 197 | at::Tensor cdf = weights.cumsum(0).to(at::kFloat).clone(); 198 | float sum_cdf = cdf[-1].item().toFloat(); 199 | 200 | TORCH_CHECK( 201 | sum_cdf > 0.0, 202 | "The sum of all the weights must be strictly greater than zero." 203 | ); 204 | 205 | cdf /= sum_cdf; 206 | 207 | at::uniform_real_distribution standard_uniform(0.0, 1.0); 208 | 209 | float *cdf_ptr = cdf.data_ptr(); 210 | 211 | for(int i = 0; i < k; i++){ 212 | float u = standard_uniform(generator); 213 | auto ptr = std::lower_bound(cdf_ptr, cdf_ptr + n, u); 214 | samples_ptr[i] = std::distance(cdf_ptr, ptr); 215 | } 216 | 217 | } 218 | 219 | return x.index_select(0, samples); 220 | } 221 | 222 | at::Tensor choice_cpu( 223 | at::Tensor& input, 224 | int64_t k, 225 | bool replace, 226 | at::Tensor& weights 227 | ){ 228 | if (replace){ 229 | return sampling_with_replacement_cpu(input, weights, k); 230 | } else { 231 | return reservoir_sampling_cpu(input, weights, k); 232 | } 233 | } 234 | 235 | at::Tensor choice_cpu( 236 | at::Tensor& input, 237 | int64_t k, 238 | bool replace 239 | ){ 240 | at::Tensor weights = at::empty({0}, input.options().dtype(at::kFloat)); 241 | return choice_cpu(input, k, replace, weights); 242 | } 243 | -------------------------------------------------------------------------------- /csrc/cpu/reservoir_sampling.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #ifdef TORCH_1_6 4 | #include 5 | #else 6 | #include 7 | #endif 8 | #include 9 | 10 | #include 11 | 12 | at::Tensor choice_cpu(at::Tensor& input, int64_t k, bool replace, at::Tensor& weights); 13 | at::Tensor choice_cpu(at::Tensor& input, int64_t k, bool replace); 14 | -------------------------------------------------------------------------------- /csrc/cuda/reservoir_sampling.cu: -------------------------------------------------------------------------------- 1 | #include "reservoir_sampling.cuh" 2 | 3 | __global__ 4 | #ifdef __HIP_PLATFORM_HCC__ 5 | C10_LAUNCH_BOUNDS_1(512) 6 | #endif 7 | void generate_samples( 8 | int64_t *samples, 9 | int64_t k, 10 | int64_t n, 11 | std::pair seeds 12 | ){ 13 | int thread_id = blockIdx.x * blockDim.x + threadIdx.x; 14 | curandStatePhilox4_32_10_t state; 15 | curand_init(seeds.first, thread_id, seeds.second, &state); 16 | int64_t s = curand4(&state).x % (thread_id + k + 1); 17 | if (thread_id < n){ 18 | samples[thread_id] = s; 19 | } 20 | } 21 | 22 | __global__ 23 | #ifdef __HIP_PLATFORM_HCC__ 24 | C10_LAUNCH_BOUNDS_1(512) 25 | #endif 26 | void generate_keys( 27 | float *keys, 28 | float *weights, 29 | int64_t n, 30 | std::pair seeds 31 | ){ 32 | int thread_id = blockIdx.x * blockDim.x + threadIdx.x; 33 | curandStatePhilox4_32_10_t state; 34 | curand_init(seeds.first, thread_id, seeds.second, &state); 35 | float u = curand_uniform4(&state).x; 36 | if(thread_id < n){ 37 | keys[thread_id] = weights[thread_id] > 0 ? (float) __powf(u, (float) 1 / weights[thread_id]):-1; 38 | } 39 | } 40 | 41 | __global__ 42 | #ifdef __HIP_PLATFORM_HCC__ 43 | C10_LAUNCH_BOUNDS_1(512) 44 | #endif 45 | void sampling_with_replacement_kernel( 46 | int64_t *samples, 47 | float *cdf, 48 | int64_t n, 49 | int64_t k, 50 | std::pair seeds 51 | ){ 52 | int thread_id = blockIdx.x * blockDim.x + threadIdx.x; 53 | curandStatePhilox4_32_10_t state; 54 | curand_init(seeds.first, thread_id, seeds.second, &state); 55 | float u = curand_uniform4(&state).x; 56 | if(thread_id < k){ 57 | auto ptr = thrust::lower_bound(thrust::device, cdf, cdf + n, u); 58 | samples[thread_id] = thrust::distance(cdf, ptr); 59 | } 60 | } 61 | 62 | __global__ 63 | #ifdef __HIP_PLATFORM_HCC__ 64 | C10_LAUNCH_BOUNDS_1(512) 65 | #endif 66 | void generate_reservoir( 67 | int64_t *indices, 68 | int64_t *samples, 69 | int64_t nb_iterations, 70 | int64_t k 71 | ){ 72 | for(int i = 0; i < nb_iterations; i++){ 73 | int64_t z = samples[i]; 74 | if (z < k) { 75 | thrust::swap(indices[z], indices[i + k]); 76 | } 77 | } 78 | } 79 | 80 | at::Tensor reservoir_sampling_cuda( 81 | at::Tensor& x, 82 | at::Tensor& weights, 83 | int64_t k 84 | ){ 85 | 86 | TORCH_CHECK( 87 | x.dim() > 0, 88 | "The input Tensor must have at least one dimension" 89 | ); 90 | 91 | int n = x.size(0); 92 | 93 | TORCH_CHECK( 94 | n >= k, 95 | "Cannot take a larger sample than population when 'replace=False'" 96 | ); 97 | 98 | cudaDeviceProp* props = at::cuda::getCurrentDeviceProperties(); 99 | THAssert(props != NULL); 100 | int threadsPerBlock = props->maxThreadsPerBlock; 101 | 102 | auto options = x.options().dtype(at::kLong); 103 | dim3 threads(threadsPerBlock); 104 | 105 | #ifdef TORCH_1_6 106 | auto gen = at::get_generator_or_default( 107 | at::cuda::detail::getDefaultCUDAGenerator(), 108 | at::cuda::detail::getDefaultCUDAGenerator() 109 | ); 110 | #else 111 | auto gen = at::get_generator_or_default( 112 | nullptr, 113 | at::cuda::detail::getDefaultCUDAGenerator() 114 | ); 115 | #endif 116 | 117 | std::pair next_philox_seed; 118 | { 119 | // See Note [Acquire lock when using random generators] 120 | std::lock_guard lock(gen->mutex_); 121 | next_philox_seed = gen->philox_engine_inputs(4); 122 | } 123 | 124 | if (weights.numel() == 0){ // Uniform Sampling 125 | at::Tensor indices_n = at::arange({n}, options); 126 | 127 | // This is a trick to speed up the reservoir sampling. 128 | // It makes the worst case be k = n / 2. 129 | int split, begin, end; 130 | if(2 * k < n){ 131 | split = n - k; 132 | begin = n - k; 133 | end = n; 134 | } else { 135 | split = k; 136 | begin = 0; 137 | end = k; 138 | } 139 | 140 | int nb_iterations = std::min(k, n - k); 141 | dim3 blocks((nb_iterations + threadsPerBlock - 1)/threadsPerBlock); 142 | 143 | at::Tensor samples = at::arange({nb_iterations}, options); 144 | 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 146 | 147 | generate_samples<<>>( 148 | samples.data_ptr(), 149 | split, 150 | n, 151 | next_philox_seed 152 | ); 153 | 154 | AT_CUDA_CHECK(cudaGetLastError()); 155 | 156 | // This must be done in a separeted kernel 157 | // since this algorithm isn't thread safe 158 | generate_reservoir<<<1, 1, 0, stream>>>( 159 | indices_n.data_ptr(), 160 | samples.data_ptr(), 161 | nb_iterations, 162 | split 163 | ); 164 | 165 | AT_CUDA_CHECK(cudaGetLastError()); 166 | 167 | return x.index_select( 168 | 0, 169 | indices_n.index_select( 170 | 0, 171 | at::arange(begin, end, options) 172 | ) 173 | ); 174 | 175 | } else { // Weighted Sampling 176 | 177 | // If the weights are contiguous floating points, then 178 | // the next step won't generate a copy. 179 | at::Tensor weights_contiguous = weights.contiguous().to(at::kFloat); 180 | 181 | TORCH_CHECK( 182 | weights_contiguous.device() == x.device(), 183 | "The weights must share the same device as the inputs." 184 | ); 185 | 186 | TORCH_CHECK( 187 | n == weights_contiguous.numel(), 188 | "The weights must have the same number of elements as the input's first dimension." 189 | ); 190 | 191 | TORCH_CHECK( 192 | weights_contiguous.dim() == 1, 193 | "The weights must 1-dimensional." 194 | ); 195 | 196 | TORCH_CHECK( 197 | weights_contiguous.nonzero().numel() >= k, 198 | "Cannot have less non-zero weights than the number of samples." 199 | ); 200 | 201 | TORCH_CHECK( 202 | weights_contiguous.min().item().toLong() >= 0, 203 | "All the weights must be non-negative." 204 | ); 205 | 206 | at::Tensor keys = at::empty({n}, weights_contiguous.options()); 207 | dim3 all_blocks((n + threadsPerBlock - 1)/threadsPerBlock); 208 | 209 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 210 | 211 | generate_keys<<>>( 212 | keys.data_ptr(), 213 | weights_contiguous.data_ptr(), 214 | n, 215 | next_philox_seed 216 | ); 217 | 218 | AT_CUDA_CHECK(cudaGetLastError()); 219 | 220 | return x.index_select(0, std::get<1>(keys.topk(k))); 221 | } 222 | } 223 | 224 | at::Tensor sampling_with_replacement_cuda( 225 | at::Tensor& x, 226 | at::Tensor& weights, 227 | int64_t k 228 | ){ 229 | 230 | TORCH_CHECK( 231 | x.dim() > 0, 232 | "The input Tensor must have at least one dimension" 233 | ); 234 | 235 | int n = x.size(0); 236 | at::Tensor samples; 237 | 238 | if (weights.numel() == 0){ // Uniform Sampling 239 | samples = at::randint(0, n, {k}, x.options().dtype(at::kLong)); 240 | } else { // Weighted Sampling 241 | 242 | TORCH_CHECK( 243 | weights.min().item().toLong() >= 0, 244 | "All the weights must be non-negative." 245 | ); 246 | 247 | 248 | TORCH_CHECK( 249 | n == weights.numel(), 250 | "The weights must have the same number of elements as the input's first dimension." 251 | ); 252 | 253 | TORCH_CHECK( 254 | weights.dim() == 1, 255 | "The weights must 1-dimensional." 256 | ); 257 | 258 | cudaDeviceProp* props = at::cuda::getCurrentDeviceProperties(); 259 | THAssert(props != NULL); 260 | int threadsPerBlock = props->maxThreadsPerBlock; 261 | 262 | #ifdef TORCH_1_6 263 | auto gen = at::get_generator_or_default( 264 | at::cuda::detail::getDefaultCUDAGenerator(), 265 | at::cuda::detail::getDefaultCUDAGenerator() 266 | ); 267 | #else 268 | auto gen = at::get_generator_or_default( 269 | nullptr, 270 | at::cuda::detail::getDefaultCUDAGenerator() 271 | ); 272 | #endif 273 | 274 | std::pair next_philox_seed; 275 | { 276 | // See Note [Acquire lock when using random generators] 277 | std::lock_guard lock(gen->mutex_); 278 | next_philox_seed = gen->philox_engine_inputs(4); 279 | } 280 | 281 | samples = at::empty({k}, x.options().dtype(at::kLong)); 282 | at::Tensor cdf = weights.cumsum(0).to(at::kFloat); 283 | float sum_cdf = cdf[-1].item().toFloat(); 284 | 285 | TORCH_CHECK( 286 | sum_cdf > 0.0, 287 | "The sum of all the weights must be strictly greater than zero." 288 | ); 289 | 290 | cdf /= sum_cdf; 291 | 292 | dim3 threads(threadsPerBlock); 293 | dim3 blocks((k + threadsPerBlock - 1)/threadsPerBlock); 294 | 295 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 296 | 297 | sampling_with_replacement_kernel<<>>( 298 | samples.data_ptr(), 299 | cdf.data_ptr(), 300 | n, 301 | k, 302 | next_philox_seed 303 | ); 304 | 305 | AT_CUDA_CHECK(cudaGetLastError()); 306 | } 307 | 308 | return x.index_select(0, samples); 309 | } 310 | 311 | at::Tensor choice_cuda( 312 | at::Tensor& input, 313 | int64_t k, 314 | bool replace, 315 | at::Tensor& weights 316 | ){ 317 | if (replace){ 318 | return sampling_with_replacement_cuda(input, weights, k); 319 | } else { 320 | return reservoir_sampling_cuda(input, weights, k); 321 | } 322 | } 323 | 324 | at::Tensor choice_cuda( 325 | at::Tensor& input, 326 | int64_t k, 327 | bool replace 328 | ){ 329 | at::Tensor weights = at::empty({0}, input.options().dtype(at::kFloat)); 330 | return choice_cuda(input, k, replace, weights); 331 | } 332 | -------------------------------------------------------------------------------- /csrc/cuda/reservoir_sampling.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #ifdef TORCH_1_6 8 | #include 9 | #else 10 | #include 11 | #endif 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | //#include 20 | //#include 21 | 22 | //#include 23 | //#include 24 | 25 | at::Tensor choice_cuda( 26 | at::Tensor& input, 27 | int64_t k, 28 | bool replace 29 | ); 30 | 31 | at::Tensor choice_cuda( 32 | at::Tensor& input, 33 | int64_t k, 34 | bool replace, 35 | at::Tensor& weights 36 | ); 37 | -------------------------------------------------------------------------------- /csrc/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "cpu/reservoir_sampling.h" 2 | 3 | #ifdef WITH_CUDA 4 | #include "cuda/reservoir_sampling.cuh" 5 | #endif 6 | 7 | at::Tensor choice( 8 | at::Tensor& input, 9 | int64_t k, 10 | bool replace, 11 | at::Tensor& weights 12 | ){ 13 | 14 | if(input.type().is_cuda()){ 15 | #ifdef WITH_CUDA 16 | return choice_cuda(input, k, replace, weights); 17 | #else 18 | AT_ERROR("Not compiled with GPU support"); 19 | #endif 20 | }else{ 21 | return choice_cpu(input, k, replace, weights); 22 | } 23 | } 24 | 25 | at::Tensor choice( 26 | at::Tensor& input, 27 | int64_t k, 28 | bool replace 29 | ){ 30 | if(input.type().is_cuda()){ 31 | #ifdef WITH_CUDA 32 | return choice_cuda(input, k, replace); 33 | #else 34 | AT_ERROR("Not compiled with GPU support"); 35 | #endif 36 | }else{ 37 | return choice_cpu(input, k, replace); 38 | } 39 | } 40 | 41 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 42 | m.def( 43 | "choice", 44 | (at::Tensor (*)(at::Tensor&, int64_t, bool)) &choice, 45 | "Choice implementation." 46 | ); 47 | m.def( 48 | "choice", 49 | (at::Tensor (*)(at::Tensor&, int64_t, bool, at::Tensor&)) &choice, 50 | "Choice implementation." 51 | ); 52 | } 53 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import torch 5 | from setuptools import setup 6 | from packaging import version 7 | from torch.utils.cpp_extension import CUDA_HOME 8 | from torch.utils.cpp_extension import CppExtension 9 | from torch.utils.cpp_extension import CUDAExtension 10 | from torch.utils.cpp_extension import BuildExtension 11 | 12 | this_dir = os.path.dirname(os.path.abspath(__file__)) 13 | extensions_dir = os.path.join(this_dir, "csrc") 14 | 15 | main_source = glob.glob(os.path.join(extensions_dir, "*.cpp")) 16 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 17 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 18 | 19 | sources = main_source + source_cpu 20 | extension = CppExtension 21 | define_macros = [] 22 | 23 | if torch.cuda.is_available() and CUDA_HOME is not None: 24 | extension = CUDAExtension 25 | sources += source_cuda 26 | define_macros += [("WITH_CUDA", None)] 27 | 28 | if version.parse(torch.__version__) >= version.parse('1.6.0'): 29 | define_macros += [("TORCH_1_6", None)] 30 | 31 | setup( 32 | name='torch_sampling', 33 | author="LeviViana", 34 | description="Efficient random sampling extension for Pytorch", 35 | ext_modules=[extension( 36 | 'torch_sampling', 37 | sources, 38 | define_macros=define_macros, 39 | include_dirs=[extensions_dir], 40 | )], 41 | cmdclass={'build_ext': BuildExtension}, 42 | ) 43 | --------------------------------------------------------------------------------