Jekyll2019-10-21T23:08:16+00:00http://jkimmel.net/feed.xmlJacob C. Kimmelpersonal websiteState transitions in aged stem cells2019-08-24T00:00:00+00:002019-08-24T00:00:00+00:00http://jkimmel.net/aging_musc_dynamics<p><em>This post is adapted from a series of posts on Twitter, so please excuse the short form nature of some descriptions.</em></p>
<h2 id="muscle-stem-cell-activation-is-impaired-with-age">Muscle stem cell activation is impaired with age</h2>
<p>Old muscle stem cells (MuSCs) are bad at regeneration, partly because they don’t activate properly.
Does aging change the set of cell states in activation, or the transition rates between them?
In my final days as a graduate student, I explored this question with my excellent mentors <a href="cellgeometry.ucsf.edu">Wallace Marshall</a> & <a href="bracklab.com">Andrew Brack.</a></p>
<p><a href="https://www.biorxiv.org/content/10.1101/739185v1">Check out our manuscript on this topic over on bioRxiv.</a></p>
<p>If aging changes a cellular response like stem cell activation, it might happen through two mechanisms.
Aging might change the set of cell states a cell transitions through (different paths), or it might change the rate of transitions (different speeds).
In biology, reality is often a weighted mixture of two models, so both of these mechanisms may be at play.
How can we determine the relative contribution of each model?</p>
<h2 id="measuring-the-trajectory-of-stem-cell-activation-in-aged-cells">Measuring the trajectory of stem cell activation in aged cells</h2>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/model.jpg" alt="Cartoon schematic showing aged and young cells moving through an abstract two dimensional space in either different directions, or at different speeds." /></p>
<p>We can measure the path of activation by measuring many individual cells at a single timepoint.
To estimate these paths, we measured transcriptomes of aged and young muscle stem cells with scRNA-seq during activation.
As an added twist, I isolated cells from mice harboring <em>H2B-GFP^+/-^; rtTA^+/-^</em> alleles that <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3605795/">allow us to label muscle stem cells with different proliferative histories.</a>
In previous work, Andrew’s lab has shown that MuSCs which divide rarely during development (label retaining cells, LRCs) are more regenerative than those that divide a lot (non-label retaining cells, nonLRCs).
By sorting these cells into separate tubes by FACS, we were able to associate each transcriptome in the single cell RNA-seq assay with both a cell age and proliferative history.</p>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/scrnaseq_schematic.jpg" alt="Schematic of our single cell RNA sequencing experimental design. We took muscle stem cells from young and aged mice using FACS and performed sequencing at two time points: 1. immediately after isolation and 2. after 18 hours of culture." /></p>
<p>From the scRNA-seq data, we can fit a “pseudotime” trajectory to estimate the path of activation.
This trajectory inference method was pioneered by Cole Trapnell, now at the University of Washington.
We find this trajectory recapitulates some known myogenic biology, and also has some surprises.
The one that stood out most to me was the non-monotonic behavior of <em>Pax7</em>.
It has long been assumed that <em>Pax7</em> marked the most quiescent, least activated stem cells, so it was a bit shocking to see it go back up as cells activated.</p>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/trajectory_fit.jpg" alt="A pseudotime trajectory fit to our single cell RNA sequencing data." /></p>
<p>Onto the aging stuff we came for — young and aged cells are pretty evenly mixed along the same trajectory.
Most aging changes are pretty subtle by differential expression, further suggesting that the change in activation trajectories with age is modest.
This seems to suggest the “path” of activation is retained in aged cells.</p>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/trajectory_age.jpg" alt="A pseudotime trajectory fit to our single cell RNA sequencing data." /></p>
<h2 id="measuring-state-transitions-in-single-cells">Measuring state transitions in single cells</h2>
<p>So if the paths are pretty similar across ages, are the rates different?
How can we even measure state transition rates in single cells?
In previous work with Wallace Marshall, <a href="http://jkimmel.net/heteromotility">I developed a tool to infer cell states from cell behavior captured by timelapse microscopy.</a>
We found we could measure state transitions during early myogenic activation in that first paper.
Measuring cell state transitions rates in aged vs. young cells was an obvious next step.</p>
<p>So, to see if aging changes the rate of stem cell activation, we did just that.
After featurizing behavior and clustering, we find that aged and young MuSCs lie along an activation trajectory, as in the first paper.</p>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/behavior.jpg" alt="Experimental schematic of our cell behavior experiment. Young and aged MuSCs were imaged by timelapse microscopy for 48 hours and Heteromotility was used to featurize behaviors." /></p>
<p>Aged and young cells again share an activation trajectory, but young cells are enriched in more activated clusters.
State transition rates are higher in young cells as well.
This suggests aging alters states transition rates.</p>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/behavior_age.jpg" alt="Cell behaviors contrasted by age." /></p>
<p>scRNA-seq and cell behavior seem to be telling a similar story.
But are the states of activation they reveal the same?
We investigated the non-monotonic change in Pax7 with activation that we found by single cell RNA-seq using cell behavior to find out.</p>
<p>We set up a cell behavior experiment, and immediately stained cells for Pax7/MyoG after the experiment was done.
This allows us to map Pax7 levels to cell behavior states.
We find the same non-monotonic change in Pax7 that we found by scRNA-seq!</p>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/behavior_stains.jpg" alt="Cell behaviors paired to immunohistochemistry." /></p>
<p>So, scRNA-seq & behavior suggest the activation trajectory is retained with age, but behavior indicates aged state transitions are slower.
Can we estimate transition rates from scRNA-seq too?
<a href="https://t.co/18jOk2X7DE?amp=1">In brilliant work from 2018,</a> La Manno <em>et. al.</em> showed that we can estimate state transitions from intronic reads in RNA-seq.
We find this inference method recapitulates the activation trajectory we found with pseudotime really well.</p>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/velocity.png" alt="RNA velocity vectors projected atop a PCA projection of our single cell RNA sequencing data." /></p>
<p>There were no obvious qualitative differences in the RNA velocity field between ages.
To make quantitative comparisons though, I turned to the classic dynamical systems technique of phase simulations.
A phase simulation places an imaginary point in a vector field and updates the position of the point over time based on the vectors in the neighborhood of the point.</p>
<p>This can reveal properties of the vector field that are hard to deduce qualitatively.
Imagine floating a leaf on top of a river to figure out how fast the water is flowing.
Here, I start phase points in the young/aged velocity fields, and update positions over time based on velocity of neighboring cells.</p>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/phase_sim.png" alt="RNA velocity phase simulations." /></p>
<p><a href="https://twitter.com/i/status/1163534186885464064">Watch an animation of this simulation process here.</a></p>
<p>At each time step, we can infer a pseudotime of coordinate for the phase point using a simple regression model.
After a thousand or so simulations, we find that young cells progress more rapidly through the activation trajectory than aged cells.
This got me excited — two totally orthogonal measurement technologies telling us the same thing.</p>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/phase_sim_result.png" alt="RNA velocity phase simulations." /></p>
<p>We also infer a “future” pseudotime coordinate for each cell from the velocity vector.
We found many cells are moving backwards!
This suggests activation is more like biased diffusion than a ball rolling downhill.</p>
<p><img src="http://jkimmel.net/assets/images/aging_musc_dynamics/backwards.jpg" alt="RNA velocity backwards motion." /></p>
<p>Perhaps more poetically, this reminds me of the difference between macroscopic motion and microscopic motion.
In macroscopic motion, like a ball rolling down a hill, inertia takes precedence and noise is negligible.
By contrast, noise often dominates the motion of microscopic particles, like a molecule of water diffusing across a glass.
It seems activating muscle stem cells more closely resemble that diffusing water molecule than a ball rolling down a hill, maybe to Waddington’s chagrin.</p>
<h2 id="conclusions">Conclusions</h2>
<p>In toto, both cell behavior and scRNA-seq indicate that aged MuSCs maintain youthful activation trajectories, but have dampened transition rates.</p>This post is adapted from a series of posts on Twitter, so please excuse the short form nature of some descriptions.Murine Aging Cell Atlas2019-06-05T00:00:00+00:002019-06-05T00:00:00+00:00http://jkimmel.net/murine_cell_aging<p>Mammals are a constellation of distinct cell types, each with a specialized function and lifestyle.
Recent cell atlas efforts suggest there are more than 100 unique cell types in a single mouse.</p>
<p>Do these different cell types experience aging in different ways?
The diversity of cellular physiology in mammals – from short-lived spherical neutrophils to long-lived arboreal neurons – suggests that aging may manifest differently across cell identities.
However, it’s difficult to compare and contrast aging phenotypes measured in individual cell types using different technologies in different laboratories in an apples-to-apples manner.</p>
<p>Along with brilliant collaborators, I recently explored this question at Calico using single cell genomics to obtain comparable measurements of aging phenotypes across cell types.
Check out <a href="https://www.biorxiv.org/content/10.1101/657726v1">our paper on bioRxiv</a>, see <a href="http://mca.research.calicolabs.com/">our research website</a> where you can interact with the data, or read a brief description of some key highlights below.</p>
<h2 id="cartographing-aging-across-three-murine-tissues">Cartographing aging across three murine tissues</h2>
<p>To explore this question, <a href="http://mca.research.calicolabs.com">a collaborative team at Calico</a> leveraged single cell RNA-seq to simultaneously measure the gene expression state of many different cell types in the kidneys, lungs, and spleens of young and old C57Bl/6 mice.
These simultaneous measurements allowed us to compare aging phenotypes across many cell types.</p>
<p><img src="http://mca.research.calicolabs.com/content/images/exp_design.png" alt="Experimental Design" /></p>
<h2 id="identifying-common-aging-phenotypes-across-cell-types">Identifying common aging phenotypes across cell types</h2>
<p>Before we could make any of these comarisons though, we first had to identify which mRNA abundance profiles corresponded to which cell types.
This is surprisingly tricky!
Some cell types have nice binary marker genes (i.e. CD3 defines the T cell compartment, if you have it, you are a T cell).
However, many others do not.</p>
<p>Traditionally, expert biologists will cluster single cell mRNA profiles and manually inspect the expression signatures to assign cell types.
Given that I’m not an expert in kidneys, lungs, or spleens, this didn’t seem like the most tractable approach at the outset.
Luckily, the <a href="https://tabula-muris.ds.czbiohub.org/"><em>Tabula Muris</em> consortium</a> recently released an expert annotated data set containing single cell mRNA profiles for every tissue in the mouse.
We trained a deep neural network on this corpus and used it to classify cell types in our own data.
Given these neural network guesses as starting points, the remaining manual confirmations of cell type identity where much easier.</p>
<p><img src="http://mca.research.calicolabs.com/content/images/web_subtypes.png" alt="Cell type latent space" /></p>
<h2 id="most-transcriptional-changes-are-cell-type-specific-only-a-few-are-common">Most transcriptional changes are cell type-specific, only a few are common</h2>
<p>Comparing differential expression with age, we found that most of the transcriptional changes are specific to one or just a few cell types.
Only a small subset of changes appears to occur in many (>5) cell types in the tissues we profiled.
This set of genes indicates a decrease in endoplasmic reticulum protein targeting with age, also seen in <em>S. cerevisiae</em>.
As seen in many other studies, we also find a common upregulation of inflammatory pathways.</p>
<p><img src="http://mca.research.calicolabs.com/content/images/web_common_diffex.png" alt="Common differentially expressed genes" /></p>
<h2 id="comparing-aging-trajectories">Comparing aging trajectories</h2>
<p>The differential expression results above suggest that cell types age in different ways.
Can we quantify these “aging trajectories” and compare them across cell types?
We leveraged a non-negative matrix factorization to summarize mRNA profiles at the level of gene expression program activity, and within this space we compute vectors that describe the difference between young and old cells in each cell type.
Comparing these vectors, we find that while similar cell types have similar aging trajectories (i.e. lymphocytes), dissimilar cell types have dissimilar trajectories (i.e. myeloid cells vs. endothelial cells).</p>
<p><img src="http://mca.research.calicolabs.com/content/images/web_aging_vec.png" alt="Aging trajectories" /></p>
<h2 id="measuring-aging-magnitudes">Measuring aging magnitudes</h2>
<p>Do some cell types age more dramatically than others?
How might we even measure that?
Comparing differences between discrete populations in high-dimensional spaces (like gene expression space) is dicey business.
Some of the simplistic metrics you might think up actually miss important differences that might arise between populations.
As just one example, the intuitive comparison of differences between the population averages actually misses differences that can arise in covariance structure of population modality.</p>
<p>To account for all these types of variation that may arise with age, we instead leveraged discrete optimal transport distances.
Optimal transport distances (a.k.a. the earth-mover distance, the Wasserstein distance) measure the minimum amount of movement needed to make two evenly sized samples match one another.
As an intuition, if we have a big, irregular pile of dirt and want to build an equally sized rectangular pile of dirt, an optimal transport distance between the two shapes would describe the minimum amount of dirt we need to move to the irregular pile rectangular.</p>
<p>To use this distance with discrete samples of unequal size, we perform bootstrapping with random samples of equal size from young and old populations of cells.
As a null distribution, we compute distances between random samples of only young or only old cells, and normalize the Old-Young distance by the larger of these nulls.
Here’s what the optimal transport for a comparison of young and old spleen B cells looks like.</p>
<p><img src="http://mca.research.calicolabs.com/content/images/animated_ot_transparent.gif" alt="OT animation" /></p>
<p>When we compute these distances across all cell types we observe, we see multi-fold differences between cell types.</p>
<p><img src="http://mca.research.calicolabs.com/content/images/web_aging_mag.png" alt="Aging magnitude" /></p>
<p>Using linear modeling, we find that cell type explains most of the variation in both aging trajectories and magnitudes, while tissue environment is a minority influence.
This suggests that for cells, who you are influences how you age more than where you live.</p>
<h2 id="dive-in">Dive in!</h2>
<p>We’ve opened up the data to the scientific community at our <a href="http://mca.research.calicolabs.com/">Calico Research website</a>.
I’d love to hear any thoughts on these results, or any results you dig out on your own <a href="mailto:jacobkimmel@gmail.com">by email.</a></p>Mammals are a constellation of distinct cell types, each with a specialized function and lifestyle. Recent cell atlas efforts suggest there are more than 100 unique cell types in a single mouse.Disentangling a Latent Space2019-04-27T00:00:00+00:002019-04-27T00:00:00+00:00http://jkimmel.net/disentangling_a_latent_space<h1 id="an-introduction-to-latent-spaces">An introduction to latent spaces</h1>
<p>High-dimensional data presents many analytical challenges and eludes human intuitions. These issues area often short-handed as the <a href="https://en.wikipedia.org/wiki/Curse_of_dimensionality?oldformat=true">“Curse of Dimensionality.”</a> A common approach to address these issues is to find a lower dimensional representation of the high dimensional data. This general problem is known as <a href="https://en.wikipedia.org/wiki/Dimensionality_reduction?oldformat=true">dimensionality reduction</a>, including common techniques like principal component analysis [PCA].</p>
<p>In cell biology, the high-dimensional space may consist of many measurements, like transcriptomics data where each gene is a dimension. In the case of images, each pixel may be viewed as a non-negative dimension. The goal of dimensionality reduction is then to find some smaller number of dimensions that capture biological differences at a higher layer of abstraction, such as cell type or stem cell differentiation state. This smaller set of dimensions is known as a <strong>latent space</strong>. The idea of latent space is that each <strong>latent factor</strong> represents an underlying dimension of variation between samples that explains variation in multiple dimensions of the measurement space.</p>
<h2 id="disentangled-representations">Disentangled Representations</h2>
<p>In the ideal case, a latent space would help elucidate the rules of the underlying process that generated our measurement space. For instance, an ideal latent space would explain variation in cell geometry due to cell cycle state using a single dimension. By sampling cells that vary only along this dimension, we could build an understanding of how the cell cycle effects cell geometry, and what kind of variation in our data set is explained by this singular process. This type of latent space is known as a <strong>disentangled representation</strong>. More formally, a disentangled representation maps each latent factor to a <strong>generative factor</strong>. A generative factor is simply some parameter in the process or model that generated the measurement data.</p>
<p>The opposite of a disentangled representation is as expected, an <strong>entangled representation</strong>. An entangled representation identifies latent factors that each map to more than one aspect of the generative process. In the cell geometry example above, an entangled latent factor may explain geometry variation due to the cell cycle, stage of filopodial motility, and environmental factors in a single dimension. This latent factor would certainly explain some variance in the data, but by convolving many aspects of the generative process together, it would be incredibly difficult to understand what aspects of cell biology were causing which aspects of this variation.</p>
<h3 id="independence-is-necessary-but-not-sufficient">Independence is necessary, but not sufficient</h3>
<p>Notably, independence between latent dimensions is necessary but not sufficient for a disentangled representation. If latent dimensions are not independent, they can’t represent individual aspects of a generative process in an unconfounded way. So, we need independent dimensions. However, independent dimensions are not necessarily disentangled. Two dimensions that are perfectly orthogonal could still change in unison when a particular parameter of the generative process is varied. For instance, imagine a latent space generated from cell shape images where one dimension represents the cell membrane edge shape and another represents the nuclear shape. If the cell cycle state of measured cells changes, both of these dimensions would covary with the cell cycle state, such that the representation is still “entangled.” A disentangled representation would have only a single dimension that changes covaries with the cell cycle state.</p>
<h2 id="how-can-we-find-a-disentangled-representation">How can we find a disentangled representation?</h2>
<p>Disentangled representations have some intuitive advantages over their entangled counterparts, <a href="https://arxiv.org/abs/1206.5538">as outlined by Yoshua Bengio</a> in his seminal 2012 review. Matching a single generative factor to a single dimension allows for easy human interpretation. More abstractly, a disentangled representation may be viewed as a concise representation of the variation in data we care about most – the generative factors. A disentangled representation may also be useful for diverse downstream tasks, whereas an entangled representation may contain information to optimize the training objective that is difficult to utilize in downstream tasks <sup id="fnref:0"><a href="#fn:0" class="footnote">1</a></sup>.</p>
<p>However, there is no obvious route to finding a set of disentangled latent factors. Real world generative processes often have parameters with non-linear effects in the measurement space that are non-trivial to decompose. For instance, the expression of various genes over the <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3392685/">“lifespan” of a yeast cell</a> may not change linearly as a function of cellular age. To uncover this latent factor of age from a set of transcriptomic data of aging yeast, the latent space encoding method must be capable of disentangling these non-linear relationships.</p>
<h1 id="advances-in-disentangling">Advances in Disentangling</h1>
<p>I’ve been excited by a few recent papers adapting the <a href="http://jkimmel.net/variational_autoencoding">variational autoencoder</a> framework to generate disentangled representations. While variational autoencoders are somewhat complex, the modifications introduced to disentangle their latent spaces are remarkably simple. The general idea is that the objective function optimized by a variational autoencoder applies a penalty on the latent space encoded by a neural network to make it match a prior distribution, and that the strength and magnitude of this prior penalty can be changed to enforce less entangled representations.</p>
<h2 id="digging-into-the-vae-objective">Digging into the VAE Objective</h2>
<p>To understand how and why this works, I find it helpful to start from the beginning and recall what a VAE is trying to do in the first place. The VAE operates on two types of data – $\mathbf{x}$’s in the measurement space, and $z$’s which represent points in the latent space we’re learning.</p>
<p>There are two main components to the network that transform between these two types of data points. The encoder $q(\mathbf{z} \vert \mathbf{x})$ estimates a distribution of possible $z$ points given a data point in the measurement space. The decoder network does the opposite, and estimate a point $\mathbf{\hat x}$ in the measurement space given a point in the latent space $z$.</p>
<p>This function below is the objective function of a VAE<sup id="fnref:1"><a href="#fn:1" class="footnote">2</a></sup>. A VAE seeks to minimize this objective by changing the parameters of the encoder $\phi$ and parameters of the decoder $\theta$.</p>
<script type="math/tex; mode=display">{L}(\mathbf{x}; \theta, \phi) = - \mathbb{E}[ \log_{q_\phi (\mathbf{z} \vert \mathbf{x})} p_\theta (\mathbf{x} \vert \mathbf{z})] + \mathbb{D}_{\text{KL}}( q_\phi (\mathbf{z} \vert \mathbf{x}) \vert \vert p(\mathbf{z}) )</script>
<p>While that looks hairy, there are basically two parts to this objective, each doing a particular task. Let’s break it down.</p>
<h3 id="reconstruction-error-pushes-the-latent-space-to-capture-meaningful-variation">Reconstruction error pushes the latent space to capture meaningful variation</h3>
<p>The first portion $-\mathbb{E}[ \log_{q_\phi (\mathbf{z} \vert \mathbf{x})} p_\theta (\mathbf{x} \vert \mathbf{z})]$ is the log likelihood of the data we observed in the measurement space $\mathbf{x}$, given the latent space $z$.</p>
<p>If the latent space is configured in a way that doesn’t capture much variation in our data, the decoder $p(x \vert z)$ will perform poorly and this log likelihood will be low. Vice-versa, a latent space that captures variation in $x$ will allow the decoder to reconstruct $\mathbf{\hat x}$ much better, and the log likelihood will be higher.</p>
<p>This is known as the <strong>reconstruction error</strong>. Since we want to minimize $L$, better reconstruction will make $L$ more negative. In practice, we estimate reconstruction error using a metric of difference between the observed data $\mathbf{x}$ and the reconstructed data $\mathbf{\hat x}$ using some metric of difference like binary cross-entropy. In order to get reasonable reconstructions, the latent space $q(\mathbf{z} \vert \mathbf{x})$ has to capture variation in the measurement data. The reconstruction error therefore acts as a pressure on the encoder network to capture meaningful variation in $\mathbf{x}$ within the latent variables $\mathbf{z}$. This portion of the objective is actually very similar to a “normal” autoencoder, simply optimizing how close we can make the reconstructed $\mathbf{\hat x}$ to the original $\mathbf{x}$ after forcing it through a smaller number of dimensions $\mathbf{z}$.</p>
<h3 id="divergence-from-a-prior-distribution-enforces-certain-properties-on-the-latent-space">Divergence from a prior distribution enforces certain properties on the latent space</h3>
<p>The second part of the objective</p>
<script type="math/tex; mode=display">\mathbb{D}_{\text{KL}}( q (\mathbf{z} \vert \mathbf{x}) \vert \vert p(\mathbf{z}) )</script>
<p>measures how different the learned latent distribution $q(\mathbf{z} \vert \mathbf{x})$ is from a prior we have on the latent distribution $p(\mathbf{z})$. This difference is measured with the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence?oldformat=true">Kullback-Leibler divergence</a> <sup id="fnref:2"><a href="#fn:2" class="footnote">3</a></sup>, which is shorthanded <script type="math/tex">\mathbb{D}_\text{KL}</script>. The more similar the two are, the lower the value of $\mathbb{D}_{\text{KL}}$ and the therefore the lower the value of the loss $L$. This portion of the loss therefore “pulls” our encoded latent distribution to match some expectations we set in the prior. By selecting a prior $p(\mathbf{z})$, we can therefore enforce some features we want our latent distribution $q(\mathbf{z} \vert \mathbf{x})$ to have.</p>
<p>This prior distribution is often set to an isotropic Gaussian with zero mean <script type="math/tex">\mu = 0</script> and a diagonalized covariance matrix with unit scale $\Sigma = \mathbf{I}$, where $\textbf{I}$ is the identity matrix, so $p(\mathbf{z}) = \mathcal{N}(0, \textbf{I})$ <sup id="fnref:3"><a href="#fn:3" class="footnote">4</a></sup>. Because the prior has a diagonal covariance, this prior pulls the encoded latent space $q(\mathbf{z} \vert \mathbf{x})$ to have independent components.</p>
<h3 id="reviewing-the-objective">Reviewing the Objective</h3>
<p>Taking these two parts together, we see that the objective part <strong>1</strong> optimizes $p(\mathbf{x} \vert \mathbf{z})$ and $q(\mathbf{z} \vert \mathbf{x})$ to recover as much information as possible about $\mathbf{x}$ after encoding to a smaller number of latent dimensions $\mathbf{z}$ and <strong>2</strong> enforces some properties we desire onto the latent space $q(\mathbf{z} \vert \mathbf{x})$ based on a desiderata we express in a prior distribution $p(\mathbf{z})$.</p>
<p>Notice that the objective doesn’t scale either of the reconstruction or divergence loss in any way.
Both are effectively multiplied by a coefficient of $1$ and simply summed to generate the objective.</p>
<h3 id="what-sort-of-latent-spaces-does-this-generate">What sort of latent spaces does this generate?</h3>
<p>The Gaussian prior $p(\mathbf{z}) = \mathcal{N}(0, \mathbf{I})$ pulls the dimensions of the latent space to be independent.</p>
<p>Why? The covariance matrix we specified for the prior is the identity matrix $\mathbf{I}$, where no dimension covaries with any others. However, it does not explicitly force the disentangling between generative factors that we so desire. As we outlined earlier, independence is necessary but not sufficient for disentanglement. The latent spaces generated with this unweighted Gaussian prior often map multiple generative factors to each of the latent dimensions, making them hard to interpret semantically.</p>
<p>We can see an example of this entangling between generative factors in a VAE trained on the dSprites data set. dSprites is a set of synthetic images of white objects moving across black backgrounds. Because the images are synthesized, we have a ground truth set of generative factors – object $x$ coordinate, object $y$ coordinate, shape, size, rotation – and we know the value of each generative factor for each image.</p>
<p>Borrowed from <a href="https://openreview.net/pdf?id=Sy2fzU9gl">Higgins 2017</a> Figure 7, here’s a visualization of the latent space learned for dSprites with a standard VAE on the right side.
Each column in the figure represents a <strong>latent space traversal</strong> – basically, latent vectors $\mathbf{z}$ are sampled with all but one dimension of $\mathbf{z}$ fixed, and the remaining dimension varied over a range. These vectors are then decoded using the trained VAE decoder. This lets us see what information is stored in each dimension.</p>
<p><img src="http://jkimmel.net/assets/images/disentangle/bvae_fig7.png" alt="VAE latent space traversals, Higgins 2017" /></p>
<p>If we look through each latent dimension for the standard VAE on the right side, we see that different generative factors are all mixed together in the model’s latent dimensions. The first dimension is some mixture of positions, shapes and scales. Likewise for the second and third columns. As a human, it’s pretty difficult to interpret what a higher value for dimension number $2$ in this model really means.</p>
<p>Curious readers will note that the columns on the left side of this figure seem to map much more directly to individual parameters we can interpret.
The first one is the $Y$ position, the second is $X$ position, &c.</p>
<p>How can we encourage our models to learn representations more like this one on the left?</p>
<h2 id="modifying-the-vae-objective">Modifying the VAE Objective</h2>
<p>Trying to encourage disentangled representations in VAEs is now a very active field of research, with many groups proposing related ideas.
One common theme explored by several methods to encourage disentanglement is the modification of the VAE objective, <a href="https://arxiv.org/abs/1812.05069">reviewed wonderfully by Tschannen <em>et.al.</em></a>.
How might we modify the objective to encourage this elusive disentanglement property?</p>
<h3 id="beta-vae-obey-your-priors-young-latent-space">$\beta$-VAE: Obey your priors young latent space</h3>
<p>One strategy explored by <a href="https://openreview.net/pdf?id=Sy2fzU9gl">Higgins <em>et. al.</em></a> and <a href="https://arxiv.org/pdf/1804.03599.pdf">Burgess <em>et. al.</em></a> at DeepMind is to simply weight the KL term of the VAE objective more heavily. Recall that the KL term in the VAE objective encourages the latent distribution $q(z \vert x)$ to be similar to $p(z)$. If $p(z) = \mathcal{N}(0, \mathbf{I})$, this puts more emphasis on matching the independence between dimensions implied by the prior.</p>
<p>The objective Higgins <em>et. al.</em> propose is a beautifully simple modification to the VAE objective.</p>
<p>We go from:</p>
<script type="math/tex; mode=display">{L}(x; \theta, \phi) = - \mathbb{E}[ \log_{q_\phi (z \vert x)} p_\theta (x \vert z)] + \mathbb{D}_{\text{KL}}( q_\phi (z \vert x) \vert \vert p(z) )</script>
<p>to:</p>
<script type="math/tex; mode=display">{L}(x; \theta, \phi) = - \mathbb{E}[ \log_{q_\phi (z \vert x)} p_\theta (x \vert z)] + \beta \mathbb{D}_{\text{KL}}( q_\phi (z \vert x) \vert \vert p(z) )</script>
<p>Notice the difference? We added a $\beta$ coefficient in front of the KL term. Higgins <em>et. al.</em> set this term $\beta > 1$ to encourage disentanglement and term their approach $\beta$-VAE.</p>
<p>As simple as this modification is, the results are quite striking. If we revisit the dSprites data set above, we note that simply weighting the KL with $\beta = 4$ leads to dramatically more interpretable latent dimensions than $\beta = 1$. I found this result quite shocking – hyperparameters in the objective really, <em>really</em> matter!</p>
<p>Here’s another example from <a href="https://openreview.net/pdf?id=Sy2fzU9gl">Higgins 2017</a> using a human face dataset.
We see that $\beta$-VAE learns latent dimensions that specifically represent generative factors like azimuth or lighting condition, while a standard VAE objective $\beta = 1$ tends to mix generative factors together in each latent dimension.</p>
<p><img src="http://jkimmel.net/assets/images/disentangle/bvae_fig3.png" alt="VAE latent space traversals, faces" /></p>
<h3 id="why-does-this-work">Why does this work?</h3>
<p>In a follow up paper, Burgess <em>et. al.</em> investigate why this seems to work so well.
They propose that we view $q(\mathbf{z} | \mathbf{x})$ as an <a href="https://arxiv.org/abs/physics/0004057"><strong>information bottleneck</strong></a>.
The basic idea here is that we want $\mathbf{z}$ to contain as much information as possible to improve performance on a task like reconstructing the input, while discarding any information in $\mathbf{x}$ that isn’t necessary to do well on the task.</p>
<p>If we take a look back at the VAE objective, we can convince ourselves that the KL divergence between the encoder $q(\mathbf{z} \vert \mathbf{x})$ and the prior $p(\mathbf{z})$ is actually an upper bound on how much information about $\mathbf{x}$ can pass through to $\mathbf{z}$ <sup id="fnref:4"><a href="#fn:4" class="footnote">5</a></sup>.
This “amount of information” is referred to in information theory as a <a href="https://www.wikiwand.com/en/Channel_capacity"><strong>channel capacity</strong></a>.</p>
<p>By increasing the cost of a high KL divergence, $\beta$-VAE reduces the amount of information that can pass through this bottleneck.
Given this constraint, Burgess <em>et. al.</em> propose that the flexible encoder $q(\mathbf{z} \vert \mathbf{x})$ learns to map generative factors to individual latent dimensions as an efficient way to encode information about $\mathbf{x}$ necessary for reconstruction during decoding.</p>
<p>While somewhat intuitive-feeling, there isn’t much quantitative data backing this argument.
The exact answer to why simply weighting the KL a bit more in the VAE objective gives such remarkable results is still, alas, an open question.</p>
<p>Based on this principle, Burgess <em>et. al.</em> also propose letting more information pass through the bottleneck over the course of training.
The rationale here is that we can first use a small information bottleneck to learn a disentangled but incomplete representation.
After latent dimensions have associated with generative factors, we can allow more information into the bottleneck to improve performance on downstream tasks, like reconstruction or classification, while maintaining this disentanglement.</p>
<p>To do this, the authors suggest another elegant modification to the objective:</p>
<script type="math/tex; mode=display">{L}(x; \theta, \phi) = - \mathbb{E}[ \log_{q_\phi (z \vert x)} p_\theta (x \vert z)] + \beta \vert \mathbb{D}_{\text{KL}}( q_\phi (z \vert x) \vert \vert p(z) ) - C \vert</script>
<p>where $C$ is a constant value that increases over the course of VAE training.
As $C$ increases, we allow the KL divergence term to increase correspondingly without adding to the loss.
The authors don’t provide direct comparisons between this new modification and $\beta$-VAE alone though, so it’s hard to know how much benefit this method provides.</p>
<h2 id="how-do-we-measure-disentanglement">How do we measure disentanglement?</h2>
<p>You may have noticed that previous figures rely on qualitative evaluation of disentanglement.
Mostly, we’ve decoded latent vectors along each dimension and eye-balled the outputs to figure out if they map to a generative factor.
This kind of eye-balling makes for compelling figures, but it’s hard to rigorously compare “how disentangled” two latent spaces are using just this scheme.</p>
<p>Multiple quantitative metrics have also been proposed <sup id="fnref:5"><a href="#fn:5" class="footnote">6</a></sup>, but they can only be used in synthetic data sets where the generative factors are known – like dSprites, where we simulate images and know what parameters are used to generate each image.
Developing methods to measure disentanglement in a quantitative manner seems like an important research direction going forward.</p>
<p>In the case of biological data, we might imagine evaluating disentanglement on generative factors we know, like experimental conditions.
Imagine we’ve captured images of cells treated with different doses of a drug.
If the different doses of the drug mapped to a single dimension of the latent space, we may consider that representation to be more disentangled than a representation where drug dose is explained across many dimensions.</p>
<p>Much of the promise in learning disentangled representations is in the potential for discovery of unknown generative factors.
In an imaging experiment like the one above, perhaps cell motility states map to a dimension – moving, just moved, not moving – even if we didn’t know how to measure those states explicitly beforehand.
Evaluating representations for their ability to disentangle these unknown generative factors seems like a difficult epistemic problem.
How do we evaluate the representation of something we don’t know to measure?
Research in this area may have to rely on qualitative evaluation of latent dimensions for the near future.
In some cases, biological priors may help us in evaluating disentanglement, as shown by work in Casey Greene’s group using gene set enrichment to evaluate representations <sup id="fnref:6"><a href="#fn:6" class="footnote">7</a></sup>.</p>
<h1 id="where-shall-we-venture">Where shall we venture?</h1>
<p>I’d love to see how these recent advances in representation learning translate to biological problems, where it’s sometimes difficult to even know if a representation is disentangled.
This seems intuitively the most useful to me in domains where our prior biological knowledge isn’t well structured.
In some domains like genomics, we have well structured ontologies and strong biological priors for associated gene sets derived from sequence information and decades of empirical observation.
Perhaps in that domain, explicit enforcement of those strong priors will lead to more useful representations <sup id="fnref:7"><a href="#fn:7" class="footnote">8</a></sup> than a VAE may be able to learn, even when encouraged to disentangled.</p>
<p>Cell imaging on the other hand has no such structured ontology of priors.
We don’t have organized expressions for the type of morphologies we expect to associate, the different types of cell geometry features are only vaguely defined, and the causal links between them even less so.
Whereas we understand that transcription factors have target genes, it remains unclear if nuclear geometry directly influences the mitochondrial network.
Imaging and other biological domains where we have less structured prior knowledge may therefore be the lowest hanging fruit for these representation learning schemes in biology.</p>
<h1 id="footnotes">Footnotes</h1>
<div class="footnotes">
<ol>
<li id="fn:0">
<p>For instance, information used to improve reconstructions in a VAE may not be useful for clustering the data in the latent space. This last point is notably hard to prove, as the specific representation that is best for any given task will depend on the task. <a href="https://arxiv.org/abs/1812.05069">See Tschannen <em>et.al.</em> for a formal treatment of this topic.</a>. <a href="#fnref:0" class="reversefootnote">↩</a></p>
</li>
<li id="fn:1">
<p>An Objective function is also known as a loss function or energy criterion. <a href="#fnref:1" class="reversefootnote">↩</a></p>
</li>
<li id="fn:2">
<p>The Kullback-Leibler divergence is a fundamental concept that allows us to measure a distance between two probability distributions. Since “the KL” is a divergence (something like a distance, but it doesn’t obey the <a href="https://en.wikipedia.org/wiki/Triangle_inequality">triangle inequality</a>), it is bounded on the low end at zero and unbounded on the upper end. <script type="math/tex">\mathbb{D}_\text{KL} \rightarrow [0, \infty)</script>. <a href="#fnref:2" class="reversefootnote">↩</a></p>
</li>
<li id="fn:3">
<p>The identity matrix $\mathbf{I}$ is a square matrix with $1$ on the diagonal and $0$ everywhere else. In mathematical notation, $\mathcal{N}(\mu, \Sigma)$ is used to shorthand the <a href="https://en.wikipedia.org/wiki/Normal_distribution?oldformat=true#General_normal_distribution">Gaussian distribution function.</a> <a href="#fnref:3" class="reversefootnote">↩</a></p>
</li>
<li id="fn:4">
<p>We can think of $\mathbf{z}$ as a “channel” through which information about $\mathbf{x}$ can flow to perform downstream tasks, like decoding and reconstructing of $\mathbf{x}$ in an autoencoder.</p>
<p>If we think about how to minimize the KL, we realize that the KL will actually be minimized when $q(z_i \vert x_i) = p(\mathbf{z})$ for every single example.
This is true if we recall that the KL is $0$ when the two distributions it compares are equal.</p>
<p>If we set out prior to $p(\mathbf{z}) = \mathcal{N}(0, \mathbf{I})$ as above, this means that the KL would be minimized when $q(z_i | x_i) = \mathcal{N}(\mu_i = \mathbf{0}, \sigma_i = \mathbf{I})$ for every sample.
If the values are the same, they obviously contain no information about the input $\mathbf{x}$!</p>
<p>So, we can think of the value of the KL as a limit on how much information about $\mathbf{x}$ can pass through $\mathbf{z}$, since minimizing the KL forces us to pass no information about $\mathbf{x}$ in $\mathbf{z}$. <a href="#fnref:4" class="reversefootnote">↩</a></p>
</li>
<li id="fn:5">
<p>See <a href="https://openreview.net/pdf?id=Sy2fzU9gl">Higgins <em>et. al.</em> 2017</a> and <a href="https://arxiv.org/pdf/1802.05983.pdf">Kim <em>et. al.</em> 2018</a>. <a href="#fnref:5" class="reversefootnote">↩</a></p>
</li>
<li id="fn:6">
<p>See <a href="https://www.biorxiv.org/content/10.1101/395947v2">Taroni 2018</a>, <a href="https://www.biorxiv.org/content/10.1101/174474v2">Way 2017</a>, <a href="https://www.biorxiv.org/content/10.1101/174474v2">Way 2019</a> <a href="#fnref:6" class="reversefootnote">↩</a></p>
</li>
<li id="fn:7">
<p>See <a href="https://www.biorxiv.org/content/10.1101/116061v2.full">W. Mao’s great PLIER paper as an example.</a> <a href="#fnref:7" class="reversefootnote">↩</a></p>
</li>
</ol>
</div>An introduction to latent spacesVariational Autoencoding for Biologists2019-01-01T00:00:00+00:002019-01-01T00:00:00+00:00http://jkimmel.net/variational_autoencoding<p>Inspired by <a href="https://arxiv.org/pdf/1705.00092.pdf">Greg Johnson’s Integrated Cell paper</a> on generative modeling of cellular structure, I spent a couple days exploring variational autoencoders to derive useful latent spaces in biological data. I’ve found that I often learn best when preparing to teach. To that aim, I wrote a tutorial on VAEs in the form of a Colab notebook working through mathematical motivations and implementing a simple model. The tutorial goes on to play with this model on some of the Allen Institute for Cell Science data.</p>
<p><a href="https://drive.google.com/open?id=1VyyPD_T_ltY09b4zJFFuo91SM5Ka4DQO"><strong>Find the notebook here.</strong></a></p>
<p>VAEs may seem a bit far afield from biological data analysis at first blush. Without repeating too much of the tutorial here, I find a few specific properties of these models particularly interesting:</p>
<h2 id="1---vae-latent-spaces-are-continuous-and-allow-for-linear-interpolation">1 - VAE latent spaces are continuous and allow for linear interpolation.</h2>
<p>This means that operations like $z_\text{Metaphase} = z_\text{Prophase} + z_\text{Anaphase}/2$ followed by decoding to the measurement space often yield sane outputs.<br />
<strong>Biological Use Case:</strong> Predict unseen intermediary data in a timecourse experiment.</p>
<h2 id="2---vaes-are-generative-models-providing-a-lens-from-which-we-can-begin-to-disentangle-individual-generative-factors-beneath-observed-variables">2 - VAEs are generative models, providing a lens from which we can begin to disentangle individual generative factors beneath observed variables.</h2>
<p>High dimensional biological data, like transcriptomes or images, is the emergent product of underlying <em>generative factors</em>. Think of a generative factor as a semantically meaningful parameter of the process that generated your data. In the case of cell biology, these factors may be aspects of cellular state, environmental variables like media conditions, or time in a dynamic process. <a href="https://arxiv.org/abs/1804.03599">Recent work from DeepMind et. al.</a> on $\beta$-VAEs has shown promise for finding latent dimensions that map to specific generative factors, allowing for interpretable latent dimensions.<br />
<strong>Biological Use Case:</strong> Learn a latent space where generative factors of interest like cell cycle state, differentiation state, &c., map uniquely to individual dimensions, allowing for estimates of covariance between generative factors and measurement space variables.</p>
<h2 id="3---vae-latent-spaces-provide-a-notion-of-variation">3 - VAE latent spaces provide a notion of variation.</h2>
<p>When encoding a given observation, VAEs provide not only a location in the latent space, but an estimate of variance around this mean. This estimate of variation provides a metric of mapping confidence that’s not only useful for estimating the likelihood of alternative outcomes, but can be used for more general tasks like anomaly detection.<br />
<strong>Biological Use Case:</strong> Run an image based screen and use a VAE model trained on control samples to estimate which perturbations deviate from the control distribution.</p>
<p>Take a spin through the tutorial if you’re so inclined. As a teaser, you get to generate interesting figures like this:</p>
<p><img src="http://jkimmel.net/assets/images/vae_tutorial/vae_teaser.png" alt="VAE latent space decoded" /></p>
<p>If you see any issues or hiccups, please feel free to <a href="mailto:jacobkimmel@gmail.com">email me</a>.</p>Inspired by Greg Johnson’s Integrated Cell paper on generative modeling of cellular structure, I spent a couple days exploring variational autoencoders to derive useful latent spaces in biological data. I’ve found that I often learn best when preparing to teach. To that aim, I wrote a tutorial on VAEs in the form of a Colab notebook working through mathematical motivations and implementing a simple model. The tutorial goes on to play with this model on some of the Allen Institute for Cell Science data.Heteromotility data analysis with `hmR`2018-04-16T00:00:00+00:002018-04-16T00:00:00+00:00http://jkimmel.net/hmR_data_analysis<p><a href="https://jacobkimmel.github.io/heteromotility"><code class="highlighter-rouge">Heteromotility</code></a> extracts quantitative features of single cell behavior from cell tracking data. Analyzing this high dimensional data presents a challenge. A typical workflow incorporates various types of analysis, such as unsupervised clustering, dimensionality reduction, visualization, analysis of specific features, pseudotiming, and more.</p>
<p>Previously, <code class="highlighter-rouge">heteromotility</code> data analysis relied on a library of rather unwieldy functions released with the feature extraction tool itself. I’m excited to release <a href="https://github.com/jacobkimmel/hmR"><code class="highlighter-rouge">hmR</code></a> today to lend some sanity to this analysis process.</p>
<p><a href="https://github.com/jacobkimmel/hmR"><code class="highlighter-rouge">hmR</code></a> provides a set of clean semantics around single cell behavior data analysis. Inspired by the semantics of <a href="https://github.com/satijalab/seurat"><code class="highlighter-rouge">Seurat</code></a> in the single cell RNA-seq analysis field, <a href="https://github.com/jacobkimmel/hmR"><code class="highlighter-rouge">hmR</code></a> focuses analysis around a single data object that can be exported and transported across environments while maintaining all intermediates and final products of analysis.</p>
<p><a href="https://github.com/jacobkimmel/hmR"><code class="highlighter-rouge">hmR</code></a> carries users from raw <code class="highlighter-rouge">heteromotility</code> feature exports, all the way to biologically meaningful analysis in just a few simple commands.</p>
<p>As an example, it’s easy to produce visualizations of cell behavior state space in just a few lines with <code class="highlighter-rouge">hmR</code>.</p>
<div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">library</span><span class="p">(</span><span class="n">hmR</span><span class="p">)</span><span class="w">
</span><span class="n">df</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">read.csv</span><span class="p">(</span><span class="s1">'path/to/motility_statistics.csv'</span><span class="p">)</span><span class="w">
</span><span class="n">mot</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">hmMakeObject</span><span class="p">(</span><span class="n">raw.data</span><span class="o">=</span><span class="n">df</span><span class="p">)</span><span class="w">
</span><span class="c1"># Perform hierarchical clustering</span><span class="w">
</span><span class="n">mot</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">hmHClust</span><span class="p">(</span><span class="n">mot</span><span class="p">,</span><span class="w"> </span><span class="n">k</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">3</span><span class="p">,</span><span class="w"> </span><span class="n">method</span><span class="o">=</span><span class="s1">'ward.D2'</span><span class="p">)</span><span class="w">
</span><span class="c1"># Run and plot PCA</span><span class="w">
</span><span class="n">mot</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">hmPCA</span><span class="p">(</span><span class="n">mot</span><span class="p">)</span><span class="w">
</span><span class="n">mot</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">hmPlotPCA</span><span class="p">(</span><span class="n">mot</span><span class="p">)</span><span class="w">
</span><span class="c1"># Run and plot tSNE</span><span class="w">
</span><span class="n">mot</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">hmTSNE</span><span class="p">(</span><span class="n">mot</span><span class="p">)</span><span class="w">
</span><span class="n">mot</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">hmPlotTSNE</span><span class="p">(</span><span class="n">mot</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<p>Running a pseudotime analysis is just as simple</p>
<div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="w">
</span><span class="n">mot</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">hmPseudotime</span><span class="p">(</span><span class="n">mot</span><span class="p">)</span><span class="w">
</span><span class="n">hmPlotPseudotime</span><span class="p">(</span><span class="n">mot</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<p><code class="highlighter-rouge">hmR</code> currently focuses on analysis of cell behavior data in the static context, with dynamic analysis (detailed balance breaking, <em>N</em>-dimensional probability flux analysis, statewise cell transition vectors, etc.) being handled by the original <code class="highlighter-rouge">heteromotility</code> analysis suite.</p>
<p>Give <code class="highlighter-rouge">hmR</code> a try with your single cell behavior data and let me know if I can be helpful!</p>
<p><a href="https://github.com/jacobkimmel/hmR"><strong>hmR Github</strong></a></p>
<div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">library</span><span class="p">(</span><span class="n">devtools</span><span class="p">)</span><span class="w">
</span><span class="n">devtools</span><span class="o">::</span><span class="n">install_github</span><span class="p">(</span><span class="s1">'jacobkimmel/hmR'</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>Heteromotility extracts quantitative features of single cell behavior from cell tracking data. Analyzing this high dimensional data presents a challenge. A typical workflow incorporates various types of analysis, such as unsupervised clustering, dimensionality reduction, visualization, analysis of specific features, pseudotiming, and more.Tiramisu: Fully Connected DenseNets in PyTorch2018-01-18T00:00:00+00:002018-01-18T00:00:00+00:00http://jkimmel.net/tiramisu_pytorch<h1 id="semantic-segmentation">Semantic Segmentation</h1>
<p><a href="">Image segmentation</a> is the first step in many image analysis tasks, spanning fields from human action recognition, to self-driving car automation, to cell biology.</p>
<p>Semantic segmentation approaches are the state-of-the-art in the field. Semantic segmentation trains a model to predict the class of each individual pixel in an image, where classes might be something like “background,” “tree,” “lung,” or “cell.” In recent years, convolutional neural network (CNN) models have dominated standardized metrics for segmentation performance, such as the <a href="http://host.robots.ox.ac.uk:8080/leaderboard/displaylb.php?challengeid=11&compid=6">PASCAL VOC dataset</a>.</p>
<h1 id="patchwise-cnn-semantic-segmentation">Patchwise CNN Semantic Segmentation</h1>
<p>The earliest CNN semantic segmentation models performed “patchwise segmentation.” As the name implies, this involves splitting the image up into a series of patches, and classifying the center pixel of each patch based on the area around it. While effective for some tasks, this approach has several computational inefficiencies.</p>
<p>Given an image <code class="highlighter-rouge">I</code> that is <code class="highlighter-rouge">N x M</code> pixels, you need to generate <code class="highlighter-rouge">NM</code> patches for classification. The vast majority of each of these patches overlaps with <em>many</em> other patches you’re going to classify. This means that not only is memory wasted in representing the same parts of an image multiple times, but that the total computational burden is increased due to this overlap. Specifically, the total area in a given image is simply <code class="highlighter-rouge">N*M</code> pixels, but the total area represented and classified with a patchwise model is <code class="highlighter-rouge">N*M*PatchX*PatchY</code>.</p>
<p>As an example, imagine we have a <code class="highlighter-rouge">512 x 512</code> image and we use <code class="highlighter-rouge">32 x 32</code> patches. The total area we need to both represent and pass through our model is <code class="highlighter-rouge">512**2(32**2)</code> pixels, or <code class="highlighter-rouge">32**2 = 1024</code> fold more than the original image itself!</p>Semantic SegmentationData Cookbook2018-01-12T00:00:00+00:002018-01-12T00:00:00+00:00http://jkimmel.net/data_cookbook<h1 id="data-cookbook">Data Cookbook</h1>
<p>An ever-growing collection of code blocks to perform useful data manipulation and plotting functions with standard Python libraries. This is mostly for my own self-reference, but possibly useful to others.</p>
<h1 id="bash">Bash</h1>
<h2 id="sed">sed</h2>
<h3 id="delete-a-line-matching-a-pattern">Delete a line matching a pattern</h3>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nb">sed</span> <span class="s1">'/some_string/d'</span> <span class="nv">$FILE</span>
</code></pre></div></div>
<h1 id="python">Python</h1>
<p>These code snacks describe useful features of Python 3+ that aren’t always emphasized.</p>
<h2 id="force-only-named-arguments-to-functions">Force only named arguments to functions</h2>
<p>In the below example, arguments following the splat <code class="highlighter-rouge">*</code> must be supplied
as named arguments.</p>
<p>This is somewhat intuitive if you’re used to Pythons splat operator for
<code class="highlighter-rouge">*args</code> or <code class="highlighter-rouge">**kwargs</code>. Here, the lonely splat “catches” positional arguments
passed to the function after its introduction in the definition string.</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>def function(positional, *, named_only0, named_only1):
# do some things
return
def only_takes_named_args(*, named_only0, named_only1):
# do some things
return
</code></pre></div></div>
<p>This is useful when defining functions that may have arguments added and removed
over time, explicitly preventing code from relying on the positional order.</p>
<h1 id="numpy">Numpy</h1>
<h2 id="preserve-array-dimensionality-when-slicing">Preserve array dimensionality when slicing</h2>
<p>When slicing a plane <code class="highlighter-rouge">i</code> from a multidimensional array <code class="highlighter-rouge">A</code>, use <code class="highlighter-rouge">A[i:i+1,...]</code> to preserve the array dimensionality with an empty dimension of size <code class="highlighter-rouge">1</code>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span><span class="mi">5</span><span class="p">,</span><span class="mi">5</span><span class="p">)</span>
<span class="n">i</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">A</span><span class="p">[:,</span><span class="n">i</span><span class="p">,:]</span><span class="o">.</span><span class="n">shape</span> <span class="c1"># (5,5)
</span><span class="n">A</span><span class="p">[:,</span><span class="n">i</span><span class="p">:</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">,:]</span><span class="o">.</span><span class="n">shape</span> <span class="c1"># (5,1,5)
</span></code></pre></div></div>
<h2 id="add-an-empty-dimension-by-indexing">Add an empty dimension by indexing</h2>
<p>You can add an empty dimension of size <code class="highlighter-rouge">1</code> to an <code class="highlighter-rouge">np.ndarray</code> by passing <code class="highlighter-rouge">None</code> to one of the axes while indexing.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">A</span><span class="p">[:,</span> <span class="p">:,</span> <span class="bp">None</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">B</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="c1"># (3, 3, 1)
</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="c1"># (3, 3, 1)
</span>
<span class="n">np</span><span class="o">.</span><span class="nb">all</span><span class="p">(</span><span class="n">B</span> <span class="o">==</span> <span class="n">C</span><span class="p">)</span> <span class="c1"># True
</span></code></pre></div></div>
<h1 id="pandas">Pandas</h1>
<h2 id="split-a-column-by-a-text-delimiter">Split a column by a text delimiter</h2>
<p>Use <code class="highlighter-rouge">.str.split</code></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># split by a '-' delimiter
# split is a pd.DataFrame, with each delimited column separated out
</span><span class="n">split</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="nb">str</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s">'-'</span><span class="p">,</span> <span class="n">expand</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>
<h2 id="replicate-each-row-in-a-dataframe-n-times">Replicate each row in a DataFrame <em>N</em> times</h2>
<p>Use the <code class="highlighter-rouge">.values</code> attribute of a DataFrame and <code class="highlighter-rouge">np.repeat</code></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">N</span> <span class="o">=</span> <span class="mi">3</span> <span class="c1"># times to replicate
</span><span class="n">newdf</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">df</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span>
<span class="n">newdf</span><span class="o">.</span><span class="n">columns</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">columns</span>
</code></pre></div></div>
<h2 id="sort-a-dataframe-by-multiple-columns">Sort a DataFrame by multiple columns</h2>
<p>Use the <code class="highlighter-rouge">sort_values</code> method of DataFrames.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">df</span><span class="o">.</span><span class="n">sort_values</span><span class="p">([</span><span class="s">'a'</span><span class="p">,</span> <span class="s">'b'</span><span class="p">],</span> <span class="n">ascending</span><span class="o">=</span><span class="p">[</span><span class="bp">True</span><span class="p">,</span> <span class="bp">False</span><span class="p">])</span>
</code></pre></div></div>
<p><a href="https://stackoverflow.com/questions/17141558/how-to-sort-a-dataframe-in-python-pandas-by-two-or-more-columns">Credit</a></p>
<h2 id="check-if-rows-are-equal-to-an-array-like-vector">Check if rows are equal to an array-like vector</h2>
<p>Given an array-like vector <code class="highlighter-rouge">v</code> with same dimensionality as rows in a DataFrame <code class="highlighter-rouge">df</code>, check which rows in <code class="highlighter-rouge">df</code> are equal to <code class="highlighter-rouge">v</code>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">([[</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">],[</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">],[</span><span class="mi">4</span><span class="p">,</span><span class="mi">5</span><span class="p">]],</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">'A'</span><span class="p">,</span> <span class="s">'B'</span><span class="p">])</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">])</span>
<span class="p">(</span><span class="n">df</span> <span class="o">==</span> <span class="n">v</span><span class="p">)</span><span class="o">.</span><span class="nb">all</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># checks for boolean True across columns
</span></code></pre></div></div>
<p><a href="https://stackoverflow.com/questions/24761133/pandas-check-if-row-exists-with-certain-values">Credit</a></p>
<h1 id="matplotlib--seaborn">Matplotlib / Seaborn</h1>
<h2 id="create-editable-uncropped-pdf-exports">Create editable, uncropped PDF exports</h2>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib</span>
<span class="c1"># ensure text in PDF exports is editable.
</span><span class="n">matplotlib</span><span class="o">.</span><span class="n">rcParams</span><span class="p">[</span><span class="s">'pdf.fonttype'</span><span class="p">]</span> <span class="o">=</span> <span class="mi">42</span>
<span class="n">matplotlib</span><span class="o">.</span><span class="n">rcParams</span><span class="p">[</span><span class="s">'ps.fonttype'</span><span class="p">]</span> <span class="o">=</span> <span class="mi">42</span>
<span class="c1"># prevent the PDF from being clipped to the "figsize".
# NOTE: this is different than `plt.tight_layout()`
# despite the similar name.
</span><span class="n">matplotlib</span><span class="o">.</span><span class="n">rcParams</span><span class="p">[</span><span class="s">'savefig.bbox'</span><span class="p">]</span> <span class="o">=</span> <span class="s">'tight'</span>
</code></pre></div></div>
<h2 id="rotate-seaborn-axis-labels">Rotate Seaborn axis labels</h2>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">g</span> <span class="o">=</span> <span class="n">sns</span><span class="o">.</span><span class="n">barplot</span><span class="p">(</span><span class="o">...</span><span class="p">)</span>
<span class="n">g</span><span class="o">.</span><span class="n">set_xticklabels</span><span class="p">(</span><span class="n">g</span><span class="o">.</span><span class="n">get_xticklabels</span><span class="p">(),</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">45</span><span class="p">)</span>
</code></pre></div></div>
<h2 id="plot-a-line-with-a-continuous-color-variable">Plot a line with a continuous color variable</h2>
<p>Use a <code class="highlighter-rouge">matplotlib.collections</code> <code class="highlighter-rouge">LineCollection</code> to plot a set of smaller lines
each with a different color, as desired.</p>
<p><a href="https://stackoverflow.com/questions/10252412/matplotlib-varying-color-of-line-to-capture-natural-time-parameterization-in-da/10253183#10253183">StackOverflow Credit</a></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">from</span> <span class="nn">matplotlib.collections</span> <span class="kn">import</span> <span class="n">LineCollection</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">pi</span><span class="p">,</span> <span class="mi">100</span><span class="p">))</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">pi</span><span class="p">,</span> <span class="mi">100</span><span class="p">))</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">,</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="c1"># your "time" variable
</span>
<span class="c1"># set up a list of (x,y) points
</span><span class="n">points</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">x</span><span class="p">,</span><span class="n">y</span><span class="p">])</span><span class="o">.</span><span class="n">transpose</span><span class="p">()</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">)</span>
<span class="k">print</span> <span class="n">points</span><span class="o">.</span><span class="n">shape</span> <span class="c1"># Out: (len(x),1,2)
</span>
<span class="c1"># set up a list of segments
</span><span class="n">segs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">points</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span><span class="n">points</span><span class="p">[</span><span class="mi">1</span><span class="p">:]],</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">print</span> <span class="n">segs</span><span class="o">.</span><span class="n">shape</span> <span class="c1"># Out: ( len(x)-1, 2, 2 )
</span> <span class="c1"># see what we've done here -- we've mapped our (x,y)
</span> <span class="c1"># points to an array of segment start/end coordinates.
</span> <span class="c1"># segs[i,0,:] == segs[i-1,1,:]
</span>
<span class="c1"># make the collection of segments
</span><span class="n">lc</span> <span class="o">=</span> <span class="n">LineCollection</span><span class="p">(</span><span class="n">segs</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">get_cmap</span><span class="p">(</span><span class="s">'viridis'</span><span class="p">))</span>
<span class="n">lc</span><span class="o">.</span><span class="n">set_array</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="c1"># color the segments by our parameter
</span>
<span class="c1"># plot the collection
</span><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">)</span>
<span class="n">ax</span><span class="o">.</span><span class="n">add_collection</span><span class="p">(</span><span class="n">lc</span><span class="p">)</span> <span class="c1"># add the collection to the plot
</span><span class="n">ax</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="nb">min</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="nb">max</span><span class="p">())</span> <span class="c1"># line collections don't auto-scale the plot
</span><span class="n">ax</span><span class="o">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="nb">min</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="nb">max</span><span class="p">())</span>
</code></pre></div></div>
<h2 id="add-a-label-to-heatmap-colorbars-in-seaborn">Add a label to heatmap colorbars in <code class="highlighter-rouge">seaborn</code></h2>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">seaborn</span><span class="o">.</span><span class="n">heatmap</span><span class="p">(</span><span class="n">data</span><span class="p">,</span>
<span class="n">cbar_kws</span><span class="o">=</span><span class="p">{</span><span class="s">'label'</span><span class="p">:</span> <span class="s">'colorbar title'</span><span class="p">})</span>
</code></pre></div></div>
<h2 id="remove-space-between-subplots">Remove space between subplots</h2>
<p>This is useful when plotting a grid of images.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">H</span><span class="p">,</span> <span class="n">W</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">)</span>
<span class="n">fig</span><span class="o">.</span><span class="n">subplots_adjust</span><span class="p">(</span><span class="n">hspace</span><span class="o">=</span><span class="mf">0.020</span><span class="p">,</span>
<span class="n">wspace</span><span class="o">=</span><span class="mf">0.00005</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">):</span>
<span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="o">//</span><span class="mi">2</span><span class="p">,</span> <span class="n">i</span><span class="o">%</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">I</span><span class="p">)</span>
<span class="n">ax</span><span class="o">.</span><span class="n">set_xticks</span><span class="p">([])</span>
<span class="n">ax</span><span class="o">.</span><span class="n">set_yticks</span><span class="p">([])</span>
</code></pre></div></div>
<h2 id="remove-axis-spines-from-a-matplotlib-plot">Remove axis spines from a <code class="highlighter-rouge">matplotlib</code> plot</h2>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">))</span>
<span class="n">ax</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">spines</span><span class="p">[</span><span class="s">'right'</span><span class="p">]</span><span class="o">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
<span class="c1"># `.spines` keys are {'left', 'right', 'top', 'bottom'}
</span></code></pre></div></div>
<h2 id="animate-an-image">Animate an image</h2>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">animation</span><span class="p">,</span> <span class="n">rc</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span><span class="mi">10</span><span class="p">))</span>
<span class="c1"># remove white frame around image
</span><span class="n">fig</span><span class="o">.</span><span class="n">subplots_adjust</span><span class="p">(</span><span class="n">left</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">bottom</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">top</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">wspace</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">hspace</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">im</span> <span class="o">=</span> <span class="n">ax</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">animated</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">updatefig</span><span class="p">(</span><span class="n">idx</span><span class="p">):</span>
<span class="n">im</span><span class="o">.</span><span class="n">set_array</span><span class="p">(</span><span class="n">new_data_iterable</span><span class="p">[</span><span class="n">idx</span><span class="p">])</span>
<span class="k">return</span> <span class="n">im</span>
<span class="n">anim</span> <span class="o">=</span> <span class="n">animation</span><span class="o">.</span><span class="n">FuncAnimation</span><span class="p">(</span>
<span class="n">fig</span><span class="p">,</span> <span class="c1"># figure with initialized artists
</span> <span class="n">updatefig</span><span class="p">,</span> <span class="c1"># updating function
</span> <span class="n">frames</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="c1"># number of iterations, passes `range(0, frames)` to `updatefig`
</span> <span class="n">interval</span><span class="o">=</span><span class="mf">1e3</span><span class="o">/</span><span class="mi">30</span><span class="p">,</span> <span class="c1"># ms between frames, i.e. 1e3/FPS for a FPS argument
</span> <span class="n">blit</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="c1"># drawing optimization
</span>
<span class="c1"># if in a Jupyter notebook, the HTML module can display the animation inline
</span><span class="kn">from</span> <span class="nn">IPython.display</span> <span class="kn">import</span> <span class="n">HTML</span>
<span class="n">HTML</span><span class="p">(</span><span class="n">anim</span><span class="o">.</span><span class="n">to_html5_video</span><span class="p">())</span>
</code></pre></div></div>
<h2 id="add-a-rowcolumn-color-legend-to-seaborn-clustermap">Add a row/column color legend to seaborn clustermap</h2>
<p><a href="http://dawnmy.github.io/2016/10/24/Plot-heatmaap-with-side-color-indicating-the-class-of-variables/">Credit</a></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># define some row clusters
</span><span class="n">row_clusters</span> <span class="o">=</span> <span class="n">get_row_clusters</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="c1"># np.ndarray, np.int
</span>
<span class="c1"># set up a LUT to assign colors to `row_clusters`
</span><span class="n">pal</span> <span class="o">=</span> <span class="n">sns</span><span class="o">.</span><span class="n">color_palette</span><span class="p">(</span><span class="s">'tab20'</span><span class="p">)</span>
<span class="c1"># make a clustermap
</span><span class="n">clmap</span> <span class="o">=</span> <span class="n">sns</span><span class="o">.</span><span class="n">clustermap</span><span class="p">(</span>
<span class="o">...</span><span class="p">,</span>
<span class="n">row_colors</span> <span class="o">=</span> <span class="n">pal</span><span class="p">[</span><span class="n">row_clusters</span><span class="p">]</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">clusters</span><span class="p">):</span>
<span class="n">clmap</span><span class="o">.</span><span class="n">ax_col_dendrogram</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="mi">0</span><span class="p">,</span>
<span class="n">color</span><span class="o">=</span><span class="n">pal</span><span class="p">[</span><span class="n">label</span><span class="p">],</span>
<span class="n">label</span><span class="o">=</span><span class="n">label</span><span class="p">,</span>
<span class="n">linewidth</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">clmap</span><span class="o">.</span><span class="n">ax_col_dendrogram</span><span class="o">.</span><span class="n">legend</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="s">"center"</span><span class="p">,</span> <span class="n">ncol</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">frameon</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>
<h2 id="add-a-second-set-of-xticklabels-to-a-seaborn-heatmap">Add a second set of xticklabels to a seaborn heatmap</h2>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="o">...</span><span class="p">)</span>
<span class="n">sns</span><span class="o">.</span><span class="n">heatmap</span><span class="p">(</span>
<span class="o">...</span><span class="p">,</span>
<span class="n">ax</span><span class="o">=</span><span class="n">ax</span>
<span class="p">)</span>
<span class="c1"># clone the x-axis
</span><span class="n">ax2</span> <span class="o">=</span> <span class="n">ax</span><span class="o">.</span><span class="n">twiny</span><span class="p">()</span>
<span class="n">ax2</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="n">ax</span><span class="o">.</span><span class="n">get_xlim</span><span class="p">())</span>
<span class="n">ax2</span><span class="o">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="n">ax</span><span class="o">.</span><span class="n">get_xticks</span><span class="p">())</span>
<span class="n">ax2</span><span class="o">.</span><span class="n">set_xticklabels</span><span class="p">(</span><span class="n">SOME_NAMES_HERE</span><span class="p">)</span>
<span class="c1"># clean up the plotting aesthetics introduced by the second axis
</span><span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="n">b</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="s">'top'</span><span class="p">,</span> <span class="s">'bottom'</span><span class="p">,</span> <span class="s">'right'</span><span class="p">,</span> <span class="s">'left'</span><span class="p">]:</span>
<span class="n">ax</span><span class="o">.</span><span class="n">spines</span><span class="p">[</span><span class="n">x</span><span class="p">]</span><span class="o">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
<span class="n">ax2</span><span class="o">.</span><span class="n">spines</span><span class="p">[</span><span class="n">x</span><span class="p">]</span><span class="o">.</span><span class="n">set_visible</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>
<h1 id="latex">LaTeX</h1>
<p>I love LaTeX.
LaTex does not love me back.
Here are some snippets to make our relationship more functional.</p>
<h2 id="use-ifthen-control-flow-in-a-latex-build">Use if/then control flow in a LaTeX build</h2>
<div class="language-latex highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">\usepackage</span><span class="p">{</span>etoolbox<span class="p">}</span>
<span class="c">% defines \newtoggle, \settoggle</span>
<span class="k">\newtoggle</span><span class="p">{</span>somevar<span class="p">}</span> <span class="c">% set a new boolean variable</span>
<span class="k">\toggletrue</span><span class="p">{</span>somevar<span class="p">}</span>
<span class="k">\togglefalse</span><span class="p">{</span>somevar<span class="p">}</span>
<span class="c">% run an if then</span>
<span class="k">\iftoggle</span><span class="p">{</span>somevar<span class="p">}{</span>
<span class="c">% do thing</span>
<span class="p">}{</span>
<span class="c">% else, do other thing or blank for nothing</span>
<span class="p">}</span>
</code></pre></div></div>
<h2 id="generate-a-custom-bibtex-style">Generate a custom bibtex style</h2>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># outputs</span>
<span class="c"># some_name.dbj - instructions for making a `bst`</span>
<span class="c"># some_name.bst - compiled `bst`</span>
latex makebst
<span class="c"># to remake a `bst` from the `dbj`</span>
tex some_name.dbj <span class="c"># outputs some_name.bst</span>
</code></pre></div></div>
<h2 id="remove-numbers-or-citation-labels-from-reference-list">Remove numbers or citation labels from reference list</h2>
<p><a href="https://tex.stackexchange.com/questions/35369/replace-or-remove-bibliography-numbers">SE Credit</a></p>
<div class="language-latex highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">\makeatletter</span>
<span class="k">\renewcommand\@</span>biblabel[1]<span class="p">{}</span>
<span class="k">\makeatother</span>
<span class="c">% we can also replace numbers with a common character, like a bullet</span>
<span class="k">\makeatletter</span>
<span class="k">\renewcommand\@</span>biblabel[1]<span class="p">{</span><span class="k">\textbullet</span><span class="p">}</span>
<span class="k">\makeatother</span>
</code></pre></div></div>
<h2 id="customize-figure-captions">Customize figure captions</h2>
<div class="language-latex highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">\usepackage</span><span class="p">{</span>caption<span class="p">}</span>
<span class="c">% remove separator between "Figure XYZ" and caption.</span>
<span class="c">% print the figure number, but no caption</span>
<span class="c">% useful for separating figures and captions in journal proofs</span>
<span class="c">% e.g. "Figure 1", the caption text is suppressed</span>
<span class="k">\captionsetup</span><span class="p">{</span>labelsep=none,textformat=empty<span class="p">}</span>
<span class="c">% use normal caption text, colon figure separator</span>
<span class="c">% e.g. "Figure 1: Caption text here"</span>
<span class="k">\captionsetup</span><span class="p">{</span>labelsep=colon,textformat=plain<span class="p">}</span>
</code></pre></div></div>
<h2 id="suppress-graphics">Suppress graphics</h2>
<p>Journals often want captions and figures separated in a final proof.
We can insert captions without graphics by redefining the <code class="highlighter-rouge">includegraphics</code> command.</p>
<div class="language-latex highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">\renewcommand</span><span class="p">{</span><span class="k">\includegraphics</span><span class="p">}</span>[2][]<span class="p">{}</span>
</code></pre></div></div>Data CookbookConvolutional Gated Recurrent Units in PyTorch2017-11-22T00:00:00+00:002017-11-22T00:00:00+00:00http://jkimmel.net/pytorch_conv_gru<p>Deep neural networks can be incredibly powerful models, but the vanilla variety suffers from a fundamental limitation. DNNs are built in a purely linear fashion, with one layer feeding directly into the next. Once a forward pass is made, vanilla DNNs don’t retain any “memory,” of the inputs they’ve seen before outside the parameters of the model itself. In many circumstances, this is totally fine! The classic supervised image classification task is a good example. A model doesn’t need to “remember,” anything about the inputs it saw previously, outside the parameters of the model, in order to demonstrate super-human performance. There is no temporal relationship between the examples shown to a simple supervised classification network competing on the traditional ImageNet or CIFAR datasets.</p>
<p>But what about situations where temporal relationships do exist? Where “remembering,” the last input you’ve seen is beneficial to understanding the current one? To allow neural networks to “remember,” <a href="https://en.wikipedia.org/wiki/Recurrent_neural_network?oldformat=true">recurrent units</a> have been developed that allow the network to store memory of previous inputs in a “hidden state” $h$. Recurrent units in the most general sense were demonstrated <a href="http://www.pnas.org/content/79/8/2554.abstract">as early as 1982</a>. However, the earliest recurrent neural networks [RNNs] were difficult to train on data with long-term temporal relationships due to the problem of vanishing gradients, <a href="http://ieeexplore.ieee.org/document/279181/">as explored by Bengio <em>et. al.</em></a>. <a href="https://en.wikipedia.org/wiki/Long_short-term_memory?oldformat=true">Long- short-term memory units [LSTMs]</a> and their somewhat simpler relative <a href="https://en.wikipedia.org/wiki/Gated_recurrent_unit?oldformat=true">gated recurrent units [GRUs]</a> have arisen as the recurrent unit of choice to solve these issues, and allow standard training by backpropogation. Chris Olah has <a href="https://en.wikipedia.org/wiki/Gated_recurrent_unit?oldformat=true">an incredibly lucid explanation of how both of these units work</a>.</p>
<p>Both LSTMs and GRUs were originally conceived as fully connected layers. Implementing transformations of the general form</p>
<script type="math/tex; mode=display">g = \sigma (W x_t + U h_{t-1} + b)</script>
<p>where $g$ is the output of a “gate” within the recurrent unit, $\sigma$ is the sigmoid function, $W$ and $U$ are parameterized weight matrices, $x_t$ is the input at time $t$, $h_{t-1}$ is the hidden state from the previous time point $t -1$, and $b$ is a bias.</p>
<p>In this form, it’s obvious that any spatial relationships which exist in the input $x_t$ are lost by the simple linear matrix multiplication of $W x_t$. In the case of image based inputs, it is likely advantageous to preserve this information.</p>
<p><strong>Enter: convolutional gated recurrent units.</strong></p>
<p><a href="https://arxiv.org/abs/1511.06432">Ballas <em>et. al.</em></a> have recently explored a convolutional form of the traditional gated recurrent unit to learn temporal relationships between images of a video. Their formulation of the convolutional GRU simply takes the standard linear GRU</p>
<p><script type="math/tex">z_t = \sigma_g(W_z x_t + U_z h_{t-1} + b_z)</script><br />
<script type="math/tex">r_t = \sigma_g(W_r x_t + U_r h_{t-1} + b_z)</script><br />
<script type="math/tex">h_t = z_t \circ h_{t-1} + (1 - z_t) \circ \sigma_h(W_h x_t + U_h(r_t \circ h_{t-1}) + b_h)</script></p>
<p>and replaces the matrix multiplications with convolutions</p>
<p><script type="math/tex">z_t = \sigma_g(W_z \star x_t + U_z \star h_{t-1} + b_z)</script><br />
<script type="math/tex">r_t = \sigma_g(W_r \star x_t + U_r \star h_{t-1} + b_z)</script><br />
<script type="math/tex">h_t = z_t \circ h_{t-1} + (1 - z_t) \circ \sigma_h(W_h \star x_t + U_h \star (r_t \circ h_{t-1}) + b_h)</script></p>
<p>where $z_t$ is an update gate at time $t$, $r_t$ is a reset gate at time $t$, and $h_t$ is the updated hidden state at time $t$.</p>
<p>With this simple restatement, our GRU now preserves spatial information!</p>
<p>I was interested in using these units for some recent experiments, so I reimplemented them in <a href="https://pytorch.org">PyTorch</a>, borrowing heavily from <a href="https://gist.github.com/halochou/acbd669af86ecb8f988325084ba7a749">@halochou’s gist</a> and <a href="http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#RNN">the PyTorch RNN source.</a></p>
<p>My implementation is <a href="https://github.com/jacobkimmel/pytorch_convgru">available on Github as <code class="highlighter-rouge">pytorch_convgru</code></a>. The implementation currently supports multi-cell layers with different hidden state depths and kernel sizes. Currently, the spatial dimensions of the input are preserved by zero padding in the module. If you want to change the spatial dimensions in the ConvGRU, you can simply place a <code class="highlighter-rouge">.view()</code> op that implements your desired transformation between two separate <code class="highlighter-rouge">ConvGRU</code> modules.</p>
<p>As an example, here we can build a 3-cell ConvGRU with different hidden state depths and kernel sizes.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">convgru</span> <span class="kn">import</span> <span class="n">ConvGRU</span>
<span class="c1"># Generate a ConvGRU with 3 cells
# input_size and hidden_sizes reflect feature map depths.
# Height and Width are preserved by zero padding within the module.
</span><span class="n">model</span> <span class="o">=</span> <span class="n">ConvGRU</span><span class="p">(</span><span class="n">input_size</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">hidden_sizes</span><span class="o">=</span><span class="p">[</span><span class="mi">32</span><span class="p">,</span><span class="mi">64</span><span class="p">,</span><span class="mi">16</span><span class="p">],</span>
<span class="n">kernel_sizes</span><span class="o">=</span><span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="n">n_layers</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">Variable</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">8</span><span class="p">,</span><span class="mi">64</span><span class="p">,</span><span class="mi">64</span><span class="p">))</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># output is a list of sequential hidden representation tensors
</span><span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">output</span><span class="p">))</span> <span class="c1"># list
</span>
<span class="c1"># final output size
</span><span class="k">print</span><span class="p">(</span><span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">size</span><span class="p">())</span> <span class="c1"># torch.Size([1, 16, 64, 64])
</span></code></pre></div></div>Deep neural networks can be incredibly powerful models, but the vanilla variety suffers from a fundamental limitation. DNNs are built in a purely linear fashion, with one layer feeding directly into the next. Once a forward pass is made, vanilla DNNs don’t retain any “memory,” of the inputs they’ve seen before outside the parameters of the model itself. In many circumstances, this is totally fine! The classic supervised image classification task is a good example. A model doesn’t need to “remember,” anything about the inputs it saw previously, outside the parameters of the model, in order to demonstrate super-human performance. There is no temporal relationship between the examples shown to a simple supervised classification network competing on the traditional ImageNet or CIFAR datasets.Generating Model Summaries in PyTorch2017-11-21T00:00:00+00:002017-11-21T00:00:00+00:00http://jkimmel.net/pytorch_generating_model_summaries<p>Some of the most common bugs I encounter when building deep neural network models are dimensionality mismatches, or simple implementation errors that lead to a model architecture different than the one I intended. Judging based on the number of forum posts related to dimensionality errors, I guess I’m not the only one. While these bugs may be trivial to detect, the cryptic error messages produced when CUDA devices run out of memory (i.e. if you unintentionally multiply two huge matrices) aren’t always helpful in tracking these bugs down.</p>
<p>To solve this, the <a href="https://keras.io"><code class="highlighter-rouge">keras</code></a> high-level neural network framework has a nice <code class="highlighter-rouge">model.summary()</code> method that lists all the layers in the network, and the dimensions of their output tensors. This sort of summary allows a user to quickly glance through the structure of their model and identify where dimensionality mismatches may be occurring.</p>
<p>I’ve taken up <a href="https://pytorch.org"><code class="highlighter-rouge">pytorch</code></a> as my DNN lingua-franca, but this is one feature I missed from “define-and-run,” frameworks like <code class="highlighter-rouge">keras</code>. Since <code class="highlighter-rouge">pytorch</code> implements dynamic computational graphs, the input and output dimensions of a given layer aren’t predefined the way they are in define-and-run frameworks. In order to get at this information and provide a tool similar to <code class="highlighter-rouge">model.summary()</code> in <code class="highlighter-rouge">keras</code>, we actually need to pass a sample input through each layer and get it’s output size on the other side!</p>
<p>This isn’t the most elegant way of doing things. I considered briefly implementing a method that identified the common layer types in a <code class="highlighter-rouge">pytorch</code> model, then computed the output dimensions based on known properties of the layer. I decided against this approach though, since it would require defining effects of each layer on dimensionality <em>a priori</em>, such that any custom layers or future layers added to <code class="highlighter-rouge">pytorch</code> would break the summary method for the whole model.</p>
<p>Instead, I implemented the inelegant solution described above of passing a sample input through the model and watching its dimensionality change. The simple tool is available as <a href="https://github.com/jacobkimmel/pytorch_modelsummary"><code class="highlighter-rouge">pytorch_modelsummary</code></a>. As with <a href="http://jkimmel.net/pytorch_estimating_model_size">the model size estimation tool</a> I described last week, the <code class="highlighter-rouge">pytorch_modelsummary</code> tool takes advantage of <code class="highlighter-rouge">pytorch</code>’s <code class="highlighter-rouge">volatile</code> Variables to minimize the memory expense of this forward pass. Model summaries are provided as a <code class="highlighter-rouge">pandas.DataFrame</code>, both for downstream analysis, and because <code class="highlighter-rouge">pandas</code> gives us pretty-printing “for free” :).</p>
<p>An example of using the model summary is provided below:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Define a model
</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">from</span> <span class="nn">torch.autograd</span> <span class="kn">import</span> <span class="n">Variable</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="c1"># Define a simple model to summarize
</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">Model</span><span class="p">,</span><span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv0</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv0</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
<span class="k">return</span> <span class="n">h</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span>
<span class="c1"># Summarize Model
</span><span class="kn">from</span> <span class="nn">pytorch_modelsummary</span> <span class="kn">import</span> <span class="n">ModelSummary</span>
<span class="n">ms</span> <span class="o">=</span> <span class="n">ModelSummary</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">input_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">))</span>
<span class="c1"># Prints
# ------
# Name Type InSz OutSz Params
# 0 conv0 Conv2d [1, 1, 256, 256] [1, 16, 264, 264] 160
# 1 conv1 Conv2d [1, 16, 264, 264] [1, 32, 262, 262] 4640
</span>
<span class="c1"># ms.summary is a Pandas DataFrame
</span><span class="k">print</span><span class="p">(</span><span class="n">ms</span><span class="o">.</span><span class="n">summary</span><span class="p">[</span><span class="s">'Params'</span><span class="p">])</span>
<span class="c1"># 0 160
# 1 4640
# Name: Params, dtype: int64
</span></code></pre></div></div>Some of the most common bugs I encounter when building deep neural network models are dimensionality mismatches, or simple implementation errors that lead to a model architecture different than the one I intended. Judging based on the number of forum posts related to dimensionality errors, I guess I’m not the only one. While these bugs may be trivial to detect, the cryptic error messages produced when CUDA devices run out of memory (i.e. if you unintentionally multiply two huge matrices) aren’t always helpful in tracking these bugs down.PyTorch Model Size Estimation2017-11-17T00:00:00+00:002017-11-17T00:00:00+00:00http://jkimmel.net/pytorch_estimating_model_size<p>When you’re building deep neural network models, running out of GPU memory is one of the most common issues you run into.</p>
<p>Adding capacity to your model by increasing the number of parameters can improve performance (or lead to overfitting!), but also increases the model’s memory requirements. Likewise, increasing the minibatch size during typical gradient descent training improves the gradient estimates and leads to more predictable training results.</p>
<p>I imagine that some years in the future, GPU memory will become so plentiful that this isn’t as common a constraint. However, in the big bright world of today, most of us are still stuck worrying about whether or not our models fit within the capacity of a typical consumer GPU.</p>
<p>I’ve really been loving <a href="pytorch.org">PyTorch</a> for deep neural network development recently. Unfortunately, estimating the size of a model in memory using PyTorch’s native tooling isn’t as easy as in some other frameworks.</p>
<p>To solve that, I built a simple tool – <a href="https://github.com/jacobkimmel/pytorch_modelsize"><code class="highlighter-rouge">pytorch_modelsize</code></a>.</p>
<p>Let’s walk through the logic of how we go about estimating the size of a model.</p>
<p>First, we’ll define a model to play with.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Define a model
</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">from</span> <span class="nn">torch.autograd</span> <span class="kn">import</span> <span class="n">Variable</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">Model</span><span class="p">,</span><span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv0</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv0</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
<span class="k">return</span> <span class="n">h</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span>
</code></pre></div></div>
<p>There are three main components that need to be stored in GPU memory during model training.</p>
<ol>
<li>Model parameters: the actual weights in your network</li>
<li>Input: the input itself has to be in there too!</li>
<li>Intermediate variables: intermediate variables passed between layers, both the values and gradients</li>
</ol>
<p>How do we calculate in human-readable megabytes how big our network will be, considering these three components?</p>
<p>Let’s walk through it step-by-step for an input with a batch size of <code class="highlighter-rouge">1</code>, image dimensions <code class="highlighter-rouge">32 x 32</code>, and <code class="highlighter-rouge">1 channel</code>. By PyTorch convention, we format the data as <code class="highlighter-rouge">(Batch, Channels,
Height, Width)</code> – <code class="highlighter-rouge">(1, 1, 32, 32)</code>.</p>
<p>Calculating the input size first in bits is simple. The number of bits needed to store the input is simply the product of the dimension sizes, multiplied by the bit-depth of the data. In most deep neural network models, we’ll be using double precision floating point numbers with a bit depth of <code class="highlighter-rouge">32</code>. Sometimes, calculations are done with single precision floats at only a <code class="highlighter-rouge">16</code> bit depth.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">bits</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">input_size</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">)</span>
<span class="n">input_bits</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">input_size</span><span class="p">)</span><span class="o">*</span><span class="n">bits</span>
<span class="k">print</span><span class="p">(</span><span class="n">input_bits</span><span class="p">)</span> <span class="c1"># 32768
</span></code></pre></div></div>
<p>Calculating the size of the parameters is similarly fairly simple. Here, we utilize the <code class="highlighter-rouge">.modules()</code> attribute of <code class="highlighter-rouge">torch.nn.Module</code>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">mods</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">modules</span><span class="p">())</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="nb">len</span><span class="p">(</span><span class="n">mods</span><span class="p">)):</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">mods</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">p</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">p</span><span class="p">)):</span>
<span class="n">sizes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">p</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">size</span><span class="p">()))</span>
<span class="n">total_bits</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">sizes</span><span class="p">)):</span>
<span class="n">s</span> <span class="o">=</span> <span class="n">sizes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">bits</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">s</span><span class="p">))</span><span class="o">*</span><span class="n">bits</span>
<span class="n">total_bits</span> <span class="o">+=</span> <span class="n">bits</span>
<span class="k">print</span><span class="p">(</span><span class="n">total_bits</span><span class="p">)</span> <span class="c1"># 148480
</span></code></pre></div></div>
<p>Calculating the size of intermediate variables in PyTorch is a bit trickier. Since PyTorch uses dynamic computational graphs, the output size of each layer in a network isn’t defined <em>a priori</em> like it is in “define-and-run” frameworks. In order to account for dimensionality changes in a general way that supports even custom layers, we need to actually run a sample through a layer and see how its size changes. Here, we’ll do that with a dummy variable with the <code class="highlighter-rouge">volatile = True</code> parameter set to use minimal resources for this probing sojourn.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">input_</span> <span class="o">=</span> <span class="n">Variable</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">input_size</span><span class="p">),</span> <span class="n">volatile</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">mods</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">modules</span><span class="p">())</span>
<span class="n">out_sizes</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">mods</span><span class="p">)):</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">mods</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">input_</span><span class="p">)</span>
<span class="n">out_sizes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">out</span><span class="o">.</span><span class="n">size</span><span class="p">()))</span>
<span class="n">input_</span> <span class="o">=</span> <span class="n">out</span>
<span class="n">total_bits</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">out_sizes</span><span class="p">)):</span>
<span class="n">s</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_sizes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">bits</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">s</span><span class="p">))</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">bits</span>
<span class="n">total_bits</span> <span class="o">+=</span> <span class="n">bits</span>
<span class="c1"># multiply by 2
# we need to store values AND gradients
</span><span class="n">total_bits</span> <span class="o">*=</span> <span class="mi">2</span>
<span class="k">print</span><span class="p">(</span><span class="n">total_bits</span><span class="p">)</span> <span class="c1"># 4595712
</span></code></pre></div></div>
<p>As we see in this example, the majority of the memory is taken up by the intermediate variables and their gradient values.</p>
<p>It becomes obvious when working through this exercise why inference requires so much less memory than training. Storing gradients is expensive!</p>
<p>Using the tool linked above, this process is automated away.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">pytorch_modelsize</span> <span class="kn">import</span> <span class="n">SizeEstimator</span>
<span class="n">se</span> <span class="o">=</span> <span class="n">SizeEstimation</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">input_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">,</span><span class="mi">32</span><span class="p">,</span><span class="mi">32</span><span class="p">))</span>
<span class="n">estimate</span> <span class="o">=</span> <span class="n">se</span><span class="o">.</span><span class="n">estimate_size</span><span class="p">()</span>
<span class="c1"># Returns
# (Size in Megabytes, Total Bits)
</span><span class="k">print</span><span class="p">(</span><span class="n">estimate</span><span class="p">)</span> <span class="c1"># (0.5694580078125, 4776960)
</span></code></pre></div></div>When you’re building deep neural network models, running out of GPU memory is one of the most common issues you run into.