Giter VIP home page Giter VIP logo

Comments (10)

sethaxen avatar sethaxen commented on June 6, 2024 2

Sorry for the very late reply!

@sethaxen, thanks for the clarification. So do I understand this correctly: when estimating a model with a (log) posterior l(x), where x is somehow constraint (eg unit sphere), I could sample in some unconstrained space Zwith a function x = g(z) and then use l(g(z)) + correction for transform?

Yes! This is correct.

I am OK to give up transformations being bijections in this package, but I want to understand it first, so suggestions for reading materials are welcome. In particular,

1. strictly speaking, this is an identification issue, and some samplers don't like that (don't know about NUTS though),

It's only an identification issue if the chosen g and correction induces non-identifiability in Z, and yes, that would then be a problem for NUTS. This is not a problem for points on a sphere, but it is for points on a hemisphere. See below.

2. usual convergence diagnostics (eg Rhat) on the raw `z` would be nonsensical.

I don't think this is any more nonsensical for this z than when working with bijectors. Rhat on z checks for consistency of between- and within-chain variance on Z, which is useful as a check, but ultimately we care more about convergence in terms of x anyways. R-hat on a parameter constrained to some manifold can be a bit strange to interpret anyways and perhaps nonsensical (e.g. R-hat of the zero triangle of a Cholesky factor will be a NaN), but that's just how it is. It probably makes more sense to check R-hat for transforms of manifold-valued parameters. e.g. if one wants to report angle from a unit vector to some reference, R-hat of that angle would probably be more useful than R-hat of any of the coordinates of the unit vector.

This particular approach is I believe a direct consequence of the co-area formula in geometric measure theory, but I unfortunately haven't seen any very accessible explanations of it for this use. So here's a more intuitive explanation in terms of familiar operations. Suppose we have a density $\pi(x)$ for $x \in M$, where $M$ is some manifold, and we assume the density is defined with respect to the Hausdorff measure of $M$ (i.e. the volume measure).

We know that discarding a coordinate in MCMC is equivalent to marginalizing out those coordinates in the target distribution. Similarly, we can augment our distribution. So let $\pi(w)$ for $w \in H$ be a density on manifold $H$. Then $\pi(x,w) = \pi(x) \pi(w)$ is a density on the product manifold $M \times H$.
Now, suppose we have a bijective map $\phi: \mathbb{R}^d \to M \times H$. We can use the usual change-of-variables formula to compute the logdetjac (with a tweak since the Jacobian is now non-square). This gives us the log-density
$$\log\pi(z) = \log\pi(x) + \log\pi(w) + \frac{1}{2}\log\det(J^\top J) = \log\pi(x) + \mathrm{correction}$$

In practice, we end up using a map $f(z) = \Pi \circ \phi: \mathbb{R}^d \to M$, which consists of the bijective map followed by a projection, so that $w$ is automatically discarded. Note that we are free to choose any $\pi(w)$, and we should choose it for speed and to improve the geometry of $\pi(z)$. e.g. we might choose it to ensure that $\pi(z)$ is proper whenever $\pi(x)$ is and to ensure $\pi(z)$ is identifiable.

I have a few ideas for under what circumstances this approach is likely to be useful, but I've never seen a paper that discussed this approach in general terms.

Now for a few examples

The unit sphere

Let $x \in \mathbb{S}^{n} \subset \mathbb{R}^{n+1}$ be a point on the $n$-sphere. We let $w>0$ and $z \in \mathbb{R}^{n+1}$ and define $\phi(z) = \begin{pmatrix} \frac{z^\top}{\lVert z \rVert} & \lVert z \rVert \end{pmatrix}^\top$. Our log-det-Jacobian ends up being $-n \log \lVert z \rVert = -n\log w$. Now, if we were to choose a uniform distribution on the positive reals for $\pi(w)$, which is improper, this would cause $\pi(z)$ to also be improper, and it also causes non-identifiability. We could choose any number of proper priors for $w$, but if we choose the Chi distribution with $n+1$ degrees of freedom, then $\pi(w) = n\log w - \frac{1}{2} w^2 + \mathrm{constant}$, so we end up with
$$\mathrm{correction} = -n\log w + n\log w - \frac{1}{2} w^2 = - \frac{1}{2} \lVert z \rVert^2,$$
which is just a standard multivariate normal on $z$.

