Giter VIP home page Giter VIP logo

Comments (17)

yiyuezhuo avatar yiyuezhuo commented on June 12, 2024 3

I can't see why Bijectors.bijector(::LKJ) = PDBijector() is even possible a method to solve the problem. corr_matrix[J] Omega; contrained parameters on "correlation matrix" space (which diagonal elements are ones), while PDBijector contrain parameters on general "covariance matrix" space (or Positive Definite matrix space). It even hide problem much that logpdf defined in KLJ in Distributions.jl doesn't check if its input is a correlation matrix.

As you may notice, Omega value tend to infinity:

image

and sigma tend to 0:

image

It's the result of LKJ density:

image

Obviously, omega will be driven to infinity to maximize the density while sigma tend to 0 to neutralize the corresponding effect on Sigma.

The illusion that HMC works just shows that the wrong "warm up" phase has not ended. If you run HMC or NUTS long enough, the strange big Omage or breaking is inevitable.

So we must define a bijector dedicated for correlation matrix like corr_matrix in Stan. I will send a PR later.

from bijectors.jl.

mohamed82008 avatar mohamed82008 commented on June 12, 2024 1

Just defining Bijectors.bijector(::LKJ) = PDBijector() should be enough I think.

from bijectors.jl.

joshualeond avatar joshualeond commented on June 12, 2024 1

Thanks @mohamed82008, I defined Bijectors.bijector(::LKJ) as you said and that's gotten me past the original ERROR: MethodError: no method matching bijector(::LKJ{Float64,Int64}) errors. I think I may be up against a user error now though. Here's a short reproducible example attempting to use the LKJ after the bijector was defined:

using Turing, Bijectors, LinearAlgebra, Random
Random.seed!(666)

# generate data
sigma = [1,2,3]
Omega = [1 0.3 0.2;
        0.3 1 0.1;
        0.2 0.1 1]

Sigma = diagm(sigma) * Omega * diagm(sigma)
N = 100
J = 3
y = rand(MvNormal(zeros(J), Sigma), N)'

# model
@model correlation(J, N, y, Zero) = begin
    sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
    Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix
    # covariance matrix
    Sigma = diagm(sigma) * Omega * diagm(sigma)

    for i in 1:N
        y[i,:] ~ MvNormal(Zero, Sigma) # sampling distribution of the observations
    end
end

Bijectors.bijector(::LKJ) = PDBijector()

# attempt to recover parameters
chain = sample(correlation(J, N, y, zeros(J)), NUTS(), 1000)

And the error:

ERROR: PosDefException: matrix is not Hermitian; Cholesky factorization failed.

Perhaps I misspecified something here?

from bijectors.jl.

devmotion avatar devmotion commented on June 12, 2024 1

These issues with positive definiteness are nasty, even if the matrix is guaranteed to be positive (semi-)definite mathematically it can easily happen that due to numerical issues it's not positive (semi-)definite numerically (I've run into this problem multiple times when parameterizing MvNormal with estimated covariance matrices). Sometimes wrapping the matrix into Symmetric helps, but lately only https://github.com/timholy/PositiveFactorizations.jl could fix my numerical issues. However, I've never tried to use it together with Turing, so I'm not sure if AD works with that package.

from bijectors.jl.

devmotion avatar devmotion commented on June 12, 2024 1

Maybe you could avoid that by using

using PDMats, LinearAlgebra
...
_Sigma = Symmetric(Diagonal(sigma) * Omega * Diagonal(sigma))
Sigma = PDMat(_Sigma, cholesky(_Sigma))
....

