Giter VIP home page Giter VIP logo

Comments (3)

wangzhishi avatar wangzhishi commented on May 18, 2024

I agree with @steveyang90 that this is only applied to .sampling. Option 1 would trigger too many changes.

First, to have the samples with the drawing order and chain info preserved, we have to turn off permute=True, but it comes with a problem that

  • .extract(permuted=True) returns an ordered dictionary
  • .extract(permuted=False) returns a ndarray with 3 dims: iteration x chains x params

I found that .extract(pars = ['param_1', ..., ], permuted=False) will return an ordered dictionary, however, each keyed item has an extra dimension (corresponding to the chain number), compared to the return of .exract(permuted=True).

So, I'm thinking the following concrete plan for mcmc method, where fit is compiled_stan_file in our case

stan_extract = fit.extract(pars = fit._get_param_names(), permuted=False)
for idx, (key, val) in enumerate(stan_extract.items()):
    if len(fit._get_param_dims()[idx]) == 0:
        stan_extract[key] = val.flatten(order='F')## here `order` is important to make chains flattened one by one
    else:
        stan_extract[key] = val.reshape(-1, val.shape[-1], order='F') 

After this, we got stan_extract which has exactly the same structure as from .extract(permuted=True), but the sample order is preserved. Say, 4 chains, 500 samples each chain, we will have 2000 samples with order [500 in chain1, ..., 500 in chain4] and insider each chain the draw order is also preserved.

Then we can use this ordered samples for diagnostics viz (ofc, needs a bit processing to cut the samples into chains).

This seems to require minimal change in our code base.

from orbit.

steveyang90 avatar steveyang90 commented on May 18, 2024

I think this makes sense @wangzhishi

What is if len(fit._get_param_dims()[idx]) == 0 checking for?

Also, you might find np.transpose() useful here

from orbit.

wangzhishi avatar wangzhishi commented on May 18, 2024

for example, scalar parameter (with dim []) and vector parameter (size 8, with dim [8]) samples has shape (500, 4) and (500, 4, 8) for 4 chains. My proposal is to collapse the shape into (2000,) and (2000, 8), which are consistent with the return of .extract(permuted=True)

from orbit.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.