In this particular case, if $\pi(x)$ is unimodal, then so is $\pi(z)$. While multiple points in $z$ map to the same $x$, the prior for $\pi(w)$ effectively breaks the tie between those values, so that we keep identifiability. e.g. here are the density contours and draws from a VonMisesFisher([0, 1], 1) distribution in z space:

tmp
…]()

I've noticed with low-dimensional unit vectors, it's much more likely (due to curse of dimensionality) to get low w values, so the geometry has high curvature when $\pi(x)$ is narrow. You can see that around the origin in the above contours. So sampling on the sphere with this approach generally requires a higher delta for adaptation. Using a Chi distribution with more than $n$ degrees of freedom penalizes low $w$ values greater, which improves the geometry. For large $n$, this is not an issue, as there's almost no volume around $w=0$.

The unit hemi-sphere

Let the unit hemisphere be the unit sphere but with the constraint that $x_1 > 0$. We can augment this space with $w \ne 0$. Then let $\phi(z) = \begin{pmatrix} \frac{z^\top}{\operatorname{sign}(z_1) \lVert z \rVert} & \operatorname{sign}(z_1) \lVert z \rVert \end{pmatrix}$. This gives us a logdetjac of $-n\log|w|$. Now for $\pi(w)$, we have some challenges. If we choose a uniform distribution, we're both improper and non-identifiable, since $w$ and $-w$ lead to the same densities. We might choose a normal distribution, but this puts a lot of density around 0, where we have a singularity and want to avoid. So we might choose $\pi(w) = n\log|w| + (w-1)^2/2$. This can work (I've tested), but it causes high curvature around $z_1=0$, so it isn't as efficient to sample with HMC:

tmp

EDIT: a much better approach is to first transform the first coordinate of $z$ to the positive reals using exp and then apply the same transformation as with the unit sphere. i.e. let $\phi_1(z) = (\exp(z_1), z_2, \ldots z_{n+1})^\top$, let $\phi_2(z) = \begin{pmatrix} \frac{z^\top}{\lVert z \rVert} & \lVert z \rVert \end{pmatrix}^\top$, and let $\phi = \phi_2 \circ \phi_1$. Then the log-det-Jacobian term is $\exp(z_1) - n\log w$. If we again choose the Chi distribution with $n+1$ degrees of freedom as the prior for $w$, then we have a nice geometry for the latent distribution on $z$:
tmp

from transformvariables.jl.

tpapp avatar tpapp commented on June 6, 2024 1

Wait, @sethaxen has a very elegant fix for this in #67 (which I was stupidly ignoring at the time, apologies), waiting for his permission to port the code.

from transformvariables.jl.

tpapp avatar tpapp commented on June 6, 2024

It is explained in the Stan manual.

from transformvariables.jl.

jonalm avatar jonalm commented on June 6, 2024

Dear @tpapp, the closest I get in the stan manual is the chapter on Unit Vectors but I don't understand how that explains the implementation of

function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index)

The specific issue i struggle to understand is that the domain of the UnitVector(n) transform is only a half-sphere as the last dimension of the transformed vector is always positive

I'm trying to handle angles in an inference problem as described in the stan manual https://mc-stan.org/docs/2_18/stan-users-guide/unit-vectors-and-rotations.html, i.e. I was hoping to do something like

t = UnitVector(2)
cos_θ, sin_θ  =  transform(t, [1.234])
θ = atan(sin_θ, cos_θ)

But because of the half-plane issue, the range of θ is [0,π] and not [-π,π].

Is this intended behavior? If so, do you have any suggestions on how to handle angles?

best
Jon Eriksen

from transformvariables.jl.

tpapp avatar tpapp commented on June 6, 2024

Sorry for closing it too hastily, and thanks for persisting, I can replicate the bug (I think the range in 2d is (-pi/2,pi/2) though).

