├── README.md └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # softtopk 2 | differentiable top-k operator 3 | 4 | Blog: https://kexue.fm/archives/10373#%E4%BA%8C%E8%80%85%E5%85%BC%E4%B9%8B 5 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | import jax.lax as lax 4 | 5 | 6 | def softtopk(x, k): 7 | """differentiable top-k operator for jax 8 | Refer: https://kexue.fm/archives/10373#%E4%BA%8C%E8%80%85%E5%85%BC%E4%B9%8B 9 | """ 10 | x_sort = x.astype('float32').sort(axis=-1) 11 | lse1 = lax.cumlogsumexp(x_sort, axis=x.ndim - 1) 12 | lse2 = lax.cumlogsumexp(-x_sort, axis=x.ndim - 1, reverse=True) 13 | lse2 = jnp.roll(lse2, -1, axis=-1).at[..., -1].set(-jnp.inf) 14 | km = k - jnp.arange(x.shape[-1] - 1, -1, -1) 15 | x_lamb = lse1 - jnp.log(jnp.sqrt(km**2 + jnp.exp(lse1 + lse2)) + km) 16 | x_sort_ = jnp.roll(x_sort, -1, axis=-1).at[..., -1].set(jnp.inf) 17 | idxs = ((x_lamb <= x_sort_) & (x_lamb >= x_sort)).argmax(axis=-1) 18 | lamb = jnp.take_along_axis(x_lamb, idxs[..., None], axis=-1) 19 | p = (1 - jnp.exp(-jnp.abs(x - lamb))) * jnp.sign(x - lamb) * 0.5 + 0.5 20 | return p.astype(x.dtype) 21 | 22 | 23 | x = jnp.array(np.random.randn(32, 128)) 24 | p = softtopk(x, 16) 25 | print(p.sum(axis=-1)) 26 | --------------------------------------------------------------------------------