├── ICML.pdf ├── ICML.tex ├── README.md └── code ├── Flow-functions.ipynb ├── Single cell stuff.ipynb ├── Theano test.ipynb └── Theano-functions.ipynb /ICML.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thashim/population-diffusions/18586024f01ad1bc17ebfe0334a45a9d4916b16b/ICML.pdf -------------------------------------------------------------------------------- /ICML.tex: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | %%%%%%%% ICML 2016 EXAMPLE LATEX SUBMISSION FILE %%%%%%%%%%%%%%%%% 3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4 | 5 | % Use the following line _only_ if you're still using LaTeX 2.09. 6 | %\documentstyle[icml2016,epsf,natbib]{article} 7 | % If you rely on Latex2e packages, like most moden people use this: 8 | \documentclass{article} 9 | 10 | % use Times 11 | \usepackage{times} 12 | % For figures 13 | \usepackage{graphicx} % more modern 14 | %\usepackage{epsfig} % less modern 15 | \usepackage[tight]{subfigure} 16 | \usepackage{amsmath} 17 | \usepackage{amsthm} 18 | \usepackage{amsfonts} 19 | \usepackage{thmtools} 20 | \renewcommand{\listtheoremname}{List of theorems and definitions} 21 | 22 | \declaretheorem[name=Theorem]{thm} 23 | \declaretheorem[name=Lemma]{lem} 24 | \declaretheorem[name=Corollary]{cor} 25 | \declaretheorem[name=Conjecture]{conj} 26 | \declaretheorem[name=Definition]{defn} 27 | \declaretheorem[name=Remark, style=remark]{remark} 28 | % For citations 29 | \usepackage{natbib} 30 | 31 | \newcommand{\red}[1]{\textcolor{red}{#1}} 32 | 33 | % For algorithms 34 | \usepackage{algorithm} 35 | \usepackage{algorithmic} 36 | 37 | % As of 2011, we use the hyperref package to produce hyperlinks in the 38 | % resulting PDF. If this breaks your system, please commend out the 39 | % following usepackage line and replace \usepackage{icml2016} with 40 | % \usepackage[nohyperref]{icml2016} above. 41 | \usepackage{hyperref} 42 | 43 | % Packages hyperref and algorithmic misbehave sometimes. We can fix 44 | % this with the following command. 45 | \newcommand{\theHalgorithm}{\arabic{algorithm}} 46 | 47 | % Employ the following version of the ``usepackage'' statement for 48 | % submitting the draft version of the paper for review. This will set 49 | % the note in the first column to ``Under review. Do not distribute.'' 50 | %\usepackage{icml2016} 51 | 52 | % Employ this version of the ``usepackage'' statement after the paper has 53 | % been accepted, when creating the final version. This will set the 54 | % note in the first column to ``Proceedings of the...'' 55 | \usepackage[accepted]{icml2016} 56 | 57 | 58 | % The \icmltitle you define below is probably too long as a header. 59 | % Therefore, a short form for the running title is supplied here: 60 | \icmltitlerunning{Learning Population-Level Diffusions with Generative Recurrent Networks} 61 | 62 | \begin{document} 63 | 64 | \twocolumn[ 65 | \icmltitle{Learning Population-Level Diffusions with Generative Recurrent Networks} 66 | 67 | % It is OKAY to include author information, even for blind 68 | % submissions: the style file will automatically remove it for you 69 | % unless you've provided the [accepted] option to the icml2016 70 | % package. 71 | \icmlauthor{Tatsunori B. Hashimoto}{thashim@mit.edu} 72 | \icmlauthor{David K. Gifford}{dkg@mit.edu} 73 | \icmlauthor{Tommi S. Jaakkola}{tommi@csail.mit.edu} 74 | 75 | 76 | % You may provide any keywords that you 77 | % find helpful for describing your paper; these are used to populate 78 | % the "keywords" metadata in the PDF but will not be shown in the document 79 | \icmlkeywords{boring formatting information, machine learning, ICML} 80 | 81 | \vskip 0.3in 82 | ] 83 | 84 | \newcommand{\oy}{\overline{y}} 85 | 86 | 87 | \begin{abstract} 88 | We estimate stochastic processes that govern the dynamics of evolving populations such as cell differentiation. The problem is challenging since longitudinal trajectory measurements of individuals in a population are rarely available due to experimental cost and/or privacy. We show that cross-sectional samples from an evolving population suffice for recovery within a class of processes even if samples are available only at a few distinct time points. We provide a stratified analysis of recoverability conditions, and establish that reversibility is sufficient for recoverability. For estimation, we derive a natural loss and regularization, and parameterize the processes as diffusive recurrent neural networks. We demonstrate the approach in the context of uncovering complex cellular dynamics known as the `epigenetic landscape' from existing biological assays. 89 | \end{abstract} 90 | 91 | 92 | \section{Motivation} 93 | 94 | Understanding the population dynamics of individuals over time is a fundamental problem in a variety of areas, from biology (gene expression of a cell population \citep{waddington1940organisers}), ecology (spatial distribution of animals \citep{tereshko2000reaction}), to census data (life expectancy \citep{manton2008cohort} and racially segregated housing \citep{bejan2007constructal}). In such areas, experimental cost or privacy concerns often prevent measurements of complete trajectories of individuals over time, and instead we observe samples from an evolving population over time (Fig. \ref{fig:problem}). 95 | 96 | For example, modeling the active life expectancy and disabilities of an individual over time is an area of substantial interest for healthcare statistics \citep{manton2008cohort}, but the expense and difficulty of collecting longitudinal health data has meant that much of the data is cross-sectional \citep{robine2004looking}. Our technique replaces longitudinal data with cross-sectional data for inferring the underlying dynamics behind continuous time-series. 97 | 98 | \begin{figure} 99 | \vspace{-7pt} 100 | \centering 101 | \includegraphics[scale=0.4]{fig/example_draws.png} 102 | \vspace{-7pt} 103 | \caption{In population-level inference we observe samples (colored points) drawn from the process at different times. The goal is to infer the dynamics (blue vectors). In this toy dataset each point can be thought of as a single cell and the x and y axes as gene expression levels of two genes.} 104 | \label{fig:problem} 105 | \vspace{-8pt} 106 | \end{figure} 107 | 108 | 109 | The framework we develop will be applicable to the general cross-sectional population inference problem, but in order to ground our discussion we will focus on a specific application in computational biology, where we seek to understand the process by which embryonic stem cells differentiate into mature cells. An individual cell's tendency to differentiate into a mature cell is thought to follow a `epigenetic landscape' much like a ball rolling down a hill. The local minima of this landscape represents cell states and the slope represents the rate of differentiation \citep{waddington1940organisers}. While more recent work has established the validity of modeling differentiation as a diffusion process \citep{hanna2009direct,morris2014mathematical}, direct inference of the epigenetic landscape has been limited to the dynamics of single genes \citep{sisan2012predicting} due to the difficulty of longitudinally tracking single cells. 110 | 111 | Our work establishes that no longitudinal tracking is necessary and population data alone can be used to recover the latent dynamics driving diffusions. This result allows cheap, high-throughput assays such as single cell RNA-seq to be used to infer the latent dynamics of tens to hundreds of genes. 112 | 113 | Analyzing the inference problem for population-level diffusions, we utilize the connection between partial differential equations, diffusion processes, and recurrent neural networks (RNN) to derive a principled loss function and estimation procedure. 114 | 115 | Our contributions are the following 116 | \begin{itemize} 117 | \item First, we rigorously study whether the dynamics of a diffusion can be recovered from cross-sectional observations, and establish the first identifiability results. 118 | \item Second, we show that a particular regularized recurrent neural network (RNN) with Wasserstein loss is a natural model for this problem and use this to construct a fast scalable initializer that exploits the connection between diffusions and RNNs. 119 | \item Finally, our method is verified to recover known dynamics from simulated data in the high-dimensional regime better than both parametric and local diffusion models, as well as predict the differentiation time-course on tens of genes for real RNA-seq data. 120 | \end{itemize} 121 | 122 | \section{Prior work} 123 | 124 | Population level inference of dynamics consists of observing samples drawn from a diffusion stopped at various times and inferring the forces driving the changes in the population (Fig. \ref{fig:problem}) which contrasts with inferring dynamics with trajectory data which tracks individuals longitudinally. Our work is distinct from existing approaches in that it considers sampled, multivariate, and non-stationary $(t < \infty)$ observations. 125 | 126 | \subsection{Population level inference} 127 | 128 | Inferring dynamics from population appears in three areas: In home-range estimation, one estimates the support of a two-dimensional continuous time series from the stationary distribution \citep{fleming2015rigorous}. Our work is distinguished by our focus on the high-dimensional $(d>2)$ and non-stationary settings. The stationary case is discussed in section \ref{sec:stationary}. 129 | 130 | Inverse problems in parabolic differential equations identify continuous, low-dimensional dynamics given noisy but complete measurements (rather than samples) along a known boundary \citep{tarantola2005inverse}. One-dimensional methods using plug-in kernel density estimates exist \citep{lund2014nonparametric} but do not generalize to greater than one dimension. 131 | 132 | Finally, estimation of discrete Markov chains using `macro' data is the discrete time and space equivalent of our problem. This is a classic problem in econometrics, and recovery properties \citep{van1983estimation}, estimation algorithms \citep{kalbfleisch1984least}, and the effect of noise \citep{bernstein2016consistently} are all well-known. The discrete solutions above observe multiple populations stopped at the same time points, which allows for the more general solutions. Our problem cannot be solved trivially via discretization: discretizing the space scales exponentially with dimension, and discretizing time results in a solution which is conceptually equivalent to the time-derivative model in section \ref{sec:manytime} and does not capture the underlying geometry of the problem. 133 | 134 | \subsection{Diffusive RNNs} 135 | 136 | Diffusive networks \citep{mineiro1998learning} connect diffusion processes and RNNs much like our work. Our work focuses on the specific problem of population-level diffusions (rather than full trajectory observations) and derives a new pre-training scheme based on contrastive divergence. Our work shows that the connection between recurrent network and diffusions such as those in \citep{mineiro1998learning} can be used to develop powerful inference techniques for general diffusions. 137 | 138 | \subsection{Computational biology} 139 | Pseudo-time analysis \citep{trapnell2014pseudo} models the differentiation of cells as measured by single-cell RNA-seq by assigning each cell to a differentiation path via bifurcations and a `pseudo-time' indicating its level of differentiation. Such analysis is driven by the desire to identify the cell-states and relevant marker genes during differentiation. Recent sophisticated methods can recover such bifurcations quite effectively \citep{setty2016wishbone, marco2014bifurcation}. 140 | 141 | Our work complements such analyses by showing that it is possible to recover quantitative parameters such as the underlying epigenetic landscape from few population measurements. Our results on identifiability of the epigenetic landscape will become more valuable as the number of captured cells in a single-cell RNA-seq experiment grows from hundreds \citep{klein2015droplet} to tens of thousands. 142 | 143 | Systems biology models of the epigenetic landscape have focused on constructing landscapes which recapitulate the qualitative properties of differentiation systems \citep{qiu2012understanding,bhattacharya2011deterministic}. Our work distinguished by a focus on data-driven identification of the epigenetic landscape. Existing data-driven models of the epigenetic landscape are for a single gene and either rely on longitudinal tracking \citep{sisan2012predicting} or require assuming that a particular cell population is stationary \citep{luo2013cell}. 144 | 145 | \section{Population-level behavior of diffusions} 146 | We will begin with a short overview of our notation, observation model, and mathematical background. 147 | 148 | A $d$-dimensional diffusion process $X(t)$ represents the state (such as gene expression) of an individual at time $t$. Formally we define $X(t)$ as a stochastic differential equation (SDE): 149 | \begin{equation}\label{eq:sde} 150 | dX(t) = \mu(X(t))dt + \sqrt{2\sigma^2} dW(t). 151 | \end{equation} 152 | Where $W(t)$ is the unit Brownian motion. This can be thought of as the continuous-time limit of the discrete stochastic process $Y(t)$ as $\Delta t\to 0$: 153 | \begin{equation}\label{eq:discr} 154 | Y(t+\Delta t) = Y(t)+\mu(Y(t))\Delta t + \sqrt{2\sigma^2\Delta t} Z(t) 155 | \end{equation} 156 | where $Z(t)$ are i.i.d standard Gaussians. The function $\mu(x)$ is called the \textbf{drift} and represents the force acting on an individual at a particular state $x$. In Fig. \ref{fig:problem}, the blue curves are $\mu(x)$ which result in $X(t)$ converging to one of four terminal states. The probability of observing $X(t)$ at any point $x$ at time $t$ is called the \textbf{marginal distribution} and corresponds to the colored points in Fig. \ref{fig:problem}. 157 | 158 | We define the population-level inference task as finding the drift function $\mu$ given distributions over the marginals. 159 | \begin{defn}[Population-level inference] 160 | 161 | Define the marginal distribution $\rho(t,x) = P(X(t)=x)$. 162 | 163 | A population-level inference problem on $X(t)$ given diffusion constant $\sigma$, time points $\mathcal{T}=\{0, t_1 \hdots t_n\}$, and samples $\mathcal{M}=\{m_0 \hdots m_n\}$ consists of identifying $\mu(x)$ from samples 164 | $\{x(t)_i \sim \rho(t,x) \mid i \in \{1\hdots m_t\}, t\in \mathcal{T}\}$. 165 | 166 | \end{defn} 167 | 168 | Fully general population level inference is impossible. Consider a process with the unit disk in $\mathbb{R}^2$ as $\rho(0,x)$, and the drift $\mu$ is a clockwise rotation. From a population standpoint, this would look identical to no drift at all. 169 | 170 | This raises the question: what restrictions on $\mu(x)$ are natural, and allow for the recovery of the underlying drift? Our paper considers \textbf{gradient flows} which are stochastic processes with drift defined as $\mu(x) = -\nabla \Psi(x)$ \footnote{For diffusion processes, the gradient flow condition is equivalent to reversibility \citep[Section 4.6]{pavliotis2014stochastic}.}. The \textbf{potential function} $\Psi(x)$ corresponds to the `epigenetic landscape' of our stochastic process. The force $\mu(x) = -\nabla \Psi(x)$ drives the process $X(t)$ toward regions of low $\Psi(x)$ much like a noisy gradient descent. 171 | 172 | A remarkable result on these gradient flows is that the marginal distribution $\rho(t,x)$ evolves by performing steepest descent on the relative entropy $D(\rho(t,x) || \exp(-\Psi(x)/\sigma^2))$ with respect to the 2-Wasserstein metric $W_2$. Formally, this is described by the Jordan-Kinderlehrer-Otto theorem \citep{jordan1998variational}: 173 | 174 | \begin{thm}[The JKO theorem]\label{thm:jko} 175 | Given a diffusion process defined by equation \ref{eq:sde} with $\mu(x) = -\nabla \Psi(x)$, then the marginal distribution $\rho(t,x)=P(X(t)=x)$ is approximated by the solution to the following recurrence equation for $\rho^{(t)}$ with $\rho^{(0)}=\rho(0,x)$. 176 | \begin{multline} 177 | \rho^{(t+\Delta t)} = \underset{\rho^{(t+\Delta t)}}{\text{argmin}} \quad W_2(\rho^{(t+\Delta t)}, \rho^{(t)})^2 \\ 178 | + \frac{\Delta t}{\sigma^2}D\left(\rho^{(t+\Delta t)}||\exp\left(\frac{-\Psi(x)}{\sigma^2}\right)\right). 179 | \end{multline} 180 | in the sense that $\lim_{\Delta t \to 0} \rho^{(t)}(x) \to \rho(t,x)$ 181 | \end{thm} 182 | 183 | This theorem is the conceptual core of our approach: the Wasserstein metric, which represents the probability of transforming one distribution to another via purely Brownian motion, will be our empirical loss \citep{adams2013large}; and the relative entropy $D(\rho||\exp(-\Psi(x)/\sigma^2))$ describing the tendency of the system to maximize entropy, will be our regularizer. 184 | 185 | 186 | 187 | 188 | \section{Recoverability of the potential $\Psi$}\label{sec:ident} 189 | 190 | Before we discuss our model, we must first establish that it is possible to asymptotically identify the true potential $\Psi(x)$ from sampled data. Otherwise the estimated $\Psi(x)$ will have limited value as a scientific and predictive tool. 191 | 192 | We consider recoverability in three regimes of increasing difficulty. First, in section \ref{sec:stationary}, we consider the stationary case of observing $\rho(\infty,x)$ which results in a closed-form estimator for $\Psi$, but requires unrealistic assumptions on our model. Next, in section \ref{sec:manytime} we consider a large number of observations across time, and show that exact identifiability is possible. However, this case requires a prohibitively large number of experiments to guarantee identifiability. Finally, in section \ref{sec:finiteobs} we will consider the most realistic case of observing a few observations across time, and discuss the conditions under which recovery of $\Psi$ is possible. 193 | 194 | \subsection{Stationary observations}\label{sec:stationary} 195 | 196 | In the stationary observation model, we are given samples from a fully mixed process $\rho(\infty,x)$. In this case, one time observation is sufficient to exactly identify the potential. 197 | This follows from representing the stochastic process in Eq. \ref{eq:sde} as a parabolic partial differential equation (PDE). 198 | \begin{thm}[Fokker-Planck \citep{jordan1998variational}] 199 | Given the SDE in equation \ref{eq:sde}, with drift $\mu(x) = -\nabla\Psi(x)$, the marginal distribution $\rho(t,x)$ fulfills: 200 | \begin{equation}\label{eq:fp} 201 | \frac{\partial \rho}{\partial t} = \text{div}(\rho(t,x)\nabla \Psi(x)) + \sigma^2\nabla^2 \rho(t,x) 202 | \end{equation} 203 | with given initial condition $\rho(0,x)$. 204 | \end{thm} 205 | 206 | Now in the stationary case, we can note that the ansatz $\rho(\infty,x) = \exp(-\Psi(x)/\sigma^2)$ gives: 207 | \[ 0 = \text{div}(\nabla \Psi(x) \rho(\infty,x))/\sigma^2 + \nabla^2 \rho(\infty,x)\] 208 | implying that $\exp(-\Psi(x)/\sigma^2)$ is the stationary distribution, and we can estimate the underlying drift as $\nabla\Psi(x) = -\nabla \log(\rho(\infty,x))\sigma^2$. The quantity $-\nabla \log(\rho(\infty,x))\sigma^2$ can be estimated from samples via one step of the mean-shift algorithm \citep[Eq. 41]{fukunaga1975estimation}. 209 | 210 | Although estimation of $\nabla \Psi(x)$ from the stationary distribution is tractable, it has two substantial drawbacks. First, it is difficult to collect samples from the exact stationary distribution $\rho(\infty,x)$; we often collect marginal distributions that are close, but not exactly equal to, the stationary distribution. Second, our estimator $-\nabla \log(\rho(\infty,x))$ is only accurate over regions of high density in $\rho(\infty,x)$ which may be distinct from our region of interest. For differentiation systems, this means we will only know the behavior of $\nabla \Psi(x)$ near the fully differentiated state, rather than over the entire differentiation timecourse. 211 | 212 | To make this drawback clear, consider the case where $\sigma^2$ is small. The stationary observations from $\exp(-\Psi(x)/\sigma^2)$ will concentrate around the global minimums of $\Psi(x)$ and will therefore only tell us about the local behavior of $\Psi(x)$ around the minima. On the other hand, observing a non-stationary sequence of distributions $\rho(0,x), \rho(t_1,x) \hdots$ does not have this drawback, as $\rho(0,x)$ may be initialized far from the minima of $\Psi(x)$ allowing us to observe how the distribution $\rho(0,x)$ converges to the minima of $\Psi(x)$. 213 | 214 | \subsection{Many time observations}\label{sec:manytime} 215 | 216 | We show that sampling multiple nonstationary timepoints is identifiable, and avoids the drawbacks of a single stationary observation. Consider a observation scheme where we obtain $\rho(0,x), \rho(t_1,x) \hdots$ up to some time $t_n=T$ such that we can estimate one of two quantities reliably: 217 | \begin{itemize} 218 | \item \textbf{Short-time:} $\frac{\partial \rho}{\partial t}\bigg|_T \approx \sum_{i=1}^n\frac{\rho(t_{i},x)-\rho(t_0,x)}{t_{i}-t_{0}}$ 219 | \item \textbf{Time-integral:} $\int_0^T \rho(t,x)dt \approx \sum_{i=1}^n \rho(t_i,x)/n$ 220 | \end{itemize} 221 | 222 | In both of these cases, we can show that the underlying potential $\Psi(x)$ is identifiable via direct inversion of the Fokker-Planck operator. The time-integral model is particularly interesting, as it can be implemented in practice for single cell RNA-seq by collecting cells at uniform times across development \citep{klein2015droplet}. 223 | 224 | \begin{thm}[Uniqueness of Fokker-Planck like operators]\label{thm:uniqueop} 225 | Let $\Psi(x)$ be a continuously differentiable solution to the following elliptic PDE: 226 | \begin{equation}\label{eq:parab} 227 | f(x) = \nabla^2 \Psi(x) \tau(x) + \nabla \Psi(x) \nabla \tau(x) + \sigma^2 \nabla^2 \tau(x) 228 | \end{equation} 229 | subject to the constraint $\int \exp(-\Psi(x)/\sigma^2) dx = 1$. 230 | 231 | Equation \ref{eq:parab} is fulfilled in the short-time case with, $f=\frac{\partial \rho}{\partial t}$, $\tau = \rho$ and in the time-integral case, $f(x)=\rho(t_0,x)-\rho(t_n,x)$ and $\tau(x) = \int_0^T \rho(t,x)dt$. 232 | 233 | Additionally, the Fokker-Planck equation associated with $\rho(t,x)$ is constrained to domain $\Omega$ via a reflecting boundary. Formally, there exists a compact domain $\Omega$ with $\langle \nabla \Psi(x) \tau(x) + \sigma^2\nabla \tau(x) , n_x \rangle = 0$ for any boundary normal vector $n_x$ with $x \in \partial \Omega$. \footnote{This boundary condition is only necessary to keep the proof simple. We prove a relaxation in section \ref{sec:uniqueop2}.} 234 | 235 | Then $\Psi(x)$ is unique up to sets of measure zero in $\tau(x)$. 236 | \end{thm} 237 | \begin{proof} 238 | Consider any $\Psi_1(x)$ and $\Psi_2(x)$, then by linearity of the PDE, $\Psi'(x)=\Psi_1(x) - \Psi_2(x)$ must be a solution to the homogeneous elliptic PDE 239 | \[0 = \text{div}(\nabla \Psi'(x) \tau(x))=\nabla^2 \Psi'(x) \tau(x) + \nabla \Psi'(x) \nabla \tau(x).\] 240 | 241 | Consider the set $R_\epsilon=\{x:x\in\Omega, \Psi'(x)\leq \min_y \Psi'(y) + \epsilon\}$. By smoothness of $\Psi'$ and compactness of $\Omega$, for all $\epsilon > \epsilon_{min} = \min_y \Psi'(y)$ the region $R_\epsilon$ is compact. 242 | 243 | By construction, $\partial R_{\epsilon}$ can be decomposed into two parts: the boundary of the level set $\Psi'(x) = \min_y \Psi'(y)+\epsilon$ which we define as $\partial R_{\epsilon}^\circ$ and a possibly empty subset of the domain boundary $\partial \Omega$ defined as $\partial \Omega^\circ$. 244 | 245 | By the divergence theorem we can integrate the elliptic PDE over any $R_{\epsilon}$: 246 | \begin{align*} 247 | \int_{x\in R_\epsilon} \text{div}(\nabla \Psi'(x)\tau(x))dx = \int_{x\in \partial \Omega\circ} \langle \nabla \Psi'(x) \tau(x) , n_x \rangle dx \\ 248 | + \int_{x\in \partial R_\epsilon^\circ} |\nabla\Psi'(x)|_2\tau(x) dx = 0 249 | \end{align*} 250 | By the boundary condition, for any $n_x$ with $x\in\partial \Omega$, $\langle \nabla \Psi_1(x) \tau + \sigma^2\nabla \tau , n_x \rangle = 0$ which implies that $\langle \nabla \Psi'(x) \tau , n_x \rangle = 0$ and therefore $\int_{x\in \partial R_\epsilon^\circ} |\nabla\Psi'(x)|_2\tau(x) dx = 0$. 251 | 252 | By construction, $\tau(x)>0$ over $\Omega$ and therefore $|\nabla \Psi'(x)| = 0$ for all $x \in \partial R_{\epsilon}^\circ$. The union of sets $\partial R_{\epsilon}^\circ$ contains all of $\Omega$ by construction, and therefore for $x\in \Omega$, $|\nabla \Psi'(x)| = |\nabla \Psi_1(x) - \nabla \Psi_2(x)| = 0$. 253 | Combined with the normalization constraint, $\int \exp(-\Psi(x)/\sigma^2) dx = 1$, this implies $\Psi_1(x) = \Psi_2(x)$. 254 | \end{proof} 255 | 256 | The proof of Thm. \ref{thm:uniqueop} illustrates that the recoverability depends critically on $\tau(x)>0$. Thus in the time-integral case, the regions which can be clearly recovered are those over which $\tau(x)=\int_0^T\rho(t,x)dt$ has large mass. Compared to the stationary situation, this is substantially better; we will get accurate estimates of $\Psi$ over the entire timecourse of $\rho(0,x) \hdots \rho(T,x)$. 257 | 258 | Finally, we ask whether $\Psi$ is recoverable when the time observations $\rho(0,x), \rho(t_1,x) \hdots $ are sufficiently few and separated in time such that both the short-time and time-integral assumptions are not valid. 259 | \subsection{Few time observations}\label{sec:finiteobs} 260 | 261 | In more realistic settings, we may get many samples, but very few time observations such that the time-integral uniqueness theorem does not hold. We analyze this case and establish two results: first, we establish exact identifiability in one dimension (Thm. \ref{thm:1d}) and give evidence for the conjecture in multiple dimensions (Cor. \ref{cor:hypo}). Next, we establish that a sufficiently mixed final time observation is sufficient for uniqueness (Thm. \ref{thm:entropic}) and derive a model constraint based on this theorem (Eq. \ref{eq:KL}). 262 | 263 | In one dimension, three time points are sufficient to recover the underlying potential function\footnote{The requirement of three marginal distributions is due to the more general nature of \citep[Problem 1]{gol2010uniqueness}. We believe only two marginals are necessary.}: 264 | \begin{thm}[1-D identifiability]\label{thm:1d} 265 | Assume there exists some $c$ such that $\sigma>c>0$; boundaries $a,b$ such that $\rho(t,a)=0$ and $\rho(t,b)=0$ for all $t$; and the marginal densities are Holder continuous with $\rho(t,x) \in H^{2+\lambda}$. 266 | 267 | Given $\rho(0,x), \rho(t_1,x), \rho(t_2,x)$ with $0\neq t_1\neq t_2 < \infty$, there exists a unique continuous potential $\Psi(x) \in C^1$ fulfilling the Fokker-Planck equation. 268 | \end{thm} 269 | \begin{proof} 270 | This is a special case of problem 1 considered in \citep{gol2010uniqueness} once we set $c(x,t,u)=1$, $f(x,t)=0$, $d(x,t,u)=0$, $b_1(x,t,u)=0$, $p(x)=d_1(x,t,u)=0$. The result follows from \citep[Theorem 1]{gol2010uniqueness}. 271 | \end{proof} 272 | 273 | In the multivariate case, the adjoint technique used in \citep{gol2010uniqueness} no longer applies, and the equivalent result is an open problem conjectured to be true \citep{de2012note}. We believe this conjecture is true and show that for any finite number of candidate $\Psi$ which agrees at two marginals $\rho(0,x)$ and $\rho(t,x)$ we can identify the true potential using a third measurement. 274 | 275 | \begin{cor}[Finite identifiability of $\Psi$]\label{cor:hypo} 276 | Let $\Psi_0$ and $\Psi_1$ be candidate potentials such that given $\rho_0(0,x)=\rho_1(0,x)$ and 277 | \[\frac{\partial \rho_i}{\partial t} = \text{div}(\nabla \Psi_i(x) \rho_i(t,x)) + \sigma^2\nabla^2 \rho_i(t,x)\] 278 | such that $\rho_0(t,x) = \rho_1(t,x)$. Define $\rho_i(t_3,x)$ where $t_3 \sim T$ is a draw from $T$ defined as a random variable absolutely continuous with respect to the Lebesgue measure, then $\rho_1(t_3,x)=\rho_0(t_3,x)$ with probability one if and only if $\forall x$, $\Psi_1(x)=\Psi_0(x)$. 279 | \end{cor} 280 | \begin{proof} 281 | See Supp. section \ref{sec:hypo}. The statement reduces to short-time uniqueness studied in section \ref{sec:manytime}. 282 | \end{proof} 283 | 284 | In the case that the final marginal distribution $\rho(t_n,x)$ is sufficiently mixed, stationary identifiability allows us to derive an identifiability result regardless of the conjecture. 285 | 286 | \begin{thm}[Relative fisher information constraint]\label{thm:entropic} 287 | Let $\rho(0,x)$ and $\rho(t_n,x)$ be marginal distributions associated with the potential $\Psi$. Then, if the final time $\rho(t_n,x)$ is sufficiently mixed: 288 | \[-\frac{\partial }{\partial t}D(\rho(t_n,x)||\exp(-\Psi(x)/\sigma^2))\leq \epsilon,\] 289 | all $\hat{\Psi}$ which are consistent with $\rho(0,x)$ and $\rho(t_n,x)$ with similar mixing constraints: $-\frac{\partial }{\partial t}D(\rho(t_n,x)||\exp(-\hat{\Psi}(x)/\sigma^2))\leq \epsilon$ must imply similar drifts: 290 | \[\int |\nabla \Psi(x)-\nabla \hat{\Psi}(x)|^2\rho(t_n,x) dx \leq 4\epsilon.\] 291 | \end{thm} 292 | \begin{proof} 293 | This follows from a relative fisher information identity in \citep[Lemma 4.1]{markowich2000trend}. We reproduce an abbreviated proof for completeness. Since $\rho$ is the solution to the Fokker-Planck equation evolving according to $\Psi$, we can write $h_t(x) = \rho(t_n,x)/\exp(-\Psi(x)/\sigma^2)$, leading to 294 | \begin{align*} 295 | &-\frac{\partial D(\rho(t_n,x)||\exp(-\Psi(x)/\sigma^2))}{\partial t} \\ 296 | %&= -\frac{\partial}{\partial t} \int \exp(-\Psi(x)/\sigma^2) h_t(x) \log h_t(x) dx\\ 297 | &= \int \frac{\exp(-\Psi(x)/\sigma^2)}{h_t(x)} |\nabla h_t(x)|^2 dx \\ 298 | &= \int |\nabla \Psi(x)-\nabla \rho(t_n,x)|^2\rho(t_n,x) dx \leq \epsilon. 299 | \end{align*} 300 | Where the second equality follows via integration by parts on the Fokker-Planck equation. 301 | Applying the Minkowski inequality to the last line gives the desired identity. 302 | \end{proof} 303 | 304 | Theorem \ref{thm:entropic} implies that if we are willing to assume that $\rho(t_n,x)$ is close to mixed, and we can ensure that our estimated $\hat{\Psi}$ has a tight bound on $-\frac{\partial }{\partial t}D(\rho(t_n,x)||\exp(-\hat{\Psi}(x)/\sigma^2))$, then we can recover a good approximation to the true $\Psi$. In practice this assumption and constraint is straightforward to fulfill: experimental designs often track cell populations until they do not show substantial changes ($\rho(t_n,x)$ is close to mixed) and we can fit $\hat{\Psi}$ under the constraint that it is smooth with bounded gradient and 305 | \begin{equation}\label{eq:KL} 306 | D(\rho(t_n,x)||\exp(-\hat{\Psi}(x)/\sigma^2)) \leq \eta. 307 | \end{equation} 308 | Which implicitly bounds the mixedness in Thm. \ref{thm:entropic} by the JKO theorem (Thm. \ref{thm:jko}). Thus we have established a constraint (Eq. \ref{eq:KL}) and experimental condition (Thm. \ref{thm:entropic}) under which we can reliably recover the underlying dynamics even with few timepoints. 309 | 310 | %Without this constraint, inference from finite samples becomes inherently ill-posed. For example, if we have a time-course over which we only observe points in the unit cube, we can either hypothesize that: 1. The potential $\Psi(x)$ is large outside the unit cube, constraining all observed trajectories to lie within the unit cube. Or 2. $\Psi(x)$ forms a wall around the unit cube which maintains the observed trajectories within the cube, but there are also nontrivial energy minima outside the walls. If we assume the mixedness condition of theorem \ref{thm:entropic}, this immediately rules out such degeneracies, leaving only the first case as a possibility. 311 | 312 | %This constraint is necessary when performing inference from finite samples. Consider a set of observations $\rho(t,x)$ where $\rho(0,x)$ is the unit normal and we observe that the samples approach a normal distribution with standard deviation of two. We would like such a set of observations to define the underlying potential $\Psi(x)$ as a quadratic function, but this is impossible: we may have two potential wells, one at zero and another far away (say at 10000) such that the mixing rate is nearly zero, and thus we never observe samples in the second potential well. 313 | 314 | \section{Inference} 315 | We will show that a Wasserstein loss with an entropic regularization on a noisy RNN is natural for this model. 316 | 317 | \subsection{Loss function and regularization} 318 | 319 | To motivate the Wasserstein loss, consider the case where we observe full trajectories of a single stochastic process $X(t)$. Then one natural loss function is to consider the expected squared loss between the observed value $x_t$ and the predicted distribution of $X(t)$ under the model. 320 | 321 | The Wasserstein distance is exactly the analogous quantity to the $L_2$ distance when we switch from fully observed trajectories to populations of indistinguishable particles in a diffusion \citep[Section 3]{adams2013large}. We outline the intuition for this argument here: the squared loss for a diffusion arises from the fact that given $m_t$ trajectories from a diffusion with $x(t)=\{x(t)_0,x(t)_1\hdots x(t)_{m_t}\}$, then $\lim_{\hat{t}\to 0}-\hat{t}\log(P(X(\hat{t}+t)=x(\hat{t}+t)|X(t)=x(t))) = \frac{1}{4}\sum_{i=1}^{m_t} |x(t+\hat{t})_i-x(t)_i|_2^2$. The squared loss thus arises as the log-probability that Brownian motion transforms the predicted value $X(t)$ into the true value $x(t)$ in an infinitesimal time $\hat{t}$. 322 | 323 | If we make the particles indistinguishable via a random permutation $\sigma \in S_{m_0}$, the above limit becomes: 324 | \begin{multline} 325 | \lim_{\hat{t}\to 0}-\hat{t}\log(P(X(t+\hat{t})=x(t+\hat{t})|X(t)=x(t))) = \\ 326 | \frac{1}{4}\inf_{\sigma \in S_{m_n}}\sum_{i=1}^{m_n} |x(t+\hat{t})_i-x(t)_{\sigma(i)}|_2^2. 327 | \end{multline} 328 | This is a special case of the Wasserstein metric, implying that for population inference, the natural analog to empirical squared loss minimization is empirical Wasserstein loss minimization. Thus at time $t_i$ we penalize $W_2(\hat{\rho}(t_i,x),\rho_{\Psi}(t_i,x))^2$ which is the Wasserstein distance between the empirical distribution $\hat{\rho}$ and the marginal distribution predicted by $\Psi$, $\rho_{\Psi}$. This loss is approximated via sampling and the Sinkhorn distance \citep{cuturi2013sinkhorn}. 329 | 330 | We regularize this loss function with an entropic regularizer. Thm. \ref{thm:entropic} states that if $\frac{\partial}{\partial t} D(\rho(t_n,x)||\exp(-\Psi(x)/\sigma^2))$ is small then we can recover any mixed potential. We fulfill this mixing constraint by controlling the relative entropy in Eq. \ref{eq:KL}, which we write as 331 | \[E_{X\sim\rho(t_n,x)}[\log(\rho(t_n,X))]+E_{X\sim\rho(t_i,x)}[\Psi(X)/\sigma^2] \leq \eta,\] 332 | where $\rho(t_n,x)$ is the unknown, true marginal distribution at time $t_n$. Removing constant terms not involving $\Psi(x)$ and replacing $\rho(t_n,x)$ with samples $x_j \sim \rho(t_n,x)$ gives us the regularizer: $\sum_{j=1}^{m_n} \Psi(x_j)/\sigma^2$. Converting this constraint into a regularization term with parameter $\tau$ and assuming that $\Psi$ is contained in a family of models $K$, our objective function is: 333 | \begin{equation}\label{eq:objective} 334 | \min_{\Psi \in K} \left[\sum_{i=1}^n W_2(\hat{\rho}(t_i,x),\rho_{\Psi}(t_i,x))^2\right] + \tau \sum_{j=1}^{m_n} \frac{\Psi(x_j)}{\sigma^2}. 335 | \end{equation} 336 | 337 | The similarity of Eq. \ref{eq:objective} to the JKO theorem (Thm. \ref{thm:jko}) is not coincidental. One interpretation of the JKO theorem is that $W_2$ is the natural metric over marginal distributions and likelihood is the natural measure of model fit over $\Psi$. 338 | 339 | 340 | \begin{figure*}[ht!] 341 | \centering 342 | \subfigure[Stationary pre-training improves both runtime and goodness of fit.]{\includegraphics[scale=0.35]{fig/runtimes.png}\label{fig:stationary}} 343 | %\subfigure[Loss of likelihood (y-axis) of pre-training only and no pre-training models compared to the full optimization model]{\includegraphics[scale=0.4]{fig/lhimprovement.png}\label{fig:goodness}} 344 | \quad 345 | \subfigure[RNN predictions are similar to the true dynamics on 50D data.]{\includegraphics[scale=0.35]{fig/toy_pred_2.png}\label{fig:goodness}} 346 | \quad 347 | \subfigure[Example prediction of baselines on same data]{\includegraphics[scale=0.35]{fig/toy_pred.png}\label{fig:goodness_2}} 348 | \vspace{-8pt} 349 | \caption{The pre-trained RNN captures the multimodal dynamics of the Styblinski flow even in 50-dimensions.} 350 | \vspace{-10pt} 351 | \end{figure*} 352 | 353 | 354 | \subsection{Diffusions as a recurrent network} 355 | 356 | Thus far we have abstractly considered all stochastic processes of the form: $dX(t) = -\nabla \Psi(x) dt + \sqrt{2\sigma^{2}} dW(t)$. 357 | 358 | A natural way to parametrize $\Psi$ is to consider linearly separable potential functions, which we may write as: 359 | \[\Psi(x) = \sum_k h(w_kx+b_k)g_k,\] 360 | such that $h$ is some strictly increasing function. This represents $\Psi$ as the sum of energy barriers $h$ in the direction of vectors $w_k$, allowing us to fit our model via gradient descent, while maintaining interpretability of the parameters. 361 | 362 | Setting $h(x)=\log(1+\exp(x))$ parametrizes $\Psi(x)$ as the sum of nearly linear ramps and we obtain that the drift $\nabla \Psi$ is a one layer of a sigmoid neural network, where the linear terms are tied together much like an autoencoder: 363 | \[\sum_k \nabla h(w_kx+b_k)g_k = \sum_k h'(w_kx+b_k)g_kw_k^T\] 364 | 365 | Applying this to the first order time discretization in Eq. \ref{eq:discr}, a draw $\oy^t_i$ of our stochastic process can be simulated as: 366 | \begin{equation}\label{eq:simul} 367 | \oy^{t+dt}_i = \oy^t_i + \Delta t \sum_k h'(w_k\oy^t_i+b_k)w_kg_k + \sqrt{\Delta t\sigma^2 }z_{it} 368 | \end{equation} 369 | This can be interpreted as a type of RNN with noise based regularization. The network is generative and as $\Delta t\to 0$ the draws from this recurrent net converge to trajectories of the diffusion process $X$ above. \footnote{In practice, we set $\Delta t$ to be 0.1 which gives at least a ten time-steps between observations in our experiments and find anywhere from five to hundred time-steps between observations to be sufficient.} 370 | 371 | 372 | \begin{figure*} 373 | \centering 374 | \subfigure[Quadratic high dimensional flow]{\includegraphics[scale=0.4]{fig/wass_error_quadratic.png}\label{fig:quaddim}} 375 | \subfigure[Styblinski flow in high dimensions]{\includegraphics[scale=0.4]{fig/wass_error_himmelblau.png}\label{fig:himdim}} 376 | \subfigure[Gene expression (D4)]{\includegraphics[scale=0.17]{fig/accuracy_genes.png}\label{fig:genedim}} 377 | \vspace{-8pt} 378 | \caption{Held-out goodness of fit (lower is better), as measured by Wasserstein distance. `Oracle' represents the error from Monte Carlo sampling for the true gradient flow. The RNN parametrization performs best across a wide range of tasks.} 379 | \vspace{-8pt} 380 | \end{figure*} 381 | 382 | 383 | 384 | \subsection{Optimization} 385 | 386 | Optimizing the full objective function (Eq. \ref{eq:objective}) directly via backpropagation across time is slow and sensitive to the initialization. 387 | Exploiting the connection between RNNs and the diffusion, we can pre-train the model by optimizing the regularizer alone: $\sum_{j=1}^{m_n} \Psi(x_j)/\sigma^2$ under the constraint that $\int \exp(-\Psi(x)/\sigma^2)dx = 1$. We solve this optimization problem with contrastive divergence \citep{hinton2002training} using the first-order Euler scheme in Eq. \ref{eq:simul} to generate negative samples. 388 | 389 | After this initialization, we perform backpropagation over time on our objective function, with $\rho_{\Psi}$ approximated via Monte Carlo samples using Eq. \ref{eq:simul} and the Wasserstein error approximated using Sinkhorn distances. These stochastic gradients are then used in Adagrad to optimize $\Psi$\citep{duchi2011adaptive}. We implement the entire method in Theano, and code is available at \url{https://github.com/thashim/population-diffusions}. 390 | 391 | 392 | \section{Results} 393 | 394 | 395 | \begin{figure*}[ht!] 396 | \vspace{-5pt} 397 | \subfigure[D0 and D7 distributions of Oct4 (y-axis) and Krt8 (x-axis)]{\includegraphics[scale=0.4]{fig/Barrier_Expr_scatter.png}\label{fig:datain}} 398 | \subfigure[Learned differentiation dynamics]{\includegraphics[scale=0.4]{fig/d0_d7_fitted_flow.png}\label{fig:dynamics}} 399 | \subfigure[Distributions of true Krt8 expression]{\includegraphics[scale=0.4]{fig/violinplot.png}\label{fig:krt8}} 400 | \vspace{-16pt} 401 | \caption{Observed data and learned model for single-cell RNA-seq data} 402 | \vspace{-8pt} 403 | \end{figure*} 404 | 405 | 406 | 407 | We now demonstrate the effectiveness of both the pre-training and RNN parametrization. 408 | \footnote{Step-size is selected by grid search (see section \ref{sec:params} for other parameter settings). $\sigma$ is assumed known in the simulations, and fixed to the observed marginal variance for the RNA-seq data.} 409 | \subsection{Effectiveness of the stationary pre-training} 410 | The stationary pre-training via contrastive divergence results in substantially better training log-likelihoods in less than a third of the total time of the randomly initialized case (Fig. \ref{fig:stationary}) for the Himmelblau flow (Fig. \ref{fig:problem}). We control for initialization and runtime of both procedures by ensuring that the initial parameters of the pre-training matches that of the random initialization, and applying shared code for both pre-training and backpropagation. 411 | 412 | 413 | \subsection{Learning high dimensional flows} 414 | 415 | One of the primary advantages of using recurrent networks and sums-of-ramps as a potential is that they behave well in high-dimensional estimation problems. We compare our RNN model against a linear $\Psi(x)$, the Orstein-Uhlenbeck process (quadratic $\Psi(x)$), and a local sum-of-gaussian potentials parametrization for $\Psi(x)$ (details in Sec. \ref{sec:altmethods}). 416 | 417 | In the first task (Fig. \ref{fig:quaddim}), we have a population evolution in $\mathbb{R}^d$ for $d\in \{2,10,50\}$ according to a unit quadratic potential $\Psi(x) = |x|_2^2$. The initial measurement is 500 samples drawn from a normal distribution with $1/2$ scale centered at $(5,0,0\hdots 0)$, and the final time measurement is 500 samples at $t=1$ with $\sigma=1.5$. This tests whether our model can recover a simple, high-dimensional potential function. In this case, the simple dynamics mean that the parametric models (Orstein-Uhlenbeck and Linear flows) perform quite well. The RNN parametrization is competitive with these models in as the dimensionality increases, and substantially outperforms the local model (Fig. \ref{fig:quaddim}). 418 | 419 | In the second task (Fig. \ref{fig:himdim}), we consider a population over $d\in \{2,10,50\}$ with two of the dimensions evolving according to the Styblinski flow ($\Psi(x) = ||3x^3-32x+5||_2^2$), and the other dimensions set to zero. This tests whether our model can identify a complex low-dimensional potential embedded in a high-dimensional space. Example outputs in Fig. \ref{fig:goodness} and \ref{fig:goodness_2} demonstrate that our RNN model can model the multi-modal dynamics embedded within a high-dimensional space. The quantitative error in Fig. \ref{fig:himdim} shows that the local and RNN methods perform best at low (2-10) dimensions, but the local method rapidly degenerates in higher dimensions. In both cases, our RNN approach produces substantially lower Wasserstein loss compared both parametric and local approaches. 420 | 421 | \subsection{Analysis of Single-cell RNA-seq} 422 | 423 | In \citep{klein2015droplet} an initially stable embryonic stem cell population (termed `D0' for day 0) begins to differentiate after removal of LIF (leukemia inhibitory factor) and single-cell RNA-seq measurements are made at two, four, and seven days after LIF removal. At each time point, the expression of 24175 genes for several hundred cells (933 cells at D0, 303 at D2, 683 at D4, and 798 at D7) are measured. We apply standard normalization procedures \citep{hicks2015widespread} to correct for batch effects, and impute missing expression levels using nonnegative matrix factorization. Our task is to predict the gene expression at D4 given only the D0 and D7 expression values. 424 | 425 | Fitting our RNN model across the top five and ten most differential genes (as determined by the Wassertein distance between D0 and D7 distributions for each gene), our RNN method performs best compared to baselines (Fig\ref{fig:genedim}), and is the only one to perform better than the trivial baseline of predicting the D4 gene expression using D7 data. We find that ten genes is the limit for accurate prediction with a few hundred cells; in higher dimensions the RNN begins to behave much like the linear model. As the number of captured cells in single-cell RNA-seq grows, our RNN model will be capable of modeling more complex multivariate potentials. 426 | 427 | We now focus on whether our model captures the qualitative dynamics of differentiation for the two main differentiation markers studied in \citep{klein2015droplet}: Keratin 8 (Krt8) which is an epithelial marker and Oct 4 (Pou5f1) which is a embryonic marker. Krt8 in particular shows two sub-populations at day 4 (Fig. \ref{fig:krt8}) suggesting that epigenetic landscape may have multiple minima. 428 | 429 | Fitting our RNN on this two dimensional problem shown in Fig. \ref{fig:datain} we obtain a potential function with a single minimum (Fig. \ref{fig:dynamics}) demonstrating that differentiation is concentrated around a linear path connecting the D0 and D7 distributions. Surprisingly, this simple unimodal potential predicts a bimodal distribution for the D4 Krt8 distribution shown in Fig. \ref{fig:krt8d4} despite the lack of bimodality in either the input data (Fig \ref{fig:datain}) or the potential (Fig \ref{fig:dynamics}). \footnote{Similar qualitative results hold for D4 Krt8 expression under five and ten-dimensional versions (Supp. Fig. \ref{fig:5dkrt}, \ref{fig:5dmarg}, \ref{fig:10dmarg}, \ref{fig:10dcor}).} 430 | 431 | \begin{figure} 432 | \vspace{-5pt} 433 | \centering 434 | \includegraphics[scale=0.30]{fig/krt8d4.png} 435 | \vspace{-7pt} 436 | \caption{D4 predictions of Krt8 recapitulate bimodality} 437 | \label{fig:krt8d4} 438 | \vspace{-15pt} 439 | \end{figure} 440 | 441 | The bimodality arises from modeling the quantitative dynamics from D0 to D7, and provides evidence that even with as few as two time measurments, complex dynamics can be recovered from population level observations. 442 | 443 | \section{Discussion} 444 | 445 | Our work establishes the problem of recovering an underlying potential function using samples from the population distribution. Using a variational interpretation of diffusions, we derive natural and scalable losses and regularizers. Finally, we demonstrate through multiple synthetic datasets and a real single cell RNA-seq dataset that our model performs well in a high-dimensional setting. 446 | 447 | \section*{Acknowledgements} 448 | We would like to thank the reviewers for their helpful comments in revising the paper. 449 | 450 | This research was funded by the National Institute of Health under grants to D.G. and T.J. 1U01HG007037-01 and 1R01HG008363-01. 451 | 452 | \bibliography{ICML} 453 | \bibliographystyle{icml2016} 454 | \newpage 455 | \onecolumn 456 | \setcounter{section}{18} 457 | \renewcommand\thesection{\Alph{section}} 458 | \setcounter{figure}{0} 459 | \renewcommand\thefigure{\thesection.\arabic{figure}} 460 | 461 | 462 | \section{Supplemental results} 463 | 464 | \subsection{Hypothesis test proof}\label{sec:hypo} 465 | \begin{cor}[Hypothesis test for $\Psi$] 466 | Let $\Psi_0$ and $\Psi_1$ be candidate potentials such that given $\rho_0(0,x)=\rho_1(0,x)$ and 467 | \[\frac{\partial \rho_i}{\partial t} = \text{div}(\nabla \Psi_i(x) \rho_i(t,x)) + \sigma^2\nabla^2 \rho_i(t,x)\] 468 | fulfill $\rho_0(t,x) = \rho_1(t,x)$. Define $\rho_i(t_3,x)$ where $t_3 \sim T$ is a draw from $T$ defined as a random variable absolutely continuous with respect to the Lebesgue measure, then either 469 | \[P(\rho_1(t_3,x) = \rho_0(t_3,x))=1\] 470 | if $\forall x$, $\Psi_1(x)=\Psi_0(x)$, or 471 | \[P(\rho_1(t_3,x) = \rho_0(t_3,x))=0\] 472 | otherwise. 473 | \end{cor} 474 | \begin{proof} 475 | By theorem \ref{thm:uniqueop}, we know that if both $\frac{\partial\rho_1}{\partial t} = \frac{\partial\rho_0}{\partial t}$ and $\rho_1(t,x) = \rho_0(t,x)$ for any $t$, then $\Psi_1(x)=\Psi_0(x)$. Therefore if $\Psi_1(x) \neq \Psi_0(x)$, any $t$ such that $\rho_1(t,x)=\rho_0(t,x)$ must have distinct time derivatives. 476 | 477 | Now by Bolzano Weierstrass, if $\rho_1(t,x)=\rho_0(t,x)$ an infinite times over any finite time interval $[0,T]$, then there must be some accumulation point such that $\rho_1(t,x)=\rho_0(t,x)$ has a convergent subsequence. By differentiability of $\rho$ with respect to time, this implies $\frac{\partial\rho_0}{\partial t}$ at some $\rho_1(t,x)=\rho_0(t,x)$. Therefore, if $\Psi_1(x) \neq \Psi_0(x)$ there can only be a finite number of times such that $\rho_1(t,x) = \rho_0(t,x)$. This has measure zero over with respect to the Lebesgue measure, thus any random stopping time $t_3$ implies 478 | \[P(\rho_1(t_3,x) = \rho_0(t_3,x))=0.\] 479 | The other direction occurs by uniqueness of the solution to the Fokker Planck equation. 480 | 481 | 482 | \end{proof} 483 | 484 | \subsection{Boundary conditions for identifiability} 485 | \label{sec:uniqueop2} 486 | 487 | 488 | We prove the non-compact boundary condition, which replaces the boundary with some sequence of compact sets such that the probability of leaving the set limits to zero. 489 | 490 | \begin{thm}[Uniqueness of Fokker-Planck like operators]\label{thm:uniqueop2} 491 | Let $\Psi(x)$ be a $C^1$ solution to the following elliptic PDE: 492 | \begin{equation} 493 | \label{eq:parab2} 494 | f(x) = \nabla^2 \Psi(x) \tau(x) + \nabla \Psi(x) \nabla \tau(x) + \sigma^2 \nabla^2 \tau(x) 495 | \end{equation} 496 | subject to the constraint $\int \exp(-\Psi(x)/\sigma^2) dx = 1$, $\int \tau(x)dx < \infty$. 497 | 498 | Equation \ref{eq:parab2} is fulfilled in the short-time case with, $f=\frac{\partial \rho}{\partial t}$, $\tau = \rho$ and in the time-integral case, $f(x)=\rho(t_0,x)-\rho(t_n,x)$ and $\tau(x) = \int_0^T \rho(t,x)dt$. 499 | 500 | In both cases, assume that the underlying Fokker-Planck boundary condition allows us to construct a sequence of compact sets $\Omega_n$ such that $\lim_{n\to\infty} \int_{x\in \Omega_n} \tau(x)dx = \int_{x\in\mathbb{R}^d} \tau(x)dx < \infty$ and $\lim_{n\to \infty} \int_{x\in \omega} f(x)\to 0$. 501 | 502 | Then $\Psi(x)$ is unique up to sets of measure zero of $\tau(x)$. 503 | \end{thm} 504 | \begin{proof} 505 | Consider any $\Psi_1(x)$ and $\Psi_2(x)$, then by linearity of the PDE $\Psi'(x)=\Psi_1(x) - \Psi_2(x)$ must be a solution to the homogeneous elliptic PDE 506 | \[0 = \text{div}(\nabla \Psi'(x) \tau(x))=\nabla^2 \Psi'(x) \tau(x) + \nabla \Psi'(x) \nabla \tau(x)\] 507 | 508 | Construct $R_{\epsilon,n} = \{x: x\in \Omega_n, \Psi'(x)\leq \epsilon\}$, which is the intersection of the level set of $\Psi'$ with $\Omega_n$. 509 | 510 | Expanding the limit boundary constraint on $f$ and taking the difference we obtain: 511 | \[\lim_{n\to \infty} \int_{x\in \partial \Omega_n} \langle \nabla \Psi'(x)\tau(x),n_x\rangle dx =0.\] 512 | 513 | Analogously to the reflecting boundary condition, define $\partial R_{\epsilon,n}^\circ$ as the boundary of the sublevel set and $\partial \Omega_{\epsilon,n}^\circ$ as the boundary of $\Omega_n$ such that the union of the two sets forms the boundary of $R_{\epsilon,n}$. 514 | 515 | Applying the divergence theorem over the decomposition of the boundary analogously to the other boundary condition: 516 | \begin{align}\label{eq:div1} 517 | &\lim_{n\to\infty}\int_{x\in R_{\epsilon,n}} \text{div}(\nabla \Psi'(x)\tau(x)) dx\\ 518 | \label{eq:div2} 519 | &= \lim_{n\to \infty}\int_{x\in \partial \Omega_n^\circ} \langle \nabla \Psi'(x) \tau(x) , n_x \rangle dx \\ 520 | &+ \lim_{n\to\infty}\int_{x\in \partial R_{\epsilon,n}^\circ} |\nabla\Psi'(x)|_2\tau(x) dx = 0 . 521 | \end{align} 522 | which implies via our boundary constraint 523 | \[\lim_{n\to\infty}\int_{x\in \partial R_{\epsilon,n}^\circ} |\nabla\Psi'(x)|_2\tau(x) dx = 0.\] 524 | This limit occurs uniformly in $\epsilon$, since the first line of Eq \ref{eq:div1} is exactly zero and Eq \ref{eq:div2} is uniformly bounded as 525 | \begin{align*} 526 | &\lim_{n\to \infty}\int_{x\in \partial \Omega_n^\circ} \langle \nabla \Psi'(x) \tau(x) , n_x \rangle dx \\ 527 | &\leq \lim_{n\to \infty}\int_{x\in \partial \Omega_n} \langle \nabla \Psi'(x) \tau(x) , n_x \rangle dx. 528 | \end{align*} 529 | 530 | 531 | Now assume that there exists some compact set $S$ of nonzero measure such that for all $x\in S$, $|\nabla \Psi(x)| \neq 0$. Since $\Psi$ is continuous the extreme value theorem implies the existence of some $\epsilon_{\text{min}} = \min_{x\in S}\Psi(x)$ and $\epsilon_{\text{max}}=\max_{x\in S}\Psi(x)$. Using the fact that any $x$ with $|\nabla\Psi'(x)|\neq 0$ must be a part of $\partial R_{\epsilon,n}^\circ$ for sufficient large $n$ and uniformity of our limit with respect to $\epsilon$ we obtain: 532 | \begin{align*} 533 | &\lim_{n\to\infty} \int_{x\in S} |\nabla \Psi'(x)|_2 \tau(x)dx\\ 534 | &=\lim_{n\to\infty} \int_{\epsilon_{\text{min}}}^{\epsilon_{\text{max}}} \int_{x\in \{\partial R_{\epsilon,n}^\circ \cap S\}} ||\nabla \Psi'(x)||_2 \tau(x) dx d\epsilon\\ 535 | &\leq \lim_{n \to \infty} \int_{\epsilon_{\text{min}}}^{\epsilon_{\text{max}}} \int_{x\in \partial R_{\epsilon,n}^\circ} ||\nabla \Psi'(x)||_2 \tau(x) dx d\epsilon =0. 536 | \end{align*} 537 | Which is a contradiction, as this implies $\lim_{n \to \infty} |\nabla \Psi(x)| =0$ from the fact that $\tau(x)$ has a lower bound strictly greater than zero over $S$. 538 | 539 | Equicontinuity of $\nabla\Psi'(x)$ then implies $|\nabla \Psi'(x)|=0$ for all $x$, and therefore 540 | \[|\nabla \Psi'(x)| = |\nabla \Psi_1(x) - \nabla \Psi_2(x)| = 0.\] 541 | Combined with the normalization constraint, $\int \exp(-\Psi(x)/\sigma^2) dx = 1$, this implies $\Psi_1(x) = \Psi_2(x)$. 542 | \end{proof} 543 | 544 | \subsection{Details on parameters and methods} \label{sec:params} 545 | The following are the `free' hyperparameters of the model: 546 | \begin{itemize} 547 | \item $K$: The number of hidden layers (200 for simulated data, 500 for RNA-seq data) 548 | \item $\Delta t$: simulation timestep (0.01 for simulations, 0.1 for RNA-seq) 549 | \item $\tau$: regularization constant (0.7 for all data) 550 | \item $\epsilon$: step size of adagrad (Grid searched from starting with 0.1 for 10 steps with decaying powers of 2) 551 | \item $\gamma$: adagrad squared gradient decay rate (0.01, all experiments) 552 | \item NS: number of samples to draw from simulations (Fixed to be the same as the number of points at the first time point) 553 | \item burnin: number of steps of the first-order Euler scheme to burn-in for contrastive divergence (set to 50) 554 | \end{itemize} 555 | For initializing the contrastive divergence, $W$ is set to be i.i.d unit Gaussians, $b$ to draws from the $[-1,1]$ uniform, and $g$ to zero. 556 | 557 | \subsection{Alternative methods} \label{sec:altmethods} 558 | We fit the following baseline models: 559 | \begin{itemize} 560 | \item \textbf{Orstein-Uhlenbeck:} Quadratic potential with one parameter $\mu$, $\Psi(x) = (x-\mu)^2$ 561 | \item \textbf{Linear:} Linear potential with one parameter $w$, $\Psi(x) = xw^T$. 562 | \item \textbf{Local:} Sum of Gaussian potentials with three parameters $\mu$, $g$ and $b$, $\Psi(x) = g\exp(-(x-\mu)^2/b^2)$. 563 | \end{itemize} 564 | 565 | 566 | \subsection{High-dim gene expression} 567 | 568 | Applying our RNN model to the top 5 or 10 differentiating genes as measured by the Wasserstein distance between the marginal day 0 and 7 distributions results in qualitatively similar results. In order to fit the higher-complexity multivariate model, we modified a few hyperparameters ($K=2000$, initialization of $b$ as $b_i = ||w_i x_i||_2^2$, increasing NS to 1000, $\sigma=\sqrt{2}$ and using continuous contrastive divergence) and included all (non-heldout) data for pre-training. The parameter changes result in producing a similar goodness-of-fit to the higher dimensional versions of the problem with only a few hundred points. 569 | 570 | For example, the D4 nonstationary dynamics of Krt8 are re-capitulated 571 | \begin{figure}[h] 572 | \centering 573 | \includegraphics[scale=0.5]{fig/krt8-5dim.png} 574 | \caption{5-gene model prediction of Krt8 also reproduces the underlying bimodality of the data} 575 | \label{fig:5dkrt} 576 | \end{figure} 577 | 578 | Plotting the predicted marginal distribution for all 5 genes, we find that the RNN based model substantially outperforms other, parametric approaches to the same problem: 579 | \begin{figure}[h] 580 | \includegraphics[scale=0.5]{fig/krt-marginals-2.png} 581 | \caption{Predicted marginal distributions of the top 5 differentiating genes at day 4} 582 | \label{fig:5dmarg} 583 | \end{figure} 584 | 585 | This same trend holds as we increase the number of genes from 5 to 10 where the RNN performs best compared to alternatives. 586 | \begin{figure}[h] 587 | \includegraphics[scale=0.5]{fig/krt-marginals-10d-2.png} 588 | \caption{Predicted marginal distributions of the top 10 differentiating genes at day 4} 589 | \label{fig:10dmarg} 590 | \end{figure} 591 | We find that as we increase the dimensionality, the learned dynamics begin to become unimodal, as all models struggle to identify the true dynamics from sparse, high-dimensional data. 592 | 593 | Even in this setting where we have a few hundred examples in 10 dimensions, we can still effectively identify correlations and other relationships between genes at this non-equilibrium state. 594 | \begin{figure}[h] 595 | \includegraphics[scale=0.5]{fig/krt-cross-10d.png} 596 | \caption{Predicted against actual pairwise gene expression distributions at the D4 timepoint. The RNN models the correlational structure of the true dynamics.} 597 | \label{fig:10dcor} 598 | \end{figure} 599 | 600 | \end{document} 601 | 602 | 603 | % This document was modified from the file originally made available by 604 | % Pat Langley and Andrea Danyluk for ICML-2K. This version was 605 | % created by Lise Getoor and Tobias Scheffer, it was slightly modified 606 | % from the 2010 version by Thorsten Joachims & Johannes Fuernkranz, 607 | % slightly modified from the 2009 version by Kiri Wagstaff and 608 | % Sam Roweis's 2008 version, which is slightly modified from 609 | % Prasad Tadepalli's 2007 version which is a lightly 610 | % changed version of the previous year's version by Andrew Moore, 611 | % which was in turn edited from those of Kristian Kersting and 612 | % Codrina Lauth. Alex Smola contributed to the algorithmic style files. 613 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | Code and paper repository for the ICML paper "Learning population level diffusions with generative recurrent networks" 3 | 4 | # Paper 5 | 6 | Tex and PDF of the paper are in the root directory. 7 | 8 | # Code 9 | 10 | The codebase consists of four files: 11 | 12 | 1. **Theano-functions** contains the main inference methods, loss functions, hyperparam optimizers. 13 | 2. **Theano-test** runs the simulation datasets 14 | 3. **Single-cell-stuff** runs the single-cell rnaseq experiments 15 | 4. **Flow-functions** is a non-theano version of the inference methods (deprecated, not used to generate results) and some plotting functions used in single-cell plots. 16 | 17 | # Usage 18 | 19 | To use on your own machine, modify theano-functions to write to your own .theanorc (currently writes to /cluster/thashim). 20 | 21 | Additionally, for the single-cell data, wget the data from [GSE65525](http://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE65525) and put it in a directory named `rnase' one level below. The second code block in single-cell-stuff shows expected output. 22 | -------------------------------------------------------------------------------- /code/Flow-functions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 78, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import random\n", 12 | "import numpy as np\n", 13 | "import scipy as sp\n", 14 | "import math\n", 15 | "from sklearn import metrics\n", 16 | "from sklearn import svm\n", 17 | "from sklearn import manifold\n", 18 | "from sklearn.datasets import *\n", 19 | "from sklearn.neighbors import NearestNeighbors\n", 20 | "import matplotlib as mpl\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "import time\n", 23 | "import seaborn as sns\n", 24 | "%matplotlib inline" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Utility / plotting functions" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 897, 37 | "metadata": { 38 | "collapsed": true 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "def plot_flow(x,y,fun,ladj=5):\n", 43 | " u=np.zeros((x.shape[0],y.shape[0]))\n", 44 | " v=np.zeros((x.shape[0],y.shape[0]))\n", 45 | " nrm=np.zeros((x.shape[0],y.shape[0]))\n", 46 | " for i in xrange(x.shape[0]):\n", 47 | " ptv=np.vstack((np.full(y.shape[0],x[i]),y))\n", 48 | " flowtmp=fun(ptv)\n", 49 | " u[:,i]=flowtmp[0,:]\n", 50 | " v[:,i]=flowtmp[1,:]\n", 51 | " nrm[:,i]=np.sqrt(np.sum(flowtmp**2.0,0))\n", 52 | " plt.streamplot(x,y,u,v,density=1.0,linewidth=ladj*nrm/np.max(nrm))" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 898, 58 | "metadata": { 59 | "collapsed": true 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "def euler_maruyama_dist(p, flow, dt, t, sd):\n", 64 | " pp = np.copy(p)\n", 65 | " n = int(t/dt)\n", 66 | " sqrtdt = np.sqrt(dt)\n", 67 | " for i in xrange(n):\n", 68 | " drift = flow(pp)\n", 69 | " pp = pp + drift*dt + np.random.normal(scale=sd,size=p.shape)*sqrtdt\n", 70 | " return pp" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": { 77 | "collapsed": true 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "def euler_maruyama_dist_traj(p, flow, dt, t, sd):\n", 82 | " pp = np.copy(p)\n", 83 | " n = int(t/dt)\n", 84 | " pset = np.zeros((pp.shape[0],pp.shape[1],n))\n", 85 | " sqrtdt = np.sqrt(dt)\n", 86 | " for i in xrange(n):\n", 87 | " drift = flow(pp)\n", 88 | " pp = pp + drift*dt + np.random.normal(scale=sd,size=p.shape)*sqrtdt\n", 89 | " pset[:,:,i]=pp\n", 90 | " return pset" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 899, 96 | "metadata": { 97 | "collapsed": true 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "def plot_w(W_matrix, b_vec, g_vec):\n", 102 | " uvw=W_matrix/np.sum(W_matrix**2,1)[:,np.newaxis]\n", 103 | " offsets = uvw*b_vec[:,np.newaxis]\n", 104 | " plt.quiver(offsets[:,0],offsets[:,1],W_matrix[:,0]*g_vec,W_matrix[:,1]*g_vec)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "### Plotting code for the output" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 894, 117 | "metadata": { 118 | "collapsed": true 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "def plot_flow_par(x,y,potfun,W,b,g):\n", 123 | " u=np.zeros((x.shape[0],y.shape[0]))\n", 124 | " v=np.zeros((x.shape[0],y.shape[0]))\n", 125 | " nrm=np.zeros((x.shape[0],y.shape[0]))\n", 126 | " for i in xrange(x.shape[0]):\n", 127 | " ptv=np.vstack((np.full(y.shape[0],x[i]),y))\n", 128 | " flowtmp=drift_fun(potfun,W,b,g,ptv)\n", 129 | " u[:,i]=flowtmp[0,:]\n", 130 | " v[:,i]=flowtmp[1,:]\n", 131 | " nrm[:,i]=np.sqrt(np.sum(flowtmp**2.0,0))\n", 132 | " #plt.quiver(x,y,u,v)\n", 133 | " plt.streamplot(x,y,u,v,density=1.0,linewidth=3*nrm/np.max(nrm))" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 895, 139 | "metadata": { 140 | "collapsed": true 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "def plot_flow_pot(pot,x,y,W,b,g):\n", 145 | " z=np.zeros((x.shape[0],y.shape[0]))\n", 146 | " for i in xrange(x.shape[0]):\n", 147 | " ptv=np.vstack((np.full(y.shape[0],x[i]),y))\n", 148 | " flowtmp= np.sum(pot.f(np.dot(W,ptv)+b[:,np.newaxis])*g[:,np.newaxis],0)\n", 149 | " z[:,i]=flowtmp\n", 150 | " plt.pcolormesh(x,y,np.exp(z))\n", 151 | " CS = plt.contour(x,y,z)\n", 152 | " plt.clabel(CS, inline=1, fontsize=10)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 896, 158 | "metadata": { 159 | "collapsed": true 160 | }, 161 | "outputs": [], 162 | "source": [ 163 | "def plot_flow_both(x,y,parin):\n", 164 | " plot_flow_pot(parin.potin,x,y,parin.W_matrix,parin.b_vec,parin.g_vec)\n", 165 | " plot_flow_par(x,y,parin.potin,parin.W_matrix,parin.b_vec,parin.g_vec)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "## Potential function" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "For algorithmic details, see: https://www.sharelatex.com/project/565221a8db798e5822aba651" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 268, 185 | "metadata": { 186 | "collapsed": true 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "def ilogit(x):\n", 191 | " return sp.special.expit(x)\n", 192 | " #return 1/(1+np.exp(-x))" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 618, 198 | "metadata": { 199 | "collapsed": true 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "class logitPotential:\n", 204 | " \"\"\"This function defines a sum-of-logits potential\"\"\"\n", 205 | " def f(self,x):\n", 206 | " return -1*ilogit(x)\n", 207 | " def fp(self,x):\n", 208 | " lx=ilogit(x)\n", 209 | " return -1*lx*(1-lx)\n", 210 | " def fpp(self,x):\n", 211 | " lx=ilogit(x)\n", 212 | " return -1*(lx*(1-lx)**2-lx**2*(1-lx))" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 263, 218 | "metadata": { 219 | "collapsed": true 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "class reluPotential:\n", 224 | " \"\"\"This function defines a potential as log(1+exp(x))\"\"\"\n", 225 | " def f(self,x):\n", 226 | " return -1*np.log(1+np.exp(x))\n", 227 | " def fp(self,x):\n", 228 | " return -1*ilogit(x)\n", 229 | " def fpp(self,x):\n", 230 | " lx = ilogit(x)\n", 231 | " return -1*lx*(1-lx)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 45, 237 | "metadata": { 238 | "collapsed": true 239 | }, 240 | "outputs": [], 241 | "source": [ 242 | "class quadraticPotential:\n", 243 | " \"\"\"This function defines a potential as x**2\"\"\"\n", 244 | " def f(self,x):\n", 245 | " return -x**2/2\n", 246 | " def fp(self,x):\n", 247 | " return -x\n", 248 | " def fpp(self,x):\n", 249 | " return np.zeros(x.shape)-1" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "# Backprop-related" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "## Simulating a SDE" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": { 270 | "collapsed": true 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "def drift_fun(pot,W,b,g,y):\n", 275 | " scalings = pot.fp(np.dot(W,y)+b[:,np.newaxis])*g[:,np.newaxis] #matrix, K by num_samp\n", 276 | " return np.dot(np.transpose(W),scalings)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 49, 282 | "metadata": { 283 | "collapsed": true 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "def drift_fun_single(pot,W,b,g,y):\n", 288 | " scalings = pot.fp(np.dot(W,y)+b[:,np.newaxis])*g[:,np.newaxis] #matrix, K by num_samp\n", 289 | " drift = np.zeros(y.shape)\n", 290 | " for i in xrange(drift.shape[1]):\n", 291 | " drift[:,i]=np.sum(W*scalings[:,i][:,np.newaxis],0)\n", 292 | " return drift" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": { 299 | "collapsed": true 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "def euler_maruyama_traj(p,num_samp,W_matrix,b_vec,g_vec,dt,time,sd,potfun):\n", 304 | " repflag = p.shape[1] < num_samp\n", 305 | " p_sub=np.random.choice(p.shape[1],size=num_samp,replace=repflag)\n", 306 | " pp = np.copy(p[:,p_sub])\n", 307 | " n = int(time/dt)\n", 308 | " ptraj = np.zeros((p.shape[0],num_samp,n))\n", 309 | " sqrtdt = np.sqrt(dt)\n", 310 | " for i in xrange(n):\n", 311 | " drift = drift_fun(potfun,W_matrix,b_vec,g_vec,pp)\n", 312 | " pp = pp + drift*dt + np.random.normal(scale=sd,size=(p.shape[0],num_samp))*sqrtdt\n", 313 | " ptraj[:,:,i]=pp\n", 314 | " return ptraj" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": {}, 320 | "source": [ 321 | "## Loss function" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 6, 327 | "metadata": { 328 | "collapsed": true 329 | }, 330 | "outputs": [], 331 | "source": [ 332 | "def get_dist(yt, ytrue):\n", 333 | " ytnorm = np.sum(yt**2,0)\n", 334 | " ytruenorm = np.sum(ytrue**2,0)\n", 335 | " dotprod = np.dot(yt.T,ytrue)\n", 336 | " return np.add.outer(ytnorm,ytruenorm) - 2*dotprod" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 7, 342 | "metadata": { 343 | "collapsed": true 344 | }, 345 | "outputs": [], 346 | "source": [ 347 | "from scipy.optimize import fminbound\n", 348 | "def error_term(yt, ytrue, kern_sig, minv = 1e-4):\n", 349 | " distsq = get_dist(yt,ytrue)\n", 350 | " d=yt.shape[0]\n", 351 | " if kern_sig is None:\n", 352 | " train_size = int(0.2*yt.shape[1])+1\n", 353 | " indices = np.random.permutation(yt.shape[1])\n", 354 | " training_idx, test_idx = indices[:train_size], indices[train_size:]\n", 355 | " training, test = yt[:,training_idx], yt[:,test_idx]\n", 356 | " dist_train = get_dist(training,test)\n", 357 | " spo=fminbound(error_from_dmat, x1=minv, x2=max(np.max(dist_train),4.0*minv)/2.0, args=(dist_train, d), full_output=True)\n", 358 | " kern_sig = spo[0]\n", 359 | " expterm = np.exp(-distsq/(2*kern_sig))/kern_sig**(d/2.0)\n", 360 | " esum = np.sum(expterm,0)\n", 361 | " #print esum.shape\n", 362 | " errweight = expterm/esum\n", 363 | " grad_err = np.zeros(yt.shape)\n", 364 | " for i in xrange(errweight.shape[0]):\n", 365 | " grad_err[:,i]=np.sum(-2*(yt[:,i][:,np.newaxis]-ytrue)/kern_sig*errweight[i,],1)\n", 366 | " return grad_err, np.sum(np.log(esum))" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 28, 372 | "metadata": { 373 | "collapsed": true 374 | }, 375 | "outputs": [], 376 | "source": [ 377 | "def error_from_dmat(kern_sig, distsq, d):\n", 378 | " expterm = np.exp(-distsq/(2*kern_sig))/kern_sig**(d/2.0)\n", 379 | " fv = -1*np.sum(np.log(np.sum(expterm,0)))\n", 380 | " #print kern_sig, fv\n", 381 | " return fv" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "metadata": { 388 | "collapsed": true 389 | }, 390 | "outputs": [], 391 | "source": [ 392 | "from scipy.optimize import brent\n", 393 | "def optim_dmat(dmat,d):\n", 394 | " spo=sp.optimize.brent(error_from_dmat, args=(dmat, d))" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "## Backprop + Error gradient" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 1, 407 | "metadata": { 408 | "collapsed": true 409 | }, 410 | "outputs": [], 411 | "source": [ 412 | "\"\"\"Given err at time t, yt-dt, produce err at t-dt, and backpropagate\"\"\"\n", 413 | "def backweight(pot, err, W_matrix,b_vec,g,ytp,dt):\n", 414 | " Wydot = np.dot(W_matrix,ytp)\n", 415 | " Wedot = np.dot(W_matrix,err)\n", 416 | " linterm = Wydot+b_vec\n", 417 | " pplin = pot.fp(linterm)\n", 418 | " pdlin = pot.fpp(linterm)\n", 419 | " scalings = g*pdlin*Wedot\n", 420 | " err_new = err + dt*np.sum(W_matrix*scalings[:,np.newaxis],0)\n", 421 | " dw = dt*(g*pplin)[:,np.newaxis]*err + (dt*g*pdlin)[:,np.newaxis]*W_matrix*np.dot(ytp,err)\n", 422 | " db = dt*g*pdlin*Wedot\n", 423 | " dg = dt*pplin*Wedot\n", 424 | " return [dw, db, dg, err_new]" 425 | ] 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "metadata": {}, 430 | "source": [ 431 | "##Old code" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": { 438 | "collapsed": true 439 | }, 440 | "outputs": [], 441 | "source": [ 442 | " \n", 443 | "#these functions deal with single yi/yt\n", 444 | "def weight_deriv(pot,err,W_matrix,b_vec,g,k,yi,dt):\n", 445 | " linterm = np.dot(W_matrix[k,:],yi)+b_vec[k]\n", 446 | " plin = pot.fp(linterm)\n", 447 | " pdlin = pot.fpp(linterm)\n", 448 | " dwk = dt*g[k]*plin*err + dt*g[k]*pdlin*W_matrix[k,:]*np.dot(yi,err)\n", 449 | " dbk = dt*g[k]*pdlin*np.dot(W_matrix[k,:],err)\n", 450 | " dgk = dt*plin*np.dot(W_matrix[k,:],err)\n", 451 | " return [dwk, dbk, dgk]\n", 452 | "\n", 453 | "def backprop_deriv(pot,err,W_matrix, b_vec,g,yt,dt):\n", 454 | " pdlin = pot.fpp(np.dot(W_matrix,yt)+b_vec)\n", 455 | " scalings = g*pdlin*np.dot(W_matrix,err)\n", 456 | " return err + dt*np.sum(W_matrix*scalings[:,np.newaxis],0)\n", 457 | "\n", 458 | "\"\"\"Given err at time t, yt-dt, produce w gradients. Takes a single y.\"\"\"\n", 459 | "def weight_deriv_all(pot,err,W_matrix,b_vec,g,yi,dt):\n", 460 | " Wydot = np.dot(W_matrix,yi)\n", 461 | " Wedot = np.dot(W_matrix,err)\n", 462 | " linterm = Wydot+b_vec\n", 463 | " pplin = pot.fp(linterm)\n", 464 | " pdlin = pot.fpp(linterm)\n", 465 | " dw = dt*(g*pplin)[:,np.newaxis]*err + (dt*g*pdlin)[:,np.newaxis]*W_matrix*np.dot(yi,err)\n", 466 | " db = dt*g*pdlin*Wedot\n", 467 | " dg = dt*pplin*Wedot\n", 468 | " return [dw, db, dg]" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "# Gradient descent helpers" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "metadata": { 482 | "collapsed": true 483 | }, 484 | "outputs": [], 485 | "source": [ 486 | "def backweight_all(pot,err_top, W_matrix, b_vec, g_vec, traj,dt):\n", 487 | " grad_mat = np.zeros(W_matrix.shape)\n", 488 | " grad_vec = np.zeros(b_vec.shape)\n", 489 | " grad_g = np.zeros(g_vec.shape)\n", 490 | " err_mat = np.zeros(traj.shape)\n", 491 | " for i in xrange(traj.shape[1]):\n", 492 | " err_cur = np.copy(err_top[:,i])\n", 493 | " err_mat[:,i,traj.shape[2]-1]=err_cur\n", 494 | " for t in xrange(traj.shape[2]-1):\n", 495 | " revt = traj.shape[2]-t-2\n", 496 | " dw, db, dg, err_cur = backweight(pot,err_cur, W_matrix, b_vec ,g_vec, traj[:,i,revt],dt)\n", 497 | " grad_mat+=dw\n", 498 | " grad_vec+=db\n", 499 | " grad_g+=dg\n", 500 | " err_mat[:,i,revt]=err_cur\n", 501 | " return grad_mat, grad_vec, grad_g, err_mat" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": {}, 507 | "source": [ 508 | "These classes carry the parameters around" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": 1, 514 | "metadata": { 515 | "collapsed": false 516 | }, 517 | "outputs": [ 518 | { 519 | "ename": "NameError", 520 | "evalue": "name 'logitPotential' is not defined", 521 | "output_type": "error", 522 | "traceback": [ 523 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 524 | "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", 525 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdt\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdt\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 14\u001b[1;33m \u001b[1;32mclass\u001b[0m \u001b[0mparset\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 15\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mK\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mD\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpotin\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlogitPotential\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscale\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpotin\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpotin\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 526 | "\u001b[1;32m\u001b[0m in \u001b[0;36mparset\u001b[1;34m()\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 14\u001b[0m \u001b[1;32mclass\u001b[0m \u001b[0mparset\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 15\u001b[1;33m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mK\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mD\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpotin\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlogitPotential\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscale\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 16\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpotin\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpotin\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 17\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mW_matrix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mK\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mD\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mscale\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 527 | "\u001b[1;31mNameError\u001b[0m: name 'logitPotential' is not defined" 528 | ] 529 | } 530 | ], 531 | "source": [ 532 | "import copy\n", 533 | "class observed:\n", 534 | " def __init__(self, p_init, p_out):\n", 535 | " self.p_init=p_init\n", 536 | " self.p_out=p_out\n", 537 | "class hyperpars:\n", 538 | " def __init__(self, NS, eps, sd, sdkern, dt, time):\n", 539 | " self.eps=eps\n", 540 | " self.NS=NS\n", 541 | " self.sd=sd\n", 542 | " self.sdkern=sdkern\n", 543 | " self.dt=dt\n", 544 | " self.time=time\n", 545 | "class parset:\n", 546 | " def __init__(self, K, D, potin=logitPotential(), scale=1, muzero=None):\n", 547 | " if muzero is None:\n", 548 | " muzero = np.zeros(D)\n", 549 | " self.potin=potin\n", 550 | " self.W_matrix=np.random.randn(K,D)*scale\n", 551 | " offset = np.dot(self.W_matrix,muzero)\n", 552 | " self.b_vec=np.random.uniform(low=-1,high=1,size=K) - offset\n", 553 | " self.g_vec=np.zeros(K)\n", 554 | " self.W_sqsum=np.ones(self.W_matrix.shape)\n", 555 | " self.b_sqsum=np.ones(self.b_vec.shape)\n", 556 | " self.g_sqsum=np.ones(self.g_vec.shape)\n", 557 | " self.fvvec=[]\n", 558 | " self.tvec=[]\n", 559 | " self.tnow=0\n", 560 | " \n", 561 | " def gclip(self, grad, gmax=1e5):\n", 562 | " g_new = []\n", 563 | " for i in xrange(len(grad)):\n", 564 | " vnorm = np.sqrt(np.sum(grad[i]**2.0))\n", 565 | " sfactor = max(1, vnorm/gmax)\n", 566 | " g_new.append(np.copy(grad[i]/sfactor))\n", 567 | " return g_new\n", 568 | " \n", 569 | " def update(self, grad, eps_val, fv, tv, ada=1e-3):\n", 570 | " self.W_sqsum = self.W_sqsum + eps_val*ada*grad[0]**2\n", 571 | " self.b_sqsum = self.b_sqsum + eps_val*ada*grad[1]**2\n", 572 | " self.g_sqsum = self.g_sqsum + eps_val*ada*grad[2]**2\n", 573 | " self.W_matrix = self.W_matrix + eps_val*grad[0]/np.sqrt(self.W_sqsum)\n", 574 | " self.b_vec = self.b_vec + eps_val*grad[1]/np.sqrt(self.b_sqsum)\n", 575 | " self.g_vec = self.g_vec + eps_val*grad[2]/np.sqrt(self.g_sqsum)\n", 576 | " self.fvvec.append(fv)\n", 577 | " self.tnow=self.tnow+tv\n", 578 | " self.tvec.append(self.tnow)\n", 579 | " \n", 580 | " def reset_ada(self):\n", 581 | " self.W_sqsum=np.ones(self.W_matrix.shape)\n", 582 | " self.b_sqsum=np.ones(self.b_vec.shape)\n", 583 | " self.g_sqsum=np.ones(self.g_vec.shape)\n", 584 | " \n", 585 | " def copy(self):\n", 586 | " parnew = parset(K=self.b_vec.shape[0],D=self.W_matrix.shape[1],potin=self.potin)\n", 587 | " parnew.W_matrix = np.copy(self.W_matrix)\n", 588 | " parnew.b_vec = np.copy(self.b_vec)\n", 589 | " parnew.g_vec = np.copy(self.g_vec)\n", 590 | " parnew.fvvec = copy.copy(self.fvvec)\n", 591 | " parnew.tvec = copy.copy(self.tvec)\n", 592 | " parnew.tnow = self.tnow\n", 593 | " return parnew\n", 594 | " \n", 595 | "class observed_list:\n", 596 | " def __init__(self, p_list, t_list):\n", 597 | " self.p_list=p_list\n", 598 | " self.t_list=t_list" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": null, 604 | "metadata": { 605 | "collapsed": true 606 | }, 607 | "outputs": [], 608 | "source": [ 609 | "from IPython.html.widgets import FloatProgress\n", 610 | "from IPython.display import display\n", 611 | "from time import sleep\n", 612 | "\n", 613 | "def run_grad(datin,parin,hpars,maxit,debug=True):\n", 614 | " if debug:\n", 615 | " f = FloatProgress(min=0, max=maxit)\n", 616 | " display(f)\n", 617 | " for i in xrange(maxit):\n", 618 | " time_start = time.clock()\n", 619 | " W_mat = parin.W_matrix\n", 620 | " b_v = parin.b_vec\n", 621 | " g_v = parin.g_vec\n", 622 | " emtj=euler_maruyama_traj(datin.p_init,hpars.NS,W_mat,b_v,g_v,hpars.dt,hpars.time,hpars.sd,parin.potin)\n", 623 | " err_out, fval=error_term(emtj[:,:,emtj.shape[2]-1],datin.p_out,hpars.sdkern)\n", 624 | " gall = backweight_all(parin.potin,err_out, W_mat, b_v, g_v, emtj, hpars.dt)\n", 625 | " parin.update(gall,hpars.eps,fval,time.clock()-time_start)\n", 626 | " if debug:\n", 627 | " f.value = i\n", 628 | " if debug:\n", 629 | " print(fval)\n", 630 | " plt.figure(1)\n", 631 | " plt.plot(parin.fvvec)\n", 632 | " plt.figure(2)\n", 633 | " ind = emtj.shape[2]-1\n", 634 | " plt.scatter(emtj[:,:,ind][0,:],emtj[:,:,ind][1,:],c='red')\n", 635 | " plt.scatter(datin.p_out[0,:],datin.p_out[1,:])\n", 636 | " plt.quiver(emtj[:,:,ind][0,:],emtj[:,:,ind][1,:],err_out[0,:],err_out[1,:])\n", 637 | " return parin" 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": null, 643 | "metadata": { 644 | "collapsed": true 645 | }, 646 | "outputs": [], 647 | "source": [ 648 | "def run_grad_list(datin_list,parin,hpars,maxit,debug=True, delta=False, ada_val=1e-3):\n", 649 | " if debug:\n", 650 | " f = FloatProgress(min=0, max=maxit)\n", 651 | " display(f)\n", 652 | " for i in xrange(maxit):\n", 653 | " W_mat = parin.W_matrix\n", 654 | " b_v = parin.b_vec\n", 655 | " g_v = parin.g_vec\n", 656 | " dW = np.zeros(W_mat.shape)\n", 657 | " db = np.zeros(b_v.shape)\n", 658 | " dg = np.zeros(g_v.shape)\n", 659 | " fv_tmp = 0\n", 660 | " time_start = time.clock()\n", 661 | " for j in xrange(len(datin_list.t_list)-1):\n", 662 | " if not delta:\n", 663 | " t_cur = datin_list.t_list[j+1] - datin_list.t_list[0]\n", 664 | " dat_cur = datin_list.p_list[j+1]\n", 665 | " dat_init = datin_list.p_list[0]\n", 666 | " else:\n", 667 | " t_cur = datin_list.t_list[j+1]-datin_list.t_list[j]\n", 668 | " dat_cur = datin_list.p_list[j+1]\n", 669 | " dat_init = datin_list.p_list[j]\n", 670 | " emtj=euler_maruyama_traj(dat_init,hpars.NS,W_mat,b_v,g_v,hpars.dt,t_cur,hpars.sd,parin.potin)\n", 671 | " err_out, fval=error_term(emtj[:,:,emtj.shape[2]-1],dat_cur,hpars.sdkern)\n", 672 | " gall = backweight_all(parin.potin,err_out, W_mat, b_v, g_v, emtj, hpars.dt)\n", 673 | " dW = dW + gall[0]\n", 674 | " db = db + gall[1]\n", 675 | " dg = dg + gall[2]\n", 676 | " fv_tmp = fv_tmp + fval\n", 677 | " parin.update([dW, db, dg],hpars.eps,fv_tmp,time.clock()-time_start, ada_val)\n", 678 | " if debug:\n", 679 | " f.value = i\n", 680 | " if debug:\n", 681 | " print(fval)\n", 682 | " plt.figure(1)\n", 683 | " plt.plot(parin.fvvec)\n", 684 | " for j in xrange(len(datin_list.t_list)-1):\n", 685 | " plt.figure(j+2)\n", 686 | " t_cur = datin_list.t_list[j+1]\n", 687 | " dat_cur = datin_list.p_list[j+1]\n", 688 | " dat_init = datin_list.p_list[0]\n", 689 | " emtj=euler_maruyama_traj(dat_init,hpars.NS,W_mat,b_v,g_v,hpars.dt,t_cur,hpars.sd,parin.potin)\n", 690 | " ind = emtj.shape[2]-1\n", 691 | " plt.scatter(emtj[:,:,ind][0,:],emtj[:,:,ind][1,:],c='red')\n", 692 | " plt.scatter(dat_cur[0],dat_cur[1])\n", 693 | " #plt.quiver(emtj[:,:,ind][0,:],emtj[:,:,ind][1,:],err_out[0,:],err_out[1,:])\n", 694 | " return parin" 695 | ] 696 | }, 697 | { 698 | "cell_type": "markdown", 699 | "metadata": {}, 700 | "source": [ 701 | "## Old code" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": null, 707 | "metadata": { 708 | "collapsed": true 709 | }, 710 | "outputs": [], 711 | "source": [ 712 | "def backprop_all(pot,err_top, W_matrix, b_vec, g_vec, traj,dt):\n", 713 | " err_mat = np.zeros(traj.shape)\n", 714 | " for i in xrange(traj.shape[1]):\n", 715 | " err_cur = np.copy(err_top[:,i])\n", 716 | " err_mat[:,i,traj.shape[2]-1]=err_cur\n", 717 | " for t in xrange(traj.shape[2]-1):\n", 718 | " revt = traj.shape[2]-t-2\n", 719 | " err_cur = backprop_deriv(pot,err_cur, W_matrix, b_vec ,g_vec, traj[:,i,revt],dt)\n", 720 | " err_mat[:,i,revt]=err_cur\n", 721 | " return err_mat" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": null, 727 | "metadata": { 728 | "collapsed": true 729 | }, 730 | "outputs": [], 731 | "source": [ 732 | "def grad_all(pot,err_all, W_matrix, b_vec, g_vec, traj,dt):\n", 733 | " grad_mat = np.zeros(W_matrix.shape)\n", 734 | " grad_vec = np.zeros(b_vec.shape)\n", 735 | " grad_g = np.zeros(g_vec.shape)\n", 736 | " for i in xrange(traj.shape[1]):\n", 737 | " for t in xrange(traj.shape[2]-1):\n", 738 | " dw, db, dg = webight_deriv_all(pot,err_all[:,i,t+1],W_matrix,b_vec,g_vec,traj[:,i,t],dt)\n", 739 | " grad_mat=grad_mat+dw\n", 740 | " grad_vec=grad_vec+db\n", 741 | " grad_g = grad_g+dg\n", 742 | " return grad_mat, grad_vec, grad_g" 743 | ] 744 | }, 745 | { 746 | "cell_type": "markdown", 747 | "metadata": {}, 748 | "source": [ 749 | "## Initialization at equilibrium" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": 841, 755 | "metadata": { 756 | "collapsed": true 757 | }, 758 | "outputs": [], 759 | "source": [ 760 | "def logP(pot, W_matrix, b_vec, g_vec, x):\n", 761 | " \"\"\"x is a matrix of (dim, n_pts), return a vector of length n_pts of logp for each point.\"\"\"\n", 762 | " return np.sum(pot.f(np.dot(W_matrix,x)+b_vec[:,np.newaxis])*g_vec[:,np.newaxis],0)" 763 | ] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": 842, 768 | "metadata": { 769 | "collapsed": false 770 | }, 771 | "outputs": [], 772 | "source": [ 773 | "def MALA_chain(pot, W_matrix, b_vec, g_vec, state, k, dt, sd, burnin=0):\n", 774 | " ptraj = np.zeros((state.shape[0],state.shape[1],k))\n", 775 | " sqrtdt = np.sqrt(dt)\n", 776 | " acc_sum = 0\n", 777 | " for i in xrange(k+burnin):\n", 778 | " drift = drift_fun(pot, W_matrix, b_vec, g_vec, state)\n", 779 | " state_new = state + drift*dt + np.random.normal(scale=sd,size=state.shape)*sqrtdt\n", 780 | " if i >= burnin:\n", 781 | " drift_new = drift_fun(pot, W_matrix, b_vec, g_vec, state_new)\n", 782 | " lpdiff = logP(pot, W_matrix, b_vec, g_vec, state_new) - logP(pot, W_matrix, b_vec, g_vec, state)\n", 783 | " lq1 = -1.0/(2*dt*sd**2) * np.sum(((state_new-state) - drift*dt)**2,0)\n", 784 | " lq2 = -1.0/(2*dt*sd**2) * np.sum(((state-state_new) - drift_new*dt)**2,0)\n", 785 | " lqdiff = lq1-lq2\n", 786 | " accpr = np.exp(lpdiff - lqdiff)\n", 787 | " accept_ind = np.random.uniform(size=accpr.shape[0]) < accpr\n", 788 | " acc_sum = acc_sum+np.sum(accept_ind)\n", 789 | " state_new[:,np.nonzero(1-accept_ind)[0]] = state[:,np.nonzero(1-accept_ind)[0]]\n", 790 | " ptraj[:,:,i-burnin]=state_new\n", 791 | " state = state_new\n", 792 | " #print acc_sum/float(state.shape[1]*k)\n", 793 | " return ptraj" 794 | ] 795 | }, 796 | { 797 | "cell_type": "code", 798 | "execution_count": 843, 799 | "metadata": { 800 | "collapsed": true 801 | }, 802 | "outputs": [], 803 | "source": [ 804 | "def MALA_tester():\n", 805 | " W=np.eye(2)\n", 806 | " b_vec=np.zeros(2)\n", 807 | " g_vec=np.ones(2)\n", 808 | " state=np.zeros((2,3))+10\n", 809 | " return MALA_chain(quadraticPotential(),W,b_vec,g_vec,state,1000, 0.1, 1, burnin=100)" 810 | ] 811 | }, 812 | { 813 | "cell_type": "code", 814 | "execution_count": 844, 815 | "metadata": { 816 | "collapsed": false, 817 | "scrolled": false 818 | }, 819 | "outputs": [], 820 | "source": [ 821 | "#mt=MALA_tester()\n", 822 | "#plt.hist(mt[0,0,:], bins=50,alpha=0.5,normed=True)\n", 823 | "#plt.hist(np.random.normal(size=mt.shape[2],scale=1.0),bins=50,alpha=0.5,normed=True)" 824 | ] 825 | }, 826 | { 827 | "cell_type": "code", 828 | "execution_count": null, 829 | "metadata": { 830 | "collapsed": true 831 | }, 832 | "outputs": [], 833 | "source": [ 834 | "def logPGrad(pot, W_matrix, b_vec, g_vec, x, factr=1.0):\n", 835 | " \"\"\"Derive the logP gradient for a vector of points x of size (dim, n_samples)\"\"\"\n", 836 | " Wdot = np.dot(W_matrix,x)\n", 837 | " linterm = Wdot+b_vec[:,np.newaxis] #linterm - size of K (num hidden units) by n_samples\n", 838 | " pplin = pot.fp(linterm) #size K by n_samples.\n", 839 | " d_base = pplin*g_vec[:,np.newaxis]\n", 840 | " dW = np.dot(d_base, np.transpose(x))*factr\n", 841 | " dg = np.sum(pot.f(linterm),1)*factr\n", 842 | " db = np.sum(d_base,1)*factr\n", 843 | " return dW, dg, db" 844 | ] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "execution_count": null, 849 | "metadata": { 850 | "collapsed": true 851 | }, 852 | "outputs": [], 853 | "source": [ 854 | "def logP_cdopt(parin, samples, niter, stepsize, dt=0.01, burnin=10, ns=500, ctk=True):\n", 855 | " p_ind = np.random.randint(0,np.shape(samples)[1],ns)\n", 856 | " p_mat = samples[:,p_ind]\n", 857 | " n_dat = np.shape(samples)[1]\n", 858 | " factr = n_dat/float(ns)\n", 859 | " for i in xrange(niter):\n", 860 | " t_start = time.clock()\n", 861 | " dW, dg, db = logPGrad(parin.potin, parin.W_matrix, parin.b_vec, parin.g_vec, samples, factr=1.0)\n", 862 | " neg_samp = MALA_chain(parin.potin, parin.W_matrix, parin.b_vec, parin.g_vec, p_mat, 1, dt, np.sqrt(2), burnin=burnin)\n", 863 | " if ctk:\n", 864 | " p_mat = neg_samp[:,:,0]\n", 865 | " dW_neg, dg_neg ,db_neg = logPGrad(parin.potin, parin.W_matrix, parin.b_vec, parin.g_vec, neg_samp[:,:,0], factr=factr)\n", 866 | " parin.update([dW-dW_neg, db-db_neg, dg-dg_neg],stepsize/n_dat,0,time.clock()-t_start) \n", 867 | " return parin, neg_samp" 868 | ] 869 | }, 870 | { 871 | "cell_type": "code", 872 | "execution_count": 992, 873 | "metadata": { 874 | "collapsed": true 875 | }, 876 | "outputs": [], 877 | "source": [ 878 | "#parin=parset(potin=reluPotential(),K=100,D=2,scale=0.2)\n", 879 | "#samples=np.random.normal(size=(2,100))+5.0\n", 880 | "#W_opt, b_opt, g_opt, p_opt = logP_cdopt(parin,samples,niter=100,stepsize=0.1,burnin=20,ns=100)\n", 881 | "#x_test = np.linspace(-5,15,num=50)\n", 882 | "#y_test = np.linspace(-5,15,num=50)\n", 883 | "#plot_flow_both(x_test,y_test,parin)" 884 | ] 885 | }, 886 | { 887 | "cell_type": "markdown", 888 | "metadata": {}, 889 | "source": [ 890 | "## Evaluation code" 891 | ] 892 | }, 893 | { 894 | "cell_type": "code", 895 | "execution_count": null, 896 | "metadata": { 897 | "collapsed": true 898 | }, 899 | "outputs": [], 900 | "source": [ 901 | "def find_close_t(t,t_list):\n", 902 | " return np.nonzero(np.array(t_list)), HostFromGpu(GpuElemwise{exp,no_inplace}.0)]\n", 97 | "Looping 1000 times took 0.777099 seconds\n", 98 | "Result is [ 1.23178029 1.61879349 1.52278066 ..., 2.20771813 2.29967761\n", 99 | " 1.62323296]\n", 100 | "Used the gpu\n" 101 | ] 102 | } 103 | ], 104 | "source": [ 105 | "from theano import function, config, shared, sandbox\n", 106 | "import theano.tensor as T\n", 107 | "import numpy\n", 108 | "import time\n", 109 | "\n", 110 | "vlen = 10 * 30 * 768 # 10 x #cores x # threads per core\n", 111 | "iters = 1000\n", 112 | "\n", 113 | "rng = numpy.random.RandomState(22)\n", 114 | "x = shared(numpy.asarray(rng.rand(vlen), config.floatX))\n", 115 | "f = function([], T.exp(x))\n", 116 | "print(f.maker.fgraph.toposort())\n", 117 | "t0 = time.time()\n", 118 | "for i in xrange(iters):\n", 119 | " r = f()\n", 120 | "t1 = time.time()\n", 121 | "print(\"Looping %d times took %f seconds\" % (iters, t1 - t0))\n", 122 | "print(\"Result is %s\" % (r,))\n", 123 | "if numpy.any([isinstance(x.op, T.Elemwise) for x in f.maker.fgraph.toposort()]):\n", 124 | " print('Used the cpu')\n", 125 | "else:\n", 126 | " print('Used the gpu')" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "## Pure Theano functions" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 6, 139 | "metadata": { 140 | "collapsed": true 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "import numpy\n", 145 | "import theano\n", 146 | "import theano.tensor as T\n", 147 | "rng = numpy.random" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "### Helper fns" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 67, 160 | "metadata": { 161 | "collapsed": true 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "def Tilogit(x):\n", 166 | " return 1/(1+T.exp(-x))\n", 167 | "\n", 168 | "def T_relu_dprime(x):\n", 169 | " return -1*Tilogit(x)*(1-Tilogit(x))\n", 170 | "\n", 171 | "def T_relu_prime(x):\n", 172 | " return -1*Tilogit(x)\n", 173 | "\n", 174 | "def T_relu(x):\n", 175 | " return -1*T.log(1+T.exp(x))" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 78, 181 | "metadata": { 182 | "collapsed": true 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "def Tdrift_relu_prime(x_p, w, b, g):\n", 187 | " linterm = g*T_relu_dprime(T.dot(w,x_p)+b)\n", 188 | " return T.dot(w.T,linterm)\n", 189 | "\n", 190 | "def Tsimul_relu_prime(z, x_p, w, b, g, dt):\n", 191 | " return x_p + Tdrift_relu_prime(x_p,w,b,g)*dt + T.sqrt(dt)*z\n", 192 | "\n", 193 | "def Tpot_relu_prime(x_p, w, b, g):\n", 194 | " linterm = g*T_relu_prime(T.dot(w,x_p)+b)\n", 195 | " return T.sum(linterm, 0)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 69, 201 | "metadata": { 202 | "collapsed": true 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "def Tdrift_relu(x_p, w, b, g):\n", 207 | " linterm = g*T_relu_prime(T.dot(w,x_p)+b)\n", 208 | " return T.dot(w.T,linterm)\n", 209 | "\n", 210 | "def Tsimul_relu(z, x_p, w, b, g, dt):\n", 211 | " return x_p + Tdrift_relu(x_p,w,b,g)*dt + T.sqrt(dt)*z\n", 212 | "\n", 213 | "def Tpot_relu(x_p, w, b, g):\n", 214 | " linterm = g*T_relu(T.dot(w,x_p)+b)\n", 215 | " return T.sum(linterm, 0)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 84, 221 | "metadata": { 222 | "collapsed": true 223 | }, 224 | "outputs": [], 225 | "source": [ 226 | "def Tpot_quad(x_p, w, b, g):\n", 227 | " sqdist = (x_p ** 2).sum(0).reshape((1, x_p.shape[1])) + (w ** 2).sum(1).reshape((w.shape[0], 1)) - 2 * T.dot(w,x_p)\n", 228 | " kernest = T.exp(-sqdist / b)*g\n", 229 | " return T.sum(kernest, 0)\n", 230 | "\n", 231 | "def Tdrift_quad(x_p, w, b, g):\n", 232 | " ksum = T.sum(Tpot_quad(x_p, w, b, g))\n", 233 | " driftterm = theano.gradient.grad(ksum, x_p)\n", 234 | " return driftterm\n", 235 | "\n", 236 | "def Tsimul_quad(z, x_p, w, b, g, dt):\n", 237 | " return x_p + Tdrift_quad(x_p,w,b,g)*dt + T.sqrt(dt)*z\n", 238 | " " 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": { 245 | "collapsed": true 246 | }, 247 | "outputs": [], 248 | "source": [ 249 | "def Tpot_lin(x_p, w, b, g):\n", 250 | " linterm = g*T.dot(w, x_p)\n", 251 | " return T.sum(linterm, 0)\n", 252 | "\n", 253 | "def Tdrift_lin(x_p, w, b, g):\n", 254 | " ksum = T.sum(Tpot_lin(x_p,w,b,g))\n", 255 | " driftterm = theano.gradient.grad(ksum, x_p)\n", 256 | " return driftterm\n", 257 | "\n", 258 | "def Tsimul_lin(z, x_p, w, b, g, dt):\n", 259 | " return x_p + Tdrift_lin(x_p,w,b,g)*dt + T.sqrt(dt)*z" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": { 266 | "collapsed": true 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "def Tpot_ou(x_p, w, b, g):\n", 271 | " sqdist = (x_p ** 2).sum(0).reshape((1, x_p.shape[1])) + (w ** 2).sum(1).reshape((w.shape[0], 1)) - 2 * T.dot(w,x_p)\n", 272 | " return T.sum(-sqdist * g, 0)\n", 273 | "\n", 274 | "def Tdrift_ou(x_p, w, b, g):\n", 275 | " ksum = T.sum(Tpot_ou(x_p, w, b, g))\n", 276 | " driftterm = theano.gradient.grad(ksum, x_p)\n", 277 | " return driftterm\n", 278 | "\n", 279 | "def Tsimul_ou(z, x_p, w, b, g, dt):\n", 280 | " return x_p + Tdrift_ou(x_p,w,b,g)*dt + T.sqrt(dt)*z" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": { 286 | "collapsed": false 287 | }, 288 | "source": [ 289 | "import numpy as np\n", 290 | "x_pt = np.array([[0,0],[1,1]]).T\n", 291 | "wt = np.array([[-1,-1],[2,2],[3,3]])\n", 292 | "gi = np.array([1,1,1])[:,np.newaxis]*10\n", 293 | "bi = np.array([1,1,1])[:,np.newaxis]\n", 294 | "zi = np.random.randn(2,2)*0\n", 295 | "\n", 296 | "quad_pot_val = Tpot_quad(xi, w, b, g)\n", 297 | "quad_pot_test = theano.function(inputs=[xi, w, b, g],outputs=quad_pot_val, allow_input_downcast=True)\n", 298 | "quad_pot_test(x_pt, wt, bi, gi)\n", 299 | "\n", 300 | "quad_drift_val = Tdrift_quad(xi, w, b, g)\n", 301 | "quad_drift_test = theano.function(inputs=[xi, w, b, g],outputs=quad_drift_val, allow_input_downcast=True)\n", 302 | "quad_drift_test(x_pt, wt, bi, gi)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 86, 308 | "metadata": { 309 | "collapsed": true 310 | }, 311 | "outputs": [], 312 | "source": [ 313 | "#theano variables\n", 314 | "n_steps = T.iscalar('n_steps')\n", 315 | "dt = T.fscalar('dt')\n", 316 | "xi = T.matrix(\"xi\")\n", 317 | "z = T.tensor3(\"z\")\n", 318 | "zmat = T.matrix(\"zmat\")\n", 319 | "w = T.matrix(\"w\")\n", 320 | "b = T.TensorType(dtype='float32',broadcastable=(False,True))('b')\n", 321 | "g = T.TensorType(dtype='float32',broadcastable=(False,True))('g')\n", 322 | "err = T.matrix(\"err\")\n", 323 | "\n", 324 | "def theano_meta_factory(sim_fn, drift_fn, pot_fn, name):\n", 325 | " return {'potential':pot_factory(pot_fn),\n", 326 | " 'drift':drift_factory(drift_fn),\n", 327 | " 'trajectory':em_traj_factory(sim_fn),\n", 328 | " 'simulate':em_final_factory(sim_fn),\n", 329 | " 'backprop':em_lop_factory(sim_fn),\n", 330 | " 'potential_grad':em_pot_factory(pot_fn),\n", 331 | " 'name':name}\n", 332 | "\n", 333 | "def drift_factory(drift_fn):\n", 334 | " drift_val = drift_fn(xi, w, b, g)\n", 335 | " return theano.function(inputs=[xi, w, b, g], outputs=drift_val, allow_input_downcast=True,on_unused_input='ignore')\n", 336 | "\n", 337 | "def pot_factory(pot_fn):\n", 338 | " pot_val = pot_fn(xi, w, b, g)\n", 339 | " return theano.function(inputs=[xi, w, b, g], outputs=pot_val, allow_input_downcast=True,on_unused_input='ignore')\n", 340 | "\n", 341 | "def em_traj_factory(sim_fn):\n", 342 | " result, updates = theano.scan(fn = sim_fn, sequences = z, outputs_info = xi, non_sequences = [w, b, g, dt], n_steps = n_steps)\n", 343 | " em_traj_fun = theano.function(inputs = [z, xi, w, b, g, dt, n_steps], outputs= result, updates=updates, allow_input_downcast=True,on_unused_input='ignore')\n", 344 | " return em_traj_fun\n", 345 | "\n", 346 | "def em_final_factory(sim_fn):\n", 347 | " result, updates = theano.scan(fn = sim_fn, sequences = z, outputs_info = xi, non_sequences = [w, b, g, dt], n_steps = n_steps)\n", 348 | " res_final = result[-1]\n", 349 | " em_final_fun = theano.function(inputs = [z, xi, w, b, g, dt, n_steps], outputs=res_final, updates=updates, allow_input_downcast=True,on_unused_input='ignore')\n", 350 | " return em_final_fun\n", 351 | "\n", 352 | "def em_lop_factory(sim_fn):\n", 353 | " result, updates = theano.scan(fn = sim_fn, sequences = z, outputs_info = xi, non_sequences = [w, b, g, dt], n_steps = n_steps)\n", 354 | " gradval = theano.gradient.Lop(T.flatten(result[-1]), [w, b, g], T.flatten(err), disconnected_inputs='warn')\n", 355 | " gradfun = theano.function(inputs = [err, z, xi, w, b, g, dt, n_steps], outputs=gradval, updates=updates, allow_input_downcast=True,on_unused_input='ignore')\n", 356 | " return gradfun\n", 357 | "\n", 358 | "def em_pot_factory(pot_fn):\n", 359 | " pot_val = T.sum(pot_fn(xi, w, b, g))\n", 360 | " gradval = theano.gradient.grad(pot_val, [w, b, g], disconnected_inputs='warn')\n", 361 | " potfun = theano.function(inputs = [xi, w, b, g], outputs = pot_val, allow_input_downcast=True,on_unused_input='ignore')\n", 362 | " potgrad = theano.function(inputs = [xi, w, b, g], outputs = gradval, allow_input_downcast=True,on_unused_input='ignore')\n", 363 | " return potgrad, potfun" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 87, 369 | "metadata": { 370 | "collapsed": false 371 | }, 372 | "outputs": [], 373 | "source": [ 374 | "relu_pack = theano_meta_factory(Tsimul_relu,Tdrift_relu,Tpot_relu, 'ramp potential')\n", 375 | "local_pack = theano_meta_factory(Tsimul_quad,Tdrift_quad,Tpot_quad, 'local potential')\n", 376 | "logit_pack = theano_meta_factory(Tsimul_relu_prime,Tdrift_relu_prime,Tpot_relu_prime, 'logit potential')\n", 377 | "ou_pack = theano_meta_factory(Tsimul_ou, Tdrift_ou, Tpot_ou, 'Orstein-Uhlenbeck potential')\n", 378 | "lin_pack = theano_meta_factory(Tsimul_lin, Tdrift_lin, Tpot_lin, 'Linear potential')" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": {}, 384 | "source": [ 385 | "## Structs" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": null, 391 | "metadata": { 392 | "collapsed": true 393 | }, 394 | "outputs": [], 395 | "source": [ 396 | "import copy\n", 397 | "class observed:\n", 398 | " def __init__(self, p_init, p_out):\n", 399 | " self.p_init=p_init\n", 400 | " self.p_out=p_out\n", 401 | "class hyperpars:\n", 402 | " def __init__(self, NS, eps, sd, sdkern, dt, time):\n", 403 | " self.eps=eps\n", 404 | " self.NS=NS\n", 405 | " self.sd=sd\n", 406 | " self.sdkern=sdkern\n", 407 | " self.dt=dt\n", 408 | " self.time=time\n", 409 | "class parset:\n", 410 | " def __init__(self, K, D, potin=relu_pack, scale=1, muzero=None):\n", 411 | " if muzero is None:\n", 412 | " muzero = np.zeros(D)\n", 413 | " self.potin=potin\n", 414 | " self.W_matrix=np.random.randn(K,D)*scale\n", 415 | " if 'local' not in potin['name']:\n", 416 | " offset = np.dot(self.W_matrix,muzero)\n", 417 | " self.b_vec=np.random.uniform(low=-1,high=1,size=K) - offset\n", 418 | " else:\n", 419 | " self.b_vec=np.ones(K)*5.0\n", 420 | " self.g_vec=np.zeros(K)\n", 421 | " self.W_sqsum=np.ones(self.W_matrix.shape)\n", 422 | " self.b_sqsum=np.ones(self.b_vec.shape)\n", 423 | " self.g_sqsum=np.ones(self.g_vec.shape)\n", 424 | " self.fvvec=[]\n", 425 | " self.tvec=[]\n", 426 | " self.tnow=0\n", 427 | " \n", 428 | " def gclip(self, grad, gmax=1e5):\n", 429 | " g_new = []\n", 430 | " for i in xrange(len(grad)):\n", 431 | " vnorm = np.sqrt(np.sum(grad[i]**2.0))\n", 432 | " sfactor = max(1, vnorm/gmax)\n", 433 | " g_new.append(np.copy(grad[i]/sfactor))\n", 434 | " return g_new\n", 435 | " \n", 436 | " def update(self, grad, eps_val, fv, tv, ada=1e-3):\n", 437 | " self.W_sqsum = self.W_sqsum + eps_val*ada*grad[0]**2\n", 438 | " self.b_sqsum = self.b_sqsum + eps_val*ada*grad[1]**2\n", 439 | " self.g_sqsum = self.g_sqsum + eps_val*ada*grad[2]**2\n", 440 | " self.W_matrix = self.W_matrix + eps_val*grad[0]/np.sqrt(self.W_sqsum)\n", 441 | " self.b_vec = self.b_vec + eps_val*grad[1]/np.sqrt(self.b_sqsum)\n", 442 | " self.g_vec = self.g_vec + eps_val*grad[2]/np.sqrt(self.g_sqsum)\n", 443 | " self.fvvec.append(fv)\n", 444 | " self.tnow=self.tnow+tv\n", 445 | " self.tvec.append(self.tnow)\n", 446 | " \n", 447 | " def reset_ada(self):\n", 448 | " self.W_sqsum=np.ones(self.W_matrix.shape)\n", 449 | " self.b_sqsum=np.ones(self.b_vec.shape)\n", 450 | " self.g_sqsum=np.ones(self.g_vec.shape)\n", 451 | " \n", 452 | " def copy(self):\n", 453 | " parnew = parset(K=self.b_vec.shape[0],D=self.W_matrix.shape[1],potin=self.potin)\n", 454 | " parnew.W_matrix = np.copy(self.W_matrix)\n", 455 | " parnew.b_vec = np.copy(self.b_vec)\n", 456 | " parnew.g_vec = np.copy(self.g_vec)\n", 457 | " parnew.fvvec = copy.copy(self.fvvec)\n", 458 | " parnew.tvec = copy.copy(self.tvec)\n", 459 | " parnew.tnow = self.tnow\n", 460 | " return parnew\n", 461 | " \n", 462 | " def plot(self, xpair, ypair):\n", 463 | " xseq = np.linspace(xpair[0],xpair[1],num=50)\n", 464 | " yseq = np.linspace(ypair[0],ypair[1],num=50)\n", 465 | " plot_flow_pot(self.potin,xseq,yseq,self.W_matrix,self.b_vec,self.g_vec)\n", 466 | " plot_flow_par(xseq,yseq,self.potin,self.W_matrix,self.b_vec,self.g_vec)\n", 467 | " \n", 468 | " def simulate(self, init, ns, time, dt, sd):\n", 469 | " W_mat = self.W_matrix\n", 470 | " b_v = self.b_vec[:,np.newaxis]\n", 471 | " g_v = self.g_vec[:,np.newaxis]\n", 472 | " num_steps = int(time/dt)\n", 473 | " pp = p_samp(init,ns)\n", 474 | " z = rng.randn(num_steps,pp.shape[0],pp.shape[1])*sd\n", 475 | " return self.potin['simulate'](z, pp, W_mat, b_v, g_v, dt, num_steps)\n", 476 | " \n", 477 | "class observed_list:\n", 478 | " def __init__(self, p_list, t_list):\n", 479 | " self.p_list=p_list\n", 480 | " self.t_list=t_list" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "metadata": { 487 | "collapsed": true 488 | }, 489 | "outputs": [], 490 | "source": [ 491 | "def plot_flow_par(x,y,potfun,W,b,g):\n", 492 | " u=np.zeros((x.shape[0],y.shape[0]))\n", 493 | " v=np.zeros((x.shape[0],y.shape[0]))\n", 494 | " nrm=np.zeros((x.shape[0],y.shape[0]))\n", 495 | " for i in xrange(x.shape[0]):\n", 496 | " ptv=np.vstack((np.full(y.shape[0],x[i]),y))\n", 497 | " flowtmp=potfun['drift'](ptv,W,b[:,np.newaxis],g[:,np.newaxis])\n", 498 | " u[:,i]=flowtmp[0,:]\n", 499 | " v[:,i]=flowtmp[1,:]\n", 500 | " nrm[:,i]=np.sqrt(np.sum(flowtmp**2.0,0))\n", 501 | " #plt.quiver(x,y,u,v)\n", 502 | " plt.streamplot(x,y,u,v,density=1.0,linewidth=3*nrm/np.max(nrm))" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": null, 508 | "metadata": { 509 | "collapsed": true 510 | }, 511 | "outputs": [], 512 | "source": [ 513 | "def plot_flow_pot(pot,x,y,W,b,g):\n", 514 | " z=np.zeros((x.shape[0],y.shape[0]))\n", 515 | " for i in xrange(x.shape[0]):\n", 516 | " ptv=np.vstack((np.full(y.shape[0],x[i]),y))\n", 517 | " flowtmp= pot['potential'](ptv,W,b[:,np.newaxis],g[:,np.newaxis])\n", 518 | " z[:,i]=flowtmp\n", 519 | " plt.pcolormesh(x,y,np.exp(z))\n", 520 | " CS = plt.contour(x,y,z)\n", 521 | " plt.clabel(CS, inline=1, fontsize=10)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "markdown", 526 | "metadata": {}, 527 | "source": [ 528 | "# Theano helpers" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": 9, 534 | "metadata": { 535 | "collapsed": true 536 | }, 537 | "outputs": [], 538 | "source": [ 539 | "def p_samp(p_in, num_samp):\n", 540 | " repflag = p_in.shape[1] < num_samp\n", 541 | " p_sub=np.random.choice(p_in.shape[1],size=num_samp,replace=repflag)\n", 542 | " return np.copy(p_in[:,p_sub])" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 10, 548 | "metadata": { 549 | "collapsed": false 550 | }, 551 | "outputs": [ 552 | { 553 | "ename": "NameError", 554 | "evalue": "name 'np' is not defined", 555 | "output_type": "error", 556 | "traceback": [ 557 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 558 | "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", 559 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mdef\u001b[0m \u001b[0mget_grad_logp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mparin\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msamples\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpp\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mburnin\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtheano_pack\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdt\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msd\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0mfactr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msamples\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m/\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpp\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mW_mat\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mparin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mW_matrix\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mb_v\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mparin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mb_vec\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnewaxis\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mg_v\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mparin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mg_vec\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnewaxis\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 560 | "\u001b[1;31mNameError\u001b[0m: name 'np' is not defined" 561 | ] 562 | } 563 | ], 564 | "source": [ 565 | "def get_grad_logp(parin, samples, pp, burnin, theano_pack, dt, sd=np.sqrt(2)):\n", 566 | " factr = np.shape(samples)[1]/float(np.shape(pp)[1])\n", 567 | " W_mat = parin.W_matrix\n", 568 | " b_v = parin.b_vec[:,np.newaxis]\n", 569 | " g_v = parin.g_vec[:,np.newaxis]\n", 570 | " num_steps = burnin\n", 571 | " z = rng.randn(num_steps,pp.shape[0],pp.shape[1])*sd\n", 572 | " #run chain forward, get result\n", 573 | " result_final = theano_pack['simulate'](z, pp, W_mat, b_v, g_v, dt, num_steps)\n", 574 | " #logp with respect to input samples\n", 575 | " grad_pos = theano_pack['potential_grad'][0](samples,W_mat, b_v, g_v)\n", 576 | " pos_fv = theano_pack['potential_grad'][1](samples,W_mat,b_v,g_v)\n", 577 | " #logp with respect to contrastive divergence smaples\n", 578 | " grad_neg = theano_pack['potential_grad'][0](result_final,W_mat, b_v, g_v)\n", 579 | " neg_fv = theano_pack['potential_grad'][1](result_final,W_mat, b_v, g_v)\n", 580 | " fv_tot = pos_fv - factr*neg_fv\n", 581 | " dW = grad_pos[0]-grad_neg[0]*factr\n", 582 | " db = np.squeeze(grad_pos[1]-grad_neg[1]*factr)\n", 583 | " dg = np.squeeze(grad_pos[2]-grad_neg[2]*factr)\n", 584 | " return [[dW, db, dg],-1*fv_tot, result_final]" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": { 591 | "collapsed": true 592 | }, 593 | "outputs": [], 594 | "source": [ 595 | "def run_logp_theano(parin,samples, niter, stepsize,theano_pack,dt=0.01, burnin=10, ns=500, ctk=True, ada_val=0,sd=np.sqrt(2)):\n", 596 | " for i in xrange(niter):\n", 597 | " pp = p_samp(samples,ns)\n", 598 | " t_start = time.clock()\n", 599 | " gradin, fv_tot, result_final = get_grad_logp(parin, samples, pp, burnin, theano_pack, dt, sd=sd)\n", 600 | " if ctk:\n", 601 | " pp = result_final\n", 602 | " parin.update(gradin,stepsize/np.shape(samples)[1],fv_tot,time.clock()-t_start, ada_val)\n", 603 | " return parin, result_final" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": null, 609 | "metadata": { 610 | "collapsed": true 611 | }, 612 | "outputs": [], 613 | "source": [ 614 | "def get_grad_marginal(parin,pp,p_target,theano_pack,time,dt,sd,sdkern,lossfun):\n", 615 | " W_mat = parin.W_matrix\n", 616 | " b_v = parin.b_vec[:,np.newaxis]\n", 617 | " g_v = parin.g_vec[:,np.newaxis]\n", 618 | " num_steps = int(time / float(dt))\n", 619 | " z = rng.randn(num_steps,pp.shape[0],pp.shape[1])*sd\n", 620 | " result_final = theano_pack['simulate'](z, pp, W_mat, b_v, g_v, dt, num_steps)\n", 621 | " err_out, fval=lossfun(result_final,p_target,sdkern)\n", 622 | " gall = theano_pack['backprop'](err_out, z, pp, W_mat, b_v, g_v, dt, num_steps)\n", 623 | " gall[1]=np.squeeze(gall[1])\n", 624 | " gall[2]=np.squeeze(gall[2])\n", 625 | " return gall, fval, result_final, err_out" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": null, 631 | "metadata": { 632 | "collapsed": true 633 | }, 634 | "outputs": [], 635 | "source": [ 636 | "def run_grad_theano(datin,parin,hpars,maxit,theano_pack,lossfun,tau=0,burnin=10,ctk=True,debug=True,ada_val=0):\n", 637 | " if debug:\n", 638 | " f = FloatProgress(min=0, max=maxit)\n", 639 | " display(f)\n", 640 | " num_samp = hpars.NS\n", 641 | " for i in xrange(maxit):\n", 642 | " pp = p_samp(datin.p_init, num_samp)\n", 643 | " pneg = pp\n", 644 | " time_start = time.clock()\n", 645 | " gall, fval, result_final, err_out = get_grad_marginal(parin, pp, datin.p_out, theano_pack, hpars.time, hpars.dt, hpars.sd, hpars.sdkern,lossfun)\n", 646 | " if tau is not 0: #entropic regularization below.\n", 647 | " gall_logp, fv_logp, result_logp = get_grad_logp(parin, datin.p_out, pneg, burnin, theano_pack, hpars.dt, hpars.sd)\n", 648 | " if ctk:\n", 649 | " pneg = result_logp\n", 650 | " for j in xrange(3):\n", 651 | " gall[j] = gall[j]*(1-tau) + gall_logp[j]*tau\n", 652 | " parin.update(gall,hpars.eps,fval,time.clock()-time_start,ada_val/num_samp)\n", 653 | " if np.isneginf(fval):\n", 654 | " break\n", 655 | " if debug:\n", 656 | " f.value = i\n", 657 | " if debug:\n", 658 | " print(fval)\n", 659 | " plt.figure(1)\n", 660 | " plt.plot(parin.fvvec)\n", 661 | " plt.figure(2)\n", 662 | " plt.scatter(result_final[0,:],result_final[1,:],c='red')\n", 663 | " plt.scatter(datin.p_out[0,:],datin.p_out[1,:])\n", 664 | " plt.quiver(result_final[0,:],result_final[1,:],err_out[0,:],err_out[1,:])\n", 665 | " return parin" 666 | ] 667 | }, 668 | { 669 | "cell_type": "code", 670 | "execution_count": null, 671 | "metadata": { 672 | "collapsed": true 673 | }, 674 | "outputs": [], 675 | "source": [ 676 | "def run_grad_theano_list(datin_list,parin,hpars,maxit,theano_pack,lossfun, tau=0, burnin=10,ctk=True,delta=False,debug=True,ada_val=0):\n", 677 | " if debug:\n", 678 | " f = FloatProgress(min=0, max=maxit)\n", 679 | " display(f)\n", 680 | " num_samp = hpars.NS\n", 681 | " dlast = datin_list.p_list[len(datin_list.t_list)-1]\n", 682 | " for i in xrange(maxit):\n", 683 | " pneg = p_samp(dlast, num_samp)\n", 684 | " db = np.zeros(parin.b_vec.shape)\n", 685 | " dg = np.zeros(parin.b_vec.shape)\n", 686 | " dW = np.zeros(parin.W_matrix.shape)\n", 687 | " fv_tmp = 0\n", 688 | " time_start = time.clock()\n", 689 | " for j in xrange(len(datin_list.t_list)-1):\n", 690 | " if not delta:\n", 691 | " t_cur = datin_list.t_list[j+1] - datin_list.t_list[0]\n", 692 | " dat_cur = datin_list.p_list[j+1]\n", 693 | " dat_init = datin_list.p_list[0]\n", 694 | " else:\n", 695 | " t_cur = datin_list.t_list[j+1]-datin_list.t_list[j]\n", 696 | " dat_cur = datin_list.p_list[j+1]\n", 697 | " dat_init = datin_list.p_list[j]\n", 698 | " pp = p_samp(dat_init, num_samp)\n", 699 | " gall, fval, result_final, err_out = get_grad_marginal(parin, pp, dat_cur, theano_pack, t_cur, hpars.dt, hpars.sd, hpars.sdkern,lossfun)\n", 700 | " dW = dW + gall[0]\n", 701 | " db = db + gall[1]\n", 702 | " dg = dg + gall[2]\n", 703 | " fv_tmp = fv_tmp + fval\n", 704 | " gnew = [dW, db, dg]\n", 705 | " if tau is not 0: #entropic regularization below.\n", 706 | " gall_logp, fv_logp, result_logp = get_grad_logp(parin, dlast, pneg, burnin, theano_pack, hpars.dt, hpars.sd)\n", 707 | " if ctk:\n", 708 | " pneg = result_logp\n", 709 | " for j in xrange(3):\n", 710 | " gnew[j] = gnew[j]*(1-tau) + gall_logp[j]*tau\n", 711 | " if np.isneginf(fv_tmp):\n", 712 | " break\n", 713 | " parin.update(gnew,hpars.eps,fv_tmp,time.clock()-time_start,ada_val/num_samp)\n", 714 | " if debug:\n", 715 | " f.value = i\n", 716 | " if debug:\n", 717 | " print(fval)\n", 718 | " plt.figure(1)\n", 719 | " plt.plot(parin.fvvec)\n", 720 | " for j in xrange(len(datin_list.t_list)-1):\n", 721 | " plt.figure(j+2)\n", 722 | " t_cur = datin_list.t_list[j+1]-datin_list.t_list[j]\n", 723 | " dat_cur = datin_list.p_list[j+1]\n", 724 | " dat_init = datin_list.p_list[j]\n", 725 | " W_mat = parin.W_matrix\n", 726 | " b_v = parin.b_vec[:,np.newaxis]\n", 727 | " g_v = parin.g_vec[:,np.newaxis]\n", 728 | " num_steps = int(t_cur / float(hpars.dt))\n", 729 | " z = rng.randn(num_steps,dat_init.shape[0],dat_init.shape[1])*hpars.sd\n", 730 | " result_final = theano_pack['simulate'](z, dat_init, W_mat, b_v, g_v, hpars.dt, num_steps)\n", 731 | " plt.scatter(result_final[0],result_final[1],c='red')\n", 732 | " plt.scatter(dat_cur[0],dat_cur[1])\n", 733 | " return parin" 734 | ] 735 | }, 736 | { 737 | "cell_type": "markdown", 738 | "metadata": {}, 739 | "source": [ 740 | "# Wasserstein loss stuff" 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "execution_count": null, 746 | "metadata": { 747 | "collapsed": true 748 | }, 749 | "outputs": [], 750 | "source": [ 751 | "def checkmat(mat):\n", 752 | " is_finite = np.all(np.isfinite(mat))\n", 753 | " is_nontrivial = np.ptp(mat)>1e-5\n", 754 | " return is_finite and is_nontrivial" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "execution_count": null, 760 | "metadata": { 761 | "collapsed": true 762 | }, 763 | "outputs": [], 764 | "source": [ 765 | "from sklearn import utils\n", 766 | "import hungarian\n", 767 | "\n", 768 | "def wasserstein_error(p_pred, p_true, sdkern):\n", 769 | " ptrue_resamp = p_samp(p_true, p_pred.shape[1])\n", 770 | " distsq = get_dist(p_pred,ptrue_resamp)\n", 771 | " #matching = utils.linear_assignment_._hungarian(distsq)\n", 772 | " distsq[np.isposinf(distsq)]=1e5\n", 773 | " if checkmat(distsq):\n", 774 | " matching = hungarian.lap(distsq)\n", 775 | " else:\n", 776 | " matching = [np.arange(p_pred.shape[1]), np.arange(p_pred.shape[1])]\n", 777 | " #m1=matching[0]\n", 778 | " m1=np.arange(len(matching[0]))\n", 779 | " #m2=matching[1]\n", 780 | " m2=matching[0]\n", 781 | " spts = p_pred[:,m1]\n", 782 | " dlts = ptrue_resamp[:,m2]-spts\n", 783 | " errs = np.sum(dlts**2.0,0)\n", 784 | " return dlts, -1*np.sum(errs)" 785 | ] 786 | }, 787 | { 788 | "cell_type": "code", 789 | "execution_count": null, 790 | "metadata": { 791 | "collapsed": true 792 | }, 793 | "outputs": [], 794 | "source": [ 795 | "def get_dist(yt, ytrue):\n", 796 | " ytnorm = np.sum(yt**2,0)\n", 797 | " ytruenorm = np.sum(ytrue**2,0)\n", 798 | " dotprod = np.dot(yt.T,ytrue)\n", 799 | " return np.add.outer(ytnorm,ytruenorm) - 2*dotprod" 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "execution_count": null, 805 | "metadata": { 806 | "collapsed": true 807 | }, 808 | "outputs": [], 809 | "source": [ 810 | "def sinkhorn(M, lamb, r, c, maxit=100):\n", 811 | " #Mp = np.array(M,dtype=np.float128)\n", 812 | " Mp=M\n", 813 | " K = np.exp(-lamb*(Mp))#-np.min(Mp)))\n", 814 | " rp = np.copy(r)\n", 815 | " cp = np.copy(c)\n", 816 | " for i in xrange(maxit):\n", 817 | " cp = 1.0/np.dot(rp,K)\n", 818 | " rp = 1.0/np.dot(K,cp)\n", 819 | " kn = rp[:,np.newaxis]*K*cp\n", 820 | " return cp, rp, kn#np.dot(np.dot(np.diag(rp),K),np.diag(cp))" 821 | ] 822 | }, 823 | { 824 | "cell_type": "code", 825 | "execution_count": 12, 826 | "metadata": { 827 | "collapsed": true 828 | }, 829 | "outputs": [], 830 | "source": [ 831 | "def sinkhorn_error(p_pred, p_true, sdkern, rep=0, numit=10):\n", 832 | " if sdkern is None:\n", 833 | " sdkern = 10.0\n", 834 | " ptrue_resamp = p_samp(p_true, p_pred.shape[1])\n", 835 | " distsq = get_dist(p_pred,ptrue_resamp)\n", 836 | " sko = sinkhorn(distsq, sdkern, np.ones(distsq.shape[0]),np.ones(distsq.shape[1]),numit)[2]\n", 837 | " sko = sko / sko.sum(axis=1,keepdims=True)\n", 838 | " if np.all(np.isfinite(sko)):\n", 839 | " targ = np.dot(ptrue_resamp,np.transpose(sko))\n", 840 | " dlts = targ - p_pred\n", 841 | " return dlts, -1*np.sum(dlts**2.0)\n", 842 | " else:\n", 843 | " if rep < 10:\n", 844 | " return sinkhorn_error(p_pred, p_true, sdkern/2.0, rep=rep+1)\n", 845 | " else:\n", 846 | " return p_pred, -float('Inf')" 847 | ] 848 | }, 849 | { 850 | "cell_type": "code", 851 | "execution_count": null, 852 | "metadata": { 853 | "collapsed": true 854 | }, 855 | "outputs": [], 856 | "source": [ 857 | "def sinkhorn_hiprec(M, lamb, r, c, maxit=100):\n", 858 | " Mp = np.array(M,dtype=np.float128)\n", 859 | " K = np.exp(-lamb*(Mp-np.min(Mp)))\n", 860 | " rp = np.copy(r)\n", 861 | " cp = np.copy(c)\n", 862 | " for i in xrange(maxit):\n", 863 | " cp = 1.0/np.dot(rp,K)\n", 864 | " rp = 1.0/np.dot(K,cp)\n", 865 | " kn = rp[:,np.newaxis]*K*cp\n", 866 | " return cp, rp, kn#np.dot(np.dot(np.diag(rp),K),np.diag(cp))" 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": null, 872 | "metadata": { 873 | "collapsed": true 874 | }, 875 | "outputs": [], 876 | "source": [ 877 | "def sinkhorn_error_hiprec(p_pred, p_true, sdkern, rep=0, numit=10):\n", 878 | " if sdkern is None:\n", 879 | " sdkern = 100.0\n", 880 | " ptrue_resamp = p_samp(p_true, p_pred.shape[1])\n", 881 | " distsq = get_dist(p_pred,ptrue_resamp)\n", 882 | " sko = sinkhorn_hiprec(distsq, sdkern, np.ones(distsq.shape[0]),np.ones(distsq.shape[1]),numit)[2]\n", 883 | " sko = sko / sko.sum(axis=1,keepdims=True)\n", 884 | " if np.all(np.isfinite(sko)):\n", 885 | " targ = np.dot(ptrue_resamp,np.transpose(sko))\n", 886 | " dlts = targ - p_pred\n", 887 | " return dlts, -1*np.sum(dlts**2.0)\n", 888 | " else:\n", 889 | " if rep < 10:\n", 890 | " return sinkhorn_error(p_pred, p_true, sdkern/2.0, rep=rep+1)\n", 891 | " else:\n", 892 | " return p_pred, -float('Inf')" 893 | ] 894 | }, 895 | { 896 | "cell_type": "code", 897 | "execution_count": null, 898 | "metadata": { 899 | "collapsed": true 900 | }, 901 | "outputs": [], 902 | "source": [] 903 | }, 904 | { 905 | "cell_type": "markdown", 906 | "metadata": {}, 907 | "source": [ 908 | "# Autorun script" 909 | ] 910 | }, 911 | { 912 | "cell_type": "code", 913 | "execution_count": null, 914 | "metadata": { 915 | "collapsed": true 916 | }, 917 | "outputs": [], 918 | "source": [ 919 | "def rescale_par(par,snew):\n", 920 | " parnew = par.copy()\n", 921 | " parnew.g_vec = np.copy(par.g_vec) * snew**2.0 / 2.0\n", 922 | " return parnew" 923 | ] 924 | }, 925 | { 926 | "cell_type": "code", 927 | "execution_count": null, 928 | "metadata": { 929 | "collapsed": true 930 | }, 931 | "outputs": [], 932 | "source": [ 933 | "def pack_sim(parin, pp, t, dt, sd, theano_pack):\n", 934 | " num_steps = int(t / float(dt))\n", 935 | " z = rng.randn(num_steps,pp.shape[0],pp.shape[1])*sd\n", 936 | " return theano_pack['simulate'](z, pp, parin.W_matrix, parin.b_vec, parin.g_vec, dt, num_steps)" 937 | ] 938 | }, 939 | { 940 | "cell_type": "code", 941 | "execution_count": null, 942 | "metadata": { 943 | "collapsed": false 944 | }, 945 | "outputs": [], 946 | "source": [ 947 | "def run_all(data_in, time_in,theano_pack,tau=0,sdin = 1.0, Knum=100, dtin=0.01,burnin=100,lossfun=sinkhorn_error, n1=5, n2=10, eps_base=0.01, scale_base=1, debug=True):\n", 948 | " np.random.seed(0)\n", 949 | " data_last = data_in[-1]\n", 950 | " time_last = time_in[-1]\n", 951 | " NS = data_last.shape[1]\n", 952 | " best_err = -1e8\n", 953 | " best_par = None\n", 954 | " bct = 50\n", 955 | " powr = 2.0\n", 956 | " for j in xrange(n1):\n", 957 | " init_par=parset(potin=theano_pack,K=Knum,D=data_last.shape[0],scale=1.0)\n", 958 | " init_par, p_mat = run_logp_theano(init_par,data_last,400,eps_base/float(powr**j),theano_pack,dt=(time_last)/bct, burnin=bct, ns=NS, ctk=False, ada_val=0.0)\n", 959 | " grad, errval = lossfun(p_mat, data_last, None)\n", 960 | " if debug:\n", 961 | " print errval\n", 962 | " if errval > best_err:\n", 963 | " best_par = init_par.copy()\n", 964 | " best_err = errval\n", 965 | " fvbase = -1e8\n", 966 | " best_out = None\n", 967 | " best_eps = None\n", 968 | " ada_2 = 1/100.0\n", 969 | " for j in xrange(n2):\n", 970 | " #h_par= rescale_par(best_par, sdin)\n", 971 | " h_par = best_par.copy()\n", 972 | " epsin = eps_base/(10.0*float(powr**j))*scale_base\n", 973 | " #print epsin\n", 974 | " h_hyp=hyperpars(NS=NS,eps=epsin,sd=sdin,sdkern=None,dt=dtin,time=time_in[1]-time_in[0])\n", 975 | " if len(time_in) is 2:\n", 976 | " h_dat=observed(data_in[0], data_in[1])\n", 977 | " parout = run_grad_theano(h_dat,h_par,h_hyp,100,theano_pack,tau=tau,burnin=burnin,lossfun=lossfun,ada_val=ada_2, debug=False)\n", 978 | " else:\n", 979 | " hl_dat=observed_list(data_in,time_in)\n", 980 | " parout = run_grad_theano_list(hl_dat,h_par,h_hyp,100,theano_pack,tau=tau,burnin=burnin,lossfun=lossfun,ada_val=ada_2, debug=False, delta=True)\n", 981 | " if debug:\n", 982 | " print (parout.fvvec[-1], epsin)\n", 983 | " #pred_output = pack_sim(parout, data_in[0], data_last, )\n", 984 | " if parout.fvvec[-1] > fvbase:\n", 985 | " best_eps = epsin\n", 986 | " best_out = parout.copy()\n", 987 | " fvbase = parout.fvvec[-1]\n", 988 | " #best_out= rescale_par(best_par, sdin)\n", 989 | " best_out2 = best_par.copy()\n", 990 | " epsin = best_eps\n", 991 | " h_hyp=hyperpars(NS=NS,eps=epsin,sd=sdin,sdkern=None,dt=dtin,time=time_in[1]-time_in[0])\n", 992 | " if len(time_in) is 2:\n", 993 | " h_dat=observed(data_in[0], data_in[1])\n", 994 | " best_out2 = run_grad_theano(h_dat,best_out2,h_hyp,500,theano_pack,tau=tau,burnin=burnin,lossfun=lossfun,ada_val=ada_2, debug=debug)\n", 995 | " else:\n", 996 | " hl_dat=observed_list(data_in,time_in)\n", 997 | " best_out2 = run_grad_theano_list(hl_dat,best_out2,h_hyp,500,theano_pack,tau=tau,burnin=burnin,lossfun=lossfun,ada_val=ada_2, debug=debug, delta=True)\n", 998 | " if not np.isfinite(best_out2.fvvec[-1]):\n", 999 | " best_out2 = best_out\n", 1000 | " return best_out2, best_par" 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "markdown", 1005 | "metadata": {}, 1006 | "source": [ 1007 | "# Plot and simulation related" 1008 | ] 1009 | }, 1010 | { 1011 | "cell_type": "code", 1012 | "execution_count": null, 1013 | "metadata": { 1014 | "collapsed": true 1015 | }, 1016 | "outputs": [], 1017 | "source": [ 1018 | "def euler_maruyama_dist(p, flow, dt, t, sd):\n", 1019 | " pp = np.copy(p)\n", 1020 | " n = int(t/dt)\n", 1021 | " sqrtdt = np.sqrt(dt)\n", 1022 | " for i in xrange(n):\n", 1023 | " drift = flow(pp)\n", 1024 | " pp = pp + drift*dt + np.random.normal(scale=sd,size=p.shape)*sqrtdt\n", 1025 | " return pp" 1026 | ] 1027 | }, 1028 | { 1029 | "cell_type": "code", 1030 | "execution_count": null, 1031 | "metadata": { 1032 | "collapsed": true 1033 | }, 1034 | "outputs": [], 1035 | "source": [ 1036 | "def plot_flow(x,y,fun,ladj=5):\n", 1037 | " u=np.zeros((x.shape[0],y.shape[0]))\n", 1038 | " v=np.zeros((x.shape[0],y.shape[0]))\n", 1039 | " nrm=np.zeros((x.shape[0],y.shape[0]))\n", 1040 | " for i in xrange(x.shape[0]):\n", 1041 | " ptv=np.vstack((np.full(y.shape[0],x[i]),y))\n", 1042 | " flowtmp=fun(ptv)\n", 1043 | " u[:,i]=flowtmp[0,:]\n", 1044 | " v[:,i]=flowtmp[1,:]\n", 1045 | " nrm[:,i]=np.sqrt(np.sum(flowtmp**2.0,0))\n", 1046 | " plt.streamplot(x,y,u,v,density=1.0,linewidth=ladj*nrm/np.max(nrm))" 1047 | ] 1048 | }, 1049 | { 1050 | "cell_type": "code", 1051 | "execution_count": null, 1052 | "metadata": { 1053 | "collapsed": true 1054 | }, 1055 | "outputs": [], 1056 | "source": [ 1057 | "from scipy.optimize import fminbound\n", 1058 | "def error_term(yt, ytrue, kern_sig, minv = 1e-4):\n", 1059 | " distsq = get_dist(yt,ytrue)\n", 1060 | " d=yt.shape[0]\n", 1061 | " if kern_sig is None:\n", 1062 | " train_size = int(0.2*yt.shape[1])+1\n", 1063 | " indices = np.random.permutation(yt.shape[1])\n", 1064 | " training_idx, test_idx = indices[:train_size], indices[train_size:]\n", 1065 | " training, test = yt[:,training_idx], yt[:,test_idx]\n", 1066 | " dist_train = get_dist(training,test)\n", 1067 | " spo=fminbound(error_from_dmat, x1=minv, x2=max(np.max(dist_train),4.0*minv)/2.0, args=(dist_train, d), full_output=True)\n", 1068 | " kern_sig = spo[0]\n", 1069 | " expterm = np.exp(-distsq/(2*kern_sig))/kern_sig**(d/2.0)\n", 1070 | " esum = np.sum(expterm,0)\n", 1071 | " #print esum.shape\n", 1072 | " errweight = expterm/esum\n", 1073 | " grad_err = np.zeros(yt.shape)\n", 1074 | " for i in xrange(errweight.shape[0]):\n", 1075 | " grad_err[:,i]=np.sum(-2*(yt[:,i][:,np.newaxis]-ytrue)/kern_sig*errweight[i,],1)\n", 1076 | " return grad_err, np.sum(np.log(esum))" 1077 | ] 1078 | }, 1079 | { 1080 | "cell_type": "code", 1081 | "execution_count": null, 1082 | "metadata": { 1083 | "collapsed": true 1084 | }, 1085 | "outputs": [], 1086 | "source": [ 1087 | "def error_from_dmat(kern_sig, distsq, d):\n", 1088 | " expterm = np.exp(-distsq/(2*kern_sig))/kern_sig**(d/2.0)\n", 1089 | " fv = -1*np.sum(np.log(np.sum(expterm,0)))\n", 1090 | " #print kern_sig, fv\n", 1091 | " return fv" 1092 | ] 1093 | } 1094 | ], 1095 | "metadata": { 1096 | "kernelspec": { 1097 | "display_name": "Python 2", 1098 | "language": "python", 1099 | "name": "python2" 1100 | }, 1101 | "language_info": { 1102 | "codemirror_mode": { 1103 | "name": "ipython", 1104 | "version": 2 1105 | }, 1106 | "file_extension": ".py", 1107 | "mimetype": "text/x-python", 1108 | "name": "python", 1109 | "nbconvert_exporter": "python", 1110 | "pygments_lexer": "ipython2", 1111 | "version": "2.7.11" 1112 | } 1113 | }, 1114 | "nbformat": 4, 1115 | "nbformat_minor": 0 1116 | } 1117 | --------------------------------------------------------------------------------