since it seems there exists an implementation of cholesky for Symmetric which might be more efficient (see, e.g., https://github.com/JuliaLang/julia/blob/7301dc61bdeb5d66e94e15bdfcd4c54f7c90f068/stdlib/LinearAlgebra/src/cholesky.jl#L217-L221). I'm wondering why that is not the default in PDMats πŸ€”

from bijectors.jl.

joshualeond avatar joshualeond commented on June 12, 2024 1

@yiyuezhuo There's another julia repo outside of the Turing org that may have some helpful code for reference: https://github.com/tpapp/TransformVariables.jl

Specifically the TransformVariables.CorrCholeskyFactor code.

from bijectors.jl.

torfjelde avatar torfjelde commented on June 12, 2024

It seems like the issue is with your Sigma = diagm(sigma) * Omega * diagm(sigma) line.

julia> # model
       @model correlation(J, N, y, Zero) = begin
           sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
           Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix
           @info isposdef(Omega)
           # covariance matrix
           Sigma = diagm(sigma) * Omega * diagm(sigma)
           @info Sigma
           @info isposdef(Sigma)

           for i in 1:N
               y[i,:] ~ MvNormal(Zero, Sigma) # sampling distribution of the observations
           end
       end
DynamicPPL.ModelGen{var"###generator#300",(:J, :N, :y, :Zero),(),Tuple{}}(##generator#300, NamedTuple())

julia> 

julia> m = correlation(J, N, y, zeros(J));

julia> m()
[ Info: true
[ Info: [2.43118945431723 0.8857326262243734 -0.38161465118558274; 0.8857326262243734 0.496164416668245 -0.2301838107162483; -0.38161465118558274 -0.2301838107162483 0.28572768840562973]
[ Info: true

julia> m()
[ Info: true
[ Info: [289.05898695571943 -3.900716770274498 -45.83646433466277; -3.900716770274498 0.3260931391923026 2.375071240046279; -45.83646433466277 2.3750712400462786 74.5652927508229]
[ Info: false
ERROR: PosDefException: matrix is not Hermitian; Cholesky factorization failed.

EDIT: haha, nevermind I'm stupid πŸ™ƒ The above is relevant, but I deleted parts of my comment that was just me brain-farting like crazy.

from bijectors.jl.

joshualeond avatar joshualeond commented on June 12, 2024

Thanks for checking out my example @torfjelde! So the diagm(sigma) * Omega * diagm(sigma) is my attempt at reproducing the quad_form_diag available in Stan. I've seen the LKJ used as a prior on the correlation matrix in Stan like the following:

data {
  int<lower=1> N; // number of observations
  int<lower=1> J; // dimension of observations
  vector[J] y[N]; // observations
  vector[J] Zero; // a vector of Zeros (fixed means of observations)
}
parameters {
  corr_matrix[J] Omega; 
  vector<lower=0>[J] sigma; 
}
transformed parameters {
  cov_matrix[J] Sigma; 
  Sigma <- quad_form_diag(Omega, sigma); 
}
model {
  y ~ multi_normal(Zero,Sigma); // sampling distribution of the observations
  sigma ~ cauchy(0, 5); // prior on the standard deviations
  Omega ~ lkj_corr(1); // LKJ prior on the correlation matrix 
}

The Stan docs make it seem pretty straight forward but perhaps there's more going on then I realize:

matrix quad_form_diag(matrix m, vector v)
The quadratic form using the column vector v as a diagonal matrix, i.e., diag_matrix(v) * m * diag_matrix(v).

from bijectors.jl.

devmotion avatar devmotion commented on June 12, 2024

Completely unrelated, but I think you should use Diagonal instead of diagm, since the former does not actually allocate a matrix whereas the latter does. A simple benchmark:

julia> using BenchmarkTools, LinearAlgebra

julia> f(v) = Diagonal(v)

julia> g(v) = diagm(v)

julia> @btime f($(rand(100));
  4.866 ns (1 allocation: 16 bytes)

julia> @btime g($(rand(100));
  5.854 ΞΌs (3 allocations: 78.23 KiB)

from bijectors.jl.

torfjelde avatar torfjelde commented on June 12, 2024

Sometimes wrapping the matrix into Symmetric helps

@devmotion to the rescue!

julia> # model
       @model correlation(J, N, y, Zero) = begin
           sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
           Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix
           @info sigma
           @info Omega
           @info isposdef(Omega)
           # covariance matrix
           Sigma = Symmetric(Diagonal(sigma) * Omega * Diagonal(sigma))
           @info Sigma
           @info isposdef(Sigma)

           for i in 1:N
               y[i,:] ~ MvNormal(Zero, Sigma) # sampling distribution of the observations
           end
       end
DynamicPPL.ModelGen{var"###generator#348",(:J, :N, :y, :Zero),(),Tuple{}}(##generator#348, NamedTuple())

julia> m = correlation(J, N, y, zeros(J));

julia> m()
[ Info: [6.554761110166019, 2.34510326436457, 0.6888832327192145]
[ Info: [1.0 0.47103043693813107 0.5681366777033788; 0.47103043693813107 1.0 -0.41754820600773; 0.5681366777033788 -0.41754820600773 1.0]
[ Info: true
[ Info: [42.96489321134486 7.24048754385414 2.5654012966083335; 7.24048754385414 5.499509320533363 -0.674550094605337; 2.5654012966083335 -0.674550094605337 0.4745601083216754]
[ Info: true

julia> m()
[ Info: [39.64435704629228, 0.9699293307273142, 3.1448223201647334]
[ Info: [1.0 0.051875690114511874 -0.052259251243831094; 0.051875690114511874 1.0 0.8063928903464262; -0.052259251243831094 0.8063928903464262 1.0]
[ Info: true
[ Info: [1571.675045613904 1.9947356925964466 -6.515393871749325; 1.9947356925964466 0.9407629066051357 2.4597042749565188; -6.515393871749325 2.4597042749565188 9.889907425406298]
[ Info: true

julia> m()
[ Info: [0.8388798227529882, 3.3189729374771266, 8.826661070228887]
[ Info: [1.0 0.7104582898219662 -0.4776195773069013; 0.7104582898219662 1.0 -0.07252323132333671; -0.4776195773069013 -0.07252323132333671 1.0]
[ Info: true
[ Info: [0.703719357022085 1.978071774380738 -3.536537920990547; 1.978071774380738 11.015581359705546 -2.124600640530144; -3.536537920990547 -2.124600640530144 77.90994564869416]
[ Info: true

julia> m()
[ Info: [2.052493460163989, 18.75614790558376, 3.1610610751936297]
[ Info: [1.0 -0.2180845855069199 -0.42228217488217845; -0.2180845855069199 1.0 0.09044611498906638; -0.42228217488217845 0.09044611498906638 1.0]
[ Info: true
[ Info: [4.212729404015944 -8.395574136610355 -2.73979089842532; -8.395574136610355 351.793084256134 5.362489474229928; -2.73979089842532 5.362489474229928 9.992307121104306]
[ Info: true

julia> m()
[ Info: [671.1873925742121, 0.3471925977453725, 12.611536505760812]
[ Info: [1.0 0.6777757171051204 -0.10065195388769885; 0.6777757171051204 1.0 0.6147688385377347; -0.10065195388769885 0.6147688385377347 1.0]
[ Info: true
[ Info: [450492.51595056953 157.94295267110346 -851.9890272445987; 157.94295267110346 0.12054269992918003 2.6918465834085405; -851.9890272445987 2.6918465834085405 159.05085303613762]
[ Info: true

julia> m()
[ Info: [0.8079920565082861, 4.041301143186047, 38.36495641353685]
[ Info: [1.0 0.35915995608688434 -0.19698184713779998; 0.35915995608688434 1.0 0.5033179678024089; -0.19698184713779998 0.5033179678024089 1.0]
[ Info: true
[ Info: [0.6528511633804894 1.1727790914573786 -6.1061575530419185; 1.1727790914573786 16.332114929916848 78.03660324156078; -6.1061575530419185 78.03660324156078 1471.8698806125824]
[ Info: true

I'm assuming that the reason why this works is that there exists more numerically stable method for symmetric matrices and by wrapping it in Symmetric you'll make sure to dispatch to the correct method:)

from bijectors.jl.

devmotion avatar devmotion commented on June 12, 2024

Unfortunately, the simple (and slightly inefficient) reason for it is that in https://github.com/JuliaStats/PDMats.jl/blob/00804c3ca96a0839c03d25782a51028fe96fa725/src/pdmat.jl#L20 a new matrix is allocated in which just the upper triangle is mirrored to the lower one. Hence if there was any numerical discrepancy between those, it should be gone afterwards. That's also the reason why it doesn't fix the issues always (according to my experience).

from bijectors.jl.

joshualeond avatar joshualeond commented on June 12, 2024

Thanks for all the tips, I tried out your suggestion with the PDMat but ran into the following error:

ERROR: MethodError: no method matching PDMat{Float64,Symmetric{Float64,Array{Float64,2}}}(::Int64, ::Symmetric{Float64,Array{Float64,2}}, ::Cholesky{Float64,Array{Float64,2}})

If I removed the second argument in PDMat then it did sample with HMC but had some odd results:

using Turing, Distributions, LinearAlgebra, Random, Bijectors, PDMats
Bijectors.bijector(d::LKJ) = Bijectors.PDBijector()

Random.seed!(666)
# generate data
sigma = [1,2,3]
Omega = [1 0.3 0.2;
        0.3 1 0.1;
        0.2 0.1 1]

Sigma = Diagonal(sigma) * Omega * Diagonal(sigma)
N = 100
J = 3
y = rand(MvNormal(zeros(J), Sigma), N)'

# model
@model correlation(J, N, y, Zero) = begin
    sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
    Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix

    _Sigma = Symmetric(Diagonal(sigma) * Omega * Diagonal(sigma))
    Sigma = PDMat(_Sigma)

    for i in 1:N
        y[i,:] ~ MvNormal(Zero, Sigma) # sampling distribution of the observations
    end
    return Sigma
end

chain = sample(correlation(J, N, y, zeros(J)), HMC(0.01, 5), 1000)
chain = sample(correlation(J, N, y, zeros(J)), NUTS(), 1000)
Summary Statistics
   parameters     mean       std  naive_se     mcse      ess   r_hat
  ───────────  ───────  ────────  ────────  ───────  ───────  ──────
  Omega[1, 1]  87.4067  183.9117    5.8158  50.4635   4.6011  1.2026
  Omega[1, 2]  14.2648   21.3900    0.6764   6.7411   4.1112  1.4768
  Omega[1, 3]   4.8683    9.4294    0.2982   2.8001   4.7358  1.2381
  Omega[2, 1]  14.2648   21.3900    0.6764   6.7411   4.1112  1.4768
  Omega[2, 2]  24.3924   35.9843    1.1379  11.4285   4.1922  1.4981
  Omega[2, 3]   0.8385    1.8374    0.0581   0.3172  23.5443  1.1229
  Omega[3, 1]   4.8683    9.4294    0.2982   2.8001   4.7358  1.2381
  Omega[3, 2]   0.8385    1.8374    0.0581   0.3172  23.5443  1.1229
  Omega[3, 3]  10.0522   13.8718    0.4387   4.4399   4.0161  1.5990
     sigma[1]   0.2820    0.2060    0.0065   0.0659   4.0161  1.8310
     sigma[2]   1.1579    0.8786    0.0278   0.2863   4.0161  2.4314
     sigma[3]   2.5682    2.3649    0.0748   0.7436   4.0161  1.8521

The sigmas aren't too far off but the correlation matrix Omega has some relatively large numbers on the diagonal that should be 1. Oddly enough, if I sample with NUTS I end up with the old error again:

ERROR: PosDefException: matrix is not Hermitian; Cholesky factorization failed.

from bijectors.jl.

devmotion avatar devmotion commented on June 12, 2024

Thanks for all the tips, I tried out your suggestion with the PDMat but ran into the following error:

Ah, then probably that's the reason for why PDMats doesn't use Symmetric directly πŸ˜„ So I guess, just remove PDMats, and just pass the Symmetric matrix to MvNormal directly.

BTW probably you should also remove Zeros: if you don't pass a mean vector, MvNormal will automatically have a mean of zero (and use something that's more optimized than just zeros(J)).

from bijectors.jl.

joshualeond avatar joshualeond commented on June 12, 2024

Good point on the Zeros, I've removed them and the PDMat:

using Turing, Distributions, LinearAlgebra, Random, Bijectors
Bijectors.bijector(d::LKJ) = Bijectors.PDBijector()

Random.seed!(666)
# generate data
sigma = [1,2,3]
Omega = [1 0.3 0.2;
        0.3 1 0.1;
        0.2 0.1 1]

Sigma = Diagonal(sigma) * Omega * Diagonal(sigma)
N = 100
J = 3
y = rand(MvNormal(Sigma), N)'

# model
@model correlation(J, N, y) = begin
    sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
    Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix

    Sigma = Symmetric(Diagonal(sigma) * Omega * Diagonal(sigma))

    for i in 1:N
        y[i,:] ~ MvNormal(Sigma) # sampling distribution of the observations
    end
    return Sigma
end

chain = sample(correlation(J, N, y), HMC(0.01, 5), 1000)

With HMC, similar results as before. Not quite recovering the original parameters:

Summary Statistics
   parameters     mean      std  naive_se     mcse     ess   r_hat
  ───────────  ───────  ───────  ────────  ───────  ──────  ──────
  Omega[1, 1]  14.0667  15.5562    0.4919   4.8148  4.0161  2.0984
  Omega[1, 2]   3.0884   3.3027    0.1044   1.0634  4.0161  1.9885
  Omega[1, 3]   1.2705   1.4212    0.0449   0.4322  4.0541  1.9204
  Omega[2, 1]   3.0884   3.3027    0.1044   1.0634  4.0161  1.9885
  Omega[2, 2]   7.4755   9.5501    0.3020   3.0134  4.1477  1.5860
  Omega[2, 3]   0.4702   0.6166    0.0195   0.1384  5.8176  1.3960
  Omega[3, 1]   1.2705   1.4212    0.0449   0.4322  4.0541  1.9204
  Omega[3, 2]   0.4702   0.6166    0.0195   0.1384  5.8176  1.3960
  Omega[3, 3]  25.0593  37.7232    1.1929  11.9310  4.0702  1.5460
     sigma[1]   0.6272   0.5028    0.0159   0.1633  4.0161  2.1600
     sigma[2]   1.5641   1.0518    0.0333   0.3419  4.0161  2.0743
     sigma[3]   2.0635   1.7036    0.0539   0.5489  4.0161  2.5594

With NUTS:

ERROR: PosDefException: matrix is not positive definite; Cholesky factorization failed.

But sometimes when I sample with NUTS I'm actually seeing VERY large numbers, like 1e155 large.

from bijectors.jl.

joshualeond avatar joshualeond commented on June 12, 2024

I had originally opened an issue on the following repo where @trappmartin ended up giving me some advice on this particular issue with the LKJ. I wanted to bring some info from that issue over to this one and close that original issue.

On the other issue Martin ended up restructuring the model specification like the following:

@model correlation(J, N, y) = begin
    sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
    Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix

    L = Diagonal(sigma) * Omega

    for i in 1:N
        y[i,:] ~ MvNormal(L*L') # sampling distribution of the observations
    end
    return L*L'
end

However, even with this I'm still seeing issues with sampling the posterior distribution. Here's an example of the data sampled from the prior for this model:

sample(correlation(J, N, y), Prior(), 2000)
Summary Statistics
   parameters     mean       std  naive_se     mcse        ess   r_hat
  ───────────  ───────  ────────  ────────  ───────  ─────────  ──────
  Omega[1, 1]   1.0000    0.0000    0.0000   0.0000        NaN     NaN
  Omega[1, 2]  -0.0021    0.4940    0.0110   0.0138  1827.3121  1.0008
  Omega[1, 3]   0.0129    0.5078    0.0114   0.0121  2202.8359  0.9996
  Omega[2, 1]  -0.0021    0.4940    0.0110   0.0138  1827.3121  1.0008
  Omega[2, 2]   1.0000    0.0000    0.0000   0.0000        NaN     NaN
  Omega[2, 3]   0.0027    0.5036    0.0113   0.0109  2002.3309  0.9997
  Omega[3, 1]   0.0129    0.5078    0.0114   0.0121  2202.8359  0.9996
  Omega[3, 2]   0.0027    0.5036    0.0113   0.0109  2002.3309  0.9997
  Omega[3, 3]   1.0000    0.0000    0.0000   0.0000        NaN     NaN
     sigma[1]  30.3646  333.1480    7.4494   6.9035  2057.3155  0.9997
     sigma[2]  39.5240  864.4344   19.3293  18.8958  2018.9760  0.9999
     sigma[3]  36.1043  548.1510   12.2570  11.4607  2024.4883  0.9999

Quantiles
   parameters     2.5%    25.0%    50.0%    75.0%     97.5%
  ───────────  ───────  ───────  ───────  ───────  ────────
  Omega[1, 1]   1.0000   1.0000   1.0000   1.0000    1.0000
  Omega[1, 2]  -0.8740  -0.4035  -0.0009   0.4025    0.8551
  Omega[1, 3]  -0.8926  -0.4015   0.0441   0.4347    0.8704
  Omega[2, 1]  -0.8740  -0.4035  -0.0009   0.4025    0.8551
  Omega[2, 2]   1.0000   1.0000   1.0000   1.0000    1.0000
  Omega[2, 3]  -0.8760  -0.4089   0.0092   0.3951    0.8914
  Omega[3, 1]  -0.8926  -0.4015   0.0441   0.4347    0.8704
  Omega[3, 2]  -0.8760  -0.4089   0.0092   0.3951    0.8914
  Omega[3, 3]   1.0000   1.0000   1.0000   1.0000    1.0000
     sigma[1]   0.2368   2.2082   5.1726  12.5942  149.0269
     sigma[2]   0.2114   2.1148   4.9539  12.9399  117.8365
     sigma[3]   0.2138   1.9621   4.8745  11.4329  138.6144

So the prior samples look good with the 1s on the diagonal of the correlation matrix. However, after sampling with HMC we get results like what I've shown previously with low ess, high r_hat, and the estimates not respecting the LKJ prior:

sample(correlation(J, N, y), HMC(0.01, 5), 2000)
Summary Statistics
   parameters    mean     std  naive_se    mcse      ess   r_hat
  ───────────  ──────  ──────  ────────  ──────  ───────  ──────
  Omega[1, 1]  0.5430  0.3717    0.0083  0.0797   8.0321  1.7048
  Omega[1, 2]  0.0781  0.0555    0.0012  0.0106   8.3276  1.1574
  Omega[1, 3]  0.0718  0.0511    0.0011  0.0069  10.2790  1.2537
  Omega[2, 1]  0.0781  0.0555    0.0012  0.0106   8.3276  1.1574
  Omega[2, 2]  0.5147  0.2778    0.0062  0.0601  10.2855  0.9997
  Omega[2, 3]  0.0614  0.0479    0.0011  0.0049  41.6326  1.0054
  Omega[3, 1]  0.0718  0.0511    0.0011  0.0069  10.2790  1.2537
  Omega[3, 2]  0.0614  0.0479    0.0011  0.0049  41.6326  1.0054
  Omega[3, 3]  2.3791  0.9044    0.0202  0.1800  19.8614  0.9999
     sigma[1]  2.7542  1.6814    0.0376  0.3638   8.0321  1.9013
     sigma[2]  5.1374  2.6200    0.0586  0.5586  10.9404  1.0535
     sigma[3]  1.3219  0.4738    0.0106  0.0955  15.5824  1.0331

Quantiles
   parameters     2.5%   25.0%   50.0%   75.0%    97.5%
  ───────────  ───────  ──────  ──────  ──────  ───────
  Omega[1, 1]   0.1534  0.2320  0.4527  0.7055   1.4918
  Omega[1, 2]   0.0227  0.0439  0.0614  0.0898   0.2528
  Omega[1, 3]  -0.0006  0.0356  0.0615  0.1008   0.1993
  Omega[2, 1]   0.0227  0.0439  0.0614  0.0898   0.2528
  Omega[2, 2]   0.1765  0.2989  0.4679  0.6571   1.3056
  Omega[2, 3]  -0.0212  0.0290  0.0547  0.0876   0.1764
  Omega[3, 1]  -0.0006  0.0356  0.0615  0.1008   0.1993
  Omega[3, 2]  -0.0212  0.0290  0.0547  0.0876   0.1764
  Omega[3, 3]   1.1109  1.7512  2.2240  2.8160   4.7351
     sigma[1]   0.6662  1.4098  2.1513  4.2178   6.2623
     sigma[2]   1.5756  3.1137  4.3573  6.7164  11.0704
     sigma[3]   0.5932  0.9867  1.2341  1.5655   2.4542

Things seem to become much less stable when using NUTS as well. Like I mentioned before, if it finishes sampling I end up with very large explosive parameters.

I was hoping that there was a simple solution for using the LKJ distribution with Turing but according to Martin it sounds like something that may be lower level:

I think it’s mostly an issue related to the constraints of LKJ. And an issue on Turing or Bijectors is the best place for it.

If there's anything you'd like me to test and report back to you then I'm definitely willing to help.

from bijectors.jl.

jfb-h avatar jfb-h commented on June 12, 2024

Just to add to this, I tried fitting a hierarchical linear model with a multivariate prior using StatsBase.cor2cov() (I think I just got lucky and this is not generally safe in terms of the PosDefException):

@model hlm(
    y, X, ll,
    ::Type{T}=Vector{Float64},
    ::Type{S}=Matrix{Float64}) where {T, S} = begin

    N, K = size(X)
    L = maximum(ll)
    Ο„ = T(undef, K)
    Ξ² = S(undef, L, K)
    Ξ© = S(undef, K, K)

    Ο„ .~ Exponential(2)
    d = Diagonal(Ο„)
    Ξ© ~ LKJ(K, 3)
    Ξ£ = cor2cov(Ξ©, Ο„)
    for l in 1:L
        Ξ²[l,:] ~ MvNormal(zeros(K), Ξ£)
    end

    ΞΌ = reshape(sum(Ξ²[ll,:] .* X, dims = 2), N)
    Οƒ ~ Exponential(.5)
    y ~ MvNormal(ΞΌ, Οƒ)
end

Similar as in @joshualeond 's case, sampling from the prior works fine but NUTS() goes nuts :) with exploding numbers and a bunch of rejected proposals due to numerical errors.

from bijectors.jl.

yiyuezhuo avatar yiyuezhuo commented on June 12, 2024

His implementation looks fine, but doesn't support Zygote (not mutation-free). Anyway it's still inspiring.

from bijectors.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.