I think that a constant is off in the calculations, and we should map from r = (-1, 1) and use its absolute value and the sign. However, I need to check the algebra. If you get to it first don't hesitate to make a PR, or just send me notes and I will code it up.

(Incidentally, I think Stan just uses Marsaglia's method, with an extra df, so it is not much help if we want a bijection).

from transformvariables.jl.

tpapp avatar tpapp commented on June 6, 2024

This is actually a dup of #66, but not closing either in favor of another; I will think about a solution and close them at the same time.

from transformvariables.jl.

tpapp avatar tpapp commented on June 6, 2024

I did some reading about this and doing it "uniformly" seems to be a hard problem. However, that is not needed for out purposes, we merely need a bijection. That said, it having nice numerical properties is useful.

#67 is what Stan uses, but it is not a bijection.

I will test out the quick fix I mention above, and if that does not work try spherical coordinates.

from transformvariables.jl.

sethaxen avatar sethaxen commented on June 6, 2024

I did some reading about this and doing it "uniformly" seems to be a hard problem. However, that is not needed for out purposes, we merely need a bijection. That said, it having nice numerical properties is useful.

#67 is what Stan uses, but it is not a bijection.

Strictly speaking, one does not need a bijection. All one needs is to draw samples in an unconstrained latent space with a transformation to constrained space and a log-density correction so that the resulting transformed samples target the correct distribution. For bijective functions, that log-density correction is a logabsdetjac (more generally, logdetsqrtmetric), but there are corrections for non-bijective transformations, which is what Stan uses here. The caveat is that if you have a non-bijective transformation, then you can only define a right-inverse, so the latent unconstrained space must be the ground truth. i.e. instead of mapping from x in constrained space to z in latent space to draw a sample, mapping back to x, then mapping back to z, instead sample z in latent space and map from z to x only when computing log-density or when returning a draw to the user, keeping the original z as the starting point for the next transition.

There are ample other cases where it makes sense to have non-bijective transformations. e.g. a user wants to sample a point in a disk. One way to do this is to sample a point on a sphere, with a non-bijective projection that discards one of the axes. The resulting distribution is non-uniform on the disk, so there's a log-density correction that makes it uniform.

I will test out the quick fix I mention above, and if that does not work try spherical coordinates.

There is no chart on the sphere that completely covers it. Every chart has singularities, and if the typical set is localized near a singularity, this will cause divergences. This is, I believe, why Stan chooses a non-bijective transformation here, because the geometry then has no singularities and is well-behaved. It's actually least well-behaved I think for low-dimensional vectors, where in the latent space one can move a short distance away from the origin and suddenly a different step size is needed to step the same distance on the surface of the sphere. But due to concentration of measure, for a high-dimensional multivariate normal, the samples concentrate to the surface of a hypersphere anyways, so this parameterization actually produces a really nice geometry for sampling.

from transformvariables.jl.

tpapp avatar tpapp commented on June 6, 2024

@sethaxen, thanks for the clarification. So do I understand this correctly: when estimating a model with a (log) posterior l(x), where x is somehow constraint (eg unit sphere), I could sample in some unconstrained space Zwith a function x = g(z) and then use l(g(z)) + correction for transform?

I am OK to give up transformations being bijections in this package, but I want to understand it first, so suggestions for reading materials are welcome. In particular,

  1. strictly speaking, this is an identification issue, and some samplers don't like that (don't know about NUTS though),
  2. usual convergence diagnostics (eg Rhat) on the raw z would be nonsensical.

from transformvariables.jl.

tpapp avatar tpapp commented on June 6, 2024

@sethaxen, thanks for the detailed answer (sorry to see that MathJax is kind of broken now, hopefully it gets fixed). And sorry for the late reply, I am still digesting this. What I still do not understand is

change-of-variables formula to compute the logdetjac (with a tweak since the Jacobian is now non-square)

ie where the $1/2 \log \det(J^\top J)$ is coming from for non-square Jacobians. Apologies if this is something super-obvious, I have not seen this form of the theorem.

from transformvariables.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.